├── .gitignore ├── .llm_server_address.jsonnet ├── .retriever_address.jsonnet ├── LICENSE ├── README.md ├── base_configs ├── ircot_codex_2wikimultihopqa.jsonnet ├── ircot_codex_hotpotqa.jsonnet ├── ircot_codex_iirc.jsonnet ├── ircot_codex_musique.jsonnet ├── ircot_flan_t5_base_2wikimultihopqa.jsonnet ├── ircot_flan_t5_base_hotpotqa.jsonnet ├── ircot_flan_t5_base_iirc.jsonnet ├── ircot_flan_t5_base_musique.jsonnet ├── ircot_flan_t5_large_2wikimultihopqa.jsonnet ├── ircot_flan_t5_large_hotpotqa.jsonnet ├── ircot_flan_t5_large_iirc.jsonnet ├── ircot_flan_t5_large_musique.jsonnet ├── ircot_flan_t5_xl_2wikimultihopqa.jsonnet ├── ircot_flan_t5_xl_hotpotqa.jsonnet ├── ircot_flan_t5_xl_iirc.jsonnet ├── ircot_flan_t5_xl_musique.jsonnet ├── ircot_flan_t5_xxl_2wikimultihopqa.jsonnet ├── ircot_flan_t5_xxl_hotpotqa.jsonnet ├── ircot_flan_t5_xxl_iirc.jsonnet ├── ircot_flan_t5_xxl_musique.jsonnet ├── ircot_qa_codex_2wikimultihopqa.jsonnet ├── ircot_qa_codex_hotpotqa.jsonnet ├── ircot_qa_codex_iirc.jsonnet ├── ircot_qa_codex_musique.jsonnet ├── ircot_qa_flan_t5_base_2wikimultihopqa.jsonnet ├── ircot_qa_flan_t5_base_hotpotqa.jsonnet ├── ircot_qa_flan_t5_base_iirc.jsonnet ├── ircot_qa_flan_t5_base_musique.jsonnet ├── ircot_qa_flan_t5_large_2wikimultihopqa.jsonnet ├── ircot_qa_flan_t5_large_hotpotqa.jsonnet ├── ircot_qa_flan_t5_large_iirc.jsonnet ├── ircot_qa_flan_t5_large_musique.jsonnet ├── ircot_qa_flan_t5_xl_2wikimultihopqa.jsonnet ├── ircot_qa_flan_t5_xl_hotpotqa.jsonnet ├── ircot_qa_flan_t5_xl_iirc.jsonnet ├── ircot_qa_flan_t5_xl_musique.jsonnet ├── ircot_qa_flan_t5_xxl_2wikimultihopqa.jsonnet ├── ircot_qa_flan_t5_xxl_hotpotqa.jsonnet ├── ircot_qa_flan_t5_xxl_iirc.jsonnet ├── ircot_qa_flan_t5_xxl_musique.jsonnet ├── nor_qa_codex_2wikimultihopqa.jsonnet ├── nor_qa_codex_hotpotqa.jsonnet ├── nor_qa_codex_iirc.jsonnet ├── nor_qa_codex_musique.jsonnet ├── nor_qa_flan_t5_base_2wikimultihopqa.jsonnet ├── nor_qa_flan_t5_base_hotpotqa.jsonnet ├── nor_qa_flan_t5_base_iirc.jsonnet ├── nor_qa_flan_t5_base_musique.jsonnet ├── nor_qa_flan_t5_large_2wikimultihopqa.jsonnet ├── nor_qa_flan_t5_large_hotpotqa.jsonnet ├── nor_qa_flan_t5_large_iirc.jsonnet ├── nor_qa_flan_t5_large_musique.jsonnet ├── nor_qa_flan_t5_xl_2wikimultihopqa.jsonnet ├── nor_qa_flan_t5_xl_hotpotqa.jsonnet ├── nor_qa_flan_t5_xl_iirc.jsonnet ├── nor_qa_flan_t5_xl_musique.jsonnet ├── nor_qa_flan_t5_xxl_2wikimultihopqa.jsonnet ├── nor_qa_flan_t5_xxl_hotpotqa.jsonnet ├── nor_qa_flan_t5_xxl_iirc.jsonnet ├── nor_qa_flan_t5_xxl_musique.jsonnet ├── oner_2wikimultihopqa.jsonnet ├── oner_hotpotqa.jsonnet ├── oner_iirc.jsonnet ├── oner_musique.jsonnet ├── oner_qa_codex_2wikimultihopqa.jsonnet ├── oner_qa_codex_hotpotqa.jsonnet ├── oner_qa_codex_iirc.jsonnet ├── oner_qa_codex_musique.jsonnet ├── oner_qa_flan_t5_base_2wikimultihopqa.jsonnet ├── oner_qa_flan_t5_base_hotpotqa.jsonnet ├── oner_qa_flan_t5_base_iirc.jsonnet ├── oner_qa_flan_t5_base_musique.jsonnet ├── oner_qa_flan_t5_large_2wikimultihopqa.jsonnet ├── oner_qa_flan_t5_large_hotpotqa.jsonnet ├── oner_qa_flan_t5_large_iirc.jsonnet ├── oner_qa_flan_t5_large_musique.jsonnet ├── oner_qa_flan_t5_xl_2wikimultihopqa.jsonnet ├── oner_qa_flan_t5_xl_hotpotqa.jsonnet ├── oner_qa_flan_t5_xl_iirc.jsonnet ├── oner_qa_flan_t5_xl_musique.jsonnet ├── oner_qa_flan_t5_xxl_2wikimultihopqa.jsonnet ├── oner_qa_flan_t5_xxl_hotpotqa.jsonnet ├── oner_qa_flan_t5_xxl_iirc.jsonnet └── oner_qa_flan_t5_xxl_musique.jsonnet ├── commaqa ├── configs │ ├── README.md │ ├── __init__.py │ ├── dataset_build_config.py │ ├── entities_config.py │ ├── predicate_config.py │ ├── predicate_language_config.py │ ├── step_config.py │ ├── theory_config.py │ └── utils.py ├── dataset │ ├── __init__.py │ ├── build_dataset.py │ ├── build_submodel_datasets.py │ ├── generate_decomposition_predictions.py │ ├── generate_decompositions_from_chains.py │ └── utils.py ├── datasets_utils │ ├── build_letter_cat_dataset.py │ ├── build_reverse_dataset.py │ ├── convert_gsm8k_to_drop.py │ ├── drop_eval.py │ └── subselect_drop_dataset.py ├── execution │ ├── __init__.py │ ├── constants.py │ ├── kblookup.py │ ├── llm_qa_model.py │ ├── math_model.py │ ├── model_executer.py │ ├── operation_executer.py │ └── utils.py ├── inference │ ├── __init__.py │ ├── configurable_inference.py │ ├── constants.py │ ├── data_instances.py │ ├── dataset_readers.py │ ├── ircot.py │ ├── model_search.py │ ├── participant_execution.py │ ├── participant_execution_routed.py │ ├── participant_qa.py │ ├── prompt_reader.py │ └── utils.py └── models │ ├── generator.py │ ├── gpt3generator.py │ └── llm_client_generator.py ├── download ├── official_eval.sh ├── predictions.sh ├── processed_data.sh └── raw_data.sh ├── evaluate.py ├── ircot.jpg ├── lib.py ├── llm_server ├── .gitignore ├── Dockerfile ├── __init__.py ├── client.py ├── constants.py └── serve.py ├── metrics ├── __init__.py ├── answer_support_recall.py ├── drop_answer_em_f1.py ├── drop_eval.py ├── metric.py ├── squad_answer_em_f1.py └── support_em_f1.py ├── predict.py ├── processing_scripts ├── process_2wikimultihopqa.py ├── process_hotpotqa.py ├── process_iirc.py ├── process_musique.py └── subsample_dataset_and_remap_paras.py ├── prompt_generator ├── attach_data_annotations.py ├── common.py ├── data_annotations │ ├── 2wikimultihopqa.jsonnet │ ├── hotpotqa.jsonnet │ ├── iirc.jsonnet │ └── musique.jsonnet └── generate_prompts.py ├── prompts ├── 2wikimultihopqa │ ├── gold_with_1_distractors_context_cot_qa_codex.txt │ ├── gold_with_1_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_1_distractors_context_direct_qa_codex.txt │ ├── gold_with_1_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_cot_qa_codex.txt │ ├── gold_with_2_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_direct_qa_codex.txt │ ├── gold_with_2_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_cot_qa_codex.txt │ ├── gold_with_3_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_direct_qa_codex.txt │ ├── gold_with_3_distractors_context_direct_qa_flan_t5.txt │ ├── no_context_cot_qa_codex.txt │ ├── no_context_cot_qa_flan_t5.txt │ ├── no_context_direct_qa_codex.txt │ └── no_context_direct_qa_flan_t5.txt ├── hotpotqa │ ├── gold_with_1_distractors_context_cot_qa_codex.txt │ ├── gold_with_1_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_1_distractors_context_direct_qa_codex.txt │ ├── gold_with_1_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_cot_qa_codex.txt │ ├── gold_with_2_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_direct_qa_codex.txt │ ├── gold_with_2_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_cot_qa_codex.txt │ ├── gold_with_3_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_direct_qa_codex.txt │ ├── gold_with_3_distractors_context_direct_qa_flan_t5.txt │ ├── no_context_cot_qa_codex.txt │ ├── no_context_cot_qa_flan_t5.txt │ ├── no_context_direct_qa_codex.txt │ └── no_context_direct_qa_flan_t5.txt ├── iirc │ ├── gold_with_1_distractors_context_cot_qa_codex.txt │ ├── gold_with_1_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_1_distractors_context_direct_qa_codex.txt │ ├── gold_with_1_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_cot_qa_codex.txt │ ├── gold_with_2_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_direct_qa_codex.txt │ ├── gold_with_2_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_cot_qa_codex.txt │ ├── gold_with_3_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_direct_qa_codex.txt │ ├── gold_with_3_distractors_context_direct_qa_flan_t5.txt │ ├── no_context_cot_qa_codex.txt │ ├── no_context_cot_qa_flan_t5.txt │ ├── no_context_direct_qa_codex.txt │ ├── no_context_direct_qa_flan_t5.txt │ ├── no_context_open_llm_retrieval_codex.txt │ └── no_context_open_llm_retrieval_flan_t5.txt └── musique │ ├── gold_with_1_distractors_context_cot_qa_codex.txt │ ├── gold_with_1_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_1_distractors_context_direct_qa_codex.txt │ ├── gold_with_1_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_cot_qa_codex.txt │ ├── gold_with_2_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_2_distractors_context_direct_qa_codex.txt │ ├── gold_with_2_distractors_context_direct_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_cot_qa_codex.txt │ ├── gold_with_3_distractors_context_cot_qa_flan_t5.txt │ ├── gold_with_3_distractors_context_direct_qa_codex.txt │ ├── gold_with_3_distractors_context_direct_qa_flan_t5.txt │ ├── no_context_cot_qa_codex.txt │ ├── no_context_cot_qa_flan_t5.txt │ ├── no_context_direct_qa_codex.txt │ └── no_context_direct_qa_flan_t5.txt ├── pyproject.toml ├── reproduce.sh ├── requirements.txt ├── retriever_server ├── build_index.py ├── elasticsearch_retriever.py ├── elasticsearch_server.py ├── interactive_query.py ├── requirements.txt ├── serve.py └── unified_retriever.py ├── run.py └── runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store 3 | raw_data/ 4 | processed_data/ 5 | predictions/ 6 | instantiated_configs/ 7 | official_evaluation/ 8 | .retriever_address.json 9 | .llm_server_address.json 10 | .retriever_address.jsonnet 11 | .llm_server_address.jsonnet 12 | .history 13 | .temp -------------------------------------------------------------------------------- /.llm_server_address.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "host": "http://localhost", 3 | "port": 8010, 4 | } 5 | -------------------------------------------------------------------------------- /.retriever_address.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "host": "http://localhost", 3 | "port": 8000, 4 | } 5 | -------------------------------------------------------------------------------- /base_configs/nor_qa_codex_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = null; 26 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_question", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_question": { 39 | "name": "copy_question", 40 | "next_model": "generate_answer", 41 | "eoq_after_n_calls": 1, 42 | "end_state": "[EOQ]", 43 | }, 44 | "generate_answer": { 45 | "name": "llmqa", 46 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 47 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 48 | "prompt_reader_args": prompt_reader_args, 49 | "end_state": "[EOQ]", 50 | "gen_model": "gpt3", 51 | "engine": "code-davinci-002", 52 | "retry_after_n_seconds": 50, 53 | "add_context": add_pinned_paras, 54 | }, 55 | "extract_answer": { 56 | "name": "answer_extractor", 57 | "query_source": "last_answer", 58 | "regex": ".* answer is:? (.*)\\.?", 59 | "match_all_on_failure": true, 60 | "remove_last_fullstop": true, 61 | } 62 | }, 63 | "reader": { 64 | "name": "multi_para_rc", 65 | "add_paras": false, 66 | "add_gold_paras": false, 67 | "add_pinned_paras": add_pinned_paras, 68 | }, 69 | "prediction_type": "answer" 70 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_codex_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = null; 26 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_question", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_question": { 39 | "name": "copy_question", 40 | "next_model": "generate_answer", 41 | "eoq_after_n_calls": 1, 42 | "end_state": "[EOQ]", 43 | }, 44 | "generate_answer": { 45 | "name": "llmqa", 46 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 47 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 48 | "prompt_reader_args": prompt_reader_args, 49 | "end_state": "[EOQ]", 50 | "gen_model": "gpt3", 51 | "engine": "code-davinci-002", 52 | "retry_after_n_seconds": 50, 53 | "add_context": add_pinned_paras, 54 | }, 55 | "extract_answer": { 56 | "name": "answer_extractor", 57 | "query_source": "last_answer", 58 | "regex": ".* answer is:? (.*)\\.?", 59 | "match_all_on_failure": true, 60 | "remove_last_fullstop": true, 61 | } 62 | }, 63 | "reader": { 64 | "name": "multi_para_rc", 65 | "add_paras": false, 66 | "add_gold_paras": false, 67 | "add_pinned_paras": add_pinned_paras, 68 | }, 69 | "prediction_type": "answer" 70 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_codex_iirc.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "iirc"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = null; 26 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_question", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_question": { 39 | "name": "copy_question", 40 | "next_model": "generate_answer", 41 | "eoq_after_n_calls": 1, 42 | "end_state": "[EOQ]", 43 | }, 44 | "generate_answer": { 45 | "name": "llmqa", 46 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 47 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 48 | "prompt_reader_args": prompt_reader_args, 49 | "end_state": "[EOQ]", 50 | "gen_model": "gpt3", 51 | "engine": "code-davinci-002", 52 | "retry_after_n_seconds": 50, 53 | "add_context": add_pinned_paras, 54 | }, 55 | "extract_answer": { 56 | "name": "answer_extractor", 57 | "query_source": "last_answer", 58 | "regex": ".* answer is:? (.*)\\.?", 59 | "match_all_on_failure": true, 60 | "remove_last_fullstop": true, 61 | } 62 | }, 63 | "reader": { 64 | "name": "multi_para_rc", 65 | "add_paras": false, 66 | "add_gold_paras": false, 67 | "add_pinned_paras": add_pinned_paras, 68 | }, 69 | "prediction_type": "answer" 70 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_codex_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = null; 26 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_question", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_question": { 39 | "name": "copy_question", 40 | "next_model": "generate_answer", 41 | "eoq_after_n_calls": 1, 42 | "end_state": "[EOQ]", 43 | }, 44 | "generate_answer": { 45 | "name": "llmqa", 46 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 47 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 48 | "prompt_reader_args": prompt_reader_args, 49 | "end_state": "[EOQ]", 50 | "gen_model": "gpt3", 51 | "engine": "code-davinci-002", 52 | "retry_after_n_seconds": 50, 53 | "add_context": add_pinned_paras, 54 | }, 55 | "extract_answer": { 56 | "name": "answer_extractor", 57 | "query_source": "last_answer", 58 | "regex": ".* answer is:? (.*)\\.?", 59 | "match_all_on_failure": true, 60 | "remove_last_fullstop": true, 61 | } 62 | }, 63 | "reader": { 64 | "name": "multi_para_rc", 65 | "add_paras": false, 66 | "add_gold_paras": false, 67 | "add_pinned_paras": add_pinned_paras, 68 | }, 69 | "prediction_type": "answer" 70 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_base_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-base", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-base", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_base_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-base", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-base", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_base_iirc.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "iirc"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-base", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-base", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_base_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-base", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-base", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_large_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-large", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-large", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_large_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-large", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-large", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_large_iirc.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "iirc"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-large", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-large", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_large_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-large", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-large", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xl_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xl_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xl_iirc.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "iirc"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xl_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xxl_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xxl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xxl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xxl_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xxl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xxl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xxl_iirc.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "iirc"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xxl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xxl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/nor_qa_flan_t5_xxl_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 0, # don't drop in reading phase. 17 | "shuffle": false, 18 | "model_length_limit": 1000000, # don't drop in reading phase. 19 | "tokenizer_model_name": "google/flan-t5-xxl", 20 | }; 21 | 22 | # (Potentially) Hyper-parameters: 23 | # null means it's unused. 24 | local llm_retrieval_count = null; 25 | local llm_map_count = null; 26 | local bm25_retrieval_count = null; 27 | local rc_context_type_ = "no"; # Choices: no, gold, gold_with_n_distractors 28 | local distractor_count = null; # Choices: 1, 2, 3 29 | local rc_context_type = ( 30 | if rc_context_type_ == "gold_with_n_distractors" 31 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 32 | ); 33 | local rc_qa_type = "direct"; # Choices: direct, cot 34 | local qa_question_prefix = ( 35 | if std.endsWith(rc_context_type, "cot") 36 | then "Answer the following question by reasoning step-by-step.\n" 37 | else "Answer the following question.\n" 38 | ); 39 | 40 | { 41 | "start_state": "generate_question", 42 | "end_state": "[EOQ]", 43 | "models": { 44 | "generate_question": { 45 | "name": "copy_question", 46 | "next_model": "generate_answer", 47 | "eoq_after_n_calls": 1, 48 | "end_state": "[EOQ]", 49 | }, 50 | "generate_answer": { 51 | "name": "llmqa", 52 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 53 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_flan_t5.txt", 54 | "question_prefix": qa_question_prefix, 55 | "prompt_reader_args": prompt_reader_args, 56 | "end_state": "[EOQ]", 57 | "gen_model": "llm_api", 58 | "model_name": "google/flan-t5-xxl", 59 | "model_tokens_limit": 6000, 60 | "max_length": 200, 61 | "add_context": add_pinned_paras, 62 | }, 63 | "extract_answer": { 64 | "name": "answer_extractor", 65 | "query_source": "last_answer", 66 | "regex": ".* answer is:? (.*)\\.?", 67 | "match_all_on_failure": true, 68 | "remove_last_fullstop": true, 69 | } 70 | }, 71 | "reader": { 72 | "name": "multi_para_rc", 73 | "add_paras": false, 74 | "add_gold_paras": false, 75 | "add_pinned_paras": add_pinned_paras, 76 | }, 77 | "prediction_type": "answer" 78 | } -------------------------------------------------------------------------------- /base_configs/oner_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = 5; 26 | local rc_context_type_ = null; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = null; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "retrieve_and_reset_paragraphs", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "retrieve_and_reset_paragraphs": { 39 | "name": "retrieve_and_reset_paragraphs", 40 | "retrieval_type": "bm25", 41 | "retriever_host": std.extVar("RETRIEVER_HOST"), 42 | "retriever_port": std.extVar("RETRIEVER_PORT"), 43 | "retrieval_count": bm25_retrieval_count, 44 | "global_max_num_paras": 15, 45 | "query_source": "original_question", 46 | "source_corpus_name": retrieval_corpus_name, 47 | "document_type": "title_paragraph_text", 48 | "return_pids": true, 49 | "end_state": "[EOQ]", 50 | }, 51 | }, 52 | "reader": { 53 | "name": "multi_para_rc", 54 | "add_paras": false, 55 | "add_gold_paras": false, 56 | "add_pinned_paras": add_pinned_paras, 57 | }, 58 | "prediction_type": "pids" 59 | } -------------------------------------------------------------------------------- /base_configs/oner_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = 5; 26 | local rc_context_type_ = null; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = null; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "retrieve_and_reset_paragraphs", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "retrieve_and_reset_paragraphs": { 39 | "name": "retrieve_and_reset_paragraphs", 40 | "retrieval_type": "bm25", 41 | "retriever_host": std.extVar("RETRIEVER_HOST"), 42 | "retriever_port": std.extVar("RETRIEVER_PORT"), 43 | "retrieval_count": bm25_retrieval_count, 44 | "global_max_num_paras": 15, 45 | "query_source": "original_question", 46 | "source_corpus_name": retrieval_corpus_name, 47 | "document_type": "title_paragraph_text", 48 | "return_pids": true, 49 | "end_state": "[EOQ]", 50 | }, 51 | }, 52 | "reader": { 53 | "name": "multi_para_rc", 54 | "add_paras": false, 55 | "add_gold_paras": false, 56 | "add_pinned_paras": add_pinned_paras, 57 | }, 58 | "prediction_type": "pids" 59 | } -------------------------------------------------------------------------------- /base_configs/oner_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = 5; 26 | local rc_context_type_ = null; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = null; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = null; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "retrieve_and_reset_paragraphs", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "retrieve_and_reset_paragraphs": { 39 | "name": "retrieve_and_reset_paragraphs", 40 | "retrieval_type": "bm25", 41 | "retriever_host": std.extVar("RETRIEVER_HOST"), 42 | "retriever_port": std.extVar("RETRIEVER_PORT"), 43 | "retrieval_count": bm25_retrieval_count, 44 | "global_max_num_paras": 15, 45 | "query_source": "original_question", 46 | "source_corpus_name": retrieval_corpus_name, 47 | "document_type": "title_paragraph_text", 48 | "return_pids": true, 49 | "end_state": "[EOQ]", 50 | }, 51 | }, 52 | "reader": { 53 | "name": "multi_para_rc", 54 | "add_paras": false, 55 | "add_gold_paras": false, 56 | "add_pinned_paras": add_pinned_paras, 57 | }, 58 | "prediction_type": "pids" 59 | } -------------------------------------------------------------------------------- /base_configs/oner_qa_codex_2wikimultihopqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "2wikimultihopqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = 5; 26 | local rc_context_type_ = "gold_with_n_distractors"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = "2"; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_titles", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_titles": { 39 | "name": "retrieve_and_reset_paragraphs", 40 | "next_model": "generate_main_question", 41 | "retrieval_type": "bm25", 42 | "retriever_host": std.extVar("RETRIEVER_HOST"), 43 | "retriever_port": std.extVar("RETRIEVER_PORT"), 44 | "retrieval_count": bm25_retrieval_count, 45 | "global_max_num_paras": 15, 46 | "query_source": "original_question", 47 | "source_corpus_name": retrieval_corpus_name, 48 | "document_type": "title_paragraph_text", 49 | "end_state": "[EOQ]", 50 | }, 51 | 52 | "generate_main_question": { 53 | "name": "copy_question", 54 | "next_model": "answer_main_question", 55 | "eoq_after_n_calls": 1, 56 | "end_state": "[EOQ]", 57 | }, 58 | "answer_main_question": { 59 | "name": "llmqa", 60 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 61 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 62 | "prompt_reader_args": prompt_reader_args, 63 | "end_state": "[EOQ]", 64 | "gen_model": "gpt3", 65 | "engine": "code-davinci-002", 66 | "retry_after_n_seconds": 50, 67 | "add_context": true, 68 | }, 69 | "extract_answer": { 70 | "name": "answer_extractor", 71 | "query_source": "last_answer", 72 | "regex": ".* answer is:? (.*)\\.?", 73 | "match_all_on_failure": true, 74 | "remove_last_fullstop": true, 75 | } 76 | }, 77 | "reader": { 78 | "name": "multi_para_rc", 79 | "add_paras": false, 80 | "add_gold_paras": false, 81 | "add_pinned_paras": add_pinned_paras, 82 | }, 83 | "prediction_type": "answer" 84 | } -------------------------------------------------------------------------------- /base_configs/oner_qa_codex_hotpotqa.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "hotpotqa"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = 5; 26 | local rc_context_type_ = "gold_with_n_distractors"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = "2"; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_titles", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_titles": { 39 | "name": "retrieve_and_reset_paragraphs", 40 | "next_model": "generate_main_question", 41 | "retrieval_type": "bm25", 42 | "retriever_host": std.extVar("RETRIEVER_HOST"), 43 | "retriever_port": std.extVar("RETRIEVER_PORT"), 44 | "retrieval_count": bm25_retrieval_count, 45 | "global_max_num_paras": 15, 46 | "query_source": "original_question", 47 | "source_corpus_name": retrieval_corpus_name, 48 | "document_type": "title_paragraph_text", 49 | "end_state": "[EOQ]", 50 | }, 51 | 52 | "generate_main_question": { 53 | "name": "copy_question", 54 | "next_model": "answer_main_question", 55 | "eoq_after_n_calls": 1, 56 | "end_state": "[EOQ]", 57 | }, 58 | "answer_main_question": { 59 | "name": "llmqa", 60 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 61 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 62 | "prompt_reader_args": prompt_reader_args, 63 | "end_state": "[EOQ]", 64 | "gen_model": "gpt3", 65 | "engine": "code-davinci-002", 66 | "retry_after_n_seconds": 50, 67 | "add_context": true, 68 | }, 69 | "extract_answer": { 70 | "name": "answer_extractor", 71 | "query_source": "last_answer", 72 | "regex": ".* answer is:? (.*)\\.?", 73 | "match_all_on_failure": true, 74 | "remove_last_fullstop": true, 75 | } 76 | }, 77 | "reader": { 78 | "name": "multi_para_rc", 79 | "add_paras": false, 80 | "add_gold_paras": false, 81 | "add_pinned_paras": add_pinned_paras, 82 | }, 83 | "prediction_type": "answer" 84 | } -------------------------------------------------------------------------------- /base_configs/oner_qa_codex_musique.jsonnet: -------------------------------------------------------------------------------- 1 | # Set dataset: 2 | local dataset = "musique"; 3 | local retrieval_corpus_name = dataset; 4 | local add_pinned_paras = if dataset == "iirc" then true else false; 5 | local valid_qids = { 6 | "hotpotqa": ["5ab92dba554299131ca422a2","5a7bbc50554299042af8f7d0","5add363c5542990dbb2f7dc8","5a835abe5542996488c2e426","5ae0185b55429942ec259c1b","5a790e7855429970f5fffe3d","5a754ab35542993748c89819","5a89c14f5542993b751ca98a","5abb14bd5542992ccd8e7f07","5a89d58755429946c8d6e9d9","5a88f9d55542995153361218","5a90620755429933b8a20508","5a77acab5542992a6e59df76","5abfb3435542990832d3a1c1","5a8f44ab5542992414482a25","5adfad0c554299603e41835a","5a7fc53555429969796c1b55","5a8ed9f355429917b4a5bddd","5ac2ada5554299657fa2900d","5a758ea55542992db9473680"], 7 | "2wikimultihopqa": ["5811079c0bdc11eba7f7acde48001122","97954d9408b011ebbd84ac1f6bf848b6","35bf3490096d11ebbdafac1f6bf848b6","c6805b2908a911ebbd80ac1f6bf848b6","5897ec7a086c11ebbd61ac1f6bf848b6","e5150a5a0bda11eba7f7acde48001122","a5995da508ab11ebbd82ac1f6bf848b6","cdbb82ec0baf11ebab90acde48001122","f44939100bda11eba7f7acde48001122","4724c54e08e011ebbda1ac1f6bf848b6","f86b4a28091711ebbdaeac1f6bf848b6","13cda43c09b311ebbdb0ac1f6bf848b6","228546780bdd11eba7f7acde48001122","c6f63bfb089e11ebbd78ac1f6bf848b6","1ceeab380baf11ebab90acde48001122","8727d1280bdc11eba7f7acde48001122","f1ccdfee094011ebbdaeac1f6bf848b6","79a863dc0bdc11eba7f7acde48001122","028eaef60bdb11eba7f7acde48001122","af8c6722088b11ebbd6fac1f6bf848b6"], 8 | "musique": ["2hop__323282_79175","2hop__292995_8796","2hop__439265_539716","4hop3__703974_789671_24078_24137","2hop__154225_727337","2hop__861128_15822","3hop1__858730_386977_851569","2hop__642271_608104","2hop__387702_20661","2hop__131516_53573","2hop__496817_701819","2hop__804754_52230","3hop1__61746_67065_43617","3hop1__753524_742157_573834","2hop__427213_79175","3hop1__443556_763924_573834","2hop__782642_52667","2hop__102217_58400","2hop__195347_20661","4hop3__463724_100414_35260_54090"], 9 | "iirc": ["q_10236","q_3268","q_8776","q_9499","q_389","q_8350","q_3283","q_3208","q_1672","q_9433","q_8173","q_8981","q_10227","q_2466","q_8736","q_9591","q_10344","q_10270","q_9518","q_3290"], 10 | }[dataset]; 11 | local prompt_reader_args = { 12 | "filter_by_key_values": { 13 | "qid": valid_qids 14 | }, 15 | "order_by_key": "qid", 16 | "estimated_generation_length": 300, 17 | "shuffle": false, 18 | "model_length_limit": 8000, 19 | }; 20 | 21 | # (Potentially) Hyper-parameters: 22 | # null means it's unused. 23 | local llm_retrieval_count = null; 24 | local llm_map_count = null; 25 | local bm25_retrieval_count = 5; 26 | local rc_context_type_ = "gold_with_n_distractors"; # Choices: no, gold, gold_with_n_distractors 27 | local distractor_count = "2"; # Choices: 1, 2, 3 28 | local rc_context_type = ( 29 | if rc_context_type_ == "gold_with_n_distractors" 30 | then "gold_with_" + distractor_count + "_distractors" else rc_context_type_ 31 | ); 32 | local rc_qa_type = "cot"; # Choices: direct, cot 33 | 34 | { 35 | "start_state": "generate_titles", 36 | "end_state": "[EOQ]", 37 | "models": { 38 | "generate_titles": { 39 | "name": "retrieve_and_reset_paragraphs", 40 | "next_model": "generate_main_question", 41 | "retrieval_type": "bm25", 42 | "retriever_host": std.extVar("RETRIEVER_HOST"), 43 | "retriever_port": std.extVar("RETRIEVER_PORT"), 44 | "retrieval_count": bm25_retrieval_count, 45 | "global_max_num_paras": 15, 46 | "query_source": "original_question", 47 | "source_corpus_name": retrieval_corpus_name, 48 | "document_type": "title_paragraph_text", 49 | "end_state": "[EOQ]", 50 | }, 51 | 52 | "generate_main_question": { 53 | "name": "copy_question", 54 | "next_model": "answer_main_question", 55 | "eoq_after_n_calls": 1, 56 | "end_state": "[EOQ]", 57 | }, 58 | "answer_main_question": { 59 | "name": "llmqa", 60 | "next_model": if std.endsWith(rc_qa_type, "cot") then "extract_answer" else null, 61 | "prompt_file": "prompts/"+dataset+"/"+rc_context_type+"_context_"+rc_qa_type+"_qa_codex.txt", 62 | "prompt_reader_args": prompt_reader_args, 63 | "end_state": "[EOQ]", 64 | "gen_model": "gpt3", 65 | "engine": "code-davinci-002", 66 | "retry_after_n_seconds": 50, 67 | "add_context": true, 68 | }, 69 | "extract_answer": { 70 | "name": "answer_extractor", 71 | "query_source": "last_answer", 72 | "regex": ".* answer is:? (.*)\\.?", 73 | "match_all_on_failure": true, 74 | "remove_last_fullstop": true, 75 | } 76 | }, 77 | "reader": { 78 | "name": "multi_para_rc", 79 | "add_paras": false, 80 | "add_gold_paras": false, 81 | "add_pinned_paras": add_pinned_paras, 82 | }, 83 | "prediction_type": "answer" 84 | } -------------------------------------------------------------------------------- /commaqa/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/commaqa/configs/__init__.py -------------------------------------------------------------------------------- /commaqa/configs/dataset_build_config.py: -------------------------------------------------------------------------------- 1 | from commaqa.configs.entities_config import EntitiesConfig 2 | from commaqa.configs.predicate_config import PredicateConfig 3 | from commaqa.configs.predicate_language_config import PredicateLanguageConfig 4 | from commaqa.configs.theory_config import TheoryConfig 5 | 6 | 7 | class DatasetBuildConfig: 8 | def __init__(self, input_json): 9 | self.version = input_json["version"] 10 | self.entities = EntitiesConfig(input_json["entities"]) 11 | self.predicates = [PredicateConfig(x) for x in input_json["predicates"].items()] 12 | self.theories = [TheoryConfig(x) for x in input_json["theories"]] 13 | self.pred_lang_config = PredicateLanguageConfig(input_json["predicate_language"]) 14 | -------------------------------------------------------------------------------- /commaqa/configs/entities_config.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import ceil 3 | from typing import Dict, List 4 | 5 | 6 | class EntitiesConfig: 7 | def __init__(self, entities_json: Dict[str, List[str]]): 8 | self.entity_type_map = entities_json 9 | 10 | def subsample(self, num_ents): 11 | new_ent_map = {} 12 | for etype, elist in self.entity_type_map.items(): 13 | # if fraction passed, sample ratio 14 | if num_ents <= 1: 15 | new_ent_map[etype] = random.sample(elist, ceil(len(elist) * num_ents)) 16 | else: 17 | new_ent_map[etype] = random.sample(elist, num_ents) 18 | 19 | return EntitiesConfig(new_ent_map) 20 | 21 | def __getitem__(self, item: str): 22 | return self.entity_type_map[item] 23 | -------------------------------------------------------------------------------- /commaqa/configs/predicate_language_config.py: -------------------------------------------------------------------------------- 1 | from commaqa.configs.step_config import StepConfig 2 | from commaqa.dataset.utils import get_predicate_args 3 | 4 | 5 | class ModelQuestionConfig: 6 | def __init__(self, config_json): 7 | self.steps = [StepConfig(x) for x in config_json["steps"]] if "steps" in config_json else [] 8 | self.questions = config_json.get("questions") 9 | self.init = config_json["init"] 10 | self.model = config_json["model"] 11 | self.predicate = config_json["predicate"] 12 | 13 | def to_json(self): 14 | return { 15 | "steps": [x.to_json() for x in self.steps], 16 | "questions": self.questions, 17 | "init": self.init, 18 | "model": self.model, 19 | "predicate": self.predicate, 20 | } 21 | 22 | 23 | class PredicateLanguageConfig: 24 | def __init__(self, pred_lang_config): 25 | # import json 26 | # print(json.dumps(pred_lang_config, indent=2)) 27 | self.predicate_config = {} 28 | self.model_config = {} 29 | for predicate, config in pred_lang_config.items(): 30 | config["predicate"] = predicate 31 | question_config = ModelQuestionConfig(config) 32 | self.predicate_config[predicate] = question_config 33 | model = config["model"] 34 | if model not in self.model_config: 35 | self.model_config[model] = [] 36 | self.model_config[model].append(question_config) 37 | 38 | def model_config_as_json(self): 39 | return {model: [config.to_json() for config in configs] for model, configs in self.model_config.items()} 40 | 41 | def find_model(self, question_predicate): 42 | matching_configs = self.find_valid_configs(question_predicate) 43 | if len(matching_configs) == 0: 44 | return None 45 | matching_models = {x.model for x in matching_configs} 46 | if len(matching_models) != 1: 47 | raise ValueError( 48 | "Unexpected number of matching models: {} for {}. " 49 | "Expected one model".format(matching_models, question_predicate) 50 | ) 51 | return matching_models.pop() 52 | 53 | def find_valid_configs(self, question_predicate): 54 | qpred, qargs = get_predicate_args(question_predicate) 55 | matching_configs = [] 56 | for key, config in self.predicate_config.items(): 57 | config_qpred, config_qargs = get_predicate_args(key) 58 | if config_qpred == qpred: 59 | assert len(qargs) == len(config_qargs), "{} {}\n{}".format(qargs, config_qargs, question_predicate) 60 | mismatch = False 61 | for qarg, cqarg in zip(qargs, config_qargs): 62 | if (cqarg == "?") ^ (qarg == "?"): 63 | mismatch = True 64 | if not mismatch: 65 | matching_configs.append(config) 66 | return matching_configs 67 | -------------------------------------------------------------------------------- /commaqa/configs/step_config.py: -------------------------------------------------------------------------------- 1 | class StepConfig: 2 | def __init__(self, step_json): 3 | self.operation = step_json["operation"] 4 | self.question = step_json["question"] 5 | self.answer = step_json["answer"] 6 | 7 | def to_json(self): 8 | return self.__dict__ 9 | -------------------------------------------------------------------------------- /commaqa/configs/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | from typing import List, Dict 4 | 5 | from commaqa.configs.predicate_language_config import PredicateLanguageConfig 6 | from commaqa.configs.step_config import StepConfig 7 | from commaqa.dataset.utils import is_question_var, nonempty_answer 8 | from commaqa.execution.operation_executer import OperationExecuter 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def execute_steps( 14 | steps: List[StepConfig], 15 | input_assignments: Dict[str, str], 16 | executer: OperationExecuter, 17 | pred_lang_config: PredicateLanguageConfig = None, 18 | input_model: str = None, 19 | ): 20 | curr_assignment = deepcopy(input_assignments) 21 | if "facts_used" not in curr_assignment: 22 | curr_assignment["facts_used"] = [] 23 | 24 | for step in steps: 25 | if input_model is None: 26 | model = pred_lang_config.find_model(step.question) 27 | if model is None: 28 | raise ValueError("No model found for {}".format(step.question)) 29 | else: 30 | model = input_model 31 | 32 | new_question = step.question 33 | for k, v in curr_assignment.items(): 34 | # only replace question variables($1, $2). Answer variables (#1, #2) used by executer 35 | if is_question_var(k): 36 | new_question = new_question.replace(k, v) 37 | answers, curr_facts = executer.execute_operation( 38 | operation=step.operation, model=model, question=new_question, assignments=curr_assignment 39 | ) 40 | if answers is None: 41 | # execution failed 42 | return None 43 | elif nonempty_answer(answers): 44 | curr_assignment[step.answer] = answers 45 | curr_assignment["facts_used"].extend(curr_facts) 46 | else: 47 | logger.debug( 48 | "Stopped Execution. Empty answer: {}\n" 49 | "Question: {}\n Step: {}\n Assignment: {}".format( 50 | answers, new_question, step.to_json(), curr_assignment 51 | ) 52 | ) 53 | return {} 54 | return curr_assignment 55 | -------------------------------------------------------------------------------- /commaqa/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/commaqa/dataset/__init__.py -------------------------------------------------------------------------------- /commaqa/dataset/generate_decomposition_predictions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from copy import deepcopy 4 | from math import ceil 5 | from random import shuffle 6 | 7 | from commaqa.configs.predicate_language_config import ModelQuestionConfig 8 | from commaqa.dataset.utils import nonempty_answer 9 | from commaqa.execution.operation_executer import OperationExecuter 10 | from commaqa.execution.utils import build_models 11 | 12 | 13 | def parse_arguments(): 14 | arg_parser = argparse.ArgumentParser(description="Solve a ReModeL dataset using composition") 15 | arg_parser.add_argument("--input_json", type=str, required=True, help="Input JSON dataset files") 16 | arg_parser.add_argument("--pred_json", type=str, required=False, help="Output predictions") 17 | arg_parser.add_argument("--decomp_json", type=str, required=False, help="Output decompositions") 18 | arg_parser.add_argument( 19 | "--max_examples", 20 | type=float, 21 | required=False, 22 | default=1.0, 23 | help="Maximum number of examples to use. " "If set to <=1.0, use as fraction.", 24 | ) 25 | return arg_parser.parse_args() 26 | 27 | 28 | def build_chain(prev_chain, operation, model, question): 29 | return prev_chain + " QS: ({}) [{}] {}".format(operation, model, question) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parse_arguments() 34 | with open(args.input_json, "r") as input_fp: 35 | input_json = json.load(input_fp) 36 | 37 | pred_json = {} 38 | decomp_json = [] 39 | for input_item in input_json: 40 | kb = input_item["kb"] 41 | model_configurations = {} 42 | for model_name, configs in input_item["pred_lang_config"].items(): 43 | model_configurations[model_name] = [ModelQuestionConfig(config) for config in configs] 44 | model_lib = build_models(model_configurations, kb) 45 | 46 | executor = OperationExecuter(model_lib) 47 | for qa_pair in input_item["qa_pairs"]: 48 | qid = qa_pair["id"] 49 | # use oracle decomposition 50 | curr_assignment = {} 51 | last_answer = "" 52 | train_seqs = [] 53 | prev_chain = " QC: " + qa_pair["question"] 54 | for idx, step in enumerate(qa_pair["decomposition"]): 55 | train_seq = build_chain( 56 | prev_chain=prev_chain, operation=step["op"], model=step["m"], question=step["q"] 57 | ) 58 | train_seqs.append(train_seq) 59 | answers, facts_used = executor.execute_operation( 60 | operation=step["op"], model=step["m"], question=step["q"], assignments=curr_assignment 61 | ) 62 | last_answer = answers 63 | if not nonempty_answer(answers): 64 | print("no answer!") 65 | print(step, curr_assignment, kb) 66 | break 67 | prev_chain = train_seq.replace(" QS: ", " QI: ") + " A: " + json.dumps(answers) 68 | curr_assignment["#" + str(idx + 1)] = answers 69 | train_seqs.append(prev_chain + " QS: [EOQ]") 70 | decomp = deepcopy(qa_pair) 71 | decomp["train_seqs"] = train_seqs 72 | decomp_json.append(decomp) 73 | if isinstance(last_answer, list): 74 | pred_json[qid] = last_answer 75 | else: 76 | pred_json[qid] = str(last_answer) 77 | 78 | if args.pred_json: 79 | with open(args.pred_json, "w") as output_fp: 80 | json.dump(pred_json, output_fp, indent=2) 81 | if args.decomp_json: 82 | # sample examples here as they will be ungrouped 83 | if args.max_examples < 1.0: 84 | shuffle(decomp_json) 85 | decomp_json = decomp_json[: ceil(len(decomp_json) * args.max_examples)] 86 | elif args.max_examples > 1.0: 87 | shuffle(decomp_json) 88 | decomp_json = decomp_json[: args.max_examples] 89 | 90 | with open(args.decomp_json, "w") as output_fp: 91 | for decomp in decomp_json: 92 | output_fp.write(json.dumps(decomp) + "\n") 93 | -------------------------------------------------------------------------------- /commaqa/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | 4 | pred_match = re.compile("(.*)\((.*)\)$") 5 | 6 | 7 | def get_answer_indices(question_str): 8 | return [int(m.group(1)) for m in re.finditer("#(\d)", question_str)] 9 | 10 | 11 | def get_question_indices(question_str): 12 | return [int(m.group(1)) for m in re.finditer("\$(\d)", question_str)] 13 | 14 | 15 | def is_question_var(var_name): 16 | return var_name.startswith("$") 17 | 18 | 19 | def get_predicate_args(predicate_str): 20 | mat = pred_match.match(predicate_str) 21 | if mat is None: 22 | return None, None 23 | predicate = mat.group(1) 24 | pred_args = mat.group(2).split(", ") if " | " not in mat.group(2) else mat.group(2).split(" | ") 25 | return predicate, pred_args 26 | 27 | 28 | def flatten_list(input_list): 29 | output_list = [] 30 | for item in input_list: 31 | if isinstance(item, list): 32 | output_list.extend(flatten_list(item)) 33 | else: 34 | output_list.append(item) 35 | return output_list 36 | 37 | 38 | def align_assignments(target_predicate, source_predicate, source_assignments): 39 | """ 40 | Returns a (map from target_predicate arg name to the assignment in source_assignments), 41 | (map from target_predicate arg name to the source predicate arg) 42 | """ 43 | target_pred, target_args = get_predicate_args(target_predicate) 44 | source_pred, source_args = get_predicate_args(source_predicate) 45 | if target_pred != source_pred: 46 | raise ValueError( 47 | "Source predicate: {} does not match target predicate: {}".format(source_predicate, target_predicate) 48 | ) 49 | if len(target_args) != len(source_args): 50 | raise ValueError( 51 | "Number of target arguments: {} don't match source arguments: {}".format(target_args, source_args) 52 | ) 53 | target_assignment = {} 54 | target_assignment_map = {} 55 | for target_arg, source_arg in zip(target_args, source_args): 56 | if source_arg == "?": 57 | if target_arg != "?": 58 | raise ValueError( 59 | "Source ({}) and Target ({}) predicates have mismatch" 60 | " on '?'".format(source_predicate, target_predicate) 61 | ) 62 | continue 63 | if source_arg not in source_assignments: 64 | raise ValueError("No assignment for {} in input assignments: {}".format(source_arg, source_assignments)) 65 | target_assignment[target_arg] = source_assignments[source_arg] 66 | target_assignment_map[target_arg] = source_arg 67 | return target_assignment, target_assignment_map 68 | 69 | 70 | def dict_product(dicts): 71 | return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values())) 72 | 73 | 74 | def nonempty_answer(answer): 75 | if isinstance(answer, list) and len(answer) == 0: 76 | return False 77 | if isinstance(answer, str) and answer == "": 78 | return False 79 | return True 80 | 81 | 82 | NOANSWER = None 83 | 84 | 85 | def valid_answer(answer): 86 | return answer is not None 87 | -------------------------------------------------------------------------------- /commaqa/datasets_utils/build_reverse_dataset.py: -------------------------------------------------------------------------------- 1 | # Reverse the sequence "card, stamp, book, water, glasses". 2 | 3 | import json 4 | import math 5 | import random 6 | import string 7 | 8 | # TODO make configurable 9 | reverse_letters = True 10 | num_examples = 100 11 | length_range = range(4, 5) 12 | 13 | 14 | def main(): 15 | with open("configs/reverse_datasets/wordlist.txt", "r") as f: 16 | wordlist = list(map(str.strip, f)) 17 | for list_length in length_range: 18 | if reverse_letters: 19 | input_arr = string.ascii_lowercase 20 | else: 21 | input_arr = wordlist 22 | # number of permutations 23 | permutation_count = math.perm(len(input_arr), list_length) 24 | # select num_examples permutation indexes 25 | permutation_idxs = random.sample(range(permutation_count), num_examples) 26 | # create num_examples permutations 27 | sublists = (get_permutation(i, input_arr, list_length) for i in permutation_idxs) 28 | qa_pairs = list() 29 | for sublist in sublists: 30 | if reverse_letters: 31 | question_word = "".join(sublist) 32 | question = "Reverse the letters in the word {}".format(question_word) 33 | answer = "".join(reversed(sublist)) 34 | else: 35 | comma_separated_sublist = ", ".join(sublist) 36 | question = f'Reverse the sequence "{comma_separated_sublist}".' 37 | answer = ", ".join(reversed(sublist)) 38 | qa_pairs.append((question, answer)) 39 | drop = { 40 | "reverseqa": { 41 | "passage": "", 42 | "qa_pairs": [ 43 | { 44 | "question": question, 45 | "answer": { 46 | "number": "", 47 | "date": {"day": "", "month": "", "year": ""}, 48 | "spans": [answer], 49 | }, 50 | "query_id": str(i), 51 | "validated_answers": [], 52 | } 53 | for i, (question, answer) in enumerate(qa_pairs) 54 | ], 55 | } 56 | } 57 | filename = "reverse_{}_{}_{}.json".format(list_length, num_examples, "L" if reverse_letters else "W") 58 | with open(filename, "w") as f: 59 | json.dump(drop, f, indent=2) 60 | 61 | 62 | def get_permutation(i, lst, length): 63 | permutation = list() 64 | for _ in range(length): 65 | i, idx = divmod(i, len(lst)) 66 | permutation.append(lst[idx]) 67 | # remove to prevent duplicates 68 | lst = lst[:idx] + lst[idx + 1 :] 69 | return permutation 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /commaqa/datasets_utils/convert_gsm8k_to_drop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | import locale 5 | 6 | locale.setlocale(locale.LC_ALL, "en_US.UTF-8") 7 | 8 | 9 | examples = [] 10 | with open(sys.argv[1], "r") as input_fp: 11 | for line in input_fp: 12 | input_json = json.loads(line.strip()) 13 | question = input_json["question"] 14 | rationale = input_json["answer"] 15 | answer = rationale.split("####")[-1].strip() 16 | try: 17 | answer = locale.atoi(answer) 18 | except ValueError: 19 | try: 20 | answer = locale.atof(answer) 21 | except ValueError: 22 | print("Can not parse: " + answer) 23 | examples.append((question, answer, rationale)) 24 | 25 | qa_pairs = [] 26 | for eg_idx, eg in enumerate(examples): 27 | drop_answer = {"number": eg[1], "date": {"day": "", "month": "", "year": ""}, "spans": []} 28 | qa_pairs.append({"question": eg[0], "answer": drop_answer, "query_id": str(eg_idx), "rationale": eg[2]}) 29 | 30 | output_json = {"1": {"passage": "", "qa_pairs": qa_pairs}} 31 | with open(sys.argv[2], "w") as output_fp: 32 | output_fp.write(json.dumps(output_json, indent=2)) 33 | -------------------------------------------------------------------------------- /commaqa/datasets_utils/subselect_drop_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | 5 | 6 | def parse_arguments(): 7 | arg_parser = argparse.ArgumentParser(description="Convert HotPotQA dataset into SQUAD format") 8 | arg_parser.add_argument("--input", type=str, required=True, help="Input JSON file") 9 | arg_parser.add_argument("--output", type=str, required=True, help="Output JSON file") 10 | arg_parser.add_argument( 11 | "--atype", 12 | type=str, 13 | action="append", 14 | choices=["SPAN", "SPANS", "DATE", "NUMBER"], 15 | help="Select questions with answers of this type.", 16 | ) 17 | arg_parser.add_argument( 18 | "--nth", 19 | type=int, 20 | required=False, 21 | action="append", 22 | help="Select questions from the nth (of k) paras." " Can provide multiple values.", 23 | ) 24 | arg_parser.add_argument( 25 | "--k", type=int, required=False, default=1, help="Select questions from the nth (of k) paras" 26 | ) 27 | arg_parser.add_argument( 28 | "--qprob", type=float, required=False, default=1.0, help="Select questions with this probability" 29 | ) 30 | arg_parser.add_argument("--maxq", type=int, required=False, default=-1, help="Select upto these many questions") 31 | arg_parser.add_argument("--ids", type=str, required=False, help="File of question ids to select") 32 | arg_parser.add_argument("--incl_para", type=str, required=False, help="File of para ids that must be included") 33 | arg_parser.add_argument("--excl_para", type=str, required=False, help="File of para ids that must be excluded") 34 | return arg_parser.parse_args() 35 | 36 | 37 | def get_answer_types(answer_json): 38 | atypes = [] 39 | 40 | if len(answer_json["spans"]) > 1: 41 | atypes.append("SPANS") 42 | elif len(answer_json["spans"]) == 1: 43 | atypes.append("SPAN") 44 | elif answer_json["number"] != "": 45 | atypes.append("NUMBER") 46 | else: 47 | date_json = answer_json["date"] 48 | if not (date_json["day"] or date_json["month"] or date_json["year"]): 49 | print("Number, Span or Date not set in {}".format(answer_json)) 50 | else: 51 | atypes.append("DATE") 52 | return atypes 53 | 54 | 55 | if __name__ == "__main__": 56 | args = parse_arguments() 57 | with open(args.input, "r") as input_fp: 58 | input_json = json.load(input_fp) 59 | ids = None 60 | incl_para = None 61 | excl_para = None 62 | if args.ids: 63 | with open(args.ids, "r") as input_fp: 64 | lines = input_fp.read() 65 | ids = set(lines.split("\n")) 66 | if args.incl_para: 67 | with open(args.incl_para, "r") as input_fp: 68 | lines = input_fp.read() 69 | incl_para = set(lines.split("\n")) 70 | if args.excl_para: 71 | with open(args.excl_para, "r") as input_fp: 72 | lines = input_fp.read() 73 | excl_para = set(lines.split("\n")) 74 | para_counter = 0 75 | num_questions = 0 76 | output_json = {} 77 | 78 | for paraid, item in input_json.items(): 79 | para_counter += 1 80 | 81 | if args.nth is not None and para_counter % args.k not in args.nth: 82 | # if these question should be included, dont skip 83 | if not (incl_para is not None and paraid in incl_para): 84 | continue 85 | if excl_para is not None and paraid in excl_para: 86 | continue 87 | para = item["passage"] 88 | out_pairs = [] 89 | for qa_pair in item["qa_pairs"]: 90 | if num_questions >= args.maxq > 0: 91 | break 92 | question = qa_pair["question"] 93 | qid = qa_pair["query_id"] 94 | if ids is not None: 95 | if qid not in ids: 96 | continue 97 | 98 | a_accept = True 99 | 100 | if args.atype is not None: 101 | a_accept = False 102 | atypes = get_answer_types(qa_pair["answer"]) 103 | if len(atypes) == 0: 104 | # no matching answer type 105 | print(qid, question) 106 | for atype in atypes: 107 | if atype in args.atype: 108 | a_accept = True 109 | break 110 | 111 | if a_accept and random.random() < args.qprob: 112 | out_pairs.append(qa_pair) 113 | num_questions += 1 114 | 115 | if len(out_pairs): 116 | item["qa_pairs"] = out_pairs 117 | output_json[paraid] = item 118 | 119 | with open(args.output, "w") as output_fp: 120 | json.dump(output_json, output_fp, indent=2) 121 | print("Num output questions: {}".format(num_questions)) 122 | -------------------------------------------------------------------------------- /commaqa/execution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/commaqa/execution/__init__.py -------------------------------------------------------------------------------- /commaqa/execution/constants.py: -------------------------------------------------------------------------------- 1 | MATH_MODEL = "math_special" 2 | KBLOOKUP_MODEL = "kblookup" 3 | -------------------------------------------------------------------------------- /commaqa/execution/kblookup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from commaqa.dataset.utils import get_predicate_args 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class KBLookup: 9 | def __init__(self, kb): 10 | self.kb = kb 11 | 12 | def ask_question(self, question_predicate, context=None): 13 | if context: 14 | raise ValueError("Input context passed to KBLookup which does not use context!" + "\n{}".format(context)) 15 | return self.ask_question_predicate(question_predicate) 16 | 17 | def ask_question_predicate(self, question_predicate): 18 | predicate, pred_args = get_predicate_args(question_predicate) 19 | answers = [] 20 | facts_used = [] 21 | for fact in self.kb[predicate]: 22 | fact_pred, fact_args = get_predicate_args(fact) 23 | if len(pred_args) != len(fact_args): 24 | raise ValueError("Mismatch in specification args {} and fact args {}".format(pred_args, fact_args)) 25 | mismatch = False 26 | answer = "" 27 | for p, f in zip(pred_args, fact_args): 28 | # KB fact arg doesn't match the predicate arg 29 | if p != "?" and p != f and p != "_": 30 | mismatch = True 31 | # predicate arg is a query, populate answer with fact arg 32 | elif p == "?": 33 | answer = f 34 | # if all args matched, add answer 35 | if not mismatch: 36 | answers.append(answer) 37 | facts_used.append(fact) 38 | if len(answers) == 0: 39 | logger.debug("No matching facts for {}. Facts:\n{}".format(question_predicate, self.kb[predicate])) 40 | 41 | # If its a boolean query, use number of answers 42 | if "?" not in pred_args: 43 | if len(answers) == 0: 44 | return "no", facts_used 45 | else: 46 | return "yes", facts_used 47 | else: 48 | return answers, facts_used 49 | -------------------------------------------------------------------------------- /commaqa/execution/llm_qa_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | 4 | from commaqa.inference.prompt_reader import read_prompt 5 | from commaqa.models.gpt3generator import GPT3Generator 6 | from commaqa.models.llm_client_generator import LLMClientGenerator 7 | 8 | 9 | class LLMQAModel: 10 | def __init__( 11 | self, prompt_file="", prompt_reader_args=None, regex_extract=None, add_context=True, gen_model="gpt3", **kwargs 12 | ): 13 | if prompt_file: 14 | prompt_reader_args = prompt_reader_args or {} 15 | prompt_reader_args["file_path"] = prompt_file 16 | self.prompt = read_prompt(**prompt_reader_args) 17 | else: 18 | self.prompt = None 19 | if gen_model == "gpt3": 20 | self.generator = GPT3Generator(**kwargs) 21 | elif gen_model == "llm_api": 22 | self.generator = LLMClientGenerator(**kwargs) 23 | else: 24 | raise ValueError("Unknown gen_model: " + gen_model) 25 | 26 | self.num_calls = 0 27 | self.regex_extract = regex_extract 28 | self.add_context = add_context 29 | 30 | def ask_question(self, input_question, context, context_suffix=""): 31 | question_prompt = self.prompt + "\n" # remove "\n" to remove \n\n\n delimiter. 32 | if context and self.add_context: 33 | # TODO Hack!! Needs a better fix 34 | m = re.match(" *PARA_([0-9]+) (.*)", input_question) 35 | if m: 36 | assert isinstance(context, list) 37 | context = context[int(m.group(1))] 38 | input_question = m.group(2) 39 | elif isinstance(context, list): 40 | context = "\n\n".join(context) 41 | if context: 42 | question_prompt += "\n\n" + context + context_suffix 43 | 44 | question_prompt += "\n\nQ: " + input_question + "\nA:" 45 | # print(": ... %s" % question_prompt[-500:]) 46 | output_text_scores = self.generator.generate_text_sequence(question_prompt) 47 | 48 | self.num_calls += 1 49 | if len(output_text_scores) > 1: 50 | print("Can not handle more than one answer for QA model yet" + "\n" + str(output_text_scores)) 51 | output_text_scores = [output_text_scores[0]] 52 | 53 | # only answer string 54 | answer_str = output_text_scores[0][0].strip() 55 | if self.regex_extract: 56 | m = re.match(self.regex_extract, answer_str) 57 | if m: 58 | answer_str = m.group(1).strip() 59 | else: 60 | # No match 61 | print("Did not find a match for input regex: {} in {}".format(self.regex_extract, answer_str)) 62 | return "", [] 63 | try: 64 | json_answer = json.loads(answer_str) 65 | return json_answer, [] 66 | except ValueError: 67 | # Not a valid json ignore 68 | pass 69 | return answer_str, [] 70 | -------------------------------------------------------------------------------- /commaqa/execution/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from commaqa.execution.constants import MATH_MODEL 4 | from commaqa.execution.kblookup import KBLookup 5 | from commaqa.execution.math_model import MathModel 6 | from commaqa.execution.model_executer import ModelExecutor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def build_models(pred_lang_config, complete_kb, ignore_input_mismatch=False): 12 | model_library = {} 13 | kblookup = KBLookup(kb=complete_kb) 14 | for model_name, configs in pred_lang_config.items(): 15 | if model_name == MATH_MODEL: 16 | model = MathModel( 17 | predicate_language=configs, 18 | model_name=model_name, 19 | kblookup=kblookup, 20 | ignore_input_mismatch=ignore_input_mismatch, 21 | ) 22 | else: 23 | model = ModelExecutor( 24 | predicate_language=configs, 25 | model_name=model_name, 26 | kblookup=kblookup, 27 | ignore_input_mismatch=ignore_input_mismatch, 28 | ) 29 | model_library[model_name] = model 30 | return model_library 31 | -------------------------------------------------------------------------------- /commaqa/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/commaqa/inference/__init__.py -------------------------------------------------------------------------------- /commaqa/inference/constants.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from commaqa.inference.dataset_readers import DatasetReader, MultiParaRCReader 4 | from commaqa.inference.participant_qa import LLMQAParticipantModel 5 | from commaqa.inference.ircot import ( 6 | AnswerExtractor, 7 | CopyQuestionParticipant, 8 | RetrieveAndResetParagraphsParticipant, 9 | StepByStepCOTGenParticipant, 10 | StepByStepLLMTitleGenParticipant, 11 | StepByStepExitControllerParticipant, 12 | ) 13 | 14 | MODEL_NAME_CLASS = { 15 | "answer_extractor": AnswerExtractor, 16 | "copy_question": CopyQuestionParticipant, 17 | "llmqa": LLMQAParticipantModel, 18 | "retrieve_and_reset_paragraphs": RetrieveAndResetParagraphsParticipant, 19 | "step_by_step_cot_gen": StepByStepCOTGenParticipant, 20 | "step_by_step_llm_title_gen": StepByStepLLMTitleGenParticipant, 21 | "step_by_step_exit_controller": StepByStepExitControllerParticipant, 22 | } 23 | 24 | READER_NAME_CLASS: Dict[str, DatasetReader] = { 25 | "multi_para_rc": MultiParaRCReader, 26 | } 27 | 28 | PREDICTION_TYPES = {"answer", "titles", "pids"} 29 | -------------------------------------------------------------------------------- /commaqa/inference/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Dict 3 | 4 | from nltk import word_tokenize 5 | from nltk.corpus import stopwords 6 | from nltk.stem.porter import PorterStemmer 7 | 8 | stemmer = PorterStemmer() 9 | 10 | stop_words_set = set(stopwords.words("english")) 11 | 12 | QUESTION_MARKER = " Q: " 13 | COMPQ_MARKER = " QC: " 14 | SIMPQ_MARKER = " QS: " 15 | INTERQ_MARKER = " QI: " 16 | ANSWER_MARKER = " A: " 17 | EOQ_MARKER = "[EOQ]" 18 | LIST_JOINER = " + " 19 | BLANK = "__" 20 | WH_WORDS = set(["who", "what", "where", "how", "why", "when", "which"]) 21 | 22 | 23 | def get_sequence_representation( 24 | origq: str, 25 | question_seq: List[str], 26 | answer_seq: List[str], 27 | compq_marker: str = COMPQ_MARKER, 28 | interq_marker: str = INTERQ_MARKER, 29 | answer_marker: str = ANSWER_MARKER, 30 | simpq_marker: str = SIMPQ_MARKER, 31 | ): 32 | ret_seq = compq_marker + origq 33 | if len(question_seq) != len(answer_seq): 34 | raise ValueError( 35 | "Number of generated questions and answers should match before" 36 | "question generation. Qs: {} As: {}".format(question_seq, answer_seq) 37 | ) 38 | 39 | for aidx in range(len(answer_seq)): 40 | ret_seq += interq_marker 41 | ret_seq += question_seq[aidx] 42 | ret_seq += answer_marker + answer_seq[aidx] 43 | ret_seq += simpq_marker 44 | return ret_seq 45 | 46 | 47 | def tokenize_str(input_str): 48 | return word_tokenize(input_str) 49 | 50 | 51 | def stem_tokens(token_arr): 52 | return [stemmer.stem(token) for token in token_arr] 53 | 54 | 55 | def filter_stop_tokens(token_arr): 56 | return [token for token in token_arr if token not in stop_words_set] 57 | 58 | 59 | def stem_filter_tokenization(input_str): 60 | return stem_tokens(filter_stop_tokens(tokenize_str(input_str.lower()))) 61 | 62 | 63 | # functions borrowed from AllenNLP to parse JSONNET with env vars 64 | def get_environment_variables() -> Dict[str, str]: 65 | """ 66 | Wraps `os.environ` to filter out non-encodable values. 67 | """ 68 | return {key: value for key, value in os.environ.items() if _is_encodable(value)} 69 | 70 | 71 | def _is_encodable(value: str) -> bool: 72 | """ 73 | We need to filter out environment variables that can't 74 | be unicode-encoded to avoid a "surrogates not allowed" 75 | error in jsonnet. 76 | """ 77 | # Idiomatically you'd like to not check the != b"" 78 | # but mypy doesn't like that. 79 | return (value == "") or (value.encode("utf-8", "ignore") != b"") 80 | -------------------------------------------------------------------------------- /commaqa/models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoConfig, AutoTokenizer, AutoModelWithLMHead 3 | from transformers.generation_utils import SampleEncoderDecoderOutput 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class LMGenerator: 10 | def __init__(self, model_path, device=None, generation_args={}, encoder_args={}, decoder_args={}): 11 | if device is None: 12 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | else: 14 | self.device = device 15 | 16 | self.config = AutoConfig.from_pretrained(model_path) 17 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 18 | self.model = AutoModelWithLMHead.from_pretrained(model_path, config=self.config).to(self.device) 19 | self.generation_args = generation_args 20 | # always generate output with scores 21 | self.generation_args["output_scores"] = True 22 | self.generation_args["return_dict_in_generate"] = True 23 | self.encoder_args = encoder_args 24 | self.decoder_args = decoder_args 25 | 26 | def generate_text_sequence(self, input_text): 27 | """ 28 | :param input_text: 29 | :return: returns a sequence of tuples (string, score) where lower score is better 30 | """ 31 | encoded_prompt = self.tokenizer.encode(input_text, **self.encoder_args) 32 | 33 | encoded_prompt = encoded_prompt.to(self.device) 34 | generated_dict = self.model.generate(input_ids=encoded_prompt, **self.generation_args) 35 | 36 | generated_seqs = generated_dict.sequences 37 | if isinstance(generated_dict, SampleEncoderDecoderOutput): 38 | logger.warning("No scores generated when sampled sequences") 39 | generated_scores = [0] * len(generated_seqs) 40 | else: 41 | generated_scores = generated_dict.sequences_scores.tolist() 42 | if len(generated_seqs.shape) > 2: 43 | generated_seqs.squeeze_() 44 | 45 | output_seq_score = [] 46 | 47 | for generated_sequence_idx, generated_seq in enumerate(generated_seqs): 48 | generated_output = generated_seq.tolist() 49 | text = self.tokenizer.decode(generated_output, **self.decoder_args) 50 | # flip the negative logit so that sequence with lowest scores is best 51 | output_seq_score.append((text, -generated_scores[generated_sequence_idx])) 52 | 53 | # Ensure sorted output 54 | return sorted(output_seq_score, key=lambda x: x[1]) 55 | -------------------------------------------------------------------------------- /download/official_eval.sh: -------------------------------------------------------------------------------- 1 | git clone https://github.com/hotpotqa/hotpot official_evaluation/hotpotqa 2 | cd official_evaluation/hotpotqa ; git checkout 3635853403a8735609ee997664e1528f4480762a 3 | cd ../.. 4 | 5 | git clone https://github.com/Alab-NII/2wikimultihop official_evaluation/2wikimultihopqa 6 | cd official_evaluation/2wikimultihopqa ; git checkout 6bdd033bd51aae2d36ba939688c651b5c54ec28a 7 | cd ../.. 8 | 9 | git clone https://github.com/stonybrooknlp/musique official_evaluation/musique 10 | cd official_evaluation/musique ; git checkout 24cc5b297acc2abfc5fb3d0becb6ef7b73d03717 11 | cd ../.. 12 | -------------------------------------------------------------------------------- /download/processed_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | # If gdown doesn't work, you can download files from mentioned URLs manually 7 | # and put them at appropriate locations. 8 | pip install gdown 9 | 10 | mkdir -p .temp/ 11 | 12 | # URL: https://drive.google.com/file/d/1t2BjJtsejSIUZI54PKObMFG6_wMMG3bC/view?usp=sharing 13 | gdown "1t2BjJtsejSIUZI54PKObMFG6_wMMG3bC&confirm=t" -O .temp/processed_data.zip 14 | unzip -o .temp/processed_data.zip -x "*.DS_Store" 15 | 16 | rm -rf .temp/ 17 | 18 | # The resulting processed_data/ directory should look like: 19 | # ├── 2wikimultihopqa 20 | # │   ├── annotated_only_train.jsonl 21 | # │   ├── dev.jsonl 22 | # │   ├── dev_subsampled.jsonl 23 | # │   ├── test_subsampled.jsonl 24 | # │   └── train.jsonl 25 | # ├── hotpotqa 26 | # │   ├── annotated_only_train.jsonl 27 | # │   ├── dev.jsonl 28 | # │   ├── dev_subsampled.jsonl 29 | # │   ├── test_subsampled.jsonl 30 | # │   └── train.jsonl 31 | # ├── iirc 32 | # │   ├── annotated_only_train.jsonl 33 | # │   ├── dev.jsonl 34 | # │   ├── dev_subsampled.jsonl 35 | # │   ├── test_subsampled.jsonl 36 | # │   └── train.jsonl 37 | # └── musique 38 | # ├── annotated_only_train.jsonl 39 | # ├── dev.jsonl 40 | # ├── dev_subsampled.jsonl 41 | # ├── test_subsampled.jsonl 42 | # └── train.jsonl 43 | -------------------------------------------------------------------------------- /download/raw_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # If gdown doesn't work, you can download files from mentioned URLs manually 4 | # and put them at appropriate locations. 5 | pip install gdown 6 | 7 | mkdir -p .temp/ 8 | mkdir -p raw_data 9 | 10 | echo "\n\nDownloading raw hotpotqa data\n" 11 | mkdir -p raw_data/hotpotqa 12 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json -O raw_data/hotpotqa/hotpot_train_v1.1.json 13 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json -O raw_data/hotpotqa/hotpot_dev_distractor_v1.json 14 | 15 | echo "\n\nDownloading raw 2wikimultihopqa data\n" 16 | mkdir -p raw_data/2wikimultihopqa 17 | wget https://www.dropbox.com/s/7ep3h8unu2njfxv/data_ids.zip?dl=0 -O .temp/2wikimultihopqa.zip 18 | unzip -jo .temp/2wikimultihopqa.zip -d raw_data/2wikimultihopqa -x "*.DS_Store" 19 | rm data_ids.zip* 20 | 21 | echo "\n\nDownloading raw musique data\n" 22 | mkdir -p raw_data/musique 23 | # URL: https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing 24 | gdown "1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h&confirm=t" -O .temp/musique_v1.0.zip 25 | unzip -jo .temp/musique_v1.0.zip -d raw_data/musique -x "*.DS_Store" 26 | 27 | echo "\n\nDownloading raw iirc data\n" 28 | mkdir -p raw_data/iirc 29 | wget https://iirc-dataset.s3.us-west-2.amazonaws.com/iirc_train_dev.tgz -O .temp/iirc_train_dev.tgz 30 | tar -xzvf .temp/iirc_train_dev.tgz -C .temp/ 31 | mv .temp/iirc_train_dev/train.json raw_data/iirc/train.json 32 | mv .temp/iirc_train_dev/dev.json raw_data/iirc/dev.json 33 | 34 | echo "\n\nDownloading iirc wikipedia corpus (this will take 2-3 mins)\n" 35 | wget https://iirc-dataset.s3.us-west-2.amazonaws.com/context_articles.tar.gz -O .temp/context_articles.tar.gz 36 | tar -xzvf .temp/context_articles.tar.gz -C raw_data/iirc 37 | 38 | echo "\n\nDownloading hotpotqa wikipedia corpus (this will take ~5 mins)\n" 39 | wget https://nlp.stanford.edu/projects/hotpotqa/enwiki-20171001-pages-meta-current-withlinks-abstracts.tar.bz2 -O .temp/wikpedia-paragraphs.tar.bz2 40 | tar -xvf .temp/wikpedia-paragraphs.tar.bz2 -C raw_data/hotpotqa 41 | mv raw_data/hotpotqa/enwiki-20171001-pages-meta-current-withlinks-abstracts raw_data/hotpotqa/wikpedia-paragraphs 42 | 43 | rm -rf .temp/ 44 | 45 | # The resulting raw_data/ directory should look like: 46 | # ── 2wikimultihopqa 47 | # │   ├── dev.json 48 | # │   ├── id_aliases.json 49 | # │   ├── test.json 50 | # │   └── train.json 51 | # ├── hotpotqa 52 | # │   ├── dev_random_20_single_hop_annotations.txt 53 | # │   ├── wikpedia-paragraphs/ 54 | # │   ├── ├── ... 55 | # │   ├── hotpot_dev_distractor_v1.json 56 | # │   └── train_random_20_single_hop_annotations.txt 57 | # ├── iirc 58 | # │   ├── context_articles.json 59 | # │   ├── dev.json 60 | # │   └── train.json 61 | # └── musique 62 | #    ├── dev_test_singlehop_questions_v1.0.json 63 | #    ├── musique_ans_v1.0_dev.jsonl 64 | #    ├── musique_ans_v1.0_test.jsonl 65 | #    ├── musique_ans_v1.0_train.jsonl 66 | #    ├── musique_full_v1.0_dev.jsonl 67 | #    ├── musique_full_v1.0_test.jsonl 68 | #    └── musique_full_v1.0_train.jsonl 69 | -------------------------------------------------------------------------------- /ircot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/ircot.jpg -------------------------------------------------------------------------------- /llm_server/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | .history 133 | .temp/ -------------------------------------------------------------------------------- /llm_server/Dockerfile: -------------------------------------------------------------------------------- 1 | # https://github.com/allenai/docker-images 2 | # https://github.com/allenai/docker-images/pkgs/container/cuda/24038895?tag=11.2-ubuntu20.04-v0.0.15 3 | FROM ghcr.io/allenai/cuda:11.2-ubuntu20.04-v0.0.15 4 | 5 | RUN apt-get update \ 6 | && DEBIAN_FRONTEND=noninteractive \ 7 | apt-get install --no-install-recommends --assume-yes \ 8 | protobuf-compiler 9 | 10 | # Install transformers 11 | RUN conda install pytorch=1.12.0 cudatoolkit=11.3 -c pytorch # needed for cuda11.3 12 | # This is 4.26.0.dev0 (needed the latest version for flan related fix) 13 | RUN pip install git+https://github.com/huggingface/transformers.git@8637316e5e94ba0a2493e5df7846f2f23f46eaef 14 | RUN pip install accelerate==0.15.0 15 | RUN pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda115 16 | RUN pip install sentencepiece 17 | RUN pip install protobuf==3.20.1 # needed to avoid error. 18 | 19 | # Skipping deepspeed for now as it's not getting install correctly. 20 | # RUN DS_BUILD_OPS=1 pip install git+https://github.com/microsoft/DeepSpeed.git@d9b788d773ce97281ee63064cc99993cb82397e2 21 | 22 | RUN pip install fastapi 23 | RUN pip install "uvicorn[standard]" 24 | 25 | COPY serve_models /run/serve_models/ 26 | COPY constants.py /run/constants.py 27 | 28 | # To run the server directly: 29 | ENTRYPOINT ["uvicorn", "serve_models.main:app", "--host", "0.0.0.0", "--port", "8000", "--app-dir", "/run/"] 30 | 31 | # To run bash: 32 | # ENTRYPOINT ["bash", "-l"] 33 | -------------------------------------------------------------------------------- /llm_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/llm_server/__init__.py -------------------------------------------------------------------------------- /llm_server/client.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | 4 | 5 | def main(): 6 | 7 | # This is just an example. Change as necessary. 8 | host = "http://aristo-cirrascale-13.reviz.ai2.in" 9 | port = 49171 10 | 11 | params = {"prompt": "Hello, I am unconscious and"} # see other arguments in serve_models/main:generate 12 | response = requests.get(host + ":" + str(port) + "/generate", params=params) 13 | result = response.json() 14 | 15 | result.get("message", "") 16 | result.get("generated_text", "") 17 | result.get("model_name", "") # To assure that response is from the right model. 18 | 19 | print(json.dumps(result, indent=4)) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /llm_server/constants.py: -------------------------------------------------------------------------------- 1 | TRANSFORMERS_CACHE = "~/.cache/huggingface/transformers" 2 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StonyBrookNLP/ircot/3c1820f698eea5eeddb4fba3c56b64c961e063e4/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/answer_support_recall.py: -------------------------------------------------------------------------------- 1 | """ 2 | Answer support recall as a measure of retrieval performance. 3 | """ 4 | from typing import Tuple, List 5 | import re 6 | 7 | from metrics.metric import Metric 8 | from metrics.squad_answer_em_f1 import normalize_answer 9 | 10 | 11 | class AnswerSupportRecallMetric(Metric): 12 | """ 13 | AnswerSupportRecall: Recall of the presense of the answer/s in the retrieved paras. 14 | """ 15 | 16 | def __init__(self) -> None: 17 | self._total_count = 0 18 | self._total_num_retrieved_paras = 0 19 | self._total_answer_support_recall = 0 20 | 21 | def __call__(self, predicted_paragraph_texts: List[str], gold_answers: List[str]): 22 | 23 | answer_covered_count = 0 24 | for gold_answer in gold_answers: 25 | for predicted_paragraph_text in predicted_paragraph_texts: 26 | 27 | def lower_clean_ws(e): 28 | return re.sub(" +", " ", e.lower().strip()) 29 | 30 | condition_1 = lower_clean_ws(gold_answer) in lower_clean_ws(predicted_paragraph_text) 31 | condition_2 = normalize_answer(gold_answer) in normalize_answer(predicted_paragraph_text) 32 | if condition_1 or condition_2: 33 | answer_covered_count += 1 34 | break 35 | 36 | answer_support_recall = answer_covered_count / len(gold_answers) 37 | self._total_answer_support_recall += answer_support_recall 38 | self._total_num_retrieved_paras += len(predicted_paragraph_texts) 39 | self._total_count += 1 40 | 41 | def get_metric(self, reset: bool = False) -> Tuple[float, float]: 42 | """ 43 | Returns 44 | ------- 45 | Average answer occurrence recall and number of paragraphs. 46 | """ 47 | 48 | avg_answer_support_recall = ( 49 | self._total_answer_support_recall / self._total_count if self._total_count > 0 else 0 50 | ) 51 | avg_retrieved_paras = self._total_num_retrieved_paras / self._total_count if self._total_count > 0 else 0 52 | 53 | avg_answer_support_recall = round(avg_answer_support_recall, 3) 54 | avg_retrieved_paras = round(avg_retrieved_paras, 3) 55 | 56 | if reset: 57 | self.reset() 58 | 59 | return { 60 | "answer_support_recall": avg_answer_support_recall, 61 | "avg_predicted_paras": avg_retrieved_paras, 62 | "count": self._total_count, 63 | } 64 | 65 | def reset(self): 66 | self._total_count = 0 67 | self._total_num_retrieved_paras = 0 68 | self._total_answer_support_recall = 0 69 | -------------------------------------------------------------------------------- /metrics/drop_answer_em_f1.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import ftfy 4 | from metrics.metric import Metric 5 | from metrics.drop_eval import ( 6 | get_metrics as drop_em_and_f1, 7 | ) 8 | from metrics.squad_answer_em_f1 import metric_max_over_ground_truths 9 | 10 | 11 | class DropAnswerEmAndF1(Metric): 12 | """ 13 | This :class:`Metric` takes the best span string computed by a model, along with the answer 14 | strings labeled in the data, and computes exact match and F1 score using the official DROP 15 | evaluator (which has special handling for numbers and for questions with multiple answer spans, 16 | among other things). 17 | """ 18 | 19 | def __init__(self) -> None: 20 | self._total_em = 0.0 21 | self._total_f1 = 0.0 22 | self._total_prec = 0.0 23 | self._total_recall = 0.0 24 | self._count = 0 25 | 26 | def __call__( 27 | self, 28 | predicted_answer_list: List[str], 29 | list_of_ground_truth_answer_list: List[List[str]], 30 | ): 31 | assert isinstance(predicted_answer_list, (list, tuple)) 32 | assert isinstance(list_of_ground_truth_answer_list, (list, tuple)) 33 | 34 | if not predicted_answer_list: 35 | predicted_answer_list = [""] 36 | 37 | assert isinstance(predicted_answer_list[0], str) 38 | assert isinstance(list_of_ground_truth_answer_list[0], (list, tuple)) 39 | assert isinstance(list_of_ground_truth_answer_list[0][0], str) 40 | 41 | predicted_answer_list = [ftfy.fix_text(e) for e in predicted_answer_list] 42 | list_of_ground_truth_answer_list = [ 43 | [ftfy.fix_text(e) for e in ground_truth_answer_list] 44 | for ground_truth_answer_list in list_of_ground_truth_answer_list 45 | ] 46 | 47 | exact_match, f1_score, prec_score, recall_score = metric_max_over_ground_truths( 48 | drop_em_and_f1, predicted_answer_list, list_of_ground_truth_answer_list 49 | ) 50 | 51 | # Converting to int here, since we want to count the number of exact matches. 52 | self._total_em += int(exact_match) 53 | self._total_f1 += f1_score 54 | self._total_prec += prec_score 55 | self._total_recall += recall_score 56 | self._count += 1 57 | 58 | def get_metric(self, reset: bool = False) -> Tuple[float, float]: 59 | """ 60 | Returns 61 | ------- 62 | Average exact match and F1 score (in that order) as computed by the official DROP script 63 | over all inputs. 64 | """ 65 | exact_match = self._total_em / self._count if self._count > 0 else 0 66 | f1_score = self._total_f1 / self._count if self._count > 0 else 0 67 | prec_score = self._total_prec / self._count if self._count > 0 else 0 68 | recall_score = self._total_recall / self._count if self._count > 0 else 0 69 | if reset: 70 | self.reset() 71 | return { 72 | "em": round(exact_match, 3), 73 | "f1": round(f1_score, 3), 74 | "precision": round(prec_score, 3), 75 | "recall": round(recall_score, 3), 76 | "count": self._count, 77 | } 78 | 79 | def reset(self): 80 | self._total_em = 0.0 81 | self._total_f1 = 0.0 82 | self._total_prec = 0.0 83 | self._total_recall = 0.0 84 | self._count = 0 85 | -------------------------------------------------------------------------------- /metrics/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | An abstract class representing a metric which can be accumulated. 3 | """ 4 | from typing import Any, Dict 5 | 6 | 7 | class Metric: 8 | """ 9 | An abstract class representing a metric which can be accumulated. 10 | """ 11 | 12 | def __call__(self, predictions: Any, gold_labels: Any): 13 | raise NotImplementedError 14 | 15 | def get_metric(self, reset: bool) -> Dict[str, Any]: 16 | """ 17 | Compute and return the metric. Optionally also call `self.reset`. 18 | """ 19 | raise NotImplementedError 20 | 21 | def reset(self) -> None: 22 | """ 23 | Reset any accumulators or internal state. 24 | """ 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /metrics/squad_answer_em_f1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Answer metric -- mostly taken directly from squad_tools of allennlp. 3 | """ 4 | import re 5 | import string 6 | import collections 7 | from typing import Tuple, List 8 | import ftfy 9 | 10 | from metrics.metric import Metric 11 | 12 | 13 | def normalize_answer(s): 14 | """Lower text and remove punctuation, articles and extra whitespace.""" 15 | 16 | def remove_articles(text): 17 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 18 | return re.sub(regex, " ", text) 19 | 20 | def white_space_fix(text): 21 | return " ".join(text.split()) 22 | 23 | def remove_punc(text): 24 | exclude = set(string.punctuation) 25 | return "".join(ch for ch in text if ch not in exclude) 26 | 27 | def lower(text): 28 | return text.lower() 29 | 30 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 31 | 32 | 33 | def get_tokens(s): 34 | if not s: 35 | return [] 36 | return normalize_answer(s).split() 37 | 38 | 39 | def compute_exact(a_gold, a_pred): 40 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 41 | 42 | 43 | def compute_f1(a_gold, a_pred): 44 | gold_toks = get_tokens(a_gold) 45 | pred_toks = get_tokens(a_pred) 46 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 47 | num_same = sum(common.values()) 48 | if len(gold_toks) == 0 or len(pred_toks) == 0: 49 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 50 | return int(gold_toks == pred_toks) 51 | if num_same == 0: 52 | return 0 53 | precision = 1.0 * num_same / len(pred_toks) 54 | recall = 1.0 * num_same / len(gold_toks) 55 | f1 = (2 * precision * recall) / (precision + recall) 56 | return f1 57 | 58 | 59 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 60 | scores_for_ground_truths = [] 61 | for ground_truth in ground_truths: 62 | score = metric_fn(prediction, ground_truth) 63 | scores_for_ground_truths.append(score) 64 | return max(scores_for_ground_truths) 65 | 66 | 67 | class SquadAnswerEmF1Metric(Metric): 68 | def __init__(self) -> None: 69 | self._total_em = 0.0 70 | self._total_f1 = 0.0 71 | self._count = 0 72 | 73 | def __call__( 74 | self, 75 | predicted_answer: str, 76 | ground_truth_answers: List[str], 77 | ): 78 | 79 | predicted_answer = ftfy.fix_text(predicted_answer) 80 | ground_truth_answers = [ftfy.fix_text(e) for e in ground_truth_answers] 81 | 82 | assert isinstance(predicted_answer, str) 83 | assert isinstance(ground_truth_answers, (Tuple, List)) 84 | 85 | exact_scores = metric_max_over_ground_truths(compute_exact, predicted_answer, ground_truth_answers) 86 | f1_scores = metric_max_over_ground_truths(compute_f1, predicted_answer, ground_truth_answers) 87 | 88 | self._total_em += int(exact_scores) 89 | self._total_f1 += f1_scores 90 | self._count += 1 91 | 92 | def get_metric(self, reset: bool = False) -> Tuple[float, float]: 93 | exact_match = self._total_em / self._count if self._count > 0 else 0 94 | f1_score = self._total_f1 / self._count if self._count > 0 else 0 95 | if reset: 96 | self.reset() 97 | return {"em": round(exact_match, 3), "f1": round(f1_score, 3), "count": self._count} 98 | 99 | def reset(self): 100 | self._total_em = 0.0 101 | self._total_f1 = 0.0 102 | self._count = 0 103 | -------------------------------------------------------------------------------- /processing_scripts/process_2wikimultihopqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lib import read_json, write_jsonl 4 | 5 | 6 | def main(): 7 | 8 | set_names = ["train", "dev"] 9 | 10 | input_directory = os.path.join("raw_data", "2wikimultihopqa") 11 | output_directory = os.path.join("processed_data", "2wikimultihopqa") 12 | os.makedirs(output_directory, exist_ok=True) 13 | 14 | for set_name in set_names: 15 | print(f"Processing {set_name}") 16 | 17 | processed_instances = [] 18 | 19 | input_filepath = os.path.join(input_directory, f"{set_name}.json") 20 | output_filepath = os.path.join(output_directory, f"{set_name}.jsonl") 21 | 22 | raw_instances = read_json(input_filepath) 23 | 24 | for raw_instance in raw_instances: 25 | 26 | question_id = raw_instance["_id"] 27 | question_text = raw_instance["question"] 28 | raw_contexts = raw_instance["context"] 29 | 30 | supporting_titles = list(set([e[0] for e in raw_instance["supporting_facts"]])) 31 | 32 | evidences = raw_instance["evidences"] 33 | reasoning_steps = [" ".join(evidence) for evidence in evidences] 34 | 35 | processed_contexts = [] 36 | for index, raw_context in enumerate(raw_contexts): 37 | title = raw_context[0] 38 | paragraph_text = " ".join(raw_context[1]).strip() 39 | is_supporting = title in supporting_titles 40 | processed_contexts.append( 41 | { 42 | "idx": index, 43 | "title": title.strip(), 44 | "paragraph_text": paragraph_text, 45 | "is_supporting": is_supporting, 46 | } 47 | ) 48 | 49 | answers_object = { 50 | "number": "", 51 | "date": {"day": "", "month": "", "year": ""}, 52 | "spans": [raw_instance["answer"]], 53 | } 54 | answers_objects = [answers_object] 55 | 56 | processed_instance = { 57 | "question_id": question_id, 58 | "question_text": question_text, 59 | "answers_objects": answers_objects, 60 | "contexts": processed_contexts, 61 | "reasoning_steps": reasoning_steps, 62 | } 63 | 64 | processed_instances.append(processed_instance) 65 | 66 | write_jsonl(processed_instances, output_filepath) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /processing_scripts/process_hotpotqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import Counter 4 | from typing import List, Dict 5 | 6 | from tqdm import tqdm 7 | from datasets import load_dataset 8 | 9 | 10 | def write_hotpotqa_instances_to_filepath(instances: List[Dict], full_filepath: str): 11 | 12 | max_num_tokens = 1000 # clip later. 13 | 14 | hop_sizes = Counter() 15 | print(f"Writing in: {full_filepath}") 16 | with open(full_filepath, "w") as full_file: 17 | for raw_instance in tqdm(instances): 18 | 19 | # Generic RC Format 20 | processed_instance = {} 21 | processed_instance["dataset"] = "hotpotqa" 22 | processed_instance["question_id"] = raw_instance["id"] 23 | processed_instance["question_text"] = raw_instance["question"] 24 | processed_instance["level"] = raw_instance["level"] 25 | processed_instance["type"] = raw_instance["type"] 26 | 27 | answers_object = { 28 | "number": "", 29 | "date": {"day": "", "month": "", "year": ""}, 30 | "spans": [raw_instance["answer"]], 31 | } 32 | processed_instance["answers_objects"] = [answers_object] 33 | 34 | raw_context = raw_instance.pop("context") 35 | supporting_titles = raw_instance.pop("supporting_facts")["title"] 36 | 37 | title_to_paragraph = { 38 | title: "".join(text) for title, text in zip(raw_context["title"], raw_context["sentences"]) 39 | } 40 | paragraph_to_title = { 41 | "".join(text): title for title, text in zip(raw_context["title"], raw_context["sentences"]) 42 | } 43 | 44 | gold_paragraph_texts = [title_to_paragraph[title] for title in supporting_titles] 45 | gold_paragraph_texts = set(list(gold_paragraph_texts)) 46 | 47 | paragraph_texts = ["".join(paragraph) for paragraph in raw_context["sentences"]] 48 | paragraph_texts = list(set(paragraph_texts)) 49 | 50 | processed_instance["contexts"] = [ 51 | { 52 | "idx": index, 53 | "title": paragraph_to_title[paragraph_text].strip(), 54 | "paragraph_text": paragraph_text.strip(), 55 | "is_supporting": paragraph_text in gold_paragraph_texts, 56 | } 57 | for index, paragraph_text in enumerate(paragraph_texts) 58 | ] 59 | 60 | supporting_contexts = [context for context in processed_instance["contexts"] if context["is_supporting"]] 61 | hop_sizes[len(supporting_contexts)] += 1 62 | 63 | for context in processed_instance["contexts"]: 64 | context["paragraph_text"] = " ".join(context["paragraph_text"].split(" ")[:max_num_tokens]) 65 | 66 | full_file.write(json.dumps(processed_instance) + "\n") 67 | 68 | print(f"Hop-sizes: {str(hop_sizes)}") 69 | 70 | 71 | if __name__ == "__main__": 72 | 73 | dataset = load_dataset("hotpot_qa", "distractor") 74 | 75 | directory = os.path.join("processed_data", "hotpotqa") 76 | os.makedirs(directory, exist_ok=True) 77 | 78 | processed_full_filepath = os.path.join(directory, "train.jsonl") 79 | write_hotpotqa_instances_to_filepath(dataset["train"], processed_full_filepath) 80 | 81 | processed_full_filepath = os.path.join(directory, "dev.jsonl") 82 | write_hotpotqa_instances_to_filepath(dataset["validation"], processed_full_filepath) 83 | -------------------------------------------------------------------------------- /processing_scripts/process_musique.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from lib import read_jsonl, write_jsonl 4 | 5 | 6 | def main(): 7 | 8 | set_names = ["train", "dev"] 9 | input_directory = os.path.join("raw_data", "musique") 10 | output_directory = os.path.join("processed_data", "musique") 11 | os.makedirs(output_directory, exist_ok=True) 12 | 13 | for set_name in set_names: 14 | processed_instances = [] 15 | 16 | input_filepath = os.path.join(input_directory, f"musique_ans_v1.0_{set_name}.jsonl") 17 | output_filepath = os.path.join(output_directory, f"{set_name}.jsonl") 18 | 19 | raw_instances = read_jsonl(input_filepath) 20 | 21 | for raw_instance in raw_instances: 22 | 23 | answers_object = { 24 | "number": "", 25 | "date": {"day": "", "month": "", "year": ""}, 26 | "spans": [raw_instance["answer"]], 27 | } 28 | 29 | number_to_answer = {} 30 | sentences = [] 31 | for index, reasoning_step in enumerate(raw_instance["question_decomposition"]): 32 | number = index + 1 33 | question = reasoning_step["question"] 34 | for mentioned_number in range(1, 10): 35 | if f"#{mentioned_number}" in reasoning_step["question"]: 36 | if mentioned_number not in number_to_answer: 37 | print("WARNING: mentioned_number not present in number_to_answer.") 38 | else: 39 | question = question.replace(f"#{mentioned_number}", number_to_answer[mentioned_number]) 40 | answer = reasoning_step["answer"] 41 | number_to_answer[number] = answer 42 | sentence = " >>>> ".join([question.strip(), answer.strip()]) 43 | sentences.append(sentence) 44 | 45 | processed_instance = { 46 | "question_id": raw_instance["id"], 47 | "question_text": raw_instance["question"], 48 | "contexts": [ 49 | { 50 | "idx": index, 51 | "paragraph_text": paragraph["paragraph_text"].strip(), 52 | "title": paragraph["title"].strip(), 53 | "is_supporting": paragraph["is_supporting"], 54 | } 55 | for index, paragraph in enumerate(raw_instance["paragraphs"]) 56 | ], 57 | "answers_objects": [answers_object], 58 | "reasoning_steps": sentences, 59 | } 60 | processed_instances.append(processed_instance) 61 | 62 | write_jsonl(processed_instances, output_filepath) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /processing_scripts/subsample_dataset_and_remap_paras.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import os 4 | 5 | from tqdm import tqdm 6 | from lib import read_jsonl, write_jsonl, find_matching_paragraph_text 7 | 8 | random.seed(13370) # Don't change this. 9 | 10 | 11 | def main(): 12 | 13 | parser = argparse.ArgumentParser(description="Save and sample data") 14 | parser.add_argument( 15 | "dataset_name", type=str, help="dataset name.", choices=("hotpotqa", "2wikimultihopqa", "musique", "iirc") 16 | ) 17 | parser.add_argument("set_name", type=str, help="set name.", choices=("dev", "test")) 18 | args = parser.parse_args() 19 | 20 | avoid_question_ids_file_path = None 21 | sample_size = 100 22 | if args.set_name == "test": 23 | avoid_question_ids_file_path = os.path.join("processed_data", args.dataset_name, "dev_subsampled.jsonl") 24 | sample_size = 500 25 | 26 | input_file_path = os.path.join("processed_data", args.dataset_name, "dev.jsonl") 27 | instances = read_jsonl(input_file_path) 28 | 29 | if avoid_question_ids_file_path: 30 | avoid_ids = set([avoid_instance["question_id"] for avoid_instance in read_jsonl(avoid_question_ids_file_path)]) 31 | instances = [instance for instance in instances if instance["question_id"] not in avoid_ids] 32 | 33 | instances = random.sample(instances, sample_size) 34 | 35 | for instance in tqdm(instances): 36 | for context in instance["contexts"]: 37 | 38 | if context in instance.get("pinned_contexts", []): 39 | # pinned contexts (iirc main) aren't in the associated wikipedia corpus. 40 | continue 41 | 42 | retrieved_result = find_matching_paragraph_text(args.dataset_name, context["paragraph_text"]) 43 | 44 | if retrieved_result is None: 45 | continue 46 | 47 | context["title"] = retrieved_result["title"] 48 | context["paragraph_text"] = retrieved_result["paragraph_text"] 49 | 50 | output_file_path = os.path.join("processed_data", args.dataset_name, f"{args.set_name}_subsampled.jsonl") 51 | write_jsonl(instances, output_file_path) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /prompt_generator/generate_prompts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from typing import List, Dict 4 | 5 | from prompt_generator.common import QAPromptGenerator, NoContextOpenRetrieverPromptGenerator 6 | 7 | 8 | def get_qa_prompt_generator_args_and_names(dataset_name: str) -> List[Dict]: 9 | max_paragraph_tokens = 250 # keep it fixed 250. 10 | prompt_generator_args_and_names = [] 11 | model_names = ["codex", "flan_t5"] 12 | for model_name in model_names: 13 | 14 | for qa_type in ("direct", "cot"): 15 | for context_type in ("no", "gold_with_distractors"): 16 | distractor_counts = (0,) if context_type == "no" else (1, 2, 3) 17 | for distractor_count in distractor_counts: 18 | 19 | if distractor_count == 0: 20 | assert context_type == "no" 21 | 22 | prompt_generator_args = { 23 | "qa_type": qa_type, 24 | "context_type": context_type, 25 | "distractor_count": distractor_count, 26 | "model_name": model_name, 27 | } 28 | if dataset_name == "iirc" and model_name == "flan_t5": 29 | prompt_generator_args["pinned_at_bottom"] = model_name == "flan_t5" 30 | 31 | context_type_ = f"gold_with_{distractor_count}_distractors" 32 | if not distractor_count: 33 | context_type_ = "no" 34 | 35 | prompt_name = f"{context_type_}_context_{qa_type}_qa_{model_name}.txt" 36 | prompt_generator_args_and_names.append( 37 | { 38 | "generator_args": prompt_generator_args, 39 | "name": prompt_name, 40 | "max_paragraph_tokens": max_paragraph_tokens, 41 | } 42 | ) 43 | 44 | return prompt_generator_args_and_names 45 | 46 | 47 | def get_no_context_open_retrieval_prompt_generator_args_and_names(dataset_name: str) -> List[Dict]: 48 | max_paragraph_tokens = 250 49 | prompt_generator_args_and_names = [] 50 | 51 | prompt_name = "no_context_open_llm_retrieval_codex.txt" 52 | prompt_generator_args_and_names.append( 53 | {"generator_args": {"model_name": "codex"}, "name": prompt_name, "max_paragraph_tokens": max_paragraph_tokens} 54 | ) 55 | 56 | prompt_name = "no_context_open_llm_retrieval_flan_t5.txt" 57 | prompt_generator_args_and_names.append( 58 | {"generator_args": {"model_name": "flan_t5"}, "name": prompt_name, "max_paragraph_tokens": max_paragraph_tokens} 59 | ) 60 | 61 | return prompt_generator_args_and_names 62 | 63 | 64 | def main(): 65 | 66 | parser = argparse.ArgumentParser(description="Generate prompts.") 67 | parser.add_argument( 68 | "dataset_name", type=str, help="dataset_name", choices={"hotpotqa", "2wikimultihopqa", "musique", "iirc"} 69 | ) 70 | args = parser.parse_args() 71 | 72 | input_file_path = os.path.join("processed_data", args.dataset_name, "annotated_only_train.jsonl") 73 | output_directory = os.path.join("prompts", args.dataset_name) 74 | 75 | task_names = ["qa"] 76 | if args.dataset_name == "iirc": 77 | task_names.append("no_context_open_retrieval") 78 | 79 | for task_name in task_names: 80 | 81 | if task_name == "qa": 82 | args_name_generator = get_qa_prompt_generator_args_and_names 83 | prompt_generator_cls = QAPromptGenerator 84 | elif task_name == "no_context_open_retrieval": 85 | args_name_generator = get_no_context_open_retrieval_prompt_generator_args_and_names 86 | prompt_generator_cls = NoContextOpenRetrieverPromptGenerator 87 | else: 88 | raise Exception(f"Invalid task_name {task_name}") 89 | 90 | for prompt_args_and_name in args_name_generator(args.dataset_name): 91 | 92 | generator_args = prompt_args_and_name["generator_args"] 93 | generator_args["input_file_path"] = input_file_path 94 | prompt_generator = prompt_generator_cls(**generator_args) 95 | 96 | output_file_name = prompt_args_and_name["name"] 97 | output_file_path = os.path.join(output_directory, output_file_name) 98 | 99 | prompt_args_and_name.pop("generator_args") 100 | prompt_args_and_name.pop("name") 101 | prompt_args_and_name.pop("max_paragraph_tokens") 102 | if prompt_args_and_name: 103 | raise Exception("Looks like prompt_args_and_name has extra unused args.") 104 | 105 | print(f"Writing in {output_file_path}") 106 | with open(output_file_path, "w") as file: 107 | file.write(prompt_generator.generate()) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /prompts/2wikimultihopqa/no_context_direct_qa_codex.txt: -------------------------------------------------------------------------------- 1 | # METADATA: {"qid": "35bf3490096d11ebbdafac1f6bf848b6"} 2 | Q: Are both Kurram Garhi and Trojkrsti located in the same country? 3 | A: no 4 | 5 | 6 | # METADATA: {"qid": "e5150a5a0bda11eba7f7acde48001122"} 7 | Q: When did the director of film Laughter In Hell die? 8 | A: August 25, 1963 9 | 10 | 11 | # METADATA: {"qid": "f44939100bda11eba7f7acde48001122"} 12 | Q: What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother? 13 | A: tuberculosis 14 | 15 | 16 | # METADATA: {"qid": "af8c6722088b11ebbd6fac1f6bf848b6"} 17 | Q: Are the directors of films The Sun of the Sleepless and Nevada (1927 film) both from the same country? 18 | A: no 19 | 20 | 21 | # METADATA: {"qid": "8727d1280bdc11eba7f7acde48001122"} 22 | Q: When was the director of film P.S. Jerusalem born? 23 | A: December 23, 1970 24 | 25 | 26 | # METADATA: {"qid": "5811079c0bdc11eba7f7acde48001122"} 27 | Q: When did the director of film Hypocrite (Film) die? 28 | A: 19 June 2013 29 | 30 | 31 | # METADATA: {"qid": "79a863dc0bdc11eba7f7acde48001122"} 32 | Q: Where did the director of film Maddalena (1954 Film) die? 33 | A: Rome 34 | 35 | 36 | # METADATA: {"qid": "f86b4a28091711ebbdaeac1f6bf848b6"} 37 | Q: Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)? 38 | A: What's Inside 39 | 40 | 41 | # METADATA: {"qid": "228546780bdd11eba7f7acde48001122"} 42 | Q: What is the date of birth of the director of film Best Friends (1982 Film)? 43 | A: July 21, 1926 44 | 45 | 46 | # METADATA: {"qid": "c6805b2908a911ebbd80ac1f6bf848b6"} 47 | Q: Who was born first out of Martin Hodge and Ivania Martinich? 48 | A: Martin Hodge 49 | 50 | 51 | # METADATA: {"qid": "97954d9408b011ebbd84ac1f6bf848b6"} 52 | Q: Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality? 53 | A: no 54 | 55 | 56 | # METADATA: {"qid": "028eaef60bdb11eba7f7acde48001122"} 57 | Q: When did the director of film The Boy And The Fog die? 58 | A: September 4, 1986 59 | 60 | 61 | # METADATA: {"qid": "c6f63bfb089e11ebbd78ac1f6bf848b6"} 62 | Q: Which film has the director born first, Two Weeks With Pay or Chhailla Babu? 63 | A: Two Weeks With Pay 64 | 65 | 66 | # METADATA: {"qid": "cdbb82ec0baf11ebab90acde48001122"} 67 | Q: Who is Boraqchin (Wife Of Ögedei)'s father-in-law? 68 | A: Genghis Khan 69 | 70 | 71 | # METADATA: {"qid": "f1ccdfee094011ebbdaeac1f6bf848b6"} 72 | Q: Which album was released more recently, If I Have to Stand Alone or Answering Machine Music? 73 | A: Answering Machine Music 74 | 75 | 76 | # METADATA: {"qid": "1ceeab380baf11ebab90acde48001122"} 77 | Q: Who is the grandchild of Krishna Shah (Nepalese Royal)? 78 | A: Prithvipati Shah 79 | 80 | 81 | # METADATA: {"qid": "13cda43c09b311ebbdb0ac1f6bf848b6"} 82 | Q: Are both mountains, Serre Mourene and Monte Galbiga, located in the same country? 83 | A: no 84 | 85 | 86 | # METADATA: {"qid": "a5995da508ab11ebbd82ac1f6bf848b6"} 87 | Q: Which film has the director died later, The Gal Who Took the West or Twenty Plus Two? 88 | A: Twenty Plus Two 89 | 90 | 91 | # METADATA: {"qid": "4724c54e08e011ebbda1ac1f6bf848b6"} 92 | Q: Which film has the director died earlier, When The Mad Aunts Arrive or The Miracle Worker (1962 Film)? 93 | A: When The Mad Aunts Arrive 94 | 95 | 96 | # METADATA: {"qid": "5897ec7a086c11ebbd61ac1f6bf848b6"} 97 | Q: Which film came out first, The Night Of Tricks or The Genealogy? 98 | A: The Night Of Tricks -------------------------------------------------------------------------------- /prompts/2wikimultihopqa/no_context_direct_qa_flan_t5.txt: -------------------------------------------------------------------------------- 1 | # METADATA: {"qid": "35bf3490096d11ebbdafac1f6bf848b6"} 2 | Q: Answer the following question. 3 | Are both Kurram Garhi and Trojkrsti located in the same country? 4 | A: no 5 | 6 | 7 | # METADATA: {"qid": "e5150a5a0bda11eba7f7acde48001122"} 8 | Q: Answer the following question. 9 | When did the director of film Laughter In Hell die? 10 | A: August 25, 1963 11 | 12 | 13 | # METADATA: {"qid": "f44939100bda11eba7f7acde48001122"} 14 | Q: Answer the following question. 15 | What is the cause of death of Grand Duke Alexei Alexandrovich Of Russia's mother? 16 | A: tuberculosis 17 | 18 | 19 | # METADATA: {"qid": "af8c6722088b11ebbd6fac1f6bf848b6"} 20 | Q: Answer the following question. 21 | Are the directors of films The Sun of the Sleepless and Nevada (1927 film) both from the same country? 22 | A: no 23 | 24 | 25 | # METADATA: {"qid": "8727d1280bdc11eba7f7acde48001122"} 26 | Q: Answer the following question. 27 | When was the director of film P.S. Jerusalem born? 28 | A: December 23, 1970 29 | 30 | 31 | # METADATA: {"qid": "5811079c0bdc11eba7f7acde48001122"} 32 | Q: Answer the following question. 33 | When did the director of film Hypocrite (Film) die? 34 | A: 19 June 2013 35 | 36 | 37 | # METADATA: {"qid": "79a863dc0bdc11eba7f7acde48001122"} 38 | Q: Answer the following question. 39 | Where did the director of film Maddalena (1954 Film) die? 40 | A: Rome 41 | 42 | 43 | # METADATA: {"qid": "f86b4a28091711ebbdaeac1f6bf848b6"} 44 | Q: Answer the following question. 45 | Which album was released earlier, What'S Inside or Cassandra'S Dream (Album)? 46 | A: What's Inside 47 | 48 | 49 | # METADATA: {"qid": "228546780bdd11eba7f7acde48001122"} 50 | Q: Answer the following question. 51 | What is the date of birth of the director of film Best Friends (1982 Film)? 52 | A: July 21, 1926 53 | 54 | 55 | # METADATA: {"qid": "c6805b2908a911ebbd80ac1f6bf848b6"} 56 | Q: Answer the following question. 57 | Who was born first out of Martin Hodge and Ivania Martinich? 58 | A: Martin Hodge 59 | 60 | 61 | # METADATA: {"qid": "97954d9408b011ebbd84ac1f6bf848b6"} 62 | Q: Answer the following question. 63 | Do director of film Coolie No. 1 (1995 Film) and director of film The Sensational Trial have the same nationality? 64 | A: no 65 | 66 | 67 | # METADATA: {"qid": "028eaef60bdb11eba7f7acde48001122"} 68 | Q: Answer the following question. 69 | When did the director of film The Boy And The Fog die? 70 | A: September 4, 1986 71 | 72 | 73 | # METADATA: {"qid": "c6f63bfb089e11ebbd78ac1f6bf848b6"} 74 | Q: Answer the following question. 75 | Which film has the director born first, Two Weeks With Pay or Chhailla Babu? 76 | A: Two Weeks With Pay 77 | 78 | 79 | # METADATA: {"qid": "cdbb82ec0baf11ebab90acde48001122"} 80 | Q: Answer the following question. 81 | Who is Boraqchin (Wife Of Ögedei)'s father-in-law? 82 | A: Genghis Khan 83 | 84 | 85 | # METADATA: {"qid": "f1ccdfee094011ebbdaeac1f6bf848b6"} 86 | Q: Answer the following question. 87 | Which album was released more recently, If I Have to Stand Alone or Answering Machine Music? 88 | A: Answering Machine Music 89 | 90 | 91 | # METADATA: {"qid": "1ceeab380baf11ebab90acde48001122"} 92 | Q: Answer the following question. 93 | Who is the grandchild of Krishna Shah (Nepalese Royal)? 94 | A: Prithvipati Shah 95 | 96 | 97 | # METADATA: {"qid": "13cda43c09b311ebbdb0ac1f6bf848b6"} 98 | Q: Answer the following question. 99 | Are both mountains, Serre Mourene and Monte Galbiga, located in the same country? 100 | A: no 101 | 102 | 103 | # METADATA: {"qid": "a5995da508ab11ebbd82ac1f6bf848b6"} 104 | Q: Answer the following question. 105 | Which film has the director died later, The Gal Who Took the West or Twenty Plus Two? 106 | A: Twenty Plus Two 107 | 108 | 109 | # METADATA: {"qid": "4724c54e08e011ebbda1ac1f6bf848b6"} 110 | Q: Answer the following question. 111 | Which film has the director died earlier, When The Mad Aunts Arrive or The Miracle Worker (1962 Film)? 112 | A: When The Mad Aunts Arrive 113 | 114 | 115 | # METADATA: {"qid": "5897ec7a086c11ebbd61ac1f6bf848b6"} 116 | Q: Answer the following question. 117 | Which film came out first, The Night Of Tricks or The Genealogy? 118 | A: The Night Of Tricks -------------------------------------------------------------------------------- /prompts/musique/no_context_direct_qa_codex.txt: -------------------------------------------------------------------------------- 1 | # METADATA: {"qid": "2hop__292995_8796"} 2 | Q: When was Neville A. Stanton's employer founded? 3 | A: 1862 4 | 5 | 6 | # METADATA: {"qid": "2hop__154225_727337"} 7 | Q: What is the headquarters for the organization who sets the standards for ISO 21500? 8 | A: Geneva 9 | 10 | 11 | # METADATA: {"qid": "2hop__642271_608104"} 12 | Q: What region of the state where Guy Shepherdson was born, contains SMA Negeri 68? 13 | A: Central Jakarta 14 | 15 | 16 | # METADATA: {"qid": "2hop__782642_52667"} 17 | Q: When was the first railway line constructed between Kotri and the city where Marie Adelaide Leprosy Centre is located? 18 | A: April 1858 19 | 20 | 21 | # METADATA: {"qid": "2hop__439265_539716"} 22 | Q: What county is Hebron located in, in the same province the Heritage Places Protection Act applies to? 23 | A: Prince County 24 | 25 | 26 | # METADATA: {"qid": "2hop__323282_79175"} 27 | Q: When did the first large winter carnival take place in the city where CIMI-FM is licensed to broadcast? 28 | A: 1894 29 | 30 | 31 | # METADATA: {"qid": "2hop__427213_79175"} 32 | Q: When did the first large winter carnival happen in Olivier Robitaille's place of birth? 33 | A: 1894 34 | 35 | 36 | # METADATA: {"qid": "2hop__387702_20661"} 37 | Q: When did Britain withdraw from the country containing Hoora? 38 | A: 1971 39 | 40 | 41 | # METADATA: {"qid": "2hop__195347_20661"} 42 | Q: When did Britain withdraw from the country where the village of Wadyan is found? 43 | A: 1971 44 | 45 | 46 | # METADATA: {"qid": "2hop__861128_15822"} 47 | Q: What did the publisher of Banjo-Tooie rely primarily on for its support? 48 | A: first-party games 49 | 50 | 51 | # METADATA: {"qid": "2hop__496817_701819"} 52 | Q: What shares a border with Rivière-Verte in the province WRSU-FM broadcasts in? 53 | A: Edmundston 54 | 55 | 56 | # METADATA: {"qid": "2hop__804754_52230"} 57 | Q: When was the state of emergency declared in the country where the Senate is located? 58 | A: 20 October 1952 59 | 60 | 61 | # METADATA: {"qid": "2hop__102217_58400"} 62 | Q: Where is the crying stone found in the country in which Raphael Tuju holds citizenship? 63 | A: along the highway towards Kisumu 64 | 65 | 66 | # METADATA: {"qid": "2hop__131516_53573"} 67 | Q: Where does the Snake River start, in the state where Lima Mountain is located? 68 | A: southern Aitkin County 69 | 70 | 71 | # METADATA: {"qid": "3hop1__753524_742157_573834"} 72 | Q: What genre is the record label of the performer of So Long, See You Tomorrow associated with? 73 | A: jazz 74 | 75 | 76 | # METADATA: {"qid": "3hop1__858730_386977_851569"} 77 | Q: In which county was the birthplace of the Smoke in tha City performer? 78 | A: Los Angeles County 79 | 80 | 81 | # METADATA: {"qid": "3hop1__443556_763924_573834"} 82 | Q: What is the genre of the record label of the band that performed on the Crush Tour? 83 | A: jazz 84 | 85 | 86 | # METADATA: {"qid": "3hop1__61746_67065_43617"} 87 | Q: How long is the US border with the country that borders the state where Finding Dory takes place? 88 | A: 1,989 mi 89 | 90 | 91 | # METADATA: {"qid": "4hop3__703974_789671_24078_24137"} 92 | Q: What weekly publication in the Connecticut city with the most Zagat rated restaurants is issued by university of America-Lite: How Imperial Academia Dismantled Our Culture's author? 93 | A: Yale Herald 94 | 95 | 96 | # METADATA: {"qid": "4hop3__463724_100414_35260_54090"} 97 | Q: How many countries in Pacific National University's continent are recognized by the organization that mediated the truce ending the Iran-Iraq war? 98 | A: 53 -------------------------------------------------------------------------------- /prompts/musique/no_context_direct_qa_flan_t5.txt: -------------------------------------------------------------------------------- 1 | # METADATA: {"qid": "2hop__292995_8796"} 2 | Q: Answer the following question. 3 | When was Neville A. Stanton's employer founded? 4 | A: 1862 5 | 6 | 7 | # METADATA: {"qid": "2hop__154225_727337"} 8 | Q: Answer the following question. 9 | What is the headquarters for the organization who sets the standards for ISO 21500? 10 | A: Geneva 11 | 12 | 13 | # METADATA: {"qid": "2hop__642271_608104"} 14 | Q: Answer the following question. 15 | What region of the state where Guy Shepherdson was born, contains SMA Negeri 68? 16 | A: Central Jakarta 17 | 18 | 19 | # METADATA: {"qid": "2hop__782642_52667"} 20 | Q: Answer the following question. 21 | When was the first railway line constructed between Kotri and the city where Marie Adelaide Leprosy Centre is located? 22 | A: April 1858 23 | 24 | 25 | # METADATA: {"qid": "2hop__439265_539716"} 26 | Q: Answer the following question. 27 | What county is Hebron located in, in the same province the Heritage Places Protection Act applies to? 28 | A: Prince County 29 | 30 | 31 | # METADATA: {"qid": "2hop__323282_79175"} 32 | Q: Answer the following question. 33 | When did the first large winter carnival take place in the city where CIMI-FM is licensed to broadcast? 34 | A: 1894 35 | 36 | 37 | # METADATA: {"qid": "2hop__427213_79175"} 38 | Q: Answer the following question. 39 | When did the first large winter carnival happen in Olivier Robitaille's place of birth? 40 | A: 1894 41 | 42 | 43 | # METADATA: {"qid": "2hop__387702_20661"} 44 | Q: Answer the following question. 45 | When did Britain withdraw from the country containing Hoora? 46 | A: 1971 47 | 48 | 49 | # METADATA: {"qid": "2hop__195347_20661"} 50 | Q: Answer the following question. 51 | When did Britain withdraw from the country where the village of Wadyan is found? 52 | A: 1971 53 | 54 | 55 | # METADATA: {"qid": "2hop__861128_15822"} 56 | Q: Answer the following question. 57 | What did the publisher of Banjo-Tooie rely primarily on for its support? 58 | A: first-party games 59 | 60 | 61 | # METADATA: {"qid": "2hop__496817_701819"} 62 | Q: Answer the following question. 63 | What shares a border with Rivière-Verte in the province WRSU-FM broadcasts in? 64 | A: Edmundston 65 | 66 | 67 | # METADATA: {"qid": "2hop__804754_52230"} 68 | Q: Answer the following question. 69 | When was the state of emergency declared in the country where the Senate is located? 70 | A: 20 October 1952 71 | 72 | 73 | # METADATA: {"qid": "2hop__102217_58400"} 74 | Q: Answer the following question. 75 | Where is the crying stone found in the country in which Raphael Tuju holds citizenship? 76 | A: along the highway towards Kisumu 77 | 78 | 79 | # METADATA: {"qid": "2hop__131516_53573"} 80 | Q: Answer the following question. 81 | Where does the Snake River start, in the state where Lima Mountain is located? 82 | A: southern Aitkin County 83 | 84 | 85 | # METADATA: {"qid": "3hop1__753524_742157_573834"} 86 | Q: Answer the following question. 87 | What genre is the record label of the performer of So Long, See You Tomorrow associated with? 88 | A: jazz 89 | 90 | 91 | # METADATA: {"qid": "3hop1__858730_386977_851569"} 92 | Q: Answer the following question. 93 | In which county was the birthplace of the Smoke in tha City performer? 94 | A: Los Angeles County 95 | 96 | 97 | # METADATA: {"qid": "3hop1__443556_763924_573834"} 98 | Q: Answer the following question. 99 | What is the genre of the record label of the band that performed on the Crush Tour? 100 | A: jazz 101 | 102 | 103 | # METADATA: {"qid": "3hop1__61746_67065_43617"} 104 | Q: Answer the following question. 105 | How long is the US border with the country that borders the state where Finding Dory takes place? 106 | A: 1,989 mi 107 | 108 | 109 | # METADATA: {"qid": "4hop3__703974_789671_24078_24137"} 110 | Q: Answer the following question. 111 | What weekly publication in the Connecticut city with the most Zagat rated restaurants is issued by university of America-Lite: How Imperial Academia Dismantled Our Culture's author? 112 | A: Yale Herald 113 | 114 | 115 | # METADATA: {"qid": "4hop3__463724_100414_35260_54090"} 116 | Q: Answer the following question. 117 | How many countries in Pacific National University's continent are recognized by the organization that mediated the truce ending the Iran-Iraq war? 118 | A: 53 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | force-exclude = ''' 4 | archive 5 | datasets 6 | solver_outputs 7 | ''' 8 | 9 | [tool.ruff] 10 | target-version = "py38" 11 | line-length = 120 12 | ignore-init-module-imports = true 13 | select = [ 14 | "F", # pyflakes (default) 15 | "E", # pycodestyle errors (default) 16 | ] 17 | ignore = [ 18 | "E501", # line too long, handled by black 19 | "B008", # do not perform function calls in argument defaults 20 | "C901", # too complex 21 | ] 22 | extend-exclude = ["archive", "datasets", "solver_outputs"] 23 | respect-gitignore = true 24 | show-source = false 25 | [tool.ruff.isort] 26 | lines-after-imports = 2 27 | [tool.ruff.pydocstyle] 28 | convention = "google" -------------------------------------------------------------------------------- /reproduce.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Expected command line argument values. 4 | valid_systems=("ircot" "ircot_qa" "oner" "oner_qa" "nor_qa") 5 | valid_models=("codex" "flan-t5-xxl" "flan-t5-xl" "flan-t5-large" "flan-t5-base" "none") 6 | valid_datasets=("hotpotqa" "2wikimultihopqa" "musique" "iirc") 7 | 8 | # Function to check if an argument is valid 9 | check_argument() { 10 | local arg="$1" 11 | local position="$2" 12 | local valid_values=("${!3}") 13 | if ! [[ " ${valid_values[*]} " =~ " $arg " ]]; then 14 | echo "argument number $position is not a valid. Please provide one of: ${valid_values[*]}" 15 | exit 1 16 | fi 17 | 18 | if [[ $position -eq 2 && $arg == "none" && $1 != "oner" ]]; then 19 | echo "The model argument can only be 'none' only if the system argument is 'oner'." 20 | exit 1 21 | fi 22 | } 23 | 24 | # Check the number of arguments 25 | if [[ $# -ne 3 ]]; then 26 | echo "Error: Invalid number of arguments. Expected format: ./reproduce.sh SYSTEM MODEL DATASET" 27 | exit 1 28 | fi 29 | 30 | # Check the validity of arguments 31 | check_argument "$1" 1 valid_systems[*] 32 | check_argument "$2" 2 valid_models[*] 33 | check_argument "$3" 3 valid_datasets[*] 34 | 35 | echo ">>>> Instantiate experiment configs with different HPs and write them in files. <<<<" 36 | python runner.py $1 $2 $3 write --prompt_set 1 37 | python runner.py $1 $2 $3 write --prompt_set 2 38 | python runner.py $1 $2 $3 write --prompt_set 3 39 | ## if you make a change to base_configs, the above steps need to be rerun to 40 | ## regenerate instantiated experiment configs (with HPs populated) 41 | 42 | echo ">>>> Run experiments for different HPs on the dev set. <<<<" 43 | python runner.py $1 $2 $3 predict --prompt_set 1 44 | ## If prediction files already exist, it won't redo them. Pass --force if you want to redo. 45 | 46 | echo ">>>> Run evaluation for different HPs on the dev set. <<<<" 47 | python runner.py $1 $2 $3 evaluate --prompt_set 1 48 | ## This runs by default after prediction. This is mainly to show a standalone option. 49 | 50 | echo ">>>> Show results for experiments with different HPs <<<<" 51 | python runner.py $1 $2 $3 summarize --prompt_set 1 52 | ## Not necessary as such, it'll just show you the results using different HPs in a nice table. 53 | 54 | echo ">>>> Pick the best HP and save the config with that HP. <<<<" 55 | python runner.py $1 $2 $3 write --prompt_set 1 --best 56 | python runner.py $1 $2 $3 write --prompt_set 2 --best 57 | python runner.py $1 $2 $3 write --prompt_set 3 --best 58 | 59 | echo ">>>> Run the experiment with best HP on the test set <<<<" 60 | python runner.py $1 $2 $3 predict --prompt_set 1 --best --eval_test --official 61 | python runner.py $1 $2 $3 predict --prompt_set 2 --best --eval_test --official 62 | python runner.py $1 $2 $3 predict --prompt_set 3 --best --eval_test --official 63 | ## If prediction files already exist, it won't redo them. Pass --force if you want to redo. 64 | 65 | echo ">>>> Run evaluation for the best HP on the test set <<<<" 66 | python runner.py $1 $2 $3 evaluate --prompt_set 1 --best --eval_test --official 67 | python runner.py $1 $2 $3 evaluate --prompt_set 2 --best --eval_test --official 68 | python runner.py $1 $2 $3 evaluate --prompt_set 3 --best --eval_test --official 69 | ## This runs by default after prediction. This is mainly to show a standalone option. 70 | 71 | echo ">>>> Summarize best test results for individual prompts and aggregate (mean +- std) of them) <<<<" 72 | python runner.py $1 $2 $3 summarize --prompt_set 1 --best --eval_test --official 73 | python runner.py $1 $2 $3 summarize --prompt_set 2 --best --eval_test --official 74 | python runner.py $1 $2 $3 summarize --prompt_set 3 --best --eval_test --official 75 | python runner.py $1 $2 $3 summarize --prompt_set aggregate --best --eval_test --official 76 | ## The mean and std in the final command is what we reported in the paper. 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonnet 2 | torch>=1.7,!=1.12.0 3 | git+https://github.com/huggingface/transformers.git@8637316e5e94ba0a2493e5df7846f2f23f46eaef 4 | accelerate==0.15.0 5 | sentencepiece 6 | protobuf==3.19.0 7 | nltk 8 | scipy 9 | openai 10 | diskcache 11 | typing_extensions<4.6.0 12 | spacy==3.4.1 # Only for one of the experiments. 13 | rapidfuzz 14 | datasets 15 | pandas 16 | requests 17 | tqdm 18 | ftfy 19 | ujson 20 | fastapi 21 | uvicorn[standard] 22 | elasticsearch==7.9.1 23 | dill 24 | base58 25 | pygments 26 | beautifulsoup4 27 | blingfire 28 | wget 29 | black 30 | ruff 31 | -------------------------------------------------------------------------------- /retriever_server/elasticsearch_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import argparse 4 | import subprocess 5 | import _jsonnet 6 | import json 7 | 8 | 9 | def is_elasticsearch_running() -> bool: 10 | try: 11 | res = requests.get("http://localhost:9200/_cluster/health") 12 | if res.status_code == 200: 13 | if res.json()["number_of_nodes"] > 0: 14 | return True 15 | return False 16 | except Exception: 17 | return False 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description="Start/stop or check status of elasticsearch server.") 22 | parser.add_argument("command", type=str, help="start, stop or check status", choices=("start", "stop", "status")) 23 | args = parser.parse_args() 24 | 25 | es_pid_path = os.path.expanduser("~/.es_pid") 26 | elasticsearch_path = json.loads(_jsonnet.evaluate_file(".global_config.jsonnet"))["ELASTICSEARCH_PATH"] 27 | 28 | if args.command == "start": 29 | 30 | if os.path.exists(es_pid_path): 31 | exit("ES PID file aleady exists. Turn off ES first.") 32 | 33 | command = 'ES_JAVA_OPTS="-Xms26g -Xmx26g" ' # larger heapsize needed for natcq 34 | command += f"{elasticsearch_path} --daemonize --silent --pidfile {es_pid_path}" 35 | subprocess.call(command, shell=True) 36 | 37 | elif args.command == "stop": 38 | 39 | if not os.path.exists(es_pid_path): 40 | exit("ES PID file not found. Recheck if it's running or turn it off manually.") 41 | 42 | command = f"pkill -F {es_pid_path}" 43 | subprocess.call(command, shell=True) 44 | 45 | if os.path.exists(es_pid_path): 46 | os.remove(es_pid_path) 47 | 48 | elif args.command == "status": 49 | 50 | if is_elasticsearch_running(): 51 | print("Elasticsearch is running.") 52 | else: 53 | print("Elasticsearch is NOT running.") 54 | 55 | if os.path.exists(es_pid_path): 56 | print("ES PID file does exist.") 57 | else: 58 | print("ES PID file does NOT exist.") 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /retriever_server/interactive_query.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import argparse 4 | 5 | from pygments import highlight, lexers, formatters 6 | 7 | 8 | def main(): 9 | 10 | parser = argparse.ArgumentParser(description="Query retriever interactively.") 11 | parser.add_argument( 12 | "--retrieval_method", 13 | type=str, 14 | help="retrieval_method", 15 | choices={ 16 | "retrieve_from_elasticsearch", 17 | "retrieve_from_blink", 18 | "retrieve_from_blink_and_elasticsearch", 19 | "retrieve_from_dpr", 20 | "retrieve_from_contriever", 21 | }, 22 | required=True, 23 | ) 24 | parser.add_argument("--host", type=str, help="host", required=True) 25 | parser.add_argument("--port", type=int, help="port", required=True) # 443 is default for ngrok 26 | parser.add_argument("--max_hits_count", type=int, help="max_hits_count", default=5, required=False) 27 | args = parser.parse_args() 28 | 29 | while True: 30 | query_text = input("Enter Query: ") 31 | 32 | params = { 33 | # choices: "retrieve_from_elasticsearch", "retrieve_from_blink", 34 | # "retrieve_from_blink_and_elasticsearch", "retrieve_from_dpr", 35 | # retrieve_from_contriever 36 | "retrieval_method": args.retrieval_method, 37 | #### 38 | "query_text": query_text, 39 | "max_hits_count": args.max_hits_count, 40 | } 41 | 42 | url = args.host.rstrip("/") + ":" + str(args.port) + "/retrieve" 43 | result = requests.post(url, json=params) 44 | 45 | if result.ok: 46 | 47 | result = result.json() 48 | retrieval = result["retrieval"] 49 | time_in_seconds = result["time_in_seconds"] 50 | retrieval_str = json.dumps(retrieval, indent=4) 51 | retrieval_str = highlight(retrieval_str.encode("utf-8"), lexers.JsonLexer(), formatters.TerminalFormatter()) 52 | 53 | print(f"Time taken in seconds: {time_in_seconds}") 54 | print(retrieval_str) 55 | 56 | else: 57 | print("Something went wrong!\n\n") 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /retriever_server/requirements.txt: -------------------------------------------------------------------------------- 1 | elasticsearch==7.9.1 # required this version 2 | tqdm 3 | dill 4 | base58 5 | fastapi 6 | jsonnet 7 | pygments 8 | uvicorn 9 | requests 10 | beautifulsoup4 11 | blingfire -------------------------------------------------------------------------------- /retriever_server/serve.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | from fastapi import FastAPI, Request 3 | 4 | from unified_retriever import UnifiedRetriever 5 | 6 | retriever = UnifiedRetriever(host="http://localhost/", port=9200) 7 | 8 | app = FastAPI() 9 | 10 | 11 | @app.get("/") 12 | async def index(): 13 | return {"message": "Hello! This is a retriever server."} 14 | 15 | 16 | @app.post("/retrieve/") 17 | async def retrieve(arguments: Request): # see the corresponding method in unified_retriever.py 18 | arguments = await arguments.json() 19 | retrieval_method = arguments.pop("retrieval_method") 20 | assert retrieval_method in ("retrieve_from_elasticsearch") 21 | start_time = perf_counter() 22 | retrieval = getattr(retriever, retrieval_method)(**arguments) 23 | end_time = perf_counter() 24 | time_in_seconds = round(end_time - start_time, 1) 25 | return {"retrieval": retrieval, "time_in_seconds": time_in_seconds} 26 | -------------------------------------------------------------------------------- /retriever_server/unified_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from elasticsearch_retriever import ElasticsearchRetriever 4 | 5 | 6 | class UnifiedRetriever: 7 | """ 8 | This class wrapper multiple different retrievers we experimented with. 9 | Since we settled with Elasticsearch, I've removed code related to other 10 | retrievers from here. Still keeping the wrapper for reproducibility. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | host: str = "http://localhost/", 16 | port: int = 9200, 17 | ): 18 | self._elasticsearch_retriever = ElasticsearchRetriever(host=host, port=port) 19 | 20 | def retrieve_from_elasticsearch( 21 | self, 22 | query_text: str, 23 | max_hits_count: int = 3, 24 | max_buffer_count: int = 100, 25 | document_type: str = "paragraph_text", 26 | allowed_titles: List[str] = None, 27 | allowed_paragraph_types: List[str] = None, 28 | paragraph_index: int = None, 29 | corpus_name: str = None, 30 | ) -> List[Dict]: 31 | 32 | assert document_type in ("title", "paragraph_text", "title_paragraph_text") 33 | 34 | if paragraph_index is not None: 35 | assert ( 36 | document_type == "paragraph_text" 37 | ), "paragraph_index not valid input for the document_type of paragraph_text." 38 | 39 | if self._elasticsearch_retriever is None: 40 | raise Exception("Elasticsearch retriever not initialized.") 41 | 42 | if document_type in ("paragraph_text", "title_paragraph_text"): 43 | is_abstract = True if corpus_name == "hotpotqa" else None # Note "None" and not False 44 | query_title_field_too = document_type == "title_paragraph_text" 45 | paragraphs_results = self._elasticsearch_retriever.retrieve_paragraphs( 46 | query_text=query_text, 47 | is_abstract=is_abstract, 48 | max_hits_count=max_hits_count, 49 | allowed_titles=allowed_titles, 50 | allowed_paragraph_types=allowed_paragraph_types, 51 | paragraph_index=paragraph_index, 52 | corpus_name=corpus_name, 53 | query_title_field_too=query_title_field_too, 54 | max_buffer_count=max_buffer_count, 55 | ) 56 | 57 | elif document_type == "title": 58 | paragraphs_results = self._elasticsearch_retriever.retrieve_titles( 59 | query_text=query_text, max_hits_count=max_hits_count, corpus_name=corpus_name 60 | ) 61 | 62 | return paragraphs_results 63 | --------------------------------------------------------------------------------