├── utils ├── __init__.py ├── config.py └── common.py ├── generator └── __init__.py ├── preprocess ├── __init__.py ├── preprocess_qa.sh └── merge_ec_file.py ├── qa_baseline ├── __init__.py ├── train_qa_baseline.sh └── config.py ├── qa_evidence_chain ├── __init__.py ├── test_qa_evidence_chain_retrieved.sh └── train_qa_evidence_chain_retrieved.sh ├── evidence_chain ├── extraction │ ├── __init__.py │ └── extraction_kws_ec.sh ├── fine-tune │ ├── __init__.py │ ├── run_evidence_train.sh │ ├── run_evidence_large_eval.sh │ └── run_evidence_train_from_pretrain.sh ├── pretrain │ ├── __init__.py │ ├── run_evidence_pretrain.sh │ ├── run_evidence_large_eval.sh │ └── run_evidence_base_train.sh ├── ranking_model │ ├── __init__.py │ ├── run_evidence_large_eval.sh │ ├── run_evidence_large_train.sh │ └── model.py ├── pretrain_data_process │ ├── __init__.py │ ├── path_extraction │ │ └── __init__.py │ └── negative_search_by_bm25 │ │ ├── __init__.py │ │ ├── create_db.py │ │ ├── elastic_search.py │ │ └── search_doc_from_db.py ├── .DS_Store └── evaluate_ranked_evidence_chain.py ├── retriever ├── retrieval │ ├── data │ │ └── __init__.py │ ├── drqa │ │ ├── __init__.py │ │ ├── retriever │ │ │ ├── __init__.py │ │ │ ├── doc_db.py │ │ │ ├── BM25_doc_ranker.py │ │ │ ├── tfidf_doc_ranker.py │ │ │ └── utils.py │ │ └── drqa_tokenizers │ │ │ ├── __init__.py │ │ │ ├── simple_tokenizer.py │ │ │ ├── spacy_tokenizer.py │ │ │ ├── regexp_tokenizer.py │ │ │ ├── tokenizer.py │ │ │ └── corenlp_tokenizer.py │ ├── models │ │ └── __init__.py │ ├── utils │ │ └── __init__.py │ ├── attention │ │ ├── __pycache__ │ │ │ ├── AFT.cpython-38.pyc │ │ │ ├── BAM.cpython-38.pyc │ │ │ ├── PSA.cpython-38.pyc │ │ │ ├── SGE.cpython-38.pyc │ │ │ ├── ViP.cpython-38.pyc │ │ │ ├── CBAM.cpython-38.pyc │ │ │ ├── DANet.cpython-38.pyc │ │ │ ├── EMSA.cpython-38.pyc │ │ │ ├── CoAtNet.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── SEAttention.cpython-38.pyc │ │ │ ├── SKAttention.cpython-38.pyc │ │ │ ├── A2Atttention.cpython-38.pyc │ │ │ ├── ECAAttention.cpython-38.pyc │ │ │ ├── MUSEAttention.cpython-38.pyc │ │ │ ├── SelfAttention.cpython-38.pyc │ │ │ ├── OutlookAttention.cpython-38.pyc │ │ │ ├── ShuffleAttention.cpython-38.pyc │ │ │ ├── ExternalAttention.cpython-38.pyc │ │ │ └── SimplifiedSelfAttention.cpython-38.pyc │ │ ├── ExternalAttention.py │ │ ├── ECAAttention.py │ │ ├── SEAttention.py │ │ ├── SGE.py │ │ ├── DANet.py │ │ ├── ViP.py │ │ ├── AFT.py │ │ ├── A2Atttention.py │ │ ├── PSA.py │ │ ├── OutlookAttention.py │ │ ├── SKAttention.py │ │ ├── CBAM.py │ │ ├── ShuffleAttention.py │ │ ├── CoAtNet.py │ │ ├── SimplifiedSelfAttention.py │ │ ├── SelfAttention.py │ │ ├── BAM.py │ │ ├── MUSEAttention.py │ │ ├── EMSA.py │ │ └── HaloAttention.py │ ├── __init__.py │ ├── ___tr_dataset.py │ ├── tfidf_retriever.py │ ├── criterions.py │ └── config.py ├── encode_corpus.py └── eval_ottqa_retrieval.py └── readme.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qa_baseline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qa_evidence_chain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/path_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/negative_search_by_bm25/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/evidence_chain/.DS_Store -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/AFT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/AFT.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/BAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/BAM.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/PSA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/PSA.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SGE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SGE.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ViP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ViP.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/CBAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/CBAM.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/DANet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/DANet.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/EMSA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/EMSA.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/CoAtNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/CoAtNet.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SEAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SEAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SKAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SKAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/A2Atttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/A2Atttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ECAAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ECAAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/MUSEAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/MUSEAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SelfAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SelfAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/OutlookAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/OutlookAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ShuffleAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ShuffleAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ExternalAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ExternalAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SimplifiedSelfAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SimplifiedSelfAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python 7 | # Copyright 2017-present, Facebook, Inc. 8 | # All rights reserved. 9 | # 10 | # This source code is licensed under the license found in the 11 | # LICENSE file in the root directory of this source tree. 12 | 13 | from . import data 14 | from . import models 15 | from . import utils 16 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | def get_class(name): 11 | if name == 'tfidf': 12 | return TfidfDocRanker 13 | if name == 'bm25': 14 | return BM25DocRanker 15 | if name == 'sqlite': 16 | return DocDB 17 | raise RuntimeError('Invalid retriever class: %s' % name) 18 | 19 | 20 | from .doc_db import DocDB 21 | from .tfidf_doc_ranker import TfidfDocRanker 22 | from .BM25_doc_ranker import BM25DocRanker -------------------------------------------------------------------------------- /preprocess/preprocess_qa.sh: -------------------------------------------------------------------------------- 1 | export CONCAT_TBS=15 2 | export TABLE_CORPUS=table_corpus_metagptdoc 3 | export MODEL_PATH=./ODQA/retrieval_results/shared_roberta_threecat_basic_mean_one_query 4 | #python ../preprocessing/qa_preprocess.py \ 5 | # --split dev \ 6 | # --reprocess \ 7 | # --add_link \ 8 | # --topk_tbs ${CONCAT_TBS} \ 9 | # --retrieval_results_file ${MODEL_PATH}/dev_output_k100_${TABLE_CORPUS}.json \ 10 | # --qa_save_path ${MODEL_PATH}/dev_preprocessed_${TABLE_CORPUS}_k100cat${CONCAT_TBS}.json \ 11 | # 2>&1 |tee ${MODEL_PATH}/run_logs/${TABLE_CORPUS}/preprocess_qa_dev_k100cat${CONCAT_TBS}.log; 12 | python ../preprocessing/qa_preprocess.py \ 13 | --split train \ 14 | --reprocess \ 15 | --add_link \ 16 | --topk_tbs ${CONCAT_TBS} \ 17 | --retrieval_results_file ${MODEL_PATH}/train_output_k100_${TABLE_CORPUS}.json \ 18 | --qa_save_path ${MODEL_PATH}/train_preprocessed_${TABLE_CORPUS}_k100cat${CONCAT_TBS}.json \ 19 | 2>&1 |tee ${MODEL_PATH}/run_logs/${TABLE_CORPUS}/preprocess_qa_train_k100cat${CONCAT_TBS}.log; -------------------------------------------------------------------------------- /evidence_chain/pretrain/run_evidence_pretrain.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/evidence_chain_data/bart_output_for_pretraining/pre-training 4 | export TRAIN_DATA_PATH=evidence_pretrain_train_all_esnegs.jsonl 5 | export DEV_DATA_PATH=evidence_pretrain_dev_all_esnegs.jsonl 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-1neg-weighted-esneg 7 | python run_classifier.py \ 8 | --model_type roberta \ 9 | --model_name_or_path roberta-base \ 10 | --task_name evidence_chain \ 11 | --do_train \ 12 | --do_eval \ 13 | --eval_all_checkpoints \ 14 | --data_dir ${DATA_PATH} \ 15 | --output_dir ${MODEL_PATH} \ 16 | --train_file ${TRAIN_DATA_PATH} \ 17 | --dev_file ${DEV_DATA_PATH} \ 18 | --max_seq_length 512 \ 19 | --per_gpu_train_batch_size 32 \ 20 | --per_gpu_eval_batch_size 64 \ 21 | --learning_rate 3e-5 \ 22 | --evaluate_during_training \ 23 | --overwrite_cache \ 24 | --save_steps 6000 \ 25 | --num_train_epochs 3 \ 26 | 2>&1 | tee log_pretrain_3weighted.log 27 | -------------------------------------------------------------------------------- /evidence_chain/extraction/extraction_kws_ec.sh: -------------------------------------------------------------------------------- 1 | #extract keywords for retrieved or ground-truth evidence block 2 | python extract_evidence_chain.py --split train/dev --extract_keywords --kw_extract_type ground-truth/retrieved 3 | #extract ground-truth evidence chain 4 | python extract_evidence_chain.py --split train/dev --extract_evidence_chain 5 | #extract ground-truth evidence chain and save evidence chain for training bart 6 | python extract_evidence_chain.py --split train/dev --extract_evidence_chain --save_bart_training_data 7 | #extract candidate evidence chain 8 | python extract_evidence_chain.py --split train/dev --extract_candidate_evidence_chain 9 | #evaluate ranked evidence chain 10 | python evaluate_ranked_evidence_chain.py 11 | 12 | #extract pretrain data 13 | cd path_extraction 14 | #generate inference data for bart 15 | python parse_table_psg_link.py bart_inference_data 16 | #generate templated pretrain data 17 | #generate fake training and dev evidence chains 18 | python parse_table_psg_link.py fake_pretrain_data 19 | 20 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/run_evidence_train.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/evidence_chain_data/ground-truth-based/ground-truth-evidence-chain 4 | export TRAIN_DATA_PATH=train_gt-ec.jsonl 5 | export DEV_DATA_PATH=dev_gt-ec.jsonl 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large 7 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-new/checkpoint-best 8 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python run_classifier.py \ 9 | --model_type roberta \ 10 | --tokenizer_name roberta-large \ 11 | --config_name roberta-large \ 12 | --model_name_or_path roberta-large \ 13 | --task_name evidence_chain \ 14 | --overwrite_cache \ 15 | --do_train \ 16 | --do_eval \ 17 | --eval_all_checkpoints \ 18 | --data_dir ${DATA_PATH} \ 19 | --output_dir ${MODEL_PATH} \ 20 | --train_file ${TRAIN_DATA_PATH} \ 21 | --dev_file ${DEV_DATA_PATH} \ 22 | --max_seq_length 512 \ 23 | --per_gpu_train_batch_size 16 \ 24 | --per_gpu_eval_batch_size 16 \ 25 | --learning_rate 1e-5 \ 26 | --num_train_epochs 10 27 | -------------------------------------------------------------------------------- /qa_evidence_chain/test_qa_evidence_chain_retrieved.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=1 2 | export BASIC_DIR=./ODQA 3 | export MODEL_NAME=allenai/longformer-base-4096 4 | export TOKENIZERS_PARALLELISM=false 5 | export TRAIN_DATA_PATH=train_ranked_evidence_chain_for_qa_weighted.json 6 | export DEV_DATA_PATH=dev_ranked_evidence_chain_for_qa_weighted.json 7 | export MODEL_DIR=${MODEL_CHECKPOINT} 8 | export PREFIX=retrieved_blink_ecmask 9 | python train_final_qa_ori.py \ 10 | --do_eval \ 11 | --model_type longformer \ 12 | --evaluate_during_training \ 13 | --data_dir ${BASIC_DIR}/data/qa_with_evidence_chain \ 14 | --output_dir ${BASIC_DIR}/model/qa_model/${MODEL_DIR} \ 15 | --train_file ${TRAIN_DATA_PATH} \ 16 | --dev_file ${DEV_DATA_PATH} \ 17 | --per_gpu_train_batch_size 2 \ 18 | --gradient_accumulation_steps 1 \ 19 | --learning_rate 1e-5 \ 20 | --num_train_epochs 10.0 \ 21 | --max_seq_length 4096 \ 22 | --doc_stride 3072 \ 23 | --threads 8 \ 24 | --topk_tbs 15 \ 25 | --model_name_or_path ${MODEL_NAME} \ 26 | --prefix ${PREFIX} \ 27 | --save_cache \ 28 | --overwrite_cache \ 29 | -------------------------------------------------------------------------------- /evidence_chain/pretrain/run_evidence_large_eval.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/evidence_chain_data/bart_output_for_pretraining/add_negatives/pre-training/ 4 | export TRAIN_DATA_PATH=evidence_pretrain_train_shortest_esnegs.jsonl 5 | export DEV_DATA_PATH=evidence_pretrain_dev_shortest_esnegs.jsonl 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-1neg-weighted-esneg 7 | python run_classifier.py \ 8 | --model_type roberta \ 9 | --model_name_or_path roberta-base \ 10 | --task_name evidence_chain \ 11 | --do_predict \ 12 | --eval_all_checkpoints \ 13 | --data_dir ${DATA_PATH} \ 14 | --output_dir ${MODEL_PATH} \ 15 | --train_file ${TRAIN_DATA_PATH} \ 16 | --dev_file ${DEV_DATA_PATH} \ 17 | --max_seq_length 512 \ 18 | --per_gpu_train_batch_size 32 \ 19 | --per_gpu_eval_batch_size 64 \ 20 | --learning_rate 3e-5 \ 21 | --evaluate_during_training \ 22 | --overwrite_cache \ 23 | --save_steps 6000 \ 24 | --num_train_epochs 3 \ 25 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 26 | --test_file ${DEV_DATA_PATH} \ 27 | --test_result_dir ${MODEL_PATH}/eval_results.txt \ 28 | -------------------------------------------------------------------------------- /qa_evidence_chain/train_qa_evidence_chain_retrieved.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=1 2 | export BASIC_DIR=./ODQA 3 | export MODEL_NAME=allenai/longformer-base-4096 4 | export TOKENIZERS_PARALLELISM=false 5 | export TRAIN_DATA_PATH=train_ranked_evidence_chain_for_qa_weighted.json 6 | export DEV_DATA_PATH=dev_ranked_evidence_chain_for_qa_weighted.json 7 | export MODEL_DIR=longformer_base_ecmask_weighted_1e-5 8 | export PREFIX=retrieved_blink_ecmask_ec_top1 9 | python train_final_qa_ori.py \ 10 | --do_train \ 11 | --do_eval \ 12 | --model_type longformer \ 13 | --evaluate_during_training \ 14 | --data_dir ${BASIC_DIR}/data/qa_with_evidence_chain \ 15 | --output_dir ${BASIC_DIR}/model/qa_model/${MODEL_DIR} \ 16 | --train_file ${TRAIN_DATA_PATH} \ 17 | --dev_file ${DEV_DATA_PATH} \ 18 | --per_gpu_train_batch_size 2 \ 19 | --gradient_accumulation_steps 1 \ 20 | --learning_rate 1e-5 \ 21 | --num_train_epochs 10.0 \ 22 | --max_seq_length 4096 \ 23 | --doc_stride 3072 \ 24 | --threads 8 \ 25 | --topk_tbs 15 \ 26 | --model_name_or_path ${MODEL_NAME} \ 27 | --save_cache \ 28 | --prefix ${PREFIX} \ 29 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/negative_search_by_bm25/create_db.py: -------------------------------------------------------------------------------- 1 | from elastic_search_wanjun import MyElastic 2 | import os 3 | import shutil 4 | import argparse 5 | if __name__ =="__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--pretrain', action='store_true') 8 | parser.add_argument('--finetune', action='store_true') 9 | 10 | args = parser.parse_args() 11 | ES=MyElastic() 12 | ES.delete_one() 13 | ES.create() 14 | # file_path = '/home/t-wzhong/v-wanzho/ODQA/data/data_wikitable/all_passages.json' 15 | if args.pretrain: 16 | file_path = '/home/t-wzhong/table-odqa/Data/evidence_chain/pre-training/evidence_output_pretrain_shortest.json' 17 | res = ES.bulk_insert_all_chains_pretrain(file_path) 18 | if args.finetune: 19 | basic_dir = '/home/t-wzhong/v-wanzho/ODQA/data/preprocessed_data/evidence_chain/ground-truth-based/ground-truth-evidence-chain/' 20 | file_paths = [os.path.join(basic_dir,file) for file in ['train_gt-ec-weighted.json','dev_gt-ec-weighted.json']] 21 | res = ES.bulk_insert_all_chains_finetune(file_paths) 22 | 23 | # res = ES.bulk_insert_all_doc(basic_path,file_path_list) 24 | # print(res) 25 | # print(len(res['hits']['hits'])) 26 | 27 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/run_evidence_large_eval.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/ground-truth-based 4 | export TRAIN_DATA_PATH=train_preprocessed_normalized_gtmodify_evichain_nx.json 5 | export DEV_DATA_PATH=dev_preprocessed_normalized_gtmodify_evichain_nx.json 6 | export TEST_DATA_PATH=train_preprocessed_normalized_gtmodify_candidate_evichain_addnoun.json 7 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-ecpretrain 8 | export TEST_DATA_PATH=dev_preprocessed_normalized_gtmodify_candidate_evichain_addnoun.json 9 | #export TEST_DATA_PATH=new_chain_blink_dev_evidence_chain_ranking.json 10 | python run_classifier.py \ 11 | --model_type roberta \ 12 | --model_name_or_path roberta-base \ 13 | --task_name evidence_chain \ 14 | --do_predict \ 15 | --eval_all_checkpoints \ 16 | --data_dir ${DATA_PATH} \ 17 | --output_dir ${MODEL_PATH} \ 18 | --train_file ${TRAIN_DATA_PATH} \ 19 | --dev_file ${DEV_DATA_PATH} \ 20 | --max_seq_length 512 \ 21 | --per_gpu_eval_batch_size 40 \ 22 | --overwrite_cache \ 23 | --test_file ${DATA_PATH}/${TEST_DATA_PATH} \ 24 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 25 | --test_result_dir ${DATA_PATH}/dev_ecpretrain_ranker_scores.json -------------------------------------------------------------------------------- /qa_baseline/train_qa_baseline.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=1 2 | export BASIC_DIR=./ODQA 3 | export MODEL_NAME=allenai/longformer-base-4096 4 | export TOKENIZERS_PARALLELISM=false 5 | export TRAIN_DATA_PATH=${BASIC_DIR}/data/preprocessed_data/qa/train_intable_p1_t360.pkl 6 | export DEV_DATA_PATH=${BASIC_DIR}/data/preprocessed_data/qa/dev_intable_p1_t360.pkl 7 | export MODEL_DIR=qa_model_longformer_normalized_gtmodify_rankdoc_top15_newretr 8 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_final_qa.py \ 9 | --do_train \ 10 | --do_eval \ 11 | --model_type longformer \ 12 | --evaluate_during_training \ 13 | --data_dir ${BASIC_DIR}/data/preprocessed_data/qa \ 14 | --output_dir ${BASIC_DIR}/model/${MODEL_DIR} \ 15 | --train_file train_preprocessed_normalized_gtmodify_newretr.json \ 16 | --dev_file dev_preprocessed_normalized_gtmodify_newretr.json\ 17 | --per_gpu_train_batch_size 2 \ 18 | --learning_rate 3e-5 \ 19 | --num_train_epochs 10.0 \ 20 | --max_seq_length 4096 \ 21 | --doc_stride 1024 \ 22 | --threads 8 \ 23 | --topk_tbs 15 \ 24 | --model_name_or_path ${MODEL_NAME} \ 25 | --repreprocess \ 26 | --overwrite_cache \ 27 | --prefix gtmodify_rankdoc_normalized_top15_newretr \ 28 | 2>&1 | tee ${BASIC_DIR}/qa_log/longformer-base-${MODEL_DIR}.log 29 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/run_evidence_large_train.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/ground-truth-based 4 | export TRAIN_DATA_PATH=train_preprocessed_normalized_gtmodify_evichain_addnoun.json 5 | export DEV_DATA_PATH=dev_preprocessed_normalized_gtmodify_evichain_addnoun.json 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-ranking-pretrain 7 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-pretrain-ec-ranker/checkpoint-best 8 | python run_classifier.py \ 9 | --model_type roberta \ 10 | --tokenizer_name roberta-base \ 11 | --config_name roberta-base \ 12 | --model_name_or_path roberta-base \ 13 | --task_name evidence_chain \ 14 | --do_train \ 15 | --do_eval \ 16 | --eval_all_checkpoints \ 17 | --data_dir ${DATA_PATH} \ 18 | --output_dir ${MODEL_PATH} \ 19 | --train_file ${TRAIN_DATA_PATH} \ 20 | --dev_file ${DEV_DATA_PATH} \ 21 | --max_seq_length 512 \ 22 | --per_gpu_train_batch_size 16 \ 23 | --per_gpu_eval_batch_size 16 \ 24 | --learning_rate 1e-5 \ 25 | --num_train_epochs 10 \ 26 | --load_save_pretrain \ 27 | --pretrain_model_dir ${PRETRAIN_MODEL_PATH} \ 28 | 29 | -------------------------------------------------------------------------------- /evidence_chain/pretrain/run_evidence_base_train.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=/wanjun/ODQA 3 | export BASIC_PATH=/home/t-wzhong/v-wanzho/ODQA 4 | export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/pre-training/fake_question_pretraining 5 | export TRAIN_DATA_PATH=fake_question_pretraining_train.jsonl 6 | export DEV_DATA_PATH=fake_question_pretraining_dev.jsonl 7 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-fake 8 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python run_classifier.py \ 9 | --model_type roberta \ 10 | --model_name_or_path roberta-base \ 11 | --task_name evidence_chain \ 12 | --overwrite_cache \ 13 | --do_train \ 14 | --do_eval \ 15 | --eval_all_checkpoints \ 16 | --data_dir ${DATA_PATH} \ 17 | --output_dir ${MODEL_PATH} \ 18 | --train_file ${TRAIN_DATA_PATH} \ 19 | --dev_file ${DEV_DATA_PATH} \ 20 | --max_seq_length 512 \ 21 | --per_gpu_train_batch_size 16 \ 22 | --per_gpu_eval_batch_size 16 \ 23 | --learning_rate 1e-5 \ 24 | --num_train_epochs 10 25 | #--test_file /home/dutang/FEVER/arranged_data/bert_data/Evidence/eval_file/train_all_sentence_evidence_title_all.tsv \ 26 | #--pred_model_dir /home/dutang/FEVER/arranged_models/xlnet_torch_models/evidence_large_title/checkpoint-best \ 27 | #--test_result_dir /home/dutang/FEVER/arranged_result/xlnet/evidence/train_evidence_xlnet_large_title_score.tsv 28 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/run_evidence_large_eval.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | #export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/ground-truth-based/candidate_chain 4 | export DATA_PATH=./ODQA/data/evidence_chain_data/retrieval-based/candidate_chain 5 | export TRAIN_DATA_PATH=train_preprocessed_normalized_gtmodify_evichain_nx.json 6 | export DEV_DATA_PATH=dev_preprocessed_normalized_gtmodify_evichain_nx.json 7 | export TEST_DATA_PATH=dev_gttb_candidate_ec.jsonl 8 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-weighted-multineg 9 | for part in 2 3 10 | do 11 | export TEST_DATA_PATH=train_evidence_chain_weighted_ranking_${part}.json 12 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python run_classifier.py \ 13 | --model_type roberta \ 14 | --model_name_or_path roberta-large \ 15 | --task_name evidence_chain \ 16 | --do_predict \ 17 | --eval_all_checkpoints \ 18 | --data_dir ${DATA_PATH} \ 19 | --output_dir ${MODEL_PATH} \ 20 | --train_file ${TRAIN_DATA_PATH} \ 21 | --dev_file ${DEV_DATA_PATH} \ 22 | --max_seq_length 512 \ 23 | --per_gpu_eval_batch_size 20 \ 24 | --overwrite_cache \ 25 | --test_file ${DATA_PATH}/${TEST_DATA_PATH} \ 26 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 27 | --test_result_dir ${DATA_PATH}/../scored_chains/weighted-large-5neg/train_evidence_chain_weighted_scores_${part}.json 28 | sleep 30 29 | done 30 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | 41 | raise RuntimeError('Invalid tokenizer: %s' % name) 42 | 43 | 44 | def get_annotators_for_args(args): 45 | annotators = set() 46 | if args.use_pos: 47 | annotators.add('pos') 48 | if args.use_lemma: 49 | annotators.add('lemma') 50 | if args.use_ner: 51 | annotators.add('ner') 52 | return annotators 53 | 54 | 55 | def get_annotators_for_model(model): 56 | return get_annotators_for_args(model.args) 57 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ExternalAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class ExternalAttention(nn.Module): 9 | 10 | def __init__(self, d_model,S=64): 11 | super().__init__() 12 | self.mk=nn.Linear(d_model,S,bias=False) 13 | self.mv=nn.Linear(S,d_model,bias=False) 14 | self.softmax=nn.Softmax(dim=1) 15 | self.init_weights() 16 | 17 | 18 | def init_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | init.kaiming_normal_(m.weight, mode='fan_out') 22 | if m.bias is not None: 23 | init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | init.constant_(m.weight, 1) 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.Linear): 28 | init.normal_(m.weight, std=0.001) 29 | if m.bias is not None: 30 | init.constant_(m.bias, 0) 31 | 32 | def forward(self, queries): 33 | attn=self.mk(queries) #bs,n,S 34 | attn=self.softmax(attn) #bs,n,S 35 | attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S 36 | out=self.mv(attn) #bs,n,d_model 37 | 38 | return out 39 | 40 | 41 | if __name__ == '__main__': 42 | input=torch.randn(50,49,512) 43 | ea = ExternalAttention(d_model=512,S=8) 44 | output=ea(input) 45 | print(output.shape) 46 | 47 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ECAAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from collections import OrderedDict 6 | 7 | 8 | 9 | class ECAAttention(nn.Module): 10 | 11 | def __init__(self, kernel_size=3): 12 | super().__init__() 13 | self.gap=nn.AdaptiveAvgPool2d(1) 14 | self.conv=nn.Conv1d(1,1,kernel_size=kernel_size,padding=(kernel_size-1)//2) 15 | self.sigmoid=nn.Sigmoid() 16 | 17 | def init_weights(self): 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | init.kaiming_normal_(m.weight, mode='fan_out') 21 | if m.bias is not None: 22 | init.constant_(m.bias, 0) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | init.constant_(m.weight, 1) 25 | init.constant_(m.bias, 0) 26 | elif isinstance(m, nn.Linear): 27 | init.normal_(m.weight, std=0.001) 28 | if m.bias is not None: 29 | init.constant_(m.bias, 0) 30 | 31 | def forward(self, x): 32 | y=self.gap(x) #bs,c,1,1 33 | y=y.squeeze(-1).permute(0,2,1) #bs,1,c 34 | y=self.conv(y) #bs,1,c 35 | y=self.sigmoid(y) #bs,1,c 36 | y=y.permute(0,2,1).unsqueeze(-1) #bs,c,1,1 37 | return x*y.expand_as(x) 38 | 39 | 40 | 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | input=torch.randn(50,512,7,7) 46 | eca = ECAAttention(kernel_size=3) 47 | output=eca(input) 48 | print(output.shape) 49 | 50 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SEAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SEAttention(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=16): 11 | super().__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc = nn.Sequential( 14 | nn.Linear(channel, channel // reduction, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(channel // reduction, channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | 20 | 21 | def init_weights(self): 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | init.kaiming_normal_(m.weight, mode='fan_out') 25 | if m.bias is not None: 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | init.constant_(m.weight, 1) 29 | init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.Linear): 31 | init.normal_(m.weight, std=0.001) 32 | if m.bias is not None: 33 | init.constant_(m.bias, 0) 34 | 35 | def forward(self, x): 36 | b, c, _, _ = x.size() 37 | y = self.avg_pool(x).view(b, c) 38 | y = self.fc(y).view(b, c, 1, 1) 39 | return x * y.expand_as(x) 40 | 41 | 42 | if __name__ == '__main__': 43 | input=torch.randn(50,512,7,7) 44 | se = SEAttention(channel=512,reduction=8) 45 | output=se(input) 46 | print(output.shape) 47 | 48 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | 12 | 13 | class DocDB(object): 14 | """Sqlite backed document storage. 15 | 16 | Implements get_doc_text(doc_id). 17 | """ 18 | 19 | def __init__(self, db_path): 20 | self.path = db_path 21 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 22 | 23 | def __enter__(self): 24 | return self 25 | 26 | def __exit__(self, *args): 27 | self.close() 28 | 29 | def path(self): 30 | """Return the path to the file that backs this database.""" 31 | return self.path 32 | 33 | def close(self): 34 | """Close the connection to the database.""" 35 | self.connection.close() 36 | 37 | def get_doc_ids(self): 38 | """Fetch all ids of docs stored in the db.""" 39 | cursor = self.connection.cursor() 40 | cursor.execute("SELECT id FROM documents") 41 | results = [r[0] for r in cursor.fetchall()] 42 | cursor.close() 43 | return results 44 | 45 | def get_doc_text(self, doc_id): 46 | """Fetch the raw text of the doc for 'doc_id'.""" 47 | cursor = self.connection.cursor() 48 | cursor.execute( 49 | "SELECT text FROM documents WHERE id = ?", 50 | #(utils.normalize(doc_id),) 51 | (doc_id, ) 52 | ) 53 | result = cursor.fetchone() 54 | cursor.close() 55 | return result if result is None else result[0] 56 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/run_evidence_train_from_pretrain.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/evidence_chain_data/ground-truth-based/ground-truth-evidence-chain 4 | export TRAIN_DATA_PATH=train_gt-ec-weighted.json 5 | export DEV_DATA_PATH=dev_gt-ec-weighted.json 6 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-new/checkpoint-best 7 | export STEP=${step} 8 | export PREFIX=southes-nonpretrain 9 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/ft_pretrained/roberta-base-${PREFIX} 10 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-3neg-all-innerneg/checkpoint-${STEP} 11 | 12 | 13 | python run_classifier.py \ 14 | --model_type roberta \ 15 | --tokenizer_name roberta-base \ 16 | --config_name roberta-base \ 17 | --model_name_or_path roberta-base \ 18 | --task_name evidence_chain \ 19 | --do_train \ 20 | --do_eval \ 21 | --eval_all_checkpoints \ 22 | --data_dir ${DATA_PATH} \ 23 | --output_dir ${MODEL_PATH} \ 24 | --train_file ${TRAIN_DATA_PATH} \ 25 | --dev_file ${DEV_DATA_PATH} \ 26 | --max_seq_length 512 \ 27 | --per_gpu_train_batch_size 16 \ 28 | --per_gpu_eval_batch_size 16 \ 29 | --learning_rate 1e-5 \ 30 | --num_train_epochs 1 \ 31 | 32 | export TEST_DATA_PATH=dev_gttb_candidate_ec_weighted.json 33 | 34 | python run_classifier.py \ 35 | --model_type roberta \ 36 | --model_name_or_path roberta-base \ 37 | --task_name evidence_chain \ 38 | --do_predict \ 39 | --eval_all_checkpoints \ 40 | --data_dir ${DATA_PATH} \ 41 | --output_dir ${MODEL_PATH} \ 42 | --max_seq_length 512 \ 43 | --per_gpu_eval_batch_size 40 \ 44 | --test_file ${DATA_PATH}/../candidate_chain/${TEST_DATA_PATH} \ 45 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 46 | --test_result_dir ${DATA_PATH}/../scored_chain/pretrained/dev_roberta_base_scored_ec_addtab_weighted_${PREFIX}.json 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SGE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SpatialGroupEnhance(nn.Module): 9 | 10 | def __init__(self, groups): 11 | super().__init__() 12 | self.groups=groups 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.weight=nn.Parameter(torch.zeros(1,groups,1,1)) 15 | self.bias=nn.Parameter(torch.zeros(1,groups,1,1)) 16 | self.sig=nn.Sigmoid() 17 | self.init_weights() 18 | 19 | 20 | def init_weights(self): 21 | for m in self.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | init.kaiming_normal_(m.weight, mode='fan_out') 24 | if m.bias is not None: 25 | init.constant_(m.bias, 0) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | init.constant_(m.weight, 1) 28 | init.constant_(m.bias, 0) 29 | elif isinstance(m, nn.Linear): 30 | init.normal_(m.weight, std=0.001) 31 | if m.bias is not None: 32 | init.constant_(m.bias, 0) 33 | 34 | def forward(self, x): 35 | b, c, h,w=x.shape 36 | x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w 37 | xn=x*self.avg_pool(x) #bs*g,dim//g,h,w 38 | xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w 39 | t=xn.view(b*self.groups,-1) #bs*g,h*w 40 | 41 | t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w 42 | std=t.std(dim=1,keepdim=True)+1e-5 43 | t=t/std #bs*g,h*w 44 | t=t.view(b,self.groups,h,w) #bs,g,h*w 45 | 46 | t=t*self.weight+self.bias #bs,g,h*w 47 | t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w 48 | x=x*self.sig(t) 49 | x=x.view(b,c,h,w) 50 | 51 | return x 52 | 53 | 54 | if __name__ == '__main__': 55 | input=torch.randn(50,512,7,7) 56 | sge = SpatialGroupEnhance(groups=8) 57 | output=sge(input) 58 | print(output.shape) 59 | 60 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/DANet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from .SelfAttention import ScaledDotProductAttention 6 | from .SimplifiedSelfAttention import SimplifiedScaledDotProductAttention 7 | 8 | class PositionAttentionModule(nn.Module): 9 | 10 | def __init__(self,d_model=512,kernel_size=3,H=7,W=7): 11 | super().__init__() 12 | self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2) 13 | self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1) 14 | 15 | def forward(self,x): 16 | bs,c,h,w=x.shape 17 | y=self.cnn(x) 18 | y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c 19 | y=self.pa(y,y,y) #bs,h*w,c 20 | return y 21 | 22 | 23 | class ChannelAttentionModule(nn.Module): 24 | 25 | def __init__(self,d_model=512,kernel_size=3,H=7,W=7): 26 | super().__init__() 27 | self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2) 28 | self.pa=SimplifiedScaledDotProductAttention(H*W,h=1) 29 | 30 | def forward(self,x): 31 | bs,c,h,w=x.shape 32 | y=self.cnn(x) 33 | y=y.view(bs,c,-1) #bs,c,h*w 34 | y=self.pa(y,y,y) #bs,c,h*w 35 | return y 36 | 37 | 38 | 39 | 40 | class DAModule(nn.Module): 41 | 42 | def __init__(self,d_model=512,kernel_size=3,H=7,W=7): 43 | super().__init__() 44 | self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7) 45 | self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7) 46 | 47 | def forward(self,input): 48 | bs,c,h,w=input.shape 49 | p_out=self.position_attention_module(input) 50 | c_out=self.channel_attention_module(input) 51 | p_out=p_out.permute(0,2,1).view(bs,c,h,w) 52 | c_out=c_out.view(bs,c,h,w) 53 | return p_out+c_out 54 | 55 | 56 | if __name__ == '__main__': 57 | input=torch.randn(50,512,7,7) 58 | danet=DAModule(d_model=512,kernel_size=3,H=7,W=7) 59 | print(danet(input).shape) 60 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ViP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MLP(nn.Module): 6 | def __init__(self,in_features,hidden_features,out_features,act_layer=nn.GELU,drop=0.1): 7 | super().__init__() 8 | self.fc1=nn.Linear(in_features,hidden_features) 9 | self.act=act_layer() 10 | self.fc2=nn.Linear(hidden_features,out_features) 11 | self.drop=nn.Dropout(drop) 12 | 13 | def forward(self, x) : 14 | return self.drop(self.fc2(self.drop(self.act(self.fc1(x))))) 15 | 16 | class WeightedPermuteMLP(nn.Module): 17 | def __init__(self,dim,seg_dim=8, qkv_bias=False, proj_drop=0.): 18 | super().__init__() 19 | self.seg_dim=seg_dim 20 | 21 | self.mlp_c=nn.Linear(dim,dim,bias=qkv_bias) 22 | self.mlp_h=nn.Linear(dim,dim,bias=qkv_bias) 23 | self.mlp_w=nn.Linear(dim,dim,bias=qkv_bias) 24 | 25 | self.reweighting=MLP(dim,dim//4,dim*3) 26 | 27 | self.proj=nn.Linear(dim,dim) 28 | self.proj_drop=nn.Dropout(proj_drop) 29 | 30 | def forward(self,x) : 31 | B,H,W,C=x.shape 32 | 33 | c_embed=self.mlp_c(x) 34 | 35 | S=C//self.seg_dim 36 | h_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,2,1,4).reshape(B,self.seg_dim,W,H*S) 37 | h_embed=self.mlp_h(h_embed).reshape(B,self.seg_dim,W,H,S).permute(0,3,2,1,4).reshape(B,H,W,C) 38 | 39 | w_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,1,2,4).reshape(B,self.seg_dim,H,W*S) 40 | w_embed=self.mlp_w(w_embed).reshape(B,self.seg_dim,H,W,S).permute(0,2,3,1,4).reshape(B,H,W,C) 41 | 42 | weight=(c_embed+h_embed+w_embed).permute(0,3,1,2).flatten(2).mean(2) 43 | weight=self.reweighting(weight).reshape(B,C,3).permute(2,0,1).softmax(0).unsqueeze(2).unsqueeze(2) 44 | 45 | x=h_embed*weight[0]+w_embed*weight[1]+h_embed*weight[2] 46 | 47 | x=self.proj_drop(self.proj(x)) 48 | 49 | return x 50 | 51 | 52 | 53 | if __name__ == '__main__': 54 | input=torch.randn(64,8,8,512) 55 | seg_dim=8 56 | vip=WeightedPermuteMLP(512,seg_dim) 57 | out=vip(input) 58 | print(out.shape) 59 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/AFT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class AFT_FULL(nn.Module): 9 | 10 | def __init__(self, d_model,n=49,simple=False): 11 | 12 | super(AFT_FULL, self).__init__() 13 | self.fc_q = nn.Linear(d_model, d_model) 14 | self.fc_k = nn.Linear(d_model, d_model) 15 | self.fc_v = nn.Linear(d_model,d_model) 16 | if(simple): 17 | self.position_biases=torch.zeros((n,n)) 18 | else: 19 | self.position_biases=nn.Parameter(torch.ones((n,n))) 20 | self.d_model = d_model 21 | self.n=n 22 | self.sigmoid=nn.Sigmoid() 23 | 24 | self.init_weights() 25 | 26 | 27 | def init_weights(self): 28 | for m in self.modules(): 29 | if isinstance(m, nn.Conv2d): 30 | init.kaiming_normal_(m.weight, mode='fan_out') 31 | if m.bias is not None: 32 | init.constant_(m.bias, 0) 33 | elif isinstance(m, nn.BatchNorm2d): 34 | init.constant_(m.weight, 1) 35 | init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.Linear): 37 | init.normal_(m.weight, std=0.001) 38 | if m.bias is not None: 39 | init.constant_(m.bias, 0) 40 | 41 | def forward(self, input): 42 | 43 | bs, n,dim = input.shape 44 | 45 | q = self.fc_q(input) #bs,n,dim 46 | k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim 47 | v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim 48 | 49 | numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim 50 | denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim 51 | 52 | out=(numerator/denominator) #n,bs,dim 53 | out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim 54 | 55 | return out 56 | 57 | 58 | if __name__ == '__main__': 59 | input=torch.randn(50,49,512) 60 | aft_full = AFT_FULL(d_model=512, n=49) 61 | output=aft_full(input) 62 | print(output.shape) 63 | 64 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/A2Atttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | 7 | 8 | 9 | class DoubleAttention(nn.Module): 10 | 11 | def __init__(self, in_channels,c_m,c_n,reconstruct = True): 12 | super().__init__() 13 | self.in_channels=in_channels 14 | self.reconstruct = reconstruct 15 | self.c_m=c_m 16 | self.c_n=c_n 17 | self.convA=nn.Conv2d(in_channels,c_m,1) 18 | self.convB=nn.Conv2d(in_channels,c_n,1) 19 | self.convV=nn.Conv2d(in_channels,c_n,1) 20 | if self.reconstruct: 21 | self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size = 1) 22 | self.init_weights() 23 | 24 | 25 | def init_weights(self): 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | init.kaiming_normal_(m.weight, mode='fan_out') 29 | if m.bias is not None: 30 | init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.BatchNorm2d): 32 | init.constant_(m.weight, 1) 33 | init.constant_(m.bias, 0) 34 | elif isinstance(m, nn.Linear): 35 | init.normal_(m.weight, std=0.001) 36 | if m.bias is not None: 37 | init.constant_(m.bias, 0) 38 | 39 | def forward(self, x): 40 | b, c, h,w=x.shape 41 | assert c==self.in_channels 42 | A=self.convA(x) #b,c_m,h,w 43 | B=self.convB(x) #b,c_n,h,w 44 | V=self.convV(x) #b,c_n,h,w 45 | tmpA=A.view(b,self.c_m,-1) 46 | attention_maps=F.softmax(B.view(b,self.c_n,-1)) 47 | attention_vectors=F.softmax(V.view(b,self.c_n,-1)) 48 | # step 1: feature gating 49 | global_descriptors=torch.bmm(tmpA,attention_maps.permute(0,2,1)) #b.c_m,c_n 50 | # step 2: feature distribution 51 | tmpZ = global_descriptors.matmul(attention_vectors) #b,c_m,h*w 52 | tmpZ=tmpZ.view(b,self.c_m,h,w) #b,c_m,h,w 53 | if self.reconstruct: 54 | tmpZ=self.conv_reconstruct(tmpZ) 55 | 56 | return tmpZ 57 | 58 | 59 | if __name__ == '__main__': 60 | input=torch.randn(50,512,7,7) 61 | a2 = DoubleAttention(512,128,128,True) 62 | output=a2(input) 63 | print(output.shape) 64 | 65 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Tokenizer that is backed by spaCy (spacy.io). 8 | 9 | Requires spaCy package and the spaCy english model. 10 | """ 11 | 12 | import spacy 13 | import copy 14 | from .tokenizer import Tokens, Tokenizer 15 | 16 | 17 | class SpacyTokenizer(Tokenizer): 18 | 19 | def __init__(self, **kwargs): 20 | """ 21 | Args: 22 | annotators: set that can include pos, lemma, and ner. 23 | model: spaCy model to use (either path, or keyword like 'en'). 24 | """ 25 | model = kwargs.get('model', 'en') 26 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 27 | nlp_kwargs = {'parser': False} 28 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 29 | nlp_kwargs['tagger'] = False 30 | if 'ner' not in self.annotators: 31 | nlp_kwargs['entity'] = False 32 | self.nlp = spacy.load(model, **nlp_kwargs) 33 | 34 | def tokenize(self, text): 35 | # We don't treat new lines as tokens. 36 | clean_text = text.replace('\n', ' ') 37 | tokens = self.nlp.tokenizer(clean_text) 38 | if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 39 | self.nlp.tagger(tokens) 40 | if 'ner' in self.annotators: 41 | self.nlp.entity(tokens) 42 | 43 | data = [] 44 | for i in range(len(tokens)): 45 | # Get whitespace 46 | start_ws = tokens[i].idx 47 | if i + 1 < len(tokens): 48 | end_ws = tokens[i + 1].idx 49 | else: 50 | end_ws = tokens[i].idx + len(tokens[i].text) 51 | 52 | data.append(( 53 | tokens[i].text, 54 | text[start_ws: end_ws], 55 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 56 | tokens[i].tag_, 57 | tokens[i].lemma_, 58 | tokens[i].ent_type_, 59 | )) 60 | 61 | # Set special option for non-entity tag: '' vs 'O' in spaCy 62 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 63 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/PSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class PSA(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=4,S=4): 11 | super().__init__() 12 | self.S=S 13 | 14 | self.convs=[] 15 | for i in range(S): 16 | self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1)) 17 | 18 | self.se_blocks=[] 19 | for i in range(S): 20 | self.se_blocks.append(nn.Sequential( 21 | nn.AdaptiveAvgPool2d(1), 22 | nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False), 25 | nn.Sigmoid() 26 | )) 27 | 28 | self.softmax=nn.Softmax(dim=1) 29 | 30 | 31 | def init_weights(self): 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | init.kaiming_normal_(m.weight, mode='fan_out') 35 | if m.bias is not None: 36 | init.constant_(m.bias, 0) 37 | elif isinstance(m, nn.BatchNorm2d): 38 | init.constant_(m.weight, 1) 39 | init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.Linear): 41 | init.normal_(m.weight, std=0.001) 42 | if m.bias is not None: 43 | init.constant_(m.bias, 0) 44 | 45 | def forward(self, x): 46 | b, c, h, w = x.size() 47 | 48 | #Step1:SPC module 49 | SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w 50 | for idx,conv in enumerate(self.convs): 51 | SPC_out[:,idx,:,:,:]=conv(SPC_out[:,idx,:,:,:]) 52 | 53 | #Step2:SE weight 54 | SE_out=torch.zeros_like(SPC_out) 55 | for idx,se in enumerate(self.se_blocks): 56 | SE_out[:,idx,:,:,:]=se(SPC_out[:,idx,:,:,:]) 57 | 58 | #Step3:Softmax 59 | softmax_out=self.softmax(SE_out) 60 | 61 | #Step4:SPA 62 | PSA_out=SPC_out*softmax_out 63 | PSA_out=PSA_out.view(b,-1,h,w) 64 | 65 | return PSA_out 66 | 67 | 68 | if __name__ == '__main__': 69 | input=torch.randn(50,512,7,7) 70 | psa = PSA(channel=512,reduction=8) 71 | output=psa(input) 72 | print(output.shape) 73 | 74 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/OutlookAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | import math 6 | from torch.nn import functional as F 7 | 8 | class OutlookAttention(nn.Module): 9 | 10 | def __init__(self,dim,num_heads=1,kernel_size=3,padding=1,stride=1,qkv_bias=False, 11 | attn_drop=0.1): 12 | super().__init__() 13 | self.dim=dim 14 | self.num_heads=num_heads 15 | self.head_dim=dim//num_heads 16 | self.kernel_size=kernel_size 17 | self.padding=padding 18 | self.stride=stride 19 | self.scale=self.head_dim**(-0.5) 20 | 21 | self.v_pj=nn.Linear(dim,dim,bias=qkv_bias) 22 | self.attn=nn.Linear(dim,kernel_size**4*num_heads) 23 | 24 | self.attn_drop=nn.Dropout(attn_drop) 25 | self.proj=nn.Linear(dim,dim) 26 | self.proj_drop=nn.Dropout(attn_drop) 27 | 28 | self.unflod=nn.Unfold(kernel_size,padding,stride) #手动卷积 29 | self.pool=nn.AvgPool2d(kernel_size=stride,stride=stride,ceil_mode=True) 30 | 31 | def forward(self, x) : 32 | B,H,W,C=x.shape 33 | 34 | #映射到新的特征v 35 | v=self.v_pj(x).permute(0,3,1,2) #B,C,H,W 36 | h,w=math.ceil(H/self.stride),math.ceil(W/self.stride) 37 | v=self.unflod(v).reshape(B,self.num_heads,self.head_dim,self.kernel_size*self.kernel_size,h*w).permute(0,1,4,3,2) #B,num_head,H*W,kxk,head_dim 38 | 39 | #生成Attention Map 40 | attn=self.pool(x.permute(0,3,1,2)).permute(0,2,3,1) #B,H,W,C 41 | attn=self.attn(attn).reshape(B,h*w,self.num_heads,self.kernel_size*self.kernel_size \ 42 | ,self.kernel_size*self.kernel_size).permute(0,2,1,3,4) #B,num_head,H*W,kxk,kxk 43 | attn=self.scale*attn 44 | attn=attn.softmax(-1) 45 | attn=self.attn_drop(attn) 46 | 47 | #获取weighted特征 48 | out=(attn @ v).permute(0,1,4,3,2).reshape(B,C*self.kernel_size*self.kernel_size,h*w) #B,dimxkxk,H*W 49 | out=F.fold(out,output_size=(H,W),kernel_size=self.kernel_size, 50 | padding=self.padding,stride=self.stride) #B,C,H,W 51 | out=self.proj(out.permute(0,2,3,1)) #B,H,W,C 52 | out=self.proj_drop(out) 53 | 54 | return out 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | input=torch.randn(50,28,28,512) 60 | outlook = OutlookAttention(dim=512) 61 | output=outlook(input) 62 | print(output.shape) 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SKAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from collections import OrderedDict 6 | 7 | 8 | 9 | class SKAttention(nn.Module): 10 | 11 | def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32): 12 | super().__init__() 13 | self.d=max(L,channel//reduction) 14 | self.convs=nn.ModuleList([]) 15 | for k in kernels: 16 | self.convs.append( 17 | nn.Sequential(OrderedDict([ 18 | ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)), 19 | ('bn',nn.BatchNorm2d(channel)), 20 | ('relu',nn.ReLU()) 21 | ])) 22 | ) 23 | self.fc=nn.Linear(channel,self.d) 24 | self.fcs=nn.ModuleList([]) 25 | for i in range(len(kernels)): 26 | self.fcs.append(nn.Linear(self.d,channel)) 27 | self.softmax=nn.Softmax(dim=1) 28 | 29 | 30 | def init_weights(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal_(m.weight, mode='fan_out') 34 | if m.bias is not None: 35 | init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant_(m.weight, 1) 38 | init.constant_(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal_(m.weight, std=0.001) 41 | if m.bias is not None: 42 | init.constant_(m.bias, 0) 43 | 44 | def forward(self, x): 45 | bs, c, _, _ = x.size() 46 | conv_outs=[] 47 | ### split 48 | for conv in self.convs: 49 | conv_outs.append(conv(x)) 50 | feats=torch.stack(conv_outs,0)#k,bs,channel,h,w 51 | 52 | ### fuse 53 | U=sum(conv_outs) #bs,c,h,w 54 | 55 | ### reduction channel 56 | S=U.mean(-1).mean(-1) #bs,c 57 | Z=self.fc(S) #bs,d 58 | 59 | ### calculate attention weight 60 | weights=[] 61 | for fc in self.fcs: 62 | weight=fc(Z) 63 | weights.append(weight.view(bs,c,1,1)) #bs,channel 64 | attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1 65 | attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1 66 | 67 | ### fuse 68 | V=(attention_weughts*feats).sum(0) 69 | return V 70 | 71 | 72 | 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | input=torch.randn(50,512,7,7) 78 | se = SKAttention(channel=512,reduction=8) 79 | output=se(input) 80 | print(output.shape) 81 | 82 | -------------------------------------------------------------------------------- /preprocess/merge_ec_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import copy 5 | from scipy.special import softmax 6 | project_dir = './ODQA' 7 | basic_dir = f'{project_dir}/data/evidence_chain_data/retrieval-based/scored_chains/' 8 | #weighted 9 | # gt_ec_file = f'{project_dir}/data/preprocessed_data/evidence_chain/ground-truth-based/scored_chains/train_gttb_evidence_chain_weighted_scores.json' 10 | gt_ec_file = f'{project_dir}/data/evidence_chain_data/ground-truth-based/scored_chains/dev_roberta_large_scored_ec_addtab_weighted_multineg.json' 11 | #original 12 | 13 | file_names = ['dev_evidence_chain_weighted_scores.json'] 14 | all_data = [] 15 | gt_ec_data = json.load(open(gt_ec_file,'r',encoding='utf8')) 16 | error_count = 0 17 | all_cnt = 0 18 | 19 | for fid,file in enumerate(file_names): 20 | start = 0 #dev 21 | # start = int(len(gt_ec_data) / 4) * fid 22 | end = len(gt_ec_data)#dev 23 | # end = int(len(gt_ec_data) / 4) * (fid + 1) if fid != (3) else len(gt_ec_data) 24 | data = json.load(open(os.path.join(basic_dir,file),'r',encoding='utf8')) 25 | output_data = [] 26 | for idx, item in tqdm(enumerate(data)): 27 | assert (item['question_id'] == gt_ec_data[start + idx]['question_id']) 28 | output_data.append(copy.deepcopy(item)) 29 | output_data[-1]['positive_table_blocks'] = copy.deepcopy(gt_ec_data[start + idx]['positive_table_blocks']) 30 | tmp_retrived_blocks = copy.deepcopy(item['retrieved_tbs'][:15]) 31 | for tbid,block in enumerate(tmp_retrived_blocks): 32 | # assert(block['table_id']==gt_ec_data[start+idx]['retrieved_tbs'][tbid]['table_id'] and block['row_id']==gt_ec_data[start+idx]['retrieved_tbs'][tbid]['row_id']) 33 | all_cnt+=1 34 | if block['candidate_evidence_chains']: 35 | ranked_ec = sorted(block['candidate_evidence_chains'], key=lambda k: softmax(k['score'])[1], reverse=True) 36 | if len(ranked_ec)>=3: 37 | selected = copy.deepcopy(ranked_ec[:3]) 38 | else: 39 | selected = copy.deepcopy(ranked_ec) 40 | del tmp_retrived_blocks[tbid]['candidate_evidence_chains'] 41 | tmp_retrived_blocks[tbid]['candidate_evidence_chains'] = selected 42 | else: 43 | error_count+=1 44 | output_data[-1]['retrieved_tbs'] = tmp_retrived_blocks 45 | assert(len(output_data[-1]['retrieved_tbs'])==15) 46 | 47 | all_data.extend(output_data) 48 | print('error rate: {}'.format(error_count/all_cnt)) 49 | with open(os.path.join(basic_dir,f'{project_dir}/data/qa_with_evidence_chain/dev_ranked_evidence_chain_for_qa_weighted.json'),'w',encoding='utf8') as outf: 50 | json.dump(all_data,outf) 51 | 52 | 53 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/CBAM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | def __init__(self,channel,reduction=16): 10 | super().__init__() 11 | self.maxpool=nn.AdaptiveMaxPool2d(1) 12 | self.avgpool=nn.AdaptiveAvgPool2d(1) 13 | self.se=nn.Sequential( 14 | nn.Conv2d(channel,channel//reduction,1,bias=False), 15 | nn.ReLU(), 16 | nn.Conv2d(channel//reduction,channel,1,bias=False) 17 | ) 18 | self.sigmoid=nn.Sigmoid() 19 | 20 | def forward(self, x) : 21 | max_result=self.maxpool(x) 22 | avg_result=self.avgpool(x) 23 | max_out=self.se(max_result) 24 | avg_out=self.se(avg_result) 25 | output=self.sigmoid(max_out+avg_out) 26 | return output 27 | 28 | class SpatialAttention(nn.Module): 29 | def __init__(self,kernel_size=7): 30 | super().__init__() 31 | self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2) 32 | self.sigmoid=nn.Sigmoid() 33 | 34 | def forward(self, x) : 35 | max_result,_=torch.max(x,dim=1,keepdim=True) 36 | avg_result=torch.mean(x,dim=1,keepdim=True) 37 | result=torch.cat([max_result,avg_result],1) 38 | output=self.conv(result) 39 | output=self.sigmoid(output) 40 | return output 41 | 42 | 43 | 44 | class CBAMBlock(nn.Module): 45 | 46 | def __init__(self, channel=512,reduction=16,kernel_size=49): 47 | super().__init__() 48 | self.ca=ChannelAttention(channel=channel,reduction=reduction) 49 | self.sa=SpatialAttention(kernel_size=kernel_size) 50 | 51 | 52 | def init_weights(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | init.kaiming_normal_(m.weight, mode='fan_out') 56 | if m.bias is not None: 57 | init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm2d): 59 | init.constant_(m.weight, 1) 60 | init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | init.normal_(m.weight, std=0.001) 63 | if m.bias is not None: 64 | init.constant_(m.bias, 0) 65 | 66 | def forward(self, x): 67 | b, c, _, _ = x.size() 68 | residual=x 69 | out=x*self.ca(x) 70 | out=out*self.sa(out) 71 | return out+residual 72 | 73 | 74 | if __name__ == '__main__': 75 | input=torch.randn(50,512,7,7) 76 | kernel_size=input.shape[2] 77 | cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size) 78 | output=cbam(input) 79 | print(output.shape) 80 | 81 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ShuffleAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | class ShuffleAttention(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=16,G=8): 11 | super().__init__() 12 | self.G=G 13 | self.channel=channel 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G)) 16 | self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1)) 17 | self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1)) 18 | self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1)) 19 | self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1)) 20 | self.sigmoid=nn.Sigmoid() 21 | 22 | 23 | def init_weights(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | init.kaiming_normal_(m.weight, mode='fan_out') 27 | if m.bias is not None: 28 | init.constant_(m.bias, 0) 29 | elif isinstance(m, nn.BatchNorm2d): 30 | init.constant_(m.weight, 1) 31 | init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.Linear): 33 | init.normal_(m.weight, std=0.001) 34 | if m.bias is not None: 35 | init.constant_(m.bias, 0) 36 | 37 | 38 | @staticmethod 39 | def channel_shuffle(x, groups): 40 | b, c, h, w = x.shape 41 | x = x.reshape(b, groups, -1, h, w) 42 | x = x.permute(0, 2, 1, 3, 4) 43 | 44 | # flatten 45 | x = x.reshape(b, -1, h, w) 46 | 47 | return x 48 | 49 | def forward(self, x): 50 | b, c, h, w = x.size() 51 | #group into subfeatures 52 | x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w 53 | 54 | #channel_split 55 | x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w 56 | 57 | #channel attention 58 | x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1 59 | x_channel=self.cweight*x_channel+self.cweight #bs*G,c//(2*G),1,1 60 | x_channel=x_0*self.sigmoid(x_channel) 61 | 62 | #spatial attention 63 | x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w 64 | x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w 65 | x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w 66 | 67 | # concatenate along channel axis 68 | out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w 69 | out=out.contiguous().view(b,-1,h,w) 70 | 71 | # channel shuffle 72 | out = self.channel_shuffle(out, 2) 73 | return out 74 | 75 | 76 | if __name__ == '__main__': 77 | input=torch.randn(50,512,7,7) 78 | se = ShuffleAttention(channel=512,G=8) 79 | output=se(input) 80 | print(output.shape) 81 | 82 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/CoAtNet.py: -------------------------------------------------------------------------------- 1 | from torch import nn, sqrt 2 | import torch 3 | import sys 4 | from math import sqrt 5 | sys.path.append('.') 6 | from conv.MBConv import MBConvBlock 7 | from attention.SelfAttention import ScaledDotProductAttention 8 | 9 | class CoAtNet(nn.Module): 10 | def __init__(self,in_ch,image_size,out_chs=[64,96,192,384,768]): 11 | super().__init__() 12 | self.out_chs=out_chs 13 | self.maxpool2d=nn.MaxPool2d(kernel_size=2,stride=2) 14 | self.maxpool1d = nn.MaxPool1d(kernel_size=2, stride=2) 15 | 16 | self.s0=nn.Sequential( 17 | nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1), 18 | nn.ReLU(), 19 | nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1) 20 | ) 21 | self.mlp0=nn.Sequential( 22 | nn.Conv2d(in_ch,out_chs[0],kernel_size=1), 23 | nn.ReLU(), 24 | nn.Conv2d(out_chs[0],out_chs[0],kernel_size=1) 25 | ) 26 | 27 | self.s1=MBConvBlock(ksize=3,input_filters=out_chs[0],output_filters=out_chs[0],image_size=image_size//2) 28 | self.mlp1=nn.Sequential( 29 | nn.Conv2d(out_chs[0],out_chs[1],kernel_size=1), 30 | nn.ReLU(), 31 | nn.Conv2d(out_chs[1],out_chs[1],kernel_size=1) 32 | ) 33 | 34 | self.s2=MBConvBlock(ksize=3,input_filters=out_chs[1],output_filters=out_chs[1],image_size=image_size//4) 35 | self.mlp2=nn.Sequential( 36 | nn.Conv2d(out_chs[1],out_chs[2],kernel_size=1), 37 | nn.ReLU(), 38 | nn.Conv2d(out_chs[2],out_chs[2],kernel_size=1) 39 | ) 40 | 41 | self.s3=ScaledDotProductAttention(out_chs[2],out_chs[2]//8,out_chs[2]//8,8) 42 | self.mlp3=nn.Sequential( 43 | nn.Linear(out_chs[2],out_chs[3]), 44 | nn.ReLU(), 45 | nn.Linear(out_chs[3],out_chs[3]) 46 | ) 47 | 48 | self.s4=ScaledDotProductAttention(out_chs[3],out_chs[3]//8,out_chs[3]//8,8) 49 | self.mlp4=nn.Sequential( 50 | nn.Linear(out_chs[3],out_chs[4]), 51 | nn.ReLU(), 52 | nn.Linear(out_chs[4],out_chs[4]) 53 | ) 54 | 55 | 56 | def forward(self, x) : 57 | B,C,H,W=x.shape 58 | #stage0 59 | y=self.mlp0(self.s0(x)) 60 | y=self.maxpool2d(y) 61 | #stage1 62 | y=self.mlp1(self.s1(y)) 63 | y=self.maxpool2d(y) 64 | #stage2 65 | y=self.mlp2(self.s2(y)) 66 | y=self.maxpool2d(y) 67 | #stage3 68 | y=y.reshape(B,self.out_chs[2],-1).permute(0,2,1) #B,N,C 69 | y=self.mlp3(self.s3(y,y,y)) 70 | y=self.maxpool1d(y.permute(0,2,1)).permute(0,2,1) 71 | #stage4 72 | y=self.mlp4(self.s4(y,y,y)) 73 | y=self.maxpool1d(y.permute(0,2,1)) 74 | N=y.shape[-1] 75 | y=y.reshape(B,self.out_chs[4],int(sqrt(N)),int(sqrt(N))) 76 | 77 | return y 78 | 79 | if __name__ == '__main__': 80 | x=torch.randn(1,3,224,224) 81 | coatnet=CoAtNet(3,224) 82 | y=coatnet(x) 83 | print(y.shape) 84 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SimplifiedSelfAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SimplifiedScaledDotProductAttention(nn.Module): 9 | ''' 10 | Scaled dot-product attention 11 | ''' 12 | 13 | def __init__(self, d_model, h,dropout=.1): 14 | ''' 15 | :param d_model: Output dimensionality of the model 16 | :param d_k: Dimensionality of queries and keys 17 | :param d_v: Dimensionality of values 18 | :param h: Number of heads 19 | ''' 20 | super(SimplifiedScaledDotProductAttention, self).__init__() 21 | 22 | self.d_model = d_model 23 | self.d_k = d_model//h 24 | self.d_v = d_model//h 25 | self.h = h 26 | 27 | self.fc_o = nn.Linear(h * self.d_v, d_model) 28 | self.dropout=nn.Dropout(dropout) 29 | 30 | 31 | 32 | self.init_weights() 33 | 34 | 35 | def init_weights(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal_(m.weight, mode='fan_out') 39 | if m.bias is not None: 40 | init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant_(m.weight, 1) 43 | init.constant_(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal_(m.weight, std=0.001) 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | 49 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 50 | ''' 51 | Computes 52 | :param queries: Queries (b_s, nq, d_model) 53 | :param keys: Keys (b_s, nk, d_model) 54 | :param values: Values (b_s, nk, d_model) 55 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 56 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 57 | :return: 58 | ''' 59 | b_s, nq = queries.shape[:2] 60 | nk = keys.shape[1] 61 | 62 | q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 63 | k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 64 | v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 65 | 66 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 67 | if attention_weights is not None: 68 | att = att * attention_weights 69 | if attention_mask is not None: 70 | att = att.masked_fill(attention_mask, -np.inf) 71 | att = torch.softmax(att, -1) 72 | att=self.dropout(att) 73 | 74 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 75 | out = self.fc_o(out) # (b_s, nq, d_model) 76 | return out 77 | 78 | 79 | if __name__ == '__main__': 80 | input=torch.randn(50,49,512) 81 | ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8) 82 | output=ssa(input,input,input) 83 | print(output.shape) 84 | 85 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SelfAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class ScaledDotProductAttention(nn.Module): 9 | ''' 10 | Scaled dot-product attention 11 | ''' 12 | 13 | def __init__(self, d_model, d_k, d_v, h,dropout=.1): 14 | ''' 15 | :param d_model: Output dimensionality of the model 16 | :param d_k: Dimensionality of queries and keys 17 | :param d_v: Dimensionality of values 18 | :param h: Number of heads 19 | ''' 20 | super(ScaledDotProductAttention, self).__init__() 21 | self.fc_q = nn.Linear(d_model, h * d_k) 22 | self.fc_k = nn.Linear(d_model, h * d_k) 23 | self.fc_v = nn.Linear(d_model, h * d_v) 24 | self.fc_o = nn.Linear(h * d_v, d_model) 25 | self.dropout=nn.Dropout(dropout) 26 | 27 | self.d_model = d_model 28 | self.d_k = d_k 29 | self.d_v = d_v 30 | self.h = h 31 | 32 | self.init_weights() 33 | 34 | 35 | def init_weights(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal_(m.weight, mode='fan_out') 39 | if m.bias is not None: 40 | init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant_(m.weight, 1) 43 | init.constant_(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal_(m.weight, std=0.001) 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | 49 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 50 | ''' 51 | Computes 52 | :param queries: Queries (b_s, nq, d_model) 53 | :param keys: Keys (b_s, nk, d_model) 54 | :param values: Values (b_s, nk, d_model) 55 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 56 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 57 | :return: 58 | ''' 59 | b_s, nq = queries.shape[:2] 60 | nk = keys.shape[1] 61 | 62 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 63 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 64 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 65 | 66 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 67 | if attention_weights is not None: 68 | att = att * attention_weights 69 | if attention_mask is not None: 70 | # att = att.masked_fill(attention_mask, -np.inf) 71 | att = att + (1-attention_mask)*(-1e6) 72 | att = torch.softmax(att, -1) 73 | att=self.dropout(att) 74 | 75 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 76 | out = self.fc_o(out) # (b_s, nq, d_model) 77 | return out 78 | 79 | 80 | if __name__ == '__main__': 81 | input=torch.randn(50,49,512) 82 | sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8) 83 | output=sa(input,input,input) 84 | print(output.shape) 85 | 86 | -------------------------------------------------------------------------------- /retriever/retrieval/___tr_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import pickle 4 | import random 5 | import torch 6 | 7 | from .data_utils import collate_tokens 8 | 9 | 10 | class TRDataset(Dataset): 11 | 12 | def __init__(self, 13 | tokenizer, 14 | data_path, 15 | max_q_len, 16 | max_q_sp_len, 17 | max_c_len, 18 | train=False, 19 | ): 20 | super().__init__() 21 | self.tokenizer = tokenizer 22 | self.max_q_len = max_q_len 23 | self.max_c_len = max_c_len 24 | self.max_q_sp_len = max_q_sp_len 25 | self.train = train 26 | print(f"Loading data from {data_path}") 27 | self.data = [json.loads(line) for line in open(data_path).readlines()] 28 | with open(data_path, 'rb') as f: 29 | self.data = pickle.load(f) 30 | print(f"Total sample count {len(self.data)}") 31 | 32 | def encode_tb(self, passages, table, max_len): 33 | return self.tokenizer(table=table, queries=passages, max_length=max_len, return_tensors="pt") 34 | 35 | def __getitem__(self, index): 36 | sample = self.data[index] 37 | question = sample['question'] 38 | if question.endswith("?"): 39 | question = question[:-1] 40 | 41 | table = sample['table'] 42 | passages = ' '.join(sample['passages']) 43 | tb_codes = self.encode_tb(passages, table, self.max_c_len) 44 | 45 | q_codes = self.tokenizer.encode_plus(question, max_length=self.max_q_len, return_tensors="pt") 46 | label = torch.tensor(sample['label']) 47 | return { 48 | "q_codes": q_codes, 49 | "tb_codes": tb_codes, 50 | "label": label, 51 | } 52 | 53 | def __len__(self): 54 | return len(self.data) 55 | 56 | 57 | def tb_collate(samples, pad_id=0): 58 | if len(samples) == 0: 59 | return {} 60 | 61 | batch = { 62 | 'q_input_ids': collate_tokens([s["q_codes"]["input_ids"].view(-1) for s in samples], 0), 63 | 'q_mask': collate_tokens([s["q_codes"]["attention_mask"].view(-1) for s in samples], 0), 64 | 65 | 'tb_input_ids': collate_tokens([s["tb_codes"]["input_ids"].view(-1) for s in samples], 0), 66 | 'tb_mask':collate_tokens([s["tb_codes"]["attention_mask"].view(-1) for s in samples], 0), 67 | 68 | } 69 | 70 | # if "token_type_ids" in samples[0]["q_codes"]: 71 | # batch.update({ 72 | # 'q_type_ids': collate_tokens([s["q_codes"]["token_type_ids"].view(-1) for s in samples], 0), 73 | # 'c1_type_ids': collate_tokens([s["start_para_codes"]["token_type_ids"] for s in samples], 0), 74 | # 'c2_type_ids': collate_tokens([s["bridge_para_codes"]["token_type_ids"] for s in samples], 0), 75 | # "q_sp_type_ids": collate_tokens([s["q_sp_codes"]["token_type_ids"].view(-1) for s in samples], 0), 76 | # 'neg1_type_ids': collate_tokens([s["neg_codes_1"]["token_type_ids"] for s in samples], 0), 77 | # 'neg2_type_ids': collate_tokens([s["neg_codes_2"]["token_type_ids"] for s in samples], 0), 78 | # }) 79 | # 80 | # if "sent_ids" in samples[0]["start_para_codes"]: 81 | # batch["c1_sent_target"] = collate_tokens([s["start_para_codes"]["sent_ids"] for s in samples], -1) 82 | # batch["c1_sent_offsets"] = collate_tokens([s["start_para_codes"]["sent_offsets"] for s in samples], 0), 83 | 84 | return batch 85 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/BAM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | class Flatten(nn.Module): 7 | def forward(self,x): 8 | return x.view(x.shape[0],-1) 9 | 10 | class ChannelAttention(nn.Module): 11 | def __init__(self,channel,reduction=16,num_layers=3): 12 | super().__init__() 13 | self.avgpool=nn.AdaptiveAvgPool2d(1) 14 | gate_channels=[channel] 15 | gate_channels+=[channel//reduction]*num_layers 16 | gate_channels+=[channel] 17 | 18 | 19 | self.ca=nn.Sequential() 20 | self.ca.add_module('flatten',Flatten()) 21 | for i in range(len(gate_channels)-2): 22 | self.ca.add_module('fc%d'%i,nn.Linear(gate_channels[i],gate_channels[i+1])) 23 | self.ca.add_module('bn%d'%i,nn.BatchNorm1d(gate_channels[i+1])) 24 | self.ca.add_module('relu%d'%i,nn.ReLU()) 25 | self.ca.add_module('last_fc',nn.Linear(gate_channels[-2],gate_channels[-1])) 26 | 27 | 28 | def forward(self, x) : 29 | res=self.avgpool(x) 30 | res=self.ca(res) 31 | res=res.unsqueeze(-1).unsqueeze(-1).expand_as(x) 32 | return res 33 | 34 | class SpatialAttention(nn.Module): 35 | def __init__(self,channel,reduction=16,num_layers=3,dia_val=2): 36 | super().__init__() 37 | self.sa=nn.Sequential() 38 | self.sa.add_module('conv_reduce1',nn.Conv2d(kernel_size=1,in_channels=channel,out_channels=channel//reduction)) 39 | self.sa.add_module('bn_reduce1',nn.BatchNorm2d(channel//reduction)) 40 | self.sa.add_module('relu_reduce1',nn.ReLU()) 41 | for i in range(num_layers): 42 | self.sa.add_module('conv_%d'%i,nn.Conv2d(kernel_size=3,in_channels=channel//reduction,out_channels=channel//reduction,padding=1,dilation=dia_val)) 43 | self.sa.add_module('bn_%d'%i,nn.BatchNorm2d(channel//reduction)) 44 | self.sa.add_module('relu_%d'%i,nn.ReLU()) 45 | self.sa.add_module('last_conv',nn.Conv2d(channel//reduction,1,kernel_size=1)) 46 | 47 | def forward(self, x) : 48 | res=self.sa(x) 49 | res=res.expand_as(x) 50 | return res 51 | 52 | 53 | 54 | 55 | class BAMBlock(nn.Module): 56 | 57 | def __init__(self, channel=512,reduction=16,dia_val=2): 58 | super().__init__() 59 | self.ca=ChannelAttention(channel=channel,reduction=reduction) 60 | self.sa=SpatialAttention(channel=channel,reduction=reduction,dia_val=dia_val) 61 | self.sigmoid=nn.Sigmoid() 62 | 63 | 64 | def init_weights(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | init.kaiming_normal_(m.weight, mode='fan_out') 68 | if m.bias is not None: 69 | init.constant_(m.bias, 0) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | init.constant_(m.weight, 1) 72 | init.constant_(m.bias, 0) 73 | elif isinstance(m, nn.Linear): 74 | init.normal_(m.weight, std=0.001) 75 | if m.bias is not None: 76 | init.constant_(m.bias, 0) 77 | 78 | def forward(self, x): 79 | b, c, _, _ = x.size() 80 | sa_out=self.sa(x) 81 | ca_out=self.ca(x) 82 | weight=self.sigmoid(sa_out+ca_out) 83 | out=(1+weight)*x 84 | return out 85 | 86 | 87 | if __name__ == '__main__': 88 | input=torch.randn(50,512,7,7) 89 | bam = BAMBlock(channel=512,reduction=16,dia_val=2) 90 | output=bam(input) 91 | print(output.shape) 92 | 93 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import parse 3 | def common_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # task 7 | parser.add_argument("--train_file", type=str, 8 | default="../data/nq-with-neg-train.txt") 9 | parser.add_argument("--predict_file", type=str, 10 | default="../data/nq-with-neg-dev.txt") 11 | parser.add_argument("--num_workers", default=30, type=int) 12 | parser.add_argument("--do_train", default=False, 13 | action='store_true', help="Whether to run training.") 14 | parser.add_argument("--do_predict", default=False, 15 | action='store_true', help="Whether to run eval on the dev set.") 16 | parser.add_argument("--basic_data_path", 17 | default='/home/t-wzhong/v-wanzho/ODQA/data/',type=str) 18 | # model 19 | parser.add_argument("--model_name", 20 | default="bert-base-uncased", type=str) 21 | parser.add_argument("--init_checkpoint", type=str, 22 | help="Initial checkpoint (usually from a pre-trained BERT model).", 23 | default="") 24 | parser.add_argument("--max_c_len", default=420, type=int, 25 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 26 | "longer than this will be truncated, and sequences shorter than this will be padded.") 27 | parser.add_argument("--max_q_len", default=70, type=int, 28 | help="The maximum number of tokens for the question. Questions longer than this will " 29 | "be truncated to this length.") 30 | parser.add_argument("--max_p_len", default=350, type=int, 31 | help="The maximum number of tokens for the question. Questions longer than this will " 32 | "be truncated to this length.") 33 | parser.add_argument("--cell_trim_length", default=20, type=int, 34 | help="The maximum number of tokens for each cell. Cell longer than this will " 35 | "be truncated to this length.") 36 | parser.add_argument('--fp16', action='store_true') 37 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 38 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 39 | "See details at https://nvidia.github.io/apex/amp.html") 40 | parser.add_argument("--no_cuda", default=False, action='store_true', 41 | help="Whether not to use CUDA when available") 42 | parser.add_argument("--local_rank", type=int, default=-1, 43 | help="local_rank for distributed training on gpus") 44 | parser.add_argument("--max_q_sp_len", default=50, type=int) 45 | parser.add_argument("--sent-level", action="store_true") 46 | parser.add_argument("--rnn-retriever", action="store_true") 47 | parser.add_argument("--predict_batch_size", default=512, 48 | type=int, help="Total batch size for predictions.") 49 | 50 | # multi vector scheme 51 | parser.add_argument("--multi-vector", type=int, default=1) 52 | parser.add_argument("--scheme", type=str, help="how to get the multivector, layerwise or tokenwise", default="none") 53 | parser.add_argument("--no_proj", action="store_true") 54 | parser.add_argument("--shared_encoder", action="store_true") 55 | 56 | # momentum 57 | parser.add_argument("--momentum", action="store_true") 58 | parser.add_argument("--init-retriever", type=str, default="") 59 | parser.add_argument("--k", type=int, default=38400, help="memory bank size") 60 | parser.add_argument("--m", type=float, default=0.999, help="momentum") 61 | 62 | 63 | # NQ multihop trial 64 | parser.add_argument("--nq-multi", action="store_true", help="train the NQ retrieval model to recover from error cases") 65 | 66 | return parser 67 | 68 | args = common_args() -------------------------------------------------------------------------------- /retriever/retrieval/attention/MUSEAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class Depth_Pointwise_Conv1d(nn.Module): 9 | def __init__(self,in_ch,out_ch,k): 10 | super().__init__() 11 | if(k==1): 12 | self.depth_conv=nn.Identity() 13 | else: 14 | self.depth_conv=nn.Conv1d( 15 | in_channels=in_ch, 16 | out_channels=in_ch, 17 | kernel_size=k, 18 | groups=in_ch, 19 | padding=k//2 20 | ) 21 | self.pointwise_conv=nn.Conv1d( 22 | in_channels=in_ch, 23 | out_channels=out_ch, 24 | kernel_size=1, 25 | groups=1 26 | ) 27 | def forward(self,x): 28 | out=self.pointwise_conv(self.depth_conv(x)) 29 | return out 30 | 31 | 32 | 33 | class MUSEAttention(nn.Module): 34 | 35 | def __init__(self, d_model, d_k, d_v, h,dropout=.1): 36 | 37 | 38 | super(MUSEAttention, self).__init__() 39 | self.fc_q = nn.Linear(d_model, h * d_k) 40 | self.fc_k = nn.Linear(d_model, h * d_k) 41 | self.fc_v = nn.Linear(d_model, h * d_v) 42 | self.fc_o = nn.Linear(h * d_v, d_model) 43 | self.dropout=nn.Dropout(dropout) 44 | 45 | self.conv1=Depth_Pointwise_Conv1d(h * d_v, d_model,1) 46 | self.conv3=Depth_Pointwise_Conv1d(h * d_v, d_model,3) 47 | self.conv5=Depth_Pointwise_Conv1d(h * d_v, d_model,5) 48 | self.dy_paras=nn.Parameter(torch.ones(3)) 49 | self.softmax=nn.Softmax(-1) 50 | 51 | self.d_model = d_model 52 | self.d_k = d_k 53 | self.d_v = d_v 54 | self.h = h 55 | 56 | self.init_weights() 57 | 58 | 59 | def init_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | init.kaiming_normal_(m.weight, mode='fan_out') 63 | if m.bias is not None: 64 | init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | init.constant_(m.weight, 1) 67 | init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.Linear): 69 | init.normal_(m.weight, std=0.001) 70 | if m.bias is not None: 71 | init.constant_(m.bias, 0) 72 | 73 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 74 | 75 | #Self Attention 76 | b_s, nq = queries.shape[:2] 77 | nk = keys.shape[1] 78 | 79 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 80 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 81 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 82 | 83 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 84 | if attention_weights is not None: 85 | att = att * attention_weights 86 | if attention_mask is not None: 87 | att = att.masked_fill(attention_mask, -np.inf) 88 | att = torch.softmax(att, -1) 89 | att=self.dropout(att) 90 | 91 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 92 | out = self.fc_o(out) # (b_s, nq, d_model) 93 | 94 | v2=v.permute(0,1,3,2).contiguous().view(b_s,-1,nk) #bs,dim,n 95 | self.dy_paras=nn.Parameter(self.softmax(self.dy_paras)) 96 | out2=self.dy_paras[0]*self.conv1(v2)+self.dy_paras[1]*self.conv3(v2)+self.dy_paras[2]*self.conv5(v2) 97 | out2=out2.permute(0,2,1) #bs.n.dim 98 | 99 | out=out+out2 100 | return out 101 | 102 | 103 | if __name__ == '__main__': 104 | input=torch.randn(50,49,512) 105 | sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8) 106 | output=sa(input,input,input) 107 | print(output.shape) 108 | 109 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/EMSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class EMSA(nn.Module): 9 | 10 | def __init__(self, d_model, d_k, d_v, h,dropout=.1,H=7,W=7,ratio=3,apply_transform=True): 11 | 12 | super(EMSA, self).__init__() 13 | self.H=H 14 | self.W=W 15 | self.fc_q = nn.Linear(d_model, h * d_k) 16 | self.fc_k = nn.Linear(d_model, h * d_k) 17 | self.fc_v = nn.Linear(d_model, h * d_v) 18 | self.fc_o = nn.Linear(h * d_v, d_model) 19 | self.dropout=nn.Dropout(dropout) 20 | 21 | self.ratio=ratio 22 | if(self.ratio>1): 23 | self.sr=nn.Sequential() 24 | self.sr_conv=nn.Conv2d(d_model,d_model,kernel_size=ratio+1,stride=ratio,padding=ratio//2,groups=d_model) 25 | self.sr_ln=nn.LayerNorm(d_model) 26 | 27 | self.apply_transform=apply_transform and h>1 28 | if(self.apply_transform): 29 | self.transform=nn.Sequential() 30 | self.transform.add_module('conv',nn.Conv2d(h,h,kernel_size=1,stride=1)) 31 | self.transform.add_module('softmax',nn.Softmax(-1)) 32 | self.transform.add_module('in',nn.InstanceNorm2d(h)) 33 | 34 | self.d_model = d_model 35 | self.d_k = d_k 36 | self.d_v = d_v 37 | self.h = h 38 | 39 | self.init_weights() 40 | 41 | 42 | def init_weights(self): 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | init.kaiming_normal_(m.weight, mode='fan_out') 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | init.constant_(m.weight, 1) 50 | init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.Linear): 52 | init.normal_(m.weight, std=0.001) 53 | if m.bias is not None: 54 | init.constant_(m.bias, 0) 55 | 56 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 57 | 58 | b_s, nq ,c = queries.shape 59 | nk = keys.shape[1] 60 | 61 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 62 | 63 | if(self.ratio>1): 64 | x=queries.permute(0,2,1).view(b_s,c,self.H,self.W) #bs,c,H,W 65 | x=self.sr_conv(x) #bs,c,h,w 66 | x=x.contiguous().view(b_s,c,-1).permute(0,2,1) #bs,n',c 67 | x=self.sr_ln(x) 68 | k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, n') 69 | v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, n', d_v) 70 | else: 71 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 72 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 73 | 74 | if(self.apply_transform): 75 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, n') 76 | att = self.transform(att) # (b_s, h, nq, n') 77 | else: 78 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, n') 79 | att = torch.softmax(att, -1) # (b_s, h, nq, n') 80 | 81 | 82 | if attention_weights is not None: 83 | att = att * attention_weights 84 | if attention_mask is not None: 85 | att = att.masked_fill(attention_mask, -np.inf) 86 | 87 | att=self.dropout(att) 88 | 89 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 90 | out = self.fc_o(out) # (b_s, nq, d_model) 91 | return out 92 | 93 | 94 | if __name__ == '__main__': 95 | input=torch.randn(50,64,512) 96 | emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True) 97 | output=emsa(input,input,input) 98 | print(output.shape) 99 | 100 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 47 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 48 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 49 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 50 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 51 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 52 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 53 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 54 | self.NON_WS), 55 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 56 | ) 57 | if len(kwargs.get('annotators', {})) > 0: 58 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 59 | (type(self).__name__, kwargs.get('annotators'))) 60 | self.annotators = set() 61 | self.substitutions = kwargs.get('substitutions', True) 62 | 63 | def tokenize(self, text): 64 | data = [] 65 | matches = [m for m in self._regexp.finditer(text)] 66 | for i in range(len(matches)): 67 | # Get text 68 | token = matches[i].group() 69 | 70 | # Make normalizations for special token types 71 | if self.substitutions: 72 | groups = matches[i].groupdict() 73 | if groups['sdquote']: 74 | token = "``" 75 | elif groups['edquote']: 76 | token = "''" 77 | elif groups['ssquote']: 78 | token = "`" 79 | elif groups['esquote']: 80 | token = "'" 81 | elif groups['dash']: 82 | token = '--' 83 | elif groups['ellipses']: 84 | token = '...' 85 | 86 | # Get whitespace 87 | span = matches[i].span() 88 | start_ws = span[0] 89 | if i + 1 < len(matches): 90 | end_ws = matches[i + 1].span()[0] 91 | else: 92 | end_ws = span[1] 93 | 94 | # Format data 95 | data.append(( 96 | token, 97 | text[start_ws: end_ws], 98 | span, 99 | )) 100 | return Tokens(data, self.annotators) 101 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/BM25_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | import os 16 | import sys 17 | 18 | current_path = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | sys.path.append(os.path.dirname(current_path)) 21 | 22 | from . import utils 23 | import drqa_tokenizers 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class BM25DocRanker(object): 29 | """Loads a pre-weighted inverted index of token/document terms. 30 | Scores new queries by taking sparse dot products. 31 | """ 32 | 33 | def __init__(self, tfidf_path, strict=True): 34 | """ 35 | Args: 36 | tfidf_path: path to saved model file 37 | strict: fail on empty queries or continue (and return empty result) 38 | """ 39 | # Load from disk 40 | tfidf_path = tfidf_path 41 | logger.info('Loading %s' % tfidf_path) 42 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 43 | self.doc_mat = matrix 44 | self.ngrams = metadata['ngram'] 45 | self.hash_size = metadata['hash_size'] 46 | self.tokenizer = drqa_tokenizers.get_class(metadata['tokenizer'])() 47 | self.doc_freqs = metadata['doc_freqs'].squeeze() 48 | self.doc_dict = metadata['doc_dict'] 49 | self.num_docs = len(self.doc_dict[0]) 50 | self.strict = strict 51 | 52 | def get_doc_index(self, doc_id): 53 | """Convert doc_id --> doc_index""" 54 | return self.doc_dict[0][doc_id] 55 | 56 | def get_doc_id(self, doc_index): 57 | """Convert doc_index --> doc_id""" 58 | return self.doc_dict[1][doc_index] 59 | 60 | def closest_docs(self, query, k=1): 61 | """Closest docs by dot product between query and documents 62 | in tfidf weighted word vector space. 63 | """ 64 | spvec = self.text2spvec(query) 65 | res = spvec * self.doc_mat 66 | 67 | if len(res.data) <= k: 68 | o_sort = np.argsort(-res.data) 69 | else: 70 | o = np.argpartition(-res.data, k)[0:k] 71 | o_sort = o[np.argsort(-res.data[o])] 72 | 73 | doc_scores = res.data[o_sort] 74 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 75 | return doc_ids, doc_scores 76 | 77 | def batch_closest_docs(self, queries, k=1, num_workers=None): 78 | """Process a batch of closest_docs requests multithreaded. 79 | Note: we can use plain threads here as scipy is outside of the GIL. 80 | """ 81 | with ThreadPool(num_workers) as threads: 82 | closest_docs = partial(self.closest_docs, k=k) 83 | results = threads.map(closest_docs, queries) 84 | return results 85 | 86 | def parse(self, query): 87 | """Parse the query into tokens (either ngrams or tokens).""" 88 | tokens = self.tokenizer.tokenize(query) 89 | return tokens.ngrams(n=self.ngrams, uncased=True, 90 | filter_fn=utils.filter_ngram) 91 | 92 | def text2spvec(self, query): 93 | """Create a sparse tfidf-weighted word vector from query. 94 | 95 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 96 | """ 97 | # Get hashed ngrams 98 | words = self.parse(utils.normalize(query)) 99 | wids = [utils.hash(w, self.hash_size) for w in words] 100 | 101 | if len(wids) == 0: 102 | if self.strict: 103 | raise RuntimeError('No valid word in: %s' % query) 104 | else: 105 | logger.warning('No valid word in: %s' % query) 106 | return sp.csr_matrix((1, self.hash_size)) 107 | 108 | # Count TF 109 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 110 | tfs = (wids_counts > 0).astype(int) 111 | 112 | # Count IDF 113 | Ns = self.doc_freqs[wids_unique] 114 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 115 | idfs[idfs < 0] = 0 116 | 117 | # TF-IDF 118 | data = np.multiply(tfs, idfs) 119 | 120 | # One row, sparse csr matrix 121 | indptr = np.array([0, len(wids_unique)]) 122 | spvec = sp.csr_matrix( 123 | (data, wids_unique, indptr), shape=(1, self.hash_size) 124 | ) 125 | 126 | return spvec 127 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | import os 16 | import sys 17 | 18 | current_path = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | sys.path.append(os.path.dirname(current_path)) 21 | 22 | from . import utils 23 | import drqa_tokenizers 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class TfidfDocRanker(object): 29 | """Loads a pre-weighted inverted index of token/document terms. 30 | Scores new queries by taking sparse dot products. 31 | """ 32 | 33 | def __init__(self, tfidf_path=None, strict=True): 34 | """ 35 | Args: 36 | tfidf_path: path to saved model file 37 | strict: fail on empty queries or continue (and return empty result) 38 | """ 39 | # Load from disk 40 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 41 | logger.info('Loading %s' % tfidf_path) 42 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 43 | self.doc_mat = matrix 44 | self.ngrams = metadata['ngram'] 45 | self.hash_size = metadata['hash_size'] 46 | self.tokenizer = drqa_tokenizers.get_class(metadata['tokenizer'])() 47 | self.doc_freqs = metadata['doc_freqs'].squeeze() 48 | self.doc_dict = metadata['doc_dict'] 49 | self.num_docs = len(self.doc_dict[0]) 50 | self.strict = strict 51 | 52 | def get_doc_index(self, doc_id): 53 | """Convert doc_id --> doc_index""" 54 | return self.doc_dict[0][doc_id] 55 | 56 | def get_doc_id(self, doc_index): 57 | """Convert doc_index --> doc_id""" 58 | return self.doc_dict[1][doc_index] 59 | 60 | def closest_docs(self, query, k=1): 61 | """Closest docs by dot product between query and documents 62 | in tfidf weighted word vector space. 63 | """ 64 | spvec = self.text2spvec(query) 65 | res = spvec * self.doc_mat 66 | 67 | if len(res.data) <= k: 68 | o_sort = np.argsort(-res.data) 69 | else: 70 | o = np.argpartition(-res.data, k)[0:k] 71 | o_sort = o[np.argsort(-res.data[o])] 72 | 73 | doc_scores = res.data[o_sort] 74 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 75 | return doc_ids, doc_scores 76 | 77 | def batch_closest_docs(self, queries, k=1, num_workers=None): 78 | """Process a batch of closest_docs requests multithreaded. 79 | Note: we can use plain threads here as scipy is outside of the GIL. 80 | """ 81 | with ThreadPool(num_workers) as threads: 82 | closest_docs = partial(self.closest_docs, k=k) 83 | results = threads.map(closest_docs, queries) 84 | return results 85 | 86 | def parse(self, query): 87 | """Parse the query into tokens (either ngrams or tokens).""" 88 | tokens = self.tokenizer.tokenize(query) 89 | return tokens.ngrams(n=self.ngrams, uncased=True, 90 | filter_fn=utils.filter_ngram) 91 | 92 | def text2spvec(self, query): 93 | """Create a sparse tfidf-weighted word vector from query. 94 | 95 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 96 | """ 97 | # Get hashed ngrams 98 | words = self.parse(utils.normalize(query)) 99 | wids = [utils.hash(w, self.hash_size) for w in words] 100 | 101 | if len(wids) == 0: 102 | if self.strict: 103 | raise RuntimeError('No valid word in: %s' % query) 104 | else: 105 | logger.warning('No valid word in: %s' % query) 106 | return sp.csr_matrix((1, self.hash_size)) 107 | 108 | # Count TF 109 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 110 | tfs = np.log1p(wids_counts) 111 | 112 | # Count IDF 113 | Ns = self.doc_freqs[wids_unique] 114 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 115 | idfs[idfs < 0] = 0 116 | 117 | # TF-IDF 118 | data = np.multiply(tfs, idfs) 119 | 120 | # One row, sparse csr matrix 121 | indptr = np.array([0, len(wids_unique)]) 122 | spvec = sp.csr_matrix( 123 | (data, wids_unique, indptr), shape=(1, self.hash_size) 124 | ) 125 | 126 | return spvec 127 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | 15 | 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename, allow_pickle=True) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | STOPWORDS = { 55 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 56 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 57 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 58 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 59 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 60 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 61 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 62 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 63 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 64 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 65 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 66 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 67 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 68 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 69 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 70 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 71 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 72 | } 73 | 74 | 75 | def normalize(text): 76 | """Resolve different type of unicode encodings.""" 77 | return unicodedata.normalize('NFD', text) 78 | 79 | 80 | def filter_word(text): 81 | """Take out english stopwords, punctuation, and compound endings.""" 82 | text = normalize(text) 83 | if regex.match(r'^\p{P}+$', text): 84 | return True 85 | if text.lower() in STOPWORDS: 86 | return True 87 | return False 88 | 89 | 90 | def filter_ngram(gram, mode='any'): 91 | """Decide whether to keep or discard an n-gram. 92 | 93 | Args: 94 | gram: list of tokens (length N) 95 | mode: Option to throw out ngram if 96 | 'any': any single token passes filter_word 97 | 'all': all tokens pass filter_word 98 | 'ends': book-ended by filterable tokens 99 | """ 100 | filtered = [filter_word(w) for w in gram] 101 | if mode == 'any': 102 | return any(filtered) 103 | elif mode == 'all': 104 | return all(filtered) 105 | elif mode == 'ends': 106 | return filtered[0] or filtered[-1] 107 | else: 108 | raise ValueError('Invalid mode: %s' % mode) 109 | 110 | def get_field(d, field_list): 111 | """get the subfield associated to a list of elastic fields 112 | E.g. ['file', 'filename'] to d['file']['filename'] 113 | """ 114 | if isinstance(field_list, str): 115 | return d[field_list] 116 | else: 117 | idx = d.copy() 118 | for field in field_list: 119 | idx = idx[field] 120 | return idx 121 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/corenlp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Simple wrapper around the Stanford CoreNLP pipeline. 8 | 9 | Serves commands to a java subprocess running the jar. Requires java 8. 10 | """ 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | 19 | 20 | class CoreNLPTokenizer(Tokenizer): 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: set that can include pos, lemma, and ner. 26 | classpath: Path to the corenlp directory of jars 27 | mem: Java heap memory 28 | """ 29 | self.classpath = (kwargs.get('classpath') or 30 | DEFAULTS['corenlp_classpath']) 31 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 32 | self.mem = kwargs.get('mem', '2g') 33 | self._launch() 34 | 35 | def _launch(self): 36 | """Start the CoreNLP jar with pexpect.""" 37 | annotators = ['tokenize', 'ssplit'] 38 | if 'ner' in self.annotators: 39 | annotators.extend(['pos', 'lemma', 'ner']) 40 | elif 'lemma' in self.annotators: 41 | annotators.extend(['pos', 'lemma']) 42 | elif 'pos' in self.annotators: 43 | annotators.extend(['pos']) 44 | annotators = ','.join(annotators) 45 | options = ','.join(['untokenizable=noneDelete', 46 | 'invertible=true']) 47 | cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, 48 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 49 | annotators, '-tokenize.options', options, 50 | '-outputFormat', 'json', '-prettyPrint', 'false'] 51 | 52 | # We use pexpect to keep the subprocess alive and feed it commands. 53 | # Because we don't want to get hit by the max terminal buffer size, 54 | # we turn off canonical input processing to have unlimited bytes. 55 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 56 | self.corenlp.setecho(False) 57 | self.corenlp.sendline('stty -icanon') 58 | self.corenlp.sendline(' '.join(cmd)) 59 | self.corenlp.delaybeforesend = 0 60 | self.corenlp.delayafterread = 0 61 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 62 | 63 | @staticmethod 64 | def _convert(token): 65 | if token == '-LRB-': 66 | return '(' 67 | if token == '-RRB-': 68 | return ')' 69 | if token == '-LSB-': 70 | return '[' 71 | if token == '-RSB-': 72 | return ']' 73 | if token == '-LCB-': 74 | return '{' 75 | if token == '-RCB-': 76 | return '}' 77 | return token 78 | 79 | def tokenize(self, text): 80 | # Since we're feeding text to the commandline, we're waiting on seeing 81 | # the NLP> prompt. Hacky! 82 | if 'NLP>' in text: 83 | raise RuntimeError('Bad token (NLP>) in text!') 84 | 85 | # Sending q will cause the process to quit -- manually override 86 | if text.lower().strip() == 'q': 87 | token = text.strip() 88 | index = text.index(token) 89 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 90 | return Tokens(data, self.annotators) 91 | 92 | # Minor cleanup before tokenizing. 93 | clean_text = text.replace('\n', ' ') 94 | 95 | self.corenlp.sendline(clean_text.encode('utf-8')) 96 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 97 | 98 | # Skip to start of output (may have been stderr logging messages) 99 | output = self.corenlp.before 100 | start = output.find(b'{"sentences":') 101 | output = json.loads(output[start:].decode('utf-8')) 102 | 103 | data = [] 104 | tokens = [t for s in output['sentences'] for t in s['tokens']] 105 | for i in range(len(tokens)): 106 | # Get whitespace 107 | start_ws = tokens[i]['characterOffsetBegin'] 108 | if i + 1 < len(tokens): 109 | end_ws = tokens[i + 1]['characterOffsetBegin'] 110 | else: 111 | end_ws = tokens[i]['characterOffsetEnd'] 112 | 113 | data.append(( 114 | self._convert(tokens[i]['word']), 115 | text[start_ws: end_ws], 116 | (tokens[i]['characterOffsetBegin'], 117 | tokens[i]['characterOffsetEnd']), 118 | tokens[i].get('pos', None), 119 | tokens[i].get('lemma', None), 120 | tokens[i].get('ner', None) 121 | )) 122 | return Tokens(data, self.annotators) 123 | -------------------------------------------------------------------------------- /retriever/encode_corpus.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import collections 4 | import logging 5 | import json 6 | import os 7 | import random 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | from transformers import AutoConfig, AutoTokenizer 12 | from torch.utils.data import DataLoader 13 | from functools import partial 14 | 15 | from retrieval.data.encode_datasets import EmDataset, em_collate_bert 16 | from retrieval.data.encode_datasets import EmDataset, EmDatasetFilter, EmDatasetMeta 17 | from retrieval.models.retriever import CtxEncoder, RobertaCtxEncoder 18 | from retrieval.models.retriever import SingleRetriever, SingleEncoder, RobertaSingleEncoder 19 | from retrieval.config import encode_args 20 | from retrieval.utils.utils import move_to_cuda, load_saved 21 | 22 | logger = logging.getLogger(__name__) 23 | def main(): 24 | args = encode_args() 25 | if args.fp16: 26 | import apex 27 | apex.amp.register_half_function(torch, 'einsum') 28 | 29 | if args.local_rank == -1 or args.no_cuda: 30 | device = torch.device( 31 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 32 | n_gpu = torch.cuda.device_count() 33 | else: 34 | device = torch.device("cuda", args.local_rank) 35 | n_gpu = 1 36 | torch.distributed.init_process_group(backend='nccl') 37 | 38 | if not args.predict_file: 39 | raise ValueError( 40 | "If `do_predict` is True, then `predict_file` must be specified.") 41 | 42 | # select encoing model 43 | bert_config = AutoConfig.from_pretrained(args.model_name) 44 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 45 | collate_fc = partial(em_collate_bert, pad_id=tokenizer.pad_token_id) 46 | if "roberta" in args.model_name: 47 | model = RobertaSingleEncoder(bert_config, args) 48 | logger.info("Model Using RobertaSingleEncoder...") 49 | else: 50 | model = SingleEncoder(bert_config, args) 51 | logger.info("Model Using SingleEncoder...") 52 | if args.add_special_tokens: 53 | # special_tokens_dict = {'additional_special_tokens': ["[HEADER]", "[PASSAGE]","[SEP]", "[TB]","[DATA]","[TITLE]","[SECTITLE]"]} 54 | special_tokens_dict = {'additional_special_tokens': ["[HEADER]", "[PASSAGE]", "[TB]","[DATA]","[TITLE]","[SECTITLE]"]} 55 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 56 | model.resize_token_embeddings(len(tokenizer)) 57 | 58 | # select dataset 59 | if args.tfidf_filter and args.encode_table: 60 | eval_dataset = EmDatasetFilter(tokenizer, args.predict_file, args.tfidf_result_file, args.encode_table, args) 61 | logger.info("Dataset Using EmDatasetFilter...") 62 | elif args.metadata: 63 | eval_dataset = EmDatasetMeta(tokenizer, args.predict_file, args.encode_table, args) 64 | logger.info("Dataset Using EmDatasetMeta...") 65 | else: 66 | eval_dataset = EmDataset(tokenizer, args.predict_file, args.encode_table, args) 67 | logger.info("Dataset Using EmDataset...") 68 | eval_dataset.processing_data() 69 | eval_dataloader = DataLoader(eval_dataset, 70 | batch_size=args.predict_batch_size, 71 | collate_fn=collate_fc, 72 | pin_memory=True, 73 | num_workers=args.num_workers) 74 | 75 | assert args.init_checkpoint != "" 76 | model = load_saved(model, args.init_checkpoint, exact=False) 77 | model.to(device) 78 | 79 | if args.fp16: 80 | try: 81 | from apex import amp 82 | except ImportError: 83 | raise ImportError( 84 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 85 | model = amp.initialize(model, opt_level=args.fp16_opt_level) 86 | 87 | if args.local_rank != -1: 88 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 89 | output_device=args.local_rank) 90 | elif n_gpu > 1: 91 | model = torch.nn.DataParallel(model) 92 | # import pdb; pdb.set_trace() 93 | 94 | embeds = predict(model, eval_dataloader) 95 | logger.info(embeds.size()) 96 | 97 | if not os.path.exists(os.path.dirname(args.embed_save_path)): 98 | os.makedirs(os.path.dirname(args.embed_save_path)) 99 | logger.info("making dir :{}".format(os.path.dirname(args.embed_save_path))) 100 | logger.info("saving to :{}".format(args.embed_save_path)) 101 | np.save(args.embed_save_path, embeds.cpu().numpy()) 102 | 103 | 104 | def predict(model, eval_dataloader): 105 | if type(model) == list: 106 | model = [m.eval() for m in model] 107 | else: 108 | model.eval() 109 | 110 | embed_array = [] 111 | # import pdb; pdb.set_trace() 112 | # logger.info("start from 379200") 113 | for idx, batch in enumerate(tqdm(eval_dataloader)): 114 | batch = move_to_cuda(batch) 115 | with torch.no_grad(): 116 | try: 117 | results = model(batch) 118 | except Exception as e: 119 | logger.info(e) 120 | # logger.info("Error Batch: {}, instance: {}".format(idx, idx*1600)) 121 | continue 122 | embed = results['embed'].cpu() 123 | embed_array.append(embed) 124 | 125 | ## linear combination tuning on dev data 126 | embed_array = torch.cat(embed_array, dim=0) 127 | 128 | # model.train() 129 | return embed_array 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/HaloAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | 7 | # relative positional embedding 8 | 9 | def to(x): 10 | return {'device': x.device, 'dtype': x.dtype} 11 | 12 | def pair(x): 13 | return (x, x) if not isinstance(x, tuple) else x 14 | 15 | def expand_dim(t, dim, k): 16 | t = t.unsqueeze(dim = dim) 17 | expand_shape = [-1] * len(t.shape) 18 | expand_shape[dim] = k 19 | return t.expand(*expand_shape) 20 | 21 | def rel_to_abs(x): 22 | b, l, m = x.shape 23 | r = (m + 1) // 2 24 | 25 | col_pad = torch.zeros((b, l, 1), **to(x)) 26 | x = torch.cat((x, col_pad), dim = 2) 27 | flat_x = rearrange(x, 'b l c -> b (l c)') 28 | flat_pad = torch.zeros((b, m - l), **to(x)) 29 | flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1) 30 | final_x = flat_x_padded.reshape(b, l + 1, m) 31 | final_x = final_x[:, :l, -r:] 32 | return final_x 33 | 34 | def relative_logits_1d(q, rel_k): 35 | b, h, w, _ = q.shape 36 | r = (rel_k.shape[0] + 1) // 2 37 | 38 | logits = einsum('b x y d, r d -> b x y r', q, rel_k) 39 | logits = rearrange(logits, 'b x y r -> (b x) y r') 40 | logits = rel_to_abs(logits) 41 | 42 | logits = logits.reshape(b, h, w, r) 43 | logits = expand_dim(logits, dim = 2, k = r) 44 | return logits 45 | 46 | class RelPosEmb(nn.Module): 47 | def __init__( 48 | self, 49 | block_size, 50 | rel_size, 51 | dim_head 52 | ): 53 | super().__init__() 54 | height = width = rel_size 55 | scale = dim_head ** -0.5 56 | 57 | self.block_size = block_size 58 | self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale) 59 | self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale) 60 | 61 | def forward(self, q): 62 | block = self.block_size 63 | 64 | q = rearrange(q, 'b (x y) c -> b x y c', x = block) 65 | rel_logits_w = relative_logits_1d(q, self.rel_width) 66 | rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)') 67 | 68 | q = rearrange(q, 'b x y d -> b y x d') 69 | rel_logits_h = relative_logits_1d(q, self.rel_height) 70 | rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)') 71 | return rel_logits_w + rel_logits_h 72 | 73 | # classes 74 | 75 | class HaloAttention(nn.Module): 76 | def __init__( 77 | self, 78 | *, 79 | dim, 80 | block_size, 81 | halo_size, 82 | dim_head = 64, 83 | heads = 8 84 | ): 85 | super().__init__() 86 | assert halo_size > 0, 'halo size must be greater than 0' 87 | 88 | self.dim = dim 89 | self.heads = heads 90 | self.scale = dim_head ** -0.5 91 | 92 | self.block_size = block_size 93 | self.halo_size = halo_size 94 | 95 | inner_dim = dim_head * heads 96 | 97 | self.rel_pos_emb = RelPosEmb( 98 | block_size = block_size, 99 | rel_size = block_size + (halo_size * 2), 100 | dim_head = dim_head 101 | ) 102 | 103 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 104 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 105 | self.to_out = nn.Linear(inner_dim, dim) 106 | 107 | def forward(self, x): 108 | b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device 109 | assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size' 110 | assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})' 111 | 112 | # get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values 113 | 114 | q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = block, p2 = block) 115 | 116 | kv_inp = F.unfold(x, kernel_size = block + halo * 2, stride = block, padding = halo) 117 | kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c = c) 118 | 119 | # derive queries, keys, values 120 | 121 | q = self.to_q(q_inp) 122 | k, v = self.to_kv(kv_inp).chunk(2, dim = -1) 123 | 124 | # split heads 125 | 126 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = heads), (q, k, v)) 127 | 128 | # scale 129 | 130 | q *= self.scale 131 | 132 | # attention 133 | 134 | sim = einsum('b i d, b j d -> b i j', q, k) 135 | 136 | # add relative positional bias 137 | 138 | sim += self.rel_pos_emb(q) 139 | 140 | # mask out padding (in the paper, they claim to not need masks, but what about padding?) 141 | 142 | mask = torch.ones(1, 1, h, w, device = device) 143 | mask = F.unfold(mask, kernel_size = block + (halo * 2), stride = block, padding = halo) 144 | mask = repeat(mask, '() j i -> (b i h) () j', b = b, h = heads) 145 | mask = mask.bool() 146 | 147 | max_neg_value = -torch.finfo(sim.dtype).max 148 | sim.masked_fill_(mask, max_neg_value) 149 | 150 | # attention 151 | 152 | attn = sim.softmax(dim = -1) 153 | 154 | # aggregate 155 | 156 | out = einsum('b i j, b j d -> b i d', attn, v) 157 | 158 | # merge and combine heads 159 | 160 | out = rearrange(out, '(b h) n d -> b n (h d)', h = heads) 161 | out = self.to_out(out) 162 | 163 | # merge blocks back to original feature map 164 | 165 | out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b = b, h = (h // block), w = (w // block), p1 = block, p2 = block) 166 | return out 167 | 168 | if __name__ == '__main__': 169 | input=torch.randn(1,512,8,8) 170 | halo = HaloAttention(dim=512, 171 | block_size=2, 172 | halo_size=1,) 173 | output=halo(input) 174 | print(output.shape) -------------------------------------------------------------------------------- /evidence_chain/evaluate_ranked_evidence_chain.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from scipy.special import softmax 4 | import sys 5 | 6 | def convert_chain_to_string(chain): 7 | chain_rep = [] 8 | for node in chain: 9 | # print(node) 10 | if node['origin']['where'] != 'question': 11 | if node['origin']['where'] == 'table': 12 | prefix = ' ' 13 | elif node['origin']['where'] == 'passage': 14 | prefix = ' {} : '.format( 15 | node['origin']['index'].replace('/wiki/', '').replace('_', ' ').split('/')[0]) 16 | chain_rep.append(prefix + node['content']) 17 | # print(' [TO] '.join(chain_rep)) 18 | # input() 19 | return ' ; '.join(chain_rep) 20 | def find_union_chain(ranked_ec,topt=2): 21 | nodes = set() 22 | all_nodes = [] 23 | for ec in ranked_ec[:topt]: 24 | # print(ec['path']) 25 | add = [node for node in ec['path'][1:] if node['content'] not in nodes] 26 | all_nodes.extend(add) 27 | nodes.update(set([node['content'] for node in ec['path'][1:]])) 28 | return all_nodes 29 | from fuzzywuzzy import fuzz 30 | def calculate_score(data): 31 | topk = {1:0,2:0,3:0,5:0,10:0,20:0} 32 | split = [] 33 | all_selected_chain = [] 34 | hit,all, length,union_hit,union_length=0,0,0,0,0 35 | for idx,d in tqdm(enumerate(data), desc='Calculating score'): 36 | positive_table_block = d['positive_table_blocks'] 37 | 38 | for bid,block in enumerate(positive_table_block): 39 | best_score = -999 40 | selected_chain = '' 41 | # for chain in block['candidate_evidence_chains']: 42 | # # positive_chain = convert_chain_to_string(chain['path']) 43 | # score = softmax(chain['score'])#fuzz.partial_ratio(d['question']+' '+orig_answer,positive_chain) 44 | # if score[1]>best_score: 45 | # best_score = score[1] 46 | # selected_chain = chain 47 | # for cid in range(len(block['candidate_evidence_chains'])): 48 | # block['candidate_evidence_chains'][cid]['score'] = (softmax(block['candidate_evidence_chains'][cid]['score']) + softmax(data2[idx]['positive_table_blocks'][bid]['candidate_evidence_chains'][cid]['score']))/2 49 | ranked_ec = sorted(block['candidate_evidence_chains'],key=lambda k: k['score'][1],reverse=True) 50 | unied_ec = find_union_chain(ranked_ec,topt=2) 51 | # ranked_ec = sorted(block['candidate_evidence_chains'], key=lambda k: k['score'][0], reverse=True) 52 | # print(ranked_ec) 53 | # input() 54 | for i in range(min(len(ranked_ec),20)): 55 | # print(ranked_ec[i]['path'][-1]) 56 | if ranked_ec[i]['path']: 57 | if any([node['is_end_node'] for node in ranked_ec[i]['path']]): 58 | for j in topk.keys(): 59 | if i<j: 60 | topk[j]+=1 61 | break 62 | if unied_ec: 63 | union_length+=len(unied_ec) 64 | if any([node['is_end_node'] for node in unied_ec]): 65 | union_hit+=1 66 | length += len(ranked_ec[0]['path']) if ranked_ec else 0 67 | # reranked_ec = [] 68 | # for i in range(min(len(ranked_ec), 3)): 69 | # # print(ranked_ec[i]['path'][-1]) 70 | # if ranked_ec[i]['path']: 71 | # chain_rep = convert_chain_to_string(ranked_ec[i]['path']) 72 | # question_rep = d['question'] 73 | # score = fuzz.token_set_ratio(question_rep.lower(),chain_rep.lower()) 74 | # ranked_ec[i]['fuzz_score'] = score 75 | # reranked_ec.append(ranked_ec[i]) 76 | # else: 77 | # ranked_ec[i]['fuzz_score'] = 0 78 | # reranked_ec.append(ranked_ec[i]) 79 | # reranked_ec = sorted(reranked_ec,key=lambda k:k['fuzz_score'],reverse=True) 80 | # if reranked_ec:#[0]['path']: 81 | # if reranked_ec[0]['path'][-1]['is_end_node']: 82 | # hit+=1 83 | all+=1 84 | print('Top 1 {}/{}={}'.format(topk[1],all,topk[1]/all)) 85 | print('Top 3 {}/{}={}'.format(topk[2], all, topk[2] / all)) 86 | print('Top 3 {}/{}={}'.format(topk[3], all, topk[3] / all)) 87 | print('Top 5 {}/{}={}'.format(topk[5], all, topk[5] / all)) 88 | print('Top 10 {}/{}={}'.format(topk[10], all, topk[10] / all)) 89 | print('Top 3 Union {}/{}={}'.format(union_hit, all, union_hit/ all)) 90 | print('Average length {}/{}={}'.format(length, all, length/ all)) 91 | print('Average Union length {}/{}={}'.format(union_length, all, union_length / all)) 92 | basic_dir = '/home/t-wzhong/v-wanzho/ODQA/data' 93 | 94 | 95 | # data_path = sys.argv[1] 96 | # data_path='/home/t-wzhong/v-wanzho/ODQA/data/preprocessed_data/evidence_chain/ground-truth-based/dev_preprocessed_normalized_gtmodify_candidate_evichain_nx_scores_addnoun.json' 97 | # data_path_2='/home/t-wzhong/v-wanzho/ODQA/data/preprocessed_data/evidence_chain/ground-truth-based/dev_preprocessed_normalized_gtmodify_candidate_evichain_nx_scores_addnoun.json' 98 | # data_path = f'{basic_dir}/preprocessed_data/evidence_chain/dev_preprocessed_normalized_gtmodify_candidate_evichain_nx_scores_addnoun_roberta_tb.json' 99 | # data_path_2 = f'{basic_dir}/preprocessed_data/evidence_chain/dev_preprocessed_normalized_gtmodify_candidate_evichain_nx_scores_addnoun_roberta_tb.json' 100 | # /home/t-wzhong/v-wanzho/ODQA/data/preprocessed_data/evidence_chain/dev_preprocessed_normalized_gtmodify_candidate_evichain_nx_scores_addnoun_roberta_tb.json 101 | # file = sys.argv[1] 102 | data_path = '/home/t-wzhong/v-wanzho/ODQA/data/preprocessed_data/evidence_chain/ground-truth-based/candidate_chain/../scored_chain/dev_roberta_base_scored_ec.json' 103 | print(f"Loading data from {data_path}") 104 | with open(data_path, 'r') as f: 105 | data = json.load(f)#[:100] 106 | # with open(data_path_2, 'r') as f: 107 | # data_2 = json.load(f)#[:100] 108 | # self.data = self.data[:300] 109 | calculate_score(data) 110 | # print(f"Total sample count {len(self.data)}") 111 | -------------------------------------------------------------------------------- /qa_baseline/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | from ast import parse 8 | from typing import NamedTuple 9 | from torch.nn import parallel 10 | 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class ClusterConfig(NamedTuple): 17 | dist_backend: str 18 | dist_url: str 19 | 20 | 21 | def qa_args(): 22 | parser = argparse.ArgumentParser() 23 | # Required parameters 24 | parser.add_argument("--model_type", default='bert', type=str) 25 | parser.add_argument("--model_name_or_path", default="bert-base-uncased", type=str, help="Path to pre-trained model or shortcut name selected in the list: ") 26 | parser.add_argument("--output_dir", default='qa_model', type=str, help="The output directory where the model checkpoints and predictions will be written.",) 27 | parser.add_argument("--train_file", default=None, type=str, help="The input training file. If a data dir is specified, will look for the file there If no data dir or train/predict files are specified, will run with tensorflow_datasets.",) 28 | parser.add_argument("--dev_file", default=None, type=str, help="The input development file. If a data dir is specified, will look for the file there If no data dir or train/predict files are specified, will run with tensorflow_datasets.",) 29 | parser.add_argument("--resource_dir", type=str, default='data/', help="Number of updates steps to accumulate before performing a backward/update pass.",) 30 | parser.add_argument("--data_dir", type=str, default='data/', help="Number of updates steps to accumulate before performing a backward/update pass.") 31 | parser.add_argument("--predict_file", default=None, type=str, help="The input evaluation file. If a data dir is specified, will look for the file there If no data dir or train/predict files are specified, will run with tensorflow_datasets.",) 32 | parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") 33 | parser.add_argument("--tokenizer_name", default="", type=str, help="Pretrained tokenizer name or path if not the same as model_name",) 34 | parser.add_argument("--cache_dir", default="/tmp/", type=str, help="Where do you want to store the pre-trained models downloaded from s3",) 35 | parser.add_argument("--version_2_with_negative", action="store_true", help="If true, the SQuAD examples contain some that do not have an answer.",) 36 | parser.add_argument("--null_score_diff_threshold", type=float, default=0.0, help="If null_score - best_non_null is greater than the threshold predict null.",) 37 | parser.add_argument("--max_seq_length", default=384, type=int, help="The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded.",) 38 | parser.add_argument("--doc_stride", default=128, type=int, help="When splitting up a long document into chunks, how much stride to take between chunks.",) 39 | parser.add_argument("--max_query_length", default=64, type=int, help="The maximum number of tokens for the question. Questions longer than this will be truncated to this length.",) 40 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 41 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 42 | parser.add_argument("--repreprocess", action="store_true", help="Whether to re-prepare the qa_evidence_chain data.") 43 | parser.add_argument("--do_predict", action="store_true", help="Whether to run eval on the dev set.") 44 | parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.") 45 | 46 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 47 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.") 48 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 49 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",) 50 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 51 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 52 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 53 | parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.") 54 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 55 | parser.add_argument("--n_best_size", default=20, type=int, help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",) 56 | parser.add_argument("--max_answer_length", default=30, type=int, help="The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.") 57 | parser.add_argument("--verbose_logging", action="store_true", help="If true, all of the warnings related to data processing will be printed. A number of warnings are expected for a normal SQuAD evaluation.",) 58 | parser.add_argument("--lang_id", default=0, type=int, help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)",) 59 | parser.add_argument("--request_path", type=str, default='request_tok', help="Request directory.") 60 | parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") 61 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 62 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets") 63 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 64 | 65 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") 66 | parser.add_argument("--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",) 67 | parser.add_argument("--evaluate_during_training", action="store_true", help="Whether to evaluate during training",) 68 | parser.add_argument("--fp16_opt_level", type=str, default="O1", help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details at https://nvidia.github.io/apex/amp.html",) 69 | parser.add_argument("--topk_tbs", type=int, default=10, help="multiple threads for converting example to features") 70 | parser.add_argument("--prefix", type=str, default='', help="prefix for saving cached file") 71 | parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") 72 | args = parser.parse_args() 73 | return args -------------------------------------------------------------------------------- /retriever/retrieval/tfidf_retriever.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Interactive mode for the tfidf DrQA retriever module.""" 8 | 9 | import argparse 10 | import code 11 | import prettytable 12 | import logging 13 | from drqa import retriever 14 | import json 15 | import sys 16 | from utils.utils import whitelist, is_year 17 | import copy 18 | from multiprocessing import Pool, cpu_count 19 | from functools import partial 20 | from tqdm import tqdm 21 | 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.INFO) 24 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 25 | console = logging.StreamHandler() 26 | console.setFormatter(fmt) 27 | logger.addHandler(console) 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--model', type=str, required=True) 31 | parser.add_argument('--option', type=str, default='tfidf') 32 | parser.add_argument('--format', type=str, required=True) 33 | parser.add_argument('--split', type=str, default='dev') 34 | parser.add_argument('--debug', action='store_true', default=False) 35 | parser.add_argument("--basic_data_path", 36 | default='/home/t-wzhong/v-wanzho/ODQA/data',type=str) 37 | args = parser.parse_args() 38 | 39 | logger.info('Initializing ranker...') 40 | ranker = retriever.get_class(args.option)(tfidf_path=f'{args.basic_data_path}/preprocessed_data/{args.model}') 41 | 42 | if args.format == 'table_construction': 43 | with open(f'{args.basic_data_path}/all_passages.json') as f: 44 | cache = set(json.load(f).keys()) 45 | logger.info('Finished loading the passage keys') 46 | new_cache = {} 47 | 48 | def table_linker(kv): 49 | k, v = kv 50 | assert isinstance(k, str) and isinstance(v, dict) 51 | 52 | new_table = copy.deepcopy(v) 53 | new_table['data'] = [] 54 | new_table['header'] = [(_, []) for _ in v['header']] 55 | 56 | for row in v['data']: 57 | new_row = [] 58 | for cell in row: 59 | if not whitelist(cell) or is_year(cell): 60 | new_row.append((cell, [])) 61 | continue 62 | guessing = '/wiki/{}'.format(cell.replace(' ', '_')) 63 | if guessing in cache: 64 | new_row.append((cell, [guessing])) 65 | continue 66 | if cell in new_cache: 67 | new_row.append((cell, new_cache[cell])) 68 | continue 69 | 70 | try: 71 | doc_name, doc_scores = ranker.closest_docs(cell, 1) 72 | assert isinstance(doc_name, list) 73 | new_row.append((cell, doc_name)) 74 | new_cache[cell] = doc_name 75 | except Exception: 76 | new_row.append((cell, [])) 77 | 78 | assert len(new_row) == len(v['header']) 79 | new_table['data'].append(new_row) 80 | assert len(new_table['data']) == len(v['data']) 81 | return k, new_table 82 | 83 | if __name__ == '__main__': 84 | 85 | data_path = args.basic_data_path 86 | ottqa_data_path = f'{data_path}/data_ottqa' 87 | 88 | 89 | if args.format == 'question_table': 90 | with open(f'{data_path}/data_ottqa/{args.split}.json', 'r') as f: 91 | data = json.load(f) 92 | for k in [100]: 93 | retr_tables = {} 94 | outf = open(f'{data_path}/preprocessed_data/retrieval/{args.split}_tfidf_title_sectitle_header_top{k}.json', 95 | 'w', encoding='utf8') 96 | succ = 0 97 | for i, d in enumerate(data): 98 | groundtruth_doc = d['table_id'] 99 | qid = d['question_id'] 100 | query = d['question'] 101 | doc_names, doc_scores = ranker.closest_docs(query, k) 102 | retr_tables[qid] = {'doc_ids':doc_names,'doc_scores':doc_scores.tolist()} 103 | if groundtruth_doc in doc_names: 104 | succ += 1 105 | sys.stdout.write('finished {}/{}; HITS@{} = {} \r'.format(i + 1, len(data), k, succ / (i + 1))) 106 | json.dump(retr_tables,outf,indent=4) 107 | print('Saving the top {} retrieved tables in path {}'.format(k,f'{data_path}/preprocessed_data/retrieval/{args.split}_tfidf_title_sectitle_header_top{k}.json')) 108 | print('finished {}/{}; HITS@{} = {} \r'.format(i + 1, len(data), k, succ / (i + 1))) 109 | 110 | elif args.format == 'cell_text': 111 | with open(f'{data_path}/traindev_tables.json') as f: 112 | traindev_tables = json.load(f) 113 | with open(f'{data_path}/train_dev_test_table_ids.json') as f: 114 | tables_ids = set(json.load(f)['dev']) 115 | with open(f'{data_path}/link_generator/row_passage_query.json', 'r') as f: 116 | mapping = json.load(f) 117 | 118 | succ, prec_total, recall_total = 0, 0, 0 119 | for k, table in traindev_tables.items(): 120 | if k not in tables_ids: 121 | continue 122 | 123 | for j, row in enumerate(table['data']): 124 | row_id = k + '_{}'.format(j) 125 | queries = mapping.get(row_id, []) 126 | gt_docs = [] 127 | for cell in row: 128 | gt_docs.extend(cell[1]) 129 | doc_names = [] 130 | for query in queries: 131 | try: 132 | doc_name, doc_scores = ranker.closest_docs(query, 1) 133 | doc_names.extend(doc_name) 134 | except Exception: 135 | pass 136 | 137 | succ += len(set(gt_docs) & set(doc_names)) 138 | prec_total += len(queries) 139 | recall_total += len(gt_docs) 140 | 141 | if len(queries) == 0 and len(gt_docs) > 0: 142 | pass 143 | 144 | recall = succ / (recall_total + 0.01) 145 | precision = succ / (prec_total + 0.01) 146 | f1 = 2 * recall * precision / (precision + recall + 0.01) 147 | sys.stdout.write('F1@{} = {} \r'.format(1, f1)) 148 | 149 | print('F1@{} = {}'.format(1, f1)) 150 | 151 | elif args.format == 'table_construction': 152 | with open(f'{data_path}/all_plain_tables.json') as f: 153 | tables = json.load(f) 154 | logger.info('Finished loading the plain tables') 155 | 156 | n_threads = 64 157 | results = [] 158 | with Pool(n_threads) as p: 159 | results = list( 160 | tqdm( 161 | p.imap(table_linker, tables.items(), chunksize=16), 162 | total=len(tables), 163 | desc="process tables", 164 | ) 165 | ) 166 | 167 | linked_tables = dict(results) 168 | with open(f'{data_path}/all_constructed_tables.json', 'w') as f: 169 | json.dump(linked_tables, f) 170 | 171 | else: 172 | raise NotImplementedError() 173 | 174 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/negative_search_by_bm25/elastic_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import unicodedata 5 | 6 | import inflect 7 | from elasticsearch import Elasticsearch 8 | from elasticsearch import helpers 9 | from tqdm import tqdm 10 | import sys 11 | sys.path.append('../') 12 | # from utils_preprocess import args 13 | # from utils.rule_pattern import RulePattern 14 | 15 | inflect = inflect.engine() 16 | # RP = RulePattern() 17 | 18 | def convert_chain_to_string(chain): 19 | chain_rep = [] 20 | for node in chain[1:]: 21 | # print(node) 22 | if node['origin']['where'] == 'table': 23 | prefix = '[TAB] ' 24 | elif node['origin']['where'] == 'passage': 25 | prefix = '[PASSAGE] {} : '.format(node['origin']['index'].replace('/wiki/', '').replace('_', ' ').split('/')[0]) 26 | else: 27 | prefix = '[QUESTION] ' 28 | chain_rep.append(prefix+node['content']) 29 | # print(' [TO] '.join(chain_rep)) 30 | # input() 31 | return ' ; '.join(chain_rep) 32 | def check_contain_upper(self, password): 33 | pattern = re.compile('[A-Z]+') 34 | match = pattern.findall(password) 35 | if match: 36 | return True 37 | else: 38 | return False 39 | 40 | 41 | class SearchQuery(): 42 | @classmethod 43 | def claim2text(cls, claim, type='title_text'): 44 | search_body = { 45 | "query": { 46 | "match": { 47 | type: claim 48 | } 49 | } 50 | } 51 | return search_body 52 | 53 | @classmethod 54 | def claim2text_title(cls, claim): 55 | # score in both text and title 56 | search_body = { 57 | "query": { 58 | "multi_match": { 59 | "query": claim, 60 | "fields": ['text', 'title'], 61 | "fuzziness": "AUTO" 62 | } 63 | } 64 | } 65 | return search_body 66 | 67 | @classmethod 68 | def kws2title(cls, multi_claim): 69 | search_body = { 70 | "query": { 71 | "bool": { 72 | "should": [ 73 | 74 | ] 75 | } 76 | }} 77 | for claim in multi_claim: 78 | tiny_body = { 79 | "match_phrase": { 80 | "title": { 81 | 'query': claim, 82 | "slop": 2 83 | } 84 | 85 | # "slop": 5 86 | } 87 | } 88 | search_body['query']['bool']['should'].append(tiny_body) 89 | return search_body 90 | 91 | 92 | class MyElastic(): 93 | def __init__(self, index_name='evidence_chain'): 94 | self.es = Elasticsearch([{'host': '127.0.0.1', 'port': 9200}]) 95 | self.index_name = index_name 96 | body = { 97 | "properties": { 98 | "text": { 99 | "type": "text", 100 | "analyzer":"analyzed" 101 | } 102 | } 103 | } 104 | 105 | if not self.es.indices.exists(index=self.index_name,request_timeout=60): 106 | self.es.indices.create(self.index_name,request_timeout=60) 107 | self.es.indices.put_mapping(index=self.index_name, doc_type='evidence_chain', 108 | body=body, include_type_name=True) 109 | # self.es.indices.put_mapping(index=self.index_name, doc_type='wiki_sentence', 110 | # body=body, include_type_name=True) 111 | 112 | def search(self, search_body): 113 | ret = self.es.search(index=self.index_name, body=search_body, size=10) 114 | return ret 115 | 116 | def bulk_insert_all_chains_finetune(self, file_paths): 117 | data = [] 118 | for file_path in file_paths: 119 | with open(file_path,'r',encoding='utf8') as inf: 120 | data.extend(json.load(inf)) 121 | all_chains_strs = [] 122 | all_chains = [] 123 | for item in data: 124 | for table_block in item['positive_table_blocks']: 125 | ecs = table_block['evidence_chain']['positive'] 126 | ec_strs = [convert_chain_to_string(chain) for chain in ecs] 127 | all_chains_strs.extend(ec_strs) 128 | all_chains.extend(ecs) 129 | cnt = 0 130 | actions = [] 131 | for id,chain in tqdm(enumerate(all_chains_strs)): 132 | input_body = { 133 | "_index": self.index_name, 134 | "_type":"evidence_chain", 135 | "_id":id, 136 | "_source":{ 137 | 'id':id, 138 | 'text':chain, 139 | 'chain':all_chains[id] 140 | } 141 | } 142 | cnt+=1 143 | actions.append(input_body) 144 | if len(actions) != 0: 145 | print(helpers.bulk(self.es, actions,request_timeout=60)) 146 | 147 | def bulk_insert_all_chains_pretrain(self, file_path): 148 | with open(file_path,'r',encoding='utf8') as inf: 149 | all_chains = [json.loads(line.strip())['ec'] for line in inf.readlines()] 150 | cnt = 0 151 | actions = [] 152 | for id,chain in tqdm(enumerate(all_chains)): 153 | input_body = { 154 | "_index": self.index_name, 155 | "_type":"evidence_chain", 156 | "_id":id, 157 | "_source":{ 158 | 'id':id, 159 | 'text':chain, 160 | } 161 | } 162 | cnt+=1 163 | actions.append(input_body) 164 | if len(actions) != 0: 165 | print(helpers.bulk(self.es, actions,request_timeout=60)) 166 | 167 | 168 | 169 | 170 | def delete_one(self): 171 | res = self.es.indices.delete(index='evidence_chain',request_timeout=60) 172 | print(res) 173 | 174 | def create(self): 175 | 176 | body = { 177 | "properties": { 178 | "text": { 179 | "type": "text", 180 | # "analyzer": "analyzed" 181 | } 182 | } 183 | } 184 | 185 | if not self.es.indices.exists(index=self.index_name): 186 | self.es.indices.create(self.index_name) 187 | self.es.indices.put_mapping(index=self.index_name, doc_type='evidence_chain', 188 | body=body, include_type_name=True) 189 | 190 | def clear_cache(self, index_name='wiki_search'): 191 | res = self.es.indices.clear_cache(index=index_name) 192 | print(res) 193 | 194 | def delete_index(self, index_name="wiki_search"): 195 | query = {'query': {"match_all": {}}} 196 | res = self.es.delete_by_query(index=index_name, body=query) 197 | print(res) 198 | 199 | def search_by_chain(self,query): 200 | search_body = SearchQuery.claim2text(query,'text') 201 | ret = self.search(search_body) 202 | return ret 203 | 204 | 205 | 206 | if __name__ == "__main__": 207 | ES = MyElastic() 208 | # ES.delete_one() 209 | res = ES.search_by_id('Soul_Food_-LRB-film-RRB-0') 210 | print(res) 211 | -------------------------------------------------------------------------------- /retriever/retrieval/criterions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | from torch.nn import CrossEntropyLoss 8 | import torch.nn.functional as F 9 | 10 | def loss_single(model, batch): 11 | outputs = model(batch) 12 | q = outputs['q'] 13 | c = outputs['c'] 14 | neg_c = outputs['neg_c'] 15 | product_in_batch = torch.mm(q, c.t()) 16 | product_neg = (q * neg_c).sum(-1).unsqueeze(1) 17 | product = torch.cat([product_in_batch, product_neg], dim=-1) 18 | 19 | target = torch.arange(product.size(0)).to(product.device) 20 | loss = F.cross_entropy(product, target) 21 | return loss 22 | 23 | 24 | 25 | def mhop_loss(model, batch, args): 26 | 27 | outputs = model(batch) 28 | loss_fct = CrossEntropyLoss(ignore_index=-1) 29 | 30 | all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0) 31 | neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1) # B x 2 x M x h 32 | 33 | scores_1_hop = torch.mm(outputs["q"], all_ctx.t()) 34 | neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 35 | scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t()) 36 | neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 37 | 38 | # mask the 1st hop 39 | bsize = outputs["q"].size(0) 40 | scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device) 41 | scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop) 42 | scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1) 43 | scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1) 44 | 45 | target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) 46 | target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0) 47 | 48 | retrieve_loss = loss_fct(scores_1_hop, target_1_hop) + loss_fct(scores_2_hop, target_2_hop) 49 | 50 | return retrieve_loss 51 | 52 | def mhop_eval(outputs, args): 53 | all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0) 54 | neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1) 55 | 56 | 57 | scores_1_hop = torch.mm(outputs["q"], all_ctx.t()) 58 | neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 59 | scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t()) 60 | neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 61 | 62 | 63 | bsize = outputs["q"].size(0) 64 | scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device) 65 | scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop) 66 | scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1) 67 | scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1) 68 | target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) 69 | target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0) 70 | 71 | ranked_1_hop = scores_1_hop.argsort(dim=1, descending=True) 72 | ranked_2_hop = scores_2_hop.argsort(dim=1, descending=True) 73 | idx2ranked_1 = ranked_1_hop.argsort(dim=1) 74 | idx2ranked_2 = ranked_2_hop.argsort(dim=1) 75 | rrs_1, rrs_2 = [], [] 76 | for t, idx2ranked in zip(target_1_hop, idx2ranked_1): 77 | rrs_1.append(1 / (idx2ranked[t].item() + 1)) 78 | for t, idx2ranked in zip(target_2_hop, idx2ranked_2): 79 | rrs_2.append(1 / (idx2ranked[t].item() + 1)) 80 | 81 | return {"rrs_1": rrs_1, "rrs_2": rrs_2} 82 | 83 | 84 | def unified_loss(model, batch, args): 85 | 86 | outputs = model(batch) 87 | all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0) 88 | neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1) 89 | scores_1_hop = torch.mm(outputs["q"], all_ctx.t()) 90 | neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 91 | scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t()) 92 | neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 93 | 94 | # mask for 1st hop 95 | bsize = outputs["q"].size(0) 96 | scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device) 97 | scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop) 98 | scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1) 99 | scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1) 100 | 101 | stop_loss = F.cross_entropy(outputs["stop_logits"], batch["stop_targets"].view(-1), reduction="sum") 102 | 103 | target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) 104 | target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0) 105 | 106 | 107 | retrieve_loss = F.cross_entropy(scores_1_hop, target_1_hop, reduction="sum") + (F.cross_entropy(scores_2_hop, target_2_hop, reduction="none") * batch["stop_targets"].view(-1)).sum() 108 | 109 | return retrieve_loss + stop_loss 110 | 111 | def unified_eval(outputs, batch): 112 | all_ctx = torch.cat([outputs['c1'], outputs['c2']], dim=0) 113 | neg_ctx = torch.cat([outputs["neg_1"].unsqueeze(1), outputs["neg_2"].unsqueeze(1)], dim=1) 114 | scores_1_hop = torch.mm(outputs["q"], all_ctx.t()) 115 | scores_2_hop = torch.mm(outputs["q_sp1"], all_ctx.t()) 116 | neg_scores_1 = torch.bmm(outputs["q"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 117 | neg_scores_2 = torch.bmm(outputs["q_sp1"].unsqueeze(1), neg_ctx.transpose(1,2)).squeeze(1) 118 | bsize = outputs["q"].size(0) 119 | scores_1_mask = torch.cat([torch.zeros(bsize, bsize), torch.eye(bsize)], dim=1).to(outputs["q"].device) 120 | scores_1_hop = scores_1_hop.float().masked_fill(scores_1_mask.bool(), float('-inf')).type_as(scores_1_hop) 121 | scores_1_hop = torch.cat([scores_1_hop, neg_scores_1], dim=1) 122 | scores_2_hop = torch.cat([scores_2_hop, neg_scores_2], dim=1) 123 | target_1_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) 124 | target_2_hop = torch.arange(outputs["q"].size(0)).to(outputs["q"].device) + outputs["q"].size(0) 125 | 126 | # stop accuracy 127 | stop_pred = outputs["stop_logits"].argmax(dim=1) 128 | stop_targets = batch["stop_targets"].view(-1) 129 | stop_acc = (stop_pred == stop_targets).float().tolist() 130 | 131 | ranked_1_hop = scores_1_hop.argsort(dim=1, descending=True) 132 | ranked_2_hop = scores_2_hop.argsort(dim=1, descending=True) 133 | idx2ranked_1 = ranked_1_hop.argsort(dim=1) 134 | idx2ranked_2 = ranked_2_hop.argsort(dim=1) 135 | 136 | rrs_1_mhop, rrs_2_mhop, rrs_nq = [], [], [] 137 | for t1, idx2ranked1, t2, idx2ranked2, stop in zip(target_1_hop, idx2ranked_1, target_2_hop, idx2ranked_2, stop_targets): 138 | if stop: # 139 | rrs_1_mhop.append(1 / (idx2ranked1[t1].item() + 1)) 140 | rrs_2_mhop.append(1 / (idx2ranked2[t2].item() + 1)) 141 | else: 142 | rrs_nq.append(1 / (idx2ranked1[t1].item() + 1)) 143 | 144 | return { 145 | "stop_acc": stop_acc, 146 | "rrs_1_mhop": rrs_1_mhop, 147 | "rrs_2_mhop": rrs_2_mhop, 148 | "rrs_nq": rrs_nq 149 | } 150 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pickle 3 | import json 4 | from tqdm import tqdm 5 | import numpy as np 6 | from scipy.linalg import norm 7 | def get_current_time_str(): 8 | return str(datetime.now().strftime('%Y_%m_%d_%H:%M:%S')) 9 | def load_json(filename): 10 | with open(filename,'r',encoding='utf8') as f: 11 | return json.load(f) 12 | 13 | def find_similar_sentence_tfidf(sentence1,corpus,origin,cv,topk=5): 14 | if corpus: 15 | text_corpus = [sentence1] + corpus#[doc['passage'] for doc in corpus] 16 | vecs = cv.fit_transform(text_corpus).toarray() 17 | que_vec = vecs[0] 18 | scores = [] 19 | for idx, doc_vec in enumerate(vecs[1:]): 20 | score = np.dot(que_vec, doc_vec) / (norm(que_vec) * norm(doc_vec)) 21 | scores.append(score) 22 | scored_corpus = [(doc,score,ori) for doc,score,ori in zip(corpus,scores,origin)] 23 | results = sorted(scored_corpus,key=lambda k:k[1],reverse=True) 24 | topk = min(len(results),topk) 25 | return [res[2] for res in results[:topk]] 26 | # for res in results: 27 | # if res[0] not in sentence1: 28 | # # print(sentence1,' ; ',res[0]) 29 | # return res[0] 30 | # return results[0][0] 31 | else: 32 | return None 33 | def load_jsonl(filename): 34 | d_list = [] 35 | with open(filename, encoding='utf-8', mode='r') as in_f: 36 | print("Load Jsonl:", filename) 37 | for line in tqdm(in_f): 38 | item = json.loads(line.strip()) 39 | d_list.append(item) 40 | return d_list 41 | 42 | def load_pickle(filename): 43 | with open(filename,'rb') as f: 44 | return pickle.load(f) 45 | 46 | def save_json(results,filename): 47 | with open(filename,'w',encoding='utf8') as inf: 48 | json.dump(results,inf) 49 | 50 | def convert_tb_to_string_metadata(table, passages, meta_data, cut='passage', max_length=400): 51 | header = table.columns.tolist() 52 | value = table.values.tolist() 53 | # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 54 | # ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] '+' [SEP] '.join(value[0]) 55 | table_str = ' [TAB] ' + ' [TITLE] ' + meta_data['title']+' [SECTITLE] ' + meta_data['section_title'] + ' [DATA] '+\ 56 | ' ; '.join(['{} is {}'.format(h,c) for h,c in zip(header,value[0])]) 57 | passage_str = ' [PASSAGE] ' + ' [SEP] '.join(passages) 58 | 59 | return '{} {}'.format(table_str, passage_str) 60 | def convert_tb_to_string_metadata_old(table, passages, meta_data, cut='passage', max_length=400): 61 | header = table.columns.tolist() 62 | value = table.values.tolist() 63 | table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 64 | ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] '+' [SEP] '.join(value[0]) 65 | passage_str = ' [PASSAGE] ' + ' [SEP] '.join(passages) 66 | return '{} {}'.format(table_str, passage_str) 67 | 68 | def convert_tb_to_string(table, passages, cut='passage', max_length=460, topk_block=15): 69 | header = table.columns.tolist() 70 | value = table.values.tolist() 71 | # table_str = '[HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 72 | table_str = '[HEADER] ' + ' '.join(['{} is {}'.format(h, c) for h, c in zip(header, value[0])]) 73 | # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 74 | # ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 75 | passage_str = ' [PASSAGE] ' + ' [SEP] '.join(passages) 76 | if cut == 'passage': 77 | table_length = min(max_length, len(table_str.split(' '))) 78 | doc_length = 0 if table_length >= max_length else max_length - table_length 79 | else: 80 | doc_length = min(max_length, len(passage_str.split(' '))) 81 | table_length = 0 if doc_length >= max_length else max_length - doc_length 82 | 83 | # table_str = ' '.join(table_str.split(' ')[:table_length]) 84 | # passage_str = ' '.join(passage_str.split(' ')[:doc_length]) 85 | return '{} {}'.format(table_str, passage_str) 86 | 87 | def convert_table_to_string(table, meta_data=None, max_length=90): 88 | header = table.columns.tolist() 89 | value = table.values.tolist() 90 | table_str = '[HEADER] ' + ' '.join(['{} is {}'.format(h, c) for h, c in zip(header, value[0])]) 91 | # table_str = '[HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 92 | if meta_data: 93 | table_str = '[TAB] [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + ' ' + table_str 94 | # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 95 | # ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 96 | return table_str 97 | 98 | # def convert_tb_to_string_metadata(table, passages, meta_data, cut='passage', max_length=400): 99 | # header = table.columns.tolist() 100 | # value = table.values.tolist() 101 | # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 102 | # ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] '+' [SEP] '.join(value[0]) 103 | # passage_str = ' [PASSAGE] ' + ' [SEP] '.join(passages) 104 | # return '{} {}'.format(table_str, passage_str) 105 | # 106 | # def convert_tb_to_string(table, passages, cut='passage', max_length=460, topk_block=15): 107 | # header = table.columns.tolist() 108 | # value = table.values.tolist() 109 | # table_str = '[HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 110 | # # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 111 | # # ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 112 | # passage_str = ' [PASSAGE] ' + ' [SEP] '.join(passages) 113 | # if cut == 'passage': 114 | # table_length = min(max_length, len(table_str.split(' '))) 115 | # doc_length = 0 if table_length >= max_length else max_length - table_length 116 | # else: 117 | # doc_length = min(max_length, len(passage_str.split(' '))) 118 | # table_length = 0 if doc_length >= max_length else max_length - doc_length 119 | # 120 | # # table_str = ' '.join(table_str.split(' ')[:table_length]) 121 | # # passage_str = ' '.join(passage_str.split(' ')[:doc_length]) 122 | # return '{} {}'.format(table_str, passage_str) 123 | # 124 | # def convert_table_to_string(table, meta_data=None, max_length=90): 125 | # header = table.columns.tolist() 126 | # value = table.values.tolist() 127 | # table_str = '[HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 128 | # if meta_data: 129 | # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + table_str 130 | # # table_str = ' [TITLE] ' + meta_data['title'] + ' [SECTITLE] ' + meta_data['section_title'] + \ 131 | # # ' [HEADER] ' + ' [SEP] '.join(header) + ' [DATA] ' + ' [SEP] '.join(value[0]) 132 | # return table_str 133 | 134 | 135 | 136 | def get_passages(js, psg_mode, neg=False): 137 | prefix = "neg_" if neg else "" 138 | if psg_mode=='ori': 139 | psg = js[prefix+"passages"] 140 | elif psg_mode=="s_sent": 141 | psg = js[prefix+'s_sent'] if len(js[prefix+'s_sent']) > 0 else js[prefix+"passages"] 142 | elif psg_mode=="s_psg": 143 | psg = js[prefix+'s_psg'] if len(js[prefix+'s_psg']) > 0 else js[prefix+"passages"] 144 | else: 145 | psg = [] 146 | return psg 147 | 148 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/model.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaForSequenceClassification 2 | import torch 3 | from torch import nn 4 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 5 | from torch.nn import CrossEntropyLoss 6 | class RobertaClassificationHead(nn.Module): 7 | """Head for sentence-level classification tasks.""" 8 | 9 | def __init__(self, config): 10 | super().__init__() 11 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 12 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 13 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 14 | 15 | def forward(self, features, **kwargs): 16 | x = features[:, 0, :] # take <s> token (equiv. to [CLS]) 17 | x = self.dropout(x) 18 | x = self.dense(x) 19 | x = torch.tanh(x) 20 | x = self.dropout(x) 21 | x = self.out_proj(x) 22 | return x 23 | class RobertaForSequenceClassificationEC(RobertaForSequenceClassification): 24 | def __init__(self, config): 25 | super().__init__(config) 26 | self.num_labels = config.num_labels 27 | config.num_labels = 1 28 | self.classifier = RobertaClassificationHead(config) 29 | self.init_weights() 30 | 31 | def forward( 32 | self, 33 | a_input_ids=None, 34 | a_attention_mask=None, 35 | a_token_type_ids=None, 36 | pos_b_input_ids=None, 37 | pos_b_attention_mask=None, 38 | pos_b_token_type_ids=None, 39 | neg_b_input_ids=None, 40 | neg_b_attention_mask=None, 41 | neg_b_token_type_ids=None, 42 | position_ids=None, 43 | head_mask=None, 44 | inputs_embeds=None, 45 | labels=None, 46 | output_attentions=None, 47 | output_hidden_states=None, 48 | return_dict=None, 49 | ): 50 | r""" 51 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 52 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 53 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 54 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 55 | """ 56 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 57 | 58 | a_outputs = self.roberta( 59 | a_input_ids, 60 | attention_mask=a_attention_mask, 61 | token_type_ids=a_token_type_ids, 62 | position_ids=position_ids, 63 | ) 64 | a_sequence_output = a_outputs[0][:, 0, :] 65 | pos_b_outputs = self.roberta( 66 | pos_b_input_ids, 67 | attention_mask=pos_b_attention_mask, 68 | token_type_ids=pos_b_token_type_ids, 69 | ) 70 | pos_b_sequence_output = pos_b_outputs[0][:, 0, :] 71 | neg_b_outputs = self.roberta( 72 | neg_b_input_ids, 73 | attention_mask=neg_b_attention_mask, 74 | token_type_ids=neg_b_token_type_ids, 75 | ) 76 | neg_b_sequence_output = neg_b_outputs[0][:, 0, :] 77 | # print(a_sequence_output.size(),pos_b_sequence_output.size()) 78 | product_in_batch = torch.mm(a_sequence_output, pos_b_sequence_output.t()) 79 | product_neg = (a_sequence_output * neg_b_sequence_output).sum(-1).unsqueeze(1) 80 | product = torch.cat([product_in_batch, product_neg], dim=-1) 81 | 82 | # return {'inbatch_qc_scores': inbatch_qc_scores, 'neg_qc_score': neg_qc_score} 83 | # logits = self.classifier(sequence_output) 84 | target = torch.arange(product.size(0)).to(product.device) 85 | loss_fct = CrossEntropyLoss() 86 | loss = loss_fct(product, target) 87 | # if not return_dict: 88 | output = [a_sequence_output,pos_b_sequence_output,neg_b_sequence_output]#(logits,) + a_sequence_outputs[2:] 89 | return ((loss,output)) if loss is not None else output 90 | 91 | def batched_index_select(input, dim, index): 92 | views = [input.shape[0]] + [1 if i != dim else -1 for i in range(1, len(input.shape))] 93 | expanse = list(input.shape) 94 | expanse[0] = -1 95 | expanse[dim] = -1 96 | index = index.view(views).expand(expanse) 97 | return torch.gather(input, dim, index) 98 | 99 | class RobertaForSequenceClassificationECMASK(RobertaForSequenceClassification): 100 | def __init__(self, config): 101 | super().__init__(config) 102 | self.num_labels = config.num_labels 103 | config.num_labels = 1 104 | self.classifier = RobertaClassificationHead(config) 105 | self.init_weights() 106 | 107 | def forward( 108 | self, 109 | a_input_ids=None, 110 | a_attention_mask=None, 111 | a_token_type_ids=None, 112 | pos_b_input_ids=None, 113 | pos_b_attention_mask=None, 114 | pos_b_token_type_ids=None, 115 | neg_b_input_ids=None, 116 | neg_b_attention_mask=None, 117 | neg_b_token_type_ids=None, 118 | mask_idx=None, 119 | position_ids=None, 120 | head_mask=None, 121 | inputs_embeds=None, 122 | labels=None, 123 | output_attentions=None, 124 | output_hidden_states=None, 125 | return_dict=None, 126 | ): 127 | r""" 128 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 129 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 130 | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 131 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 132 | """ 133 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 134 | 135 | a_outputs = self.roberta( 136 | a_input_ids, 137 | attention_mask=a_attention_mask, 138 | token_type_ids=a_token_type_ids, 139 | position_ids=position_ids, 140 | ) 141 | a_sequence_output = a_outputs[0][:, 0, :] 142 | # a_sequence_output = batched_index_select(a_sequence_output, 1, mask_idx).squeeze(1) 143 | # print(a_sequence_output.size()) 144 | # print(mask_idx,a_sequence_output) 145 | pos_b_outputs = self.roberta( 146 | pos_b_input_ids, 147 | attention_mask=pos_b_attention_mask, 148 | token_type_ids=pos_b_token_type_ids, 149 | ) 150 | pos_b_sequence_output = pos_b_outputs[0][:, 0, :] 151 | neg_b_outputs = self.roberta( 152 | neg_b_input_ids, 153 | attention_mask=neg_b_attention_mask, 154 | token_type_ids=neg_b_token_type_ids, 155 | ) 156 | neg_b_sequence_output = neg_b_outputs[0][:, 0, :] 157 | # print(a_sequence_output.size(),pos_b_sequence_output.size()) 158 | product_in_batch = torch.mm(a_sequence_output, pos_b_sequence_output.t()) 159 | product_neg = (a_sequence_output * neg_b_sequence_output).sum(-1).unsqueeze(1) 160 | product = torch.cat([product_in_batch, product_neg], dim=-1) 161 | 162 | # return {'inbatch_qc_scores': inbatch_qc_scores, 'neg_qc_score': neg_qc_score} 163 | # logits = self.classifier(sequence_output) 164 | target = torch.arange(product.size(0)).to(product.device) 165 | loss_fct = CrossEntropyLoss() 166 | loss = loss_fct(product, target) 167 | # if not return_dict: 168 | output = [a_sequence_output,pos_b_sequence_output,neg_b_sequence_output]#(logits,) + a_sequence_outputs[2:] 169 | return ((loss,output)) if loss is not None else output 170 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/negative_search_by_bm25/search_doc_from_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from elastic_search import MyElastic, SearchQuery 4 | # from utils.config import args 5 | # from utils.rule_pattern import RulePattern 6 | # from utils import common 7 | import sys 8 | import argparse 9 | from multiprocessing import Pool, cpu_count 10 | from functools import partial 11 | import json 12 | sys.path.append('../../') 13 | from qa.utils_qa import read_jsonl 14 | # spacy.download('en_core_web_lg') 15 | ES = MyElastic() 16 | 17 | def convert_chain_to_string(chain): 18 | chain_rep = [] 19 | for node in chain[1:]: 20 | # print(node) 21 | if node['origin']['where'] == 'table': 22 | prefix = '[TAB] ' 23 | elif node['origin']['where'] == 'passage': 24 | prefix = '[PASSAGE] {} : '.format(node['origin']['index'].replace('/wiki/', '').replace('_', ' ').split('/')[0]) 25 | else: 26 | prefix = '[QUESTION] ' 27 | chain_rep.append(prefix+node['content']) 28 | # print(' [TO] '.join(chain_rep)) 29 | # input() 30 | return ' ; '.join(chain_rep) 31 | def get_result(query, res): 32 | ranked = res['hits']['hits'] 33 | end = min(len(ranked), 5) 34 | negs = [item['_source']['text'] for item in ranked[:end] if query != item['_source']['text']] 35 | return negs 36 | 37 | def search_from_db_single(data): 38 | try: 39 | res = ES.search_by_chain(data['ec']) 40 | neg_chain = get_result(data['ec'],res) 41 | # data['neg_chains'] = neg_chain 42 | return {'question':data['output'],'chain':data['ec'],'neg_chains':neg_chain,'tb':data['tb']} 43 | except Exception as e: 44 | # print(query) 45 | print(e) 46 | return None 47 | 48 | 49 | def search_from_db_pretrain(all_data,outf): 50 | def get_result(query,res): 51 | ranked = res['hits']['hits'] 52 | end = min(len(ranked),5) 53 | negs = [item['_source']['text'] for item in ranked[:end] if query!=item['_source']['text']] 54 | return negs 55 | error_cnt,all_cnt=0,0 56 | for did, data in tqdm(enumerate(all_data)): 57 | all_cnt += 1 58 | try: 59 | res = ES.search_by_chain(data['ec']) 60 | neg_chain = get_result(data['ec'],res) 61 | data['neg_chains'] = neg_chain 62 | if did<=5: 63 | print(data['ec']) 64 | print(neg_chain) 65 | print('---------------------') 66 | outf.write(json.dumps(data)+'\n') 67 | except Exception as e: 68 | # print(query) 69 | print(e) 70 | error_cnt += 1 71 | error_cnt = 0 72 | all_cnt = 0 73 | print(error_cnt,all_cnt,error_cnt/all_cnt) 74 | return tbib2_search_doc 75 | 76 | def get_result_finetune(query,res): 77 | ranked = res['hits']['hits'] 78 | end = min(len(ranked),5) 79 | negs = [item['_source']['chain'] for item in ranked[:end] if query!=item['_source']['text']] 80 | return negs 81 | def search_from_db_ft_single(data): 82 | for tbid, tb in enumerate(data['positive_table_blocks']): 83 | data['positive_table_blocks'][tbid]['evidence_chain']['es_negative']=[] 84 | for ecid, ec in enumerate(tb['evidence_chain']['positive']): 85 | ec_rep = convert_chain_to_string(ec) 86 | data['positive_table_blocks'][tbid]['evidence_chain']['es_negative'].append([]) 87 | try: 88 | res = ES.search_by_chain(ec_rep) 89 | neg_chain = get_result_finetune(ec_rep, res) 90 | data['positive_table_blocks'][tbid]['evidence_chain']['es_negative'][ecid] = neg_chain 91 | except Exception as e: 92 | print(e) 93 | data['positive_table_blocks'][tbid]['evidence_chain']['es_negative'][ecid] = None 94 | return data 95 | 96 | def search_from_db_finetune(all_data): 97 | def get_result(query,res): 98 | ranked = res['hits']['hits'] 99 | end = min(len(ranked),5) 100 | negs = [item['_source']['chain'] for item in ranked[:end] if query!=item['_source']['text']] 101 | return negs 102 | error_cnt,all_cnt=0,0 103 | for did, data in tqdm(enumerate(all_data)): 104 | 105 | for tbid,tb in enumerate(data['positive_table_blocks']): 106 | for ecid, ec in enumerate(tb['evidence_chain']['positive']): 107 | ec_rep = convert_chain_to_string(ec) 108 | all_cnt += 1 109 | try: 110 | res = ES.search_by_chain(convert_chain_to_string(ec)) 111 | neg_chain = get_result(ec_rep, res) 112 | all_data[did]['positive_table_blocks'][tbid]['evidence_chain']['es_negative'] = neg_chain 113 | except Exception as e: 114 | print(e) 115 | error_cnt +=1 116 | 117 | print(error_cnt,all_cnt,error_cnt/all_cnt) 118 | return tbib2_search_doc 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--pretrain', action='store_true') 122 | parser.add_argument('--finetune', action='store_true') 123 | args = parser.parse_args() 124 | if args.pretrain: 125 | basic_dir = './ODQA/data/evidence_chain_data/bart_output_for_pretraining/pre-training/evidence_output_pretrain_shortest.json' 126 | file_path = os.path.join(basic_dir,'') 127 | inf = open(basic_dir,'r') 128 | data = [] 129 | cnt = 0 130 | for line in inf: 131 | cnt += 1 132 | if cnt <= 1500000: 133 | continue 134 | data.append(json.loads(line.strip())) 135 | 136 | # data = [json.loads(line.strip()) for line in open(basic_dir,'r').readlines()] 137 | n_threads = 20 138 | with open('./ODQA/data/evidence_chain_data/bart_output_for_pretraining/add_negatives/pre-training/evidence_output_pretrain_shortest_esnegs-2.json', 'w') as outf: 139 | # results = search_from_db_pretrain(data,outf) 140 | running_function = search_from_db_single 141 | with Pool(n_threads) as p: 142 | func_ = partial(running_function) 143 | all_results = list(tqdm(p.imap(func_, data, chunksize=16), total=len(data), 144 | desc="find negatives", )) 145 | for result in all_results: 146 | 147 | if result: 148 | outf.write(json.dumps(result) + '\n') 149 | if args.finetune: 150 | basic_dir = './ODQA/data/evidence_chain_data/ground-truth-based/ground-truth-evidence-chain/' 151 | file_paths = [os.path.join(basic_dir, file) for file in 152 | ['train_gt-ec-weighted.json', 'dev_gt-ec-weighted.json']] 153 | n_threads = 20 154 | for file_path in file_paths: 155 | with open(file_path, 'r', encoding='utf8') as inf: 156 | data=json.load(inf) 157 | # data = search_from_db_finetune(data) 158 | # outf = open(file_path.replace('.json','-esneg.json'),'w',encoding='utf8') 159 | # print('Saving output to {}'.format(file_path.replace('.json','esneg.json'))) 160 | # del data 161 | running_function = search_from_db_ft_single 162 | with Pool(n_threads) as p: 163 | func_ = partial(running_function) 164 | all_results = list(tqdm(p.imap(func_, data, chunksize=16), total=len(data), 165 | desc="find negatives", )) 166 | print([item['positive_table_blocks'] for item in all_results[:2]]) 167 | outf = open(file_path.replace('.json', '-esneg.json'), 'w', encoding='utf8') 168 | print('Saving output to {}'.format(file_path.replace('.json','-esneg.json'))) 169 | json.dump(all_results,outf) 170 | # for result in all_results: 171 | # if result: 172 | # outf.write(json.dumps(result) + '\n') 173 | 174 | -------------------------------------------------------------------------------- /retriever/retrieval/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import argparse 7 | from ast import parse 8 | from typing import NamedTuple 9 | 10 | class ClusterConfig(NamedTuple): 11 | dist_backend: str 12 | dist_url: str 13 | 14 | def common_args(): 15 | parser = argparse.ArgumentParser() 16 | 17 | # task 18 | parser.add_argument("--train_file", type=str, 19 | default="../data/nq-with-neg-train.txt") 20 | parser.add_argument("--predict_file", type=str, 21 | default="../data/nq-with-neg-dev.txt") 22 | parser.add_argument("--num_workers", default=30, type=int) 23 | parser.add_argument("--do_train", default=False, 24 | action='store_true', help="Whether to run training.") 25 | parser.add_argument("--do_predict", default=False, 26 | action='store_true', help="Whether to run eval on the dev set.") 27 | parser.add_argument("--basic_data_path", 28 | default='/home/t-wzhong/v-wanzho/ODQA/data/',type=str) 29 | # model 30 | parser.add_argument("--model_name", 31 | default="bert-base-uncased", type=str) 32 | parser.add_argument("--init_checkpoint", type=str, 33 | help="Initial checkpoint (usually from a pre-trained BERT model).", 34 | default="") 35 | parser.add_argument("--max_c_len", default=512, type=int, 36 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 37 | "longer than this will be truncated, and sequences shorter than this will be padded.") 38 | parser.add_argument("--max_q_len", default=70, type=int, 39 | help="The maximum number of tokens for the question. Questions longer than this will " 40 | "be truncated to this length.") 41 | parser.add_argument("--max_p_len", default=360, type=int, 42 | help="The maximum number of tokens for the question. Questions longer than this will " 43 | "be truncated to this length.") 44 | parser.add_argument("--psg_mode", type=str, help="ways to use linked psg.", default="ori", 45 | choices=["ori", "s_sent", "s_psg"]) 46 | parser.add_argument("--cell_trim_length", default=20, type=int, 47 | help="The maximum number of tokens for each cell. Cell longer than this will " 48 | "be truncated to this length.") 49 | parser.add_argument('--fp16', action='store_true') 50 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 51 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 52 | "See details at https://nvidia.github.io/apex/amp.html") 53 | parser.add_argument("--no_cuda", default=False, action='store_true', 54 | help="Whether not to use CUDA when available") 55 | parser.add_argument("--local_rank", type=int, default=-1, 56 | help="local_rank for distributed training on gpus") 57 | parser.add_argument("--max_q_sp_len", default=50, type=int) 58 | parser.add_argument("--sent_level", action="store_true") 59 | parser.add_argument("--rnn_retriever", action="store_true") 60 | parser.add_argument("--predict_batch_size", default=512, 61 | type=int, help="Total batch size for predictions.") 62 | 63 | # multi vector scheme 64 | parser.add_argument("--multi_vector", type=int, default=1) 65 | parser.add_argument("--scheme", type=str, help="how to get the multivector, layerwise or tokenwise", default="none") 66 | parser.add_argument("--metadata", action="store_true", help="whether to add meta data, True(use) if call") 67 | parser.add_argument("--add_special_tokens", action="store_true", help="whether to add special tokens, True(use) if call") 68 | 69 | # NQ multihop trial 70 | parser.add_argument("--nq_multi", action="store_true", help="train the NQ retrieval model to recover from error cases") 71 | 72 | return parser 73 | 74 | def train_args(): 75 | parser = common_args() 76 | # optimization 77 | parser.add_argument('--prefix', type=str, default="eval") 78 | parser.add_argument("--weight_decay", default=0.0, type=float, 79 | help="Weight decay if we apply some.") 80 | parser.add_argument("--temperature", default=1, type=float) 81 | parser.add_argument("--output_dir", default="./logs", type=str, 82 | help="The output directory where the model checkpoints will be written.") 83 | parser.add_argument("--save_tensor_path", default="", type=str, 84 | help="The output directory where the training tensors will be stored.") 85 | parser.add_argument("--data_augmentation", action="store_true") 86 | parser.add_argument('--augment_file', type=str, default="") 87 | parser.add_argument("--overwrite_cache", action="store_true", help="Overwrite the cached training sets") 88 | parser.add_argument("--train_batch_size", default=64, 89 | type=int, help="Total batch size for training.") 90 | parser.add_argument("--per_gpu_train_batch_size", default=4,type=int, help="per-gpu batch size for training.") 91 | parser.add_argument("--learning_rate", default=2e-5, 92 | type=float, help="The initial learning rate for Adam.") 93 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 94 | help="Epsilon for Adam optimizer.") 95 | parser.add_argument("--num_train_epochs", default=50, type=float, 96 | help="Total number of training epochs to perform.") 97 | parser.add_argument("--save_checkpoints_steps", default=10000, type=int, 98 | help="How often to save the model checkpoint.") 99 | parser.add_argument("--iterations_per_loop", default=1000, type=int, 100 | help="How many steps to make in each estimator call.") 101 | parser.add_argument("--accumulate_gradients", type=int, default=1, 102 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 103 | parser.add_argument('--seed', type=int, default=1997, 104 | help="random seed for initialization") 105 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 106 | help="Number of updates steps to accumualte before performing a backward/update pass.") 107 | parser.add_argument('--eval_period', type=int, default=-1) 108 | parser.add_argument("--max_grad_norm", default=2.0, type=float, help="Max gradient norm.") 109 | parser.add_argument("--stop_drop", default=0, type=float) 110 | parser.add_argument("--use_adam", action="store_true") 111 | parser.add_argument("--warmup_ratio", default=0, type=float, help="Linear warmup over warmup_steps.") 112 | 113 | parser.add_argument("--train_link", action="store_true") 114 | parser.add_argument("--train_block", action="store_true") 115 | 116 | return parser.parse_args() 117 | 118 | def encode_args(): 119 | parser = common_args() 120 | parser.add_argument('--top_k', type=int, default=20,help='searching from topk retrieved documents') 121 | parser.add_argument('--tfidf_filter', action="store_true") 122 | parser.add_argument('--tfidf_result_file',type=str,default=None) 123 | parser.add_argument('--embed_save_path', type=str, default="") 124 | parser.add_argument('--tmp_data_save_path', type=str, default="") 125 | parser.add_argument('--encode_table', action="store_true", 126 | help="activate when encoding tables with passages, otherwise plain texts like questions") 127 | parser.add_argument('--is_query_embed', action="store_true") 128 | parser.add_argument('--no_passages', action="store_true") 129 | parser.add_argument('--save_links', action="store_true", help='save psg links to id2doc ') 130 | args = parser.parse_args() 131 | return args 132 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # CARP 2 | This repository serves primarily as codebase and data, model for training, evaluation and inference of the framework CARP, which is proposed by the paper 3 | [Reasoning over Hybrid Chain for Table-and-Text Open Domain Question Answering](https://arxiv.org/pdf/2201.05880.pdf) is the chain-centric reasoning and pre-training framework for table-and-text open domain QA. 4 | 5 | # Preprocessing 6 | 7 | ## Requirements 8 | 9 | ``` 10 | nltk 11 | fuzzywuzzy 12 | sklearn 13 | ``` 14 | 15 | ## Obtain Data 16 | 17 | Download OTT-QA and wiki tables: 18 | 19 | ```shell 20 | git clone https://github.com/wenhuchen/OTT-QA.git 21 | cd OTT-QA/data 22 | wget https://opendomainhybridqa.s3-us-west-2.amazonaws.com/all_plain_tables.json 23 | wget https://opendomainhybridqa.s3-us-west-2.amazonaws.com/all_passages.json 24 | cd .. 25 | cp -r OTT-QA/data data_wikitable 26 | cp -r OTT-QA/released_data data_ottqa 27 | ``` 28 | # Retrieval 29 | 30 | ## Data Preprocess 31 | 32 | #### Preprocess training data for retrieval 33 | 34 | ``` 35 | python retriever_preprocess.py \ 36 | --split train \ 37 | --nega intable_contra 38 | python retriever_preprocess.py \ 39 | --split dev \ 40 | --nega intable_contra 41 | ``` 42 | 43 | If you want to use the linked passage from BLINK, you can first download the linked passages from [all_constructed_blink_tables.json](https://github.com/zhongwanjun/CARP/releases/tag/blink-linked-table), then move the json file to `./data_wikitable`. After that, use the following command to preprocess. 44 | 45 | ``` 46 | python retriever_preprocess.py \ 47 | --split train \ 48 | --nega intable_contra \ 49 | --replace_link_passages \ 50 | --aug_blink 51 | python retriever_preprocess.py \ 52 | --split dev \ 53 | --nega intable_contra \ 54 | --replace_link_passages \ 55 | --aug_blink 56 | ``` 57 | 58 | #### Build retrieval corpus 59 | 60 | ``` 61 | python corpus_preprocess.py --split table_corpus_blink 62 | ``` 63 | 64 | This script creates corpus data used for inference. 65 | 66 | #### Train retriever 67 | 68 | `````` 69 | RUN_ID=0 70 | BASIC_PATH=. 71 | DATA_PATH=${BASIC_PATH}/preprocessed_data/retrieval 72 | TRAIN_DATA_PATH=${BASIC_PATH}/preprocessed_data/retrieval/train_intable_contra_blink_row.pkl 73 | DEV_DATA_PATH=${BASIC_PATH}/preprocessed_data/retrieval/dev_intable_contra_blink_row.pkl 74 | MODEL_PATH=${BASIC_PATH}/models/otter 75 | TABLE_CORPUS=table_corpus_blink 76 | mkdir ${MODEL_PATH} 77 | 78 | cd retriever/ 79 | python train_1hop_tb_retrieval.py \ 80 | --do_train \ 81 | --prefix ${RUN_ID} \ 82 | --predict_batch_size 800 \ 83 | --model_name roberta-base \ 84 | --shared_encoder \ 85 | --train_batch_size 64 \ 86 | --fp16 \ 87 | --max_c_len 512 \ 88 | --max_q_len 70 \ 89 | --metadata \ 90 | --num_train_epochs 20 \ 91 | --accumulate_gradients 1 \ 92 | --gradient_accumulation_steps 1 \ 93 | --warmup_ratio 0.1 \ 94 | --train_file ${TRAIN_DATA_PATH} \ 95 | --predict_file ${DEV_DATA_PATH} \ 96 | --output_dir ${MODEL_PATH} 97 | `````` 98 | 99 | #### Inference 100 | 101 | ##### Step 1: Encode table corpus and dev. questions 102 | 103 | ``` 104 | python encode_corpus.py \ 105 | --do_predict \ 106 | --predict_batch_size 100 \ 107 | --model_name roberta-base \ 108 | --metadata \ 109 | --fp16 \ 110 | --max_c_len 512 \ 111 | --predict_file ${BASIC_PATH}/data_ottqa/dev.json \ 112 | --init_checkpoint ${MODEL_PATH}/checkpoint_best.pt \ 113 | --embed_save_path ${MODEL_PATH}/indexed_embeddings/question_dev 114 | ``` 115 | 116 | Encode table-text block corpus. It takes about 3 hours to encode. 117 | 118 | ``` 119 | python encode_corpus.py \ 120 | --do_predict \ 121 | --encode_table \ 122 | --metadata \ 123 | --predict_batch_size 1600 \ 124 | --model_name roberta-base \ 125 | --fp16 \ 126 | --max_c_len 512 \ 127 | --predict_file ${DATA_PATH}/${TABLE_CORPUS}.pkl \ 128 | --init_checkpoint ${MODEL_PATH}/checkpoint_best.pt \ 129 | --embed_save_path ${MODEL_PATH}/indexed_embeddings/${TABLE_CORPUS} 130 | ``` 131 | 132 | ##### Step 4-2: Build index and search with FAISS 133 | 134 | The reported results are table recalls. 135 | 136 | ``` 137 | python eval_ottqa_retrieval.py \ 138 | --raw_data_path ${BASIC_PATH}/data_ottqa/dev.json \ 139 | --eval_only_ans \ 140 | --query_embeddings_path ${MODEL_PATH}/indexed_embeddings/question_dev.npy \ 141 | --corpus_embeddings_path ${MODEL_PATH}/indexed_embeddings/${TABLE_CORPUS}.npy \ 142 | --id2doc_path ${MODEL_PATH}/indexed_embeddings/${TABLE_CORPUS}/id2doc.json \ 143 | --output_save_path ${MODEL_PATH}/indexed_embeddings/dev_output_k100_${TABLE_CORPUS}.json \ 144 | --beam_size 100 145 | ``` 146 | 147 | # Evidence Chain 148 | ## Data Preprocess 149 | ```angular2html 150 | cd preprocess/ 151 | export CONCAT_TBS=15 152 | export TABLE_CORPUS=table_corpus_metagptdoc 153 | export MODEL_PATH=./ODQA/data/retrieval_results 154 | python ../preprocessing/qa_preprocess.py \ 155 | --split dev \ 156 | --reprocess \ 157 | --add_link \ 158 | --topk_tbs ${CONCAT_TBS} \ 159 | --retrieval_results_file ${MODEL_PATH}/dev_output_k100_${TABLE_CORPUS}.json \ 160 | --qa_save_path ${MODEL_PATH}/dev_preprocessed_${TABLE_CORPUS}_k100cat${CONCAT_TBS}.json \ 161 | 2>&1 |tee ${MODEL_PATH}/run_logs/${TABLE_CORPUS}/preprocess_qa_dev_k100cat${CONCAT_TBS}.log; 162 | ``` 163 | ## Extraction Model Training 164 | ### Extract ground-truth evidence chain 165 | ```angular2html 166 | cd evidence_chain/extraction 167 | 1. first extract keywords for ground-truth/retrieved table blocks for train/dev set 168 | python extract_evidence_chain.py --split train/dev --extract_keywords --kw_extract_type ground-truth/retrieved 169 | 2. extract ground-truth evidence chain 170 | python extract_evidence_chain.py --split train/dev --extract_evidence_chain 171 | or extract ground-truth evidence chain & generate data for training bart generator 172 | python extract_evidence_chain.py --split train/dev --extract_evidence_chain --save_bart_training_data 173 | ``` 174 | ### Extract candidate evidence chain 175 | ```angular2html 176 | python extract_evidence_chain.py --split train/dev --extract_candidate_evidence_chain 177 | ``` 178 | ### Training Extraction Model 179 | ```angular2html 180 | cd evidence_chain/fine-tune 181 | bash run_evidence_train.sh 182 | ``` 183 | ### Evaluate ranked evidence chain by their score 184 | ```angular2html 185 | python evaluate_ranked_evidence_chain.py 186 | ``` 187 | ## Extraction Model Pre-training 188 | ### Data Preprocess 189 | ```angular2html 190 | cd evidence_chain/pretrain_data_process 191 | generate inference data for bart 192 | python parse_table_psg_link.py bart_inference_data 193 | (generate templated fake pre-train data 194 | python parse_table_psg_link.py fake_pretrain_data) 195 | ``` 196 | ### BART-based Generator 197 | ```angular2html 198 | 199 | ``` 200 | ### Pre-training 201 | ```angular2html 202 | cd evidence_chain/pretrain 203 | bash run_evidence_pretrain.sh 204 | ``` 205 | # QA 206 | ## Baseline 207 | ### Data Preprocess 208 | ```angular2html 209 | cd preprocess/ 210 | export CONCAT_TBS=15 211 | export TABLE_CORPUS=table_corpus_metagptdoc 212 | export MODEL_PATH=./ODQA/data/retrieval_results 213 | python qa_preprocess.py \ 214 | --split dev \ 215 | --reprocess \ 216 | --add_link \ 217 | --topk_tbs ${CONCAT_TBS} \ 218 | --retrieval_results_file ${MODEL_PATH}/dev_output_k100_${TABLE_CORPUS}.json \ 219 | --qa_save_path ${MODEL_PATH}/dev_preprocessed_${TABLE_CORPUS}_k100cat${CONCAT_TBS}.json \ 220 | 2>&1 |tee ${MODEL_PATH}/run_logs/${TABLE_CORPUS}/preprocess_qa_dev_k100cat${CONCAT_TBS}.log; 221 | ``` 222 | ### Baseline QA Model Training 223 | ``` 224 | cd qa_baseline/ 225 | bash train_qa_baseline.sh 226 | ``` 227 | ### CARP QA Model Training 228 | ```angular2html 229 | #merge ground-truth and retrieved evidence chains for question answering 230 | cd preprocess 231 | python merge_ec_file.py 232 | cd qa_evidence_chain/ 233 | bash train_qa_evidence_chain_retrieved.sh 234 | ``` 235 | ### CARP QA Model Testing 236 | ```angular2html 237 | Eval the model with checkpoint: 238 | cd qa_evidence_chain/ 239 | bash test_qa_evidence_chain_retrieved.sh 240 | ``` 241 | # Data Information 242 | | File Type | File Name | File Location | 243 | | ---- | ---- | ---- | 244 | | Source Corpus | all_passages.json (and) all_plain_tables.json | source_corpus/OTT-QA/ 245 | | Wikipedia tables and passages | all_tables.json | source_corpus/Wikipedia-table-passages 246 | | Retrieval Results | train/dev/test_output_k100_table_corpus_metagptdoc.json | retrieval_results/ 247 | | QA data with extracted evidence chain | train/dev_ranked_evidence_chain_for_qa_weighted.json / test_evidence_chain_weighted_scores.json | qa_with_evidence_chain 248 | 249 | [//]: # "| evidence chain pretrain/train/valid/test data | (for-pretraining) bart_output_for_pretraining / (for training) ground-truth-based / (for testing) retrieval_based | evidence_chain_data/ " 250 | 251 | # Citation 252 | If you find this resource useful, please cite the paper introducing CARP: 253 | 254 | ``` 255 | @article{zhong2022reasoning, 256 | title={Reasoning over Hybrid Chain for Table-and-Text Open Domain QA}, 257 | author={Zhong, Wanjun and Huang, Junjie and Liu, Qian and Zhou, Ming and Wang, Jiahai and Yin, Jian and Duan, Nan}, 258 | journal={arXiv preprint arXiv:2201.05880}, 259 | year={2022} 260 | } 261 | ``` 262 | -------------------------------------------------------------------------------- /retriever/eval_ottqa_retrieval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import argparse 5 | import collections 6 | import json 7 | import logging 8 | import os 9 | import time 10 | 11 | import faiss 12 | import numpy as np 13 | import torch 14 | from tqdm import tqdm 15 | from transformers import AutoConfig, AutoTokenizer 16 | from multiprocessing import Pool, cpu_count 17 | from functools import partial 18 | 19 | # from retrieval.models.mhop_retriever import RobertaRetriever 20 | from retrieval.utils.basic_tokenizer import SimpleTokenizer 21 | from retrieval.utils.utils import (load_saved, move_to_cuda, para_has_answer, found_table) 22 | 23 | logger = logging.getLogger() 24 | logger.setLevel(logging.INFO) 25 | if (logger.hasHandlers()): 26 | logger.handlers.clear() 27 | console = logging.StreamHandler() 28 | logger.addHandler(console) 29 | 30 | 31 | def convert_hnsw_query(query_vectors): 32 | aux_dim = np.zeros(len(query_vectors), dtype='float32') 33 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 34 | return query_nhsw_vectors 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--raw_data_path', type=str, default=None) 40 | parser.add_argument('--query_embeddings_path', type=str, default=None) 41 | parser.add_argument('--corpus_embeddings_path', type=str, default=None) 42 | parser.add_argument('--id2doc_path', type=str, default=None) 43 | parser.add_argument('--faiss_save_path', type=str, default="data/ottqa_index/ottqa_index_tapas") 44 | parser.add_argument("--output_save_path", type=str, default="") 45 | 46 | # parser.add_argument('--topk', type=int, default=5, help="topk paths") 47 | parser.add_argument('--num_workers', type=int, default=20) 48 | parser.add_argument('--beam_size', type=int, default=100) 49 | parser.add_argument('--gpu', action="store_true") 50 | parser.add_argument('--save_index', action="store_true") 51 | parser.add_argument('--eval_only_ans', action="store_true") 52 | parser.add_argument('--eval_table_id', action="store_true") 53 | parser.add_argument('--hnsw', action="store_true") 54 | args = parser.parse_args() 55 | 56 | logger.info("Loading data...") 57 | ds_items = json.load(open(args.raw_data_path, 'r', encoding='utf8')) 58 | # filter 59 | # if args.eval_only_ans: 60 | # ds_items = [_ for _ in ds_items if _["answer"][0] not in ["yes", "no"]] 61 | 62 | simple_tokenizer = SimpleTokenizer() 63 | 64 | logger.info("Building index...") 65 | logger.info("Loading question embeddings from {}".format(args.query_embeddings_path)) 66 | questions = [_["question"][:-1] if _["question"].endswith("?") else _["question"] for _ in ds_items] 67 | metrics = [] 68 | # metrics_eval_table_id = [] 69 | retrieval_outputs = [] 70 | query_embeddings = np.load(args.query_embeddings_path).astype('float32') 71 | 72 | 73 | logger.info("Loading corpus embeddings from {}".format(args.corpus_embeddings_path)) 74 | d = 768 75 | xb = np.load(args.corpus_embeddings_path).astype('float32') 76 | logger.info("corpus size: {}".format(xb.shape)) 77 | 78 | if args.hnsw: 79 | if os.path.exists(args.faiss_save_path): 80 | # index = faiss.read_index("index/ottqa_index_hnsw.index") 81 | index = faiss.read_index(args.faiss_save_path) 82 | else: 83 | index = faiss.IndexHNSWFlat(d + 1, 512) 84 | index.hnsw.efSearch = 128 85 | index.hnsw.efConstruction = 200 86 | phi = 0 87 | for i, vector in enumerate(xb): 88 | norms = (vector ** 2).sum() 89 | phi = max(phi, norms) 90 | logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) 91 | 92 | data = xb 93 | buffer_size = 50000 94 | n = len(data) 95 | logger.info(n) 96 | for i in tqdm(range(0, n, buffer_size)): 97 | vectors = [np.reshape(t, (1, -1)) for t in data[i:i + buffer_size]] 98 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 99 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 100 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[idx].reshape(-1, 1))) for idx, doc_vector in 101 | enumerate(vectors)] 102 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 103 | index.add(hnsw_vectors) 104 | else: 105 | if os.path.exists(args.faiss_save_path): 106 | index = faiss.read_index(args.faiss_save_path) 107 | else: 108 | index = faiss.IndexFlatIP(d) 109 | index.add(xb) 110 | if args.gpu: 111 | res = faiss.StandardGpuResources() 112 | index = faiss.index_cpu_to_all_gpus(index) 113 | logger.info("Finish Building Index with IndexFlatIP") 114 | # if not os.path.exists(args.faiss_save_path): 115 | # faiss.write_index(index, args.faiss_save_path) 116 | 117 | logger.info(f"Loading corpus...") 118 | id2doc = json.load(open(args.id2doc_path)) 119 | if isinstance(id2doc["0"], list): 120 | id2doc = {k: {"title": v[0], "text": v[1]} for k, v in id2doc.items()} 121 | # title2text = {v[0]:v[1] for v in id2doc.values()} 122 | logger.info(f"Corpus size {len(id2doc)}") 123 | 124 | def searching(idx): 125 | # for idx, q_embeds_numpy in enumerate(tqdm(query_embeddings, desc='Processing: ')): 126 | q_embeds_numpy = np.expand_dims(query_embeddings[idx], 0) 127 | if args.hnsw: 128 | q_embeds_numpy = convert_hnsw_query(q_embeds_numpy) 129 | D, I = index.search(q_embeds_numpy, args.beam_size) 130 | 131 | b_idx = 0 132 | metric_i = {} 133 | output_i = {} 134 | topk_tbs = [] 135 | for _, tb_id in enumerate(I[b_idx]): 136 | tb = id2doc[str(tb_id)] 137 | topk_tbs.append(tb) 138 | if args.eval_only_ans: 139 | gold_answers = ds_items[idx]["answer-text"] 140 | metric_i = { 141 | "question": ds_items[idx]["question"], 142 | 'table_recall': int(found_table(ds_items[idx]['table_id'], topk_tbs)), 143 | "ans_recall": int(para_has_answer(gold_answers, topk_tbs, simple_tokenizer)), 144 | "type": ds_items[idx].get("type", "single") 145 | } 146 | if args.output_save_path != "": 147 | output_i = {"question_id": ds_items[idx]["question_id"], 148 | "question": ds_items[idx]["question"], 149 | "top_{}".format(args.beam_size): topk_tbs, } 150 | if 'answer-text' in ds_items[idx]: 151 | output_i['answer-text'] = ds_items[idx]['answer-text'] 152 | if 'table_id' in ds_items[idx]: 153 | output_i['table_id'] = ds_items[idx]['table_id'] 154 | return metric_i, output_i 155 | 156 | n_threads = 24 157 | with Pool(n_threads) as p: 158 | # func_ = partial(searching) 159 | results = list( 160 | tqdm(p.imap(searching, range(len(query_embeddings)), chunksize=16), total=len(query_embeddings), desc="Searching: ", )) 161 | metrics, retrieval_outputs = [item[0] for item in results], [item[1] for item in results] 162 | 163 | if args.output_save_path != "": 164 | with open(args.output_save_path, "w") as out: 165 | for l in retrieval_outputs: 166 | out.write(json.dumps(l) + "\n") 167 | logger.info("Saving outputs to {}".format(args.output_save_path)) 168 | 169 | 170 | def get_recall(answers, preds, n=5): 171 | truth_table = [] 172 | for idx, ans in enumerate(answers): 173 | truth_table.append(any([ans == inst for inst in preds[idx][:n]])) 174 | return sum(truth_table) / len(truth_table), sum(truth_table), len(truth_table) 175 | 176 | if 'test' not in args.raw_data_path: 177 | table_id_gold = [inst['table_id'] for inst in retrieval_outputs] 178 | table_id_preds = [[item['table_id'] for item in inst['top_100']] for inst in retrieval_outputs] 179 | r, t, a = get_recall(table_id_gold, table_id_preds, 1) 180 | logger.info("Table Recall @1: {}, {}/{}".format(r, t, a)) 181 | r, t, a = get_recall(table_id_gold, table_id_preds, 3) 182 | logger.info("Table Recall @5: {}, {}/{}".format(r, t, a)) 183 | r, t, a = get_recall(table_id_gold, table_id_preds, 10) 184 | logger.info("Table Recall @10: {}, {}/{}".format(r, t, a)) 185 | r, t, a = get_recall(table_id_gold, table_id_preds, 15) 186 | logger.info("Table Recall @15: {}, {}/{}".format(r, t, a)) 187 | r, t, a = get_recall(table_id_gold, table_id_preds, 20) 188 | logger.info("Table Recall @20: {}, {}/{}".format(r, t, a)) 189 | r, t, a = get_recall(table_id_gold, table_id_preds, 25) 190 | logger.info("Table Recall @50: {}, {}/{}".format(r, t, a)) 191 | r, t, a = get_recall(table_id_gold, table_id_preds, 100) 192 | logger.info("Table Recall @100: {}, {}/{}".format(r, t, a)) 193 | 194 | if args.eval_only_ans: 195 | logger.info(f"Evaluating {len(metrics)} samples...") 196 | type2items = collections.defaultdict(list) 197 | for item in metrics: 198 | type2items[item["type"]].append(item) 199 | logger.info(f'Ans Recall: {np.mean([m["ans_recall"] for m in metrics])}') 200 | logger.info(f'Table Recall: {np.mean([m["table_recall"] for m in metrics])}') 201 | for t in type2items.keys(): 202 | logger.info(f"{t} Questions num: {len(type2items[t])}") 203 | logger.info(f'Ans Recall: {np.mean([m["ans_recall"] for m in type2items[t]])}') 204 | 205 | 206 | --------------------------------------------------------------------------------