├── .gitignore ├── .gitmodules ├── README.md └── code_qa_txt ├── Makefile.example ├── eval.sh ├── include ├── bfs_path_embed.h ├── config.h ├── dataset.h ├── dict.h ├── global.h ├── graph_inner_product.h ├── inet.h ├── knowledge_base.h ├── net_latent_y.h ├── net_multihop.h ├── node_select.h ├── util.h └── var_sample.h ├── init_run.sh ├── run.sh ├── src ├── lib │ ├── bfs_path_embed.cpp │ ├── config.cpp │ ├── dataset.cpp │ ├── dict.cpp │ ├── global.cpp │ ├── graph_inner_product.cpp │ ├── graph_inner_product.cu │ ├── inet.cpp │ ├── knowledge_base.cpp │ ├── net_latent_y.cpp │ ├── net_multihop.cpp │ ├── node_select.cpp │ ├── util.cpp │ └── var_sample.cpp └── main.cpp └── vis.sh /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | metaQA 3 | metaQA/ 4 | .vscode/ 5 | Makefile 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "graphnn"] 2 | path = graphnn 3 | url = https://github.com/Hanjun-Dai/graphnn 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational-Reasoning-Networks 2 | 3 | ## Setup 4 | 5 | #### get the code 6 | 7 | First clone the repo and the submodules: 8 | 9 | `git clone git@github.com:yuyuz/Variational-Reasoning-Networks --recursive` 10 | 11 | The code depends on the graphnn library, which can be found here: https://github.com/Hanjun-Dai/graphnn 12 | 13 | #### build the graphnn library 14 | 15 | Under the current project folder, build the graphnn library as instructed in the above link. 16 | 17 | `cd Variational-Reasoning-Networks/graphnn` 18 | 19 | Make modifications to the make_common file, such that it can correctly locate your cuda-8 library, and intel MKL and intel TBB libraries. 20 | 21 | `cp make_common.example make_common` 22 | 23 | Then build the graphnn library: 24 | 25 | `make -j` 26 | 27 | #### build main source code 28 | 29 | To build the code for text based question answering, do the following: 30 | 31 | `cd Variational-Reasoning-Networks/code_qa_txt` 32 | 33 | Make modifications to the Makefile, if necessary 34 | 35 | `cp Makefile.example Makefile` 36 | 37 | Then build everything: 38 | 39 | `make -j` 40 | 41 | #### link the data folder 42 | 43 | First download the data from https://github.com/yuyuz/MetaQA 44 | 45 | Then link the root data folder into the top level folder of the project. 46 | 47 | `cd Variational-Reasoning-Networks` 48 | 49 | `ln -s path/to/your/data metaQA` 50 | 51 | ## Play with the code 52 | 53 | Below we illustrate with the text based question answering. First navigate to the root code folder: 54 | 55 | `cd Variational-Reasoning-Networks/code_qa_txt` 56 | 57 | #### pretraining 58 | 59 | We first use 5% of the labeled data to train the posterior inference model, as well as knowledge graph reasoning model. 60 | 61 | `./init_run.sh` 62 | 63 | You can make edits to the script, so as to train with different datasets (vanilla or ntm) and different # hops (1, 2 or 3) for reasoning. 64 | 65 | #### joint training 66 | 67 | Then we load the pretrained model dump, and use REINFORCE with variance reduction to jointly train the model. 68 | 69 | `./run.sh` 70 | 71 | Typically ~1000 iterations are enough for the pretraining. 72 | 73 | 74 | #### inspect the learned model 75 | 76 | The script `vis.sh` is used to visualize the learned inference logic, where `eval.sh` can be used to inspect the topic entity recognition model. In some settings (like ntm with 1-hop reasoning), the jointly learned model can further improve the entity recognition performance after pretraining with 5% data. 77 | 78 | 79 | 80 | ## Reference 81 | 82 | If you find the code or data is useful, please cite our work: 83 | 84 | ``` 85 | @inproceedings{zhang2018variational, 86 | title={Variational reasoning for question answering with knowledge graph}, 87 | author={Zhang, Yuyu and Dai, Hanjun and Kozareva, Zornitsa and Smola, Alexander J and Song, Le}, 88 | booktitle={Thirty-Second AAAI Conference on Artificial Intelligence}, 89 | year={2018} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /code_qa_txt/Makefile.example: -------------------------------------------------------------------------------- 1 | GNN_HOME=../graphnn 2 | include $(GNN_HOME)/make_common 3 | USE_GPU = 1 4 | 5 | include_dirs = $(CUDA_HOME)/include $(MKL_ROOT)/include $(GNN_HOME)/include ./include 6 | 7 | CXXFLAGS += $(addprefix -I,$(include_dirs)) -Wno-unused-local-typedef 8 | cpp_files = $(shell $(FIND) src/lib -name "*.cpp" -printf "%P\n") 9 | cxx_obj_files = $(subst .cpp,.o,$(cpp_files)) 10 | objs = $(addprefix $(obj_build_root)/cxx/,$(cxx_obj_files)) 11 | 12 | ifeq ($(USE_GPU), 1) 13 | CXXFLAGS += -DUSE_GPU 14 | NVCCFLAGS += -DUSE_GPU 15 | NVCCFLAGS += $(addprefix -I,$(include_dirs)) 16 | NVCCFLAGS += -std=c++11 --use_fast_math 17 | cu_files = $(shell $(FIND) src/lib -name "*.cu" -printf "%P\n") 18 | cu_obj_files = $(subst .cu,.o,$(cu_files)) 19 | objs += $(addprefix $(obj_build_root)/cuda/,$(cu_obj_files)) 20 | lib_dir = $(GNN_HOME)/build/lib 21 | else 22 | lib_dir = $(GNN_HOME)/build_cpuonly/lib 23 | endif 24 | 25 | gnn_lib = $(lib_dir)/libgnn.a 26 | 27 | obj_build_root = build/objs 28 | 29 | DEPS = $(objs:.o=.d) 30 | 31 | target = build/main 32 | target_dep = $(addsuffix .d,$(target)) 33 | 34 | .PRECIOUS: $(obj_build_root)/cuda/%.o $(obj_build_root)/cxx/%.o 35 | 36 | all: $(target) 37 | 38 | build/%: src/%.cpp $(gnn_lib) $(objs) 39 | $(dir_guard) 40 | $(CXX) $(CXXFLAGS) -MMD -o $@ $(filter %.cpp %.o, $^) -L$(lib_dir) -lgnn $(LDFLAGS) 41 | 42 | DEPS += $(target_dep) 43 | 44 | ifeq ($(USE_GPU), 1) 45 | $(obj_build_root)/cuda/%.o: src/lib/%.cu 46 | $(dir_guard) 47 | $(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} -odir $(@D) 48 | $(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 49 | endif 50 | 51 | $(obj_build_root)/cxx/%.o: src/lib/%.cpp 52 | $(dir_guard) 53 | $(CXX) $(CXXFLAGS) -MMD -c -o $@ $(filter %.cpp, $^) 54 | 55 | clean: 56 | rm -rf build 57 | 58 | -include $(DEPS) 59 | -------------------------------------------------------------------------------- /code_qa_txt/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | make 3 | 4 | nhop_subg=3 5 | dataset=vanilla 6 | data_root=../metaQA 7 | #net_type=NetMultiHop 8 | net_type=NetLatentY 9 | 10 | result_root=$HOME/scratch/results/graph_mem/nhop-$nhop_subg/$dataset 11 | 12 | num_neg=10000 13 | max_bp_iter=1 14 | max_q_iter=3 15 | batch_size=128 16 | n_hidden=64 17 | n_embed=256 18 | margin=0.1 19 | learning_rate=0.01 20 | max_iter=4000000 21 | cur_iter=800 22 | w_scale=0.01 23 | loss_type=cross_entropy 24 | save_dir=$result_root/embed-$n_embed 25 | 26 | if [ ! -e $save_dir ]; 27 | then 28 | mkdir -p $save_dir 29 | fi 30 | 31 | ./build/main \ 32 | -test_only 1 \ 33 | -test_topk 5 \ 34 | -num_neg $num_neg \ 35 | -loss_type $loss_type \ 36 | -data_root $data_root \ 37 | -dataset $dataset \ 38 | -n_hidden $n_hidden \ 39 | -nhop_subg $nhop_subg \ 40 | -lr $learning_rate \ 41 | -max_bp_iter $max_bp_iter \ 42 | -net_type $net_type \ 43 | -max_q_iter $max_q_iter \ 44 | -margin $margin \ 45 | -max_iter $max_iter \ 46 | -svdir $save_dir \ 47 | -embed $n_embed \ 48 | -batch_size $batch_size \ 49 | -m 0.9 \ 50 | -l2 0.00 \ 51 | -w_scale $w_scale \ 52 | -int_report 1 \ 53 | -int_test 1 \ 54 | -int_save 1000000 \ 55 | -cur_iter $cur_iter 56 | -------------------------------------------------------------------------------- /code_qa_txt/include/bfs_path_embed.h: -------------------------------------------------------------------------------- 1 | #ifndef BFS_PATH_EMBED_H 2 | #define BFS_PATH_EMBED_H 3 | 4 | #include "util/gnn_macros.h" 5 | #include "nn/factor.h" 6 | #include "nn/variable.h" 7 | #include "var_sample.h" 8 | 9 | namespace gnn 10 | { 11 | 12 | template 13 | class BfsPathEmbed : public Factor 14 | { 15 | public: 16 | static std::string StrType() 17 | { 18 | return "BfsPathEmbed"; 19 | } 20 | 21 | using OutType = std::tuple< std::shared_ptr< DTensorVar >, 22 | std::shared_ptr< VectorVar >, 23 | std::shared_ptr< VectorVar > >; 24 | 25 | OutType CreateOutVar() 26 | { 27 | auto o0 = std::make_shared< DTensorVar >( fmt::sprintf("%s:out_0", this->name) ); 28 | auto o1 = std::make_shared< VectorVar >( fmt::sprintf("%s:out_1", this->name) ); 29 | auto o2 = std::make_shared< VectorVar >( fmt::sprintf("%s:out_2", this->name) ); 30 | 31 | return std::make_tuple(o0, o1, o2); 32 | } 33 | 34 | BfsPathEmbed(std::string _name, PropErr _properr = PropErr::T); 35 | 36 | virtual void Forward(std::vector< std::shared_ptr >& operands, 37 | std::vector< std::shared_ptr >& outputs, 38 | Phase phase) override; 39 | 40 | virtual void Backward(std::vector< std::shared_ptr >& operands, 41 | std::vector< bool >& isConst, 42 | std::vector< std::shared_ptr >& outputs) override; 43 | 44 | void GetOutInfo(std::vector& out_node_info, std::vector& out_num_node); 45 | 46 | void ConstructRelSp(std::vector& out_node_info, std::vector& out_num_node, size_t rel_nums); 47 | void ConstructNodeSp(std::vector& out_node_info, std::vector& out_num_node, size_t num_in); 48 | 49 | SpTensor cpu_rel_mat, cpu_node_mat; 50 | SpTensor rel_mat, node_mat; 51 | std::vector* ptr_in_info; 52 | std::vector* ptr_in_cnt; 53 | 54 | std::vector< std::set > src_positions; 55 | std::vector< std::set > rel_types; 56 | 57 | DTensor node_trans, node_grad; 58 | }; 59 | 60 | } 61 | 62 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/config.h: -------------------------------------------------------------------------------- 1 | #ifndef cfg_H 2 | #define cfg_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "util/fmt.h" 11 | #include "util/gnn_macros.h" 12 | typedef float Dtype; 13 | typedef gnn::CPU mode; 14 | 15 | struct cfg 16 | { 17 | static int iter, max_bp_iter, max_q_iter, nhop_subg; 18 | static int num_neg; 19 | static int n_hidden; 20 | static int test_tpok; 21 | static unsigned batch_size, dev_id; 22 | static unsigned n_embed; 23 | static unsigned max_iter; 24 | static unsigned test_interval; 25 | static unsigned report_interval; 26 | static unsigned save_interval; 27 | static bool test_only; 28 | static bool vis_score; 29 | static Dtype lr; 30 | static Dtype p_pos; 31 | static Dtype l2_penalty; 32 | static Dtype momentum; 33 | static Dtype margin; 34 | static Dtype w_scale; 35 | static const char *save_dir, *data_root, *dataset, *loss_type, *net_type, *init_idx_file; 36 | 37 | static void LoadParams(const int argc, const char** argv) 38 | { 39 | for (int i = 1; i < argc; i += 2) 40 | { 41 | if (strcmp(argv[i], "-test_tpok") == 0) 42 | test_tpok = atoi(argv[i + 1]); 43 | if (strcmp(argv[i], "-data_root") == 0) 44 | data_root = argv[i + 1]; 45 | if (strcmp(argv[i], "-loss_type") == 0) 46 | loss_type = argv[i + 1]; 47 | if (strcmp(argv[i], "-dataset") == 0) 48 | dataset = argv[i + 1]; 49 | if (strcmp(argv[i], "-net_type") == 0) 50 | net_type = argv[i + 1]; 51 | if (strcmp(argv[i], "-lr") == 0) 52 | lr = atof(argv[i + 1]); 53 | if (strcmp(argv[i], "-n_hidden") == 0) 54 | n_hidden = atoi(argv[i + 1]); 55 | if (strcmp(argv[i], "-max_bp_iter") == 0) 56 | max_bp_iter = atoi(argv[i + 1]); 57 | if (strcmp(argv[i], "-num_neg") == 0) 58 | num_neg = atoi(argv[i + 1]); 59 | if (strcmp(argv[i], "-max_q_iter") == 0) 60 | max_q_iter = atoi(argv[i + 1]); 61 | if (strcmp(argv[i], "-nhop_subg") == 0) 62 | nhop_subg = atoi(argv[i + 1]); 63 | if (strcmp(argv[i], "-dev_id") == 0) 64 | dev_id = atoi(argv[i + 1]); 65 | if (strcmp(argv[i], "-cur_iter") == 0) 66 | iter = atoi(argv[i + 1]); 67 | if (strcmp(argv[i], "-embed") == 0) 68 | n_embed = atoi(argv[i + 1]); 69 | if (strcmp(argv[i], "-max_iter") == 0) 70 | max_iter = atoi(argv[i + 1]); 71 | if (strcmp(argv[i], "-batch_size") == 0) 72 | batch_size = atoi(argv[i + 1]); 73 | if (strcmp(argv[i], "-int_test") == 0) 74 | test_interval = atoi(argv[i + 1]); 75 | if (strcmp(argv[i], "-int_report") == 0) 76 | report_interval = atoi(argv[i + 1]); 77 | if (strcmp(argv[i], "-int_save") == 0) 78 | save_interval = atoi(argv[i + 1]); 79 | if (strcmp(argv[i], "-l2") == 0) 80 | l2_penalty = atof(argv[i + 1]); 81 | if (strcmp(argv[i], "-margin") == 0) 82 | margin = atof(argv[i + 1]); 83 | if (strcmp(argv[i], "-w_scale") == 0) 84 | w_scale = atof(argv[i + 1]); 85 | if (strcmp(argv[i], "-m") == 0) 86 | momentum = atof(argv[i + 1]); 87 | if (strcmp(argv[i], "-svdir") == 0) 88 | save_dir = argv[i + 1]; 89 | if (strcmp(argv[i], "-init_idx_file") == 0) 90 | init_idx_file = argv[i + 1]; 91 | if (strcmp(argv[i], "-test_only") == 0) 92 | test_only = atoi(argv[i + 1]); 93 | if (strcmp(argv[i], "-vis_score") == 0) 94 | vis_score = atoi(argv[i + 1]); 95 | } 96 | 97 | if (vis_score) 98 | { 99 | std::cerr << "vis score" << std::endl; 100 | assert(iter); 101 | } 102 | if (test_only) 103 | { 104 | std::cerr << "test only" << std::endl; 105 | std::cerr << "test_tpok = " << test_tpok << std::endl; 106 | assert(iter); 107 | } 108 | if (init_idx_file) 109 | { 110 | std::cerr << "init network" << std::endl; 111 | } 112 | std::cerr << "net_type = " << net_type << std::endl; 113 | std::cerr << "n_hidden = " << n_hidden << std::endl; 114 | std::cerr << "dev_id = " << dev_id << std::endl; 115 | std::cerr << "loss_type = " << loss_type << std::endl; 116 | std::cerr << "nhop_subg = " << nhop_subg << std::endl; 117 | std::cerr << "margin = " << margin << std::endl; 118 | std::cerr << "num_neg = " << num_neg << std::endl; 119 | std::cerr << "max_q_iter = " << max_q_iter << std::endl; 120 | std::cerr << "max_bp_iter = " << max_bp_iter << std::endl; 121 | std::cerr << "batch_size = " << batch_size << std::endl; 122 | std::cerr << "dataset = " << dataset << std::endl; 123 | std::cerr << "n_embed = " << n_embed << std::endl; 124 | std::cerr << "max_iter = " << max_iter << std::endl; 125 | std::cerr << "test_interval = " << test_interval << std::endl; 126 | std::cerr << "report_interval = " << report_interval << std::endl; 127 | std::cerr << "save_interval = " << save_interval << std::endl; 128 | std::cerr << "lr = " << lr << std::endl; 129 | std::cerr << "w_scale = " << w_scale << std::endl; 130 | std::cerr << "l2_penalty = " << l2_penalty << std::endl; 131 | std::cerr << "momentum = " << momentum << std::endl; 132 | std::cerr << "init iter = " << iter << std::endl; 133 | } 134 | }; 135 | 136 | #endif 137 | -------------------------------------------------------------------------------- /code_qa_txt/include/dataset.h: -------------------------------------------------------------------------------- 1 | #ifndef DATASET_H 2 | #define DATASET_H 3 | 4 | #include "config.h" 5 | #include "knowledge_base.h" 6 | 7 | struct Sample 8 | { 9 | int s_idx; 10 | std::vector< int > q_word_list, q_side_word_list; 11 | std::vector< Node* > q_entities, answer_entities; 12 | Sample(); 13 | }; 14 | 15 | class Dataset 16 | { 17 | public: 18 | 19 | Dataset(); 20 | 21 | void Load(const char* suffix); 22 | 23 | void SetupStream(bool randomized = false); 24 | 25 | bool GetMiniBatch(int batch_size, std::vector< Sample* >& mini_batch); 26 | bool GetSplitMiniBatch(int batch_size, std::vector< Sample* >& mini_batch); 27 | 28 | std::vector< Sample* > orig_samples; 29 | std::vector< Sample* > split_samples; 30 | 31 | private: 32 | 33 | bool GetData(int batch_size, std::vector< Sample* >& mini_batch, std::vector< Sample* >& samples); 34 | bool randomized; 35 | int cur_pos; 36 | std::vector idxes; 37 | }; 38 | 39 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/dict.h: -------------------------------------------------------------------------------- 1 | #ifndef DICT_H 2 | #define DICT_H 3 | 4 | #include "config.h" 5 | 6 | std::map GetRelations(); 7 | 8 | std::map GetVocab(); 9 | 10 | std::map GetSideWordDict(); 11 | 12 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/global.h: -------------------------------------------------------------------------------- 1 | #ifndef GLOBAL_H 2 | #define GLOBAL_H 3 | 4 | #include "config.h" 5 | #include "nn/factor_graph.h" 6 | #include "util/graph_struct.h" 7 | #include "nn/param_set.h" 8 | #include "knowledge_base.h" 9 | #include 10 | 11 | using namespace gnn; 12 | 13 | extern KnowledgeBase kb; 14 | extern std::map side_word_dict; 15 | extern ParamSet model; 16 | extern FactorGraph fg; 17 | extern std::map word_dict; 18 | extern std::map relation_dict; 19 | 20 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/graph_inner_product.h: -------------------------------------------------------------------------------- 1 | #ifndef GRAPH_INNER_PRODUCT_H 2 | #define GRAPH_INNER_PRODUCT_H 3 | 4 | #include "util/gnn_macros.h" 5 | #include "nn/factor.h" 6 | #include "nn/variable.h" 7 | #include "util/fmt.h" 8 | 9 | #ifdef USE_GPU 10 | #include 11 | #include 12 | #endif 13 | 14 | namespace gnn 15 | { 16 | 17 | template 18 | void SetVal(DTensor& src, int* entity_idx, int* sample_idx, DTensor& dst); 19 | template 20 | void SetVal(DTensor& src, int* entity_idx, int* sample_idx, DTensor& dst); 21 | 22 | template 23 | void BpError(DTensor& grad_out, int* entity_idx, int* sample_idx, DTensor& cur_grad); 24 | template 25 | void BpError(DTensor& grad_out, int* entity_idx, int* sample_idx, DTensor& cur_grad); 26 | 27 | template 28 | class GraphInnerProduct : public Factor 29 | { 30 | public: 31 | static std::string StrType() 32 | { 33 | return "GraphInnerProduct"; 34 | } 35 | 36 | using OutType = std::shared_ptr< DTensorVar >; 37 | 38 | OutType CreateOutVar() 39 | { 40 | auto out_name = fmt::sprintf("%s:out_0", this->name); 41 | return std::make_shared< DTensorVar >(out_name); 42 | } 43 | 44 | GraphInnerProduct(std::string _name, int _num_entities, PropErr _properr = PropErr::T); 45 | 46 | virtual void Forward(std::vector< std::shared_ptr >& operands, 47 | std::vector< std::shared_ptr >& outputs, 48 | Phase phase) override; 49 | 50 | virtual void Backward(std::vector< std::shared_ptr >& operands, 51 | std::vector< bool >& isConst, 52 | std::vector< std::shared_ptr >& outputs) override; 53 | 54 | int num_entities; 55 | DTensor tmp_out; 56 | #ifdef USE_GPU 57 | thrust::host_vector entity_idx, sample_idx; 58 | thrust::device_vector gpu_entity_idx, gpu_sample_idx; 59 | #else 60 | std::vector entity_idx, sample_idx; 61 | #endif 62 | 63 | int* ptr_entity; 64 | int* ptr_sample; 65 | }; 66 | 67 | } 68 | 69 | #endif 70 | -------------------------------------------------------------------------------- /code_qa_txt/include/inet.h: -------------------------------------------------------------------------------- 1 | #ifndef INET_H 2 | #define INET_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include "config.h" 8 | #include "tensor/tensor.h" 9 | #include "nn/variable.h" 10 | 11 | using namespace gnn; 12 | 13 | struct Sample; 14 | class INet 15 | { 16 | public: 17 | INet(); 18 | 19 | virtual void BuildNet() = 0; 20 | virtual void BuildBatchGraph(std::vector< Sample* >& mini_batch, Phase phase); 21 | 22 | SpTensor q_bow_input, ans_output; 23 | SpTensor m_q_bow_input, m_ans_output; 24 | 25 | DTensor y_idxes; 26 | 27 | std::map< std::string, void* > inputs; 28 | SpTensor entity_bow; 29 | SpTensor m_entity_bow; 30 | std::shared_ptr< DTensorVar > loss, hit_rate, pos_probs, pred, q_embed_query, q_y_bow_match; 31 | std::shared_ptr< DTensorVar > sampled_y_idx, hitk; 32 | }; 33 | 34 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/knowledge_base.h: -------------------------------------------------------------------------------- 1 | #ifndef KNOWLEDGE_BASE_H 2 | #define KNOWLEDGE_BASE_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | class Node 10 | { 11 | public: 12 | 13 | Node(std::string _name, int _idx); 14 | 15 | void AddNeighbor(int rel_type, Node* y, int kb_idx, bool isReverse); 16 | 17 | std::string name; 18 | int idx; 19 | std::vector< int > word_idx_list; 20 | std::vector< std::pair< int, Node* > > adj_list; 21 | std::vector< std::pair< int, bool > > triplet_info; 22 | }; 23 | 24 | struct Sample; 25 | class SubGraph 26 | { 27 | public: 28 | SubGraph(Sample* _sample); 29 | 30 | std::vector< Node* > nodes; 31 | std::vector< std::tuple< int, int, int > > edges; 32 | Sample* sample; 33 | 34 | protected: 35 | void AddNode(Node* node); 36 | void AddEdge(Node* src, int r_type, Node* dst); 37 | std::map node_map; 38 | std::set< std::tuple > edge_set; 39 | void BFS(Node* root); 40 | }; 41 | 42 | class KnowledgeBase 43 | { 44 | public: 45 | 46 | KnowledgeBase(); 47 | 48 | void ParseKnowledgeFile(); 49 | 50 | void ParseEntityInAnswers(const char* suffix); 51 | 52 | Node* GetOrAddNode(std::string name); 53 | 54 | Node* GetNodeOrDie(std::string name); 55 | 56 | int NodeIdx(std::string name); 57 | 58 | std::map node_dict; 59 | std::vector node_list; 60 | int n_knowledges; 61 | }; 62 | 63 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/net_latent_y.h: -------------------------------------------------------------------------------- 1 | #ifndef NET_LATENT_Y_H 2 | #define NET_LATENT_Y_H 3 | 4 | #include "inet.h" 5 | #include "tensor/tensor_all.h" 6 | #include "bfs_path_embed.h" 7 | #include "graph_inner_product.h" 8 | #include "node_select.h" 9 | #include "var_sample.h" 10 | 11 | #include 12 | #include 13 | 14 | using namespace gnn; 15 | 16 | class NetLatentY : public INet 17 | { 18 | public: 19 | NetLatentY(); 20 | virtual void BuildNet() override; 21 | virtual void BuildBatchGraph(std::vector< Sample* >& mini_batch, Phase phase) override; 22 | 23 | std::shared_ptr< DTensorVar > GetCritic(std::shared_ptr< SpTensorVar > q_bow); 24 | 25 | std::shared_ptr< DTensorVar > GetMatchScores(std::shared_ptr< DTensorVar >& q_embed, 26 | std::shared_ptr& samples, 27 | std::shared_ptr< DTensorVar >& rel_embed, 28 | std::shared_ptr< DTensorVar >& w_recur, 29 | std::shared_ptr< VectorVar >& start_nodes); 30 | 31 | std::vector answer_dst_nodes; 32 | }; 33 | 34 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/net_multihop.h: -------------------------------------------------------------------------------- 1 | #ifndef NET_MULTI_H 2 | #define NET_MULTI_H 3 | 4 | #include "inet.h" 5 | #include "tensor/tensor_all.h" 6 | #include "bfs_path_embed.h" 7 | #include "graph_inner_product.h" 8 | #include "node_select.h" 9 | #include "var_sample.h" 10 | 11 | #include 12 | #include 13 | 14 | using namespace gnn; 15 | 16 | class NetMultiHop : public INet 17 | { 18 | public: 19 | NetMultiHop(); 20 | virtual void BuildNet() override; 21 | virtual void BuildBatchGraph(std::vector< Sample* >& mini_batch, Phase phase) override; 22 | 23 | std::shared_ptr< DTensorVar > GetCritic(std::shared_ptr< SpTensorVar > q_bow); 24 | 25 | std::shared_ptr< DTensorVar > GetMatchScores(std::shared_ptr< DTensorVar >& q_embed, 26 | std::shared_ptr& samples, 27 | std::shared_ptr< DTensorVar >& rel_embed, 28 | std::shared_ptr< DTensorVar >& w_recur, 29 | std::shared_ptr< VectorVar >& start_nodes); 30 | 31 | std::vector answer_dst_nodes; 32 | }; 33 | 34 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/node_select.h: -------------------------------------------------------------------------------- 1 | #ifndef NODE_SELECT_H 2 | #define NODE_SELECT_H 3 | 4 | #include "util/gnn_macros.h" 5 | #include "nn/factor.h" 6 | #include "nn/variable.h" 7 | #include "util/fmt.h" 8 | #include "var_sample.h" 9 | 10 | class Node; 11 | 12 | namespace gnn 13 | { 14 | 15 | class NodeSelect : public Factor 16 | { 17 | public: 18 | static std::string StrType() 19 | { 20 | return "NodeSelect"; 21 | } 22 | 23 | using OutType = std::shared_ptr< VectorVar >; 24 | 25 | OutType CreateOutVar() 26 | { 27 | auto out_name = fmt::sprintf("%s:out_0", this->name); 28 | return std::make_shared< VectorVar >(out_name); 29 | } 30 | 31 | NodeSelect(std::string _name); 32 | 33 | virtual void Forward(std::vector< std::shared_ptr >& operands, 34 | std::vector< std::shared_ptr >& outputs, 35 | Phase phase) override; 36 | 37 | private: 38 | void SetupNodes(const int len, const int* idxes, std::vector& dst); 39 | }; 40 | 41 | } 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /code_qa_txt/include/util.h: -------------------------------------------------------------------------------- 1 | #ifndef UTIL_H 2 | #define UTIL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | void str_split(const std::string &s, char delim, std::vector &result); 11 | 12 | void str_replace(std::string& str, std::string p, std::string q); 13 | 14 | inline std::string <rim(std::string &s) { 15 | s.erase(s.begin(), std::find_if(s.begin(), s.end(), 16 | std::not1(std::ptr_fun(std::isspace)))); 17 | return s; 18 | } 19 | 20 | // trim from end 21 | inline std::string &rtrim(std::string &s) { 22 | s.erase(std::find_if(s.rbegin(), s.rend(), 23 | std::not1(std::ptr_fun(std::isspace))).base(), s.end()); 24 | return s; 25 | } 26 | 27 | // trim from both ends 28 | inline std::string &trim(std::string &s) { 29 | return ltrim(rtrim(s)); 30 | } 31 | 32 | #endif -------------------------------------------------------------------------------- /code_qa_txt/include/var_sample.h: -------------------------------------------------------------------------------- 1 | #ifndef VAR_SAMPLE_H 2 | #define VAR_SAMPLE_H 3 | 4 | #include "nn/variable.h" 5 | #include "dataset.h" 6 | 7 | namespace gnn 8 | { 9 | 10 | class SampleVar : public Variable 11 | { 12 | public: 13 | SampleVar(std::string _name); 14 | 15 | virtual EleType GetEleType() override; 16 | 17 | virtual MatMode GetMode() override; 18 | 19 | virtual void SetRef(void* p) override; 20 | 21 | std::vector* samples; 22 | }; 23 | 24 | template 25 | class VectorVar : public Variable 26 | { 27 | public: 28 | VectorVar(std::string _name); 29 | 30 | virtual EleType GetEleType() override; 31 | 32 | virtual MatMode GetMode() override; 33 | 34 | virtual void SetRef(void* p) override; 35 | 36 | std::vector* vec; 37 | }; 38 | 39 | } 40 | 41 | #endif -------------------------------------------------------------------------------- /code_qa_txt/init_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | make 3 | 4 | nhop_subg=1 5 | init_pct=0.05 6 | dataset=vanilla 7 | data_root=../metaQA 8 | net_type=NetMultiHop 9 | init_idx_file=$data_root/$nhop_subg-hop/init_index_${init_pct}_qa_train.txt 10 | 11 | result_root=$HOME/scratch/results/graph_mem/nhop-$nhop_subg/$dataset 12 | 13 | num_neg=10000 14 | max_bp_iter=1 15 | max_q_iter=3 16 | batch_size=128 17 | n_hidden=64 18 | n_embed=256 19 | margin=0.1 20 | learning_rate=0.01 21 | max_iter=4000000 22 | cur_iter=0 23 | w_scale=0.01 24 | loss_type=cross_entropy 25 | save_dir=$result_root/embed-$n_embed 26 | 27 | if [ ! -e $save_dir ]; 28 | then 29 | mkdir -p $save_dir 30 | fi 31 | 32 | ./build/main \ 33 | -init_idx_file $init_idx_file \ 34 | -num_neg $num_neg \ 35 | -loss_type $loss_type \ 36 | -data_root $data_root \ 37 | -dataset $dataset \ 38 | -n_hidden $n_hidden \ 39 | -nhop_subg $nhop_subg \ 40 | -lr $learning_rate \ 41 | -max_bp_iter $max_bp_iter \ 42 | -net_type $net_type \ 43 | -max_q_iter $max_q_iter \ 44 | -margin $margin \ 45 | -max_iter $max_iter \ 46 | -svdir $save_dir \ 47 | -embed $n_embed \ 48 | -batch_size $batch_size \ 49 | -m 0.9 \ 50 | -l2 0.00 \ 51 | -w_scale $w_scale \ 52 | -int_report 10 \ 53 | -int_test 10000 \ 54 | -int_save 100 \ 55 | -cur_iter $cur_iter \ 56 | 2>&1 | tee $save_dir/log-${net_type}.txt 57 | -------------------------------------------------------------------------------- /code_qa_txt/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | make 3 | 4 | nhop_subg=1 5 | dataset=vanilla 6 | data_root=../metaQA 7 | net_type=NetLatentY 8 | 9 | result_root=$HOME/scratch/results/graph_mem/nhop-$nhop_subg/$dataset 10 | 11 | num_neg=10000 12 | max_bp_iter=1 13 | max_q_iter=3 14 | batch_size=128 15 | n_hidden=64 16 | n_embed=256 17 | margin=0.1 18 | learning_rate=0.001 19 | max_iter=4000000 20 | cur_iter=1000 21 | w_scale=0.01 22 | loss_type=cross_entropy 23 | save_dir=$result_root/embed-$n_embed 24 | 25 | if [ ! -e $save_dir ]; 26 | then 27 | mkdir -p $save_dir 28 | fi 29 | 30 | ./build/main \ 31 | -num_neg $num_neg \ 32 | -loss_type $loss_type \ 33 | -data_root $data_root \ 34 | -dataset $dataset \ 35 | -n_hidden $n_hidden \ 36 | -nhop_subg $nhop_subg \ 37 | -lr $learning_rate \ 38 | -max_bp_iter $max_bp_iter \ 39 | -net_type $net_type \ 40 | -max_q_iter $max_q_iter \ 41 | -margin $margin \ 42 | -max_iter $max_iter \ 43 | -svdir $save_dir \ 44 | -embed $n_embed \ 45 | -batch_size $batch_size \ 46 | -m 0.9 \ 47 | -l2 0.00 \ 48 | -w_scale $w_scale \ 49 | -int_report 1 \ 50 | -int_test 100 \ 51 | -int_save 100 \ 52 | -cur_iter $cur_iter \ 53 | 2>&1 | tee $save_dir/log-${net_type}.txt 54 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/bfs_path_embed.cpp: -------------------------------------------------------------------------------- 1 | #include "bfs_path_embed.h" 2 | #include "var_sample.h" 3 | #include "global.h" 4 | 5 | namespace gnn 6 | { 7 | 8 | template 9 | BfsPathEmbed::BfsPathEmbed(std::string _name, PropErr _properr) 10 | : Factor(_name, _properr) 11 | { 12 | ptr_in_info = new std::vector(); 13 | ptr_in_cnt = new std::vector(); 14 | } 15 | 16 | template 17 | void BfsPathEmbed::GetOutInfo(std::vector& out_node_info, std::vector& out_num_node) 18 | { 19 | auto& in_node_info = *ptr_in_info; 20 | auto& in_num_node = *ptr_in_cnt; 21 | 22 | src_positions.clear(); 23 | rel_types.clear(); 24 | 25 | int cur_in_pos = 0; 26 | for (size_t i = 0; i < in_num_node.size(); ++i) 27 | { 28 | std::map next_nodes; 29 | for (int j = 0; j < in_num_node[i]; ++j, ++cur_in_pos) 30 | { 31 | auto* src_node = in_node_info[cur_in_pos]; 32 | 33 | for (auto& edge : src_node->adj_list) 34 | { 35 | if (!next_nodes.count(edge.second->idx)) 36 | { 37 | next_nodes[edge.second->idx] = out_node_info.size(); 38 | out_node_info.push_back(edge.second); 39 | src_positions.push_back(std::set()); 40 | rel_types.push_back(std::set()); 41 | } 42 | auto out_pos = next_nodes[edge.second->idx]; 43 | src_positions[out_pos].insert(cur_in_pos); 44 | rel_types[out_pos].insert(edge.first); 45 | } 46 | } 47 | out_num_node[i] = next_nodes.size(); 48 | } 49 | assert(cur_in_pos == (int)in_node_info.size()); 50 | } 51 | 52 | template 53 | void BfsPathEmbed::ConstructRelSp(std::vector& out_node_info, std::vector& out_num_node, size_t rel_nums) 54 | { 55 | uint nnz_rel = 0; 56 | for (size_t i = 0; i < rel_types.size(); ++i) 57 | nnz_rel += rel_types[i].size(); 58 | cpu_rel_mat.Reshape({out_node_info.size(), rel_nums}); 59 | cpu_rel_mat.ResizeSp(nnz_rel, out_node_info.size() + 1); 60 | 61 | nnz_rel = 0; 62 | for (size_t i = 0; i < out_node_info.size(); ++i) 63 | { 64 | cpu_rel_mat.data->row_ptr[i] = nnz_rel; 65 | for (auto rel : rel_types[i]) 66 | { 67 | cpu_rel_mat.data->val[nnz_rel] = 1.0; 68 | cpu_rel_mat.data->col_idx[nnz_rel] = rel; 69 | nnz_rel += 1; 70 | } 71 | } 72 | cpu_rel_mat.data->row_ptr[out_node_info.size()] = nnz_rel; 73 | assert((int)nnz_rel == cpu_rel_mat.data->nnz); 74 | rel_mat.CopyFrom(cpu_rel_mat); 75 | } 76 | 77 | template 78 | void BfsPathEmbed::ConstructNodeSp(std::vector& out_node_info, std::vector& out_num_node, size_t num_in) 79 | { 80 | uint nnz_node = 0; 81 | for (size_t i = 0; i < src_positions.size(); ++i) 82 | nnz_node += src_positions[i].size(); 83 | cpu_node_mat.Reshape({out_node_info.size(), num_in}); 84 | cpu_node_mat.ResizeSp(nnz_node, out_node_info.size() + 1); 85 | 86 | nnz_node = 0; 87 | for (size_t i = 0; i < out_node_info.size(); ++i) 88 | { 89 | cpu_node_mat.data->row_ptr[i] = nnz_node; 90 | for (auto pos : src_positions[i]) 91 | { 92 | cpu_node_mat.data->val[nnz_node] = 1.0 / src_positions[i].size(); 93 | cpu_node_mat.data->col_idx[nnz_node] = pos; 94 | nnz_node += 1; 95 | } 96 | } 97 | cpu_node_mat.data->row_ptr[out_node_info.size()] = nnz_node; 98 | assert((int)nnz_node == cpu_node_mat.data->nnz); 99 | node_mat.CopyFrom(cpu_node_mat); 100 | } 101 | 102 | template 103 | void BfsPathEmbed::Forward(std::vector< std::shared_ptr >& operands, 104 | std::vector< std::shared_ptr >& outputs, 105 | Phase phase) 106 | { 107 | ASSERT(operands.size() == 6 || operands.size() == 4, "unexpected input size for " << StrType()); 108 | ASSERT(outputs.size() == 3, "unexpected output size for " << StrType()); 109 | 110 | auto& samples = *(dynamic_cast< SampleVar* >(operands[0].get())->samples); 111 | auto& rel_embed = dynamic_cast*>(operands[1].get())->value; 112 | auto& w_path_recur = dynamic_cast*>(operands[2].get())->value; 113 | 114 | auto& output = dynamic_cast*>(outputs[0].get())->value; 115 | auto* var_node_info = dynamic_cast< VectorVar* >(outputs[1].get()); 116 | auto* var_num_node = dynamic_cast< VectorVar* >(outputs[2].get()); 117 | if (!var_node_info->vec) 118 | var_node_info->vec = new std::vector(); 119 | if (!var_num_node->vec) 120 | var_num_node->vec = new std::vector(); 121 | auto& out_node_info = *(var_node_info->vec); 122 | auto& out_num_node = *(var_num_node->vec); 123 | 124 | if ((int)operands.size() == 4) 125 | { 126 | auto& st_nodes = *(dynamic_cast< VectorVar* >(operands[3].get())->vec); 127 | 128 | ptr_in_cnt->resize(samples.size()); 129 | ptr_in_info->clear(); 130 | for (size_t i = 0; i < samples.size(); ++i) 131 | { 132 | assert(st_nodes[i]); 133 | ptr_in_info->push_back(st_nodes[i]); 134 | (*ptr_in_cnt)[i] = 1; 135 | } 136 | } else { 137 | ptr_in_info = dynamic_cast< VectorVar* >(operands[4].get())->vec; 138 | ptr_in_cnt = dynamic_cast< VectorVar* >(operands[5].get())->vec; 139 | } 140 | 141 | out_node_info.clear(); 142 | out_num_node.resize(samples.size()); 143 | 144 | GetOutInfo(out_node_info, out_num_node); 145 | 146 | ConstructRelSp(out_node_info, out_num_node, rel_embed.rows()); 147 | output.MM(rel_mat, rel_embed, Trans::N, Trans::N, 1.0, 0.0); 148 | 149 | if ((int)operands.size() == 6) 150 | { 151 | auto& prev_embed = dynamic_cast*>(operands[3].get())->value; 152 | node_trans.MM(prev_embed, w_path_recur, Trans::N, Trans::N, 1.0, 0.0); 153 | 154 | ConstructNodeSp(out_node_info, out_num_node, ptr_in_info->size()); 155 | output.MM(node_mat, node_trans, Trans::N, Trans::N, 1.0, 1.0); 156 | } 157 | } 158 | 159 | template 160 | void BfsPathEmbed::Backward(std::vector< std::shared_ptr >& operands, 161 | std::vector< bool >& isConst, 162 | std::vector< std::shared_ptr >& outputs) 163 | { 164 | ASSERT(operands.size() == 6 || operands.size() == 4, "unexpected input size for " << StrType()); 165 | ASSERT(outputs.size() == 3, "unexpected output size for " << StrType()); 166 | 167 | auto grad_out = dynamic_cast*>(outputs[0].get())->grad.Full(); 168 | 169 | auto rel_grad = dynamic_cast*>(operands[1].get())->grad.Full(); 170 | 171 | if ((int)operands.size() == 4) // first hop 172 | { 173 | rel_grad.MM(rel_mat, grad_out, Trans::T, Trans::N, 1.0, 1.0); 174 | } else { 175 | node_grad.MM(node_mat, grad_out, Trans::T, Trans::N, 1.0, 0.0); 176 | 177 | auto& prev_embed = dynamic_cast*>(operands[3].get())->value; 178 | auto prev_grad = dynamic_cast*>(operands[3].get())->grad.Full(); 179 | auto& w_path_recur = dynamic_cast*>(operands[2].get())->value; 180 | auto w_grad = dynamic_cast*>(operands[2].get())->grad.Full(); 181 | 182 | prev_grad.MM(node_grad, w_path_recur, Trans::N, Trans::T, 1.0, 1.0); 183 | w_grad.MM(prev_embed, node_grad, Trans::T, Trans::N, 1.0, 1.0); 184 | } 185 | } 186 | 187 | INSTANTIATE_CLASS(BfsPathEmbed) 188 | 189 | } 190 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/config.cpp: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | 3 | int cfg::iter = 0; 4 | int cfg::max_bp_iter = 1; 5 | int cfg::max_q_iter = 1; 6 | int cfg::nhop_subg = 1; 7 | int cfg::num_neg = 100; 8 | int cfg::n_hidden = 64; 9 | int cfg::test_tpok = 5; 10 | bool cfg::test_only = false; 11 | bool cfg::vis_score = false; 12 | unsigned cfg::n_embed = 64; 13 | unsigned cfg::max_iter = 0; 14 | unsigned cfg::dev_id = 0; 15 | unsigned cfg::batch_size = 32; 16 | unsigned cfg::test_interval = 10000; 17 | unsigned cfg::report_interval = 100; 18 | unsigned cfg::save_interval = 50000; 19 | Dtype cfg::lr = 0.0005; 20 | Dtype cfg::margin = 0.1; 21 | Dtype cfg::l2_penalty = 0; 22 | Dtype cfg::momentum = 0; 23 | Dtype cfg::w_scale = 0.01; 24 | Dtype cfg::p_pos = 0.5; 25 | const char* cfg::save_dir = "./saved"; 26 | const char* cfg::dataset = nullptr; 27 | const char* cfg::net_type = nullptr; 28 | const char* cfg::data_root = nullptr; 29 | const char* cfg::loss_type = nullptr; 30 | const char* cfg::init_idx_file = nullptr; 31 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/dataset.cpp: -------------------------------------------------------------------------------- 1 | #include "dataset.h" 2 | #include "util.h" 3 | #include 4 | #include 5 | #include 6 | #include "global.h" 7 | 8 | Sample::Sample() 9 | { 10 | q_word_list.clear(); 11 | q_entities.clear(); 12 | answer_entities.clear(); 13 | q_side_word_list.clear(); 14 | } 15 | 16 | Dataset::Dataset() 17 | { 18 | orig_samples.clear(); 19 | split_samples.clear(); 20 | idxes.clear(); 21 | } 22 | 23 | void Dataset::Load(const char* suffix) 24 | { 25 | auto filename = fmt::sprintf("%s/%d-hop/%s/qa_%s.txt", 26 | cfg::data_root, cfg::nhop_subg, cfg::dataset, suffix); 27 | std::set train_idxes; 28 | if (!strcmp(suffix, "train") && cfg::init_idx_file) 29 | { 30 | std::ifstream fin(cfg::init_idx_file); 31 | int idx; 32 | while (fin >> idx) 33 | { 34 | train_idxes.insert(idx); 35 | } 36 | } 37 | std::ifstream fin(filename); 38 | 39 | std::string st; 40 | std::vector buf; 41 | int s_idx = 0, line_num = 0; 42 | std::map n_q; 43 | while (std::getline(fin, st)) 44 | { 45 | line_num++; 46 | if (train_idxes.size() && !train_idxes.count(line_num - 1)) 47 | continue; 48 | str_split(st, '\t', buf); 49 | assert(buf.size() == 2); 50 | 51 | auto* cur_sample = new Sample(); 52 | auto q = buf[0], a = buf[1]; 53 | 54 | size_t pos = 0; 55 | while (pos < q.size()) 56 | { 57 | if (q[pos] == '[') 58 | { 59 | auto ed = pos + 1; 60 | while (ed < q.size() && q[ed] != ']') 61 | ed++; 62 | auto e_name = q.substr(pos + 1, ed - pos - 1); 63 | if (kb.node_dict.count(e_name)) 64 | cur_sample->q_entities.push_back(kb.node_dict[e_name]); 65 | pos = ed; 66 | } 67 | pos++; 68 | } 69 | if (cur_sample->q_entities.size() == 0) 70 | continue; 71 | if (!n_q.count(cur_sample->q_entities.size())) 72 | n_q[cur_sample->q_entities.size()] = 0; 73 | n_q[cur_sample->q_entities.size()]++; 74 | assert(cur_sample->q_entities.size() <= 1); // at most have one entity in query 75 | 76 | cur_sample->s_idx = s_idx; 77 | s_idx++; 78 | str_replace(q, "[", ""); 79 | str_replace(q, "]", ""); 80 | 81 | str_split(q, ' ', buf); 82 | for (auto w : buf) 83 | { 84 | if (word_dict.count(w)) 85 | cur_sample->q_word_list.push_back(word_dict[w]); 86 | if (side_word_dict.count(w)) 87 | cur_sample->q_side_word_list.push_back(side_word_dict[w]); 88 | } 89 | 90 | str_split(a, '|', buf); 91 | for (auto e : buf) 92 | { 93 | cur_sample->answer_entities.push_back(kb.GetNodeOrDie(e)); 94 | } 95 | 96 | orig_samples.push_back(cur_sample); 97 | 98 | for (auto e : cur_sample->answer_entities) 99 | { 100 | auto* s = new Sample(); 101 | s->q_word_list = cur_sample->q_word_list; 102 | s->q_side_word_list = cur_sample->q_side_word_list; 103 | s->q_entities = cur_sample->q_entities; 104 | s->answer_entities.push_back(e); 105 | split_samples.push_back(s); 106 | } 107 | } 108 | std::cerr << suffix << " has " << orig_samples.size() << " samples" << " and split into " << split_samples.size() << std::endl; 109 | for (auto p : n_q) 110 | std::cerr << "n_q: " << p.first << " # samples: " << p.second << std::endl; 111 | } 112 | 113 | void Dataset::SetupStream(bool randomized) 114 | { 115 | this->randomized = randomized; 116 | cur_pos = 0; 117 | 118 | if (randomized) 119 | { 120 | std::random_shuffle(orig_samples.begin(), orig_samples.end()); 121 | std::random_shuffle(split_samples.begin(), split_samples.end()); 122 | } 123 | } 124 | 125 | bool Dataset::GetData(int batch_size, std::vector< Sample* >& mini_batch, std::vector< Sample* >& samples) 126 | { 127 | if (cur_pos + batch_size > (int)samples.size() && randomized) 128 | { 129 | std::random_shuffle(samples.begin(), samples.end()); 130 | cur_pos = 0; 131 | } 132 | if (cur_pos + batch_size > (int)samples.size()) 133 | batch_size = samples.size() - cur_pos; 134 | if (batch_size <= 0) 135 | return false; 136 | 137 | mini_batch.resize(batch_size); 138 | for (int i = cur_pos; i < cur_pos + batch_size; ++i) 139 | { 140 | mini_batch[i - cur_pos] = samples[i]; 141 | } 142 | 143 | cur_pos += batch_size; 144 | return true; 145 | } 146 | 147 | bool Dataset::GetMiniBatch(int batch_size, std::vector< Sample* >& mini_batch) 148 | { 149 | return GetData(batch_size, mini_batch, orig_samples); 150 | } 151 | 152 | bool Dataset::GetSplitMiniBatch(int batch_size, std::vector< Sample* >& mini_batch) 153 | { 154 | return GetData(batch_size, mini_batch, split_samples); 155 | } 156 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/dict.cpp: -------------------------------------------------------------------------------- 1 | #include "dict.h" 2 | #include "util.h" 3 | #include 4 | 5 | std::map GetRelations() 6 | { 7 | auto relations = {"directed_by", 8 | "has_genre", 9 | "has_imdb_rating", 10 | "has_imdb_votes", 11 | "has_plot", 12 | "has_tags", 13 | "in_language", 14 | "release_year", 15 | "starred_actors", 16 | "written_by"}; 17 | std::map result; 18 | result.clear(); 19 | int t = 0; 20 | for (auto& st : relations) 21 | { 22 | result[st] = t; 23 | t++; 24 | } 25 | 26 | std::cerr << "num relations: " << result.size() << std::endl; 27 | std::cerr << "["; 28 | for (auto i : relations) 29 | std::cerr << " " << i; 30 | std::cerr << " ]" << std::endl; 31 | 32 | return result; 33 | } 34 | 35 | std::map GetVocab() 36 | { 37 | std::map word_dict; 38 | word_dict.clear(); 39 | 40 | for (auto suffix : {"train", "test", "dev"}) 41 | { 42 | auto file = fmt::format("{0}/{1}-hop/{2}/qa_{3}.txt", 43 | cfg::data_root, cfg::nhop_subg, cfg::dataset, suffix); 44 | 45 | std::ifstream fin(file); 46 | 47 | std::string st; 48 | std::vector buf; 49 | 50 | while (std::getline(fin, st)) 51 | { 52 | str_split(st, '\t', buf); 53 | assert(buf.size() == 2); 54 | auto q = buf[0]; 55 | str_replace(q, "[", ""); 56 | str_replace(q, "]", ""); 57 | 58 | str_split(q, ' ', buf); 59 | for (auto w : buf) 60 | { 61 | if (word_dict.count(w) == 0) 62 | { 63 | int t = word_dict.size(); 64 | word_dict[w] = t; 65 | } 66 | } 67 | } 68 | } 69 | // words from kb 70 | auto kb_file = fmt::format("{0}/kb.txt", cfg::data_root); 71 | std::ifstream fin(kb_file); 72 | 73 | std::string st; 74 | std::vector buf, word_buf; 75 | while (std::getline(fin, st)) 76 | { 77 | str_split(st, '|', buf); 78 | assert(buf.size() == 3); 79 | 80 | str_split(buf[0], ' ', word_buf); 81 | for (auto& w : word_buf) 82 | { 83 | if (word_dict.count(w) == 0) 84 | { 85 | int t = word_dict.size(); 86 | word_dict[w] = t; 87 | } 88 | } 89 | 90 | str_split(buf[2], ' ', word_buf); 91 | for (auto& w : word_buf) 92 | { 93 | if (word_dict.count(w) == 0) 94 | { 95 | int t = word_dict.size(); 96 | word_dict[w] = t; 97 | } 98 | } 99 | } 100 | 101 | std::cerr << "size of vocab: " << word_dict.size() << std::endl; 102 | return word_dict; 103 | } 104 | 105 | std::map GetSideWordDict() 106 | { 107 | auto file = fmt::format("{0}/{1}-hop/{2}/qa_train.txt", 108 | cfg::data_root, cfg::nhop_subg, cfg::dataset); 109 | std::ifstream fin(file); 110 | 111 | std::string st; 112 | std::vector buf; 113 | 114 | std::map word_dict; 115 | word_dict.clear(); 116 | 117 | while (std::getline(fin, st)) 118 | { 119 | str_split(st, '\t', buf); 120 | assert(buf.size() == 2); 121 | auto q = buf[0]; 122 | str_split(q, ' ', buf); 123 | bool in_entity = false; 124 | for (auto w : buf) 125 | { 126 | if (w == "1") 127 | continue; 128 | if (w[0] == '[') 129 | in_entity = true; 130 | if (in_entity) 131 | { 132 | if (w[w.size() - 1] == ']') 133 | in_entity = false; 134 | } else { 135 | if (word_dict.count(w) == 0) 136 | { 137 | int t = word_dict.size(); 138 | word_dict[w] = t; 139 | } 140 | } 141 | } 142 | } 143 | 144 | std::cerr << "size of side_vocab: " << word_dict.size() << std::endl; 145 | return word_dict; 146 | } -------------------------------------------------------------------------------- /code_qa_txt/src/lib/global.cpp: -------------------------------------------------------------------------------- 1 | #include "global.h" 2 | 3 | KnowledgeBase kb; 4 | std::map side_word_dict; 5 | ParamSet model; 6 | FactorGraph fg; 7 | std::map word_dict; 8 | std::map relation_dict; 9 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/graph_inner_product.cpp: -------------------------------------------------------------------------------- 1 | #include "graph_inner_product.h" 2 | #include "var_sample.h" 3 | #include "global.h" 4 | 5 | namespace gnn 6 | { 7 | 8 | template 9 | void SetVal(DTensor& src, int* entity_idx, int* sample_idx, DTensor& dst) 10 | { 11 | for (size_t i = 0; i < src.shape.Count(); ++i) 12 | { 13 | auto row = sample_idx[i], col = entity_idx[i]; 14 | dst.data->ptr[row * dst.cols() + col] = src.data->ptr[i]; 15 | } 16 | } 17 | 18 | template 19 | void BpError(DTensor& grad_out, int* entity_idx, int* sample_idx, DTensor& cur_grad) 20 | { 21 | for (size_t i = 0; i < cur_grad.shape.Count(); ++i) 22 | { 23 | auto row = sample_idx[i], col = entity_idx[i]; 24 | cur_grad.data->ptr[i] = grad_out.data->ptr[row * grad_out.cols() + col]; 25 | } 26 | } 27 | 28 | template 29 | GraphInnerProduct::GraphInnerProduct(std::string _name, int _num_entities, PropErr _properr) 30 | : Factor(_name, _properr), num_entities(_num_entities) 31 | { 32 | 33 | } 34 | 35 | template 36 | void GraphInnerProduct::Forward(std::vector< std::shared_ptr >& operands, 37 | std::vector< std::shared_ptr >& outputs, 38 | Phase phase) 39 | { 40 | ASSERT(operands.size() == 4, "unexpected input size for " << StrType()); 41 | ASSERT(outputs.size() == 1, "unexpected output size for " << StrType()); 42 | 43 | auto& q_embed = dynamic_cast*>(operands[0].get())->value; 44 | auto& ans_embed = dynamic_cast*>(operands[1].get())->value; 45 | auto& node_info = *(dynamic_cast< VectorVar* >(operands[2].get())->vec); 46 | auto& sample_nodecnt = *(dynamic_cast< VectorVar* >(operands[3].get())->vec); 47 | 48 | entity_idx.resize(node_info.size()); 49 | sample_idx.resize(node_info.size()); 50 | 51 | tmp_out.Reshape({node_info.size(), (size_t)1}); 52 | size_t row_idx = 0; 53 | for (size_t i = 0; i < sample_nodecnt.size(); ++i) 54 | { 55 | auto row_cnt = sample_nodecnt[i]; 56 | auto cur_q = q_embed.GetRowRef(i, 1); 57 | auto cur_ans = ans_embed.GetRowRef(row_idx, row_cnt); 58 | auto cur_out = tmp_out.GetRowRef(row_idx, row_cnt); 59 | 60 | cur_out.MM(cur_ans, cur_q, Trans::N, Trans::T, 1.0, 0.0); 61 | 62 | for (size_t j = row_idx; j < row_idx + row_cnt; ++j) 63 | sample_idx[j] = i; 64 | row_idx += row_cnt; 65 | } 66 | assert(row_idx == node_info.size()); 67 | 68 | for (size_t i = 0; i < node_info.size(); ++i) 69 | entity_idx[i] = node_info[i]->idx; 70 | 71 | #ifdef USE_GPU 72 | if (mode::type == MatMode::cpu) 73 | { 74 | ptr_entity = thrust::raw_pointer_cast(entity_idx.data()); 75 | ptr_sample = thrust::raw_pointer_cast(sample_idx.data()); 76 | } else { 77 | gpu_entity_idx = entity_idx; 78 | gpu_sample_idx = sample_idx; 79 | 80 | ptr_entity = thrust::raw_pointer_cast(gpu_entity_idx.data()); 81 | ptr_sample = thrust::raw_pointer_cast(gpu_sample_idx.data()); 82 | } 83 | #else 84 | ptr_entity = entity_idx.data(); 85 | ptr_sample = sample_idx.data(); 86 | #endif 87 | 88 | auto& output = dynamic_cast*>(outputs[0].get())->value; 89 | output.Reshape({q_embed.rows(), (size_t)num_entities}); 90 | output.Zeros(); 91 | SetVal(tmp_out, ptr_entity, ptr_sample, output); 92 | } 93 | 94 | template 95 | void GraphInnerProduct::Backward(std::vector< std::shared_ptr >& operands, 96 | std::vector< bool >& isConst, 97 | std::vector< std::shared_ptr >& outputs) 98 | { 99 | ASSERT(operands.size() == 4, "unexpected input size for " << StrType()); 100 | ASSERT(outputs.size() == 1, "unexpected output size for " << StrType()); 101 | auto& sample_nodecnt = *(dynamic_cast< VectorVar* >(operands[3].get())->vec); 102 | 103 | auto grad_out = dynamic_cast*>(outputs[0].get())->grad.Full(); 104 | BpError(grad_out, ptr_entity, ptr_sample, tmp_out); 105 | 106 | auto& q_embed = dynamic_cast*>(operands[0].get())->value; 107 | auto& ans_embed = dynamic_cast*>(operands[1].get())->value; 108 | 109 | auto q_grad = dynamic_cast*>(operands[0].get())->grad.Full(); 110 | auto ans_grad = dynamic_cast*>(operands[1].get())->grad.Full(); 111 | size_t row_idx = 0; 112 | for (size_t i = 0; i < sample_nodecnt.size(); ++i) 113 | { 114 | auto row_cnt = sample_nodecnt[i]; 115 | auto cur_q = q_embed.GetRowRef(i, 1); 116 | auto cur_ans = ans_embed.GetRowRef(row_idx, row_cnt); 117 | 118 | auto cur_grad = tmp_out.GetRowRef(row_idx, row_cnt); 119 | 120 | auto cur_q_grad = q_grad.GetRowRef(i, 1); 121 | auto cur_ans_grad = ans_grad.GetRowRef(row_idx, row_cnt); 122 | 123 | cur_ans_grad.MM(cur_grad, cur_q, Trans::N, Trans::N, 1.0, 1.0); 124 | cur_q_grad.MM(cur_grad, cur_ans, Trans::T, Trans::N, 1.0, 1.0); 125 | 126 | row_idx += row_cnt; 127 | } 128 | assert(row_idx == tmp_out.shape.Count()); 129 | } 130 | 131 | INSTANTIATE_CLASS(GraphInnerProduct) 132 | 133 | } 134 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/graph_inner_product.cu: -------------------------------------------------------------------------------- 1 | #include "graph_inner_product.h" 2 | #include "tensor/gpu_handle.h" 3 | #include "tensor/gpu_unary_functor.h" 4 | 5 | namespace gnn 6 | { 7 | 8 | template 9 | __global__ void SetValKernel(Dtype *dst, Dtype *src, int* entity_idx, int* sample_idx, int cols, int numElements) 10 | { 11 | int i = blockDim.x * blockIdx.x + threadIdx.x; 12 | 13 | if (i < numElements) 14 | { 15 | dst[sample_idx[i] * cols + entity_idx[i]] = src[i]; 16 | } 17 | } 18 | 19 | template 20 | void SetVal(DTensor& src, int* entity_idx, int* sample_idx, DTensor& dst) 21 | { 22 | int thread_num = c_uCudaThreadNum; 23 | if (src.shape.Count() < thread_num) 24 | thread_num = src.shape.Count(); 25 | int blocksPerGrid = (src.shape.Count() + thread_num - 1) / thread_num; 26 | 27 | SetValKernel <<< blocksPerGrid, thread_num, 0, cudaStreamPerThread >>>(dst.data->ptr, src.data->ptr, entity_idx, sample_idx, dst.cols(), src.shape.Count()); 28 | } 29 | 30 | template void SetVal(DTensor& src, int* entity_idx, int* sample_idx, DTensor& dst); 31 | template void SetVal(DTensor& src, int* entity_idx, int* sample_idx, DTensor& dst); 32 | 33 | template 34 | __global__ void BpErrorKernel(Dtype *dst, Dtype *src, int* entity_idx, int* sample_idx, int cols, int numElements) 35 | { 36 | int i = blockDim.x * blockIdx.x + threadIdx.x; 37 | 38 | if (i < numElements) 39 | { 40 | dst[i] = src[sample_idx[i] * cols + entity_idx[i]]; 41 | } 42 | } 43 | 44 | template 45 | void BpError(DTensor& grad_out, int* entity_idx, int* sample_idx, DTensor& cur_grad) 46 | { 47 | int thread_num = c_uCudaThreadNum; 48 | if (cur_grad.shape.Count() < thread_num) 49 | thread_num = cur_grad.shape.Count(); 50 | int blocksPerGrid = (cur_grad.shape.Count() + thread_num - 1) / thread_num; 51 | 52 | BpErrorKernel <<< blocksPerGrid, thread_num, 0, cudaStreamPerThread >>>(cur_grad.data->ptr, grad_out.data->ptr, entity_idx, sample_idx, grad_out.cols(), cur_grad.shape.Count()); 53 | } 54 | 55 | template void BpError(DTensor& grad_out, int* entity_idx, int* sample_idx, DTensor& cur_grad); 56 | template void BpError(DTensor& grad_out, int* entity_idx, int* sample_idx, DTensor& cur_grad); 57 | 58 | } -------------------------------------------------------------------------------- /code_qa_txt/src/lib/inet.cpp: -------------------------------------------------------------------------------- 1 | #include "inet.h" 2 | #include "knowledge_base.h" 3 | #include "dataset.h" 4 | #include "util/graph_struct.h" 5 | 6 | #include 7 | #include "global.h" 8 | 9 | INet::INet() 10 | { 11 | inputs.clear(); 12 | entity_bow.Reshape({kb.node_dict.size(), word_dict.size()}); 13 | 14 | int nnz = 0; 15 | for (auto& p : kb.node_dict) 16 | nnz += p.second->word_idx_list.size(); 17 | entity_bow.ResizeSp(nnz, kb.node_dict.size() + 1); 18 | 19 | nnz = 0; 20 | for (size_t i = 0; i < kb.node_list.size(); ++i) 21 | { 22 | auto* node = kb.node_list[i]; 23 | entity_bow.data->row_ptr[i] = nnz; 24 | for (size_t j = 0; j < node->word_idx_list.size(); ++j) 25 | { 26 | entity_bow.data->col_idx[nnz] = node->word_idx_list[j]; 27 | entity_bow.data->val[nnz] = 1.0; 28 | nnz++; 29 | } 30 | } 31 | entity_bow.data->row_ptr[kb.node_list.size()] = nnz; 32 | assert(nnz == entity_bow.data->nnz); 33 | 34 | m_entity_bow.CopyFrom(entity_bow); 35 | inputs["entity_bow"] = &m_entity_bow; 36 | } 37 | 38 | void INet::BuildBatchGraph(std::vector< Sample* >& mini_batch, Phase phase) 39 | { 40 | /* 41 | q_bow_input.Reshape({mini_batch.size(), side_word_dict.size()}); 42 | q_entity_input.Reshape({mini_batch.size(), kb.node_dict.size()}); 43 | ans_output.Reshape({mini_batch.size(), kb.node_dict.size()}); 44 | int nnz_q_bow = 0, nnz_q_entity = 0, nnz_ans = 0; 45 | for (auto* s : mini_batch) 46 | { 47 | nnz_q_bow += s->q_side_word_list.size(); 48 | nnz_q_entity += s->q_entities.size(); 49 | nnz_ans += s->answer_entities.size(); 50 | } 51 | q_bow_input.ResizeSp(nnz_q_bow, mini_batch.size() + 1); 52 | q_entity_input.ResizeSp(nnz_q_entity, mini_batch.size() + 1); 53 | if (phase == Phase::TEST) 54 | ans_output.ResizeSp(nnz_ans, mini_batch.size() + 1); 55 | else 56 | ans_output.ResizeSp(mini_batch.size(), mini_batch.size() + 1); 57 | 58 | nnz_q_bow = 0; nnz_q_entity = 0; nnz_ans = 0; 59 | for (int i = 0; i < (int)mini_batch.size(); ++i) 60 | { 61 | auto* sample = mini_batch[i]; 62 | q_bow_input.data->row_ptr[i] = nnz_q_bow; 63 | q_entity_input.data->row_ptr[i] = nnz_q_entity; 64 | ans_output.data->row_ptr[i] = nnz_ans; 65 | 66 | int base_idx = nnz_q_bow; 67 | for (auto e : sample->q_side_word_list) 68 | { 69 | q_bow_input.data->val[nnz_q_bow] = 1.0; 70 | q_bow_input.data->col_idx[nnz_q_bow] = e; 71 | nnz_q_bow += 1; 72 | } 73 | std::sort(q_bow_input.data->col_idx + base_idx, q_bow_input.data->col_idx + nnz_q_bow); 74 | 75 | base_idx = nnz_q_entity; 76 | for (auto e : sample->q_entities) 77 | { 78 | q_entity_input.data->val[nnz_q_entity] = 1.0; 79 | q_entity_input.data->col_idx[nnz_q_entity] = e->idx; 80 | nnz_q_entity += 1; 81 | } 82 | std::sort(q_entity_input.data->col_idx + base_idx, q_entity_input.data->col_idx + nnz_q_entity); 83 | 84 | if (phase == Phase::TEST) 85 | { 86 | base_idx = nnz_ans; 87 | for (auto e : sample->answer_entities) 88 | { 89 | ans_output.data->val[nnz_ans] = 1.0; 90 | ans_output.data->col_idx[nnz_ans] = e->idx; 91 | nnz_ans += 1; 92 | } 93 | std::sort(ans_output.data->col_idx + base_idx, ans_output.data->col_idx + nnz_ans); 94 | } else { 95 | int e_idx = rand() % sample->answer_entities.size(); 96 | assert(e_idx == 0); 97 | ans_output.data->val[nnz_ans] = 1.0; 98 | ans_output.data->col_idx[nnz_ans] = sample->answer_entities[e_idx]->idx; 99 | nnz_ans += 1; 100 | } 101 | } 102 | q_bow_input.data->row_ptr[mini_batch.size()] = nnz_q_bow; 103 | q_entity_input.data->row_ptr[mini_batch.size()] = nnz_q_entity; 104 | ans_output.data->row_ptr[mini_batch.size()] = nnz_ans; 105 | assert(nnz_q_bow == q_bow_input.data->nnz); 106 | assert(nnz_q_entity == q_entity_input.data->nnz); 107 | assert(nnz_ans == ans_output.data->nnz); 108 | 109 | m_q_bow_input.CopyFrom(q_bow_input); 110 | m_q_entity_input.CopyFrom(q_entity_input); 111 | m_ans_output.CopyFrom(ans_output); */ 112 | } 113 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/knowledge_base.cpp: -------------------------------------------------------------------------------- 1 | #include "knowledge_base.h" 2 | #include "config.h" 3 | #include "util.h" 4 | #include "dataset.h" 5 | #include 6 | #include 7 | #include 8 | #include "global.h" 9 | 10 | Node::Node(std::string _name, int _idx) : name(_name), idx(_idx) 11 | { 12 | adj_list.clear(); 13 | triplet_info.clear(); 14 | word_idx_list.clear(); 15 | 16 | std::vector buf; 17 | str_split(name, ' ', buf); 18 | 19 | for (auto& w : buf) 20 | { 21 | if (word_dict.count(w) == 0) 22 | std::cerr << w << std::endl; 23 | assert(word_dict.count(w)); 24 | word_idx_list.push_back(word_dict[w]); 25 | } 26 | std::sort(word_idx_list.begin(), word_idx_list.end()); 27 | } 28 | 29 | void Node::AddNeighbor(int rel_type, Node* y, int kb_idx, bool isReverse) 30 | { 31 | adj_list.push_back(std::make_pair(rel_type, y)); 32 | triplet_info.push_back(std::make_pair(kb_idx, isReverse)); 33 | } 34 | 35 | SubGraph::SubGraph(Sample* _sample) : sample(_sample) 36 | { 37 | nodes.clear(); 38 | edges.clear(); 39 | 40 | assert(sample->q_entities.size()); 41 | 42 | node_map.clear(); 43 | edge_set.clear(); 44 | for (auto* e : sample->q_entities) 45 | { 46 | BFS(e); 47 | } 48 | } 49 | 50 | void SubGraph::AddNode(Node* node) 51 | { 52 | if (node_map.count(node->name)) 53 | return; 54 | node_map[node->name] = nodes.size(); 55 | nodes.push_back(node); 56 | } 57 | 58 | void SubGraph::AddEdge(Node* src, int r_type, Node* dst) 59 | { 60 | assert(node_map.count(src->name)); 61 | assert(node_map.count(dst->name)); 62 | 63 | int x = node_map[src->name], y = node_map[dst->name]; 64 | assert(x < (int)nodes.size()); 65 | assert(y < (int)nodes.size()); 66 | if (x > y){ 67 | int t = x; x = y; y = t; 68 | } 69 | auto p = std::make_tuple(x, r_type, y); 70 | if (!edge_set.count(p)) 71 | { 72 | edge_set.insert(p); 73 | edges.push_back(p); 74 | } 75 | } 76 | 77 | void SubGraph::BFS(Node* root) 78 | { 79 | AddNode(root); 80 | std::queue< std::pair > q_node; 81 | while (!q_node.empty()) 82 | q_node.pop(); 83 | q_node.push( std::make_pair(root, 0) ); 84 | while (!q_node.empty()) 85 | { 86 | auto tt = q_node.front(); 87 | if (tt.second >= cfg::nhop_subg) 88 | break; 89 | q_node.pop(); 90 | for (auto p : tt.first->adj_list) 91 | { 92 | if (!node_map.count(p.second->name)) 93 | { 94 | AddNode(p.second); 95 | q_node.push(std::make_pair(p.second, tt.second + 1)); 96 | } 97 | AddEdge(tt.first, p.first, p.second); 98 | } 99 | } 100 | } 101 | 102 | KnowledgeBase::KnowledgeBase() 103 | { 104 | node_dict.clear(); 105 | node_list.clear(); 106 | } 107 | 108 | Node* KnowledgeBase::GetOrAddNode(std::string name) 109 | { 110 | if (node_dict.count(name) == 0) 111 | { 112 | assert(node_dict.size() == node_list.size()); 113 | int t = node_dict.size(); 114 | auto* node = new Node(name, t); 115 | node_dict[name] = node; 116 | node_list.push_back(node); 117 | } 118 | return node_dict[name]; 119 | } 120 | 121 | Node* KnowledgeBase::GetNodeOrDie(std::string name) 122 | { 123 | assert(node_dict.count(name)); 124 | return node_dict[name]; 125 | } 126 | 127 | int KnowledgeBase::NodeIdx(std::string name) 128 | { 129 | assert(node_dict.count(name)); 130 | return node_dict[name]->idx; 131 | } 132 | 133 | void KnowledgeBase::ParseKnowledgeFile() 134 | { 135 | auto kb_file = fmt::format("{0}/kb.txt", cfg::data_root); 136 | std::ifstream fin(kb_file); 137 | 138 | std::string st; 139 | std::vector buf; 140 | int n_lines = 0; 141 | while (std::getline(fin, st)) 142 | { 143 | str_split(st, '|', buf); 144 | assert(buf.size() == 3); 145 | 146 | auto* src = GetOrAddNode(buf[0]); 147 | assert(relation_dict.count(buf[1])); 148 | auto* dst = GetOrAddNode(buf[2]); 149 | 150 | src->AddNeighbor(relation_dict[buf[1]], dst, n_lines, false); 151 | dst->AddNeighbor(relation_dict.size() + relation_dict[buf[1]], src, n_lines, true); 152 | n_lines += 1; 153 | } 154 | std::cerr << n_lines << " knowledge triples loaded" << std::endl; 155 | n_knowledges = n_lines; 156 | std::cerr << "#entities in kb: " << node_dict.size() << std::endl; 157 | 158 | for (auto& p : node_dict) 159 | { 160 | auto* node = p.second; 161 | std::sort(node->adj_list.begin(), node->adj_list.end(), 162 | [](const std::pair< int, Node* >& x, const std::pair< int, Node* >& y){ 163 | return x.second->idx < y.second->idx; 164 | }); 165 | } 166 | } 167 | 168 | void KnowledgeBase::ParseEntityInAnswers(const char* suffix) 169 | { 170 | auto file = fmt::sprintf("{0}/{1}-hop/{2}/qa_{3}.txt", 171 | cfg::data_root, cfg::nhop_subg, cfg::dataset, suffix); 172 | 173 | std::ifstream fin(file); 174 | 175 | std::string st; 176 | std::vector buf; 177 | 178 | while (std::getline(fin, st)) 179 | { 180 | str_split(st, '\t', buf); 181 | assert(buf.size() == 2); 182 | st = buf[1]; 183 | str_split(st, '|', buf); 184 | 185 | for (auto e : buf) 186 | { 187 | GetOrAddNode(e); 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/net_latent_y.cpp: -------------------------------------------------------------------------------- 1 | #include "net_latent_y.h" 2 | #include "knowledge_base.h" 3 | #include "dataset.h" 4 | #include "var_sample.h" 5 | #include "util/graph_struct.h" 6 | #include "nn/nn_all.h" 7 | #include 8 | #include 9 | #include "global.h" 10 | 11 | using namespace gnn; 12 | 13 | NetLatentY::NetLatentY() : INet() 14 | { 15 | inputs["q_bow"] = &m_q_bow_input; 16 | inputs["ans_output"] = &m_ans_output; 17 | inputs["a_dst_nodes"] = &answer_dst_nodes; 18 | } 19 | 20 | void NetLatentY::BuildBatchGraph(std::vector< Sample* >& mini_batch, Phase phase) 21 | { 22 | q_bow_input.Reshape({mini_batch.size(), word_dict.size()}); 23 | ans_output.Reshape({mini_batch.size(), kb.node_dict.size()}); 24 | int nnz_q_bow = 0, nnz_ans = 0; 25 | for (auto* s : mini_batch) 26 | { 27 | nnz_q_bow += s->q_word_list.size(); 28 | nnz_ans += s->answer_entities.size(); 29 | } 30 | q_bow_input.ResizeSp(nnz_q_bow, mini_batch.size() + 1); 31 | ans_output.ResizeSp(nnz_ans, mini_batch.size() + 1); 32 | 33 | nnz_q_bow = 0; nnz_ans = 0; 34 | answer_dst_nodes.clear(); 35 | for (int i = 0; i < (int)mini_batch.size(); ++i) 36 | { 37 | auto* sample = mini_batch[i]; 38 | 39 | assert(sample->q_entities.size()); 40 | int idx = rand() % sample->answer_entities.size(); 41 | answer_dst_nodes.push_back(sample->answer_entities[idx]); 42 | 43 | q_bow_input.data->row_ptr[i] = nnz_q_bow; 44 | ans_output.data->row_ptr[i] = nnz_ans; 45 | 46 | int base_idx = nnz_q_bow; 47 | for (auto e : sample->q_word_list) 48 | { 49 | q_bow_input.data->val[nnz_q_bow] = 1.0; 50 | q_bow_input.data->col_idx[nnz_q_bow] = e; 51 | nnz_q_bow += 1; 52 | } 53 | std::sort(q_bow_input.data->col_idx + base_idx, q_bow_input.data->col_idx + nnz_q_bow); 54 | 55 | base_idx = nnz_ans; 56 | for (auto e : sample->answer_entities) 57 | { 58 | ans_output.data->val[nnz_ans] = 1.0; 59 | ans_output.data->col_idx[nnz_ans] = e->idx; 60 | nnz_ans += 1; 61 | } 62 | std::sort(ans_output.data->col_idx + base_idx, ans_output.data->col_idx + nnz_ans); 63 | } 64 | q_bow_input.data->row_ptr[mini_batch.size()] = nnz_q_bow; 65 | ans_output.data->row_ptr[mini_batch.size()] = nnz_ans; 66 | assert(nnz_q_bow == q_bow_input.data->nnz); 67 | assert(nnz_ans == ans_output.data->nnz); 68 | 69 | m_q_bow_input.CopyFrom(q_bow_input); 70 | m_ans_output.CopyFrom(ans_output); 71 | inputs["sample_var"] = &mini_batch; 72 | } 73 | 74 | std::shared_ptr< DTensorVar > NetLatentY::GetMatchScores(std::shared_ptr< DTensorVar >& q_embed, 75 | std::shared_ptr& samples, 76 | std::shared_ptr< DTensorVar >& rel_embed, 77 | std::shared_ptr< DTensorVar >& w_recur, 78 | std::shared_ptr< VectorVar >& start_nodes) 79 | { 80 | std::vector< std::shared_ptr< Variable > > args = { samples, rel_embed, w_recur, start_nodes}; 81 | auto tp = af< BfsPathEmbed >(fg, args); 82 | //std::get<0>(tp) = af< ReLU >(fg, {std::get<0>(tp)}); 83 | int h = 1; 84 | while (h < cfg::nhop_subg) 85 | { 86 | h++; 87 | args = {samples, rel_embed, w_recur, std::get<0>(tp), std::get<1>(tp), std::get<2>(tp)}; 88 | tp = af< BfsPathEmbed >(fg, args); 89 | //std::get<0>(tp) = af< ReLU >(fg, {std::get<0>(tp)}); 90 | } 91 | args = {q_embed, std::get<0>(tp), std::get<1>(tp), std::get<2>(tp)}; 92 | auto match_scores = af< GraphInnerProduct >(fg, args, kb.node_dict.size()); 93 | return match_scores; 94 | } 95 | 96 | std::shared_ptr< DTensorVar > NetLatentY::GetCritic(std::shared_ptr< SpTensorVar > q_bow) 97 | { 98 | auto critic_w_embed = add_diff(model, "critic_w_embed", {word_dict.size(), (size_t)cfg::n_embed}); 99 | auto w1 = add_diff(model, "critic_w1", {(size_t)(cfg::n_embed + 1), (size_t)cfg::n_hidden}); 100 | auto w2 = add_diff(model, "critic_w2", {(size_t)(cfg::n_hidden + 1), (size_t)1}); 101 | 102 | critic_w_embed->value.SetRandN(0, cfg::w_scale); 103 | w1->value.SetRandN(0, cfg::w_scale); 104 | w2->value.SetRandN(0, cfg::w_scale); 105 | fg.AddParam(critic_w_embed); 106 | fg.AddParam(w1); 107 | fg.AddParam(w2); 108 | 109 | auto q_embed = af< MatMul >(fg, {q_bow, critic_w_embed}); 110 | q_embed = af< ReLU >(fg, {q_embed}); 111 | 112 | auto h1 = af< FullyConnected >(fg, {q_embed, w1}); 113 | h1 = af< ReLU >(fg, {h1}); 114 | 115 | auto h2 = af< FullyConnected >(fg, {h1, w2}); 116 | h2 = af< ReLU >(fg, {h2}); 117 | return h2; 118 | } 119 | 120 | void NetLatentY::BuildNet() 121 | { 122 | // inputs 123 | auto q_bow = add_const< SpTensorVar >(fg, "q_bow", true); 124 | auto entity_bow = add_const< SpTensorVar >(fg, "entity_bow", true); 125 | auto ans_output = add_const< SpTensorVar >(fg, "ans_output", true); 126 | auto samples = add_const< SampleVar >(fg, "sample_var", true); 127 | auto a_dst_nodes = add_const< VectorVar >(fg, "a_dst_nodes", true); 128 | 129 | // parameters 130 | auto w_entity_match = add_diff(model, "w_entity_match", {word_dict.size(), (size_t)cfg::n_embed}); 131 | auto w_path_query = add_diff(model, "w_path_query", {word_dict.size(), (size_t)cfg::n_embed}); 132 | 133 | auto rel_embed = add_diff(model, "rel_embedding", {relation_dict.size() * (size_t)2, (size_t)cfg::n_embed}); 134 | auto w_a2y_recur = add_diff(model, "w_a2y_recur", {cfg::n_embed, cfg::n_embed}); 135 | auto w_y2a_recur = add_diff(model, "w_y2a_recur", {cfg::n_embed, cfg::n_embed}); 136 | 137 | auto moving_mean = add_nondiff(model, "moving_mean", {(size_t)1, (size_t)1}); 138 | auto moving_inv_std = add_nondiff(model, "moving_inv_std", {(size_t)1, (size_t)1}); 139 | 140 | moving_mean->value.Fill(0.0); 141 | moving_inv_std->value.Fill(1.0); 142 | fg.AddParam(moving_mean); 143 | fg.AddParam(moving_inv_std); 144 | 145 | w_entity_match->value.SetRandN(0, cfg::w_scale); 146 | w_path_query->value.SetRandN(0, cfg::w_scale); 147 | rel_embed->value.SetRandN(0, cfg::w_scale); 148 | w_a2y_recur->value.SetRandN(0, cfg::w_scale); 149 | w_y2a_recur->value.SetRandN(0, cfg::w_scale); 150 | fg.AddParam(w_entity_match); 151 | fg.AddParam(w_path_query); 152 | fg.AddParam(rel_embed); 153 | fg.AddParam(w_a2y_recur); 154 | fg.AddParam(w_y2a_recur); 155 | 156 | auto q_embed_entity = af< MatMul >(fg, {q_bow, w_entity_match}); 157 | q_embed_query = af< MatMul >(fg, {q_bow, w_path_query}); 158 | auto entity_bow_embed = af< MatMul >(fg, {entity_bow, w_entity_match}); 159 | q_y_bow_match = af< MatMul >(fg, {q_embed_entity, entity_bow_embed}, Trans::N, Trans::T); 160 | 161 | //================ q_y_given_qa =========================== 162 | auto y_q_path_match = GetMatchScores(q_embed_entity, samples, rel_embed, w_a2y_recur, a_dst_nodes); 163 | auto q_y_given_qa_scores = af< ElewiseAdd >(fg, {q_y_bow_match, y_q_path_match}); 164 | 165 | // should do sampling 166 | pos_probs = af< Softmax >(fg, {q_y_given_qa_scores}); 167 | 168 | sampled_y_idx = af< MultinomialSample >(fg, {pos_probs}, false); 169 | 170 | auto sampled_y_nodes = af< NodeSelect >(fg, {sampled_y_idx}); 171 | 172 | //================ p_a_given_qy =========================== 173 | auto a_q_path_math = GetMatchScores(q_embed_query, samples, rel_embed, w_y2a_recur, sampled_y_nodes); 174 | 175 | //================ -log p(y | q) - log p(a | y, q) =========================== 176 | auto ce_ans = af< CrossEntropy >(fg, {a_q_path_math, ans_output}, true); 177 | auto one_hot_sampled_y = af< OneHot >(fg, {sampled_y_idx}, kb.node_dict.size()); 178 | auto ce_y = af< CrossEntropy >(fg, {q_y_bow_match, one_hot_sampled_y}, true); 179 | 180 | //================ \abla log Q(y | q, a) * score =========================== 181 | auto ce_joint = af< ElewiseAdd >(fg, {ce_ans, ce_y}); 182 | auto baseline = GetCritic(q_bow); 183 | 184 | auto normed_signal = af< MovingNorm >(fg, {ce_joint, moving_mean, moving_inv_std}, 0.1, PropErr::N); 185 | auto learning_signal = af< ElewiseMinus >(fg, {normed_signal, baseline}, PropErr::N); 186 | 187 | std::vector< std::shared_ptr< Variable > > tmp = {sampled_y_idx, learning_signal}; 188 | auto onehot_joint = af< OneHot >(fg, tmp, kb.node_dict.size()); 189 | auto pos_neg_loss = af< CrossEntropy >(fg, {pos_probs, onehot_joint}, false); 190 | //auto pos_neg_loss = af< CrossEntropy >(fg, {q_y_given_qa_scores, onehot_joint}, true); 191 | 192 | //=========== baseline mse ================== 193 | auto square_error = af< SquareError >(fg, {baseline, normed_signal}); 194 | std::vector coeff = {1.0, -1.0, 1.0}; 195 | loss = af< ElewiseAdd >(fg, {ce_joint, pos_neg_loss, square_error}, coeff); 196 | loss = af< ReduceMean >(fg, {loss}); 197 | //================ inference =========================== 198 | auto argmax_y = af< ArgMax >(fg, {q_y_bow_match}); 199 | auto infer_y_nodes = af< NodeSelect >(fg, {argmax_y}); 200 | pred = GetMatchScores(q_embed_query, samples, rel_embed, w_y2a_recur, infer_y_nodes); 201 | 202 | hitk = af< HitAtK >(fg, {pred, ans_output}); 203 | auto real_hitk = af< TypeCast >(fg, {hitk}); 204 | hit_rate = af< ReduceMean >(fg, {real_hitk}); 205 | } 206 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/net_multihop.cpp: -------------------------------------------------------------------------------- 1 | #include "net_multihop.h" 2 | #include "knowledge_base.h" 3 | #include "dataset.h" 4 | #include "var_sample.h" 5 | #include "util/graph_struct.h" 6 | #include "nn/nn_all.h" 7 | #include 8 | #include 9 | #include "global.h" 10 | 11 | using namespace gnn; 12 | 13 | NetMultiHop::NetMultiHop() : INet() 14 | { 15 | inputs["q_bow"] = &m_q_bow_input; 16 | inputs["ans_output"] = &m_ans_output; 17 | inputs["a_dst_nodes"] = &answer_dst_nodes; 18 | } 19 | 20 | void NetMultiHop::BuildBatchGraph(std::vector< Sample* >& mini_batch, Phase phase) 21 | { 22 | q_bow_input.Reshape({mini_batch.size(), word_dict.size()}); 23 | y_idxes.Reshape({mini_batch.size(), (size_t)1}); 24 | ans_output.Reshape({mini_batch.size(), kb.node_dict.size()}); 25 | int nnz_q_bow = 0, nnz_ans = 0; 26 | for (auto* s : mini_batch) 27 | { 28 | nnz_q_bow += s->q_word_list.size(); 29 | nnz_ans += s->answer_entities.size(); 30 | } 31 | q_bow_input.ResizeSp(nnz_q_bow, mini_batch.size() + 1); 32 | ans_output.ResizeSp(nnz_ans, mini_batch.size() + 1); 33 | 34 | nnz_q_bow = 0; nnz_ans = 0; 35 | answer_dst_nodes.clear(); 36 | for (int i = 0; i < (int)mini_batch.size(); ++i) 37 | { 38 | auto* sample = mini_batch[i]; 39 | 40 | assert(sample->q_entities.size()); 41 | y_idxes.data->ptr[i] = sample->q_entities[0]->idx; 42 | int idx = rand() % sample->answer_entities.size(); 43 | answer_dst_nodes.push_back(sample->answer_entities[idx]); 44 | 45 | q_bow_input.data->row_ptr[i] = nnz_q_bow; 46 | ans_output.data->row_ptr[i] = nnz_ans; 47 | 48 | int base_idx = nnz_q_bow; 49 | for (auto e : sample->q_word_list) 50 | { 51 | q_bow_input.data->val[nnz_q_bow] = 1.0; 52 | q_bow_input.data->col_idx[nnz_q_bow] = e; 53 | nnz_q_bow += 1; 54 | } 55 | std::sort(q_bow_input.data->col_idx + base_idx, q_bow_input.data->col_idx + nnz_q_bow); 56 | 57 | base_idx = nnz_ans; 58 | for (auto e : sample->answer_entities) 59 | { 60 | ans_output.data->val[nnz_ans] = 1.0; 61 | ans_output.data->col_idx[nnz_ans] = e->idx; 62 | nnz_ans += 1; 63 | } 64 | std::sort(ans_output.data->col_idx + base_idx, ans_output.data->col_idx + nnz_ans); 65 | } 66 | q_bow_input.data->row_ptr[mini_batch.size()] = nnz_q_bow; 67 | ans_output.data->row_ptr[mini_batch.size()] = nnz_ans; 68 | assert(nnz_q_bow == q_bow_input.data->nnz); 69 | assert(nnz_ans == ans_output.data->nnz); 70 | 71 | m_q_bow_input.CopyFrom(q_bow_input); 72 | m_ans_output.CopyFrom(ans_output); 73 | inputs["sample_var"] = &mini_batch; 74 | inputs["y_idxes"] = &y_idxes; 75 | } 76 | 77 | std::shared_ptr< DTensorVar > NetMultiHop::GetMatchScores(std::shared_ptr< DTensorVar >& q_embed, 78 | std::shared_ptr& samples, 79 | std::shared_ptr< DTensorVar >& rel_embed, 80 | std::shared_ptr< DTensorVar >& w_recur, 81 | std::shared_ptr< VectorVar >& start_nodes) 82 | { 83 | std::vector< std::shared_ptr< Variable > > args = { samples, rel_embed, w_recur, start_nodes}; 84 | auto tp = af< BfsPathEmbed >(fg, args); 85 | //std::get<0>(tp) = af< ReLU >(fg, {std::get<0>(tp)}); 86 | int h = 1; 87 | while (h < cfg::nhop_subg) 88 | { 89 | h++; 90 | args = {samples, rel_embed, w_recur, std::get<0>(tp), std::get<1>(tp), std::get<2>(tp)}; 91 | tp = af< BfsPathEmbed >(fg, args); 92 | //std::get<0>(tp) = af< ReLU >(fg, {std::get<0>(tp)}); 93 | } 94 | args = {q_embed, std::get<0>(tp), std::get<1>(tp), std::get<2>(tp)}; 95 | auto match_scores = af< GraphInnerProduct >(fg, args, kb.node_dict.size()); 96 | return match_scores; 97 | } 98 | 99 | std::shared_ptr< DTensorVar > NetMultiHop::GetCritic(std::shared_ptr< SpTensorVar > q_bow) 100 | { 101 | auto critic_w_embed = add_diff(model, "critic_w_embed", {word_dict.size(), (size_t)cfg::n_embed}); 102 | auto w1 = add_diff(model, "critic_w1", {(size_t)(cfg::n_embed + 1), (size_t)cfg::n_hidden}); 103 | auto w2 = add_diff(model, "critic_w2", {(size_t)(cfg::n_hidden + 1), (size_t)1}); 104 | 105 | critic_w_embed->value.SetRandN(0, cfg::w_scale); 106 | w1->value.SetRandN(0, cfg::w_scale); 107 | w2->value.SetRandN(0, cfg::w_scale); 108 | fg.AddParam(critic_w_embed); 109 | fg.AddParam(w1); 110 | fg.AddParam(w2); 111 | 112 | auto q_embed = af< MatMul >(fg, {q_bow, critic_w_embed}); 113 | q_embed = af< ReLU >(fg, {q_embed}); 114 | 115 | auto h1 = af< FullyConnected >(fg, {q_embed, w1}); 116 | h1 = af< ReLU >(fg, {h1}); 117 | 118 | auto h2 = af< FullyConnected >(fg, {h1, w2}); 119 | h2 = af< ReLU >(fg, {h2}); 120 | return h2; 121 | } 122 | 123 | void NetMultiHop::BuildNet() 124 | { 125 | // inputs 126 | auto q_bow = add_const< SpTensorVar >(fg, "q_bow", true); 127 | auto entity_bow = add_const< SpTensorVar >(fg, "entity_bow", true); 128 | auto ans_output = add_const< SpTensorVar >(fg, "ans_output", true); 129 | auto samples = add_const< SampleVar >(fg, "sample_var", true); 130 | auto a_dst_nodes = add_const< VectorVar >(fg, "a_dst_nodes", true); 131 | auto true_y_idx = add_const< DTensorVar >(fg, "y_idxes", true); 132 | 133 | // parameters 134 | auto w_entity_match = add_diff(model, "w_entity_match", {word_dict.size(), (size_t)cfg::n_embed}); 135 | auto w_path_query = add_diff(model, "w_path_query", {word_dict.size(), (size_t)cfg::n_embed}); 136 | 137 | auto rel_embed = add_diff(model, "rel_embedding", {relation_dict.size() * (size_t)2, (size_t)cfg::n_embed}); 138 | auto w_a2y_recur = add_diff(model, "w_a2y_recur", {cfg::n_embed, cfg::n_embed}); 139 | auto w_y2a_recur = add_diff(model, "w_y2a_recur", {cfg::n_embed, cfg::n_embed}); 140 | 141 | auto moving_mean = add_nondiff(model, "moving_mean", {(size_t)1, (size_t)1}); 142 | auto moving_inv_std = add_nondiff(model, "moving_inv_std", {(size_t)1, (size_t)1}); 143 | 144 | moving_mean->value.Fill(0.0); 145 | moving_inv_std->value.Fill(1.0); 146 | fg.AddParam(moving_mean); 147 | fg.AddParam(moving_inv_std); 148 | 149 | w_entity_match->value.SetRandN(0, cfg::w_scale); 150 | w_path_query->value.SetRandN(0, cfg::w_scale); 151 | rel_embed->value.SetRandN(0, cfg::w_scale); 152 | w_a2y_recur->value.SetRandN(0, cfg::w_scale); 153 | w_y2a_recur->value.SetRandN(0, cfg::w_scale); 154 | fg.AddParam(w_entity_match); 155 | fg.AddParam(w_path_query); 156 | fg.AddParam(rel_embed); 157 | fg.AddParam(w_a2y_recur); 158 | fg.AddParam(w_y2a_recur); 159 | 160 | auto q_embed_entity = af< MatMul >(fg, {q_bow, w_entity_match}); 161 | auto q_embed_query = af< MatMul >(fg, {q_bow, w_path_query}); 162 | auto entity_bow_embed = af< MatMul >(fg, {entity_bow, w_entity_match}); 163 | auto q_y_bow_match = af< MatMul >(fg, {q_embed_entity, entity_bow_embed}, Trans::N, Trans::T); 164 | 165 | 166 | //================ q_y_given_qa =========================== 167 | auto y_q_path_match = GetMatchScores(q_embed_entity, samples, rel_embed, w_a2y_recur, a_dst_nodes); 168 | auto q_y_given_qa_scores = af< ElewiseAdd >(fg, {q_y_bow_match, y_q_path_match}); 169 | 170 | auto sampled_y_nodes = af< NodeSelect >(fg, {true_y_idx}); 171 | 172 | 173 | //================ p_a_given_qy =========================== 174 | auto a_q_path_math = GetMatchScores(q_embed_query, samples, rel_embed, w_y2a_recur, sampled_y_nodes); 175 | 176 | //================ -log p(y | q) - log p(a | y, q) =========================== 177 | auto ce_ans = af< CrossEntropy >(fg, {a_q_path_math, ans_output}, true); 178 | auto one_hot_sampled_y = af< OneHot >(fg, {true_y_idx}, kb.node_dict.size()); 179 | auto ce_y = af< CrossEntropy >(fg, {q_y_bow_match, one_hot_sampled_y}, true); 180 | 181 | //================ \abla log Q(y | q, a) * score =========================== 182 | auto ce_joint = af< ElewiseAdd >(fg, {ce_ans, ce_y}); 183 | 184 | auto truey_onehot = af< OneHot >(fg, {true_y_idx}, kb.node_dict.size()); 185 | auto pos_neg_loss = af< CrossEntropy >(fg, {q_y_given_qa_scores, truey_onehot}, true); 186 | auto baseline = GetCritic(q_bow); 187 | auto normed_signal = af< MovingNorm >(fg,{ce_joint, moving_mean, moving_inv_std}, 0.1, PropErr::N); 188 | auto learning_signal = af< ElewiseMinus >(fg, {normed_signal, baseline}, PropErr::N); 189 | 190 | //=========== baseline mse ================== 191 | // auto square_error = af< SquareError > (fg, {baseline, normed_signal}); 192 | // std::vector coeff = {1.0, 1.0, 1.0}; 193 | // loss = af< ElewiseAdd >(fg, {ce_joint, pos_neg_loss, square_error}, coeff); 194 | loss = af< ElewiseAdd >(fg, {ce_joint, pos_neg_loss}); 195 | loss = af< ReduceMean >(fg, {loss}); 196 | 197 | //================ inference =========================== 198 | auto argmax_y = true_y_idx; 199 | auto infer_y_nodes = af< NodeSelect >(fg, {argmax_y}); 200 | auto pred = GetMatchScores(q_embed_query, samples, rel_embed, w_y2a_recur, infer_y_nodes); 201 | 202 | auto hitk = af< HitAtK >(fg, {pred, ans_output}); 203 | auto real_hitk = af< TypeCast >(fg, {hitk}); 204 | hit_rate = af< ReduceMean >(fg, {real_hitk}); 205 | } 206 | -------------------------------------------------------------------------------- /code_qa_txt/src/lib/node_select.cpp: -------------------------------------------------------------------------------- 1 | #include "node_select.h" 2 | #include "global.h" 3 | 4 | namespace gnn 5 | { 6 | 7 | NodeSelect::NodeSelect(std::string _name) 8 | : Factor(_name, PropErr::N) 9 | { 10 | 11 | } 12 | 13 | void NodeSelect::SetupNodes(const int len, const int* idxes, std::vector& dst) 14 | { 15 | dst.resize(len); 16 | for (int i = 0; i < len; ++i) 17 | dst[i] = kb.node_list[idxes[i]]; 18 | } 19 | 20 | void NodeSelect::Forward(std::vector< std::shared_ptr >& operands, 21 | std::vector< std::shared_ptr >& outputs, 22 | Phase phase) 23 | { 24 | ASSERT(operands.size() == 1, "unexpected input size for " << StrType()); 25 | ASSERT(outputs.size() == 1, "unexpected output size for " << StrType()); 26 | 27 | auto& node_info = *(dynamic_cast< VectorVar* >(outputs[0].get())->vec); 28 | 29 | MAT_MODE_SWITCH(operands[0]->GetMode(), matMode, { 30 | auto& indexes = dynamic_cast*>(operands[0].get())->value; 31 | DTensor t_idxes; 32 | 33 | int* ptr = indexes.data->ptr; 34 | if (indexes.GetMatMode() == MatMode::gpu) 35 | { 36 | t_idxes.CopyFrom(indexes); 37 | ptr = t_idxes.data->ptr; 38 | } 39 | SetupNodes(indexes.shape.Count(), ptr, node_info); 40 | }); 41 | } 42 | 43 | } -------------------------------------------------------------------------------- /code_qa_txt/src/lib/util.cpp: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | 3 | void str_split(const std::string &s, char delim, std::vector &result) 4 | { 5 | std::stringstream ss; 6 | ss.str(s); 7 | std::string item; 8 | result.clear(); 9 | while (std::getline(ss, item, delim)) 10 | result.push_back(item); 11 | } 12 | 13 | void str_replace(std::string& str, std::string p, std::string q) 14 | { 15 | while (true) 16 | { 17 | auto idx = str.find(p); 18 | if (idx == std::string::npos) 19 | break; 20 | str.replace(idx, p.size(), q); 21 | } 22 | } -------------------------------------------------------------------------------- /code_qa_txt/src/lib/var_sample.cpp: -------------------------------------------------------------------------------- 1 | #include "var_sample.h" 2 | #include "dataset.h" 3 | 4 | namespace gnn 5 | { 6 | 7 | SampleVar::SampleVar(std::string _name) : Variable(_name), samples(nullptr) 8 | { 9 | 10 | } 11 | 12 | EleType SampleVar::GetEleType() 13 | { 14 | return EleType::UNKNOWN; 15 | } 16 | 17 | MatMode SampleVar::GetMode() 18 | { 19 | return MatMode::cpu; 20 | } 21 | 22 | void SampleVar::SetRef(void* p) 23 | { 24 | samples = static_cast*>(p); 25 | } 26 | 27 | template 28 | VectorVar::VectorVar(std::string _name) : Variable(_name) 29 | { 30 | vec = new std::vector(); 31 | } 32 | 33 | template 34 | EleType VectorVar::GetEleType() 35 | { 36 | return EleType::UNKNOWN; 37 | } 38 | 39 | template 40 | MatMode VectorVar::GetMode() 41 | { 42 | return MatMode::cpu; 43 | } 44 | 45 | template 46 | void VectorVar::SetRef(void* p) 47 | { 48 | vec = static_cast*>(p); 49 | } 50 | 51 | template class VectorVar; 52 | template class VectorVar; 53 | } -------------------------------------------------------------------------------- /code_qa_txt/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "config.h" 2 | #include "nn/nn_all.h" 3 | #include "dataset.h" 4 | #include 5 | #include "dict.h" 6 | //#include "graph_struct.h" 7 | #include "knowledge_base.h" 8 | #include "global.h" 9 | #include "net_multihop.h" 10 | #include "net_latent_y.h" 11 | 12 | Dataset train_set, val_set, test_set; 13 | std::vector< Sample* > mini_batch; 14 | 15 | void EvalSet(std::string prefix, Dataset& dset, INet* net) 16 | { 17 | dset.SetupStream(false); 18 | Dtype loss_total = 0.0; 19 | while (dset.GetMiniBatch(cfg::batch_size, mini_batch)) 20 | { 21 | net->BuildBatchGraph(mini_batch, Phase::TEST); 22 | fg.FeedForward({net->hit_rate}, net->inputs, Phase::TEST); 23 | 24 | loss_total += mini_batch.size() * net->hit_rate->value.AsScalar(); 25 | } 26 | std::cerr << prefix << "@iter: " << cfg::iter; 27 | std::cerr << "\thit_rate@1: " << loss_total / dset.orig_samples.size(); 28 | std::cerr << std::endl; 29 | } 30 | 31 | std::vector idx_buf; 32 | void GetTopK(DTensor& prob, std::vector< std::vector< std::pair > >& idx_list) 33 | { 34 | idx_list.resize(prob.rows()); 35 | 36 | for (size_t i = 0; i < prob.rows(); ++i) 37 | { 38 | idx_list[i].clear(); 39 | Dtype* ptr = prob.data->ptr + i * prob.cols(); 40 | 41 | std::sort(idx_buf.begin(), idx_buf.end(), [&](const int& i, const int& j) { 42 | return ptr[i] > ptr[j]; 43 | }); 44 | for (int j = 0; j < cfg::test_tpok; ++j) 45 | idx_list[i].push_back(std::make_pair(idx_buf[j], ptr[idx_buf[j]])); 46 | } 47 | } 48 | 49 | void Print2File(FILE* fid, DTensor& prob, std::vector< std::vector< std::pair > >& idx_list) 50 | { 51 | for (size_t i = 0; i < prob.rows(); ++i) 52 | { 53 | for (int j = 0; j < cfg::test_tpok; ++j) 54 | { 55 | auto& p = idx_list[i][j]; 56 | 57 | auto* node = kb.node_list[p.first]; 58 | if (j) 59 | fprintf(fid, "|"); 60 | fprintf(fid, "("); 61 | for (size_t j = 0; j < node->name.size(); ++j) 62 | fprintf(fid, "%c", node->name[j]); 63 | fprintf(fid, ",%.6f)", p.second); 64 | } 65 | fprintf(fid, "\n"); 66 | } 67 | } 68 | 69 | void SavePred(std::string prefix, Dataset& dset, INet* net) 70 | { 71 | dset.SetupStream(false); 72 | idx_buf.resize(kb.node_dict.size()); 73 | for (size_t i = 0; i < idx_buf.size(); ++i) 74 | idx_buf[i] = i; 75 | 76 | FILE* fy = fopen(fmt::sprintf("%s/%s_ypred.txt", cfg::save_dir, prefix).c_str(), "w"); 77 | // FILE* fa = fopen(fmt::sprintf("%s/%s_apred.txt", cfg::save_dir, prefix).c_str(), "w"); 78 | 79 | while (dset.GetMiniBatch(cfg::batch_size, mini_batch)) 80 | { 81 | net->BuildBatchGraph(mini_batch, Phase::TEST); 82 | // fg.FeedForward({net->pos_probs, net->pred}, net->inputs, Phase::TEST); 83 | fg.FeedForward({net->q_y_bow_match}, net->inputs, Phase::TEST); 84 | 85 | auto& y_prob = net->q_y_bow_match->value; 86 | y_prob.Softmax(); 87 | std::vector< std::vector< std::pair > > idx_list; 88 | GetTopK(y_prob, idx_list); 89 | Print2File(fy, y_prob, idx_list); 90 | 91 | // auto& pred = net->pred->value; 92 | // pred.Softmax(); 93 | // GetTopK(pred, idx_list); 94 | // Print2File(fa, pred, idx_list); 95 | } 96 | fclose(fy); 97 | // fclose(fa); 98 | } 99 | 100 | void MainLoop(INet* inet) 101 | { 102 | //MomentumSGDOptimizer learner(&model, cfg::lr, cfg::momentum, cfg::l2_penalty); 103 | AdamOptimizer learner(&model, cfg::lr, cfg::l2_penalty); 104 | 105 | int max_iter = (long long)cfg::max_iter; 106 | int init_iter = cfg::iter; 107 | if (init_iter > 0) 108 | { 109 | std::cerr << fmt::sprintf("loading model for iter=%d", init_iter) << std::endl; 110 | model.Load(fmt::sprintf("%s/iter_%d.model", cfg::save_dir, init_iter)); 111 | } 112 | 113 | train_set.SetupStream(true); 114 | 115 | for (; cfg::iter <= max_iter; ++cfg::iter) 116 | { 117 | if (cfg::iter != init_iter && cfg::iter % cfg::test_interval == 0) 118 | { 119 | //EvalSet("train", train_set, inet); 120 | EvalSet("dev", val_set, inet); 121 | EvalSet("test", test_set, inet); 122 | } 123 | if (cfg::iter % cfg::save_interval == 0 && cfg::iter != init_iter) 124 | { 125 | printf("saving model for iter=%d\n", cfg::iter); 126 | model.Save(fmt::sprintf("%s/iter_%d.model", cfg::save_dir, cfg::iter)); 127 | } 128 | 129 | assert(train_set.GetSplitMiniBatch(cfg::batch_size, mini_batch)); 130 | inet->BuildBatchGraph(mini_batch, Phase::TRAIN); 131 | 132 | fg.FeedForward({inet->loss}, inet->inputs, Phase::TRAIN); 133 | 134 | if (cfg::iter % cfg::report_interval == 0) 135 | { 136 | std::cerr << "iter: " << cfg::iter; 137 | std::cerr << "\tloss: " << inet->loss->value.AsScalar(); 138 | // for (auto t : inet->targets) 139 | // std::cerr << "\t" << t->name << ": " << dynamic_cast*>(t.get())->AsScalar(); 140 | std::cerr << std::endl; 141 | } 142 | fg.BackPropagate({inet->loss}); 143 | learner.Update(); 144 | } 145 | } 146 | 147 | std::vector< std::string> q_types; 148 | size_t total_ntypes; 149 | 150 | void LoadTestTypes() 151 | { 152 | auto file = fmt::format("{0}/{1}-hop/qa_test_qtype.txt", 153 | cfg::data_root, cfg::nhop_subg); 154 | std::ifstream fin(file); 155 | std::string st; 156 | q_types.clear(); 157 | total_ntypes = 0; 158 | std::set ss; 159 | while (fin >> st) 160 | { 161 | q_types.push_back(st); 162 | ss.insert(st); 163 | } 164 | total_ntypes = ss.size(); 165 | assert(q_types.size() == test_set.orig_samples.size()); 166 | std::cerr << "ntypes: " << total_ntypes << std::endl; 167 | } 168 | 169 | void OutputScores(INet* net) 170 | { 171 | FILE* fout = fopen(fmt::sprintf("%s/test_vis.txt", cfg::save_dir).c_str(), "w"); 172 | for (size_t i = 0; i < kb.node_list.size(); ++i) 173 | fprintf(fout, "%s\n", kb.node_list[i]->name.c_str()); 174 | std::set< std::string > selected; 175 | 176 | auto& rel_embed = model.params["rel_embedding"]->value; 177 | 178 | DTensor mat; 179 | size_t idx = 0; 180 | 181 | std::vector< std::string > rel_list(relation_dict.size()); 182 | for (auto& p : relation_dict) 183 | { 184 | rel_list[p.second] = p.first; 185 | } 186 | 187 | while (test_set.GetMiniBatch(cfg::batch_size, mini_batch)) 188 | { 189 | net->BuildBatchGraph(mini_batch, Phase::TEST); 190 | fg.FeedForward({net->q_embed_query, net->pred, net->hitk}, net->inputs, Phase::TEST); 191 | 192 | auto& q_embed = net->q_embed_query->value; 193 | mat.MM(q_embed, rel_embed, Trans::N, Trans::T, 1.0, 0.0); 194 | 195 | for (size_t j = 0; j < q_embed.rows(); ++j) 196 | { 197 | auto& tt = q_types[idx + j]; 198 | if (selected.count(tt)) 199 | continue; 200 | if (net->hitk->value.data->ptr[j] == 0) 201 | continue; 202 | selected.insert(tt); 203 | std::cerr << tt << std::endl; 204 | 205 | fprintf(fout, "%d\n", (int)idx + (int)j); 206 | 207 | auto* cur_pred = net->pred->value.data->ptr + j * net->pred->value.cols(); 208 | for (size_t k = 0; k < net->pred->value.cols(); ++k) 209 | { 210 | if (k) 211 | fprintf(fout, " "); 212 | fprintf(fout, "%.6f", cur_pred[k]); 213 | } 214 | fprintf(fout, "\n"); 215 | 216 | cur_pred = mat.data->ptr + j * mat.cols(); 217 | for (size_t k = 0; k < mat.cols(); ++k) 218 | { 219 | std::string tttt; 220 | if (k < rel_list.size()) 221 | tttt = rel_list[k]; 222 | else 223 | tttt = rel_list[k - rel_list.size()] + "-inv"; 224 | fprintf(fout, "%s %.6f\n", tttt.c_str(), cur_pred[k]); 225 | } 226 | } 227 | std::cerr << selected.size() << " " << total_ntypes << std::endl; 228 | if (selected.size() == total_ntypes) 229 | { 230 | std::cerr << "job done" << std::endl; 231 | break; 232 | } 233 | idx += q_embed.rows(); 234 | } 235 | std::cerr << idx << std::endl; 236 | fclose(fout); 237 | } 238 | 239 | int main(const int argc, const char** argv) 240 | { 241 | srand(time(NULL)); 242 | 243 | cfg::LoadParams(argc, argv); 244 | GpuHandle::Init(cfg::dev_id, 1); 245 | 246 | relation_dict = GetRelations(); 247 | word_dict = GetVocab(); 248 | side_word_dict = GetSideWordDict(); 249 | kb.ParseKnowledgeFile(); 250 | for (auto suffix : {"train", "test", "dev"}) 251 | { 252 | kb.ParseEntityInAnswers(suffix); 253 | } 254 | std::cerr << "# entites in total: " << kb.node_dict.size() << std::endl; 255 | 256 | train_set.Load("train"); 257 | val_set.Load("dev"); 258 | test_set.Load("test"); 259 | 260 | std::cerr << "building net..." << std::endl; 261 | INet* net = nullptr; 262 | if (!strcmp(cfg::net_type, "NetMultiHop")) 263 | { 264 | net = new NetMultiHop(); 265 | } 266 | else if (!strcmp(cfg::net_type, "NetLatentY")) 267 | { 268 | assert(cfg::init_idx_file == nullptr); 269 | net = new NetLatentY(); 270 | } 271 | else { 272 | std::cerr << "unknown net type: " << cfg::net_type << std::endl; 273 | return 0; 274 | } 275 | 276 | net->BuildNet(); 277 | std::cerr << "done" << std::endl; 278 | if (cfg::test_only) 279 | { 280 | model.Load(fmt::sprintf("%s/iter_%d.model", cfg::save_dir, cfg::iter)); 281 | SavePred("dev", val_set, net); 282 | SavePred("test", test_set, net); 283 | } else if (cfg::vis_score) 284 | { 285 | model.Load(fmt::sprintf("%s/iter_%d.model", cfg::save_dir, cfg::iter)); 286 | LoadTestTypes(); 287 | OutputScores(net); 288 | } else 289 | MainLoop(net); 290 | 291 | //GpuHandle::Destroy(); 292 | return 0; 293 | } 294 | -------------------------------------------------------------------------------- /code_qa_txt/vis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | make 3 | 4 | nhop_subg=3 5 | dataset=ntm 6 | data_root=../metaQA 7 | #net_type=NetMultiHop 8 | net_type=NetLatentY 9 | 10 | result_root=$HOME/scratch/results/graph_mem/nhop-$nhop_subg/$dataset 11 | 12 | num_neg=10000 13 | max_bp_iter=1 14 | max_q_iter=3 15 | batch_size=128 16 | n_hidden=64 17 | n_embed=256 18 | margin=0.1 19 | learning_rate=0.01 20 | max_iter=4000000 21 | cur_iter=900 22 | w_scale=0.01 23 | loss_type=cross_entropy 24 | save_dir=$result_root/embed-$n_embed 25 | 26 | if [ ! -e $save_dir ]; 27 | then 28 | mkdir -p $save_dir 29 | fi 30 | 31 | ./build/main \ 32 | -num_neg $num_neg \ 33 | -vis_score 1 \ 34 | -loss_type $loss_type \ 35 | -data_root $data_root \ 36 | -dataset $dataset \ 37 | -n_hidden $n_hidden \ 38 | -nhop_subg $nhop_subg \ 39 | -lr $learning_rate \ 40 | -max_bp_iter $max_bp_iter \ 41 | -net_type $net_type \ 42 | -max_q_iter $max_q_iter \ 43 | -margin $margin \ 44 | -max_iter $max_iter \ 45 | -svdir $save_dir \ 46 | -embed $n_embed \ 47 | -batch_size $batch_size \ 48 | -m 0.9 \ 49 | -l2 0.00 \ 50 | -w_scale $w_scale \ 51 | -int_report 1 \ 52 | -int_test 1 \ 53 | -int_save 1000000 \ 54 | -cur_iter $cur_iter 55 | --------------------------------------------------------------------------------