├── tools ├── simnet │ ├── train │ │ ├── tf │ │ │ ├── nets │ │ │ │ ├── __init__.py │ │ │ │ └── knrm.py │ │ │ ├── layers │ │ │ │ └── __init__.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ └── simnet_loss.py │ │ │ ├── tools │ │ │ │ └── __init__.py │ │ │ ├── utils │ │ │ │ └── __init__.py │ │ │ ├── data │ │ │ │ ├── convert_test_pairwise_data │ │ │ │ ├── convert_test_pointwise_data │ │ │ │ ├── convert_train_pairwise_data │ │ │ │ ├── convert_train_pointwise_data │ │ │ │ ├── test_pairwise_data │ │ │ │ ├── test_pointwise_data │ │ │ │ └── train_pairwise_data │ │ │ ├── run_infer.sh │ │ │ ├── run_train.sh │ │ │ └── examples │ │ │ │ ├── knrm-pointwise.json │ │ │ │ ├── bow-pointwise.json │ │ │ │ ├── lstm-pointwise.json │ │ │ │ ├── bow-pairwise.json │ │ │ │ ├── knrm-pairwise.json │ │ │ │ ├── lstm-pairwise.json │ │ │ │ ├── cnn-pointwise.json │ │ │ │ ├── mvlstm-pointwise.json │ │ │ │ ├── cnn-pairwise.json │ │ │ │ ├── mvlstm-pairwise.json │ │ │ │ ├── pyramid-pointwise.json │ │ │ │ ├── mmdnn-pointwise.json │ │ │ │ ├── pyramid-pairwise.json │ │ │ │ └── mmdnn-pairwise.json │ │ └── paddle │ │ │ ├── nets │ │ │ ├── __init__.py │ │ │ ├── bow.py │ │ │ ├── cnn.py │ │ │ └── gru.py │ │ │ ├── util │ │ │ └── __init__.py │ │ │ ├── layers │ │ │ └── __init__.py │ │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── log_loss.py │ │ │ ├── softmax_cross_entropy_loss.py │ │ │ └── hinge_loss.py │ │ │ ├── optimizers │ │ │ ├── __init__.py │ │ │ └── paddle_optimizers.py │ │ │ ├── data │ │ │ ├── test_pairwise_data │ │ │ ├── test_pointwise_data │ │ │ ├── train_pointwise_data │ │ │ └── train_pairwise_data │ │ │ ├── run_infer.sh │ │ │ ├── run_train.sh │ │ │ └── examples │ │ │ ├── bow_pointwise.json │ │ │ ├── gru_pointwise.json │ │ │ ├── lstm_pointwise.json │ │ │ ├── cnn_pointwise.json │ │ │ ├── bow_pairwise.json │ │ │ ├── gru_pairwise.json │ │ │ ├── lstm_pairwise.json │ │ │ ├── cnn_pairwise.json │ │ │ └── mmdnn-pointwise.json │ ├── preprocess │ │ ├── run_preprocess.sh │ │ └── preprocess.conf │ ├── train.sh │ ├── predict.sh │ └── README.md ├── ltr │ └── xgboost │ │ ├── xgboost │ │ ├── feature.test │ │ ├── train.sh │ │ ├── test.sh │ │ └── train_parameter.conf ├── anyq_deps.sh └── solr │ ├── README.md │ ├── solr_api.py │ ├── anyq_solr.sh │ ├── sample_docs │ └── solr_deply.sh ├── docs ├── images │ └── AnyQ-Framework.png └── semantic_retrieval_tutorial.md ├── include ├── server │ ├── http_server.h │ ├── anyq_postprocessor.h │ ├── anyq_preprocessor.h │ ├── request_postprocess_interface.h │ ├── request_preprocess_interface.h │ ├── session_data_factory.h │ └── solr_accessor.h ├── matching │ ├── lexical │ │ ├── cosine_sim.h │ │ ├── jaccard_sim.h │ │ ├── wordseg_proc.h │ │ ├── bm25_sim.h │ │ ├── edit_distance_sim.h │ │ └── contain_sim.h │ ├── semantic │ │ ├── simnet_paddle_sim.h │ │ └── simnet_tf_sim.h │ └── matching_interface.h ├── retrieval │ ├── term │ │ ├── equal_solr_q_builder.h │ │ ├── date_compare_solr_q_builder.h │ │ ├── synonym_solr_q_builder.h │ │ ├── contain_solr_q_builder.h │ │ ├── boost_solr_q_builder.h │ │ ├── solr_q_interface.h │ │ └── term_retrieval.h │ ├── manual │ │ └── manual_retrieval.h │ ├── retrieval_strategy.h │ ├── semantic │ │ └── semantic_retrieval.h │ └── retrieval_interface.h ├── analysis │ ├── method_wordseg.h │ ├── method_query_intervene.h │ ├── analysis_strategy.h │ ├── method_simnet_emb.h │ └── method_interface.h ├── common │ ├── http_client.h │ ├── plugin_factory.h │ └── paddle_thread_resource.h ├── strategy │ └── anyq_strategy.h └── dict │ ├── dict_interface.h │ ├── dual_dict_wrapper.h │ └── dict_manager.h ├── proto └── http_service.proto ├── AUTHORS ├── cmake ├── external │ ├── jsoncpp.cmake │ ├── jdk.cmake │ ├── curl.cmake │ ├── eigen.cmake │ ├── openssl.cmake │ ├── boost.cmake │ ├── bazel.cmake │ ├── gtest.cmake │ ├── lac.cmake │ ├── xgboost.cmake │ ├── brpc.cmake │ ├── leveldb.cmake │ ├── glog.cmake │ ├── gflags.cmake │ ├── protobuf.cmake │ ├── paddle.cmake │ ├── tensorflow.cmake │ └── zlib.cmake └── proto_build.cmake ├── README.EN.md ├── demo ├── run_server.cpp ├── feature_dump.cpp └── annoy_index_build.cpp └── src ├── analysis ├── method_interface.cpp ├── method_query_intervene.cpp └── method_wordseg.cpp ├── common └── plugin_factory.cpp ├── rank └── predictor │ ├── predict_select_model.cpp │ └── predict_linear_model.cpp ├── dict ├── wordseg_adapter.cpp ├── tf_model_adapter.cpp └── dict_adapter.cpp ├── matching └── lexical │ ├── jaccard_sim.cpp │ ├── cosine_sim.cpp │ └── contain_sim.cpp ├── retrieval ├── term │ ├── equal_solr_q_builder.cpp │ ├── date_compare_solr_q_builder.cpp │ └── synonym_solr_q_builder.cpp └── manual │ └── manual_retrieval.cpp └── server ├── anyq_postprocessor.cpp ├── anyq_preprocessor.cpp └── session_data_factory.cpp /tools/simnet/train/tf/nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/convert_test_pairwise_data: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/simnet/preprocess/run_preprocess.sh: -------------------------------------------------------------------------------- 1 | python preprocess.py preprocess.conf 2 | -------------------------------------------------------------------------------- /tools/ltr/xgboost/xgboost: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/AnyQ/master/tools/ltr/xgboost/xgboost -------------------------------------------------------------------------------- /tools/ltr/xgboost/feature.test: -------------------------------------------------------------------------------- 1 | 0 1:0.400000 2:0.753433 3:0.625919 2 | 0 1:0.400000 2:0.839659 3:1.049439 3 | -------------------------------------------------------------------------------- /docs/images/AnyQ-Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/AnyQ/master/docs/images/AnyQ-Framework.png -------------------------------------------------------------------------------- /include/server/http_server.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/AnyQ/master/include/server/http_server.h -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/convert_test_pointwise_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/AnyQ/master/tools/simnet/train/tf/data/convert_test_pointwise_data -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/convert_train_pairwise_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/AnyQ/master/tools/simnet/train/tf/data/convert_train_pairwise_data -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/convert_train_pointwise_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/AnyQ/master/tools/simnet/train/tf/data/convert_train_pointwise_data -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/test_pairwise_data: -------------------------------------------------------------------------------- 1 | 0 1 1 1 1 1 1 1 1 1 1 1 2 | 0 0 1 1 1 1 1 2 2 2 2 2 3 | 0 1 1 1 1 1 1 1 1 1 1 1 4 | 0 0 1 1 1 1 1 2 2 2 2 2 5 | 0 1 1 1 1 1 1 1 1 1 1 1 6 | 0 0 1 1 1 1 1 2 2 2 2 2 7 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/data/test_pairwise_data: -------------------------------------------------------------------------------- 1 | 0 1 1 1 1 1 1 1 1 1 1 1 2 | 0 0 1 1 1 1 1 2 2 2 2 2 3 | 0 1 1 1 1 1 1 1 1 1 1 1 4 | 0 0 1 1 1 1 1 2 2 2 2 2 5 | 0 1 1 1 1 1 1 1 1 1 1 1 6 | 0 0 1 1 1 1 1 2 2 2 2 2 7 | 8 | -------------------------------------------------------------------------------- /tools/ltr/xgboost/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Binary Classification demo: https://github.com/dmlc/xgboost/tree/master/demo/binary_classification 4 | 5 | #XGBoost Parameters: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md 6 | 7 | ./xgboost train_parameter.conf 8 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/run_infer.sh: -------------------------------------------------------------------------------- 1 | set -e # set -o errexit 2 | set -u # set -o nounset 3 | set -o pipefail 4 | 5 | in_task_type='predict' 6 | in_task_conf='./examples/cnn-pointwise.json' 7 | python tf_simnet.py \ 8 | --task $in_task_type \ 9 | --task_conf $in_task_conf 10 | 11 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/run_infer.sh: -------------------------------------------------------------------------------- 1 | set -e # set -o errexit 2 | set -u # set -o nounset 3 | set -o pipefail 4 | 5 | in_task_type='predict' 6 | in_conf_file_path='examples/cnn_pointwise.json' 7 | python paddle_simnet.py \ 8 | --task_type $in_task_type \ 9 | --conf_file_path $in_conf_file_path 10 | 11 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/run_train.sh: -------------------------------------------------------------------------------- 1 | set -e # set -o errexit 2 | set -u # set -o nounset 3 | set -o pipefail 4 | 5 | in_task_type='train' 6 | in_conf_file_path='examples/gru_pairwise.json' 7 | python paddle_simnet.py \ 8 | --task_type $in_task_type \ 9 | --conf_file_path $in_conf_file_path 10 | 11 | -------------------------------------------------------------------------------- /tools/simnet/preprocess/preprocess.conf: -------------------------------------------------------------------------------- 1 | [GLOBAL] 2 | flow = gendict,convertid,partition,write 3 | model_type = pointwise 4 | src_data = ./sample_data/sample_test.txt 5 | output_dir = ../output 6 | partition_ratio = 9:1 7 | src_data_seg_sep = ' ' 8 | 9 | [FEATURE.CUSTOM] 10 | name = f_custom 11 | seg_grain = custom 12 | -------------------------------------------------------------------------------- /tools/ltr/xgboost/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Get Predictions 4 | #After training, we can use the output model to get the prediction of the test data: 5 | 6 | #../../xgboost mushroom.conf task=pred model_in=0002.model 7 | #For binary classification, the output predictions are probability confidence scores in [0,1], corresponds to the probability of the label to be positive. 8 | 9 | ./xgboost train_parameter.conf task=pred model_in=0010.model 10 | -------------------------------------------------------------------------------- /tools/anyq_deps.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 下载示例配置 4 | if [ ! -d example ]; then 5 | wget --no-check-certificate http://anyq.bj.bcebos.com/example.tar.gz 6 | tar xzf example.tar.gz 7 | fi 8 | # 下载定制版本solr 9 | # 在solr-4.10.3版本基础上加入了百度开源词法分析作为中文分词插件 10 | # 用户可以使用此版本,或使用自己定制的solr版本 11 | # 依赖python2.7(with json) jdk1.8以上 12 | if [ ! -d solr-4.10.3-anyq ]; then 13 | wget --no-check-certificate http://anyq.bj.bcebos.com/solr-4.10.3-anyq.tar.gz 14 | tar xzf solr-4.10.3-anyq.tar.gz 15 | fi 16 | -------------------------------------------------------------------------------- /tools/simnet/train.sh: -------------------------------------------------------------------------------- 1 | set -e # set -o errexit 2 | set -u # set -o nounset 3 | set -o pipefail 4 | 5 | usage() 6 | { 7 | echo "Usage:" 8 | echo "sh `basename $0` [paddle/tensorflow]" 9 | } 10 | 11 | platform=$1 12 | 13 | if [ "paddle" = "$platform" ]; then 14 | cd train/paddle/ 15 | sh run_train.sh 16 | cd ../../ 17 | elif [ "tensorflow" = "$platform" ]; then 18 | cd train/tf/ 19 | sh run_train.sh 20 | cd ../../ 21 | else 22 | echo "illegal platform" 23 | usage 24 | exit 1 25 | fi 26 | -------------------------------------------------------------------------------- /tools/simnet/predict.sh: -------------------------------------------------------------------------------- 1 | set -e # set -o errexit 2 | set -u # set -o nounset 3 | set -o pipefail 4 | 5 | usage() 6 | { 7 | echo "Usage:" 8 | echo "sh `basename $0` [paddle/tensorflow]" 9 | } 10 | 11 | platform=$1 12 | 13 | if [ "paddle" = "$platform" ]; then 14 | cd train/paddle/ 15 | sh run_infer.sh 16 | cd ../../ 17 | elif [ "tensorflow" = "$platform" ]; then 18 | cd train/tf/ 19 | sh run_infer.sh 20 | cd ../../ 21 | else 22 | echo "illegal platform" 23 | usage 24 | exit 1 25 | fi 26 | -------------------------------------------------------------------------------- /proto/http_service.proto: -------------------------------------------------------------------------------- 1 | package anyq; 2 | 3 | option cc_generic_services = true; 4 | 5 | message HttpRequest { 6 | }; 7 | 8 | message HttpResponse { 9 | }; 10 | 11 | service HttpService { 12 | rpc anyq(HttpRequest) returns (HttpResponse); 13 | rpc solr_insert(HttpRequest) returns (HttpResponse); 14 | rpc solr_update(HttpRequest) returns (HttpResponse); 15 | rpc solr_delete(HttpRequest) returns (HttpResponse); 16 | rpc solr_clear(HttpRequest) returns (HttpResponse); 17 | }; 18 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/run_train.sh: -------------------------------------------------------------------------------- 1 | set -e # set -o errexit 2 | set -u # set -o nounset 3 | set -o pipefail 4 | 5 | echo "convert train data" 6 | python ./tools/tf_record_writer.py pointwise ./data/train_pointwise_data ./data/convert_train_pointwise_data 0 32 7 | echo "convert test data" 8 | python ./tools/tf_record_writer.py pointwise ./data/test_pointwise_data ./data/convert_test_pointwise_data 0 32 9 | echo "convert data finish" 10 | 11 | in_task_type='train' 12 | in_task_conf='./examples/lstm-pointwise.json' 13 | python tf_simnet.py \ 14 | --task $in_task_type \ 15 | --task_conf $in_task_conf 16 | 17 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # Contributors should be added to this file in the following format: 2 | # Name or Organization 3 | 4 | Baidu.com, Inc. 5 | 6 | # Initial version authors: 7 | Pang Chao 8 | Sun Yu 9 | Tang Jiji 10 | Wu Yuchuan 11 | Yin Weichong 12 | Zeng Gang 13 | Zhang Han 14 | Zhang Xiyuan 15 | 16 | # Partial list of contributors: 17 | Ding Xinzhe 18 | Huang Wenzhi 19 | Zhu Pengfei 20 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/data/test_pointwise_data: -------------------------------------------------------------------------------- 1 | 1 1 1 1 1 1 1 1 1 1 1 1 1 2 | 1 1 1 1 1 1 2 2 2 2 2 2 0 3 | 1 1 1 1 1 1 1 1 1 1 1 1 1 4 | 1 1 1 1 1 1 2 2 2 2 2 2 0 5 | 1 1 1 1 1 1 1 1 1 1 1 1 1 6 | 1 1 1 1 1 1 2 2 2 2 2 2 0 7 | 1 1 1 1 1 1 1 1 1 1 1 1 1 8 | 1 1 1 1 1 1 2 2 2 2 2 2 0 9 | 1 1 1 1 1 1 1 1 1 1 1 1 1 10 | 1 1 1 1 1 1 2 2 2 2 2 2 0 11 | 1 1 1 1 1 1 1 1 1 1 1 1 1 12 | 1 1 1 1 1 1 2 2 2 2 2 2 0 13 | 1 1 1 1 1 1 1 1 1 1 1 1 1 14 | 1 1 1 1 1 1 2 2 2 2 2 2 0 15 | 1 1 1 1 1 1 1 1 1 1 1 1 1 16 | 1 1 1 1 1 1 2 2 2 2 2 2 0 17 | 1 1 1 1 1 1 1 1 1 1 1 1 1 18 | 1 1 1 1 1 1 2 2 2 2 2 2 0 19 | 1 1 1 1 1 1 1 1 1 1 1 1 1 20 | 1 1 1 1 1 1 2 2 2 2 2 2 0 21 | 1 1 1 1 1 1 1 1 1 1 1 1 1 22 | 1 1 1 1 1 1 2 2 2 2 2 2 0 23 | 24 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/data/train_pointwise_data: -------------------------------------------------------------------------------- 1 | 1 1 1 1 1 1 1 1 1 1 1 1 1 2 | 1 1 1 1 1 1 2 2 2 2 2 2 0 3 | 1 1 1 1 1 1 1 1 1 1 1 1 1 4 | 1 1 1 1 1 1 2 2 2 2 2 2 0 5 | 1 1 1 1 1 1 1 1 1 1 1 1 1 6 | 1 1 1 1 1 1 2 2 2 2 2 2 0 7 | 1 1 1 1 1 1 1 1 1 1 1 1 1 8 | 1 1 1 1 1 1 2 2 2 2 2 2 0 9 | 1 1 1 1 1 1 1 1 1 1 1 1 1 10 | 1 1 1 1 1 1 2 2 2 2 2 2 0 11 | 1 1 1 1 1 1 1 1 1 1 1 1 1 12 | 1 1 1 1 1 1 2 2 2 2 2 2 0 13 | 1 1 1 1 1 1 1 1 1 1 1 1 1 14 | 1 1 1 1 1 1 2 2 2 2 2 2 0 15 | 1 1 1 1 1 1 1 1 1 1 1 1 1 16 | 1 1 1 1 1 1 2 2 2 2 2 2 0 17 | 1 1 1 1 1 1 1 1 1 1 1 1 1 18 | 1 1 1 1 1 1 2 2 2 2 2 2 0 19 | 1 1 1 1 1 1 1 1 1 1 1 1 1 20 | 1 1 1 1 1 1 2 2 2 2 2 2 0 21 | 1 1 1 1 1 1 1 1 1 1 1 1 1 22 | 1 1 1 1 1 1 2 2 2 2 2 2 0 23 | 24 | -------------------------------------------------------------------------------- /cmake/external/jsoncpp.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(JSONCPP_SOURCES_DIR ${THIRD_PARTY_PATH}/jsoncpp) 4 | SET(JSONCPP_INSTALL_DIR ${THIRD_PARTY_PATH}/install/jsoncpp) 5 | SET(JSONCPP_TARGET_VERSION "1.8.4") 6 | 7 | ExternalProject_Add(extern_jsoncpp 8 | GIT_REPOSITORY https://github.com/open-source-parsers/jsoncpp.git 9 | GIT_TAG 1.8.4 10 | PREFIX ${JSONCPP_SOURCES_DIR} 11 | CONFIGURE_COMMAND cd && ${CMAKE_COMMAND} -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_PREFIX=${THIRD_PARTY_PATH} CMakeLists.txt 12 | BUILD_COMMAND cd && make 13 | INSTALL_COMMAND cd && make install 14 | UPDATE_COMMAND "" 15 | ) 16 | 17 | LIST(APPEND external_project_dependencies jsoncpp) 18 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/bow_pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "bow", 4 | "class_name": "BOW", 5 | "emb_dim": 128, 6 | "bow_dim": 128 7 | }, 8 | "loss": { 9 | "module_name": "softmax_cross_entropy_loss", 10 | "class_name": "SoftmaxCrossEntropyLoss" 11 | }, 12 | "optimizer": { 13 | "class_name": "SGDOptimizer", 14 | "learning_rate" : 0.001 15 | }, 16 | "dict_size": 3, 17 | "task_mode": "pointwise", 18 | "train_file_path": "data/train_pointwise_data", 19 | "test_file_path": "data/test_pointwise_data", 20 | "result_file_path": "result_bow_pointwise", 21 | "epoch_num": 10, 22 | "model_path": "models/bow_pointwise", 23 | "use_epoch": 0, 24 | "batch_size": 64, 25 | "num_threads": 4 26 | } 27 | -------------------------------------------------------------------------------- /cmake/external/jdk.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(JDK_SOURCES_DIR ${THIRD_PARTY_PATH}/jdk) 4 | SET(JDK_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | SET(JDK_DOWNLOAD_DIR "${JDK_SOURCES_DIR}/src/") 6 | SET(JDK_URL "http://anyq.bj.bcebos.com/jdk-8u171-linux-x64.tar.gz") 7 | ExternalProject_Add( 8 | extern_jdk 9 | ${EXTERNAL_PROJECT_LOG_ARGS} 10 | DOWNLOAD_DIR ${JDK_DOWNLOAD_DIR} 11 | DOWNLOAD_COMMAND wget --no-check-certificate ${JDK_URL} -c 12 | DOWNLOAD_NO_PROGRESS 1 13 | PREFIX ${JDK_SOURCES_DIR} 14 | BUILD_COMMAND cd ${JDK_DOWNLOAD_DIR}/ && tar -zxvf jdk-8u171-linux-x64.tar.gz 15 | UPDATE_COMMAND "" 16 | CONFIGURE_COMMAND "" 17 | INSTALL_COMMAND cd ${JDK_DOWNLOAD_DIR}/ && cp -rf jdk1.8.0_171/ ${JDK_INSTALL_DIR}/jdk-1.8 18 | BUILD_IN_SOURCE 1 19 | ) 20 | 21 | -------------------------------------------------------------------------------- /cmake/proto_build.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | ExternalProject_Add( 4 | proto_build 5 | ${EXTERNAL_PROJECT_LOG_ARGS} 6 | DEPENDS extern_protobuf 7 | DOWNLOAD_DIR "" 8 | DOWNLOAD_COMMAND "" 9 | DOWNLOAD_NO_PROGRESS 1 10 | PREFIX "" 11 | BUILD_COMMAND "" 12 | UPDATE_COMMAND "" 13 | CONFIGURE_COMMAND "" 14 | INSTALL_COMMAND "" 15 | BUILD_IN_SOURCE 1 16 | 17 | ) 18 | 19 | add_custom_command(TARGET proto_build POST_BUILD 20 | COMMAND ${PROTOC_BIN} --cpp_out=${PROTO_PATH} --proto_path=${PROTO_PATH} ${PROTO_PATH}/*.proto 21 | COMMAND mkdir -p ${PROTO_INC} ${PROTO_SRC} 22 | COMMAND mv ${PROTO_PATH}/*.h ${PROTO_INC} 23 | COMMAND mv ${PROTO_PATH}/*.cc ${PROTO_SRC} 24 | ) 25 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/gru_pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "gru", 4 | "class_name": "GRU", 5 | "emb_dim": 128, 6 | "gru_dim": 128, 7 | "hidden_dim": 128 8 | }, 9 | "loss": { 10 | "module_name": "softmax_cross_entropy_loss", 11 | "class_name": "SoftmaxCrossEntropyLoss" 12 | }, 13 | "optimizer": { 14 | "class_name": "SGDOptimizer", 15 | "learning_rate" : 0.001 16 | }, 17 | "dict_size": 3, 18 | "task_mode": "pointwise", 19 | "train_file_path": "data/train_pointwise_data", 20 | "test_file_path": "data/test_pointwise_data", 21 | "result_file_path": "result_gru_pointwise", 22 | "epoch_num": 10, 23 | "model_path": "models/gru_pointwise", 24 | "use_epoch": 0, 25 | "batch_size": 64, 26 | "num_threads": 4 27 | } 28 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/lstm_pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "lstm", 4 | "class_name": "LSTM", 5 | "emb_dim": 128, 6 | "lstm_dim": 128, 7 | "hidden_dim": 128 8 | }, 9 | "loss": { 10 | "module_name": "softmax_cross_entropy_loss", 11 | "class_name": "SoftmaxCrossEntropyLoss" 12 | }, 13 | "optimizer": { 14 | "class_name": "SGDOptimizer", 15 | "learning_rate" : 0.001 16 | }, 17 | "dict_size": 3, 18 | "task_mode": "pointwise", 19 | "train_file_path": "data/train_pointwise_data", 20 | "test_file_path": "data/test_pointwise_data", 21 | "result_file_path": "result_lstm_pointwise", 22 | "epoch_num": 10, 23 | "model_path": "models/lstm_pointwise", 24 | "use_epoch": 0, 25 | "batch_size": 64, 26 | "num_threads": 4 27 | } 28 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/cnn_pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "cnn", 4 | "class_name": "CNN", 5 | "emb_dim": 128, 6 | "filter_size": 3, 7 | "num_filters": 256, 8 | "hidden_dim": 128 9 | }, 10 | "loss": { 11 | "module_name": "softmax_cross_entropy_loss", 12 | "class_name": "SoftmaxCrossEntropyLoss" 13 | }, 14 | "optimizer": { 15 | "class_name": "SGDOptimizer", 16 | "learning_rate" : 0.001 17 | }, 18 | "dict_size": 3, 19 | "task_mode": "pointwise", 20 | "train_file_path": "data/train_pointwise_data", 21 | "test_file_path": "data/test_pointwise_data", 22 | "result_file_path": "result_cnn_pointwise", 23 | "epoch_num": 10, 24 | "model_path": "models/cnn_pointwise", 25 | "use_epoch": 0, 26 | "batch_size": 64, 27 | "num_threads": 4 28 | } 29 | -------------------------------------------------------------------------------- /cmake/external/curl.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | SET(CURL_PROJECT "extern_curl") 4 | SET(CURL_URL "https://curl.haxx.se/download/curl-7.60.0.tar.gz") 5 | SET(CURL_SOURCES_DIR ${THIRD_PARTY_PATH}/curl) 6 | SET(CURL_DOWNLOAD_DIR "${CURL_SOURCES_DIR}/src/") 7 | 8 | ExternalProject_Add( 9 | ${CURL_PROJECT} 10 | ${EXTERNAL_PROJECT_LOG_ARGS} 11 | DOWNLOAD_DIR ${CURL_DOWNLOAD_DIR} 12 | DOWNLOAD_COMMAND wget --no-check-certificate ${CURL_URL} -c && tar -zxvf curl-7.60.0.tar.gz 13 | DOWNLOAD_NO_PROGRESS 1 14 | PREFIX ${CURL_SOURCES_DIR} 15 | CONFIGURE_COMMAND cd ${CURL_DOWNLOAD_DIR}/curl-7.60.0 && ./configure --prefix=${THIRD_PARTY_PATH} --without-ssl 16 | BUILD_COMMAND cd ${CURL_DOWNLOAD_DIR}/curl-7.60.0 && make 17 | INSTALL_COMMAND cd ${CURL_DOWNLOAD_DIR}/curl-7.60.0 && make install 18 | UPDATE_COMMAND "" 19 | ) 20 | -------------------------------------------------------------------------------- /cmake/external/eigen.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(EIGEN_SOURCES_DIR ${THIRD_PARTY_PATH}/eigen) 4 | SET(EIGEN_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | set(EIGEN_DOWNLOAD_DIR ${EIGEN_SOURCES_DIR}/src/) 6 | ExternalProject_Add( 7 | extern_eigen 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | DOWNLOAD_DIR ${EIGEN_DOWNLOAD_DIR} 10 | DOWNLOAD_COMMAND git clone https://github.com/PX4/eigen.git && cd eigen && git checkout 3.3.4 11 | DOWNLOAD_NO_PROGRESS 1 12 | PREFIX ${EIGEN_SOURCES_DIR} 13 | BUILD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | INSTALL_COMMAND "" 17 | BUILD_IN_SOURCE 1 18 | ) 19 | 20 | add_custom_command(TARGET extern_eigen POST_BUILD 21 | COMMAND mkdir -p third_party/lib/ 22 | COMMAND mkdir -p third_party/include/ 23 | COMMAND cp -rf ${EIGEN_DOWNLOAD_DIR}/eigen/ ${EIGEN_INSTALL_DIR}/include/ 24 | ) 25 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/bow_pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "bow", 4 | "class_name": "BOW", 5 | "emb_dim": 128, 6 | "bow_dim": 128, 7 | "hidden_dim": 128 8 | }, 9 | "loss": { 10 | "module_name": "hinge_loss", 11 | "class_name": "HingeLoss", 12 | "margin": 0.1 13 | }, 14 | "optimizer": { 15 | "class_name": "AdamOptimizer", 16 | "learning_rate": 0.001, 17 | "beta1": 0.9, 18 | "beta2": 0.999, 19 | "epsilon": 1e-08 20 | }, 21 | "dict_size": 3, 22 | "task_mode": "pairwise", 23 | "train_file_path": "data/train_pairwise_data", 24 | "test_file_path": "data/test_pairwise_data", 25 | "result_file_path": "result_bow_pairwise", 26 | "epoch_num": 10, 27 | "model_path": "models/bow_pairwise", 28 | "use_epoch": 0, 29 | "batch_size": 64, 30 | "num_threads": 4 31 | } 32 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/gru_pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "gru", 4 | "class_name": "GRU", 5 | "emb_dim": 128, 6 | "gru_dim": 128, 7 | "hidden_dim": 128 8 | }, 9 | "loss": { 10 | "module_name": "hinge_loss", 11 | "class_name": "HingeLoss", 12 | "margin": 0.1 13 | }, 14 | "optimizer": { 15 | "class_name": "AdamOptimizer", 16 | "learning_rate": 0.001, 17 | "beta1": 0.9, 18 | "beta2": 0.999, 19 | "epsilon": 1e-08 20 | }, 21 | "dict_size": 3, 22 | "task_mode": "pairwise", 23 | "train_file_path": "data/train_pairwise_data", 24 | "test_file_path": "data/test_pairwise_data", 25 | "result_file_path": "result_gru_pairwise", 26 | "epoch_num": 10, 27 | "model_path": "models/gru_pairwise", 28 | "use_epoch": 0, 29 | "batch_size": 64, 30 | "num_threads": 4 31 | } 32 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/test_pointwise_data: -------------------------------------------------------------------------------- 1 | 1 1 1 1 1 1 1 1 1 1 1 1 1 2 | 1 1 1 1 1 1 2 2 2 2 2 2 0 3 | 1 1 1 1 1 1 1 1 1 1 1 1 1 4 | 1 1 1 1 1 1 2 2 2 2 2 2 0 5 | 1 1 1 1 1 1 1 1 1 1 1 1 1 6 | 1 1 1 1 1 1 2 2 2 2 2 2 0 7 | 1 1 1 1 1 1 1 1 1 1 1 1 1 8 | 1 1 1 1 1 1 2 2 2 2 2 2 0 9 | 1 1 1 1 1 1 1 1 1 1 1 1 1 10 | 1 1 1 1 1 1 2 2 2 2 2 2 0 11 | 1 1 1 1 1 1 1 1 1 1 1 1 1 12 | 1 1 1 1 1 1 2 2 2 2 2 2 0 13 | 1 1 1 1 1 1 1 1 1 1 1 1 1 14 | 1 1 1 1 1 1 2 2 2 2 2 2 0 15 | 1 1 1 1 1 1 1 1 1 1 1 1 1 16 | 1 1 1 1 1 1 2 2 2 2 2 2 0 17 | 1 1 1 1 1 1 1 1 1 1 1 1 1 18 | 1 1 1 1 1 1 2 2 2 2 2 2 0 19 | 1 1 1 1 1 1 1 1 1 1 1 1 1 20 | 1 1 1 1 1 1 2 2 2 2 2 2 0 21 | 1 1 1 1 1 1 1 1 1 1 1 1 1 22 | 1 1 1 1 1 1 2 2 2 2 2 2 0 23 | 1 1 1 1 1 1 1 1 1 1 1 1 1 24 | 1 1 1 1 1 1 2 2 2 2 2 2 0 25 | 1 1 1 1 1 1 1 1 1 1 1 1 1 26 | 1 1 1 1 1 1 2 2 2 2 2 2 0 27 | 1 1 1 1 1 1 1 1 1 1 1 1 1 28 | 1 1 1 1 1 1 2 2 2 2 2 2 0 29 | 1 1 1 1 1 1 1 1 1 1 1 1 1 30 | 1 1 1 1 1 1 2 2 2 2 2 2 0 31 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/lstm_pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "lstm", 4 | "class_name": "LSTM", 5 | "emb_dim": 128, 6 | "lstm_dim": 128, 7 | "hidden_dim": 128 8 | }, 9 | "loss": { 10 | "module_name": "hinge_loss", 11 | "class_name": "HingeLoss", 12 | "margin": 0.1 13 | }, 14 | "optimizer": { 15 | "class_name": "AdamOptimizer", 16 | "learning_rate": 0.001, 17 | "beta1": 0.9, 18 | "beta2": 0.999, 19 | "epsilon": 1e-08 20 | }, 21 | "dict_size": 3, 22 | "task_mode": "pairwise", 23 | "train_file_path": "data/train_pairwise_data", 24 | "test_file_path": "data/test_pairwise_data", 25 | "result_file_path": "result_lstm_pairwise", 26 | "epoch_num": 10, 27 | "model_path": "models/lstm_pairwise", 28 | "use_epoch": 0, 29 | "batch_size": 64, 30 | "num_threads": 4 31 | } 32 | -------------------------------------------------------------------------------- /cmake/external/openssl.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(OPENSSL_SOURCES_DIR ${THIRD_PARTY_PATH}/openssl) 4 | SET(OPENSSL_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | set(OPENSSL_DOWNLOAD_DIR "${OPENSSL_SOURCES_DIR}/src/") 6 | ExternalProject_Add( 7 | extern_openssl 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | DOWNLOAD_DIR ${OPENSSL_DOWNLOAD_DIR} 10 | DOWNLOAD_COMMAND git clone https://github.com/openssl/openssl.git && cd openssl && git checkout OpenSSL_1_1_0 11 | DOWNLOAD_NO_PROGRESS 1 12 | PREFIX ${OPENSSL_SOURCES_DIR} 13 | BUILD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | INSTALL_COMMAND cd ${OPENSSL_DOWNLOAD_DIR}/openssl/ && pwd 17 | && ./config --prefix=${OPENSSL_INSTALL_DIR} --libdir=lib && pwd 18 | && make -j32 19 | && make install 20 | BUILD_IN_SOURCE 1 21 | ) 22 | 23 | -------------------------------------------------------------------------------- /tools/ltr/xgboost/train_parameter.conf: -------------------------------------------------------------------------------- 1 | # General Parameters, see comment for each definition 2 | # can be gbtree or gblinear 3 | booster = gbtree 4 | # choose logistic regression loss function for binary classification 5 | objective = reg:logistic 6 | 7 | # Tree Booster Parameters 8 | # step size shrinkage 9 | eta = 0.6 10 | # minimum loss reduction required to make a further partition 11 | gamma = 1.0 12 | # minimum sum of instance weight(hessian) needed in a child 13 | min_child_weight = 1 14 | # maximum depth of a tree 15 | max_depth = 20 16 | 17 | # Task Parameters 18 | # the number of round to do boosting 19 | num_round = 10 20 | # 0 means do not save any model except the final round model 21 | save_period = 1 22 | # The path of training data 23 | data = "feature.dat" 24 | # The path of validation data, used to monitor training process, here [test] sets name of the validation set 25 | eval[test] = "feature.test" 26 | # The path of test data 27 | test:data = "feature.test" 28 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/cnn_pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "cnn", 4 | "class_name": "CNN", 5 | "emb_dim": 128, 6 | "filter_size": 3, 7 | "num_filters": 256, 8 | "hidden_dim": 128 9 | }, 10 | "loss": { 11 | "module_name": "hinge_loss", 12 | "class_name": "HingeLoss", 13 | "margin": 0.1 14 | }, 15 | "optimizer": { 16 | "class_name": "AdamOptimizer", 17 | "learning_rate": 0.001, 18 | "beta1": 0.9, 19 | "beta2": 0.999, 20 | "epsilon": 1e-08 21 | }, 22 | "dict_size": 3, 23 | "task_mode": "pairwise", 24 | "train_file_path": "data/train_pairwise_data", 25 | "test_file_path": "data/test_pairwise_data", 26 | "result_file_path": "result_cnn_pairwise", 27 | "epoch_num": 10, 28 | "model_path": "models/cnn_pairwise", 29 | "use_epoch": 0, 30 | "batch_size": 64, 31 | "num_threads": 4 32 | } 33 | -------------------------------------------------------------------------------- /cmake/external/boost.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | 3 | SET(BOOST_PROJECT "extern_boost") 4 | SET(BOOST_VER "1.41.0") 5 | SET(BOOST_URL "https://jaist.dl.sourceforge.net/project/boost/boost/1.41.0/boost_1_41_0.tar.gz") 6 | SET(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost) 7 | SET(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/") 8 | 9 | ExternalProject_Add( 10 | ${BOOST_PROJECT} 11 | ${EXTERNAL_PROJECT_LOG_ARGS} 12 | DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR} 13 | DOWNLOAD_COMMAND wget --no-check-certificate ${BOOST_URL} -c && tar -zxvf boost_1_41_0.tar.gz 14 | DOWNLOAD_NO_PROGRESS 1 15 | PREFIX ${BOOST_SOURCES_DIR} 16 | CONFIGURE_COMMAND "" 17 | BUILD_COMMAND "" 18 | INSTALL_COMMAND "" 19 | UPDATE_COMMAND "" 20 | ) 21 | 22 | add_custom_command(TARGET extern_boost POST_BUILD 23 | COMMAND mkdir -p ${THIRD_PARTY_PATH}/include/ 24 | COMMAND cp -r ${BOOST_DOWNLOAD_DIR}/boost_1_41_0/boost ${THIRD_PARTY_PATH}/include/ 25 | ) 26 | -------------------------------------------------------------------------------- /tools/solr/README.md: -------------------------------------------------------------------------------- 1 | ## AnyQ solr一键启动 2 | 3 | sh solr/anyq_solr.sh solr/sample_faq 4 | 5 | 配置要求: 6 | - jdk1.8以上,python2.7 7 | - 获取anyq定制solr-4.10.3 8 | 9 | ### solr_deply.sh 接口方法 10 | ↓↓**启动solr服务**↓↓ 11 | ``` 12 | sh solr_deply.sh start solr_home solr_port 13 | ``` 14 | 15 | ↓↓**停止solr服务**↓↓ 16 | ``` 17 | sh solr_deply.sh stop solr_home solr_port 18 | ``` 19 | 20 | ### solr_tools.py 接口方法 21 | ↓↓**添加引擎**↓↓ 22 | ``` 23 | add_engine(host, enginename, port=8983, shard=1, replica=1, maxshardpernode=5, conf='myconf') 24 | ``` 25 | 26 | ↓↓**删除引擎**↓↓ 27 | ``` 28 | delete_engine(host, enginename, port=8983) 29 | ``` 30 | 31 | ↓↓**设置引擎的数据格式**↓↓ 32 | ``` 33 | set_engine_schema(host, enginename, schema_config, port=8983) 34 | chema_config 可以为json文件路径,也可以是一个json list 35 | ``` 36 | 37 | ↓↓**文档灌库**↓↓ 38 | ``` 39 | upload_documents(host, enginename, port=8983, documents="", num_thread=1) 40 | ``` 41 | 42 | ↓↓**清空库**↓↓ 43 | ``` 44 | clear_documents(host, enginename, port=8983) 45 | ``` 46 | 47 | ### solr_tools.py 命令行方式 48 | ``` 49 | 查看命令行使用方法 50 | python solr_tools.py -help 51 | ``` 52 | -------------------------------------------------------------------------------- /cmake/external/bazel.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(BAZEL_SOURCES_DIR ${THIRD_PARTY_PATH}/bazel) 4 | SET(BAZEL_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | SET(BAZEL_DOWNLOAD_DIR "${BAZEL_SOURCES_DIR}/src/") 6 | SET(BAZEL_URL "https://github.com/bazelbuild/bazel/releases/download/0.10.0/bazel-0.10.0-dist.zip") 7 | ExternalProject_Add( 8 | extern_bazel 9 | ${EXTERNAL_PROJECT_LOG_ARGS} 10 | DEPENDS extern_jdk 11 | DOWNLOAD_DIR ${BAZEL_DOWNLOAD_DIR} 12 | DOWNLOAD_COMMAND wget --no-check-certificate ${BAZEL_URL} -c -O bazel-0.10.0-dist.zip && mkdir -p bazel_build 13 | && unzip bazel-0.10.0-dist.zip -d bazel_build 14 | DOWNLOAD_NO_PROGRESS 1 15 | PREFIX ${BAZEL_SOURCES_DIR} 16 | BUILD_COMMAND export JAVA_HOME=${THIRD_PARTY_PATH}/jdk-1.8 && cd ${BAZEL_DOWNLOAD_DIR}/bazel_build && bash compile.sh 17 | UPDATE_COMMAND "" 18 | CONFIGURE_COMMAND "" 19 | INSTALL_COMMAND mkdir -p ${BAZEL_INSTALL_DIR}/bin 20 | && cp ${BAZEL_DOWNLOAD_DIR}/bazel_build/output/bazel ${BAZEL_INSTALL_DIR}/bin 21 | BUILD_IN_SOURCE 1 22 | ) 23 | 24 | -------------------------------------------------------------------------------- /tools/simnet/README.md: -------------------------------------------------------------------------------- 1 | anyq/tools/simnet 2 | 3 | 应用介绍: 4 | 1) 提供语义相似度模型的训练和预测流程; 5 | 2) 支持paddle和tensorflow两种框架进行模型训练; 6 | 3) 提供pointwise和pairwise等方法的神经网络配置; 7 | 8 | 使用说明: 9 | 1) 试用: 用户指定运行平台, 如下: 10 | * paddle框架训练: sh train.sh paddle 11 | * paddle框架预测: sh predict.sh paddle 12 | * tensorflow框架训练: sh train.sh tensorflow 13 | * tensorflow框架预测: sh predict.sh tensorflow 14 | 15 | 2) 用户自定义训练: 16 | * paddle训练目录: anyq/tools/simnet/train/paddle 17 | * run_train.sh内为训练流程: 18 | 指定训练参数in_conf_file_path, 训练参数参见anyq/tools/simnet/train/paddle/examples/cnn_pointwise.json 19 | * run_infer.sh内为预测流程 20 | 21 | * tensorflow训练目录: anyq/tools/simnet/train/tf 22 | * run_train.sh内为训练流程: 23 | 指定训练参数in_task_conf, 训练参数参见anyq/tools/simnet/train/tf/examples/cnn-pointwise.json 24 | * run_infer.sh内为预测流程 25 | 26 | 注意事项: 27 | 用户在本地安装paddle和tensorflow包后即可使用; 28 | 29 | 问题咨询: 30 | 欢迎用户提交任何问题和Bug Report 31 | -------------------------------------------------------------------------------- /cmake/external/gtest.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest) 4 | SET(GTEST_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | 6 | ExternalProject_Add( 7 | extern_gtest 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | DEPENDS ${GTEST_DEPENDS} 10 | GIT_REPOSITORY "https://github.com/google/googletest.git" 11 | GIT_TAG "release-1.8.0" 12 | PREFIX ${GTEST_SOURCES_DIR} 13 | UPDATE_COMMAND "" 14 | CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} 15 | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} 16 | -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} 17 | -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} 18 | -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} 19 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON 20 | -DBUILD_GMOCK=ON 21 | -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} 22 | ${EXTERNAL_OPTIONAL_ARGS} 23 | CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GTEST_INSTALL_DIR} 24 | -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON 25 | -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} 26 | ) 27 | 28 | -------------------------------------------------------------------------------- /cmake/external/lac.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(LAC_SOURCES_DIR ${THIRD_PARTY_PATH}/lac) 4 | SET(LAC_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | SET(LAC_DOWNLOAD_DIR ${LAC_SOURCES_DIR}/src/) 6 | ExternalProject_Add( 7 | extern_lac 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | DEPENDS extern_paddle 10 | DOWNLOAD_DIR ${LAC_DOWNLOAD_DIR} 11 | DOWNLOAD_COMMAND git clone https://github.com/baidu/lac.git && cd lac && git checkout v1.0.0 12 | DOWNLOAD_NO_PROGRESS 1 13 | PREFIX ${LAC_SOURCES_DIR} 14 | BUILD_COMMAND cd ${LAC_DOWNLOAD_DIR}/lac && cmake -DPADDLE_ROOT=${THIRD_PARTY_PATH}/install/paddle/fluid_install_dir/ ./ 15 | UPDATE_COMMAND "" 16 | CONFIGURE_COMMAND "" 17 | INSTALL_COMMAND cd ${LAC_DOWNLOAD_DIR}/lac && make -j8 && make install 18 | BUILD_IN_SOURCE 1 19 | ) 20 | 21 | add_custom_command(TARGET extern_lac POST_BUILD 22 | COMMAND mkdir -p ${LAC_INSTALL_DIR}/include/ 23 | COMMAND mkdir -p ${LAC_INSTALL_DIR}/lib/ 24 | COMMAND cp -r ${LAC_DOWNLOAD_DIR}/lac/output/include/* ${LAC_INSTALL_DIR}/include/ 25 | COMMAND cp -r ${LAC_DOWNLOAD_DIR}/lac/output/lib/* ${LAC_INSTALL_DIR}/lib/ 26 | ) 27 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/losses/log_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import layers.paddle_layers as layers 18 | 19 | 20 | class LogLoss(object): 21 | """ 22 | Log Loss Calculate 23 | """ 24 | def __init__(self, conf_dict): 25 | """ 26 | initialize 27 | """ 28 | pass 29 | 30 | def compute(self, pos, neg): 31 | """ 32 | compute loss 33 | """ 34 | sigmoid = layers.SigmoidLayer() 35 | reduce_mean = layers.ReduceMeanLayer() 36 | loss = reduce_mean.ops(sigmoid.ops(neg - pos)) 37 | return loss 38 | -------------------------------------------------------------------------------- /cmake/external/xgboost.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(XGBOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/xgboost) 4 | SET(XGBOOST_INSTALL_DIR ${XGBOOST_SOURCES_DIR}/src/xgboost) 5 | 6 | ExternalProject_Add( 7 | extern_xgboost 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | DOWNLOAD_DIR ${XGBOOST_SOURCES_DIR}/src/ 10 | DOWNLOAD_COMMAND git clone --recursive https://github.com/dmlc/xgboost.git 11 | DOWNLOAD_NO_PROGRESS 1 12 | PREFIX ${XGBOOST_SOURCES_DIR} 13 | BUILD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | INSTALL_COMMAND cd ${XGBOOST_INSTALL_DIR} && make -j4 17 | BUILD_IN_SOURCE 1 18 | ) 19 | 20 | add_custom_command(TARGET extern_xgboost POST_BUILD 21 | COMMAND mkdir -p third_party/lib/ 22 | COMMAND mkdir -p third_party/include/ 23 | COMMAND cp -r ${XGBOOST_INSTALL_DIR}/include/* third_party/include/ 24 | COMMAND cp -r ${XGBOOST_INSTALL_DIR}/lib/* third_party/lib 25 | COMMAND cp ${XGBOOST_INSTALL_DIR}/rabit/lib/librabit.a ${XGBOOST_INSTALL_DIR}/dmlc-core/libdmlc.a third_party/lib/ 26 | COMMAND cp -r ${XGBOOST_INSTALL_DIR}/src third_party/ 27 | COMMAND cp -r ${XGBOOST_INSTALL_DIR}/dmlc-core/include/* ${XGBOOST_INSTALL_DIR}/rabit/include/* third_party/include/ 28 | ) 29 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/examples/mmdnn-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "net": { 3 | "module_name": "mm_dnn", 4 | "class_name": "MMDNN", 5 | "embedding_dim": 128, 6 | "num_filters": 256, 7 | "lstm_dim": 128, 8 | "hidden_size": 128, 9 | "window_size_left": 3, 10 | "window_size_right": 3, 11 | "dpool_size_left": 2, 12 | "dpool_size_right": 2 13 | }, 14 | "loss": { 15 | "module_name": "softmax_cross_entropy_loss", 16 | "class_name": "SoftmaxCrossEntropyLoss" 17 | }, 18 | "optimizer": { 19 | "class_name": "AdamOptimizer", 20 | "learning_rate": 0.001, 21 | "beta1": 0.9, 22 | "beta2": 0.999, 23 | "epsilon": 1e-08 24 | }, 25 | "use_cuda": 1, 26 | "dict_size": 3, 27 | "max_len_left": 32, 28 | "max_len_right": 32, 29 | "n_class": 2, 30 | "task_mode": "pointwise", 31 | "match_mask" : 1, 32 | "train_file_path": "data/train_pointwise_data", 33 | "test_file_path": "data/test_pointwise_data", 34 | "result_file_path": "result_mm_dnn_pointwise", 35 | "epoch_num": 1, 36 | "model_path": "models/mm_dnn_pointwise", 37 | "use_epoch": 0, 38 | "batch_size": 64, 39 | "num_threads": 6 40 | } 41 | -------------------------------------------------------------------------------- /README.EN.md: -------------------------------------------------------------------------------- 1 | # AnyQ 2 | 3 | AnyQ(ANswer Your Questions) is a configurable & plugable FAQ-based Question Answering framework. SimNet, a Semantic Matching Framework developed by Baidu-NLP, is also conveyed with AnyQ. 4 | 5 | In our FAQ-based QA framework, which is designed to be configurable and plugable, all the processes or functions are plugins. Developers can easily designed their own processes and add to our framework, so they can quickly build QA system for their own application. 6 | 7 | SimNet, first designed in 2013 by Baidu-NLP, is a flexiable semantic matching framework which is widely used in many applications in Baidu. SimNet consists of the neural network structure BOW、CNN、RNN and MM-DNN. Meanwhile, we have implemented more state-of-the-art structures such as MatchPyramid、MV-LSTM、K-NRM. SimNet has a unified interface, implemented with PaddleFluid and Tensorflow. Models trained using SimNet can be easily added into our AnyQ framework, through which we can augment our semantic matching ability. 8 | 9 | The overall framework of AnyQ is as follows: 10 |
11 | 12 |
13 | 14 | ## Acknowledgments & Statements 15 | 16 | This work is supported by the National Key R&D Program of China (No. **2018YFB1004300** ). 17 | -------------------------------------------------------------------------------- /demo/run_server.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include "server/http_server.h" 17 | #include "common/utils.h" 18 | #include "common/plugin_header.h" 19 | 20 | int main(int argc, char* argv[]) { 21 | google::InitGoogleLogging(argv[0]); 22 | FLAGS_stderrthreshold = google::INFO; 23 | anyq::HttpServer server; 24 | std::string anyq_brpc_conf = "./example/conf/anyq_brpc.conf"; 25 | if (server.init(anyq_brpc_conf) != 0) { 26 | FATAL_LOG("server init failed"); 27 | return -1; 28 | } 29 | 30 | if (server.always_run() != 0) { 31 | FATAL_LOG("server run failed"); 32 | return -1; 33 | } 34 | return 0; 35 | } 36 | -------------------------------------------------------------------------------- /cmake/external/brpc.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(BRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/brpc) 4 | SET(BRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | SET(BRPC_DOWNLOAD_DIR ${BRPC_SOURCES_DIR}/src/) 6 | #set(ENV{PATH} ${THIRD_PARTY_PATH}/bin:$ENV{PATH} ) 7 | ExternalProject_Add( 8 | extern_brpc 9 | ${EXTERNAL_PROJECT_LOG_ARGS} 10 | DEPENDS extern_protobuf extern_glog extern_gflags extern_leveldb extern_openssl 11 | DOWNLOAD_DIR ${BRPC_DOWNLOAD_DIR} 12 | DOWNLOAD_COMMAND git clone https://github.com/brpc/brpc.git && cd brpc && git checkout v0.9.0 13 | DOWNLOAD_NO_PROGRESS 1 14 | PREFIX ${BRPC_SOURCES_DIR} 15 | BUILD_COMMAND "" 16 | UPDATE_COMMAND "" 17 | CONFIGURE_COMMAND "" 18 | INSTALL_COMMAND cd ${BRPC_DOWNLOAD_DIR}/brpc/ && export PATH=${THIRD_PARTY_PATH}/bin:$ENV{PATH} 19 | && bash config_brpc.sh --headers=${BRPC_INSTALL_DIR}/include --libs=${BRPC_INSTALL_DIR}/lib --with-glog 20 | && make 21 | BUILD_IN_SOURCE 1 22 | 23 | ) 24 | 25 | add_custom_command(TARGET extern_brpc POST_BUILD 26 | COMMAND mkdir -p third_party/lib/ 27 | COMMAND mkdir -p third_party/include/ 28 | COMMAND cp -r ${BRPC_DOWNLOAD_DIR}/brpc/output/* ${BRPC_INSTALL_DIR}/ 29 | ) 30 | -------------------------------------------------------------------------------- /cmake/external/leveldb.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(LEVELDB_SOURCES_DIR ${THIRD_PARTY_PATH}/leveldb) 4 | SET(LEVELDB_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | set(LEVELDB_DOWNLOAD_DIR ${LEVELDB_SOURCES_DIR}/src/) 6 | ExternalProject_Add( 7 | extern_leveldb 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | DOWNLOAD_DIR ${LEVELDB_DOWNLOAD_DIR} 10 | DOWNLOAD_COMMAND git clone https://github.com/google/leveldb.git && cd leveldb && git checkout v1.20 11 | DOWNLOAD_NO_PROGRESS 1 12 | PREFIX ${LEVELDB_SOURCES_DIR} 13 | BUILD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | INSTALL_COMMAND cd ${LEVELDB_DOWNLOAD_DIR}/leveldb/ 17 | && ./build_detect_platform build_config.mk ./ 18 | && make 19 | BUILD_IN_SOURCE 1 20 | ) 21 | 22 | add_custom_command(TARGET extern_leveldb POST_BUILD 23 | COMMAND mkdir -p third_party/lib/ 24 | COMMAND mkdir -p third_party/include/ 25 | COMMAND cp -r ${LEVELDB_DOWNLOAD_DIR}/leveldb/include/* ${LEVELDB_INSTALL_DIR}/include/ 26 | COMMAND cp -r ${LEVELDB_DOWNLOAD_DIR}/leveldb/out-static/lib* ${LEVELDB_INSTALL_DIR}/lib/ 27 | COMMAND cp -r ${LEVELDB_DOWNLOAD_DIR}/leveldb/out-shared/lib* ${LEVELDB_INSTALL_DIR}/lib/ 28 | ) 29 | -------------------------------------------------------------------------------- /src/analysis/method_interface.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "analysis/method_interface.h" 16 | 17 | namespace anyq { 18 | 19 | int AnalysisMethodInterface::single_process(AnalysisItem& analysis_item) { 20 | return 0; 21 | } 22 | 23 | int AnalysisMethodInterface::method_process(AnalysisResult& analysis_result) { 24 | std::vector& analysis = analysis_result.analysis; 25 | for (size_t j = 0; j < analysis.size(); j++) { 26 | AnalysisItem& analysis_item = analysis[j]; 27 | if (single_process(analysis_item) != 0) { 28 | FATAL_LOG("single_process err"); 29 | return -1; 30 | } 31 | } 32 | return 0; 33 | } 34 | 35 | } // namespace anyq 36 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/losses/softmax_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import layers.paddle_layers as layers 18 | import paddle.fluid as fluid 19 | 20 | 21 | class SoftmaxCrossEntropyLoss(object): 22 | """ 23 | Softmax with Cross Entropy Loss Calculate 24 | """ 25 | def __init__(self, conf_dict): 26 | """ 27 | initialize 28 | """ 29 | pass 30 | 31 | def compute(self, input, label): 32 | """ 33 | compute loss 34 | """ 35 | reduce_mean = layers.ReduceMeanLayer() 36 | cost = fluid.layers.cross_entropy(input=input, label=label) 37 | avg_cost = reduce_mean.ops(cost) 38 | return avg_cost 39 | -------------------------------------------------------------------------------- /cmake/external/glog.cmake: -------------------------------------------------------------------------------- 1 | include(ExternalProject) 2 | set(GLOG_PROJECT "extern_glog") 3 | set(GLOG_VER "v0.3.5") 4 | set(GLOG_SOURCES_DIR ${THIRD_PARTY_PATH}/glog) 5 | set(GLOG_DOWNLOAD_DIR "${GLOG_SOURCES_DIR}/src/") 6 | set(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}) 7 | 8 | ExternalProject_Add( 9 | ${GLOG_PROJECT} 10 | ${EXTERNAL_PROJECT_LOG_ARGS} 11 | GIT_REPOSITORY "https://github.com/google/glog.git" 12 | GIT_TAG "v0.3.5" 13 | PREFIX ${GLOG_SOURCES_DIR} 14 | UPDATE_COMMAND "" 15 | CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} 16 | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} 17 | -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} 18 | -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} 19 | -DCMAKE_PREFIX_PATH=${GLOG_SOURCES_DIR} 20 | -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} 21 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON 22 | -DWITH_GFLAGS=OFF 23 | ${EXTERNAL_OPTIONAL_ARGS} 24 | CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR} 25 | -DBUILD_SHARED_LIBS:BOOL=ON 26 | -DBUILD_STATIC_LIBS:BOOL=ON 27 | -DWITH_GFLAGS:BOOL=OFF 28 | -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON 29 | ) 30 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/knrm-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file":"data/convert_train_point_data", 4 | "data_size":400, 5 | "left_slots":[["left",32]], 6 | "right_slots":[["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py":"./nets/knrm", 11 | "net_class":"KNRM", 12 | "vocabulary_size":3, 13 | "embedding_dim":128, 14 | "kernel_num":3, 15 | "lamb":0.5, 16 | "loss_py":"./losses/simnet_loss", 17 | "loss_class":"SoftmaxWithLoss" 18 | }, 19 | 20 | "global":{ 21 | "training_mode":"pointwise", 22 | "n_class":2, 23 | "max_len_left":32, 24 | "max_len_right":32 25 | }, 26 | 27 | "setting":{ 28 | "batch_size":64, 29 | "num_epochs":1, 30 | "thread_num":6, 31 | "print_iter":100, 32 | "model_path":"model/pointwise", 33 | "model_prefix":"knrm", 34 | "learning_rate":0.001, 35 | "shuffle":1 36 | }, 37 | 38 | "test_data":{ 39 | "test_file":"data/convert_test_pointwise_data", 40 | "test_model_file":"model/pointwise/knrm.epoch1", 41 | "test_result":"result_knrm_pointwise" 42 | }, 43 | 44 | "freeze":{ 45 | "save_path": "model/pointwise/knrm.epoch1", 46 | "freeze_path": "tf.graph" 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/bow-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pointwise_data", 4 | "data_size": 62, 5 | "left_slots": [["left", 32]], 6 | "right_slots": [["right", 32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/bow", 11 | "net_class": "BOW", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "bow_size": 128, 15 | "hidden_size": 128, 16 | "loss_py": "./losses/simnet_loss", 17 | "loss_class": "SoftmaxWithLoss" 18 | }, 19 | 20 | "global": { 21 | "training_mode": "pointwise", 22 | "n_class": 2, 23 | "max_len_left": 32, 24 | "max_len_right": 32 25 | }, 26 | 27 | "setting": { 28 | "batch_size": 8, 29 | "num_epochs": 1, 30 | "thread_num": 6, 31 | "print_iter": 100, 32 | "model_path": "model/pointwise", 33 | "model_prefix": "bow", 34 | "learning_rate": 0.001, 35 | "shuffle": 1 36 | }, 37 | 38 | "test_data": { 39 | "test_file": "data/convert_test_point_data", 40 | "test_model_file": "model/pointwise/bow.epoch1", 41 | "test_result": "result_bow_pointwise" 42 | }, 43 | 44 | "freeze":{ 45 | "save_path": "model/pointwise/bow.epoch1", 46 | "freeze_path": "tf.graph" 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/common/plugin_factory.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "common/plugin_factory.h" 16 | 17 | namespace anyq{ 18 | 19 | // 注册组件回调函数 20 | int PluginFactory::register_plugin(std::string plugin_type, PluginCreateFunc create_func) { 21 | _plugin_map[plugin_type] = create_func; 22 | return 0; 23 | } 24 | 25 | // 根据组件类型生成一个组件实例, 自己创建的实例自己销毁,工厂不负责 26 | void* PluginFactory::create_plugin(std::string plugin_type) { 27 | if (_plugin_map.count(plugin_type) < 1) { 28 | FATAL_LOG("create plugin[%s] failed.", plugin_type.c_str()); 29 | return NULL; 30 | } 31 | return _plugin_map[plugin_type](); 32 | } 33 | 34 | PluginFactory& PluginFactory::instance() { 35 | static PluginFactory factory_ins; 36 | return factory_ins; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/lstm-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pointwise_data", 4 | "data_size": 400, 5 | "left_slots" : [["left", 32]], 6 | "right_slots" : [["right", 32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/lstm", 11 | "net_class": "LSTM", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "rnn_hidden_size": 128, 15 | "hidden_size": 128, 16 | "loss_py": "./losses/simnet_loss", 17 | "loss_class": "SoftmaxWithLoss" 18 | }, 19 | 20 | "global":{ 21 | "training_mode": "pointwise", 22 | "n_class": 2, 23 | "max_len_left": 32, 24 | "max_len_right": 32 25 | }, 26 | 27 | "setting":{ 28 | "batch_size": 8, 29 | "num_epochs": 10, 30 | "thread_num": 6, 31 | "print_iter": 100, 32 | "model_path": "model/pointwise", 33 | "model_prefix": "lstm", 34 | "learning_rate": 0.001, 35 | "shuffle": 1 36 | }, 37 | 38 | "test_data":{ 39 | "test_file": "data/convert_test_pointwise_data", 40 | "test_model_file": "model/pointwise/lstm.epoch1", 41 | "test_result": "result_lstm_pointwise" 42 | }, 43 | 44 | "freeze":{ 45 | "save_path": "model/pointwise/lstm.epoch1", 46 | "freeze_path": "tf.graph" 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /tools/solr/solr_api.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import sys 18 | import solr_tools 19 | 20 | if sys.argv[1] == "add_engine": 21 | solr_tools.add_engine(sys.argv[2], sys.argv[3], sys.argv[4], 22 | shard=1, replica=1, maxshardpernode=5, conf='myconf') 23 | elif sys.argv[1] == "set_schema": 24 | solr_tools.set_engine_schema(sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5]) 25 | elif sys.argv[1] == "delete_engine": 26 | solr_tools.delete_engine(sys.argv[2], sys.argv[3], sys.argv[4]) 27 | elif sys.argv[1] == "upload_doc": 28 | solr_tools.upload_documents(sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], num_thread=1) 29 | elif sys.argv[1] == "clear_doc": 30 | solr_tools.clear_documents(sys.argv[2], sys.argv[3], sys.argv[4]) 31 | 32 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/bow-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pairwise_data", 4 | "data_size": 62, 5 | "left_slots": [["left", 32]], 6 | "right_slots": [["right", 32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/bow", 11 | "net_class": "BOW", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "bow_size": 128, 15 | "hidden_size": 128, 16 | "loss_py": "./losses/simnet_loss", 17 | "loss_class": "PairwiseHingeLoss", 18 | "margin": 0.1 19 | }, 20 | 21 | "global": { 22 | "training_mode": "pairwise", 23 | "n_class": 2, 24 | "max_len_left": 32, 25 | "max_len_right": 32 26 | }, 27 | 28 | "setting": { 29 | "batch_size": 8, 30 | "num_epochs": 10, 31 | "thread_num": 6, 32 | "print_iter": 100, 33 | "model_path": "model/pairwise", 34 | "model_prefix": "bow", 35 | "learning_rate": 0.001, 36 | "shuffle": 1 37 | }, 38 | 39 | "test_data": { 40 | "test_file": "data/convert_test_pair_data", 41 | "test_model_file": "model/pairwise/bow.epoch1", 42 | "test_result": "result_bow_pairwise" 43 | }, 44 | 45 | "freeze":{ 46 | "save_path": "model/pairwise/bow.epoch1", 47 | "freeze_path": "tf.graph" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/knrm-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data": { 3 | "train_file": "data/convert_train_pair_data", 4 | "data_size": 62, 5 | "left_slots": [["left", 32]], 6 | "right_slots": [["right", 32]] 7 | }, 8 | 9 | "model": { 10 | "net_py": "./nets/knrm", 11 | "net_class": "KNRM", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "kernel_num": 3, 15 | "lamb": 0.5, 16 | "loss_py": "./losses/simnet_loss", 17 | "loss_class": "PairwiseLogLoss", 18 | "margin": 0.1 19 | }, 20 | 21 | "global": { 22 | "training_mode": "pairwise", 23 | "n_class": 2, 24 | "max_len_left": 32, 25 | "max_len_right": 32 26 | }, 27 | 28 | "setting": { 29 | "batch_size": 64, 30 | "num_epochs": 1, 31 | "thread_num": 6, 32 | "print_iter": 100, 33 | "model_path": "model/pairwise", 34 | "model_prefix": "knrm", 35 | "learning_rate": 0.001, 36 | "shuffle": 1 37 | }, 38 | 39 | "test_data": { 40 | "test_file": "data/convert_test_pair_data", 41 | "test_model_file": "model/pairwise/knrm.epoch1", 42 | "test_result": "result_knrm_pairwise" 43 | }, 44 | 45 | "freeze":{ 46 | "save_path": "model/pairwise/knrm.epoch1", 47 | "freeze_path": "tf.graph" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/lstm-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pairwise_data", 4 | "data_size": 62, 5 | "left_slots" : [["left", 32]], 6 | "right_slots" : [["right", 32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/lstm", 11 | "net_class": "LSTM", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "rnn_hidden_size": 128, 15 | "hidden_size": 128, 16 | "loss_py": "./losses/simnet_loss", 17 | "loss_class": "PairwiseHingeLoss", 18 | "margin": 0.1 19 | }, 20 | 21 | "global":{ 22 | "training_mode": "pairwise", 23 | "n_class": 2, 24 | "max_len_left": 32, 25 | "max_len_right": 32 26 | }, 27 | 28 | "setting":{ 29 | "batch_size": 8, 30 | "num_epochs": 10, 31 | "thread_num": 6, 32 | "print_iter": 100, 33 | "model_path": "model/pairwise", 34 | "model_prefix": "lstm", 35 | "learning_rate": 0.001, 36 | "shuffle": 1 37 | }, 38 | 39 | "test_data":{ 40 | "test_file": "data/convert_test_pair_data", 41 | "test_model_file": "model/pairwise/lstm.epoch1", 42 | "test_result": "result_lstm_pairwise" 43 | }, 44 | 45 | "freeze":{ 46 | "save_path": "model/pairwise/lstm.epoch1", 47 | "freeze_path": "tf.graph" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /include/server/anyq_postprocessor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_ANYQ_POSTPROCESSOR_H 16 | #define BAIDU_NLP_ANYQ_ANYQ_POSTPROCESSOR_H 17 | 18 | #include "server/request_postprocess_interface.h" 19 | 20 | namespace anyq { 21 | 22 | class AnyqPostprocessor : public ReqPostprocInterface { 23 | public: 24 | AnyqPostprocessor() {}; 25 | virtual ~AnyqPostprocessor() override {}; 26 | virtual int init(const ServerConfig& config) override; 27 | virtual int destroy() override; 28 | virtual int process(ANYQResult& anyq_result, 29 | Json::Value& parameters, 30 | std::string& output) override; 31 | 32 | private: 33 | DISALLOW_COPY_AND_ASSIGN(AnyqPostprocessor); 34 | }; 35 | } // namespace anyq 36 | 37 | #endif // BAIDU_NLP_ANYQ_ANYQ_POSTPROCESSOR_H 38 | -------------------------------------------------------------------------------- /include/server/anyq_preprocessor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_ANYQ_PREPROCESSOR_H 16 | #define BAIDU_NLP_ANYQ_ANYQ_PREPROCESSOR_H 17 | 18 | #include "server/request_preprocess_interface.h" 19 | 20 | namespace anyq { 21 | 22 | class AnyqPreprocessor : public ReqPreprocInterface { 23 | public: 24 | AnyqPreprocessor() {}; 25 | virtual ~AnyqPreprocessor() override {}; 26 | virtual int init(const ReqPreprocPluginConfig& config) override; 27 | virtual int destroy() override; 28 | virtual int process(brpc::Controller*, 29 | Json::Value& parameters, 30 | std::string& input) override; 31 | 32 | private: 33 | DISALLOW_COPY_AND_ASSIGN(AnyqPreprocessor); 34 | }; 35 | 36 | } // namespace anyq 37 | 38 | #endif // BAIDU_NLP_ANYQ_ANYQ_PREPROCESSOR_H 39 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/cnn-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pointwise_data", 4 | "data_size": 400, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/mlpcnn", 11 | "net_class": "MLPCnn", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "num_filters": 256, 15 | "hidden_size": 128, 16 | "window_size": 3, 17 | "loss_py": "./losses/simnet_loss", 18 | "loss_class": "SoftmaxWithLoss" 19 | }, 20 | 21 | "global":{ 22 | "training_mode": "pointwise", 23 | "n_class": 2, 24 | "max_len_left": 32, 25 | "max_len_right": 32 26 | }, 27 | 28 | "setting":{ 29 | "batch_size": 8, 30 | "num_epochs": 10, 31 | "thread_num": 6, 32 | "print_iter": 100, 33 | "model_path": "model/pointwise", 34 | "model_prefix": "cnn", 35 | "learning_rate": 0.001, 36 | "shuffle": 1 37 | }, 38 | 39 | "test_data":{ 40 | "test_file": "data/convert_test_pointwise_data", 41 | "test_model_file": "model/pointwise/cnn.epoch1", 42 | "test_result": "result_cnn_pointwise" 43 | }, 44 | 45 | "freeze":{ 46 | "save_path": "model/pointwise/cnn.epoch1", 47 | "freeze_path": "tf.graph" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /include/matching/lexical/cosine_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_COSINE_SIM_H 16 | #define BAIDU_NLP_ANYQ_COSINE_SIM_H 17 | 18 | #include "matching/matching_interface.h" 19 | 20 | namespace anyq { 21 | //该特征表示query与候选在字面上的余弦相似度 22 | class CosineSimilarity : public MatchingInterface { 23 | public: 24 | CosineSimilarity(); 25 | virtual ~CosineSimilarity() override; 26 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 27 | virtual int destroy() override; 28 | virtual int compute_similarity(const AnalysisResult& analysis_res, 29 | RankResult& candidates) override; 30 | private: 31 | DISALLOW_COPY_AND_ASSIGN(CosineSimilarity); 32 | }; 33 | 34 | } // namespace anyq 35 | 36 | #endif // BAIDU_NLP_ANYQ_COSINE_SIMILARITY_H 37 | -------------------------------------------------------------------------------- /cmake/external/gflags.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) 4 | SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | 6 | ExternalProject_Add( 7 | extern_gflags 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | GIT_REPOSITORY "https://github.com/gflags/gflags.git" 10 | GIT_TAG "v2.2.1" 11 | PREFIX ${GFLAGS_SOURCES_DIR} 12 | UPDATE_COMMAND "" 13 | CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} 14 | -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} 15 | -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} 16 | -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} 17 | -DCMAKE_PREFIX_PATH=${GFLAGS_SOURCES_DIR} 18 | -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} 19 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON 20 | -DBUILD_SHARED_LIBS=ON 21 | -DBUILD_STATIC_LIBS=ON 22 | -DBUILD_TESTING=OFF 23 | -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} 24 | ${EXTERNAL_OPTIONAL_ARGS} 25 | CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} 26 | -DBUILD_SHARED_LIBS:BOOL=ON 27 | -DBUILD_STATIC_LIBS:BOOL=ON 28 | -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON 29 | -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} 30 | ) 31 | 32 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/mvlstm-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_point_data", 4 | "data_size": 400, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/mvlstm", 11 | "net_class": "MVLSTM", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "hidden_size": 128, 15 | "k_max_num": 32, 16 | "match_mask" : 1, 17 | "loss_py": "./losses/simnet_loss", 18 | "loss_class": "SoftmaxWithLoss" 19 | }, 20 | 21 | "global":{ 22 | "training_mode": "pointwise", 23 | "n_class": 2, 24 | "max_len_left": 32, 25 | "max_len_right": 32 26 | }, 27 | 28 | "setting":{ 29 | "batch_size": 64, 30 | "num_epochs": 1, 31 | "thread_num": 6, 32 | "print_iter": 100, 33 | "model_path": "model/pointwise", 34 | "model_prefix": "mvlstm", 35 | "learning_rate": 0.001, 36 | "shuffle": 1 37 | }, 38 | 39 | "test_data":{ 40 | "test_file": "data/convert_test_pointwise_data", 41 | "test_model_file": "model/pointwise/mvlstm.epoch1", 42 | "test_result": "result_mvlstm_pointwise" 43 | }, 44 | 45 | "freeze":{ 46 | "save_path": "model/pointwise/mvlstm.epoch1", 47 | "freeze_path": "tf.graph" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /include/retrieval/term/equal_solr_q_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_EQUAL_SOLR_Q_BUILDER_H 16 | #define BAIDU_NLP_ANYQ_EQUAL_SOLR_Q_BUILDER_H 17 | 18 | #include "retrieval/term/solr_q_interface.h" 19 | 20 | namespace anyq { 21 | // term相同的solr表达式构造插件 22 | class EqualSolrQBuilder : public SolrQInterface 23 | { 24 | public: 25 | EqualSolrQBuilder() {}; 26 | virtual ~EqualSolrQBuilder() override {}; 27 | int init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) override; 28 | virtual int make_q(const AnalysisResult& analysis_res, 29 | int analysis_idx, 30 | std::string& q) override; 31 | 32 | private: 33 | DISALLOW_COPY_AND_ASSIGN(EqualSolrQBuilder); 34 | }; 35 | 36 | } // namespace anyq 37 | 38 | #endif // BAIDU_NLP_ANYQ_EQUAL_SOLR_Q_BUILDER_H 39 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/cnn-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pairwise_data", 4 | "data_size": 62, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/mlpcnn", 11 | "net_class": "MLPCnn", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "num_filters": 256, 15 | "hidden_size": 128, 16 | "window_size": 3, 17 | "loss_py": "./losses/simnet_loss", 18 | "loss_class": "PairwiseHingeLoss", 19 | "margin": 0.1 20 | }, 21 | 22 | "global":{ 23 | "training_mode": "pairwise", 24 | "n_class": 2, 25 | "max_len_left": 32, 26 | "max_len_right": 32 27 | }, 28 | 29 | "setting":{ 30 | "batch_size": 8, 31 | "num_epochs": 10, 32 | "thread_num": 6, 33 | "print_iter": 100, 34 | "model_path": "model/pairwise", 35 | "model_prefix": "cnn", 36 | "learning_rate": 0.001, 37 | "shuffle": 1 38 | }, 39 | 40 | "test_data":{ 41 | "test_file": "data/convert_test_pair_data", 42 | "test_model_file": "model/pairwise/cnn.epoch1", 43 | "test_result": "result_cnn_pairwise" 44 | }, 45 | 46 | "freeze":{ 47 | "save_path": "model/pairwise/cnn.epoch1", 48 | "freeze_path": "tf.graph" 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/mvlstm-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pair_data", 4 | "data_size": 62, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/mvlstm", 11 | "net_class": "MVLSTM", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "hidden_size": 128, 15 | "k_max_num": 32, 16 | "match_mask" : 1, 17 | "loss_py": "./losses/simnet_loss", 18 | "loss_class": "PairwiseLogLoss", 19 | "margin": 0.1 20 | }, 21 | 22 | "global":{ 23 | "training_mode": "pairwise", 24 | "n_class": 2, 25 | "max_len_left": 32, 26 | "max_len_right": 32 27 | }, 28 | 29 | "setting":{ 30 | "batch_size": 64, 31 | "num_epochs": 1, 32 | "thread_num": 6, 33 | "print_iter": 100, 34 | "model_path": "model/pairwise", 35 | "model_prefix": "mvlstm", 36 | "learning_rate": 0.001, 37 | "shuffle": 1 38 | }, 39 | 40 | "test_data":{ 41 | "test_file": "data/convert_test_pair_data", 42 | "test_model_file": "model/pairwise/mvlstm.epoch1", 43 | "test_result": "result_mvlstm_pairwise" 44 | }, 45 | 46 | "freeze":{ 47 | "save_path": "model/pairwise/mvlstm.epoch1", 48 | "freeze_path": "tf.graph" 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /include/analysis/method_wordseg.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_METHOD_WORDSEG_H 16 | #define BAIDU_NLP_ANYQ_METHOD_WORDSEG_H 17 | 18 | #include "analysis/method_interface.h" 19 | 20 | namespace anyq { 21 | 22 | class AnalysisWordseg: public AnalysisMethodInterface { 23 | public: 24 | AnalysisWordseg(); 25 | virtual ~AnalysisWordseg() override; 26 | virtual int init(DualDictWrapper* dict, const AnalysisMethodConfig& analysis_method) override; 27 | virtual int destroy() override; 28 | virtual int single_process(AnalysisItem& analysis_item) override; 29 | 30 | private: 31 | WordsegPack* _p_wordseg_pack; 32 | void* _lexer_buff; 33 | tag_t _basic_tokens[MAX_TERM_COUNT]; 34 | DISALLOW_COPY_AND_ASSIGN(AnalysisWordseg); 35 | }; 36 | 37 | } // namespace anyq 38 | 39 | #endif // BAIDU_NLP_ANYQ_METHOD_WORDSEG_H 40 | -------------------------------------------------------------------------------- /include/matching/lexical/jaccard_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_JACCARD_SIM_H 16 | #define BAIDU_NLP_ANYQ_JACCARD_SIM_H 17 | 18 | #include "matching/matching_interface.h" 19 | #include "common/utils.h" 20 | 21 | namespace anyq { 22 | // 该特征表示query与候选的字符串Jaccard相似度 23 | class JaccardSimilarity : public MatchingInterface { 24 | public: 25 | JaccardSimilarity(); 26 | virtual ~JaccardSimilarity() override; 27 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 28 | virtual int destroy() override; 29 | virtual int compute_similarity(const AnalysisResult& analysis_res, 30 | RankResult& candidates) override; 31 | 32 | private: 33 | DISALLOW_COPY_AND_ASSIGN(JaccardSimilarity); 34 | }; 35 | 36 | } // namespace anyq 37 | 38 | #endif // BAIDU_NLP_ANYQ_JACCARD_SIMILARITY_H 39 | -------------------------------------------------------------------------------- /include/analysis/method_query_intervene.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_ANALYSIS_METHOD_QUERY_INTERVENE_H 16 | #define BAIDU_NLP_ANYQ_ANALYSIS_METHOD_QUERY_INTERVENE_H 17 | 18 | #include "analysis/method_interface.h" 19 | 20 | namespace anyq { 21 | 22 | class AnalysisQueryIntervene: public AnalysisMethodInterface { 23 | public: 24 | AnalysisQueryIntervene(); 25 | virtual ~AnalysisQueryIntervene() override; 26 | virtual int init(DualDictWrapper* dict, const AnalysisMethodConfig& analysis_method) override; 27 | virtual int destroy() override; 28 | virtual int single_process(AnalysisItem& analysis_item) override; 29 | private: 30 | // reload词典 31 | DualDictWrapper* _dict; 32 | DISALLOW_COPY_AND_ASSIGN(AnalysisQueryIntervene); 33 | }; 34 | 35 | } // namespace anyq 36 | #endif //BAIDU_NLP_ANYQ_ANALYSIS_METHOD_QUERY_INTERVENE_H 37 | -------------------------------------------------------------------------------- /include/retrieval/manual/manual_retrieval.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_MANUAL_RETRIEVAL_H 16 | #define BAIDU_NLP_ANYQ_MANUAL_RETRIEVAL_H 17 | 18 | #include "retrieval/retrieval_interface.h" 19 | 20 | namespace anyq { 21 | //人工干预检索插件 22 | class ManualRetrievalPlugin : public RetrievalPluginInterface{ 23 | public: 24 | ManualRetrievalPlugin(){}; 25 | virtual ~ManualRetrievalPlugin() override {}; 26 | virtual int init(DictMap* dict_map, const RetrievalPluginConfig& plugin_config) override; 27 | virtual int destroy() override; 28 | virtual int retrieval(const AnalysisResult& analysis_res, 29 | RetrievalResult& retrieval_res) override; 30 | private: 31 | DualDictWrapper *_p_dual_dict_wrapper; 32 | DISALLOW_COPY_AND_ASSIGN(ManualRetrievalPlugin); 33 | }; 34 | 35 | } // namespace anyq 36 | #endif // BAIDU_NLP_ANYQ_MANUAL_RETRIEVAL_H 37 | -------------------------------------------------------------------------------- /tools/solr/anyq_solr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -lt 1 ];then 4 | echo "need faq file" 5 | exit 2 6 | fi 7 | SCRIPT_PATH=$(dirname $0) 8 | FAQ_FILE=$1 9 | # faq 10 | # file -> json 11 | # file -> schema 12 | mkdir -p faq 13 | SCHEMA_FILE=`pwd`/faq/schema_format 14 | FAQ_JSON=`pwd`/faq/faq_json 15 | python ${SCRIPT_PATH}/make_json.py $FAQ_FILE $SCHEMA_FILE $FAQ_JSON 16 | if [ $? -ne 0 ];then 17 | echo "faq-file trans error" 18 | exit 1 19 | else 20 | echo "faq-file trans done" 21 | fi 22 | 23 | # set solr dir 24 | SOLR_HOME=./solr-4.10.3-anyq 25 | SOLR_SERVER=${SCRIPT_PATH}/solr_deply.sh 26 | SOLR_API=${SCRIPT_PATH}/solr_api.py 27 | ENGINE_HOST=localhost 28 | ENGINE_NAME=collection1 29 | SOLR_PORT=8900 30 | 31 | #set emp schema 32 | SOLR_EMP_CONF=$SOLR_HOME/example/solr_config_set/common 33 | SOLR_CONF=$SOLR_HOME/example/solr/collection1/conf 34 | cp $SOLR_EMP_CONF/* $SOLR_CONF/ 35 | 36 | #set paddle environment variable 37 | export MKL_NUM_THREADS=1 38 | export OMP_NUM_THREADS=1 39 | 40 | #start 41 | /bin/bash ${SOLR_SERVER} start ${SOLR_HOME} ${SOLR_PORT} 42 | 43 | # set schema 44 | python ${SOLR_API} set_schema ${ENGINE_HOST} ${ENGINE_NAME} ${SCHEMA_FILE} ${SOLR_PORT} 45 | 46 | # clear docs 47 | python ${SOLR_API} clear_doc ${ENGINE_HOST} ${ENGINE_NAME} ${SOLR_PORT} 48 | 49 | # upload docs 50 | python ${SOLR_API} upload_doc ${ENGINE_HOST} ${ENGINE_NAME} ${SOLR_PORT} ${FAQ_JSON} 51 | 52 | if [ $? -ne 0 ];then 53 | echo "upload file error" 54 | exit 1 55 | else 56 | echo "upload file success" 57 | fi 58 | -------------------------------------------------------------------------------- /include/retrieval/term/date_compare_solr_q_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_DATE_COMPARE_SOLR_Q_BUILDER_H 16 | #define BAIDU_NLP_ANYQ_DATE_COMPARE_SOLR_Q_BUILDER_H 17 | 18 | #include "retrieval/term/solr_q_interface.h" 19 | 20 | namespace anyq { 21 | // 时间限定的solr索引表达式构造插件 22 | class DateCompareSolrQBuilder : public SolrQInterface{ 23 | public: 24 | DateCompareSolrQBuilder() {}; 25 | virtual ~DateCompareSolrQBuilder() override {}; 26 | virtual int init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) override; 27 | virtual int make_q(const AnalysisResult& analysis_res, 28 | int analysis_idx, 29 | std::string& q) override; 30 | 31 | private: 32 | std::string _compare_type; 33 | DISALLOW_COPY_AND_ASSIGN(DateCompareSolrQBuilder); 34 | }; 35 | 36 | } // namespace anyq 37 | 38 | #endif // BAIDU_NLP_ANYQ_DATE_COMPARE_SOLR_Q_BUILDER_H 39 | -------------------------------------------------------------------------------- /include/retrieval/retrieval_strategy.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_RETRIEVAL_STRATEGY_H 16 | #define BAIDU_NLP_ANYQ_RETRIEVAL_STRATEGY_H 17 | 18 | #include "retrieval/retrieval_interface.h" 19 | #include "dict/dict_manager.h" 20 | 21 | namespace anyq { 22 | // 检索策略类,检索模块的入口 23 | class RetrievalStrategy{ 24 | public: 25 | RetrievalStrategy(); 26 | ~RetrievalStrategy(); 27 | int init(DictMap* dict_map, const std::string& retrieval_conf); 28 | int destroy(); 29 | // 删除检索出的重复query 30 | int rm_duplicate_query(RetrievalResult& retrieval_result); 31 | // 总的召回策略流程 32 | int run_strategy(const AnalysisResult& analysis_result, RetrievalResult& retrieval_result); 33 | 34 | private: 35 | // 存放检索插件 36 | std::vector _retrieval_plugins; 37 | DISALLOW_COPY_AND_ASSIGN(RetrievalStrategy); 38 | }; 39 | 40 | } // namespace anyq 41 | 42 | #endif // BAIDU_NLP_ANYQ_RETRIEVAL_STRATEGY_H 43 | -------------------------------------------------------------------------------- /include/retrieval/term/synonym_solr_q_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SYNONYM_SOLR_Q_BUILDER_H 16 | #define BAIDU_NLP_ANYQ_SYNONYM_SOLR_Q_BUILDER_H 17 | 18 | #include "retrieval/term/solr_q_interface.h" 19 | 20 | namespace anyq { 21 | //term同义词solr索引表达式插件 22 | class SynonymSolrQBuilder : public SolrQInterface { 23 | public: 24 | SynonymSolrQBuilder() {}; 25 | virtual ~SynonymSolrQBuilder() override {}; 26 | virtual int init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) override; 27 | int term_synonym(const std::string& term, std::string& synonym_terms); 28 | virtual int make_q(const AnalysisResult& analysis_res, 29 | int analysis_idx, 30 | std::string& q) override; 31 | 32 | private: 33 | DualDictWrapper *_p_dual_dict_wrapper; 34 | DISALLOW_COPY_AND_ASSIGN(SynonymSolrQBuilder); 35 | }; 36 | 37 | } // namespace anyq 38 | #endif // BAIDU_NLP_ANYQ_SYNONYM_SOLR_Q_BUILDER_H 39 | -------------------------------------------------------------------------------- /include/matching/lexical/wordseg_proc.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_WORDSEG_PROC_H 16 | #define BAIDU_NLP_ANYQ_WORDSEG_PROC_H 17 | 18 | #include "matching/matching_interface.h" 19 | 20 | namespace anyq { 21 | // 该特征插件仅对候选进行分词输出处理,不输出特征值,output_num应设置为0 22 | class WordsegProcessor : public MatchingInterface { 23 | public: 24 | WordsegProcessor(); 25 | virtual ~WordsegProcessor() override; 26 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 27 | virtual int destroy() override; 28 | virtual int compute_similarity(const AnalysisResult& analysis_res, 29 | RankResult& candidates) override; 30 | private: 31 | // 分词全局字典指针 32 | WordsegPack* _p_wordseg_pack; 33 | // 分词线程资源 34 | void* _lexer_buff; 35 | tag_t _basic_tokens[MAX_TERM_COUNT]; 36 | DISALLOW_COPY_AND_ASSIGN(WordsegProcessor); 37 | }; 38 | 39 | } // namespace anyq 40 | #endif // BAIDU_NLP_ANYQ_WORDSEG_PROC_H 41 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/pyramid-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_point_data", 4 | "data_size": 400, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/matchpyramid", 11 | "net_class": "MatchPyramid", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "num_filters": 256, 15 | "hidden_size": 128, 16 | "window_size_left": 3, 17 | "window_size_right": 3, 18 | "dpool_size_left": 2, 19 | "dpool_size_right": 2, 20 | "match_mask" : 1, 21 | "loss_py": "./losses/simnet_loss", 22 | "loss_class": "SoftmaxWithLoss" 23 | }, 24 | 25 | "global":{ 26 | "training_mode": "pointwise", 27 | "n_class": 2, 28 | "max_len_left": 32, 29 | "max_len_right": 32 30 | }, 31 | 32 | "setting":{ 33 | "batch_size": 64, 34 | "num_epochs": 1, 35 | "thread_num": 6, 36 | "print_iter": 100, 37 | "model_path": "model/pointwise", 38 | "model_prefix": "pyramid", 39 | "learning_rate": 0.001, 40 | "shuffle": 1 41 | }, 42 | 43 | "test_data":{ 44 | "test_file": "data/convert_test_pointwise_data", 45 | "test_model_file": "model/pointwise/pyramid.epoch1", 46 | "test_result": "result_pyramid_pointwise" 47 | }, 48 | 49 | "freeze":{ 50 | "save_path": "model/pointwise/pyramid.epoch1", 51 | "freeze_path": "tf.graph" 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/mmdnn-pointwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_point_data", 4 | "data_size": 400, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/mm_dnn", 11 | "net_class": "MMDNN", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "num_filters": 256, 15 | "lstm_dim": 128, 16 | "hidden_size": 128, 17 | "window_size_left": 3, 18 | "window_size_right": 3, 19 | "dpool_size_left": 2, 20 | "dpool_size_right": 2, 21 | "match_mask" : 1, 22 | "loss_py": "./losses/simnet_loss", 23 | "loss_class": "SoftmaxWithLoss" 24 | }, 25 | 26 | "global":{ 27 | "training_mode": "pointwise", 28 | "n_class": 2, 29 | "max_len_left": 32, 30 | "max_len_right": 32 31 | }, 32 | 33 | "setting":{ 34 | "batch_size": 64, 35 | "num_epochs": 1, 36 | "thread_num": 6, 37 | "print_iter": 100, 38 | "model_path": "model/pointwise", 39 | "model_prefix": "mmdnn", 40 | "learning_rate": 0.001, 41 | "shuffle": 1 42 | }, 43 | 44 | "test_data":{ 45 | "test_file": "data/convert_test_pointwise_data", 46 | "test_model_file": "model/pointwise/mmdnn.epoch1", 47 | "test_result": "result_mmdnn_pointwise" 48 | }, 49 | 50 | "freeze":{ 51 | "save_path": "model/pointwise/mmdnn.epoch1", 52 | "freeze_path": "tf.graph" 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /include/matching/lexical/bm25_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_BM25_SIM_H 16 | #define BAIDU_NLP_ANYQ_BM25_SIM_H 17 | 18 | #include "matching/matching_interface.h" 19 | #include "common/utils.h" 20 | 21 | namespace anyq { 22 | // 该特征表示query与候选的BM25相似度 23 | class BM25Similarity : public MatchingInterface { 24 | public: 25 | BM25Similarity(); 26 | virtual ~BM25Similarity() override; 27 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 28 | virtual int destroy() override; 29 | virtual int compute_similarity(const AnalysisResult& analysis_res, 30 | RankResult& candidates) override; 31 | 32 | private: 33 | DualDictWrapper *_p_dual_dict_wrapper; 34 | static const float _s_bm25_k1; 35 | static const float _s_bm25_k2; 36 | static const float _s_bm25_b; 37 | DISALLOW_COPY_AND_ASSIGN(BM25Similarity); 38 | }; 39 | 40 | } // namespace anyq 41 | 42 | #endif // BAIDU_NLP_ANYQ_BM25_SIM_H 43 | -------------------------------------------------------------------------------- /include/matching/lexical/edit_distance_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_EDIT_DISTANCE_SIM_H 16 | #define BAIDU_NLP_ANYQ_EDIT_DISTANCE_SIM_H 17 | 18 | #include "matching/matching_interface.h" 19 | 20 | namespace anyq { 21 | //该特征表示query与候选编辑距离相似度 22 | class EditDistanceSimilarity : public MatchingInterface { 23 | public: 24 | EditDistanceSimilarity(); 25 | virtual ~EditDistanceSimilarity() override; 26 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 27 | virtual int destroy() override; 28 | virtual int compute_similarity(const AnalysisResult& analysis_res, 29 | RankResult& candidates) override; 30 | int compute_edit_distance(const std::vector& seq1, 31 | const std::vector& seq2); 32 | private: 33 | DISALLOW_COPY_AND_ASSIGN(EditDistanceSimilarity); 34 | }; 35 | 36 | } // namespace anyq 37 | 38 | #endif // BAIDU_NLP_ANYQ_EDIT_DISTANCE_SIM_H 39 | -------------------------------------------------------------------------------- /cmake/external/protobuf.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | SET(OPTIONAL_ARGS 3 | "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" 4 | "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" 5 | "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" 6 | "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" 7 | "-Dprotobuf_WITH_ZLIB=ON" 8 | "-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}" 9 | ${EXTERNAL_OPTIONAL_ARGS}) 10 | 11 | SET(OPTIONAL_CACHE_ARGS "-DZLIB_ROOT:STRING=${ZLIB_ROOT}") 12 | SET(PROTOBUF_REPO "https://github.com/google/protobuf.git") 13 | SET(PROTOBUF_TAG "v3.1.0") 14 | IF(USE_TENSORFLOW) 15 | SET(PROTOBUF_TAG "v3.5.0") 16 | ENDIF() 17 | SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/protobuf) 18 | SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}) 19 | 20 | ExternalProject_Add( 21 | extern_protobuf 22 | ${EXTERNAL_PROJECT_LOG_ARGS} 23 | DEPENDS extern_zlib 24 | GIT_REPOSITORY ${PROTOBUF_REPO} 25 | GIT_TAG ${PROTOBUF_TAG} 26 | PREFIX ${PROTOBUF_SOURCES_DIR} 27 | CONFIGURE_COMMAND cd && ${CMAKE_COMMAND} -DCMAKE_SKIP_RPATH=ON 28 | -Dprotobuf_BUILD_TESTS=OFF 29 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON 30 | -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} 31 | -DCMAKE_INSTALL_LIBDIR=lib ./cmake 32 | BUILD_COMMAND cd && make -j8 && make install 33 | UPDATE_COMMAND "" 34 | INSTALL_COMMAND "" 35 | ) 36 | 37 | add_custom_command(TARGET extern_protobuf POST_BUILD 38 | COMMAND cp ${PROTOBUF_INSTALL_DIR}/bin/protoc ${PROTOBUF_INSTALL_DIR}/lib 39 | ) 40 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/pyramid-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pair_data", 4 | "data_size": 62, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/matchpyramid", 11 | "net_class": "MatchPyramid", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "num_filters": 256, 15 | "hidden_size": 128, 16 | "window_size_left": 3, 17 | "window_size_right": 3, 18 | "dpool_size_left": 2, 19 | "dpool_size_right": 2, 20 | "match_mask" : 1, 21 | "loss_py": "./losses/simnet_loss", 22 | "loss_class": "PairwiseLogLoss", 23 | "margin": 0.1 24 | }, 25 | 26 | "global":{ 27 | "training_mode": "pairwise", 28 | "n_class": 2, 29 | "max_len_left": 32, 30 | "max_len_right": 32 31 | }, 32 | 33 | "setting":{ 34 | "batch_size": 64, 35 | "num_epochs": 1, 36 | "thread_num": 6, 37 | "print_iter": 100, 38 | "model_path": "model/pairwise", 39 | "model_prefix": "pyramid", 40 | "learning_rate": 0.001, 41 | "shuffle": 1 42 | }, 43 | 44 | "test_data":{ 45 | "test_file": "data/convert_test_pair_data", 46 | "test_model_file": "model/pairwise/pyramid.epoch1", 47 | "test_result": "result_pyramid_pairwise" 48 | }, 49 | 50 | "freeze":{ 51 | "save_path": "model/pairwise/pyramid.epoch1", 52 | "freeze_path": "tf.graph" 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /include/matching/lexical/contain_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_CONTAIN_SIM_H 16 | #define BAIDU_NLP_ANYQ_CONTAIN_SIM_H 17 | 18 | #include "matching/matching_interface.h" 19 | #include "common/utils.h" 20 | 21 | namespace anyq { 22 | 23 | //该特征表示query与候选之间是否存在包含关系 24 | class ContainSimilarity : public MatchingInterface { 25 | public: 26 | ContainSimilarity(); 27 | virtual ~ContainSimilarity() override; 28 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 29 | virtual int destroy(); 30 | // 判断analysis中query是否与召回query存在包含的关系 31 | virtual bool contain(const AnalysisItem& analysis_item, const RankItem& rank_item); 32 | // 特征计算 33 | virtual int compute_similarity(const AnalysisResult& analysis_res, 34 | RankResult& candidates) override; 35 | private: 36 | DISALLOW_COPY_AND_ASSIGN(ContainSimilarity); 37 | }; 38 | 39 | } // namespace anyq 40 | 41 | #endif // BAIDU_NLP_ANYQ_CONTAIN_SIMILARITY_H 42 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/examples/mmdnn-pairwise.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data":{ 3 | "train_file": "data/convert_train_pair_data", 4 | "data_size": 62, 5 | "left_slots" : [["left",32]], 6 | "right_slots" : [["right",32]] 7 | }, 8 | 9 | "model":{ 10 | "net_py": "./nets/mm_dnn", 11 | "net_class": "MMDNN", 12 | "vocabulary_size": 3, 13 | "embedding_dim": 128, 14 | "num_filters": 128, 15 | "lstm_dim": 128, 16 | "hidden_size": 128, 17 | "window_size_left": 3, 18 | "window_size_right": 3, 19 | "dpool_size_left": 2, 20 | "dpool_size_right": 2, 21 | "match_mask" : 1, 22 | "loss_py": "./losses/simnet_loss", 23 | "loss_class": "PairwiseLogLoss", 24 | "margin": 0.1 25 | }, 26 | 27 | "global":{ 28 | "training_mode": "pairwise", 29 | "n_class": 2, 30 | "max_len_left": 32, 31 | "max_len_right": 32 32 | }, 33 | 34 | "setting":{ 35 | "batch_size": 64, 36 | "num_epochs": 1, 37 | "thread_num": 6, 38 | "print_iter": 100, 39 | "model_path": "model/pairwise", 40 | "model_prefix": "mmdnn", 41 | "learning_rate": 0.001, 42 | "shuffle": 1 43 | }, 44 | 45 | "test_data":{ 46 | "test_file": "data/convert_test_pair_data", 47 | "test_model_file": "model/pairwise/mmdnn.epoch1", 48 | "test_result": "result_mmdnn_pairwise" 49 | }, 50 | 51 | "freeze":{ 52 | "save_path": "model/pairwise/mmdnn.epoch1", 53 | "freeze_path": "tf.graph" 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /cmake/external/paddle.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(PADDLE_SOURCES_DIR ${THIRD_PARTY_PATH}/paddle) 4 | SET(PADDLE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/paddle) 5 | 6 | ExternalProject_Add( 7 | extern_paddle 8 | ${EXTERNAL_PROJECT_LOG_ARGS} 9 | GIT_REPOSITORY "https://github.com/PaddlePaddle/Paddle.git" 10 | GIT_TAG "release/0.14.0" 11 | PREFIX ${PADDLE_SOURCES_DIR} 12 | CONFIGURE_COMMAND mkdir -p ${PADDLE_INSTALL_DIR} && cd ${PADDLE_INSTALL_DIR} && ${CMAKE_COMMAND} -DCMAKE_INSTALL_PREFIX=${PADDLE_INSTALL_DIR} 13 | -DCMAKE_BUILD_TYPE=Release -DWITH_PYTHON=OFF -DWITH_MKL=ON -DWITH_MKLDNN=OFF -DWITH_GPU=OFF -DWITH_FLUID_ONLY=ON 14 | BUILD_COMMAND cd ${PADDLE_INSTALL_DIR} && make -j16 15 | INSTALL_COMMAND cd ${PADDLE_INSTALL_DIR} && make inference_lib_dist 16 | UPDATE_COMMAND "" 17 | ) 18 | 19 | add_custom_command(TARGET extern_paddle POST_BUILD 20 | COMMAND mkdir -p third_party/include/paddle/ third_party/lib 21 | COMMAND cp -rf ${PADDLE_INSTALL_DIR}/fluid_install_dir/paddle/fluid third_party/include/paddle 22 | COMMAND cp -rf ${PADDLE_INSTALL_DIR}/fluid_install_dir/paddle/fluid/inference/lib* third_party/lib 23 | COMMAND cp -rf ${PADDLE_INSTALL_DIR}/fluid_install_dir/third_party/install/mklml/include/* ${THIRD_PARTY_PATH}/include/ 24 | COMMAND cp -rf ${PADDLE_INSTALL_DIR}/fluid_install_dir/third_party/install/mklml/lib/* ${THIRD_PARTY_PATH}/lib/ 25 | COMMAND cp -rf ${PADDLE_INSTALL_DIR}/fluid_install_dir/third_party/boost ${PADDLE_INSTALL_DIR}/fluid_install_dir/third_party/install/boost_1_41_0 26 | ) 27 | -------------------------------------------------------------------------------- /include/retrieval/term/contain_solr_q_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_CONTAIN_SOLR_Q_BUILDER_H 16 | #define BAIDU_NLP_ANYQ_CONTAIN_SOLR_Q_BUILDER_H 17 | 18 | #include "retrieval/term/solr_q_interface.h" 19 | 20 | namespace anyq 21 | { 22 | // 黑、白名单term的solr表达式构造插件 23 | class ContainSolrQBuilder : public SolrQInterface 24 | { 25 | public: 26 | ContainSolrQBuilder() {}; 27 | virtual ~ContainSolrQBuilder() override {}; 28 | virtual int init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) override; 29 | int term_contain(const std::string& term, 30 | std::string& contain_terms, 31 | std::string& exclude_terms); 32 | virtual int make_q(const AnalysisResult& analysis_res, 33 | int analysis_idx, 34 | std::string& q) override; 35 | 36 | private: 37 | DualDictWrapper *_p_dual_dict_wrapper; 38 | DISALLOW_COPY_AND_ASSIGN(ContainSolrQBuilder); 39 | }; 40 | 41 | } // namespace anyq 42 | #endif // BAIDU_NLP_ANYQ_CONTAIN_SOLR_Q_BUILDER_H 43 | -------------------------------------------------------------------------------- /demo/feature_dump.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include "common/plugin_header.h" 17 | #include "dict/dict_manager.h" 18 | #include "strategy/anyq_strategy.h" 19 | 20 | int main(int argc, char* argv[]){ 21 | google::InitGoogleLogging(argv[0]); 22 | FLAGS_stderrthreshold = google::INFO; 23 | if (argc != 5) { 24 | FATAL_LOG("Usage: ./output/bin/feature_dump_tool anyq_dict_dir " 25 | "anyq_conf_dir query_file feature_file"); 26 | return -1; 27 | } 28 | 29 | anyq::DictManager dm; 30 | if (dm.load_dict(argv[1]) != 0) { 31 | FATAL_LOG("load dict error"); 32 | return -1; 33 | } 34 | anyq::AnyqStrategy anyq_strategy; 35 | if (anyq_strategy.create_resource(dm, argv[2]) != 0) { 36 | FATAL_LOG("create resource error"); 37 | return -1; 38 | } 39 | if (anyq_strategy.dump_feature(argv[3], argv[4]) != 0) { 40 | FATAL_LOG("feature dump failed!"); 41 | return -1; 42 | } 43 | 44 | return 0; 45 | } 46 | -------------------------------------------------------------------------------- /include/common/http_client.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_HTTP_CLIENT_H 16 | #define BAIDU_NLP_ANYQ_HTTP_CLIENT_H 17 | 18 | #include 19 | #include 20 | #include 21 | #include "common/common_define.h" 22 | 23 | namespace anyq { 24 | 25 | class HttpClient { 26 | public: 27 | HttpClient(); 28 | ~HttpClient(); 29 | // get请求 30 | int curl_get(const char* url, std::string* buffer); 31 | // post请求 32 | int curl_post(const char* url, const std::map& para_map, 33 | std::string* buffer); 34 | int curl_post(const char* url, const std::string& para_str, std::string* buffer, 35 | const std::string& header_str); 36 | // 请求结束后,读取返回数据回调函数 37 | static int str_write_callback(char* data, size_t size, size_t nmemb, std::string* buffer); 38 | 39 | private: 40 | // curl指针 41 | CURL* _p_curl; 42 | DISALLOW_COPY_AND_ASSIGN(HttpClient); 43 | }; 44 | } // namespace anyq 45 | 46 | #endif // BAIDU_NLP_ANYQ_HTTP_CLIENT_H 47 | -------------------------------------------------------------------------------- /tools/solr/sample_docs: -------------------------------------------------------------------------------- 1 | question answer 2 | 需要使用什么账号登录? 您需要拥有一个百度账号,用来登录百度云,可以点击此处注册百度账户。如您以前拥有百度推广账户,同样可以登录百度云。 3 | 注册百度账户时收不到验证码怎么办? 由于欠费停机、存储信息已满、信号网络延迟等原因没有及时收到验证码,这时请检查您的手机及话费余额,保证手机可正常接收短信后,请尝试重新获取验证码。 4 | AI服务支持推广账号使用么? 支持推广账户使用。 5 | 为什么登录到百度云还要填写手机号、邮箱等信息? 如果您是初次使用百度云,我们需要收集一些您的几个核心信息,用于做开发者认证,这些信息也会作为您使用产品过程中,我们与您取得联系的重要联系方式。如您之前已经是百度云用户、百度开发者中心用户,此步骤将会自动省略。 6 | 我以前是百度开发者中心用户,还需要进行开发者认证么? 不需要。我们会自动同步您的开发者信息,但是为保证后续使用中可以及时联系到您,可能会提示您重新补充最新的开发者信息(手机号等)。 7 | 目前都开放了哪些服务? 目前百度语音、文字识别、人脸识别、自然语言处理、图像审核、知识图谱,这六项技术您可以直接在控制台中使用,也可以通过在百度AI开放平台官网,提交商务合作需求。 8 | 每个服务的请求配额都是免费的么? 目前我们为每个账户下的每项API服务,都设置了固定的免费请求配额,便于您体验服务及应用调试。 9 | 每个服务的请求配额有限制么? 目前在同一账号下,每个API服务都有免费的请求配额,您可以在对应服务的控制台中查看。付费服务不限请求数量,即用即扣。 10 | 我可以创建多少个应用? 每项服务目前可创建最多100个应用,需要注意的是:每项服务下的所有应用,将会共享您该项服务的请求配额。 11 | 目前这些服务免费吗?能够保证QPS吗? 目前百度AI开放平台的绝大多数的基础技术能力都是免费的(包括语音识别、语音合成、语音唤醒等等),每天都有既定的配额,如果不够可以填写申请。 语音技术方向的申请方法:请登录控制台,点击百度语音,选择应用列表,选择对应应用,查看详情,点击申请提高配额,一般会在2个工作日内完成审核,审核通过后,将可无限调用。 其他方向申请方式:工单免费申请配额和QPS,官方根据您的应用场景和需求审核后评估调整配额的额度。 还有一些技术服务,已经陆续推出付费商用方案(比如文字识别方向等),为您提供更多维度的支持,您可根据自己的需求定制化自由调用,全方位保障您的产品需求。 12 | 我是百度云的老用户,可以使用百度云的AK/SK么? 目前文字识别、人脸识别、自然语言处理、图像审核、知识图谱在后台都可以使用百度云AK/SK调用,请求限额相同。非常抱歉的是语音服务暂不支持百度云AK/SK调用,我们会尽快完善,给您带来的不便深感抱歉。 13 | 支持哪些语言的服务端SDK? 目前支持各项服务的Java、PHP、Python、C#、Node.js版本服务端SDK,我们会尽快陆续推出更多语言支持,请您持续关注。 14 | 我有一些定制化需求,如何与你们取得联系? 您可以通过以下两种方式与我们联系: 15 | 如果我正在做一个比较大型的落地项目,需要更多配额如何接洽? 目前百度AI开放平台大部分产品是免费的,如果您合理化接入应用,有一定的合理化应用场景,可以通过【工单】或者官方右侧【合作咨询】,说明您的使用场景和预期的配额量级,进行申请,我们评估后会尽快满足您的需求:给您免费提高配额,让您充分试用我们的产品;或是有专业负责人与您商务对接,为您的调用保驾护航。其他付费服务-如文字识别方向,已经可以自助付费使用,充值付费后,调用量不再受限。 16 | 目前除了免费部分,是否支持付费,计费价目表是怎样的? 目前大部分开放服务是免费的,付费技术服务在技术介绍页最下方以及您的控制台中,都会有相应的免费配额、付费计价的介绍。 17 | 除了免费的百度AI技术服务,付费如何充值? 我们将统一使用百度云的账户计费,您只需在百度云中充值即可,依据不同付费方案,将会从您的账户余额中扣费。 18 | -------------------------------------------------------------------------------- /include/analysis/analysis_strategy.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_ANALYSIS_STRATEGY_H 16 | #define BAIDU_NLP_ANYQ_ANALYSIS_STRATEGY_H 17 | 18 | #include "analysis/method_interface.h" 19 | #include "anyq.pb.h" 20 | #include "dict/dict_manager.h" 21 | 22 | namespace anyq { 23 | // 对输入query进行分析, 24 | class AnalysisStrategy { 25 | public: 26 | AnalysisStrategy(); 27 | ~AnalysisStrategy(); 28 | 29 | // 使用词典和配置初始化 30 | int init(DictMap* dict_map, const std::string& analysis_conf); 31 | // 运行Analysis的插件 32 | int run_strategy(const std::string& analysis_input_str, AnalysisResult& analysis_result); 33 | 34 | int destroy(); 35 | private: 36 | // 输入格式转换 37 | int json_parser(const std::string& analysis_input_str, AnalysisResult& analysis_result); 38 | 39 | // 插件列表 40 | std::vector _method_list; 41 | 42 | DictMap* _dict_map; 43 | DISALLOW_COPY_AND_ASSIGN(AnalysisStrategy); 44 | }; 45 | 46 | } // namespace anyq 47 | 48 | #endif //BAIDU_NLP_ANYQ_ANALYSIS_STRATEGY_H 49 | -------------------------------------------------------------------------------- /include/strategy/anyq_strategy.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_ANYQ_STRATEGY_H 16 | #define BAIDU_NLP_ANYQ_ANYQ_STRATEGY_H 17 | 18 | #include "dict/dict_manager.h" 19 | #include "analysis/analysis_strategy.h" 20 | #include "retrieval/retrieval_strategy.h" 21 | #include "rank/rank_strategy.h" 22 | #include "common/utils.h" 23 | 24 | namespace anyq{ 25 | //Anyq总体策略类 26 | class AnyqStrategy { 27 | public: 28 | AnyqStrategy(); 29 | ~AnyqStrategy(); 30 | // 创建线程资源 31 | int create_resource(DictManager& dm, const std::string& conf_path); 32 | /// 释放线程资源 33 | int release_resource(); 34 | // anyq策略运行接口 35 | int run_strategy(const std::string& analysis_input, ANYQResult& result); 36 | // 输出lib-svm格式的特征 37 | int dump_feature(const std::string& input_file, const std::string& out_file); 38 | private: 39 | AnalysisStrategy _analysis; 40 | RetrievalStrategy _retrieval; 41 | RankStrategy _rank; 42 | DISALLOW_COPY_AND_ASSIGN(AnyqStrategy); 43 | }; 44 | 45 | } 46 | 47 | #endif //BAIDU_NLP_ANYQ_ANYQ_STRATEGY_H 48 | -------------------------------------------------------------------------------- /include/server/request_postprocess_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_REQUEST_POSTPROCESS_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_REQUEST_POSTPROCESS_INTERFACE_H 17 | 18 | #include "json/json.h" 19 | #include "common/utils.h" 20 | #include "anyq.pb.h" 21 | 22 | namespace anyq { 23 | 24 | class ReqPostprocInterface 25 | { 26 | public: 27 | ReqPostprocInterface() {}; 28 | virtual ~ReqPostprocInterface() {}; 29 | int init_base(const std::string& plugin_name){ 30 | _plugin_name = plugin_name; 31 | return 0; 32 | }; 33 | virtual int init(const ServerConfig&) = 0; 34 | virtual int destroy() = 0; 35 | const std::string& plugin_name(){ 36 | return _plugin_name; 37 | } 38 | virtual int process(ANYQResult& anyq_result, 39 | Json::Value& parameters, 40 | std::string& output) = 0; 41 | 42 | private: 43 | std::string _plugin_name; 44 | DISALLOW_COPY_AND_ASSIGN(ReqPostprocInterface); 45 | }; 46 | 47 | } // namespace anyq 48 | 49 | #endif // BAIDU_NLP_ANYQ_REQUEST_PROCESSOR_INTERFACE_H 50 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/losses/hinge_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import layers.paddle_layers as layers 18 | 19 | 20 | class HingeLoss(object): 21 | """ 22 | Hing Loss Calculate class 23 | """ 24 | def __init__(self, conf_dict): 25 | """ 26 | initialize 27 | """ 28 | self.margin = conf_dict["loss"]["margin"] 29 | 30 | def compute(self, pos, neg): 31 | """ 32 | compute loss 33 | """ 34 | elementwise_max = layers.ElementwiseMaxLayer() 35 | elementwise_add = layers.ElementwiseAddLayer() 36 | elementwise_sub = layers.ElementwiseSubLayer() 37 | constant = layers.ConstantLayer() 38 | reduce_mean = layers.ReduceMeanLayer() 39 | loss = reduce_mean.ops( 40 | elementwise_max.ops( 41 | constant.ops(neg, neg.shape, "float32", 0.0), 42 | elementwise_add.ops( 43 | elementwise_sub.ops(neg, pos), 44 | constant.ops(neg, neg.shape, "float32", self.margin)))) 45 | return loss 46 | -------------------------------------------------------------------------------- /include/server/request_preprocess_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_REQUEST_PREPROCESS_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_REQUEST_PREPROCESS_INTERFACE_H 17 | 18 | #include "json/json.h" 19 | #include "brpc/server.h" 20 | #include "common/utils.h" 21 | #include "anyq.pb.h" 22 | 23 | namespace anyq{ 24 | 25 | class ReqPreprocInterface{ 26 | public: 27 | ReqPreprocInterface(){}; 28 | virtual ~ReqPreprocInterface() {}; 29 | int init_base(const std::string& plugin_name){ 30 | _plugin_name = plugin_name; 31 | return 0; 32 | } 33 | virtual int init(const ReqPreprocPluginConfig& config) = 0; 34 | virtual int destroy() = 0; 35 | const std::string& plugin_name() { 36 | return _plugin_name; 37 | } 38 | virtual int process(brpc::Controller* cntl, 39 | Json::Value& parameters, 40 | std::string& str_input) = 0; 41 | 42 | private: 43 | std::string _plugin_name; 44 | DISALLOW_COPY_AND_ASSIGN(ReqPreprocInterface); 45 | }; 46 | 47 | } 48 | 49 | #endif // BAIDU_NLP_ANYQ_REQUEST_PREPROCESS_INTERFACE_H 50 | -------------------------------------------------------------------------------- /docs/semantic_retrieval_tutorial.md: -------------------------------------------------------------------------------- 1 | # 添加语义索引 2 | 3 | ## 语义索引库构建 4 | 1. 将灌库文件faq_file(utf8编码)转换成Json格式: 5 | ``` 6 | cp -rp ../tool/solr ./solr_script 7 | mkdir -p faq 8 | python solr_script/make_json.py solr_script/sample_docs faq/schema_format faq/faq_json 9 | ``` 10 | 11 | 2. 对Json文本添加索引id: 12 | 13 | ``` 14 | awk -F "\t" '{print ++ind"\t"$0}' faq/faq_json > faq/faq_json.index 15 | ``` 16 | 17 | 3. 在anyq词典配置dict.conf增加语义表示模型的插件: 18 | 19 | ``` 20 | dict_config{ 21 | name: "fluid_simnet" 22 | type: "PaddleSimAdapter" 23 | path: "./simnet" 24 | } 25 | ``` 26 | 27 | 4. 在analysis.conf中增加query语义表示的插件: 28 | 29 | ``` 30 | analysis_method { 31 | name: "method_simnet_emb" 32 | type: "AnalysisSimNetEmb" 33 | using_dict_name: "fluid_simnet" 34 | dim: 128 35 | query_feed_name: "left" 36 | cand_feed_name: "right" 37 | embedding_fetch_name: "tanh.tmp" 38 | } 39 | ``` 40 | 41 | 5. 生成语义索引库: 42 | 43 | ``` 44 | ./annoy_index_build_tool example/conf/ example/conf/analysis.conf faq/faq_json.index 128 10 semantic.annoy 1>std 2>err 45 | ``` 46 | 47 | ## 语义索引库使用 48 | 49 | 1. 把带索引id的faq库文件和语义索引库放到anyq配置目录下 50 | 51 | ``` 52 | cp faq/faq_json.index semantic.annoy example/conf 53 | ``` 54 | 55 | 2. 在dict.conf中配置faq库文件的读取 56 | 57 | ``` 58 | dict_config { 59 | name: "annoy_knowledge_dict" 60 | type: "String2RetrievalItemAdapter" 61 | path: "./faq_json.index" 62 | } 63 | ``` 64 | 65 | 3. 在retrieval.conf中配置语义检索插件 66 | 67 | ``` 68 | retrieval_plugin { 69 | name : "semantic_recall" 70 | type : "SemanticRetrievalPlugin" 71 | vector_size : 128 72 | search_k : 10000 73 | index_path : "./example/conf/semantic.annoy" 74 | using_dict_name: "annoy_knowledge_dict" 75 | num_result : 10 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /src/rank/predictor/predict_select_model.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "rank/predictor/predictor_interface.h" 16 | 17 | namespace anyq { 18 | 19 | int PredictSelectModel::init(DualDictWrapper* dict, 20 | const std::vector& feature_names, 21 | const RankPredict& predict_config) { 22 | // 获取选择特征的下标 23 | _select_index = predict_config.select_config().select_index(); 24 | return 0; 25 | } 26 | 27 | int PredictSelectModel::destroy() { 28 | return 0; 29 | } 30 | 31 | int PredictSelectModel::predict(RankResult& candidates) { 32 | for (size_t i = 0; i < candidates.size(); i++) { 33 | // 无效候选,跳过 34 | if (candidates[i].abandoned) { 35 | candidates[i].ltr_score = 0.0f; 36 | continue; 37 | } 38 | float score = 0.0f; 39 | // 用下标为_select_index的特征值作为得分 40 | if (candidates[i].features.size() > _select_index) { 41 | score = candidates[i].features[_select_index]; 42 | } 43 | candidates[i].ltr_score = score; 44 | } 45 | return 0; 46 | } 47 | 48 | } // namespace anyq 49 | -------------------------------------------------------------------------------- /cmake/external/tensorflow.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(TENSORFLOW_SOURCES_DIR ${THIRD_PARTY_PATH}/tensorflow) 4 | SET(TENSORFLOW_INSTALL_DIR ${THIRD_PARTY_PATH}/) 5 | SET(TENSORFLOW_DOWNLOAD_DIR "${TENSORFLOW_SOURCES_DIR}/src/") 6 | SET(TENSORFLOW_URL "https://codeload.github.com/tensorflow/tensorflow/zip/v1.8.0") 7 | ExternalProject_Add( 8 | extern_tensorflow 9 | ${EXTERNAL_PROJECT_LOG_ARGS} 10 | DEPENDS extern_bazel 11 | DOWNLOAD_DIR ${TENSORFLOW_DOWNLOAD_DIR} 12 | DOWNLOAD_COMMAND wget --no-check-certificate ${TENSORFLOW_URL} -c -O tensorflow-1.8.0.zip 13 | && unzip tensorflow-1.8.0.zip 14 | DOWNLOAD_NO_PROGRESS 1 15 | PREFIX ${TENSORFLOW_SOURCES_DIR} 16 | BUILD_COMMAND export JAVA_HOME=${THIRD_PARTY_PATH}/jdk-1.8 17 | && cd ${TENSORFLOW_DOWNLOAD_DIR}/tensorflow-1.8.0 18 | && ${THIRD_PARTY_PATH}/bin/bazel build //tensorflow:libtensorflow_cc.so 19 | UPDATE_COMMAND "" 20 | CONFIGURE_COMMAND "" 21 | INSTALL_COMMAND "" 22 | BUILD_IN_SOURCE 1 23 | ) 24 | 25 | add_custom_command(TARGET extern_tensorflow POST_BUILD 26 | COMMAND mkdir -p ${TENSORFLOW_INSTALL_DIR}/include/tf 27 | COMMAND mkdir -p ${TENSORFLOW_INSTALL_DIR}/lib 28 | COMMAND cp -rf ${TENSORFLOW_DOWNLOAD_DIR}/tensorflow-1.8.0/bazel-bin/tensorflow/lib*.so ${TENSORFLOW_INSTALL_DIR}/lib 29 | COMMAND cp -rf ${TENSORFLOW_DOWNLOAD_DIR}/tensorflow-1.8.0/bazel-genfiles/* ${TENSORFLOW_INSTALL_DIR}/include/tf 30 | COMMAND cp -rf ${TENSORFLOW_DOWNLOAD_DIR}/tensorflow-1.8.0/tensorflow ${TENSORFLOW_INSTALL_DIR}/include/tf 31 | COMMAND cp -rf ${TENSORFLOW_DOWNLOAD_DIR}/tensorflow-1.8.0/third_party ${TENSORFLOW_INSTALL_DIR}/include/tf 32 | ) 33 | -------------------------------------------------------------------------------- /src/dict/wordseg_adapter.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "dict/dict_adapter.h" 16 | #include "common/utils.h" 17 | 18 | namespace anyq { 19 | 20 | WordsegAdapter::WordsegAdapter() { 21 | } 22 | 23 | WordsegAdapter::~WordsegAdapter() { 24 | } 25 | 26 | int WordsegAdapter::load(const std::string& path, const DictConfig& config) { 27 | WordsegPack *p_wordseg_pack = new WordsegPack(); 28 | p_wordseg_pack->lexer_dict = NULL; 29 | p_wordseg_pack->lexer_dict = lac_create(path.c_str()); 30 | if (p_wordseg_pack->lexer_dict == NULL) 31 | { 32 | FATAL_LOG("wordseg dict load error"); 33 | return -1; 34 | } 35 | 36 | TRACE_LOG("wordseg dict load success."); 37 | set_dict((void*)p_wordseg_pack); 38 | return 0; 39 | } 40 | 41 | int WordsegAdapter::release() { 42 | void* dict = get_dict(); 43 | if (dict != NULL) { 44 | WordsegPack* p_wordseg_pack = static_cast(dict); 45 | lac_destroy(p_wordseg_pack->lexer_dict); 46 | p_wordseg_pack->lexer_dict = NULL; 47 | delete p_wordseg_pack; 48 | set_dict(NULL); 49 | } 50 | return 0; 51 | } 52 | 53 | } // namespace anyq 54 | -------------------------------------------------------------------------------- /include/analysis/method_simnet_emb.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_METHOD_SIMNET_EMB_H 16 | #define BAIDU_NLP_ANYQ_METHOD_SIMNET_EMB_H 17 | 18 | #ifndef USE_TENSORFLOW 19 | #include "analysis/method_interface.h" 20 | #include "common/paddle_thread_resource.h" 21 | #include "anyq.pb.h" 22 | 23 | namespace anyq { 24 | // 获取SimNet网络的Query表示向量 25 | class AnalysisSimNetEmb : public AnalysisMethodInterface { 26 | public: 27 | AnalysisSimNetEmb(); 28 | virtual ~AnalysisSimNetEmb() override; 29 | virtual int init(DualDictWrapper* dict, const AnalysisMethodConfig& analysis_method) override; 30 | virtual int destroy() override; 31 | // 复用Interface的method_process, 实现自己的single_process 32 | virtual int single_process(AnalysisItem& analysis_item) override; 33 | 34 | private: 35 | PaddlePack* _p_paddle_pack; 36 | PaddleThreadResource* _paddle_resource; 37 | size_t _query_feed_index; 38 | size_t _cand_feed_index; 39 | size_t _embedding_fetch_index; 40 | size_t _dim; 41 | DISALLOW_COPY_AND_ASSIGN(AnalysisSimNetEmb); 42 | }; 43 | 44 | } //namespace anyq 45 | #endif 46 | 47 | #endif // BAIDU_NLP_ANYQ_METHOD_SIMNET_EMB_H 48 | -------------------------------------------------------------------------------- /include/retrieval/term/boost_solr_q_builder.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_BOOST_SOLR_Q_BUILDER_H 16 | #define BAIDU_NLP_ANYQ_BOOST_SOLR_Q_BUILDER_H 17 | 18 | #include "retrieval/term/solr_q_interface.h" 19 | 20 | namespace anyq { 21 | //term带权重的solr索引表达式构造插件 22 | class BoostSolrQBuilder : public SolrQInterface { 23 | public: 24 | BoostSolrQBuilder() {}; 25 | virtual ~BoostSolrQBuilder() override {}; 26 | virtual int init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) override; 27 | int term_weight(const std::string& term, 28 | std::string& high_freq_token_q, 29 | std::string& low_freq_token_q_with_stopword, 30 | std::string& low_freq_token_q_without_stopword); 31 | virtual int make_q(const AnalysisResult& analysis_res, 32 | int analysis_idx, 33 | std::string& q) override; 34 | 35 | private: 36 | void solr_wrapper(const std::string& term, std::string& solr_q, float weight); 37 | DualDictWrapper *_p_dual_dict_wrapper; 38 | DISALLOW_COPY_AND_ASSIGN(BoostSolrQBuilder); 39 | }; 40 | 41 | } // namespace anyq 42 | #endif // BAIDU_NLP_ANYQ_BOOST_SOLR_Q_BUILDER_H 43 | -------------------------------------------------------------------------------- /src/matching/lexical/jaccard_sim.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "matching/lexical/jaccard_sim.h" 16 | 17 | namespace anyq { 18 | 19 | JaccardSimilarity::JaccardSimilarity(){ 20 | } 21 | 22 | JaccardSimilarity::~JaccardSimilarity(){ 23 | destroy(); 24 | } 25 | 26 | int JaccardSimilarity::init(DualDictWrapper* dict, const MatchingConfig& matching_config) { 27 | return 0; 28 | } 29 | 30 | int JaccardSimilarity::destroy() { 31 | return 0; 32 | } 33 | 34 | int JaccardSimilarity::compute_similarity(const AnalysisResult& analysis_res, 35 | RankResult& candidates) { 36 | if (analysis_res.analysis.size() < 1) { 37 | return -1; 38 | } 39 | 40 | for (size_t i = 0; i < candidates.size(); i++) { 41 | // 无效候选,跳过 42 | if (candidates[i].abandoned) { 43 | continue; 44 | } 45 | // 计算query与候选的字符串Jaccard相似度 46 | float jaccard_sim = jaccard_similarity(analysis_res.analysis[0].query, 47 | candidates[i].match_info.text); 48 | DEBUG_LOG("jaccard = %f", jaccard_sim); 49 | candidates[i].features.push_back(jaccard_sim); 50 | } 51 | 52 | return 0; 53 | } 54 | 55 | } // namespace anyq 56 | -------------------------------------------------------------------------------- /include/matching/semantic/simnet_paddle_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SIMNET_PADDLE_SIM_H 16 | #define BAIDU_NLP_ANYQ_SIMNET_PADDLE_SIM_H 17 | 18 | #ifndef USE_TENSORFLOW 19 | #include "matching/matching_interface.h" 20 | #include "common/paddle_thread_resource.h" 21 | 22 | namespace anyq { 23 | //基于paddle的语义相似度特征 24 | class PaddleSimilarity : public MatchingInterface { 25 | public: 26 | PaddleSimilarity(); 27 | virtual ~PaddleSimilarity() override; 28 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 29 | virtual int destroy() override; 30 | virtual int compute_similarity(const AnalysisResult& analysis_res, 31 | RankResult& candidates) override; 32 | 33 | private: 34 | // paddle dict 指针 35 | PaddlePack* _p_paddle_pack; 36 | PaddleThreadResource* _paddle_resource; 37 | // paddle线程资源 38 | size_t _query_feed_index; 39 | // "query" 在feed中的下标 40 | size_t _cand_feed_index; 41 | // "score"在fetch中的下标 42 | size_t _score_fetch_index; 43 | DISALLOW_COPY_AND_ASSIGN(PaddleSimilarity); 44 | }; 45 | 46 | } // namespace anyq 47 | #endif 48 | 49 | #endif // BAIDU_NLP_ANYQ_PADDLE_SIM_FEATURE_H 50 | 51 | -------------------------------------------------------------------------------- /include/retrieval/semantic/semantic_retrieval.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SEMANTIC_RETRIEVAL_H 16 | #define BAIDU_NLP_ANYQ_SEMANTIC_RETRIEVAL_H 17 | 18 | #include "annoylib.h" 19 | #include "kissrandom.h" 20 | #include "retrieval/retrieval_interface.h" 21 | #include "dict/dict_adapter.h" 22 | 23 | namespace anyq { 24 | // 语义索引插件 25 | class SemanticRetrievalPlugin : public RetrievalPluginInterface{ 26 | public: 27 | SemanticRetrievalPlugin() {}; 28 | virtual ~SemanticRetrievalPlugin() override {}; 29 | virtual int init(DictMap* dict_map, const RetrievalPluginConfig& plugin_config) override; 30 | virtual int destroy() override; 31 | virtual int retrieval(const AnalysisResult& analysis_res, 32 | RetrievalResult& retrieval_res) override; 33 | 34 | private: 35 | // annoy索引字典路径 36 | std::string _index_path; 37 | // query语义向量维度 38 | uint32_t _vector_size; 39 | uint32_t _search_k; 40 | // annoy语义索引 41 | AnnoyIndex* _annoy_index; 42 | // id到候选的映射知识库 43 | String2RetrievalItemAdapter* _knowledge_dict; 44 | DISALLOW_COPY_AND_ASSIGN(SemanticRetrievalPlugin); 45 | }; 46 | 47 | } // namespace anyq 48 | #endif // BAIDU_NLP_ANYQ_SEMANTIC_RETRIEVAL_H 49 | -------------------------------------------------------------------------------- /include/matching/semantic/simnet_tf_sim.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SIMNET_TF_SIM_H 16 | #define BAIDU_NLP_ANYQ_SIMNET_TF_SIM_H 17 | 18 | #ifdef USE_TENSORFLOW 19 | #include "matching/matching_interface.h" 20 | 21 | namespace anyq { 22 | // 基于tensorflow的语义相似特征 23 | class TFSimilarity : public MatchingInterface { 24 | public: 25 | TFSimilarity(); 26 | virtual ~TFSimilarity() override; 27 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) override; 28 | virtual int destroy() override; 29 | virtual int compute_similarity(const AnalysisResult& analysis_res, 30 | RankResult& candidates) override; 31 | // 转化成tensorflow的输入 32 | int trans_to_tf_input( 33 | std::vector > &input, 34 | std::string input_name, 35 | std::vector& vec); 36 | private: 37 | tensorflow::Session* _p_session; 38 | TFPack* _p_tf_pack; 39 | std::string _left_name; 40 | std::string _right_name; 41 | std::string _output_tensor_name; 42 | int _pad_id; 43 | int _sen_len; 44 | DISALLOW_COPY_AND_ASSIGN(TFSimilarity); 45 | }; 46 | 47 | } // namespace anyq 48 | #endif 49 | 50 | #endif //BAIDU_NLP_ANYQ_TF_SIM_H 51 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/optimizers/paddle_optimizers.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import paddle.fluid as fluid 18 | 19 | 20 | class SGDOptimizer(object): 21 | """ 22 | SGD 23 | """ 24 | def __init__(self, conf_dict): 25 | """ 26 | initialize 27 | """ 28 | self.learning_rate = conf_dict["optimizer"]["learning_rate"] 29 | 30 | def ops(self, loss): 31 | """ 32 | SGD optimizer operation 33 | """ 34 | sgd = fluid.optimizer.SGDOptimizer(self.learning_rate) 35 | sgd.minimize(loss) 36 | 37 | 38 | class AdamOptimizer(object): 39 | """ 40 | Adam 41 | """ 42 | def __init__(self, conf_dict): 43 | """ 44 | initialize 45 | """ 46 | self.learning_rate = conf_dict["optimizer"]["learning_rate"] 47 | self.beta1 = conf_dict["optimizer"]["beta1"] 48 | self.beta2 = conf_dict["optimizer"]["beta2"] 49 | self.epsilon = conf_dict["optimizer"]["epsilon"] 50 | 51 | def ops(self, loss): 52 | """ 53 | Adam optimizer operation 54 | """ 55 | adam = fluid.optimizer.AdamOptimizer( 56 | self.learning_rate, beta1=self.beta1, beta2=self.beta2, epsilon=self.epsilon) 57 | adam.minimize(loss) 58 | -------------------------------------------------------------------------------- /include/dict/dict_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_DICT_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_DICT_INTERFACE_H 17 | 18 | #include 19 | 20 | #include "anyq.pb.h" 21 | #include "common/common_define.h" 22 | 23 | namespace anyq 24 | { 25 | 26 | // 字典接口 27 | class DictInterface { 28 | public: 29 | // 默认不支持reload 30 | DictInterface() { 31 | _support_reload = false; 32 | _dict = NULL; 33 | } 34 | 35 | // 根据路径和配置加载词典 36 | virtual int load(const std::string& path, const DictConfig& config) = 0; 37 | 38 | virtual int release() = 0; 39 | 40 | // 获取词典 41 | void* get_dict() { 42 | return _dict; 43 | } 44 | 45 | // 是否支持reload 46 | bool support_reload() { 47 | return _support_reload; 48 | } 49 | 50 | virtual ~DictInterface() { 51 | }; 52 | 53 | protected: 54 | // 设置是否支持reload 55 | void set_support_reload(const bool& support_reload){ 56 | _support_reload = support_reload; 57 | } 58 | 59 | void set_dict(void* dict){ 60 | _dict = dict; 61 | } 62 | 63 | private: 64 | bool _support_reload; 65 | void* _dict; 66 | DISALLOW_COPY_AND_ASSIGN(DictInterface); 67 | }; 68 | 69 | } // namespace anyq 70 | #endif //BAIDU_NLP_ANYQ_DICT_INTERFACE_H 71 | -------------------------------------------------------------------------------- /include/dict/dual_dict_wrapper.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_DUAL_DICT_WRAPPER_H 16 | #define BAIDU_NLP_ANYQ_DUAL_DICT_WRAPPER_H 17 | 18 | #include 19 | #include 20 | 21 | #include "dict/dict_interface.h" 22 | 23 | namespace anyq 24 | { 25 | // 封装DictInterface类,通过Dual Dict 实现reload机制 26 | class DualDictWrapper { 27 | public: 28 | // 保存dict的路径和配置 29 | DualDictWrapper(const std::string& conf_path, const DictConfig& config); 30 | 31 | int reload(); 32 | 33 | int release(); 34 | 35 | DictInterface* get(); 36 | 37 | void* get_dict(); 38 | 39 | ~DualDictWrapper(); 40 | 41 | bool is_reload_able() { 42 | return _reload_able; 43 | } 44 | 45 | std::string get_dict_name() { 46 | return _dict_name; 47 | } 48 | 49 | public: 50 | std::string _dict_name; 51 | std::string _dict_path; 52 | DictConfig _config; 53 | bool _reload_able; 54 | 55 | // 双词典 56 | DictInterface* _dual_dict[2]; 57 | 58 | // 词典标示,通过判断标示变化决定是否reload 59 | std::string _last_identifier; 60 | 61 | // 互斥锁 62 | std::mutex _mutex; 63 | 64 | // 当前使用的词典index 65 | uint8_t _cur_dict; 66 | }; 67 | 68 | } // namespace anyq 69 | #endif //BAIDU_NLP_ANYQ_DUAL_DICT_WRAPPER_H 70 | -------------------------------------------------------------------------------- /src/analysis/method_query_intervene.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "analysis/method_query_intervene.h" 16 | #include "common/utils.h" 17 | 18 | namespace anyq{ 19 | 20 | AnalysisQueryIntervene::AnalysisQueryIntervene(){ 21 | } 22 | 23 | AnalysisQueryIntervene::~AnalysisQueryIntervene(){ 24 | 25 | } 26 | int AnalysisQueryIntervene::init(DualDictWrapper* dict, const AnalysisMethodConfig& analysis_method){ 27 | _dict = dict; 28 | TRACE_LOG("analysis_method_query_intervene init"); 29 | set_method_name(analysis_method.name()); 30 | return 0; 31 | 32 | }; 33 | 34 | int AnalysisQueryIntervene::destroy(){ 35 | TRACE_LOG("destroy analysis_query_intervene"); 36 | return 0; 37 | 38 | }; 39 | 40 | int AnalysisQueryIntervene::single_process(AnalysisItem& analysis_item) { 41 | TRACE_LOG("method_process analysis_method_query_intervene, query is: %s", analysis_item.query.c_str()); 42 | // 每次调用,获取最新词典 43 | hashmap_str2str* tmp_dict = (hashmap_str2str*)(_dict->get_dict()); 44 | if (tmp_dict->count(analysis_item.query) >= 1) { 45 | analysis_item.query = (*tmp_dict)[analysis_item.query]; 46 | } 47 | TRACE_LOG("method_process analysis_method_query_intervene, query is: %s", analysis_item.query.c_str()); 48 | return 0; 49 | 50 | }; 51 | 52 | } // namespace anyq 53 | -------------------------------------------------------------------------------- /include/analysis/method_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_ANALYSIS_METHOD_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_ANALYSIS_METHOD_INTERFACE_H 17 | 18 | #include 19 | #include 20 | 21 | #include "common/common_define.h" 22 | #include "anyq.pb.h" 23 | #include "dict/dict_adapter.h" 24 | #include "dict/dual_dict_wrapper.h" 25 | 26 | namespace anyq { 27 | 28 | class AnalysisMethodInterface { 29 | public: 30 | AnalysisMethodInterface() { 31 | }; 32 | virtual ~AnalysisMethodInterface() { 33 | }; 34 | 35 | virtual int init(DualDictWrapper* dict, const AnalysisMethodConfig& analysis_method) = 0; 36 | virtual int destroy() = 0; 37 | 38 | // 处理query(可能包含多个,如改写或扩展会增加query) 39 | virtual int method_process(AnalysisResult& analysis_result); 40 | 41 | // 处理单个query 42 | virtual int single_process(AnalysisItem& analysis_item); 43 | 44 | std::string get_method_name() { 45 | return _method_name; 46 | } 47 | 48 | protected: 49 | void set_method_name(const std::string& method_name) { 50 | _method_name = method_name; 51 | } 52 | 53 | private: 54 | std::string _method_name; 55 | DISALLOW_COPY_AND_ASSIGN(AnalysisMethodInterface); 56 | }; 57 | 58 | } // namespace anyq 59 | 60 | #endif //BAIDU_NLP_ANYQ_ANALYSIS_METHOD_INTERFACE_H 61 | -------------------------------------------------------------------------------- /src/matching/lexical/cosine_sim.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "matching/lexical/cosine_sim.h" 16 | 17 | namespace anyq { 18 | 19 | CosineSimilarity::CosineSimilarity(){ 20 | } 21 | 22 | CosineSimilarity::~CosineSimilarity(){ 23 | destroy(); 24 | } 25 | 26 | int CosineSimilarity::init(DualDictWrapper* dict, const MatchingConfig& matching_config) { 27 | return 0; 28 | } 29 | 30 | int CosineSimilarity::destroy() { 31 | return 0; 32 | } 33 | 34 | int CosineSimilarity::compute_similarity(const AnalysisResult& analysis_res, RankResult& candidates) { 35 | if (analysis_res.analysis.size() < 1) { 36 | return -1; 37 | } 38 | 39 | // 传入的参数是const,计算相似度需要排序,所以拷贝一份 40 | std::vector tmp_analysis_tokens = analysis_res.analysis[0].tokens_basic; 41 | for (size_t i = 0; i < candidates.size(); i++) { 42 | // 无效候选,跳过 43 | if (candidates[i].abandoned) { 44 | continue; 45 | } 46 | std::vector tmp_matching_tokens = candidates[i].match_info.tokens_basic; 47 | // 余弦相似度 48 | float cos_sim = cosine_similarity(tmp_analysis_tokens, tmp_matching_tokens); 49 | DEBUG_LOG("cos_sim %f", cos_sim); 50 | candidates[i].features.push_back(cos_sim); 51 | } 52 | 53 | return 0; 54 | } 55 | 56 | } // namespace anyq 57 | -------------------------------------------------------------------------------- /include/dict/dict_manager.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_DICT_MANAGER_H 16 | #define BAIDU_NLP_ANYQ_DICT_MANAGER_H 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "anyq.pb.h" 23 | #include "dict/dict_interface.h" 24 | #include "dict/dual_dict_wrapper.h" 25 | #include "common/utils.h" 26 | 27 | namespace anyq 28 | { 29 | typedef std::string DictName; 30 | typedef std::unordered_map DictMap; 31 | 32 | // 词典管理类,负责load/reload词典,获取词典等操作 33 | class DictManager { 34 | public: 35 | DictManager(); 36 | 37 | // 根据路径,加载所有词典 38 | int load_dict(const std::string conf_path); 39 | 40 | // 释放词典 41 | int release_dict(); 42 | 43 | // 获取所有词典 44 | DictMap* get_dict(); 45 | 46 | ~DictManager(); 47 | // reload函数 48 | void reload_func(); 49 | 50 | bool is_dm_released(); 51 | 52 | private: 53 | // 词典map 54 | DictMap* _all_dict; 55 | 56 | // reload线程 57 | std::thread _reload_thread; 58 | 59 | // 是否有需要reload的词典 60 | bool _need_reload; 61 | 62 | // 词典是否已经释放 63 | bool _dm_released; 64 | 65 | std::mutex _mutex; 66 | 67 | DISALLOW_COPY_AND_ASSIGN(DictManager); 68 | }; 69 | 70 | } // namespace anyq 71 | 72 | 73 | #endif //BAIDU_NLP_ANYQ_DICT_MANAGER_H 74 | -------------------------------------------------------------------------------- /include/retrieval/retrieval_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_RETRIEVAL_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_RETRIEVAL_INTERFACE_H 17 | 18 | #include "anyq.pb.h" 19 | #include "common/utils.h" 20 | #include "dict/dict_interface.h" 21 | #include "dict/dict_manager.h" 22 | 23 | namespace anyq { 24 | 25 | // 检索插件接口 26 | class RetrievalPluginInterface{ 27 | public: 28 | RetrievalPluginInterface() {}; 29 | virtual ~RetrievalPluginInterface() {}; 30 | // 基类初始化 31 | int init_base(const std::string& plugin_name, const int& num_result){ 32 | _plugin_name = plugin_name; 33 | _num_result = num_result; 34 | return 0; 35 | } 36 | 37 | const std::string& plugin_name(){ 38 | return _plugin_name; 39 | } 40 | 41 | uint32_t get_num_result(){ 42 | return _num_result; 43 | } 44 | // 检索插件线程资源初始化 45 | virtual int init(DictMap* dict_map, const RetrievalPluginConfig& plugin_config) = 0; 46 | // 检索插件线程资源销毁 47 | virtual int destroy() = 0; 48 | // 根据query的analysis结果进行检索 49 | virtual int retrieval(const AnalysisResult& analysis_res, RetrievalResult& retrieval_res) = 0; 50 | 51 | private: 52 | std::string _plugin_name; 53 | // 该插件检索召回的候选个数 54 | uint32_t _num_result; 55 | DISALLOW_COPY_AND_ASSIGN(RetrievalPluginInterface); 56 | }; 57 | 58 | } // namespace anyq 59 | 60 | #endif // BAIDU_NLP_ANYQ_RETRIEVAL_INTERFACE_H 61 | -------------------------------------------------------------------------------- /src/retrieval/term/equal_solr_q_builder.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "retrieval/term/equal_solr_q_builder.h" 16 | 17 | namespace anyq { 18 | 19 | int EqualSolrQBuilder::init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) { 20 | init_base(solr_q_config.name(), solr_q_config.solr_field(), solr_q_config.source_name()); 21 | return 0; 22 | } 23 | 24 | int EqualSolrQBuilder::make_q(const AnalysisResult& analysis_res, 25 | int analysis_idx, 26 | std::string& q) { 27 | q = ""; 28 | const std::string& source_name = get_source_name(); 29 | if (source_name == "question") { 30 | q += get_solr_field(); 31 | q.append(":"); 32 | q += escape(analysis_res.analysis[analysis_idx].query); 33 | } else { 34 | // 在info map中 35 | std::map::const_iterator it; 36 | it = analysis_res.info.find(source_name); 37 | std::string field_value = ""; 38 | if (it != analysis_res.info.end()) { 39 | field_value = escape(it->second); 40 | } else { 41 | FATAL_LOG("search filed[%s] not exist in analysis info_map", source_name.c_str()); 42 | return -1; 43 | } 44 | q += get_solr_field(); 45 | q.append(":"); 46 | q += escape(field_value); 47 | } 48 | DEBUG_LOG("equal solr_fetch_q=%s", q.c_str()); 49 | 50 | return 0; 51 | } 52 | 53 | } // namespace anyq 54 | -------------------------------------------------------------------------------- /include/common/plugin_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_PLUGIN_FACTORY_H 16 | #define BAIDU_NLP_ANYQ_PLUGIN_FACTORY_H 17 | 18 | #include "common/utils.h" 19 | 20 | namespace anyq 21 | { 22 | // create插件的回调函数指针 23 | typedef void* (*PluginCreateFunc)(); 24 | // 存放插件类型到插件回调函数的map 25 | typedef std::unordered_map PluginMap; 26 | 27 | class PluginFactory { 28 | public: 29 | // 注册组件回调函数 30 | int register_plugin(std::string plugin_type, PluginCreateFunc create_func); 31 | // 根据组件类型生成一个组件实例, 自己创建的实例自己销毁,工厂不负责 32 | void* create_plugin(std::string plugin_type); 33 | 34 | static PluginFactory& instance(); 35 | 36 | private: 37 | PluginMap _plugin_map; 38 | }; 39 | 40 | #define PLUGIN_FACTORY anyq::PluginFactory::instance() 41 | 42 | class Register { 43 | public: 44 | Register(std::string plugin_type, PluginCreateFunc func) { 45 | PLUGIN_FACTORY.register_plugin(plugin_type, func); 46 | } 47 | }; 48 | 49 | #define REGISTER_PLUGIN(plugin_type) \ 50 | namespace anyq { \ 51 | class plugin_type##Register { \ 52 | public: \ 53 | static void* newInstance() { \ 54 | return new plugin_type; \ 55 | } \ 56 | private: \ 57 | static const Register reg; \ 58 | };\ 59 | const Register plugin_type##Register::reg(#plugin_type, \ 60 | plugin_type##Register::newInstance); \ 61 | } 62 | 63 | } // namespace anyq 64 | 65 | #endif // BAIDU_NLP_ANYQ_PLUGIN_FACTORY_H 66 | -------------------------------------------------------------------------------- /src/server/anyq_postprocessor.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "server/anyq_postprocessor.h" 16 | #include "common/utils.h" 17 | 18 | namespace anyq { 19 | 20 | int AnyqPostprocessor::init(const ServerConfig& config) { 21 | if (!config.postproc_plugin().has_type()) { 22 | FATAL_LOG("ReqPostprocPluginConfig.type unset!"); 23 | return -1; 24 | } 25 | if (!config.postproc_plugin().has_name()) { 26 | FATAL_LOG("ReqPostprocPluginConfig.name unset!"); 27 | return -1; 28 | } 29 | init_base(config.postproc_plugin().name()); 30 | return 0; 31 | } 32 | 33 | int AnyqPostprocessor::destroy() { 34 | return 0; 35 | } 36 | 37 | int AnyqPostprocessor::process(ANYQResult& any_result, 38 | Json::Value& parameters, 39 | std::string& output) { 40 | // do nothing 41 | Json::Value json_anyq_res = Json::Value(Json::arrayValue); 42 | for (size_t i = 0; i < any_result.items.size(); i++) { 43 | Json::Value result_item; 44 | result_item["question"] = any_result.items[i].query; 45 | result_item["answer"] = any_result.items[i].answer; 46 | result_item["confidence"] = any_result.items[i].confidence; 47 | result_item["qa_id"] = any_result.items[i].qa_id; 48 | result_item["json_info"] = any_result.items[i].json_info; 49 | json_anyq_res.append(result_item); 50 | } 51 | output = json_dumps(json_anyq_res); 52 | return 0; 53 | } 54 | 55 | } // namespace anyq 56 | -------------------------------------------------------------------------------- /src/dict/tf_model_adapter.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifdef USE_TENSORFLOW 16 | #include "dict/dict_adapter.h" 17 | #include "common/utils.h" 18 | 19 | namespace anyq { 20 | 21 | TFModelAdapter::TFModelAdapter() {} 22 | 23 | TFModelAdapter::~TFModelAdapter() {} 24 | 25 | int TFModelAdapter::load(const std::string& path, const DictConfig& config) { 26 | TFPack* p_tf_pack = new TFPack(); 27 | std::string dict_path = path + "/term2id"; 28 | int ret = hash_load(dict_path.c_str(), p_tf_pack->term2id); 29 | if (ret != 0) { 30 | FATAL_LOG("term2id dict load error"); 31 | return -1; 32 | } 33 | // tf训练之后将graph 和weights合并到一个文件 34 | std::string model_path = path + "/tf.graph"; 35 | tensorflow::Status status_load = tensorflow::ReadBinaryProto( 36 | tensorflow::Env::Default(), 37 | model_path, 38 | &p_tf_pack->graphdef); 39 | if (!status_load.ok()) { 40 | FATAL_LOG("load tensorflow model error"); 41 | return -1; 42 | } 43 | TRACE_LOG("tf model load success"); 44 | set_dict((void*)p_tf_pack); 45 | return 0; 46 | } 47 | 48 | int TFModelAdapter::release() { 49 | void* dict = get_dict(); 50 | if (dict != NULL) { 51 | TFPack* p_tf_pack = static_cast(dict); 52 | delete p_tf_pack; 53 | set_dict(NULL); 54 | } 55 | return 0; 56 | } 57 | 58 | } // namespace anyq 59 | 60 | #endif 61 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/data/train_pairwise_data: -------------------------------------------------------------------------------- 1 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 3 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 4 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 5 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 6 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 7 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 8 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 9 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 10 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 11 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 12 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 13 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 14 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 15 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 16 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 17 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 18 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 19 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 20 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 21 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 22 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 23 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 24 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 25 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 26 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 27 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 28 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 29 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 30 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 31 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 32 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 33 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 34 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 35 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 36 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 37 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 38 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 39 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 40 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 41 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 42 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 43 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 44 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 45 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 46 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 47 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 48 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 49 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 50 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 51 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 52 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 53 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 54 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 55 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 56 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 57 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 58 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 59 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 60 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 61 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 62 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 63 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/data/train_pairwise_data: -------------------------------------------------------------------------------- 1 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 3 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 4 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 5 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 6 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 7 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 8 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 9 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 10 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 11 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 12 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 13 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 14 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 15 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 16 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 17 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 18 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 19 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 20 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 21 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 22 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 23 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 24 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 25 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 26 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 27 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 28 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 29 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 30 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 31 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 32 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 33 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 34 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 35 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 36 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 37 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 38 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 39 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 40 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 41 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 42 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 43 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 44 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 45 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 46 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 47 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 48 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 49 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 50 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 51 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 52 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 53 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 54 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 55 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 56 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 57 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 58 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 59 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 60 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 61 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 62 | 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 63 | -------------------------------------------------------------------------------- /include/retrieval/term/solr_q_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SOLR_Q_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_SOLR_Q_INTERFACE_H 17 | 18 | #include "anyq.pb.h" 19 | #include "common/utils.h" 20 | #include "dict/dict_interface.h" 21 | #include "dict/dual_dict_wrapper.h" 22 | 23 | namespace anyq { 24 | // solr索引表达式构建插件 25 | class SolrQInterface{ 26 | public: 27 | SolrQInterface() {}; 28 | virtual ~SolrQInterface() {}; 29 | int init_base(const std::string& plugin_name, 30 | const std::string& solr_field, 31 | const std::string& source_name){ 32 | _plugin_name = plugin_name; 33 | _solr_field = solr_field; 34 | _source_name = source_name; 35 | return 0; 36 | } 37 | // 线程初始化 38 | virtual int init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) = 0; 39 | // 构造solr索引表达式 40 | virtual int make_q(const AnalysisResult& analysis_res, int analysis_idx, std::string& q) = 0; 41 | 42 | const std::string& plugin_name(){ 43 | return _plugin_name; 44 | } 45 | const std::string& get_solr_field(){ 46 | return _solr_field; 47 | } 48 | const std::string& get_source_name(){ 49 | return _source_name; 50 | } 51 | 52 | private: 53 | std::string _plugin_name; 54 | // solr检索字段 55 | std::string _solr_field; 56 | // 指定analysis结果中用于检索的内容 57 | std::string _source_name; 58 | DISALLOW_COPY_AND_ASSIGN(SolrQInterface); 59 | }; 60 | 61 | } // namespace anyq 62 | 63 | #endif // BAIDU_NLP_ANYQ_SOLR_Q_INTERFACE_H 64 | -------------------------------------------------------------------------------- /include/common/paddle_thread_resource.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_PADDLE_THREAD_RESOURCE_H 16 | #define BAIDU_NLP_ANYQ_PADDLE_THREAD_RESOURCE_H 17 | 18 | #ifndef USE_TENSORFLOW 19 | #include "dict/dict_adapter.h" 20 | 21 | namespace anyq { 22 | 23 | class PaddleThreadResource{ 24 | public: 25 | PaddleThreadResource(){ 26 | _p_paddle_pack = NULL; 27 | } 28 | ~PaddleThreadResource() {} 29 | int init(PaddlePack* p_paddle_pack); 30 | int destroy(){ 31 | return 0; 32 | } 33 | int run(); 34 | int set_feed(const size_t& index, const std::vector& ids); 35 | const float* get_fetch(const size_t& index); 36 | 37 | const std::vector& get_feed_target_names(){ 38 | return _feed_target_names; 39 | } 40 | 41 | const std::vector& get_fetch_target_names(){ 42 | return _fetch_target_names; 43 | } 44 | 45 | private: 46 | PaddlePack* _p_paddle_pack; 47 | std::unique_ptr _copy_program; 48 | std::vector _feed_target_names; 49 | std::vector _fetch_target_names; 50 | std::string _feed_holder_name; 51 | std::string _fetch_holder_name; 52 | std::vector _feeds; 53 | std::vector _fetchs; 54 | std::map _feed_targets; 55 | std::map _fetch_targets; 56 | DISALLOW_COPY_AND_ASSIGN(PaddleThreadResource); 57 | }; 58 | 59 | } // namespace anyq 60 | #endif 61 | 62 | #endif 63 | -------------------------------------------------------------------------------- /tools/solr/solr_deply.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #set -e # set -o errexit 4 | set -u # set -o nounset 5 | set -o pipefail 6 | 7 | readonly SOLR_STOP_KEY="secret_key" 8 | readonly START="start" 9 | readonly STOP="stop" 10 | 11 | function help() { 12 | echo "usage: sh solr_deply.sh start solr_home solr_port" 13 | echo " sh solr_deply.sh stop solr_home solr_port" 14 | } 15 | 16 | # start solr service 17 | function solr_start() { 18 | if [[ $# -ne 2 ]]; then 19 | echo "usage: solr_start solr_home solr_port" 20 | return 1 21 | fi 22 | local solr_home=$1 23 | local solr_port=$2 24 | # start solr 25 | nohup java \ 26 | -DSTOP.PORT=$((solr_port+1)) \ 27 | -DSTOP.KEY=$SOLR_STOP_KEY \ 28 | -Djetty.port=$solr_port \ 29 | -Dsolr.solr.home=${solr_home}/example/solr/ \ 30 | -Djetty.home=${solr_home}/example/ \ 31 | -jar ${solr_home}/example/start.jar & 32 | 33 | # check solr service 34 | sleep 20s #time needed to start solr, maybe longer 35 | curl "http://localhost:$solr_port/solr" 36 | if [[ $? -ne 0 ]]; then 37 | echo "solr[$solr_port] start failed!" 38 | res=`solr_stop $solr_home $solr_port` 39 | return 1 40 | fi 41 | echo "solr[$solr_port] start success!" 42 | return 0 43 | } 44 | 45 | # stop solr service 46 | function solr_stop() { 47 | if [[ $# -ne 2 ]]; then 48 | echo "usage: solr_stop solr_home sole_port" 49 | return 1 50 | fi 51 | local solr_home=$1 52 | local solr_port=$2 53 | java \ 54 | -DSTOP.PORT=$((solr_port+1)) \ 55 | -DSTOP.KEY=$SOLR_STOP_KEY \ 56 | -jar ${solr_home}/example/start.jar --stop 57 | if [[ $? -ne 0 ]]; then 58 | echo "solr[$solr_port] stop fail!" 59 | return 1 60 | fi 61 | echo "solr[$solr_port] stop success!" 62 | 63 | return 0 64 | } 65 | 66 | # main 67 | function main() { 68 | if [[ $# -ne 3 ]]; then 69 | help 70 | return 1 71 | fi 72 | 73 | if [[ $1 == ${START} ]]; then 74 | solr_start $2 $3 75 | elif [[ $1 == ${STOP} ]]; then 76 | solr_stop $2 $3 77 | else 78 | help 79 | return 1 80 | fi 81 | return 0 82 | } 83 | 84 | main "$@" 85 | -------------------------------------------------------------------------------- /include/retrieval/term/term_retrieval.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_TERM_RETRIEVAL_H 16 | #define BAIDU_NLP_ANYQ_TERM_RETRIEVAL_H 17 | 18 | #include "retrieval/retrieval_interface.h" 19 | #include "common/http_client.h" 20 | #include "retrieval/term/solr_q_interface.h" 21 | 22 | namespace anyq { 23 | // 基于solr的term检索插件 24 | class TermRetrievalPlugin : public RetrievalPluginInterface 25 | { 26 | public: 27 | TermRetrievalPlugin() {}; 28 | virtual ~TermRetrievalPlugin() override {}; 29 | virtual int init(DictMap* dict_map, const RetrievalPluginConfig& plugin_config) override; 30 | virtual int destroy() override; 31 | // 构造检索solr所需的q参数 32 | int make_fetch_q(const AnalysisResult& analysis_result, int analysis_idx, std::string& q); 33 | // 对solr返回结构进行解析 34 | int solr_result_parse(const char* solr_result, RetrievalResult& retrieval_result); 35 | // 请求solr 36 | int solr_request(const char* q, std::string* buffer); 37 | virtual int retrieval(const AnalysisResult& analysis_res, 38 | RetrievalResult& retrieval_res) override; 39 | 40 | private: 41 | // http请求客户端 42 | HttpClient _http_client; 43 | // solr服务器ip 44 | std::string _search_host; 45 | // solr服务器端口 46 | int32_t _search_port; 47 | // collention name 48 | std::string _engine_name; 49 | // 请求solr的字段 50 | std::string _solr_request_fl; 51 | // solr_q builder 52 | std::vector _solr_q_builder; 53 | // solr返回的字段 54 | std::string _solr_result_fl; 55 | DISALLOW_COPY_AND_ASSIGN(TermRetrievalPlugin); 56 | }; 57 | 58 | } // namespace anyq 59 | #endif // BAIDU_NLP_ANYQ_TERM_RETRIEVAL_H 60 | -------------------------------------------------------------------------------- /src/retrieval/term/date_compare_solr_q_builder.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "retrieval/term/date_compare_solr_q_builder.h" 16 | 17 | namespace anyq { 18 | 19 | int DateCompareSolrQBuilder::init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) { 20 | init_base(solr_q_config.name(), solr_q_config.solr_field(), solr_q_config.source_name()); 21 | _compare_type = solr_q_config.compare_type(); 22 | return 0; 23 | } 24 | 25 | int DateCompareSolrQBuilder::make_q(const AnalysisResult& analysis_res, int analysis_idx, std::string& q) { 26 | q = ""; 27 | std::map::const_iterator it; 28 | const std::string& source_name = get_source_name(); 29 | it = analysis_res.info.find(source_name); 30 | // 构造时间比较字段 31 | std::string field_value = ""; 32 | if (it != analysis_res.info.end()) { 33 | if (_compare_type == "before") { 34 | field_value.append("["); 35 | field_value.append(it->second); 36 | field_value.append(" TO *]"); 37 | }else if (_compare_type == "after") { 38 | field_value.append("[* TO "); 39 | field_value.append(it->second); 40 | field_value.append("]"); 41 | }else{ 42 | FATAL_LOG("compare type[%s] is invalid", _compare_type.c_str()); 43 | } 44 | 45 | } else { 46 | FATAL_LOG("search filed[%s] not exist in analysis info_map", source_name.c_str()); 47 | return -1; 48 | } 49 | q += get_solr_field(); 50 | q.append(":"); 51 | q += field_value; 52 | 53 | DEBUG_LOG("date compare solr_fetch_q=%s", q.c_str()); 54 | return 0; 55 | } 56 | 57 | } // namespace anyq 58 | -------------------------------------------------------------------------------- /include/server/session_data_factory.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SESSION_DATA_FACTORY_H 16 | #define BAIDU_NLP_ANYQ_SESSION_DATA_FACTORY_H 17 | 18 | #include 19 | #include "brpc/data_factory.h" 20 | #include "strategy/anyq_strategy.h" 21 | #include "server/solr_accessor.h" 22 | 23 | namespace anyq { 24 | 25 | // 线程级数据 26 | class SessionData { 27 | public: 28 | SessionData(); 29 | ~SessionData(); 30 | // 使用全局字典和配置初始化anyq_strategy 31 | int init(DictManager* ptr_dict_manager, 32 | const std::string& anyq_conf_file, 33 | const std::string& solr_clear_passwd); 34 | AnyqStrategy* get_anyq() { 35 | return &_anyq_strategy; 36 | } 37 | SolrAccessor* get_solr_accessor() { 38 | return _use_solr ? &_solr_accessor : NULL; 39 | } 40 | 41 | private: 42 | AnyqStrategy _anyq_strategy; 43 | SolrAccessor _solr_accessor; 44 | bool _use_solr; 45 | DISALLOW_COPY_AND_ASSIGN(SessionData); 46 | }; 47 | 48 | class SessionDataFactory : public brpc::DataFactory 49 | { 50 | public: 51 | SessionDataFactory(); 52 | SessionDataFactory(DictManager* ptr_dict_manager, 53 | const std::string& anyq_conf_file, 54 | const std::string& solr_clear_passwd); 55 | ~SessionDataFactory(); 56 | 57 | void* CreateData() const; 58 | void DestroyData(void* session_data) const; 59 | 60 | private: 61 | // 词典是进程数据,以指针的形式传入,内存中只有一份词典 62 | DictManager* _dict_manager; 63 | std::string _anyq_conf_file; 64 | std::string _solr_clear_passwd; 65 | DISALLOW_COPY_AND_ASSIGN(SessionDataFactory); 66 | }; 67 | 68 | } // namespace anyq 69 | 70 | #endif // BAIDU_NLP_ANYQ_SESSION_DATA_FACTORY_H 71 | -------------------------------------------------------------------------------- /src/matching/lexical/contain_sim.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "matching/lexical/contain_sim.h" 16 | 17 | namespace anyq { 18 | 19 | ContainSimilarity::ContainSimilarity() { 20 | } 21 | 22 | ContainSimilarity::~ContainSimilarity() { 23 | destroy(); 24 | } 25 | 26 | int ContainSimilarity::init(DualDictWrapper* dict, const MatchingConfig& matching_config) { 27 | return 0; 28 | } 29 | 30 | int ContainSimilarity::destroy() { 31 | return 0; 32 | } 33 | 34 | // 判断analysis中query是否与召回query存在包含的关系 35 | bool ContainSimilarity::contain(const AnalysisItem& analysis_item, const RankItem& rank_item) { 36 | if (analysis_item.query.length() < 4) { 37 | return false; 38 | } 39 | if (rank_item.match_info.text.length() < 4) { 40 | return false; 41 | } 42 | if (analysis_item.query.find(rank_item.match_info.text) != std::string::npos) { 43 | return true; 44 | } 45 | if (rank_item.match_info.text.find(analysis_item.query) != std::string::npos) { 46 | return true; 47 | } 48 | return false; 49 | } 50 | 51 | int ContainSimilarity::compute_similarity(const AnalysisResult& analysis_res, RankResult& candidates) { 52 | if (analysis_res.analysis.size() < 1) { 53 | return -1; 54 | } 55 | for (int i = 0; i < candidates.size(); i++) { 56 | // 无效候选,跳过 57 | if (candidates[i].abandoned) { 58 | continue; 59 | } 60 | bool fea_value = contain(analysis_res.analysis[0], candidates[i]); 61 | DEBUG_LOG("contain %d", fea_value); 62 | // 如果存在包含关系,特征值为1,否则为0 63 | candidates[i].features.push_back(fea_value); 64 | } 65 | 66 | return 0; 67 | } 68 | 69 | } // namespace anyq 70 | -------------------------------------------------------------------------------- /cmake/external/zlib.cmake: -------------------------------------------------------------------------------- 1 | INCLUDE(ExternalProject) 2 | 3 | SET(ZLIB_SOURCES_DIR ${THIRD_PARTY_PATH}/zlib) 4 | SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib) 5 | SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE) 6 | SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE) 7 | SET(ZLIB_LIBRARIES "${ZLIB_INSTALL_DIR}/lib/libz.a" CACHE FILEPATH "zlib library." FORCE) 8 | 9 | INCLUDE_DIRECTORIES(${ZLIB_INCLUDE_DIR}) 10 | INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) 11 | 12 | ExternalProject_Add( 13 | extern_zlib 14 | ${EXTERNAL_PROJECT_LOG_ARGS} 15 | GIT_REPOSITORY "https://github.com/madler/zlib.git" 16 | GIT_TAG "v1.2.8" 17 | PREFIX ${ZLIB_SOURCES_DIR} 18 | UPDATE_COMMAND "" 19 | CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} 20 | -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} 21 | -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} 22 | -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} 23 | -DCMAKE_INSTALL_PREFIX=${ZLIB_INSTALL_DIR} 24 | -DBUILD_SHARED_LIBS=OFF 25 | -DCMAKE_POSITION_INDEPENDENT_CODE=ON 26 | -DCMAKE_MACOSX_RPATH=ON 27 | -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} 28 | ${EXTERNAL_OPTIONAL_ARGS} 29 | CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ZLIB_INSTALL_DIR} 30 | -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON 31 | -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} 32 | ) 33 | 34 | ADD_LIBRARY(zlib STATIC IMPORTED GLOBAL) 35 | SET_PROPERTY(TARGET zlib PROPERTY IMPORTED_LOCATION ${ZLIB_LIBRARIES}) 36 | ADD_DEPENDENCIES(zlib extern_zlib) 37 | add_custom_command(TARGET extern_zlib POST_BUILD 38 | COMMAND mkdir -p third_party/lib/ 39 | COMMAND mkdir -p third_party/include/ 40 | COMMAND ${CMAKE_COMMAND} -E copy ${ZLIB_LIBRARIES} third_party/lib/ 41 | COMMAND ${CMAKE_COMMAND} -E copy_directory ${ZLIB_INCLUDE_DIR} third_party/include/ 42 | ) 43 | 44 | LIST(APPEND external_project_dependencies zlib) 45 | 46 | IF(WITH_C_API) 47 | INSTALL(DIRECTORY ${ZLIB_INCLUDE_DIR} DESTINATION third_party/zlib) 48 | INSTALL(FILES ${ZLIB_LIBRARIES} DESTINATION third_party/zlib/lib) 49 | ENDIF() 50 | -------------------------------------------------------------------------------- /src/server/anyq_preprocessor.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "server/anyq_preprocessor.h" 16 | #include "common/utils.h" 17 | 18 | namespace anyq { 19 | 20 | int AnyqPreprocessor::init(const ReqPreprocPluginConfig& config) { 21 | if (!config.has_type()) { 22 | FATAL_LOG("ReqPreprocPluginConfig.type unset!"); 23 | return -1; 24 | } 25 | if (!config.has_name()) { 26 | FATAL_LOG("ReqPreprocPluginConfig.name unset!"); 27 | return -1; 28 | } 29 | init_base(config.name()); 30 | return 0; 31 | } 32 | 33 | int AnyqPreprocessor::destroy() { 34 | return 0; 35 | } 36 | 37 | int AnyqPreprocessor::process(brpc::Controller* cntl, 38 | Json::Value& parameters, 39 | std::string& str_anyq_input) { 40 | if (!parameters.isMember("question")) { 41 | FATAL_LOG("Query field is required."); 42 | } 43 | Json::Value json_analysis_input; 44 | Json::Value json_analysis_info; 45 | Json::Value::Members mem = parameters.getMemberNames(); 46 | for (Json::Value::Members::iterator it = mem.begin(); it != mem.end(); ++it) { 47 | if (*it == "question") { 48 | json_analysis_input["question"] = parameters["question"].asString(); 49 | std::string debug_str = parameters["question"].asString(); 50 | Json::Value analysis_item; 51 | analysis_item["question"] = parameters["question"].asString(); 52 | analysis_item["type"] = 0; 53 | json_analysis_input["analysis_item"] = analysis_item; 54 | } else { 55 | json_analysis_info[*it] = parameters[*it]; 56 | } 57 | } 58 | json_analysis_input["info"] = json_analysis_info; 59 | str_anyq_input = json_dumps(json_analysis_input); 60 | return 0; 61 | } 62 | 63 | } // namespace anyq 64 | -------------------------------------------------------------------------------- /include/matching/matching_interface.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_MATCHING_INTERFACE_H 16 | #define BAIDU_NLP_ANYQ_MATCHING_INTERFACE_H 17 | 18 | #include 19 | #include 20 | #include "anyq.pb.h" 21 | #include "common/common_define.h" 22 | #include "dict/dict_adapter.h" 23 | #include "dict/dual_dict_wrapper.h" 24 | #include "common/utils.h" 25 | 26 | namespace anyq { 27 | 28 | class MatchingInterface { 29 | //mathcing特征插件接口,继承类必须实现init,destroy和compute_similarity三个函数 30 | public: 31 | MatchingInterface() {}; 32 | virtual ~MatchingInterface() {}; 33 | // 线程资源初始化 34 | virtual int init(DualDictWrapper* dict, const MatchingConfig& matching_config) = 0; 35 | // 释放线程资源 36 | virtual int destroy() = 0; 37 | virtual int compute_similarity(const AnalysisResult& analysis_res, RankResult& candidates) = 0; 38 | // 基类初始化 39 | int init_base(const std::string& feature_name, int output_num, bool rough){ 40 | _feature_name = feature_name; 41 | _output_num = output_num; 42 | _rough = rough; 43 | return 0; 44 | } 45 | 46 | std::string feature_name(){ 47 | return _feature_name; 48 | } 49 | 50 | int get_output_num(){ 51 | return _output_num; 52 | } 53 | bool is_rough(){ 54 | return _rough; 55 | } 56 | 57 | protected: 58 | void set_output_num(int output_num){ 59 | _output_num = output_num; 60 | } 61 | 62 | private: 63 | std::string _feature_name; 64 | // 特征值个数。一个matching插件可以有多个特征值; 65 | // 当output_num=0时,该插件不输出特征值,只对候选query进行处理,如分词、过滤等。 66 | int _output_num; 67 | // 该特征是否用于粗排 68 | bool _rough; 69 | DISALLOW_COPY_AND_ASSIGN(MatchingInterface); 70 | }; 71 | 72 | } // namespace anyq 73 | 74 | #endif //BAIDU_NLP_ANYQ_MATCHING_INTERFACE_H 75 | -------------------------------------------------------------------------------- /src/rank/predictor/predict_linear_model.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "rank/predictor/predictor_interface.h" 16 | 17 | namespace anyq { 18 | 19 | int PredictLinearModel::init(DualDictWrapper* dict, 20 | const std::vector& feature_names, 21 | const RankPredict& predict_config) { 22 | _feature_weight = (hashmap_str2float*)dict->get_dict(); 23 | int fea_name_cnt = feature_names.size(); 24 | // 初始化权值向量,如果某特征名不在字典里,则权值为0.0 25 | _weights.resize(fea_name_cnt, 0.0f); 26 | for (int i = 0; i < fea_name_cnt; i++) { 27 | if (_feature_weight->count(feature_names[i]) > 0) { 28 | _weights[i] = (*_feature_weight)[feature_names[i]]; 29 | } else { 30 | WARNING_LOG("invalid feature name %s", feature_names[i].c_str()); 31 | _weights[i] = 0.0f; 32 | } 33 | } 34 | return 0; 35 | } 36 | 37 | int PredictLinearModel::destroy() { 38 | return 0; 39 | } 40 | 41 | int PredictLinearModel::predict(RankResult& candidates) { 42 | int weight_cnt = _weights.size(); 43 | for (size_t i = 0; i < candidates.size(); i++) { 44 | // 无效候选,跳过 45 | if (candidates[i].abandoned) { 46 | candidates[i].ltr_score = 0.0f; 47 | continue; 48 | } 49 | float score = 0.0f; 50 | int feature_size = candidates[i].features.size(); 51 | if (feature_size != weight_cnt) { 52 | FATAL_LOG("features size=%d; weights size=%d;", feature_size, weight_cnt); 53 | return -1; 54 | } 55 | //线性加权计算得分 56 | for (int j = 0; j < weight_cnt; j++) { 57 | score += candidates[i].features[j] * _weights[j]; 58 | } 59 | candidates[i].ltr_score = score; 60 | } 61 | return 0; 62 | } 63 | 64 | } // namespace anyq 65 | -------------------------------------------------------------------------------- /src/retrieval/term/synonym_solr_q_builder.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "retrieval/term/synonym_solr_q_builder.h" 16 | 17 | namespace anyq { 18 | 19 | int SynonymSolrQBuilder::init(DualDictWrapper* dict, const SolrQConfig& solr_q_config) { 20 | init_base(solr_q_config.name(), solr_q_config.solr_field(), solr_q_config.source_name()); 21 | _p_dual_dict_wrapper = dict; 22 | return 0; 23 | } 24 | 25 | int SynonymSolrQBuilder::term_synonym(const std::string& term, 26 | std::string& synonym_terms) { 27 | hashmap_str2str* black_white_list = (hashmap_str2str*)_p_dual_dict_wrapper->get_dict(); 28 | if (black_white_list->count(term) > 0) { 29 | std::string synonym_str = (*black_white_list)[term]; 30 | std::vector synonym_list; 31 | split_string(synonym_str, synonym_list, "|"); 32 | for (size_t i = 0; i < synonym_list.size(); ++i) { 33 | synonym_terms += escape(synonym_list[i]); 34 | synonym_terms.append(" "); 35 | } 36 | } 37 | return 0; 38 | } 39 | 40 | int SynonymSolrQBuilder::make_q(const AnalysisResult& analysis_res, int analysis_idx, std::string& q) { 41 | q = ""; 42 | std::string synonym_terms = ""; 43 | const std::string& source_name = get_source_name(); 44 | if (source_name == "basic_token") { 45 | for (uint32_t j = 0; j < analysis_res.analysis[analysis_idx].tokens_basic.size(); j++) { 46 | term_synonym(analysis_res.analysis[analysis_idx].tokens_basic[j].buffer, 47 | synonym_terms); 48 | } 49 | } 50 | if (synonym_terms != "") { 51 | q.append("+"); 52 | q += source_name; 53 | q.append(":("); 54 | q += synonym_terms; 55 | q.append(")"); 56 | } 57 | 58 | return 0; 59 | } 60 | 61 | } // namespace anyq 62 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/nets/knrm.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | 19 | import layers.tf_layers as layers 20 | 21 | 22 | class KNRM(object): 23 | """ 24 | k-nrm init funtion 25 | """ 26 | def __init__(self, config): 27 | self.vocab_size = int(config['vocabulary_size']) 28 | self.emb_size = int(config['embedding_dim']) 29 | self.kernel_num = int(config['kernel_num']) 30 | self.left_name, self.seq_len1 = config['left_slots'][0] 31 | self.right_name, self.seq_len2 = config['right_slots'][0] 32 | self.lamb = float(config['lamb']) 33 | self.task_mode = config['training_mode'] 34 | self.emb_layer = layers.EmbeddingLayer(self.vocab_size, self.emb_size) 35 | self.sim_mat_layer = layers.SimilarityMatrixLayer() 36 | self.kernel_pool_layer = layers.KernelPoolingLayer(self.kernel_num, self.lamb) 37 | self.tanh_layer = layers.TanhLayer() 38 | if self.task_mode == "pointwise": 39 | self.n_class = int(config['n_class']) 40 | self.fc_layer = layers.FCLayer(self.kernel_num, self.n_class) 41 | elif self.task_mode == "pairwise": 42 | self.fc_layer = layers.FCLayer(self.kernel_num, 1) 43 | else: 44 | logging.error("training mode not supported") 45 | 46 | def predict(self, left_slots, right_slots): 47 | """ 48 | predict graph of this net 49 | """ 50 | left = left_slots[self.left_name] 51 | right = right_slots[self.right_name] 52 | left_emb = self.emb_layer.ops(left) 53 | right_emb = self.emb_layer.ops(right) 54 | sim_mat = self.sim_mat_layer.ops(left_emb, right_emb) 55 | feats = self.kernel_pool_layer.ops(sim_mat) 56 | pred = self.fc_layer.ops(feats) 57 | return pred 58 | -------------------------------------------------------------------------------- /tools/simnet/train/tf/losses/simnet_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import numpy as np 18 | import math 19 | 20 | import tensorflow as tf 21 | from tensorflow.contrib.rnn import GRUCell 22 | from tensorflow.contrib.rnn import LSTMCell 23 | from tensorflow.python.ops import array_ops 24 | from tensorflow.python.ops.rnn import dynamic_rnn as rnn 25 | from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn as bi_rnn 26 | 27 | 28 | class PairwiseHingeLoss(object): 29 | """ 30 | a layer class: pairwise hinge loss 31 | """ 32 | def __init__(self, config): 33 | """ 34 | init function 35 | """ 36 | self.margin = float(config["margin"]) 37 | 38 | def ops(self, score_pos, score_neg): 39 | """ 40 | operation 41 | """ 42 | return tf.reduce_mean(tf.maximum(0., score_neg + 43 | self.margin - score_pos)) 44 | 45 | 46 | class PairwiseLogLoss(object): 47 | """ 48 | a layer class: pairwise log loss 49 | """ 50 | def __init__(self, config=None): 51 | """ 52 | init function 53 | """ 54 | pass 55 | 56 | def ops(self, score_pos, score_neg): 57 | """ 58 | operation 59 | """ 60 | return tf.reduce_mean(tf.nn.sigmoid(score_neg - score_pos)) 61 | 62 | 63 | class SoftmaxWithLoss(object): 64 | """ 65 | a layer class: softmax loss 66 | """ 67 | def __init__(self): 68 | """ 69 | init function 70 | """ 71 | pass 72 | 73 | def ops(self, pred, label): 74 | """ 75 | operation 76 | """ 77 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, 78 | labels=label)) 79 | -------------------------------------------------------------------------------- /include/server/solr_accessor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef BAIDU_NLP_ANYQ_SOLR_ACCESSOR_H 16 | #define BAIDU_NLP_ANYQ_SOLR_ACCESSOR_H 17 | 18 | #include 19 | #include "common/common_define.h" 20 | #include "common/http_client.h" 21 | #include "anyq.pb.h" 22 | 23 | namespace anyq { 24 | 25 | // 通过anyq操纵solr 数据接口,通过该接口,anyq可以提供统一对外服务的接口,包括问答检索、FAQ库的增删改等操作 26 | class SolrAccessor { 27 | public: 28 | SolrAccessor(); 29 | ~SolrAccessor(); 30 | 31 | int init(const std::string& conf_path, const std::string& solr_clear_passwd); 32 | int insert_doc(const Json::Value& param, std::string& result); 33 | int update_doc(const Json::Value& param, std::string& result); 34 | int delete_doc(const Json::Value& param, std::string& result); 35 | int clear_doc(const std::string& passwd, std::string& result); 36 | 37 | private: 38 | // solr配置相关 39 | HttpClient _http_client; 40 | char _solr_url[URL_LENGTH]; 41 | std::string _buffer; 42 | std::string _solr_clear_passwd; 43 | 44 | static const std::string pack_error_msg(const std::string& msg); 45 | const std::string pack_str(std::string param_str); 46 | static int parse_request_result(const std::string& buffer, std::string& result); 47 | 48 | static int single_insert_param(const Json::Value&, std::string&, std::string&); 49 | static int single_update_param(const Json::Value&, std::string&, std::string&); 50 | static int single_delete_param(const Json::Value&, std::string&, std::string&); 51 | // 支持batch级 solr库操作 52 | int batch_param(int (*single_pf)(const Json::Value&, std::string&, std::string&), 53 | const Json::Value& param, 54 | std::string& batch_param_str, 55 | const std::string& sep, 56 | std::string& result); 57 | 58 | DISALLOW_COPY_AND_ASSIGN(SolrAccessor); 59 | }; 60 | 61 | } // namespace anyq 62 | 63 | #endif // BAIDU_NLP_ANYQ_SOLR_ACCESSOR_H 64 | -------------------------------------------------------------------------------- /src/retrieval/manual/manual_retrieval.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "retrieval/manual/manual_retrieval.h" 16 | 17 | namespace anyq { 18 | 19 | int ManualRetrievalPlugin::init(DictMap* dict_map, const RetrievalPluginConfig& plugin_config) { 20 | // 基类初始化 21 | init_base(plugin_config.name(), plugin_config.num_result()); 22 | // 设置人工干预字典 23 | if (dict_map == NULL) { 24 | FATAL_LOG("dict_map is null"); 25 | return -1; 26 | } 27 | if (!plugin_config.has_using_dict_name()) { 28 | FATAL_LOG("RetrievalPluginConfig.%s unset!", "using_dict_name"); 29 | return -1; 30 | } 31 | std::string q2a_dict_name = plugin_config.using_dict_name(); 32 | if (dict_map->count(q2a_dict_name) < 1) { 33 | FATAL_LOG("using dict %s that does not exist", q2a_dict_name.c_str()); 34 | return -1; 35 | } 36 | _p_dual_dict_wrapper = (*dict_map)[q2a_dict_name]; 37 | return 0; 38 | } 39 | 40 | int ManualRetrievalPlugin::destroy() { 41 | return 0; 42 | } 43 | 44 | // 召回 45 | int ManualRetrievalPlugin::retrieval(const AnalysisResult& analysis_result, RetrievalResult& retrieval_res) { 46 | // 干预词典支持reload,检索时动态获取词典 47 | hashmap_str2str* q2a_dict = (hashmap_str2str*)(_p_dual_dict_wrapper->get_dict()); 48 | for (uint32_t i = 0; i < analysis_result.analysis.size(); i++) { 49 | if (q2a_dict->count(analysis_result.analysis[i].query) == 0) { 50 | continue; 51 | } 52 | TextInfo t_answer; 53 | t_answer.text = (*q2a_dict)[analysis_result.analysis[i].query]; 54 | // 人工干预的检索结果优先级最高,query如果命中干预字典,将anyq_end设置为true,跳过其他的检索和rank 55 | retrieval_res.anyq_end = true; 56 | RetrievalItem retrieval_item; 57 | retrieval_item.query.text = analysis_result.analysis[i].query; 58 | retrieval_item.answer.push_back(t_answer); 59 | retrieval_res.items.push_back(retrieval_item); 60 | } 61 | 62 | return 0; 63 | } 64 | 65 | } // namespace anyq 66 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/nets/bow.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import layers.paddle_layers as layers 18 | 19 | 20 | class BOW(object): 21 | """ 22 | BOW 23 | """ 24 | def __init__(self, conf_dict): 25 | """ 26 | initialize 27 | """ 28 | self.dict_size = conf_dict["dict_size"] 29 | self.task_mode = conf_dict["task_mode"] 30 | self.emb_dim = conf_dict["net"]["emb_dim"] 31 | self.bow_dim = conf_dict["net"]["bow_dim"] 32 | 33 | def predict(self, left, right): 34 | """ 35 | Forward network 36 | """ 37 | # embedding layer 38 | emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb") 39 | left_emb = emb_layer.ops(left) 40 | right_emb = emb_layer.ops(right) 41 | # Presentation context 42 | pool_layer = layers.SequencePoolLayer("sum") 43 | left_pool = pool_layer.ops(left_emb) 44 | right_pool = pool_layer.ops(right_emb) 45 | softsign_layer = layers.SoftsignLayer() 46 | left_soft = softsign_layer.ops(left_pool) 47 | right_soft = softsign_layer.ops(right_pool) 48 | # matching layer 49 | if self.task_mode == "pairwise": 50 | bow_layer = layers.FCLayer(self.bow_dim, "relu", "fc") 51 | left_bow = bow_layer.ops(left_soft) 52 | right_bow = bow_layer.ops(right_soft) 53 | cos_sim_layer = layers.CosSimLayer() 54 | pred = cos_sim_layer.ops(left_bow, right_bow) 55 | return left_bow, pred 56 | else: 57 | concat_layer = layers.ConcatLayer(1) 58 | concat = concat_layer.ops([left_soft, right_soft]) 59 | bow_layer = layers.FCLayer(self.bow_dim, "relu", "fc") 60 | concat_fc = bow_layer.ops(concat) 61 | softmax_layer = layers.FCLayer(2, "softmax", "cos_sim") 62 | pred = softmax_layer.ops(concat_fc) 63 | return left_soft, pred 64 | -------------------------------------------------------------------------------- /src/analysis/method_wordseg.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "analysis/method_wordseg.h" 16 | #include 17 | 18 | namespace anyq{ 19 | 20 | AnalysisWordseg::AnalysisWordseg(){ 21 | } 22 | 23 | AnalysisWordseg::~AnalysisWordseg(){ 24 | destroy(); 25 | } 26 | 27 | int AnalysisWordseg::init(DualDictWrapper* dict, const AnalysisMethodConfig& analysis_method) 28 | { 29 | _p_wordseg_pack = (WordsegPack*)dict->get_dict(); 30 | _lexer_buff = NULL; 31 | _lexer_buff = lac_buff_create(_p_wordseg_pack->lexer_dict); 32 | 33 | if (_lexer_buff == NULL) { 34 | FATAL_LOG("error init lexer_buff = thread"); 35 | return -1; 36 | } 37 | 38 | set_method_name(analysis_method.name()); 39 | TRACE_LOG("init wordseg success"); 40 | return 0; 41 | } 42 | 43 | int AnalysisWordseg::destroy(){ 44 | if (_lexer_buff != NULL) { 45 | lac_buff_destroy(_p_wordseg_pack->lexer_dict, _lexer_buff); 46 | _lexer_buff = NULL; 47 | } 48 | return 0; 49 | }; 50 | 51 | int AnalysisWordseg::single_process(AnalysisItem& analysis_item) { 52 | const char* c_query = analysis_item.query.c_str(); 53 | int basic_tk_num = -1; 54 | try { 55 | basic_tk_num = lac_tagging(_p_wordseg_pack->lexer_dict, 56 | _lexer_buff, 57 | c_query, 58 | _basic_tokens, 59 | MAX_TERM_COUNT); 60 | } catch (std::exception& e) { 61 | FATAL_LOG("wordseg segment error."); 62 | return -1; 63 | } 64 | 65 | if (basic_tk_num < 0) { 66 | FATAL_LOG("wordseg segment error."); 67 | return -1; 68 | } 69 | 70 | int ret = array_tokens_conduct(_basic_tokens, 71 | basic_tk_num, 72 | analysis_item.tokens_basic, 73 | analysis_item.query); 74 | if (ret != 0) { 75 | FATAL_LOG("analysis segment token convert error."); 76 | return -1; 77 | } 78 | 79 | return 0; 80 | }; 81 | 82 | } // namespace anyq 83 | -------------------------------------------------------------------------------- /src/dict/dict_adapter.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "dict/dict_adapter.h" 16 | #include 17 | 18 | namespace anyq { 19 | 20 | String2RetrievalItemAdapter::String2RetrievalItemAdapter(){ 21 | hashmap_str2retrieval_item* tmp_dict = new hashmap_str2retrieval_item(); 22 | if (tmp_dict == NULL) { 23 | FATAL_LOG("new str2retrieval_item dict error"); 24 | } 25 | DEBUG_LOG("new str2retrieval_item dict suecess"); 26 | set_dict((void*)tmp_dict); 27 | } 28 | 29 | String2RetrievalItemAdapter::~String2RetrievalItemAdapter(){ 30 | void* dict = get_dict(); 31 | if (dict != NULL) { 32 | delete static_cast(dict); 33 | } 34 | } 35 | 36 | int String2RetrievalItemAdapter::load(const std::string& path, const DictConfig& config) { 37 | DEBUG_LOG("%s", path.c_str()); 38 | hashmap_str2retrieval_item* tmp_dict = static_cast(get_dict()); 39 | int ret = str2retrieval_item_load(path.c_str(), (*tmp_dict)); 40 | if (ret != 0) { 41 | FATAL_LOG("load string2retrieval_item dict %s error", path.c_str()); 42 | return -1; 43 | } 44 | return 0; 45 | } 46 | 47 | // 查找 48 | int String2RetrievalItemAdapter::get(const std::string& key, RetrievalItem& retrieval_item) { 49 | hashmap_str2retrieval_item* tmp_dict = static_cast(get_dict()); 50 | hashmap_str2retrieval_item::const_iterator it = tmp_dict->find(key); 51 | if (it != tmp_dict->end()) { 52 | retrieval_item = it->second; 53 | } else { 54 | WARNING_LOG("String2RetrievalItemAdapter key[%s] not exist!", key.c_str()); 55 | return -1; 56 | } 57 | return 0; 58 | } 59 | 60 | int String2RetrievalItemAdapter::release() { 61 | void* dict = get_dict(); 62 | if (dict != NULL) { 63 | hashmap_str2retrieval_item* tmp_dict = static_cast(dict); 64 | tmp_dict->clear(); 65 | } 66 | return 0; 67 | } 68 | 69 | } 70 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/nets/cnn.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import layers.paddle_layers as layers 18 | 19 | 20 | class CNN(object): 21 | """ 22 | CNN 23 | """ 24 | def __init__(self, conf_dict): 25 | """ 26 | initialize 27 | """ 28 | self.dict_size = conf_dict["dict_size"] 29 | self.task_mode = conf_dict["task_mode"] 30 | self.emb_dim = conf_dict["net"]["emb_dim"] 31 | self.filter_size = conf_dict["net"]["filter_size"] 32 | self.num_filters = conf_dict["net"]["num_filters"] 33 | self.hidden_dim = conf_dict["net"]["hidden_dim"] 34 | 35 | def predict(self, left, right): 36 | """ 37 | Forward network 38 | """ 39 | # embedding layer 40 | emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb") 41 | left_emb = emb_layer.ops(left) 42 | right_emb = emb_layer.ops(right) 43 | # Presentation context 44 | cnn_layer = layers.SequenceConvPoolLayer( 45 | self.filter_size, self.num_filters, "conv") 46 | left_cnn = cnn_layer.ops(left_emb) 47 | right_cnn = cnn_layer.ops(right_emb) 48 | # matching layer 49 | if self.task_mode == "pairwise": 50 | relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu") 51 | left_relu = relu_layer.ops(left_cnn) 52 | right_relu = relu_layer.ops(right_cnn) 53 | cos_sim_layer = layers.CosSimLayer() 54 | pred = cos_sim_layer.ops(left_relu, right_relu) 55 | return left_relu, pred 56 | else: 57 | concat_layer = layers.ConcatLayer(1) 58 | concat = concat_layer.ops([left_cnn, right_cnn]) 59 | relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu") 60 | concat_fc = relu_layer.ops(concat) 61 | softmax_layer = layers.FCLayer(2, "softmax", "cos_sim") 62 | pred = softmax_layer.ops(concat_fc) 63 | return left_cnn, pred 64 | -------------------------------------------------------------------------------- /demo/annoy_index_build.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "dict/dict_manager.h" 16 | #include "analysis/analysis_strategy.h" 17 | #include 18 | #include 19 | #include "retrieval/semantic/kissrandom.h" 20 | #include "retrieval/semantic/annoylib.h" 21 | #include "common/plugin_header.h" 22 | 23 | int main(int argc, char* argv[]){ 24 | google::InitGoogleLogging(argv[0]); 25 | FLAGS_stderrthreshold = google::INFO; 26 | if (argc != 7) { 27 | FATAL_LOG("Usage: annoy_index_build_tool anyq_dict_dir analysis_conf_path query_file vector_dim " 28 | "num_trees index_sava_file"); 29 | FATAL_LOG("Example: ./output/bin/annoy_index_build_tool example/conf/ example/conf/analysis.conf " 30 | "example/conf/annoy_query_to_build_tree.dat 128 200 example/conf/annoy_index.tree"); 31 | return -1; 32 | } 33 | anyq::DictManager dm; 34 | if (dm.load_dict(argv[1]) != 0) { 35 | FATAL_LOG("load dict error"); 36 | return -1; 37 | } 38 | anyq::DictMap* global_dict = dm.get_dict(); 39 | anyq::AnalysisStrategy analysis_strategy; 40 | analysis_strategy.init(global_dict, argv[2]); 41 | 42 | std::fstream fs(argv[3], std::fstream::in); 43 | if (!fs.is_open()) { 44 | FATAL_LOG("open query file error"); 45 | return -1; 46 | } 47 | 48 | AnnoyIndex annoy_index = 49 | AnnoyIndex(atoi(argv[4])); 50 | std::string line; 51 | std::vector fields; 52 | while (getline(fs, line)) { 53 | anyq::AnalysisResult analysis_result; 54 | fields.clear(); 55 | anyq::split_string(line, fields, "\t"); 56 | analysis_strategy.run_strategy(fields[1], analysis_result); 57 | annoy_index.add_item(atoi(fields[0].c_str()), &analysis_result.analysis[0].query_emb[0]); 58 | } 59 | annoy_index.build(atoi(argv[5])); 60 | annoy_index.save(argv[6]); 61 | 62 | return 0; 63 | } 64 | -------------------------------------------------------------------------------- /src/server/session_data_factory.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "server/session_data_factory.h" 16 | 17 | namespace anyq { 18 | 19 | SessionData::SessionData() : _use_solr(true) { 20 | } 21 | 22 | SessionData::~SessionData() { 23 | } 24 | 25 | int SessionData::init(DictManager* ptr_dict_manager, 26 | const std::string& anyq_conf_file, 27 | const std::string& solr_clear_passwd) { 28 | if (ptr_dict_manager == NULL) { 29 | FATAL_LOG("ptr_dict_manager is NULL"); 30 | return -1; 31 | } 32 | if (_anyq_strategy.create_resource(*ptr_dict_manager, anyq_conf_file) != 0) { 33 | FATAL_LOG("anyq create resource failed, anyq_conf_file=%s", anyq_conf_file.c_str()); 34 | return -1; 35 | } 36 | if (_solr_accessor.init(anyq_conf_file, solr_clear_passwd) != 0) { 37 | WARNING_LOG("solr accessor init not success!"); 38 | _use_solr = false; 39 | } 40 | return 0; 41 | } 42 | 43 | SessionDataFactory::SessionDataFactory() { 44 | } 45 | 46 | SessionDataFactory::~SessionDataFactory() { 47 | } 48 | 49 | SessionDataFactory::SessionDataFactory(DictManager* ptr_dict_manager, 50 | const std::string& anyq_conf_file, 51 | const std::string& solr_clear_passwd) { 52 | _dict_manager = ptr_dict_manager; 53 | _anyq_conf_file = anyq_conf_file; 54 | _solr_clear_passwd = solr_clear_passwd; 55 | } 56 | 57 | void* SessionDataFactory::CreateData() const { 58 | if (_dict_manager == NULL) { 59 | FATAL_LOG("_dict_manager is NULL"); 60 | return NULL; 61 | } 62 | SessionData* sd = new SessionData(); 63 | if (sd->init(_dict_manager, _anyq_conf_file, _solr_clear_passwd) != 0) { 64 | FATAL_LOG("session data init failed."); 65 | return NULL; 66 | } 67 | DEBUG_LOG("session data init success!!!"); 68 | 69 | return static_cast(sd); 70 | } 71 | 72 | void SessionDataFactory::DestroyData(void* session_data) const { 73 | delete static_cast(session_data); 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /tools/simnet/train/paddle/nets/gru.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # Copyright (c) 2018 Baidu, Inc. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import layers.paddle_layers as layers 18 | 19 | 20 | class GRU(object): 21 | """ 22 | GRU 23 | """ 24 | def __init__(self, conf_dict): 25 | """ 26 | initialize 27 | """ 28 | self.dict_size = conf_dict["dict_size"] 29 | self.task_mode = conf_dict["task_mode"] 30 | self.emb_dim = conf_dict["net"]["emb_dim"] 31 | self.gru_dim = conf_dict["net"]["gru_dim"] 32 | self.hidden_dim = conf_dict["net"]["hidden_dim"] 33 | 34 | def predict(self, left, right): 35 | """ 36 | Forward network 37 | """ 38 | # embedding layer 39 | emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb") 40 | left_emb = emb_layer.ops(left) 41 | right_emb = emb_layer.ops(right) 42 | # Presentation context 43 | gru_layer = layers.DynamicGRULayer(self.gru_dim, "gru") 44 | left_gru = gru_layer.ops(left_emb) 45 | right_gru = gru_layer.ops(right_emb) 46 | last_layer = layers.SequenceLastStepLayer() 47 | left_last = last_layer.ops(left_gru) 48 | right_last = last_layer.ops(right_gru) 49 | # matching layer 50 | if self.task_mode == "pairwise": 51 | relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu") 52 | left_relu = relu_layer.ops(left_last) 53 | right_relu = relu_layer.ops(right_last) 54 | cos_sim_layer = layers.CosSimLayer() 55 | pred = cos_sim_layer.ops(left_relu, right_relu) 56 | return left_relu, pred 57 | else: 58 | concat_layer = layers.ConcatLayer(1) 59 | concat = concat_layer.ops([left_last, right_last]) 60 | relu_layer = layers.FCLayer(self.hidden_dim, "relu", "relu") 61 | concat_fc = relu_layer.ops(concat) 62 | softmax_layer = layers.FCLayer(2, "softmax", "cos_sim") 63 | pred = softmax_layer.ops(concat_fc) 64 | return left_last, pred 65 | --------------------------------------------------------------------------------