├── .gitignore ├── conf ├── datasets │ ├── table_train.yaml │ └── table_retrieval.yaml ├── train │ ├── biencoder_default.yaml │ ├── biencoder_local.yaml │ └── biencoder_nq.yaml ├── encoder │ └── hf_bert.yaml ├── ctx_sources │ └── table_sources.yaml ├── gen_embs.yaml ├── biencoder_train.yaml └── dense_retrieval.yaml ├── scripts ├── run_inference.sh ├── curate_data.sh └── tune_model.sh ├── LICENSE.md ├── setup.py ├── dpr ├── utils │ ├── conf_utils.py │ ├── dist_utils.py │ ├── model_utils.py │ ├── tokenizers.py │ └── data_utils.py ├── models │ ├── __init__.py │ ├── reader.py │ └── biencoder.py ├── options.py ├── data │ ├── table_data.py │ ├── qa_validation.py │ ├── retriever_data.py │ ├── tables.py │ └── reader_data.py └── indexer │ └── faiss_indexers.py ├── convert_data.py ├── generate_embeddings.py ├── README.md ├── process_table.py └── dense_retrieval.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | checkpoint/** 3 | outputs/** -------------------------------------------------------------------------------- /conf/datasets/table_train.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | nq_table_train: 4 | _target_: dpr.data.biencoder_data.NqtJsonQADataset 5 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/train.converted" 6 | 7 | nq_table_dev: 8 | _target_: dpr.data.biencoder_data.NqtJsonQADataset 9 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/dev.converted" -------------------------------------------------------------------------------- /conf/datasets/table_retrieval.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | nq_table_test: 4 | _target_: dpr.data.retriever_data.JsonlQASrcTable 5 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/test.jsonl" 6 | 7 | nq_table_dev: 8 | _target_: dpr.data.retriever_data.JsonlQASrcTable 9 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/dev.jsonl" 10 | 11 | nq_table_train: 12 | _target_: dpr.data.retriever_data.JsonlQASrcTable 13 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/train.jsonl" -------------------------------------------------------------------------------- /conf/train/biencoder_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 2 4 | dev_batch_size: 1 5 | adam_eps: 1e-8 6 | adam_betas: (0.9, 0.999) 7 | max_grad_norm: 1.0 8 | log_batch_step: 100 9 | train_rolling_loss_step: 100 10 | weight_decay: 0.0 11 | learning_rate: 1e-5 12 | 13 | # Linear warmup over warmup_steps. 14 | warmup_steps: 100 15 | 16 | # Number of updates steps to accumulate before performing a backward/update pass. 17 | gradient_accumulation_steps: 1 18 | 19 | # Total number of training epochs to perform. 20 | num_train_epochs: 40 21 | eval_per_epoch: 1 22 | hard_negatives: 1 23 | other_negatives: 0 24 | val_av_rank_hard_neg: 30 25 | val_av_rank_other_neg: 30 26 | val_av_rank_bsz: 128 27 | val_av_rank_max_qs: 10000 -------------------------------------------------------------------------------- /conf/train/biencoder_local.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 1 4 | dev_batch_size: 16 5 | adam_eps: 1e-8 6 | adam_betas: (0.9, 0.999) 7 | max_grad_norm: 2.0 8 | log_batch_step: 1 9 | train_rolling_loss_step: 100 10 | weight_decay: 0.0 11 | learning_rate: 2e-5 12 | 13 | # Linear warmup over warmup_steps. 14 | warmup_steps: 1237 15 | 16 | # Number of updates steps to accumulate before performing a backward/update pass. 17 | gradient_accumulation_steps: 1 18 | 19 | # Total number of training epochs to perform. 20 | num_train_epochs: 40 21 | eval_per_epoch: 1 22 | hard_negatives: 1 23 | other_negatives: 0 24 | val_av_rank_hard_neg: 30 25 | val_av_rank_other_neg: 30 26 | val_av_rank_bsz: 128 27 | val_av_rank_max_qs: 10000 28 | -------------------------------------------------------------------------------- /conf/train/biencoder_nq.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 8 4 | dev_batch_size: 16 5 | adam_eps: 1e-8 6 | adam_betas: (0.9, 0.999) 7 | max_grad_norm: 2.0 8 | log_batch_step: 10 9 | train_rolling_loss_step: 100 10 | weight_decay: 0.0 11 | learning_rate: 2e-5 12 | 13 | # Linear warmup over warmup_steps. 14 | warmup_steps: 1237 15 | 16 | # Number of updates steps to accumulate before performing a backward/update pass. 17 | gradient_accumulation_steps: 1 18 | 19 | # Total number of training epochs to perform. 20 | num_train_epochs: 30 21 | eval_per_epoch: 1 22 | hard_negatives: 1 23 | other_negatives: 0 24 | val_av_rank_hard_neg: 40 # 30 25 | val_av_rank_other_neg: 40 # 30 26 | val_av_rank_bsz: 128 27 | val_av_rank_max_qs: 10000 28 | -------------------------------------------------------------------------------- /conf/encoder/hf_bert.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # model type. One of [hf_bert, pytext_bert, fairseq_roberta, hf_bert_mix, hf_bert_bias] 4 | encoder_model_type: hf_bert 5 | 6 | # HuggingFace's config name for model initialization 7 | pretrained_model_cfg: bert-base-uncased 8 | 9 | # Some encoders need to be initialized from a file 10 | pretrained_file: 11 | 12 | # Extra linear layer on top of standard bert/roberta encoder 13 | projection_dim: 0 14 | 15 | # Max length of the encoder input sequence 16 | sequence_length: 256 17 | 18 | dropout: 0.1 19 | 20 | # whether to fix (don't update) context encoder during training or not 21 | fix_ctx_encoder: False 22 | fix_q_encoder: False 23 | 24 | # if False, the model won't load pre-trained BERT weights 25 | pretrained: True 26 | 27 | # optimized parameters 28 | auxiliary_embeddings_only: False -------------------------------------------------------------------------------- /scripts/run_inference.sh: -------------------------------------------------------------------------------- 1 | # Zero-shot Inference for NQ-Table using DPR checkpoint 2 | 3 | set -euo pipefail 4 | 5 | if [[ -z $ROOT_DIR ]]; then 6 | echo "\$ROOT_DIR enviromental variable needs to be set" 7 | exit 1 8 | fi 9 | 10 | declare -a split_list=("test") 11 | 12 | DATA="datasets" 13 | CTX="nq_table" 14 | 15 | MODEL=${ROOT}/"checkpoint/retriever/single-adv-hn/nq/bert-base-encoder.cp" 16 | EMBED=${ROOT}/${DATA}/${CTX}/"embed" 17 | 18 | 19 | cd ${ROOT} 20 | 21 | # [1] generate dense embeddings 22 | echo "[1] generating ["${CTX}"] embeddings" 23 | python generate_embeddings.py \ 24 | model_file=${MODEL} \ 25 | ctx_src=${CTX} \ 26 | out_file=${EMBED} 27 | 28 | # [2] run retrieval inference using the generated embeddings 29 | echo "[2] run ["${CTX}"] retrieval inference" 30 | 31 | for spt in ${split_list[@]}; do 32 | RETRIEVED=${ROOT}/${DATA}/${CTX}/${spt}".retrieved" 33 | 34 | python dense_retrieval.py \ 35 | model_file=${MODEL} \ 36 | ctx_datatsets=[${CTX}"_all"] \ 37 | encoded_ctx_files=[${EMBED}"_0"] \ 38 | qa_dataset=${CTX}"_"${spt} \ 39 | out_file=${RETRIEVED} 40 | 41 | done -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zora/Zhiruo Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/curate_data.sh: -------------------------------------------------------------------------------- 1 | # Curate the training and validation samples for NQ-Table. 2 | 3 | set -euo pipefail 4 | 5 | if [[ -z $ROOT_DIR ]]; then 6 | echo "\$ROOT_DIR enviromental variable needs to be set" 7 | exit 1 8 | fi 9 | 10 | declare -a split_list=("dev" "train") 11 | 12 | DATA="datasets" 13 | CTX="nq_table" 14 | 15 | EMBED=${ROOT}/${DATA}/${ctx}/"embed" 16 | TABLES=${ROOT}/${DATA}/${ctx}/"tables_proc.jsonl" 17 | 18 | cd ${ROOT} 19 | 20 | for spt in ${split_list[@]}; do 21 | 22 | # [0] declare output files 23 | RETRIEVED=${ROOT}/${DATA}/${CTX}/${spt}".retrieved" 24 | CONVERTED=${ROOT}/${DATA}/${CTX}/${spt}".converted" 25 | ANNOTATED=${ROOT}/${DATA}/${CTX}/${spt}".jsonl" 26 | 27 | # [1] run retrieval inference using the generated embeddings 28 | echo "[1] run ["${CTX}"] retrieval inference" 29 | python dense_retrieval.py \ 30 | ctx_datatsets=[${CTX}"_all"] \ 31 | encoded_ctx_files=[${EMBED}"_0"] \ 32 | qa_dataset=${CTX}"_"${spt} \ 33 | out_file=${RETRIEVED} 34 | 35 | # [2] convert to training format 36 | echo "[2] ["${CTX}"] convert "${spt}" results" 37 | python convert_data.py \ 38 | --tables_file=${TABLES} \ 39 | --retrieved_path=${RETRIEVED} \ 40 | --converted_path=${CONVERTED} \ 41 | --annotated_path=${ANNOTATED} 42 | 43 | done -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from setuptools import setup 9 | 10 | with open("README.md") as f: 11 | readme = f.read() 12 | 13 | setup( 14 | name="dpr", 15 | version="1.0.0", 16 | description="Facebook AI Research Open Domain Q&A Toolkit", 17 | url="https://github.com/facebookresearch/DPR/", 18 | classifiers=[ 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: MIT License", 21 | "Programming Language :: Python :: 3.6", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | ], 24 | long_description=readme, 25 | long_description_content_type="text/markdown", 26 | setup_requires=[ 27 | "setuptools>=18.0", 28 | ], 29 | install_requires=[ 30 | "faiss-cpu>=1.6.1", 31 | "filelock", 32 | "numpy", 33 | "regex", 34 | "torch>=1.5.0", 35 | "transformers>=4.3", 36 | "tqdm>=4.27", 37 | "wget", 38 | "spacy>=2.1.8", 39 | "hydra-core>=1.0.0", 40 | "omegaconf>=2.0.1", 41 | "jsonlines", 42 | "soundfile", 43 | "editdistance", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /dpr/utils/conf_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | 5 | import hydra 6 | from omegaconf import DictConfig 7 | 8 | from dpr.data.biencoder_data import NqtJsonQADataset 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class BiencoderDatasetsCfg(object): 14 | def __init__(self, cfg: DictConfig): 15 | ds_cfg = cfg.datasets 16 | self.train_datasets_names = cfg.train_datasets 17 | logger.info("train_datasets: %s", self.train_datasets_names) 18 | self.train_datasets = _init_datasets(self.train_datasets_names, ds_cfg) 19 | self.dev_datasets_names = cfg.dev_datasets 20 | logger.info("dev_datasets: %s", self.dev_datasets_names) 21 | self.dev_datasets = _init_datasets(self.dev_datasets_names, ds_cfg) 22 | self.sampling_rates = cfg.train_sampling_rates 23 | 24 | 25 | def _init_datasets(datasets_names, ds_cfg: DictConfig): 26 | if isinstance(datasets_names, str): 27 | return [_init_dataset(datasets_names, ds_cfg)] 28 | elif datasets_names: 29 | return [_init_dataset(ds_name, ds_cfg) for ds_name in datasets_names] 30 | else: 31 | return [] 32 | 33 | 34 | def _init_dataset(name: str, ds_cfg: DictConfig): 35 | if os.path.exists(name): 36 | # use default biencoder json class 37 | return NqtJsonQADataset(name) 38 | elif glob.glob(name): 39 | files = glob.glob(name) 40 | return [_init_dataset(f, ds_cfg) for f in files] 41 | # try to find in cfg 42 | if name not in ds_cfg: 43 | raise RuntimeError("Can't find dataset location/config for: {}".format(name)) 44 | return hydra.utils.instantiate(ds_cfg[name]) 45 | -------------------------------------------------------------------------------- /conf/ctx_sources/table_sources.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | nq_table_raw: 4 | _target_: dpr.data.retriever_data.JsonlNQTablesCtxSrc 5 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl" 6 | id_prefix: 'nqt:' 7 | 8 | nq_table: 9 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 10 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_proc.jsonl" 11 | id_prefix: 'nqt:' 12 | 13 | nq_table_all: 14 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 15 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_all.jsonl" 16 | id_prefix: 'nqt:' 17 | 18 | 19 | nq_table_row: 20 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 21 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_row.jsonl" 22 | id_prefix: 'nqt:' 23 | 24 | nq_table_column: 25 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 26 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_column.jsonl" 27 | id_prefix: 'nqt:' 28 | 29 | nq_table_both: 30 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 31 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_both.jsonl" 32 | id_prefix: 'nqt:' 33 | 34 | 35 | nq_table_dcell: 36 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 37 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_dcell.jsonl" 38 | id_prefix: 'nqt:' 39 | 40 | nq_table_drow: 41 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 42 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_drow.jsonl" 43 | id_prefix: 'nqt:' 44 | 45 | nq_table_dnone: 46 | _target_: dpr.data.retriever_data.JsonlNqtCtxSrc 47 | file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_dnone.jsonl" 48 | id_prefix: 'nqt:' -------------------------------------------------------------------------------- /conf/gen_embs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - encoder: hf_bert 3 | - ctx_sources: table_sources 4 | 5 | # A trained bi-encoder checkpoint file to initialize the model 6 | model_file: 7 | 8 | # Name of the all-passages resource 9 | ctx_src: nq_table 10 | 11 | # which (ctx or query) encoder to be used for embedding generation 12 | encoder_type: ctx 13 | 14 | # output .tsv file path to write results to 15 | out_file: 16 | 17 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 18 | do_lower_case: True 19 | 20 | # Number(0-based) of data shard to process 21 | shard_id: 0 22 | 23 | # Total amount of data shards 24 | num_shards: 1 25 | 26 | # Batch size for the passage encoder forward pass (works in DataParallel mode) 27 | batch_size: 32 28 | 29 | tables_as_passages: False 30 | 31 | # tokens which won't be slit by tokenizer 32 | special_tokens: 33 | 34 | tables_chunk_sz: 100 35 | 36 | # TODO 37 | tables_split_type: type1 38 | 39 | 40 | # TODO: move to a conf group 41 | # local_rank for distributed training on gpus 42 | local_rank: -1 43 | device: 44 | distributed_world_size: 45 | distributed_port: 46 | no_cuda: False 47 | n_gpu: 48 | fp16: False 49 | 50 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 51 | # "See details at https://nvidia.github.io/apex/amp.html 52 | fp16_opt_level: O1 53 | 54 | 55 | # preprocessing details 56 | row_selection: 'none' # ['none', 'random', 'ngram'] 57 | max_cell_num: None # maximum number of cells to include in the sequence 58 | 59 | max_words: 120 60 | max_words_per_header: 12 # maximum number of words in the header cell 61 | max_words_per_cell: 8 # ~ in the content cells 62 | max_cell_num_per_row: 64 # default to no restrictions 63 | header_delimiter: '|' 64 | cell_delimiter: '|' 65 | row_delimiter: '.' 66 | 67 | max_sequence_length: 256 68 | 69 | structure_option: 'global' # [global, rowcol, auxemb] -------------------------------------------------------------------------------- /scripts/tune_model.sh: -------------------------------------------------------------------------------- 1 | # Tuning the model using curated samples under different settings 2 | 3 | set -euo pipefail 4 | 5 | if [[ -z $ROOT_DIR ]]; then 6 | echo "\$ROOT_DIR enviromental variable needs to be set" 7 | exit 1 8 | fi 9 | 10 | DATA="datasets" 11 | CTX="nq_table" 12 | CONF="biencoder_nq" # ("biencoder_nq" "biencoder_local" "biencoder_default") 13 | opt="global" # ("global" "rowcol" "auxemb" "biased") 14 | 15 | TRAIN_DATA=${ROOT}/${DATA}/${CTX}/"train.converted" 16 | DEV_DATA=${ROOT}/${DATA}/${CTX}/"dev.converted" 17 | 18 | MODEL=${ROOT}/"checkpoint/retriever/single-adv-hn/nq/bert-base-encoder.cp" 19 | 20 | cd ${ROOT} 21 | 22 | # 'glocal' OR 'rowcol' 23 | python train_biencoder.py \ 24 | model_file=${MODEL} \ 25 | train=${CONF} \ 26 | train_datasets=[${TRAIN_DATA}] \ 27 | dev_datasets=[${DEV_DATA}] \ 28 | output_dir=${ROOT}/"checkpoint" \ 29 | encoder.encoder_model_type="hf_bert" 30 | checkpoint_file_name=${CTX}"_"${opt} \ 31 | ignore_checkpoint_offset=True \ 32 | ignore_checkpoint_lr=True \ 33 | structure_option=${opt} 34 | 35 | 36 | # # 'auxemb' 37 | # python train_biencoder.py \ 38 | # model_file=${MODEL} \ 39 | # train=${CONF} \ 40 | # train_datasets=[${TRAIN_DATA}] \ 41 | # dev_datasets=[${DEV_DATA}] \ 42 | # output_dir=${ROOT}/"checkpoint" \ 43 | # encoder.encoder_model_type="hf_bert_mix" 44 | # checkpoint_file_name=${CTX}"_"${opt} \ 45 | # ignore_checkpoint_optimizer=True \ 46 | # ignore_checkpoint_offset=True \ 47 | # ignore_checkpoint_lr=True \ 48 | # structure_option=${opt} 49 | 50 | # # 'biased' 51 | # python train_biencoder.py \ 52 | # model_file=${MODEL} \ 53 | # train=${CONF} \ 54 | # train_datasets=[${TRAIN_DATA}] \ 55 | # dev_datasets=[${DEV_DATA}] \ 56 | # output_dir=${ROOT}/"checkpoint" \ 57 | # encoder.encoder_model_type="hf_bert_bias" 58 | # checkpoint_file_name=${CTX}"_"${opt} \ 59 | # ignore_checkpoint_optimizer=True \ 60 | # ignore_checkpoint_offset=True \ 61 | # ignore_checkpoint_lr=True \ 62 | # structure_option=${opt} -------------------------------------------------------------------------------- /conf/biencoder_train.yaml: -------------------------------------------------------------------------------- 1 | 2 | # configuration groups 3 | defaults: 4 | - encoder: hf_bert 5 | - train: biencoder_nq 6 | - datasets: table_train 7 | 8 | train_datasets: [nq_table_train] 9 | dev_datasets: [nq_table_dev] 10 | output_dir: 11 | train_sampling_rates: 12 | loss_scale_factors: 13 | 14 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 15 | do_lower_case: True 16 | 17 | val_av_rank_start_epoch: 50 # 30 18 | seed: 12345 19 | checkpoint_file_name: 20 | 21 | # A trained bi-encoder checkpoint file to initialize the model 22 | model_file: 23 | 24 | # TODO: move to a conf group 25 | # local_rank for distributed training on gpus 26 | 27 | # TODO: rename to distributed_rank 28 | local_rank: -1 29 | global_loss_buf_sz: 592000 30 | device: 31 | distributed_world_size: 32 | distributed_port: 33 | distributed_init_method: 34 | 35 | no_cuda: False 36 | n_gpu: 37 | fp16: False 38 | 39 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 40 | # "See details at https://nvidia.github.io/apex/amp.html 41 | fp16_opt_level: O1 42 | 43 | # tokens which won't be slit by tokenizer 44 | special_tokens: 45 | 46 | ignore_checkpoint_offset: False 47 | ignore_checkpoint_optimizer: False 48 | ignore_checkpoint_lr: False 49 | 50 | # set to >1 to enable multiple query encoders 51 | multi_q_encoder: False 52 | 53 | # Set to True to reduce memory footprint and loose a bit the full train data randomization if you train in DDP mode 54 | local_shards_dataloader: False 55 | 56 | 57 | # preprocessing details 58 | row_selection: 'none' # ['none', 'random', 'ngram'] 59 | max_cell_num: None # maximum number of cells to include in the sequence 60 | 61 | max_words: 120 62 | max_words_per_header: 12 # maximum number of words in the header cell 63 | max_words_per_cell: 8 # ~ in the content cells 64 | max_cell_num_per_row: 64 # default to no restrictions 65 | header_delimiter: '|' 66 | cell_delimiter: '|' 67 | row_delimiter: '.' 68 | 69 | max_sequence_length: 256 70 | 71 | structure_option: 'global' # [global, rowcol, auxemb] -------------------------------------------------------------------------------- /conf/dense_retrieval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - encoder: hf_bert # defines encoder initialization parameters 3 | - datasets: table_retrieval # contains a list of all possible sources of queries for evaluation. Specific set is selected by qa_dataset parameter 4 | - ctx_sources: table_sources # contains a list of all possible passage sources. Specific passages sources selected by ctx_datatsets parameter 5 | 6 | indexers: 7 | flat: 8 | _target_: dpr.indexer.faiss_indexers.DenseFlatIndexer 9 | 10 | hnsw: 11 | _target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer 12 | 13 | hnsw_sq: 14 | _target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer 15 | 16 | # the name of the queries dataset from the 'datasets' config group 17 | qa_dataset: 18 | 19 | # a list of names of the passages datasets from the 'ctx_sources' config group 20 | ctx_datatsets: 21 | 22 | #Glob paths to encoded passages (from generate_dense_embeddings tool) 23 | encoded_ctx_files: 24 | 25 | out_file: 26 | # "regex" or "string" 27 | match: string 28 | n_docs: 100 29 | validation_workers: 16 30 | 31 | # Batch size to generate query embeddings 32 | batch_size: 128 33 | 34 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 35 | do_lower_case: True 36 | 37 | # The attribute name of encoder to use for queries. Options for the BiEncoder model: question_model, ctx_model 38 | # question_model is used if this param is empty 39 | encoder_path: 40 | 41 | # path to the FAISS index location - it is only needed if you want to serialize faiss index to files or read from them 42 | # (instead of using encoded_ctx_files) 43 | # it should point to either directory or a common index files prefix name 44 | # if there is no index at the specific location, the index will be created from encoded_ctx_files 45 | index_path: 46 | 47 | kilt_out_file: 48 | 49 | # A trained bi-encoder checkpoint file to initialize the model 50 | model_file: 51 | 52 | validate_as_tables: False 53 | 54 | # RPC settings 55 | rpc_retriever_cfg_file: 56 | rpc_index_id: 57 | use_l2_conversion: False 58 | use_rpc_meta: False 59 | rpc_meta_compressed: False 60 | 61 | indexer: flat 62 | 63 | # tokens which won't be slit by tokenizer 64 | special_tokens: 65 | 66 | # TODO: move to a conf group 67 | # local_rank for distributed training on gpus 68 | local_rank: -1 69 | global_loss_buf_sz: 150000 70 | device: 71 | distributed_world_size: 72 | distributed_port: 73 | no_cuda: False 74 | n_gpu: 75 | fp16: False 76 | 77 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 78 | # "See details at https://nvidia.github.io/apex/amp.html 79 | fp16_opt_level: O1 80 | 81 | 82 | # preprocessing details 83 | row_selection: 'none' # ['none', 'random', 'ngram'] 84 | max_cell_num: None # maximum number of cells to include in the sequence 85 | 86 | max_words: 120 87 | max_words_per_header: 12 # maximum number of words in the header cell 88 | max_words_per_cell: 8 # ~ in the content cells 89 | max_cell_num_per_row: 64 # default to no restrictions 90 | header_delimiter: '|' 91 | cell_delimiter: '|' 92 | row_delimiter: '.' 93 | 94 | max_sequence_length: 256 95 | 96 | structure_option: 'global' # [global, rowcol, auxemb] -------------------------------------------------------------------------------- /dpr/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for distributed model training 10 | """ 11 | 12 | import pickle 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | def get_rank(): 19 | return dist.get_rank() 20 | 21 | 22 | def get_world_size(): 23 | return dist.get_world_size() 24 | 25 | 26 | def get_default_group(): 27 | return dist.group.WORLD 28 | 29 | 30 | def all_reduce(tensor, group=None): 31 | if group is None: 32 | group = get_default_group() 33 | return dist.all_reduce(tensor, group=group) 34 | 35 | 36 | def all_gather_list(data, group=None, max_size=16384): 37 | """Gathers arbitrary data from all nodes into a list. 38 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 39 | data. Note that *data* must be picklable. 40 | Args: 41 | data (Any): data from the local worker to be gathered on other workers 42 | group (optional): group of the collective 43 | """ 44 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 45 | 46 | enc = pickle.dumps(data) 47 | enc_size = len(enc) 48 | 49 | if enc_size + SIZE_STORAGE_BYTES > max_size: 50 | raise ValueError( 51 | 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) 52 | 53 | rank = get_rank() 54 | world_size = get_world_size() 55 | buffer_size = max_size * world_size 56 | 57 | if not hasattr(all_gather_list, '_buffer') or \ 58 | all_gather_list._buffer.numel() < buffer_size: 59 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 60 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 61 | 62 | buffer = all_gather_list._buffer 63 | buffer.zero_() 64 | cpu_buffer = all_gather_list._cpu_buffer 65 | 66 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( 67 | 256 ** SIZE_STORAGE_BYTES) 68 | 69 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 70 | 71 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 72 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 73 | 74 | start = rank * max_size 75 | size = enc_size + SIZE_STORAGE_BYTES 76 | buffer[start: start + size].copy_(cpu_buffer[:size]) 77 | 78 | all_reduce(buffer, group=group) 79 | 80 | try: 81 | result = [] 82 | for i in range(world_size): 83 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 84 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 85 | if size > 0: 86 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) 87 | return result 88 | except pickle.UnpicklingError: 89 | raise Exception( 90 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 91 | 'workers to enter the function together, so this error usually indicates ' 92 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 93 | 'sync if one of them runs out of memory, or if there are other conditions ' 94 | 'in your training script that can cause one worker to finish an epoch ' 95 | 'while other workers are still iterating over their portions of the data.' 96 | ) 97 | -------------------------------------------------------------------------------- /convert_data.py: -------------------------------------------------------------------------------- 1 | """Convert the table format in the retrieved ctxs from text to table dict. """ 2 | 3 | import json 4 | import argparse 5 | 6 | from typing import List 7 | 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | 13 | def load_tables_dict(tables_file: str, key: str = 'id'): 14 | tables_dict = {} 15 | with open(tables_file, 'r') as fr: 16 | for line in fr: 17 | tdict = json.loads(line.strip()) 18 | table_id = tdict[key] 19 | tables_dict[table_id] = tdict 20 | return tables_dict 21 | 22 | 23 | def get_annotated_table_ids(anno_path: str) -> List[str]: 24 | gold_ids = [] 25 | 26 | with open(anno_path, 'r') as fr: 27 | for line in fr: 28 | sample = json.loads(line.strip()) 29 | gold_ids.append(sample['table']['tableId']) 30 | 31 | return gold_ids 32 | 33 | 34 | def convert_ctxs( 35 | tables_dict, 36 | retrieved_path: str, 37 | converted_path: str, 38 | gold_table_ids: List[str] 39 | ): 40 | with open(retrieved_path, 'r') as fr: 41 | dataset = json.load(fr) 42 | 43 | newset = [] 44 | for i, sample in enumerate(dataset): 45 | pos_ctxs, neg_ctxs = [], [] 46 | 47 | if gold_table_ids: 48 | gold_tid = gold_table_ids[i] 49 | gold_table = tables_dict[gold_tid] 50 | 51 | table_ctx = { 52 | 'id': gold_tid, 53 | 'title': gold_table['title'], 54 | 'score': 1.0, 55 | 'has_answer': True, 56 | 'table': gold_table, 57 | } 58 | pos_ctxs.append(table_ctx) 59 | 60 | for j, ctx in enumerate(sample['ctxs']): 61 | table_id = ctx['id'] 62 | if ':' in table_id: 63 | table_id = table_id[table_id.index(':')+1: ] 64 | ctab = tables_dict[table_id] 65 | assert ctab is not None 66 | 67 | table_ctx = { 68 | 'id': ctx['id'], 69 | 'title': ctx['title'], 70 | 'score': ctx['score'], 71 | 'has_answer': ctx['has_answer'], 72 | 'table': ctab, 73 | } 74 | 75 | if table_ctx['has_answer']: 76 | pos_ctxs.append(table_ctx) 77 | else: 78 | neg_ctxs.append(table_ctx) 79 | 80 | num_hard_negatives = min(5, len(neg_ctxs)) 81 | new_sample = { 82 | 'question': sample['question'], 83 | 'answers': sample['answers'], 84 | 'positive_ctxs': pos_ctxs if pos_ctxs else neg_ctxs[:2], 85 | 'negative_ctxs': neg_ctxs[num_hard_negatives: ], 86 | 'hard_negative_ctxs': neg_ctxs[: num_hard_negatives], 87 | } 88 | newset.append(new_sample) 89 | 90 | with open(converted_path, 'w') as fw: 91 | json.dump(newset, fw) 92 | 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | 98 | parser.add_argument('--tables_file', type=str, required=True) 99 | parser.add_argument('--retrieved_path', type=str, required=True) 100 | parser.add_argument('--converted_path', type=str, required=True) 101 | parser.add_argument('--annotated_path', type=str, default=None) 102 | 103 | args = parser.parse_args() 104 | 105 | tables_dict = load_tables_dict(args.tables_file) 106 | logger.info(f"loaded {len(tables_dict)} tables") 107 | 108 | gold_table_ids = None 109 | if args.annotated_path: 110 | gold_table_ids = get_annotated_table_ids(args.annotated_path) 111 | 112 | convert_ctxs(tables_dict, args.retrieved_path, args.converted_path, gold_table_ids) -------------------------------------------------------------------------------- /dpr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import importlib 9 | 10 | """ 11 | 'Router'-like set of methods for component initialization with lazy imports 12 | """ 13 | 14 | 15 | def init_hf_bert_biencoder(args, **kwargs): 16 | if importlib.util.find_spec("transformers") is None: 17 | raise RuntimeError("Please install transformers lib") 18 | from .hf_models import get_bert_biencoder_components 19 | 20 | return get_bert_biencoder_components(args, **kwargs) 21 | 22 | 23 | def init_hf_bert_mix_biencoder(args, **kwargs): 24 | if importlib.util.find_spec("transformers") is None: 25 | raise RuntimeError("Please install transformers lib") 26 | from .hf_models import get_bert_mix_biencoder_components 27 | 28 | return get_bert_mix_biencoder_components(args, **kwargs) 29 | 30 | 31 | def init_hf_bert_bias_biencoder(args, **kwargs): 32 | if importlib.util.find_spec("transformers") is None: 33 | raise RuntimeError("Please install transformers lib") 34 | from .hf_models import get_bert_bias_biencoder_components 35 | 36 | return get_bert_bias_biencoder_components(args, **kwargs) 37 | 38 | 39 | def init_hf_bert_reader(args, **kwargs): 40 | if importlib.util.find_spec("transformers") is None: 41 | raise RuntimeError("Please install transformers lib") 42 | from .hf_models import get_bert_reader_components 43 | 44 | return get_bert_reader_components(args, **kwargs) 45 | 46 | 47 | # def init_pytext_bert_biencoder(args, **kwargs): 48 | # if importlib.util.find_spec("pytext") is None: 49 | # raise RuntimeError("Please install pytext lib") 50 | # from .pytext_models import get_bert_biencoder_components 51 | 52 | # return get_bert_biencoder_components(args, **kwargs) 53 | 54 | 55 | # def init_fairseq_roberta_biencoder(args, **kwargs): 56 | # if importlib.util.find_spec("fairseq") is None: 57 | # raise RuntimeError("Please install fairseq lib") 58 | # from .fairseq_models import get_roberta_biencoder_components 59 | 60 | return get_roberta_biencoder_components(args, **kwargs) 61 | 62 | 63 | def init_hf_bert_tenzorizer(args, **kwargs): 64 | if importlib.util.find_spec("transformers") is None: 65 | raise RuntimeError("Please install transformers lib") 66 | from .hf_models import get_bert_tensorizer 67 | 68 | return get_bert_tensorizer(args) 69 | 70 | 71 | def init_hf_roberta_tenzorizer(args, **kwargs): 72 | if importlib.util.find_spec("transformers") is None: 73 | raise RuntimeError("Please install transformers lib") 74 | from .hf_models import get_roberta_tensorizer 75 | return get_roberta_tensorizer(args.encoder.pretrained_model_cfg, args.do_lower_case, args.encoder.sequence_length) 76 | 77 | 78 | BIENCODER_INITIALIZERS = { 79 | "hf_bert": init_hf_bert_biencoder, 80 | # "pytext_bert": init_pytext_bert_biencoder, 81 | # "fairseq_roberta": init_fairseq_roberta_biencoder, 82 | "hf_bert_mix": init_hf_bert_mix_biencoder, 83 | "hf_bert_bias": init_hf_bert_bias_biencoder, 84 | } 85 | 86 | READER_INITIALIZERS = { 87 | "hf_bert": init_hf_bert_reader, 88 | } 89 | 90 | TENSORIZER_INITIALIZERS = { 91 | "hf_bert": init_hf_bert_tenzorizer, 92 | "hf_roberta": init_hf_roberta_tenzorizer, 93 | "pytext_bert": init_hf_bert_tenzorizer, # using HF's code as of now 94 | "fairseq_roberta": init_hf_roberta_tenzorizer, # using HF's code as of now 95 | } 96 | 97 | 98 | def init_comp(initializers_dict, type, args, **kwargs): 99 | if type in initializers_dict: 100 | return initializers_dict[type](args, **kwargs) 101 | else: 102 | raise RuntimeError("unsupported model type: {}".format(type)) 103 | 104 | 105 | def init_biencoder_components(encoder_type: str, args, **kwargs): 106 | return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) 107 | 108 | 109 | def init_reader_components(encoder_type: str, args, **kwargs): 110 | return init_comp(READER_INITIALIZERS, encoder_type, args, **kwargs) 111 | 112 | 113 | def init_tenzorizer(encoder_type: str, args, **kwargs): 114 | return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) 115 | -------------------------------------------------------------------------------- /dpr/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import glob 10 | import logging 11 | import os 12 | from typing import List, Tuple 13 | 14 | import torch 15 | from torch import nn 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.serialization import default_restore_location 18 | 19 | logger = logging.getLogger() 20 | 21 | CheckpointState = collections.namedtuple( 22 | "CheckpointState", 23 | [ 24 | "model_dict", 25 | "optimizer_dict", 26 | "scheduler_dict", 27 | "offset", 28 | "epoch", 29 | "encoder_params", 30 | ], 31 | ) 32 | 33 | 34 | def setup_for_distributed_mode( 35 | model: nn.Module, 36 | optimizer: torch.optim.Optimizer, 37 | device: object, 38 | n_gpu: int = 1, 39 | local_rank: int = -1, 40 | fp16: bool = False, 41 | fp16_opt_level: str = "O1", 42 | ) -> Tuple[nn.Module, torch.optim.Optimizer]: 43 | model.to(device) 44 | if fp16: 45 | try: 46 | import apex 47 | from apex import amp 48 | 49 | apex.amp.register_half_function(torch, "einsum") 50 | except ImportError: 51 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 52 | 53 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 54 | 55 | if n_gpu > 1: 56 | model = torch.nn.DataParallel(model) 57 | 58 | if local_rank != -1: 59 | model = torch.nn.parallel.DistributedDataParallel( 60 | model, 61 | device_ids=[device if device else local_rank], 62 | output_device=local_rank, 63 | find_unused_parameters=True, 64 | ) 65 | return model, optimizer 66 | 67 | 68 | def move_to_cuda(sample): 69 | if len(sample) == 0: 70 | return {} 71 | 72 | def _move_to_cuda(maybe_tensor): 73 | if torch.is_tensor(maybe_tensor): 74 | return maybe_tensor.cuda() 75 | elif isinstance(maybe_tensor, dict): 76 | return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} 77 | elif isinstance(maybe_tensor, list): 78 | return [_move_to_cuda(x) for x in maybe_tensor] 79 | elif isinstance(maybe_tensor, tuple): 80 | return [_move_to_cuda(x) for x in maybe_tensor] 81 | else: 82 | return maybe_tensor 83 | 84 | return _move_to_cuda(sample) 85 | 86 | 87 | def move_to_device(sample, device): 88 | if len(sample) == 0: 89 | return {} 90 | 91 | def _move_to_device(maybe_tensor, device): 92 | if torch.is_tensor(maybe_tensor): 93 | return maybe_tensor.to(device) 94 | elif isinstance(maybe_tensor, dict): 95 | return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()} 96 | elif isinstance(maybe_tensor, list): 97 | return [_move_to_device(x, device) for x in maybe_tensor] 98 | elif isinstance(maybe_tensor, tuple): 99 | return [_move_to_device(x, device) for x in maybe_tensor] 100 | else: 101 | return maybe_tensor 102 | 103 | return _move_to_device(sample, device) 104 | 105 | 106 | def get_schedule_linear( 107 | optimizer, 108 | warmup_steps, 109 | total_training_steps, 110 | steps_shift=0, 111 | last_epoch=-1, 112 | ): 113 | 114 | """Create a schedule with a learning rate that decreases linearly after 115 | linearly increasing during a warmup period. 116 | """ 117 | 118 | def lr_lambda(current_step): 119 | current_step += steps_shift 120 | if current_step < warmup_steps: 121 | return float(current_step) / float(max(1, warmup_steps)) 122 | return max( 123 | 1e-7, 124 | float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)), 125 | ) 126 | 127 | return LambdaLR(optimizer, lr_lambda, last_epoch) 128 | 129 | 130 | def init_weights(modules: List): 131 | for module in modules: 132 | if isinstance(module, (nn.Linear, nn.Embedding)): 133 | module.weight.data.normal_(mean=0.0, std=0.02) 134 | elif isinstance(module, nn.LayerNorm): 135 | module.bias.data.zero_() 136 | module.weight.data.fill_(1.0) 137 | if isinstance(module, nn.Linear) and module.bias is not None: 138 | module.bias.data.zero_() 139 | 140 | 141 | def get_model_obj(model: nn.Module): 142 | return model.module if hasattr(model, "module") else model 143 | 144 | 145 | def get_model_file(args, file_prefix) -> str: 146 | if args.model_file and os.path.exists(args.model_file): 147 | return args.model_file 148 | 149 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + "*")) if args.output_dir else [] 150 | logger.info("Checkpoint files %s", out_cp_files) 151 | model_file = None 152 | 153 | if len(out_cp_files) > 0: 154 | model_file = max(out_cp_files, key=os.path.getctime) 155 | return model_file 156 | 157 | 158 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 159 | logger.info("Reading saved model from %s", model_file) 160 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, "cpu")) 161 | logger.info("model_state_dict keys %s", state_dict.keys()) 162 | return CheckpointState(**state_dict) 163 | -------------------------------------------------------------------------------- /generate_embeddings.py: -------------------------------------------------------------------------------- 1 | """Generate dense embeddings. """ 2 | 3 | import os 4 | import math 5 | import hydra 6 | import pickle 7 | import pathlib 8 | import logging 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from typing import List, Tuple 14 | from omegaconf import DictConfig, OmegaConf 15 | from dpr.options import set_cfg_params_from_state, setup_cfg_gpu, setup_logger 16 | from dpr.utils.data_utils import Tensorizer 17 | from dpr.utils.model_utils import ( 18 | setup_for_distributed_mode, 19 | get_model_obj, 20 | load_states_from_checkpoint, 21 | move_to_device, 22 | ) 23 | from dpr.models import init_biencoder_components 24 | from dpr.data.biencoder_data import BiEncoderTable 25 | from dpr.data.table_data import prepare_table_ctx_inputs_batch 26 | 27 | 28 | logger = logging.getLogger() 29 | setup_logger(logger) 30 | 31 | 32 | 33 | def get_table_ctx_vectors( 34 | cfg: DictConfig, 35 | ctx_rows: List[Tuple[object, BiEncoderTable]], 36 | model: nn.Module, 37 | tensorizer: Tensorizer, 38 | insert_title: bool = True, 39 | ): 40 | """Encode table with context encoder under global/rowcol/auxemb settings.""" 41 | n = len(ctx_rows) 42 | bsz = cfg.batch_size 43 | total = 0 44 | results = [] 45 | for j, batch_start in enumerate(range(0, n, bsz)): 46 | batch = ctx_rows[batch_start: batch_start + bsz] 47 | input_tensors = prepare_table_ctx_inputs_batch( 48 | batch, 49 | tensorizer.tokenizer, 50 | cfg.structure_option, 51 | insert_title, 52 | cfg.max_sequence_length, 53 | ) 54 | ctx_ids_batch = move_to_device(input_tensors['token_ids'], cfg.device) 55 | ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch).long(), cfg.device) 56 | 57 | if cfg.structure_option == "rowcol": 58 | ctx_attn_mask = move_to_device(input_tensors['attn_mask'], cfg.device) 59 | else: 60 | ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), cfg.device) 61 | 62 | if cfg.structure_option == 'auxemb': 63 | ctx_row_batch = move_to_device(input_tensors['row_ids'], cfg.device) 64 | ctx_col_batch = move_to_device(input_tensors['column_ids'], cfg.device) 65 | with torch.no_grad(): 66 | _, out, _ = model( 67 | input_ids=ctx_ids_batch, 68 | token_type_ids=ctx_seg_batch, 69 | attention_mask=ctx_attn_mask, 70 | row_ids=ctx_row_batch, 71 | column_ids=ctx_col_batch, 72 | ) 73 | elif cfg.structure_option == "biased": 74 | ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), cfg.device) 75 | ctx_bias_mask_id = move_to_device(input_tensors['row_ids'], cfg.device) 76 | ctx_col_batch = move_to_device(input_tensors['column_ids'], cfg.device) 77 | with torch.no_grad(): 78 | _, out, _ = model( 79 | input_ids=ctx_ids_batch, 80 | token_type_ids=ctx_seg_batch, 81 | attention_mask=ctx_attn_mask, 82 | row_ids=ctx_bias_mask_id, 83 | column_ids=ctx_col_batch, 84 | ) 85 | else: 86 | with torch.no_grad(): 87 | _, out, _ = model( 88 | input_ids=ctx_ids_batch, 89 | token_type_ids=ctx_seg_batch, 90 | attention_mask=ctx_attn_mask, 91 | ) 92 | out = out.cpu() 93 | 94 | ctx_ids= [r[0] for r in batch] 95 | extra_info = [] 96 | if len(batch[0]) > 3: extra_info = [r[3:] for r in batch] 97 | assert len(ctx_ids) == out.size(0) 98 | total += len(ctx_ids) 99 | 100 | if extra_info: 101 | results.extend([(ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i]) for i in range(out.size(0))]) 102 | else: 103 | results.extend([(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]) 104 | 105 | if total % 10 == 0: logger.info("Encoded passages %d", total) 106 | 107 | return results 108 | 109 | 110 | 111 | 112 | @hydra.main(config_path="conf", config_name="gen_embs") 113 | def main(cfg: DictConfig): 114 | 115 | assert cfg.model_file, "Please specify encoder checkpoint as model_file param" 116 | assert cfg.ctx_src, "Please specify passages source as ctx_src param" 117 | 118 | cfg = setup_cfg_gpu(cfg) 119 | 120 | saved_state = load_states_from_checkpoint(cfg.model_file) 121 | set_cfg_params_from_state(saved_state.encoder_params, cfg) 122 | 123 | logger.info("CFG:") 124 | logger.info("%s", OmegaConf.to_yaml(cfg)) 125 | 126 | tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True) 127 | 128 | encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model 129 | 130 | encoder, _ = setup_for_distributed_mode( 131 | encoder, None, cfg.device, cfg.n_gpu, 132 | cfg.local_rank, cfg.fp16, cfg.fp16_opt_level, 133 | ) 134 | encoder.eval() 135 | 136 | # load weights from the model file 137 | model_to_load = get_model_obj(encoder) 138 | logger.info("Loading saved model state ...") 139 | logger.debug("saved model keys =%s", saved_state.model_dict.keys()) 140 | 141 | prefix_len = len("ctx_model.") 142 | ctx_state = { 143 | key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("ctx_model.") 144 | } 145 | model_to_load.load_state_dict(ctx_state, strict=False) 146 | 147 | # load from table data sources 148 | logger.info("reading data source: %s", cfg.ctx_src) 149 | 150 | ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src]) 151 | all_passages_dict = {} 152 | ctx_src.load_data_to(all_passages_dict, cfg) 153 | all_passages = [(k, v) for k, v in all_passages_dict.items()] 154 | 155 | shard_size = math.ceil(len(all_passages) / cfg.num_shards) 156 | start_idx = cfg.shard_id * shard_size 157 | end_idx = start_idx + shard_size 158 | 159 | logger.info( 160 | "Producing encodings for passages range: %d to %d (out of total %d)", 161 | start_idx, end_idx, len(all_passages), 162 | ) 163 | shard_passages = all_passages[start_idx:end_idx] 164 | 165 | data = get_table_ctx_vectors(cfg, shard_passages, encoder, tensorizer, insert_title=True) 166 | 167 | file = cfg.out_file + "_" + str(cfg.shard_id) 168 | pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) 169 | logger.info("Writing results to %s" % file) 170 | with open(file, mode="wb") as f: 171 | pickle.dump(data, f) 172 | 173 | logger.info("Total passages processed %d. Written to %s", len(data), file) 174 | 175 | 176 | 177 | if __name__ == "__main__": 178 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open-Domain Table Retrieval for Natural Questions 2 | 3 | MIT License 4 | 5 | This repository involves the data and code for the paper: 6 | 7 | [Table Retrieval May Not Necessitate Table-specific Model Design](https://arxiv.org/pdf/2205.09843.pdf) 8 | 9 | 10 | ## Preliminaries 11 | To install the necessary libraries, run 12 | ``` 13 | pip install . 14 | ``` 15 | You'll also need the pre-trained DPR model checkpoint for (1) evaluating its zero-shot performance, and (2) initialize the model instance before start the fine-tuning. 16 | To do this, run 17 | ``` 18 | cd ./dpr/ 19 | 20 | python data/download_data.py \ 21 | --resource checkpoint.retriever.single-adv-hn.nq.bert-base-encoder \ 22 | --output_dir ../../downloads/checkpoint 23 | ``` 24 | 25 | remember to set the ROOT before running the given scripts: 26 | ``` 27 | export ROOT_DIR=`pwd` 28 | ``` 29 | 30 | 31 | ## Zero-shot Retrieval Inference 32 | To perform zero-shot retrieval inference on the NQ-table table retrieval dataset, we need to first generate the embeddings of all tables (`generate_embeddings.py`), then, encode question/queries in-time and search for the most relevant tables (`dense_retrieval.py`). 33 | 34 | To automate the entire pipeline, just execute the `scripts/run_inference.sh`. This will iterate each dataset to generate context embeddings and run inference accordingly. 35 | 36 | If you want a more concrete walk-thru of each module, we detail the them as follows: 37 | 38 | **Generate Context Embeddings** 39 | 40 | Different table contexts, located in different files and may need to be loaded with different classes, are specified in the `conf/ctx_sources/table_sources.yaml` file. One can alter the `ctx_src` argument when calling the `generate_embeddings.py` script. 41 | By default, we use NQ-Table which is denoted as `nq_table`. 42 | 43 | For example, to generate embeddings of NQ-Table, run: 44 | ``` 45 | python generate_embeddings.py 46 | ctx_src=nq_table \ 47 | out_file=${your_path_to_store_embeddings} 48 | ``` 49 | 50 | 51 | **Retrieve Relevant Tables for Questions/Queries** 52 | 53 | In this step, we need to specify the file(s) containing questions so as to pair relevant tables for them using the generated embeddings. 54 | Likewise, these files are included in the `conf/datasets/table_retrieval.yaml` file. One can alter the `qa_dataset` argument to load different questions, when calling the `dense_retrieval.py`. 55 | 56 | The train/dev/test sets of NQ-Table are indicated by `nq_table_train`/`nq_table_dev`/`nq_table_test`. 57 | 58 | For example, to run retrieval inference on NQ-Table test questions, run: 59 | ``` 60 | python dense_retrieval.py 61 | ctx_datatsets=[nq_table] \ 62 | encoded_ctx_files=[${your_path_to_store_embeddings}"_0"] \ 63 | qa_dataset=nq_table_test \ 64 | out_file=${your_path_for_the_retrieval_result.json} 65 | ``` 66 | 67 | 68 | ## Fine-tune with Model Variants 69 | 70 | **Settings** 71 | 72 | Neither DPR has table-specific designs nor has it been trained on tables. 73 | We further explore the benefit of (1) augmented fine-tuning, and (2) add auxiliary structure-aware modules. 74 | 75 | The first and naive version of fine-tuning feeds models with serialized table content and applies **no** model modifcations. We denote this as the `global` setting (since it applies a global attention against structurally restricted ones). 76 | 77 | The other three fine-tune setting adppts the two major methods to incorporate table structures. 78 | 1. Adding auxiliary embeddings, specifically for row and column indices. We denote this as the `auxemb` setting. 79 | 2. Applying structure-aware attention, by enforcing tokens to be visible in-row or in-column. This is denoted as the `rowcol` setting. 80 | 3. Adding relation-based attention bias onto the global self-attention scores, denoted as `biased`. 81 | 82 | Among experiments, one can alter between these three settings by specifiing the `structure_option` argument. 83 | For `auxemb` and `biased` which requires extra parameters (hence change in model architecure), alter the encoder type by additionally specifying 84 | ``` 85 | encoder.encoder_model_type="hf_bert_mix" # or "hf_bert_bias" 86 | ``` 87 | 88 | 89 | **Creating Training (and Validation) Dataset** 90 | 91 | To obtain the most effective training data, we follow the hard-negative selection strategy and leverage the retrieval results for sample curation. To be more concrete, for trainable datasets (NQ-Table and WebQueryTable), we firstly run zero-shot retrieval for training and validation samples. Then for each question and its retrieved 100 table contexts, we categorize them into (1) positive, (2) negative, (3) hard negative. To (1) if it contains the answer text, and to (2)/(3) otherwise. If the context ranks among the top-20, it goes into (3), otherwise would be a rather simple negative context and goes into (2). 92 | 93 | To implement this, we also need to run `dense_retrieval.py` inference using the generated table context embeddings. 94 | Then, convert the retrieval result into training format using `convert_data.py`. 95 | ``` 96 | python get_trainset_from_retrieved.py \ 97 | ${raw_tables_path} \ 98 | ${retrieved_result} \ 99 | ${converted_training_data} 100 | ``` 101 | Remember to do this for both of your training and validation samples. 102 | 103 | One can also automate this process by running the `scripts/curate_data.sh` 104 | 105 | 106 | **Bi-Encoder Training** 107 | 108 | With the curated datasets, we can then start fine-tuning using `train_biencoder.py`. Viable training options reads in the `conf/datasets/biencoder_train.yaml`. 109 | 110 | ``` 111 | python train_biencoder.py \ 112 | train=biencoder_nq \ 113 | train_datasets=[${your_train_data_pile}] \ 114 | dev_datasets=[${your_dev_data_file}] \ 115 | output_dir=${directory_for_checkpoints} \ 116 | checkpoint_file_name=${model_name} 117 | ``` 118 | Alter the arguments in `conf/biencoder_train.yaml` to train in different experimental setings. 119 | 120 | Or simply, just run `scripts/tune_model.sh`. 121 | 122 | 123 | **Evaluate with Tuned Models** 124 | 125 | Similarly to the zero-shot inference, but probably need to specified the fine-tuned model file, as well as the new names for embedding and retrieval results. See `scripts/run_inference.sh` for more details. 126 | 127 | 128 | 129 | ## Ablation Study (NQ-Table) 130 | **Delimiter** 131 | 132 | Use the `process_tables.py` to create table contexts linearized using different delimiters. 133 | Alter the arguments `header_delimiter`, `cell_delimiter`, and `row_delimiter` to compare. 134 | 135 | **Structure Perturbation** 136 | 137 | Use the `process_tables.py` to create processed tables shuffled in different orientations (by row, by column) and to the designated extent (prob default to 0.5). This will create `datasets/nq_table/tables_row.jsonl` and `datasets/nq_table/tables_column.jsonl`. 138 | 139 | 140 | ## Citation 141 | 142 | ``` 143 | @article{wang2022table, 144 | title={Table Retrieval May Not Necessitate Table-specific Model Design}, 145 | author={Wang, Zhiruo and Jiang, Zhengbao and Nyberg, Eric and Neubig, Graham}, 146 | journal={arXiv preprint arXiv:2205.09843}, 147 | year={2022} 148 | } 149 | ``` 150 | -------------------------------------------------------------------------------- /dpr/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line arguments utils 10 | """ 11 | 12 | 13 | import logging 14 | import os 15 | import random 16 | import socket 17 | import subprocess 18 | from typing import Tuple 19 | 20 | import numpy as np 21 | import torch 22 | from omegaconf import DictConfig 23 | 24 | logger = logging.getLogger() 25 | 26 | # TODO: to be merged with conf_utils.py 27 | 28 | 29 | def set_cfg_params_from_state(state: dict, cfg: DictConfig): 30 | """ 31 | Overrides some of the encoder config parameters from a give state object 32 | """ 33 | if not state: 34 | return 35 | 36 | cfg.do_lower_case = state["do_lower_case"] 37 | 38 | if "encoder" in state: 39 | saved_encoder_params = state["encoder"] 40 | # TODO: try to understand why cfg.encoder = state["encoder"] doesn't work 41 | 42 | for k, v in saved_encoder_params.items(): 43 | 44 | # TODO: tmp fix 45 | if k == "q_wav2vec_model_cfg": 46 | k = "q_encoder_model_cfg" 47 | if k == "q_wav2vec_cp_file": 48 | k = "q_encoder_cp_file" 49 | if k == "q_wav2vec_cp_file": 50 | k = "q_encoder_cp_file" 51 | 52 | setattr(cfg.encoder, k, v) 53 | else: # 'old' checkpoints backward compatibility support 54 | pass 55 | # cfg.encoder.pretrained_model_cfg = state["pretrained_model_cfg"] 56 | # cfg.encoder.encoder_model_type = state["encoder_model_type"] 57 | # cfg.encoder.pretrained_file = state["pretrained_file"] 58 | # cfg.encoder.projection_dim = state["projection_dim"] 59 | # cfg.encoder.sequence_length = state["sequence_length"] 60 | 61 | 62 | def get_encoder_params_state_from_cfg(cfg: DictConfig): 63 | """ 64 | Selects the param values to be saved in a checkpoint, so that a trained model can be used for downstream 65 | tasks without the need to specify these parameter again 66 | :return: Dict of params to memorize in a checkpoint 67 | """ 68 | return { 69 | "do_lower_case": cfg.do_lower_case, 70 | "encoder": cfg.encoder, 71 | } 72 | 73 | 74 | def set_seed(args): 75 | seed = args.seed 76 | random.seed(seed) 77 | np.random.seed(seed) 78 | torch.manual_seed(seed) 79 | if args.n_gpu > 0: 80 | torch.cuda.manual_seed_all(seed) 81 | 82 | 83 | def setup_cfg_gpu(cfg): 84 | """ 85 | Setup params for CUDA, GPU & distributed training 86 | """ 87 | logger.info("CFG's local_rank=%s", cfg.local_rank) 88 | ws = os.environ.get("WORLD_SIZE") 89 | cfg.distributed_world_size = int(ws) if ws else 1 90 | logger.info("Env WORLD_SIZE=%s", ws) 91 | 92 | if cfg.distributed_port and cfg.distributed_port > 0: 93 | logger.info("distributed_port is specified, trying to init distributed mode from SLURM params ...") 94 | init_method, local_rank, world_size, device = _infer_slurm_init(cfg) 95 | 96 | logger.info( 97 | "Inferred params from SLURM: init_method=%s | local_rank=%s | world_size=%s", 98 | init_method, 99 | local_rank, 100 | world_size, 101 | ) 102 | 103 | cfg.local_rank = local_rank 104 | cfg.distributed_world_size = world_size 105 | cfg.n_gpu = 1 106 | 107 | torch.cuda.set_device(device) 108 | device = str(torch.device("cuda", device)) 109 | 110 | torch.distributed.init_process_group( 111 | backend="nccl", init_method=init_method, world_size=world_size, rank=local_rank 112 | ) 113 | 114 | elif cfg.local_rank == -1 or cfg.no_cuda: # single-node multi-gpu (or cpu) mode 115 | device = str(torch.device("cuda" if torch.cuda.is_available() and not cfg.no_cuda else "cpu")) 116 | cfg.n_gpu = torch.cuda.device_count() 117 | else: # distributed mode 118 | torch.cuda.set_device(cfg.local_rank) 119 | device = str(torch.device("cuda", cfg.local_rank)) 120 | torch.distributed.init_process_group(backend="nccl") 121 | cfg.n_gpu = 1 122 | 123 | cfg.device = device 124 | 125 | logger.info( 126 | "Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d", 127 | socket.gethostname(), 128 | cfg.local_rank, 129 | cfg.device, 130 | cfg.n_gpu, 131 | cfg.distributed_world_size, 132 | ) 133 | logger.info("16-bits training: %s ", cfg.fp16) 134 | return cfg 135 | 136 | 137 | def _infer_slurm_init(cfg) -> Tuple[str, int, int, int]: 138 | 139 | node_list = os.environ.get("SLURM_STEP_NODELIST") 140 | if node_list is None: 141 | node_list = os.environ.get("SLURM_JOB_NODELIST") 142 | logger.info("SLURM_JOB_NODELIST: %s", node_list) 143 | 144 | if node_list is None: 145 | raise RuntimeError("Can't find SLURM node_list from env parameters") 146 | 147 | local_rank = None 148 | world_size = None 149 | distributed_init_method = None 150 | device_id = None 151 | try: 152 | hostnames = subprocess.check_output(["scontrol", "show", "hostnames", node_list]) 153 | distributed_init_method = "tcp://{host}:{port}".format( 154 | host=hostnames.split()[0].decode("utf-8"), 155 | port=cfg.distributed_port, 156 | ) 157 | nnodes = int(os.environ.get("SLURM_NNODES")) 158 | logger.info("SLURM_NNODES: %s", nnodes) 159 | ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") 160 | if ntasks_per_node is not None: 161 | ntasks_per_node = int(ntasks_per_node) 162 | logger.info("SLURM_NTASKS_PER_NODE: %s", ntasks_per_node) 163 | else: 164 | ntasks = int(os.environ.get("SLURM_NTASKS")) 165 | logger.info("SLURM_NTASKS: %s", ntasks) 166 | assert ntasks % nnodes == 0 167 | ntasks_per_node = int(ntasks / nnodes) 168 | 169 | if ntasks_per_node == 1: 170 | gpus_per_node = torch.cuda.device_count() 171 | node_id = int(os.environ.get("SLURM_NODEID")) 172 | local_rank = node_id * gpus_per_node 173 | world_size = nnodes * gpus_per_node 174 | logger.info("node_id: %s", node_id) 175 | else: 176 | world_size = ntasks_per_node * nnodes 177 | proc_id = os.environ.get("SLURM_PROCID") 178 | local_id = os.environ.get("SLURM_LOCALID") 179 | logger.info("SLURM_PROCID %s", proc_id) 180 | logger.info("SLURM_LOCALID %s", local_id) 181 | local_rank = int(proc_id) 182 | device_id = int(local_id) 183 | 184 | except subprocess.CalledProcessError as e: # scontrol failed 185 | raise e 186 | except FileNotFoundError: # Slurm is not installed 187 | pass 188 | return distributed_init_method, local_rank, world_size, device_id 189 | 190 | 191 | def setup_logger(logger): 192 | logger.setLevel(logging.INFO) 193 | if logger.hasHandlers(): 194 | logger.handlers.clear() 195 | log_formatter = logging.Formatter("[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s") 196 | console = logging.StreamHandler() 197 | console.setFormatter(log_formatter) 198 | logger.addHandler(console) 199 | -------------------------------------------------------------------------------- /process_table.py: -------------------------------------------------------------------------------- 1 | """Preprocess the NQ-Table. 2 | dst: {'id': str, 'title': str, 'cells': ['text': str, 'row_idx': 0, 'col_idx': 0]} 3 | """ 4 | 5 | import sys 6 | import json 7 | import random 8 | 9 | from typing import Dict 10 | 11 | from dpr.data.biencoder_data import get_processed_table_dict 12 | 13 | 14 | # %% ablation: shuffling the content, by a given extent `prob` 15 | 16 | def shuffle_nq_table_by_row(table: Dict, prob: float = 0.5): 17 | num_rows = len(table['rows']) 18 | for i in range(num_rows): 19 | p = random.random() 20 | if (p < prob): 21 | random.shuffle(table['rows'][i]['cells']) 22 | return table 23 | 24 | 25 | def shuffle_nq_table_by_column(table: Dict, prob: float = 0.5): 26 | num_rows = len(table['rows']) 27 | 28 | column_sizes = [len(row['cells']) for row in table['rows']] 29 | num_columns = min(column_sizes) 30 | 31 | for j in range(num_columns): 32 | p = random.random() 33 | if (p < prob): 34 | # collect cells in the j-th column 35 | cells = [row['cells'][j] for row in table['rows']] 36 | random.shuffle(cells) 37 | for i in range(num_rows): 38 | table['rows'][i]['cells'][j] = cells[i] 39 | return table 40 | 41 | 42 | def preprocess_nq_tables_perturbed( 43 | original_tables_path: str, 44 | processed_tables_path: str, 45 | mode: int, # 0: by row, 1: by column, 2: both 46 | prob: float = 0.5, 47 | ): 48 | fr = open(original_tables_path, 'r') 49 | fw = open(processed_tables_path, 'w') 50 | 51 | for idx, orig_line in enumerate(fr): 52 | # if (idx > 10): break 53 | orig_table = json.loads(orig_line.strip()) 54 | if (mode % 2) == 0: 55 | orig_table = shuffle_nq_table_by_row(orig_table, prob) 56 | if mode == 1: 57 | orig_table = shuffle_nq_table_by_column(orig_table, prob) 58 | 59 | proc_table = get_processed_table_dict( 60 | orig_table, 61 | row_selection='none', 62 | max_cell_num=120, 63 | max_words=120, 64 | max_words_per_header=10, 65 | max_words_per_cell=8, 66 | max_cell_num_per_row=64, 67 | header_delimiter='|', 68 | cell_delimiter='|', 69 | row_delimiter='.', 70 | return_dict=True, 71 | ) 72 | proc_line = json.dumps(proc_table) 73 | fw.write(f"{proc_line}\n") 74 | 75 | fr.close() 76 | fw.close() 77 | 78 | 79 | 80 | # %% normal pre-process 81 | # ablation: header/cell and row delimiters 82 | 83 | def preprocess_nq_tables( 84 | original_tables_path: str, 85 | processed_tables_path: str, 86 | keep_all: bool = False, 87 | cell_delim: bool = True, 88 | row_delim: bool = True, 89 | ): 90 | if cell_delim: 91 | cell_delimiter = '|' 92 | header_delimiter = '|' 93 | else: 94 | cell_delimiter = '' 95 | header_delimiter = '' 96 | 97 | if row_delim: 98 | row_delimiter = '.' 99 | else: 100 | row_delimiter = '' 101 | 102 | fr = open(original_tables_path, 'r') 103 | fw = open(processed_tables_path, 'w') 104 | 105 | for idx, orig_line in enumerate(fr): 106 | # if (idx > 10): break 107 | orig_table = json.loads(orig_line.strip()) 108 | if keep_all: 109 | proc_table = get_processed_table_dict( 110 | orig_table, 111 | row_selection='none', 112 | max_cell_num=100000, 113 | max_words=100000, 114 | max_words_per_header=100, 115 | max_words_per_cell=100, 116 | max_cell_num_per_row=1000, 117 | header_delimiter='', 118 | cell_delimiter='', 119 | row_delimiter='', 120 | return_dict=True, 121 | ) 122 | else: 123 | proc_table = get_processed_table_dict( 124 | orig_table, 125 | row_selection='random', 126 | max_cell_num=None, 127 | max_words=112, 128 | max_words_per_header=12, 129 | max_words_per_cell=8, 130 | max_cell_num_per_row=64, 131 | header_delimiter=header_delimiter, 132 | cell_delimiter=cell_delimiter, 133 | row_delimiter=row_delimiter, 134 | return_dict=True, 135 | ) 136 | proc_line = json.dumps(proc_table) 137 | fw.write(f"{proc_line}\n") 138 | 139 | fr.close() 140 | fw.close() 141 | 142 | 143 | 144 | # %% main 145 | 146 | if __name__ == "__main__": 147 | 148 | exp = sys.argv[1] 149 | 150 | if exp == "nq": # 1. nq-table 151 | preprocess_nq_tables( 152 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 153 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_all.jsonl", 154 | keep_all=True, 155 | ) 156 | preprocess_nq_tables( 157 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 158 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_proc_2.jsonl", 159 | keep_all=False, 160 | ) 161 | elif exp == 'perturb': # 2. ablation (perturb) 162 | preprocess_nq_tables_perturbed( 163 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 164 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_row.jsonl", 165 | mode=0, 166 | prob=1.0, 167 | ) 168 | preprocess_nq_tables_perturbed( 169 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 170 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_column.jsonl", 171 | mode=1, 172 | prob=1.0, 173 | ) 174 | preprocess_nq_tables_perturbed( 175 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 176 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_both.jsonl", 177 | mode=2, 178 | prob=1.0, 179 | ) 180 | elif exp == 'delim': # 3. ablation (delimiter) 181 | preprocess_nq_tables( 182 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 183 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_dcell.jsonl", 184 | keep_all=False, 185 | cell_delim=True, 186 | row_delim=False, 187 | ) 188 | preprocess_nq_tables( 189 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 190 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_drow.jsonl", 191 | keep_all=False, 192 | cell_delim=False, 193 | row_delim=True, 194 | ) 195 | preprocess_nq_tables( 196 | original_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl", 197 | processed_tables_path="/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables_dnone.jsonl", 198 | keep_all=False, 199 | cell_delim=False, 200 | row_delim=False, 201 | ) 202 | -------------------------------------------------------------------------------- /dpr/data/table_data.py: -------------------------------------------------------------------------------- 1 | """Table data processing functions. """ 2 | 3 | from re import A 4 | import torch 5 | from typing import Dict, List, Tuple 6 | from .biencoder_data import BiEncoderTable 7 | 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | import transformers 12 | if transformers.__version__.startswith("4"): from transformers import BertTokenizer 13 | else: from transformers.tokenization_bert import BertTokenizer 14 | 15 | 16 | def to_max_len(seq_list: List[int], pad_id: int, max_len: int): 17 | if len(seq_list) < max_len: 18 | seq_list.extend([pad_id for _ in range(max_len-len(seq_list))]) 19 | seq_list = seq_list[: max_len] 20 | return seq_list 21 | 22 | 23 | 24 | def create_global_attn_mask(max_seq_length: int) -> torch.LongTensor: 25 | mask = torch.ones(max_seq_length, max_seq_length).long() 26 | return mask 27 | 28 | 29 | def create_rowcol_attn_mask( 30 | row_ids: List[int], 31 | column_ids: List[int], 32 | title_len: int, 33 | ) -> torch.LongTensor: 34 | # tokens within the same row OR column are mutually visible 35 | mask = (row_ids == row_ids.transpose(0, 1)) & (column_ids == column_ids.transpose(0, 1)) 36 | mask = mask.long() 37 | # title tokens are globally visible 38 | mask[: title_len, :] = 1 39 | mask[:, : title_len] = 1 40 | return mask 41 | 42 | 43 | 44 | def prepare_table_ctx_inputs( 45 | ctx: BiEncoderTable, 46 | tokenizer: BertTokenizer, 47 | structure_option: str = "global", 48 | insert_title: bool = True, 49 | max_seq_length: int = 256, 50 | ): 51 | """Tokenize a single table. """ 52 | token_ids = [] 53 | row_ids, col_ids = [], [] 54 | 55 | if insert_title: 56 | if hasattr(ctx, "title"): 57 | text = ctx.title.strip() 58 | else: 59 | text = ctx['title'].strip() 60 | title_token_ids = tokenizer.encode( 61 | text, 62 | add_special_tokens=True, 63 | max_length=max_seq_length, 64 | truncation=True, 65 | pad_to_max_length=False 66 | ) 67 | title_token_ids = title_token_ids[1: ] # remove [CLS] 68 | title_len = len(title_token_ids) 69 | token_ids.extend(title_token_ids) 70 | row_ids.extend([0 for _ in title_token_ids]) 71 | col_ids.extend([0 for _ in title_token_ids]) 72 | 73 | if hasattr(ctx, "cells"): 74 | cell_list = ctx.cells 75 | else: 76 | cell_list = ctx['cells'] 77 | for cell in cell_list: 78 | text = cell['text'].strip() 79 | cell_token_ids = tokenizer.encode( 80 | text, 81 | add_special_tokens=False, 82 | max_length=max_seq_length, 83 | truncation=True, 84 | pad_to_max_length=False 85 | ) 86 | token_ids.extend(cell_token_ids) 87 | row_ids.extend([cell['row_idx'] for _ in cell_token_ids]) 88 | col_ids.extend([cell['col_idx'] for _ in cell_token_ids]) 89 | 90 | assert len(token_ids) == len(row_ids) == len(col_ids) 91 | valid_len = min(len(token_ids), max_seq_length) 92 | 93 | token_ids = to_max_len(token_ids, tokenizer.pad_token_id, max_seq_length) 94 | token_ids[-1] = tokenizer.sep_token_id 95 | token_ids = torch.LongTensor(token_ids) 96 | row_ids = torch.LongTensor(to_max_len(row_ids, 0, max_seq_length)) 97 | col_ids = torch.LongTensor(to_max_len(col_ids, 0, max_seq_length)) # [max-len] 98 | 99 | if structure_option == "rowcol": 100 | attn_mask = create_rowcol_attn_mask(row_ids.unsqueeze(0), col_ids.unsqueeze(0), title_len) 101 | else: 102 | attn_mask = create_global_attn_mask(max_seq_length) 103 | 104 | # set pad positions to invisible 105 | attn_mask[valid_len: , :] = 0 106 | attn_mask[:, valid_len: ] = 0 107 | 108 | if structure_option == "biased": 109 | bias_mask_id = create_biased_id(row_ids, col_ids) 110 | return { 111 | 'token_ids': token_ids, 112 | 'attn_mask': attn_mask, 113 | 'row_ids': bias_mask_id, 114 | 'column_ids': col_ids, 115 | } 116 | 117 | return { 118 | 'token_ids': token_ids, 119 | 'attn_mask': attn_mask, 120 | 'row_ids': row_ids, 121 | 'column_ids': col_ids, 122 | } 123 | 124 | 125 | 126 | def prepare_table_ctx_inputs_batch( 127 | batch: List[Tuple[object, BiEncoderTable]], 128 | tokenizer: BertTokenizer, 129 | structure_option: str = "global", 130 | insert_title: bool = True, 131 | max_seq_length: int = 256, 132 | ) -> Dict[str, torch.Tensor]: 133 | token_ids_batch, attn_mask_batch = [], [] 134 | row_ids_batch, column_ids_batch = [], [] 135 | for ctx in batch: 136 | ctx_input_tensors = prepare_table_ctx_inputs( 137 | ctx[1], tokenizer, structure_option, 138 | insert_title, max_seq_length, 139 | ) 140 | token_ids_batch.append(ctx_input_tensors['token_ids']) 141 | attn_mask_batch.append(ctx_input_tensors['attn_mask']) 142 | row_ids_batch.append(ctx_input_tensors['row_ids']) 143 | column_ids_batch.append(ctx_input_tensors['column_ids']) 144 | 145 | token_ids_batch = torch.stack(token_ids_batch, dim=0) # [batch-size, max-seq-len] 146 | attn_mask_batch = torch.stack(attn_mask_batch, dim=0) # [batch-size, max-seq-len, max-seq-len] 147 | row_ids_batch = torch.stack(row_ids_batch, dim=0) # [batch-size, max-seq-len] 148 | column_ids_batch = torch.stack(column_ids_batch, dim=0) # [batch-size, max-seq-len] 149 | 150 | return { 151 | 'token_ids': token_ids_batch, 152 | 'attn_mask': attn_mask_batch, 153 | 'row_ids': row_ids_batch, 154 | 'column_ids': column_ids_batch, 155 | } 156 | 157 | # %% biased 158 | 159 | def create_biased_id(row_ids: torch.Tensor, column_ids: torch.Tensor) -> torch.Tensor: 160 | """Compute relation-based bias id. 161 | args: 162 | row_ids: 163 | column_ids: 164 | valid_len: 165 | ret: 166 | bias_id: 167 | 168 | notes: 169 | - title: row-id = 0, col-id = 0 170 | - header: row-id = 0, col-id = 1-indexed 171 | - cell: row-id = 1-indexed, col-id = 1-indexed 172 | """ 173 | n = row_ids.size()[0] 174 | bias_id = [] 175 | 176 | for i in range(n): 177 | i_bid = [] 178 | irow, icol = row_ids[i], column_ids[i] 179 | for j in range(n): 180 | jrow, jcol = row_ids[j], column_ids[j] 181 | 182 | if (irow == 0) and (icol == 0): # [f] sentence 183 | if (jrow == 0) and (jcol == 0): # [t] sentence 184 | ij_bid = 0 185 | elif (jrow == 0): # [t] header 186 | ij_bid = 1 187 | else: # [t] cell 188 | ij_bid = 2 189 | elif (irow == 0): # [f] header 190 | if (jrow == 0) and (jcol == 0): # [t] sentence 191 | ij_bid = 3 192 | elif (jrow == 0) and (icol == jcol): # [t] same header 193 | ij_bid = 4 194 | elif (jrow == 0): # [t] other header 195 | ij_bid = 5 196 | else: # [t] cell 197 | ij_bid = 6 198 | else: # [f] cell 199 | if (jrow == 0) and (jcol == 0): # [t] sentence 200 | ij_bid = 7 201 | elif (jrow == 0): # [t] column header 202 | ij_bid = 8 203 | elif (irow == jrow) and (icol == jcol): # [t] same cell 204 | ij_bid = 9 205 | elif (irow == jrow): # [t] same row 206 | ij_bid = 10 207 | elif (icol == jcol): # [t] same column 208 | ij_bid = 11 209 | else: 210 | ij_bid = 12 211 | 212 | i_bid.append(ij_bid) 213 | 214 | bias_id.append(i_bid) 215 | return torch.LongTensor(bias_id) 216 | -------------------------------------------------------------------------------- /dpr/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | """ 10 | Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency 11 | """ 12 | 13 | import copy 14 | import logging 15 | 16 | import regex 17 | import spacy 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Tokens(object): 23 | """A class to represent a list of tokenized text.""" 24 | 25 | TEXT = 0 26 | TEXT_WS = 1 27 | SPAN = 2 28 | POS = 3 29 | LEMMA = 4 30 | NER = 5 31 | 32 | def __init__(self, data, annotators, opts=None): 33 | self.data = data 34 | self.annotators = annotators 35 | self.opts = opts or {} 36 | 37 | def __len__(self): 38 | """The number of tokens.""" 39 | return len(self.data) 40 | 41 | def slice(self, i=None, j=None): 42 | """Return a view of the list of tokens from [i, j).""" 43 | new_tokens = copy.copy(self) 44 | new_tokens.data = self.data[i:j] 45 | return new_tokens 46 | 47 | def untokenize(self): 48 | """Returns the original text (with whitespace reinserted).""" 49 | return "".join([t[self.TEXT_WS] for t in self.data]).strip() 50 | 51 | def words(self, uncased=False): 52 | """Returns a list of the text of each token 53 | 54 | Args: 55 | uncased: lower cases text 56 | """ 57 | if uncased: 58 | return [t[self.TEXT].lower() for t in self.data] 59 | else: 60 | return [t[self.TEXT] for t in self.data] 61 | 62 | def offsets(self): 63 | """Returns a list of [start, end) character offsets of each token.""" 64 | return [t[self.SPAN] for t in self.data] 65 | 66 | def pos(self): 67 | """Returns a list of part-of-speech tags of each token. 68 | Returns None if this annotation was not included. 69 | """ 70 | if "pos" not in self.annotators: 71 | return None 72 | return [t[self.POS] for t in self.data] 73 | 74 | def lemmas(self): 75 | """Returns a list of the lemmatized text of each token. 76 | Returns None if this annotation was not included. 77 | """ 78 | if "lemma" not in self.annotators: 79 | return None 80 | return [t[self.LEMMA] for t in self.data] 81 | 82 | def entities(self): 83 | """Returns a list of named-entity-recognition tags of each token. 84 | Returns None if this annotation was not included. 85 | """ 86 | if "ner" not in self.annotators: 87 | return None 88 | return [t[self.NER] for t in self.data] 89 | 90 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 91 | """Returns a list of all ngrams from length 1 to n. 92 | 93 | Args: 94 | n: upper limit of ngram length 95 | uncased: lower cases text 96 | filter_fn: user function that takes in an ngram list and returns 97 | True or False to keep or not keep the ngram 98 | as_string: return the ngram as a string vs list 99 | """ 100 | 101 | def _skip(gram): 102 | if not filter_fn: 103 | return False 104 | return filter_fn(gram) 105 | 106 | words = self.words(uncased) 107 | ngrams = [ 108 | (s, e + 1) 109 | for s in range(len(words)) 110 | for e in range(s, min(s + n, len(words))) 111 | if not _skip(words[s : e + 1]) 112 | ] 113 | 114 | # Concatenate into strings 115 | if as_strings: 116 | ngrams = ["{}".format(" ".join(words[s:e])) for (s, e) in ngrams] 117 | 118 | return ngrams 119 | 120 | def entity_groups(self): 121 | """Group consecutive entity tokens with the same NER tag.""" 122 | entities = self.entities() 123 | if not entities: 124 | return None 125 | non_ent = self.opts.get("non_ent", "O") 126 | groups = [] 127 | idx = 0 128 | while idx < len(entities): 129 | ner_tag = entities[idx] 130 | # Check for entity tag 131 | if ner_tag != non_ent: 132 | # Chomp the sequence 133 | start = idx 134 | while idx < len(entities) and entities[idx] == ner_tag: 135 | idx += 1 136 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 137 | else: 138 | idx += 1 139 | return groups 140 | 141 | 142 | class Tokenizer(object): 143 | """Base tokenizer class. 144 | Tokenizers implement tokenize, which should return a Tokens class. 145 | """ 146 | 147 | def tokenize(self, text): 148 | raise NotImplementedError 149 | 150 | def shutdown(self): 151 | pass 152 | 153 | def __del__(self): 154 | self.shutdown() 155 | 156 | 157 | class SimpleTokenizer(Tokenizer): 158 | ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" 159 | NON_WS = r"[^\p{Z}\p{C}]" 160 | 161 | def __init__(self, **kwargs): 162 | """ 163 | Args: 164 | annotators: None or empty set (only tokenizes). 165 | """ 166 | self._regexp = regex.compile( 167 | "(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), 168 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, 169 | ) 170 | if len(kwargs.get("annotators", {})) > 0: 171 | logger.warning( 172 | "%s only tokenizes! Skipping annotators: %s" % (type(self).__name__, kwargs.get("annotators")) 173 | ) 174 | self.annotators = set() 175 | 176 | def tokenize(self, text): 177 | data = [] 178 | matches = [m for m in self._regexp.finditer(text)] 179 | for i in range(len(matches)): 180 | # Get text 181 | token = matches[i].group() 182 | 183 | # Get whitespace 184 | span = matches[i].span() 185 | start_ws = span[0] 186 | if i + 1 < len(matches): 187 | end_ws = matches[i + 1].span()[0] 188 | else: 189 | end_ws = span[1] 190 | 191 | # Format data 192 | data.append( 193 | ( 194 | token, 195 | text[start_ws:end_ws], 196 | span, 197 | ) 198 | ) 199 | return Tokens(data, self.annotators) 200 | 201 | 202 | class SpacyTokenizer(Tokenizer): 203 | def __init__(self, **kwargs): 204 | """ 205 | Args: 206 | annotators: set that can include pos, lemma, and ner. 207 | model: spaCy model to use (either path, or keyword like 'en'). 208 | """ 209 | model = kwargs.get("model", "en_core_web_sm") # TODO: replace with en ? 210 | self.annotators = copy.deepcopy(kwargs.get("annotators", set())) 211 | nlp_kwargs = {"parser": False} 212 | if not any([p in self.annotators for p in ["lemma", "pos", "ner"]]): 213 | nlp_kwargs["tagger"] = False 214 | if "ner" not in self.annotators: 215 | nlp_kwargs["entity"] = False 216 | self.nlp = spacy.load(model, **nlp_kwargs) 217 | 218 | def tokenize(self, text): 219 | # We don't treat new lines as tokens. 220 | clean_text = text.replace("\n", " ") 221 | tokens = self.nlp.tokenizer(clean_text) 222 | if any([p in self.annotators for p in ["lemma", "pos", "ner"]]): 223 | self.nlp.tagger(tokens) 224 | if "ner" in self.annotators: 225 | self.nlp.entity(tokens) 226 | 227 | data = [] 228 | for i in range(len(tokens)): 229 | # Get whitespace 230 | start_ws = tokens[i].idx 231 | if i + 1 < len(tokens): 232 | end_ws = tokens[i + 1].idx 233 | else: 234 | end_ws = tokens[i].idx + len(tokens[i].text) 235 | 236 | data.append( 237 | ( 238 | tokens[i].text, 239 | text[start_ws:end_ws], 240 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 241 | tokens[i].tag_, 242 | tokens[i].lemma_, 243 | tokens[i].ent_type_, 244 | ) 245 | ) 246 | 247 | # Set special option for non-entity tag: '' vs 'O' in spaCy 248 | return Tokens(data, self.annotators, opts={"non_ent": ""}) 249 | -------------------------------------------------------------------------------- /dense_retrieval.py: -------------------------------------------------------------------------------- 1 | """Run retrieval inference with applied hard attention mask. """ 2 | 3 | import faiss # do not remove, include this to resolve import conflict 4 | import glob 5 | import hydra 6 | from omegaconf import DictConfig, OmegaConf 7 | from dpr.options import set_cfg_params_from_state, setup_cfg_gpu, setup_logger 8 | from dpr.utils.model_utils import ( 9 | setup_for_distributed_mode, 10 | get_model_obj, 11 | load_states_from_checkpoint, 12 | move_to_device, 13 | ) 14 | 15 | import logging 16 | logger = logging.getLogger() 17 | setup_logger(logger) 18 | 19 | from dpr.models import init_biencoder_components 20 | from retrieval_utils import ( 21 | DenseRPCRetriever, LocalFaissRetriever, KiltCsvCtxSrc, 22 | generate_question_vectors, 23 | validate_from_meta, save_results_from_meta, 24 | validate_tables, validate, save_results, 25 | ) 26 | from dpr.data.biencoder_data import BiEncoderPassage 27 | 28 | 29 | 30 | def get_all_tables_as_passages(ctx_sources, cfg): 31 | all_tables = {} 32 | for ctx_src in ctx_sources: 33 | ctx_src.load_data_to(all_tables, cfg) 34 | logger.info("Loaded ctx data: %d", len(all_tables)) 35 | 36 | if len(all_tables) == 0: 37 | raise RuntimeError("No passages data found. Please specify ctx_file param properly.") 38 | 39 | all_passages = {} 40 | for k, v in all_tables.items(): 41 | text = ' '.join([c['text'] for c in v.cells]) 42 | text = f"{v.title} {text}" 43 | all_passages[k] = BiEncoderPassage(text, v.title) 44 | 45 | if len(all_passages) == 0: 46 | raise RuntimeError("No passages data found. Please specify ctx_file param properly.") 47 | return all_passages 48 | 49 | 50 | 51 | @hydra.main(config_path="conf", config_name="dense_retrieval") 52 | def main(cfg: DictConfig): 53 | cfg = setup_cfg_gpu(cfg) 54 | saved_state = load_states_from_checkpoint(cfg.model_file) 55 | 56 | set_cfg_params_from_state(saved_state.encoder_params, cfg) 57 | 58 | logger.info("CFG (after gpu configuration):") 59 | logger.info("%s", OmegaConf.to_yaml(cfg)) 60 | 61 | tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True) 62 | 63 | logger.info("Loading saved model state ...") 64 | encoder.load_state(saved_state, strict=False) 65 | 66 | encoder_path = cfg.encoder_path 67 | if encoder_path: 68 | logger.info("Selecting encoder: %s", encoder_path) 69 | encoder = getattr(encoder, encoder_path) 70 | else: 71 | logger.info("Selecting standard question encoder") 72 | encoder = encoder.question_model 73 | 74 | encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16) 75 | encoder.eval() 76 | 77 | model_to_load = get_model_obj(encoder) 78 | vector_size = model_to_load.get_out_size() 79 | logger.info("Encoder vector_size=%d", vector_size) 80 | 81 | # get questions & answers 82 | questions = [] 83 | questions_text = [] 84 | question_answers = [] 85 | 86 | if not cfg.qa_dataset: 87 | logger.warning("Please specify qa_dataset to use") 88 | return 89 | 90 | ds_key = cfg.qa_dataset 91 | logger.info("qa_dataset: %s", ds_key) 92 | 93 | qa_src = hydra.utils.instantiate(cfg.datasets[ds_key]) 94 | qa_src.load_data() 95 | 96 | total_queries = len(qa_src) 97 | for i in range(total_queries): 98 | qa_sample = qa_src[i] 99 | question, answers = qa_sample.query, qa_sample.answers 100 | questions.append(question) 101 | question_answers.append(answers) 102 | 103 | logger.info("questions len %d", len(questions)) 104 | logger.info("questions_text len %d", len(questions_text)) 105 | 106 | if cfg.rpc_retriever_cfg_file: 107 | index_buffer_sz = 1000 108 | retriever = DenseRPCRetriever( 109 | encoder, 110 | cfg.batch_size, 111 | tensorizer, 112 | cfg.rpc_retriever_cfg_file, 113 | vector_size, 114 | use_l2_conversion=cfg.use_l2_conversion, 115 | ) 116 | else: 117 | index = hydra.utils.instantiate(cfg.indexers[cfg.indexer]) 118 | logger.info("Local Index class %s ", type(index)) 119 | index_buffer_sz = index.buffer_size 120 | index.init_index(vector_size) 121 | retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index) 122 | 123 | logger.info("Using special token %s", qa_src.special_query_token) 124 | questions_tensor = retriever.generate_question_vectors(questions, query_token=qa_src.special_query_token) 125 | 126 | if qa_src.selector: 127 | logger.info("Using custom representation token selector") 128 | retriever.selector = qa_src.selector 129 | 130 | index_path = cfg.index_path 131 | if cfg.rpc_retriever_cfg_file and cfg.rpc_index_id: 132 | retriever.load_index(cfg.rpc_index_id) 133 | elif index_path and index.index_exists(index_path): 134 | logger.info("Index path: %s", index_path) 135 | retriever.index.deserialize(index_path) 136 | else: 137 | # send data for indexing 138 | id_prefixes = [] 139 | ctx_sources = [] 140 | for ctx_src in cfg.ctx_datatsets: 141 | ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src]) 142 | id_prefixes.append(ctx_src.id_prefix) 143 | ctx_sources.append(ctx_src) 144 | logger.info("ctx_sources: %s", type(ctx_src)) 145 | 146 | logger.info("id_prefixes per dataset: %s", id_prefixes) 147 | 148 | # index all passages 149 | ctx_files_patterns = cfg.encoded_ctx_files 150 | 151 | logger.info("ctx_files_patterns: %s", ctx_files_patterns) 152 | if ctx_files_patterns: 153 | assert len(ctx_files_patterns) == len(id_prefixes), "ctx len={} pref leb={}".format( 154 | len(ctx_files_patterns), len(id_prefixes) 155 | ) 156 | else: 157 | assert ( 158 | index_path or cfg.rpc_index_id 159 | ), "Either encoded_ctx_files or index_path pr rpc_index_id parameter should be set." 160 | 161 | input_paths = [] 162 | path_id_prefixes = [] 163 | for i, pattern in enumerate(ctx_files_patterns): 164 | pattern_files = glob.glob(pattern) 165 | pattern_id_prefix = id_prefixes[i] 166 | input_paths.extend(pattern_files) 167 | path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files)) 168 | logger.info("Embeddings files id prefixes: %s", path_id_prefixes) 169 | logger.info("Reading all passages data from files: %s", input_paths) 170 | retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes) 171 | if index_path: 172 | retriever.index.serialize(index_path) 173 | 174 | # get top k results 175 | top_results_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs) 176 | 177 | if cfg.use_rpc_meta: 178 | questions_doc_hits = validate_from_meta( 179 | question_answers, 180 | top_results_and_scores, 181 | cfg.validation_workers, 182 | cfg.match, 183 | cfg.rpc_meta_compressed, 184 | ) 185 | if cfg.out_file: 186 | save_results_from_meta( 187 | questions, 188 | question_answers, 189 | top_results_and_scores, 190 | questions_doc_hits, 191 | cfg.out_file, 192 | cfg.rpc_meta_compressed, 193 | ) 194 | else: 195 | all_passages = get_all_tables_as_passages(ctx_sources, cfg) 196 | if cfg.validate_as_tables: 197 | 198 | questions_doc_hits = validate_tables( 199 | all_passages, 200 | question_answers, 201 | top_results_and_scores, 202 | cfg.validation_workers, 203 | cfg.match, 204 | ) 205 | 206 | else: 207 | questions_doc_hits = validate( 208 | all_passages, 209 | question_answers, 210 | top_results_and_scores, 211 | cfg.validation_workers, 212 | cfg.match, 213 | ) 214 | 215 | if cfg.out_file: 216 | save_results( 217 | all_passages, 218 | questions_text if questions_text else questions, 219 | question_answers, 220 | top_results_and_scores, 221 | questions_doc_hits, 222 | cfg.out_file, 223 | ) 224 | 225 | if cfg.kilt_out_file: 226 | kilt_ctx = next(iter([ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)]), None) 227 | if not kilt_ctx: 228 | raise RuntimeError("No Kilt compatible context file provided") 229 | assert hasattr(cfg, "kilt_out_file") 230 | kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file) 231 | 232 | 233 | if __name__ == "__main__": 234 | main() 235 | -------------------------------------------------------------------------------- /dpr/indexer/faiss_indexers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | FAISS-based index components for dense retriever 10 | """ 11 | 12 | import faiss 13 | import logging 14 | import numpy as np 15 | import os 16 | import pickle 17 | 18 | from typing import List, Tuple 19 | 20 | logger = logging.getLogger() 21 | 22 | 23 | class DenseIndexer(object): 24 | def __init__(self, buffer_size: int = 50000): 25 | self.buffer_size = buffer_size 26 | self.index_id_to_db_id = [] 27 | self.index = None 28 | 29 | def init_index(self, vector_sz: int): 30 | raise NotImplementedError 31 | 32 | def index_data(self, data: List[Tuple[object, np.array]]): 33 | raise NotImplementedError 34 | 35 | def get_index_name(self): 36 | raise NotImplementedError 37 | 38 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 39 | raise NotImplementedError 40 | 41 | def serialize(self, file: str): 42 | logger.info("Serializing index to %s", file) 43 | 44 | if os.path.isdir(file): 45 | index_file = os.path.join(file, "index.dpr") 46 | meta_file = os.path.join(file, "index_meta.dpr") 47 | else: 48 | index_file = file + ".index.dpr" 49 | meta_file = file + ".index_meta.dpr" 50 | 51 | faiss.write_index(self.index, index_file) 52 | with open(meta_file, mode="wb") as f: 53 | pickle.dump(self.index_id_to_db_id, f) 54 | 55 | def get_files(self, path: str): 56 | if os.path.isdir(path): 57 | index_file = os.path.join(path, "index.dpr") 58 | meta_file = os.path.join(path, "index_meta.dpr") 59 | else: 60 | index_file = path + ".{}.dpr".format(self.get_index_name()) 61 | meta_file = path + ".{}_meta.dpr".format(self.get_index_name()) 62 | return index_file, meta_file 63 | 64 | def index_exists(self, path: str): 65 | index_file, meta_file = self.get_files(path) 66 | return os.path.isfile(index_file) and os.path.isfile(meta_file) 67 | 68 | def deserialize(self, path: str): 69 | logger.info("Loading index from %s", path) 70 | index_file, meta_file = self.get_files(path) 71 | 72 | self.index = faiss.read_index(index_file) 73 | logger.info("Loaded index of type %s and size %d", type(self.index), self.index.ntotal) 74 | 75 | with open(meta_file, "rb") as reader: 76 | self.index_id_to_db_id = pickle.load(reader) 77 | assert ( 78 | len(self.index_id_to_db_id) == self.index.ntotal 79 | ), "Deserialized index_id_to_db_id should match faiss index size" 80 | 81 | def _update_id_mapping(self, db_ids: List) -> int: 82 | self.index_id_to_db_id.extend(db_ids) 83 | return len(self.index_id_to_db_id) 84 | 85 | 86 | class DenseFlatIndexer(DenseIndexer): 87 | def __init__(self, buffer_size: int = 50000): 88 | super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) 89 | 90 | def init_index(self, vector_sz: int): 91 | self.index = faiss.IndexFlatIP(vector_sz) 92 | 93 | def index_data(self, data: List[Tuple[object, np.array]]): 94 | n = len(data) 95 | # indexing in batches is beneficial for many faiss index types 96 | for i in range(0, n, self.buffer_size): 97 | db_ids = [t[0] for t in data[i : i + self.buffer_size]] 98 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]] 99 | vectors = np.concatenate(vectors, axis=0) 100 | total_data = self._update_id_mapping(db_ids) 101 | self.index.add(vectors) 102 | logger.info("data indexed %d", total_data) 103 | 104 | indexed_cnt = len(self.index_id_to_db_id) 105 | logger.info("Total data indexed %d", indexed_cnt) 106 | 107 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 108 | scores, indexes = self.index.search(query_vectors, top_docs) 109 | # convert to external ids 110 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 111 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 112 | return result 113 | 114 | def get_index_name(self): 115 | return "flat_index" 116 | 117 | 118 | class DenseHNSWFlatIndexer(DenseIndexer): 119 | """ 120 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 121 | """ 122 | 123 | def __init__( 124 | self, 125 | buffer_size: int = 1e9, 126 | store_n: int = 512, 127 | ef_search: int = 128, 128 | ef_construction: int = 200, 129 | ): 130 | super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) 131 | self.store_n = store_n 132 | self.ef_search = ef_search 133 | self.ef_construction = ef_construction 134 | self.phi = 0 135 | 136 | def init_index(self, vector_sz: int): 137 | # IndexHNSWFlat supports L2 similarity only 138 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 139 | index = faiss.IndexHNSWFlat(vector_sz + 1, self.store_n) 140 | index.hnsw.efSearch = self.ef_search 141 | index.hnsw.efConstruction = self.ef_construction 142 | self.index = index 143 | 144 | def index_data(self, data: List[Tuple[object, np.array]]): 145 | n = len(data) 146 | 147 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 148 | if self.phi > 0: 149 | raise RuntimeError( 150 | "DPR HNSWF index needs to index all data at once," "results will be unpredictable otherwise." 151 | ) 152 | phi = 0 153 | for i, item in enumerate(data): 154 | id, doc_vector = item[0:2] 155 | norms = (doc_vector ** 2).sum() 156 | phi = max(phi, norms) 157 | logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) 158 | self.phi = phi 159 | 160 | # indexing in batches is beneficial for many faiss index types 161 | bs = int(self.buffer_size) 162 | for i in range(0, n, bs): 163 | db_ids = [t[0] for t in data[i : i + bs]] 164 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + bs]] 165 | 166 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 167 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 168 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in enumerate(vectors)] 169 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 170 | self.train(hnsw_vectors) 171 | 172 | self._update_id_mapping(db_ids) 173 | self.index.add(hnsw_vectors) 174 | logger.info("data indexed %d", len(self.index_id_to_db_id)) 175 | indexed_cnt = len(self.index_id_to_db_id) 176 | logger.info("Total data indexed %d", indexed_cnt) 177 | 178 | def train(self, vectors: np.array): 179 | pass 180 | 181 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 182 | 183 | aux_dim = np.zeros(len(query_vectors), dtype="float32") 184 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 185 | logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) 186 | scores, indexes = self.index.search(query_nhsw_vectors, top_docs) 187 | # convert to external ids 188 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 189 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 190 | return result 191 | 192 | def deserialize(self, file: str): 193 | super(DenseHNSWFlatIndexer, self).deserialize(file) 194 | # to trigger exception on subsequent indexing 195 | self.phi = 1 196 | 197 | def get_index_name(self): 198 | return "hnsw_index" 199 | 200 | 201 | class DenseHNSWSQIndexer(DenseHNSWFlatIndexer): 202 | """ 203 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 204 | """ 205 | 206 | def __init__( 207 | self, 208 | buffer_size: int = 1e10, 209 | store_n: int = 128, 210 | ef_search: int = 128, 211 | ef_construction: int = 200, 212 | ): 213 | super(DenseHNSWSQIndexer, self).__init__( 214 | buffer_size=buffer_size, 215 | store_n=store_n, 216 | ef_search=ef_search, 217 | ef_construction=ef_construction, 218 | ) 219 | 220 | def init_index(self, vector_sz: int): 221 | # IndexHNSWFlat supports L2 similarity only 222 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 223 | index = faiss.IndexHNSWSQ(vector_sz + 1, faiss.ScalarQuantizer.QT_8bit, self.store_n) 224 | index.hnsw.efSearch = self.ef_search 225 | index.hnsw.efConstruction = self.ef_construction 226 | self.index = index 227 | 228 | def train(self, vectors: np.array): 229 | self.index.train(vectors) 230 | 231 | def get_index_name(self): 232 | return "hnswsq_index" 233 | -------------------------------------------------------------------------------- /dpr/data/qa_validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Set of utilities for Q&A results validation tasks - Retriver passage validation and Reader predicted answer validation 10 | """ 11 | 12 | import collections 13 | import logging 14 | import string 15 | import unicodedata 16 | import zlib 17 | from functools import partial 18 | from multiprocessing import Pool as ProcessPool 19 | from typing import Tuple, List, Dict 20 | 21 | import regex as re 22 | 23 | from dpr.data.retriever_data import TableChunk 24 | from dpr.utils.tokenizers import SimpleTokenizer 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | QAMatchStats = collections.namedtuple("QAMatchStats", ["top_k_hits", "questions_doc_hits"]) 29 | 30 | QATableMatchStats = collections.namedtuple( 31 | "QAMatchStats", ["top_k_chunk_hits", "top_k_table_hits", "questions_doc_hits"] 32 | ) 33 | 34 | 35 | def calculate_matches( 36 | all_docs: Dict[object, Tuple[str, str]], 37 | answers: List[List[str]], 38 | closest_docs: List[Tuple[List[object], List[float]]], 39 | workers_num: int, 40 | match_type: str, 41 | ) -> QAMatchStats: 42 | """ 43 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 44 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 45 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 46 | :param answers: list of answers's list. One list per question 47 | :param closest_docs: document ids of the top results along with their scores 48 | :param workers_num: amount of parallel threads to process data 49 | :param match_type: type of answer matching. Refer to has_answer code for available options 50 | :return: matching information tuple. 51 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 52 | valid matches across an entire dataset. 53 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 54 | """ 55 | logger.info("all_docs size %d", len(all_docs)) 56 | global dpr_all_documents 57 | dpr_all_documents = all_docs 58 | logger.info("dpr_all_documents size %d", len(dpr_all_documents)) 59 | 60 | tok_opts = {} 61 | tokenizer = SimpleTokenizer(**tok_opts) 62 | 63 | processes = ProcessPool(processes=workers_num) 64 | logger.info("Matching answers in top docs...") 65 | get_score_partial = partial(check_answer, match_type=match_type, tokenizer=tokenizer) 66 | 67 | questions_answers_docs = zip(answers, closest_docs) 68 | scores = processes.map(get_score_partial, questions_answers_docs) 69 | 70 | logger.info("Per question validation results len=%d", len(scores)) 71 | 72 | n_docs = len(closest_docs[0][0]) 73 | top_k_hits = [0] * n_docs 74 | for question_hits in scores: 75 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 76 | if best_hit is not None: 77 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 78 | 79 | return QAMatchStats(top_k_hits, scores) 80 | 81 | 82 | def calculate_matches_from_meta( 83 | answers: List[List[str]], 84 | closest_docs: List[Tuple[List[object], List[float]]], 85 | workers_num: int, 86 | match_type: str, 87 | use_title: bool = False, 88 | meta_compressed: bool = False, 89 | ) -> QAMatchStats: 90 | 91 | tok_opts = {} 92 | tokenizer = SimpleTokenizer(**tok_opts) 93 | 94 | processes = ProcessPool(processes=workers_num) 95 | logger.info("Matching answers in top docs...") 96 | get_score_partial = partial( 97 | check_answer_from_meta, 98 | match_type=match_type, 99 | tokenizer=tokenizer, 100 | use_title=use_title, 101 | meta_compressed=meta_compressed, 102 | ) 103 | 104 | questions_answers_docs = zip(answers, closest_docs) 105 | scores = processes.map(get_score_partial, questions_answers_docs) 106 | 107 | logger.info("Per question validation results len=%d", len(scores)) 108 | 109 | n_docs = len(closest_docs[0][0]) 110 | top_k_hits = [0] * n_docs 111 | for question_hits in scores: 112 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 113 | if best_hit is not None: 114 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 115 | 116 | return QAMatchStats(top_k_hits, scores) 117 | 118 | 119 | def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: 120 | """Search through all the top docs to see if they have any of the answers.""" 121 | answers, (doc_ids, doc_scores) = questions_answers_docs 122 | 123 | global dpr_all_documents 124 | # print(f"excerpt of dpr document title: {list(dpr_all_documents.keys())[:10]}") 125 | hits = [] 126 | 127 | for i, doc_id in enumerate(doc_ids): 128 | doc = dpr_all_documents[doc_id] 129 | text = doc[0] 130 | 131 | answer_found = False 132 | if text is None: # cannot find the document for some reason 133 | logger.warning("no doc in db") 134 | hits.append(False) 135 | continue 136 | if match_type == "kilt": 137 | if has_answer_kilt(answers, text): 138 | answer_found = True 139 | elif has_answer(answers, text, tokenizer, match_type): 140 | answer_found = True 141 | hits.append(answer_found) 142 | return hits 143 | 144 | 145 | def check_answer_from_meta( 146 | questions_answers_docs, 147 | tokenizer, 148 | match_type, 149 | meta_body_idx: int = 1, 150 | meta_title_idx: int = 2, 151 | use_title: bool = False, 152 | meta_compressed: bool = False, 153 | ) -> List[bool]: 154 | """Search through all the top docs to see if they have any of the answers.""" 155 | answers, (docs_meta, doc_scores) = questions_answers_docs 156 | 157 | hits = [] 158 | 159 | for i, doc_meta in enumerate(docs_meta): 160 | 161 | text = doc_meta[meta_body_idx] 162 | title = doc_meta[meta_title_idx] if len(doc_meta) > meta_title_idx else "" 163 | if meta_compressed: 164 | text = zlib.decompress(text).decode() 165 | title = zlib.decompress(title).decode() 166 | 167 | if use_title: 168 | text = title + " . " + text 169 | answer_found = False 170 | if has_answer(answers, text, tokenizer, match_type): 171 | answer_found = True 172 | hits.append(answer_found) 173 | return hits 174 | 175 | 176 | def has_answer(answers, text, tokenizer, match_type) -> bool: 177 | """Check if a document contains an answer string. 178 | If `match_type` is string, token matching is done between the text and answer. 179 | If `match_type` is regex, we search the whole text with the regex. 180 | """ 181 | text = _normalize(text) 182 | 183 | if match_type == "string": 184 | # Answer is a list of possible strings 185 | text = tokenizer.tokenize(text).words(uncased=True) 186 | 187 | for single_answer in answers: 188 | single_answer = _normalize(single_answer) 189 | single_answer = tokenizer.tokenize(single_answer) 190 | single_answer = single_answer.words(uncased=True) 191 | 192 | for i in range(0, len(text) - len(single_answer) + 1): 193 | if single_answer == text[i : i + len(single_answer)]: 194 | return True 195 | 196 | elif match_type == "regex": 197 | # Answer is a regex 198 | for single_answer in answers: 199 | single_answer = _normalize(single_answer) 200 | if regex_match(text, single_answer): 201 | return True 202 | return False 203 | 204 | 205 | def regex_match(text, pattern): 206 | """Test if a regex pattern is contained within a text.""" 207 | try: 208 | pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) 209 | except BaseException: 210 | return False 211 | return pattern.search(text) is not None 212 | 213 | 214 | # function for the reader model answer validation 215 | def exact_match_score(prediction, ground_truth): 216 | return _normalize_answer(prediction) == _normalize_answer(ground_truth) 217 | 218 | 219 | def _normalize_answer(s): 220 | def remove_articles(text): 221 | return re.sub(r"\b(a|an|the)\b", " ", text) 222 | 223 | def white_space_fix(text): 224 | return " ".join(text.split()) 225 | 226 | def remove_punc(text): 227 | exclude = set(string.punctuation) 228 | return "".join(ch for ch in text if ch not in exclude) 229 | 230 | def lower(text): 231 | return text.lower() 232 | 233 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 234 | 235 | 236 | def _normalize(text): 237 | return unicodedata.normalize("NFD", text) 238 | 239 | 240 | def calculate_chunked_matches( 241 | all_docs: Dict[object, TableChunk], 242 | answers: List[List[str]], 243 | closest_docs: List[Tuple[List[object], List[float]]], 244 | workers_num: int, 245 | match_type: str, 246 | ) -> QATableMatchStats: 247 | global dpr_all_documents 248 | dpr_all_documents = all_docs 249 | 250 | global dpr_all_tables 251 | dpr_all_tables = {} 252 | 253 | for key, table_chunk in all_docs.items(): 254 | table_str, title, table_id = table_chunk 255 | table_chunks = dpr_all_tables.get(table_id, []) 256 | table_chunks.append((table_str, title)) 257 | dpr_all_tables[table_id] = table_chunks 258 | 259 | tok_opts = {} 260 | tokenizer = SimpleTokenizer(**tok_opts) 261 | 262 | processes = ProcessPool(processes=workers_num) 263 | 264 | logger.info("Matching answers in top docs...") 265 | get_score_partial = partial(check_chunked_docs_answer, match_type=match_type, tokenizer=tokenizer) 266 | questions_answers_docs = zip(answers, closest_docs) 267 | scores = processes.map(get_score_partial, questions_answers_docs) 268 | logger.info("Per question validation results len=%d", len(scores)) 269 | 270 | n_docs = len(closest_docs[0][0]) 271 | top_k_hits = [0] * n_docs 272 | top_k_orig_hits = [0] * n_docs 273 | for s in scores: 274 | question_hits, question_orig_doc_hits = s 275 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 276 | if best_hit is not None: 277 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 278 | 279 | best_hit = next((i for i, x in enumerate(question_orig_doc_hits) if x), None) 280 | if best_hit is not None: 281 | top_k_orig_hits[best_hit:] = [v + 1 for v in top_k_orig_hits[best_hit:]] 282 | 283 | return QATableMatchStats(top_k_hits, top_k_orig_hits, scores) 284 | 285 | 286 | # -------------------- KILT eval --------------------------------- 287 | 288 | 289 | def has_answer_kilt(answers, text) -> bool: 290 | text = normalize_kilt(text) 291 | for single_answer in answers: 292 | single_answer = normalize_kilt(single_answer) 293 | if single_answer in text: 294 | return True 295 | return False 296 | 297 | 298 | # answer normalization 299 | def normalize_kilt(s): 300 | """Lower text and remove punctuation, articles and extra whitespace.""" 301 | 302 | def remove_articles(text): 303 | return re.sub(r"\b(a|an|the)\b", " ", text) 304 | 305 | def white_space_fix(text): 306 | return " ".join(text.split()) 307 | 308 | def remove_punc(text): 309 | exclude = set(string.punctuation) 310 | return "".join(ch for ch in text if ch not in exclude) 311 | 312 | def lower(text): 313 | return text.lower() 314 | 315 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 316 | -------------------------------------------------------------------------------- /dpr/models/reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | The reader model code + its utilities (loss computation and input batch tensor generator) 10 | """ 11 | 12 | import collections 13 | import logging 14 | from typing import List 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | from torch import Tensor as T 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from dpr.data.reader_data import ReaderSample, ReaderPassage 23 | from dpr.utils.model_utils import init_weights 24 | logger = logging.getLogger() 25 | 26 | ReaderBatch = collections.namedtuple( 27 | "ReaderBatch", ["input_ids", "start_positions", "end_positions", "answers_mask", "token_type_ids"] 28 | ) 29 | 30 | 31 | class Reader(nn.Module): 32 | def __init__(self, encoder: nn.Module, hidden_size): 33 | super(Reader, self).__init__() 34 | self.encoder = encoder 35 | self.qa_outputs = nn.Linear(hidden_size, 2) 36 | self.qa_classifier = nn.Linear(hidden_size, 1) 37 | init_weights([self.qa_outputs, self.qa_classifier]) 38 | 39 | def forward( 40 | self, 41 | input_ids: T, 42 | attention_mask: T, 43 | toke_type_ids: T, 44 | start_positions=None, 45 | end_positions=None, 46 | answer_mask=None, 47 | ): 48 | # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length 49 | N, M, L = input_ids.size() 50 | start_logits, end_logits, relevance_logits = self._forward( 51 | input_ids.view(N * M, L), 52 | attention_mask.view(N * M, L), 53 | toke_type_ids.view(N * M, L), 54 | ) 55 | if self.training: 56 | return compute_loss( 57 | start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M 58 | ) 59 | 60 | return start_logits.view(N, M, L), end_logits.view(N, M, L), relevance_logits.view(N, M) 61 | 62 | def _forward(self, input_ids, attention_mask, toke_type_ids: T): 63 | sequence_output, _pooled_output, _hidden_states = self.encoder(input_ids, toke_type_ids, attention_mask) 64 | logits = self.qa_outputs(sequence_output) 65 | start_logits, end_logits = logits.split(1, dim=-1) 66 | start_logits = start_logits.squeeze(-1) 67 | end_logits = end_logits.squeeze(-1) 68 | rank_logits = self.qa_classifier(sequence_output[:, 0, :]) 69 | return start_logits, end_logits, rank_logits 70 | 71 | 72 | def compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M): 73 | start_positions = start_positions.view(N * M, -1) 74 | end_positions = end_positions.view(N * M, -1) 75 | answer_mask = answer_mask.view(N * M, -1) 76 | 77 | start_logits = start_logits.view(N * M, -1) 78 | end_logits = end_logits.view(N * M, -1) 79 | relevance_logits = relevance_logits.view(N * M) 80 | 81 | answer_mask = answer_mask.type(torch.FloatTensor).cuda() 82 | 83 | ignored_index = start_logits.size(1) 84 | start_positions.clamp_(0, ignored_index) 85 | end_positions.clamp_(0, ignored_index) 86 | loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index) 87 | 88 | # compute switch loss 89 | relevance_logits = relevance_logits.view(N, M) 90 | switch_labels = torch.zeros(N, dtype=torch.long).cuda() 91 | switch_loss = torch.sum(loss_fct(relevance_logits, switch_labels)) 92 | 93 | # compute span loss 94 | start_losses = [ 95 | (loss_fct(start_logits, _start_positions) * _span_mask) 96 | for (_start_positions, _span_mask) in zip( 97 | torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1) 98 | ) 99 | ] 100 | 101 | end_losses = [ 102 | (loss_fct(end_logits, _end_positions) * _span_mask) 103 | for (_end_positions, _span_mask) in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1)) 104 | ] 105 | loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat( 106 | [t.unsqueeze(1) for t in end_losses], dim=1 107 | ) 108 | 109 | loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0] 110 | span_loss = _calc_mml(loss_tensor) 111 | return span_loss + switch_loss 112 | 113 | 114 | def create_reader_input( 115 | pad_token_id: int, 116 | samples: List[ReaderSample], 117 | passages_per_question: int, 118 | max_length: int, 119 | max_n_answers: int, 120 | is_train: bool, 121 | shuffle: bool, 122 | sep_token_id: int, 123 | ) -> ReaderBatch: 124 | """ 125 | Creates a reader batch instance out of a list of ReaderSample-s 126 | :param pad_token_id: id of the padding token 127 | :param samples: list of samples to create the batch for 128 | :param passages_per_question: amount of passages for every question in a batch 129 | :param max_length: max model input sequence length 130 | :param max_n_answers: max num of answers per single question 131 | :param is_train: if the samples are for a train set 132 | :param shuffle: should passages selection be randomized 133 | :return: ReaderBatch instance 134 | """ 135 | input_ids = [] 136 | start_positions = [] 137 | end_positions = [] 138 | answers_masks = [] 139 | token_type_ids = [] 140 | empty_sequence = torch.Tensor().new_full((max_length,), pad_token_id, dtype=torch.long) 141 | 142 | for sample in samples: 143 | positive_ctxs = sample.positive_passages 144 | negative_ctxs = sample.negative_passages if is_train else sample.passages 145 | 146 | sample_tensors = _create_question_passages_tensors( 147 | positive_ctxs, 148 | negative_ctxs, 149 | passages_per_question, 150 | empty_sequence, 151 | max_n_answers, 152 | pad_token_id, 153 | sep_token_id, 154 | is_train, 155 | is_random=shuffle, 156 | ) 157 | if not sample_tensors: 158 | logger.debug("No valid passages combination for question=%s ", sample.question) 159 | continue 160 | sample_input_ids, starts_tensor, ends_tensor, answer_mask, sample_ttids = sample_tensors 161 | input_ids.append(sample_input_ids) 162 | token_type_ids.append(sample_ttids) 163 | if is_train: 164 | start_positions.append(starts_tensor) 165 | end_positions.append(ends_tensor) 166 | answers_masks.append(answer_mask) 167 | input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0) 168 | token_type_ids = torch.cat([ids.unsqueeze(0) for ids in token_type_ids], dim=0) # .unsqueeze(0) 169 | 170 | if is_train: 171 | start_positions = torch.stack(start_positions, dim=0) 172 | end_positions = torch.stack(end_positions, dim=0) 173 | answers_masks = torch.stack(answers_masks, dim=0) 174 | 175 | return ReaderBatch(input_ids, start_positions, end_positions, answers_masks, token_type_ids) 176 | 177 | 178 | def _calc_mml(loss_tensor): 179 | marginal_likelihood = torch.sum(torch.exp(-loss_tensor - 1e10 * (loss_tensor == 0).float()), 1) 180 | return -torch.sum( 181 | torch.log(marginal_likelihood + torch.ones(loss_tensor.size(0)).cuda() * (marginal_likelihood == 0).float()) 182 | ) 183 | 184 | 185 | def _pad_to_len(seq: T, pad_id: int, max_len: int): 186 | s_len = seq.size(0) 187 | if s_len > max_len: 188 | return seq[0:max_len] 189 | return torch.cat([seq, torch.Tensor().new_full((max_len - s_len,), pad_id, dtype=torch.long)], dim=0) 190 | 191 | 192 | def _get_answer_spans(idx, positives: List[ReaderPassage], max_len: int): 193 | positive_a_spans = positives[idx].answers_spans 194 | return [span for span in positive_a_spans if (span[0] < max_len and span[1] < max_len)] 195 | 196 | 197 | def _get_positive_idx(positives: List[ReaderPassage], max_len: int, is_random: bool): 198 | # select just one positive 199 | positive_idx = np.random.choice(len(positives)) if is_random else 0 200 | 201 | if not _get_answer_spans(positive_idx, positives, max_len): 202 | # question may be too long, find the first positive with at least one valid span 203 | positive_idx = next((i for i in range(len(positives)) if _get_answer_spans(i, positives, max_len)), None) 204 | return positive_idx 205 | 206 | 207 | def _create_question_passages_tensors( 208 | positives: List[ReaderPassage], 209 | negatives: List[ReaderPassage], 210 | total_size: int, 211 | empty_ids: T, 212 | max_n_answers: int, 213 | pad_token_id: int, 214 | sep_token_id: int, 215 | is_train: bool, 216 | is_random: bool = True, 217 | first_segment_ttid: int = 0, 218 | ): 219 | max_len = empty_ids.size(0) 220 | if is_train: 221 | # select just one positive 222 | positive_idx = _get_positive_idx(positives, max_len, is_random) 223 | if positive_idx is None: 224 | return None 225 | 226 | positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0:max_n_answers] 227 | 228 | answer_starts = [span[0] for span in positive_a_spans] 229 | answer_ends = [span[1] for span in positive_a_spans] 230 | 231 | assert all(s < max_len for s in answer_starts) 232 | assert all(e < max_len for e in answer_ends) 233 | 234 | positive_input_ids = _pad_to_len(positives[positive_idx].sequence_ids, pad_token_id, max_len) 235 | 236 | answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() 237 | answer_starts_tensor[0, 0 : len(answer_starts)] = torch.tensor(answer_starts) 238 | 239 | answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() 240 | answer_ends_tensor[0, 0 : len(answer_ends)] = torch.tensor(answer_ends) 241 | 242 | answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) 243 | answer_mask[0, 0 : len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) 244 | 245 | positives_selected = [positive_input_ids] 246 | 247 | else: 248 | positives_selected = [] 249 | answer_starts_tensor = None 250 | answer_ends_tensor = None 251 | answer_mask = None 252 | 253 | positives_num = len(positives_selected) 254 | negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range(len(negatives) - positives_num) 255 | 256 | negative_idxs = negative_idxs[: total_size - positives_num] 257 | 258 | negatives_selected = [_pad_to_len(negatives[i].sequence_ids, pad_token_id, max_len) for i in negative_idxs] 259 | negatives_num = len(negatives_selected) 260 | 261 | input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) 262 | 263 | toke_type_ids = _create_token_type_ids(input_ids, sep_token_id, first_segment_ttid) 264 | 265 | if positives_num + negatives_num < total_size: 266 | empty_negatives = [empty_ids.clone().view(1, -1) for _ in range(total_size - (positives_num + negatives_num))] 267 | empty_token_type_ids = [ 268 | empty_ids.clone().view(1, -1) for _ in range(total_size - (positives_num + negatives_num)) 269 | ] 270 | 271 | input_ids = torch.cat([input_ids, *empty_negatives], dim=0) 272 | toke_type_ids = torch.cat([toke_type_ids, *empty_token_type_ids], dim=0) 273 | 274 | return input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask, toke_type_ids 275 | 276 | 277 | def _create_token_type_ids(input_ids: torch.Tensor, sep_token_id: int, first_segment_ttid: int = 0): 278 | 279 | token_type_ids = torch.full(input_ids.shape, fill_value=0) 280 | # return token_type_ids 281 | sep_tokens_indexes = torch.nonzero(input_ids == sep_token_id) 282 | bsz = input_ids.size(0) 283 | second_ttid = 0 if first_segment_ttid == 1 else 1 284 | 285 | for i in range(bsz): 286 | token_type_ids[i, 0 : sep_tokens_indexes[2 * i, 1] + 1] = first_segment_ttid 287 | token_type_ids[i, sep_tokens_indexes[2 * i, 1] + 1 :] = second_ttid 288 | return token_type_ids 289 | -------------------------------------------------------------------------------- /dpr/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for general purpose data processing 10 | """ 11 | import json 12 | import logging 13 | import pickle 14 | import random 15 | 16 | import itertools 17 | import math 18 | 19 | import hydra 20 | import jsonlines 21 | import torch 22 | from omegaconf import DictConfig 23 | from torch import Tensor as T 24 | from typing import List, Iterator, Callable, Tuple 25 | 26 | logger = logging.getLogger() 27 | 28 | 29 | def read_serialized_data_from_files(paths: List[str]) -> List: 30 | results = [] 31 | for i, path in enumerate(paths): 32 | with open(path, "rb") as reader: 33 | logger.info("Reading file %s", path) 34 | data = pickle.load(reader) 35 | results.extend(data) 36 | logger.info("Aggregated data size: {}".format(len(results))) 37 | logger.info("Total data size: {}".format(len(results))) 38 | return results 39 | 40 | 41 | def read_data_from_json_files(paths: List[str]) -> List: 42 | results = [] 43 | for i, path in enumerate(paths): 44 | with open(path, "r", encoding="utf-8") as f: 45 | logger.info("Reading file %s" % path) 46 | data = json.load(f) 47 | results.extend(data) 48 | logger.info("Aggregated data size: {}".format(len(results))) 49 | return results 50 | 51 | 52 | def read_data_from_jsonl_files(paths: List[str]) -> List: 53 | results = [] 54 | for i, path in enumerate(paths): 55 | logger.info("Reading file %s" % path) 56 | with jsonlines.open(path, mode="r") as jsonl_reader: 57 | data = [r for r in jsonl_reader] 58 | results.extend(data) 59 | logger.info("Aggregated data size: {}".format(len(results))) 60 | return results 61 | 62 | 63 | def normalize_question(question: str) -> str: 64 | question = question.replace("’", "'") 65 | return question 66 | 67 | 68 | class Tensorizer(object): 69 | """ 70 | Component for all text to model input data conversions and related utility methods 71 | """ 72 | 73 | # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) 74 | def text_to_tensor( 75 | self, 76 | text: str, 77 | title: str = None, 78 | add_special_tokens: bool = True, 79 | apply_max_len: bool = True, 80 | ): 81 | raise NotImplementedError 82 | 83 | def get_pair_separator_ids(self) -> T: 84 | raise NotImplementedError 85 | 86 | def get_pad_id(self) -> int: 87 | raise NotImplementedError 88 | 89 | def get_attn_mask(self, tokens_tensor: T): 90 | raise NotImplementedError 91 | 92 | def is_sub_word_id(self, token_id: int): 93 | raise NotImplementedError 94 | 95 | def to_string(self, token_ids, skip_special_tokens=True): 96 | raise NotImplementedError 97 | 98 | def set_pad_to_max(self, pad: bool): 99 | raise NotImplementedError 100 | 101 | def get_token_id(self, token: str) -> int: 102 | raise NotImplementedError 103 | 104 | 105 | class RepTokenSelector(object): 106 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 107 | raise NotImplementedError 108 | 109 | 110 | class RepStaticPosTokenSelector(RepTokenSelector): 111 | def __init__(self, static_position: int = 0): 112 | self.static_position = static_position 113 | 114 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 115 | return self.static_position 116 | 117 | 118 | class RepSpecificTokenSelector(RepTokenSelector): 119 | def __init__(self, token: str = "[CLS]"): 120 | self.token = token 121 | self.token_id = None 122 | 123 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 124 | if not self.token_id: 125 | self.token_id = tenzorizer.get_token_id(self.token) 126 | token_indexes = (input_ids == self.token_id).nonzero() 127 | # check if all samples in input_ids has index presence and out a default value otherwise 128 | bsz = input_ids.size(0) 129 | if bsz == token_indexes.size(0): 130 | return token_indexes 131 | 132 | token_indexes_result = [] 133 | found_idx_cnt = 0 134 | for i in range(bsz): 135 | if found_idx_cnt < token_indexes.size(0) and token_indexes[found_idx_cnt][0] == i: 136 | # this samples has the special token 137 | token_indexes_result.append(token_indexes[found_idx_cnt]) 138 | found_idx_cnt += 1 139 | else: 140 | logger.warning("missing special token %s", input_ids[i]) 141 | 142 | token_indexes_result.append( 143 | torch.tensor([i, 0]).to(input_ids.device) 144 | ) # setting 0-th token, i.e. CLS for BERT as the special one 145 | token_indexes_result = torch.stack(token_indexes_result, dim=0) 146 | return token_indexes_result 147 | 148 | 149 | DEFAULT_SELECTOR = RepStaticPosTokenSelector() 150 | 151 | 152 | class Dataset(torch.utils.data.Dataset): 153 | def __init__( 154 | self, 155 | selector: DictConfig = None, 156 | special_token: str = None, 157 | shuffle_positives: bool = False, 158 | query_special_suffix: str = None, 159 | encoder_type: str = None, 160 | ): 161 | if selector: 162 | self.selector = hydra.utils.instantiate(selector) 163 | else: 164 | self.selector = DEFAULT_SELECTOR 165 | self.special_token = special_token 166 | self.encoder_type = encoder_type 167 | self.shuffle_positives = shuffle_positives 168 | self.query_special_suffix = query_special_suffix 169 | self.data = [] 170 | 171 | def load_data(self, start_pos: int = -1, end_pos: int = -1): 172 | raise NotImplementedError 173 | 174 | def calc_total_data_len(self): 175 | raise NotImplementedError 176 | 177 | def __len__(self): 178 | return len(self.data) 179 | 180 | def __getitem__(self, index): 181 | raise NotImplementedError 182 | 183 | def _process_query(self, query: str): 184 | # as of now, always normalize query 185 | query = normalize_question(query) 186 | if self.query_special_suffix and not query.endswith(self.query_special_suffix): 187 | query += self.query_special_suffix 188 | 189 | return query 190 | 191 | 192 | # TODO: to be fully replaced with LocalSharded{...}. Keeping it only for old results reproduction compatibility 193 | class ShardedDataIterator(object): 194 | """ 195 | General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of 196 | the data. 197 | Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. 198 | It fills the extra sample by just taking first samples in a shard. 199 | It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). 200 | """ 201 | 202 | def __init__( 203 | self, 204 | dataset: Dataset, 205 | shard_id: int = 0, 206 | num_shards: int = 1, 207 | batch_size: int = 1, 208 | shuffle=True, 209 | shuffle_seed: int = 0, 210 | offset: int = 0, 211 | strict_batch_size: bool = False, 212 | ): 213 | 214 | self.dataset = dataset 215 | self.shard_id = shard_id 216 | self.num_shards = num_shards 217 | self.iteration = offset # to track in-shard iteration status 218 | self.shuffle = shuffle 219 | self.batch_size = batch_size 220 | self.shuffle_seed = shuffle_seed 221 | self.strict_batch_size = strict_batch_size 222 | self.shard_start_idx = -1 223 | self.shard_end_idx = -1 224 | self.max_iterations = 0 225 | 226 | def calculate_shards(self): 227 | logger.info("Calculating shard positions") 228 | shards_num = max(self.num_shards, 1) 229 | shard_id = max(self.shard_id, 0) 230 | 231 | total_size = self.dataset.calc_total_data_len() 232 | samples_per_shard = math.ceil(total_size / shards_num) 233 | 234 | self.shard_start_idx = shard_id * samples_per_shard 235 | self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) 236 | 237 | if self.strict_batch_size: 238 | self.max_iterations = math.ceil(samples_per_shard / self.batch_size) 239 | else: 240 | self.max_iterations = int(samples_per_shard / self.batch_size) 241 | 242 | logger.info( 243 | "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d", 244 | samples_per_shard, 245 | self.shard_start_idx, 246 | self.shard_end_idx, 247 | self.max_iterations, 248 | ) 249 | 250 | def load_data(self): 251 | self.calculate_shards() 252 | self.dataset.load_data() 253 | logger.info("Sharded dataset data %d", len(self.dataset)) 254 | 255 | def total_data_len(self) -> int: 256 | return len(self.dataset) 257 | 258 | def iterations_num(self) -> int: 259 | return self.max_iterations - self.iteration 260 | 261 | def max_iterations_num(self) -> int: 262 | return self.max_iterations 263 | 264 | def get_iteration(self) -> int: 265 | return self.iteration 266 | 267 | def apply(self, visitor_func: Callable): 268 | for sample in self.dataset: 269 | visitor_func(sample) 270 | 271 | def get_shard_indices(self, epoch: int): 272 | indices = list(range(len(self.dataset))) 273 | if self.shuffle: 274 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 275 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 276 | epoch_rnd.shuffle(indices) 277 | shard_indices = indices[self.shard_start_idx : self.shard_end_idx] 278 | return shard_indices 279 | 280 | # TODO: merge with iterate_ds_sampled_data 281 | def iterate_ds_data(self, epoch: int = 0) -> Iterator[List]: 282 | # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations 283 | max_iterations = self.max_iterations - self.iteration 284 | shard_indices = self.get_shard_indices(epoch) 285 | 286 | for i in range(self.iteration * self.batch_size, len(shard_indices), self.batch_size): 287 | items_idxs = shard_indices[i : i + self.batch_size] 288 | if self.strict_batch_size and len(items_idxs) < self.batch_size: 289 | logger.debug("Extending batch to max size") 290 | items_idxs.extend(shard_indices[0 : self.batch_size - len(items)]) 291 | self.iteration += 1 292 | items = [self.dataset[idx] for idx in items_idxs] 293 | yield items 294 | 295 | # some shards may done iterating while the others are at the last batch. Just return the first batch 296 | while self.iteration < max_iterations: 297 | logger.debug("Fulfilling non complete shard=".format(self.shard_id)) 298 | self.iteration += 1 299 | items_idxs = shard_indices[0 : self.batch_size] 300 | items = [self.dataset[idx] for idx in items_idxs] 301 | yield items 302 | 303 | logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id)) 304 | # reset the iteration status 305 | self.iteration = 0 306 | 307 | def iterate_ds_sampled_data(self, num_iterations: int, epoch: int = 0) -> Iterator[List]: 308 | self.iteration = 0 309 | shard_indices = self.get_shard_indices(epoch) 310 | cycle_it = itertools.cycle(shard_indices) 311 | for i in range(num_iterations): 312 | items_idxs = [next(cycle_it) for _ in range(self.batch_size)] 313 | self.iteration += 1 314 | items = [self.dataset[idx] for idx in items_idxs] 315 | yield items 316 | 317 | logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id)) 318 | # TODO: reset the iteration status? 319 | self.iteration = 0 320 | 321 | def get_dataset(self) -> Dataset: 322 | return self.dataset 323 | 324 | 325 | class LocalShardedDataIterator(ShardedDataIterator): 326 | # uses only one shard after the initial dataset load to reduce memory footprint 327 | def load_data(self): 328 | self.calculate_shards() 329 | self.dataset.load_data(start_pos=self.shard_start_idx, end_pos=self.shard_end_idx) 330 | logger.info("Sharded dataset data %d", len(self.dataset)) 331 | 332 | def get_shard_indices(self, epoch: int): 333 | indices = list(range(len(self.dataset))) 334 | if self.shuffle: 335 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 336 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 337 | epoch_rnd.shuffle(indices) 338 | shard_indices = indices 339 | return shard_indices 340 | 341 | 342 | class MultiSetDataIterator(object): 343 | """ 344 | Iterator over multiple data sources. Useful when all samples form a single batch should be from the same dataset. 345 | """ 346 | 347 | def __init__( 348 | self, 349 | datasets: List[ShardedDataIterator], 350 | shuffle_seed: int = 0, 351 | shuffle=True, 352 | sampling_rates: List = [], 353 | rank: int = 0, 354 | ): 355 | # randomized data loading to avoid file system congestion 356 | ds_list_copy = [ds for ds in datasets] 357 | rnd = random.Random(rank) 358 | rnd.shuffle(ds_list_copy) 359 | [ds.load_data() for ds in ds_list_copy] 360 | 361 | self.iterables = datasets 362 | data_lengths = [it.total_data_len() for it in datasets] 363 | self.total_data = sum(data_lengths) 364 | logger.info("rank=%d; Multi set data sizes %s", rank, data_lengths) 365 | logger.info("rank=%d; Multi set total data %s", rank, self.total_data) 366 | logger.info("rank=%d; Multi set sampling_rates %s", rank, sampling_rates) 367 | self.shuffle_seed = shuffle_seed 368 | self.shuffle = shuffle 369 | self.iteration = 0 370 | self.rank = rank 371 | 372 | if sampling_rates: 373 | self.max_its_pr_ds = [int(ds.max_iterations_num() * sampling_rates[i]) for i, ds in enumerate(datasets)] 374 | else: 375 | self.max_its_pr_ds = [ds.max_iterations_num() for ds in datasets] 376 | 377 | self.max_iterations = sum(self.max_its_pr_ds) 378 | logger.info("rank=%d; Multi set max_iterations per dataset %s", rank, self.max_its_pr_ds) 379 | logger.info("rank=%d; Multi set max_iterations %d", rank, self.max_iterations) 380 | 381 | def total_data_len(self) -> int: 382 | return self.total_data 383 | 384 | def get_max_iterations(self): 385 | return self.max_iterations 386 | 387 | def iterate_ds_data(self, epoch: int = 0) -> Iterator[Tuple[List, int]]: 388 | 389 | logger.info("rank=%d; Iteration start", self.rank) 390 | logger.info( 391 | "rank=%d; Multi set iteration: iteration ptr per set: %s", 392 | self.rank, 393 | [it.get_iteration() for it in self.iterables], 394 | ) 395 | 396 | data_src_indices = [] 397 | iterators = [] 398 | for source, src_its in enumerate(self.max_its_pr_ds): 399 | logger.info( 400 | "rank=%d; Multi set iteration: source %d, batches to be taken: %s", 401 | self.rank, 402 | source, 403 | src_its, 404 | ) 405 | data_src_indices.extend([source] * src_its) 406 | 407 | iterators.append(self.iterables[source].iterate_ds_sampled_data(src_its, epoch=epoch)) 408 | 409 | if self.shuffle: 410 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 411 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 412 | epoch_rnd.shuffle(data_src_indices) 413 | 414 | logger.info("rank=%d; data_src_indices len=%d", self.rank, len(data_src_indices)) 415 | for i, source_idx in enumerate(data_src_indices): 416 | it = iterators[source_idx] 417 | next_item = next(it, None) 418 | if next_item is not None: 419 | self.iteration += 1 420 | yield (next_item, source_idx) 421 | else: 422 | logger.warning("rank=%d; Next item in the source %s is None", self.rank, source_idx) 423 | 424 | logger.info("rank=%d; last iteration %d", self.rank, self.iteration) 425 | 426 | logger.info( 427 | "rank=%d; Multi set iteration finished: iteration per set: %s", 428 | self.rank, 429 | [it.iteration for it in self.iterables], 430 | ) 431 | [next(it, None) for it in iterators] 432 | 433 | # TODO: clear iterators in some non-hacky way 434 | for it in self.iterables: 435 | it.iteration = 0 436 | logger.info( 437 | "rank=%d; Multi set iteration finished after next: iteration per set: %s", 438 | self.rank, 439 | [it.iteration for it in self.iterables], 440 | ) 441 | # reset the iteration status 442 | self.iteration = 0 443 | 444 | def get_iteration(self) -> int: 445 | return self.iteration 446 | 447 | def get_dataset(self, ds_id: int) -> Dataset: 448 | return self.iterables[ds_id].get_dataset() 449 | 450 | def get_datasets(self) -> List[Dataset]: 451 | return [it.get_dataset() for it in self.iterables] 452 | -------------------------------------------------------------------------------- /dpr/data/retriever_data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import csv 4 | import json 5 | import logging 6 | import pickle 7 | from typing import Dict, List 8 | 9 | import hydra 10 | import jsonlines 11 | import torch 12 | from omegaconf import DictConfig 13 | 14 | from dpr.data.biencoder_data import ( 15 | BiEncoderPassage, 16 | normalize_passage, 17 | get_dpr_files, 18 | read_nq_tables_jsonl, 19 | split_tables_to_chunks, 20 | ) 21 | from dpr.data.biencoder_data import ( 22 | BiEncoderTable, 23 | get_nq_table_files, 24 | get_processed_table, 25 | get_processed_table_wiki, 26 | get_processed_table_wqt, 27 | ) 28 | 29 | from dpr.utils.data_utils import normalize_question 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | TableChunk = collections.namedtuple("TableChunk", ["text", "title", "table_id"]) 34 | 35 | 36 | class QASample: 37 | def __init__(self, query: str, id, answers: List[str]): 38 | self.query = query 39 | self.id = id 40 | self.answers = answers 41 | 42 | 43 | class RetrieverData(torch.utils.data.Dataset): 44 | def __init__(self, file: str): 45 | """ 46 | :param file: - real file name or the resource name as they are defined in download_data.py 47 | """ 48 | self.file = file 49 | self.data_files = [] 50 | 51 | def load_data(self): 52 | self.data_files = get_dpr_files(self.file) 53 | assert ( 54 | len(self.data_files) == 1 55 | ), "RetrieverData source currently works with single files only. Files specified: {}".format(self.data_files) 56 | self.file = self.data_files[0] 57 | 58 | 59 | class QASrc(RetrieverData): 60 | def __init__( 61 | self, 62 | file: str, 63 | selector: DictConfig = None, 64 | special_query_token: str = None, 65 | query_special_suffix: str = None, 66 | ): 67 | super().__init__(file) 68 | self.data = None 69 | self.selector = hydra.utils.instantiate(selector) if selector else None 70 | self.special_query_token = special_query_token 71 | self.query_special_suffix = query_special_suffix 72 | 73 | def __getitem__(self, index) -> QASample: 74 | return self.data[index] 75 | 76 | def __len__(self): 77 | return len(self.data) 78 | 79 | def _process_question(self, question: str): 80 | # as of now, always normalize query 81 | question = normalize_question(question) 82 | if self.query_special_suffix and not question.endswith(self.query_special_suffix): 83 | question += self.query_special_suffix 84 | return question 85 | 86 | 87 | class CsvQASrc(QASrc): 88 | def __init__( 89 | self, 90 | file: str, 91 | question_col: int = 0, 92 | answers_col: int = 1, 93 | id_col: int = -1, 94 | selector: DictConfig = None, 95 | special_query_token: str = None, 96 | query_special_suffix: str = None, 97 | data_range_start: int = -1, 98 | data_size: int = -1, 99 | ): 100 | super().__init__(file, selector, special_query_token, query_special_suffix) 101 | self.question_col = question_col 102 | self.answers_col = answers_col 103 | self.id_col = id_col 104 | self.data_range_start = data_range_start 105 | self.data_size = data_size 106 | 107 | def load_data(self): 108 | super().load_data() 109 | data = [] 110 | start = self.data_range_start 111 | # size = self.data_size 112 | samples_count = 0 113 | # TODO: optimize 114 | with open(self.file) as ifile: 115 | reader = csv.reader(ifile, delimiter="\t") 116 | for row in reader: 117 | question = row[self.question_col] 118 | answers = eval(row[self.answers_col]) 119 | id = None 120 | if self.id_col >= 0: 121 | id = row[self.id_col] 122 | samples_count += 1 123 | # if start !=-1 and samples_count<=start: 124 | # continue 125 | data.append(QASample(self._process_question(question), id, answers)) 126 | 127 | if start != -1: 128 | end = start + self.data_size if self.data_size != -1 else -1 129 | logger.info("Selecting dataset range [%s,%s]", start, end) 130 | self.data = data[start:end] if end != -1 else data[start:] 131 | else: 132 | self.data = data 133 | 134 | 135 | class JsonlQASrc(QASrc): 136 | def __init__( 137 | self, 138 | file: str, 139 | selector: DictConfig = None, 140 | question_attr: str = "question", 141 | answers_attr: str = "answers", 142 | id_attr: str = "id", 143 | special_query_token: str = None, 144 | query_special_suffix: str = None, 145 | ): 146 | super().__init__(file, selector, special_query_token, query_special_suffix) 147 | self.question_attr = question_attr 148 | self.answers_attr = answers_attr 149 | self.id_attr = id_attr 150 | 151 | def load_data(self): 152 | super().load_data() 153 | data = [] 154 | with jsonlines.open(self.file, mode="r") as jsonl_reader: 155 | for jline in jsonl_reader: 156 | question = jline[self.question_attr] 157 | answers = jline[self.answers_attr] if self.answers_attr in jline else [] 158 | id = None 159 | if self.id_attr in jline: 160 | id = jline[self.id_attr] 161 | data.append(QASample(self._process_question(question), id, answers)) 162 | self.data = data 163 | 164 | 165 | class KiltCsvQASrc(CsvQASrc): 166 | def __init__( 167 | self, 168 | file: str, 169 | kilt_gold_file: str, 170 | question_col: int = 0, 171 | answers_col: int = 1, 172 | id_col: int = -1, 173 | selector: DictConfig = None, 174 | special_query_token: str = None, 175 | query_special_suffix: str = None, 176 | data_range_start: int = -1, 177 | data_size: int = -1, 178 | ): 179 | super().__init__( 180 | file, 181 | question_col, 182 | answers_col, 183 | id_col, 184 | selector, 185 | special_query_token, 186 | query_special_suffix, 187 | data_range_start, 188 | data_size, 189 | ) 190 | self.kilt_gold_file = kilt_gold_file 191 | 192 | 193 | class KiltJsonlQASrc(JsonlQASrc): 194 | def __init__( 195 | self, 196 | file: str, 197 | kilt_gold_file: str, 198 | question_attr: str = "input", 199 | answers_attr: str = "answer", 200 | id_attr: str = "id", 201 | selector: DictConfig = None, 202 | special_query_token: str = None, 203 | query_special_suffix: str = None, 204 | ): 205 | super().__init__( 206 | file, 207 | selector, 208 | question_attr, 209 | answers_attr, 210 | id_attr, 211 | special_query_token, 212 | query_special_suffix, 213 | ) 214 | self.kilt_gold_file = kilt_gold_file 215 | 216 | def load_data(self): 217 | super().load_data() 218 | data = [] 219 | with jsonlines.open(self.file, mode="r") as jsonl_reader: 220 | for jline in jsonl_reader: 221 | question = jline[self.question_attr] 222 | out = jline["output"] 223 | answers = [o["answer"] for o in out if "answer" in o] 224 | id = None 225 | if self.id_attr in jline: 226 | id = jline[self.id_attr] 227 | data.append(QASample(self._process_question(question), id, answers)) 228 | self.data = data 229 | 230 | 231 | class TTS_ASR_QASrc(QASrc): 232 | def __init__(self, file: str, trans_file: str): 233 | super().__init__(file) 234 | self.trans_file = trans_file 235 | 236 | def load_data(self): 237 | super().load_data() 238 | orig_data_dict = {} 239 | with open(self.file, "r") as ifile: 240 | reader = csv.reader(ifile, delimiter="\t") 241 | id = 0 242 | for row in reader: 243 | question = row[0] 244 | answers = eval(row[1]) 245 | orig_data_dict[id] = (question, answers) 246 | id += 1 247 | data = [] 248 | with open(self.trans_file, "r") as tfile: 249 | reader = csv.reader(tfile, delimiter="\t") 250 | for r in reader: 251 | row_str = r[0] 252 | idx = row_str.index("(None-") 253 | q_id = int(row_str[idx + len("(None-") : -1]) 254 | orig_data = orig_data_dict[q_id] 255 | answers = orig_data[1] 256 | q = row_str[:idx].strip().lower() 257 | data.append(QASample(q, idx, answers)) 258 | self.data = data 259 | 260 | 261 | class CsvCtxSrc(RetrieverData): 262 | def __init__( 263 | self, 264 | file: str, 265 | id_col: int = 0, 266 | text_col: int = 1, 267 | title_col: int = 2, 268 | id_prefix: str = None, 269 | normalize: bool = False, 270 | ): 271 | super().__init__(file) 272 | self.text_col = text_col 273 | self.title_col = title_col 274 | self.id_col = id_col 275 | self.id_prefix = id_prefix 276 | self.normalize = normalize 277 | 278 | def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): 279 | super().load_data() 280 | logger.info("Reading file %s", self.file) 281 | with open(self.file) as ifile: 282 | reader = csv.reader(ifile, delimiter="\t") 283 | for row in reader: 284 | # for row in ifile: 285 | # row = row.strip().split("\t") 286 | if row[self.id_col] == "id": 287 | continue 288 | if self.id_prefix: 289 | sample_id = self.id_prefix + str(row[self.id_col]) 290 | else: 291 | sample_id = row[self.id_col] 292 | passage = row[self.text_col].strip('"') 293 | if self.normalize: 294 | passage = normalize_passage(passage) 295 | ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col]) 296 | 297 | 298 | class KiltCsvCtxSrc(CsvCtxSrc): 299 | def __init__( 300 | self, 301 | file: str, 302 | mapping_file: str, 303 | id_col: int = 0, 304 | text_col: int = 1, 305 | title_col: int = 2, 306 | id_prefix: str = None, 307 | normalize: bool = False, 308 | ): 309 | super().__init__(file, id_col, text_col, title_col, id_prefix, normalize=normalize) 310 | self.mapping_file = mapping_file 311 | 312 | def convert_to_kilt(self, kilt_gold_file, dpr_output, kilt_out_file): 313 | logger.info("Converting to KILT format file: %s", dpr_output) 314 | 315 | with open(dpr_output, "rt") as fin: 316 | dpr_output = json.load(fin) 317 | 318 | with jsonlines.open(kilt_gold_file, "r") as reader: 319 | kilt_gold_file = list(reader) 320 | assert len(kilt_gold_file) == len(dpr_output) 321 | map_path = self.mapping_file 322 | with open(map_path, "rb") as fin: 323 | mapping = pickle.load(fin) 324 | 325 | with jsonlines.open(kilt_out_file, mode="w") as writer: 326 | for dpr_entry, kilt_gold_entry in zip(dpr_output, kilt_gold_file): 327 | # assert dpr_entry["question"] == kilt_gold_entry["input"] 328 | provenance = [] 329 | for ctx in dpr_entry["ctxs"]: 330 | wikipedia_id, end_paragraph_id = mapping[int(ctx["id"])] 331 | provenance.append( 332 | { 333 | "wikipedia_id": wikipedia_id, 334 | "end_paragraph_id": end_paragraph_id, 335 | } 336 | ) 337 | kilt_entry = { 338 | "id": kilt_gold_entry["id"], 339 | "input": kilt_gold_entry["input"], # dpr_entry["question"], 340 | "output": [{"provenance": provenance}], 341 | } 342 | writer.write(kilt_entry) 343 | 344 | logger.info("Saved KILT formatted results to: %s", kilt_out_file) 345 | 346 | 347 | class JsonlTablesCtxSrc(object): 348 | def __init__( 349 | self, 350 | file: str, 351 | tables_chunk_sz: int = 100, 352 | split_type: str = "type1", 353 | id_prefix: str = None, 354 | ): 355 | self.tables_chunk_sz = tables_chunk_sz 356 | self.split_type = split_type 357 | self.file = file 358 | self.id_prefix = id_prefix 359 | 360 | def load_data_to(self, ctxs: Dict): 361 | docs = {} 362 | logger.info("Parsing Tables data from: %s", self.file) 363 | tables_dict = read_nq_tables_jsonl(self.file) 364 | table_chunks = split_tables_to_chunks(tables_dict, self.tables_chunk_sz, split_type=self.split_type) 365 | for chunk in table_chunks: 366 | sample_id = self.id_prefix + str(chunk[0]) 367 | docs[sample_id] = TableChunk(chunk[1], chunk[2], chunk[3]) 368 | logger.info("Loaded %d tables chunks", len(docs)) 369 | ctxs.update(docs) 370 | 371 | 372 | # %% NQ-Table 373 | 374 | # [table_sources] 375 | class JsonlNQTablesCtxSrc(object): # need a long time to process 376 | """To load NQ tables from one jsonl file. """ 377 | 378 | def __init__(self, file: str, id_prefix: str = None,): 379 | self.file = file 380 | self.data_files = [] 381 | 382 | self.id_prefix = id_prefix 383 | 384 | def load_data(self): 385 | self.data_files = get_nq_table_files(self.file) 386 | assert (len(self.data_files) == 1), \ 387 | "JsonTablesCtxSrc works with single files only. Files specified: {}".format(self.data_files) 388 | self.file = self.data_files[0] 389 | 390 | def load_data_to(self, ctxs: Dict[object, BiEncoderTable], cfg: DictConfig): 391 | """Load dataset into the `ctxs` argument. """ 392 | self.load_data() 393 | logger.info("Loading NQ-Table data from: %s", self.file) 394 | with open(self.file, 'r') as fr: 395 | dataset = [json.loads(line.strip()) for line in fr] 396 | 397 | for i, sample in enumerate(dataset): 398 | if self.id_prefix: 399 | sample_id = self.id_prefix + sample['tableId'] 400 | # sample_id = self.id_prefix + str(i) 401 | else: 402 | sample_id = sample['tableId'] 403 | # sample_id = str(i) 404 | 405 | processed_table = get_processed_table( 406 | table=sample, 407 | row_selection=cfg.row_selection, 408 | max_cell_num=cfg.max_cell_num, 409 | max_words=cfg.max_words, 410 | max_words_per_header=cfg.max_words_per_header, 411 | max_words_per_cell=cfg.max_words_per_cell, 412 | max_cell_num_per_row=cfg.max_cell_num_per_row, 413 | header_delimiter=cfg.header_delimiter, 414 | cell_delimiter=cfg.cell_delimiter, 415 | row_delimiter=cfg.row_delimiter, 416 | ) 417 | ctxs[sample_id] = processed_table 418 | 419 | logger.info("Loaded %d tables", len(ctxs)) 420 | 421 | 422 | # [table_sources] 423 | class JsonlNqtCtxSrc(object): # pre-processed version 424 | """class to load processed nq-tables from jsonl files. """ 425 | def __init__(self, file: str, id_prefix: str = 'nqt:'): 426 | self.file = file 427 | self.id_prefix = id_prefix 428 | 429 | self.data_files = [] 430 | 431 | def load_data(self): 432 | self.data_files = get_nq_table_files(self.file) 433 | assert (len(self.data_files) == 1), \ 434 | "JsonTablesCtxSrc works with single files only. Files specified: {}".format(self.data_files) 435 | self.file = self.data_files[0] 436 | 437 | def load_data_to(self, ctxs: Dict[object, BiEncoderTable], cfg: DictConfig): 438 | """Load data from self.file to `ctxs`. """ 439 | self.load_data() 440 | logger.info("Loading NQ-Table data from: %s", self.file) 441 | 442 | with open(self.file, 'r') as fr: 443 | dataset = [json.loads(line.strip()) for line in fr] 444 | # sample: {'id': str, 'title': str, 'cells': [{'text': str, 'row_idx': 0, 'col_idx': 0}, ..., {}]} 445 | for i, sample in enumerate(dataset): 446 | if self.id_prefix: sample_id = self.id_prefix + sample['id'] 447 | else: sample_id = sample['id'] 448 | ctxs[sample_id] = BiEncoderTable( 449 | cells=sample['cells'], title=sample['title'], 450 | ) 451 | logger.info("Loaded %d tables", len(ctxs)) 452 | 453 | 454 | # [table_retrieval] 455 | class JsonlQASrcTable(QASrc): 456 | def __init__( 457 | self, 458 | file: str, 459 | selector: DictConfig = None, 460 | question_attr: str = "originalText", 461 | answers_attr: str = "answers", 462 | id_attr: str = "id", 463 | special_query_token: str = None, 464 | query_special_suffix: str = None, 465 | ): 466 | super().__init__(file, selector, special_query_token, query_special_suffix) 467 | self.question_attr = question_attr 468 | self.answers_attr = answers_attr 469 | self.id_attr = id_attr 470 | 471 | def _load_data(self): 472 | self.data_files = get_nq_table_files(self.file) 473 | assert (len(self.data_files) == 1), \ 474 | "JsonTablesCtxSrc works with single files only. Files specified: {}".format(self.data_files) 475 | self.file = self.data_files[0] 476 | 477 | def load_data(self): 478 | self._load_data() 479 | data = [] 480 | with jsonlines.open(self.file, mode="r") as jsonl_reader: 481 | for jline in jsonl_reader: 482 | qa = jline['questions'][0] 483 | question = qa['originalText'] 484 | answers = qa['answer']['answerTexts'] 485 | id = qa['id'] 486 | data.append(QASample(self._process_question(question), id, answers)) 487 | self.data = data 488 | -------------------------------------------------------------------------------- /dpr/models/biencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | BiEncoder component + loss function for 'all-in-batch' training 10 | """ 11 | 12 | import collections 13 | import logging 14 | import random 15 | from typing import Tuple, List 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from torch import Tensor as T 21 | from torch import nn 22 | 23 | from dpr.data.biencoder_data import BiEncoderSample 24 | from dpr.data.table_data import BiEncoderTable, prepare_table_ctx_inputs 25 | from dpr.utils.data_utils import Tensorizer 26 | from dpr.utils.model_utils import CheckpointState 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | BiEncoderBatch = collections.namedtuple( 31 | "BiENcoderInput", 32 | [ 33 | "question_ids", 34 | "question_segments", 35 | "context_ids", 36 | "ctx_segments", 37 | "is_positive", 38 | "hard_negatives", 39 | "encoder_type", 40 | ], 41 | ) 42 | # TODO: it is only used by _select_span_with_token. Move them to utils 43 | rnd = random.Random(0) 44 | 45 | 46 | BiEncoderBatchTable = collections.namedtuple( 47 | "BiEncoderInputTable", 48 | [ 49 | "question_ids", 50 | "question_segments", 51 | "context_ids", 52 | "ctx_segments", 53 | "ctx_attn_masks", 54 | "is_positive", 55 | "hard_negatives", 56 | "encoder_type", 57 | ], 58 | ) 59 | 60 | BiEncoderBatchTableAuxEmb = collections.namedtuple( 61 | "BiEncoderInputTable", 62 | [ 63 | "question_ids", 64 | "question_segments", 65 | "context_ids", 66 | "row_ids", 67 | "column_ids", 68 | "ctx_segments", 69 | "is_positive", 70 | "hard_negatives", 71 | "encoder_type", 72 | ], 73 | ) 74 | 75 | 76 | TableBiEncoderBatch = collections.namedtuple( 77 | "TableBiEncoderBatch", 78 | [ 79 | "question_ids", "question_segments", 80 | "context_ids", "ctx_segments", "ctx_attn_masks", 81 | "ctx_row_ids", "ctx_column_ids", 82 | "is_positive", "hard_negatives", "encoder_type", 83 | ], 84 | ) 85 | 86 | 87 | def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: 88 | """ 89 | calculates q->ctx scores for every row in ctx_vector 90 | :param q_vector: 91 | :param ctx_vector: 92 | :return: 93 | """ 94 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 95 | r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) 96 | return r 97 | 98 | 99 | def cosine_scores(q_vector: T, ctx_vectors: T): 100 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 101 | return F.cosine_similarity(q_vector, ctx_vectors, dim=1) 102 | 103 | 104 | class BiEncoder(nn.Module): 105 | """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" 106 | 107 | def __init__( 108 | self, 109 | question_model: nn.Module, 110 | ctx_model: nn.Module, 111 | fix_q_encoder: bool = False, 112 | fix_ctx_encoder: bool = False, 113 | ): 114 | super(BiEncoder, self).__init__() 115 | self.question_model = question_model 116 | self.ctx_model = ctx_model 117 | self.fix_q_encoder = fix_q_encoder 118 | self.fix_ctx_encoder = fix_ctx_encoder 119 | 120 | @staticmethod 121 | def get_representation( 122 | sub_model: nn.Module, 123 | ids: T, 124 | segments: T, 125 | attn_mask: T = None, 126 | row_ids: T = None, 127 | column_ids: T = None, 128 | fix_encoder: bool = False, 129 | representation_token_pos=0, 130 | ) -> Tuple[T, T, T]: 131 | sequence_output = None 132 | pooled_output = None 133 | hidden_states = None 134 | if ids is not None: 135 | if fix_encoder: 136 | with torch.no_grad(): 137 | sequence_output, pooled_output, hidden_states = sub_model( 138 | input_ids=ids, 139 | token_type_ids=segments, 140 | attention_mask=attn_mask, 141 | row_ids=row_ids, 142 | column_ids=column_ids, 143 | representation_token_pos=representation_token_pos, 144 | ) 145 | 146 | if sub_model.training: 147 | sequence_output.requires_grad_(requires_grad=True) 148 | pooled_output.requires_grad_(requires_grad=True) 149 | else: 150 | sequence_output, pooled_output, hidden_states = sub_model( 151 | input_ids=ids, 152 | token_type_ids=segments, 153 | attention_mask=attn_mask, 154 | row_ids=row_ids, 155 | column_ids=column_ids, 156 | representation_token_pos=representation_token_pos, 157 | ) 158 | 159 | return sequence_output, pooled_output, hidden_states 160 | 161 | def forward( 162 | self, 163 | question_ids: T, 164 | question_segments: T, 165 | question_attn_mask: T, 166 | context_ids: T, 167 | ctx_segments: T, 168 | ctx_attn_mask: T = None, 169 | ctx_row_ids: T = None, 170 | ctx_column_ids: T = None, 171 | encoder_type: str = None, 172 | representation_token_pos=0, 173 | ) -> Tuple[T, T]: 174 | q_encoder = self.question_model if encoder_type is None or encoder_type == "question" else self.ctx_model 175 | _q_seq, q_pooled_out, _q_hidden = self.get_representation( 176 | q_encoder, 177 | question_ids, 178 | question_segments, 179 | question_attn_mask, 180 | self.fix_q_encoder, 181 | representation_token_pos=representation_token_pos, 182 | ) 183 | 184 | ctx_encoder = self.ctx_model if encoder_type is None or encoder_type == "ctx" else self.question_model 185 | _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( 186 | ctx_encoder, context_ids, ctx_segments, ctx_attn_mask, ctx_row_ids, ctx_column_ids, 187 | self.fix_ctx_encoder 188 | ) 189 | 190 | return q_pooled_out, ctx_pooled_out 191 | 192 | def create_biencoder_input( 193 | self, 194 | samples: List[BiEncoderSample], 195 | tensorizer: Tensorizer, 196 | insert_title: bool, 197 | num_hard_negatives: int = 0, 198 | num_other_negatives: int = 0, 199 | shuffle: bool = True, 200 | shuffle_positives: bool = False, 201 | hard_neg_fallback: bool = True, 202 | query_token: str = None, 203 | ) -> BiEncoderBatch: 204 | """ 205 | Creates a batch of the biencoder training tuple. 206 | :param samples: list of BiEncoderSample-s to create the batch for 207 | :param tensorizer: components to create model input tensors from a text sequence 208 | :param insert_title: enables title insertion at the beginning of the context sequences 209 | :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) 210 | :param num_other_negatives: amount of other negatives per question (taken from samples' pools) 211 | :param shuffle: shuffles negative passages pools 212 | :param shuffle_positives: shuffles positive passages pools 213 | :return: BiEncoderBatch tuple 214 | """ 215 | question_tensors = [] 216 | ctx_tensors = [] 217 | positive_ctx_indices = [] 218 | hard_neg_ctx_indices = [] 219 | 220 | for sample in samples: 221 | # ctx+ & [ctx-] composition 222 | # as of now, take the first(gold) ctx+ only 223 | 224 | if shuffle and shuffle_positives: 225 | positive_ctxs = sample.positive_passages 226 | positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] 227 | else: 228 | positive_ctx = sample.positive_passages[0] 229 | 230 | neg_ctxs = sample.negative_passages 231 | hard_neg_ctxs = sample.hard_negative_passages 232 | question = sample.query 233 | # question = normalize_question(sample.query) 234 | 235 | if shuffle: 236 | random.shuffle(neg_ctxs) 237 | random.shuffle(hard_neg_ctxs) 238 | 239 | if hard_neg_fallback and len(hard_neg_ctxs) == 0: 240 | hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] 241 | 242 | neg_ctxs = neg_ctxs[0:num_other_negatives] 243 | hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] 244 | 245 | all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs 246 | hard_negatives_start_idx = 1 247 | hard_negatives_end_idx = 1 + len(hard_neg_ctxs) 248 | 249 | current_ctxs_len = len(ctx_tensors) 250 | 251 | sample_ctxs_tensors = [ 252 | tensorizer.text_to_tensor(ctx.text, title=ctx.title if (insert_title and ctx.title) else None) 253 | for ctx in all_ctxs 254 | ] 255 | 256 | ctx_tensors.extend(sample_ctxs_tensors) 257 | positive_ctx_indices.append(current_ctxs_len) 258 | hard_neg_ctx_indices.append( 259 | [ 260 | i 261 | for i in range( 262 | current_ctxs_len + hard_negatives_start_idx, 263 | current_ctxs_len + hard_negatives_end_idx, 264 | ) 265 | ] 266 | ) 267 | 268 | if query_token: 269 | # TODO: tmp workaround for EL, remove or revise 270 | if query_token == "[START_ENT]": 271 | query_span = _select_span_with_token(question, tensorizer, token_str=query_token) 272 | question_tensors.append(query_span) 273 | else: 274 | question_tensors.append(tensorizer.text_to_tensor(" ".join([query_token, question]))) 275 | else: 276 | question_tensors.append(tensorizer.text_to_tensor(question)) 277 | 278 | ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) 279 | questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) 280 | 281 | ctx_segments = torch.zeros_like(ctxs_tensor) 282 | question_segments = torch.zeros_like(questions_tensor) 283 | 284 | return BiEncoderBatch( 285 | questions_tensor, 286 | question_segments, 287 | ctxs_tensor, 288 | ctx_segments, 289 | positive_ctx_indices, 290 | hard_neg_ctx_indices, 291 | "question", 292 | ) 293 | 294 | def load_state(self, saved_state: CheckpointState, strict: bool = True): 295 | # TODO: make a long term HF compatibility fix 296 | # if "question_model.embeddings.position_ids" in saved_state.model_dict: 297 | # del saved_state.model_dict["question_model.embeddings.position_ids"] 298 | # del saved_state.model_dict["ctx_model.embeddings.position_ids"] 299 | self.load_state_dict(saved_state.model_dict, strict=strict) 300 | 301 | def get_state_dict(self): 302 | return self.state_dict() 303 | 304 | def create_table_input( 305 | self, 306 | samples: List[BiEncoderTable], 307 | tensorizer: Tensorizer, 308 | structure_option: str = "global", 309 | insert_title: bool = True, 310 | num_hard_negatives: int = 0, 311 | num_other_negatives: int = 0, 312 | shuffle: bool = True, 313 | shuffle_positives: bool = False, 314 | hard_neg_fallback: bool = True, 315 | query_token: str = None, 316 | ) -> BiEncoderBatch: 317 | question_tensors, ctx_tensors = [], [] 318 | ctx_attn_masks, ctx_row_ids, ctx_column_ids = [], [], [] 319 | positive_ctx_indices, hard_neg_ctx_indices = [], [] 320 | 321 | for sample in samples: 322 | if shuffle and shuffle_positives: 323 | positive_ctxs = sample.positive_passages 324 | positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] 325 | else: 326 | positive_ctx = sample.positive_passages[0] 327 | 328 | neg_ctxs = sample.negative_passages 329 | hard_neg_ctxs = sample.hard_negative_passages 330 | question = sample.query 331 | 332 | if shuffle: 333 | random.shuffle(neg_ctxs) 334 | random.shuffle(hard_neg_ctxs) 335 | 336 | if hard_neg_fallback and len(hard_neg_ctxs) == 0: 337 | hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] 338 | 339 | neg_ctxs = neg_ctxs[0:num_other_negatives] 340 | hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] 341 | 342 | all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs 343 | hard_negatives_start_idx = 1 344 | hard_negatives_end_idx = 1 + len(hard_neg_ctxs) 345 | 346 | current_ctxs_len = len(ctx_tensors) 347 | 348 | for ctx in all_ctxs: 349 | ctx_inputs = prepare_table_ctx_inputs( 350 | ctx['table'], 351 | tokenizer=tensorizer.tokenizer, 352 | structure_option=structure_option, 353 | insert_title=insert_title, 354 | ) 355 | ctx_tensors.append(ctx_inputs['token_ids']) 356 | ctx_attn_masks.append(ctx_inputs['attn_mask']) 357 | ctx_row_ids.append(ctx_inputs['row_ids']) 358 | ctx_column_ids.append(ctx_inputs['column_ids']) 359 | 360 | positive_ctx_indices.append(current_ctxs_len) 361 | hard_neg_ctx_indices.append( 362 | [ 363 | i 364 | for i in range( 365 | current_ctxs_len + hard_negatives_start_idx, 366 | current_ctxs_len + hard_negatives_end_idx, 367 | ) 368 | ] 369 | ) 370 | if query_token: 371 | # TODO: tmp workaround for EL, remove or revise 372 | if query_token == "[START_ENT]": 373 | query_span = _select_span_with_token(question, tensorizer, token_str=query_token) 374 | question_tensors.append(query_span) 375 | else: 376 | question_tensors.append(tensorizer.text_to_tensor(" ".join([query_token, question]))) 377 | else: 378 | question_tensors.append(tensorizer.text_to_tensor(question)) 379 | 380 | ctx_tensors_batch = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) 381 | # ctx_row_ids_batch = torch.cat([ctx.view(1, -1) for ctx in ctx_row_ids], dim=0) 382 | ctx_row_ids_batch = torch.cat([ctx.unsqueeze(0) for ctx in ctx_row_ids], dim=0) 383 | ctx_column_ids_batch = torch.cat([ctx.view(1, -1) for ctx in ctx_column_ids], dim=0) 384 | ctx_attn_masks_batch = torch.cat([mask.unsqueeze(0) for mask in ctx_attn_masks], dim=0) 385 | 386 | question_tensors_batch = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) 387 | 388 | positive_ctx_indices = torch.LongTensor(positive_ctx_indices) 389 | hard_neg_ctx_indices = torch.LongTensor(hard_neg_ctx_indices) 390 | 391 | ctx_segments_batch = torch.zeros_like(ctx_tensors_batch) 392 | question_segments_batch = torch.zeros_like(question_tensors_batch) 393 | 394 | return TableBiEncoderBatch( 395 | question_ids=question_tensors_batch, 396 | question_segments=question_segments_batch, 397 | context_ids=ctx_tensors_batch, 398 | ctx_segments=ctx_segments_batch, 399 | ctx_attn_masks=ctx_attn_masks_batch, 400 | ctx_row_ids=ctx_row_ids_batch, 401 | ctx_column_ids=ctx_column_ids_batch, 402 | is_positive=positive_ctx_indices, 403 | hard_negatives=hard_neg_ctx_indices, 404 | encoder_type="question", 405 | ) 406 | 407 | 408 | 409 | class BiEncoderNllLoss(object): 410 | def calc( 411 | self, 412 | q_vectors: T, 413 | ctx_vectors: T, 414 | positive_idx_per_question: list, 415 | hard_negative_idx_per_question: list = None, 416 | loss_scale: float = None, 417 | ) -> Tuple[T, int]: 418 | """ 419 | Computes nll loss for the given lists of question and ctx vectors. 420 | Note that although hard_negative_idx_per_question in not currently in use, one can use it for the 421 | loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. 422 | :return: a tuple of loss value and amount of correct predictions per batch 423 | """ 424 | scores = self.get_scores(q_vectors, ctx_vectors) 425 | 426 | if len(q_vectors.size()) > 1: 427 | q_num = q_vectors.size(0) 428 | scores = scores.view(q_num, -1) 429 | 430 | softmax_scores = F.log_softmax(scores, dim=1) 431 | 432 | # print(f"softmax_scores ({softmax_scores.shape}): {softmax_scores}") 433 | # print(f"positive_idx_per_question ({positive_idx_per_question.shape}): {positive_idx_per_question}") 434 | 435 | loss = F.nll_loss( 436 | softmax_scores, 437 | torch.tensor(positive_idx_per_question).to(softmax_scores.device), 438 | reduction="mean", 439 | ) 440 | 441 | max_score, max_idxs = torch.max(softmax_scores, 1) 442 | correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum() 443 | 444 | if loss_scale: 445 | loss.mul_(loss_scale) 446 | 447 | # print(f"correct_predictions_count: {correct_predictions_count}") 448 | return loss, correct_predictions_count 449 | 450 | @staticmethod 451 | def get_scores(q_vector: T, ctx_vectors: T) -> T: 452 | f = BiEncoderNllLoss.get_similarity_function() 453 | return f(q_vector, ctx_vectors) 454 | 455 | @staticmethod 456 | def get_similarity_function(): 457 | return dot_product_scores 458 | 459 | 460 | def _select_span_with_token(text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]") -> T: 461 | id = tensorizer.get_token_id(token_str) 462 | query_tensor = tensorizer.text_to_tensor(text) 463 | 464 | if id not in query_tensor: 465 | query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False) 466 | token_indexes = (query_tensor_full == id).nonzero() 467 | if token_indexes.size(0) > 0: 468 | start_pos = token_indexes[0, 0].item() 469 | # add some randomization to avoid overfitting to a specific token position 470 | 471 | left_shit = int(tensorizer.max_length / 2) 472 | rnd_shift = int((rnd.random() - 0.5) * left_shit / 2) 473 | left_shit += rnd_shift 474 | 475 | query_tensor = query_tensor_full[start_pos - left_shit :] 476 | cls_id = tensorizer.tokenizer.cls_token_id 477 | if query_tensor[0] != cls_id: 478 | query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0) 479 | 480 | from dpr.models.reader import _pad_to_len 481 | 482 | query_tensor = _pad_to_len(query_tensor, tensorizer.get_pad_id(), tensorizer.max_length) 483 | query_tensor[-1] = tensorizer.tokenizer.sep_token_id 484 | # logger.info('aligned query_tensor %s', query_tensor) 485 | 486 | assert id in query_tensor, "query_tensor={}".format(query_tensor) 487 | return query_tensor 488 | else: 489 | raise RuntimeError("[START_ENT] toke not found for Entity Linking sample query={}".format(text)) 490 | else: 491 | return query_tensor 492 | -------------------------------------------------------------------------------- /dpr/data/tables.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import re 3 | import csv 4 | import json 5 | import logging 6 | import unicodedata 7 | 8 | import jsonlines 9 | import spacy as spacy 10 | from typing import List, Dict 11 | 12 | 13 | logger = logging.getLogger() 14 | logger.setLevel(logging.INFO) 15 | if logger.hasHandlers(): 16 | logger.handlers.clear() 17 | log_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s") 18 | console = logging.StreamHandler() 19 | console.setFormatter(log_formatter) 20 | logger.addHandler(console) 21 | 22 | nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger", "ner", "entity_ruler"]) 23 | 24 | 25 | class Cell: 26 | def __init__(self): 27 | self.value_tokens: List[str] = [] 28 | self.type: str = "" 29 | self.nested_tables: List[Table] = [] 30 | 31 | def __str__(self): 32 | return " ".join(self.value_tokens) 33 | 34 | def to_dpr_json(self, cell_idx: int): 35 | r = {"col": cell_idx} 36 | r["value"] = str(self) 37 | return r 38 | 39 | 40 | class Row: 41 | def __init__(self): 42 | self.cells: List[Cell] = [] 43 | 44 | def __str__(self): 45 | return "| ".join([str(c) for c in self.cells]) 46 | 47 | def visit(self, tokens_function, row_idx: int): 48 | for i, c in enumerate(self.cells): 49 | if c.value_tokens: 50 | tokens_function(c.value_tokens, row_idx, i) 51 | 52 | def to_dpr_json(self, row_idx: int): 53 | r = {"row": row_idx} 54 | r["columns"] = [c.to_dpr_json(i) for i, c in enumerate(self.cells)] 55 | return r 56 | 57 | 58 | class Table(object): 59 | def __init__(self, caption=""): 60 | self.caption = caption 61 | self.body: List[Row] = [] 62 | self.key = None 63 | self.gold_match = False 64 | 65 | def __str__(self): 66 | table_str = ": {}\n".format(self.caption) 67 | table_str += " rows:\n" 68 | for i, r in enumerate(self.body): 69 | table_str += " row #{}: {}\n".format(i, str(r)) 70 | 71 | return table_str 72 | 73 | def get_key(self) -> str: 74 | if not self.key: 75 | self.key = str(self) 76 | return self.key 77 | 78 | def visit(self, tokens_function, include_caption: bool = False) -> bool: 79 | if include_caption: 80 | tokens_function(self.caption, -1, -1) 81 | for i, r in enumerate(self.body): 82 | r.visit(tokens_function, i) 83 | 84 | def to_dpr_json(self): 85 | r = { 86 | "caption": self.caption, 87 | "rows": [r.to_dpr_json(i) for i, r in enumerate(self.body)], 88 | } 89 | if self.gold_match: 90 | r["gold_match"] = 1 91 | return r 92 | 93 | 94 | class NQTableParser(object): 95 | def __init__(self, tokens, is_html_mask, title): 96 | self.tokens = tokens 97 | self.is_html_mask = is_html_mask 98 | self.max_idx = len(self.tokens) 99 | self.all_tables = [] 100 | 101 | self.current_table: Table = None 102 | self.tables_stack = collections.deque() 103 | self.title = title 104 | 105 | def parse(self) -> List[Table]: 106 | self.all_tables = [] 107 | self.tables_stack = collections.deque() 108 | 109 | for i in range(self.max_idx): 110 | 111 | t = self.tokens[i] 112 | 113 | if not self.is_html_mask[i]: 114 | # cell content 115 | self._on_content(t) 116 | continue 117 | 118 | if "": 121 | self._on_table_end() 122 | elif "": 125 | self._onRowEnd() 126 | elif "", ""]: 129 | self._on_cell_end() 130 | 131 | return self.all_tables 132 | 133 | def _on_table_start(self): 134 | caption = self.title 135 | parent_table = self.current_table 136 | if parent_table: 137 | self.tables_stack.append(parent_table) 138 | 139 | caption = parent_table.caption 140 | if parent_table.body and parent_table.body[-1].cells: 141 | current_cell = self.current_table.body[-1].cells[-1] 142 | caption += " | " + " ".join(current_cell.value_tokens) 143 | 144 | t = Table() 145 | t.caption = caption 146 | self.current_table = t 147 | self.all_tables.append(t) 148 | 149 | def _on_table_end(self): 150 | t = self.current_table 151 | if t: 152 | if self.tables_stack: # t is a nested table 153 | self.current_table = self.tables_stack.pop() 154 | if self.current_table.body: 155 | current_cell = self.current_table.body[-1].cells[-1] 156 | current_cell.nested_tables.append(t) 157 | else: 158 | logger.error("table end without table object") 159 | 160 | def _onRowStart(self): 161 | self.current_table.body.append(Row()) 162 | 163 | def _onRowEnd(self): 164 | pass 165 | 166 | def _onCellStart(self): 167 | current_row = self.current_table.body[-1] 168 | current_row.cells.append(Cell()) 169 | 170 | def _on_cell_end(self): 171 | pass 172 | 173 | def _on_content(self, token): 174 | if self.current_table.body: 175 | current_row = self.current_table.body[-1] 176 | current_cell = current_row.cells[-1] 177 | current_cell.value_tokens.append(token) 178 | else: # tokens outside of row/cells. Just append to the table caption. 179 | self.current_table.caption += " " + token 180 | 181 | 182 | def read_nq_tables_jsonl(path: str, out_file: str = None) -> Dict[str, Table]: 183 | tables_with_issues = 0 184 | single_row_tables = 0 185 | nested_tables = 0 186 | regular_tables = 0 187 | total_tables = 0 188 | total_rows = 0 189 | tables_dict = {} 190 | 191 | with jsonlines.open(path, mode="r") as jsonl_reader: 192 | for jline in jsonl_reader: 193 | tokens = jline["tokens"] 194 | 195 | if "( hide ) This section has multiple issues" in " ".join(tokens): 196 | tables_with_issues += 1 197 | continue 198 | mask = jline["html_mask"] 199 | # _page_url = jline["doc_url"] 200 | title = jline["title"] 201 | p = NQTableParser(tokens, mask, title) 202 | tables = p.parse() 203 | 204 | nested_tables += len(tables[1:]) 205 | 206 | for t in tables: 207 | total_tables += 1 208 | 209 | # calc amount of non empty rows 210 | non_empty_rows = sum([1 for r in t.body if r.cells and any([True for c in r.cells if c.value_tokens])]) 211 | 212 | if non_empty_rows <= 1: 213 | single_row_tables += 1 214 | else: 215 | regular_tables += 1 216 | total_rows += len(t.body) 217 | 218 | if t.get_key() not in tables_dict: 219 | tables_dict[t.get_key()] = t 220 | 221 | if len(tables_dict) % 1000 == 0: 222 | logger.info("tables_dict %d", len(tables_dict)) 223 | 224 | logger.info("regular tables %d", regular_tables) 225 | logger.info("tables_with_issues %d", tables_with_issues) 226 | logger.info("single_row_tables %d", single_row_tables) 227 | logger.info("nested_tables %d", nested_tables) 228 | 229 | if out_file: 230 | convert_to_csv_for_lucene(tables_dict, out_file) 231 | return tables_dict 232 | 233 | 234 | def get_table_string_for_answer_check(table: Table): # this doesn't use caption 235 | table_text = "" 236 | for r in table.body: 237 | table_text += " . ".join([" ".join(c.value_tokens) for c in r.cells]) 238 | table_text += " . " 239 | return table_text 240 | 241 | 242 | def convert_to_csv_for_lucene(tables_dict, out_file: str): 243 | id = 0 244 | with open(out_file, "w", newline="") as csvfile: 245 | writer = csv.writer(csvfile, delimiter="\t") 246 | for _, v in tables_dict.items(): 247 | id += 1 248 | # strip all 249 | table_text = get_table_string_for_answer_check(v) 250 | writer.writerow([id, table_text, v.caption]) 251 | logger.info("Saved to %s", out_file) 252 | 253 | 254 | def convert_jsonl_to_qas_tsv(path, out): 255 | results = [] 256 | with jsonlines.open(path, mode="r") as jsonl_reader: 257 | for jline in jsonl_reader: 258 | q = jline["question"] 259 | answers = [] 260 | if "short_answers" in jline: 261 | answers = jline["short_answers"] 262 | 263 | results.append((q, answers)) 264 | 265 | with open(out, "w", newline="") as csvfile: 266 | writer = csv.writer(csvfile, delimiter="\t") 267 | for r in results: 268 | writer.writerow([r[0], r[1]]) 269 | logger.info("Saved to %s", out) 270 | 271 | 272 | def tokenize(text): 273 | doc = nlp(text) 274 | return [token.text.lower() for token in doc] 275 | 276 | 277 | def normalize(text): 278 | """Resolve different type of unicode encodings.""" 279 | return unicodedata.normalize("NFD", text) 280 | 281 | 282 | def prepare_answers(answers) -> List[List[str]]: 283 | r = [] 284 | for single_answer in answers: 285 | single_answer = normalize(single_answer) 286 | single_answer = single_answer.lower().split(" ") # tokenize(single_answer) 287 | r.append(single_answer) 288 | return r 289 | 290 | 291 | def has_prepared_answer(prep_answers: List[List[str]], text: List[str]): 292 | """Check if a document contains an answer string.""" 293 | text = [normalize(token).lower() for token in text] 294 | 295 | for single_answer in prep_answers: 296 | for i in range(0, len(text) - len(single_answer) + 1): 297 | if single_answer == text[i : i + len(single_answer)]: 298 | return True 299 | return False 300 | 301 | 302 | def has_answer(answers, text, regMatxh=False): 303 | """Check if a document contains an answer string.""" 304 | 305 | text = normalize(text) 306 | 307 | if regMatxh: 308 | single_answer = normalize(answers[0]) 309 | if regex_match(text, single_answer): 310 | return True 311 | else: 312 | # Answer is a list of possible strings 313 | text = tokenize(text) 314 | 315 | for single_answer in answers: 316 | single_answer = normalize(single_answer) 317 | single_answer = tokenize(single_answer) 318 | 319 | for i in range(0, len(text) - len(single_answer) + 1): 320 | if single_answer == text[i : i + len(single_answer)]: 321 | return True 322 | return False 323 | 324 | 325 | def convert_search_res_to_dpr_and_eval( 326 | res_file, all_tables_file_jsonl, nq_table_file, out_file, gold_res_file: str = None 327 | ): 328 | db = {} 329 | id = 0 330 | tables_dict = read_nq_tables_jsonl(all_tables_file_jsonl) 331 | for _, v in tables_dict.items(): 332 | id += 1 333 | db[id] = v 334 | 335 | logger.info("db size %s", len(db)) 336 | total = 0 337 | dpr_results = {} 338 | import torch 339 | 340 | bm25_per_topk_hits = torch.tensor([0] * 100) 341 | qas = [] 342 | with open(res_file) as tsvfile: 343 | reader = csv.reader(tsvfile, delimiter="\t") 344 | # file format: id, text 345 | for row in reader: 346 | total += 1 347 | q = row[0] 348 | answers = eval(row[1]) 349 | prep_answers = prepare_answers(answers) 350 | qas.append((q, prep_answers)) 351 | question_hns = [] 352 | question_positives = [] 353 | answers_table_links = [] 354 | 355 | for k, bm25result in enumerate(row[2:]): 356 | score, id = bm25result.split(",") 357 | table = db[int(id)] 358 | answer_locations = [] 359 | 360 | def check_answer(tokens, row_idx: int, cell_idx: int): 361 | if has_prepared_answer(prep_answers, tokens): 362 | answer_locations.append((row_idx, cell_idx)) 363 | 364 | # get string representation to find answer 365 | if (len(question_positives) >= 10 and len(question_hns) >= 10) or (len(question_hns) >= 30): 366 | break 367 | 368 | # table_str = get_table_string_for_answer_check(table) 369 | table.visit(check_answer) 370 | has_answer = len(answer_locations) > 0 371 | 372 | if has_answer: 373 | question_positives.append(table) 374 | answers_table_links.append(answer_locations) 375 | else: 376 | question_hns.append(table) 377 | 378 | dpr_results[q] = (question_positives, question_hns, answers_table_links) 379 | if len(dpr_results) % 100 == 0: 380 | logger.info("dpr_results %s", len(dpr_results)) 381 | 382 | logger.info("dpr_results size %s", len(dpr_results)) 383 | logger.info("total %s", total) 384 | logger.info("bm25_per_topk_hits %s", bm25_per_topk_hits) 385 | 386 | if gold_res_file: 387 | logger.info("Processing gold_res_file") 388 | with open(gold_res_file) as cFile: 389 | csvReader = csv.reader(cFile, delimiter=",") 390 | for row in csvReader: 391 | q_id = int(row[0]) 392 | qas_tuple = qas[q_id] 393 | prep_answers = qas_tuple[1] 394 | question_gold_positive_match = None 395 | q = qas_tuple[0] 396 | answers_links = None 397 | for field in row[1:]: 398 | psg_id = int(field.split()[0]) 399 | table = db[psg_id] 400 | answer_locations = [] 401 | 402 | def check_answer(tokens, row_idx: int, cell_idx: int): 403 | if has_prepared_answer(prep_answers, tokens): 404 | answer_locations.append((row_idx, cell_idx)) 405 | 406 | table.visit(check_answer) 407 | has_answer = len(answer_locations) > 0 408 | if has_answer and question_gold_positive_match is None: 409 | question_gold_positive_match = table 410 | question_gold_positive_match.gold_match = True 411 | answers_links = answer_locations 412 | 413 | if question_gold_positive_match is None: 414 | logger.info("No gold match for q=%s, q_id=%s", q, q_id) 415 | else: # inject into ctx+ at the first position 416 | question_positives, hns, ans_links = dpr_results[q] 417 | question_positives.insert(0, question_gold_positive_match) 418 | ans_links.insert(0, answers_links) 419 | 420 | out_results = [] 421 | with jsonlines.open(nq_table_file, mode="r") as jsonl_reader: 422 | for jline in jsonl_reader: 423 | q = jline["question"] 424 | gold_positive_table = jline["contexts"][0] 425 | mask = gold_positive_table["html_mask"] 426 | # page_url = jline['doc_url'] 427 | title = jline["title"] 428 | p = NQTableParser(gold_positive_table["tokens"], mask, title) 429 | tables = p.parse() 430 | # select the one with the answer(s) 431 | prep_answers = prepare_answers(jline["short_answers"]) 432 | 433 | tables_with_answers = [] 434 | tables_answer_locations = [] 435 | 436 | for t in tables: 437 | answer_locations = [] 438 | 439 | def check_answer(tokens, row_idx: int, cell_idx: int): 440 | if has_prepared_answer(prep_answers, tokens): 441 | answer_locations.append((row_idx, cell_idx)) 442 | 443 | t.visit(check_answer) 444 | has_answer = len(answer_locations) > 0 445 | if has_answer: 446 | tables_with_answers.append(t) 447 | tables_answer_locations.append(answer_locations) 448 | 449 | if not tables_with_answers: 450 | logger.info("No answer in gold table(s) for q=%s", q) 451 | 452 | positive_ctxs, hard_neg_ctxs, answers_table_links = dpr_results[q] 453 | positive_ctxs = positive_ctxs + tables_with_answers 454 | tables_answer_locations = answers_table_links + tables_answer_locations 455 | assert len(positive_ctxs) == len(tables_answer_locations) 456 | positive_ctxs = [t.to_dpr_json() for t in positive_ctxs] 457 | 458 | # set has_answer attributes 459 | for i, ctx_json in enumerate(positive_ctxs): 460 | answer_links = tables_answer_locations[i] 461 | ctx_json["answer_pos"] = answer_links 462 | hard_neg_ctxs = [t.to_dpr_json() for t in hard_neg_ctxs] 463 | out_results.append( 464 | { 465 | "question": q, 466 | "id": jline["example_id"], 467 | "answers": jline["short_answers"], 468 | "positive_ctxs": positive_ctxs, 469 | "hard_negative_ctxs": hard_neg_ctxs, 470 | } 471 | ) 472 | 473 | logger.info("out_results size %s", len(out_results)) 474 | 475 | with jsonlines.open(out_file, mode="w") as writer: # encoding="utf-8", .encode('utf-8') 476 | for r in out_results: 477 | writer.write(r) 478 | 479 | logger.info("Saved to %s", out_file) 480 | 481 | 482 | def convert_long_ans_to_dpr(nq_table_file, out_file): 483 | out_results = [] 484 | with jsonlines.open(nq_table_file, mode="r") as jsonl_reader: 485 | for jline in jsonl_reader: 486 | q = jline["question"] 487 | 488 | gold_positive_table = jline["contexts"] 489 | 490 | mask = gold_positive_table["la_ans_tokens_html_mask"] 491 | # page_url = jline['doc_url'] 492 | title = jline["title"] 493 | 494 | p = NQTableParser(gold_positive_table["la_ans_tokens"], mask, title) 495 | tables = p.parse() 496 | # select the one with the answer(s) 497 | 498 | positive_ctxs = [tables[0].to_dpr_json()] 499 | 500 | out_results.append( 501 | { 502 | "question": q, 503 | "id": jline["example_id"], 504 | "answers": [], 505 | "positive_ctxs": positive_ctxs, 506 | "hard_negative_ctxs": [], 507 | } 508 | ) 509 | 510 | logger.info("out_results size %s", len(out_results)) 511 | 512 | with jsonlines.open(out_file, mode="w") as writer: # encoding="utf-8", .encode('utf-8') 513 | for r in out_results: 514 | writer.write(r) 515 | 516 | logger.info("Saved to %s", out_file) 517 | 518 | 519 | def parse_qa_csv_file(location): 520 | res = [] 521 | with open(location) as ifile: 522 | reader = csv.reader(ifile, delimiter="\t") 523 | for row in reader: 524 | question = row[0] 525 | answers = eval(row[1]) 526 | res.append((question, answers)) 527 | return res 528 | 529 | 530 | def calc_questions_overlap(tables_file, regular_file, dev_file): 531 | tab_questions = set() 532 | 533 | with jsonlines.open(tables_file, mode="r") as jsonl_reader: 534 | logger.info("Reading file %s" % tables_file) 535 | for jline in jsonl_reader: 536 | q = jline["question"] 537 | tab_questions.add(q) 538 | 539 | reg_questions = set() 540 | 541 | if regular_file[-4:] == ".csv": 542 | qas = parse_qa_csv_file(regular_file) 543 | for qa in qas: 544 | reg_questions.add(qa[0]) 545 | else: 546 | with open(regular_file, "r", encoding="utf-8") as f: 547 | logger.info("Reading file %s" % regular_file) 548 | data = json.load(f) 549 | for item in data: 550 | q = item["question"] 551 | reg_questions.add(q) 552 | if dev_file: 553 | if dev_file[-4:] == ".csv": 554 | qas = parse_qa_csv_file(dev_file) 555 | for qa in qas: 556 | reg_questions.add(qa[0]) 557 | else: 558 | with open(dev_file, "r", encoding="utf-8") as f: 559 | logger.info("Reading file %s" % dev_file) 560 | data = json.load(f) 561 | for item in data: 562 | q = item["question"] 563 | reg_questions.add(q) 564 | 565 | logger.info("tab_questions %d", len(tab_questions)) 566 | logger.info("reg_questions %d", len(reg_questions)) 567 | logger.info("overlap %d", len(tab_questions.intersection(reg_questions))) 568 | 569 | 570 | def convert_train_jsonl_to_ctxmatch(path: str, out_file: str): 571 | def get_table_string_for_ctx_match(table: dict): # this doesn't use caption 572 | table_text = table["caption"] + " . " 573 | for r in table["rows"]: 574 | table_text += " . ".join([c["value"] for c in r["columns"]]) 575 | table_text += " . " 576 | return table_text 577 | 578 | results = [] 579 | with jsonlines.open(path, mode="r") as jsonl_reader: 580 | for jline in jsonl_reader: 581 | if len(jline["positive_ctxs"]) == 0: 582 | continue 583 | ctx_pos = jline["positive_ctxs"][0] 584 | table_str = get_table_string_for_ctx_match(ctx_pos) 585 | q = jline["question"] 586 | results.append((q, table_str)) 587 | 588 | if len(results) % 1000 == 0: 589 | logger.info("results %d", len(results)) 590 | 591 | shards_sz = 3000 592 | shard = 0 593 | 594 | for s in range(0, len(results), shards_sz): 595 | chunk = results[s : s + shards_sz] 596 | shard_file = out_file + ".shard_{}".format(shard) 597 | with jsonlines.open(shard_file, mode="w") as writer: 598 | logger.info("Saving to %s", shard_file) 599 | for i, item in enumerate(chunk): 600 | writer.write({"id": s + i, "question": item[0], "context": item[1]}) 601 | shard += 1 602 | 603 | 604 | # TODO: tmp copy-paste fix to avoid circular dependency 605 | def regex_match(text, pattern): 606 | """Test if a regex pattern is contained within a text.""" 607 | try: 608 | pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) 609 | except BaseException: 610 | return False 611 | return pattern.search(text) is not None 612 | -------------------------------------------------------------------------------- /dpr/data/reader_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Set of utilities for the Reader model related data processing tasks 10 | """ 11 | 12 | import collections 13 | import glob 14 | import json 15 | import logging 16 | import math 17 | import multiprocessing 18 | import os 19 | import pickle 20 | from functools import partial 21 | from typing import Tuple, List, Dict, Iterable, Optional 22 | 23 | import torch 24 | from torch import Tensor as T 25 | from tqdm import tqdm 26 | 27 | from dpr.utils.data_utils import ( 28 | Tensorizer, 29 | read_serialized_data_from_files, 30 | read_data_from_json_files, 31 | Dataset as DprDataset, 32 | ) 33 | 34 | logger = logging.getLogger() 35 | 36 | 37 | class ReaderPassage(object): 38 | """ 39 | Container to collect and cache all Q&A passages related attributes before generating the reader input 40 | """ 41 | 42 | def __init__( 43 | self, 44 | id=None, 45 | text: str = None, 46 | title: str = None, 47 | score=None, 48 | has_answer: bool = None, 49 | ): 50 | self.id = id 51 | # string passage representations 52 | self.passage_text = text 53 | self.title = title 54 | self.score = score 55 | self.has_answer = has_answer 56 | self.passage_token_ids = None 57 | # offset of the actual passage (i.e. not a question or may be title) in the sequence_ids 58 | self.passage_offset = None 59 | self.answers_spans = None 60 | # passage token ids 61 | self.sequence_ids = None 62 | 63 | def on_serialize(self): 64 | # store only final sequence_ids and the ctx offset 65 | self.sequence_ids = self.sequence_ids.numpy() 66 | self.passage_text = None 67 | self.title = None 68 | self.passage_token_ids = None 69 | 70 | def on_deserialize(self): 71 | self.sequence_ids = torch.tensor(self.sequence_ids) 72 | 73 | 74 | class ReaderSample(object): 75 | """ 76 | Container to collect all Q&A passages data per singe question 77 | """ 78 | 79 | def __init__( 80 | self, 81 | question: str, 82 | answers: List, 83 | positive_passages: List[ReaderPassage] = [], 84 | negative_passages: List[ReaderPassage] = [], 85 | passages: List[ReaderPassage] = [], 86 | ): 87 | self.question = question 88 | self.answers = answers 89 | self.positive_passages = positive_passages 90 | self.negative_passages = negative_passages 91 | self.passages = passages 92 | 93 | def on_serialize(self): 94 | for passage in self.passages + self.positive_passages + self.negative_passages: 95 | passage.on_serialize() 96 | 97 | def on_deserialize(self): 98 | for passage in self.passages + self.positive_passages + self.negative_passages: 99 | passage.on_deserialize() 100 | 101 | 102 | class ExtractiveReaderDataset(torch.utils.data.Dataset): 103 | def __init__( 104 | self, 105 | files: str, 106 | is_train: bool, 107 | gold_passages_src: str, 108 | tensorizer: Tensorizer, 109 | run_preprocessing: bool, 110 | num_workers: int, 111 | ): 112 | self.files = files 113 | self.data = [] 114 | self.is_train = is_train 115 | self.gold_passages_src = gold_passages_src 116 | self.tensorizer = tensorizer 117 | self.run_preprocessing = run_preprocessing 118 | self.num_workers = num_workers 119 | 120 | def __getitem__(self, index): 121 | return self.data[index] 122 | 123 | def __len__(self): 124 | return len(self.data) 125 | 126 | def calc_total_data_len(self): 127 | if not self.data: 128 | self.load_data() 129 | return len(self.data) 130 | 131 | def load_data( 132 | self, 133 | ): 134 | if self.data: 135 | return 136 | 137 | data_files = glob.glob(self.files) 138 | logger.info("Data files: %s", data_files) 139 | if not data_files: 140 | raise RuntimeError("No Data files found") 141 | preprocessed_data_files = self._get_preprocessed_files(data_files) 142 | self.data = read_serialized_data_from_files(preprocessed_data_files) 143 | 144 | def _get_preprocessed_files( 145 | self, 146 | data_files: List, 147 | ): 148 | 149 | serialized_files = [file for file in data_files if file.endswith(".pkl")] 150 | if serialized_files: 151 | return serialized_files 152 | assert len(data_files) == 1, "Only 1 source file pre-processing is supported." 153 | 154 | # data may have been serialized and cached before, try to find ones from same dir 155 | def _find_cached_files(path: str): 156 | dir_path, base_name = os.path.split(path) 157 | base_name = base_name.replace(".json", "") 158 | out_file_prefix = os.path.join(dir_path, base_name) 159 | out_file_pattern = out_file_prefix + "*.pkl" 160 | return glob.glob(out_file_pattern), out_file_prefix 161 | 162 | serialized_files, out_file_prefix = _find_cached_files(data_files[0]) 163 | if serialized_files: 164 | logger.info("Found preprocessed files. %s", serialized_files) 165 | return serialized_files 166 | 167 | logger.info("Data are not preprocessed for reader training. Start pre-processing ...") 168 | 169 | # start pre-processing and save results 170 | def _run_preprocessing(tensorizer: Tensorizer): 171 | # temporarily disable auto-padding to save disk space usage of serialized files 172 | tensorizer.set_pad_to_max(False) 173 | serialized_files = convert_retriever_results( 174 | self.is_train, 175 | data_files[0], 176 | out_file_prefix, 177 | self.gold_passages_src, 178 | self.tensorizer, 179 | num_workers=self.num_workers, 180 | ) 181 | tensorizer.set_pad_to_max(True) 182 | return serialized_files 183 | 184 | if self.run_preprocessing: 185 | serialized_files = _run_preprocessing(self.tensorizer) 186 | # TODO: check if pytorch process group is initialized 187 | # torch.distributed.barrier() 188 | else: 189 | # torch.distributed.barrier() 190 | serialized_files = _find_cached_files(data_files[0]) 191 | return serialized_files 192 | 193 | 194 | SpanPrediction = collections.namedtuple( 195 | "SpanPrediction", 196 | [ 197 | "prediction_text", 198 | "span_score", 199 | "relevance_score", 200 | "passage_index", 201 | "passage_token_ids", 202 | ], 203 | ) 204 | 205 | # configuration for reader model passage selection 206 | ReaderPreprocessingCfg = collections.namedtuple( 207 | "ReaderPreprocessingCfg", 208 | [ 209 | "use_tailing_sep", 210 | "skip_no_positves", 211 | "include_gold_passage", 212 | "gold_page_only_positives", 213 | "max_positives", 214 | "max_negatives", 215 | "min_negatives", 216 | "max_retriever_passages", 217 | ], 218 | ) 219 | 220 | DEFAULT_PREPROCESSING_CFG_TRAIN = ReaderPreprocessingCfg( 221 | use_tailing_sep=False, 222 | skip_no_positves=True, 223 | include_gold_passage=False, # True - for speech Q&A 224 | gold_page_only_positives=True, 225 | max_positives=20, 226 | max_negatives=50, 227 | min_negatives=150, 228 | max_retriever_passages=200, 229 | ) 230 | 231 | DEFAULT_EVAL_PASSAGES = 100 232 | 233 | 234 | def preprocess_retriever_data( 235 | samples: List[Dict], 236 | gold_info_file: Optional[str], 237 | tensorizer: Tensorizer, 238 | cfg: ReaderPreprocessingCfg = DEFAULT_PREPROCESSING_CFG_TRAIN, 239 | is_train_set: bool = True, 240 | ) -> Iterable[ReaderSample]: 241 | """ 242 | Converts retriever results into reader training data. 243 | :param samples: samples from the retriever's json file results 244 | :param gold_info_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ 245 | :param tensorizer: Tensorizer object for text to model input tensors conversions 246 | :param cfg: ReaderPreprocessingCfg object with positive and negative passage selection parameters 247 | :param is_train_set: if the data should be processed as a train set 248 | :return: iterable of ReaderSample objects which can be consumed by the reader model 249 | """ 250 | sep_tensor = tensorizer.get_pair_separator_ids() # separator can be a multi token 251 | gold_passage_map, canonical_questions = _get_gold_ctx_dict(gold_info_file) if gold_info_file else ({}, {}) 252 | 253 | no_positive_passages = 0 254 | positives_from_gold = 0 255 | 256 | def create_reader_sample_ids(sample: ReaderPassage, question: str): 257 | question_and_title = tensorizer.text_to_tensor(sample.title, title=question, add_special_tokens=True) 258 | if sample.passage_token_ids is None: 259 | sample.passage_token_ids = tensorizer.text_to_tensor(sample.passage_text, add_special_tokens=False) 260 | 261 | all_concatenated, shift = _concat_pair( 262 | question_and_title, 263 | sample.passage_token_ids, 264 | tailing_sep=sep_tensor if cfg.use_tailing_sep else None, 265 | ) 266 | 267 | sample.sequence_ids = all_concatenated 268 | sample.passage_offset = shift 269 | assert shift > 1 270 | if sample.has_answer and is_train_set: 271 | sample.answers_spans = [(span[0] + shift, span[1] + shift) for span in sample.answers_spans] 272 | return sample 273 | 274 | for sample in samples: 275 | question = sample["question"] 276 | question_txt = sample["query_text"] if "query_text" in sample else question 277 | 278 | if canonical_questions and question_txt in canonical_questions: 279 | question_txt = canonical_questions[question_txt] 280 | 281 | positive_passages, negative_passages = _select_reader_passages( 282 | sample, 283 | question_txt, 284 | tensorizer, 285 | gold_passage_map, 286 | cfg.gold_page_only_positives, 287 | cfg.max_positives, 288 | cfg.max_negatives, 289 | cfg.min_negatives, 290 | cfg.max_retriever_passages, 291 | cfg.include_gold_passage, 292 | is_train_set, 293 | ) 294 | # create concatenated sequence ids for each passage and adjust answer spans 295 | positive_passages = [create_reader_sample_ids(s, question) for s in positive_passages] 296 | negative_passages = [create_reader_sample_ids(s, question) for s in negative_passages] 297 | 298 | if is_train_set and len(positive_passages) == 0: 299 | no_positive_passages += 1 300 | if cfg.skip_no_positves: 301 | continue 302 | 303 | if next(iter(ctx for ctx in positive_passages if ctx.score == -1), None): 304 | positives_from_gold += 1 305 | 306 | if is_train_set: 307 | yield ReaderSample( 308 | question, 309 | sample["answers"], 310 | positive_passages=positive_passages, 311 | negative_passages=negative_passages, 312 | ) 313 | else: 314 | yield ReaderSample(question, sample["answers"], passages=negative_passages) 315 | 316 | logger.info("no positive passages samples: %d", no_positive_passages) 317 | logger.info("positive passages from gold samples: %d", positives_from_gold) 318 | 319 | 320 | def convert_retriever_results( 321 | is_train_set: bool, 322 | input_file: str, 323 | out_file_prefix: str, 324 | gold_passages_file: str, 325 | tensorizer: Tensorizer, 326 | num_workers: int = 8, 327 | ) -> List[str]: 328 | """ 329 | Converts the file with dense retriever(or any compatible file format) results into the reader input data and 330 | serializes them into a set of files. 331 | Conversion splits the input data into multiple chunks and processes them in parallel. Each chunk results are stored 332 | in a separate file with name out_file_prefix.{number}.pkl 333 | :param is_train_set: if the data should be processed for a train set (i.e. with answer span detection) 334 | :param input_file: path to a json file with data to convert 335 | :param out_file_prefix: output path prefix. 336 | :param gold_passages_file: optional path for the 'gold passages & questions' file. Required to get best results for NQ 337 | :param tensorizer: Tensorizer object for text to model input tensors conversions 338 | :param num_workers: the number of parallel processes for conversion 339 | :return: names of files with serialized results 340 | """ 341 | with open(input_file, "r", encoding="utf-8") as f: 342 | samples = json.loads("".join(f.readlines())) 343 | logger.info("Loaded %d questions + retrieval results from %s", len(samples), input_file) 344 | workers = multiprocessing.Pool(num_workers) 345 | ds_size = len(samples) 346 | step = max(math.ceil(ds_size / num_workers), 1) 347 | chunks = [samples[i : i + step] for i in range(0, ds_size, step)] 348 | chunks = [(i, chunks[i]) for i in range(len(chunks))] 349 | 350 | logger.info("Split data into %d chunks", len(chunks)) 351 | 352 | processed = 0 353 | _parse_batch = partial( 354 | _preprocess_reader_samples_chunk, 355 | out_file_prefix=out_file_prefix, 356 | gold_passages_file=gold_passages_file, 357 | tensorizer=tensorizer, 358 | is_train_set=is_train_set, 359 | ) 360 | serialized_files = [] 361 | for file_name in workers.map(_parse_batch, chunks): 362 | processed += 1 363 | serialized_files.append(file_name) 364 | logger.info("Chunks processed %d", processed) 365 | logger.info("Data saved to %s", file_name) 366 | logger.info("Preprocessed data stored in %s", serialized_files) 367 | return serialized_files 368 | 369 | 370 | def get_best_spans( 371 | tensorizer: Tensorizer, 372 | start_logits: List, 373 | end_logits: List, 374 | ctx_ids: List, 375 | max_answer_length: int, 376 | passage_idx: int, 377 | relevance_score: float, 378 | top_spans: int = 1, 379 | ) -> List[SpanPrediction]: 380 | """ 381 | Finds the best answer span for the extractive Q&A model 382 | """ 383 | scores = [] 384 | for (i, s) in enumerate(start_logits): 385 | for (j, e) in enumerate(end_logits[i : i + max_answer_length]): 386 | scores.append(((i, i + j), s + e)) 387 | 388 | scores = sorted(scores, key=lambda x: x[1], reverse=True) 389 | 390 | chosen_span_intervals = [] 391 | best_spans = [] 392 | 393 | for (start_index, end_index), score in scores: 394 | assert start_index <= end_index 395 | length = end_index - start_index + 1 396 | assert length <= max_answer_length 397 | 398 | if any( 399 | [ 400 | start_index <= prev_start_index <= prev_end_index <= end_index 401 | or prev_start_index <= start_index <= end_index <= prev_end_index 402 | for (prev_start_index, prev_end_index) in chosen_span_intervals 403 | ] 404 | ): 405 | continue 406 | 407 | # extend bpe subtokens to full tokens 408 | start_index, end_index = _extend_span_to_full_words(tensorizer, ctx_ids, (start_index, end_index)) 409 | 410 | predicted_answer = tensorizer.to_string(ctx_ids[start_index : end_index + 1]) 411 | best_spans.append(SpanPrediction(predicted_answer, score, relevance_score, passage_idx, ctx_ids)) 412 | chosen_span_intervals.append((start_index, end_index)) 413 | 414 | if len(chosen_span_intervals) == top_spans: 415 | break 416 | return best_spans 417 | 418 | 419 | def _select_reader_passages( 420 | sample: Dict, 421 | question: str, 422 | tensorizer: Tensorizer, 423 | gold_passage_map: Optional[Dict[str, ReaderPassage]], 424 | gold_page_only_positives: bool, 425 | max_positives: int, 426 | max1_negatives: int, 427 | max2_negatives: int, 428 | max_retriever_passages: int, 429 | include_gold_passage: bool, 430 | is_train_set: bool, 431 | ) -> Tuple[List[ReaderPassage], List[ReaderPassage]]: 432 | answers = sample["answers"] 433 | 434 | ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]][0:max_retriever_passages] 435 | answers_token_ids = [tensorizer.text_to_tensor(a, add_special_tokens=False) for a in answers] 436 | 437 | if is_train_set: 438 | positive_samples = list(filter(lambda ctx: ctx.has_answer, ctxs)) 439 | negative_samples = list(filter(lambda ctx: not ctx.has_answer, ctxs)) 440 | else: 441 | positive_samples = [] 442 | negative_samples = ctxs 443 | 444 | positive_ctxs_from_gold_page = ( 445 | list( 446 | filter( 447 | lambda ctx: _is_from_gold_wiki_page(gold_passage_map, ctx.title, question), 448 | positive_samples, 449 | ) 450 | ) 451 | if gold_page_only_positives and gold_passage_map 452 | else [] 453 | ) 454 | 455 | def find_answer_spans(ctx: ReaderPassage): 456 | if ctx.has_answer: 457 | if ctx.passage_token_ids is None: 458 | ctx.passage_token_ids = tensorizer.text_to_tensor(ctx.passage_text, add_special_tokens=False) 459 | 460 | answer_spans = [ 461 | _find_answer_positions(ctx.passage_token_ids, answers_token_ids[i]) for i in range(len(answers)) 462 | ] 463 | 464 | # flatten spans list 465 | answer_spans = [item for sublist in answer_spans for item in sublist] 466 | answers_spans = list(filter(None, answer_spans)) 467 | ctx.answers_spans = answers_spans 468 | 469 | if not answers_spans: 470 | logger.warning( 471 | "No answer found in passage id=%s text=%s, answers=%s, question=%s", 472 | ctx.id, 473 | "", # ctx.passage_text 474 | answers, 475 | question, 476 | ) 477 | ctx.has_answer = bool(answers_spans) 478 | return ctx 479 | 480 | # check if any of the selected ctx+ has answer spans 481 | selected_positive_ctxs = list( 482 | filter( 483 | lambda ctx: ctx.has_answer, 484 | [find_answer_spans(ctx) for ctx in positive_ctxs_from_gold_page], 485 | ) 486 | ) 487 | 488 | if not selected_positive_ctxs: # fallback to positive ctx not from gold pages 489 | selected_positive_ctxs = list( 490 | filter( 491 | lambda ctx: ctx.has_answer, 492 | [find_answer_spans(ctx) for ctx in positive_samples], 493 | ) 494 | )[0:max_positives] 495 | 496 | # optionally include gold passage itself if it is still not in the positives list 497 | if include_gold_passage and question in gold_passage_map: 498 | gold_passage = gold_passage_map[question] 499 | included_gold_passage = next( 500 | iter(ctx for ctx in selected_positive_ctxs if ctx.passage_text == gold_passage.passage_text), 501 | None, 502 | ) 503 | if not included_gold_passage: 504 | gold_passage.has_answer = True 505 | gold_passage = find_answer_spans(gold_passage) 506 | if not gold_passage.has_answer: 507 | logger.warning("No answer found in gold passage: %s", gold_passage) 508 | else: 509 | selected_positive_ctxs.append(gold_passage) 510 | 511 | max_negatives = ( 512 | min(max(10 * len(selected_positive_ctxs), max1_negatives), max2_negatives) 513 | if is_train_set 514 | else DEFAULT_EVAL_PASSAGES 515 | ) 516 | negative_samples = negative_samples[0:max_negatives] 517 | return selected_positive_ctxs, negative_samples 518 | 519 | 520 | def _find_answer_positions(ctx_ids: T, answer: T) -> List[Tuple[int, int]]: 521 | c_len = ctx_ids.size(0) 522 | a_len = answer.size(0) 523 | answer_occurences = [] 524 | for i in range(0, c_len - a_len + 1): 525 | if (answer == ctx_ids[i : i + a_len]).all(): 526 | answer_occurences.append((i, i + a_len - 1)) 527 | return answer_occurences 528 | 529 | 530 | def _concat_pair(t1: T, t2: T, middle_sep: T = None, tailing_sep: T = None): 531 | middle = [middle_sep] if middle_sep else [] 532 | r = [t1] + middle + [t2] + ([tailing_sep] if tailing_sep else []) 533 | return torch.cat(r, dim=0), t1.size(0) + len(middle) 534 | 535 | 536 | def _get_gold_ctx_dict(file: str) -> Tuple[Dict[str, ReaderPassage], Dict[str, str]]: 537 | gold_passage_infos = {} # question|question_tokens -> ReaderPassage (with title and gold ctx) 538 | 539 | # original NQ dataset has 2 forms of same question - original, and tokenized. 540 | # Tokenized form is not fully consisted with the original question if tokenized by some encoder tokenizers 541 | # Specifically, this is the case for the BERT tokenizer. 542 | # Depending of which form was used for retriever training and results generation, it may be useful to convert 543 | # all questions to the canonical original representation. 544 | original_questions = {} # question from tokens -> original question (NQ only) 545 | 546 | with open(file, "r", encoding="utf-8") as f: 547 | logger.info("Reading file %s" % file) 548 | data = json.load(f)["data"] 549 | 550 | for sample in data: 551 | question = sample["question"] 552 | question_from_tokens = sample["question_tokens"] if "question_tokens" in sample else question 553 | original_questions[question_from_tokens] = question 554 | title = sample["title"].lower() 555 | context = sample["context"] # Note: This one is cased 556 | rp = ReaderPassage(sample["example_id"], text=context, title=title) 557 | if question in gold_passage_infos: 558 | logger.info("Duplicate question %s", question) 559 | rp_exist = gold_passage_infos[question] 560 | logger.info( 561 | "Duplicate question gold info: title new =%s | old title=%s", 562 | title, 563 | rp_exist.title, 564 | ) 565 | logger.info("Duplicate question gold info: new ctx =%s ", context) 566 | logger.info("Duplicate question gold info: old ctx =%s ", rp_exist.passage_text) 567 | gold_passage_infos[question] = rp 568 | gold_passage_infos[question_from_tokens] = rp 569 | return gold_passage_infos, original_questions 570 | 571 | 572 | def _is_from_gold_wiki_page(gold_passage_map: Dict[str, ReaderPassage], passage_title: str, question: str): 573 | gold_info = gold_passage_map.get(question, None) 574 | if gold_info: 575 | return passage_title.lower() == gold_info.title.lower() 576 | return False 577 | 578 | 579 | def _extend_span_to_full_words(tensorizer: Tensorizer, tokens: List[int], span: Tuple[int, int]) -> Tuple[int, int]: 580 | start_index, end_index = span 581 | max_len = len(tokens) 582 | while start_index > 0 and tensorizer.is_sub_word_id(tokens[start_index]): 583 | start_index -= 1 584 | 585 | while end_index < max_len - 1 and tensorizer.is_sub_word_id(tokens[end_index + 1]): 586 | end_index += 1 587 | 588 | return start_index, end_index 589 | 590 | 591 | def _preprocess_reader_samples_chunk( 592 | samples: List, 593 | out_file_prefix: str, 594 | gold_passages_file: str, 595 | tensorizer: Tensorizer, 596 | is_train_set: bool, 597 | ) -> str: 598 | chunk_id, samples = samples 599 | logger.info("Start batch %d", len(samples)) 600 | iterator = preprocess_retriever_data( 601 | samples, 602 | gold_passages_file, 603 | tensorizer, 604 | is_train_set=is_train_set, 605 | ) 606 | 607 | results = [] 608 | 609 | iterator = tqdm(iterator) 610 | for i, r in enumerate(iterator): 611 | r.on_serialize() 612 | results.append(r) 613 | 614 | out_file = out_file_prefix + "." + str(chunk_id) + ".pkl" 615 | with open(out_file, mode="wb") as f: 616 | logger.info("Serialize %d results to %s", len(results), out_file) 617 | pickle.dump(results, f) 618 | return out_file 619 | --------------------------------------------------------------------------------