├── generation ├── models │ ├── .gitkeep │ └── __init__.py ├── cwq_evaluate.py ├── virtuoso_setup.py └── webqsp_evaluate_offcial.py ├── relation_retrieval └── bi-encoder │ ├── __init__.py │ ├── biencoder.py │ ├── faiss_indexer.py │ └── run_bi_encoder.py ├── lib └── virtodbc.so ├── figures └── overview.png ├── ontology └── README.md ├── data ├── common_data │ ├── README.md │ └── facc1 │ │ └── README.md └── README.md ├── entity_retrieval ├── entity_disamb_predictlog.txt ├── BERT_NER │ ├── api.py │ └── bert.py ├── aqqu_util.py ├── entity_disamb_predictrun_entity_disamb_CWQ.sh ├── bert_ranker.py ├── bert_entity_linker.py └── surface_index_memory.py ├── scripts ├── README.md ├── run_bi_encoder_WebQSP.sh ├── run_bi_encoder_CWQ.sh ├── run_entity_disamb.sh ├── t5_base_CWQ.sh ├── t5_base_WebQSP.sh ├── concat_retrieval_CWQ.sh ├── concat_retrieval_WebQSP.sh ├── concat_gold_CWQ.sh ├── concat_gold_WebQSP.sh ├── run_cross_encoder_CWQ_question_relation.sh ├── without_relation_CWQ.sh ├── without_entity_CWQ.sh ├── GMT_KBQA_CWQ.sh ├── without_entity_WebQSP.sh ├── without_relation_WebQSP.sh ├── GMT_KBQA_WebQSP.sh └── run_cross_encoder_WebQSP_question_relation.sh ├── environment.yml ├── components ├── dataset_utils.py ├── utils.py └── expr_parser.py ├── LICENSE ├── .gitignore ├── generation_command.txt ├── inputDataset └── gen_mtl_dataset.py ├── config.py ├── detect_and_link_entity.py ├── ablation_exps.py └── run_relation_data_process.py /generation/models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generation/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /relation_retrieval/bi-encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/virtodbc.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HXX97/GMT-KBQA/HEAD/lib/virtodbc.so -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HXX97/GMT-KBQA/HEAD/figures/overview.png -------------------------------------------------------------------------------- /ontology/README.md: -------------------------------------------------------------------------------- 1 | Files under this folder originate from [GrailQA](https://github.com/dki-lab/GrailQA) -------------------------------------------------------------------------------- /data/common_data/README.md: -------------------------------------------------------------------------------- 1 | # source 2 | - relation_freq.json: originates from https://raw.githubusercontent.com/salesforce/rng-kbqa/main/GrailQA/misc/relation_freq.json -------------------------------------------------------------------------------- /entity_retrieval/entity_disamb_predictlog.txt: -------------------------------------------------------------------------------- 1 | 06/11/2022 16:27:56 - WARNING - __main__ - Process rank: -1, device: cuda, n_gpu: 1, distributed training: False 2 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | - run_bi_encoder_WebQSP.sh: WebQSP bi-encoder 2 | - run_bi_encoder_CWQ.sh: CWQ bi-encoder 3 | - run_cross_encoder_WebQSP_question_relation.sh: WebQSP cross-encoder 4 | - run_cross_encoder_CWQ_question_relation.sh: CWQ cross-encoder -------------------------------------------------------------------------------- /data/common_data/facc1/README.md: -------------------------------------------------------------------------------- 1 | Please download the mention information (including processed FACC1 mentions and all entity alias in Freebase) from https://1drv.ms/u/s!AuJiG47gLqTznjl7VbnOESK6qPW2?e=HDy2Ye. 2 | 3 | --- 4 | ## Content demo: 5 | 6 | **entity_list_file_freebase_complete_all_mention** 7 | 8 | [mid] [surface form] [popularity] 9 | 10 | m.0142cyx away 5 11 | 12 | 13 | **entity_linker/data/surface_map_file_freebase_complete_all_mention** 14 | 15 | [surface form] [popularity] [mid] 16 | 17 | ps 1.598454413869431e-05 m.0f469q 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gmt 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia 6 | - anaconda 7 | - defaults 8 | dependencies: 9 | - requests=2.26.0=pyhd3eb1b0_0 10 | - tqdm=4.62.1=pyhd3eb1b0_1 11 | - nltk=3.6.2=pyhd3eb1b0_0 12 | - pandas=1.3.2=py37h8c16a72_0 13 | - pytorch=1.9.0=py3.7_cuda11.1_cudnn8.0.5_0 14 | - scikit-learn=0.24.2=py37ha9443f7_0 15 | - glob2=0.7=pyhd3eb1b0_0 16 | - networkx=2.6.2=pyhd3eb1b0_0 17 | - pyodbc=4.0.31=py37h295c915_0 18 | - pip 19 | - pip: 20 | - sparqlwrapper==1.8.5 21 | - urllib3==1.25.11 22 | - transformers==4.18.0 23 | - numpy==1.21.5 24 | - faiss-gpu==1.7.2 25 | -------------------------------------------------------------------------------- /entity_retrieval/BERT_NER/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021, Ohio State University (Yu Gu) 3 | Yu Gu 4 | Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | """ 6 | 7 | 8 | from flask import Flask,request,jsonify 9 | from flask_cors import CORS 10 | 11 | from bert import Ner 12 | 13 | app = Flask(__name__) 14 | CORS(app) 15 | 16 | model = Ner("out_!x") 17 | 18 | @app.route("/predict",methods=['POST']) 19 | def predict(): 20 | text = request.json["text"] 21 | try: 22 | out = model.predict(text) 23 | return jsonify({"result":out}) 24 | except Exception as e: 25 | print(e) 26 | return jsonify({"result":"Model Failed"}) 27 | 28 | if __name__ == "__main__": 29 | app.run('0.0.0.0',port=8000) -------------------------------------------------------------------------------- /scripts/run_bi_encoder_WebQSP.sh: -------------------------------------------------------------------------------- 1 | exp_id=${1:-none} 2 | 3 | dataset='WebQSP' 4 | exp_prefix="data/${dataset}/relation_retrieval/bi-encoder/saved_models/${exp_id}/" 5 | log_dir="data/${dataset}/relation_retrieval/bi-encoder/saved_models/${exp_id}/" 6 | 7 | if [ -d ${exp_prefix} ]; then 8 | echo "${exp_prefix} already exists" 9 | else 10 | mkdir ${exp_prefix} 11 | fi 12 | if [ -d ${log_dir} ]; then 13 | echo "${log_dir} already exists" 14 | else 15 | mkdir ${log_dir} 16 | fi 17 | python relation_retrieval/bi-encoder/run_bi_encoder.py \ 18 | --dataset_type WebQSP \ 19 | --model_save_path ${exp_prefix} \ 20 | --max_len 60 \ 21 | --batch_size 4 \ 22 | --epochs 3 \ 23 | --log_dir ${log_dir} \ 24 | --cache_dir bert-base-uncased -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Data Introduction 2 | ============================== 3 | 4 | 5 | ------------ 6 | ── common_data: data of entities or relations of freebase. 7 | ├── CWQ 8 | ├── entity_retrieval: data of retrieved entities 9 | ├── relation_retrieval: data of retrieved relations 10 | ├── generation: data for logical form generation 11 | ├── origin: original ComplexWebQuestions datasets from (https://www.dropbox.com/sh/7pkwkrfnwqhsnpo/AACuu4v3YNkhirzBOeeaHYala) 12 | ├── sexpr: S-expressions translated from sparql queries 13 | ├── WebQSP 14 | ├── entity_retrieval: data of retrieved entities 15 | ├── relation_retrieval: data of retrieved relations 16 | ├── generation: data for logical form generation 17 | ├── origin: original WebQSP datasets from (https://www.microsoft.com/en-us/download/details.aspx?id=52763) 18 | ├── sexpr: S-expressions translated from sparql queries 19 | 20 | -------- -------------------------------------------------------------------------------- /scripts/run_bi_encoder_CWQ.sh: -------------------------------------------------------------------------------- 1 | exp_id=${1:-none} 2 | 3 | dataset='CWQ' 4 | exp_prefix="data/${dataset}/relation_retrieval/bi-encoder/saved_models/${exp_id}/" 5 | log_dir="data/${dataset}/relation_retrieval/bi-encoder/saved_models/${exp_id}/" 6 | 7 | if [ -d ${exp_prefix} ]; then 8 | echo "${exp_prefix} already exists" 9 | else 10 | mkdir ${exp_prefix} 11 | fi 12 | if [ -d ${log_dir} ]; then 13 | echo "${log_dir} already exists" 14 | else 15 | mkdir ${log_dir} 16 | fi 17 | python relation_retrieval/bi-encoder/run_bi_encoder.py \ 18 | --add_special_tokens \ 19 | --dataset_type CWQ \ 20 | --model_save_path ${exp_prefix} \ 21 | --max_len 32 \ 22 | --batch_size 4 \ 23 | --epochs 1 \ 24 | --log_dir ${log_dir} \ 25 | --cache_dir bert-base-uncased -------------------------------------------------------------------------------- /components/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | from torch.utils.data import Dataset 10 | 11 | class ListDataset(Dataset): 12 | def __init__(self, examples): 13 | self.examples = examples 14 | 15 | def __len__(self): 16 | return len(self.examples) 17 | 18 | def __getitem__(self, i): 19 | return self.examples[i] 20 | 21 | def __iter__(self): 22 | return iter(self.examples) 23 | 24 | class LFCandidate: 25 | def __init__(self, s_expr, normed_expr, ex=None, f1=None, edist=None): 26 | self.s_expr = s_expr 27 | self.normed_expr = normed_expr 28 | self.ex = ex 29 | self.f1 = f1 30 | self.edist = edist 31 | 32 | def __str__(self): 33 | return '{}\n\t->{}\n'.format(self.s_expr, self.normed_expr) 34 | 35 | def __repr__(self): 36 | return self.__str__() 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, HXX97 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /scripts/run_entity_disamb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | dataset=${1:-"CWQ"} 5 | ACTION=${2:-none} 6 | 7 | exp_prefix="entity_retrieval/entity_disamb_${dataset}" 8 | DATA_DIR="data/${dataset}/entity_retrieval/candidate_entities" 9 | 10 | if [ "$ACTION" = "train" ]; then 11 | 12 | if [ -d ${exp_prefix} ]; then 13 | echo "${exp_prefix} already exists" 14 | else 15 | mkdir ${exp_prefix} 16 | fi 17 | 18 | cp scripts/run_entity_disamb.sh "${exp_prefix}run_entity_disamb.sh" 19 | git rev-parse HEAD > "${exp_prefix}commitid.log" 20 | 21 | # --overwrite_cache \ 22 | python -u run_entity_disamb.py \ 23 | --dataset ${dataset} \ 24 | --model_type bert \ 25 | --model_name_or_path bert-base-uncased \ 26 | --do_lower_case \ 27 | --do_train \ 28 | --do_eval \ 29 | --disable_tqdm \ 30 | --train_file $DATA_DIR/${dataset}_train_entities_facc1_unranked.json \ 31 | --predict_file $DATA_DIR/${dataset}_dev_entities_facc1_unranked.json \ 32 | --learning_rate 1e-5 \ 33 | --evaluate_during_training \ 34 | --num_train_epochs 2 \ 35 | --overwrite_output_dir \ 36 | --max_seq_length 96 \ 37 | --logging_steps 200 \ 38 | --eval_steps 1000 \ 39 | --save_steps 2000 \ 40 | --warmup_ratio 0.1 \ 41 | --output_dir "${exp_prefix}output" \ 42 | --per_gpu_train_batch_size 8 \ 43 | --per_gpu_eval_batch_size 16 | tee "${exp_prefix}log.txt" 44 | 45 | elif [ "$ACTION" = "predict" ]; then 46 | 47 | model=${exp_prefix}/output 48 | split=${3:-test} 49 | 50 | python -u run_entity_disamb.py \ 51 | --dataset ${dataset} \ 52 | --model_type bert \ 53 | --model_name_or_path ${model} \ 54 | --do_lower_case \ 55 | --do_eval \ 56 | --do_predict \ 57 | --predict_file $DATA_DIR/${dataset}_${split}_entities_facc1_unranked.json \ 58 | --overwrite_output_dir \ 59 | --max_seq_length 96 \ 60 | --output_dir $DATA_DIR/disamb_results/${dataset}_${split} \ 61 | --per_gpu_eval_batch_size 64 62 | else 63 | echo "train or eval or predict" 64 | fi -------------------------------------------------------------------------------- /entity_retrieval/aqqu_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2015, University of Freiburg. 3 | 4 | Elmar Haussmann 5 | """ 6 | import re 7 | from nltk import word_tokenize 8 | 9 | 10 | def normalize_entity_name(name): 11 | name = name.lower() 12 | # name = name.replace('!', '') 13 | # name = name.replace('.', '') 14 | # name = name.replace(',', '') 15 | # name = name.replace('-', '') 16 | # name = name.replace('_', '') 17 | # name = name.replace(' ', '') 18 | # name = name.replace('\'', '') 19 | # name = name.replace('"', '') 20 | # name = name.replace('\\', '') 21 | 22 | 23 | # the following is only for freebase_complete_all_mention 24 | name = ' '.join(word_tokenize(name)) 25 | # word_tokenize from nltk will change the left " to ``, which is pretty weird. Fix it here 26 | name = name.replace('``', '"').replace("''", '"') 27 | 28 | return name 29 | 30 | 31 | def read_abbreviations(abbreviations_file): 32 | ''' 33 | Return a set of abbreviations. 34 | :param abbreviations_file: 35 | :return: 36 | ''' 37 | abbreviations = set() 38 | with open(abbreviations_file, 'r') as f: 39 | for line in f: 40 | abbreviations.add(line.strip().decode('utf-8').lower()) 41 | return abbreviations 42 | 43 | 44 | def remove_abbreviations_from_entity_name(entity_name, 45 | abbreviations): 46 | tokens = entity_name.lower().split(' ') 47 | non_abbr_tokens = [t for t in tokens if t not in abbreviations] 48 | return ' '.join(non_abbr_tokens) 49 | 50 | 51 | def remove_prefixes_from_name(name): 52 | if name.startswith('the'): 53 | name = name[3:] 54 | return name 55 | 56 | 57 | def remove_suffixes_from_name(name): 58 | if '#' in name or '(' in name: 59 | name = remove_number_suffix(name) 60 | name = remove_bracket_suffix(name) 61 | return name 62 | 63 | 64 | def remove_number_suffix(name): 65 | res = re.match(r'.*( #[0-9]+)$', name) 66 | if res: 67 | name = name[:res.start(1)] 68 | return name 69 | else: 70 | return name 71 | 72 | 73 | def remove_bracket_suffix(name): 74 | res = re.match(r'.*( \([^\(\)]+\))$', name) 75 | if res: 76 | name = name[:res.start(1)] 77 | return name 78 | else: 79 | return name -------------------------------------------------------------------------------- /scripts/t5_base_CWQ.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='CWQ' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 190 \ 22 | --max_src_len 256 \ 23 | --epochs 15 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --model T5_generation \ 34 | --dataset_type CWQ \ 35 | --train_batch_size 8 \ 36 | --eval_batch_size 4 \ 37 | --test_batch_size 4 | tee "${exp_prefix}log.txt" 38 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 39 | split=${4:-test} 40 | beam_size=${5:-50} 41 | test_batch_size=${6:-4} 42 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 43 | python run_multitask_generator_final.py \ 44 | --do_predict \ 45 | --do_debug ${do_debug} \ 46 | --predict_split ${split} \ 47 | --epochs 15 \ 48 | --lr 5e-5 \ 49 | --max_tgt_len 190 \ 50 | --max_src_len 256 \ 51 | --iters_to_accumulate 1 \ 52 | --eval_beams ${beam_size} \ 53 | --pretrained_model_path t5-base \ 54 | --output_dir ${exp_prefix} \ 55 | --model_save_dir "${exp_prefix}model_saved" \ 56 | --normalize_relations \ 57 | --sample_size 10 \ 58 | --model T5_generation \ 59 | --overwrite_output_dir \ 60 | --dataset_type CWQ \ 61 | --train_batch_size 8 \ 62 | --eval_batch_size 4 \ 63 | --test_batch_size ${test_batch_size} 64 | else 65 | echo "train or eval" 66 | fi -------------------------------------------------------------------------------- /scripts/t5_base_WebQSP.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='WebQSP' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 110 \ 22 | --max_src_len 256 \ 23 | --epochs 20 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --model T5_generation \ 34 | --dataset_type WebQSP \ 35 | --train_batch_size 2 \ 36 | --eval_batch_size 4 \ 37 | --test_batch_size 2 | tee "${exp_prefix}log.txt" 38 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 39 | split=${4:-test} 40 | beam_size=${5:-50} 41 | test_batch_size=${6:-2} 42 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 43 | python run_multitask_generator_final.py \ 44 | --do_predict \ 45 | --do_debug ${do_debug} \ 46 | --predict_split ${split} \ 47 | --epochs 20 \ 48 | --lr 5e-5 \ 49 | --max_tgt_len 110 \ 50 | --max_src_len 256 \ 51 | --iters_to_accumulate 1 \ 52 | --eval_beams ${beam_size} \ 53 | --pretrained_model_path t5-base \ 54 | --output_dir ${exp_prefix} \ 55 | --model_save_dir "${exp_prefix}model_saved" \ 56 | --normalize_relations \ 57 | --sample_size 10 \ 58 | --model T5_generation \ 59 | --overwrite_output_dir \ 60 | --dataset_type WebQSP \ 61 | --train_batch_size 2 \ 62 | --eval_batch_size 4 \ 63 | --test_batch_size ${test_batch_size} 64 | else 65 | echo "train or eval" 66 | fi -------------------------------------------------------------------------------- /scripts/concat_retrieval_CWQ.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='CWQ' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 190 \ 22 | --max_src_len 256 \ 23 | --epochs 15 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --model T5_generation_concat \ 34 | --dataset_type CWQ \ 35 | --train_batch_size 8 \ 36 | --eval_batch_size 4 \ 37 | --test_batch_size 4 | tee "${exp_prefix}log.txt" 38 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 39 | split=${4:-test} 40 | beam_size=${5:-50} 41 | test_batch_size=${6:-4} 42 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 43 | python run_multitask_generator_final.py \ 44 | --do_predict \ 45 | --do_debug ${do_debug} \ 46 | --predict_split ${split} \ 47 | --epochs 15 \ 48 | --lr 5e-5 \ 49 | --max_tgt_len 190 \ 50 | --max_src_len 256 \ 51 | --iters_to_accumulate 1 \ 52 | --eval_beams ${beam_size} \ 53 | --pretrained_model_path t5-base \ 54 | --output_dir ${exp_prefix} \ 55 | --model_save_dir "${exp_prefix}model_saved" \ 56 | --normalize_relations \ 57 | --sample_size 10 \ 58 | --model T5_generation_concat \ 59 | --overwrite_output_dir \ 60 | --dataset_type CWQ \ 61 | --train_batch_size 8 \ 62 | --eval_batch_size 4 \ 63 | --test_batch_size ${test_batch_size} 64 | else 65 | echo "train or eval" 66 | fi -------------------------------------------------------------------------------- /scripts/concat_retrieval_WebQSP.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='WebQSP' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_debug ${do_debug} \ 20 | --do_predict \ 21 | --max_tgt_len 110 \ 22 | --max_src_len 256 \ 23 | --epochs 20 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --model T5_generation_concat \ 34 | --dataset_type WebQSP \ 35 | --train_batch_size 2 \ 36 | --eval_batch_size 4 \ 37 | --test_batch_size 2 | tee "${exp_prefix}log.txt" 38 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 39 | split=${4:-test} 40 | beam_size=${5:-10} 41 | test_batch_size=${6:-8} 42 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 43 | python run_multitask_generator_final.py \ 44 | --do_predict \ 45 | --do_debug ${do_debug} \ 46 | --predict_split ${split} \ 47 | --epochs 20 \ 48 | --lr 5e-5 \ 49 | --max_tgt_len 110 \ 50 | --max_src_len 256 \ 51 | --iters_to_accumulate 1 \ 52 | --eval_beams ${beam_size} \ 53 | --pretrained_model_path t5-base \ 54 | --output_dir ${exp_prefix} \ 55 | --model_save_dir "${exp_prefix}model_saved" \ 56 | --normalize_relations \ 57 | --sample_size 10 \ 58 | --model T5_generation_concat \ 59 | --overwrite_output_dir \ 60 | --dataset_type WebQSP \ 61 | --train_batch_size 2 \ 62 | --eval_batch_size 4 \ 63 | --test_batch_size ${test_batch_size} 64 | else 65 | echo "train or eval" 66 | fi -------------------------------------------------------------------------------- /scripts/concat_gold_CWQ.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='CWQ' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 190 \ 22 | --max_src_len 256 \ 23 | --epochs 15 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --model T5_generation_concat \ 34 | --dataset_type CWQ \ 35 | --concat_golden \ 36 | --train_batch_size 8 \ 37 | --eval_batch_size 4 \ 38 | --test_batch_size 4 | tee "${exp_prefix}log.txt" 39 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 40 | split=${4:-test} 41 | beam_size=${5:-50} 42 | test_batch_size=${6:-4} 43 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 44 | python run_multitask_generator_final.py \ 45 | --do_predict \ 46 | --do_debug ${do_debug} \ 47 | --predict_split ${split} \ 48 | --epochs 15 \ 49 | --lr 5e-5 \ 50 | --max_tgt_len 190 \ 51 | --max_src_len 256 \ 52 | --iters_to_accumulate 1 \ 53 | --eval_beams ${beam_size} \ 54 | --pretrained_model_path t5-base \ 55 | --output_dir ${exp_prefix} \ 56 | --model_save_dir "${exp_prefix}model_saved" \ 57 | --normalize_relations \ 58 | --sample_size 10 \ 59 | --model T5_generation_concat \ 60 | --overwrite_output_dir \ 61 | --dataset_type CWQ \ 62 | --concat_golden \ 63 | --train_batch_size 8 \ 64 | --eval_batch_size 4 \ 65 | --test_batch_size ${test_batch_size} 66 | else 67 | echo "train or eval" 68 | fi -------------------------------------------------------------------------------- /scripts/concat_gold_WebQSP.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='WebQSP' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 110 \ 22 | --max_src_len 256 \ 23 | --epochs 20 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --model T5_generation_concat \ 34 | --dataset_type WebQSP \ 35 | --concat_golden \ 36 | --train_batch_size 2 \ 37 | --eval_batch_size 4 \ 38 | --test_batch_size 2 | tee "${exp_prefix}log.txt" 39 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 40 | split=${4:-test} 41 | beam_size=${5:-50} 42 | test_batch_size=${6:-2} 43 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 44 | python run_multitask_generator_final.py \ 45 | --do_predict \ 46 | --do_debug ${do_debug} \ 47 | --predict_split ${split} \ 48 | --epochs 20 \ 49 | --lr 5e-5 \ 50 | --max_tgt_len 110 \ 51 | --max_src_len 256 \ 52 | --iters_to_accumulate 1 \ 53 | --eval_beams ${beam_size} \ 54 | --pretrained_model_path t5-base \ 55 | --output_dir ${exp_prefix} \ 56 | --model_save_dir "${exp_prefix}model_saved" \ 57 | --normalize_relations \ 58 | --sample_size 10 \ 59 | --model T5_generation_concat \ 60 | --overwrite_output_dir \ 61 | --dataset_type WebQSP \ 62 | --concat_golden \ 63 | --train_batch_size 2 \ 64 | --eval_batch_size 4 \ 65 | --test_batch_size ${test_batch_size} 66 | else 67 | echo "train or eval" 68 | fi -------------------------------------------------------------------------------- /scripts/run_cross_encoder_CWQ_question_relation.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | 4 | dataset='CWQ' 5 | exp_prefix="data/${dataset}/relation_retrieval/cross-encoder/saved_models/${exp_id}/" 6 | log_dir="data/${dataset}/relation_retrieval/cross-encoder/saved_models/${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir ${exp_prefix} 15 | fi 16 | if [ -d ${log_dir} ]; then 17 | echo "${log_dir} already exists" 18 | else 19 | mkdir ${log_dir} 20 | fi 21 | python relation_retrieval/cross-encoder/cross_encoder_main.py \ 22 | --do_train \ 23 | --max_len 50 \ 24 | --batch_size 128 \ 25 | --epochs 6 \ 26 | --log_dir ${log_dir} \ 27 | --dataset_type CWQ \ 28 | --model_save_path ${exp_prefix} \ 29 | --output_dir ${exp_prefix} \ 30 | --cache_dir bert-base-uncased \ 31 | 32 | elif [ "$ACTION" = "eval" ]; then 33 | split=${3:-test} 34 | model_name=${4:-none} 35 | echo "Evaluating ${split}" 36 | python relation_retrieval/cross-encoder/cross_encoder_main.py \ 37 | --do_eval \ 38 | --predict_split ${split} \ 39 | --max_len 50 \ 40 | --batch_size 128 \ 41 | --epochs 6 \ 42 | --log_dir ${log_dir} \ 43 | --dataset_type CWQ \ 44 | --model_save_path "${exp_prefix}${model_name}" \ 45 | --output_dir "${exp_prefix}${model_name}_${split}/" \ 46 | --cache_dir bert-base-uncased \ 47 | 48 | elif [ "$ACTION" = "predict" ]; then 49 | split=${3:-test} 50 | model_name=${4:-none} 51 | if [ -d "${exp_prefix}${model_name}_${split}/" ]; then 52 | echo "${exp_prefix}${model_name}_${split}/ already exists" 53 | else 54 | mkdir "${exp_prefix}${model_name}_${split}/" 55 | fi 56 | echo "Predicting ${split}" 57 | python relation_retrieval/cross-encoder/cross_encoder_main.py \ 58 | --do_predict \ 59 | --predict_split ${split} \ 60 | --max_len 50 \ 61 | --batch_size 128 \ 62 | --epochs 6 \ 63 | --log_dir ${log_dir} \ 64 | --dataset_type CWQ \ 65 | --model_save_path "${exp_prefix}${model_name}" \ 66 | --output_dir "${exp_prefix}${model_name}_${split}/" \ 67 | --cache_dir bert-base-uncased \ 68 | 69 | else 70 | echo "train or eval or predict" 71 | fi -------------------------------------------------------------------------------- /entity_retrieval/entity_disamb_predictrun_entity_disamb_CWQ.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR="data/CWQ/entity_retrieval/candidate_entities" 4 | dataset=${1:-"CWQ"} 5 | ACTION=${2:-none} 6 | if [ "$ACTION" = "train" ]; then 7 | 8 | exp_prefix="entity_retrieval/entity_disamb_${dataset}" 9 | 10 | if [ -d ${exp_prefix} ]; then 11 | echo "${exp_prefix} already exists" 12 | else 13 | mkdir ${exp_prefix} 14 | fi 15 | 16 | cp scripts/run_entity_disamb_CWQ.sh "${exp_prefix}run_entity_disamb_CWQ.sh" 17 | git rev-parse HEAD > "${exp_prefix}commitid.log" 18 | 19 | # --overwrite_cache \ 20 | python -u run_entity_disamb.py \ 21 | --dataset ${dataset} \ 22 | --model_type bert \ 23 | --model_name_or_path bert-base-uncased \ 24 | --do_lower_case \ 25 | --do_train \ 26 | --do_eval \ 27 | --disable_tqdm \ 28 | --train_file $DATA_DIR/${dataset}_train_entities_facc1_unranked.json \ 29 | --predict_file $DATA_DIR/${dataset}_dev_entities_facc1_unranked.json \ 30 | --learning_rate 1e-5 \ 31 | --evaluate_during_training \ 32 | --num_train_epochs 2 \ 33 | --overwrite_output_dir \ 34 | --max_seq_length 96 \ 35 | --logging_steps 200 \ 36 | --eval_steps 1000 \ 37 | --save_steps 2000 \ 38 | --warmup_ratio 0.1 \ 39 | --output_dir "${exp_prefix}output" \ 40 | --per_gpu_train_batch_size 8 \ 41 | --per_gpu_eval_batch_size 16 | tee "${exp_prefix}log.txt" 42 | 43 | elif [ "$ACTION" = "eval" ]; then 44 | 45 | exp_prefix="entity_retrieval/entity_disamb_${dataset}" 46 | model=${exp_prefix}/output 47 | split=${3:-test} 48 | 49 | python -u run_disamb.py \ 50 | --dataset ${dataset} \ 51 | --model_type bert \ 52 | --model_name_or_path ${model} \ 53 | --do_lower_case \ 54 | --do_eval \ 55 | --predict_file $DATA_DIR/${dataset}_${split}_entities_facc1_unranked.json \ 56 | --overwrite_output_dir \ 57 | --max_seq_length 96 \ 58 | --output_dir $DATA_DIR/disamb_results/${dataset}_${split} \ 59 | --per_gpu_eval_batch_size 64 60 | 61 | elif [ "$ACTION" = "predict" ]; then 62 | 63 | exp_prefix="entity_retrieval/entity_disamb_${dataset}" 64 | model=${exp_prefix}/output 65 | split=${3:-test} 66 | 67 | python -u run_entity_disamb.py \ 68 | --dataset ${dataset} \ 69 | --model_type bert \ 70 | --model_name_or_path ${model} \ 71 | --do_lower_case \ 72 | --do_eval \ 73 | --do_predict \ 74 | --predict_file $DATA_DIR/${dataset}_${split}_entities_facc1_unranked.json \ 75 | --overwrite_output_dir \ 76 | --max_seq_length 96 \ 77 | --output_dir $DATA_DIR/disamb_results/${dataset}_${split} \ 78 | --per_gpu_eval_batch_size 64 79 | # copy the entity disambugation file to misc/ directory 80 | # cp results/disamb/${dataset}_${split}/predictions.json misc/${dataset}_${split}_entity_linking.json 81 | else 82 | echo "train or eval or predict" 83 | fi -------------------------------------------------------------------------------- /scripts/without_relation_CWQ.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='CWQ' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 190 \ 22 | --max_src_len 256 \ 23 | --epochs 15 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --add_prefix \ 34 | --model T5_Multitask_Entity_Concat \ 35 | --dataset_type CWQ \ 36 | --warmup_epochs 5 \ 37 | --train_batch_size 8 \ 38 | --eval_batch_size 4 \ 39 | --test_batch_size 4 | tee "${exp_prefix}log.txt" 40 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 41 | split=${4:-test} 42 | beam_size=${5:-50} 43 | test_batch_size=${6:-4} 44 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 45 | python run_multitask_generator_final.py \ 46 | --do_predict \ 47 | --do_debug ${do_debug} \ 48 | --predict_split ${split} \ 49 | --epochs 15 \ 50 | --lr 5e-5 \ 51 | --max_tgt_len 190 \ 52 | --max_src_len 256 \ 53 | --iters_to_accumulate 1 \ 54 | --eval_beams ${beam_size} \ 55 | --pretrained_model_path t5-base \ 56 | --output_dir ${exp_prefix} \ 57 | --model_save_dir "${exp_prefix}model_saved" \ 58 | --normalize_relations \ 59 | --sample_size 10 \ 60 | --add_prefix \ 61 | --model T5_Multitask_Entity_Concat \ 62 | --overwrite_output_dir \ 63 | --dataset_type CWQ \ 64 | --train_batch_size 8 \ 65 | --eval_batch_size 4 \ 66 | --test_batch_size ${test_batch_size} 67 | else 68 | echo "train or eval" 69 | fi -------------------------------------------------------------------------------- /scripts/without_entity_CWQ.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='CWQ' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 190 \ 22 | --max_src_len 256 \ 23 | --epochs 15 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --add_prefix \ 34 | --model T5_Multitask_Relation_Concat \ 35 | --dataset_type CWQ \ 36 | --warmup_epochs 5 \ 37 | --train_batch_size 8 \ 38 | --eval_batch_size 4 \ 39 | --test_batch_size 4 | tee "${exp_prefix}log.txt" 40 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 41 | split=${4:-test} 42 | beam_size=${5:-50} 43 | test_batch_size=${6:-4} 44 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 45 | python run_multitask_generator_final.py \ 46 | --do_predict \ 47 | --do_debug ${do_debug} \ 48 | --predict_split ${split} \ 49 | --epochs 15 \ 50 | --lr 5e-5 \ 51 | --max_tgt_len 190 \ 52 | --max_src_len 256 \ 53 | --iters_to_accumulate 1 \ 54 | --eval_beams ${beam_size} \ 55 | --pretrained_model_path t5-base \ 56 | --output_dir ${exp_prefix} \ 57 | --model_save_dir "${exp_prefix}model_saved" \ 58 | --normalize_relations \ 59 | --sample_size 10 \ 60 | --add_prefix \ 61 | --model T5_Multitask_Relation_Concat \ 62 | --overwrite_output_dir \ 63 | --dataset_type CWQ \ 64 | --train_batch_size 8 \ 65 | --eval_batch_size 4 \ 66 | --test_batch_size ${test_batch_size} 67 | else 68 | echo "train or eval" 69 | fi -------------------------------------------------------------------------------- /scripts/GMT_KBQA_CWQ.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='CWQ' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_debug ${do_debug} \ 20 | --do_predict \ 21 | --max_tgt_len 190 \ 22 | --max_src_len 256 \ 23 | --epochs 15 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --add_prefix \ 34 | --model T5_MultiTask_Relation_Entity_Concat \ 35 | --dataset_type CWQ \ 36 | --warmup_epochs 5 \ 37 | --train_batch_size 8 \ 38 | --eval_batch_size 4 \ 39 | --test_batch_size 4 | tee "${exp_prefix}log.txt" 40 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 41 | split=${4:-test} 42 | beam_size=${5:-50} 43 | test_batch_size=${6:-4} 44 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 45 | python run_multitask_generator_final.py \ 46 | --do_predict \ 47 | --do_debug ${do_debug} \ 48 | --predict_split ${split} \ 49 | --epochs 15 \ 50 | --lr 5e-5 \ 51 | --max_tgt_len 190 \ 52 | --max_src_len 256 \ 53 | --iters_to_accumulate 1 \ 54 | --eval_beams ${beam_size} \ 55 | --pretrained_model_path t5-base \ 56 | --output_dir ${exp_prefix} \ 57 | --model_save_dir "${exp_prefix}model_saved" \ 58 | --normalize_relations \ 59 | --sample_size 10 \ 60 | --add_prefix \ 61 | --model T5_MultiTask_Relation_Entity_Concat \ 62 | --overwrite_output_dir \ 63 | --dataset_type CWQ \ 64 | --train_batch_size 8 \ 65 | --eval_batch_size 4 \ 66 | --test_batch_size ${test_batch_size} 67 | else 68 | echo "train or eval" 69 | fi -------------------------------------------------------------------------------- /scripts/without_entity_WebQSP.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='WebQSP' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 110 \ 22 | --max_src_len 256 \ 23 | --epochs 20 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --add_prefix \ 33 | --sample_size 10 \ 34 | --model T5_Multitask_Relation_Concat \ 35 | --dataset_type WebQSP \ 36 | --warmup_epochs 5 \ 37 | --train_batch_size 2 \ 38 | --eval_batch_size 4 \ 39 | --test_batch_size 2 | tee "${exp_prefix}log.txt" 40 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 41 | split=${4:-test} 42 | beam_size=${5:-50} 43 | test_batch_size=${6:-2} 44 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 45 | python run_multitask_generator_final.py \ 46 | --do_predict \ 47 | --do_debug ${do_debug} \ 48 | --predict_split ${split} \ 49 | --epochs 20 \ 50 | --lr 5e-5 \ 51 | --max_tgt_len 110 \ 52 | --max_src_len 256 \ 53 | --iters_to_accumulate 1 \ 54 | --eval_beams ${beam_size} \ 55 | --pretrained_model_path t5-base \ 56 | --output_dir ${exp_prefix} \ 57 | --model_save_dir "${exp_prefix}model_saved" \ 58 | --normalize_relations \ 59 | --sample_size 10 \ 60 | --add_prefix \ 61 | --model T5_Multitask_Relation_Concat \ 62 | --overwrite_output_dir \ 63 | --dataset_type WebQSP \ 64 | --train_batch_size 2 \ 65 | --eval_batch_size 4 \ 66 | --test_batch_size ${test_batch_size} 67 | else 68 | echo "train or eval" 69 | fi -------------------------------------------------------------------------------- /scripts/without_relation_WebQSP.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='WebQSP' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_debug ${do_debug} \ 20 | --do_predict \ 21 | --max_tgt_len 110 \ 22 | --max_src_len 256 \ 23 | --epochs 20 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --add_prefix \ 34 | --model T5_Multitask_Entity_Concat \ 35 | --dataset_type WebQSP \ 36 | --warmup_epochs 5 \ 37 | --train_batch_size 2 \ 38 | --eval_batch_size 4 \ 39 | --test_batch_size 2 | tee "${exp_prefix}log.txt" 40 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 41 | split=${4:-test} 42 | beam_size=${5:-50} 43 | test_batch_size=${6:-2} 44 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 45 | python run_multitask_generator_final.py \ 46 | --do_predict \ 47 | --do_debug ${do_debug} \ 48 | --predict_split ${split} \ 49 | --epochs 20 \ 50 | --lr 5e-5 \ 51 | --max_tgt_len 110 \ 52 | --max_src_len 256 \ 53 | --iters_to_accumulate 1 \ 54 | --eval_beams ${beam_size} \ 55 | --pretrained_model_path t5-base \ 56 | --output_dir ${exp_prefix} \ 57 | --model_save_dir "${exp_prefix}model_saved" \ 58 | --normalize_relations \ 59 | --sample_size 10 \ 60 | --add_prefix \ 61 | --model T5_Multitask_Entity_Concat \ 62 | --overwrite_output_dir \ 63 | --dataset_type WebQSP \ 64 | --train_batch_size 2 \ 65 | --eval_batch_size 4 \ 66 | --test_batch_size ${test_batch_size} 67 | else 68 | echo "train or eval" 69 | fi -------------------------------------------------------------------------------- /scripts/GMT_KBQA_WebQSP.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | do_debug=${3:-False} 4 | 5 | dataset='WebQSP' 6 | exp_prefix="exps/${dataset}_${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir -p ${exp_prefix} 15 | fi 16 | # --do_eval \ 17 | python run_multitask_generator_final.py \ 18 | --do_train \ 19 | --do_predict \ 20 | --do_debug ${do_debug} \ 21 | --max_tgt_len 110 \ 22 | --max_src_len 256 \ 23 | --epochs 20 \ 24 | --lr 5e-5 \ 25 | --eval_beams 50 \ 26 | --iters_to_accumulate 1 \ 27 | --pretrained_model_path t5-base \ 28 | --output_dir ${exp_prefix} \ 29 | --model_save_dir "${exp_prefix}model_saved" \ 30 | --overwrite_output_dir \ 31 | --normalize_relations \ 32 | --sample_size 10 \ 33 | --add_prefix \ 34 | --model T5_MultiTask_Relation_Entity_Concat \ 35 | --dataset_type WebQSP \ 36 | --warmup_epochs 5 \ 37 | --train_batch_size 2 \ 38 | --eval_batch_size 4 \ 39 | --test_batch_size 2 | tee "${exp_prefix}log.txt" 40 | elif [ "$ACTION" = "eval" -o "$ACTION" = "predict" ]; then 41 | split=${4:-test} 42 | beam_size=${5:-50} 43 | test_batch_size=${6:-2} 44 | echo "Predicting ${split} with beam_size: ${beam_size} and batch_size: ${test_batch_size}" 45 | python run_multitask_generator_final.py \ 46 | --do_predict \ 47 | --do_debug ${do_debug} \ 48 | --predict_split ${split} \ 49 | --epochs 20 \ 50 | --lr 5e-5 \ 51 | --max_tgt_len 110 \ 52 | --max_src_len 256 \ 53 | --iters_to_accumulate 1 \ 54 | --eval_beams ${beam_size} \ 55 | --pretrained_model_path t5-base \ 56 | --output_dir ${exp_prefix} \ 57 | --model_save_dir "${exp_prefix}model_saved" \ 58 | --normalize_relations \ 59 | --sample_size 10 \ 60 | --add_prefix \ 61 | --model T5_MultiTask_Relation_Entity_Concat \ 62 | --overwrite_output_dir \ 63 | --dataset_type WebQSP \ 64 | --train_batch_size 2 \ 65 | --eval_batch_size 4 \ 66 | --test_batch_size ${test_batch_size} 67 | else 68 | echo "train or eval" 69 | fi -------------------------------------------------------------------------------- /scripts/run_cross_encoder_WebQSP_question_relation.sh: -------------------------------------------------------------------------------- 1 | ACTION=${1:-none} 2 | exp_id=${2:-none} 3 | 4 | dataset='WebQSP' 5 | exp_prefix="data/WebQSP/relation_retrieval/cross-encoder/saved_models/${exp_id}/" 6 | log_dir="data/WebQSP/relation_retrieval/cross-encoder/saved_models/${exp_id}/" 7 | 8 | 9 | if [ "$ACTION" = "train" ]; then 10 | 11 | if [ -d ${exp_prefix} ]; then 12 | echo "${exp_prefix} already exists" 13 | else 14 | mkdir ${exp_prefix} 15 | fi 16 | if [ -d ${log_dir} ]; then 17 | echo "${log_dir} already exists" 18 | else 19 | mkdir ${log_dir} 20 | fi 21 | python relation_retrieval/cross-encoder/cross_encoder_main.py \ 22 | --do_train \ 23 | --max_len 34 \ 24 | --batch_size 128 \ 25 | --epochs 3 \ 26 | --log_dir ${log_dir} \ 27 | --dataset_type WebQSP \ 28 | --model_save_path ${exp_prefix} \ 29 | --output_dir ${exp_prefix} \ 30 | --cache_dir bert-base-uncased \ 31 | 32 | elif [ "$ACTION" = "eval" ]; then 33 | split=${3:-test} 34 | model_name=${4:-none} 35 | if [ -d "${exp_prefix}${model_name}_${split}/" ]; then 36 | echo "${exp_prefix}${model_name}_${split}/ already exists" 37 | else 38 | mkdir "${exp_prefix}${model_name}_${split}/" 39 | fi 40 | echo "Evaluating ${split}" 41 | python relation_retrieval/cross-encoder/cross_encoder_main.py \ 42 | --do_eval \ 43 | --predict_split ${split} \ 44 | --max_len 34 \ 45 | --batch_size 128 \ 46 | --epochs 3 \ 47 | --log_dir ${log_dir} \ 48 | --dataset_type WebQSP \ 49 | --model_save_path "${exp_prefix}${model_name}" \ 50 | --output_dir "${exp_prefix}${model_name}_${split}" \ 51 | --cache_dir bert-base-uncased \ 52 | 53 | elif [ "$ACTION" = "predict" ]; then 54 | split=${3:-test} 55 | model_name=${4:-none} 56 | if [ -d "${exp_prefix}${model_name}_${split}/" ]; then 57 | echo "${exp_prefix}${model_name}_${split}/ already exists" 58 | else 59 | mkdir "${exp_prefix}${model_name}_${split}/" 60 | fi 61 | echo "Predicting ${split}" 62 | python relation_retrieval/cross-encoder/cross_encoder_main.py \ 63 | --do_predict \ 64 | --predict_split ${split} \ 65 | --max_len 34 \ 66 | --batch_size 128 \ 67 | --epochs 3 \ 68 | --log_dir ${log_dir} \ 69 | --dataset_type WebQSP \ 70 | --model_save_path "${exp_prefix}${model_name}" \ 71 | --output_dir "${exp_prefix}${model_name}_${split}" \ 72 | --cache_dir bert-base-uncased \ 73 | 74 | else 75 | echo "train or eval or predict" 76 | fi -------------------------------------------------------------------------------- /entity_retrieval/bert_ranker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | from logging import log 10 | import torch 11 | from torch import nn 12 | from torch.nn import CrossEntropyLoss 13 | from transformers import( 14 | BertPreTrainedModel, 15 | BertModel, 16 | ) 17 | 18 | def get_inf_mask(bool_mask): 19 | return (~bool_mask) * -100000.0 20 | 21 | class BertForCandidateRanking(BertPreTrainedModel): 22 | """Use Bert to rank candidates, inheritated from BertPreTrainedModel""" 23 | def __init__(self, config): 24 | super().__init__(config) 25 | 26 | self.bert = BertModel(config) 27 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 28 | self.classifier = nn.Linear(config.hidden_size, 1) 29 | 30 | self.init_weights() 31 | 32 | # for training return loss, [batch_size * num_sample] 33 | # for testing, batch size have to be 1 34 | def forward( 35 | self, 36 | input_ids=None, 37 | attention_mask=None, 38 | token_type_ids=None, 39 | position_ids=None, 40 | head_mask=None, 41 | inputs_embeds=None, 42 | sample_mask=None, 43 | labels=None, 44 | output_attentions=None, 45 | output_hidden_states=None, 46 | return_dict=None, 47 | ): 48 | assert return_dict is None 49 | # return_dict = return_dict if return_dict is not None else self.config.use_return_dict 50 | return_dict = False 51 | 52 | # for training, input is batch_size * sample_size * L 53 | # for testing, it is batch_size * L 54 | if labels is not None: 55 | # input_ids.size: #[batch_size, sample_size, max_length] 56 | batch_size = input_ids.size(0) 57 | sample_size = input_ids.size(1) 58 | # flatten first two dim 59 | input_ids = input_ids.view((batch_size * sample_size,-1)) 60 | token_type_ids = token_type_ids.view((batch_size * sample_size,-1)) 61 | attention_mask = attention_mask.view((batch_size * sample_size,-1)) 62 | 63 | outputs = self.bert( 64 | input_ids, 65 | attention_mask=attention_mask, 66 | token_type_ids=token_type_ids, 67 | position_ids=position_ids, 68 | head_mask=head_mask, 69 | inputs_embeds=inputs_embeds, 70 | output_attentions=output_attentions, 71 | output_hidden_states=output_hidden_states, 72 | return_dict=return_dict, 73 | ) 74 | 75 | pooled_output = outputs[1] 76 | 77 | pooled_output = self.dropout(pooled_output) 78 | logits = self.classifier(pooled_output) 79 | 80 | loss = None 81 | if labels is not None: 82 | # reshape logits 83 | logits = logits.view((batch_size, sample_size)) 84 | logits = logits + get_inf_mask(sample_mask) 85 | # apply infmask 86 | loss_fct = CrossEntropyLoss() 87 | loss = loss_fct(logits, labels.view(-1)) 88 | else: 89 | logits = logits.squeeze(1) 90 | 91 | if not return_dict: 92 | output = (logits,) + outputs[2:] 93 | return ((loss,) + output) if loss is not None else output 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | # *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | 131 | # VSCode settings 132 | .vscode/ 133 | 134 | 135 | 136 | 137 | 138 | # data 139 | data/common_data/bak 140 | data/common_data/facc1/entity_list_file_freebase_complete_all_mention 141 | data/common_data/facc1/freebase_complete_all_mention_mid_vocab 142 | data/common_data/facc1/freebase_complete_all_mention_surface_index 143 | data/common_data/facc1/surface_map_file_freebase_complete_all_mention 144 | data/CWQ 145 | data/WebQSP 146 | data/CWQ_uploaded 147 | data/WebQSP_uploaded 148 | data/bak 149 | 150 | 151 | 152 | 153 | 154 | # BERT NER model for entity mention detection 155 | entity_retrieval/BERT_NER/trained_ner_model/ 156 | entity_retrieval/entity_disamb_CWQ/ 157 | entity_retrieval/entity_disamb_CWQ_prev/ 158 | entity_retrieval/entity_disamb_WebQSP/ 159 | 160 | 161 | # huggingface cache 162 | hfcache/ 163 | 164 | 165 | # feature cache 166 | feature_cache 167 | 168 | # experiment models and logs 169 | exps/ 170 | exps_final/ 171 | 172 | WebQSP/ 173 | bak/ 174 | relation_retrieval/bi-encoder/bak 175 | relation_retrieval/cross-encoder/bak 176 | relation_retrieval/bak 177 | generation/bak 178 | generation/models/bak 179 | entity_retrieval/bak 180 | components/bak 181 | notes/ 182 | scripts/backup -------------------------------------------------------------------------------- /generation/cwq_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | """ 4 | @File : cwq_evaluate.py 5 | @Time : 2022/01/10 16:14:58 6 | @Author : Xixin Hu 7 | @Version : 1.0 8 | @Contact : xixinhu97@foxmail.com 9 | @Desc : None 10 | """ 11 | 12 | # here put the import lib 13 | import argparse 14 | from executor import sparql_executor 15 | from components.utils import dump_json, load_json 16 | from tqdm import tqdm 17 | import os 18 | 19 | 20 | def cwq_evaluate_valid_results(args): 21 | """Compute P, R and F1 for CWQ""" 22 | pred_data = load_json(args.pred_file) 23 | # origin dataset 24 | dataset_data = load_json(f'data/CWQ/origin/ComplexWebQuestions_{args.split}.json') 25 | 26 | dataset_dict = {x["ID"]:x for x in dataset_data} 27 | 28 | p_list = [] 29 | r_list = [] 30 | f_list = [] 31 | p_dict = {} 32 | r_dict = {} 33 | f_dict = {} 34 | acc_num = 0 35 | 36 | pred_dict = {} 37 | acc_qid_list = [] # Pred Answer ACC 38 | for pred in pred_data: 39 | qid = pred['qid'] 40 | pred_answer = set(pred['answer']) 41 | pred_dict[qid]=pred_answer 42 | 43 | for qid,example in tqdm(dataset_dict.items()): 44 | 45 | gt_sparql = example['sparql'] 46 | if 'answer' in example: 47 | gt_answer = set(example['answer']) 48 | else: 49 | gt_answer = set(sparql_executor.execute_query(gt_sparql)) 50 | 51 | # for dev split 52 | # gt_answer = set([item["answer_id"] for item in example["answers"]]) 53 | 54 | pred_answer = set(pred_dict.get(qid,{})) 55 | 56 | # assert len(pred_answer)>0 and len(gt_answer)>0 57 | if pred_answer == gt_answer: 58 | acc_num+=1 59 | acc_qid_list.append(qid) 60 | 61 | if len(pred_answer)== 0: 62 | if len(gt_answer)==0: 63 | p=1 64 | r=1 65 | f=1 66 | else: 67 | p=0 68 | r=0 69 | f=0 70 | elif len(gt_answer)==0: 71 | p=0 72 | r=0 73 | f=0 74 | else: 75 | p = len(pred_answer & gt_answer)/ len(pred_answer) 76 | r = len(pred_answer & gt_answer)/ len(gt_answer) 77 | f = 2*(p*r)/(p+r) if p+r>0 else 0 78 | 79 | p_list.append(p) 80 | r_list.append(r) 81 | f_list.append(f) 82 | p_dict[qid] = p 83 | r_dict[qid] = r 84 | f_dict[qid] = f 85 | 86 | p_average = sum(p_list)/len(p_list) 87 | r_average = sum(r_list)/len(r_list) 88 | f_average = sum(f_list)/len(f_list) 89 | 90 | res = f'Total: {len(p_list)}, ACC:{acc_num/len(p_list)}, AVGP: {p_average}, AVGR: {r_average}, AVGF: {f_average}' 91 | print(res) 92 | dirname = os.path.dirname(args.pred_file) 93 | filename = os.path.basename(args.pred_file) 94 | with open (os.path.join(dirname,f'{filename}_final_eval_results.txt'),'w') as f: 95 | f.write(res) 96 | f.flush() 97 | 98 | # Write answer acc result to prediction file 99 | for pred in pred_data: 100 | qid = pred['qid'] 101 | if qid in acc_qid_list: 102 | pred['answer_acc'] = True 103 | else: 104 | pred['answer_acc'] = False 105 | pred['precision'] = p_dict[qid] if qid in p_dict else None 106 | pred['recall'] = r_dict[qid] if qid in r_dict else None 107 | pred['f1'] = f_dict[qid] if qid in f_dict else None 108 | 109 | dump_json(pred_data, os.path.join(dirname, f'{filename}_new.json'), indent=4) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument( 115 | "--split", 116 | type=str, 117 | required=True, 118 | help="split to operate on, can be `test`, `dev` and `train`", 119 | ) 120 | parser.add_argument( 121 | "--pred_file", type=str, default=None, help="prediction results file" 122 | ) 123 | 124 | args = parser.parse_args() 125 | 126 | cwq_evaluate_valid_results(args) 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /relation_retrieval/bi-encoder/biencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.cuda.amp import autocast 7 | from tqdm import tqdm 8 | from transformers import AutoModel 9 | 10 | 11 | class BiEncoderModule(torch.nn.Module): 12 | def __init__(self, device, bert_model="bert-base-uncased", tokenizer=None, freeze_bert=False): 13 | super(BiEncoderModule, self).__init__() 14 | self.question_bert_layer = AutoModel.from_pretrained(bert_model) 15 | self.relation_bert_layer = AutoModel.from_pretrained(bert_model) 16 | self.device = device 17 | if tokenizer: 18 | self.question_bert_layer.resize_token_embeddings(len(tokenizer)) 19 | self.relation_bert_layer.resize_token_embeddings(len(tokenizer)) 20 | # Freeze bert layers and only train the classification layer weights 21 | if freeze_bert: 22 | for p in self.question_bert_layer.parameters(): 23 | p.requires_grad = False 24 | for p in self.relation_bert_layer.parameters(): 25 | p.requires_grad = False 26 | 27 | @autocast() 28 | def forward( 29 | self, 30 | question_input_ids, 31 | question_attn_masks, 32 | question_token_type_ids, 33 | relations_input_ids, 34 | relations_attn_masks, 35 | relations_token_type_ids, 36 | golden_id 37 | ): 38 | embedding_question = self.question_bert_layer(question_input_ids, question_attn_masks, question_token_type_ids).pooler_output 39 | 40 | 41 | embedding_relations = [] 42 | # bert only accept (batch_size, maxlen) size of input, while embedding_relations is with the size (batch_size, sample_size, maxlen) 43 | for i in range(0, relations_input_ids.shape[1]): 44 | relation_input_id = relations_input_ids[:,i,:] 45 | # print('relation_input_id: {}'.format(relation_input_id.shape)) # (batch_size, maxlen) 46 | relations_attn_mask = relations_attn_masks[:,i,:] 47 | relations_token_type_id = relations_token_type_ids[:,i,:] 48 | embedding_relation = self.relation_bert_layer(relation_input_id, relations_attn_mask, relations_token_type_id).pooler_output 49 | embedding_relations.append(embedding_relation) 50 | 51 | embedding_relations = torch.stack(embedding_relations, dim=1) 52 | 53 | embedding_question = embedding_question.unsqueeze(1) 54 | # print('embedding_question: {}'.format(embedding_question.shape)) # (batch_size, 1, 768) 55 | # print('embedding_relations: {}'.format(embedding_relations.shape)) # (batch_size, sample_size, 768) 56 | 57 | scores = torch.bmm(embedding_question, torch.transpose(embedding_relations, 1, 2)).squeeze(1) 58 | # print('scores: {}'.format(scores.shape)) # (batch_size, sample_size) 59 | loss = self.calculate_loss(scores, golden_id) 60 | 61 | return scores.to(self.device), loss 62 | 63 | @autocast() 64 | def calculate_loss(self, scores, golden_id): 65 | """ 66 | scores: (batch_size, sample_size) 67 | golden_id: (batch_size) 68 | loss = -scores[golden_id] + log \sum_{i=1}^B exp(scores[i]) 69 | """ 70 | assert len(golden_id.shape) == 1, print(golden_id.shape) 71 | assert golden_id.shape[0] == scores.shape[0], print('golden_id: {}, scores: {}'.format(golden_id.shape, scores.shape)) 72 | 73 | loss_fct = nn.CrossEntropyLoss() 74 | loss = loss_fct(scores.to(self.device), golden_id.to(self.device)) / scores.shape[0] 75 | 76 | return loss 77 | 78 | def encode_question(self, question_token_ids, question_attn_masks, question_token_type_ids): 79 | """ 80 | question_token_ids: (batch_size, maxlen) 81 | """ 82 | question_representation = self.question_bert_layer(question_token_ids, question_attn_masks, question_token_type_ids).pooler_output # (batch_size, 768) 83 | return question_representation 84 | 85 | def encode_relation(self, relation_input_id, relations_attn_mask, relations_token_type_id): 86 | """ 87 | relation_input_id: (batch_size, maxlen) 88 | """ 89 | relation_representation = self.relation_bert_layer(relation_input_id, relations_attn_mask, relations_token_type_id).pooler_output # (batch_size, 768) 90 | return relation_representation -------------------------------------------------------------------------------- /generation_command.txt: -------------------------------------------------------------------------------- 1 | cheatsheet for main experiments 2 | CWQ: 3 | 1. T5-base 4 | train: CUDA_VISIBLE_DEVICES=2 sh scripts/t5_base_CWQ.sh train t5_base False 5 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_t5_base/beam_50_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 50 6 | 7 | 2. T5-base Concatenating Retrieval 8 | train & predict: CUDA_VISIBLE_DEVICES=2 sh scripts/concat_retrieval_CWQ.sh train concat_retrieval False 9 | predict: CUDA_VISIBLE_DEVICES=3 sh scripts/concat_retrieval_CWQ.sh predict concat_retrieval False test 50 4 10 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_concat_retrieval/beam_50_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 50 11 | 12 | 3. T5-base Concatenating Oracle 13 | train & predict: CUDA_VISIBLE_DEVICES=2 sh scripts/concat_gold_CWQ.sh train concat_goldEnt_goldRel False 14 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_concat_goldEnt_goldRel/beam_50_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 50 15 | 16 | 4. GMT-KBQA Without entity 17 | train & predict: CUDA_VISIBLE_DEVICES=1 sh scripts/without_entity_CWQ.sh train without_entity False 18 | predict: CUDA_VISIBLE_DEVICES=0 sh scripts/without_entity_CWQ.sh predict without_entity False test 50 4 19 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_without_entity/beam_50_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 50 20 | 21 | 5. GMT-KBQA without relation 22 | train & predict: CUDA_VISIBLE_DEVICES=0 sh scripts/without_relation_CWQ.sh train without_relation False 23 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_without_relation/beam_50_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 50 24 | 25 | 6. GMT-KBQA 26 | train & predict: CUDA_VISIBLE_DEVICES=1 sh scripts/GMT_KBQA_CWQ.sh train GMT_KBQA False 27 | predict: CUDA_VISIBLE_DEVICES=1 sh scripts/GMT_KBQA_CWQ.sh predict GMT_KBQA False test 50 4 28 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_GMT_KBQA/beam_50_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 50 29 | 30 | 6.1 different inference beam size, taking beam size=10 as example 31 | predict: CUDA_VISIBLE_DEVICES=1 sh scripts/GMT_KBQA_CWQ.sh predict GMT_KBQA False test 10 4 32 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/CWQ_GMT_KBQA/beam_10_test_4_top_k_predictions.json --test_batch_size 4 --dataset CWQ --beam_size 10 33 | 34 | WebQSP: 35 | 36 | 1. T5-base 37 | train & predict: CUDA_VISIBLE_DEVICES=1 sh scripts/t5_base_WebQSP.sh train t5_base False 38 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_t5_base/beam_50_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 50 39 | 40 | 41 | 2. T5-base Concatenating Retrieval 42 | train & predict: CUDA_VISIBLE_DEVICES=0 sh scripts/concat_retrieval_WebQSP.sh train concat_retrieval False 43 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_concat_retrieval/beam_50_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 50 44 | 45 | 3. T5-base Concatenating Oracle 46 | train & predict: CUDA_VISIBLE_DEVICES=4 sh scripts/concat_gold_WebQSP.sh train concat_goldEnt_goldRel False 47 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_concat_goldEnt_goldRel/beam_50_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 50 48 | 49 | 4. GMT-KBQA Without entity 50 | train & predict: CUDA_VISIBLE_DEVICES=3 sh scripts/without_entity_WebQSP.sh train without_entity False 51 | predict: CUDA_VISIBLE_DEVICES=2 sh scripts/without_entity_WebQSP.sh predict without_entity False test 50 2 52 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_without_entity/beam_50_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 50 53 | 54 | 5. GMT-KBQA Without relation 55 | train & predict: CUDA_VISIBLE_DEVICES=3 sh scripts/without_relation_WebQSP.sh train without_relation False 56 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_without_relation/beam_50_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 50 57 | 58 | 6. GMT-KBQA 59 | train & predict: CUDA_VISIBLE_DEVICES=1 sh scripts/GMT_KBQA_WebQSP.sh train GMT_KBQA False 60 | predict: CUDA_VISIBLE_DEVICES=1 sh scripts/GMT_KBQA_WebQSP.sh predict GMT_KBQA False test 50 2 61 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_GMT_KBQA/beam_50_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 50 62 | 63 | 6.1 different inference beam size, taking beam size=10 as example 64 | predict: CUDA_VISIBLE_DEVICES=1 sh scripts/GMT_KBQA_WebQSP.sh predict GMT_KBQA False test 10 2 65 | evaluation: python3 eval_topk_prediction_final.py --split test --pred_file exps/WebQSP_GMT_KBQA/beam_10_test_2_top_k_predictions.json --test_batch_size 2 --dataset WebQSP --beam_size 10 66 | -------------------------------------------------------------------------------- /generation/virtuoso_setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | #!/usr/bin/env python3 10 | 11 | # This script provides a convenient wrapper for the Virtuoso SPARQL server. 12 | # Adapted from Sempre (https://github.com/percyliang/sempre) 13 | 14 | import os 15 | import sys 16 | import subprocess 17 | import argparse 18 | 19 | virtuosoPath = "/export/home/virtuoso-opensource" 20 | if not os.path.exists(virtuosoPath): 21 | print(f"{virtuosoPath} does not exist") 22 | sys.exit(1) 23 | 24 | # Virtuoso has two services: the server (isql) and SPARQL endpoint 25 | def isqlPort(port): return 10000 + port 26 | def httpPort(port): return port 27 | 28 | def run(command): 29 | print(f"RUNNING: {command}") 30 | res = subprocess.run(command, shell=True, stdout=subprocess.PIPE) 31 | return res.stdout 32 | 33 | def start(dbPath, port): 34 | 35 | # if not os.path.exists(dbPath): 36 | # os.mkdir(dbPath) 37 | 38 | # Recommended: 70% of RAM, each buffer is 8K 39 | # Use a fraction of the free RAM. The result may vary across runs. 40 | # memFree = parseInt(`cat /proc/meminfo | grep MemFree | awk '{print $2}'`) # KB 41 | # Use a fraction of the total RAM. The result is the same across runs. 42 | memFree = int(run("cat /proc/meminfo | grep MemTotal | awk '{print $2}'")) # KB 43 | memFree = int(memFree/4) 44 | print(memFree) 45 | numberOfBuffers = memFree * 0.70 / 8 46 | maxDirtyBuffers = numberOfBuffers / 2 47 | print(f"{memFree} KB free, using {numberOfBuffers} buffers, {maxDirtyBuffers} dirty buffers") 48 | 49 | # Configuration options: 50 | # http://docs.openlinksw.com/virtuoso/dbadm.html 51 | # http://virtuoso.openlinksw.com/dataspace/doc/dav/wiki/Main/VirtConfigScale 52 | config = ( 53 | f"[Database]\n" 54 | f"DatabaseFile = {dbPath}/virtuoso.db\n" 55 | f"ErrorLogFile = {dbPath}/virtuoso.log\n" 56 | f"LockFile = {dbPath}/virtuoso.lck\n" 57 | f"TransactionFile = {dbPath}/virtuoso.trx\n" 58 | f"xa_persistent_file = {dbPath}/virtuoso.pxa\n" 59 | f"ErrorLogLevel = 7\n" 60 | f"FileExtend = 200\n" 61 | f"MaxCheckpointRemap = 2000\n" 62 | f"Striping = 0\n" 63 | f"TempStorage = TempDatabase\n" 64 | f"\n" 65 | f"[TempDatabase]\n" 66 | f"DatabaseFile = {dbPath}/virtuoso-temp.db\n" 67 | f"TransactionFile = {dbPath}/virtuoso-temp.trx\n" 68 | f"MaxCheckpointRemap = 2000\n" 69 | f"Striping = 0\n" 70 | f"\n" 71 | f"[Parameters]\n" 72 | f"ServerPort = {isqlPort(port)}\n" 73 | f"LiteMode = 0\n" 74 | f"DisableUnixSocket = 1\n" 75 | f"DisableTcpSocket = 0\n" 76 | f"ServerThreads = 100 ; increased from 20\n" 77 | f"CheckpointInterval = 60\n" 78 | f"O_DIRECT = 1 ; increased from 0\n" 79 | f"CaseMode = 2\n" 80 | f"MaxStaticCursorRows = 100000\n" 81 | f"CheckpointAuditTrail = 0\n" 82 | f"AllowOSCalls = 0\n" 83 | f"SchedulerInterval = 10\n" 84 | f"DirsAllowed = .\n" 85 | f"ThreadCleanupInterval = 0\n" 86 | f"ThreadThreshold = 10\n" 87 | f"ResourcesCleanupInterval = 0\n" 88 | f"FreeTextBatchSize = 100000\n" 89 | # f"SingleCPU = 0\n" 90 | f"PrefixResultNames = 0\n" 91 | f"RdfFreeTextRulesSize = 100\n" 92 | f"IndexTreeMaps = 256\n" 93 | f"MaxMemPoolSize = 200000000\n" 94 | f"PrefixResultNames = 0\n" 95 | f"MacSpotlight = 0\n" 96 | f"IndexTreeMaps = 64\n" 97 | f"NumberOfBuffers = {numberOfBuffers}\n" 98 | f"MaxDirtyBuffers = {maxDirtyBuffers}\n" 99 | f"\n" 100 | f"[SPARQL]\n" 101 | f"ResultSetMaxRows = 50000\n" 102 | f"MaxQueryCostEstimationTime = 600 ; in seconds (increased)\n" 103 | f"MaxQueryExecutionTime = 180; in seconds (increased)\n" 104 | f"\n" 105 | f"[HTTPServer]\n" 106 | f"ServerPort = {httpPort(port)}\n" 107 | f"Charset = UTF-8\n" 108 | f"ServerThreads = 15 ; increased from unknown\n" 109 | ) 110 | 111 | configPath = f"{dbPath}/virtuoso.ini" 112 | print(config) 113 | print() 114 | print(configPath) 115 | print(f"==== Starting Virtuoso server for {dbPath} on port {port}...") 116 | with open(configPath, 'w') as f: 117 | f.write(config) 118 | run(f"{virtuosoPath}/bin/virtuoso-t +configfile {configPath} +wait") 119 | 120 | def stop(port): 121 | run(f"echo 'shutdown;' | {virtuosoPath}/bin/isql localhost:{isqlPort(port)}") 122 | 123 | def status(port): 124 | run(f"echo 'status();' | {virtuosoPath}/bin/isql localhost:{isqlPort(port)}") 125 | 126 | ############################################################ 127 | # Main 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser(description="manage Virtuoso services") 131 | parser.add_argument("action", type=str, help="start or stop") 132 | parser.add_argument("port", type=int, help="port for the SPARQL HTTP endpoint") 133 | parser.add_argument("-d", "--db-path", type=str, help="path to the db directory") 134 | 135 | args = parser.parse_args() 136 | if args.action == "start": 137 | if not args.db_path: 138 | print("please specify path to the db directory with -d") 139 | sys.exit() 140 | 141 | if not os.path.isdir(args.db_path): 142 | print("the path specified does not exist") 143 | sys.exit() 144 | 145 | start(args.db_path, args.port) 146 | elif args.action == "stop": 147 | stop(args.port) 148 | elif args.action == "status": 149 | status(args.port) 150 | else: 151 | print(f"invalid action: ${args.action}") 152 | sys.exit() 153 | -------------------------------------------------------------------------------- /relation_retrieval/bi-encoder/faiss_indexer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | FAISS-based index components. Originated from 9 | https://github.com/facebookresearch/DPR/blob/master/dpr/indexer/faiss_indexers.py 10 | """ 11 | 12 | import os 13 | import logging 14 | import pickle 15 | 16 | import faiss 17 | import numpy as np 18 | 19 | logger = logging.getLogger() 20 | 21 | 22 | class DenseIndexer(object): 23 | def __init__(self, buffer_size: int = 50000): 24 | self.buffer_size = buffer_size 25 | self.index_id_to_db_id = [] 26 | self.index = None 27 | 28 | def index_data(self, data: np.array): 29 | raise NotImplementedError 30 | 31 | def search_knn(self, query_vectors: np.array, top_docs: int): 32 | raise NotImplementedError 33 | 34 | def serialize(self, index_file: str): 35 | logger.info("Serializing index to %s", index_file) 36 | faiss.write_index(self.index, index_file) 37 | 38 | def deserialize_from(self, index_file: str): 39 | logger.info("Loading index from %s", index_file) 40 | self.index = faiss.read_index(index_file) 41 | logger.info( 42 | "Loaded index of type %s and size %d", type(self.index), self.index.ntotal 43 | ) 44 | 45 | 46 | # DenseFlatIndexer does exact search 47 | class DenseFlatIndexer(DenseIndexer): 48 | def __init__(self, vector_sz: int = 1, buffer_size: int = 50000): 49 | super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) 50 | self.index = faiss.IndexFlatIP(vector_sz) 51 | 52 | def index_data(self, data: np.array): 53 | n = len(data) 54 | # indexing in batches is beneficial for many faiss index types 55 | logger.info("Indexing data, this may take a while.") 56 | cnt = 0 57 | for i in range(0, n, self.buffer_size): 58 | vectors = [np.reshape(t, (1, -1)) for t in data[i : i + self.buffer_size]] 59 | vectors = np.concatenate(vectors, axis=0) 60 | self.index.add(vectors) 61 | cnt += self.buffer_size 62 | 63 | logger.info("Total data indexed %d", n) 64 | 65 | def search_knn(self, query_vectors, top_k): 66 | scores, indexes = self.index.search(query_vectors, top_k) 67 | return scores, indexes 68 | 69 | 70 | # DenseHNSWFlatIndexer does approximate search 71 | class DenseHNSWFlatIndexer(DenseIndexer): 72 | """ 73 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 74 | """ 75 | 76 | def __init__( 77 | self, 78 | vector_sz: int, 79 | buffer_size: int = 50000, 80 | store_n: int = 128, 81 | ef_search: int = 256, 82 | ef_construction: int = 200, 83 | ): 84 | super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) 85 | 86 | # IndexHNSWFlat supports L2 similarity only 87 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 88 | index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) 89 | index.hnsw.efSearch = ef_search 90 | index.hnsw.efConstruction = ef_construction 91 | self.index = index 92 | self.phi = 0 93 | 94 | def index_data(self, data: np.array): 95 | n = len(data) 96 | 97 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 98 | if self.phi > 0: 99 | raise RuntimeError( 100 | "DPR HNSWF index needs to index all data at once," 101 | "results will be unpredictable otherwise." 102 | ) 103 | phi = 0 104 | for i, item in enumerate(data): 105 | doc_vector = item 106 | norms = (doc_vector ** 2).sum() 107 | phi = max(phi, norms) 108 | logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) 109 | self.phi = 0 110 | 111 | # indexing in batches is beneficial for many faiss index types 112 | logger.info("Indexing data, this may take a while.") 113 | cnt = 0 114 | for i in range(0, n, self.buffer_size): 115 | vectors = [np.reshape(t, (1, -1)) for t in data[i : i + self.buffer_size]] 116 | 117 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 118 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 119 | hnsw_vectors = [ 120 | np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) 121 | for i, doc_vector in enumerate(vectors) 122 | ] 123 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 124 | 125 | self.index.add(hnsw_vectors) 126 | cnt += self.buffer_size 127 | logger.info("Indexed data %d" % cnt) 128 | 129 | logger.info("Total data indexed %d" % n) 130 | 131 | def search_knn(self, query_vectors, top_k): 132 | aux_dim = np.zeros(len(query_vectors), dtype="float32") 133 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 134 | logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) 135 | scores, indexes = self.index.search(query_nhsw_vectors, top_k) 136 | return scores, indexes 137 | 138 | def deserialize_from(self, file: str): 139 | super(DenseHNSWFlatIndexer, self).deserialize_from(file) 140 | # to trigger warning on subsequent indexing 141 | self.phi = 1 142 | -------------------------------------------------------------------------------- /components/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | import pickle 10 | import json 11 | import os 12 | import shutil 13 | import re 14 | from typing import List 15 | from executor.sparql_executor import get_label_with_odbc 16 | 17 | 18 | def dump_to_bin(obj, fname): 19 | with open(fname, "wb") as f: 20 | pickle.dump(obj, f) 21 | 22 | 23 | def load_bin(fname): 24 | with open(fname, "rb") as f: 25 | return pickle.load(f) 26 | 27 | 28 | def load_json(fname, mode="r", encoding="utf8"): 29 | if "b" in mode: 30 | encoding = None 31 | with open(fname, mode=mode, encoding=encoding) as f: 32 | return json.load(f) 33 | 34 | 35 | def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False): 36 | if "b" in mode: 37 | encoding = None 38 | with open(fname, "w", encoding=encoding) as f: 39 | return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii) 40 | 41 | 42 | def mkdir_f(prefix): 43 | if os.path.exists(prefix): 44 | shutil.rmtree(prefix) 45 | os.makedirs(prefix) 46 | 47 | 48 | def mkdir_p(prefix): 49 | if not os.path.exists(prefix): 50 | os.makedirs(prefix) 51 | 52 | 53 | illegal_xml_re = re.compile(u'[\x00-\x08\x0b-\x1f\x7f-\x84\x86-\x9f\ud800-\udfff\ufdd0-\ufddf\ufffe-\uffff]') 54 | def clean_str(s: str) -> str: 55 | """remove illegal unicode characters""" 56 | return illegal_xml_re.sub('',s) 57 | 58 | 59 | 60 | def tokenize_s_expr(expr): 61 | expr = expr.replace('(', ' ( ') 62 | expr = expr.replace(')', ' ) ') 63 | toks = expr.split(' ') 64 | toks = [x for x in toks if len(x)] 65 | return toks 66 | 67 | def extract_mentioned_entities_from_sexpr(expr:str) -> List[str]: 68 | expr = expr.replace('(', ' ( ') 69 | expr = expr.replace(')', ' ) ') 70 | toks = expr.split(' ') 71 | toks = [x for x in toks if len(x)] 72 | entitiy_tokens = [] 73 | for t in toks: 74 | # normalize entity 75 | if t.startswith('m.') or t.startswith('g.'): 76 | entitiy_tokens.append(t) 77 | return entitiy_tokens 78 | 79 | def extract_mentioned_entities_from_sparql(sparql:str) -> List[str]: 80 | """extract entity from sparql""" 81 | sparql = sparql.replace('(',' ( ').replace(')',' ) ') 82 | toks = sparql.split(' ') 83 | toks = [x.replace('\t.','') for x in toks if len(x)] 84 | entity_tokens = [] 85 | for t in toks: 86 | if t.startswith('ns:m.') or t.startswith('ns:g.'): 87 | entity_tokens.append(t[3:]) 88 | 89 | entity_tokens = list(set(entity_tokens)) 90 | return entity_tokens 91 | 92 | def extract_mentioned_relations_from_sparql(sparql:str): 93 | """extract relation from sparql""" 94 | sparql = sparql.replace('(',' ( ').replace(')',' ) ') 95 | toks = sparql.split(' ') 96 | toks = [x for x in toks if len(x)] 97 | relation_tokens = [] 98 | for t in toks: 99 | if (re.match("ns:[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*",t.strip()) 100 | or re.match("ns:[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*",t.strip())): 101 | relation_tokens.append(t[3:]) 102 | 103 | relation_tokens = list(set(relation_tokens)) 104 | return relation_tokens 105 | 106 | 107 | def extract_mentioned_relations_from_sexpr(sexpr:str)->List[str]: 108 | sexpr = sexpr.replace('(',' ( ').replace(')',' ) ') 109 | toks = sexpr.split(' ') 110 | toks = [x for x in toks if len(x)] 111 | relation_tokens = [] 112 | 113 | for t in toks: 114 | if (re.match("[a-zA-Z_]*\.[a-zA-Z_]*\.[a-zA-z_]*",t.strip()) 115 | or re.match("[a-zA-Z_]*\.[a-zA-Z_]*\.[a-zA-Z_]*\.[a-zA-Z_]*",t.strip())): 116 | relation_tokens.append(t) 117 | relation_tokens = list(set(relation_tokens)) 118 | return relation_tokens 119 | 120 | def vanilla_sexpr_linearization_method(expr, entity_label_map={}, relation_label_map={}, linear_origin_map={}): 121 | """ 122 | textualize a logical form, replace mids with labels 123 | 124 | Returns: 125 | (str): normalized s_expr 126 | """ 127 | expr = expr.replace("(", " ( ") # add space for parantheses 128 | expr = expr.replace(")", " ) ") 129 | toks = expr.split(" ") # split by space 130 | toks = [x for x in toks if len(x)] 131 | 132 | norm_toks = [] 133 | for t in toks: 134 | 135 | # original token 136 | origin_t = t 137 | 138 | if t.startswith("m.") or t.startswith("g."): # replace entity with its name 139 | if t in entity_label_map: 140 | t = entity_label_map[t] 141 | else: 142 | # name = get_label(t) 143 | name = get_label_with_odbc(t) 144 | if name is not None: 145 | entity_label_map[t] = name 146 | t = name 147 | t = '[ '+t+' ]' 148 | elif "XMLSchema" in t: # remove xml type 149 | format_pos = t.find("^^") 150 | t = t[:format_pos] 151 | elif t == "ge": # replace ge/gt/le/lt 152 | t = "GREATER EQUAL" 153 | elif t == "gt": 154 | t = "GREATER THAN" 155 | elif t == "le": 156 | t = "LESS EQUAL" 157 | elif t == "lt": 158 | t = "LESS THAN" 159 | else: 160 | t = t.replace("_", " ") # replace "_" with " " 161 | t = t.replace(".", " , ") # replace "." with " , " 162 | 163 | if "." in origin_t: # relation 164 | t = "[ "+t+" ]" 165 | relation_label_map[origin_t]=t 166 | 167 | norm_toks.append(t) 168 | linear_origin_map[t] = origin_t # for reverse transduction 169 | 170 | return " ".join(norm_toks) 171 | 172 | def _textualize_relation(r): 173 | """return a relation string with '_' and '.' replaced""" 174 | if "_" in r: # replace "_" with " " 175 | r = r.replace("_", " ") 176 | if "." in r: # replace "." with " , " 177 | r = r.replace(".", " , ") 178 | return r -------------------------------------------------------------------------------- /generation/webqsp_evaluate_offcial.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on offcial evaluation script from https://www.microsoft.com/en-us/download/details.aspx?id=52763 3 | """ 4 | import json 5 | import os 6 | import argparse 7 | 8 | def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False): 9 | if "b" in mode: 10 | encoding = None 11 | with open(fname, "w", encoding=encoding) as f: 12 | return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii) 13 | 14 | def load_json(fname, mode="r", encoding="utf8"): 15 | if "b" in mode: 16 | encoding = None 17 | with open(fname, mode=mode, encoding=encoding) as f: 18 | return json.load(f) 19 | 20 | def webqsp_evaluate_valid_results(args): 21 | if args.split == 'dev': 22 | res = main(args.pred_file, f'data/WebQSP/origin/WebQSP.pdev.json') 23 | else: 24 | res = main(args.pred_file, f'data/WebQSP/origin/WebQSP.{args.split}.json') 25 | dirname = os.path.dirname(args.pred_file) 26 | filename = os.path.basename(args.pred_file) 27 | with open (os.path.join(dirname,f'{filename}_final_eval_results_official.txt'),'w') as f: 28 | f.write(res) 29 | f.flush() 30 | 31 | def FindInList(entry,elist): 32 | for item in elist: 33 | if entry == item: 34 | return True 35 | return False 36 | 37 | def CalculatePRF1(goldAnswerList, predAnswerList): 38 | if len(goldAnswerList) == 0: 39 | if len(predAnswerList) == 0: 40 | return [1.0, 1.0, 1.0] # consider it 'correct' when there is no labeled answer, and also no predicted answer 41 | else: 42 | return [0.0, 1.0, 0.0] # precision=0 and recall=1 when there is no labeled answer, but has some predicted answer(s) 43 | elif len(predAnswerList)==0: 44 | return [1.0, 0.0, 0.0] # precision=1 and recall=0 when there is labeled answer(s), but no predicted answer 45 | else: 46 | glist =[x["AnswerArgument"] for x in goldAnswerList] 47 | plist =predAnswerList 48 | 49 | tp = 1e-40 # numerical trick 50 | fp = 0.0 51 | fn = 0.0 52 | 53 | for gentry in glist: 54 | if FindInList(gentry,plist): 55 | tp += 1 56 | else: 57 | fn += 1 58 | for pentry in plist: 59 | if not FindInList(pentry,glist): 60 | fp += 1 61 | 62 | 63 | precision = tp/(tp + fp) 64 | recall = tp/(tp + fn) 65 | 66 | f1 = (2*precision*recall)/(precision+recall) 67 | return [precision, recall, f1] 68 | 69 | 70 | def main(pred_data, dataset_data): 71 | 72 | goldData = load_json(dataset_data) 73 | predAnswers = load_json(pred_data) 74 | 75 | PredAnswersById = {} 76 | 77 | for item in predAnswers: 78 | PredAnswersById[item["QuestionId"]] = item["Answers"] 79 | 80 | total = 0.0 81 | f1sum = 0.0 82 | recSum = 0.0 83 | precSum = 0.0 84 | numCorrect = 0 85 | prediction_res = [] 86 | if "Questions" in goldData: 87 | goldData = goldData["Questions"] 88 | for entry in goldData: 89 | 90 | skip = True 91 | for pidx in range(0,len(entry["Parses"])): 92 | np = entry["Parses"][pidx] 93 | if np["AnnotatorComment"]["QuestionQuality"] == "Good" and np["AnnotatorComment"]["ParseQuality"] == "Complete": 94 | skip = False 95 | 96 | if(len(entry["Parses"])==0 or skip): 97 | continue 98 | 99 | total += 1 100 | 101 | id = entry["QuestionId"] 102 | 103 | if id not in PredAnswersById: 104 | print("The problem " + id + " is not in the prediction set") 105 | print("Continue to evaluate the other entries") 106 | continue 107 | 108 | if len(entry["Parses"]) == 0: 109 | print("Empty parses in the gold set. Breaking!!") 110 | break 111 | 112 | predAnswers = PredAnswersById[id] 113 | 114 | bestf1 = -9999 115 | bestf1Rec = -9999 116 | bestf1Prec = -9999 117 | for pidx in range(0,len(entry["Parses"])): 118 | pidxAnswers = entry["Parses"][pidx]["Answers"] 119 | prec,rec,f1 = CalculatePRF1(pidxAnswers,predAnswers) 120 | if f1 > bestf1: 121 | bestf1 = f1 122 | bestf1Rec = rec 123 | bestf1Prec = prec 124 | 125 | f1sum += bestf1 126 | recSum += bestf1Rec 127 | precSum += bestf1Prec 128 | 129 | pred = {} 130 | pred['qid'] = id 131 | pred['precision'] = bestf1Prec 132 | pred['recall'] = bestf1Rec 133 | pred['f1'] = bestf1 134 | prediction_res.append(pred) 135 | 136 | if bestf1 == 1.0: 137 | numCorrect += 1 138 | 139 | print("Number of questions:", int(total)) 140 | print("Average precision over questions: %.3f" % (precSum / total)) 141 | print("Average recall over questions: %.3f" % (recSum / total)) 142 | print("Average f1 over questions (accuracy): %.3f" % (f1sum / total)) 143 | print("F1 of average recall and average precision: %.3f" % (2 * (recSum / total) * (precSum / total) / (recSum / total + precSum / total))) 144 | print("True accuracy (ratio of questions answered exactly correctly): %.3f" % (numCorrect / total)) 145 | res = f'Number of questions:{int(total)}\n, Average precision over questions: {(precSum / total)}\n, Average recall over questions: {(recSum / total)}\n, Average f1 over questions (accuracy): {(f1sum / total)}\n, F1 of average recall and average precision: {(2 * (recSum / total) * (precSum / total) / (recSum / total + precSum / total))}\n, True accuracy (ratio of questions answered exactly correctly): {(numCorrect / total)}' 146 | dirname = os.path.dirname(pred_data) 147 | filename = os.path.basename(pred_data) 148 | dump_json(prediction_res, os.path.join(dirname, f'{filename}_new.json')) 149 | return res 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument( 154 | "--split", 155 | type=str, 156 | required=True, 157 | help="split to operate on, can be `test`, `dev` and `train`", 158 | ) 159 | parser.add_argument( 160 | "--pred_file", type=str, default=None, help="prediction results file" 161 | ) 162 | args = parser.parse_args() 163 | 164 | webqsp_evaluate_valid_results(args) -------------------------------------------------------------------------------- /entity_retrieval/BERT_NER/bert.py: -------------------------------------------------------------------------------- 1 | """BERT NER Inference. 2 | 3 | Copyright 2021, Ohio State University (Yu Gu) 4 | Yu Gu 5 | Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """ 7 | 8 | from __future__ import absolute_import, division, print_function 9 | 10 | import json 11 | import os 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from nltk import word_tokenize 16 | # from pytorch_transformers import (BertConfig, BertForTokenClassification, 17 | # BertTokenizer) 18 | from transformers import (BertConfig, BertForTokenClassification, BertTokenizer) 19 | 20 | class BertNer(BertForTokenClassification): 21 | 22 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None, device=None): 23 | sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0] 24 | batch_size, max_len, feat_dim = sequence_output.shape 25 | valid_output = torch.zeros(batch_size, max_len, feat_dim, dtype=torch.float32, 26 | device=device if torch.cuda.is_available() else 'cpu') 27 | for i in range(batch_size): 28 | jj = -1 29 | for j in range(max_len): 30 | if valid_ids[i][j].item() == 1: # valid position for tagging 31 | jj += 1 32 | valid_output[i][jj] = sequence_output[i][j] # use the output of the valid position 33 | sequence_output = self.dropout(valid_output) 34 | logits = self.classifier(sequence_output) 35 | return logits 36 | 37 | 38 | class Ner: 39 | """ 40 | Ner model using pytorch_transformers 41 | """ 42 | 43 | def __init__(self, model_dir: str, device="cuda:0"): 44 | self.model, self.tokenizer, self.model_config = self.load_model(model_dir) 45 | self.label_map = self.model_config["label_map"] 46 | self.max_seq_length = self.model_config["max_seq_length"] 47 | self.label_map = {int(k): v for k, v in self.label_map.items()} 48 | # self.device = "cuda" if torch.cuda.is_available() else "cpu" 49 | self.device = device 50 | self.model = self.model.to(self.device) 51 | self.model.eval() 52 | 53 | def load_model(self, model_dir: str, model_config: str = "model_config.json"): 54 | model_config = os.path.join(model_dir, model_config) 55 | model_config = json.load(open(model_config)) 56 | model = BertNer.from_pretrained(model_dir) 57 | tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"]) 58 | return model, tokenizer, model_config 59 | 60 | def tokenize(self, text: str): 61 | """ tokenize input 62 | 63 | return: tokens:List[str], valid_positions:List[int] 64 | """ 65 | words = word_tokenize(text) # nltk.word_tokenize 66 | tokens = [] 67 | valid_positions = [] 68 | for i, word in enumerate(words): 69 | token = self.tokenizer.tokenize(word) # BertTokenizer tokenize 70 | tokens.extend(token) 71 | for i in range(len(token)): 72 | if i == 0: 73 | valid_positions.append(1) # 1 for start position of a token 74 | else: 75 | valid_positions.append(0) 76 | return tokens, valid_positions 77 | 78 | def preprocess(self, text: str): 79 | """ preprocess text with tokenization, special token addition and convertion from token to ids 80 | 81 | @param text: test to preprocess 82 | @return: input_ids, input_mask, segment_ids, valid_positions 83 | 84 | """ 85 | tokens, valid_positions = self.tokenize(text) 86 | ## insert "[CLS]" 87 | tokens.insert(0, "[CLS]") 88 | valid_positions.insert(0, 1) 89 | ## insert "[SEP]" 90 | tokens.append("[SEP]") 91 | valid_positions.append(1) 92 | segment_ids = [] 93 | for i in range(len(tokens)): 94 | segment_ids.append(0) # 0 for first segment, 1 for second segment 95 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) # bert tokenizer convert tokens to ids 96 | input_mask = [1] * len(input_ids) # input mask 97 | while len(input_ids) < self.max_seq_length: 98 | """padding""" 99 | input_ids.append(0) 100 | input_mask.append(0) 101 | segment_ids.append(0) 102 | valid_positions.append(0) 103 | return input_ids, input_mask, segment_ids, valid_positions 104 | 105 | def predict(self, text: str): 106 | input_ids, input_mask, segment_ids, valid_ids = self.preprocess(text) 107 | input_ids = torch.tensor([input_ids], dtype=torch.long, device=self.device) 108 | input_mask = torch.tensor([input_mask], dtype=torch.long, device=self.device) 109 | segment_ids = torch.tensor([segment_ids], dtype=torch.long, device=self.device) 110 | valid_ids = torch.tensor([valid_ids], dtype=torch.long, device=self.device) 111 | with torch.no_grad(): 112 | """inference""" 113 | logits = self.model(input_ids, segment_ids, input_mask, valid_ids, device=self.device) 114 | logits = F.softmax(logits, dim=2) 115 | logits_label = torch.argmax(logits, dim=2) # argmax 116 | logits_label = logits_label.detach().cpu().numpy().tolist()[0] 117 | 118 | # confidence 119 | logits_confidence = [values[label].item() for values, label in zip(logits[0], logits_label)] 120 | 121 | logits = [] 122 | pos = 0 123 | for index, mask in enumerate(valid_ids[0]): 124 | if index == 0: 125 | continue 126 | if mask == 1: # valid position 127 | logits.append((logits_label[index - pos], logits_confidence[index - pos])) 128 | else: 129 | pos += 1 130 | logits.pop() 131 | 132 | labels = [(self.label_map[label], confidence) for label, confidence in logits] 133 | words = word_tokenize(text) # nltk.tokenize 134 | if len(labels) != len(words): 135 | print(text) 136 | print(words) 137 | print(labels) 138 | assert len(labels) == len(words) 139 | output = [{"word": word, "tag": label, "confidence": confidence} for word, (label, confidence) in 140 | zip(words, labels)] 141 | return output 142 | -------------------------------------------------------------------------------- /entity_retrieval/bert_entity_linker.py: -------------------------------------------------------------------------------- 1 | """ 2 | An approach to identify entities in a query. Uses a custom index for entity information. 3 | 4 | Copyright 2015, University of Freiburg. 5 | 6 | Elmar Haussmann 7 | """ 8 | import sys 9 | import os 10 | # sys.path.append("..") 11 | # print(os.getcwd()) 12 | # sys.path.append(os.getcwd()) 13 | # print(f'cwd:{os.getcwd()}') 14 | sys.path.append("..") 15 | # print(f'sys.path:{sys.path}') 16 | 17 | import json 18 | from tqdm import tqdm 19 | from pathlib import Path 20 | # from entity_linking.google_kg_api import get_entity_from_surface 21 | from entity_retrieval.BERT_NER.bert import Ner 22 | from entity_retrieval import surface_index_memory 23 | from entity_retrieval.aqqu_entity_linker import IdentifiedEntity 24 | from entity_retrieval.aqqu_util import normalize_entity_name, remove_prefixes_from_name, remove_suffixes_from_name 25 | 26 | 27 | path = str(Path(__file__).parent.absolute()) 28 | 29 | 30 | class BertEntityLinker: 31 | """ 32 | Identify entities in a query, using bert ner model 33 | """ 34 | 35 | def __init__(self, surface_index, 36 | # Better name it max_entities_per_surface 37 | max_entities_per_tokens=4, 38 | model_path="/BERT_NER/out_base_gq/", 39 | device="cuda:0" 40 | ): 41 | self.surface_index = surface_index 42 | self._model = Ner(path + model_path, device) 43 | 44 | def get_mentions(self, question: str): 45 | question = question.lower() 46 | output = self._model.predict(question) 47 | mentions = [] 48 | current_mention = [] 49 | for i, token in enumerate(output): 50 | if token['tag'][0] == 'B': 51 | current_mention.append(token['word']) 52 | elif token['tag'][0] == 'I': 53 | current_mention.append(token['word']) 54 | else: 55 | if len(current_mention) > 0: 56 | mentions.append(' '.join(current_mention)) 57 | current_mention = [] 58 | if i == len(output) - 1 and len(current_mention) > 0: 59 | mentions.append(' '.join(current_mention)) 60 | 61 | for i, mention in enumerate(mentions): 62 | # word_tokenize from nltk will change the left " to ``, which is pretty weird. Fix it here 63 | mentions[i] = mention.replace('``', '"').replace("''", '"') 64 | 65 | return mentions 66 | 67 | def _text_matches_main_name(self, entity, text): 68 | 69 | """ 70 | Check if the entity name is a perfect match on the text. 71 | :param entity: 72 | :param text: 73 | :return: 74 | """ 75 | text = normalize_entity_name(text) 76 | text = remove_prefixes_from_name(text) 77 | name = remove_suffixes_from_name(entity.name) 78 | name = normalize_entity_name(name) 79 | name = remove_prefixes_from_name(name) 80 | if name == text: 81 | return True 82 | return False 83 | 84 | def get_entities(self, utterance): 85 | entities = {} 86 | identified_entities = self.identify_entities(utterance) 87 | 88 | for entity in identified_entities: 89 | entities[entity.entity.id] = entity.entity.name 90 | 91 | return entities 92 | 93 | def identify_entities(self, utterance, min_surface_score=0.3): 94 | """ 95 | identify entities from utterance 96 | 97 | @param utterance: text to identify entities 98 | @param min_surface_score: min score for surface form 99 | @return: a sorted list of IdentifiedEntity(surface_form,entity_name, mid, mid.score, surface_form.socre, is_perfect_match) 100 | """ 101 | mentions = self.get_mentions(utterance) 102 | identified_entities = [] 103 | mids = set() 104 | for mention in mentions: 105 | # use facc1 106 | entities = self.surface_index.get_entities_for_surface(mention) 107 | 108 | # use google kg api 109 | # entities = get_entity_from_surface(mention) 110 | # if len(entities) == 0: 111 | # entities = get_entity_from_surface(mention) 112 | 113 | # empty entities, strip or add 'the' 114 | if len(entities) == 0 and len(mention) > 3 and mention.split()[0] == 'the': 115 | # strip 'the' 116 | mention = mention[3:].strip() 117 | entities = self.surface_index.get_entities_for_surface(mention) 118 | elif len(entities) == 0 and f'the {mention}' in utterance: 119 | # add 'the' 120 | mention = f'the {mention}' 121 | entities = self.surface_index.get_entities_for_surface(mention) 122 | 123 | # no entities retrieved, continue 124 | if len(entities) == 0: 125 | continue 126 | 127 | entities = sorted(entities, key=lambda x:x[1], reverse=True) 128 | for i, (e, surface_score) in enumerate(entities): 129 | if e.id in mids: 130 | continue 131 | # Ignore entities with low surface score. But if even the top 1 entity is lower than the threshold, 132 | # we keep it 133 | if surface_score < min_surface_score and i > 0: 134 | continue 135 | perfect_match = False 136 | # Check if the main name of the entity exactly matches the text. 137 | # I only use the label as surface, so the perfect match is always True 138 | if self._text_matches_main_name(e, mention): 139 | perfect_match = True 140 | ie = IdentifiedEntity(mention, 141 | e.name, e, e.score, surface_score, 142 | perfect_match) 143 | # self.boost_entity_score(ie) 144 | identified_entities.append(ie) 145 | mids.add(e.id) 146 | 147 | identified_entities = sorted(identified_entities, key=lambda x: x.surface_score, reverse=True) 148 | 149 | return identified_entities 150 | 151 | 152 | if __name__ == '__main__': 153 | # an example of how to use our BERT entity linker 154 | 155 | surface_index = surface_index_memory.EntitySurfaceIndexMemory( 156 | "data/entity_list_file_freebase_complete_all_mention", 157 | "data/surface_map_file_freebase_complete_all_mention", 158 | "data/freebase_complete_all_mention") 159 | entity_linker = BertEntityLinker(surface_index, model_path="/BERT_NER/trained_ner_model/", device="cuda:1") 160 | identified_entities = entity_linker.identify_entities("safety and tolerance of intermittent intravenous and oral zidovudine therapy in human immunodeficiency virus-infected pediatric patients. pediatric zidovudine phase i study group. is a medical trial for what?") 161 | for ie in identified_entities: 162 | print(f'mention:{ie.mention}\nname:{ie.name}\nmid:{ie.entity.id}') 163 | 164 | -------------------------------------------------------------------------------- /inputDataset/gen_mtl_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | IGONORED_DOMAIN_LIST = ['type', 'common', 'kg', 'dataworld'] 5 | 6 | 7 | def _tokenize_relation(r): 8 | return r.replace('.', ' ').replace('_', ' ').split() 9 | 10 | class MTLGenerationExample: 11 | """ 12 | Multi Task Generation Example 13 | """ 14 | def __init__(self, dict_data): 15 | """ Initialize from dict data""" 16 | self.ID = dict_data['ID'] 17 | self.question = dict_data['question'] 18 | self.comp_type = dict_data['comp_type'] 19 | self.sparql = dict_data['sparql'] 20 | self.sexpr = dict_data['sexpr'] 21 | self.normed_sexpr = dict_data['normed_sexpr'] 22 | self.gold_entity_map = dict_data['gold_entity_map'] 23 | self.gold_relation_map = dict_data['gold_relation_map'] 24 | self.gold_type_map = dict_data['gold_type_map'] 25 | self.cand_relation_list = dict_data['cand_relation_list'] 26 | self.answer = dict_data['answer'] 27 | self.cand_entity_list = dict_data['cand_entity_list'] 28 | self.disambiguated_cand_entity = dict_data['disambiguated_cand_entity'] 29 | 30 | 31 | def __str__(self) -> str: 32 | return f'{self.question}\n\t->{self.normed_sexpr}' 33 | 34 | def __repr__(self) -> str: 35 | return self.__str__() 36 | 37 | 38 | class MTLGenDataset(Dataset): 39 | """Dataset for MTLGeneration""" 40 | 41 | def __init__( 42 | self, 43 | examples, 44 | tokenizer, 45 | do_lower=True, 46 | normalize_relations=False, 47 | max_src_len=256, 48 | max_tgt_len=196, 49 | add_prefix=False, 50 | ): 51 | # super().__init__() 52 | self.examples = examples 53 | self.tokenizer = tokenizer 54 | self.do_lower = do_lower 55 | self.normalize_relations = normalize_relations 56 | self.max_src_len = max_src_len 57 | self.max_tgt_len = max_tgt_len 58 | self.add_prefix = add_prefix 59 | self.REL_TOKEN = ' [REL] ' 60 | self.ENT_TOKEN = ' [ENT] ' 61 | self.LITERAL_TOKEN = ' [LIT] ' 62 | self.SEPERATOR = ' | ' 63 | 64 | def __len__(self): 65 | return len(self.examples) 66 | 67 | def __getitem__(self, index): 68 | example = self.examples[index] 69 | 70 | ID = example.ID 71 | question = example.question 72 | normed_sexpr = example.normed_sexpr 73 | 74 | candidate_relations = [x[0] for x in example.cand_relation_list] 75 | gold_relation_set = set(example.gold_relation_map.keys()) 76 | 77 | relation_labels = [(rel in gold_relation_set) for rel in candidate_relations] 78 | relation_clf_pairs_labels = torch.LongTensor(relation_labels) 79 | 80 | # entity id identifies diffrent entities 81 | gold_entities_ids_set = set([item.lower() for item in example.gold_entity_map.keys()]) 82 | 83 | entity_labels = [(ent['id'] in gold_entities_ids_set) for ent in example.cand_entity_list] 84 | entity_clf_pairs_labels = torch.LongTensor(entity_labels) 85 | 86 | input_src = question 87 | 88 | if self.do_lower: 89 | input_src = input_src.lower() 90 | normed_sexpr = normed_sexpr.lower() 91 | 92 | gen_src = input_src 93 | if self.add_prefix: 94 | gen_src = 'Translate to S-Expression: ' + input_src 95 | if self.do_lower: 96 | gen_src = gen_src.lower() 97 | tokenized_src = self.tokenizer( 98 | gen_src, 99 | max_length=self.max_src_len, 100 | truncation=True, 101 | return_tensors='pt', 102 | ).data['input_ids'].squeeze(0) 103 | 104 | # Concatenate candidate entities & relations 105 | gen_src_concatenated = input_src 106 | if self.add_prefix: 107 | gen_src_concatenated = 'Translate to S-Expression: ' + gen_src_concatenated 108 | for rel in example.cand_relation_list: 109 | logits = float(rel[1]) 110 | if logits > 0.0: 111 | if self.normalize_relations: 112 | gen_src_concatenated += self.REL_TOKEN + _textualize_relation(rel[0]) 113 | else: 114 | gen_src_concatenated += self.REL_TOKEN + rel[0] 115 | gen_src_concatenated += self.SEPERATOR 116 | 117 | for ent in example.disambiguated_cand_entity: 118 | gen_src_concatenated += self.ENT_TOKEN + ent['label'] 119 | gen_src_concatenated += self.SEPERATOR 120 | 121 | if self.do_lower: 122 | gen_src_concatenated = gen_src_concatenated.lower() 123 | 124 | tokenized_src_concatenated = self.tokenizer( 125 | gen_src_concatenated, 126 | max_length=self.max_src_len, 127 | truncation=True, 128 | return_tensors='pt', 129 | ).data['input_ids'].squeeze(0) 130 | 131 | # concatenate golden entities/relations 132 | gen_src_golden_concatenated = input_src 133 | if self.add_prefix: 134 | gen_src_golden_concatenated = 'Translate to S-Expression: ' + gen_src_golden_concatenated 135 | for rel in example.gold_relation_map: 136 | if self.normalize_relations: 137 | gen_src_golden_concatenated += self.REL_TOKEN + _textualize_relation(rel) 138 | else: 139 | gen_src_golden_concatenated += self.REL_TOKEN + rel 140 | gen_src_golden_concatenated += self.SEPERATOR 141 | for mid in example.gold_entity_map: 142 | gen_src_golden_concatenated += self.ENT_TOKEN + example.gold_entity_map[mid] # concat label 143 | gen_src_golden_concatenated += self.SEPERATOR 144 | 145 | if self.do_lower: 146 | gen_src_golden_concatenated = gen_src_golden_concatenated.lower() 147 | 148 | tokenized_src_golden_concatenated = self.tokenizer( 149 | gen_src_golden_concatenated, 150 | max_length=self.max_src_len, 151 | truncation=True, 152 | return_tensors='pt', 153 | ).data['input_ids'].squeeze(0) 154 | 155 | 156 | with self.tokenizer.as_target_tokenizer(): 157 | tokenized_tgt = self.tokenizer( 158 | normed_sexpr, 159 | max_length=self.max_tgt_len, 160 | truncation=True, 161 | return_tensors='pt', 162 | ).data['input_ids'].squeeze(0) 163 | 164 | tokenized_relation_clf_pairs = [] 165 | 166 | for cand_rel in candidate_relations: 167 | if self.normalize_relations: 168 | cand_rel = _textualize_relation(cand_rel) 169 | 170 | rel_src = input_src 171 | if self.add_prefix: 172 | rel_src = 'Relation Classification: ' + rel_src 173 | 174 | if self.do_lower: 175 | rel_src = rel_src.lower() 176 | cand_rel = cand_rel.lower() 177 | 178 | tokenized_relation_pair = self.tokenizer( 179 | rel_src, 180 | cand_rel, 181 | max_length=self.max_src_len, 182 | truncation=True, 183 | return_tensors='pt', 184 | ).data['input_ids'].squeeze(0) 185 | 186 | tokenized_relation_clf_pairs.append(tokenized_relation_pair) 187 | 188 | tokenized_entity_clf_pairs = [] 189 | question_tokens = question.split(' ') 190 | 191 | for cand_ent in example.cand_entity_list: 192 | label = cand_ent['label'] 193 | def key_func(r): 194 | r_tokens = _tokenize_relation(r) 195 | overlapping_val = len(set(question_tokens) & set(r_tokens)) 196 | return( 197 | -overlapping_val 198 | ) 199 | 200 | one_hop_relations = cand_ent['1hop_relations'] 201 | one_hop_relations = [x for x in one_hop_relations if x.split('.')[0] not in IGONORED_DOMAIN_LIST] 202 | one_hop_relations = sorted(one_hop_relations, key=lambda x: key_func(x)) 203 | 204 | ent_info = label 205 | for rel in one_hop_relations[:3]: 206 | if self.normalize_relations: 207 | ent_info += (self.SEPERATOR + _textualize_relation(rel)) 208 | else: 209 | ent_info += (self.SEPERATOR + rel) 210 | 211 | ent_src = input_src 212 | if self.add_prefix: 213 | ent_src = 'Entity Classification: ' + input_src 214 | 215 | if self.do_lower: 216 | ent_info = ent_info.lower() 217 | ent_src = ent_src.lower() 218 | 219 | tokenized_entity_pair = self.tokenizer( 220 | ent_src, 221 | ent_info, 222 | max_length=self.max_src_len, 223 | truncation=True, 224 | return_tensors='pt' 225 | ).data['input_ids'].squeeze(0) 226 | 227 | tokenized_entity_clf_pairs.append(tokenized_entity_pair) 228 | 229 | return ( 230 | tokenized_src, 231 | tokenized_tgt, 232 | tokenized_relation_clf_pairs, 233 | relation_clf_pairs_labels, 234 | [input_src], 235 | candidate_relations, 236 | tokenized_entity_clf_pairs, 237 | entity_clf_pairs_labels, 238 | example.cand_entity_list, 239 | tokenized_src_concatenated, 240 | tokenized_src_golden_concatenated 241 | ) 242 | 243 | 244 | 245 | def _textualize_relation(r): 246 | """return a relation string with '_' and '.' replaced""" 247 | if "_" in r: # replace "_" with " " 248 | r = r.replace("_", " ") 249 | if "." in r: # replace "." with " , " 250 | r = r.replace(".", " , ") 251 | return r 252 | 253 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified on the basis of [RNG-KBQA](https://github.com/salesforce/rng-kbqa). 3 | The original license information is as follows: 4 | Copyright (c) 2021, salesforce.com, inc. 5 | All rights reserved. 6 | SPDX-License-Identifier: BSD-3-Clause 7 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 8 | """ 9 | 10 | 11 | import logging 12 | import random 13 | 14 | import torch 15 | import numpy as np 16 | 17 | from transformers import ( 18 | AutoTokenizer, 19 | AutoConfig, 20 | ) 21 | 22 | 23 | from entity_retrieval.bert_ranker import BertForCandidateRanking 24 | 25 | 26 | MODEL_TYPE_DICT = { 27 | 'bert': BertForCandidateRanking, 28 | } 29 | 30 | ELQ_SERVICE_URL = "http://210.28.134.34:5688/entity_linking" 31 | FREEBASE_SPARQL_WRAPPER_URL = "http://210.28.134.34:8890/sparql" 32 | FREEBASE_ODBC_PORT = "1111" 33 | 34 | def set_seed(args): 35 | random.seed(args.seed) 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | if args.n_gpu > 0: 39 | torch.cuda.manual_seed_all(args.seed) 40 | 41 | 42 | def to_list(tensor): 43 | return tensor.detach().cpu().tolist() 44 | 45 | def register_args(parser): 46 | # Required parameters 47 | parser.add_argument( 48 | "--dataset", 49 | default=None, 50 | type=str, 51 | required=True, 52 | help="dataset to operate on", 53 | ) 54 | parser.add_argument( 55 | "--model_type", 56 | default=None, 57 | type=str, 58 | required=True, 59 | help="Model type", 60 | ) 61 | parser.add_argument( 62 | "--model_name_or_path", 63 | default=None, 64 | type=str, 65 | required=True, 66 | help="Path to pretrained model or model identifier from huggingface.co/models", 67 | ) 68 | parser.add_argument( 69 | "--output_dir", 70 | default=None, 71 | type=str, 72 | required=True, 73 | help="The output directory where the model checkpoints and predictions will be written.", 74 | ) 75 | 76 | # Other parameters 77 | parser.add_argument( 78 | "--data_dir", 79 | default=None, 80 | type=str, 81 | help="The input data dir. Should contain the .json files for the task." 82 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 83 | ) 84 | parser.add_argument( 85 | "--train_file", 86 | default=None, 87 | type=str, 88 | help="The input training file. If a data dir is specified, will look for the file there" 89 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 90 | ) 91 | parser.add_argument( 92 | "--predict_file", 93 | default=None, 94 | type=str, 95 | help="The input evaluation file. If a data dir is specified, will look for the file there" 96 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 97 | ) 98 | parser.add_argument( 99 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 100 | ) 101 | parser.add_argument( 102 | "--tokenizer_name", 103 | default="", 104 | type=str, 105 | help="Pretrained tokenizer name or path if not the same as model_name", 106 | ) 107 | parser.add_argument( 108 | "--cache_dir", 109 | default=None, 110 | type=str, 111 | help="Where do you want to store the pre-trained models downloaded from s3", 112 | ) 113 | 114 | parser.add_argument( 115 | "--max_seq_length", 116 | default=96, 117 | type=int, 118 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 119 | "longer than this will be truncated, and sequences shorter than this will be padded.", 120 | ) 121 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 122 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 123 | parser.add_argument("--do_predict", action="store_true", help="Whether to do prediction.") 124 | parser.add_argument( 125 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." 126 | ) 127 | parser.add_argument( 128 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 129 | ) 130 | 131 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 132 | parser.add_argument( 133 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." 134 | ) 135 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 136 | parser.add_argument( 137 | "--gradient_accumulation_steps", 138 | type=int, 139 | default=1, 140 | help="Number of updates steps to accumulate before performing a backward/update pass.", 141 | ) 142 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 143 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 144 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 145 | parser.add_argument( 146 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform." 147 | ) 148 | parser.add_argument( 149 | "--max_steps", 150 | default=-1, 151 | type=int, 152 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 153 | ) 154 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 155 | parser.add_argument("--warmup_ratio", default=0.0, type=float, help="Linear warmup over warmup ratio.") 156 | parser.add_argument( 157 | "--verbose_logging", 158 | action="store_true", 159 | help="If true, all of the warnings related to data processing will be printed. " 160 | "A number of warnings are expected for a normal SQuAD evaluation.", 161 | ) 162 | 163 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 164 | parser.add_argument("--eval_steps", type=int, default=500, help="Eval every X updates steps.") 165 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 166 | parser.add_argument( 167 | "--eval_all_checkpoints", 168 | action="store_true", 169 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 170 | ) 171 | parser.add_argument( 172 | "--disable_tqdm", action="store_true", help="Disable tqdm bar" 173 | ) 174 | parser.add_argument("--num_contrast_sample", type=int, default=20, help="number of samples in a batch.") 175 | parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") 176 | parser.add_argument( 177 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" 178 | ) 179 | parser.add_argument( 180 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 181 | ) 182 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 183 | 184 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") 185 | parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") 186 | parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") 187 | 188 | parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") 189 | 190 | # train curriculum 191 | parser.add_argument("--training_curriculum", default="random",type=str, choices=["random", "bootstrap", "mixbootstrap"]) 192 | parser.add_argument("--bootstrapping_start", default=None, type=int, help="when to start bootstrapping sampling") 193 | parser.add_argument("--bootstrapping_ticks", default=None, type=str, help="when to update scores for bootstrapping in addition to the startpoint") 194 | 195 | # textualizing choices 196 | parser.add_argument("--linear_method", default="vanilla",type=str, choices=["vanilla", "naive_text", "reduct_text"]) 197 | 198 | # logger 199 | parser.add_argument("--logger",default=None, help="logger") 200 | 201 | def validate_args(args): 202 | # validate before loading data 203 | if args.training_curriculum == "random": 204 | args.bootstrapping_update_epochs = [] 205 | else: 206 | assert args.bootstrapping_start is not None 207 | assert args.bootstrapping_start > 0 208 | 209 | if args.bootstrapping_ticks is None: 210 | bootstrapping_update_epochs = [args.bootstrapping_start] 211 | else: 212 | additional_update_epochs = [int(x) for x in args.bootstrapping_ticks.split(',')] 213 | bootstrapping_update_epochs = [args.bootstrapping_start] + additional_update_epochs 214 | args.bootstrapping_update_epochs = bootstrapping_update_epochs 215 | 216 | def load_untrained_model(args): 217 | args.model_type = args.model_type.lower() 218 | config = AutoConfig.from_pretrained( 219 | args.config_name if args.config_name else args.model_name_or_path, 220 | cache_dir=args.cache_dir if args.cache_dir else None, 221 | ) 222 | tokenizer = AutoTokenizer.from_pretrained( 223 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 224 | do_lower_case=args.do_lower_case, 225 | cache_dir=args.cache_dir if args.cache_dir else None, 226 | ) 227 | model_class = MODEL_TYPE_DICT[args.model_type] 228 | model = model_class.from_pretrained( 229 | args.model_name_or_path, 230 | from_tf=bool(".ckpt" in args.model_name_or_path), 231 | config=config, 232 | cache_dir=args.cache_dir if args.cache_dir else None, 233 | ) 234 | 235 | return config, tokenizer, model 236 | 237 | def get_model_class(args): 238 | return MODEL_TYPE_DICT[args.model_type] -------------------------------------------------------------------------------- /components/expr_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2021, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | 9 | from components.utils import * 10 | 11 | def tokenize_s_expr(expr): 12 | expr = expr.replace('(', ' ( ') 13 | expr = expr.replace(')', ' ) ') 14 | toks = expr.split(' ') 15 | toks = [x for x in toks if len(x)] 16 | return toks 17 | 18 | def extract_entities(expr): 19 | toks = tokenize_s_expr(expr) 20 | return [x for x in toks if x.startswith('m.')] 21 | 22 | 23 | def extract_relations(expr): 24 | toks = tokenize_s_expr(expr) 25 | return [x for x in toks if ('.' in x) and (not x.startswith('m.')) and (not '^^' in x)] 26 | 27 | class ASTNode: 28 | UNARY = 'unary' 29 | BINARY = 'binary' 30 | def __init__(self, construction, val, data_type, fields): 31 | self.construction = construction 32 | self.val = val 33 | # unary or binary 34 | self.data_type = data_type 35 | self.fields = fields 36 | 37 | # determined after construction 38 | self.depth = -1 39 | self.level = -1 40 | 41 | def assign_depth_and_level(self, level=0): 42 | self.level = level 43 | if self.fields: 44 | max_depth = max([x.assign_depth_and_level(level + 1) for x in self.fields]) 45 | self.depth = max_depth + 1 46 | else: 47 | self.depth = 0 48 | return self.depth 49 | 50 | @classmethod 51 | def build(cls, tok, data_type, fields): 52 | if tok == 'AND': 53 | return AndNode(data_type, fields) 54 | elif tok == 'R': 55 | return RNode(data_type, fields) 56 | elif tok == 'COUNT': 57 | return CountNode(data_type, fields) 58 | elif tok == 'JOIN': 59 | return JoinNode(data_type, fields) 60 | elif tok in ['le', 'lt', 'ge', 'gt']: 61 | return CompNode(tok, data_type, fields) 62 | elif tok in ['ARGMIN', 'ARGMAX']: 63 | return ArgNode(tok, data_type, fields) 64 | elif tok.startswith('m.'): 65 | return EntityNode(tok, data_type, fields) 66 | elif '^^http://www.w3.org/2001/XMLSchema' in tok: 67 | return ValNode(tok, data_type, fields) 68 | else: 69 | return SchemaNode(tok, data_type, fields) 70 | 71 | def logical_form(self): 72 | if self.depth == 0: 73 | return self.val 74 | else: 75 | fields_str = [x.logical_form() for x in self.fields] 76 | return ' '.join(['(', self.val] + fields_str + [')']) 77 | 78 | # nothing special. just fit legacy code input syle 79 | def compact_logical_form(self): 80 | lf = self.logical_form() 81 | return lf.replace('( ', '(').replace(' )', ')') 82 | 83 | def skeleton_form(self): 84 | if self.depth == 0: 85 | return self.construction 86 | else: 87 | fields_str = [x.skeleton_form() for x in self.fields] 88 | return ' '.join(['(', self.construction] + fields_str + [')']) 89 | 90 | def logical_form_with_type(self): 91 | if self.depth == 0: 92 | return '{}[{}]'.format(self.val, self.data_type) 93 | else: 94 | fields_str = [x.logical_form_with_type() for x in self.fields] 95 | return ' '.join(['(', '{}[{}]'.format(self.val, self.data_type)] + fields_str + [')']) 96 | 97 | def __str__(self): 98 | return self.logical_form() 99 | 100 | def __repr__(self): 101 | return self.logical_form() 102 | 103 | def textual_form_core(self): 104 | raise NotImplementedError('Textual form not implemented for abstract ast node') 105 | 106 | def textual_form(self): 107 | core_form = self.textual_form_core() 108 | if self.depth == 0 or self.level == 0: 109 | return core_form 110 | else: 111 | return '( ' + core_form + ' )' 112 | 113 | class AndNode(ASTNode): 114 | def __init__(self, data_type, fields): 115 | super().__init__('AND', 'AND', data_type, fields) 116 | 117 | def textual_form_core(self): 118 | # the xxx that 119 | if self.fields[0].depth == 0: 120 | return ' '.join([self.fields[0].textual_form() ,'that', self.fields[1].textual_form()]) 121 | # xxx and xxxx 122 | else: 123 | return ' '.join([self.fields[0].textual_form() ,'and', self.fields[1].textual_form()]) 124 | 125 | class RNode(ASTNode): 126 | def __init__(self, data_type, fields): 127 | super().__init__('R', 'R', data_type, fields) 128 | 129 | def textual_form(self): 130 | # only surface relation is reserved 131 | assert self.depth == 1 132 | return self.textual_form_core() 133 | 134 | def textual_form_core(self): 135 | return self.fields[0].textual_form() + ' by' 136 | 137 | class CountNode(ASTNode): 138 | def __init__(self, data_type, fields): 139 | super().__init__('COUNT', 'COUNT', data_type, fields) 140 | 141 | def textual_form_core(self): 142 | return 'how many ' + self.fields[0].textual_form() 143 | 144 | class JoinNode(ASTNode): 145 | def __init__(self, data_type, fields): 146 | super().__init__('JOIN', 'JOIN', data_type, fields) 147 | 148 | def textual_form_core(self): 149 | return ' '.join([self.fields[0].textual_form(), self.fields[1].textual_form()]) 150 | 151 | # argmin argmax 152 | class ArgNode(ASTNode): 153 | def __init__(self, val, data_type, fields): 154 | super().__init__('ARG', val, data_type, fields) 155 | 156 | def textual_form_core(self): 157 | prompt = 'with most' if self.val == 'ARGMAX' else 'with least' 158 | return ' '.join([self.fields[0].textual_form(), prompt, self.fields[1].textual_form()]) 159 | 160 | # lt le gt ge 161 | class CompNode(ASTNode): 162 | PROMPT_DICT = { 163 | 'gt': 'greater than', 164 | 'ge': 'greater equal', 165 | 'lt': 'less than', 166 | 'le': 'less equal', 167 | } 168 | 169 | def __init__(self, val, data_type, fields): 170 | super().__init__('COMP', val, data_type, fields) 171 | 172 | def textual_form_core(self): 173 | prompt = CompNode.PROMPT_DICT[self.val] 174 | return ' '.join([self.fields[0].textual_form(), prompt, self.fields[1].textual_form()]) 175 | 176 | class EntityNode(ASTNode): 177 | def __init__(self, val, data_type, fields): 178 | super().__init__('ENTITY', val, data_type, fields) 179 | 180 | def textual_form_core(self): 181 | return self.val 182 | 183 | class SchemaNode(ASTNode): 184 | def __init__(self, val, data_type, fields): 185 | super().__init__('SCHEMA', val, data_type, fields) 186 | 187 | def textual_form_core(self): 188 | return self.val 189 | 190 | class ValNode(ASTNode): 191 | def __init__(self, val, data_type, fields): 192 | super().__init__('VAL', val, data_type, fields) 193 | 194 | def textual_form_core(self): 195 | return self.val 196 | 197 | def _consume_a_node(tokens, cursor, data_type): 198 | is_root = cursor == 0 199 | cur_tok = tokens[cursor] 200 | cursor += 1 201 | if cur_tok == '(': 202 | node, cursor = _consume_a_node(tokens, cursor, data_type) 203 | assert tokens[cursor] == ')' 204 | cursor += 1 205 | elif cur_tok == 'AND': 206 | # left, right, all unary 207 | left, cursor = _consume_a_node(tokens, cursor, ASTNode.UNARY) 208 | right, cursor = _consume_a_node(tokens, cursor, ASTNode.UNARY) 209 | node = ASTNode.build(cur_tok, data_type, [left, right]) 210 | elif cur_tok == 'JOIN': 211 | # if cur is unary, right unary, else right binary 212 | left, cursor = _consume_a_node(tokens, cursor, ASTNode.BINARY) 213 | right, cursor = _consume_a_node(tokens, cursor, data_type) 214 | node = ASTNode.build(cur_tok, data_type, [left, right]) 215 | elif cur_tok == 'ARGMIN' or cur_tok == 'ARGMAX': 216 | left, cursor = _consume_a_node(tokens, cursor, ASTNode.UNARY) 217 | right, cursor = _consume_a_node(tokens, cursor, ASTNode.BINARY) 218 | node = ASTNode.build(cur_tok, data_type, [left, right]) 219 | elif cur_tok == 'le' or cur_tok == 'lt' or cur_tok == 'ge' or cur_tok == 'gt': 220 | left, cursor = _consume_a_node(tokens, cursor, ASTNode.UNARY) 221 | right, cursor = _consume_a_node(tokens, cursor, ASTNode.UNARY) 222 | node = ASTNode.build(cur_tok, data_type, [left, right]) 223 | elif cur_tok == 'R': 224 | child, cursor = _consume_a_node(tokens, cursor, ASTNode.BINARY) 225 | node = ASTNode.build(cur_tok, data_type, [child]) 226 | elif cur_tok == 'COUNT': 227 | child, cursor = _consume_a_node(tokens, cursor, ASTNode.UNARY) 228 | node = ASTNode.build(cur_tok, data_type, [child]) 229 | else: 230 | # symbol 231 | # class relation 232 | # value 233 | # entity 234 | node = ASTNode.build(cur_tok, ASTNode.UNARY, []) 235 | if is_root: 236 | node.assign_depth_and_level() 237 | 238 | return node, cursor 239 | 240 | # top lvel: and, arg, count, cant be JOIN, LE, R 241 | def parse_s_expr(expr): 242 | tokens = tokenize_s_expr(expr) 243 | # assert tokens[0] == '(' and tokens[-1] == ')' 244 | # tokens = tokens[1:-1] 245 | ast, cursor = _consume_a_node(tokens, 0, 'unary') 246 | assert cursor == len(tokens) 247 | assert ' '.join(tokens) == ast.logical_form() 248 | return ast 249 | 250 | def textualize_s_expr(expr): 251 | ast = parse_s_expr(expr) 252 | return ast.textual_form() 253 | # print(ast.logical_form_with_type()) 254 | 255 | def simplify_textual_form(expr): 256 | toks = expr.split(' ') 257 | 258 | norm_toks = [] 259 | for t in toks: 260 | # normalize entity 261 | if t.startswith('m.'): 262 | pass 263 | elif 'XMLSchema' in t: 264 | pass 265 | elif '.' in t: 266 | meta_relations = t = t.split('.') 267 | t = meta_relations[-1] 268 | if '.' in t: 269 | t = t.replace('.', ' , ') 270 | if '_' in t: 271 | t = t.replace('.', ' , ') 272 | # normalize type 273 | norm_toks.append(t) 274 | return ' '.join(norm_toks) 275 | 276 | def test_text_converter(): 277 | # dataset = load_json('outputs/grailqa_v1.0_dev.json') 278 | dataset = load_json('outputs/grailqa_v1.0_train.json') 279 | import random 280 | random.seed(123) 281 | random.shuffle(dataset) 282 | 283 | templates = set() 284 | for i, data in enumerate(dataset[:100]): 285 | s_expr = data['s_expression'] 286 | question = data['question'] 287 | ast = parse_s_expr(data['s_expression']) 288 | skeleton = ast.skeleton_form() 289 | templates.add(skeleton) 290 | textual_expr = ast.textual_form() 291 | 292 | simplified_expr = simplify_textual_form(textual_expr) 293 | # if ('AND' in skeleton): 294 | # sim_skeleton = skeleton.replace('AND SCHEMA', '') 295 | # if 'AND' in sim_skeleton: 296 | # print('------------------------------') 297 | # print(question) 298 | # print(s_expr) 299 | # if ('COMP' in skeleton): 300 | print(f'----------------{i}--------------') 301 | print(question) 302 | print(s_expr) 303 | print(textual_expr) 304 | print(simplified_expr) 305 | 306 | # print(len(templates)) 307 | # for t in templates: 308 | # if 'COMP' in t: 309 | # print(t) 310 | # # if 'ARG' in t: 311 | # # print(t) 312 | 313 | -------------------------------------------------------------------------------- /relation_retrieval/bi-encoder/run_bi_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | from torch.utils.data import DataLoader, Dataset 7 | from torch.cuda.amp import GradScaler 8 | from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup 9 | from tqdm import tqdm 10 | from sklearn.metrics import accuracy_score 11 | import copy 12 | import argparse 13 | 14 | from biencoder import BiEncoderModule 15 | BLANK_TOKEN = '[BLANK]' 16 | 17 | 18 | def _parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--add_special_tokens', default=False, action='store_true',help='True when mask entity mention') 21 | parser.add_argument('--dataset_type', default="CWQ", type=str, help="CWQ | WebQSP") 22 | parser.add_argument('--model_save_path', default='data/', type=str) 23 | parser.add_argument('--max_len', default=32, type=int, help="32 for CWQ, 80 for WebQSP with richRelation, 28 for LC") 24 | parser.add_argument('--batch_size', default=4, type=int, help="4 for CWQ") 25 | parser.add_argument('--epochs', default=1, type=int, help="1 for CWQ, 3 for WebQSP") 26 | parser.add_argument('--log_dir', default='log/', type=str) 27 | parser.add_argument('--cache_dir', default='bert-base-uncased', type=str) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def data_process(dataset_type): 33 | if dataset_type == "CWQ": 34 | train_df = pd.read_csv('data/CWQ/relation_retrieval/bi-encoder/CWQ.train.sampled.tsv', delimiter='\t',dtype={"id":int, "question":str, "relation":str, 'label':int}) 35 | dev_df = pd.read_csv('data/CWQ/relation_retrieval/bi-encoder/CWQ.dev.sampled.tsv', delimiter='\t',dtype={"id":int, "question":str, "relation":str, 'label':int}) 36 | else: 37 | # Use the model saved in last epoch 38 | train_df = pd.read_csv('data/WebQSP/relation_retrieval/bi-encoder/WebQSP.train.sampled.tsv', delimiter='\t',dtype={"id":int, "question":str, "relation":str, 'label':int}) 39 | dev_df = None 40 | 41 | return train_df, dev_df 42 | 43 | def set_seed(seed): 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.deterministic = True 47 | torch.backends.cudnn.benchmark = False 48 | np.random.seed(seed) 49 | random.seed(seed) 50 | os.environ['PYTHONHASHSEED'] = str(seed) 51 | 52 | 53 | def evaluate(model, device, dataloader): 54 | model.eval() 55 | 56 | mean_loss = 0 57 | count = 0 58 | golden_truth = [] 59 | preds = [] 60 | 61 | with torch.no_grad(): 62 | for question_token_ids, question_attn_masks, question_token_type_ids, relations_token_ids, relations_attn_masks, relations_token_type_ids, golden_id in tqdm(dataloader): 63 | scores, loss = model( 64 | question_token_ids.to(device), 65 | question_attn_masks.to(device), 66 | question_token_type_ids.to(device), 67 | relations_token_ids.to(device), 68 | relations_attn_masks.to(device), 69 | relations_token_type_ids.to(device), 70 | golden_id.to(device) 71 | ) 72 | mean_loss += loss 73 | count += 1 74 | pred_id = torch.argmax(scores, dim=1) 75 | # print('pred_id: {}'.format(pred_id.shape)) 76 | # print('golden_id: {}'.format(golden_id.shape)) 77 | preds += pred_id.tolist() 78 | golden_truth += golden_id.tolist() 79 | 80 | accuracy = accuracy_score(golden_truth, preds) 81 | 82 | return mean_loss / count, accuracy 83 | 84 | 85 | class CustomDataset(Dataset): 86 | def __init__(self, data, maxlen, tokenizer=None, bert_model='bert-base-uncased', sample_size=100): 87 | self.data = data 88 | self.sample_size = sample_size 89 | self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(bert_model) 90 | self.maxlen = maxlen 91 | 92 | def __len__(self): 93 | return int(len(self.data) / self.sample_size) 94 | 95 | def __getitem__(self, index): 96 | start = self.sample_size * index 97 | end = min(self.sample_size*(index+1), len(self.data)) 98 | question = str(self.data.loc[start, 'question']) 99 | relations = [str(self.data.loc[i, 'relation']) for i in range(start, end)] 100 | golden_id = [i-start for i in range(start, end) if self.data.loc[i, 'label'] == 1] 101 | assert len(golden_id) == 1, print(start, end) 102 | 103 | encoded_question = self.tokenizer( 104 | question, 105 | padding='max_length', 106 | truncation=True, 107 | max_length=self.maxlen, 108 | return_tensors='pt' 109 | ) 110 | encoded_relations = [self.tokenizer( 111 | relation, 112 | padding='max_length', 113 | truncation=True, 114 | max_length=self.maxlen, 115 | return_tensors='pt' 116 | ) for relation in relations] 117 | 118 | question_token_ids = encoded_question['input_ids'].squeeze(0) # tensor of token ids 119 | question_attn_masks = encoded_question['attention_mask'].squeeze(0) # binary tensor with "0" for padded values and "1" for the other values 120 | question_token_type_ids = encoded_question['token_type_ids'].squeeze(0) # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens 121 | 122 | relations_token_ids = torch.cat([encoded_relation['input_ids'] for encoded_relation in encoded_relations], 0) 123 | relations_attn_masks = torch.cat([encoded_relation['attention_mask'] for encoded_relation in encoded_relations], 0) 124 | relations_token_type_ids = torch.cat([encoded_relation['token_type_ids'] for encoded_relation in encoded_relations], 0) 125 | 126 | return question_token_ids, question_attn_masks, question_token_type_ids, relations_token_ids, relations_attn_masks, relations_token_type_ids, golden_id[0] 127 | 128 | 129 | def train_bert(model, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate, device, log_path, model_save_path, dataset_type): 130 | nb_iterations = len(train_loader) 131 | print_every = nb_iterations // 5 132 | if log_path: 133 | log_w = open(log_path, 'w') 134 | scaler = GradScaler() 135 | best_loss = np.Inf 136 | best_epoch = 1 137 | 138 | for ep in range(epochs): 139 | model.train() 140 | running_loss = 0.0 141 | 142 | for it, (question_token_ids, question_attn_masks, question_token_type_ids, relations_token_ids, relations_attn_masks, relations_token_type_ids, golden_id) in enumerate(tqdm(train_loader)): 143 | scores, loss = model( 144 | question_token_ids.to(device), 145 | question_attn_masks.to(device), 146 | question_token_type_ids.to(device), 147 | relations_token_ids.to(device), 148 | relations_attn_masks.to(device), 149 | relations_token_type_ids.to(device), 150 | golden_id.to(device) 151 | ) 152 | loss = loss / iters_to_accumulate 153 | scaler.scale(loss).backward() 154 | 155 | if (it + 1) % iters_to_accumulate == 0: 156 | scaler.step(opti) 157 | # Updates the scale for next iteration. 158 | scaler.update() 159 | # Adjust the learning rate based on the number of iterations. 160 | lr_scheduler.step() 161 | # Clear gradients 162 | opti.zero_grad() 163 | 164 | running_loss += loss.item() 165 | if (it + 1) % print_every == 0: # Print training loss information 166 | print() 167 | print("Iteration {}/{} of epoch {} complete. Loss : {} " 168 | .format(it+1, nb_iterations, ep+1, running_loss / print_every)) 169 | 170 | running_loss = 0.0 171 | 172 | if val_loader: 173 | val_loss, accuracy = evaluate(model, device, val_loader) 174 | print("Epoch {} complete! Validation Loss : {}".format(ep+1, val_loss)) 175 | print("Accuracy on dev data: {}\n".format(accuracy)) 176 | if log_w: 177 | log_w.write("Epoch {} complete! Validation Loss : {}\n".format(ep+1, val_loss)) 178 | log_w.write("Accuracy on dev data: {}\n".format(accuracy)) 179 | # Recording validation loss, while still saving models of every epoch 180 | model_copy = copy.deepcopy(model) 181 | if val_loss < best_loss: 182 | print("Best validation loss improved from {} to {}".format(best_loss, val_loss)) 183 | print() 184 | best_loss = val_loss 185 | best_epoch = ep+1 186 | 187 | model_path = os.path.join(model_save_path, '{}_ep_{}.pt'.format(dataset_type, ep+1)) 188 | torch.save(model_copy.state_dict(), model_path) 189 | print("The model has been saved in {}".format(model_path)) 190 | 191 | if log_w: 192 | log_w.close() 193 | print('Best epoch is: {}, with validation loss: {}'.format(best_epoch, best_loss)) 194 | del loss 195 | torch.cuda.empty_cache() 196 | 197 | 198 | def main(args): 199 | bert_model = args.cache_dir 200 | freeze_bert = False 201 | maxlen = args.max_len 202 | bs = args.batch_size 203 | iters_to_accumulate = 2 # the gradient accumulation adds gradients over an effective batch of size : bs * iters_to_accumulate. If set to "1", you get the usual batch size 204 | lr = 2e-5 # learning rate 205 | epochs = args.epochs 206 | log_path = os.path.join(args.log_dir, 'log.txt') 207 | 208 | if args.add_special_tokens: 209 | print('add special tokens') 210 | tokenizer = AutoTokenizer.from_pretrained(bert_model) 211 | special_tokens_dict = {'additional_special_tokens': [BLANK_TOKEN]} 212 | tokenizer.add_special_tokens(special_tokens_dict) 213 | else: 214 | tokenizer = AutoTokenizer.from_pretrained(bert_model) 215 | 216 | set_seed(1) 217 | print("Reading training data...") 218 | train_df, dev_df = data_process(args.dataset_type) 219 | print(train_df.shape) 220 | train_set = CustomDataset(train_df, maxlen, tokenizer=tokenizer, bert_model=bert_model) 221 | train_loader = DataLoader(train_set, batch_size=bs, num_workers=2) 222 | if dev_df is not None: 223 | print("Reading validation data...") 224 | print(dev_df.shape) 225 | val_set = CustomDataset(dev_df, maxlen, tokenizer=tokenizer, bert_model=bert_model) 226 | val_loader = DataLoader(val_set, batch_size=bs, num_workers=2) 227 | else: 228 | val_loader = None 229 | 230 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 231 | model = BiEncoderModule(device, bert_model=bert_model, tokenizer=tokenizer, freeze_bert=freeze_bert) 232 | model.to(device) 233 | 234 | opti = AdamW(model.parameters(), lr=lr, weight_decay=1e-2) 235 | num_warmup_steps = 0 # The number of steps for the warmup phase. 236 | num_training_steps = epochs * len(train_loader) # The total number of training steps 237 | t_total = (len(train_loader) // iters_to_accumulate) * epochs # Necessary to take into account Gradient accumulation 238 | lr_scheduler = get_linear_schedule_with_warmup(optimizer=opti, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) 239 | 240 | train_bert(model, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate, device, log_path, args.model_save_path, args.dataset_type) 241 | 242 | 243 | if __name__=='__main__': 244 | args = _parse_args() 245 | print(args) 246 | main(args) -------------------------------------------------------------------------------- /entity_retrieval/surface_index_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides access to entities via IDs (MIDs) and surface forms (aliases). 3 | 4 | Each entity is assigned an ID equivalent to the byte offset in the entity list 5 | file. A hashmap stores a mapping from MID to this offset. Additionally, 6 | another hashmap stores a mapping from surface form to this offset, along with 7 | a score. 8 | Matched entities with additional info (scores, other aliases) are then read 9 | from the list file using the found offset. This avoids keeping all entities 10 | with unneeded info in RAM. 11 | 12 | Note: this can be improved in terms of required RAM. 13 | 14 | Copyright 2015, University of Freiburg. 15 | 16 | Elmar Haussmann 17 | """ 18 | import mmap 19 | import logging 20 | import os 21 | import array 22 | import marshal 23 | 24 | # import globals 25 | # from common.globals_args import fn_cwq_file 26 | # from common.hand_files import write_set 27 | from entity_retrieval import aqqu_entity_linker 28 | from entity_retrieval.aqqu_util import normalize_entity_name 29 | import collections 30 | 31 | logger = logging.getLogger(__name__) 32 | logging.basicConfig(level="INFO") 33 | 34 | class EntitySurfaceIndexMemory(object): 35 | """A memory based index for finding entities. 36 | Remember to delete the old _mid_vocab and _surface_index if updated the file (or choose a different prefix) 37 | """ 38 | 39 | def __init__(self, entity_list_file, surface_map_file, entity_index_prefix): 40 | self.entity_list_file = entity_list_file 41 | self.surface_map_file = surface_map_file 42 | 43 | # mid_vocabulary: {mid:offset} 44 | self.mid_vocabulary = self._get_entity_vocabulary(entity_index_prefix) 45 | # surface_indxe: 46 | self.surface_index = self._get_surface_index(entity_index_prefix) 47 | 48 | self.entities_mm_f = open(entity_list_file, 'r') 49 | self.entities_mm = mmap.mmap(self.entities_mm_f.fileno(), 0,access=mmap.ACCESS_READ) 50 | logger.info("Done initializing surface index.") 51 | 52 | def _get_entity_vocabulary(self, index_prefix): 53 | """Return vocabulary by building a new or reading an existing one. 54 | 55 | :param index_prefix: 56 | :return: 57 | """ 58 | vocab_file = index_prefix + "_mid_vocab" 59 | if os.path.isfile(vocab_file): 60 | """ 61 | Mid vocabulary file already exists 62 | """ 63 | logger.info("Loading entity vocabulary from disk.") 64 | vocabulary = marshal.load(open(vocab_file, 'rb')) 65 | else: 66 | """ 67 | Mid vocabulary does not exist, build it. 68 | """ 69 | vocabulary = self._build_entity_vocabulary() 70 | logger.info("Writing entity vocabulary to disk.") 71 | marshal.dump(vocabulary, open(vocab_file, 'wb')) 72 | return vocabulary 73 | 74 | def _get_surface_index(self, index_prefix): 75 | """Return surface index by building new or reading existing one. 76 | 77 | :param index_prefix: 78 | :return: 79 | """ 80 | surface_index_file = index_prefix + "_surface_index" 81 | if os.path.isfile(surface_index_file): 82 | logger.info("Loading surfaces from disk.") 83 | surface_index = marshal.load(open(surface_index_file, 'rb')) 84 | else: 85 | surface_index = self._build_surface_index() 86 | logger.info("Writing entity surfaces to disk.") 87 | marshal.dump(surface_index, open(surface_index_file, 'wb')) 88 | return surface_index 89 | 90 | def _build_surface_index(self): 91 | """Build the surface index. 92 | 93 | Reads from the surface map on disk and creates a map from 94 | surface_form -> offset, score .... 95 | 96 | :return: 97 | """ 98 | n_lines = 0 99 | surface_index = dict() 100 | num_not_found = 0 101 | with open(self.surface_map_file, 'r',encoding="utf-8") as f: 102 | for line in f: 103 | n_lines += 1 104 | if n_lines % 1000 == 0: 105 | logger.info('Bulding surface-forms (%s/5996)' % (n_lines//10000)) 106 | try: 107 | cols = line.rstrip().split('\t') 108 | surface_form = cols[0] # surface_form 109 | # surface_form = normalize_entity_name(surface_form) 110 | surface_form = normalize_entity_name(surface_form) # normalized entity name 111 | score = float(cols[1]) # popularity score 112 | mid = cols[2] # mid 113 | entity_id = self.mid_vocabulary[mid] # offset 114 | if not surface_form in surface_index: 115 | surface_form_entries = array.array('d') # double (float with 8 bytes) 116 | surface_index[surface_form] = surface_form_entries # {surface_form:[entity_id, score]} 117 | surface_index[surface_form].append(entity_id) # entity_id 118 | surface_index[surface_form].append(score) # score 119 | except KeyError: 120 | num_not_found += 1 121 | if num_not_found < 100: 122 | logger.warn("Mid %s appears in surface map but " 123 | "not in entity list." % cols[2]) 124 | elif num_not_found == 100: 125 | logger.warn("Suppressing further warnings about " 126 | "unfound mids.") 127 | if n_lines % 5000000 == 0: 128 | logger.info('Stored %s surface-forms.' % n_lines) 129 | logger.warn("%s entity appearances in surface map w/o mapping to " 130 | "entity list" % num_not_found) 131 | return surface_index 132 | 133 | def _build_entity_vocabulary(self): 134 | """Create mapping from MID to offset/ID. 135 | 136 | :return: 137 | """ 138 | logger.info("Building entity mid vocabulary.") 139 | mid_vocab = dict() # {mid:offset} 140 | num_lines = 0 141 | # Remember the offset for each entity. 142 | with open(self.entity_list_file, 'r',encoding="utf-8") as f: 143 | # m=mmap.mmap(fileno, length[, flags[, prot[, access[, offset]]]]) 144 | mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) # create a mmap object 145 | offset = mm.tell() # return the pointer 146 | line = mm.readline() # readline 147 | while line: 148 | num_lines += 1 149 | if num_lines % 5000000 == 0: 150 | logger.info('Read %s lines' % num_lines) 151 | cols = line.decode().strip().split('\t') 152 | mid = cols[0] # mid 153 | mid_vocab[mid] = offset # offset 154 | offset = mm.tell() 155 | line = mm.readline() 156 | return mid_vocab 157 | 158 | def get_entity_for_mid(self, mid): 159 | """Returns the entity object for the MID or None if the MID is unknown. 160 | 161 | :param mid: 162 | :return: 163 | """ 164 | try: 165 | offset = self.mid_vocabulary[mid] 166 | entity = self._read_entity_from_offset(int(offset)) 167 | return entity 168 | except KeyError: 169 | logger.warn("Unknown entity mid: '%s'." % mid) 170 | return None 171 | 172 | def get_entities_for_surface(self, surface): 173 | """Return all entities for the surface form. 174 | :param surface: 175 | :return: 176 | """ 177 | # I think we are going to make the mentions in our dataset case sensitive 178 | # surface = normalize_entity_name(surface) 179 | surface = normalize_entity_name(surface) 180 | try: 181 | # Only when read from an existing surface_index, bytestr is a byte string. If it's just created 182 | # in this call, then bytestr is an array 183 | bytestr = self.surface_index[surface] # bytestr 184 | if isinstance(bytestr, array.array): 185 | ids_array = bytestr 186 | else: 187 | ids_array = array.array('d') 188 | ids_array.frombytes(bytestr) # [offset1,surface_score1,...] 189 | result = [] 190 | i = 0 191 | while i < len(ids_array) - 1: 192 | offset = ids_array[i] 193 | surface_score = ids_array[i + 1] 194 | entity = self._read_entity_from_offset(int(offset)) 195 | # Check if the main name of the entity exactly matches the text. 196 | result.append((entity, surface_score)) 197 | i += 2 198 | return result 199 | except KeyError: 200 | return [] 201 | 202 | @staticmethod 203 | def _string_to_entity(line): 204 | """Instantiate entity from string representation. 205 | 206 | :param line: 207 | :return: 208 | """ 209 | line = line.decode('utf-8') 210 | cols = line.strip().split('\t') 211 | mid = cols[0] 212 | name = cols[1] 213 | score = int(cols[2]) 214 | aliases = cols[3:] 215 | return aqqu_entity_linker.KBEntity(name, mid, score, aliases) 216 | 217 | def _read_entity_from_offset(self, offset): 218 | """Read entity string representation from offset. 219 | 220 | :param offset: 221 | :return: 222 | """ 223 | self.entities_mm.seek(offset) 224 | l = self.entities_mm.readline() 225 | return self._string_to_entity(l) 226 | 227 | # get second element of a list 228 | def get_indexrange_entity_el_pro_one_mention(self, mention, top_k=10): 229 | tuple_list = self.get_entities_for_surface(mention) 230 | if not tuple_list: 231 | return collections.OrderedDict() 232 | entities_dict = dict() 233 | for entity, surface_score in tuple_list: 234 | entities_dict[entity.id] = surface_score 235 | entities_tuple_list = sorted(entities_dict.items(), key=lambda d:d[1], reverse=True) 236 | result_entities_dict = collections.OrderedDict() 237 | for i, (entity_id, surface_score) in enumerate(entities_tuple_list): 238 | i += 1 239 | result_entities_dict[entity_id] = surface_score 240 | if i >= top_k: 241 | break 242 | return result_entities_dict 243 | 244 | if __name__ == '__main__': 245 | # get_aqqu_mids(fn_cwq_file.entity_list_file,fn_cwq_file.surface_map_file,fn_cwq_file.aqqu_entity_contained) 246 | # def get_aqqu_mids(entity_file, surface_file, aqqu_entityall_file): 247 | # mids = set() 248 | # with open(entity_file, 'r', encoding="utf-8") as f: 249 | # mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) 250 | # line = mm.readline() 251 | # while line: 252 | # cols = line.decode().strip().split('\t') 253 | # mid = cols[0] 254 | # mids.add(mid) 255 | # line = mm.readline() 256 | # with open(surface_file, 'r', encoding="utf-8") as f: 257 | # for line in f: 258 | # cols = line.rstrip().split('\t') 259 | # mid = cols[2] 260 | # mids.add(mid) 261 | # write_set(mids, aqqu_entityall_file) 262 | # main() 263 | # logging.basicConfig( 264 | # format='%(asctime)s : %(levelname)s : %(module)s : %(message)s', level=logging.INFO) 265 | # for entity, surface_score in ( 266 | # entity_linking_aqqu_index.get_entities_for_surface("taylor lautner")): # Albert Einstein 267 | # print(entity.id, surface_score) 268 | # for entity, surface_score in (entity_linking_aqqu_index.get_entities_for_surface('Agusan del Sur')): 269 | # print(entity.id, surface_score) 270 | # mention_to_entities('Agusan del Sur', top_k=10) 271 | # print(mention_to_entities('2010 Formula One World Championship', top_k=10)) 272 | # print(mention_to_entities('Theresa Russo', top_k=10)) 273 | # tuple_list.sort(key=takeSecond) 274 | # print('**************************') 275 | # for entity,surface_score in tuple_list: 276 | # print(entity.id, surface_score) 277 | pass 278 | -------------------------------------------------------------------------------- /detect_and_link_entity.py: -------------------------------------------------------------------------------- 1 | # from typing import final 2 | from tqdm import tqdm 3 | import json 4 | import argparse 5 | from executor.sparql_executor import get_freebase_mid_from_wikiID, get_label, get_label_with_odbc 6 | from entity_retrieval.aqqu_entity_linker import IdentifiedEntity 7 | from entity_retrieval import surface_index_memory 8 | from entity_retrieval.bert_entity_linker import BertEntityLinker 9 | from components.utils import dump_json, load_json, clean_str 10 | import requests 11 | from nltk.tokenize import word_tokenize 12 | import os 13 | from config import ELQ_SERVICE_URL 14 | 15 | """ 16 | This file performs candidate entity linking for CWQ and WebQSP, 17 | using BERT_NER+FACC1 or ELQ. 18 | """ 19 | 20 | def _parse_args(): 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--dataset', default='CWQ', help='dataset to perform entity linking, should be CWQ or WebQSP') 24 | parser.add_argument('--split', required=True, help='split to operate on') # the split file: ['dev','test','train'] 25 | parser.add_argument('--linker', default='FACC1', help='linker, should be FACC1 or ELQ') 26 | 27 | parser.add_argument('--server_ip',default=None,required=False, help='server ip for debugger to attach') 28 | parser.add_argument('--server_port',default=None,required=False, help='server port for debugger to attach') 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def to_output_data_format(identified_entity): 34 | """Transform an identified entity to a dict""" 35 | data = {} 36 | data['label'] = identified_entity.name 37 | data['mention'] = identified_entity.mention 38 | data['pop_score'] = identified_entity.score 39 | data['surface_score'] = identified_entity.surface_score 40 | data['id'] = identified_entity.entity.id 41 | data['aliases'] = identified_entity.entity.aliases 42 | data['perfect_match'] = identified_entity.perfect_match 43 | return data 44 | 45 | 46 | def get_all_entity_candidates(linker, utterance): 47 | """get all the entity candidates given an utterance 48 | 49 | @param linker: entity linker 50 | @param utterance: natural language utterance 51 | @return: a list of all candidate entities 52 | """ 53 | mentions = linker.get_mentions(utterance) # get all the mentions detected by ner model 54 | identified_entities = [] 55 | mids = set() 56 | all_entities = [] 57 | for mention in mentions: 58 | results_per_mention = [] 59 | # use facc1 60 | entities = linker.surface_index.get_entities_for_surface(mention) 61 | # use google kg api 62 | if len(entities) == 0 and len(mention) > 3 and mention.split()[0] == 'the': 63 | mention = mention[3:].strip() 64 | entities = linker.surface_index.get_entities_for_surface(mention) 65 | 66 | elif len(entities) == 0 and f'the {mention}' in utterance: 67 | mention = f'the {mention}' 68 | entities = linker.surface_index.get_entities_for_surface(mention) 69 | 70 | if len(entities) == 0: 71 | continue 72 | 73 | entities = sorted(entities, key=lambda x:x[1], reverse=True) 74 | for i, (e, surface_score) in enumerate(entities): 75 | if e.id in mids: 76 | continue 77 | # Ignore entities with low surface score. But if even the top 1 entity is lower than the threshold, 78 | # we keep it 79 | perfect_match = False 80 | # Check if the main name of the entity exactly matches the text. 81 | # I only use the label as surface, so the perfect match is always True 82 | if linker._text_matches_main_name(e, mention): 83 | perfect_match = True 84 | ie = IdentifiedEntity(mention, 85 | e.name, 86 | e, 87 | e.score, 88 | surface_score, 89 | perfect_match) 90 | # self.boost_entity_score(ie) 91 | # identified_entities.append(ie) 92 | mids.add(e.id) 93 | results_per_mention.append(to_output_data_format(ie)) 94 | results_per_mention.sort(key=lambda x: x['surface_score'], reverse=True) 95 | all_entities.append(results_per_mention) 96 | 97 | return all_entities 98 | 99 | 100 | def dump_entity_linking_results_for_CWQ(split,keep=10): 101 | 102 | # 1. build and load entity linking surface index 103 | # surface_index_memory.EntitySurfaceIndexMemory(entity_list_file, surface_map_file, output_prefix) 104 | surface_index = surface_index_memory.EntitySurfaceIndexMemory( 105 | "data/common_data/facc1/entity_list_file_freebase_complete_all_mention", 106 | "data/common_data/facc1/surface_map_file_freebase_complete_all_mention", 107 | "data/common_data/facc1/freebase_complete_all_mention") 108 | 109 | # 2. load BERTEntityLinker 110 | entity_linker = BertEntityLinker(surface_index, model_path="/BERT_NER/trained_ner_model/", device="cuda:0") 111 | # sanity check 112 | sanity_checking = get_all_entity_candidates(entity_linker, "the music video stronger was directed by whom") 113 | print('RUNNING Sanity Checking on untterance') 114 | print('\t', "the music video stronger was directed by whom") 115 | print('Checking result', sanity_checking[0][:2]) 116 | print('Checking result should successfully link stronger to some nodes in Freebase (MIDs)') 117 | print('If checking result does not look good please check if the linker has been set up successfully') 118 | 119 | # 3. Load dataset split file 120 | #datafile = f'data/origin/ComplexWebQuestions_{split}.json' 121 | datafile = f'data/CWQ/sexpr/CWQ.{split}.expr.json' 122 | data = load_json(datafile, encoding='utf8') 123 | print(len(data)) 124 | 125 | # 4. do entity linking 126 | el_results = {} 127 | for ex in tqdm(data, total=len(data)): 128 | question = ex['question'] 129 | question = clean_str(question) 130 | qid = ex['ID'] 131 | all_candidates = get_all_entity_candidates(entity_linker, question) 132 | all_candidates = [x[:keep] for x in all_candidates] 133 | for instance in all_candidates: 134 | for x in instance: 135 | x['label'] = get_label_with_odbc(x['id']) 136 | el_results[qid]=all_candidates 137 | 138 | # 5. dump the entity linking results 139 | cand_entity_dir = 'data/CWQ/entity_retrieval/candidate_entities' 140 | with open(f'{cand_entity_dir}/CWQ_{split}_entities_facc1_unranked.json',encoding='utf8', mode='w') as f: 141 | json.dump(el_results, f, indent=4) 142 | 143 | 144 | def dump_entity_linking_results_for_WebQSP(split,keep=10): 145 | 146 | # 1. build and load entity linking surface index 147 | # surface_index_memory.EntitySurfaceIndexMemory(entity_list_file, surface_map_file, output_prefix) 148 | surface_index = surface_index_memory.EntitySurfaceIndexMemory( 149 | "data/common_data/facc1/entity_list_file_freebase_complete_all_mention", 150 | "data/common_data/facc1/surface_map_file_freebase_complete_all_mention", 151 | "data/common_data/facc1/freebase_complete_all_mention") 152 | 153 | # 2. load BERTEntityLinker 154 | entity_linker = BertEntityLinker(surface_index, model_path="/BERT_NER/trained_ner_model/", device="cuda:0") 155 | # sanity check 156 | sanity_checking = get_all_entity_candidates(entity_linker, "the music video stronger was directed by whom") 157 | print('RUNNING Sanity Checking on untterance') 158 | print('\t', "the music video stronger was directed by whom") 159 | print('Checking result', sanity_checking[0][:2]) 160 | print('Checking result should successfully link stronger to some nodes in Freebase (MIDs)') 161 | print('If checking result does not look good please check if the linker has been set up successfully') 162 | 163 | # 3. Load dataset split file 164 | datafile = f'data/WebQSP/origin/WebQSP.{split}.json' 165 | data = load_json(datafile)['Questions'] 166 | print(len(data)) 167 | 168 | # 4. do entity linking 169 | el_results = {} 170 | for ex in tqdm(data, total=len(data)): 171 | question = ex['RawQuestion'] 172 | question = clean_str(question) 173 | qid = ex['QuestionId'] 174 | all_candidates = get_all_entity_candidates(entity_linker, question) 175 | all_candidates = [x[:keep] for x in all_candidates] 176 | for instance in all_candidates: 177 | for x in instance: 178 | x['label'] = get_label_with_odbc(x['id']) 179 | el_results[qid]=all_candidates 180 | 181 | # 5. dump the entity linking results 182 | with open(f'data/WebQSP/entity_retrieval/candidate_entities/WebQSP_{split}_entities_facc1_unranked.json',encoding='utf8',mode='w') as f: 183 | json.dump(el_results, f, indent=4) 184 | 185 | 186 | def get_entity_linking_from_elq(question:str): 187 | res = requests.post( 188 | url=ELQ_SERVICE_URL 189 | , data=json.dumps({'question':question}) 190 | ) 191 | 192 | 193 | cand_ent_list = [] 194 | if res.text: 195 | try: 196 | el_res = json.loads(res.text) 197 | except Exception: 198 | el_res = None 199 | detection_res = el_res.get('detection_res',None) if el_res else None 200 | 201 | if detection_res: 202 | detect = detection_res[0] 203 | mention_num = len(detect['dbpedia_ids']) 204 | 205 | for i in range(mention_num): 206 | cand_num = len(detect['dbpedia_ids'][i]) 207 | for j in range(cand_num): 208 | wiki_id = detect['dbpedia_ids'][i][j] 209 | fb_mid = get_freebase_mid_from_wikiID(wiki_id) 210 | if fb_mid=='': # empty id 211 | continue 212 | label = detect['pred_tuples_string'][i][j][0] 213 | mention = detect['pred_tuples_string'][i][j][1] 214 | # mention_start = detect['pred_triples'][i][1] 215 | # mention_end = detect['pred_triples'][i][2] 216 | # mention = " ".join(question_tokens[max(0,mention_start):min(len(question_tokens),mention_end)]) 217 | score = detect['scores'][i][j] 218 | 219 | el_data = {} 220 | el_data['id']=fb_mid 221 | el_data['label']=label 222 | el_data['mention']=mention 223 | el_data['score']=score 224 | el_data['perfect_match']= (label==mention or label.lower()==mention.lower()) 225 | 226 | 227 | cand_ent_list.append(el_data) 228 | 229 | cand_ent_list.sort(key=lambda x:x['score'],reverse=True) 230 | 231 | return cand_ent_list 232 | 233 | 234 | def dump_entity_linking_results_from_elq_for_CWQ(split, keep=10): 235 | datafile = f'data/CWQ/sexpr/CWQ.{split}.expr.json' 236 | data = load_json(datafile,encoding='utf8') 237 | print(len(data)) 238 | 239 | # Check if ELQ service is running 240 | res = requests.post( 241 | url=ELQ_SERVICE_URL 242 | , data=json.dumps({'question':"what religions are practiced in the country that has the national anthem Afghan National Anthem"}) 243 | ) 244 | try: 245 | res.raise_for_status() 246 | except requests.exceptions.HTTPError as e: 247 | # not 200 248 | print(f"Error: " + str(e)) 249 | return 250 | 251 | # 1. do entity linking by elq 252 | el_results = {} 253 | for ex in tqdm(data, total=len(data), desc='Detecting Entities from ELQ for CWQ'): 254 | question = ex['question'] 255 | question = clean_str(question) 256 | qid = ex['ID'] 257 | all_candidates = get_entity_linking_from_elq(question) 258 | el_results[qid]=all_candidates 259 | 260 | # 2. dump the entity linking results 261 | cand_entity_dir = 'data/CWQ/entity_retrieval/candidate_entities' 262 | if not os.path.exists(cand_entity_dir): 263 | os.makedirs(cand_entity_dir) 264 | with open(f'{cand_entity_dir}/CWQ_{split}_cand_entities_elq.json',encoding='utf8',mode='w') as f: 265 | json.dump(el_results, f, indent=4) 266 | 267 | 268 | def dump_entity_linking_results_from_elq_for_WebQSP(split, keep=10): 269 | datafile = f'data/WebQSP/origin/WebQSP.{split}.json' 270 | data = load_json(datafile)['Questions'] 271 | print(len(data)) 272 | 273 | # Check if ELQ service is running 274 | res = requests.post( 275 | url=ELQ_SERVICE_URL 276 | , data=json.dumps({'question':"what religions are practiced in the country that has the national anthem Afghan National Anthem"}) 277 | ) 278 | try: 279 | res.raise_for_status() 280 | except requests.exceptions.HTTPError as e: 281 | # not 200 282 | print(f"Error: " + str(e)) 283 | return 284 | 285 | # 1. do entity linking 286 | el_results = {} 287 | for ex in tqdm(data, total=len(data),desc="Detecting Entities from ELQ for WebQSP"): 288 | question = ex['RawQuestion'] 289 | question = clean_str(question) 290 | qid = ex['QuestionId'] 291 | all_candidates = get_entity_linking_from_elq(question) 292 | el_results[qid]=all_candidates 293 | 294 | # 2. dump the entity linking results 295 | cand_entity_dir = 'data/WebQSP/entity_retrieval/candidate_entities' 296 | if not os.path.exists(cand_entity_dir): 297 | os.makedirs(cand_entity_dir) 298 | with open(f'{cand_entity_dir}/WebQSP_{split}_cand_entities_elq.json',encoding='utf8',mode='w') as f: 299 | json.dump(el_results, f, indent=4) 300 | 301 | 302 | if __name__=='__main__': 303 | args = _parse_args() 304 | 305 | if args.server_ip and args.server_port: 306 | import ptvsd 307 | print('Waiting for debugger to attach...') 308 | ptvsd.enable_attach(address=(args.server_ip,args.server_port),redirect_output=True) 309 | ptvsd.wait_for_attach() 310 | 311 | if args.dataset.lower() == 'cwq': 312 | if args.linker.lower() == 'elq': 313 | dump_entity_linking_results_from_elq_for_CWQ(args.split) 314 | else: 315 | dump_entity_linking_results_for_CWQ(args.split) 316 | elif args.dataset.lower() == 'webqsp': 317 | if args.linker.lower() == 'elq': 318 | dump_entity_linking_results_from_elq_for_WebQSP(args.split) 319 | else: 320 | dump_entity_linking_results_for_WebQSP(args.split) 321 | 322 | 323 | -------------------------------------------------------------------------------- /ablation_exps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ablation Experiments 3 | """ 4 | import json 5 | import os 6 | import re 7 | from tqdm import tqdm 8 | import argparse 9 | 10 | def _parse_args(): 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('action',type=str,help='Action to operate') 14 | parser.add_argument('--dataset', default='CWQ', help='CWQ or WebQSP') 15 | parser.add_argument('--eval_beam_size', default=50, type=int) 16 | parser.add_argument('--model_type', default='full', type=str, help='full or base') 17 | 18 | return parser.parse_args() 19 | 20 | """ 21 | Evaluation functions 22 | """ 23 | def evaluate_question_with_unseen_relation_entity(dataset='CWQ', model_type='full'): 24 | if dataset.lower() == 'cwq': 25 | if model_type.lower() == 'full': 26 | predictions_path = 'exps/CWQ_GMT_KBQA/beam_50_test_4_top_k_predictions.json_gen_sexpr_results.json_new.json' 27 | elif model_type.lower() == 'base': 28 | predictions_path = 'exps/CWQ_t5_base/beam_50_test_4_top_k_predictions.json_gen_sexpr_results.json_new.json' 29 | else: 30 | return 31 | predictions = load_json(predictions_path) 32 | unseen_qids = load_json('data/CWQ/generation/ablation/test_unseen_entity_or_relation_qids.json') 33 | elif dataset.lower() == 'webqsp': 34 | if model_type.lower() == 'full': 35 | predictions_path = 'exps/WebQSP_GMT_KBQA/beam_50_test_2_top_k_predictions.json_gen_sexpr_results_official_format.json_new.json' 36 | elif model_type.lower() == 'base': 37 | predictions_path = 'exps/WebQSP_t5_base/beam_50_test_2_top_k_predictions.json_gen_sexpr_results_official_format.json_new.json' 38 | else: 39 | return 40 | predictions = load_json(predictions_path) 41 | unseen_qids = load_json('data/WebQSP/generation/ablation/test_unseen_entity_or_relation_qids.json') 42 | else: 43 | unseen_qids = None 44 | return 45 | prediction_map = {pred["qid"]: pred for pred in predictions} 46 | 47 | p_list = [] 48 | r_list = [] 49 | f_list = [] 50 | acc_num = 0 51 | 52 | for qid in tqdm(unseen_qids, total=len(unseen_qids), desc='Evaluating QA performance on question with unseen KB relation/entity'): 53 | if qid not in prediction_map: 54 | p = 0.0 55 | r = 0.0 56 | f = 0.0 57 | else: 58 | p = prediction_map[qid]["precision"] 59 | r = prediction_map[qid]["recall"] 60 | f = prediction_map[qid]["f1"] 61 | if f == 1.0: 62 | acc_num += 1 63 | 64 | p_list.append(p) 65 | r_list.append(r) 66 | f_list.append(f) 67 | 68 | p_average = sum(p_list)/len(p_list) 69 | r_average = sum(r_list)/len(r_list) 70 | f_average = sum(f_list)/len(f_list) 71 | 72 | res = f'Total: {len(p_list)}, ACC:{acc_num/len(p_list)}, AVGP: {p_average}, AVGR: {r_average}, AVGF: {f_average}' 73 | 74 | print(res) 75 | 76 | dirname = os.path.dirname(predictions_path) 77 | with open (os.path.join(dirname,'eval_results_unseen_relation_or_entity.txt'),'w') as f: 78 | f.write(res) 79 | f.flush() 80 | 81 | def entity_relation_linking_evaluation(dataset='CWQ', beam_size=50): 82 | """ 83 | Evaluation of entity linking and relation linking 84 | 85 | Evaluation of entity linking 86 | - Before multi-task: disamb_entities/merged_{dataset}_test_linking_results.json 87 | - After multi-task: {dataset}_test_{test_batch_size}_beam_{beam_size}_candidate_entity_map.json 88 | 89 | Evaluation of relation linking: 90 | - Before multi-task: merged/{dataset}_test.json, in "cand_relation_list" with prediciton logits > 0.0 91 | - After multi-task: 92 | prediction results: beam_{beam_size}_test_{batch_size}_top_k_predictions.json_gen_sexpr_results.json + beam_{beam_size}_test_{batch_size}_top_k_predictions.json_gen_failed_results.json 93 | relation with prediction logit > 0.5 94 | """ 95 | if dataset.lower() == 'cwq': 96 | dirname = 'exps/CWQ_GMT_KBQA' 97 | gen_failed_predictions = load_json(os.path.join(dirname, f'beam_{beam_size}_test_4_top_k_predictions.json_gen_failed_results.json')) 98 | gen_succeed_predictions = load_json(os.path.join(dirname, f'beam_{beam_size}_test_4_top_k_predictions.json_gen_sexpr_results.json')) 99 | predictions = gen_failed_predictions + gen_succeed_predictions 100 | 101 | dataset_content = load_json('data/CWQ/generation/merged/CWQ_test.json') 102 | label_maps = load_json('data/CWQ/generation/label_maps/CWQ_test_label_maps.json') 103 | after_entity_linking_res = load_json(os.path.join(dirname, f'CWQ_test_4_beam_{beam_size}_candidate_entity_map.json')) 104 | before_entity_linking_res = load_json('data/CWQ/entity_retrieval/disamb_entities/CWQ_merged_test_disamb_entities.json') 105 | 106 | elif dataset.lower() == 'webqsp': 107 | dirname = 'exps/WebQSP_GMT_KBQA' 108 | gen_failed_predictions = load_json(os.path.join(dirname, f'beam_{beam_size}_test_2_top_k_predictions.json_gen_failed_results.json')) 109 | gen_succeed_predictions = load_json(os.path.join(dirname, f'beam_{beam_size}_test_2_top_k_predictions.json_gen_sexpr_results.json')) 110 | predictions = gen_failed_predictions + gen_succeed_predictions 111 | 112 | dataset_content = load_json('data/WebQSP/generation/merged/WebQSP_test.json') 113 | label_maps = load_json('data/WebQSP/generation/label_maps/WebQSP_test_label_maps.json') 114 | after_entity_linking_res = load_json(os.path.join(dirname, f'WebQSP_test_2_beam_{beam_size}_candidate_entity_map.json')) 115 | before_entity_linking_res = load_json('data/WebQSP/entity_retrieval/disamb_entities/WebQSP_merged_test_disamb_entities.json') 116 | 117 | else: 118 | return 119 | 120 | assert len(predictions) == len(label_maps), print(len(predictions), len(dataset)) 121 | assert len(predictions) == len(dataset_content), print(len(predictions), len(dataset_content)) 122 | golden_entities = [] 123 | golden_relations = [] 124 | after_entity_predictions = [] 125 | after_relation_predictions = [] 126 | before_entity_predictions = [] 127 | before_relation_predictions = [] 128 | 129 | predictions_map = {pred["qid"]: pred for pred in predictions} 130 | dataset_content = {example["ID"]: example for example in dataset_content} 131 | 132 | for qid in tqdm(label_maps, total=len(label_maps)): 133 | assert qid in predictions_map, print(qid) 134 | golden_entities.append(list(label_maps[qid]["entity_label_map"].keys())) 135 | golden_relations.append(list(label_maps[qid]["rel_label_map"].keys())) 136 | after_relation_pred_indexes = [idx for (idx, score) in enumerate(predictions_map[qid]["pred"]["pred_relation_clf_labels"]) if float(score) > 0.5] 137 | before_relation_predictions.append([item[0] for item in dataset_content[qid]["cand_relation_list"] if float(item[1]) > 0.0]) 138 | after_relation_predictions.append([dataset_content[qid]["cand_relation_list"][idx][0] for idx in after_relation_pred_indexes]) 139 | if qid not in before_entity_linking_res: 140 | before_entity_predictions.append([]) 141 | else: 142 | before_entity_predictions.append([item['id'] for item in before_entity_linking_res[qid]]) 143 | if qid not in after_entity_linking_res: 144 | after_entity_predictions.append([]) 145 | else: 146 | after_entity_predictions.append([item['id'] for item in after_entity_linking_res[qid].values()]) 147 | 148 | after_relation_linking_res = general_PRF1(after_relation_predictions, golden_relations) 149 | after_entity_linking_res = general_PRF1(after_entity_predictions, golden_entities) 150 | before_relation_linking_res = general_PRF1(before_relation_predictions, golden_relations) 151 | before_entity_linking_res = general_PRF1(before_entity_predictions, golden_entities) 152 | 153 | with open(os.path.join(dirname, f'beam_{beam_size}_entity_relation_linking_evaluation.txt'), 'w') as f: 154 | f.write(f'After multi-task, Relation linking: {after_relation_linking_res}\n') 155 | f.write(f'After multi-task, Entity linking: {after_entity_linking_res}\n') 156 | f.write(f'Before multi-task, Relation linking: {before_relation_linking_res}\n') 157 | f.write(f'Before multi-task, Entity linking: {before_entity_linking_res}\n') 158 | print(f'After multi-task, Relation linking: {after_relation_linking_res}\n') 159 | print(f'After multi-task, Entity linking: {after_entity_linking_res}\n') 160 | print(f'Before multi-task, Relation linking: {before_relation_linking_res}\n') 161 | print(f'Before multi-task, Entity linking: {before_entity_linking_res}\n') 162 | 163 | 164 | """ 165 | Utility functions 166 | """ 167 | def load_json(fname, mode="r", encoding="utf8"): 168 | if "b" in mode: 169 | encoding = None 170 | with open(fname, mode=mode, encoding=encoding) as f: 171 | return json.load(f) 172 | 173 | 174 | def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False): 175 | if "b" in mode: 176 | encoding = None 177 | with open(fname, "w", encoding=encoding) as f: 178 | return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii) 179 | 180 | 181 | def general_PRF1(predictions, goldens): 182 | assert len(predictions) == len(goldens), print(len(predictions), len(goldens)) 183 | p_list = [] 184 | r_list = [] 185 | f_list = [] 186 | acc_num = 0 187 | for (pred, golden) in zip(predictions, goldens): 188 | pred = set(pred) 189 | golden = set(golden) 190 | if pred == golden: 191 | acc_num+=1 192 | if len(pred)== 0: 193 | if len(golden)==0: 194 | p=1 195 | r=1 196 | f=1 197 | else: 198 | p=0 199 | r=0 200 | f=0 201 | elif len(golden) == 0: 202 | p=0 203 | r=0 204 | f=0 205 | else: 206 | p = len(pred & golden)/ len(pred) 207 | r = len(pred & golden)/ len(golden) 208 | f = 2*(p*r)/(p+r) if p+r>0 else 0 209 | 210 | p_list.append(p) 211 | r_list.append(r) 212 | f_list.append(f) 213 | 214 | p_average = sum(p_list)/len(p_list) 215 | r_average = sum(r_list)/len(r_list) 216 | f_average = sum(f_list)/len(f_list) 217 | res = f'Total: {len(p_list)}, ACC:{acc_num/len(p_list)}, AVGP: {p_average}, AVGR: {r_average}, AVGF: {f_average}' 218 | return res 219 | 220 | 221 | """ 222 | Data preparation 223 | """ 224 | def get_train_unique_relations_entities(dataset="CWQ"): 225 | """ 226 | Using files under label_maps/ folder 227 | """ 228 | if dataset.lower() == 'cwq': 229 | if os.path.exists('data/CWQ/generation/ablation/train_unique_relation_entity.json'): 230 | return 231 | elif dataset.lower() == 'webqsp': 232 | if os.path.exists('data/WebQSP/generation/ablation/train_unique_relation_entity.json'): 233 | return 234 | else: 235 | return 236 | 237 | if dataset.lower() == 'cwq': 238 | train_label_maps = load_json('data/CWQ/generation/label_maps/CWQ_train_label_maps.json') 239 | elif dataset.lower() == 'webqsp': 240 | train_label_maps = load_json('data/WebQSP/generation/label_maps/WebQSP_train_label_maps.json') 241 | 242 | unique_relations = set() 243 | unique_entities = set() 244 | 245 | for qid in tqdm(train_label_maps, total=len(train_label_maps)): 246 | data = train_label_maps[qid] 247 | rels = data["rel_label_map"].keys() 248 | for rel in rels: 249 | unique_relations.add(rel) 250 | 251 | entities = data["entity_label_map"].keys() 252 | for ent in entities: 253 | unique_entities.add(ent) 254 | 255 | if dataset.lower() == 'cwq': 256 | dump_json({ 257 | 'entities': list(unique_entities), 258 | 'relations': list(unique_relations) 259 | }, 'data/CWQ/generation/ablation/train_unique_relation_entity.json') 260 | elif dataset.lower() == 'webqsp': 261 | dump_json({ 262 | 'entities': list(unique_entities), 263 | 'relations': list(unique_relations) 264 | }, 'data/WebQSP/generation/ablation/train_unique_relation_entity.json') 265 | 266 | 267 | 268 | def get_test_unseen_questions(dataset='CWQ'): 269 | if dataset.lower() == 'cwq': 270 | if os.path.exists('data/CWQ/generation/ablation/test_unseen_entity_or_relation_qids.json'): 271 | return 272 | elif dataset.lower() == 'webqsp': 273 | if os.path.exists('data/WebQSP/generation/ablation/test_unseen_entity_or_relation_qids.json'): 274 | return 275 | else: 276 | return 277 | if dataset.lower() == 'cwq': 278 | test_label_maps = load_json('data/CWQ/generation/label_maps/CWQ_test_label_maps.json') 279 | train_relation_entity_list = load_json('data/CWQ/generation/ablation/train_unique_relation_entity.json') 280 | elif dataset.lower() == 'webqsp': 281 | test_label_maps = load_json('data/WebQSP/generation/label_maps/WebQSP_test_label_maps.json') 282 | train_relation_entity_list = load_json('data/WebQSP/generation/ablation/train_unique_relation_entity.json') 283 | 284 | entities_list = train_relation_entity_list["entities"] 285 | relations_list = train_relation_entity_list["relations"] 286 | unseen_qids = set() 287 | 288 | for qid in tqdm(test_label_maps, total=len(test_label_maps)): 289 | data = test_label_maps[qid] 290 | relations = data["rel_label_map"].keys() 291 | for rel in relations: 292 | if rel not in relations_list: 293 | unseen_qids.add(qid) 294 | 295 | entities = data["entity_label_map"].keys() 296 | for ent in entities: 297 | if ent not in entities_list: 298 | unseen_qids.add(qid) 299 | 300 | unseen_qids = list(unseen_qids) 301 | if dataset.lower() == 'cwq': 302 | dump_json(unseen_qids, 'data/CWQ/generation/ablation/test_unseen_entity_or_relation_qids.json') 303 | elif dataset.lower() == 'webqsp': 304 | dump_json(unseen_qids, 'data/WebQSP/generation/ablation/test_unseen_entity_or_relation_qids.json') 305 | 306 | if __name__=='__main__': 307 | args = _parse_args() 308 | action = args.action 309 | 310 | if action.lower() == 'linking_evaluation': 311 | entity_relation_linking_evaluation( 312 | args.dataset, 313 | args.eval_beam_size 314 | ) 315 | elif action.lower() == 'unseen_evaluation': 316 | # Data preparation 317 | get_train_unique_relations_entities(args.dataset) 318 | get_test_unseen_questions(args.dataset) 319 | evaluate_question_with_unseen_relation_entity(args.dataset, args.model_type) -------------------------------------------------------------------------------- /run_relation_data_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import collections 4 | import os 5 | import random 6 | import csv 7 | from executor.sparql_executor import get_2hop_relations_with_odbc_wo_filter 8 | from tqdm import tqdm 9 | import pandas as pd 10 | from transformers import AutoTokenizer 11 | 12 | from components.utils import ( 13 | extract_mentioned_relations_from_sparql, 14 | load_json, 15 | dump_json, 16 | _textualize_relation 17 | ) 18 | 19 | BLANK_TOKEN = '[BLANK]' 20 | 21 | def _parse_args(): 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('action',type=str,help='Action to operate') 25 | parser.add_argument('--dataset', default='CWQ', help='dataset to perform entity linking, should be CWQ or WebQSP') 26 | parser.add_argument('--split', default='test', help="split to operate on; ['train', 'dev', 'test']") # the split file: ['dev','test','train'] 27 | 28 | return parser.parse_args() 29 | 30 | 31 | def extract_golden_relations_cwq(src_path, tgt_path): 32 | """ 33 | Extracting golden relations from sparql. 34 | """ 35 | dataset_with_sexpr = load_json(src_path) 36 | merged_data = [] 37 | for example in tqdm(dataset_with_sexpr, total=len(dataset_with_sexpr), desc=f'Extracting golden relations'): 38 | sparql = example['sparql'] 39 | gold_relations = extract_mentioned_relations_from_sparql(sparql) 40 | gold_rel_label_map = {} 41 | for rel in gold_relations: 42 | linear_rel = _textualize_relation(rel) 43 | gold_rel_label_map[rel] = linear_rel 44 | example['gold_relation_map'] = gold_rel_label_map 45 | merged_data.append(example) 46 | 47 | print(f'Wrinting merged data to {tgt_path}...') 48 | dump_json(merged_data,tgt_path,indent=4) 49 | print('Writing finished') 50 | 51 | def extract_golden_relations_webqsp(src_path, tgt_path): 52 | origin_dataset = load_json(src_path) 53 | if "Questions" in origin_dataset: 54 | origin_dataset = origin_dataset["Questions"] 55 | merged_data = [] 56 | for example in tqdm(origin_dataset, total=len(origin_dataset), desc=f'Extracting golden relations'): 57 | gold_rel_label_map = {} 58 | for parse in example['Parses']: 59 | sparql = parse['Sparql'] 60 | gold_relations = extract_mentioned_relations_from_sparql(sparql) 61 | 62 | for rel in gold_relations: 63 | linear_rel = _textualize_relation(rel) 64 | gold_rel_label_map[rel] = linear_rel # no duplicate keys in a dictionary 65 | 66 | example['gold_relation_map'] = gold_rel_label_map 67 | merged_data.append(example) 68 | 69 | print(f'Wrinting merged data to {tgt_path}...') 70 | dump_json(merged_data,tgt_path,indent=4) 71 | print('Writing finished') 72 | 73 | 74 | 75 | 76 | 77 | def sample_data_mask_entity_mention( 78 | golden_file, 79 | entity_linking_file, 80 | all_relations_file, 81 | output_path, 82 | sample_size=100 83 | ): 84 | """ 85 | This method mask entity mentions in question accroding to entity linking results 86 | """ 87 | print('output_path: {}'.format(output_path)) 88 | golden_NLQ_relations = dict() 89 | all_relations = load_json(all_relations_file) 90 | entity_linking_res = load_json(entity_linking_file) 91 | items = load_json(golden_file) 92 | for item in items: 93 | # mask entity mention in question 94 | question = item["question"].lower() 95 | qid = item["ID"] 96 | el_result = entity_linking_res[qid] if qid in entity_linking_res else {} 97 | el_result = {example['id']: example for example in el_result} 98 | for eid in el_result: 99 | mention = el_result[eid]["mention"] 100 | question = question.replace(mention, BLANK_TOKEN) 101 | golden_NLQ_relations[question] = list(item["gold_relation_map"].keys()) 102 | 103 | samples = [] 104 | for question in tqdm(golden_NLQ_relations, total=len(golden_NLQ_relations), desc="Sampling data for Bi-encoder"): 105 | relations = golden_NLQ_relations[question] 106 | diff_rels = list(set(all_relations) - set(relations)) 107 | 108 | negative_rels = random.sample(diff_rels, (sample_size-1) * len(relations)) 109 | # Make sure each batch contains 1 golden relation 110 | for idx in range(len(relations)): 111 | sample = [] 112 | sample.append([question, relations[idx], '1']) 113 | for n_rel in negative_rels[idx * (sample_size-1): (idx+1) * (sample_size-1)]: 114 | sample.append([question, n_rel, '0']) 115 | random.shuffle(sample) 116 | samples.extend(sample) 117 | 118 | with open(output_path, 'w') as f: 119 | header = ['id', 'question', 'relation', 'label'] 120 | writer = csv.writer(f, delimiter='\t') 121 | writer.writerow(header) 122 | idx = 0 123 | for line in samples: 124 | writer.writerow([str(idx)] + line) 125 | idx += 1 126 | 127 | def sample_data_rich_relation( 128 | golden_file, 129 | relations_file, 130 | relation_rich_map_path, 131 | output_path, 132 | sample_size=100 133 | ): 134 | """ 135 | sample_data() create training/test dataset based relations 136 | sample_data_rich_relation() create training/test dataset based on enriched relations, i.e., relation|label|domain|range 137 | """ 138 | golden_rich_relations_map = dict() 139 | relation_rich_map = load_json(relation_rich_map_path) 140 | all_relations = load_json(relations_file) 141 | all_rich_relations = list(set(map( 142 | lambda item: relation_rich_map[item] if item in relation_rich_map else item, 143 | all_relations 144 | ))) 145 | examples = load_json(golden_file) 146 | 147 | for example in examples: 148 | golden_rich_relations_map[example["QuestionId"]] = list(set(map( 149 | lambda item: relation_rich_map[item] if item in relation_rich_map else item, list(example["gold_relation_map"].keys()) 150 | ))) 151 | 152 | qid_example_map = {item["QuestionId"]: item for item in examples} 153 | 154 | samples = [] 155 | for qid in tqdm(golden_rich_relations_map, total=len(golden_rich_relations_map)): 156 | question = qid_example_map[qid]["ProcessedQuestion"].lower() 157 | golden_rich_relations = golden_rich_relations_map[qid] 158 | diff_rich_relations = list(set(all_rich_relations) - set(golden_rich_relations)) 159 | 160 | negative_rich = random.sample(diff_rich_relations, (sample_size-1) * len(golden_rich_relations)) 161 | # Make sure each batch contains 1 golden relation 162 | for idx in range(len(golden_rich_relations)): 163 | sample = [] 164 | sample.append([question, golden_rich_relations[idx], '1']) 165 | for n_lab in negative_rich[idx * (sample_size-1): (idx+1) * (sample_size-1)]: 166 | sample.append([question, n_lab, '0']) 167 | assert len(sample) == sample_size 168 | random.shuffle(sample) 169 | samples.extend(sample) 170 | 171 | with open(output_path, 'w') as f: 172 | header = ['id', 'question', 'relation', 'label'] 173 | writer = csv.writer(f, delimiter='\t') 174 | writer.writerow(header) 175 | idx = 0 176 | for line in samples: 177 | writer.writerow([str(idx)] + line) 178 | idx += 1 179 | 180 | 181 | def make_partial_train_dev(train_split_path): 182 | random.seed(17) 183 | data = load_json(train_split_path)["Questions"] 184 | random.shuffle(data) 185 | ptrain = data[:-200] 186 | pdev = data[-200:] 187 | print(len(ptrain)) 188 | print(len(pdev)) 189 | dump_json(ptrain, f'data/WebQSP/origin/WebQSP.ptrain.json', indent=4) 190 | dump_json(pdev, f'data/WebQSP/origin/WebQSP.pdev.json', indent=4) 191 | 192 | def validate_data_sequence( 193 | data_path_1, 194 | data_path_2, 195 | ): 196 | data_1 = load_json(data_path_1) 197 | data_2 = load_json(data_path_2) 198 | if "Questions" in data_1: 199 | data_1 = data_1["Questions"] 200 | if "Questions" in data_2: 201 | data_2 = data_2["Questions"] 202 | data_1_qids = [example["QuestionId"] for example in data_1] 203 | data_2_qids = [example["QuestionId"] for example in data_2] 204 | print(data_1_qids == data_2_qids) 205 | 206 | 207 | def sample_data(dataset, split): 208 | if dataset.lower() == 'cwq': 209 | if not os.path.exists('data/CWQ/relation_retrieval/bi-encoder/CWQ.{}.goldenRelation.json'.format(split)): 210 | for sp in ['train', 'dev', 'test']: 211 | extract_golden_relations_cwq( 212 | 'data/CWQ/sexpr/CWQ.{}.expr.json'.format(sp), 213 | 'data/CWQ/relation_retrieval/bi-encoder/CWQ.{}.goldenRelation.json'.format(sp) 214 | ) 215 | if split != 'test': 216 | sample_data_mask_entity_mention( 217 | 'data/CWQ/relation_retrieval/bi-encoder/CWQ.{}.goldenRelation.json'.format(split), 218 | 'data/CWQ/entity_retrieval/disamb_entities/CWQ_merged_{}_disamb_entities.json'.format(split), 219 | 'data/common_data/freebase_relations_filtered.json', 220 | 'data/CWQ/relation_retrieval/bi-encoder/CWQ.{}.sampled.tsv'.format(split) 221 | ) 222 | elif dataset.lower() == 'webqsp': 223 | if not os.path.exists('data/WebQSP/origin/WebQSP.pdev.json'): 224 | print('Dividing ptrain and pdev') 225 | make_partial_train_dev('data/WebQSP/origin/WebQSP.train.json') 226 | if not os.path.exists('data/WebQSP/relation_retrieval/bi-encoder/WebQSP.{}.goldenRelation.json'.format(split)): 227 | for sp in ['train', 'ptrain', 'pdev', 'test']: 228 | print('extract golden relations') 229 | extract_golden_relations_webqsp( 230 | 'data/WebQSP/origin/WebQSP.{}.json'.format(sp), 231 | 'data/WebQSP/relation_retrieval/bi-encoder/WebQSP.{}.goldenRelation.json'.format(sp) 232 | ) 233 | validate_data_sequence( 234 | f'data/WebQSP/origin/WebQSP.{sp}.json', 235 | f'data/WebQSP/relation_retrieval/bi-encoder/WebQSP.{sp}.goldenRelation.json', 236 | ) 237 | 238 | if split != 'test': 239 | sample_data_rich_relation( 240 | 'data/WebQSP/relation_retrieval/bi-encoder/WebQSP.{}.goldenRelation.json'.format(split), 241 | 'data/common_data/freebase_relations_filtered.json', 242 | 'data/common_data/fb_relation_rich_map.json', 243 | 'data/WebQSP/relation_retrieval/bi-encoder/WebQSP.{}.sampled.tsv'.format(split) 244 | ) 245 | 246 | prepare_2hop_relations(dataset) 247 | 248 | 249 | def get_unique_entity_ids( 250 | train_split_path, 251 | dev_split_path, 252 | test_split_path, 253 | output_path 254 | ): 255 | if train_split_path is not None: 256 | train_data = load_json(train_split_path) 257 | else: 258 | train_data = None 259 | 260 | if dev_split_path is not None: 261 | dev_data = load_json(dev_split_path) 262 | else: 263 | dev_data = None 264 | 265 | if test_split_path is not None: 266 | test_data = load_json(test_split_path) 267 | else: 268 | test_data = None 269 | 270 | unique_entity_ids = set() 271 | for data in [train_data, dev_data, test_data]: 272 | if data is None: 273 | continue 274 | for example in data: 275 | for entity_id in example["freebase_ids"]: 276 | unique_entity_ids.add(entity_id) 277 | dump_json(list(unique_entity_ids), output_path) 278 | 279 | def query_2hop_relations(entity_ids_path, output_path): 280 | entity_ids = load_json(entity_ids_path) 281 | res = dict() 282 | for eid in tqdm(entity_ids, total=len(entity_ids), desc="querying 2 hop relations"): 283 | in_relations, out_relations, _ = get_2hop_relations_with_odbc_wo_filter(eid) 284 | relations = list(set(in_relations) | set(out_relations)) 285 | res[eid] = relations 286 | dump_json(res, output_path) 287 | 288 | def construct_question_2hop_relations( 289 | train_split_path, 290 | dev_split_path, 291 | test_split_path, 292 | entities_2hop_relations_path 293 | ): 294 | if train_split_path is not None: 295 | train_data = load_json(train_split_path) 296 | else: 297 | train_data = None 298 | 299 | if dev_split_path is not None: 300 | dev_data = load_json(dev_split_path) 301 | else: 302 | dev_data = None 303 | 304 | if test_split_path is not None: 305 | test_data = load_json(test_split_path) 306 | else: 307 | test_data = None 308 | 309 | entity_relation_map = load_json(entities_2hop_relations_path) 310 | for item in [(train_data, train_split_path), (dev_data, dev_split_path), (test_data, test_split_path)]: 311 | data, path = item 312 | if data is None: 313 | continue 314 | enhanced_linking_results = [] 315 | for example in tqdm(data, total=len(data)): 316 | two_hop_relations = [] 317 | for entity_id in example["freebase_ids"]: 318 | if entity_id in entity_relation_map: 319 | two_hop_relations.extend(entity_relation_map[entity_id]) 320 | example["two_hop_relations"] = list(set(two_hop_relations)) # remove duplicate 321 | enhanced_linking_results.append(example) 322 | print('split: {}, length: {}'.format(path, len(enhanced_linking_results))) 323 | dump_json(enhanced_linking_results, path[:-5] + '_two_hop_relations.json') 324 | 325 | 326 | 327 | 328 | def prepare_2hop_relations( 329 | dataset 330 | ): 331 | if dataset.lower() == 'webqsp': 332 | if not os.path.exists('data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/unique_entity_ids.json'): 333 | print('Collecting unique entity') 334 | get_unique_entity_ids( 335 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/webqsp_train_rng_el.json', 336 | None, 337 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/webqsp_test_rng_el.json', 338 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/unique_entity_ids.json' 339 | ) 340 | if not os.path.exists('data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/entities_2hop_relations.json'): 341 | print('quering 2hop relations') 342 | query_2hop_relations( 343 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/unique_entity_ids.json', 344 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/entities_2hop_relations.json' 345 | ) 346 | if not os.path.exists('data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/webqsp_test_rng_el_two_hop_relations.json'): 347 | print('adding two hop relations to original linking results') 348 | construct_question_2hop_relations( 349 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/webqsp_train_rng_el.json', 350 | None, 351 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/webqsp_test_rng_el.json', 352 | 'data/WebQSP/relation_retrieval/cross-encoder/rng_kbqa_linking_results/entities_2hop_relations.json' 353 | ) 354 | 355 | 356 | 357 | 358 | if __name__=='__main__': 359 | args = _parse_args() 360 | action = args.action 361 | 362 | if action.lower() == 'sample_data': 363 | sample_data(dataset=args.dataset, split=args.split) 364 | 365 | --------------------------------------------------------------------------------