├── .gitignore ├── assets └── example.jpeg ├── config ├── filter │ ├── llama7b │ │ ├── squad.conf │ │ ├── newsqa.conf │ │ ├── searchqa.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf │ ├── llama70b │ │ ├── hotpotqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── searchqa.conf │ │ ├── nq.conf │ │ └── triviaqa.conf │ ├── mistral7b │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── searchqa.conf │ │ ├── nq.conf │ │ └── triviaqa.conf │ └── mixtral8x7b │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── searchqa.conf │ │ ├── nq.conf │ │ └── triviaqa.conf ├── custom │ ├── filter.conf │ ├── ob.conf │ ├── add.conf │ └── cb.conf ├── ob │ ├── llama70b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ ├── llama7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf │ ├── mistral7b │ │ ├── searchqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── newsqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ └── mixtral8x7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf ├── add │ ├── llama70b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ ├── llama7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf │ ├── mistral7b │ │ ├── searchqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── newsqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ └── mixtral8x7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf ├── mask │ ├── llama7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ ├── llama70b │ │ ├── searchqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── newsqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ ├── mistral7b │ │ ├── searchqa.conf │ │ ├── squad.conf │ │ ├── newsqa.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf │ └── mixtral8x7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf ├── icl │ ├── llama70b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf │ ├── llama7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf │ ├── mistral7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ └── mixtral8x7b │ │ ├── searchqa.conf │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── triviaqa.conf │ │ ├── hotpotqa.conf │ │ └── nq.conf ├── cb │ ├── llama70b │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── searchqa.conf │ │ ├── nq.conf │ │ └── triviaqa.conf │ ├── llama7b │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── searchqa.conf │ │ ├── hotpotqa.conf │ │ ├── nq.conf │ │ └── triviaqa.conf │ ├── mistral7b │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── searchqa.conf │ │ ├── triviaqa.conf │ │ └── nq.conf │ └── mixtral8x7b │ │ ├── newsqa.conf │ │ ├── squad.conf │ │ ├── hotpotqa.conf │ │ ├── searchqa.conf │ │ ├── nq.conf │ │ └── triviaqa.conf └── freshqa │ ├── llama70b.conf │ ├── llama7b.conf │ ├── mistral7b.conf │ └── mixtral8x7b.conf ├── 3_run_ob_experiment.py ├── 1_gather_cb_answers.py ├── src ├── validation.py ├── evaluation │ ├── icl.py │ ├── true.py │ ├── bem.py │ ├── exact_match.py │ ├── question_answering.py │ └── prompt_helpers.py ├── analysis.py ├── closedbook_exp.py ├── model_utils.py ├── file_utils.py └── openbook_exp.py ├── requirements.txt ├── 2_filter_out_no_conflict.py ├── 4_download_freshqa.py ├── 0_download_data.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | results/* 3 | 4 | 5 | *.py[cod] -------------------------------------------------------------------------------- /assets/example.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kortukov/realistic_knowledge_conflicts/HEAD/assets/example.jpeg -------------------------------------------------------------------------------- /config/filter/llama7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | 3 | dataset: "SQuAD" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama70b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | dataset: "HotpotQA" 3 | 4 | wrong_examples_path: "data//cb_wrong/.parquet" 5 | conflict_examples_path: "data//conflict/.parquet" 6 | -------------------------------------------------------------------------------- /config/filter/llama70b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | 3 | dataset: "NewsQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama70b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | 3 | dataset: "SQuAD" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | 3 | dataset: "NewsQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama70b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | 3 | dataset: "SearchQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | 3 | dataset: "SearchQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | 3 | dataset: "TriviaQA-web" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mistral7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | 3 | dataset: "NewsQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mistral7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | 3 | dataset: "SQuAD" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama70b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | 3 | dataset: "NaturalQuestionsShort" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama70b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | 3 | dataset: "TriviaQA-web" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mistral7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | 3 | dataset: "HotpotQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mistral7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | 3 | dataset: "SearchQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mixtral8x7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | 3 | dataset: "NewsQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mixtral8x7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | 3 | dataset: "SQuAD" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/custom/filter.conf: -------------------------------------------------------------------------------- 1 | model_name: "your_model_name" 2 | 3 | dataset: "your_dataset_name" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /config/filter/mistral7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | 3 | dataset: "NaturalQuestionsShort" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mistral7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | 3 | dataset: "TriviaQA-web" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mixtral8x7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | 3 | dataset: "HotpotQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mixtral8x7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | 3 | dataset: "SearchQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mixtral8x7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | 3 | dataset: "NaturalQuestionsShort" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/mixtral8x7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | 3 | dataset: "TriviaQA-web" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | -------------------------------------------------------------------------------- /config/filter/llama7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | 3 | dataset: "HotpotQA" 4 | 5 | wrong_examples_path: "data//cb_wrong/.parquet" 6 | conflict_examples_path: "data//conflict/.parquet" 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /config/filter/llama7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | 3 | dataset: "NaturalQuestionsShort" 4 | dataset_path: "data/test/.parquet" 5 | dataset_length: null 6 | 7 | 8 | correct_examples_path: null 9 | wrong_examples_path: "data//cb_wrong/.parquet" 10 | conflict_examples_path: "data//conflict/.parquet" 11 | -------------------------------------------------------------------------------- /config/ob/llama70b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mistral7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mixtral8x7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama70b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/custom/ob.conf: -------------------------------------------------------------------------------- 1 | model_name: "your_model_name" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "your_dataset_name" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/add/mistral7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mixtral8x7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama70b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama70b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/mask/llama7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/ob/llama70b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama70b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/llama7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mistral7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /3_run_ob_experiment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import munch 3 | import yaml 4 | 5 | import src.openbook_exp as openbook_exp 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--config", type=str, required=True, help="Path to the config file.") 12 | 13 | args = parser.parse_args() 14 | 15 | with open(args.config, 'r') as file: 16 | config = yaml.safe_load(file) 17 | config = munch.munchify(config) 18 | 19 | openbook_exp.run_openbook_experiment(config) 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /config/mask/llama70b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/ob/llama7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mistral7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mistral7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mistral7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mixtral8x7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mixtral8x7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mixtral8x7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/mask/mistral7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mixtral8x7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SearchQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/ob/llama70b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mixtral8x7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /1_gather_cb_answers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import munch 3 | import yaml 4 | 5 | import src.closedbook_exp as closedbook_exp 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("--config", type=str, required=True, help="Path to the config file.") 12 | 13 | args = parser.parse_args() 14 | 15 | with open(args.config, 'r') as file: 16 | config = yaml.safe_load(file) 17 | config = munch.munchify(config) 18 | 19 | closedbook_exp.run_closed_book_experiment(config) 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /config/ob/mistral7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/ob/mixtral8x7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//ob_.out" 15 | -------------------------------------------------------------------------------- /config/custom/add.conf: -------------------------------------------------------------------------------- 1 | model_name: "your_model_name" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "your_dataset_name" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama70b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama70b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama70b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama70b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mistral7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mistral7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mistral7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mistral7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mixtral8x7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mixtral8x7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mixtral8x7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/llama70b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mixtral8x7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/mask/llama70b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mistral7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" -------------------------------------------------------------------------------- /config/add/mistral7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/add/mixtral8x7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Unrelated text: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//add_.out" 15 | -------------------------------------------------------------------------------- /config/mask/llama70b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama70b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama70b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mistral7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mistral7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mixtral8x7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NewsQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mixtral8x7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "SQuAD" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama70b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/llama7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mistral7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mixtral8x7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "HotpotQA" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/mask/mixtral8x7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "TriviaQA-web" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/icl/llama70b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/mask/mistral7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/icl/mistral7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/mask/mixtral8x7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | dataset: "NaturalQuestionsShort" 10 | dataset_path: "data//conflict/.parquet" 11 | dataset_length: null 12 | 13 | results_dir: "data//ob" 14 | output_path: "results//mask_.out" 15 | 16 | masking_strategy: "input_tokens_space/1" 17 | -------------------------------------------------------------------------------- /config/icl/mixtral8x7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question concisely using the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/cb/llama70b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama70b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama70b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama70b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mistral7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mistral7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama70b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama70b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | 23 | -------------------------------------------------------------------------------- /config/cb/llama7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/llama7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mistral7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mistral7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mixtral8x7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mixtral8x7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mistral7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mixtral8x7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mixtral8x7b/searchqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "SearchQA" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/icl/llama70b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama70b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama70b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/cb/mistral7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mixtral8x7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/cb/mixtral8x7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/icl/.parquet" 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data/test/.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: null 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: null 20 | 21 | output_path: "results//cb_.out" 22 | -------------------------------------------------------------------------------- /config/icl/llama70b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mistral7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mistral7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama70b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/llama7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 5 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mistral7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mistral7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mixtral8x7b/newsqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "NewsQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mixtral8x7b/squad.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "SQuAD" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mixtral8x7b/triviaqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible through the information given in the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "TriviaQA-web" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mixtral8x7b/hotpotqa.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "HotpotQA" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mistral7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/icl/mixtral8x7b/nq.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question with as few words as possible by extracting information directly from the context.Context: Question: Answer:" 6 | metric_name: "BEM" 7 | sameness_metric: "EM" 8 | 9 | icl_demo_prompt: "Context: Question: Answer: " 10 | icl_n: 8 11 | 12 | dataset: "NaturalQuestionsShort" 13 | dataset_path: "data//conflict/.parquet" 14 | dataset_length: null 15 | 16 | results_dir: "data//icl" 17 | output_path: "results//icl_.out" 18 | -------------------------------------------------------------------------------- /config/custom/cb.conf: -------------------------------------------------------------------------------- 1 | model_name: "your_model_name" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/custom/icl_.parquet" 11 | 12 | dataset: "your_dataset_name" 13 | dataset_path: "data/custom/test_.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: "data//cb_correct/.parquet" 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: "data//cb_full/.parquet" 20 | 21 | output_path: "results//cb_.out" 22 | 23 | -------------------------------------------------------------------------------- /config/freshqa/llama70b.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-70b-chat-hf" 2 | quantized: True 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/freshqa/icl_freshqa.parquet" 11 | 12 | dataset: "freshqa" 13 | dataset_path: "data/freshqa/changing_freshqa.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: "data//cb_correct/.parquet" 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: "data//cb_full/.parquet" 20 | 21 | output_path: "results//.out" 22 | 23 | -------------------------------------------------------------------------------- /config/freshqa/llama7b.conf: -------------------------------------------------------------------------------- 1 | model_name: "meta-llama/Llama-2-7b-chat-hf" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/freshqa/icl_freshqa.parquet" 11 | 12 | dataset: "freshqa" 13 | dataset_path: "data/freshqa/changing_freshqa.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: "data//cb_correct/.parquet" 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: "data//cb_full/.parquet" 20 | 21 | output_path: "results//.out" 22 | 23 | -------------------------------------------------------------------------------- /config/freshqa/mistral7b.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mistral-7B-Instruct-v0.2" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/freshqa/icl_freshqa.parquet" 11 | 12 | dataset: "freshqa" 13 | dataset_path: "data/freshqa/changing_freshqa.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: "data//cb_correct/.parquet" 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: "data//cb_full/.parquet" 20 | 21 | output_path: "results//.out" 22 | 23 | -------------------------------------------------------------------------------- /config/freshqa/mixtral8x7b.conf: -------------------------------------------------------------------------------- 1 | model_name: "mistralai/Mixtral-8x7B-Instruct-v0.1" 2 | quantized: False 3 | model_parallelism: True 4 | 5 | custom_prompt: "Answer the question.Question: Answer:" 6 | metric_name: BEM 7 | 8 | icl_demo_prompt: "Question: Answer: " 9 | icl_n: 10 10 | icl_dataset_path: "data/freshqa/icl_freshqa.parquet" 11 | 12 | dataset: "freshqa" 13 | dataset_path: "data/freshqa/changing_freshqa.parquet" 14 | dataset_length: null 15 | 16 | 17 | correct_examples_path: "data//cb_correct/.parquet" 18 | wrong_examples_path: "data//cb_wrong/.parquet" 19 | full_examples_path: "data//cb_full/.parquet" 20 | 21 | output_path: "results//.out" 22 | 23 | -------------------------------------------------------------------------------- /src/validation.py: -------------------------------------------------------------------------------- 1 | def ensure_string_fields(dataset, fields): 2 | """Ensure that all examples in the dataset have all the fields as strings 3 | 4 | If the field is not present, it is added with an empty string. 5 | 6 | Mutates the dataset. 7 | """ 8 | for example in dataset: 9 | for field in fields: 10 | value_to_write = str(example.get(field, "")) 11 | example[field] = value_to_write 12 | 13 | 14 | def assert_fields_exist(dataset, fields): 15 | """Assert that all examples in the dataset have all the fields. 16 | 17 | If field is a list then any of the fields should exist. 18 | """ 19 | for i, example in enumerate(dataset): 20 | for field in fields: 21 | if isinstance(field, list): 22 | assert any([f in example for f in field]), f"None of {field} is in example {i}: {example}" 23 | else: 24 | assert field in example, f"Field {field} not in example {i}: {example}" 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | aiohttp==3.9.3 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | bitsandbytes==0.43.0 7 | certifi==2024.2.2 8 | charset-normalizer==3.3.2 9 | datasets==2.18.0 10 | dill==0.3.8 11 | filelock==3.13.3 12 | frozenlist==1.4.1 13 | fsspec==2024.2.0 14 | huggingface-hub==0.22.1 15 | idna==3.6 16 | Jinja2==3.1.3 17 | jusText==3.0.0 18 | lxml==5.2.1 19 | lxml_html_clean==0.1.1 20 | MarkupSafe==2.1.5 21 | mpmath==1.3.0 22 | multidict==6.0.5 23 | multiprocess==0.70.16 24 | munch==4.0.0 25 | networkx==3.2.1 26 | numpy==1.26.4 27 | nvidia-cublas-cu12==12.1.3.1 28 | nvidia-cuda-cupti-cu12==12.1.105 29 | nvidia-cuda-nvrtc-cu12==12.1.105 30 | nvidia-cuda-runtime-cu12==12.1.105 31 | nvidia-cudnn-cu12==8.9.2.26 32 | nvidia-cufft-cu12==11.0.2.54 33 | nvidia-curand-cu12==10.3.2.106 34 | nvidia-cusolver-cu12==11.4.5.107 35 | nvidia-cusparse-cu12==12.1.0.106 36 | nvidia-nccl-cu12==2.19.3 37 | nvidia-nvjitlink-cu12==12.4.99 38 | nvidia-nvtx-cu12==12.1.105 39 | packaging==24.0 40 | pandas==2.2.1 41 | protobuf==5.26.1 42 | psutil==5.9.8 43 | pyarrow==15.0.2 44 | pyarrow-hotfix==0.6 45 | python-dateutil==2.9.0.post0 46 | pytz==2024.1 47 | PyYAML==6.0.1 48 | regex==2023.12.25 49 | requests==2.31.0 50 | safetensors==0.4.2 51 | scipy==1.13.0 52 | sentencepiece==0.2.0 53 | six==1.16.0 54 | sympy==1.12 55 | tokenizers==0.15.2 56 | torch==2.2.2 57 | tqdm==4.66.2 58 | transformers==4.39.2 59 | triton==2.2.0 60 | typing_extensions==4.10.0 61 | tzdata==2024.1 62 | urllib3==2.2.1 63 | xxhash==3.4.1 64 | yarl==1.9.4 65 | -------------------------------------------------------------------------------- /src/evaluation/icl.py: -------------------------------------------------------------------------------- 1 | import src.file_utils as file_utils 2 | 3 | ICL_SEPARATOR = "" 4 | 5 | def format_icl_demonstration(icl_demo_prompt, icl_demonstration_example): 6 | question = icl_demonstration_example["question"] 7 | if not question.endswith("?"): 8 | question = question + "?" 9 | icl_demo_prompt = icl_demo_prompt.replace("", question) 10 | icl_demo_prompt = icl_demo_prompt.replace("", icl_demonstration_example["context"]) 11 | 12 | cb_answer = icl_demonstration_example.get("closedbook_answer", "") 13 | if "" in icl_demo_prompt and not cb_answer: 14 | print("WARNING: in prompt but no closedbook_answer in example") 15 | icl_demo_prompt = icl_demo_prompt.replace("", cb_answer) 16 | 17 | icl_demo_prompt = icl_demo_prompt.replace("", icl_demonstration_example["answers"][0]) 18 | return icl_demo_prompt 19 | 20 | 21 | def prepare_prompt_for_icl( 22 | original_prompt, icl_demo_prompt, icl_n, icl_dataset_path 23 | ): 24 | assert "" in original_prompt 25 | icl_dataset = file_utils.load_parquet_dataset(icl_dataset_path) 26 | icl_demonstration_examples = icl_dataset[:icl_n] 27 | 28 | icl_demo_strings = [ 29 | format_icl_demonstration(icl_demo_prompt, ex) 30 | for ex in icl_demonstration_examples 31 | ] 32 | full_icl_demonstration = ICL_SEPARATOR.join(icl_demo_strings) 33 | original_prompt = original_prompt.replace("", full_icl_demonstration) 34 | return original_prompt 35 | 36 | 37 | -------------------------------------------------------------------------------- /src/evaluation/true.py: -------------------------------------------------------------------------------- 1 | import transformers as tf 2 | 3 | 4 | class TrueNLIClassifier: 5 | """Entailment classifier from paper 6 | 7 | "TRUE: Re-evaluating Factual Consistency Evaluation" 8 | 9 | 10 | """ 11 | MODEL_URL = "google/t5_xxl_true_nli_mixture" 12 | 13 | def __init__(self): 14 | model_args = {} 15 | model_args["device_map"] = "auto" 16 | 17 | self.tokenizer = tf.T5Tokenizer.from_pretrained(self.MODEL_URL) 18 | self.model = tf.T5ForConditionalGeneration.from_pretrained(self.MODEL_URL, **model_args) 19 | 20 | 21 | @staticmethod 22 | def format_example_for_autoais(context, question, answer): 23 | premise = context 24 | hypothesis = f"The answer to the question '{question}' is '{answer}'" 25 | return f"premise: {premise} hypothesis: {hypothesis}" 26 | 27 | 28 | def infer_entailment(self, context, question, answer): 29 | """Runs inference for assessing AIS between a premise and hypothesis. 30 | 31 | Args: 32 | example: Dict with the example data. 33 | tokenizer: A huggingface tokenizer object. 34 | model: A huggingface model object. 35 | 36 | Returns: 37 | A string representing the model prediction. 38 | """ 39 | input_text = self.format_example_for_autoais(context, question, answer) 40 | input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.model.device) 41 | outputs = self.model.generate(input_ids) 42 | result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) 43 | entailment = True if result == "1" else False 44 | return entailment 45 | -------------------------------------------------------------------------------- /2_filter_out_no_conflict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import munch 3 | from tqdm import tqdm 4 | import yaml 5 | 6 | import src.file_utils as file_utils 7 | import src.evaluation.true as true 8 | 9 | 10 | def divide_dataset_into_nli_true_false(data): 11 | print(f"Loading TRUE NLI classifier") 12 | nli_classifier = true.TrueNLIClassifier() 13 | print(f"Loaded NLI classifier on device {nli_classifier.model.device_map}") 14 | 15 | nli_true, nli_false = [], [] 16 | for example in tqdm(data): 17 | question = example["question"] 18 | context = example["context"] 19 | answer = example["closedbook_answer"] 20 | pred_entail = nli_classifier.infer_entailment(context, question, answer) 21 | example["nli_pred"] = pred_entail 22 | if pred_entail: 23 | nli_true.append(example) 24 | else: 25 | nli_false.append(example) 26 | return nli_true, nli_false 27 | 28 | 29 | def filter_examples_by_true(config): 30 | wrong_cb_path = config.wrong_examples_path.replace("", config.model_name).replace("", config.dataset) 31 | print(f"Loading wrong closed book examples from {wrong_cb_path}") 32 | wrong_cb_data = file_utils.load_parquet_dataset(wrong_cb_path) 33 | 34 | no_conflict_data, conflict_data = divide_dataset_into_nli_true_false(wrong_cb_data) 35 | 36 | conflict_path = config.conflict_examples_path.replace("", config.model_name).replace("", config.dataset) 37 | print(f"Saving examples with knowledge conflict to {conflict_path}") 38 | file_utils.save_dataset_to_parquet(conflict_data, conflict_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | 44 | parser.add_argument("--config", type=str, required=True, help="Path to the config file.") 45 | 46 | args = parser.parse_args() 47 | 48 | with open(args.config, 'r') as file: 49 | config = yaml.safe_load(file) 50 | config = munch.munchify(config) 51 | 52 | filter_examples_by_true(config) 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /src/analysis.py: -------------------------------------------------------------------------------- 1 | import scipy.stats as stats 2 | 3 | def binomial_hypothesis_test(m_1, n_1, m_0, n_0): 4 | """Bayesian hypothesis testing with binomial likelihood and beta prior. 5 | 6 | Args: 7 | m_1 (int): Number of successes in the first group. 8 | n_1 (int): Number of trials in the first group. 9 | m_0 (int): Number of successes in the second group. 10 | n_0 (int): Number of trials in the second group. 11 | 12 | Returns: 13 | float: p-value (probability of observing such or more extreme data given the null hypothesis). 14 | 15 | We model each group as a sequence of Bernoulli trials with unknown success probability. 16 | We treat the success probability as a random variable p with a beta(1, 1) (uniform) prior. 17 | 18 | The likelihood of the data given the success probability is binomial. 19 | The posterior distribution of the success probability is beta with parameters alpha = m + 1 and beta = n - m + 1. 20 | 21 | We marginalize over the posterior distribution of the success probability to obtain the predictive distribution of the data. 22 | The predictive distribution is a beta-binomial distribution. 23 | 24 | The null hypothesis is that the success probability is the same in both groups. 25 | We calculate the probability of observing such or more extreme data given the null hypothesis. 26 | If p_1 > p_0 then we calculate the probability of observing m_1 or more successes in n_1 trials 27 | given m_0 successes in n_0 trials. 28 | 29 | """ 30 | if n_0 == 0 or n_1 == 0: 31 | print("Number of trials in each group must be greater than 0.") 32 | return -1 33 | 34 | # Map estimate of p_1 under beta posterior 35 | p_0 = m_0 / n_0 36 | p_1 = m_1 / n_1 37 | 38 | # Predictive distribution of the data given the null hypothesis 39 | pred_dist = stats.betabinom(n=n_1, a=m_0 + 1, b=n_0 - m_0 + 1) 40 | 41 | # Probability of observing such or more extreme data given the null hypothesis 42 | if p_1 > p_0: 43 | # -1 because cdf is P(X <= x) and we want P(X >= x) 44 | p_value = 1 - pred_dist.cdf(m_1 - 1) 45 | else: 46 | p_value = pred_dist.cdf(m_1) 47 | 48 | return p_value -------------------------------------------------------------------------------- /src/evaluation/bem.py: -------------------------------------------------------------------------------- 1 | import transformers as tf 2 | import torch 3 | 4 | class BemMetric: 5 | """BEM metric from the paper 6 | "Tomayto, Tomahto. Beyond Token-level Answer Equivalence for Question Answering Evaluation" 7 | """ 8 | 9 | MODEL_URL = "kortukov/answer-equivalence-bem" 10 | def __init__(self): 11 | self.tokenizer = tf.AutoTokenizer.from_pretrained(self.MODEL_URL) 12 | self.model = tf.AutoModelForSequenceClassification.from_pretrained(self.MODEL_URL) 13 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 14 | self.model.to(self.device) 15 | 16 | def tokenize_function(self, question, reference, candidate): 17 | text = f"[CLS] {candidate} [SEP]" 18 | text_pair = f"{reference} [SEP] {question} [SEP]" 19 | inputs = self.tokenizer( 20 | text=text, 21 | text_pair=text_pair, 22 | add_special_tokens=False, 23 | padding='max_length', 24 | truncation=True, 25 | return_tensors='pt' 26 | ) 27 | return inputs.to(self.device) 28 | 29 | 30 | def compute_bem_score(self, question, reference, candidate): 31 | inputs = self.tokenize_function(question, reference, candidate) 32 | out = self.model(**inputs) 33 | 34 | bem_score = torch.nn.functional.softmax(out.logits, dim=-1)[0,1].item() 35 | return bem_score 36 | 37 | def correct_by_bem(self, question, reference, candidate): 38 | return self.compute_bem_score(question, reference, candidate) > 0.5 39 | 40 | def correct_by_disjunction_bem(self, question, reference, candidate): 41 | """Checks if candidate is equivalent to reference or reference is equivalent to candidate 42 | 43 | Answer equivalence is originally a asymmetric relation. Candidate is equivalent if it contains same or better information. 44 | In our setting we empirically found that disjuncted symmetric relation works better. 45 | """ 46 | return self.correct_by_bem(question, reference, candidate) or self.correct_by_bem(question, candidate, reference) 47 | 48 | def any_correct_by_bem(self, question, references, candidate): 49 | return any([self.correct_by_bem(question, reference, candidate) for reference in references]) 50 | 51 | def any_correct_by_disjunction_bem(self, question, references, candidate): 52 | return any([self.correct_by_disjunction_bem(question, reference, candidate) for reference in references]) 53 | -------------------------------------------------------------------------------- /4_download_freshqa.py: -------------------------------------------------------------------------------- 1 | import justext 2 | import os 3 | import pandas as pd 4 | import requests 5 | 6 | RAW_FRESHQA_URL = "https://docs.google.com/spreadsheet/ccc?key=1V6nIxVTI9tqZ-wfgK-uFuUPiGEa1Zmnz53OeGbaNtO0&output=csv" 7 | RAW_FRESHQA_PATH = "data/freshqa/raw_freshqa.csv" 8 | CHANGING_FRESHQA_PATH = "data/freshqa/changing_freshqa.parquet" 9 | ICL_FRESHQA_PATH = "data/freshqa/icl_freshqa.parquet" 10 | 11 | 12 | def save_csv(): 13 | print("Downloading FreshQA CSV") 14 | response = requests.get(RAW_FRESHQA_URL) 15 | assert response.status_code == 200, f"Failed to download FreshQA CSV, response {response.status_code}" 16 | 17 | print("Saving FreshQA CSV") 18 | os.makedirs(os.path.dirname(RAW_FRESHQA_PATH), exist_ok=True) 19 | with open(RAW_FRESHQA_PATH, "wb") as f: 20 | f.write(response.content) 21 | 22 | 23 | def get_text(urls, idx=[0]): 24 | 25 | # Is this neat or is this horrible? I like it. 26 | print(idx[0]) 27 | idx[0]+=1 28 | 29 | total_text = "" 30 | 31 | if not urls or pd.isna(urls): 32 | return total_text 33 | urls = urls.split('\n') 34 | 35 | for url in urls: 36 | try: 37 | response = requests.get(url, timeout=10) 38 | paragraphs = justext.justext(response.content, justext.get_stoplist("English")) 39 | url_text = "\n".join([p.text for p in paragraphs if not p.is_boilerplate]) 40 | total_text += url_text + "\n" 41 | except: 42 | print(f"Error with url: {url}") 43 | 44 | return total_text 45 | 46 | 47 | def process_freshqa(): 48 | print("Processing FreshQA data") 49 | 50 | print("Filling in answers") 51 | freshqa = pd.read_csv(RAW_FRESHQA_PATH, skiprows=[0,1]) 52 | answer_columns = [col for col in freshqa.columns if 'answer' in col] 53 | freshqa['answers'] = freshqa[answer_columns].apply(lambda row: [a for a in row if pd.notna(a)], axis=1) 54 | 55 | print("Filling in context from the web source") 56 | freshqa['context'] = freshqa.source.apply(get_text) 57 | 58 | no_empty_context = freshqa[freshqa.context.apply(lambda x: len(x) > 0)] 59 | 60 | icl_examples = no_empty_context.sample(10, random_state=42) 61 | 62 | full_freshqa = no_empty_context[~no_empty_context.index.isin(icl_examples.index)] 63 | freshqa_slow = full_freshqa[full_freshqa.fact_type == "slow-changing"] 64 | freshqa_fast = full_freshqa[full_freshqa.fact_type == "fast-changing"] 65 | 66 | freshqa_changing = pd.concat([freshqa_slow, freshqa_fast]) 67 | 68 | print("Saving data") 69 | icl_examples.to_parquet(ICL_FRESHQA_PATH) 70 | freshqa_changing.to_parquet(CHANGING_FRESHQA_PATH) 71 | 72 | if __name__ == "__main__": 73 | save_csv() 74 | 75 | process_freshqa() -------------------------------------------------------------------------------- /0_download_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datasets 3 | import os 4 | 5 | DATA_DIR = "data" 6 | 7 | 8 | def remove_tags_from_ctx(example): 9 | context = example['context'] 10 | context = context.replace("[PAR]", "\n") 11 | 12 | context = context.replace("[DOC]", "") 13 | context = context.replace("[TLE]", "") 14 | context = context.replace("[SEP]", "") 15 | context = context.strip() 16 | example['context'] = context 17 | return example 18 | 19 | 20 | def download_custom_dataset(custom_dataset_name, data_dir): 21 | full_data = datasets.load_dataset(custom_dataset_name) 22 | 23 | # For ICL we shuffle the original data and only save 10 examples. 24 | shuffled_data = full_data.shuffle(seed=42) 25 | print("We use 10 random examples for ICL dataset.") 26 | icl_data = shuffled_data.select(range(10)) 27 | # For test data we use the remaining examples 28 | test_data = shuffled_data.select(range(10, len(shuffled_data))) 29 | 30 | icl_path = data_dir + f"/icl_{custom_dataset_name}.parquet" 31 | test_path = data_dir + f"/test_{custom_dataset_name}.parquet" 32 | 33 | icl_data.to_parquet(icl_path) 34 | test_data.to_parquet(test_path) 35 | 36 | 37 | def download_data(args): 38 | data_dir = f"{DATA_DIR}/{args.dataset_type}" 39 | os.makedirs(data_dir, exist_ok=True) 40 | 41 | if args.dataset_type == "custom": 42 | download_custom_dataset(args.custom_dataset_name, data_dir) 43 | return 44 | elif args.dataset_type == "test": 45 | # We use MrQA validation split as our test data 46 | dataset_split = "validation" 47 | else: 48 | # MrQA train split acts as ICL dataset in our experiments 49 | dataset_split = "train" 50 | 51 | print(f"Downloading {args.dataset_type} data.") 52 | full_data = datasets.load_dataset("mrqa", split=dataset_split) 53 | split_subset_names = list(set(full_data['subset'])) 54 | 55 | for subset in split_subset_names: 56 | print(f"Processing {subset} subset.") 57 | subset_data = full_data.filter(lambda ex: ex['subset'] == subset) 58 | 59 | if args.dataset_type == "icl": 60 | # For ICL we shuffle the original data and only save 10 examples. 61 | shuffled_data = subset_data.shuffle(seed=42) 62 | subset_data = shuffled_data.select(range(10)) 63 | 64 | subset_data = subset_data.map(remove_tags_from_ctx) 65 | subset_split_path = data_dir + f"/{subset}.parquet" 66 | 67 | subset_data.to_parquet(subset_split_path) 68 | 69 | 70 | if __name__ == "__main__": 71 | parser = argparse.ArgumentParser() 72 | 73 | parser.add_argument( 74 | "--dataset-type", type=str, required=True, 75 | choices=["test", "icl", "custom"], help="Which dataset to download." 76 | ) 77 | 78 | parser.add_argument( 79 | "--custom-dataset-name", type=str, default=None, 80 | help="Huggingface hub id of the custom dataset." 81 | ) 82 | 83 | args = parser.parse_args() 84 | 85 | download_data(args) 86 | 87 | -------------------------------------------------------------------------------- /src/evaluation/exact_match.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import re 3 | import string 4 | 5 | import src.model_utils as model_utils 6 | 7 | 8 | def normalize_answer(answer_string): 9 | """Lower text and remove punctuation, articles and extra whitespace.""" 10 | def remove_articles(text): 11 | return re.sub(r"\b(a|an|the)\b", " ", text) 12 | 13 | def white_space_fix(text): 14 | return " ".join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return "".join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(answer_string)))) 24 | 25 | def normalize_without_lowering(answer_string): 26 | """Remove punctuation, articles and extra whitespace.""" 27 | def remove_articles(text): 28 | return re.sub(r"\b(a|an|the)\b", " ", text) 29 | 30 | def white_space_fix(text): 31 | return " ".join(text.split()) 32 | 33 | def remove_punc(text): 34 | exclude = set(string.punctuation) 35 | return "".join(ch for ch in text if ch not in exclude) 36 | 37 | return white_space_fix(remove_articles(remove_punc(answer_string))) 38 | 39 | def text_has_answer(answers, text) -> bool: 40 | """Check if any of the answers is in the text.""" 41 | if isinstance(answers, str): 42 | answers = [answers] 43 | text = normalize_answer(text) 44 | for single_answer in answers: 45 | single_answer = normalize_answer(single_answer) 46 | if single_answer in text: 47 | return True 48 | return False 49 | 50 | 51 | def compute_exact_match(prediction, ground_truth): 52 | """Check if normalized prediction is same as normalized ground truth.""" 53 | return normalize_answer(prediction) == normalize_answer(ground_truth) 54 | 55 | 56 | def get_generated_answer(outputs, tokenizer, input_len): 57 | """Get the generated answer from the model outputs.""" 58 | 59 | generation_str = tokenizer.decode( 60 | outputs[0, input_len:].cpu(), skip_special_tokens=True 61 | ) 62 | answer = generation_str.split("\n")[0] 63 | return answer 64 | 65 | 66 | def check_if_model_output_is_correct_and_get_prediction( 67 | outputs, tokenizer, input_len, answers 68 | ): 69 | """Check if the model output is correct. 70 | 71 | Outputs is full model generation, including both the prompt and the 72 | generated text. Length of original input in tokens is given by input_len. 73 | """ 74 | prediction = get_generated_answer(outputs, tokenizer, input_len) 75 | is_correct = any( 76 | [compute_exact_match(prediction, answer) for answer in answers] 77 | ) 78 | return is_correct, prediction 79 | 80 | 81 | def compute_f1(a_gold, a_pred, tokenizer): 82 | """Computes the token-level F1 score of two utterances, a_gold and a_pred. 83 | 84 | Code taken from SQuAD evaluation script: https://rajpurkar.github.io/SQuAD-explorer/ 85 | """ 86 | gold_toks = tokenizer.encode(a_gold)[1:] # Skip start of sequence token 87 | pred_toks = tokenizer.encode(a_pred)[1:] 88 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 89 | num_same = sum(common.values()) 90 | if len(gold_toks) == 0 or len(pred_toks) == 0: 91 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 92 | return int(gold_toks == pred_toks) 93 | if num_same == 0: 94 | return 0 95 | precision = 1.0 * num_same / len(pred_toks) 96 | recall = 1.0 * num_same / len(gold_toks) 97 | f1 = (2 * precision * recall) / (precision + recall) 98 | return f1 -------------------------------------------------------------------------------- /src/closedbook_exp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import munch 3 | import os 4 | 5 | import src.evaluation.icl as icl 6 | import src.evaluation.question_answering as question_answering 7 | import src.validation as validation 8 | import src.file_utils as file_utils 9 | import src.model_utils as model_utils 10 | 11 | 12 | def run_closed_book_experiment(exp_config: munch.Munch): 13 | """Stage 1: Closed-book answer gathering. 14 | 15 | We run an LLM on a dataset closed-book to probe for its parametric knowledge. 16 | We also save the answers to later determine when a knowledge conflict between 17 | parametric and contextual knowledge occurs. 18 | """ 19 | custom_prompt = exp_config.custom_prompt 20 | 21 | logging_dict = {} 22 | def log_metric(metric_name, value): 23 | logging_dict[metric_name] = value 24 | 25 | if "metric_name" not in exp_config or not exp_config.metric_name: 26 | exp_config.metric_name = "EM" 27 | 28 | assert exp_config.metric_name in ["EM", "BEM"] 29 | 30 | model, tokenizer, _, device = model_utils.load_model_and_tokenizer( 31 | model_name=exp_config.model_name, 32 | model_parallelism=exp_config.model_parallelism, 33 | quantized=exp_config.quantized, 34 | ) 35 | 36 | file_utils.replace_placeholders_in_paths( 37 | exp_config, 38 | path_keys=["dataset_path", "correct_examples_path", "wrong_examples_path", "full_examples_path", "icl_dataset_path", "output_path"] 39 | ) 40 | 41 | dataset = file_utils.load_parquet_dataset(exp_config.dataset_path) 42 | 43 | if "dataset_length" in exp_config and exp_config.dataset_length: 44 | dataset = dataset[:exp_config.dataset_length] 45 | print(f"Using only {exp_config.dataset_length} examples from the dataset") 46 | 47 | validation.assert_fields_exist( 48 | dataset=dataset, 49 | fields=["context", "question", "answers"], 50 | ) 51 | validation.ensure_string_fields( 52 | dataset=dataset, 53 | fields=["context", "question", "contextual_answer"], 54 | ) 55 | 56 | if "" in custom_prompt: 57 | assert "icl_demo_prompt" in exp_config 58 | assert "icl_n" in exp_config 59 | assert "icl_dataset_path" in exp_config 60 | custom_prompt = icl.prepare_prompt_for_icl( 61 | custom_prompt, exp_config.icl_demo_prompt, exp_config.icl_n, exp_config.icl_dataset_path 62 | ) 63 | 64 | correct_ratio, correct_examples, wrong_examples, additional_results = question_answering.evaluate_closed_book( 65 | model, tokenizer, dataset, custom_prompt, device, exp_config.metric_name 66 | ) 67 | log_metric(exp_config.metric_name, correct_ratio) 68 | 69 | correct_pct = correct_ratio * 100 70 | 71 | log_metric("Full data", len(dataset)) 72 | log_metric("Closed-book correct", f"{len(correct_examples)} ({correct_pct:.2f}%)") 73 | log_metric("Closed-book wrong", f"{len(wrong_examples)} ({100 - correct_pct:.2f}%)") 74 | 75 | log_metric("Parametric answer in context", additional_results["cb_in_ctx_ratio"]) 76 | log_metric("Incorrect out of parametric in context", additional_results["incorrect_given_cb_in_ctx"]) 77 | 78 | print(f"{exp_config.metric_name}: {correct_pct:.2f}%") 79 | print(f"{len(correct_examples)} / {len(dataset)} correct examples") 80 | 81 | if exp_config.correct_examples_path: 82 | file_utils.save_dataset_to_parquet( 83 | correct_examples, exp_config.correct_examples_path 84 | ) 85 | 86 | if exp_config.wrong_examples_path: 87 | file_utils.save_dataset_to_parquet( 88 | wrong_examples, exp_config.wrong_examples_path 89 | ) 90 | 91 | if exp_config.full_examples_path: 92 | file_utils.save_dataset_to_parquet( 93 | dataset, exp_config.full_examples_path 94 | ) 95 | 96 | if exp_config.output_path: 97 | # Ensure parent dir exists 98 | os.makedirs(os.path.dirname(exp_config.output_path), exist_ok=True) 99 | 100 | # Pretty print the logging dict as json to the output path 101 | with open(exp_config.output_path, "w") as f: 102 | f.write(json.dumps(logging_dict, indent=4)) 103 | else: 104 | print("Output path, not specified") 105 | print(f"Results: ") 106 | print(json.dumps(logging_dict, indent=4)) 107 | 108 | 109 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import time 4 | import torch 5 | import transformers as tf 6 | 7 | 8 | 9 | 10 | def load_tokenizer(model_name): 11 | """Load tokenizer from model_name""" 12 | if "llama" in model_name or "Llama" in model_name: 13 | return tf.LlamaTokenizer.from_pretrained(model_name) 14 | return tf.AutoTokenizer.from_pretrained(model_name) 15 | 16 | 17 | def load_model_and_tokenizer( 18 | model_name, model_parallelism=False, cache_dir=None, auth_token=None, quantized=False, 19 | ): 20 | """Load model and tokenizer from model_name""" 21 | device = "cuda" if torch.cuda.is_available() else "cpu" 22 | device_count = torch.cuda.device_count() 23 | 24 | 25 | config = tf.AutoConfig.from_pretrained(model_name) 26 | model_args = {} 27 | if cache_dir is not None: 28 | model_args["cache_dir"] = cache_dir 29 | if model_parallelism: 30 | model_args["device_map"] = "auto" 31 | model_args["low_cpu_mem_usage"] = True 32 | if hasattr(config, "torch_dtype") and config.torch_dtype is not None: 33 | model_args["torch_dtype"] = config.torch_dtype 34 | if auth_token is not None: 35 | model_args["use_auth_token"] = auth_token 36 | if quantized: 37 | quant_config = tf.BitsAndBytesConfig( 38 | load_in_4bit=True, 39 | bnb_4bit_use_double_quant=True, # Save 0.4 bits per parameter 40 | bnb_4bit_quant_type="nf4", # Recommended by https://huggingface.co/blog/4bit-transformers-bitsandbytes 41 | bnb_4bit_compute_dtype=torch.bfloat16, # bfloat 8b range, 7b precision, for large numbers (source https://huggingface.co/blog/hf-bitsandbytes-integration) 42 | ) 43 | model_args["quantization_config"] = quant_config 44 | 45 | if "gemma" in model_name: 46 | torch.backends.cuda.enable_mem_efficient_sdp(False) 47 | torch.backends.cuda.enable_flash_sdp(False) 48 | 49 | if "GPTQ" in model_name: 50 | # I use model_name = "TheBloke/Llama-2-7b-Chat-GPTQ" 51 | model = tf.AutoModelForCausalLM.from_pretrained( 52 | model_name, 53 | device_map="auto", 54 | trust_remote_code=True, 55 | revision="gptq-3bit-128g-actorder_True", 56 | ).to(device) 57 | tokenizer = tf.AutoTokenizer.from_pretrained(model_name, use_fast=True) 58 | else: 59 | model = tf.AutoModelForCausalLM.from_pretrained( 60 | model_name, **model_args 61 | ).eval() 62 | if not model_parallelism and not quantized: 63 | model = model.to(device) 64 | tokenizer = load_tokenizer(model_name) 65 | 66 | if device_count > 1 and not model_parallelism: 67 | model = torch.nn.DataParallel(model) 68 | 69 | if model_parallelism: 70 | print(f"Device map: {model.hf_device_map}") 71 | 72 | return model, tokenizer, config, device 73 | 74 | 75 | def generate_one_token(model, inputs): 76 | """Generate one token from the model.""" 77 | with torch.no_grad(): 78 | out = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask) 79 | next_token = out.logits[0, -1, :].argmax().item() 80 | return next_token 81 | 82 | 83 | def generate_one_token_with_activations(model, inputs, activations_to_collect): 84 | """Generate one token and return the token and the hidden states (activations)""" 85 | with torch.no_grad(): 86 | out = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, output_hidden_states=True) 87 | next_token = out.logits[0, -1, :].argmax().item() 88 | if activations_to_collect == "last": 89 | aggregated_activations = (h[:, -1, :] for h in out.hidden_states ) 90 | elif activations_to_collect == "mean": 91 | aggregated_activations = (h.mean(dim=1) for h in out.hidden_states) 92 | elif activations_to_collect == "max": 93 | aggregated_activations = (h.max(dim=1)[0] for h in out.hidden_states) 94 | else: 95 | raise ValueError(f"Unknown activations_to_collect parameter: {activations_to_collect}, should be one of 'last', 'mean', 'max'") 96 | stacked_activations = np.stack([h.cpu().detach().numpy() for h in aggregated_activations]) 97 | return next_token, stacked_activations 98 | 99 | 100 | def generate_answer( 101 | model, tokenizer, inputs, max_tokens_to_generate, device 102 | ): 103 | """Generate answer by one token at a time. 104 | 105 | Return the generated tokens. 106 | """ 107 | 108 | for i in range(max_tokens_to_generate): 109 | next_token= generate_one_token(model, inputs) 110 | 111 | if next_token == tokenizer.eos_token_id: 112 | break 113 | 114 | # Add the generated token to the input for the next iteration 115 | inputs.input_ids = torch.cat( 116 | [inputs.input_ids, torch.tensor([[next_token]], device=device)], dim=-1 117 | ) 118 | # Add attention mask for the new token 119 | inputs.attention_mask = torch.cat( 120 | [inputs.attention_mask, torch.tensor([[1]], device=device)], dim=-1 121 | ) 122 | 123 | # Now input ids contains both the prompt and the generated tokens 124 | outputs = inputs.input_ids 125 | return outputs 126 | -------------------------------------------------------------------------------- /src/file_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import munch 4 | import os 5 | import pandas as pd 6 | import pickle 7 | import yaml 8 | 9 | logger = logging.getLogger() 10 | logger.setLevel(logging.INFO) 11 | log_formatter = logging.Formatter( 12 | "[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s" 13 | ) 14 | console = logging.StreamHandler() 15 | console.setFormatter(log_formatter) 16 | logger.addHandler(console) 17 | 18 | 19 | def print_args(args, output_dir=None, output_file=None): 20 | """Logs the args to the console and optionally to a file.""" 21 | assert output_dir is None or output_file is None 22 | 23 | logger.info(" **************** CONFIGURATION **************** ") 24 | for key, val in sorted(vars(args).items()): 25 | keystr = "{}".format(key) + (" " * (30 - len(key))) 26 | logger.info("%s --> %s", keystr, val) 27 | logger.info(" **************** CONFIGURATION **************** ") 28 | 29 | if output_dir is not None or output_file is not None: 30 | output_file = output_file or os.path.join(output_dir, "args.txt") 31 | with open(output_file, "w", encoding="utf-8") as file: 32 | for key, val in sorted(vars(args).items()): 33 | keystr = f"{key}" + (" " * (30 - len(key))) 34 | file.write(f"{keystr} {val}\n") 35 | 36 | 37 | def load_parquet_dataset(dataset_path: str): 38 | """Load a dataset from parquet to a list of dicts""" 39 | data = pd.read_parquet(dataset_path) 40 | return data.to_dict(orient="records") 41 | 42 | 43 | def save_dataset_to_parquet(dataset, dataset_path: str): 44 | """Save a list of dicts to a parquet file""" 45 | print(f"Saving dataset of length {len(dataset)} to path {dataset_path}") 46 | # Create parent dirs if needed 47 | os.makedirs(os.path.dirname(dataset_path), exist_ok=True) 48 | 49 | dataframe = pd.DataFrame(dataset) 50 | dataframe.to_parquet(dataset_path) 51 | 52 | 53 | def load_csv_dataset(dataset_path: str): 54 | """Load a dataset from csv to a list of dicts""" 55 | data = pd.read_csv(dataset_path) 56 | return data.to_dict(orient="records") 57 | 58 | 59 | def save_dataset_to_csv(dataset, dataset_path: str): 60 | """Save a list of dicts to a csv file""" 61 | print(f"Saving dataset of length {len(dataset)} to path {dataset_path}") 62 | # Create parent dirs if needed 63 | os.makedirs(os.path.dirname(dataset_path), exist_ok=True) 64 | 65 | dataframe = pd.DataFrame(dataset) 66 | dataframe.to_csv(dataset_path) 67 | 68 | 69 | def load_json_dataset(dataset_path: str): 70 | """Loads a dataset from a json file""" 71 | print("Loading dataset:", dataset_path) 72 | with open(dataset_path, encoding="utf-8") as file: 73 | return json.load(file) 74 | 75 | 76 | def load_several_json_datasets(dataset_paths: "list[str]"): 77 | """Returns a dict mapping dataset names to datasets""" 78 | datasets = {} 79 | for dataset_path in dataset_paths: 80 | dataset_name = os.path.basename(dataset_path).replace(".json", "") 81 | datasets[dataset_name] = load_json_dataset(dataset_path) 82 | return datasets 83 | 84 | 85 | def save_json_dataset(dataset, dataset_path: str): 86 | """Saves a dataset to a json file 87 | 88 | Dataset must be json-serializable. It must be a list of dicts, 89 | where each dict is a datapoint. 90 | """ 91 | print(f"Saving dataset of length {len(dataset)} to path {dataset_path}") 92 | # Create parent dirs if needed 93 | os.makedirs(os.path.dirname(dataset_path), exist_ok=True) 94 | with open(dataset_path, 'w', encoding="utf-8") as file: 95 | json.dump(dataset, file) 96 | 97 | 98 | def save_to_pickle(object_to_pickle, path): 99 | """Saves an object to a pickle file""" 100 | # Create parent dirs if needed 101 | os.makedirs(os.path.dirname(path), exist_ok=True) 102 | with open(path, "wb") as f: 103 | pickle.dump(object_to_pickle, f) 104 | 105 | 106 | def load_yaml_config_as_munch(config_path: str): 107 | """Load a yaml config file as a munch object. 108 | Munch is just a dict that allows you to access keys as attributes. 109 | """ 110 | with open(config_path, 'r', encoding="utf-8") as file: 111 | yaml_config = yaml.safe_load(file) 112 | munch_config = munch.munchify(yaml_config) 113 | return munch_config 114 | 115 | 116 | def dummy_logger(*args, **kwargs): 117 | """A dummy logger that does nothing""" 118 | pass 119 | 120 | def replace_placeholders_in_paths(exp_config: munch.Munch, path_keys: list): 121 | """Replace placeholders , and in the 122 | specified path_keys with actual values from exp_config. 123 | 124 | This is useful when we want to run the same experiment on different splits 125 | of the dataset, or with different models. 126 | 127 | Mutates exp_config inplace. 128 | """ 129 | 130 | if "dataset_split" in exp_config: 131 | for key in path_keys: 132 | if key not in exp_config: 133 | continue 134 | if exp_config[key] is None: 135 | continue 136 | 137 | exp_config[key] = exp_config[key].replace( 138 | "", exp_config.dataset_split 139 | ) 140 | 141 | 142 | if "model_name" in exp_config: 143 | for key in path_keys: 144 | if key not in exp_config: 145 | continue 146 | if exp_config[key] is None: 147 | continue 148 | 149 | exp_config[key] = exp_config[key].replace( 150 | "", exp_config.model_name 151 | ) 152 | 153 | if "dataset" in exp_config: 154 | for key in path_keys: 155 | if key not in exp_config: 156 | continue 157 | if exp_config[key] is None: 158 | continue 159 | 160 | exp_config[key] = exp_config[key].replace( 161 | "", exp_config.dataset 162 | ) -------------------------------------------------------------------------------- /src/evaluation/question_answering.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import src.evaluation.bem as bem 4 | import src.evaluation.exact_match as exact_match 5 | import src.evaluation.prompt_helpers as prompt_helpers 6 | import src.model_utils as model_utils 7 | 8 | 9 | def get_example_is_correct_fun(metric_name): 10 | """Get a function that checks if the model prediction is correct. 11 | 12 | Its parameters are question, context, answers, prediction. 13 | """ 14 | if metric_name == "EM": 15 | example_is_correct = lambda question, context, answers, prediction: any( 16 | [exact_match.compute_exact_match(prediction, answer) for answer in answers] 17 | ) 18 | elif metric_name == "BEM": 19 | bem_metric = bem.BemMetric() 20 | example_is_correct = lambda question, context, answers, prediction: bem_metric.any_correct_by_disjunction_bem( 21 | question, answers, prediction 22 | ) 23 | else: 24 | raise ValueError(f"Unknown metric name {metric_name}") 25 | return example_is_correct 26 | 27 | 28 | class ClosedBookExampleResult: 29 | """Result of one example closed-book evaluation""" 30 | def __init__(self, example, correct, generated_answer, cb_in_ctx): 31 | self.example = example 32 | self.correct = correct 33 | self.generated_answer = generated_answer 34 | self.cb_in_ctx = cb_in_ctx 35 | 36 | def __repr__(self): 37 | return f"ClosedBookExampleResult(correct={self.correct})" 38 | 39 | def evaluate_example_closed_book( 40 | model, tokenizer, example, custom_prompt, device, example_is_correct_fun 41 | ): 42 | """Evaluate a CausalLM model on a single example from the closed book 43 | dataset given a custom prompt. 44 | """ 45 | answers = example["answers"] 46 | 47 | inputs, input_text = prompt_helpers.get_input_ids_with_prompt( 48 | tokenizer, example, custom_prompt, device 49 | ) 50 | input_len = inputs.input_ids.shape[-1] 51 | 52 | outputs = model_utils.generate_answer( 53 | model, tokenizer, inputs, max_tokens_to_generate=10, device=device 54 | ) 55 | 56 | generated_answer = exact_match.get_generated_answer( 57 | outputs, tokenizer, input_len 58 | ) 59 | is_correct = example_is_correct_fun(example["question"], example["context"], answers, generated_answer) 60 | 61 | cb_in_ctx = prompt_helpers.tokenized_answer_found_in_model_inputs( 62 | answer=generated_answer, model_inputs=inputs, tokenizer=tokenizer, 63 | ) 64 | 65 | return ClosedBookExampleResult(example, is_correct, generated_answer, cb_in_ctx) 66 | 67 | 68 | def evaluate_closed_book( 69 | model, tokenizer, dataset, custom_prompt, device, metric_name 70 | ): 71 | """Evaluate a CausalLM model on the closed book dataset. 72 | Metric can be EM (Exact Match) or BEM (BERT Exact Match). 73 | 74 | Report the metric. 75 | Save the subset that is answered correctly. 76 | Also save the subset that is answered wrongly. 77 | For both of them, save the generated answer. 78 | """ 79 | num_correct = 0 80 | total_cost = 0.0 81 | correct_examples = [] 82 | wrong_examples = [] 83 | cb_in_ctx_correct = [] 84 | cb_in_ctx_wrong = [] 85 | 86 | example_is_correct_fun = get_example_is_correct_fun(metric_name) 87 | 88 | model.eval() 89 | for idx, example in tqdm(enumerate(dataset)): 90 | ex_result = evaluate_example_closed_book( 91 | model, tokenizer, example, custom_prompt, device, example_is_correct_fun 92 | ) 93 | 94 | example["closedbook_answer"] = ex_result.generated_answer 95 | 96 | if ex_result.correct: 97 | num_correct += 1 98 | correct_examples.append(example) 99 | if ex_result.cb_in_ctx: 100 | cb_in_ctx_correct.append(example) 101 | else: 102 | wrong_examples.append(example) 103 | if ex_result.cb_in_ctx: 104 | cb_in_ctx_wrong.append(example) 105 | 106 | num_cb_in_ctx = len(cb_in_ctx_correct) + len(cb_in_ctx_wrong) 107 | 108 | additional_results = { 109 | "cb_in_ctx_ratio": num_cb_in_ctx / len(dataset), 110 | "incorrect_given_cb_in_ctx": len(cb_in_ctx_wrong) / num_cb_in_ctx, 111 | } 112 | 113 | 114 | correct_pct = num_correct / len(dataset) 115 | return correct_pct, correct_examples, wrong_examples, additional_results 116 | 117 | 118 | class OpenBookExampleResult: 119 | """Result of one example open-book evaluation (with context)""" 120 | def __init__(self, example, correct, generated_answer, same_as_closed_book, cb_in_ctx, input_len): 121 | self.example = example 122 | self.correct = correct 123 | self.generated_answer = generated_answer 124 | self.same_as_closed_book = same_as_closed_book 125 | self.cb_in_ctx = cb_in_ctx 126 | self.input_len = input_len 127 | 128 | def __repr__(self): 129 | return f"OpenBookExampleResult(correct={self.correct}, same_as_closed_book={self.same_as_closed_book}, cb_in_ctx={self.cb_in_ctx})" 130 | 131 | 132 | def evaluate_example_openbook( 133 | model, 134 | tokenizer, 135 | example, 136 | custom_prompt, 137 | device, 138 | example_is_correct_fun, 139 | example_is_same_fun, 140 | masking_strategy : str = None, 141 | ) -> OpenBookExampleResult: 142 | """Evaluate one example open-book (with context).""" 143 | 144 | 145 | if "answers" in example: 146 | ctx_answers = example["answers"] 147 | else: 148 | raise ValueError(f"Example must contain the fields 'answers'") 149 | 150 | closed_book_answers = [example["closedbook_answer"]] 151 | 152 | inputs, input_text = prompt_helpers.get_input_ids_with_prompt( 153 | tokenizer, example, custom_prompt, device 154 | ) 155 | input_len = inputs.input_ids.shape[-1] 156 | 157 | if masking_strategy: 158 | inputs, input_text = prompt_helpers.mask_cb_answer(tokenizer, inputs, example, masking_strategy) 159 | 160 | cb_in_ctx = prompt_helpers.tokenized_answer_found_in_model_inputs( 161 | answer=closed_book_answers[0], model_inputs=inputs, tokenizer=tokenizer, 162 | ) 163 | 164 | outputs = model_utils.generate_answer( 165 | model, tokenizer, inputs, max_tokens_to_generate=10, device=device 166 | ) 167 | 168 | generated_answer = exact_match.get_generated_answer( 169 | outputs, tokenizer, input_len 170 | ) 171 | ctx_is_correct = example_is_correct_fun(example["question"], example["context"], ctx_answers, generated_answer) 172 | 173 | same_as_closed_book = example_is_same_fun(example["question"], example["context"], closed_book_answers, generated_answer) 174 | 175 | correct = ctx_is_correct 176 | 177 | return OpenBookExampleResult(example, correct, generated_answer, same_as_closed_book, cb_in_ctx, input_len) 178 | 179 | 180 | def evaluate_openbook( 181 | model, 182 | tokenizer, 183 | dataset, 184 | custom_prompt, 185 | device, 186 | metric_name, 187 | sameness_metric, 188 | masking_strategy : str = None, 189 | ): 190 | """Evaluate a CausalLM model on an open-book question answering task (with contexts). 191 | 192 | Each example has a question, context, ground-truth answer and a closed-book answer (given by the model 193 | in previous experiments). 194 | 195 | For each example check if model predicts the Correct updated answer, or if wrong if 196 | it is same as closed-book or different. 197 | (Assumption is that we run on the Wrong-closed-book subset). 198 | """ 199 | incorrect_update = [] 200 | retain_parametric = [] 201 | correct_update = [] 202 | 203 | ctx_lengths = [] 204 | 205 | example_is_correct_fun = get_example_is_correct_fun(metric_name) 206 | example_is_same_fun = get_example_is_correct_fun(sameness_metric) 207 | 208 | cb_in_ctx_incorrect_update, cb_in_ctx_retain_parametric, cb_in_ctx_correct_update = 0, 0, 0 209 | 210 | model.eval() 211 | for example in tqdm(dataset): 212 | ex_result = evaluate_example_openbook( 213 | model, 214 | tokenizer, 215 | example, 216 | custom_prompt, 217 | device, 218 | example_is_correct_fun, 219 | example_is_same_fun, 220 | masking_strategy, 221 | ) 222 | example["openbook_answer"] = ex_result.generated_answer 223 | 224 | if ex_result.correct: 225 | correct_update.append(example) 226 | cb_in_ctx_correct_update += ex_result.cb_in_ctx 227 | else: 228 | if ex_result.same_as_closed_book: 229 | retain_parametric.append(example) 230 | cb_in_ctx_retain_parametric += ex_result.cb_in_ctx 231 | else: 232 | incorrect_update.append(example) 233 | cb_in_ctx_incorrect_update += ex_result.cb_in_ctx 234 | 235 | ctx_lengths.append(ex_result.input_len) 236 | 237 | 238 | results = { 239 | "incorrect_update": incorrect_update, 240 | "retain_parametric": retain_parametric, 241 | "correct_update": correct_update, 242 | "cb_in_ctx": { 243 | "incorrect_update": cb_in_ctx_incorrect_update, 244 | "retain_parametric": cb_in_ctx_retain_parametric, 245 | "correct_update": cb_in_ctx_correct_update, 246 | }, 247 | "cb_not_in_ctx": { 248 | "incorrect_update": len(incorrect_update) - cb_in_ctx_incorrect_update, 249 | "retain_parametric": len(retain_parametric) - cb_in_ctx_retain_parametric, 250 | "correct_update": len(correct_update) - cb_in_ctx_correct_update, 251 | }, 252 | "ctx_len_min": min(ctx_lengths), 253 | "ctx_len_avg": sum(ctx_lengths) / len(ctx_lengths), 254 | "ctx_len_max": max(ctx_lengths), 255 | } 256 | return results 257 | 258 | 259 | -------------------------------------------------------------------------------- /src/evaluation/prompt_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import src.evaluation.exact_match as exact_match 3 | 4 | 5 | def replace_placeholders_with_data(example, custom_prompt): 6 | """Replace placeholders in the custom prompt with actual data from the example.""" 7 | # Prompt must contain placeholders for context and question 8 | # assert "" in custom_prompt, "Prompt must contain " 9 | 10 | context = example["context"] 11 | question = example["question"] 12 | closedbook_answer = example.get("closedbook_answer", "") 13 | answer = example.get("answers", "")[0] 14 | paraphrase_context = example.get("paraphrase_context", "") 15 | 16 | prompt = custom_prompt.replace("", context) 17 | if not question.endswith("?"): 18 | question = question + "?" 19 | 20 | prompt = prompt.replace("", question) 21 | prompt = prompt.replace("", closedbook_answer) 22 | prompt = prompt.replace("", answer) 23 | prompt = prompt.replace("", paraphrase_context) 24 | 25 | prompt = prompt.replace("", "\n") 26 | prompt = prompt.replace("", "[INST]") 27 | prompt = prompt.replace("", "[/INST]") 28 | 29 | return prompt 30 | 31 | 32 | def find_first_subtensor_index(tensor, sub_tensor): 33 | """Find the first index of sub_tensor in tensor.""" 34 | for i in range(len(tensor) - len(sub_tensor) + 1): 35 | if (tensor[i : i + len(sub_tensor)] == sub_tensor).all(): 36 | return i 37 | return -1 38 | 39 | 40 | def tokenized_answer_found_in_model_inputs(answer, model_inputs, tokenizer): 41 | # We search for the answer as is, normalized and normalized without lowering 42 | all_answers = [ 43 | answer, 44 | exact_match.normalize_answer(answer), 45 | exact_match.normalize_without_lowering(answer), 46 | ] 47 | 48 | all_answers_tokenized = [ 49 | tokenizer( 50 | answer, 51 | return_tensors="pt", 52 | ).to(model_inputs.input_ids.device) 53 | for answer in all_answers 54 | ] 55 | all_answers_tensors = [ 56 | answer_tokenized.input_ids[0][1:] 57 | for answer_tokenized in all_answers_tokenized 58 | ] 59 | 60 | text_tensor = model_inputs.input_ids[0] 61 | for answer_tensor in all_answers_tensors: 62 | sub_i = find_first_subtensor_index(text_tensor, answer_tensor) 63 | if sub_i != -1: 64 | return True 65 | 66 | return False 67 | 68 | 69 | 70 | def get_masking_token(masking_strategy, tokenizer, example, device): 71 | if masking_strategy == "input_tokens_space": 72 | mask = torch.Tensor([259]).long().to(device) # 259 is the space token 73 | elif masking_strategy == "input_tokens_remove": 74 | mask = None 75 | elif masking_strategy == "input_tokens_unk": 76 | mask = torch.Tensor([tokenizer.unk_token_id]).long().to(device) 77 | elif masking_strategy == "input_tokens_entity": 78 | mask = torch.Tensor([7855]).long().to(device) # 7855 is the 'entity' token 79 | elif masking_strategy == "input_tokens_paraphrase_gpt": 80 | para_tokenized = tokenizer( 81 | exact_match.normalize_without_lowering(example["paraphrase_closedbook_answer_gpt"]), 82 | return_tensors="pt", 83 | ).to(device) 84 | mask = para_tokenized.input_ids[0][1:] 85 | else: 86 | raise ValueError(f"Unknown masking strategy: {masking_strategy}") 87 | return mask 88 | 89 | 90 | def mask_cb_answer_with_attention_mask(inputs, cb_inputs): 91 | prompt_inputs = inputs.input_ids[0] 92 | for i in range(len(prompt_inputs) - len(cb_inputs) + 1): 93 | if (prompt_inputs[i : i + len(cb_inputs)] == cb_inputs).all(): 94 | inputs.attention_mask[0][i : i + len(cb_inputs)] = 0 95 | return inputs 96 | 97 | 98 | def mask_input_tokens_of_cb_answer(tokenizer, inputs, cb_inputs, example, masking_strategy): 99 | device = inputs.input_ids.device 100 | 101 | cb_answer_found = True 102 | # Defense from infinite loop (in case of a bug) 103 | num_substitutions = 0 104 | max_substitutions = 50 105 | while cb_answer_found: 106 | prompt_inputs = inputs.input_ids[0] 107 | sub_i = find_first_subtensor_index(prompt_inputs, cb_inputs) 108 | if sub_i == -1: 109 | cb_answer_found = False 110 | else: 111 | # CB answer is found, masking it out using attention mask 112 | mask = get_masking_token(masking_strategy, tokenizer, example, device) 113 | 114 | unmasked_tensor = prompt_inputs 115 | # Replace all the tokens of CB answer with one token of the mask 116 | if mask is not None: 117 | new_tensor = torch.cat([unmasked_tensor[:sub_i], mask, unmasked_tensor[sub_i + len(cb_inputs):]]).to(device) 118 | else: 119 | new_tensor = torch.cat([unmasked_tensor[:sub_i], unmasked_tensor[sub_i + len(cb_inputs):]]).to(device) 120 | 121 | inputs.input_ids = new_tensor.unsqueeze(0) 122 | inputs.attention_mask = torch.ones_like(inputs.input_ids) 123 | 124 | num_substitutions += 1 125 | if num_substitutions > max_substitutions: 126 | break 127 | 128 | return inputs 129 | 130 | 131 | def mask_a_tensor(tokenizer, inputs, example, masking_strategy, tensor_to_mask): 132 | 133 | split_strategy = masking_strategy.split("/") 134 | if len(split_strategy) == 1: 135 | masking_strategy, cb_answer_length = split_strategy[0], None 136 | elif len(split_strategy) == 2: 137 | masking_strategy, cb_answer_length = split_strategy 138 | cb_answer_length = int(cb_answer_length) 139 | else: 140 | raise ValueError(f"Invalid masking strategy: {masking_strategy}") 141 | 142 | if cb_answer_length: 143 | tensor_to_mask = tensor_to_mask[:cb_answer_length] 144 | 145 | if "input_text" in masking_strategy: 146 | input_text = tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True) 147 | text_to_mask = tokenizer.decode(tensor_to_mask, skip_special_tokens=True) 148 | if masking_strategy == "input_text_space": 149 | input_text = input_text.replace(text_to_mask, " ") 150 | elif masking_strategy == "input_text_paraphrase_gpt": 151 | para_text = exact_match.normalize_without_lowering(example["paraphrase_closedbook_answer_gpt"]) 152 | input_text = input_text.replace(text_to_mask, para_text) 153 | else: 154 | raise ValueError(f"Unknown masking strategy: {masking_strategy}") 155 | inputs = tokenizer(input_text, return_tensors="pt").to(inputs.input_ids.device) 156 | 157 | elif masking_strategy == "attention_mask_full": 158 | inputs = mask_cb_answer_with_attention_mask(inputs, tensor_to_mask) 159 | elif "input_tokens" in masking_strategy: 160 | inputs = mask_input_tokens_of_cb_answer(tokenizer, inputs, tensor_to_mask, example, masking_strategy) 161 | else: 162 | raise ValueError(f"Unknown masking strategy: {masking_strategy}") 163 | 164 | input_text = tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True) 165 | 166 | return inputs, input_text 167 | 168 | 169 | def mask_cb_answer(tokenizer, inputs, example, masking_strategy): 170 | """Find the CB answer in the input ids and mask it out using the specified strategy. 171 | 172 | Returns the modified inputs and decoded input text. 173 | Looks for both the lowercased and unlowercased versions of the CB answer. 174 | """ 175 | lowered_cb = exact_match.normalize_answer(example["closedbook_answer"]) 176 | unlowered_cb = exact_match.normalize_without_lowering(example["closedbook_answer"]) 177 | unnormalized_cb = example["closedbook_answer"] 178 | 179 | for normalized_cb_answer in [lowered_cb, unlowered_cb, unnormalized_cb]: 180 | cb_tokenized = tokenizer( 181 | normalized_cb_answer, 182 | return_tensors="pt", 183 | ).to(inputs.input_ids.device) 184 | cb_inputs = cb_tokenized.input_ids[0][1:] 185 | 186 | inputs, input_text = mask_a_tensor(tokenizer, inputs, example, masking_strategy, cb_inputs) 187 | return inputs, input_text 188 | 189 | 190 | def get_input_ids_with_prompt(tokenizer, example, custom_prompt, device): 191 | """Create input ids from the example and the custom prompt. 192 | 193 | Replace and with the actual context and question. 194 | Example is a dictionary that must contain the following keys: 195 | - context 196 | - question 197 | - closedbook_answer (optional) 198 | """ 199 | prompt_with_data = replace_placeholders_with_data(example, custom_prompt) 200 | 201 | inputs = tokenizer(prompt_with_data, return_tensors="pt") 202 | 203 | return inputs.to(device), prompt_with_data 204 | 205 | 206 | def get_input_ids_with_doc_positions(tokenizer, example, custom_prompt, device): 207 | prompt = replace_placeholders_with_data(example, custom_prompt) 208 | 209 | sep1 = "Context: " 210 | sep2 = " Question:" 211 | assert sep1 in prompt 212 | assert sep2 in prompt 213 | prefix, docs = prompt.split(sep1) 214 | prefix = prefix + sep1 215 | docs, suffix = docs.split(sep2) 216 | suffix = sep2 + suffix 217 | 218 | prefix_input_ids = tokenizer(prefix, return_tensors="pt").input_ids 219 | # Remove the first token, which is the token 220 | docs_input_ids = tokenizer(docs, return_tensors="pt").input_ids[:, 1:] 221 | suffix_input_ids = tokenizer(suffix, return_tensors="pt").input_ids[:, 1:] 222 | 223 | combined_input_ids = torch.cat([prefix_input_ids, docs_input_ids, suffix_input_ids], dim=1) 224 | docs_start = prefix_input_ids.shape[1] 225 | docs_end = docs_start + docs_input_ids.shape[1] 226 | 227 | assert all(combined_input_ids[0, docs_start:docs_end] == docs_input_ids[0]) 228 | new_prompt = tokenizer.decode(combined_input_ids[0], skip_special_tokens=True) 229 | 230 | return combined_input_ids.to(device), new_prompt, docs_start, docs_end -------------------------------------------------------------------------------- /src/openbook_exp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import munch 4 | 5 | import src.evaluation.icl as icl 6 | import src.evaluation.question_answering as question_answering 7 | import src.analysis as analysis 8 | import src.validation as validation 9 | import src.file_utils as file_utils 10 | import src.model_utils as model_utils 11 | 12 | 13 | def do_logging_based_on_results(log_metric: callable, results: dict): 14 | incorrect_update = results["incorrect_update"] 15 | retain_parametric = results["retain_parametric"] 16 | correct_update = results["correct_update"] 17 | num_examples = len(incorrect_update) + len(retain_parametric) + len(correct_update) 18 | 19 | incorrect_update_pct = len(incorrect_update) / num_examples 20 | retain_parametric_pct = len(retain_parametric) / num_examples 21 | correct_update_pct = len(correct_update) / num_examples 22 | 23 | log_metric("Incorrect update", incorrect_update_pct) 24 | log_metric("Retain parametric", retain_parametric_pct) 25 | log_metric("Correct update", correct_update_pct) 26 | 27 | log_metric("Num Incorrect update", len(incorrect_update)) 28 | log_metric("Num Retain parametric", len(retain_parametric)) 29 | log_metric("Num Correct update", len(correct_update)) 30 | 31 | cb_in_ctx_incorrect_update_pct = results["cb_in_ctx"]["incorrect_update"] / (len(incorrect_update) or 1) 32 | cb_in_ctx_retain_parametric_pct = results["cb_in_ctx"]["retain_parametric"] / (len(retain_parametric) or 1) 33 | cb_in_ctx_correct_update_pct = results["cb_in_ctx"]["correct_update"] / (len(correct_update) or 1) 34 | 35 | log_metric("CB in Context Incorrect update", cb_in_ctx_incorrect_update_pct) 36 | log_metric("CB in Context Retain parametric", cb_in_ctx_retain_parametric_pct) 37 | log_metric("CB in Context Correct update", cb_in_ctx_correct_update_pct) 38 | 39 | num_cb_in_ctx = ( 40 | results["cb_in_ctx"]["incorrect_update"] 41 | + results["cb_in_ctx"]["retain_parametric"] 42 | + results["cb_in_ctx"]["correct_update"] 43 | ) 44 | p_cb_in_ctx = num_cb_in_ctx / num_examples 45 | log_metric("Overall CB in Context", p_cb_in_ctx) 46 | 47 | log_metric("Num CB in CTX Incorrect update", results["cb_in_ctx"]["incorrect_update"]) 48 | log_metric("Num CB in CTX Retain parametric", results["cb_in_ctx"]["retain_parametric"]) 49 | log_metric("Num CB in CTX Correct update", results["cb_in_ctx"]["correct_update"]) 50 | 51 | log_metric("Num CB in Context", num_cb_in_ctx) 52 | 53 | num_not_cb_in_ctx = num_examples - num_cb_in_ctx 54 | num_not_cb_in_ctx_alt = ( 55 | results["cb_not_in_ctx"]["incorrect_update"] 56 | + results["cb_not_in_ctx"]["retain_parametric"] 57 | + results["cb_not_in_ctx"]["correct_update"] 58 | ) 59 | assert num_not_cb_in_ctx == num_not_cb_in_ctx_alt, "Calculations don't match" 60 | 61 | 62 | p_not_cb_in_ctx = 1 - p_cb_in_ctx 63 | 64 | log_metric("Overall Not CB in Context", p_not_cb_in_ctx) 65 | 66 | log_metric("Num Not CB in CTX Incorrect update", results["cb_not_in_ctx"]["incorrect_update"]) 67 | log_metric("Num Not CB in CTX Retain parametric", results["cb_not_in_ctx"]["retain_parametric"]) 68 | log_metric("Num Not CB in CTX Correct update", results["cb_not_in_ctx"]["correct_update"]) 69 | 70 | log_metric("Num Not CB in Context", num_not_cb_in_ctx) 71 | 72 | 73 | p_iu_given_cb_in_ctx = results["cb_in_ctx"]["incorrect_update"] / (num_cb_in_ctx or 1) 74 | log_metric("P(incorrect_update | cb_in_ctx)", p_iu_given_cb_in_ctx) 75 | 76 | p_iu_given_not_cb_in_ctx = results["cb_not_in_ctx"]["incorrect_update"] / (num_not_cb_in_ctx or 1) 77 | log_metric("P(incorrect_update | not cb_in_ctx)", p_iu_given_not_cb_in_ctx) 78 | 79 | # Test the hypothesis that p(iu) is different in the two groups 80 | p_value_iu = analysis.binomial_hypothesis_test( 81 | m_1=results["cb_in_ctx"]["incorrect_update"], 82 | n_1=num_cb_in_ctx, 83 | m_0=results["cb_not_in_ctx"]["incorrect_update"], 84 | n_0=num_not_cb_in_ctx, 85 | ) 86 | log_metric("P-val IU", p_value_iu) 87 | 88 | p_rp_given_cb_in_ctx = results["cb_in_ctx"]["retain_parametric"] / (num_cb_in_ctx or 1) 89 | log_metric("P(retain_parametric | cb_in_ctx)", p_rp_given_cb_in_ctx) 90 | 91 | p_rp_given_not_cb_in_ctx = results["cb_not_in_ctx"]["retain_parametric"] / (num_not_cb_in_ctx or 1) 92 | log_metric("P(retain_parametric | not cb_in_ctx)", p_rp_given_not_cb_in_ctx) 93 | 94 | # Test the hypothesis that p(rp) is different in the two groups 95 | p_value_rp = analysis.binomial_hypothesis_test( 96 | m_1=results["cb_in_ctx"]["retain_parametric"], 97 | n_1=num_cb_in_ctx, 98 | m_0=results["cb_not_in_ctx"]["retain_parametric"], 99 | n_0=num_not_cb_in_ctx, 100 | ) 101 | log_metric("P-val RP", p_value_rp) 102 | 103 | p_cu_given_cb_in_ctx = results["cb_in_ctx"]["correct_update"] / (num_cb_in_ctx or 1) 104 | log_metric("P(correct_update | cb_in_ctx)", p_cu_given_cb_in_ctx) 105 | 106 | p_cu_given_not_cb_in_ctx = results["cb_not_in_ctx"]["correct_update"] / (num_not_cb_in_ctx or 1) 107 | log_metric("P(correct_update | not cb_in_ctx)", p_cu_given_not_cb_in_ctx) 108 | 109 | # Test the hypothesis that p(cu) is different in the two groups 110 | p_value_cu = analysis.binomial_hypothesis_test( 111 | m_1=results["cb_in_ctx"]["correct_update"], 112 | n_1=num_cb_in_ctx, 113 | m_0=results["cb_not_in_ctx"]["correct_update"], 114 | n_0=num_not_cb_in_ctx, 115 | ) 116 | log_metric("P-val CU", p_value_cu) 117 | 118 | 119 | # Memorization ratio = (rp / (rp + cu)) 120 | memorization_ratio = retain_parametric_pct / (retain_parametric_pct + correct_update_pct) 121 | log_metric("Memorization Ratio", memorization_ratio) 122 | 123 | mem_ratio_given_cb_in_ctx = p_rp_given_cb_in_ctx / ((p_rp_given_cb_in_ctx + p_cu_given_cb_in_ctx) or 1) 124 | log_metric("Memorization Ratio | CB in Context", mem_ratio_given_cb_in_ctx) 125 | 126 | mem_ratio_given_not_cb_in_ctx = p_rp_given_not_cb_in_ctx / ((p_rp_given_not_cb_in_ctx + p_cu_given_not_cb_in_ctx) or 1) 127 | log_metric("Memorization Ratio | Not CB in Context", mem_ratio_given_not_cb_in_ctx) 128 | 129 | 130 | log_metric(f"Ctx len min", results["ctx_len_min"]) 131 | log_metric(f"Ctx len avg", results["ctx_len_avg"]) 132 | log_metric(f"Ctx len max", results["ctx_len_max"]) 133 | 134 | print(f"Incorrect update, Percentage: {len(incorrect_update) / num_examples:.2%} ({len(incorrect_update)} / {num_examples}), CB in Context: {cb_in_ctx_incorrect_update_pct}") 135 | print (f"Retain parametric, Percentage: {len(retain_parametric) / num_examples:.2%} ({len(retain_parametric)} / {num_examples}), CB in Context: {cb_in_ctx_retain_parametric_pct}") 136 | print(f"Correct update, Percentage {len(correct_update) / num_examples:.2%} ({len(correct_update)} / {num_examples}), CB in Context: {cb_in_ctx_correct_update_pct}") 137 | 138 | 139 | def run_openbook_experiment(exp_config: munch.Munch): 140 | """Run open-book QA experiment on a dataset that contains "context", 141 | "question", "answers" and "closedbook_answer" fields. 142 | """ 143 | 144 | # Logging dict to store metrics 145 | logging_dict = {} 146 | def log_metric(metric_name, value): 147 | logging_dict[metric_name] = value 148 | 149 | custom_prompt = exp_config.custom_prompt 150 | 151 | # Example level correctness metric with which we define correctness of model answers. 152 | if "metric_name" not in exp_config or not exp_config.metric_name: 153 | exp_config.metric_name = "EM" 154 | 155 | assert exp_config.metric_name in ["EM", "BEM"] 156 | 157 | # Metric that identifies if open-book answer is same as closed-book answer 158 | if "sameness_metric" not in exp_config or not exp_config.sameness_metric: 159 | exp_config.sameness_metric = exp_config.metric_name 160 | 161 | assert exp_config.sameness_metric in ["EM", "BEM"] 162 | 163 | file_utils.replace_placeholders_in_paths( 164 | exp_config, 165 | path_keys=["dataset_path", "results_dir", "activation_dir", "icl_dataset_path", "output_path"] 166 | ) 167 | 168 | model, tokenizer, _, device = model_utils.load_model_and_tokenizer( 169 | model_name=exp_config.model_name, 170 | model_parallelism=exp_config.model_parallelism, 171 | quantized=exp_config.quantized, 172 | ) 173 | 174 | dataset = file_utils.load_parquet_dataset(exp_config.dataset_path) 175 | 176 | if "dataset_length" in exp_config and exp_config.dataset_length: 177 | dataset = dataset[:exp_config.dataset_length] 178 | print(f"Using only {exp_config.dataset_length} examples from the dataset") 179 | 180 | 181 | validation.assert_fields_exist( 182 | dataset=dataset, 183 | fields=["context", "question", "answers", "closedbook_answer"], 184 | ) 185 | validation.ensure_string_fields( 186 | dataset=dataset, 187 | fields=["context", "question", "closedbook_answer"], 188 | ) 189 | 190 | 191 | if "masking_strategy" not in exp_config or not exp_config.masking_strategy: 192 | exp_config.masking_strategy = None 193 | 194 | if "" in custom_prompt: 195 | assert "icl_demo_prompt" in exp_config 196 | assert "icl_n" in exp_config 197 | assert "icl_dataset_path" in exp_config 198 | custom_prompt = icl.prepare_prompt_for_icl( 199 | custom_prompt, exp_config.icl_demo_prompt, exp_config.icl_n, exp_config.icl_dataset_path 200 | ) 201 | 202 | results = question_answering.evaluate_openbook( 203 | model, 204 | tokenizer, 205 | dataset, 206 | custom_prompt, 207 | device, 208 | exp_config.metric_name, 209 | exp_config.sameness_metric, 210 | exp_config.masking_strategy, 211 | ) 212 | 213 | do_logging_based_on_results(log_metric, results) 214 | 215 | incorrect_update = results["incorrect_update"] 216 | retain_parametric = results["retain_parametric"] 217 | correct_update = results["correct_update"] 218 | 219 | # Save examples in each category 220 | if exp_config.results_dir: 221 | file_utils.save_dataset_to_parquet( 222 | incorrect_update, exp_config.results_dir + "/incorrect_update.parquet" 223 | ) 224 | file_utils.save_dataset_to_parquet( 225 | retain_parametric, exp_config.results_dir + "/retain_parametric.parquet" 226 | ) 227 | file_utils.save_dataset_to_parquet( 228 | correct_update, exp_config.results_dir + "/correct_update.parquet" 229 | ) 230 | 231 | 232 | if exp_config.output_path: 233 | # Ensure parent dir exists 234 | os.makedirs(os.path.dirname(exp_config.output_path), exist_ok=True) 235 | 236 | # Pretty print the logging dict as json to the output path 237 | with open(exp_config.output_path, "w") as f: 238 | f.write(json.dumps(logging_dict, indent=4)) 239 | else: 240 | print("Output path, not specified") 241 | print(f"Results: ") 242 | print(json.dumps(logging_dict, indent=4)) 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Studying Large Language Model Behaviors Under Context-Memory Conflicts With Real Documents 2 | 3 | Official repository for the paper [**Studying Large Language Model Behaviors Under Context-Memory Conflicts With Real Documents**](https://openreview.net/forum?id=xm8zYRfrqE). 4 | 5 | We introduce a framework for studying context-memory knowledge conflicts in a realistic setup (see image below). 6 | 7 |

8 | Experimental design 9 |

10 | 11 | We update incorrect parametric knowledge (Stages 1 and 2) using real conflicting documents (Stage 3). 12 | This reflects how knowledge conflicts arise in practice. 13 | In this realistic scenario, we find that knowledge updates fail less often than previously reported. 14 | 15 | 16 | In cases where the models still fail to update their answers, we find a **parametric bias**: the incorrect parametric answer appearing in context makes the knowledge update likelier to fail (see below). 17 | This suggests that the factual parametric knowledge of LLMs can negatively influence their reading abilities and behaviors. 18 | 19 | ![Example](assets/example.jpeg) 20 | 21 | We include a protocol for evaluating the susceptibility of a RAG system to the **parametric bias** in this repository. 22 | 23 | 24 | 25 | ## Getting started 26 | 27 | ### Clone this repo 28 | ```bash 29 | git clone https://github.com/kortukov/realistic_knowledge_conflicts 30 | cd realistic_knowledge_conflicts 31 | ``` 32 | 33 | ### Install dependencies 34 | ```bash 35 | conda create -n realistic_kc python=3.10 36 | conda activate realistic_kc 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ## Reproducing experiments 41 | 42 |
43 |

0. Download data

44 | 45 | #### Test data 46 | We download the [MrQA](https://huggingface.co/datasets/mrqa) validation split and use it as test data: 47 | NQ, SQuAD, NewsQA, TriviaQA, SearchQA, HotpotQA. 48 | ``` 49 | python 0_download_data.py --dataset-type test 50 | ``` 51 | 52 | #### ICL data 53 | In Stage 1 of our experimental pipeline we run the models closed-book. 54 | To ensure best posssible closed-book performance we use ICL demonstrations. 55 | For ICL we use the train split of each dataset. 56 | We shuffle the original data and only save 10 examples. 57 | ``` 58 | python 0_download_data.py --dataset-type icl 59 | ``` 60 | 61 |
62 | 63 |
64 |

1. Creating knowledge conflict dataset

65 | 66 | #### Stage 1: Closed-book answer gathering 67 | We run the closed-book experiments using configs in [config/cb](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/cb). 68 | 69 | ``` 70 | python 1_gather_cb_answers.py --config config/cb/llama7b/hotpotqa.conf 71 | ``` 72 | 73 | #### Stage 2: Filtering out no-conflict examples 74 | ``` 75 | python 2_filter_out_no_conflict.py --config config/filter/llama7b/hotpotqa.conf 76 | ``` 77 | 78 |
79 | 80 |
81 |

2. Studying knowledge updating behaviors under knowledge conflict

82 | 83 | #### Section 4.2 Studying knowledge updating behaviors under realistic knowledge conflicts 84 | In this experiment, we run stage 3 of the pipeline. 85 | We run the open-book experiments using configs in [config/ob](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/ob). 86 | By default, the results are saved into results/{model_name}/ob_{dataset}.out. 87 | 88 | 89 | ``` 90 | python 3_run_ob_experiment.py --config config/ob/llama7b/hotpotqa.conf 91 | ``` 92 | 93 | Results reported in Table 3 can be found by keys "Retain parametric", "Correct update", and "Incorrect update" 94 | in the output file. 95 | 96 | #### Section 4.3.1 Studying the differences between example categories 97 | Results reported in Figure 2 can be found in the output file by keys "Overall CB in Context", 98 | "CB in Context Retain parametric", "CB in Context Correct update", and "CB in Context Incorrect update". 99 | 100 | #### Section 4.3.2 Influence of parametric answer in context on knowledge update failures 101 | Results reported in Table 4 can be found in the output file by taking the following difference: 102 | 103 | (1 - "P(correct_update | cb_in_ctx)") - (1 - "P(correct_update | not cb_in_ctx)") 104 | 105 | = "P(correct_update | not cb_in_ctx)" - "P(correct_update | cb_in_ctx)" 106 | 107 | The p-values are reported in key "P-val CU". 108 | 109 |
110 | 111 | 112 |
113 |

3. Intervention experiments

114 | 115 | #### Section 4.4.1 Masking reduces the likelihood of retaining parametric answer 116 | We run the masking experiments using configs in [config/mask](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/mask). 117 | 118 | The results are saved into results/{model_name}/mask_{dataset}.out. 119 | ``` 120 | python 3_run_ob_experiment.py --config config/mask/llama7b/hotpotqa.conf 121 | ``` 122 | 123 | #### Section 4.4.2 Adding the parametric answer to the context increases the likelihood of retaining it 124 | We run the experiments with adding incorrect parametric answer to context using configs in [config/add](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/add). 125 | 126 | The results are saved into results/{model_name}/add_{dataset}.out. 127 | ``` 128 | python 3_run_ob_experiment.py --config config/add/llama7b/hotpotqa.conf 129 | ``` 130 | 131 |
132 | 133 | 134 |
135 |

4. ICL task adaptation

136 | 137 | #### Appendix E Task adaptation using in-context learning 138 | In this experiment, we test whether in-context demonstrations can minimize the influence of the discovered parametric bias. 139 | 140 | We run the ICL experiments using configs in [config/icl](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/icl). 141 | 142 | The results are saved into results/{model_name}/icl_{dataset}.out. 143 | ``` 144 | python 3_run_ob_experiment.py --config config/icl/llama7b/hotpotqa.conf 145 | ``` 146 |
147 | 148 | 149 |
150 |

5. FreshQA experiment

151 | 152 | #### Appendix G Parametric answer is likely to appear in real-world documents 153 | In this experiment, we move closer to a realistic RAG knowledge updating scenario and check how often does the incorrect parametric 154 | answer of a model appears in real-world retrieved documents. To that end, we run models on the FreshQA dataset. 155 | It contains questions, whose answers change with time. Updated truth answers are supplied together with web 156 | documents containing them. 157 | 158 | First, we download the FreshQA data for Feb 26, 2024 (as in the paper). 159 | 160 | ``` 161 | python 4_download_freshqa.py 162 | ``` 163 | 164 | Then we find out the parametric (outdated) answers of the model by running the closed-book experiment. 165 | 166 | We use configs in [config/freshqa](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/freshqa). 167 | ``` 168 | python 1_gather_cb_answers.py --config config/freshqa/llama7b.conf 169 | ``` 170 | 171 | The results are saved into results/{model_name}/add_{dataset}.out. 172 | Values reported in Table 15 can be found under the keys "Parametric answer in context", and "Incorrect out of parametric in context". 173 | 174 | 175 |
176 | 177 | 178 | ## Parametric bias susceptibility evaluation 179 | 180 | You can use the code in this repository to test your own task-specific retrieval-augmented LLM for parametric bias susceptibility. 181 | 182 | To achieve that we will make use of the Huggingface hub. 183 | 184 |
185 |

Data and model preparation

186 | 187 | #### Prepare the dataset 188 | First, you will need to [upload your dataset to the Huggingface hub](https://huggingface.co/docs/hub/en/datasets-adding) in the correct format. 189 | To be compatible with our evaluation it should have "question", "context", and answers fields. 190 | 191 | Formulate your downstream task in the QA format and supply your retrieved documents in the "context" field. 192 | 193 | #### Prepare the model 194 | As with the data, choose a model from the hub or [upload your custom model to the Huggingface hub](https://huggingface.co/docs/hub/en/models-uploading). 195 | 196 | #### Prepare the config file 197 | In all config files in the [config/custom](https://github.com/kortukov/realistic_knowledge_conflicts/tree/main/config/custom) you have to replace the lines 198 | ``` 199 | model_name: "your_model_name" 200 | ``` 201 | and 202 | ``` 203 | dataset: "your_dataset_name" 204 | ``` 205 | with the hub identifiers of your own model and dataset. 206 | 207 |
208 | 209 | 210 |
211 |

Evaluation protocol

212 | The protocol is based on the intervention experiments in the paper. 213 | 214 | #### Download the dataset 215 | ``` 216 | python 0_download_data.py --dataset-type custom --custom-dataset-name 217 | ``` 218 | 219 | #### Gather closed-book answers of your model 220 | ``` 221 | python 1_gather_cb_answers.py --config config/custom/cb.conf 222 | ``` 223 | 224 | #### Filter out no-conflict examples 225 | ``` 226 | python 2_filter_out_no_conflict.py --config config/custom/filter.conf 227 | ``` 228 | 229 | #### Evaluate the model open-book on your task 230 | ``` 231 | python 3_run_ob_experiment.py --config config/custom/ob.conf 232 | ``` 233 | 234 | #### Introduce the incorrect parametric answer into the context 235 | ``` 236 | python 3_run_ob_experiment.py --config config/custom/add.conf 237 | ``` 238 | 239 | #### Interpret the results 240 | To see how susceptible your model is to the parametric bias we compare the results before and after adding the incorrect parametric answer to the context. 241 | We compare the fields "Retain parametric", "Correct update", and "Incorrect update" in the files 242 | results/{your_model_name}/ob_{your_dataset_name}.out and results/{your_model_name}/add_{your_dataset_name}.out. 243 | 244 | If adding the incorrect answer to the context increases the prevalence of "Retain parametric" class, your model is susceptible to the **parametric bias**. 245 | 246 | 247 |
248 | 249 | ## Citing 250 | 251 | If you use this repository, consider citing our paper: 252 | 253 | ```bibtex 254 | @inproceedings{kortukov2024studying, 255 | title={Studying Large Language Model Behaviors Under Context-Memory Conflicts With Real Documents}, 256 | author={Evgenii Kortukov and Alexander Rubinstein and Elisa Nguyen and Seong Joon Oh}, 257 | booktitle={First Conference on Language Modeling}, 258 | year={2024}, 259 | url={https://openreview.net/forum?id=xm8zYRfrqE} 260 | } 261 | ``` 262 | --------------------------------------------------------------------------------