├── utils ├── __init__.py ├── config.py └── common.py ├── generator └── __init__.py ├── preprocess ├── __init__.py ├── preprocess_qa.sh └── merge_ec_file.py ├── qa_baseline ├── __init__.py ├── train_qa_baseline.sh └── config.py ├── qa_evidence_chain ├── __init__.py ├── test_qa_evidence_chain_retrieved.sh └── train_qa_evidence_chain_retrieved.sh ├── evidence_chain ├── extraction │ ├── __init__.py │ └── extraction_kws_ec.sh ├── fine-tune │ ├── __init__.py │ ├── run_evidence_train.sh │ ├── run_evidence_large_eval.sh │ └── run_evidence_train_from_pretrain.sh ├── pretrain │ ├── __init__.py │ ├── run_evidence_pretrain.sh │ ├── run_evidence_large_eval.sh │ └── run_evidence_base_train.sh ├── ranking_model │ ├── __init__.py │ ├── run_evidence_large_eval.sh │ ├── run_evidence_large_train.sh │ └── model.py ├── pretrain_data_process │ ├── __init__.py │ ├── path_extraction │ │ └── __init__.py │ └── negative_search_by_bm25 │ │ ├── __init__.py │ │ ├── create_db.py │ │ ├── elastic_search.py │ │ └── search_doc_from_db.py ├── .DS_Store └── evaluate_ranked_evidence_chain.py ├── retriever ├── retrieval │ ├── data │ │ └── __init__.py │ ├── drqa │ │ ├── __init__.py │ │ ├── retriever │ │ │ ├── __init__.py │ │ │ ├── doc_db.py │ │ │ ├── BM25_doc_ranker.py │ │ │ ├── tfidf_doc_ranker.py │ │ │ └── utils.py │ │ └── drqa_tokenizers │ │ │ ├── __init__.py │ │ │ ├── simple_tokenizer.py │ │ │ ├── spacy_tokenizer.py │ │ │ ├── regexp_tokenizer.py │ │ │ ├── tokenizer.py │ │ │ └── corenlp_tokenizer.py │ ├── models │ │ └── __init__.py │ ├── utils │ │ └── __init__.py │ ├── attention │ │ ├── __pycache__ │ │ │ ├── AFT.cpython-38.pyc │ │ │ ├── BAM.cpython-38.pyc │ │ │ ├── PSA.cpython-38.pyc │ │ │ ├── SGE.cpython-38.pyc │ │ │ ├── ViP.cpython-38.pyc │ │ │ ├── CBAM.cpython-38.pyc │ │ │ ├── DANet.cpython-38.pyc │ │ │ ├── EMSA.cpython-38.pyc │ │ │ ├── CoAtNet.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── SEAttention.cpython-38.pyc │ │ │ ├── SKAttention.cpython-38.pyc │ │ │ ├── A2Atttention.cpython-38.pyc │ │ │ ├── ECAAttention.cpython-38.pyc │ │ │ ├── MUSEAttention.cpython-38.pyc │ │ │ ├── SelfAttention.cpython-38.pyc │ │ │ ├── OutlookAttention.cpython-38.pyc │ │ │ ├── ShuffleAttention.cpython-38.pyc │ │ │ ├── ExternalAttention.cpython-38.pyc │ │ │ └── SimplifiedSelfAttention.cpython-38.pyc │ │ ├── ExternalAttention.py │ │ ├── ECAAttention.py │ │ ├── SEAttention.py │ │ ├── SGE.py │ │ ├── DANet.py │ │ ├── ViP.py │ │ ├── AFT.py │ │ ├── A2Atttention.py │ │ ├── PSA.py │ │ ├── OutlookAttention.py │ │ ├── SKAttention.py │ │ ├── CBAM.py │ │ ├── ShuffleAttention.py │ │ ├── CoAtNet.py │ │ ├── SimplifiedSelfAttention.py │ │ ├── SelfAttention.py │ │ ├── BAM.py │ │ ├── MUSEAttention.py │ │ ├── EMSA.py │ │ └── HaloAttention.py │ ├── __init__.py │ ├── ___tr_dataset.py │ ├── tfidf_retriever.py │ ├── criterions.py │ └── config.py ├── encode_corpus.py └── eval_ottqa_retrieval.py └── readme.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qa_baseline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /qa_evidence_chain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retriever/retrieval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/path_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/negative_search_by_bm25/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evidence_chain/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/evidence_chain/.DS_Store -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/AFT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/AFT.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/BAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/BAM.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/PSA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/PSA.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SGE.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SGE.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ViP.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ViP.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/CBAM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/CBAM.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/DANet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/DANet.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/EMSA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/EMSA.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/CoAtNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/CoAtNet.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SEAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SEAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SKAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SKAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/A2Atttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/A2Atttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ECAAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ECAAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/MUSEAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/MUSEAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SelfAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SelfAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/OutlookAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/OutlookAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ShuffleAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ShuffleAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/ExternalAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/ExternalAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/attention/__pycache__/SimplifiedSelfAttention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhongwanjun/CARP/HEAD/retriever/retrieval/attention/__pycache__/SimplifiedSelfAttention.cpython-38.pyc -------------------------------------------------------------------------------- /retriever/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python 7 | # Copyright 2017-present, Facebook, Inc. 8 | # All rights reserved. 9 | # 10 | # This source code is licensed under the license found in the 11 | # LICENSE file in the root directory of this source tree. 12 | 13 | from . import data 14 | from . import models 15 | from . import utils 16 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | def get_class(name): 11 | if name == 'tfidf': 12 | return TfidfDocRanker 13 | if name == 'bm25': 14 | return BM25DocRanker 15 | if name == 'sqlite': 16 | return DocDB 17 | raise RuntimeError('Invalid retriever class: %s' % name) 18 | 19 | 20 | from .doc_db import DocDB 21 | from .tfidf_doc_ranker import TfidfDocRanker 22 | from .BM25_doc_ranker import BM25DocRanker -------------------------------------------------------------------------------- /preprocess/preprocess_qa.sh: -------------------------------------------------------------------------------- 1 | export CONCAT_TBS=15 2 | export TABLE_CORPUS=table_corpus_metagptdoc 3 | export MODEL_PATH=./ODQA/retrieval_results/shared_roberta_threecat_basic_mean_one_query 4 | #python ../preprocessing/qa_preprocess.py \ 5 | # --split dev \ 6 | # --reprocess \ 7 | # --add_link \ 8 | # --topk_tbs ${CONCAT_TBS} \ 9 | # --retrieval_results_file ${MODEL_PATH}/dev_output_k100_${TABLE_CORPUS}.json \ 10 | # --qa_save_path ${MODEL_PATH}/dev_preprocessed_${TABLE_CORPUS}_k100cat${CONCAT_TBS}.json \ 11 | # 2>&1 |tee ${MODEL_PATH}/run_logs/${TABLE_CORPUS}/preprocess_qa_dev_k100cat${CONCAT_TBS}.log; 12 | python ../preprocessing/qa_preprocess.py \ 13 | --split train \ 14 | --reprocess \ 15 | --add_link \ 16 | --topk_tbs ${CONCAT_TBS} \ 17 | --retrieval_results_file ${MODEL_PATH}/train_output_k100_${TABLE_CORPUS}.json \ 18 | --qa_save_path ${MODEL_PATH}/train_preprocessed_${TABLE_CORPUS}_k100cat${CONCAT_TBS}.json \ 19 | 2>&1 |tee ${MODEL_PATH}/run_logs/${TABLE_CORPUS}/preprocess_qa_train_k100cat${CONCAT_TBS}.log; -------------------------------------------------------------------------------- /evidence_chain/pretrain/run_evidence_pretrain.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/evidence_chain_data/bart_output_for_pretraining/pre-training 4 | export TRAIN_DATA_PATH=evidence_pretrain_train_all_esnegs.jsonl 5 | export DEV_DATA_PATH=evidence_pretrain_dev_all_esnegs.jsonl 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-1neg-weighted-esneg 7 | python run_classifier.py \ 8 | --model_type roberta \ 9 | --model_name_or_path roberta-base \ 10 | --task_name evidence_chain \ 11 | --do_train \ 12 | --do_eval \ 13 | --eval_all_checkpoints \ 14 | --data_dir ${DATA_PATH} \ 15 | --output_dir ${MODEL_PATH} \ 16 | --train_file ${TRAIN_DATA_PATH} \ 17 | --dev_file ${DEV_DATA_PATH} \ 18 | --max_seq_length 512 \ 19 | --per_gpu_train_batch_size 32 \ 20 | --per_gpu_eval_batch_size 64 \ 21 | --learning_rate 3e-5 \ 22 | --evaluate_during_training \ 23 | --overwrite_cache \ 24 | --save_steps 6000 \ 25 | --num_train_epochs 3 \ 26 | 2>&1 | tee log_pretrain_3weighted.log 27 | -------------------------------------------------------------------------------- /evidence_chain/extraction/extraction_kws_ec.sh: -------------------------------------------------------------------------------- 1 | #extract keywords for retrieved or ground-truth evidence block 2 | python extract_evidence_chain.py --split train/dev --extract_keywords --kw_extract_type ground-truth/retrieved 3 | #extract ground-truth evidence chain 4 | python extract_evidence_chain.py --split train/dev --extract_evidence_chain 5 | #extract ground-truth evidence chain and save evidence chain for training bart 6 | python extract_evidence_chain.py --split train/dev --extract_evidence_chain --save_bart_training_data 7 | #extract candidate evidence chain 8 | python extract_evidence_chain.py --split train/dev --extract_candidate_evidence_chain 9 | #evaluate ranked evidence chain 10 | python evaluate_ranked_evidence_chain.py 11 | 12 | #extract pretrain data 13 | cd path_extraction 14 | #generate inference data for bart 15 | python parse_table_psg_link.py bart_inference_data 16 | #generate templated pretrain data 17 | #generate fake training and dev evidence chains 18 | python parse_table_psg_link.py fake_pretrain_data 19 | 20 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/run_evidence_train.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/evidence_chain_data/ground-truth-based/ground-truth-evidence-chain 4 | export TRAIN_DATA_PATH=train_gt-ec.jsonl 5 | export DEV_DATA_PATH=dev_gt-ec.jsonl 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large 7 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-new/checkpoint-best 8 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python run_classifier.py \ 9 | --model_type roberta \ 10 | --tokenizer_name roberta-large \ 11 | --config_name roberta-large \ 12 | --model_name_or_path roberta-large \ 13 | --task_name evidence_chain \ 14 | --overwrite_cache \ 15 | --do_train \ 16 | --do_eval \ 17 | --eval_all_checkpoints \ 18 | --data_dir ${DATA_PATH} \ 19 | --output_dir ${MODEL_PATH} \ 20 | --train_file ${TRAIN_DATA_PATH} \ 21 | --dev_file ${DEV_DATA_PATH} \ 22 | --max_seq_length 512 \ 23 | --per_gpu_train_batch_size 16 \ 24 | --per_gpu_eval_batch_size 16 \ 25 | --learning_rate 1e-5 \ 26 | --num_train_epochs 10 27 | -------------------------------------------------------------------------------- /qa_evidence_chain/test_qa_evidence_chain_retrieved.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=1 2 | export BASIC_DIR=./ODQA 3 | export MODEL_NAME=allenai/longformer-base-4096 4 | export TOKENIZERS_PARALLELISM=false 5 | export TRAIN_DATA_PATH=train_ranked_evidence_chain_for_qa_weighted.json 6 | export DEV_DATA_PATH=dev_ranked_evidence_chain_for_qa_weighted.json 7 | export MODEL_DIR=${MODEL_CHECKPOINT} 8 | export PREFIX=retrieved_blink_ecmask 9 | python train_final_qa_ori.py \ 10 | --do_eval \ 11 | --model_type longformer \ 12 | --evaluate_during_training \ 13 | --data_dir ${BASIC_DIR}/data/qa_with_evidence_chain \ 14 | --output_dir ${BASIC_DIR}/model/qa_model/${MODEL_DIR} \ 15 | --train_file ${TRAIN_DATA_PATH} \ 16 | --dev_file ${DEV_DATA_PATH} \ 17 | --per_gpu_train_batch_size 2 \ 18 | --gradient_accumulation_steps 1 \ 19 | --learning_rate 1e-5 \ 20 | --num_train_epochs 10.0 \ 21 | --max_seq_length 4096 \ 22 | --doc_stride 3072 \ 23 | --threads 8 \ 24 | --topk_tbs 15 \ 25 | --model_name_or_path ${MODEL_NAME} \ 26 | --prefix ${PREFIX} \ 27 | --save_cache \ 28 | --overwrite_cache \ 29 | -------------------------------------------------------------------------------- /evidence_chain/pretrain/run_evidence_large_eval.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/evidence_chain_data/bart_output_for_pretraining/add_negatives/pre-training/ 4 | export TRAIN_DATA_PATH=evidence_pretrain_train_shortest_esnegs.jsonl 5 | export DEV_DATA_PATH=evidence_pretrain_dev_shortest_esnegs.jsonl 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-1neg-weighted-esneg 7 | python run_classifier.py \ 8 | --model_type roberta \ 9 | --model_name_or_path roberta-base \ 10 | --task_name evidence_chain \ 11 | --do_predict \ 12 | --eval_all_checkpoints \ 13 | --data_dir ${DATA_PATH} \ 14 | --output_dir ${MODEL_PATH} \ 15 | --train_file ${TRAIN_DATA_PATH} \ 16 | --dev_file ${DEV_DATA_PATH} \ 17 | --max_seq_length 512 \ 18 | --per_gpu_train_batch_size 32 \ 19 | --per_gpu_eval_batch_size 64 \ 20 | --learning_rate 3e-5 \ 21 | --evaluate_during_training \ 22 | --overwrite_cache \ 23 | --save_steps 6000 \ 24 | --num_train_epochs 3 \ 25 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 26 | --test_file ${DEV_DATA_PATH} \ 27 | --test_result_dir ${MODEL_PATH}/eval_results.txt \ 28 | -------------------------------------------------------------------------------- /qa_evidence_chain/train_qa_evidence_chain_retrieved.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=1 2 | export BASIC_DIR=./ODQA 3 | export MODEL_NAME=allenai/longformer-base-4096 4 | export TOKENIZERS_PARALLELISM=false 5 | export TRAIN_DATA_PATH=train_ranked_evidence_chain_for_qa_weighted.json 6 | export DEV_DATA_PATH=dev_ranked_evidence_chain_for_qa_weighted.json 7 | export MODEL_DIR=longformer_base_ecmask_weighted_1e-5 8 | export PREFIX=retrieved_blink_ecmask_ec_top1 9 | python train_final_qa_ori.py \ 10 | --do_train \ 11 | --do_eval \ 12 | --model_type longformer \ 13 | --evaluate_during_training \ 14 | --data_dir ${BASIC_DIR}/data/qa_with_evidence_chain \ 15 | --output_dir ${BASIC_DIR}/model/qa_model/${MODEL_DIR} \ 16 | --train_file ${TRAIN_DATA_PATH} \ 17 | --dev_file ${DEV_DATA_PATH} \ 18 | --per_gpu_train_batch_size 2 \ 19 | --gradient_accumulation_steps 1 \ 20 | --learning_rate 1e-5 \ 21 | --num_train_epochs 10.0 \ 22 | --max_seq_length 4096 \ 23 | --doc_stride 3072 \ 24 | --threads 8 \ 25 | --topk_tbs 15 \ 26 | --model_name_or_path ${MODEL_NAME} \ 27 | --save_cache \ 28 | --prefix ${PREFIX} \ 29 | -------------------------------------------------------------------------------- /evidence_chain/pretrain_data_process/negative_search_by_bm25/create_db.py: -------------------------------------------------------------------------------- 1 | from elastic_search_wanjun import MyElastic 2 | import os 3 | import shutil 4 | import argparse 5 | if __name__ =="__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--pretrain', action='store_true') 8 | parser.add_argument('--finetune', action='store_true') 9 | 10 | args = parser.parse_args() 11 | ES=MyElastic() 12 | ES.delete_one() 13 | ES.create() 14 | # file_path = '/home/t-wzhong/v-wanzho/ODQA/data/data_wikitable/all_passages.json' 15 | if args.pretrain: 16 | file_path = '/home/t-wzhong/table-odqa/Data/evidence_chain/pre-training/evidence_output_pretrain_shortest.json' 17 | res = ES.bulk_insert_all_chains_pretrain(file_path) 18 | if args.finetune: 19 | basic_dir = '/home/t-wzhong/v-wanzho/ODQA/data/preprocessed_data/evidence_chain/ground-truth-based/ground-truth-evidence-chain/' 20 | file_paths = [os.path.join(basic_dir,file) for file in ['train_gt-ec-weighted.json','dev_gt-ec-weighted.json']] 21 | res = ES.bulk_insert_all_chains_finetune(file_paths) 22 | 23 | # res = ES.bulk_insert_all_doc(basic_path,file_path_list) 24 | # print(res) 25 | # print(len(res['hits']['hits'])) 26 | 27 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/run_evidence_large_eval.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/ground-truth-based 4 | export TRAIN_DATA_PATH=train_preprocessed_normalized_gtmodify_evichain_nx.json 5 | export DEV_DATA_PATH=dev_preprocessed_normalized_gtmodify_evichain_nx.json 6 | export TEST_DATA_PATH=train_preprocessed_normalized_gtmodify_candidate_evichain_addnoun.json 7 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-ecpretrain 8 | export TEST_DATA_PATH=dev_preprocessed_normalized_gtmodify_candidate_evichain_addnoun.json 9 | #export TEST_DATA_PATH=new_chain_blink_dev_evidence_chain_ranking.json 10 | python run_classifier.py \ 11 | --model_type roberta \ 12 | --model_name_or_path roberta-base \ 13 | --task_name evidence_chain \ 14 | --do_predict \ 15 | --eval_all_checkpoints \ 16 | --data_dir ${DATA_PATH} \ 17 | --output_dir ${MODEL_PATH} \ 18 | --train_file ${TRAIN_DATA_PATH} \ 19 | --dev_file ${DEV_DATA_PATH} \ 20 | --max_seq_length 512 \ 21 | --per_gpu_eval_batch_size 40 \ 22 | --overwrite_cache \ 23 | --test_file ${DATA_PATH}/${TEST_DATA_PATH} \ 24 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 25 | --test_result_dir ${DATA_PATH}/dev_ecpretrain_ranker_scores.json -------------------------------------------------------------------------------- /qa_baseline/train_qa_baseline.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=1 2 | export BASIC_DIR=./ODQA 3 | export MODEL_NAME=allenai/longformer-base-4096 4 | export TOKENIZERS_PARALLELISM=false 5 | export TRAIN_DATA_PATH=${BASIC_DIR}/data/preprocessed_data/qa/train_intable_p1_t360.pkl 6 | export DEV_DATA_PATH=${BASIC_DIR}/data/preprocessed_data/qa/dev_intable_p1_t360.pkl 7 | export MODEL_DIR=qa_model_longformer_normalized_gtmodify_rankdoc_top15_newretr 8 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_final_qa.py \ 9 | --do_train \ 10 | --do_eval \ 11 | --model_type longformer \ 12 | --evaluate_during_training \ 13 | --data_dir ${BASIC_DIR}/data/preprocessed_data/qa \ 14 | --output_dir ${BASIC_DIR}/model/${MODEL_DIR} \ 15 | --train_file train_preprocessed_normalized_gtmodify_newretr.json \ 16 | --dev_file dev_preprocessed_normalized_gtmodify_newretr.json\ 17 | --per_gpu_train_batch_size 2 \ 18 | --learning_rate 3e-5 \ 19 | --num_train_epochs 10.0 \ 20 | --max_seq_length 4096 \ 21 | --doc_stride 1024 \ 22 | --threads 8 \ 23 | --topk_tbs 15 \ 24 | --model_name_or_path ${MODEL_NAME} \ 25 | --repreprocess \ 26 | --overwrite_cache \ 27 | --prefix gtmodify_rankdoc_normalized_top15_newretr \ 28 | 2>&1 | tee ${BASIC_DIR}/qa_log/longformer-base-${MODEL_DIR}.log 29 | -------------------------------------------------------------------------------- /evidence_chain/ranking_model/run_evidence_large_train.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/ground-truth-based 4 | export TRAIN_DATA_PATH=train_preprocessed_normalized_gtmodify_evichain_addnoun.json 5 | export DEV_DATA_PATH=dev_preprocessed_normalized_gtmodify_evichain_addnoun.json 6 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-ranking-pretrain 7 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-pretrain-ec-ranker/checkpoint-best 8 | python run_classifier.py \ 9 | --model_type roberta \ 10 | --tokenizer_name roberta-base \ 11 | --config_name roberta-base \ 12 | --model_name_or_path roberta-base \ 13 | --task_name evidence_chain \ 14 | --do_train \ 15 | --do_eval \ 16 | --eval_all_checkpoints \ 17 | --data_dir ${DATA_PATH} \ 18 | --output_dir ${MODEL_PATH} \ 19 | --train_file ${TRAIN_DATA_PATH} \ 20 | --dev_file ${DEV_DATA_PATH} \ 21 | --max_seq_length 512 \ 22 | --per_gpu_train_batch_size 16 \ 23 | --per_gpu_eval_batch_size 16 \ 24 | --learning_rate 1e-5 \ 25 | --num_train_epochs 10 \ 26 | --load_save_pretrain \ 27 | --pretrain_model_dir ${PRETRAIN_MODEL_PATH} \ 28 | 29 | -------------------------------------------------------------------------------- /evidence_chain/pretrain/run_evidence_base_train.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=/wanjun/ODQA 3 | export BASIC_PATH=/home/t-wzhong/v-wanzho/ODQA 4 | export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/pre-training/fake_question_pretraining 5 | export TRAIN_DATA_PATH=fake_question_pretraining_train.jsonl 6 | export DEV_DATA_PATH=fake_question_pretraining_dev.jsonl 7 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-fake 8 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python run_classifier.py \ 9 | --model_type roberta \ 10 | --model_name_or_path roberta-base \ 11 | --task_name evidence_chain \ 12 | --overwrite_cache \ 13 | --do_train \ 14 | --do_eval \ 15 | --eval_all_checkpoints \ 16 | --data_dir ${DATA_PATH} \ 17 | --output_dir ${MODEL_PATH} \ 18 | --train_file ${TRAIN_DATA_PATH} \ 19 | --dev_file ${DEV_DATA_PATH} \ 20 | --max_seq_length 512 \ 21 | --per_gpu_train_batch_size 16 \ 22 | --per_gpu_eval_batch_size 16 \ 23 | --learning_rate 1e-5 \ 24 | --num_train_epochs 10 25 | #--test_file /home/dutang/FEVER/arranged_data/bert_data/Evidence/eval_file/train_all_sentence_evidence_title_all.tsv \ 26 | #--pred_model_dir /home/dutang/FEVER/arranged_models/xlnet_torch_models/evidence_large_title/checkpoint-best \ 27 | #--test_result_dir /home/dutang/FEVER/arranged_result/xlnet/evidence/train_evidence_xlnet_large_title_score.tsv 28 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/run_evidence_large_eval.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | #export DATA_PATH=${BASIC_PATH}/data/preprocessed_data/evidence_chain/ground-truth-based/candidate_chain 4 | export DATA_PATH=./ODQA/data/evidence_chain_data/retrieval-based/candidate_chain 5 | export TRAIN_DATA_PATH=train_preprocessed_normalized_gtmodify_evichain_nx.json 6 | export DEV_DATA_PATH=dev_preprocessed_normalized_gtmodify_evichain_nx.json 7 | export TEST_DATA_PATH=dev_gttb_candidate_ec.jsonl 8 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-weighted-multineg 9 | for part in 2 3 10 | do 11 | export TEST_DATA_PATH=train_evidence_chain_weighted_ranking_${part}.json 12 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python run_classifier.py \ 13 | --model_type roberta \ 14 | --model_name_or_path roberta-large \ 15 | --task_name evidence_chain \ 16 | --do_predict \ 17 | --eval_all_checkpoints \ 18 | --data_dir ${DATA_PATH} \ 19 | --output_dir ${MODEL_PATH} \ 20 | --train_file ${TRAIN_DATA_PATH} \ 21 | --dev_file ${DEV_DATA_PATH} \ 22 | --max_seq_length 512 \ 23 | --per_gpu_eval_batch_size 20 \ 24 | --overwrite_cache \ 25 | --test_file ${DATA_PATH}/${TEST_DATA_PATH} \ 26 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 27 | --test_result_dir ${DATA_PATH}/../scored_chains/weighted-large-5neg/train_evidence_chain_weighted_scores_${part}.json 28 | sleep 30 29 | done 30 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | 41 | raise RuntimeError('Invalid tokenizer: %s' % name) 42 | 43 | 44 | def get_annotators_for_args(args): 45 | annotators = set() 46 | if args.use_pos: 47 | annotators.add('pos') 48 | if args.use_lemma: 49 | annotators.add('lemma') 50 | if args.use_ner: 51 | annotators.add('ner') 52 | return annotators 53 | 54 | 55 | def get_annotators_for_model(model): 56 | return get_annotators_for_args(model.args) 57 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ExternalAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class ExternalAttention(nn.Module): 9 | 10 | def __init__(self, d_model,S=64): 11 | super().__init__() 12 | self.mk=nn.Linear(d_model,S,bias=False) 13 | self.mv=nn.Linear(S,d_model,bias=False) 14 | self.softmax=nn.Softmax(dim=1) 15 | self.init_weights() 16 | 17 | 18 | def init_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | init.kaiming_normal_(m.weight, mode='fan_out') 22 | if m.bias is not None: 23 | init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.BatchNorm2d): 25 | init.constant_(m.weight, 1) 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.Linear): 28 | init.normal_(m.weight, std=0.001) 29 | if m.bias is not None: 30 | init.constant_(m.bias, 0) 31 | 32 | def forward(self, queries): 33 | attn=self.mk(queries) #bs,n,S 34 | attn=self.softmax(attn) #bs,n,S 35 | attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S 36 | out=self.mv(attn) #bs,n,d_model 37 | 38 | return out 39 | 40 | 41 | if __name__ == '__main__': 42 | input=torch.randn(50,49,512) 43 | ea = ExternalAttention(d_model=512,S=8) 44 | output=ea(input) 45 | print(output.shape) 46 | 47 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ECAAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from collections import OrderedDict 6 | 7 | 8 | 9 | class ECAAttention(nn.Module): 10 | 11 | def __init__(self, kernel_size=3): 12 | super().__init__() 13 | self.gap=nn.AdaptiveAvgPool2d(1) 14 | self.conv=nn.Conv1d(1,1,kernel_size=kernel_size,padding=(kernel_size-1)//2) 15 | self.sigmoid=nn.Sigmoid() 16 | 17 | def init_weights(self): 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | init.kaiming_normal_(m.weight, mode='fan_out') 21 | if m.bias is not None: 22 | init.constant_(m.bias, 0) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | init.constant_(m.weight, 1) 25 | init.constant_(m.bias, 0) 26 | elif isinstance(m, nn.Linear): 27 | init.normal_(m.weight, std=0.001) 28 | if m.bias is not None: 29 | init.constant_(m.bias, 0) 30 | 31 | def forward(self, x): 32 | y=self.gap(x) #bs,c,1,1 33 | y=y.squeeze(-1).permute(0,2,1) #bs,1,c 34 | y=self.conv(y) #bs,1,c 35 | y=self.sigmoid(y) #bs,1,c 36 | y=y.permute(0,2,1).unsqueeze(-1) #bs,c,1,1 37 | return x*y.expand_as(x) 38 | 39 | 40 | 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | input=torch.randn(50,512,7,7) 46 | eca = ECAAttention(kernel_size=3) 47 | output=eca(input) 48 | print(output.shape) 49 | 50 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SEAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SEAttention(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=16): 11 | super().__init__() 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.fc = nn.Sequential( 14 | nn.Linear(channel, channel // reduction, bias=False), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(channel // reduction, channel, bias=False), 17 | nn.Sigmoid() 18 | ) 19 | 20 | 21 | def init_weights(self): 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | init.kaiming_normal_(m.weight, mode='fan_out') 25 | if m.bias is not None: 26 | init.constant_(m.bias, 0) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | init.constant_(m.weight, 1) 29 | init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.Linear): 31 | init.normal_(m.weight, std=0.001) 32 | if m.bias is not None: 33 | init.constant_(m.bias, 0) 34 | 35 | def forward(self, x): 36 | b, c, _, _ = x.size() 37 | y = self.avg_pool(x).view(b, c) 38 | y = self.fc(y).view(b, c, 1, 1) 39 | return x * y.expand_as(x) 40 | 41 | 42 | if __name__ == '__main__': 43 | input=torch.randn(50,512,7,7) 44 | se = SEAttention(channel=512,reduction=8) 45 | output=se(input) 46 | print(output.shape) 47 | 48 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | 12 | 13 | class DocDB(object): 14 | """Sqlite backed document storage. 15 | 16 | Implements get_doc_text(doc_id). 17 | """ 18 | 19 | def __init__(self, db_path): 20 | self.path = db_path 21 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 22 | 23 | def __enter__(self): 24 | return self 25 | 26 | def __exit__(self, *args): 27 | self.close() 28 | 29 | def path(self): 30 | """Return the path to the file that backs this database.""" 31 | return self.path 32 | 33 | def close(self): 34 | """Close the connection to the database.""" 35 | self.connection.close() 36 | 37 | def get_doc_ids(self): 38 | """Fetch all ids of docs stored in the db.""" 39 | cursor = self.connection.cursor() 40 | cursor.execute("SELECT id FROM documents") 41 | results = [r[0] for r in cursor.fetchall()] 42 | cursor.close() 43 | return results 44 | 45 | def get_doc_text(self, doc_id): 46 | """Fetch the raw text of the doc for 'doc_id'.""" 47 | cursor = self.connection.cursor() 48 | cursor.execute( 49 | "SELECT text FROM documents WHERE id = ?", 50 | #(utils.normalize(doc_id),) 51 | (doc_id, ) 52 | ) 53 | result = cursor.fetchone() 54 | cursor.close() 55 | return result if result is None else result[0] 56 | -------------------------------------------------------------------------------- /evidence_chain/fine-tune/run_evidence_train_from_pretrain.sh: -------------------------------------------------------------------------------- 1 | export RUN_ID=5 2 | export BASIC_PATH=./ODQA 3 | export DATA_PATH=${BASIC_PATH}/evidence_chain_data/ground-truth-based/ground-truth-evidence-chain 4 | export TRAIN_DATA_PATH=train_gt-ec-weighted.json 5 | export DEV_DATA_PATH=dev_gt-ec-weighted.json 6 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-large-new/checkpoint-best 7 | export STEP=${step} 8 | export PREFIX=southes-nonpretrain 9 | export MODEL_PATH=${BASIC_PATH}/model/evidence_chain/ft_pretrained/roberta-base-${PREFIX} 10 | export PRETRAIN_MODEL_PATH=${BASIC_PATH}/model/evidence_chain/roberta-base-pretrain-3neg-all-innerneg/checkpoint-${STEP} 11 | 12 | 13 | python run_classifier.py \ 14 | --model_type roberta \ 15 | --tokenizer_name roberta-base \ 16 | --config_name roberta-base \ 17 | --model_name_or_path roberta-base \ 18 | --task_name evidence_chain \ 19 | --do_train \ 20 | --do_eval \ 21 | --eval_all_checkpoints \ 22 | --data_dir ${DATA_PATH} \ 23 | --output_dir ${MODEL_PATH} \ 24 | --train_file ${TRAIN_DATA_PATH} \ 25 | --dev_file ${DEV_DATA_PATH} \ 26 | --max_seq_length 512 \ 27 | --per_gpu_train_batch_size 16 \ 28 | --per_gpu_eval_batch_size 16 \ 29 | --learning_rate 1e-5 \ 30 | --num_train_epochs 1 \ 31 | 32 | export TEST_DATA_PATH=dev_gttb_candidate_ec_weighted.json 33 | 34 | python run_classifier.py \ 35 | --model_type roberta \ 36 | --model_name_or_path roberta-base \ 37 | --task_name evidence_chain \ 38 | --do_predict \ 39 | --eval_all_checkpoints \ 40 | --data_dir ${DATA_PATH} \ 41 | --output_dir ${MODEL_PATH} \ 42 | --max_seq_length 512 \ 43 | --per_gpu_eval_batch_size 40 \ 44 | --test_file ${DATA_PATH}/../candidate_chain/${TEST_DATA_PATH} \ 45 | --pred_model_dir ${MODEL_PATH}/checkpoint-best \ 46 | --test_result_dir ${DATA_PATH}/../scored_chain/pretrained/dev_roberta_base_scored_ec_addtab_weighted_${PREFIX}.json 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SGE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SpatialGroupEnhance(nn.Module): 9 | 10 | def __init__(self, groups): 11 | super().__init__() 12 | self.groups=groups 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.weight=nn.Parameter(torch.zeros(1,groups,1,1)) 15 | self.bias=nn.Parameter(torch.zeros(1,groups,1,1)) 16 | self.sig=nn.Sigmoid() 17 | self.init_weights() 18 | 19 | 20 | def init_weights(self): 21 | for m in self.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | init.kaiming_normal_(m.weight, mode='fan_out') 24 | if m.bias is not None: 25 | init.constant_(m.bias, 0) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | init.constant_(m.weight, 1) 28 | init.constant_(m.bias, 0) 29 | elif isinstance(m, nn.Linear): 30 | init.normal_(m.weight, std=0.001) 31 | if m.bias is not None: 32 | init.constant_(m.bias, 0) 33 | 34 | def forward(self, x): 35 | b, c, h,w=x.shape 36 | x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w 37 | xn=x*self.avg_pool(x) #bs*g,dim//g,h,w 38 | xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w 39 | t=xn.view(b*self.groups,-1) #bs*g,h*w 40 | 41 | t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w 42 | std=t.std(dim=1,keepdim=True)+1e-5 43 | t=t/std #bs*g,h*w 44 | t=t.view(b,self.groups,h,w) #bs,g,h*w 45 | 46 | t=t*self.weight+self.bias #bs,g,h*w 47 | t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w 48 | x=x*self.sig(t) 49 | x=x.view(b,c,h,w) 50 | 51 | return x 52 | 53 | 54 | if __name__ == '__main__': 55 | input=torch.randn(50,512,7,7) 56 | sge = SpatialGroupEnhance(groups=8) 57 | output=sge(input) 58 | print(output.shape) 59 | 60 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/DANet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from .SelfAttention import ScaledDotProductAttention 6 | from .SimplifiedSelfAttention import SimplifiedScaledDotProductAttention 7 | 8 | class PositionAttentionModule(nn.Module): 9 | 10 | def __init__(self,d_model=512,kernel_size=3,H=7,W=7): 11 | super().__init__() 12 | self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2) 13 | self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1) 14 | 15 | def forward(self,x): 16 | bs,c,h,w=x.shape 17 | y=self.cnn(x) 18 | y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c 19 | y=self.pa(y,y,y) #bs,h*w,c 20 | return y 21 | 22 | 23 | class ChannelAttentionModule(nn.Module): 24 | 25 | def __init__(self,d_model=512,kernel_size=3,H=7,W=7): 26 | super().__init__() 27 | self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2) 28 | self.pa=SimplifiedScaledDotProductAttention(H*W,h=1) 29 | 30 | def forward(self,x): 31 | bs,c,h,w=x.shape 32 | y=self.cnn(x) 33 | y=y.view(bs,c,-1) #bs,c,h*w 34 | y=self.pa(y,y,y) #bs,c,h*w 35 | return y 36 | 37 | 38 | 39 | 40 | class DAModule(nn.Module): 41 | 42 | def __init__(self,d_model=512,kernel_size=3,H=7,W=7): 43 | super().__init__() 44 | self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7) 45 | self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7) 46 | 47 | def forward(self,input): 48 | bs,c,h,w=input.shape 49 | p_out=self.position_attention_module(input) 50 | c_out=self.channel_attention_module(input) 51 | p_out=p_out.permute(0,2,1).view(bs,c,h,w) 52 | c_out=c_out.view(bs,c,h,w) 53 | return p_out+c_out 54 | 55 | 56 | if __name__ == '__main__': 57 | input=torch.randn(50,512,7,7) 58 | danet=DAModule(d_model=512,kernel_size=3,H=7,W=7) 59 | print(danet(input).shape) 60 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ViP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MLP(nn.Module): 6 | def __init__(self,in_features,hidden_features,out_features,act_layer=nn.GELU,drop=0.1): 7 | super().__init__() 8 | self.fc1=nn.Linear(in_features,hidden_features) 9 | self.act=act_layer() 10 | self.fc2=nn.Linear(hidden_features,out_features) 11 | self.drop=nn.Dropout(drop) 12 | 13 | def forward(self, x) : 14 | return self.drop(self.fc2(self.drop(self.act(self.fc1(x))))) 15 | 16 | class WeightedPermuteMLP(nn.Module): 17 | def __init__(self,dim,seg_dim=8, qkv_bias=False, proj_drop=0.): 18 | super().__init__() 19 | self.seg_dim=seg_dim 20 | 21 | self.mlp_c=nn.Linear(dim,dim,bias=qkv_bias) 22 | self.mlp_h=nn.Linear(dim,dim,bias=qkv_bias) 23 | self.mlp_w=nn.Linear(dim,dim,bias=qkv_bias) 24 | 25 | self.reweighting=MLP(dim,dim//4,dim*3) 26 | 27 | self.proj=nn.Linear(dim,dim) 28 | self.proj_drop=nn.Dropout(proj_drop) 29 | 30 | def forward(self,x) : 31 | B,H,W,C=x.shape 32 | 33 | c_embed=self.mlp_c(x) 34 | 35 | S=C//self.seg_dim 36 | h_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,2,1,4).reshape(B,self.seg_dim,W,H*S) 37 | h_embed=self.mlp_h(h_embed).reshape(B,self.seg_dim,W,H,S).permute(0,3,2,1,4).reshape(B,H,W,C) 38 | 39 | w_embed=x.reshape(B,H,W,self.seg_dim,S).permute(0,3,1,2,4).reshape(B,self.seg_dim,H,W*S) 40 | w_embed=self.mlp_w(w_embed).reshape(B,self.seg_dim,H,W,S).permute(0,2,3,1,4).reshape(B,H,W,C) 41 | 42 | weight=(c_embed+h_embed+w_embed).permute(0,3,1,2).flatten(2).mean(2) 43 | weight=self.reweighting(weight).reshape(B,C,3).permute(2,0,1).softmax(0).unsqueeze(2).unsqueeze(2) 44 | 45 | x=h_embed*weight[0]+w_embed*weight[1]+h_embed*weight[2] 46 | 47 | x=self.proj_drop(self.proj(x)) 48 | 49 | return x 50 | 51 | 52 | 53 | if __name__ == '__main__': 54 | input=torch.randn(64,8,8,512) 55 | seg_dim=8 56 | vip=WeightedPermuteMLP(512,seg_dim) 57 | out=vip(input) 58 | print(out.shape) 59 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/AFT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class AFT_FULL(nn.Module): 9 | 10 | def __init__(self, d_model,n=49,simple=False): 11 | 12 | super(AFT_FULL, self).__init__() 13 | self.fc_q = nn.Linear(d_model, d_model) 14 | self.fc_k = nn.Linear(d_model, d_model) 15 | self.fc_v = nn.Linear(d_model,d_model) 16 | if(simple): 17 | self.position_biases=torch.zeros((n,n)) 18 | else: 19 | self.position_biases=nn.Parameter(torch.ones((n,n))) 20 | self.d_model = d_model 21 | self.n=n 22 | self.sigmoid=nn.Sigmoid() 23 | 24 | self.init_weights() 25 | 26 | 27 | def init_weights(self): 28 | for m in self.modules(): 29 | if isinstance(m, nn.Conv2d): 30 | init.kaiming_normal_(m.weight, mode='fan_out') 31 | if m.bias is not None: 32 | init.constant_(m.bias, 0) 33 | elif isinstance(m, nn.BatchNorm2d): 34 | init.constant_(m.weight, 1) 35 | init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.Linear): 37 | init.normal_(m.weight, std=0.001) 38 | if m.bias is not None: 39 | init.constant_(m.bias, 0) 40 | 41 | def forward(self, input): 42 | 43 | bs, n,dim = input.shape 44 | 45 | q = self.fc_q(input) #bs,n,dim 46 | k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim 47 | v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim 48 | 49 | numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim 50 | denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim 51 | 52 | out=(numerator/denominator) #n,bs,dim 53 | out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim 54 | 55 | return out 56 | 57 | 58 | if __name__ == '__main__': 59 | input=torch.randn(50,49,512) 60 | aft_full = AFT_FULL(d_model=512, n=49) 61 | output=aft_full(input) 62 | print(output.shape) 63 | 64 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/A2Atttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | 7 | 8 | 9 | class DoubleAttention(nn.Module): 10 | 11 | def __init__(self, in_channels,c_m,c_n,reconstruct = True): 12 | super().__init__() 13 | self.in_channels=in_channels 14 | self.reconstruct = reconstruct 15 | self.c_m=c_m 16 | self.c_n=c_n 17 | self.convA=nn.Conv2d(in_channels,c_m,1) 18 | self.convB=nn.Conv2d(in_channels,c_n,1) 19 | self.convV=nn.Conv2d(in_channels,c_n,1) 20 | if self.reconstruct: 21 | self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size = 1) 22 | self.init_weights() 23 | 24 | 25 | def init_weights(self): 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | init.kaiming_normal_(m.weight, mode='fan_out') 29 | if m.bias is not None: 30 | init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.BatchNorm2d): 32 | init.constant_(m.weight, 1) 33 | init.constant_(m.bias, 0) 34 | elif isinstance(m, nn.Linear): 35 | init.normal_(m.weight, std=0.001) 36 | if m.bias is not None: 37 | init.constant_(m.bias, 0) 38 | 39 | def forward(self, x): 40 | b, c, h,w=x.shape 41 | assert c==self.in_channels 42 | A=self.convA(x) #b,c_m,h,w 43 | B=self.convB(x) #b,c_n,h,w 44 | V=self.convV(x) #b,c_n,h,w 45 | tmpA=A.view(b,self.c_m,-1) 46 | attention_maps=F.softmax(B.view(b,self.c_n,-1)) 47 | attention_vectors=F.softmax(V.view(b,self.c_n,-1)) 48 | # step 1: feature gating 49 | global_descriptors=torch.bmm(tmpA,attention_maps.permute(0,2,1)) #b.c_m,c_n 50 | # step 2: feature distribution 51 | tmpZ = global_descriptors.matmul(attention_vectors) #b,c_m,h*w 52 | tmpZ=tmpZ.view(b,self.c_m,h,w) #b,c_m,h,w 53 | if self.reconstruct: 54 | tmpZ=self.conv_reconstruct(tmpZ) 55 | 56 | return tmpZ 57 | 58 | 59 | if __name__ == '__main__': 60 | input=torch.randn(50,512,7,7) 61 | a2 = DoubleAttention(512,128,128,True) 62 | output=a2(input) 63 | print(output.shape) 64 | 65 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Tokenizer that is backed by spaCy (spacy.io). 8 | 9 | Requires spaCy package and the spaCy english model. 10 | """ 11 | 12 | import spacy 13 | import copy 14 | from .tokenizer import Tokens, Tokenizer 15 | 16 | 17 | class SpacyTokenizer(Tokenizer): 18 | 19 | def __init__(self, **kwargs): 20 | """ 21 | Args: 22 | annotators: set that can include pos, lemma, and ner. 23 | model: spaCy model to use (either path, or keyword like 'en'). 24 | """ 25 | model = kwargs.get('model', 'en') 26 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 27 | nlp_kwargs = {'parser': False} 28 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 29 | nlp_kwargs['tagger'] = False 30 | if 'ner' not in self.annotators: 31 | nlp_kwargs['entity'] = False 32 | self.nlp = spacy.load(model, **nlp_kwargs) 33 | 34 | def tokenize(self, text): 35 | # We don't treat new lines as tokens. 36 | clean_text = text.replace('\n', ' ') 37 | tokens = self.nlp.tokenizer(clean_text) 38 | if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 39 | self.nlp.tagger(tokens) 40 | if 'ner' in self.annotators: 41 | self.nlp.entity(tokens) 42 | 43 | data = [] 44 | for i in range(len(tokens)): 45 | # Get whitespace 46 | start_ws = tokens[i].idx 47 | if i + 1 < len(tokens): 48 | end_ws = tokens[i + 1].idx 49 | else: 50 | end_ws = tokens[i].idx + len(tokens[i].text) 51 | 52 | data.append(( 53 | tokens[i].text, 54 | text[start_ws: end_ws], 55 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 56 | tokens[i].tag_, 57 | tokens[i].lemma_, 58 | tokens[i].ent_type_, 59 | )) 60 | 61 | # Set special option for non-entity tag: '' vs 'O' in spaCy 62 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 63 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/PSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class PSA(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=4,S=4): 11 | super().__init__() 12 | self.S=S 13 | 14 | self.convs=[] 15 | for i in range(S): 16 | self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1)) 17 | 18 | self.se_blocks=[] 19 | for i in range(S): 20 | self.se_blocks.append(nn.Sequential( 21 | nn.AdaptiveAvgPool2d(1), 22 | nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False), 25 | nn.Sigmoid() 26 | )) 27 | 28 | self.softmax=nn.Softmax(dim=1) 29 | 30 | 31 | def init_weights(self): 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | init.kaiming_normal_(m.weight, mode='fan_out') 35 | if m.bias is not None: 36 | init.constant_(m.bias, 0) 37 | elif isinstance(m, nn.BatchNorm2d): 38 | init.constant_(m.weight, 1) 39 | init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.Linear): 41 | init.normal_(m.weight, std=0.001) 42 | if m.bias is not None: 43 | init.constant_(m.bias, 0) 44 | 45 | def forward(self, x): 46 | b, c, h, w = x.size() 47 | 48 | #Step1:SPC module 49 | SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w 50 | for idx,conv in enumerate(self.convs): 51 | SPC_out[:,idx,:,:,:]=conv(SPC_out[:,idx,:,:,:]) 52 | 53 | #Step2:SE weight 54 | SE_out=torch.zeros_like(SPC_out) 55 | for idx,se in enumerate(self.se_blocks): 56 | SE_out[:,idx,:,:,:]=se(SPC_out[:,idx,:,:,:]) 57 | 58 | #Step3:Softmax 59 | softmax_out=self.softmax(SE_out) 60 | 61 | #Step4:SPA 62 | PSA_out=SPC_out*softmax_out 63 | PSA_out=PSA_out.view(b,-1,h,w) 64 | 65 | return PSA_out 66 | 67 | 68 | if __name__ == '__main__': 69 | input=torch.randn(50,512,7,7) 70 | psa = PSA(channel=512,reduction=8) 71 | output=psa(input) 72 | print(output.shape) 73 | 74 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/OutlookAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | import math 6 | from torch.nn import functional as F 7 | 8 | class OutlookAttention(nn.Module): 9 | 10 | def __init__(self,dim,num_heads=1,kernel_size=3,padding=1,stride=1,qkv_bias=False, 11 | attn_drop=0.1): 12 | super().__init__() 13 | self.dim=dim 14 | self.num_heads=num_heads 15 | self.head_dim=dim//num_heads 16 | self.kernel_size=kernel_size 17 | self.padding=padding 18 | self.stride=stride 19 | self.scale=self.head_dim**(-0.5) 20 | 21 | self.v_pj=nn.Linear(dim,dim,bias=qkv_bias) 22 | self.attn=nn.Linear(dim,kernel_size**4*num_heads) 23 | 24 | self.attn_drop=nn.Dropout(attn_drop) 25 | self.proj=nn.Linear(dim,dim) 26 | self.proj_drop=nn.Dropout(attn_drop) 27 | 28 | self.unflod=nn.Unfold(kernel_size,padding,stride) #手动卷积 29 | self.pool=nn.AvgPool2d(kernel_size=stride,stride=stride,ceil_mode=True) 30 | 31 | def forward(self, x) : 32 | B,H,W,C=x.shape 33 | 34 | #映射到新的特征v 35 | v=self.v_pj(x).permute(0,3,1,2) #B,C,H,W 36 | h,w=math.ceil(H/self.stride),math.ceil(W/self.stride) 37 | v=self.unflod(v).reshape(B,self.num_heads,self.head_dim,self.kernel_size*self.kernel_size,h*w).permute(0,1,4,3,2) #B,num_head,H*W,kxk,head_dim 38 | 39 | #生成Attention Map 40 | attn=self.pool(x.permute(0,3,1,2)).permute(0,2,3,1) #B,H,W,C 41 | attn=self.attn(attn).reshape(B,h*w,self.num_heads,self.kernel_size*self.kernel_size \ 42 | ,self.kernel_size*self.kernel_size).permute(0,2,1,3,4) #B,num_head,H*W,kxk,kxk 43 | attn=self.scale*attn 44 | attn=attn.softmax(-1) 45 | attn=self.attn_drop(attn) 46 | 47 | #获取weighted特征 48 | out=(attn @ v).permute(0,1,4,3,2).reshape(B,C*self.kernel_size*self.kernel_size,h*w) #B,dimxkxk,H*W 49 | out=F.fold(out,output_size=(H,W),kernel_size=self.kernel_size, 50 | padding=self.padding,stride=self.stride) #B,C,H,W 51 | out=self.proj(out.permute(0,2,3,1)) #B,H,W,C 52 | out=self.proj_drop(out) 53 | 54 | return out 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | input=torch.randn(50,28,28,512) 60 | outlook = OutlookAttention(dim=512) 61 | output=outlook(input) 62 | print(output.shape) 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SKAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from collections import OrderedDict 6 | 7 | 8 | 9 | class SKAttention(nn.Module): 10 | 11 | def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32): 12 | super().__init__() 13 | self.d=max(L,channel//reduction) 14 | self.convs=nn.ModuleList([]) 15 | for k in kernels: 16 | self.convs.append( 17 | nn.Sequential(OrderedDict([ 18 | ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)), 19 | ('bn',nn.BatchNorm2d(channel)), 20 | ('relu',nn.ReLU()) 21 | ])) 22 | ) 23 | self.fc=nn.Linear(channel,self.d) 24 | self.fcs=nn.ModuleList([]) 25 | for i in range(len(kernels)): 26 | self.fcs.append(nn.Linear(self.d,channel)) 27 | self.softmax=nn.Softmax(dim=1) 28 | 29 | 30 | def init_weights(self): 31 | for m in self.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal_(m.weight, mode='fan_out') 34 | if m.bias is not None: 35 | init.constant_(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant_(m.weight, 1) 38 | init.constant_(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal_(m.weight, std=0.001) 41 | if m.bias is not None: 42 | init.constant_(m.bias, 0) 43 | 44 | def forward(self, x): 45 | bs, c, _, _ = x.size() 46 | conv_outs=[] 47 | ### split 48 | for conv in self.convs: 49 | conv_outs.append(conv(x)) 50 | feats=torch.stack(conv_outs,0)#k,bs,channel,h,w 51 | 52 | ### fuse 53 | U=sum(conv_outs) #bs,c,h,w 54 | 55 | ### reduction channel 56 | S=U.mean(-1).mean(-1) #bs,c 57 | Z=self.fc(S) #bs,d 58 | 59 | ### calculate attention weight 60 | weights=[] 61 | for fc in self.fcs: 62 | weight=fc(Z) 63 | weights.append(weight.view(bs,c,1,1)) #bs,channel 64 | attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1 65 | attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1 66 | 67 | ### fuse 68 | V=(attention_weughts*feats).sum(0) 69 | return V 70 | 71 | 72 | 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | input=torch.randn(50,512,7,7) 78 | se = SKAttention(channel=512,reduction=8) 79 | output=se(input) 80 | print(output.shape) 81 | 82 | -------------------------------------------------------------------------------- /preprocess/merge_ec_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import copy 5 | from scipy.special import softmax 6 | project_dir = './ODQA' 7 | basic_dir = f'{project_dir}/data/evidence_chain_data/retrieval-based/scored_chains/' 8 | #weighted 9 | # gt_ec_file = f'{project_dir}/data/preprocessed_data/evidence_chain/ground-truth-based/scored_chains/train_gttb_evidence_chain_weighted_scores.json' 10 | gt_ec_file = f'{project_dir}/data/evidence_chain_data/ground-truth-based/scored_chains/dev_roberta_large_scored_ec_addtab_weighted_multineg.json' 11 | #original 12 | 13 | file_names = ['dev_evidence_chain_weighted_scores.json'] 14 | all_data = [] 15 | gt_ec_data = json.load(open(gt_ec_file,'r',encoding='utf8')) 16 | error_count = 0 17 | all_cnt = 0 18 | 19 | for fid,file in enumerate(file_names): 20 | start = 0 #dev 21 | # start = int(len(gt_ec_data) / 4) * fid 22 | end = len(gt_ec_data)#dev 23 | # end = int(len(gt_ec_data) / 4) * (fid + 1) if fid != (3) else len(gt_ec_data) 24 | data = json.load(open(os.path.join(basic_dir,file),'r',encoding='utf8')) 25 | output_data = [] 26 | for idx, item in tqdm(enumerate(data)): 27 | assert (item['question_id'] == gt_ec_data[start + idx]['question_id']) 28 | output_data.append(copy.deepcopy(item)) 29 | output_data[-1]['positive_table_blocks'] = copy.deepcopy(gt_ec_data[start + idx]['positive_table_blocks']) 30 | tmp_retrived_blocks = copy.deepcopy(item['retrieved_tbs'][:15]) 31 | for tbid,block in enumerate(tmp_retrived_blocks): 32 | # assert(block['table_id']==gt_ec_data[start+idx]['retrieved_tbs'][tbid]['table_id'] and block['row_id']==gt_ec_data[start+idx]['retrieved_tbs'][tbid]['row_id']) 33 | all_cnt+=1 34 | if block['candidate_evidence_chains']: 35 | ranked_ec = sorted(block['candidate_evidence_chains'], key=lambda k: softmax(k['score'])[1], reverse=True) 36 | if len(ranked_ec)>=3: 37 | selected = copy.deepcopy(ranked_ec[:3]) 38 | else: 39 | selected = copy.deepcopy(ranked_ec) 40 | del tmp_retrived_blocks[tbid]['candidate_evidence_chains'] 41 | tmp_retrived_blocks[tbid]['candidate_evidence_chains'] = selected 42 | else: 43 | error_count+=1 44 | output_data[-1]['retrieved_tbs'] = tmp_retrived_blocks 45 | assert(len(output_data[-1]['retrieved_tbs'])==15) 46 | 47 | all_data.extend(output_data) 48 | print('error rate: {}'.format(error_count/all_cnt)) 49 | with open(os.path.join(basic_dir,f'{project_dir}/data/qa_with_evidence_chain/dev_ranked_evidence_chain_for_qa_weighted.json'),'w',encoding='utf8') as outf: 50 | json.dump(all_data,outf) 51 | 52 | 53 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/CBAM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | def __init__(self,channel,reduction=16): 10 | super().__init__() 11 | self.maxpool=nn.AdaptiveMaxPool2d(1) 12 | self.avgpool=nn.AdaptiveAvgPool2d(1) 13 | self.se=nn.Sequential( 14 | nn.Conv2d(channel,channel//reduction,1,bias=False), 15 | nn.ReLU(), 16 | nn.Conv2d(channel//reduction,channel,1,bias=False) 17 | ) 18 | self.sigmoid=nn.Sigmoid() 19 | 20 | def forward(self, x) : 21 | max_result=self.maxpool(x) 22 | avg_result=self.avgpool(x) 23 | max_out=self.se(max_result) 24 | avg_out=self.se(avg_result) 25 | output=self.sigmoid(max_out+avg_out) 26 | return output 27 | 28 | class SpatialAttention(nn.Module): 29 | def __init__(self,kernel_size=7): 30 | super().__init__() 31 | self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2) 32 | self.sigmoid=nn.Sigmoid() 33 | 34 | def forward(self, x) : 35 | max_result,_=torch.max(x,dim=1,keepdim=True) 36 | avg_result=torch.mean(x,dim=1,keepdim=True) 37 | result=torch.cat([max_result,avg_result],1) 38 | output=self.conv(result) 39 | output=self.sigmoid(output) 40 | return output 41 | 42 | 43 | 44 | class CBAMBlock(nn.Module): 45 | 46 | def __init__(self, channel=512,reduction=16,kernel_size=49): 47 | super().__init__() 48 | self.ca=ChannelAttention(channel=channel,reduction=reduction) 49 | self.sa=SpatialAttention(kernel_size=kernel_size) 50 | 51 | 52 | def init_weights(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | init.kaiming_normal_(m.weight, mode='fan_out') 56 | if m.bias is not None: 57 | init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.BatchNorm2d): 59 | init.constant_(m.weight, 1) 60 | init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | init.normal_(m.weight, std=0.001) 63 | if m.bias is not None: 64 | init.constant_(m.bias, 0) 65 | 66 | def forward(self, x): 67 | b, c, _, _ = x.size() 68 | residual=x 69 | out=x*self.ca(x) 70 | out=out*self.sa(out) 71 | return out+residual 72 | 73 | 74 | if __name__ == '__main__': 75 | input=torch.randn(50,512,7,7) 76 | kernel_size=input.shape[2] 77 | cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size) 78 | output=cbam(input) 79 | print(output.shape) 80 | 81 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/ShuffleAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | class ShuffleAttention(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=16,G=8): 11 | super().__init__() 12 | self.G=G 13 | self.channel=channel 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G)) 16 | self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1)) 17 | self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1)) 18 | self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1)) 19 | self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1)) 20 | self.sigmoid=nn.Sigmoid() 21 | 22 | 23 | def init_weights(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | init.kaiming_normal_(m.weight, mode='fan_out') 27 | if m.bias is not None: 28 | init.constant_(m.bias, 0) 29 | elif isinstance(m, nn.BatchNorm2d): 30 | init.constant_(m.weight, 1) 31 | init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.Linear): 33 | init.normal_(m.weight, std=0.001) 34 | if m.bias is not None: 35 | init.constant_(m.bias, 0) 36 | 37 | 38 | @staticmethod 39 | def channel_shuffle(x, groups): 40 | b, c, h, w = x.shape 41 | x = x.reshape(b, groups, -1, h, w) 42 | x = x.permute(0, 2, 1, 3, 4) 43 | 44 | # flatten 45 | x = x.reshape(b, -1, h, w) 46 | 47 | return x 48 | 49 | def forward(self, x): 50 | b, c, h, w = x.size() 51 | #group into subfeatures 52 | x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w 53 | 54 | #channel_split 55 | x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w 56 | 57 | #channel attention 58 | x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1 59 | x_channel=self.cweight*x_channel+self.cweight #bs*G,c//(2*G),1,1 60 | x_channel=x_0*self.sigmoid(x_channel) 61 | 62 | #spatial attention 63 | x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w 64 | x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w 65 | x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w 66 | 67 | # concatenate along channel axis 68 | out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w 69 | out=out.contiguous().view(b,-1,h,w) 70 | 71 | # channel shuffle 72 | out = self.channel_shuffle(out, 2) 73 | return out 74 | 75 | 76 | if __name__ == '__main__': 77 | input=torch.randn(50,512,7,7) 78 | se = ShuffleAttention(channel=512,G=8) 79 | output=se(input) 80 | print(output.shape) 81 | 82 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/CoAtNet.py: -------------------------------------------------------------------------------- 1 | from torch import nn, sqrt 2 | import torch 3 | import sys 4 | from math import sqrt 5 | sys.path.append('.') 6 | from conv.MBConv import MBConvBlock 7 | from attention.SelfAttention import ScaledDotProductAttention 8 | 9 | class CoAtNet(nn.Module): 10 | def __init__(self,in_ch,image_size,out_chs=[64,96,192,384,768]): 11 | super().__init__() 12 | self.out_chs=out_chs 13 | self.maxpool2d=nn.MaxPool2d(kernel_size=2,stride=2) 14 | self.maxpool1d = nn.MaxPool1d(kernel_size=2, stride=2) 15 | 16 | self.s0=nn.Sequential( 17 | nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1), 18 | nn.ReLU(), 19 | nn.Conv2d(in_ch,in_ch,kernel_size=3,padding=1) 20 | ) 21 | self.mlp0=nn.Sequential( 22 | nn.Conv2d(in_ch,out_chs[0],kernel_size=1), 23 | nn.ReLU(), 24 | nn.Conv2d(out_chs[0],out_chs[0],kernel_size=1) 25 | ) 26 | 27 | self.s1=MBConvBlock(ksize=3,input_filters=out_chs[0],output_filters=out_chs[0],image_size=image_size//2) 28 | self.mlp1=nn.Sequential( 29 | nn.Conv2d(out_chs[0],out_chs[1],kernel_size=1), 30 | nn.ReLU(), 31 | nn.Conv2d(out_chs[1],out_chs[1],kernel_size=1) 32 | ) 33 | 34 | self.s2=MBConvBlock(ksize=3,input_filters=out_chs[1],output_filters=out_chs[1],image_size=image_size//4) 35 | self.mlp2=nn.Sequential( 36 | nn.Conv2d(out_chs[1],out_chs[2],kernel_size=1), 37 | nn.ReLU(), 38 | nn.Conv2d(out_chs[2],out_chs[2],kernel_size=1) 39 | ) 40 | 41 | self.s3=ScaledDotProductAttention(out_chs[2],out_chs[2]//8,out_chs[2]//8,8) 42 | self.mlp3=nn.Sequential( 43 | nn.Linear(out_chs[2],out_chs[3]), 44 | nn.ReLU(), 45 | nn.Linear(out_chs[3],out_chs[3]) 46 | ) 47 | 48 | self.s4=ScaledDotProductAttention(out_chs[3],out_chs[3]//8,out_chs[3]//8,8) 49 | self.mlp4=nn.Sequential( 50 | nn.Linear(out_chs[3],out_chs[4]), 51 | nn.ReLU(), 52 | nn.Linear(out_chs[4],out_chs[4]) 53 | ) 54 | 55 | 56 | def forward(self, x) : 57 | B,C,H,W=x.shape 58 | #stage0 59 | y=self.mlp0(self.s0(x)) 60 | y=self.maxpool2d(y) 61 | #stage1 62 | y=self.mlp1(self.s1(y)) 63 | y=self.maxpool2d(y) 64 | #stage2 65 | y=self.mlp2(self.s2(y)) 66 | y=self.maxpool2d(y) 67 | #stage3 68 | y=y.reshape(B,self.out_chs[2],-1).permute(0,2,1) #B,N,C 69 | y=self.mlp3(self.s3(y,y,y)) 70 | y=self.maxpool1d(y.permute(0,2,1)).permute(0,2,1) 71 | #stage4 72 | y=self.mlp4(self.s4(y,y,y)) 73 | y=self.maxpool1d(y.permute(0,2,1)) 74 | N=y.shape[-1] 75 | y=y.reshape(B,self.out_chs[4],int(sqrt(N)),int(sqrt(N))) 76 | 77 | return y 78 | 79 | if __name__ == '__main__': 80 | x=torch.randn(1,3,224,224) 81 | coatnet=CoAtNet(3,224) 82 | y=coatnet(x) 83 | print(y.shape) 84 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SimplifiedSelfAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class SimplifiedScaledDotProductAttention(nn.Module): 9 | ''' 10 | Scaled dot-product attention 11 | ''' 12 | 13 | def __init__(self, d_model, h,dropout=.1): 14 | ''' 15 | :param d_model: Output dimensionality of the model 16 | :param d_k: Dimensionality of queries and keys 17 | :param d_v: Dimensionality of values 18 | :param h: Number of heads 19 | ''' 20 | super(SimplifiedScaledDotProductAttention, self).__init__() 21 | 22 | self.d_model = d_model 23 | self.d_k = d_model//h 24 | self.d_v = d_model//h 25 | self.h = h 26 | 27 | self.fc_o = nn.Linear(h * self.d_v, d_model) 28 | self.dropout=nn.Dropout(dropout) 29 | 30 | 31 | 32 | self.init_weights() 33 | 34 | 35 | def init_weights(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal_(m.weight, mode='fan_out') 39 | if m.bias is not None: 40 | init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant_(m.weight, 1) 43 | init.constant_(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal_(m.weight, std=0.001) 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | 49 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 50 | ''' 51 | Computes 52 | :param queries: Queries (b_s, nq, d_model) 53 | :param keys: Keys (b_s, nk, d_model) 54 | :param values: Values (b_s, nk, d_model) 55 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 56 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 57 | :return: 58 | ''' 59 | b_s, nq = queries.shape[:2] 60 | nk = keys.shape[1] 61 | 62 | q = queries.view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 63 | k = keys.view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 64 | v = values.view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 65 | 66 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 67 | if attention_weights is not None: 68 | att = att * attention_weights 69 | if attention_mask is not None: 70 | att = att.masked_fill(attention_mask, -np.inf) 71 | att = torch.softmax(att, -1) 72 | att=self.dropout(att) 73 | 74 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 75 | out = self.fc_o(out) # (b_s, nq, d_model) 76 | return out 77 | 78 | 79 | if __name__ == '__main__': 80 | input=torch.randn(50,49,512) 81 | ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8) 82 | output=ssa(input,input,input) 83 | print(output.shape) 84 | 85 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/SelfAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class ScaledDotProductAttention(nn.Module): 9 | ''' 10 | Scaled dot-product attention 11 | ''' 12 | 13 | def __init__(self, d_model, d_k, d_v, h,dropout=.1): 14 | ''' 15 | :param d_model: Output dimensionality of the model 16 | :param d_k: Dimensionality of queries and keys 17 | :param d_v: Dimensionality of values 18 | :param h: Number of heads 19 | ''' 20 | super(ScaledDotProductAttention, self).__init__() 21 | self.fc_q = nn.Linear(d_model, h * d_k) 22 | self.fc_k = nn.Linear(d_model, h * d_k) 23 | self.fc_v = nn.Linear(d_model, h * d_v) 24 | self.fc_o = nn.Linear(h * d_v, d_model) 25 | self.dropout=nn.Dropout(dropout) 26 | 27 | self.d_model = d_model 28 | self.d_k = d_k 29 | self.d_v = d_v 30 | self.h = h 31 | 32 | self.init_weights() 33 | 34 | 35 | def init_weights(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal_(m.weight, mode='fan_out') 39 | if m.bias is not None: 40 | init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant_(m.weight, 1) 43 | init.constant_(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal_(m.weight, std=0.001) 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | 49 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 50 | ''' 51 | Computes 52 | :param queries: Queries (b_s, nq, d_model) 53 | :param keys: Keys (b_s, nk, d_model) 54 | :param values: Values (b_s, nk, d_model) 55 | :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking. 56 | :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk). 57 | :return: 58 | ''' 59 | b_s, nq = queries.shape[:2] 60 | nk = keys.shape[1] 61 | 62 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 63 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 64 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 65 | 66 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 67 | if attention_weights is not None: 68 | att = att * attention_weights 69 | if attention_mask is not None: 70 | # att = att.masked_fill(attention_mask, -np.inf) 71 | att = att + (1-attention_mask)*(-1e6) 72 | att = torch.softmax(att, -1) 73 | att=self.dropout(att) 74 | 75 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 76 | out = self.fc_o(out) # (b_s, nq, d_model) 77 | return out 78 | 79 | 80 | if __name__ == '__main__': 81 | input=torch.randn(50,49,512) 82 | sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8) 83 | output=sa(input,input,input) 84 | print(output.shape) 85 | 86 | -------------------------------------------------------------------------------- /retriever/retrieval/___tr_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import json 3 | import pickle 4 | import random 5 | import torch 6 | 7 | from .data_utils import collate_tokens 8 | 9 | 10 | class TRDataset(Dataset): 11 | 12 | def __init__(self, 13 | tokenizer, 14 | data_path, 15 | max_q_len, 16 | max_q_sp_len, 17 | max_c_len, 18 | train=False, 19 | ): 20 | super().__init__() 21 | self.tokenizer = tokenizer 22 | self.max_q_len = max_q_len 23 | self.max_c_len = max_c_len 24 | self.max_q_sp_len = max_q_sp_len 25 | self.train = train 26 | print(f"Loading data from {data_path}") 27 | self.data = [json.loads(line) for line in open(data_path).readlines()] 28 | with open(data_path, 'rb') as f: 29 | self.data = pickle.load(f) 30 | print(f"Total sample count {len(self.data)}") 31 | 32 | def encode_tb(self, passages, table, max_len): 33 | return self.tokenizer(table=table, queries=passages, max_length=max_len, return_tensors="pt") 34 | 35 | def __getitem__(self, index): 36 | sample = self.data[index] 37 | question = sample['question'] 38 | if question.endswith("?"): 39 | question = question[:-1] 40 | 41 | table = sample['table'] 42 | passages = ' '.join(sample['passages']) 43 | tb_codes = self.encode_tb(passages, table, self.max_c_len) 44 | 45 | q_codes = self.tokenizer.encode_plus(question, max_length=self.max_q_len, return_tensors="pt") 46 | label = torch.tensor(sample['label']) 47 | return { 48 | "q_codes": q_codes, 49 | "tb_codes": tb_codes, 50 | "label": label, 51 | } 52 | 53 | def __len__(self): 54 | return len(self.data) 55 | 56 | 57 | def tb_collate(samples, pad_id=0): 58 | if len(samples) == 0: 59 | return {} 60 | 61 | batch = { 62 | 'q_input_ids': collate_tokens([s["q_codes"]["input_ids"].view(-1) for s in samples], 0), 63 | 'q_mask': collate_tokens([s["q_codes"]["attention_mask"].view(-1) for s in samples], 0), 64 | 65 | 'tb_input_ids': collate_tokens([s["tb_codes"]["input_ids"].view(-1) for s in samples], 0), 66 | 'tb_mask':collate_tokens([s["tb_codes"]["attention_mask"].view(-1) for s in samples], 0), 67 | 68 | } 69 | 70 | # if "token_type_ids" in samples[0]["q_codes"]: 71 | # batch.update({ 72 | # 'q_type_ids': collate_tokens([s["q_codes"]["token_type_ids"].view(-1) for s in samples], 0), 73 | # 'c1_type_ids': collate_tokens([s["start_para_codes"]["token_type_ids"] for s in samples], 0), 74 | # 'c2_type_ids': collate_tokens([s["bridge_para_codes"]["token_type_ids"] for s in samples], 0), 75 | # "q_sp_type_ids": collate_tokens([s["q_sp_codes"]["token_type_ids"].view(-1) for s in samples], 0), 76 | # 'neg1_type_ids': collate_tokens([s["neg_codes_1"]["token_type_ids"] for s in samples], 0), 77 | # 'neg2_type_ids': collate_tokens([s["neg_codes_2"]["token_type_ids"] for s in samples], 0), 78 | # }) 79 | # 80 | # if "sent_ids" in samples[0]["start_para_codes"]: 81 | # batch["c1_sent_target"] = collate_tokens([s["start_para_codes"]["sent_ids"] for s in samples], -1) 82 | # batch["c1_sent_offsets"] = collate_tokens([s["start_para_codes"]["sent_offsets"] for s in samples], 0), 83 | 84 | return batch 85 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/BAM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | class Flatten(nn.Module): 7 | def forward(self,x): 8 | return x.view(x.shape[0],-1) 9 | 10 | class ChannelAttention(nn.Module): 11 | def __init__(self,channel,reduction=16,num_layers=3): 12 | super().__init__() 13 | self.avgpool=nn.AdaptiveAvgPool2d(1) 14 | gate_channels=[channel] 15 | gate_channels+=[channel//reduction]*num_layers 16 | gate_channels+=[channel] 17 | 18 | 19 | self.ca=nn.Sequential() 20 | self.ca.add_module('flatten',Flatten()) 21 | for i in range(len(gate_channels)-2): 22 | self.ca.add_module('fc%d'%i,nn.Linear(gate_channels[i],gate_channels[i+1])) 23 | self.ca.add_module('bn%d'%i,nn.BatchNorm1d(gate_channels[i+1])) 24 | self.ca.add_module('relu%d'%i,nn.ReLU()) 25 | self.ca.add_module('last_fc',nn.Linear(gate_channels[-2],gate_channels[-1])) 26 | 27 | 28 | def forward(self, x) : 29 | res=self.avgpool(x) 30 | res=self.ca(res) 31 | res=res.unsqueeze(-1).unsqueeze(-1).expand_as(x) 32 | return res 33 | 34 | class SpatialAttention(nn.Module): 35 | def __init__(self,channel,reduction=16,num_layers=3,dia_val=2): 36 | super().__init__() 37 | self.sa=nn.Sequential() 38 | self.sa.add_module('conv_reduce1',nn.Conv2d(kernel_size=1,in_channels=channel,out_channels=channel//reduction)) 39 | self.sa.add_module('bn_reduce1',nn.BatchNorm2d(channel//reduction)) 40 | self.sa.add_module('relu_reduce1',nn.ReLU()) 41 | for i in range(num_layers): 42 | self.sa.add_module('conv_%d'%i,nn.Conv2d(kernel_size=3,in_channels=channel//reduction,out_channels=channel//reduction,padding=1,dilation=dia_val)) 43 | self.sa.add_module('bn_%d'%i,nn.BatchNorm2d(channel//reduction)) 44 | self.sa.add_module('relu_%d'%i,nn.ReLU()) 45 | self.sa.add_module('last_conv',nn.Conv2d(channel//reduction,1,kernel_size=1)) 46 | 47 | def forward(self, x) : 48 | res=self.sa(x) 49 | res=res.expand_as(x) 50 | return res 51 | 52 | 53 | 54 | 55 | class BAMBlock(nn.Module): 56 | 57 | def __init__(self, channel=512,reduction=16,dia_val=2): 58 | super().__init__() 59 | self.ca=ChannelAttention(channel=channel,reduction=reduction) 60 | self.sa=SpatialAttention(channel=channel,reduction=reduction,dia_val=dia_val) 61 | self.sigmoid=nn.Sigmoid() 62 | 63 | 64 | def init_weights(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | init.kaiming_normal_(m.weight, mode='fan_out') 68 | if m.bias is not None: 69 | init.constant_(m.bias, 0) 70 | elif isinstance(m, nn.BatchNorm2d): 71 | init.constant_(m.weight, 1) 72 | init.constant_(m.bias, 0) 73 | elif isinstance(m, nn.Linear): 74 | init.normal_(m.weight, std=0.001) 75 | if m.bias is not None: 76 | init.constant_(m.bias, 0) 77 | 78 | def forward(self, x): 79 | b, c, _, _ = x.size() 80 | sa_out=self.sa(x) 81 | ca_out=self.ca(x) 82 | weight=self.sigmoid(sa_out+ca_out) 83 | out=(1+weight)*x 84 | return out 85 | 86 | 87 | if __name__ == '__main__': 88 | input=torch.randn(50,512,7,7) 89 | bam = BAMBlock(channel=512,reduction=16,dia_val=2) 90 | output=bam(input) 91 | print(output.shape) 92 | 93 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ast import parse 3 | def common_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # task 7 | parser.add_argument("--train_file", type=str, 8 | default="../data/nq-with-neg-train.txt") 9 | parser.add_argument("--predict_file", type=str, 10 | default="../data/nq-with-neg-dev.txt") 11 | parser.add_argument("--num_workers", default=30, type=int) 12 | parser.add_argument("--do_train", default=False, 13 | action='store_true', help="Whether to run training.") 14 | parser.add_argument("--do_predict", default=False, 15 | action='store_true', help="Whether to run eval on the dev set.") 16 | parser.add_argument("--basic_data_path", 17 | default='/home/t-wzhong/v-wanzho/ODQA/data/',type=str) 18 | # model 19 | parser.add_argument("--model_name", 20 | default="bert-base-uncased", type=str) 21 | parser.add_argument("--init_checkpoint", type=str, 22 | help="Initial checkpoint (usually from a pre-trained BERT model).", 23 | default="") 24 | parser.add_argument("--max_c_len", default=420, type=int, 25 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 26 | "longer than this will be truncated, and sequences shorter than this will be padded.") 27 | parser.add_argument("--max_q_len", default=70, type=int, 28 | help="The maximum number of tokens for the question. Questions longer than this will " 29 | "be truncated to this length.") 30 | parser.add_argument("--max_p_len", default=350, type=int, 31 | help="The maximum number of tokens for the question. Questions longer than this will " 32 | "be truncated to this length.") 33 | parser.add_argument("--cell_trim_length", default=20, type=int, 34 | help="The maximum number of tokens for each cell. Cell longer than this will " 35 | "be truncated to this length.") 36 | parser.add_argument('--fp16', action='store_true') 37 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 38 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 39 | "See details at https://nvidia.github.io/apex/amp.html") 40 | parser.add_argument("--no_cuda", default=False, action='store_true', 41 | help="Whether not to use CUDA when available") 42 | parser.add_argument("--local_rank", type=int, default=-1, 43 | help="local_rank for distributed training on gpus") 44 | parser.add_argument("--max_q_sp_len", default=50, type=int) 45 | parser.add_argument("--sent-level", action="store_true") 46 | parser.add_argument("--rnn-retriever", action="store_true") 47 | parser.add_argument("--predict_batch_size", default=512, 48 | type=int, help="Total batch size for predictions.") 49 | 50 | # multi vector scheme 51 | parser.add_argument("--multi-vector", type=int, default=1) 52 | parser.add_argument("--scheme", type=str, help="how to get the multivector, layerwise or tokenwise", default="none") 53 | parser.add_argument("--no_proj", action="store_true") 54 | parser.add_argument("--shared_encoder", action="store_true") 55 | 56 | # momentum 57 | parser.add_argument("--momentum", action="store_true") 58 | parser.add_argument("--init-retriever", type=str, default="") 59 | parser.add_argument("--k", type=int, default=38400, help="memory bank size") 60 | parser.add_argument("--m", type=float, default=0.999, help="momentum") 61 | 62 | 63 | # NQ multihop trial 64 | parser.add_argument("--nq-multi", action="store_true", help="train the NQ retrieval model to recover from error cases") 65 | 66 | return parser 67 | 68 | args = common_args() -------------------------------------------------------------------------------- /retriever/retrieval/attention/MUSEAttention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class Depth_Pointwise_Conv1d(nn.Module): 9 | def __init__(self,in_ch,out_ch,k): 10 | super().__init__() 11 | if(k==1): 12 | self.depth_conv=nn.Identity() 13 | else: 14 | self.depth_conv=nn.Conv1d( 15 | in_channels=in_ch, 16 | out_channels=in_ch, 17 | kernel_size=k, 18 | groups=in_ch, 19 | padding=k//2 20 | ) 21 | self.pointwise_conv=nn.Conv1d( 22 | in_channels=in_ch, 23 | out_channels=out_ch, 24 | kernel_size=1, 25 | groups=1 26 | ) 27 | def forward(self,x): 28 | out=self.pointwise_conv(self.depth_conv(x)) 29 | return out 30 | 31 | 32 | 33 | class MUSEAttention(nn.Module): 34 | 35 | def __init__(self, d_model, d_k, d_v, h,dropout=.1): 36 | 37 | 38 | super(MUSEAttention, self).__init__() 39 | self.fc_q = nn.Linear(d_model, h * d_k) 40 | self.fc_k = nn.Linear(d_model, h * d_k) 41 | self.fc_v = nn.Linear(d_model, h * d_v) 42 | self.fc_o = nn.Linear(h * d_v, d_model) 43 | self.dropout=nn.Dropout(dropout) 44 | 45 | self.conv1=Depth_Pointwise_Conv1d(h * d_v, d_model,1) 46 | self.conv3=Depth_Pointwise_Conv1d(h * d_v, d_model,3) 47 | self.conv5=Depth_Pointwise_Conv1d(h * d_v, d_model,5) 48 | self.dy_paras=nn.Parameter(torch.ones(3)) 49 | self.softmax=nn.Softmax(-1) 50 | 51 | self.d_model = d_model 52 | self.d_k = d_k 53 | self.d_v = d_v 54 | self.h = h 55 | 56 | self.init_weights() 57 | 58 | 59 | def init_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | init.kaiming_normal_(m.weight, mode='fan_out') 63 | if m.bias is not None: 64 | init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | init.constant_(m.weight, 1) 67 | init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.Linear): 69 | init.normal_(m.weight, std=0.001) 70 | if m.bias is not None: 71 | init.constant_(m.bias, 0) 72 | 73 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 74 | 75 | #Self Attention 76 | b_s, nq = queries.shape[:2] 77 | nk = keys.shape[1] 78 | 79 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 80 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 81 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 82 | 83 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 84 | if attention_weights is not None: 85 | att = att * attention_weights 86 | if attention_mask is not None: 87 | att = att.masked_fill(attention_mask, -np.inf) 88 | att = torch.softmax(att, -1) 89 | att=self.dropout(att) 90 | 91 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 92 | out = self.fc_o(out) # (b_s, nq, d_model) 93 | 94 | v2=v.permute(0,1,3,2).contiguous().view(b_s,-1,nk) #bs,dim,n 95 | self.dy_paras=nn.Parameter(self.softmax(self.dy_paras)) 96 | out2=self.dy_paras[0]*self.conv1(v2)+self.dy_paras[1]*self.conv3(v2)+self.dy_paras[2]*self.conv5(v2) 97 | out2=out2.permute(0,2,1) #bs.n.dim 98 | 99 | out=out+out2 100 | return out 101 | 102 | 103 | if __name__ == '__main__': 104 | input=torch.randn(50,49,512) 105 | sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8) 106 | output=sa(input,input,input) 107 | print(output.shape) 108 | 109 | -------------------------------------------------------------------------------- /retriever/retrieval/attention/EMSA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | 8 | class EMSA(nn.Module): 9 | 10 | def __init__(self, d_model, d_k, d_v, h,dropout=.1,H=7,W=7,ratio=3,apply_transform=True): 11 | 12 | super(EMSA, self).__init__() 13 | self.H=H 14 | self.W=W 15 | self.fc_q = nn.Linear(d_model, h * d_k) 16 | self.fc_k = nn.Linear(d_model, h * d_k) 17 | self.fc_v = nn.Linear(d_model, h * d_v) 18 | self.fc_o = nn.Linear(h * d_v, d_model) 19 | self.dropout=nn.Dropout(dropout) 20 | 21 | self.ratio=ratio 22 | if(self.ratio>1): 23 | self.sr=nn.Sequential() 24 | self.sr_conv=nn.Conv2d(d_model,d_model,kernel_size=ratio+1,stride=ratio,padding=ratio//2,groups=d_model) 25 | self.sr_ln=nn.LayerNorm(d_model) 26 | 27 | self.apply_transform=apply_transform and h>1 28 | if(self.apply_transform): 29 | self.transform=nn.Sequential() 30 | self.transform.add_module('conv',nn.Conv2d(h,h,kernel_size=1,stride=1)) 31 | self.transform.add_module('softmax',nn.Softmax(-1)) 32 | self.transform.add_module('in',nn.InstanceNorm2d(h)) 33 | 34 | self.d_model = d_model 35 | self.d_k = d_k 36 | self.d_v = d_v 37 | self.h = h 38 | 39 | self.init_weights() 40 | 41 | 42 | def init_weights(self): 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | init.kaiming_normal_(m.weight, mode='fan_out') 46 | if m.bias is not None: 47 | init.constant_(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | init.constant_(m.weight, 1) 50 | init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.Linear): 52 | init.normal_(m.weight, std=0.001) 53 | if m.bias is not None: 54 | init.constant_(m.bias, 0) 55 | 56 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 57 | 58 | b_s, nq ,c = queries.shape 59 | nk = keys.shape[1] 60 | 61 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 62 | 63 | if(self.ratio>1): 64 | x=queries.permute(0,2,1).view(b_s,c,self.H,self.W) #bs,c,H,W 65 | x=self.sr_conv(x) #bs,c,h,w 66 | x=x.contiguous().view(b_s,c,-1).permute(0,2,1) #bs,n',c 67 | x=self.sr_ln(x) 68 | k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, n') 69 | v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, n', d_v) 70 | else: 71 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 72 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 73 | 74 | if(self.apply_transform): 75 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, n') 76 | att = self.transform(att) # (b_s, h, nq, n') 77 | else: 78 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, n') 79 | att = torch.softmax(att, -1) # (b_s, h, nq, n') 80 | 81 | 82 | if attention_weights is not None: 83 | att = att * attention_weights 84 | if attention_mask is not None: 85 | att = att.masked_fill(attention_mask, -np.inf) 86 | 87 | att=self.dropout(att) 88 | 89 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 90 | out = self.fc_o(out) # (b_s, nq, d_model) 91 | return out 92 | 93 | 94 | if __name__ == '__main__': 95 | input=torch.randn(50,64,512) 96 | emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8,H=8,W=8,ratio=2,apply_transform=True) 97 | output=emsa(input,input,input) 98 | print(output.shape) 99 | 100 | -------------------------------------------------------------------------------- /retriever/retrieval/drqa/drqa_tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P