├── dapr ├── __init__.py ├── exps │ ├── __init__.py │ ├── coref │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── spladev2.py │ │ │ ├── colbertv2.py │ │ │ ├── dragon_plus.py │ │ │ └── base.py │ │ ├── spladev2.py │ │ ├── dragon_plus.py │ │ ├── colbertv2.py │ │ └── shared_pipeline.py │ ├── keyphrases │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── spladev2.py │ │ │ ├── colbertv2.py │ │ │ ├── dragon_plus.py │ │ │ └── base.py │ │ ├── spladev2.py │ │ ├── dragon_plus.py │ │ ├── colbertv2.py │ │ └── shared_pipeline.py │ ├── passage_only │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── bm25.py │ │ │ ├── retromae.py │ │ │ ├── spladev2.py │ │ │ ├── colbertv2.py │ │ │ ├── base.py │ │ │ ├── dragon_plus.py │ │ │ └── jinav2.py │ │ ├── bm25.py │ │ ├── jinav2.py │ │ ├── retromae.py │ │ ├── spladev2.py │ │ ├── dragon_plus.py │ │ └── colbertv2.py │ ├── bm25_doc_retrieval │ │ ├── __init__.py │ │ ├── args.py │ │ └── pipeline.py │ ├── bm25_doc_passage_fusion │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── bm25.py │ │ │ ├── retromae.py │ │ │ ├── spladev2.py │ │ │ ├── colbertv2.py │ │ │ ├── dragon_plus.py │ │ │ └── base.py │ │ ├── bm25.py │ │ ├── retromae.py │ │ ├── spladev2.py │ │ ├── colbertv2.py │ │ ├── dragon_plus.py │ │ └── shared_pipeline.py │ ├── bm25_doc_passage_hierarchy │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── colbertv2.py │ │ │ ├── spladev2.py │ │ │ ├── dragon_plus.py │ │ │ └── base.py │ │ ├── spladev2.py │ │ ├── dragon_plus.py │ │ ├── colbertv2.py │ │ └── shared_pipeline.py │ ├── doc_retrieval_with_titles │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── spladev2.py │ │ │ ├── dragon_plus.py │ │ │ └── colbertv2.py │ │ ├── spladev2.py │ │ ├── dragon_plus.py │ │ ├── colbertv2.py │ │ └── shared_pipeline.py │ ├── jinav2_doc_passage_fusion │ │ ├── __init__.py │ │ ├── args │ │ │ ├── __init__.py │ │ │ ├── spladev2.py │ │ │ ├── colbertv2.py │ │ │ ├── dragon_plus.py │ │ │ └── base.py │ │ ├── spladev2.py │ │ ├── colbertv2.py │ │ ├── dragon_plus.py │ │ └── shared_pipeline.py │ └── prepending_titles │ │ ├── args │ │ ├── __init__.py │ │ ├── base.py │ │ ├── bm25.py │ │ ├── colbertv2.py │ │ ├── spladev2.py │ │ └── dragon_plus.py │ │ ├── bm25.py │ │ ├── spladev2.py │ │ ├── dragon_plus.py │ │ └── colbertv2.py ├── annotators │ ├── __init__.py │ ├── base.py │ └── pke.py ├── datasets │ ├── __init__.py │ ├── conditionalqa.py │ ├── tagged_conditionalqa.py │ ├── base.py │ └── genomics.py ├── retrievers │ ├── __init__.py │ ├── sparse.py │ ├── late_interaction.py │ ├── dense.py │ └── bm25.py ├── dm.py ├── fusion.py └── dataloader.py ├── imgs └── motivative-example.png ├── scripts └── dgx2 │ └── exps │ ├── jinav2_doc_passage_fusion │ ├── dev │ │ ├── aio.sh │ │ ├── spladev2.sh │ │ ├── colbertv2.sh │ │ └── dragon_plus.sh │ ├── spladev2.sh │ ├── colbertv2.sh │ └── dragon_plus.sh │ ├── passage_only │ ├── retromae.sh │ ├── dev │ │ ├── spladev2.sh │ │ ├── colbertv2.sh │ │ └── dragon_plus.sh │ ├── bm25.sh │ ├── jinav2.sh │ ├── spladev2.sh │ ├── colbertv2.sh │ └── dragon_plus.sh │ ├── prepending_titles │ ├── bm25.sh │ ├── spladev2.sh │ ├── dragon_plus.sh │ └── colbertv2.sh │ ├── bm25_doc_passage_hierarchy │ ├── dev │ │ ├── spladev2.sh │ │ ├── colbertv2.sh │ │ └── dragon_plus.sh │ ├── spladev2.sh │ ├── colbertv2.sh │ └── dragon_plus.sh │ ├── README.md │ ├── bm25_doc_passage_fusion │ ├── retromae.sh │ ├── dev │ │ ├── colbertv2.sh │ │ ├── spladev2.sh │ │ └── dragon_plus.sh │ ├── bm25.sh │ ├── colbertv2.sh │ ├── dragon_plus.sh │ └── spladev2.sh │ ├── bm25_doc_retrieval │ └── run.sh │ ├── coref │ ├── spladev2.sh │ ├── colbertv2.sh │ └── dragon_plus.sh │ ├── doc_retrieval_with_titles │ ├── spladev2.sh │ ├── dragon_plus.sh │ └── colbertv2.sh │ └── keyphrases │ ├── colbertv2.sh │ ├── spladev2.sh │ └── dragon_plus.sh ├── NOTICE.txt ├── requirements.txt ├── setup.py ├── .gitignore ├── environment.yml ├── README.md └── LICENSE /dapr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/annotators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/coref/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/coref/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/args/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/motivative-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UKPLab/acl2024-dapr/HEAD/imgs/motivative-example.png -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/aio.sh: -------------------------------------------------------------------------------- 1 | bash scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/dragon_plus.sh 2 | bash scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/spladev2.sh 3 | bash scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/colbertv2.sh 4 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | Copyright 2023 3 | Ubiquitous Knowledge Processing (UKP) Lab 4 | Technische Universität Darmstadt 5 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /dapr/dm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from enum import Enum 3 | 4 | 5 | class RetrievalLevel(str, Enum): 6 | paragraph = "paragraph" 7 | document = "document" 8 | 9 | 10 | class ParagraphSeparator(str, Enum): 11 | blank = "blank" 12 | newline = "newline" 13 | 14 | @property 15 | def string(self) -> str: 16 | return {self.blank: " ", self.newline: "\n"}[self] 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ColBERT @ git+https://github.com/stanford-futuredata/ColBERT.git@83658dcab6b48ec089c186508355b55efb6e2ba8 2 | ujson==5.7.0 3 | datasets==2.16.1 4 | more-itertools==9.1.0 5 | matplotlib==3.7.4 6 | pytrec-eval==0.5 7 | pke @ git+https://github.com/boudinfl/pke.git@ebd6e5754b4156a61a4ec6c4c283e821d11a36be 8 | ir-datasets==0.5.4 9 | pyserini==0.21.0 10 | transformers==4.37.2 11 | clddp==0.0.8 12 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/bm25.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.bm25_doc_passage_fusion.args.bm25 import ( 2 | BM25DocBM25PassageFusionArguments, 3 | ) 4 | from dapr.exps.bm25_doc_passage_fusion.shared_pipeline import ( 5 | run_bm25_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_bm25_doc_passage_fusion( 10 | arguments_class=BM25DocBM25PassageFusionArguments, 11 | passage_retriever_name="bm25", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/retromae.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.bm25_doc_passage_fusion.args.retromae import ( 2 | RetroMAEBM25DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.bm25_doc_passage_fusion.shared_pipeline import ( 5 | run_bm25_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_bm25_doc_passage_fusion( 10 | arguments_class=RetroMAEBM25DocPassageFusionArguments, 11 | passage_retriever_name="retromae", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/spladev2.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.bm25_doc_passage_fusion.args.spladev2 import ( 2 | SPLADEv2BM25DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.bm25_doc_passage_fusion.shared_pipeline import ( 5 | run_bm25_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_bm25_doc_passage_fusion( 10 | arguments_class=SPLADEv2BM25DocPassageFusionArguments, 11 | passage_retriever_name="spladev2", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.bm25_doc_passage_fusion.args.colbertv2 import ( 2 | CoLBERTv2BM25DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.bm25_doc_passage_fusion.shared_pipeline import ( 5 | run_bm25_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_bm25_doc_passage_fusion( 10 | arguments_class=CoLBERTv2BM25DocPassageFusionArguments, 11 | passage_retriever_name="colbertv2", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.bm25_doc_passage_fusion.args.dragon_plus import ( 2 | DRAGONPlusBM25DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.bm25_doc_passage_fusion.shared_pipeline import ( 5 | run_bm25_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_bm25_doc_passage_fusion( 10 | arguments_class=DRAGONPlusBM25DocPassageFusionArguments, 11 | passage_retriever_name="dragon_plus", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/spladev2.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.jinav2_doc_passage_fusion.args.spladev2 import ( 2 | SPLADEv2JinaV2DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.jinav2_doc_passage_fusion.shared_pipeline import ( 5 | run_jina_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_jina_doc_passage_fusion( 10 | arguments_class=SPLADEv2JinaV2DocPassageFusionArguments, 11 | passage_retriever_name="spladev2", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.jinav2_doc_passage_fusion.args.colbertv2 import ( 2 | CoLBERTv2JinaV2DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.jinav2_doc_passage_fusion.shared_pipeline import ( 5 | run_jina_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_jina_doc_passage_fusion( 10 | arguments_class=CoLBERTv2JinaV2DocPassageFusionArguments, 11 | passage_retriever_name="colbertv2", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dapr.exps.jinav2_doc_passage_fusion.args.dragon_plus import ( 2 | DRAGONPlusJinaV2DocPassageFusionArguments, 3 | ) 4 | from dapr.exps.jinav2_doc_passage_fusion.shared_pipeline import ( 5 | run_jina_doc_passage_fusion, 6 | ) 7 | 8 | if __name__ == "__main__": 9 | run_jina_doc_passage_fusion( 10 | arguments_class=DRAGONPlusJinaV2DocPassageFusionArguments, 11 | passage_retriever_name="dragon_plus", 12 | ) 13 | -------------------------------------------------------------------------------- /dapr/exps/coref/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.coref.args.base import CorefArguments 5 | 6 | 7 | @dataclass 8 | class SPLADEvCorefArguments(CorefArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/coref/spladev2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print(parse_cli(SPLADEvCorefArguments).output_dir) # For creating the logging path 15 | -------------------------------------------------------------------------------- /dapr/exps/coref/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.coref.args.base import CorefArguments 5 | 6 | 7 | @dataclass 8 | class ColBERTv2CorefArguments(CorefArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/coref/colbertv2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(ColBERTv2CorefArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/coref/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.coref.args.base import CorefArguments 5 | 6 | 7 | @dataclass 8 | class DRAGONPlusCorefArguments(CorefArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/coref/dragon_plus", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(DRAGONPlusCorefArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/retromae.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | 3 | dataset="ConditionalQA" 4 | export DATA_DIR="data" 5 | export DATASET_PATH="$DATA_DIR/$dataset" 6 | export CLI_ARGS=" 7 | --data_dir=$DATASET_PATH 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.retromae $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | nohup torchrun --nproc_per_node=2 --master_port=29501 -m dapr.exps.passage_only.retromae $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /dapr/exps/keyphrases/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.keyphrases.args.base import KeyphrasesArguments 5 | 6 | 7 | @dataclass 8 | class SPLADEvKeyphrasesArguments(KeyphrasesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/keyphrases/spladev2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(SPLADEvKeyphrasesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/bm25.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.passage_only.args.base import PassageOnlyArguments 5 | 6 | 7 | @dataclass 8 | class BM25PassageOnlyArguments(PassageOnlyArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/passage_only/bm25", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(BM25PassageOnlyArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/coref/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | 6 | 7 | @dataclass 8 | class CorefArguments(AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn): 9 | data_dir: Optional[str] = None 10 | split: Split = Split.test 11 | topk: int = 1000 12 | per_device_eval_batch_size: int = 32 13 | fp16: bool = True 14 | 15 | def __post_init__(self) -> None: 16 | super().__post_init__() 17 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.keyphrases.args.base import KeyphrasesArguments 5 | 6 | 7 | @dataclass 8 | class ColBERTv2KeyphrasesArguments(KeyphrasesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/keyphrases/colbertv2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(ColBERTv2KeyphrasesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.keyphrases.args.base import KeyphrasesArguments 5 | 6 | 7 | @dataclass 8 | class DRAGONPlusKeyphrasesArguments(KeyphrasesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/keyphrases/dragon_plus", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(DRAGONPlusKeyphrasesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/retromae.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.passage_only.args.base import PassageOnlyArguments 5 | 6 | 7 | @dataclass 8 | class RetroMAEPassageOnlyArguments(PassageOnlyArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/passage_only/retromae", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(RetroMAEPassageOnlyArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.passage_only.args.base import PassageOnlyArguments 5 | 6 | 7 | @dataclass 8 | class SPLADEv2PassageOnlyArguments(PassageOnlyArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/passage_only/spladev2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(SPLADEv2PassageOnlyArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.passage_only.args.base import PassageOnlyArguments 5 | 6 | 7 | @dataclass 8 | class ColBERTv2PassageOnlyArguments(PassageOnlyArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/passage_only/colbertv2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(ColBERTv2PassageOnlyArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/dev/spladev2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | export DATA_DIR="data" 5 | export DATASET_PATH="$DATA_DIR/MSMARCO" 6 | export CLI_ARGS=" 7 | --data_dir=$DATASET_PATH 8 | --split=dev 9 | " 10 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.spladev2 $CLI_ARGS) 11 | mkdir -p $OUTPUT_DIR 12 | export LOG_PATH="$OUTPUT_DIR/logging.log" 13 | echo "Logging file path: $LOG_PATH" 14 | setsid nohup torchrun --nproc_per_node=2 --master_port=29501 -m dapr.exps.passage_only.spladev2 $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | 6 | 7 | @dataclass 8 | class PassageOnlyArguments(AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn): 9 | data_dir: Optional[str] = None 10 | split: Split = Split.test 11 | topk: int = 1000 12 | per_device_eval_batch_size: int = 32 13 | fp16: bool = True 14 | 15 | def __post_init__(self) -> None: 16 | super().__post_init__() 17 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.passage_only.args.base import PassageOnlyArguments 5 | 6 | 7 | @dataclass 8 | class DRAGONPlusPassageOnlyArguments(PassageOnlyArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/passage_only/dragon_plus", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(DRAGONPlusPassageOnlyArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/dev/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | export DATA_DIR="data" 5 | export DATASET_PATH="$DATA_DIR/MSMARCO" 6 | export CLI_ARGS=" 7 | --data_dir=$DATASET_PATH 8 | --split=dev 9 | " 10 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.colbertv2 $CLI_ARGS) 11 | mkdir -p $OUTPUT_DIR 12 | export LOG_PATH="$OUTPUT_DIR/logging.log" 13 | echo "Logging file path: $LOG_PATH" 14 | setsid nohup torchrun --nproc_per_node=2 --master_port=29501 -m dapr.exps.passage_only.colbertv2 $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/bm25.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | for dataset in ${datasets[@]} 3 | do 4 | export DATA_DIR="data" 5 | export DATASET_PATH="$DATA_DIR/$dataset" 6 | export CLI_ARGS=" 7 | --data_dir=$DATASET_PATH 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.bm25 $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | python -m dapr.exps.passage_only.bm25 $CLI_ARGS > $LOG_PATH 14 | done -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | 6 | 7 | @dataclass 8 | class PrependingTitlesArguments(AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn): 9 | data_dir: Optional[str] = None 10 | split: Split = Split.test 11 | topk: int = 1000 12 | per_device_eval_batch_size: int = 32 13 | fp16: bool = True 14 | 15 | def __post_init__(self) -> None: 16 | super().__post_init__() 17 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/args/bm25.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.prepending_titles.args.base import PrependingTitlesArguments 5 | 6 | 7 | @dataclass 8 | class BM25PrependingTitlesArguments(PrependingTitlesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/prepending_titles/bm25", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(BM25PrependingTitlesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/prepending_titles/bm25.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | for dataset in ${datasets[@]} 3 | do 4 | export DATA_DIR="data" 5 | export DATASET_PATH="$DATA_DIR/$dataset" 6 | export CLI_ARGS=" 7 | --data_dir=$DATASET_PATH 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.prepending_titles.args.bm25 $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | python -m dapr.exps.prepending_titles.bm25 $CLI_ARGS > $LOG_PATH 14 | done -------------------------------------------------------------------------------- /dapr/exps/coref/spladev2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.sparse import SPLADEv2 4 | from dapr.exps.coref.shared_pipeline import run_coref 5 | from dapr.exps.coref.args.spladev2 import SPLADEvCorefArguments 6 | 7 | if __name__ == "__main__": 8 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 9 | initialize_ddp() 10 | args = parse_cli(SPLADEvCorefArguments) 11 | retriever = SPLADEv2() 12 | run_coref(args=args, retriever=retriever, retriever_name="spladev2") 13 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_hierarchy/dev/spladev2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | dataset="MSMARCO" 3 | export DATA_DIR="data" 4 | export DATASET_PATH="$DATA_DIR/$dataset" 5 | export CLI_ARGS=" 6 | --data_dir=$DATASET_PATH 7 | --split=dev 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_hierarchy.args.spladev2 $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | setsid nohup torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.bm25_doc_passage_hierarchy.spladev2 $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.prepending_titles.args.base import PrependingTitlesArguments 5 | 6 | 7 | @dataclass 8 | class ColBERTv2PrependingTitlesArguments(PrependingTitlesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/prepending_titles/colbertv2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(ColBERTv2PrependingTitlesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.prepending_titles.args.base import PrependingTitlesArguments 5 | 6 | 7 | @dataclass 8 | class SPLADEv2lusPrependingTitlesArguments(PrependingTitlesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/prepending_titles/spladev2", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(SPLADEv2lusPrependingTitlesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_hierarchy/dev/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | dataset="MSMARCO" 3 | export DATA_DIR="data" 4 | export DATASET_PATH="$DATA_DIR/$dataset" 5 | export CLI_ARGS=" 6 | --data_dir=$DATASET_PATH 7 | --split=dev 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_hierarchy.args.colbertv2 $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | setsid nohup torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.bm25_doc_passage_hierarchy.colbertv2 $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.prepending_titles.args.base import PrependingTitlesArguments 5 | 6 | 7 | @dataclass 8 | class DRAGONPlusPrependingTitlesArguments(PrependingTitlesArguments): 9 | def build_output_dir(self) -> str: 10 | return os.path.join("exps/prepending_titles/dragon_plus", self.run_name) 11 | 12 | 13 | if __name__ == "__main__": 14 | print( 15 | parse_cli(DRAGONPlusPrependingTitlesArguments).output_dir 16 | ) # For creating the logging path 17 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_hierarchy/dev/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | dataset="MSMARCO" 3 | export DATA_DIR="data" 4 | export DATASET_PATH="$DATA_DIR/$dataset" 5 | export CLI_ARGS=" 6 | --data_dir=$DATASET_PATH 7 | --split=dev 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_hierarchy.args.dragon_plus $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | setsid nohup torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.bm25_doc_passage_hierarchy.dragon_plus $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /dapr/exps/coref/dragon_plus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.dense import DRAGONPlus 4 | from dapr.exps.coref.shared_pipeline import run_coref 5 | from dapr.exps.coref.args.dragon_plus import DRAGONPlusCorefArguments 6 | 7 | if __name__ == "__main__": 8 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 9 | initialize_ddp() 10 | args = parse_cli(DRAGONPlusCorefArguments) 11 | retriever = DRAGONPlus() 12 | run_coref(args=args, retriever=retriever, retriever_name="dragon_plus") 13 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | 6 | 7 | @dataclass 8 | class DocRetrievalWithTitlesArguments( 9 | AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 10 | ): 11 | data_dir: Optional[str] = None 12 | split: Split = Split.test 13 | topk: int = 1000 14 | per_device_eval_batch_size: int = 32 15 | fp16: bool = True 16 | 17 | def __post_init__(self) -> None: 18 | super().__post_init__() 19 | -------------------------------------------------------------------------------- /dapr/exps/coref/colbertv2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.late_interaction import ColBERTv2 4 | from dapr.exps.coref.shared_pipeline import run_coref 5 | from dapr.exps.coref.args.colbertv2 import ColBERTv2CorefArguments 6 | 7 | if __name__ == "__main__": 8 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 9 | initialize_ddp() 10 | args = parse_cli(ColBERTv2CorefArguments) 11 | retriever = ColBERTv2() 12 | run_coref(args=args, retriever=retriever, retriever_name="colbertv2") 13 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/bm25.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_fusion.args.base import ( 5 | BM25DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class BM25DocBM25PassageFusionArguments(BM25DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_fusion/bm25", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(BM25DocBM25PassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/spladev2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.sparse import SPLADEv2 4 | from dapr.exps.keyphrases.shared_pipeline import run_keyphrases 5 | from dapr.exps.keyphrases.args.spladev2 import SPLADEvKeyphrasesArguments 6 | 7 | if __name__ == "__main__": 8 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 9 | initialize_ddp() 10 | args = parse_cli(SPLADEvKeyphrasesArguments) 11 | retriever = SPLADEv2() 12 | run_keyphrases(args=args, retriever=retriever, retriever_name="spladev2") 13 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/retromae.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_fusion.args.base import ( 5 | BM25DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class RetroMAEBM25DocPassageFusionArguments(BM25DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_fusion/retromae", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(RetroMAEBM25DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_fusion.args.base import ( 5 | BM25DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class SPLADEv2BM25DocPassageFusionArguments(BM25DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_fusion/spladev2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(SPLADEv2BM25DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/dragon_plus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.dense import DRAGONPlus 4 | from dapr.exps.keyphrases.shared_pipeline import run_keyphrases 5 | from dapr.exps.keyphrases.args.dragon_plus import DRAGONPlusKeyphrasesArguments 6 | 7 | if __name__ == "__main__": 8 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 9 | initialize_ddp() 10 | args = parse_cli(DRAGONPlusKeyphrasesArguments) 11 | retriever = DRAGONPlus() 12 | run_keyphrases(args=args, retriever=retriever, retriever_name="dragon_plus") 13 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/README.md: -------------------------------------------------------------------------------- 1 | # Experiment scripts 2 | 3 | This folder contains the experiment scripts for the DAPR paper. With data downloaded/built in the `./data` folder, one can run most scripts directly. 4 | > For `coref` and `keyphrases`, the corresponding data are needed: 5 | > - https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/coref 6 | > - https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/keyphrases 7 | 8 | For the hybrid-search experiments (i.e. `bm25_doc_passage_fusion` and `bm25_doc_passage_hierarchy`), their running depends on the `passage_only` result. Please run the corresponding experiment beforehand. 9 | 10 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_fusion.args.base import ( 5 | BM25DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class CoLBERTv2BM25DocPassageFusionArguments(BM25DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_fusion/colbertv2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(CoLBERTv2BM25DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/colbertv2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.late_interaction import ColBERTv2 4 | from dapr.exps.keyphrases.shared_pipeline import run_keyphrases 5 | from dapr.exps.keyphrases.args.colbertv2 import ColBERTv2KeyphrasesArguments 6 | 7 | if __name__ == "__main__": 8 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 9 | initialize_ddp() 10 | args = parse_cli(ColBERTv2KeyphrasesArguments) 11 | retriever = ColBERTv2() 12 | run_keyphrases(args=args, retriever=retriever, retriever_name="colbertv2") 13 | -------------------------------------------------------------------------------- /dapr/retrievers/sparse.py: -------------------------------------------------------------------------------- 1 | from clddp.retriever import Retriever, RetrieverConfig, Pooling, SimilarityFunction 2 | from clddp.dm import Separator 3 | 4 | 5 | class SPLADEv2(Retriever): 6 | def __init__(self) -> None: 7 | config = RetrieverConfig( 8 | query_model_name_or_path="naver/splade-cocondenser-ensembledistil", 9 | shared_encoder=True, 10 | sep=Separator.blank, 11 | pooling=Pooling.splade, 12 | similarity_function=SimilarityFunction.dot_product, 13 | query_max_length=512, 14 | passage_max_length=512, 15 | ) 16 | super().__init__(config) 17 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_fusion.args.base import ( 5 | BM25DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class DRAGONPlusBM25DocPassageFusionArguments(BM25DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_fusion/dragon_plus", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(DRAGONPlusBM25DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.jinav2_doc_passage_fusion.args.base import ( 5 | JinaV2DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class SPLADEv2JinaV2DocPassageFusionArguments(JinaV2DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/jinav2_doc_passage_fusion/spladev2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(SPLADEv2JinaV2DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/retromae.sh: -------------------------------------------------------------------------------- 1 | dataset="ConditionalQA" 2 | export DATA_DIR="data" 3 | export DATASET_PATH="$DATA_DIR/$dataset" 4 | export CLI_ARGS=" 5 | --data_dir=$DATASET_PATH 6 | --passage_results="$(ls exps/passage_only/retromae/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | " 8 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.retromae $CLI_ARGS) 9 | mkdir -p $OUTPUT_DIR 10 | export LOG_PATH="$OUTPUT_DIR/logging.log" 11 | echo "Logging file path: $LOG_PATH" 12 | nohup python -m dapr.exps.bm25_doc_passage_fusion.retromae $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.doc_retrieval_with_titles.args.base import ( 5 | DocRetrievalWithTitlesArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class SPLADEv2lusDocRetrievalWithTitlesArguments(DocRetrievalWithTitlesArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/doc_retrieval_with_titles/spladev2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(SPLADEv2lusDocRetrievalWithTitlesArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.jinav2_doc_passage_fusion.args.base import ( 5 | JinaV2DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class CoLBERTv2JinaV2DocPassageFusionArguments(JinaV2DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/jinav2_doc_passage_fusion/colbertv2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(CoLBERTv2JinaV2DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_hierarchy.args.base import ( 5 | BM25DocPassageHierarchyArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class CoLBERTv2BM25DocPassageHierarchyArguments(BM25DocPassageHierarchyArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_hierarchy/colbertv2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(CoLBERTv2BM25DocPassageHierarchyArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/args/spladev2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_hierarchy.args.base import ( 5 | BM25DocPassageHierarchyArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class SPLADEv2BM25DocPassageHierarchyArguments(BM25DocPassageHierarchyArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/bm25_doc_passage_hierarchy/spladev2", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(SPLADEv2BM25DocPassageHierarchyArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.doc_retrieval_with_titles.args.base import ( 5 | DocRetrievalWithTitlesArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class DRAGONPlusDocRetrievalWithTitlesArguments(DocRetrievalWithTitlesArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/doc_retrieval_with_titles/dragon_plus", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(DRAGONPlusDocRetrievalWithTitlesArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.jinav2_doc_passage_fusion.args.base import ( 5 | JinaV2DocPassageFusionArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class DRAGONPlusJinaV2DocPassageFusionArguments(JinaV2DocPassageFusionArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join("exps/jinav2_doc_passage_fusion/dragon_plus", self.run_name) 13 | 14 | 15 | if __name__ == "__main__": 16 | print( 17 | parse_cli(DRAGONPlusJinaV2DocPassageFusionArguments).output_dir 18 | ) # For creating the logging path 19 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_retrieval/run.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | for i in {0..4} 3 | do 4 | dataset=${datasets[$i]} 5 | passage_results_path=${passage_results_paths[$i]} 6 | export DATA_DIR="data" 7 | export DATASET_PATH="$DATA_DIR/$dataset" 8 | export CLI_ARGS=" 9 | --data_dir=$DATASET_PATH 10 | " 11 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_retrieval.args $CLI_ARGS) 12 | mkdir -p $OUTPUT_DIR 13 | export LOG_PATH="$OUTPUT_DIR/logging.log" 14 | echo "Logging file path: $LOG_PATH" 15 | python -m dapr.exps.bm25_doc_retrieval.pipeline $CLI_ARGS > $LOG_PATH 16 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/dev/colbertv2.sh: -------------------------------------------------------------------------------- 1 | dataset="MSMARCO" 2 | export DATA_DIR="data" 3 | export DATASET_PATH="$DATA_DIR/$dataset" 4 | export CLI_ARGS=" 5 | --data_dir=$DATASET_PATH 6 | --passage_results="$(ls exps/passage_only/colbertv2/data_dir_data/MSMARCO/split_dev/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | --split=dev 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.colbertv2 $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | setsid nohup python -m dapr.exps.bm25_doc_passage_fusion.colbertv2 $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/dev/spladev2.sh: -------------------------------------------------------------------------------- 1 | dataset="MSMARCO" 2 | export DATA_DIR="data" 3 | export DATASET_PATH="$DATA_DIR/$dataset" 4 | export CLI_ARGS=" 5 | --data_dir=$DATASET_PATH 6 | --passage_results="$(ls exps/passage_only/spladev2/data_dir_data/MSMARCO/split_dev/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | --split=dev 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.spladev2 $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | setsid nohup python -m dapr.exps.bm25_doc_passage_fusion.spladev2 $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/dev/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | dataset="MSMARCO" 2 | export DATA_DIR="data" 3 | export DATASET_PATH="$DATA_DIR/$dataset" 4 | export CLI_ARGS=" 5 | --data_dir=$DATASET_PATH 6 | --passage_results="$(ls exps/passage_only/dragon_plus/data_dir_data/MSMARCO/split_dev/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | --split=dev 8 | " 9 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.dragon_plus $CLI_ARGS) 10 | mkdir -p $OUTPUT_DIR 11 | export LOG_PATH="$OUTPUT_DIR/logging.log" 12 | echo "Logging file path: $LOG_PATH" 13 | setsid nohup python -m dapr.exps.bm25_doc_passage_fusion.dragon_plus $CLI_ARGS > $LOG_PATH & -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/jinav2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | 5 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 6 | for dataset in ${datasets[@]} 7 | do 8 | export DATA_DIR="data" 9 | export DATASET_PATH="$DATA_DIR/$dataset" 10 | export CLI_ARGS=" 11 | --data_dir=$DATASET_PATH 12 | " 13 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.jinav2 $CLI_ARGS) 14 | mkdir -p $OUTPUT_DIR 15 | export LOG_PATH="$OUTPUT_DIR/logging.log" 16 | echo "Logging file path: $LOG_PATH" 17 | torchrun --nproc_per_node=4 --master_port=29511 -m dapr.exps.passage_only.jinav2 $CLI_ARGS > $LOG_PATH 18 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/spladev2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.spladev2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.passage_only.spladev2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/args/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from clddp.utils import parse_cli 4 | from dapr.exps.bm25_doc_passage_hierarchy.args.base import ( 5 | BM25DocPassageHierarchyArguments, 6 | ) 7 | 8 | 9 | @dataclass 10 | class DRAGONPlusBM25DocPassageHierarchyArguments(BM25DocPassageHierarchyArguments): 11 | def build_output_dir(self) -> str: 12 | return os.path.join( 13 | "exps/bm25_doc_passage_hierarchy/dragon_plus", self.run_name 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | print( 19 | parse_cli(DRAGONPlusBM25DocPassageHierarchyArguments).output_dir 20 | ) # For creating the logging path 21 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.colbertv2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.passage_only.colbertv2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/coref/spladev2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "coref/ConditionalQA" "coref/MSMARCO" "coref/NaturalQuestions" "coref/Genomics" "coref/MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.coref.args.spladev2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.coref.spladev2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.dragon_plus $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.passage_only.dragon_plus $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/coref/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "coref/ConditionalQA" "coref/MSMARCO" "coref/NaturalQuestions" "coref/Genomics" "coref/MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.coref.args.colbertv2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.coref.colbertv2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/prepending_titles/spladev2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.prepending_titles.args.spladev2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.prepending_titles.spladev2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /dapr/exps/keyphrases/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Set 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | 6 | 7 | @dataclass 8 | class KeyphrasesArguments(AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn): 9 | data_dir: Optional[str] = None 10 | keyphrases_path: Optional[str] = None 11 | split: Split = Split.test 12 | topk: int = 1000 13 | per_device_eval_batch_size: int = 32 14 | fp16: bool = True 15 | 16 | def __post_init__(self) -> None: 17 | super().__post_init__() 18 | 19 | @property 20 | def escaped_args(self) -> Set[str]: 21 | return {"keyphrases_path"} 22 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/coref/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "coref/ConditionalQA" "coref/MSMARCO" "coref/NaturalQuestions" "coref/Genomics" "coref/MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.coref.args.dragon_plus $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.coref.dragon_plus $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/passage_only/dev/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.dragon_plus $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.passage_only.dragon_plus $CLI_ARGS > $LOG_PATH 17 | done 18 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/prepending_titles/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.prepending_titles.args.dragon_plus $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.prepending_titles.dragon_plus $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/spladev2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | dataset="MSMARCO" 3 | export DATA_DIR="data" 4 | export DATASET_PATH="$DATA_DIR/$dataset" 5 | export CLI_ARGS=" 6 | --data_dir=$DATASET_PATH 7 | --passage_results="$(ls exps/passage_only/spladev2/data_dir_data/MSMARCO/split_dev/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | --split=dev 9 | " 10 | export OUTPUT_DIR=$(python -m dapr.exps.jinav2_doc_passage_fusion.args.spladev2 $CLI_ARGS) 11 | mkdir -p $OUTPUT_DIR 12 | export LOG_PATH="$OUTPUT_DIR/logging.log" 13 | echo "Logging file path: $LOG_PATH" 14 | torchrun --nproc_per_node=4 --master_port=29511 -m dapr.exps.jinav2_doc_passage_fusion.spladev2 $CLI_ARGS -------------------------------------------------------------------------------- /scripts/dgx2/exps/prepending_titles/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.prepending_titles.args.colbertv2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.prepending_titles.colbertv2 $CLI_ARGS > $LOG_PATH 17 | done 18 | 19 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | dataset="MSMARCO" 3 | export DATA_DIR="data" 4 | export DATASET_PATH="$DATA_DIR/$dataset" 5 | export CLI_ARGS=" 6 | --data_dir=$DATASET_PATH 7 | --passage_results="$(ls exps/passage_only/colbertv2/data_dir_data/MSMARCO/split_dev/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | --split=dev 9 | " 10 | export OUTPUT_DIR=$(python -m dapr.exps.jinav2_doc_passage_fusion.args.colbertv2 $CLI_ARGS) 11 | mkdir -p $OUTPUT_DIR 12 | export LOG_PATH="$OUTPUT_DIR/logging.log" 13 | echo "Logging file path: $LOG_PATH" 14 | torchrun --nproc_per_node=4 --master_port=29511 -m dapr.exps.jinav2_doc_passage_fusion.colbertv2 $CLI_ARGS -------------------------------------------------------------------------------- /dapr/exps/passage_only/args/jinav2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from typing import List 4 | from clddp.utils import parse_cli 5 | from dapr.exps.passage_only.args.base import PassageOnlyArguments 6 | 7 | 8 | @dataclass 9 | class JinaV2PassageOnlyArguments(PassageOnlyArguments): 10 | max_length: int = 512 11 | 12 | def get_arguments_from(self) -> List[type]: 13 | return [PassageOnlyArguments, JinaV2PassageOnlyArguments] 14 | 15 | def build_output_dir(self) -> str: 16 | return os.path.join("exps/passage_only/jinav2", self.run_name) 17 | 18 | 19 | if __name__ == "__main__": 20 | print( 21 | parse_cli(JinaV2PassageOnlyArguments).output_dir 22 | ) # For creating the logging path 23 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/doc_retrieval_with_titles/spladev2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.doc_retrieval_with_titles.args.spladev2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.doc_retrieval_with_titles.spladev2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/dev/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | dataset="MSMARCO" 3 | export DATA_DIR="data" 4 | export DATASET_PATH="$DATA_DIR/$dataset" 5 | export CLI_ARGS=" 6 | --data_dir=$DATASET_PATH 7 | --passage_results="$(ls exps/passage_only/dragon_plus/data_dir_data/MSMARCO/split_dev/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | --split=dev 9 | " 10 | export OUTPUT_DIR=$(python -m dapr.exps.jinav2_doc_passage_fusion.args.dragon_plus $CLI_ARGS) 11 | mkdir -p $OUTPUT_DIR 12 | export LOG_PATH="$OUTPUT_DIR/logging.log" 13 | echo "Logging file path: $LOG_PATH" 14 | torchrun --nproc_per_node=4 --master_port=29511 -m dapr.exps.jinav2_doc_passage_fusion.dragon_plus $CLI_ARGS -------------------------------------------------------------------------------- /dapr/retrievers/late_interaction.py: -------------------------------------------------------------------------------- 1 | from clddp.retriever import Retriever, RetrieverConfig, Pooling, SimilarityFunction 2 | from clddp.dm import Separator 3 | 4 | 5 | class ColBERTv2(Retriever): 6 | def __init__( 7 | self, query_max_length: int = 150, passage_max_length: int = 512 8 | ) -> None: 9 | config = RetrieverConfig( 10 | query_model_name_or_path="colbert-ir/colbertv2.0", 11 | shared_encoder=True, 12 | sep=Separator.blank, 13 | pooling=Pooling.no_pooling, 14 | similarity_function=SimilarityFunction.maxsim, 15 | query_max_length=query_max_length, 16 | passage_max_length=passage_max_length, 17 | ) 18 | super().__init__(config) 19 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_hierarchy/spladev2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | 3 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 4 | for dataset in ${datasets[@]} 5 | do 6 | export DATA_DIR="data" 7 | export DATASET_PATH="$DATA_DIR/$dataset" 8 | export CLI_ARGS=" 9 | --data_dir=$DATASET_PATH 10 | --report_passage_weight=0.8 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_hierarchy.args.spladev2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.bm25_doc_passage_hierarchy.spladev2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/doc_retrieval_with_titles/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.doc_retrieval_with_titles.args.dragon_plus $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.doc_retrieval_with_titles.dragon_plus $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_hierarchy/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | 3 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 4 | for dataset in ${datasets[@]} 5 | do 6 | export DATA_DIR="data" 7 | export DATASET_PATH="$DATA_DIR/$dataset" 8 | export CLI_ARGS=" 9 | --data_dir=$DATASET_PATH 10 | --report_passage_weight=0.9 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_hierarchy.args.colbertv2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.bm25_doc_passage_hierarchy.colbertv2 $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/doc_retrieval_with_titles/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export CLI_ARGS=" 10 | --data_dir=$DATASET_PATH 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.doc_retrieval_with_titles.args.colbertv2 $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.doc_retrieval_with_titles.colbertv2 $CLI_ARGS > $LOG_PATH 17 | done 18 | 19 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_hierarchy/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | 3 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 4 | for dataset in ${datasets[@]} 5 | do 6 | export DATA_DIR="data" 7 | export DATASET_PATH="$DATA_DIR/$dataset" 8 | export CLI_ARGS=" 9 | --data_dir=$DATASET_PATH 10 | --report_passage_weight=0.8 11 | " 12 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_hierarchy.args.dragon_plus $CLI_ARGS) 13 | mkdir -p $OUTPUT_DIR 14 | export LOG_PATH="$OUTPUT_DIR/logging.log" 15 | echo "Logging file path: $LOG_PATH" 16 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.bm25_doc_passage_hierarchy.dragon_plus $CLI_ARGS > $LOG_PATH 17 | done -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/spladev2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.sparse import SPLADEv2 4 | from dapr.exps.doc_retrieval_with_titles.shared_pipeline import ( 5 | run_doc_retrieval_with_titles, 6 | ) 7 | from dapr.exps.doc_retrieval_with_titles.args.spladev2 import ( 8 | SPLADEv2lusDocRetrievalWithTitlesArguments, 9 | ) 10 | 11 | if __name__ == "__main__": 12 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 13 | initialize_ddp() 14 | args = parse_cli(SPLADEv2lusDocRetrievalWithTitlesArguments) 15 | retriever = SPLADEv2() 16 | run_doc_retrieval_with_titles( 17 | args=args, retriever=retriever, retriever_name="spladev2" 18 | ) 19 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/dragon_plus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.dense import DRAGONPlus 4 | from dapr.exps.doc_retrieval_with_titles.shared_pipeline import ( 5 | run_doc_retrieval_with_titles, 6 | ) 7 | from dapr.exps.doc_retrieval_with_titles.args.dragon_plus import ( 8 | DRAGONPlusDocRetrievalWithTitlesArguments, 9 | ) 10 | 11 | if __name__ == "__main__": 12 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 13 | initialize_ddp() 14 | args = parse_cli(DRAGONPlusDocRetrievalWithTitlesArguments) 15 | retriever = DRAGONPlus() 16 | run_doc_retrieval_with_titles( 17 | args=args, retriever=retriever, retriever_name="dragon_plus" 18 | ) 19 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/keyphrases/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export KEYPHRASES_PATH="$DATA_DIR/keyphrases/$dataset/did2dsum.jsonl" 10 | export CLI_ARGS=" 11 | --data_dir=$DATASET_PATH 12 | --keyphrases_path=$KEYPHRASES_PATH 13 | " 14 | export OUTPUT_DIR=$(python -m dapr.exps.keyphrases.args.colbertv2 $CLI_ARGS) 15 | mkdir -p $OUTPUT_DIR 16 | export LOG_PATH="$OUTPUT_DIR/logging.log" 17 | echo "Logging file path: $LOG_PATH" 18 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.keyphrases.colbertv2 $CLI_ARGS > $LOG_PATH 19 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/keyphrases/spladev2.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export KEYPHRASES_PATH="$DATA_DIR/keyphrases/$dataset/did2dsum.jsonl" 10 | export CLI_ARGS=" 11 | --data_dir=$DATASET_PATH 12 | --keyphrases_path=$KEYPHRASES_PATH 13 | " 14 | export OUTPUT_DIR=$(python -m dapr.exps.keyphrases.args.spladev2 $CLI_ARGS) 15 | mkdir -p $OUTPUT_DIR 16 | export LOG_PATH="$OUTPUT_DIR/logging.log" 17 | echo "Logging file path: $LOG_PATH" 18 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.keyphrases.spladev2 $CLI_ARGS > $LOG_PATH 19 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/keyphrases/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export NCCL_DEBUG="INFO" 2 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 3 | 4 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 5 | for dataset in ${datasets[@]} 6 | do 7 | export DATA_DIR="data" 8 | export DATASET_PATH="$DATA_DIR/$dataset" 9 | export KEYPHRASES_PATH="$DATA_DIR/keyphrases/$dataset/did2dsum.jsonl" 10 | export CLI_ARGS=" 11 | --data_dir=$DATASET_PATH 12 | --keyphrases_path=$KEYPHRASES_PATH 13 | " 14 | export OUTPUT_DIR=$(python -m dapr.exps.keyphrases.args.dragon_plus $CLI_ARGS) 15 | mkdir -p $OUTPUT_DIR 16 | export LOG_PATH="$OUTPUT_DIR/logging.log" 17 | echo "Logging file path: $LOG_PATH" 18 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.keyphrases.dragon_plus $CLI_ARGS > $LOG_PATH 19 | done -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Set 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | from clddp.evaluation import RetrievalMetric 6 | 7 | 8 | @dataclass 9 | class BM25DocPassageHierarchyArguments( 10 | AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 11 | ): 12 | data_dir: Optional[str] = None 13 | split: Split = Split.test 14 | topk: int = 1000 15 | per_device_eval_batch_size: int = 32 16 | fp16: bool = True 17 | report_metric: str = RetrievalMetric.ndcg_string.at(10) 18 | report_passage_weight: Optional[float] = None 19 | 20 | def __post_init__(self) -> None: 21 | super().__post_init__() 22 | 23 | @property 24 | def escaped_args(self) -> Set[str]: 25 | return {"passage_results"} 26 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Set 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | from clddp.evaluation import RetrievalMetric 6 | 7 | 8 | @dataclass 9 | class BM25DocPassageFusionArguments(AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn): 10 | data_dir: Optional[str] = None 11 | passage_results: Optional[str] = None 12 | split: Split = Split.test 13 | topk: int = 1000 14 | per_device_eval_batch_size: int = 32 15 | fp16: bool = True 16 | report_metric: str = RetrievalMetric.ndcg_string.at(10) 17 | report_passage_weight: Optional[float] = None 18 | 19 | def __post_init__(self) -> None: 20 | super().__post_init__() 21 | 22 | @property 23 | def escaped_args(self) -> Set[str]: 24 | return {"passage_results"} 25 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/colbertv2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 3 | from dapr.retrievers.late_interaction import ColBERTv2 4 | from dapr.exps.doc_retrieval_with_titles.shared_pipeline import ( 5 | run_doc_retrieval_with_titles, 6 | ) 7 | from dapr.exps.doc_retrieval_with_titles.args.colbertv2 import ( 8 | ColBERTv2DocRetrievalWithTitlesArguments, 9 | ) 10 | 11 | if __name__ == "__main__": 12 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 13 | initialize_ddp() 14 | args = parse_cli(ColBERTv2DocRetrievalWithTitlesArguments) 15 | retriever = ColBERTv2( 16 | query_max_length=args.query_max_length, 17 | passage_max_length=args.passage_max_length, 18 | ) # titles are extra short 19 | run_doc_retrieval_with_titles( 20 | args=args, retriever=retriever, retriever_name="colbertv2" 21 | ) 22 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/args/colbertv2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from typing import List 4 | from clddp.utils import parse_cli 5 | from dapr.exps.doc_retrieval_with_titles.args.base import ( 6 | DocRetrievalWithTitlesArguments, 7 | ) 8 | 9 | 10 | @dataclass 11 | class ColBERTv2DocRetrievalWithTitlesArguments(DocRetrievalWithTitlesArguments): 12 | query_max_length: int = 150 13 | passage_max_length: int = 512 14 | 15 | def build_output_dir(self) -> str: 16 | return os.path.join("exps/doc_retrieval_with_titles/colbertv2", self.run_name) 17 | 18 | def get_arguments_from(self) -> List[type]: 19 | return [ 20 | DocRetrievalWithTitlesArguments, 21 | ColBERTv2DocRetrievalWithTitlesArguments, 22 | ] 23 | 24 | 25 | if __name__ == "__main__": 26 | print( 27 | parse_cli(ColBERTv2DocRetrievalWithTitlesArguments).output_dir 28 | ) # For creating the logging path 29 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/args/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Set 3 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 4 | from clddp.dm import Split 5 | from clddp.evaluation import RetrievalMetric 6 | 7 | 8 | @dataclass 9 | class JinaV2DocPassageFusionArguments( 10 | AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 11 | ): 12 | data_dir: Optional[str] = None 13 | passage_results: Optional[str] = None 14 | split: Split = Split.test 15 | topk: int = 1000 16 | per_device_doc_batch_size: int = 2 17 | per_device_passage_batch_size: int = 32 18 | fp16: bool = True 19 | report_metric: str = RetrievalMetric.ndcg_string.at(10) 20 | report_passage_weight: Optional[float] = None 21 | 22 | def __post_init__(self) -> None: 23 | super().__post_init__() 24 | 25 | @property 26 | def escaped_args(self) -> Set[str]: 27 | return {"passage_results"} 28 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_retrieval/args.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from typing import Optional, Set 4 | from clddp.args.base import AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn 5 | from clddp.dm import Split 6 | from clddp.evaluation import RetrievalMetric 7 | from clddp.utils import parse_cli 8 | 9 | 10 | @dataclass 11 | class BM25DocRetrievalArguments(AutoRunNameArgumentsMixIn, DumpableArgumentsMixIn): 12 | data_dir: Optional[str] = None 13 | split: Split = Split.test 14 | topk: int = 1000 15 | per_device_eval_batch_size: int = 32 16 | fp16: bool = True 17 | report_metric: str = RetrievalMetric.ndcg_string.at(10) 18 | report_passage_weight: Optional[float] = None 19 | 20 | def __post_init__(self) -> None: 21 | super().__post_init__() 22 | 23 | def build_output_dir(self) -> str: 24 | return os.path.join("exps/bm25_doc_retrieval", self.run_name) 25 | 26 | 27 | if __name__ == "__main__": 28 | print( 29 | parse_cli(BM25DocRetrievalArguments).output_dir 30 | ) # For creating the logging path 31 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/spladev2.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | from dapr.exps.bm25_doc_passage_hierarchy.args.spladev2 import ( 4 | SPLADEv2BM25DocPassageHierarchyArguments, 5 | ) 6 | from dapr.exps.bm25_doc_passage_hierarchy.shared_pipeline import ( 7 | run_bm25_doc_passage_hierarchy, 8 | ) 9 | from dapr.retrievers.sparse import SPLADEv2 10 | from clddp.utils import is_device_zero, parse_cli, initialize_ddp, set_logger_format 11 | from clddp.search import search 12 | 13 | if __name__ == "__main__": 14 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 15 | initialize_ddp() 16 | args = parse_cli(SPLADEv2BM25DocPassageHierarchyArguments) 17 | if is_device_zero(): 18 | args.dump_arguments() 19 | retriever = SPLADEv2() 20 | run_bm25_doc_passage_hierarchy( 21 | args=args, 22 | passage_retriever_name="spladev2", 23 | scoped_search_function=partial( 24 | search, 25 | retriever=retriever, 26 | per_device_eval_batch_size=args.per_device_eval_batch_size, 27 | fp16=args.fp16, 28 | ), 29 | ) 30 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/dragon_plus.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | from dapr.exps.bm25_doc_passage_hierarchy.args.dragon_plus import ( 4 | DRAGONPlusBM25DocPassageHierarchyArguments, 5 | ) 6 | from dapr.exps.bm25_doc_passage_hierarchy.shared_pipeline import ( 7 | run_bm25_doc_passage_hierarchy, 8 | ) 9 | from dapr.retrievers.dense import DRAGONPlus 10 | from clddp.utils import is_device_zero, parse_cli, initialize_ddp, set_logger_format 11 | from clddp.search import search 12 | 13 | if __name__ == "__main__": 14 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 15 | initialize_ddp() 16 | args = parse_cli(DRAGONPlusBM25DocPassageHierarchyArguments) 17 | if is_device_zero(): 18 | args.dump_arguments() 19 | retriever = DRAGONPlus() 20 | run_bm25_doc_passage_hierarchy( 21 | args=args, 22 | passage_retriever_name="dragon_plus", 23 | scoped_search_function=partial( 24 | search, 25 | retriever=retriever, 26 | per_device_eval_batch_size=args.per_device_eval_batch_size, 27 | fp16=args.fp16, 28 | ), 29 | ) 30 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/colbertv2.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | from dapr.exps.bm25_doc_passage_hierarchy.args.colbertv2 import ( 4 | CoLBERTv2BM25DocPassageHierarchyArguments, 5 | ) 6 | from dapr.exps.bm25_doc_passage_hierarchy.shared_pipeline import ( 7 | run_bm25_doc_passage_hierarchy, 8 | ) 9 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 10 | from clddp.search import search 11 | from dapr.retrievers.late_interaction import ColBERTv2 12 | 13 | if __name__ == "__main__": 14 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 15 | initialize_ddp() 16 | args = parse_cli(CoLBERTv2BM25DocPassageHierarchyArguments) 17 | if is_device_zero(): 18 | args.dump_arguments() 19 | retriever = ColBERTv2() 20 | run_bm25_doc_passage_hierarchy( 21 | args=args, 22 | passage_retriever_name="colbertv2", 23 | scoped_search_function=partial( 24 | search, 25 | retriever=retriever, 26 | per_device_eval_batch_size=args.per_device_eval_batch_size, 27 | fp16=args.fp16, 28 | ), 29 | ) 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | with open("README.md", "r", encoding="utf-8") as fh: 5 | readme = fh.read() 6 | 7 | setup( 8 | name="dapr", 9 | version="0.0.0", 10 | author="Kexin Wang", 11 | author_email="kexin.wang.2049@gmail.com", 12 | description="A benchmark on Document-Aware Passage Retrieval (DAPR).", 13 | long_description=readme, 14 | long_description_content_type="text/markdown", 15 | url="https://https://github.com/kwang2049/dapr", 16 | project_urls={ 17 | "Bug Tracker": "https://github.com/kwang2049/dapr/issues", 18 | }, 19 | packages=find_packages(), 20 | classifiers=[ 21 | "Programming Language :: Python :: 3", 22 | "License :: OSI Approved :: Apache Software License", 23 | "Operating System :: OS Independent", 24 | ], 25 | python_requires=">=3.8", 26 | install_requires=[ 27 | "colbert @ git+https://github.com/stanford-futuredata/ColBERT.git@21b460a606bed606e8a7fa105ada36b18e8084ec", 28 | "ujson==5.7.0", 29 | "datasets==2.16.1", 30 | "more-itertools==9.1.0", 31 | "matplotlib==3.7.4", 32 | "pytrec-eval==0.5", 33 | "transformers==4.37.2", 34 | "pke @ git+https://github.com/boudinfl/pke.git", 35 | "ir-datasets==0.5.4", 36 | "pyserini==0.21.0", 37 | "clddp==0.0.8", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/bm25.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | passage_results_paths=( 3 | "$(ls exps/passage_only/bm25/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 4 | "$(ls exps/passage_only/bm25/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/bm25/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/bm25/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/bm25/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | ) 9 | for i in {0..4} 10 | do 11 | dataset=${datasets[$i]} 12 | passage_results_path=${passage_results_paths[$i]} 13 | export DATA_DIR="data" 14 | export DATASET_PATH="$DATA_DIR/$dataset" 15 | export CLI_ARGS=" 16 | --data_dir=$DATASET_PATH 17 | --passage_results=$passage_results_path 18 | --report_passage_weight=0.7 19 | " 20 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.bm25 $CLI_ARGS) 21 | mkdir -p $OUTPUT_DIR 22 | export LOG_PATH="$OUTPUT_DIR/logging.log" 23 | echo "Logging file path: $LOG_PATH" 24 | python -m dapr.exps.bm25_doc_passage_fusion.bm25 $CLI_ARGS > $LOG_PATH 25 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/colbertv2.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | passage_results_paths=( 3 | "$(ls exps/passage_only/colbertv2/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 4 | "$(ls exps/passage_only/colbertv2/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/colbertv2/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/colbertv2/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/colbertv2/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | ) 9 | for i in {0..4} 10 | do 11 | dataset=${datasets[$i]} 12 | passage_results_path=${passage_results_paths[$i]} 13 | export DATA_DIR="data" 14 | export DATASET_PATH="$DATA_DIR/$dataset" 15 | export CLI_ARGS=" 16 | --data_dir=$DATASET_PATH 17 | --passage_results=$passage_results_path 18 | --report_passage_weight=0.7 19 | " 20 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.colbertv2 $CLI_ARGS) 21 | mkdir -p $OUTPUT_DIR 22 | export LOG_PATH="$OUTPUT_DIR/logging.log" 23 | echo "Logging file path: $LOG_PATH" 24 | python -m dapr.exps.bm25_doc_passage_fusion.colbertv2 $CLI_ARGS > $LOG_PATH 25 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | passage_results_paths=( 3 | "$(ls exps/passage_only/dragon_plus/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 4 | "$(ls exps/passage_only/dragon_plus/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/dragon_plus/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/dragon_plus/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/dragon_plus/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | ) 9 | for i in {0..4} 10 | do 11 | dataset=${datasets[$i]} 12 | passage_results_path=${passage_results_paths[$i]} 13 | export DATA_DIR="data" 14 | export DATASET_PATH="$DATA_DIR/$dataset" 15 | export CLI_ARGS=" 16 | --data_dir=$DATASET_PATH 17 | --passage_results=$passage_results_path 18 | --report_passage_weight=0.7 19 | " 20 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.dragon_plus $CLI_ARGS) 21 | mkdir -p $OUTPUT_DIR 22 | export LOG_PATH="$OUTPUT_DIR/logging.log" 23 | echo "Logging file path: $LOG_PATH" 24 | python -m dapr.exps.bm25_doc_passage_fusion.dragon_plus $CLI_ARGS > $LOG_PATH 25 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/bm25_doc_passage_fusion/spladev2.sh: -------------------------------------------------------------------------------- 1 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 2 | passage_results_paths=( 3 | "$(ls exps/passage_only/spladev2/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 4 | "$(ls exps/passage_only/spladev2/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/spladev2/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/spladev2/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/colbertv2/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | ) 9 | # for i in {4..4} 10 | for i in {0..4} 11 | do 12 | dataset=${datasets[$i]} 13 | passage_results_path=${passage_results_paths[$i]} 14 | export DATA_DIR="data" 15 | export DATASET_PATH="$DATA_DIR/$dataset" 16 | export CLI_ARGS=" 17 | --data_dir=$DATASET_PATH 18 | --passage_results=$passage_results_path 19 | --report_passage_weight=0.7 20 | " 21 | export OUTPUT_DIR=$(python -m dapr.exps.bm25_doc_passage_fusion.args.spladev2 $CLI_ARGS) 22 | mkdir -p $OUTPUT_DIR 23 | export LOG_PATH="$OUTPUT_DIR/logging.log" 24 | echo "Logging file path: $LOG_PATH" 25 | python -m dapr.exps.bm25_doc_passage_fusion.spladev2 $CLI_ARGS > $LOG_PATH 26 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/spladev2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 3 | passage_results_paths=( 4 | "$(ls exps/passage_only/spladev2/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/spladev2/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/spladev2/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/spladev2/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | "$(ls exps/passage_only/colbertv2/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 9 | ) 10 | for i in {0..4} 11 | do 12 | dataset=${datasets[$i]} 13 | passage_results_path=${passage_results_paths[$i]} 14 | export DATA_DIR="data" 15 | export DATASET_PATH="$DATA_DIR/$dataset" 16 | export CLI_ARGS=" 17 | --data_dir=$DATASET_PATH 18 | --passage_results=$passage_results_path 19 | --report_passage_weight=0.7 20 | " 21 | export OUTPUT_DIR=$(python -m dapr.exps.jinav2_doc_passage_fusion.args.spladev2 $CLI_ARGS) 22 | mkdir -p $OUTPUT_DIR 23 | export LOG_PATH="$OUTPUT_DIR/logging.log" 24 | echo "Logging file path: $LOG_PATH" 25 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.jinav2_doc_passage_fusion.spladev2 $CLI_ARGS 26 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/colbertv2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 3 | passage_results_paths=( 4 | "$(ls exps/passage_only/colbertv2/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/colbertv2/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/colbertv2/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/colbertv2/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | "$(ls exps/passage_only/colbertv2/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 9 | ) 10 | for i in {0..4} 11 | do 12 | dataset=${datasets[$i]} 13 | passage_results_path=${passage_results_paths[$i]} 14 | export DATA_DIR="data" 15 | export DATASET_PATH="$DATA_DIR/$dataset" 16 | export CLI_ARGS=" 17 | --data_dir=$DATASET_PATH 18 | --passage_results=$passage_results_path 19 | --report_passage_weight=0.8 20 | " 21 | export OUTPUT_DIR=$(python -m dapr.exps.jinav2_doc_passage_fusion.args.colbertv2 $CLI_ARGS) 22 | mkdir -p $OUTPUT_DIR 23 | export LOG_PATH="$OUTPUT_DIR/logging.log" 24 | echo "Logging file path: $LOG_PATH" 25 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.jinav2_doc_passage_fusion.colbertv2 $CLI_ARGS 26 | done -------------------------------------------------------------------------------- /scripts/dgx2/exps/jinav2_doc_passage_fusion/dragon_plus.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 2 | datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 3 | passage_results_paths=( 4 | "$(ls exps/passage_only/dragon_plus/data_dir_data/ConditionalQA/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 5 | "$(ls exps/passage_only/dragon_plus/data_dir_data/MSMARCO/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 6 | "$(ls exps/passage_only/dragon_plus/data_dir_data/NaturalQuestions/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 7 | "$(ls exps/passage_only/dragon_plus/data_dir_data/Genomics/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 8 | "$(ls exps/passage_only/dragon_plus/data_dir_data/MIRACL/split_test/topk_1000/per_device_eval_batch_size_32/fp16_True/*/ranking_results.txt|head -1)" 9 | ) 10 | for i in {0..4} 11 | do 12 | dataset=${datasets[$i]} 13 | passage_results_path=${passage_results_paths[$i]} 14 | export DATA_DIR="data" 15 | export DATASET_PATH="$DATA_DIR/$dataset" 16 | export CLI_ARGS=" 17 | --data_dir=$DATASET_PATH 18 | --passage_results=$passage_results_path 19 | --report_passage_weight=0.7 20 | " 21 | export OUTPUT_DIR=$(python -m dapr.exps.jinav2_doc_passage_fusion.args.dragon_plus $CLI_ARGS) 22 | mkdir -p $OUTPUT_DIR 23 | export LOG_PATH="$OUTPUT_DIR/logging.log" 24 | echo "Logging file path: $LOG_PATH" 25 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.jinav2_doc_passage_fusion.dragon_plus $CLI_ARGS 26 | done -------------------------------------------------------------------------------- /dapr/exps/coref/shared_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.retriever import Retriever 6 | from clddp.utils import is_device_zero 7 | from dapr.exps.coref.args.base import CorefArguments 8 | from clddp.search import search 9 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 10 | from clddp.evaluation import RetrievalEvaluator 11 | 12 | 13 | def run_coref(args: CorefArguments, retriever: Retriever, retriever_name: str) -> None: 14 | # Actually the same as passage_only. And the difference relies on the data 15 | if is_device_zero(): 16 | args.dump_arguments() 17 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 18 | progress_bar=is_device_zero() 19 | ) 20 | labeled_queries = dataset.get_labeled_queries(args.split) 21 | queries = LabeledQuery.get_unique_queries(labeled_queries) 22 | retrieved = search( 23 | retriever=retriever, 24 | collection_iter=dataset.collection_iter, 25 | collection_size=dataset.collection_size, 26 | queries=queries, 27 | topk=args.topk, 28 | per_device_eval_batch_size=args.per_device_eval_batch_size, 29 | fp16=args.fp16, 30 | ) 31 | if is_device_zero(): 32 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 33 | report = evaluator(retrieved) 34 | freport = os.path.join(args.output_dir, "metrics.json") 35 | with open(freport, "w") as f: 36 | json.dump(report, f, indent=4) 37 | logging.info(f"Saved evaluation metrics to {freport}.") 38 | franked = os.path.join(args.output_dir, "ranking_results.txt") 39 | RetrievedPassageIDList.dump_trec_csv( 40 | retrieval_results=retrieved, fpath=franked, system=retriever_name 41 | ) 42 | logging.info(f"Saved ranking results to {franked}.") 43 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/bm25.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.bm25 import BM25 7 | from dapr.exps.passage_only.args.bm25 import BM25PassageOnlyArguments 8 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 9 | from clddp.evaluation import RetrievalEvaluator 10 | 11 | if __name__ == "__main__": 12 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 13 | args = parse_cli(BM25PassageOnlyArguments) 14 | if is_device_zero(): 15 | args.dump_arguments() 16 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 17 | progress_bar=is_device_zero() 18 | ) 19 | retriever = BM25() 20 | labeled_queries = dataset.get_labeled_queries(args.split) 21 | queries = LabeledQuery.get_unique_queries(labeled_queries) 22 | index_path = os.path.join(args.output_dir, "index") 23 | retriever.index( 24 | collection_iter=dataset.collection_iter, 25 | collection_size=dataset.collection_size, 26 | output_dir=index_path, 27 | ) 28 | retrieved = retriever.search( 29 | queries=queries, 30 | index_path=index_path, 31 | topk=args.topk, 32 | batch_size=args.per_device_eval_batch_size, 33 | ) 34 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 35 | report = evaluator(retrieved) 36 | freport = os.path.join(args.output_dir, "metrics.json") 37 | with open(freport, "w") as f: 38 | json.dump(report, f, indent=4) 39 | logging.info(f"Saved evaluation metrics to {freport}.") 40 | franked = os.path.join(args.output_dir, "ranking_results.txt") 41 | RetrievedPassageIDList.dump_trec_csv( 42 | retrieval_results=retrieved, fpath=franked, system="retromae" 43 | ) 44 | logging.info(f"Saved ranking results to {franked}.") 45 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/jinav2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.dense import JinaV2 7 | from dapr.exps.passage_only.args.jinav2 import JinaV2PassageOnlyArguments 8 | from clddp.search import search 9 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 10 | from clddp.evaluation import RetrievalEvaluator 11 | 12 | if __name__ == "__main__": 13 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 14 | initialize_ddp() 15 | args = parse_cli(JinaV2PassageOnlyArguments) 16 | if is_device_zero(): 17 | args.dump_arguments() 18 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 19 | progress_bar=is_device_zero() 20 | ) 21 | retriever = JinaV2(args.max_length) 22 | labeled_queries = dataset.get_labeled_queries(args.split) 23 | queries = LabeledQuery.get_unique_queries(labeled_queries) 24 | retrieved = search( 25 | retriever=retriever, 26 | collection_iter=dataset.collection_iter, 27 | collection_size=dataset.collection_size, 28 | queries=queries, 29 | topk=args.topk, 30 | per_device_eval_batch_size=args.per_device_eval_batch_size, 31 | fp16=args.fp16, 32 | ) 33 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 34 | report = evaluator(retrieved) 35 | freport = os.path.join(args.output_dir, "metrics.json") 36 | with open(freport, "w") as f: 37 | json.dump(report, f, indent=4) 38 | logging.info(f"Saved evaluation metrics to {freport}.") 39 | franked = os.path.join(args.output_dir, "ranking_results.txt") 40 | RetrievedPassageIDList.dump_trec_csv( 41 | retrieval_results=retrieved, fpath=franked, system="jinav2" 42 | ) 43 | logging.info(f"Saved ranking results to {franked}.") 44 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/retromae.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.dense import RetroMAE 7 | from dapr.exps.passage_only.args.retromae import RetroMAEPassageOnlyArguments 8 | from clddp.search import search 9 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 10 | from clddp.evaluation import RetrievalEvaluator 11 | 12 | if __name__ == "__main__": 13 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 14 | initialize_ddp() 15 | args = parse_cli(RetroMAEPassageOnlyArguments) 16 | if is_device_zero(): 17 | args.dump_arguments() 18 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 19 | progress_bar=is_device_zero() 20 | ) 21 | retriever = RetroMAE() 22 | labeled_queries = dataset.get_labeled_queries(args.split) 23 | queries = LabeledQuery.get_unique_queries(labeled_queries) 24 | retrieved = search( 25 | retriever=retriever, 26 | collection_iter=dataset.collection_iter, 27 | collection_size=dataset.collection_size, 28 | queries=queries, 29 | topk=args.topk, 30 | per_device_eval_batch_size=args.per_device_eval_batch_size, 31 | fp16=args.fp16, 32 | ) 33 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 34 | report = evaluator(retrieved) 35 | freport = os.path.join(args.output_dir, "metrics.json") 36 | with open(freport, "w") as f: 37 | json.dump(report, f, indent=4) 38 | logging.info(f"Saved evaluation metrics to {freport}.") 39 | franked = os.path.join(args.output_dir, "ranking_results.txt") 40 | RetrievedPassageIDList.dump_trec_csv( 41 | retrieval_results=retrieved, fpath=franked, system="retromae" 42 | ) 43 | logging.info(f"Saved ranking results to {franked}.") 44 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/spladev2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.sparse import SPLADEv2 7 | from dapr.exps.passage_only.args.spladev2 import SPLADEv2PassageOnlyArguments 8 | from clddp.search import search 9 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 10 | from clddp.evaluation import RetrievalEvaluator 11 | 12 | if __name__ == "__main__": 13 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 14 | initialize_ddp() 15 | args = parse_cli(SPLADEv2PassageOnlyArguments) 16 | if is_device_zero(): 17 | args.dump_arguments() 18 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 19 | progress_bar=is_device_zero() 20 | ) 21 | retriever = SPLADEv2() 22 | labeled_queries = dataset.get_labeled_queries(args.split) 23 | queries = LabeledQuery.get_unique_queries(labeled_queries) 24 | retrieved = search( 25 | retriever=retriever, 26 | collection_iter=dataset.collection_iter, 27 | collection_size=dataset.collection_size, 28 | queries=queries, 29 | topk=args.topk, 30 | per_device_eval_batch_size=args.per_device_eval_batch_size, 31 | fp16=args.fp16, 32 | ) 33 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 34 | report = evaluator(retrieved) 35 | freport = os.path.join(args.output_dir, "metrics.json") 36 | with open(freport, "w") as f: 37 | json.dump(report, f, indent=4) 38 | logging.info(f"Saved evaluation metrics to {freport}.") 39 | franked = os.path.join(args.output_dir, "ranking_results.txt") 40 | RetrievedPassageIDList.dump_trec_csv( 41 | retrieval_results=retrieved, fpath=franked, system="spladev2" 42 | ) 43 | logging.info(f"Saved ranking results to {franked}.") 44 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/dragon_plus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.dense import DRAGONPlus 7 | from dapr.exps.passage_only.args.dragon_plus import DRAGONPlusPassageOnlyArguments 8 | from clddp.search import search 9 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 10 | from clddp.evaluation import RetrievalEvaluator 11 | 12 | if __name__ == "__main__": 13 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 14 | initialize_ddp() 15 | args = parse_cli(DRAGONPlusPassageOnlyArguments) 16 | if is_device_zero(): 17 | args.dump_arguments() 18 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 19 | progress_bar=is_device_zero() 20 | ) 21 | retriever = DRAGONPlus() 22 | labeled_queries = dataset.get_labeled_queries(args.split) 23 | queries = LabeledQuery.get_unique_queries(labeled_queries) 24 | retrieved = search( 25 | retriever=retriever, 26 | collection_iter=dataset.collection_iter, 27 | collection_size=dataset.collection_size, 28 | queries=queries, 29 | topk=args.topk, 30 | per_device_eval_batch_size=args.per_device_eval_batch_size, 31 | fp16=args.fp16, 32 | ) 33 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 34 | report = evaluator(retrieved) 35 | freport = os.path.join(args.output_dir, "metrics.json") 36 | with open(freport, "w") as f: 37 | json.dump(report, f, indent=4) 38 | logging.info(f"Saved evaluation metrics to {freport}.") 39 | franked = os.path.join(args.output_dir, "ranking_results.txt") 40 | RetrievedPassageIDList.dump_trec_csv( 41 | retrieval_results=retrieved, fpath=franked, system="dragon_plus" 42 | ) 43 | logging.info(f"Saved ranking results to {franked}.") 44 | -------------------------------------------------------------------------------- /dapr/exps/passage_only/colbertv2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from clddp.search import search 7 | from dapr.retrievers.late_interaction import ColBERTv2 8 | from dapr.exps.passage_only.args.colbertv2 import ColBERTv2PassageOnlyArguments 9 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 10 | from clddp.evaluation import RetrievalEvaluator 11 | 12 | if __name__ == "__main__": 13 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 14 | initialize_ddp() 15 | args = parse_cli(ColBERTv2PassageOnlyArguments) 16 | if is_device_zero(): 17 | args.dump_arguments() 18 | dataset = DAPRDataLoader(DAPRDataConfig(args.data_dir)).load_data( 19 | progress_bar=is_device_zero() 20 | ) 21 | retriever = ColBERTv2() 22 | labeled_queries = dataset.get_labeled_queries(args.split) 23 | queries = LabeledQuery.get_unique_queries(labeled_queries) 24 | retrieved = search( 25 | retriever=retriever, 26 | collection_iter=dataset.collection_iter, 27 | collection_size=dataset.collection_size, 28 | queries=queries, 29 | topk=args.topk, 30 | per_device_eval_batch_size=args.per_device_eval_batch_size, 31 | fp16=args.fp16, 32 | ) 33 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 34 | report = evaluator(retrieved) 35 | freport = os.path.join(args.output_dir, "metrics.json") 36 | with open(freport, "w") as f: 37 | json.dump(report, f, indent=4) 38 | logging.info(f"Saved evaluation metrics to {freport}.") 39 | franked = os.path.join(args.output_dir, "ranking_results.txt") 40 | RetrievedPassageIDList.dump_trec_csv( 41 | retrieval_results=retrieved, fpath=franked, system="colbertv2" 42 | ) 43 | logging.info(f"Saved ranking results to {franked}.") 44 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/bm25.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.bm25 import BM25 7 | from dapr.exps.prepending_titles.args.bm25 import BM25PrependingTitlesArguments 8 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 9 | from clddp.evaluation import RetrievalEvaluator 10 | 11 | if __name__ == "__main__": 12 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 13 | args = parse_cli(BM25PrependingTitlesArguments) 14 | if is_device_zero(): 15 | args.dump_arguments() 16 | dataset = DAPRDataLoader( 17 | DAPRDataConfig(data_name_or_path=args.data_dir, titled=True) 18 | ).load_data(progress_bar=is_device_zero()) 19 | retriever = BM25() 20 | labeled_queries = dataset.get_labeled_queries(args.split) 21 | queries = LabeledQuery.get_unique_queries(labeled_queries) 22 | index_path = os.path.join(args.output_dir, "index") 23 | retriever.index( 24 | collection_iter=dataset.collection_iter, 25 | collection_size=dataset.collection_size, 26 | output_dir=index_path, 27 | ) 28 | retrieved = retriever.search( 29 | queries=queries, 30 | index_path=index_path, 31 | topk=args.topk, 32 | batch_size=args.per_device_eval_batch_size, 33 | ) 34 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 35 | report = evaluator(retrieved) 36 | freport = os.path.join(args.output_dir, "metrics.json") 37 | with open(freport, "w") as f: 38 | json.dump(report, f, indent=4) 39 | logging.info(f"Saved evaluation metrics to {freport}.") 40 | franked = os.path.join(args.output_dir, "ranking_results.txt") 41 | RetrievedPassageIDList.dump_trec_csv( 42 | retrieval_results=retrieved, fpath=franked, system="retromae" 43 | ) 44 | logging.info(f"Saved ranking results to {franked}.") 45 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/spladev2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.sparse import SPLADEv2 7 | from dapr.exps.prepending_titles.args.spladev2 import ( 8 | SPLADEv2lusPrependingTitlesArguments, 9 | ) 10 | from clddp.search import search 11 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 12 | from clddp.evaluation import RetrievalEvaluator 13 | 14 | if __name__ == "__main__": 15 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 16 | initialize_ddp() 17 | args = parse_cli(SPLADEv2lusPrependingTitlesArguments) 18 | if is_device_zero(): 19 | args.dump_arguments() 20 | dataset = DAPRDataLoader( 21 | DAPRDataConfig(data_name_or_path=args.data_dir, titled=True) 22 | ).load_data(progress_bar=is_device_zero()) 23 | retriever = SPLADEv2() 24 | labeled_queries = dataset.get_labeled_queries(args.split) 25 | queries = LabeledQuery.get_unique_queries(labeled_queries) 26 | retrieved = search( 27 | retriever=retriever, 28 | collection_iter=dataset.collection_iter, 29 | collection_size=dataset.collection_size, 30 | queries=queries, 31 | topk=args.topk, 32 | per_device_eval_batch_size=args.per_device_eval_batch_size, 33 | fp16=args.fp16, 34 | ) 35 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 36 | report = evaluator(retrieved) 37 | freport = os.path.join(args.output_dir, "metrics.json") 38 | with open(freport, "w") as f: 39 | json.dump(report, f, indent=4) 40 | logging.info(f"Saved evaluation metrics to {freport}.") 41 | franked = os.path.join(args.output_dir, "ranking_results.txt") 42 | RetrievedPassageIDList.dump_trec_csv( 43 | retrieval_results=retrieved, fpath=franked, system="spladev2" 44 | ) 45 | logging.info(f"Saved ranking results to {franked}.") 46 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/dragon_plus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from dapr.retrievers.dense import DRAGONPlus 7 | from dapr.exps.prepending_titles.args.dragon_plus import ( 8 | DRAGONPlusPrependingTitlesArguments, 9 | ) 10 | from clddp.search import search 11 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 12 | from clddp.evaluation import RetrievalEvaluator 13 | 14 | if __name__ == "__main__": 15 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 16 | initialize_ddp() 17 | args = parse_cli(DRAGONPlusPrependingTitlesArguments) 18 | if is_device_zero(): 19 | args.dump_arguments() 20 | dataset = DAPRDataLoader( 21 | DAPRDataConfig(data_name_or_path=args.data_dir, titled=True) 22 | ).load_data(progress_bar=is_device_zero()) 23 | retriever = DRAGONPlus() 24 | labeled_queries = dataset.get_labeled_queries(args.split) 25 | queries = LabeledQuery.get_unique_queries(labeled_queries) 26 | retrieved = search( 27 | retriever=retriever, 28 | collection_iter=dataset.collection_iter, 29 | collection_size=dataset.collection_size, 30 | queries=queries, 31 | topk=args.topk, 32 | per_device_eval_batch_size=args.per_device_eval_batch_size, 33 | fp16=args.fp16, 34 | ) 35 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 36 | report = evaluator(retrieved) 37 | freport = os.path.join(args.output_dir, "metrics.json") 38 | with open(freport, "w") as f: 39 | json.dump(report, f, indent=4) 40 | logging.info(f"Saved evaluation metrics to {freport}.") 41 | franked = os.path.join(args.output_dir, "ranking_results.txt") 42 | RetrievedPassageIDList.dump_trec_csv( 43 | retrieval_results=retrieved, fpath=franked, system="dragon_plus" 44 | ) 45 | logging.info(f"Saved ranking results to {franked}.") 46 | -------------------------------------------------------------------------------- /dapr/exps/prepending_titles/colbertv2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 6 | from clddp.search import search 7 | from dapr.retrievers.late_interaction import ColBERTv2 8 | from dapr.exps.prepending_titles.args.colbertv2 import ( 9 | ColBERTv2PrependingTitlesArguments, 10 | ) 11 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 12 | from clddp.evaluation import RetrievalEvaluator 13 | 14 | if __name__ == "__main__": 15 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 16 | initialize_ddp() 17 | args = parse_cli(ColBERTv2PrependingTitlesArguments) 18 | if is_device_zero(): 19 | args.dump_arguments() 20 | dataset = DAPRDataLoader( 21 | DAPRDataConfig(data_name_or_path=args.data_dir, titled=True) 22 | ).load_data(progress_bar=is_device_zero()) 23 | retriever = ColBERTv2() 24 | labeled_queries = dataset.get_labeled_queries(args.split) 25 | queries = LabeledQuery.get_unique_queries(labeled_queries) 26 | retrieved = search( 27 | retriever=retriever, 28 | collection_iter=dataset.collection_iter, 29 | collection_size=dataset.collection_size, 30 | queries=queries, 31 | topk=args.topk, 32 | per_device_eval_batch_size=args.per_device_eval_batch_size, 33 | fp16=args.fp16, 34 | ) 35 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 36 | report = evaluator(retrieved) 37 | freport = os.path.join(args.output_dir, "metrics.json") 38 | with open(freport, "w") as f: 39 | json.dump(report, f, indent=4) 40 | logging.info(f"Saved evaluation metrics to {freport}.") 41 | franked = os.path.join(args.output_dir, "ranking_results.txt") 42 | RetrievedPassageIDList.dump_trec_csv( 43 | retrieval_results=retrieved, fpath=franked, system="colbertv2" 44 | ) 45 | logging.info(f"Saved ranking results to {franked}.") 46 | -------------------------------------------------------------------------------- /dapr/fusion.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | from clddp.dm import RetrievedPassageIDList, ScoredPassageID 3 | 4 | 5 | SOFT_ZERO = 1e-4 6 | 7 | 8 | def normalize_min_max(scores: Dict[str, float]) -> Dict[str, float]: 9 | if len(scores) == 0: 10 | return {} 11 | score_min = min(scores.values()) 12 | score_max = max(scores.values()) 13 | divisor = max(score_max - score_min, SOFT_ZERO) 14 | normalized = {k: (v - score_min) / divisor for k, v in scores.items()} 15 | return normalized 16 | 17 | 18 | def M2C2( 19 | scores1: Dict[str, float], 20 | scores2: Dict[str, float], 21 | weight2: float, 22 | ) -> Dict[str, float]: 23 | scores1 = normalize_min_max(scores=scores1) 24 | scores2 = normalize_min_max(scores=scores2) 25 | ids_union = set(scores1) | set(scores2) 26 | default_score = 0 27 | scores = { 28 | k: (1 - weight2) * scores1.get(k, default_score) 29 | + weight2 * scores2.get(k, default_score) 30 | for k in ids_union 31 | } 32 | return scores 33 | 34 | 35 | def doc_passage_fusion_with_M2C2( 36 | doc_results: List[RetrievedPassageIDList], 37 | passage_results: List[RetrievedPassageIDList], 38 | pid2did: Dict[str, str], 39 | passage_weight: float, 40 | ) -> List[RetrievedPassageIDList]: 41 | fused_lists = [] 42 | for doc_result, passage_result in zip(doc_results, passage_results): 43 | assert doc_result.query_id == passage_result.query_id 44 | did2score = ScoredPassageID.build_pid2score(doc_result.scored_passage_ids) 45 | pid2score = ScoredPassageID.build_pid2score(passage_result.scored_passage_ids) 46 | doc_pid2score = { 47 | pid: did2score[pid2did[pid]] 48 | for pid in pid2score.keys() 49 | if pid2did[pid] in did2score 50 | } 51 | fused = M2C2(scores1=doc_pid2score, scores2=pid2score, weight2=passage_weight) 52 | spids = [ 53 | ScoredPassageID(passage_id=pid, score=score) for pid, score in fused.items() 54 | ] 55 | fused_lists.append( 56 | RetrievedPassageIDList( 57 | query_id=doc_result.query_id, scored_passage_ids=spids 58 | ) 59 | ) 60 | return fused_lists 61 | -------------------------------------------------------------------------------- /dapr/retrievers/dense.py: -------------------------------------------------------------------------------- 1 | from clddp.retriever import Retriever, RetrieverConfig, Pooling, SimilarityFunction 2 | from clddp.dm import Separator 3 | from transformers import PreTrainedModel, AutoModel 4 | 5 | 6 | class RetroMAE(Retriever): 7 | def __init__(self) -> None: 8 | config = RetrieverConfig( 9 | query_model_name_or_path="Shitao/RetroMAE_BEIR", 10 | shared_encoder=True, 11 | sep=Separator.bert_sep, 12 | pooling=Pooling.cls, 13 | similarity_function=SimilarityFunction.dot_product, 14 | query_max_length=512, 15 | passage_max_length=512, 16 | ) 17 | super().__init__(config) 18 | 19 | 20 | class DRAGONPlus(Retriever): 21 | def __init__(self) -> None: 22 | config = RetrieverConfig( 23 | query_model_name_or_path="facebook/dragon-plus-query-encoder", 24 | passage_model_name_or_path="facebook/dragon-plus-context-encoder", 25 | shared_encoder=False, 26 | sep=Separator.blank, 27 | pooling=Pooling.cls, 28 | similarity_function=SimilarityFunction.dot_product, 29 | query_max_length=512, 30 | passage_max_length=512, 31 | ) 32 | super().__init__(config) 33 | 34 | 35 | class JinaV2(Retriever): 36 | def __init__(self, max_length: int = 8192) -> None: 37 | config = RetrieverConfig( 38 | query_model_name_or_path="jinaai/jina-embeddings-v2-base-en", 39 | shared_encoder=True, 40 | sep=Separator.blank, 41 | pooling=Pooling.mean, 42 | similarity_function=SimilarityFunction.cos_sim, 43 | # max_length=512, # Use this for paragraph-level retrieval 44 | query_max_length=max_length, 45 | passage_max_length=max_length, 46 | ) 47 | super().__init__(config) 48 | 49 | @staticmethod 50 | def load_checkpoint( 51 | model_name_or_path: str, config: RetrieverConfig 52 | ) -> PreTrainedModel: 53 | return AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) 54 | 55 | # def encode( 56 | # self, encoder: PreTrainedModel, texts: List[str], batch_size: int 57 | # ) -> torch.Tensor: 58 | # return encoder.encode( 59 | # texts, 60 | # batch_size=batch_size, 61 | # convert_to_tensor=True, 62 | # convert_to_numpy=False, 63 | # show_progress_bar=False, 64 | # ) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | **/ouptut 132 | output 133 | *.gz 134 | *.zip 135 | *.pkl 136 | outputs 137 | wandb 138 | ckpts 139 | *.out 140 | *.out.* 141 | 142 | data 143 | results 144 | /exps 145 | .DS_Store 146 | 147 | /load.py 148 | /evaluate.py -------------------------------------------------------------------------------- /dapr/datasets/conditionalqa.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import replace 3 | import re 4 | from typing import List, Optional 5 | from dapr.datasets.dm import ( 6 | Document, 7 | LabeledQuery, 8 | LoadedData, 9 | ) 10 | from dapr.utils import Separator, set_logger_format 11 | from dapr.datasets.tagged_conditionalqa import TaggedConditionalQA 12 | import tqdm 13 | 14 | 15 | class ConditionalQA(TaggedConditionalQA): 16 | """The cleaned version of ConditionalQA where the HTML tags have been removed. Used in the DAPR experiments.""" 17 | 18 | HTML_TAG_PATTERN = re.compile("<.*?>") 19 | 20 | def __init__( 21 | self, 22 | resource_path: str = "https://raw.githubusercontent.com/haitian-sun/ConditionalQA/master/v1_0", 23 | nheldout: Optional[int] = None, 24 | cache_root_dir: str = "data", 25 | chunk_separator: Separator = Separator.empty, 26 | tokenizer: str = "roberta-base", 27 | nprocs: int = 10, 28 | ) -> None: 29 | super().__init__( 30 | resource_path, nheldout, cache_root_dir, chunk_separator, tokenizer, nprocs 31 | ) 32 | 33 | def clean_document(self, doc: Document) -> Document: 34 | cloned_doc = replace(doc) 35 | chunks = [replace(chk) for chk in cloned_doc.chunks] 36 | for chk in chunks: 37 | chk.text = re.sub(self.HTML_TAG_PATTERN, "", chk.text).strip() 38 | cloned_doc.chunks = chunks 39 | return cloned_doc 40 | 41 | def clean_labeled_queries( 42 | self, labeled_queries: List[LabeledQuery] 43 | ) -> List[LabeledQuery]: 44 | cleaned_lqs = [] 45 | for lq in tqdm.tqdm(labeled_queries, desc="Cleaning labeled queries"): 46 | for jchk in lq.judged_chunks: 47 | jchk.chunk.belonging_doc = self.clean_document(jchk.chunk.belonging_doc) 48 | cleaned_lqs.append(lq) 49 | return cleaned_lqs 50 | 51 | def _load_data(self, nheldout: Optional[int]) -> LoadedData: 52 | data = super()._load_data(nheldout) 53 | docs = data.corpus_iter_fn() 54 | cleaned_docs = [ 55 | self.clean_document(doc) for doc in tqdm.tqdm(docs, desc="Cleaning corpus") 56 | ] 57 | return LoadedData( 58 | corpus_iter_fn=lambda: iter(cleaned_docs), 59 | labeled_queries_train=self.clean_labeled_queries( 60 | data.labeled_queries_train 61 | ), 62 | labeled_queries_dev=self.clean_labeled_queries(data.labeled_queries_dev), 63 | labeled_queries_test=self.clean_labeled_queries(data.labeled_queries_test), 64 | ) 65 | 66 | 67 | if __name__ == "__main__": 68 | from dapr.utils import set_logger_format 69 | 70 | set_logger_format() 71 | dataset = ConditionalQA() 72 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_retrieval/pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 5 | from clddp.evaluation import RetrievalEvaluator 6 | from clddp.utils import is_device_zero, parse_cli, set_logger_format 7 | from dapr.exps.bm25_doc_retrieval.args import BM25DocRetrievalArguments 8 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader, RetrievalLevel 9 | from dapr.retrievers.bm25 import BM25 10 | 11 | 12 | def run_bm25_doc_retrieval() -> None: 13 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 14 | args = parse_cli(BM25DocRetrievalArguments) 15 | args.dump_arguments() 16 | 17 | # Doing BM25 document retrieval: 18 | fdocs_ranking = os.path.join(args.output_dir, "doc_ranking_results.txt") 19 | doc_dataset = DAPRDataLoader( 20 | DAPRDataConfig( 21 | data_name_or_path=args.data_dir, retrieval_level=RetrievalLevel.document 22 | ) 23 | ).load_data(True) 24 | labeled_queries = doc_dataset.get_labeled_queries(args.split) 25 | queries = LabeledQuery.get_unique_queries(labeled_queries) 26 | retriever = BM25() 27 | index_path = os.path.join(args.output_dir, "index") 28 | if not (os.path.exists(index_path) and len(os.listdir(index_path))): 29 | retriever.index( 30 | collection_iter=doc_dataset.collection_iter, 31 | collection_size=doc_dataset.collection_size, 32 | output_dir=index_path, 33 | ) 34 | else: 35 | logging.info(f"Found existing index {index_path}") 36 | retrieved_docs = retriever.search( 37 | queries=queries, 38 | index_path=index_path, 39 | topk=args.topk, 40 | batch_size=args.per_device_eval_batch_size, 41 | ) 42 | RetrievedPassageIDList.dump_trec_csv( 43 | retrieval_results=retrieved_docs, fpath=fdocs_ranking, system="bm25" 44 | ) 45 | logging.info(f"Saved BM25 document ranking results to {fdocs_ranking}.") 46 | evaluator = RetrievalEvaluator(eval_dataset=doc_dataset, split=args.split) 47 | eval_scores = evaluator._pytrec_eval( 48 | trec_scores=RetrievedPassageIDList.build_trec_scores(retrieved_docs), 49 | qrels=evaluator.qrels, 50 | ) 51 | report_metrics = evaluator._build_report(eval_scores) 52 | fdetails = os.path.join(args.output_dir, "q2d_details.jsonl") 53 | with open(fdetails, "w") as f: 54 | for qid, row in eval_scores.items(): 55 | row = dict(eval_scores[qid]) 56 | row["query_id"] = qid 57 | f.write(json.dumps(row) + "\n") 58 | freport = os.path.join(args.output_dir, "q2d_metrics.json") 59 | with open(freport, "w") as f: 60 | json.dump(report_metrics, f, indent=4) 61 | logging.info(f"Saved evaluation metrics to {freport}.") 62 | 63 | 64 | if __name__ == "__main__": 65 | run_bm25_doc_retrieval() 66 | -------------------------------------------------------------------------------- /dapr/annotators/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | from typing import Dict, Iterable, TypedDict 4 | 5 | from dapr.datasets.base import LoadedData 6 | from dapr.datasets.dm import Document 7 | from dapr.utils import tqdm_ropen 8 | import ujson 9 | 10 | 11 | class DocID2DocSummaryJson(TypedDict): 12 | doc_id: str 13 | doc_summary: str 14 | 15 | 16 | class BaseAnnotator(ABC): 17 | def annotate(self, data: LoadedData, cache_root_dir: str) -> None: 18 | """Annotate the doc_summary fields inplace.""" 19 | data.meta_data["corpus_identifier"] = "/".join( 20 | [ 21 | data.meta_data["corpus_identifier"], 22 | self.__class__.__name__, 23 | ] 24 | ) 25 | cache_fpath = os.path.join( 26 | cache_root_dir, 27 | data.meta_data["corpus_identifier"], 28 | "did2dsum.jsonl", 29 | ) 30 | if os.path.exists(cache_fpath): 31 | did2dsum: Dict[str, str] = {} 32 | for line in tqdm_ropen( 33 | fpath=cache_fpath, desc="Loading document summaries" 34 | ): 35 | line_dict: DocID2DocSummaryJson = ujson.loads(line) 36 | did2dsum[line_dict["doc_id"]] = line_dict["doc_summary"] 37 | else: 38 | os.makedirs(os.path.dirname(cache_fpath), exist_ok=True) 39 | try: 40 | did2dsum = self._annotate(data) 41 | with open(cache_fpath, "w") as f: 42 | for did, dsum in did2dsum.items(): 43 | line_dict = DocID2DocSummaryJson(doc_id=did, doc_summary=dsum) 44 | line = ujson.dumps(line_dict) + "\n" 45 | f.write(line) 46 | except Exception as e: 47 | if os.path.exists(cache_fpath): 48 | os.remove(cache_fpath) 49 | raise e 50 | 51 | for lqs in [ 52 | data.labeled_queries_train, 53 | data.labeled_queries_dev, 54 | data.labeled_queries_test, 55 | ]: 56 | if lqs is None: 57 | continue 58 | for lq in lqs: 59 | for jchk in lq.judged_chunks: 60 | doc_id = jchk.chunk.belonging_doc.doc_id 61 | jchk.chunk.doc_summary = did2dsum[doc_id] 62 | 63 | corpus_iter_fn = data.corpus_iter_fn 64 | 65 | def new_corpus_iter_fn() -> Iterable[Document]: 66 | for doc in corpus_iter_fn(): 67 | for chk in doc.chunks: 68 | chk.doc_summary = did2dsum[doc.doc_id] 69 | yield doc 70 | 71 | data.corpus_iter_fn = new_corpus_iter_fn 72 | 73 | @abstractmethod 74 | def _annotate(self, data: LoadedData) -> Dict[str, str]: 75 | """Annotate the doc_summary fields. Return a mapping from `doc_id` to `doc_summary`.""" 76 | pass 77 | -------------------------------------------------------------------------------- /dapr/exps/doc_retrieval_with_titles/shared_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from typing import Iterable 5 | from clddp.dm import LabeledQuery, Passage, RetrievedPassageIDList 6 | from clddp.retriever import Retriever 7 | from clddp.utils import is_device_zero 8 | from dapr.dataloader import DAPRDataConfig 9 | from dapr.exps.doc_retrieval_with_titles.args.base import ( 10 | DocRetrievalWithTitlesArguments, 11 | ) 12 | from clddp.search import search 13 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader, RetrievalLevel 14 | from clddp.evaluation import RetrievalEvaluator 15 | 16 | 17 | def title_only_collection_iter(collection_iter: Iterable[Passage]) -> Iterable[Passage]: 18 | for doc in collection_iter: 19 | yield Passage(passage_id=doc.passage_id, text="", title=doc.title) 20 | 21 | 22 | def run_doc_retrieval_with_titles( 23 | args: DocRetrievalWithTitlesArguments, retriever: Retriever, retriever_name: str 24 | ) -> None: 25 | # Actually the same as passage_only. And the difference relies on the data 26 | if is_device_zero(): 27 | args.dump_arguments() 28 | dataset = DAPRDataLoader( 29 | DAPRDataConfig( 30 | data_name_or_path=args.data_dir, 31 | titled=True, 32 | retrieval_level=RetrievalLevel.document, 33 | ) 34 | ).load_data(progress_bar=is_device_zero()) 35 | labeled_queries = dataset.get_labeled_queries(args.split) 36 | queries = LabeledQuery.get_unique_queries(labeled_queries) 37 | retrieved = search( 38 | retriever=retriever, 39 | collection_iter=title_only_collection_iter(dataset.collection_iter), 40 | collection_size=dataset.collection_size, 41 | queries=queries, 42 | topk=args.topk, 43 | per_device_eval_batch_size=args.per_device_eval_batch_size, 44 | fp16=args.fp16, 45 | ) 46 | if is_device_zero(): 47 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 48 | eval_scores = evaluator._pytrec_eval( 49 | trec_scores=RetrievedPassageIDList.build_trec_scores(retrieved), 50 | qrels=evaluator.qrels, 51 | ) 52 | report_metrics = evaluator._build_report(eval_scores) 53 | fdetails = os.path.join(args.output_dir, "q2d_details.jsonl") 54 | with open(fdetails, "w") as f: 55 | for qid, row in eval_scores.items(): 56 | row = dict(eval_scores[qid]) 57 | row["query_id"] = qid 58 | f.write(json.dumps(row) + "\n") 59 | freport = os.path.join(args.output_dir, "q2d_metrics.json") 60 | with open(freport, "w") as f: 61 | json.dump(report_metrics, f, indent=4) 62 | logging.info(f"Saved evaluation metrics to {freport}.") 63 | franked = os.path.join(args.output_dir, "doc_ranking_results.txt") 64 | RetrievedPassageIDList.dump_trec_csv( 65 | retrieval_results=retrieved, fpath=franked, system=retriever_name 66 | ) 67 | logging.info(f"Saved ranking results to {franked}.") 68 | -------------------------------------------------------------------------------- /dapr/exps/keyphrases/shared_pipeline.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | import logging 4 | import os 5 | from typing import Dict, Iterable, Optional, TypedDict 6 | from clddp.dm import LabeledQuery, Passage, RetrievalDataset, RetrievedPassageIDList 7 | from clddp.retriever import Retriever 8 | from clddp.utils import is_device_zero 9 | from dapr.datasets.dm import LoadedData 10 | from ...dataloader import DAPRDataConfig 11 | from dapr.exps.keyphrases.args.base import KeyphrasesArguments 12 | from clddp.search import search 13 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader 14 | from clddp.evaluation import RetrievalEvaluator 15 | import ujson 16 | 17 | 18 | class KeyphrasesRow(TypedDict): 19 | doc_id: str 20 | doc_summary: str 21 | 22 | 23 | @dataclass 24 | class KeyphrasesDAPRDataConfig(DAPRDataConfig): 25 | keyphrases_path: Optional[str] = None 26 | 27 | 28 | class KeyphrasesDAPRDataLoader(DAPRDataLoader): 29 | def __init__(self, config: KeyphrasesDAPRDataConfig) -> None: 30 | self.config = config 31 | self.did2kps: Dict[str, str] = {} 32 | self.pid2did: Dict[str, str] = {} 33 | 34 | def collection_iter_fn(self, data: LoadedData) -> Iterable[Passage]: 35 | for psg in super().collection_iter_fn(data): 36 | kps = self.did2kps[self.pid2did[psg.passage_id]] 37 | yield Passage(passage_id=psg.passage_id, text=psg.text, title=kps) 38 | 39 | def load_data(self, progress_bar: bool) -> RetrievalDataset: 40 | logging.info("Loading keyphrases") 41 | with open(self.config.keyphrases_path) as f: 42 | for line in f: 43 | line_dict: KeyphrasesRow = ujson.loads(line) 44 | self.did2kps[line_dict["doc_id"]] = line_dict["doc_summary"] 45 | 46 | self.pid2did = self.get_pid2did(progress_bar) 47 | return super().load_data(progress_bar) 48 | 49 | 50 | def run_keyphrases( 51 | args: KeyphrasesArguments, retriever: Retriever, retriever_name: str 52 | ) -> None: 53 | # Actually the same as passage_only. And the difference relies on the data 54 | if is_device_zero(): 55 | args.dump_arguments() 56 | dataset = KeyphrasesDAPRDataLoader( 57 | KeyphrasesDAPRDataConfig( 58 | data_name_or_path=args.data_dir, 59 | keyphrases_path=args.keyphrases_path, 60 | titled=True, # Actually the keyphrases here 61 | ) 62 | ).load_data(progress_bar=is_device_zero()) 63 | labeled_queries = dataset.get_labeled_queries(args.split) 64 | queries = LabeledQuery.get_unique_queries(labeled_queries) 65 | retrieved = search( 66 | retriever=retriever, 67 | collection_iter=dataset.collection_iter, 68 | collection_size=dataset.collection_size, 69 | queries=queries, 70 | topk=args.topk, 71 | per_device_eval_batch_size=args.per_device_eval_batch_size, 72 | fp16=args.fp16, 73 | ) 74 | if is_device_zero(): 75 | evaluator = RetrievalEvaluator(eval_dataset=dataset, split=args.split) 76 | report = evaluator(retrieved) 77 | freport = os.path.join(args.output_dir, "metrics.json") 78 | with open(freport, "w") as f: 79 | json.dump(report, f, indent=4) 80 | logging.info(f"Saved evaluation metrics to {freport}.") 81 | franked = os.path.join(args.output_dir, "ranking_results.txt") 82 | RetrievedPassageIDList.dump_trec_csv( 83 | retrieval_results=retrieved, fpath=franked, system=retriever_name 84 | ) 85 | logging.info(f"Saved ranking results to {franked}.") 86 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dapr 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.12.12=h06a4308_0 8 | - dbus=1.13.18=hb2f20db_0 9 | - expat=2.5.0=h6a678d5_0 10 | - freetype=2.12.1=h4a9f257_0 11 | - glib=2.69.1=he621ea3_2 12 | - ld_impl_linux-64=2.38=h1181459_1 13 | - libffi=3.4.4=h6a678d5_0 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgomp=11.2.0=h1234567_1 16 | - libpng=1.6.39=h5eee18b_0 17 | - libstdcxx-ng=11.2.0=h1234567_1 18 | - libxcb=1.15=h7f8727e_0 19 | - ncurses=6.4=h6a678d5_0 20 | - openjdk=11.0.13=h87a67e3_0 21 | - openssl=3.0.12=h7f8727e_0 22 | - pcre=8.45=h295c915_0 23 | - pip=23.3.1=py38h06a4308_0 24 | - python=3.8.18=h955ad1f_0 25 | - readline=8.2=h5eee18b_0 26 | - setuptools=68.2.2=py38h06a4308_0 27 | - sqlite=3.41.2=h5eee18b_0 28 | - tk=8.6.12=h1ccaba5_0 29 | - wheel=0.41.2=py38h06a4308_0 30 | - xz=5.4.5=h5eee18b_0 31 | - zlib=1.2.13=h5eee18b_0 32 | - pip: 33 | - accelerate==0.26.1 34 | - aiohttp==3.9.3 35 | - aiosignal==1.3.1 36 | - annotated-types==0.6.0 37 | - appdirs==1.4.4 38 | - async-timeout==4.0.3 39 | - attrs==23.2.0 40 | - beautifulsoup4==4.12.3 41 | - blis==0.7.11 42 | - catalogue==2.0.10 43 | - cbor==1.0.0 44 | - certifi==2023.11.17 45 | - charset-normalizer==3.3.2 46 | - clddp==0.0.8 47 | - click==8.1.7 48 | - cloudpathlib==0.16.0 49 | - colbert==0.2.0 50 | - coloredlogs==15.0.1 51 | - confection==0.1.4 52 | - contourpy==1.1.1 53 | - cycler==0.12.1 54 | - cymem==2.0.8 55 | - cython==3.0.8 56 | - datasets==2.16.1 57 | - dill==0.3.7 58 | - docker-pycreds==0.4.0 59 | - faiss-cpu==1.7.4 60 | - filelock==3.13.1 61 | - flatbuffers==23.5.26 62 | - fonttools==4.47.2 63 | - frozenlist==1.4.1 64 | - fsspec==2023.10.0 65 | - future==0.18.3 66 | - gitdb==4.0.11 67 | - gitpython==3.1.41 68 | - huggingface-hub==0.20.3 69 | - humanfriendly==10.0 70 | - idna==3.6 71 | - ijson==3.2.3 72 | - ir-datasets==0.5.4 73 | - jinja2==3.1.3 74 | - joblib==1.3.2 75 | - kiwisolver==1.4.5 76 | - langcodes==3.3.0 77 | - lightgbm==4.3.0 78 | - lxml==5.1.0 79 | - lz4==4.3.3 80 | - markupsafe==2.1.4 81 | - matplotlib==3.7.4 82 | - more-itertools==9.1.0 83 | - mpmath==1.3.0 84 | - multidict==6.0.4 85 | - multiprocess==0.70.15 86 | - murmurhash==1.0.10 87 | - networkx==3.1 88 | - nltk==3.8.1 89 | - nmslib==2.1.1 90 | - numpy==1.24.4 91 | - nvidia-cublas-cu12==12.1.3.1 92 | - nvidia-cuda-cupti-cu12==12.1.105 93 | - nvidia-cuda-nvrtc-cu12==12.1.105 94 | - nvidia-cuda-runtime-cu12==12.1.105 95 | - nvidia-cudnn-cu12==8.9.2.26 96 | - nvidia-cufft-cu12==11.0.2.54 97 | - nvidia-curand-cu12==10.3.2.106 98 | - nvidia-cusolver-cu12==11.4.5.107 99 | - nvidia-cusparse-cu12==12.1.0.106 100 | - nvidia-nccl-cu12==2.19.3 101 | - nvidia-nvjitlink-cu12==12.3.101 102 | - nvidia-nvtx-cu12==12.1.105 103 | - onnxruntime==1.16.3 104 | - packaging==23.2 105 | - pandas==2.0.3 106 | - pillow==10.2.0 107 | - pke==2.0.0 108 | - preshed==3.0.9 109 | - protobuf==4.25.2 110 | - psutil==5.9.8 111 | - pyarrow==15.0.0 112 | - pyarrow-hotfix==0.6 113 | - pyautocorpus==0.1.12 114 | - pybind11==2.6.1 115 | - pydantic==2.6.0 116 | - pydantic-core==2.16.1 117 | - pyjnius==1.6.1 118 | - pyparsing==3.1.1 119 | - pyserini==0.21.0 120 | - python-dateutil==2.8.2 121 | - pytrec-eval==0.5 122 | - pytz==2023.4 123 | - pyyaml==6.0.1 124 | - regex==2023.12.25 125 | - requests==2.31.0 126 | - safetensors==0.4.2 127 | - scikit-learn==1.3.2 128 | - scipy==1.10.1 129 | - sentence-transformers==2.3.1 130 | - sentencepiece==0.1.99 131 | - sentry-sdk==1.40.0 132 | - setproctitle==1.3.3 133 | - six==1.16.0 134 | - smart-open==6.4.0 135 | - smmap==5.0.1 136 | - soupsieve==2.5 137 | - spacy==3.7.2 138 | - spacy-legacy==3.0.12 139 | - spacy-loggers==1.0.5 140 | - srsly==2.4.8 141 | - sympy==1.12 142 | - thinc==8.2.2 143 | - threadpoolctl==3.2.0 144 | - tokenizers==0.15.1 145 | - torch==2.2.0 146 | - tqdm==4.66.1 147 | - transformers==4.37.2 148 | - trec-car-tools==2.6 149 | - triton==2.2.0 150 | - typer==0.9.0 151 | - typing-extensions==4.9.0 152 | - tzdata==2023.4 153 | - ujson==5.7.0 154 | - unidecode==1.3.8 155 | - unlzw3==0.2.2 156 | - urllib3==2.2.0 157 | - wandb==0.16.2 158 | - warc3-wet==0.2.3 159 | - warc3-wet-clueweb09==0.2.5 160 | - wasabi==1.1.2 161 | - weasel==0.3.4 162 | - xxhash==3.4.1 163 | - yarl==1.9.4 164 | - zlib-state==0.1.6 165 | -------------------------------------------------------------------------------- /dapr/annotators/pke.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from enum import Enum 3 | from functools import partial 4 | from typing import Dict, Iterable, List, Tuple 5 | from dapr.datasets.base import LoadedData 6 | from dapr.datasets.dm import Document 7 | from dapr.utils import Multiprocesser 8 | from pke.base import LoadFile 9 | from pke.unsupervised import TopicRank 10 | from dapr.annotators.base import BaseAnnotator 11 | 12 | MAX_LENGTH = 65536 # <= 10*6 Required by spaCy, 65536 takes around 12s already 13 | 14 | 15 | class KeyphraseApproach(str, Enum): 16 | topic_rank = "topic_rank" 17 | 18 | def __call__(self, text: str, n: int) -> Dict[str, float]: 19 | model = {KeyphraseApproach.topic_rank: TopicRank}[self]() 20 | keyphrases = {KeyphraseApproach.topic_rank: KeyphraseApproach._topic_rank}[ 21 | self 22 | ](model, text, n) 23 | return keyphrases 24 | 25 | @staticmethod 26 | def _topic_rank(ka: LoadFile, text: str, n: int) -> List[str]: 27 | assert isinstance(ka, TopicRank) 28 | ka.load_document(input=text[:MAX_LENGTH], language="en") 29 | ka.candidate_selection() 30 | ka.candidate_weighting() 31 | keyphrases: List[Tuple[str, float]] = ka.get_n_best(n=n) 32 | keyphrases_sorted = sorted( 33 | keyphrases, key=lambda kp_and_score: kp_and_score[1], reverse=True 34 | ) 35 | kps = list(map(lambda kp_and_score: kp_and_score[0], keyphrases_sorted)) 36 | return kps 37 | 38 | 39 | class PKEAnnotator(BaseAnnotator): 40 | def __init__( 41 | self, 42 | top_k_words: int, 43 | keyphrase_approach: KeyphraseApproach, 44 | nprocs: int, 45 | ) -> None: 46 | self.top_k_words = top_k_words 47 | self.keyphrase_approach = keyphrase_approach 48 | self.nprocs = nprocs 49 | 50 | def _run(self, docs: Iterable[Document], ndocs: int) -> List[List[str]]: 51 | multiprocessor = Multiprocesser(self.nprocs) 52 | texts = map( 53 | lambda doc: "\n".join([chk.text for chk in doc.chunks])[:MAX_LENGTH], docs 54 | ) 55 | results: List[List[str]] = multiprocessor.run( 56 | texts, 57 | func=partial(self.keyphrase_approach, n=self.top_k_words), 58 | desc=f"Running {self.keyphrase_approach}", 59 | total=ndocs, 60 | chunk_size=500, # https://stackoverflow.com/questions/64515797/there-appear-to-be-6-leaked-semaphore-objects-to-clean-up-at-shutdown-warnings-w#comment126544553_65130215 61 | ) 62 | return results 63 | 64 | def extract(self, data: LoadedData) -> Dict[str, List[str]]: 65 | assert data.corpus_iter_fn is not None 66 | 67 | kps_corpus = self._run( 68 | docs=data.corpus_iter_fn(), ndocs=data.meta_data["ndocs"] 69 | ) 70 | doc_id2kps = { 71 | doc.doc_id: kps for doc, kps in zip(data.corpus_iter_fn(), kps_corpus) 72 | } 73 | 74 | leftover: List[Document] = [] 75 | for lqs in [ 76 | data.labeled_queries_dev, 77 | data.labeled_queries_test, 78 | data.labeled_queries_train, 79 | ]: 80 | if lqs is None: 81 | continue 82 | for lq in lqs: 83 | for jchk in lq.judged_chunks: 84 | doc = jchk.chunk.belonging_doc 85 | if doc.doc_id not in doc_id2kps: 86 | leftover.append(doc) 87 | 88 | kps_leftover = self._run(docs=leftover, ndocs=len(leftover)) 89 | for doc, kps in zip(leftover, kps_leftover): 90 | doc_id2kps[doc.doc_id] = kps 91 | 92 | return doc_id2kps 93 | 94 | def _annotate(self, data: LoadedData) -> Dict[str, str]: 95 | assert data.corpus_iter_fn is not None 96 | extract_fn = self.extract 97 | doc_id2kps = extract_fn(data) 98 | did2dsum: Dict[str, str] = {} 99 | for doc in data.corpus_iter_fn(): 100 | doc_summary = "; ".join(doc_id2kps[doc.doc_id]) 101 | did2dsum[doc.doc_id] = doc_summary 102 | 103 | for lqs in [ 104 | data.labeled_queries_train, 105 | data.labeled_queries_dev, 106 | data.labeled_queries_test, 107 | ]: 108 | if lqs is None: 109 | continue 110 | for lq in lqs: 111 | for jchk in lq.judged_chunks: 112 | doc = jchk.chunk.belonging_doc 113 | doc_summary = "; ".join(doc_id2kps[doc.doc_id]) 114 | did2dsum[doc.doc_id] = doc_summary 115 | return did2dsum 116 | 117 | 118 | if __name__ == "__main__": 119 | import sys 120 | from dapr.datasets.dm import LoadedData 121 | from dapr.utils import set_logger_format 122 | import logging 123 | 124 | set_logger_format() 125 | data_dir = sys.argv[1] 126 | logging.info(f"Loading from {data_dir}") 127 | data = LoadedData.from_dump(data_dir) 128 | pke_summarizer = PKEAnnotator(10, KeyphraseApproach.topic_rank, 32) 129 | pke_summarizer.annotate(data=data, cache_root_dir="pke") 130 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_fusion/shared_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from typing import Type 5 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 6 | from clddp.utils import is_device_zero, parse_cli, set_logger_format 7 | from dapr.retrievers.bm25 import BM25 8 | from dapr.exps.bm25_doc_passage_fusion.args.base import ( 9 | BM25DocPassageFusionArguments, 10 | ) 11 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader, RetrievalLevel 12 | from clddp.evaluation import RetrievalEvaluator 13 | from dapr.fusion import doc_passage_fusion_with_M2C2 14 | import numpy as np 15 | import tqdm 16 | from matplotlib import pyplot as plt 17 | 18 | 19 | def run_bm25_doc_passage_fusion( 20 | arguments_class: Type[BM25DocPassageFusionArguments], passage_retriever_name: str 21 | ) -> None: 22 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 23 | args = parse_cli(arguments_class) 24 | args.dump_arguments() 25 | 26 | # Doing BM25 document retrieval: 27 | fdocs_ranking = os.path.join(args.output_dir, "doc_ranking_results.txt") 28 | doc_dataset = DAPRDataLoader( 29 | DAPRDataConfig( 30 | data_name_or_path=args.data_dir, retrieval_level=RetrievalLevel.document 31 | ) 32 | ).load_data(True) 33 | labeled_queries = doc_dataset.get_labeled_queries(args.split) 34 | queries = LabeledQuery.get_unique_queries(labeled_queries) 35 | retriever = BM25() 36 | index_path = os.path.join(args.output_dir, "index") 37 | if not (os.path.exists(index_path) and len(os.listdir(index_path))): 38 | retriever.index( 39 | collection_iter=doc_dataset.collection_iter, 40 | collection_size=doc_dataset.collection_size, 41 | output_dir=index_path, 42 | ) 43 | else: 44 | logging.info(f"Found existing index {index_path}") 45 | retrieved_docs = retriever.search( 46 | queries=queries, 47 | index_path=index_path, 48 | topk=args.topk, 49 | batch_size=args.per_device_eval_batch_size, 50 | ) 51 | RetrievedPassageIDList.dump_trec_csv( 52 | retrieval_results=retrieved_docs, fpath=fdocs_ranking, system="bm25" 53 | ) 54 | logging.info(f"Saved BM25 document ranking results to {fdocs_ranking}.") 55 | 56 | # Loading passage ranking results and doing fusion: 57 | retrieved_passages = RetrievedPassageIDList.from_trec_csv(args.passage_results) 58 | paragraph_dataloader = DAPRDataLoader( 59 | DAPRDataConfig( 60 | data_name_or_path=args.data_dir, retrieval_level=RetrievalLevel.paragraph 61 | ) 62 | ) 63 | pid2did = paragraph_dataloader.get_pid2did(is_device_zero()) 64 | passage_weights = [round(weight, 1) for weight in np.arange(0, 1.1, 0.1).tolist()] 65 | paragraph_dataset = paragraph_dataloader.load_data(is_device_zero()) 66 | pwegiht2metrics = {} 67 | for passage_weight in tqdm.tqdm(passage_weights, desc="Doing fusion"): 68 | fused = doc_passage_fusion_with_M2C2( 69 | doc_results=retrieved_docs, 70 | passage_results=retrieved_passages, 71 | pid2did=pid2did, 72 | passage_weight=passage_weight, 73 | ) 74 | evaluator = RetrievalEvaluator(eval_dataset=paragraph_dataset, split=args.split) 75 | report = evaluator(fused) 76 | pwegiht2metrics[passage_weight] = report 77 | freport = os.path.join( 78 | args.output_dir, f"metrics-pweight_{passage_weight}.json" 79 | ) 80 | with open(freport, "w") as f: 81 | json.dump(report, f, indent=4) 82 | logging.info(f"Saved evaluation metrics to {freport}.") 83 | franked = os.path.join( 84 | args.output_dir, f"ranking_results-pweight_{passage_weight}.txt" 85 | ) 86 | RetrievedPassageIDList.dump_trec_csv( 87 | retrieval_results=fused, fpath=franked, system=passage_retriever_name 88 | ) 89 | logging.info(f"Saved ranking results to {franked}.") 90 | 91 | # Save again the metrics to be reported: 92 | if args.report_passage_weight: 93 | report_metrics = pwegiht2metrics[args.report_passage_weight] 94 | freport = os.path.join(args.output_dir, f"metrics.json") 95 | with open(freport, "w") as f: 96 | json.dump(report_metrics, f, indent=4) 97 | logging.info(f"Saved evaluation metrics to {freport}.") 98 | 99 | # Plot the curve: 100 | pweight2main_metric = { 101 | pweight: pwegiht2metrics[pweight][args.report_metric] 102 | for pweight in passage_weights 103 | } 104 | fcurve_data = os.path.join( 105 | args.output_dir, f"passage_weight-vs-{args.report_metric}.json" 106 | ) 107 | with open(fcurve_data, "w") as f: 108 | json.dump(pweight2main_metric, f, indent=4) 109 | logging.info(f"Saved the curve data to {fcurve_data}") 110 | main_metrics = [pweight2main_metric[pweight] for pweight in passage_weights] 111 | fcurve = os.path.join( 112 | args.output_dir, f"passage_weight-vs-{args.report_metric}.pdf" 113 | ) 114 | plt.plot(passage_weights, main_metrics) 115 | plt.grid(linestyle="dashed") 116 | plt.xlabel("passage weight") 117 | plt.ylabel(f"{args.report_metric}") 118 | plt.savefig(fcurve, bbox_inches="tight") 119 | -------------------------------------------------------------------------------- /dapr/retrievers/bm25.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | import json 4 | import os 5 | import shutil 6 | from typing import Callable, Dict, Iterable, List, Type 7 | import tqdm 8 | from clddp.dm import Passage, Query, RetrievedPassageIDList, ScoredPassageID 9 | import logging 10 | 11 | 12 | @dataclass 13 | class PyseriniHit: 14 | docid: str 15 | score: float 16 | 17 | @classmethod 18 | def from_pyserini(cls: Type[PyseriniHit], hit: PyseriniHit) -> PyseriniHit: 19 | """This conversion step makes the memory releasable (otherwise leaked by Java).""" 20 | return PyseriniHit(docid=hit.docid, score=hit.score) 21 | 22 | 23 | @dataclass 24 | class PyseriniCollectionRow: 25 | id: str 26 | title: str 27 | contents: str 28 | 29 | def to_json(self) -> Dict[str, str]: 30 | return {"id": self.id, "title": self.title, "contents": self.contents} 31 | 32 | 33 | class BM25: 34 | CORPUS_FOLDER = "corpus" 35 | 36 | def __init__(self) -> None: 37 | os.environ["_JAVA_OPTIONS"] = ( 38 | "-Xmx5g" # Otherwise it would cause to huge memory leak! 39 | ) 40 | 41 | def index( 42 | self, 43 | collection_iter: Iterable[Passage], 44 | collection_size: int, 45 | output_dir: str, 46 | nthreads: int = 12, 47 | keep_converted_corpus: bool = False, 48 | ) -> None: 49 | logging.info("Run indexing.") 50 | import pyserini.index.lucene 51 | from jnius import autoclass 52 | 53 | # Converting into the required format for indexing: 54 | coprus_path = os.path.join(output_dir, self.CORPUS_FOLDER) 55 | os.makedirs(coprus_path, exist_ok=True) 56 | with open(os.path.join(coprus_path, "texts.jsonl"), "w") as f: 57 | for psg in tqdm.tqdm( 58 | collection_iter, 59 | total=collection_size, 60 | desc="Converting to pyserini format", 61 | ): 62 | title = psg.title if psg.title else "" 63 | json_line = PyseriniCollectionRow( 64 | id=psg.passage_id, title=title, contents=psg.text 65 | ).to_json() 66 | f.write(json.dumps(json_line) + "\n") 67 | 68 | # Run actual indexing: 69 | args = [ 70 | "-collection", 71 | "JsonCollection", 72 | "-generator", 73 | "DefaultLuceneDocumentGenerator", 74 | "-threads", 75 | str(nthreads), 76 | "-input", 77 | coprus_path, 78 | "-index", 79 | output_dir, 80 | # "-storeRaw", 81 | "-storePositions", 82 | "-storeDocvectors", 83 | "-fields", 84 | "title", 85 | ] 86 | JIndexCollection = autoclass("io.anserini.index.IndexCollection") 87 | index_fn: Callable[ 88 | [ 89 | List[str], 90 | ], 91 | None, 92 | ] = getattr(JIndexCollection, "main") 93 | try: 94 | index_fn(args) 95 | except Exception as e: 96 | shutil.rmtree(output_dir) 97 | raise e 98 | logging.info(f"Done indexing. Index path: {output_dir}") 99 | if not keep_converted_corpus: 100 | logging.info(f"Removing the converted corpus {coprus_path}") 101 | shutil.rmtree(coprus_path) 102 | 103 | def search( 104 | self, 105 | queries: List[Query], 106 | index_path: str, 107 | topk: int, 108 | batch_size: int, 109 | contents_weight: float = 1.0, 110 | title_weight: float = 1.0, 111 | nthreads: int = 12, 112 | ) -> List[RetrievedPassageIDList]: 113 | from pyserini.search import SimpleSearcher 114 | 115 | # Actual search: 116 | fields = { 117 | "contents": contents_weight, 118 | "title": title_weight, 119 | } 120 | searcher = SimpleSearcher(index_path) 121 | searcher.set_bm25() 122 | qid2hits_all: Dict[str, List[PyseriniHit]] = {} 123 | for b in tqdm.trange(0, len(queries), batch_size, desc="Query batch"): 124 | e = b + batch_size 125 | qid2hits = searcher.batch_search( 126 | queries=list(map(lambda query: query.text, queries[b:e])), 127 | qids=list(map(lambda query: query.query_id, queries[b:e])), 128 | k=topk, 129 | threads=nthreads, 130 | fields=fields, 131 | ) 132 | qid2hits_converted = { 133 | qid: list(map(PyseriniHit.from_pyserini, hits)) 134 | for qid, hits in qid2hits.items() 135 | } 136 | qid2hits_all.update(qid2hits_converted) 137 | searcher.close() 138 | 139 | # Convert to the clddp format: 140 | retrieved = [] 141 | for query in queries: 142 | hits = qid2hits_all[query.query_id] 143 | spids = [ 144 | ScoredPassageID(passage_id=hit.docid, score=hit.score) for hit in hits 145 | ] 146 | retrieved.append( 147 | RetrievedPassageIDList( 148 | query_id=query.query_id, scored_passage_ids=spids 149 | ) 150 | ) 151 | return retrieved 152 | -------------------------------------------------------------------------------- /dapr/exps/jinav2_doc_passage_fusion/shared_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from typing import Type 5 | from clddp.dm import LabeledQuery, RetrievedPassageIDList 6 | from clddp.utils import initialize_ddp, is_device_zero, parse_cli, set_logger_format 7 | from dapr.retrievers.dense import JinaV2 8 | from dapr.exps.jinav2_doc_passage_fusion.args.base import ( 9 | JinaV2DocPassageFusionArguments, 10 | ) 11 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader, RetrievalLevel 12 | from clddp.evaluation import RetrievalEvaluator 13 | from clddp.search import search 14 | from dapr.fusion import doc_passage_fusion_with_M2C2 15 | import numpy as np 16 | import tqdm 17 | from matplotlib import pyplot as plt 18 | 19 | 20 | def run_jina_doc_passage_fusion( 21 | arguments_class: Type[JinaV2DocPassageFusionArguments], passage_retriever_name: str 22 | ) -> None: 23 | set_logger_format(logging.INFO if is_device_zero() else logging.WARNING) 24 | initialize_ddp() 25 | args = parse_cli(arguments_class) 26 | if is_device_zero(): 27 | args.dump_arguments() 28 | 29 | # Doing JinaV2 document retrieval: 30 | fdocs_ranking = os.path.join(args.output_dir, "doc_ranking_results.txt") 31 | doc_dataset = DAPRDataLoader( 32 | DAPRDataConfig( 33 | data_name_or_path=args.data_dir, retrieval_level=RetrievalLevel.document 34 | ) 35 | ).load_data(is_device_zero()) 36 | labeled_queries = doc_dataset.get_labeled_queries(args.split) 37 | queries = LabeledQuery.get_unique_queries(labeled_queries) 38 | retriever = JinaV2() 39 | retrieved_docs = search( 40 | retriever=retriever, 41 | collection_iter=doc_dataset.collection_iter, 42 | collection_size=doc_dataset.collection_size, 43 | queries=queries, 44 | topk=args.topk, 45 | per_device_eval_batch_size=args.per_device_doc_batch_size, 46 | fp16=args.fp16, 47 | ) 48 | if is_device_zero(): 49 | RetrievedPassageIDList.dump_trec_csv( 50 | retrieval_results=retrieved_docs, fpath=fdocs_ranking, system="jinav2" 51 | ) 52 | logging.info(f"Saved JinaV2 document ranking results to {fdocs_ranking}.") 53 | 54 | # Loading passage ranking results and doing fusion: 55 | if is_device_zero(): 56 | retrieved_passages = RetrievedPassageIDList.from_trec_csv(args.passage_results) 57 | paragraph_dataloader = DAPRDataLoader( 58 | DAPRDataConfig( 59 | data_name_or_path=args.data_dir, 60 | retrieval_level=RetrievalLevel.paragraph, 61 | ) 62 | ) 63 | pid2did = paragraph_dataloader.get_pid2did(is_device_zero()) 64 | passage_weights = [ 65 | round(weight, 1) for weight in np.arange(0, 1.1, 0.1).tolist() 66 | ] 67 | paragraph_dataset = paragraph_dataloader.load_data(is_device_zero()) 68 | pwegiht2metrics = {} 69 | for passage_weight in tqdm.tqdm(passage_weights, desc="Doing fusion"): 70 | fused = doc_passage_fusion_with_M2C2( 71 | doc_results=retrieved_docs, 72 | passage_results=retrieved_passages, 73 | pid2did=pid2did, 74 | passage_weight=passage_weight, 75 | ) 76 | evaluator = RetrievalEvaluator( 77 | eval_dataset=paragraph_dataset, split=args.split 78 | ) 79 | report = evaluator(fused) 80 | pwegiht2metrics[passage_weight] = report 81 | freport = os.path.join( 82 | args.output_dir, f"metrics-pweight_{passage_weight}.json" 83 | ) 84 | with open(freport, "w") as f: 85 | json.dump(report, f, indent=4) 86 | logging.info(f"Saved evaluation metrics to {freport}.") 87 | franked = os.path.join( 88 | args.output_dir, f"ranking_results-pweight_{passage_weight}.txt" 89 | ) 90 | RetrievedPassageIDList.dump_trec_csv( 91 | retrieval_results=fused, fpath=franked, system=passage_retriever_name 92 | ) 93 | logging.info(f"Saved ranking results to {franked}.") 94 | 95 | # Save again the metrics to be reported: 96 | if args.report_passage_weight: 97 | report_metrics = pwegiht2metrics[args.report_passage_weight] 98 | freport = os.path.join(args.output_dir, f"metrics.json") 99 | with open(freport, "w") as f: 100 | json.dump(report_metrics, f, indent=4) 101 | logging.info(f"Saved evaluation metrics to {freport}.") 102 | 103 | # Plot the curve: 104 | pweight2main_metric = { 105 | pweight: pwegiht2metrics[pweight][args.report_metric] 106 | for pweight in passage_weights 107 | } 108 | fcurve_data = os.path.join( 109 | args.output_dir, f"passage_weight-vs-{args.report_metric}.json" 110 | ) 111 | with open(fcurve_data, "w") as f: 112 | json.dump(pweight2main_metric, f, indent=4) 113 | logging.info(f"Saved the curve data to {fcurve_data}") 114 | main_metrics = [pweight2main_metric[pweight] for pweight in passage_weights] 115 | fcurve = os.path.join( 116 | args.output_dir, f"passage_weight-vs-{args.report_metric}.pdf" 117 | ) 118 | plt.plot(passage_weights, main_metrics) 119 | plt.grid(linestyle="dashed") 120 | plt.xlabel("passage weight") 121 | plt.ylabel(f"{args.report_metric}") 122 | plt.savefig(fcurve, bbox_inches="tight") 123 | -------------------------------------------------------------------------------- /dapr/exps/bm25_doc_passage_hierarchy/shared_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from typing import Iterable, List, Protocol, Set 5 | from clddp.dm import LabeledQuery, Passage, Query, RetrievedPassageIDList 6 | from clddp.utils import is_device_zero 7 | from dapr.retrievers.bm25 import BM25 8 | from dapr.exps.bm25_doc_passage_hierarchy.args.base import ( 9 | BM25DocPassageHierarchyArguments, 10 | ) 11 | from dapr.dataloader import DAPRDataConfig, DAPRDataLoader, RetrievalLevel 12 | from clddp.evaluation import RetrievalEvaluator 13 | from dapr.fusion import doc_passage_fusion_with_M2C2 14 | import numpy as np 15 | import tqdm 16 | from matplotlib import pyplot as plt 17 | from torch import distributed as dist 18 | 19 | 20 | class ScopedSearchFunction(Protocol): 21 | def __call__( 22 | self, 23 | collection_iter: Iterable[Passage], 24 | collection_size: int, 25 | queries: List[Query], 26 | topk: int, 27 | passage_scopes: List[Set[str]], # For each query, which pids are allowed 28 | ) -> List[RetrievedPassageIDList]: 29 | pass 30 | 31 | 32 | def run_bm25_doc_passage_hierarchy( 33 | args: BM25DocPassageHierarchyArguments, 34 | passage_retriever_name: str, 35 | scoped_search_function: ScopedSearchFunction, 36 | ) -> None: 37 | # Doing BM25 document retrieval: 38 | fdocs_ranking = os.path.join(args.output_dir, "doc_ranking_results.txt") 39 | if is_device_zero(): 40 | doc_dataset = DAPRDataLoader( 41 | DAPRDataConfig( 42 | data_name_or_path=args.data_dir, retrieval_level=RetrievalLevel.document 43 | ) 44 | ).load_data(True) 45 | labeled_queries = doc_dataset.get_labeled_queries(args.split) 46 | queries = LabeledQuery.get_unique_queries(labeled_queries) 47 | retriever = BM25() 48 | index_path = os.path.join(args.output_dir, "index") 49 | if not (os.path.exists(index_path) and len(os.listdir(index_path))): 50 | retriever.index( 51 | collection_iter=doc_dataset.collection_iter, 52 | collection_size=doc_dataset.collection_size, 53 | output_dir=index_path, 54 | ) 55 | else: 56 | logging.info(f"Found existing index {index_path}") 57 | retrieved_docs = retriever.search( 58 | queries=queries, 59 | index_path=index_path, 60 | topk=args.topk, 61 | batch_size=args.per_device_eval_batch_size, 62 | ) 63 | RetrievedPassageIDList.dump_trec_csv( 64 | retrieval_results=retrieved_docs, fpath=fdocs_ranking, system="bm25" 65 | ) 66 | logging.info(f"Saved BM25 document ranking results to {fdocs_ranking}.") 67 | if dist.is_initialized(): 68 | dist.barrier() 69 | retrieved_docs = RetrievedPassageIDList.from_trec_csv(fdocs_ranking) 70 | 71 | # Search over the paragraphs which belongs to the retrieved docs: 72 | paragraph_dataloader = DAPRDataLoader( 73 | DAPRDataConfig( 74 | data_name_or_path=args.data_dir, retrieval_level=RetrievalLevel.paragraph 75 | ) 76 | ) 77 | paragraph_dataset = paragraph_dataloader.load_data(is_device_zero()) 78 | pid2did = paragraph_dataloader.get_pid2did(is_device_zero()) 79 | did2pids = paragraph_dataloader.get_did2pids(is_device_zero()) 80 | labeled_queries = paragraph_dataset.get_labeled_queries(args.split) 81 | queries = LabeledQuery.get_unique_queries(labeled_queries) 82 | passage_scopes = [ 83 | {pid for sdoc in qdoc.scored_passage_ids for pid in did2pids[sdoc.passage_id]} 84 | for qdoc in retrieved_docs 85 | ] 86 | retrieved_passages = scoped_search_function( 87 | collection_iter=paragraph_dataset.collection_iter, 88 | collection_size=paragraph_dataset.collection_size, 89 | queries=queries, 90 | topk=args.topk, 91 | passage_scopes=passage_scopes, 92 | ) 93 | 94 | # Doing fusion: 95 | if is_device_zero(): 96 | passage_weights = [ 97 | round(weight, 1) for weight in np.arange(0, 1.1, 0.1).tolist() 98 | ] 99 | pwegiht2metrics = {} 100 | for passage_weight in tqdm.tqdm(passage_weights, desc="Doing fusion"): 101 | fused = doc_passage_fusion_with_M2C2( 102 | doc_results=retrieved_docs, 103 | passage_results=retrieved_passages, 104 | pid2did=pid2did, 105 | passage_weight=passage_weight, 106 | ) 107 | evaluator = RetrievalEvaluator( 108 | eval_dataset=paragraph_dataset, split=args.split 109 | ) 110 | report = evaluator(fused) 111 | pwegiht2metrics[passage_weight] = report 112 | freport = os.path.join( 113 | args.output_dir, f"metrics-pweight_{passage_weight}.json" 114 | ) 115 | with open(freport, "w") as f: 116 | json.dump(report, f, indent=4) 117 | logging.info(f"Saved evaluation metrics to {freport}.") 118 | franked = os.path.join( 119 | args.output_dir, f"ranking_results-pweight_{passage_weight}.txt" 120 | ) 121 | RetrievedPassageIDList.dump_trec_csv( 122 | retrieval_results=fused, fpath=franked, system=passage_retriever_name 123 | ) 124 | logging.info(f"Saved ranking results to {franked}.") 125 | 126 | # Save again the metrics to be reported: 127 | if args.report_passage_weight: 128 | report_metrics = pwegiht2metrics[args.report_passage_weight] 129 | freport = os.path.join(args.output_dir, f"metrics.json") 130 | with open(freport, "w") as f: 131 | json.dump(report_metrics, f, indent=4) 132 | logging.info(f"Saved evaluation metrics to {freport}.") 133 | 134 | # Plot the curve: 135 | pweight2main_metric = { 136 | pweight: pwegiht2metrics[pweight][args.report_metric] 137 | for pweight in passage_weights 138 | } 139 | fcurve_data = os.path.join( 140 | args.output_dir, f"passage_weight-vs-{args.report_metric}.json" 141 | ) 142 | with open(fcurve_data, "w") as f: 143 | json.dump(pweight2main_metric, f, indent=4) 144 | logging.info(f"Saved the curve data to {fcurve_data}") 145 | main_metrics = [pweight2main_metric[pweight] for pweight in passage_weights] 146 | fcurve = os.path.join( 147 | args.output_dir, f"passage_weight-vs-{args.report_metric}.pdf" 148 | ) 149 | plt.plot(passage_weights, main_metrics) 150 | plt.grid(linestyle="dashed") 151 | plt.xlabel("passage weight") 152 | plt.ylabel(f"{args.report_metric}") 153 | plt.savefig(fcurve, bbox_inches="tight") 154 | -------------------------------------------------------------------------------- /dapr/dataloader.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Dict, Iterable, List, Optional 4 | from clddp.dm import RetrievalDataset, Passage, JudgedPassage, Query, LabeledQuery 5 | from dapr.datasets.dm import LoadedData 6 | from dapr.datasets.dm import LabeledQuery as DAPRLabeledQuery 7 | from dapr.dm import RetrievalLevel, ParagraphSeparator 8 | import tqdm 9 | 10 | 11 | @dataclass 12 | class DAPRDataConfig: 13 | data_name_or_path: str 14 | titled: bool = False 15 | retrieval_level: RetrievalLevel = RetrievalLevel.paragraph 16 | paragraph_separator: ParagraphSeparator = ParagraphSeparator.blank 17 | 18 | 19 | class DAPRDataLoader: 20 | def __init__(self, config: DAPRDataConfig) -> None: 21 | self.config = config 22 | 23 | def collection_iter_fn( 24 | self, 25 | data: LoadedData, 26 | ) -> Iterable[Passage]: 27 | titled = self.config.titled 28 | retrieval_level = self.config.retrieval_level 29 | paragraph_separator = self.config.paragraph_separator 30 | for doc in data.corpus_iter_fn(): 31 | doc.set_default_candidates() 32 | title = doc.title if titled else None 33 | if retrieval_level is RetrievalLevel.document: 34 | yield Passage( 35 | passage_id=doc.doc_id, 36 | text=paragraph_separator.string.join( 37 | chk.text for chk in doc.chunks 38 | ), 39 | title=title, 40 | ) 41 | else: 42 | for chunk in doc.chunks: 43 | if chunk.chunk_id in doc.candidate_chunk_ids: 44 | yield Passage( 45 | passage_id=chunk.chunk_id, 46 | text=chunk.text, 47 | title=chunk.belonging_doc.title if titled else None, 48 | ) 49 | 50 | def build_labeled_queries( 51 | self, 52 | labeled_queries: Optional[List[DAPRLabeledQuery]], 53 | ) -> List[LabeledQuery]: 54 | titled = self.config.titled 55 | retrieval_level = self.config.retrieval_level 56 | paragraph_separator = self.config.paragraph_separator 57 | if labeled_queries is None: 58 | return None 59 | lqs = [] 60 | for lq in labeled_queries: 61 | pid2positive: Dict[str, JudgedPassage] = {} 62 | pid2negative: Dict[str, JudgedPassage] = {} 63 | for jchk in lq.judged_chunks: 64 | query = Query(query_id=jchk.query.query_id, text=jchk.query.text) 65 | chunk = jchk.chunk 66 | doc = chunk.belonging_doc 67 | title = doc.title if titled else None 68 | if retrieval_level is RetrievalLevel.document: 69 | pid = doc.doc_id 70 | passage = Passage( 71 | passage_id=pid, 72 | text=paragraph_separator.string.join( 73 | chk.text for chk in doc.chunks 74 | ), 75 | title=title, 76 | ) 77 | else: 78 | pid = chunk.chunk_id 79 | passage = Passage( 80 | passage_id=pid, 81 | text=chunk.text, 82 | title=chunk.belonging_doc.title if titled else None, 83 | ) 84 | jpsg = JudgedPassage( 85 | query=query, passage=passage, judgement=jchk.judgement 86 | ) 87 | if jchk.judgement: 88 | # For document-level annotation, keep the highest judgement on paragraphs: 89 | if ( 90 | pid not in pid2positive 91 | or pid2positive[pid].judgement < jpsg.judgement 92 | ): 93 | pid2positive[pid] = jpsg 94 | else: 95 | pid2negative[pid] = jpsg 96 | lqs.append( 97 | LabeledQuery( 98 | query=query, 99 | positives=list(pid2positive.values()), 100 | negatives=list(pid2negative.values()), 101 | ) 102 | ) 103 | return lqs 104 | 105 | def load_data(self, progress_bar: bool) -> RetrievalDataset: 106 | data = LoadedData.from_dump(self.config.data_name_or_path, pbar=progress_bar) 107 | assert data.meta_data is not None 108 | retrieval_level = self.config.retrieval_level 109 | if retrieval_level is RetrievalLevel.document: 110 | collection_size = data.meta_data["ndocs"] 111 | else: 112 | collection_size = ( 113 | data.meta_data["nchunks"] 114 | if data.meta_data.get("nchunks_candidates") is None 115 | else data.meta_data["nchunks_candidates"] 116 | ) 117 | dataset = RetrievalDataset( 118 | collection_iter_fn=partial(self.collection_iter_fn, data=data), 119 | collection_size=collection_size, 120 | train_labeled_queries=self.build_labeled_queries( 121 | labeled_queries=data.labeled_queries_train 122 | ), 123 | dev_labeled_queries=self.build_labeled_queries( 124 | labeled_queries=data.labeled_queries_dev 125 | ), 126 | test_labeled_queries=self.build_labeled_queries( 127 | labeled_queries=data.labeled_queries_test 128 | ), 129 | ) 130 | return dataset 131 | 132 | def get_pid2did(self, progress_bar: bool) -> Dict[str, str]: 133 | data = LoadedData.from_dump(self.config.data_name_or_path, pbar=progress_bar) 134 | pid2did = {} 135 | assert data.corpus_iter_fn is not None 136 | for doc in tqdm.tqdm( 137 | data.corpus_iter_fn(), 138 | total=data.meta_data["ndocs"], 139 | desc="Building pid2did", 140 | disable=not progress_bar, 141 | ): 142 | for chunk in doc.chunks: 143 | pid2did[chunk.chunk_id] = doc.doc_id 144 | return pid2did 145 | 146 | def get_did2pids(self, progress_bar: bool) -> Dict[str, List[str]]: 147 | data = LoadedData.from_dump(self.config.data_name_or_path, pbar=progress_bar) 148 | did2pids: Dict[str, List[str]] = {} 149 | assert data.corpus_iter_fn is not None 150 | for doc in tqdm.tqdm( 151 | data.corpus_iter_fn(), 152 | total=data.meta_data["ndocs"], 153 | desc="Building did2pids", 154 | disable=not progress_bar, 155 | ): 156 | for chunk in doc.chunks: 157 | did2pids.setdefault(doc.doc_id, []) 158 | did2pids[doc.doc_id].append(chunk.chunk_id) 159 | return did2pids 160 | -------------------------------------------------------------------------------- /dapr/datasets/tagged_conditionalqa.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | from dataclasses import dataclass 4 | import os 5 | from typing import Dict, List, Optional, Union 6 | from dapr.datasets.base import BaseDataset 7 | from dapr.datasets.dm import ( 8 | Chunk, 9 | Document, 10 | JudgedChunk, 11 | LabeledQuery, 12 | LoadedData, 13 | Query, 14 | ) 15 | from dapr.utils import randomly_split_by_number, set_logger_format 16 | import datasets 17 | import ujson 18 | 19 | 20 | @dataclass 21 | class DocumentRecord: 22 | doc_id: str 23 | url: str 24 | title: str 25 | contents: List[str] 26 | 27 | @staticmethod 28 | def build_url2drecord(drecords: List[DocumentRecord]) -> Dict[str, DocumentRecord]: 29 | return {drecord.url: drecord for drecord in drecords} 30 | 31 | 32 | @dataclass 33 | class JudgedDocumentRecord: 34 | document_record: DocumentRecord 35 | qid: str 36 | query: str 37 | evidence_positions: List[int] 38 | 39 | 40 | class TaggedConditionalQA(BaseDataset): 41 | """The original version of ConditionalQA, which contains HTML tags.""" 42 | 43 | fdocuments_v1: Optional[str] = None 44 | ftrain_v1: Optional[str] = None 45 | fdev_v1: Optional[str] = None 46 | 47 | def _download(self, resource_path: str) -> None: 48 | """`resource_path` is something like https://raw.githubusercontent.com/haitian-sun/ConditionalQA/master/v1_0""" 49 | if os.path.exists(resource_path): 50 | self.fdocuments_v1 = os.path.join(resource_path, "documents.json") 51 | self.train_v1 = os.path.join(resource_path, "train.json") 52 | self.dev_v1 = os.path.join(resource_path, "dev.json") 53 | else: 54 | dm = datasets.DownloadManager() 55 | self.fdocuments_v1 = dm.download( 56 | os.path.join(resource_path, "documents.json") 57 | ) 58 | self.ftrain_v1 = dm.download(os.path.join(resource_path, "train.json")) 59 | self.fdev_v1 = dm.download(os.path.join(resource_path, "dev.json")) 60 | 61 | def _load_drecords(self) -> List[DocumentRecord]: 62 | with open(self.fdocuments_v1) as f: 63 | data = ujson.load(f) 64 | drecords = [] 65 | for i, doc in enumerate(data): 66 | drecords.append( 67 | DocumentRecord( 68 | doc_id=str(i), 69 | url=doc["url"], 70 | title=doc["title"], 71 | contents=doc["contents"], 72 | ) 73 | ) 74 | return drecords 75 | 76 | def _load_jrecords( 77 | self, fpath: str, url2drecod: Dict[str, DocumentRecord] 78 | ) -> List[JudgedDocumentRecord]: 79 | with open(fpath) as f: 80 | data = ujson.load(f) 81 | jrecords = [] 82 | for example in data: 83 | if len(example["evidences"]) == 0: 84 | continue 85 | 86 | drecord = url2drecod[example["url"]] 87 | query = " ".join([example["scenario"], example["question"]]) 88 | jrecord = JudgedDocumentRecord( 89 | document_record=drecord, 90 | qid=example["id"], 91 | query=query, 92 | evidence_positions=[ 93 | drecord.contents.index(evidence) 94 | for evidence in example["evidences"] 95 | ], 96 | ) 97 | jrecords.append(jrecord) 98 | 99 | return jrecords 100 | 101 | def _build_chunks( 102 | self, record: Union[JudgedDocumentRecord, DocumentRecord] 103 | ) -> Union[List[JudgedChunk], List[Chunk]]: 104 | query: Optional[Query] = None 105 | drecord: Optional[DocumentRecord] = None 106 | if type(record) is JudgedDocumentRecord: 107 | query = Query(query_id=record.qid, text=record.query) 108 | texts = [evidence for evidence in record.document_record.contents] 109 | marked = [ 110 | pos in record.evidence_positions 111 | for pos, _ in enumerate(record.document_record.contents) 112 | ] 113 | drecord = record.document_record 114 | else: 115 | assert type(record) is DocumentRecord 116 | texts = [evidence for evidence in record.contents] 117 | marked = [False] * len(texts) 118 | drecord = record 119 | 120 | document = Document(doc_id=drecord.doc_id, chunks=[], title=drecord.title) 121 | judged_chunks = [] 122 | for text, positive in zip(texts, marked): 123 | chunk_id = Chunk.build_chunk_id( 124 | doc_id=drecord.doc_id, position=len(document.chunks) 125 | ) 126 | chunk = Chunk( 127 | chunk_id=chunk_id, text=text, doc_summary=None, belonging_doc=document 128 | ) 129 | document.chunks.append(chunk) 130 | if positive and query is not None: 131 | judged_chunks.append(JudgedChunk(query=query, chunk=chunk, judgement=1)) 132 | document.set_default_candidates() 133 | 134 | if query is not None: 135 | return judged_chunks 136 | else: 137 | return document.chunks 138 | 139 | def _build_labeled_queries( 140 | self, jrecords: List[JudgedDocumentRecord] 141 | ) -> List[LabeledQuery]: 142 | qid2jchunks: Dict[str, List[JudgedChunk]] = defaultdict(list) 143 | jchunks_list: List[List[JudgedChunk]] = list(map(self._build_chunks, jrecords)) 144 | for jchunks in jchunks_list: 145 | for jchunk in jchunks: # gether jchunks by qid 146 | qid2jchunks[jchunk.query.query_id].append(jchunk) 147 | labeled_queries = [ 148 | LabeledQuery(query=jchunks[0].query, judged_chunks=jchunks) 149 | for jchunks in qid2jchunks.values() 150 | ] 151 | return labeled_queries 152 | 153 | def _build_corpus(self, drecords: List[DocumentRecord]) -> List[Document]: 154 | corpus = [] 155 | chunks_list: List[List[Chunk]] = list(map(self._build_chunks, drecords)) 156 | for chunks in chunks_list: 157 | doc = chunks[0].belonging_doc 158 | corpus.append(doc) 159 | return corpus 160 | 161 | def _load_data(self, nheldout: Optional[int]) -> LoadedData: 162 | drecords = self._load_drecords() 163 | url2drecord = DocumentRecord.build_url2drecord(drecords) 164 | jrecords_train = self._load_jrecords( 165 | fpath=self.ftrain_v1, url2drecod=url2drecord 166 | ) 167 | jrecords_dev = self._load_jrecords(fpath=self.fdev_v1, url2drecod=url2drecord) 168 | labeled_queries_train_and_dev = self._build_labeled_queries(jrecords_train) 169 | labeled_queries_test = self._build_labeled_queries(jrecords_dev) 170 | if nheldout is None: 171 | nheldout = len(labeled_queries_test) 172 | labeled_queries_dev, labeled_queries_train = randomly_split_by_number( 173 | data=labeled_queries_train_and_dev, number=nheldout 174 | ) 175 | corpus = self._build_corpus(drecords) 176 | return LoadedData( 177 | corpus_iter_fn=lambda: iter(corpus), 178 | labeled_queries_train=labeled_queries_train, 179 | labeled_queries_dev=labeled_queries_dev, 180 | labeled_queries_test=labeled_queries_test, 181 | ) 182 | 183 | 184 | if __name__ == "__main__": 185 | set_logger_format() 186 | dataset = TaggedConditionalQA( 187 | resource_path="https://raw.githubusercontent.com/haitian-sun/ConditionalQA/master/v1_0", 188 | nheldout=None, 189 | ) 190 | -------------------------------------------------------------------------------- /dapr/datasets/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import inspect 3 | import json 4 | import os 5 | from random import Random 6 | import shutil 7 | from typing import Any, Dict, Iterable, List, Optional 8 | from dapr.utils import Separator, md5 9 | from dapr.datasets.dm import Document, LabeledQuery, LoadedData 10 | import numpy as np 11 | from transformers import AutoTokenizer 12 | import logging 13 | 14 | 15 | class BaseDataset(ABC): 16 | def __init__( 17 | self, 18 | resource_path: str, 19 | nheldout: Optional[int] = None, 20 | cache_root_dir: str = "data", 21 | chunk_separator: Separator = Separator.empty, 22 | tokenizer: str = "roberta-base", 23 | nprocs: int = 10, 24 | ) -> None: 25 | self.kwargs = dict(inspect.getargvalues(inspect.currentframe()).locals) 26 | self.kwargs.pop("self") 27 | self.kwargs.pop("nprocs") 28 | self.logger = logging.getLogger(__name__) 29 | self.resource_path = resource_path 30 | self.chunk_separator = chunk_separator 31 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) 32 | self.nprocs = nprocs 33 | cache_dir = os.path.join(cache_root_dir, self.name) 34 | if os.path.exists(cache_dir) and os.listdir(cache_dir): 35 | loaded_data = LoadedData.from_dump(cache_dir) 36 | else: 37 | self._download(resource_path) 38 | loaded_data = self.load_data(nheldout=nheldout, cache_dir=cache_dir) 39 | self.loaded_data = loaded_data 40 | 41 | @property 42 | def name(self) -> str: 43 | return self.__class__.__name__ 44 | 45 | def estimate_chunk_length( 46 | self, corpus: Iterable[Document], n: int = 1000, candidates_only: bool = False 47 | ) -> float: 48 | if candidates_only: 49 | chunks = [ 50 | chk 51 | for doc in corpus 52 | for chk in doc.chunks 53 | if chk.chunk_id in doc.candidate_chunk_ids 54 | ] 55 | else: 56 | chunks = [chk for doc in corpus for chk in doc.chunks] 57 | random_state = Random(42) 58 | sampled = random_state.sample(chunks, k=min(n, len(chunks))) 59 | lengths = [len(self.tokenizer.tokenize(chunk.text)) for chunk in sampled] 60 | return float(np.mean(lengths)) 61 | 62 | def stats(self, loaded_data: LoadedData) -> Dict[str, Any]: 63 | assert loaded_data.corpus_iter_fn is not None 64 | corpus_size = sum(1 for _ in loaded_data.corpus_iter_fn()) 65 | assert loaded_data.labeled_queries_test is not None 66 | stats = { 67 | "name": self.name, 68 | "ndocs": corpus_size, 69 | "nchunks": Document.nchunks_in_corpus(loaded_data.corpus_iter_fn()), 70 | "nchunks_candidates": Document.nchunks_in_corpus( 71 | loaded_data.corpus_iter_fn(), candidates_only=True 72 | ), 73 | "nchunks_percentiles": Document.nchunks_percentiles( 74 | corpus=loaded_data.corpus_iter_fn() 75 | ), 76 | "nchunks_candidates_percentiles": Document.nchunks_percentiles( 77 | corpus=loaded_data.corpus_iter_fn(), candidates_only=True 78 | ), 79 | "avg_chunk_length": self.estimate_chunk_length( 80 | loaded_data.corpus_iter_fn() 81 | ), 82 | "avg_candidate_chunk_length": self.estimate_chunk_length( 83 | loaded_data.corpus_iter_fn(), candidates_only=True 84 | ), 85 | } 86 | for split, lqs in [ 87 | ("train", loaded_data.labeled_queries_train), 88 | ("dev", loaded_data.labeled_queries_dev), 89 | ("test", loaded_data.labeled_queries_test), 90 | ]: 91 | stats[f"nqueries_{split}"] = LabeledQuery.nqueries(lqs) if lqs else None 92 | stats[f"jpq_{split}"] = LabeledQuery.njudged_per_query(lqs) if lqs else None 93 | stats = { 94 | k: round(v, 2) if type(v) in [int, float] else v for k, v in stats.items() 95 | } 96 | self.logger.info("\n" + json.dumps(stats, indent=4)) 97 | return stats 98 | 99 | def check_heldout_in_corpus(self, loaded_data: LoadedData) -> None: 100 | """Assert that the chunk ids in heldout data are all included in the corpus.""" 101 | cids_coprus = set( 102 | [chk.chunk_id for doc in loaded_data.corpus_iter_fn() for chk in doc.chunks] 103 | ) 104 | cids_heldout = set() 105 | assert loaded_data.labeled_queries_test is not None 106 | lqs: List[LabeledQuery] 107 | for lqs in [ 108 | loaded_data.labeled_queries_dev, 109 | loaded_data.labeled_queries_test, 110 | ]: 111 | if lqs is None: 112 | continue 113 | for lq in lqs: 114 | for jc in lq.judged_chunks: 115 | cids_heldout.add(jc.chunk.chunk_id) 116 | 117 | cids_left = cids_heldout - cids_coprus 118 | assert len(cids_heldout - cids_coprus) == 0, f"Left: {cids_left}" 119 | 120 | @abstractmethod 121 | def _download(self, resource_path: str) -> None: 122 | pass 123 | 124 | def load_data(self, nheldout: Optional[int], cache_dir: str) -> LoadedData: 125 | loaded_data = self._load_data(nheldout) 126 | loaded_data.meta_data = {} 127 | loaded_data.meta_data["chunk_separator"] = self.chunk_separator 128 | loaded_data.meta_data["corpus_identifier"] = ( 129 | f"{self.name}_{md5(map(lambda doc: str(doc.to_json()), loaded_data.corpus_iter_fn()))}" 130 | ) 131 | stats = self.stats(loaded_data) 132 | loaded_data.meta_data.update(stats) 133 | self.check_heldout_in_corpus(loaded_data) 134 | try: 135 | loaded_data.dump(cache_dir) 136 | except Exception as e: 137 | if os.path.exists(cache_dir): 138 | shutil.rmtree(cache_dir) 139 | raise e 140 | del loaded_data 141 | return LoadedData.from_dump(cache_dir) 142 | 143 | @abstractmethod 144 | def _load_data(self, nheldout: Optional[int]) -> LoadedData: 145 | pass 146 | 147 | 148 | # Stats: 149 | # { 150 | # "name": "Genomics", 151 | # "#docs": 162259, 152 | # "#chks": 4035411, 153 | # "#chks percentiles": { 154 | # "5": 5.0, 155 | # "25": 19.0, 156 | # "50": 25.0, 157 | # "75": 32.0, 158 | # "95": 41.0 159 | # }, 160 | # "#train queries": null, 161 | # "#Judged per query (train)": null, 162 | # "#dev queries": null, 163 | # "#Judged per query (dev)": null, 164 | # "#test queries": 62, 165 | # "#Judged per query (test)": 225.52 166 | # } 167 | # { 168 | # "name": "MSMARCO", 169 | # "#docs": 3201821, 170 | # "#chks": 11799171, 171 | # "percentiles": { 172 | # "5": 1.0, 173 | # "25": 1.0, 174 | # "50": 2.0, 175 | # "75": 4.0, 176 | # "95": 11.0 177 | # }, 178 | # "#train queries": 170025, 179 | # "#Judged per query (train)": 1.09, 180 | # "#dev queries": 25908, 181 | # "#Judged per query (dev)": 1.09, 182 | # "#test queries": 25908, 183 | # "#Judged per query (test)": 1.1 184 | # } 185 | # { 186 | # "name": "NaturalQuestions", 187 | # "#docs": 108626, 188 | # "#chks": 593737, 189 | # "#chks percentiles": { 190 | # "5": 1.0, 191 | # "25": 2.0, 192 | # "50": 3.0, 193 | # "75": 7.0, 194 | # "95": 18.0 195 | # }, 196 | # "#train queries": 93275, 197 | # "#Judged per query (train)": 1.11, 198 | # "#dev queries": 3610, 199 | # "#Judged per query (dev)": 1.09, 200 | # "#test queries": 3610, 201 | # "#Judged per query (test)": 1.27 202 | # } 203 | -------------------------------------------------------------------------------- /dapr/datasets/genomics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import re 3 | from typing import List, Optional, Set, Tuple, Union 4 | from dapr.datasets.base import BaseDataset, LoadedData 5 | from dapr.datasets.dm import Chunk, Document, JudgedChunk, LabeledQuery, Query 6 | from dapr.utils import Separator 7 | import ir_datasets 8 | from ir_datasets.datasets import highwire 9 | import tqdm 10 | 11 | 12 | @dataclass 13 | class QRelRecord: 14 | query: Query 15 | qrel: highwire.HighwireQrel 16 | 17 | 18 | class Genomics(BaseDataset): 19 | def __init__( 20 | self, 21 | resource_path: str = "", 22 | nheldout: Optional[int] = None, 23 | cache_root_dir: str = "data", 24 | chunk_separator: Separator = Separator.empty, 25 | tokenizer: str = "roberta-base", 26 | nprocs: int = 10, 27 | ) -> None: 28 | super().__init__( 29 | resource_path, nheldout, cache_root_dir, chunk_separator, tokenizer, nprocs 30 | ) 31 | 32 | def _download(self, resource_path: str) -> None: 33 | pass 34 | 35 | def _keep_single_return(self, text: str) -> str: 36 | """For example, 37 | ``` 38 | "\\r\\n\\r\\n\\r\\nJohn Snow and Modern-Day Environmental Epidemiology\\r\\n\\r\\n\\r\\n\\r\\nDale P. Sandler\\r\\n" 39 | ``` 40 | into 41 | ``` 42 | "\\r\\nJohn Snow and Modern-Day Environmental Epidemiology\\r\\nDale P. Sandler\\r\\n" 43 | ``` 44 | """ 45 | return re.sub("(\r|\n)+", "\r\n", text.strip()) + "\r\n" 46 | 47 | def _build_chunks( 48 | self, hw_doc_qrel_record: Tuple[highwire.HighwireDoc, Optional[QRelRecord]] 49 | ) -> Union[List[JudgedChunk], List[Chunk]]: 50 | hw_doc, qrel_record = hw_doc_qrel_record 51 | passages = [self._keep_single_return(span.text) for span in hw_doc.spans] 52 | marked = [False] * len(passages) 53 | if qrel_record is not None: 54 | assert qrel_record.qrel.relevance > 0 55 | marked = [ 56 | ( 57 | True 58 | if (span.start, span.length) 59 | == (qrel_record.qrel.start, qrel_record.qrel.length) 60 | else False 61 | ) 62 | for span in hw_doc.spans 63 | ] 64 | judged_chunks = [] 65 | document = Document(doc_id=hw_doc.doc_id, chunks=[], title=str(hw_doc.title)) 66 | for passage, mark in zip(passages, marked): 67 | chunk_id = Chunk.build_chunk_id( 68 | doc_id=document.doc_id, position=len(document.chunks) 69 | ) 70 | chunk = Chunk( 71 | chunk_id=chunk_id, 72 | text=passage, 73 | doc_summary=None, 74 | belonging_doc=document, 75 | ) 76 | document.chunks.append(chunk) 77 | if mark: 78 | judged_chunks.append( 79 | JudgedChunk( 80 | query=qrel_record.query, 81 | chunk=chunk, 82 | judgement=qrel_record.qrel.relevance, 83 | ) 84 | ) 85 | document.set_default_candidates() 86 | 87 | if qrel_record is None: 88 | return document.chunks 89 | else: 90 | return judged_chunks 91 | 92 | def _build_document(self, hw_doc: highwire.HighwireDoc) -> Document: 93 | chunks: List[Chunk] = self._build_chunks((hw_doc, None)) 94 | return chunks[0].belonging_doc 95 | 96 | def _build_corpus(self) -> List[Document]: 97 | dataset: ir_datasets.datasets.base.Dataset = ir_datasets.load( 98 | "highwire/trec-genomics-2007" 99 | ) # 2006 and 2007 share the same corpus 100 | hw_docs: ir_datasets.datasets.base._BetaPythonApiDocs = getattr(dataset, "docs") 101 | corpus = [ 102 | self._build_document(hw_doc) 103 | for hw_doc in tqdm.tqdm(hw_docs, desc="Building corpus") 104 | ] 105 | return corpus 106 | 107 | def _build_labeled_queries(self, year: str) -> List[LabeledQuery]: 108 | assert year in ["2006", "2007"] 109 | dataset: ir_datasets.datasets.base.Dataset = ir_datasets.load( 110 | f"highwire/trec-genomics-{year}" 111 | ) 112 | hw_docs: ir_datasets.datasets.base._BetaPythonApiDocs = getattr(dataset, "docs") 113 | hw_queries: ir_datasets.datasets.base._BetaPythonApiQueries = getattr( 114 | dataset, "queries" 115 | ) 116 | hw_qrels: ir_datasets.datasets.base._BetaPythonApiQrels = getattr( 117 | dataset, "qrels" 118 | ) 119 | hw_qrel: highwire.HighwireQrel 120 | labeled_queries = [] 121 | for hw_qrel in tqdm.tqdm( 122 | hw_qrels, desc=f"Building labeled queries (year: {year})" 123 | ): 124 | if hw_qrel.relevance == 0: 125 | continue 126 | 127 | hw_query: ir_datasets.formats.base.GenericQuery = hw_queries.lookup( 128 | hw_qrel.query_id 129 | ) 130 | qrel_record = QRelRecord( 131 | query=Query(query_id=hw_qrel.query_id, text=hw_query.text), qrel=hw_qrel 132 | ) 133 | hw_doc: highwire.HighwireDoc = hw_docs.lookup(hw_qrel.doc_id) 134 | jchunks = self._build_chunks((hw_doc, qrel_record)) 135 | labeled_queries.append( 136 | LabeledQuery(query=qrel_record.query, judged_chunks=jchunks) 137 | ) 138 | return labeled_queries 139 | 140 | def _remove_first_paragraphs( 141 | self, corpus: List[Document], labeled_queries_test: List[LabeledQuery] 142 | ): 143 | """Remove the first paragraph in each Genomics document. All the first paragraphs contain the title.""" 144 | nqueries = LabeledQuery.get_unique_queries(labeled_queries_test) 145 | 146 | # Remove first paragraphs for the corpus: 147 | cids_to_remove: Set[str] = set() 148 | for doc in corpus: 149 | chunk = doc.chunks.pop(0) 150 | cids_to_remove.add(chunk.chunk_id) 151 | 152 | # Remove empty docs: 153 | for i in range(len(corpus) - 1, -1, -1): 154 | if len(corpus[i].chunks) == 0: 155 | corpus.pop(i) 156 | 157 | # Remove the corresponding judged chunks: 158 | for lq in labeled_queries_test: 159 | jchks = lq.judged_chunks 160 | for i in range(len(jchks) - 1, -1, -1): 161 | if jchks[i].chunk.chunk_id in cids_to_remove: 162 | jchks.pop(i) 163 | 164 | # Remove labeled queries with empty judged chunks: 165 | for i in range(len(labeled_queries_test) - 1, -1, -1): 166 | if len(labeled_queries_test[i].judged_chunks) == 0: 167 | labeled_queries_test.pop(i) 168 | 169 | assert len(LabeledQuery.get_unique_queries(labeled_queries_test)) == len( 170 | nqueries 171 | ), "#queris has been changed after removing the first paragraphs. This should not happen to Genomics." 172 | 173 | def _load_data(self, nheldout: Optional[int]) -> LoadedData: 174 | corpus = self._build_corpus() 175 | labeled_queries_2006 = self._build_labeled_queries("2006") 176 | labeled_queries_2007 = self._build_labeled_queries("2007") 177 | labeled_queries_test = [] 178 | labeled_queries_test.extend(labeled_queries_2006) 179 | labeled_queries_test.extend(labeled_queries_2007) 180 | self._remove_first_paragraphs( 181 | corpus=corpus, labeled_queries_test=labeled_queries_test 182 | ) 183 | 184 | return LoadedData( 185 | corpus_iter_fn=lambda: iter(corpus), 186 | labeled_queries_test=labeled_queries_test, 187 | ) 188 | 189 | 190 | if __name__ == "__main__": 191 | from dapr.utils import set_logger_format 192 | 193 | set_logger_format() 194 | genomics = Genomics() 195 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document-Aware Passage Retrieval (DAPR) 2 | [![arXiv](https://img.shields.io/badge/arXiv-2305.13915-b31b1b.svg)](https://arxiv.org/abs/2305.13915) 3 | [![The Vault on HuggingFace datasets](https://img.shields.io/badge/%F0%9F%A4%97%20Datasets-UKPLab/dapr-yellow?style=flat)](https://huggingface.co/datasets/UKPLab/dapr) 4 | 5 | DAPR is a benchmark for document-aware passage retrieval: given a (large) collection of documents, relevant passages within these documents for a given query are required to be returned. 6 | 7 | A key focus of DAPR is forcing/encouraging retrieval systems to utilize the document-level context which surrounds the relevant passages. An example is shown below: 8 | 9 | 10 | 11 | > In this example, the query asks for a musician or a group who has ever played at a certain venue. However, the gold relevant passage mentions only the reference noun, "the venue" but its actual name, "the Half Moon, Putney". The model thus needs to explore the context from the belonging document of the passage, which in this case means coreference resolution. 12 | 13 | ## Installation: 14 | Python>=3.8 is required. Run this installation script below: 15 | ```bash 16 | pip install git+https://github.com/kwang2049/dapr.git 17 | ``` 18 | For the optional usage of BM25, please install JDK (`openjdk`>=11). One can install it via conda: 19 | ```bash 20 | conda install openjdk=11 21 | ``` 22 | 23 | 24 | ## Usage 25 | ### Building/loading data 26 | ```python 27 | from dapr.datasets.conditionalqa import ConditionalQA 28 | from dapr.datasets.nq import NaturalQuestions 29 | from dapr.datasets.genomics import Genomics 30 | from dapr.datasets.miracl import MIRACL 31 | from dapr.datasets.msmarco import MSMARCO 32 | from dapr.datasets.dm import LoadedData 33 | 34 | # Build the data on the fly: (this will save the data to ./data/ConditionalQA) 35 | data = ConditionalQA().loaded_data # Also the same for NaturalQuestions, etc. 36 | # data = LoadedData.from_dump("data/ConditionalQA") # Load the pre-built data (please download it from https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/ConditionalQA) 37 | 38 | # Iterate over the corpus: 39 | for doc in data.corpus_iter_fn(): 40 | doc.doc_id 41 | doc.title 42 | for chunk in doc.chunks: 43 | chunk.chunk_id 44 | chunk.text 45 | chunk.belonging_doc.doc_id 46 | 47 | # Iterate over the labeled queries (of the test split): 48 | for labeled_query in data.labeled_queries_test: 49 | labeled_query.query.query_id 50 | labeled_query.query.text 51 | for judged_chunk in labeled_query.judged_chunks: 52 | judged_chunk.chunk.chunk_id 53 | judged_chunk.chunk.text 54 | judged_chunk.chunk.belonging_doc.doc_id 55 | ``` 56 | 57 | ### Evaluation 58 | ```python 59 | from typing import Dict 60 | from dapr.retrievers.dense import DRAGONPlus 61 | from dapr.datasets.conditionalqa import ConditionalQA 62 | from clddp.dm import Query, Passage 63 | import torch 64 | import pytrec_eval 65 | import numpy as np 66 | 67 | # Load data: 68 | data = ConditionalQA().loaded_data 69 | 70 | # Encode queries and passages: 71 | retriever = DRAGONPlus() 72 | retriever.eval() 73 | queries = [ 74 | Query(query_id=labeled_query.query.query_id, text=labeled_query.query.text) 75 | for labeled_query in data.labeled_queries_test 76 | ] 77 | passages = [ 78 | Passage(passage_id=chunk.chunk_id, text=chunk.text) 79 | for doc in data.corpus_iter_fn() 80 | for chunk in doc.chunks 81 | ] 82 | query_embeddings = retriever.encode_queries(queries) 83 | with torch.no_grad(): # Takes around a minute on a V100 GPU 84 | passage_embeddings, passage_mask = retriever.encode_passages(passages) 85 | 86 | # Calculate the similarities and keep top-K: 87 | similarity_scores = torch.matmul( 88 | query_embeddings, passage_embeddings.t() 89 | ) # (query_num, passage_num) 90 | topk = torch.topk(similarity_scores, k=10) 91 | topk_values: torch.Tensor = topk[0] 92 | topk_indices: torch.LongTensor = topk[1] 93 | topk_value_lists = topk_values.tolist() 94 | topk_index_lists = topk_indices.tolist() 95 | 96 | # Run evaluation with pytrec_eval: 97 | retrieval_scores: Dict[str, Dict[str, float]] = {} 98 | for query_i, (values, indices) in enumerate(zip(topk_value_lists, topk_index_lists)): 99 | query_id = queries[query_i].query_id 100 | retrieval_scores.setdefault(query_id, {}) 101 | for value, passage_i in zip(values, indices): 102 | passage_id = passages[passage_i].passage_id 103 | retrieval_scores[query_id][passage_id] = value 104 | qrels: Dict[str, Dict[str, int]] = { 105 | labeled_query.query.query_id: { 106 | judged_chunk.chunk.chunk_id: judged_chunk.judgement 107 | for judged_chunk in labeled_query.judged_chunks 108 | } 109 | for labeled_query in data.labeled_queries_test 110 | } 111 | evaluator = pytrec_eval.RelevanceEvaluator( 112 | query_relevance=qrels, measures=["ndcg_cut_10"] 113 | ) 114 | query_performances: Dict[str, Dict[str, float]] = evaluator.evaluate(retrieval_scores) 115 | ndcg = np.mean([score["ndcg_cut_10"] for score in query_performances.values()]) 116 | print(ndcg) # 0.21796083196880855 117 | ``` 118 | 119 | ### Reproducing experiment results 120 | All the experiment scripts are available at [scripts/dgx2/exps](scripts/dgx2/exps). For example, one can evaluate the DRAGON+ retriever in a passage-only manner like this: 121 | ```bash 122 | # scripts/dgx2/exps/passage_only/dragon_plus.sh 123 | export NCCL_DEBUG="INFO" 124 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 125 | 126 | datasets=( "ConditionalQA" ) 127 | # datasets=( "ConditionalQA" "MSMARCO" "NaturalQuestions" "Genomics" "MIRACL" ) 128 | for dataset in ${datasets[@]} 129 | do 130 | export DATA_DIR="data" 131 | export DATASET_PATH="$DATA_DIR/$dataset" 132 | export CLI_ARGS=" 133 | --data_dir=$DATASET_PATH 134 | " 135 | export OUTPUT_DIR=$(python -m dapr.exps.passage_only.args.dragon_plus $CLI_ARGS) 136 | mkdir -p $OUTPUT_DIR 137 | export LOG_PATH="$OUTPUT_DIR/logging.log" 138 | echo "Logging file path: $LOG_PATH" 139 | torchrun --nproc_per_node=4 --master_port=29501 -m dapr.exps.passage_only.dragon_plus $CLI_ARGS > $LOG_PATH 140 | done 141 | ``` 142 | 143 | ## Pre-Built Data 144 | The building processes above require relative large memory for the large datasets. The loading part after this data building is cheap though (the collections will be loaded on the fly via Python generators). The budgets are listed below (with 12 multi-processes): 145 | | Dataset | Memory | Time | 146 | | -------- | ------- | ------- | 147 | | NaturalQuestions | 25.6GB | 39min | 148 | | Genomics | 18.3GB |25min | 149 | | MSMARCO | 102.9GB |3h | 150 | | MIRACL | 69.7GB |1h30min | 151 | | ConditionalQA | <1GB | <1min | 152 | 153 | To bypass this, one can also download the pre-built data: 154 | ```bash 155 | mkdir data 156 | wget -r -np -nH --cut-dirs=3 https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/NaturalQuestions/ -P ./data 157 | wget -r -np -nH --cut-dirs=3 https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/MSMARCO/ -P ./data 158 | wget -r -np -nH --cut-dirs=3 https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/Genomics/ -P ./data 159 | wget -r -np -nH --cut-dirs=3 https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/MIRACL/ -P ./data 160 | wget -r -np -nH --cut-dirs=3 https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data/ConditionalQA/ -P ./data 161 | ``` 162 | The data are also available at the Huggingface datasets: https://huggingface.co/datasets/kwang2049/dapr. 163 | 164 | ## Citation 165 | If you use the code/data, feel free to cite our publication [DAPR: A Benchmark on Document-Aware Passage Retrieval](https://arxiv.org/abs/2305.13915): 166 | ```bibtex 167 | @article{wang2023dapr, 168 | title = "DAPR: A Benchmark on Document-Aware Passage Retrieval", 169 | author = "Kexin Wang and Nils Reimers and Iryna Gurevych", 170 | journal= "arXiv preprint arXiv:2305.13915", 171 | year = "2023", 172 | url = "https://arxiv.org/abs/2305.13915", 173 | } 174 | ``` 175 | 176 | Contact person and main contributor: [Kexin Wang](https://kwang2049.github.io/), kexin.wang.2049@gmail.com 177 | 178 | [https://www.ukp.tu-darmstadt.de/](https://www.ukp.tu-darmstadt.de/) 179 | 180 | [https://www.tu-darmstadt.de/](https://www.tu-darmstadt.de/) 181 | 182 | Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions. 183 | 184 | > This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication. 185 | 186 | ## Updates 187 | - May 16, 2024 188 | - Accepted by ACL 2024 Main 189 | - Feb. 06, 2023 190 | - Rename the folder name from v3 into data in the fileserver. 191 | - Uploaded other data like coference-resolution results, extracted keyphrases, and the experiment results. 192 | - Created the HF datasets. 193 | - Refactored the experiment code, aligning with the new paper version. 194 | - Nov. 16, 2023 195 | - New version of data uploaded to https://public.ukp.informatik.tu-darmstadt.de/kwang/dapr/data 196 | - Replaced COLIEE with ConditionalQA 197 | - ConditionalQA has two sub-versions here: (1) ConditionalQA, the original dataset; (2) CleanedConditionalQA whose html tags are removed. 198 | - The MSMARCO dataset now segments the documents by keeping the labeled paragraphs while leaving the leftover parts as the other paragraphs. 199 | - For example, given the original unsegmented document text "11122222334444566" and if the labeled paragraphs are "22222" and "4444", then the segmentations will be ["111", "22222", "33", "4444", "566"]. 200 | - We only do retrieval over the labeled paragraphs, which is specified by the attribute "candidate_chunk_ids" of each document object. 201 | - We now use only the specific version of the ColBERT package, as the latest one has some unknown issue. 202 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------