├── .gitignore ├── .gitmodules ├── README.md ├── data_v2 ├── gen_dpr_hc_jsonl.py ├── input_data │ └── DPR │ │ └── sampled_query │ │ ├── nq-test-sample-200.jsonl │ │ ├── pop-test-sample-200.jsonl │ │ ├── tqa-test-sample-200.jsonl │ │ └── webq-test-sample-200.jsonl ├── process_raw_to_input_format.py └── update_data │ └── DPR │ └── update_passage_processed │ ├── merged_file │ └── merged.jsonl │ ├── nq-test-gen-qwen-0.5b-chat-postprocessed.jsonl │ ├── nq-test-gen-qwen-1.8b-chat-postprocessed.jsonl │ ├── nq-test-gen-qwen-4b-chat-postprocessed.jsonl │ ├── pop-test-gen-qwen-0.5b-chat-postprocessed.jsonl │ ├── pop-test-gen-qwen-1.8b-chat-postprocessed.jsonl │ ├── pop-test-gen-qwen-4b-chat-postprocessed.jsonl │ ├── tqa-test-gen-qwen-0.5b-chat-postprocessed.jsonl │ ├── tqa-test-gen-qwen-1.8b-chat-postprocessed.jsonl │ ├── tqa-test-gen-qwen-4b-chat-postprocessed.jsonl │ ├── webq-test-gen-qwen-0.5b-chat-postprocessed.jsonl │ ├── webq-test-gen-qwen-1.8b-chat-postprocessed.jsonl │ └── webq-test-gen-qwen-4b-chat-postprocessed.jsonl ├── get_response.py ├── pipeline.png ├── requirements.txt └── src ├── evaluation ├── compute_avg_rank.py ├── compute_mis_em.py ├── compute_pearson.py ├── compute_self_bleu.py ├── compute_sig.py ├── draw_bleu.py ├── draw_change_index.py ├── draw_context.py ├── draw_filter_retrieval.py ├── draw_map.py ├── draw_percentage.py ├── draw_qa.py ├── draw_rank.py ├── eva_configs │ ├── bleu_eva_config.json │ ├── context_answer_config.json │ ├── percentage_eva_config.json │ ├── qa_eva_config.json │ ├── rank_eva_config.json │ └── retrieval_eva_config.json ├── eva_context_answer.py ├── eva_domi.py ├── eva_em_llm_judge.py ├── eva_generate_em.py ├── eva_percentage.py ├── eva_pipe.py ├── eva_rank.py ├── eva_retrieval_acc.py ├── eva_retrieval_trec.py ├── run_context_eva.sh ├── sum_QA.py ├── sum_bleu.py ├── sum_change_index.py ├── sum_context_tsv.py ├── sum_percentage_tsv.py ├── sum_retrieval.py └── trec_eval.py ├── filtering ├── compute_source.py ├── filter_configs │ ├── bleu_filter_config.json │ └── source_filter_config.json └── filter_for_loop.py ├── llm_zero_generate ├── eva_generate.py ├── get_response_llm.py ├── rag_configs │ ├── baichuan2-13b-chat-config-nq.json │ ├── baichuan2-13b-chat-config-pop.json │ ├── baichuan2-13b-chat-config-tqa.json │ ├── baichuan2-13b-chat-config-webq.json │ ├── chatglm3-6b-config-nq.json │ ├── chatglm3-6b-config-pop.json │ ├── chatglm3-6b-config-tqa.json │ ├── chatglm3-6b-config-webq.json │ ├── gpt-3.5-turbo-config-nq.json │ ├── gpt-3.5-turbo-config-pop.json │ ├── gpt-3.5-turbo-config-tqa.json │ ├── gpt-3.5-turbo-config-webq.json │ ├── llama2-13b-chat-config-nq.json │ ├── llama2-13b-chat-config-pop.json │ ├── llama2-13b-chat-config-tqa.json │ ├── llama2-13b-chat-config-webq.json │ ├── qwen-14b-chat-config-nq.json │ ├── qwen-14b-chat-config-pop.json │ ├── qwen-14b-chat-config-tqa.json │ └── qwen-14b-chat-config-webq.json ├── run_generate.sh ├── update_configs │ ├── rag_configs │ │ ├── baichuan2-13b-chat-config-nq.json │ │ ├── baichuan2-13b-chat-config-pop.json │ │ ├── baichuan2-13b-chat-config-tqa.json │ │ ├── baichuan2-13b-chat-config-webq.json │ │ ├── baichuan2-7b-chat-config-nq.json │ │ ├── baichuan2-7b-chat-config-pop.json │ │ ├── baichuan2-7b-chat-config-tqa.json │ │ ├── baichuan2-7b-chat-config-webq.json │ │ ├── gpt-3.5-turbo-config-nq.json │ │ ├── gpt-3.5-turbo-config-pop.json │ │ ├── gpt-3.5-turbo-config-tqa.json │ │ ├── gpt-3.5-turbo-config-webq.json │ │ ├── llama2-13b-chat-config-nq.json │ │ ├── llama2-13b-chat-config-pop.json │ │ ├── llama2-13b-chat-config-tqa.json │ │ ├── llama2-13b-chat-config-webq.json │ │ ├── llama2-7b-chat-config-nq.json │ │ ├── llama2-7b-chat-config-pop.json │ │ ├── llama2-7b-chat-config-tqa.json │ │ ├── llama2-7b-chat-config-webq.json │ │ ├── qwen-0.5b-chat-config-nq.json │ │ ├── qwen-0.5b-chat-config-pop.json │ │ ├── qwen-0.5b-chat-config-tqa.json │ │ ├── qwen-0.5b-chat-config-webq.json │ │ ├── qwen-1.8b-chat-config-nq.json │ │ ├── qwen-1.8b-chat-config-pop.json │ │ ├── qwen-1.8b-chat-config-tqa.json │ │ ├── qwen-1.8b-chat-config-webq.json │ │ ├── qwen-14b-chat-config-nq.json │ │ ├── qwen-14b-chat-config-pop.json │ │ ├── qwen-14b-chat-config-tqa.json │ │ ├── qwen-14b-chat-config-webq.json │ │ ├── qwen-4b-chat-config-nq.json │ │ ├── qwen-4b-chat-config-pop.json │ │ ├── qwen-4b-chat-config-tqa.json │ │ ├── qwen-4b-chat-config-webq.json │ │ ├── qwen-7b-chat-config-nq.json │ │ ├── qwen-7b-chat-config-pop.json │ │ ├── qwen-7b-chat-config-tqa.json │ │ ├── qwen-7b-chat-config-webq.json │ │ └── temp │ │ │ ├── baichuan2-7b-chat-config-nq.json │ │ │ ├── baichuan2-7b-chat-config-pop.json │ │ │ ├── baichuan2-7b-chat-config-tqa.json │ │ │ ├── baichuan2-7b-chat-config-webq.json │ │ │ ├── llama2-7b-chat-config-nq.json │ │ │ ├── llama2-7b-chat-config-pop.json │ │ │ ├── llama2-7b-chat-config-tqa.json │ │ │ └── llama2-7b-chat-config-webq.json │ └── zero-shot_configs │ │ ├── llama2-7b-chat-config-nq.json │ │ ├── llama2-7b-chat-config-pop.json │ │ ├── llama2-7b-chat-config-tqa.json │ │ ├── llama2-7b-chat-config-webq.json │ │ ├── qwen-0.5b-chat-config-nq.json │ │ ├── qwen-0.5b-chat-config-pop.json │ │ ├── qwen-0.5b-chat-config-tqa.json │ │ ├── qwen-0.5b-chat-config-webq.json │ │ ├── qwen-1.8b-chat-config-nq.json │ │ ├── qwen-1.8b-chat-config-pop.json │ │ ├── qwen-1.8b-chat-config-tqa.json │ │ ├── qwen-1.8b-chat-config-webq.json │ │ ├── qwen-4b-chat-config-nq.json │ │ ├── qwen-4b-chat-config-pop.json │ │ ├── qwen-4b-chat-config-tqa.json │ │ └── qwen-4b-chat-config-webq.json └── zero-shot_configs │ ├── baichuan2-13b-chat-config-nq.json │ ├── baichuan2-13b-chat-config-pop.json │ ├── baichuan2-13b-chat-config-trivia.json │ ├── baichuan2-13b-chat-config-wq.json │ ├── chatglm3-6b-config-nq.json │ ├── chatglm3-6b-config-pop.json │ ├── chatglm3-6b-config-trivia.json │ ├── chatglm3-6b-config-wq.json │ ├── gpt-3.5-turbo-config-nq.json │ ├── gpt-3.5-turbo-config-pop.json │ ├── gpt-3.5-turbo-config-trivia.json │ ├── gpt-3.5-turbo-config-wq.json │ ├── llama2-13b-chat-config-nq.json │ ├── llama2-13b-chat-config-pop.json │ ├── llama2-13b-chat-config-trivia.json │ ├── llama2-13b-chat-config-wq.json │ ├── qwen-14b-chat-config-nq.json │ ├── qwen-14b-chat-config-pop.json │ ├── qwen-14b-chat-config-trivia.json │ └── qwen-14b-chat-config-wq.json ├── misinfo ├── eva_mis.py ├── gen_misinfo_llm.py ├── mis_config │ ├── mis_config_qwen.json │ ├── nq_mis_config_answer_gpt.json │ ├── nq_mis_config_passage_baichuan.json │ ├── nq_mis_config_passage_chatglm.json │ ├── nq_mis_config_passage_gpt.json │ ├── nq_mis_config_passage_llama.json │ ├── nq_mis_config_passage_qwen.json │ ├── pop_mis_config_answer_gpt.json │ ├── pop_mis_config_passage_baichuan.json │ ├── pop_mis_config_passage_chatglm.json │ ├── pop_mis_config_passage_gpt.json │ ├── pop_mis_config_passage_llama.json │ ├── pop_mis_config_passage_qwen.json │ ├── tqa_mis_config_answer_gpt.json │ ├── tqa_mis_config_passage_baichuan.json │ ├── tqa_mis_config_passage_chatglm.json │ ├── tqa_mis_config_passage_gpt.json │ ├── tqa_mis_config_passage_llama.json │ ├── tqa_mis_config_passage_qwen.json │ ├── webq_mis_config_answer_gpt.json │ ├── webq_mis_config_passage_baichuan.json │ ├── webq_mis_config_passage_chatglm.json │ ├── webq_mis_config_passage_gpt.json │ ├── webq_mis_config_passage_llama.json │ └── webq_mis_config_passage_qwen.json └── run_gen_misinfo.sh ├── post_process ├── delete_configs │ ├── delete_config.json │ └── delete_config_bm25.json ├── delete_doc_from_index.py ├── post_process.sh ├── process_configs │ └── template_config.json ├── process_llm_text.py └── run_index_doc_delete.sh ├── rerank_loop ├── monot5_support.py ├── rankgpt_prompter.py ├── rankgpt_support.py ├── rerank_configs │ ├── bge-config-nq.json │ ├── bge-config-pop.json │ ├── bge-config-tqa.json │ ├── bge-config-webq.json │ ├── monot5-config-nq.json │ ├── monot5-config-pop.json │ ├── monot5-config-tqa.json │ ├── monot5-config-webq.json │ ├── rankgpt-config-nq.json │ ├── rankgpt-config-pop.json │ ├── rankgpt-config-tqa.json │ ├── rankgpt-config-webq.json │ ├── upr-config-nq.json │ ├── upr-config-pop.json │ ├── upr-config-tqa.json │ └── upr-config-webq.json ├── rerank_for_loop.py ├── run_upr.sh └── templates │ ├── LLMreranker.json │ └── PRFreranker.json ├── retrieval_loop ├── elastic_bm25_search_with_metadata.py ├── embedding_index_incremental_corpus.py ├── evaluate_dpr_retrieval.py ├── faiss_search.py ├── index_configs │ ├── all-mpnet-config-psgs_w100.json │ ├── bge-base-config-psgs_w100.json │ ├── bge-large-config-psgs_w100.json │ ├── bm25-config-psgs_w100.json │ ├── contriever-config-psgs_w100.json │ ├── dpr-config-psgs_w100.json │ ├── llm-embedder-config-psgs_w100.json │ ├── retromae-config-psgs_w100.json │ └── tmp │ │ └── bge-config-psgs_w100.json ├── retrieve_configs │ ├── all-mpnet-config-nq.json │ ├── bge-base-config-nq.json │ ├── bge-large-config-nq.json │ ├── bm25-config-nq.json │ ├── contriever-config-nq.json │ ├── dpr-config-nq.json │ ├── llm-embedder-config-nq.json │ └── retromae-config-nq.json ├── retrieve_methods.py ├── run_index_builder.sh └── run_retrieval.sh ├── rewrite_configs.py ├── run_loop.sh ├── run_zero-shot.sh └── test_function └── test_configs ├── indexing_config.json ├── post_process_config.json ├── retrieval_config.json └── template_total_config.json /.gitignore: -------------------------------------------------------------------------------- 1 | # My Folder # 2 | ############# 3 | indexes/ 4 | input_data/ 5 | raw_data/ 6 | ret_output/ 7 | zero_gen_data/ 8 | incxt_gen_data/ 9 | post_processed_data/ 10 | test_output/ 11 | tmp/ 12 | loop_output/ 13 | misinfo_data/ 14 | elasticsearch-8.11.1/ 15 | flash-attention/ 16 | llama-2-13b-chat-hf/ 17 | llm_logs/ 18 | ret_model/ 19 | logs/ 20 | run_logs/ 21 | run_configs/ 22 | api_new/ 23 | # any folder name started with api-for-open-llm 24 | api-for-open-llm/ 25 | api-for-open-llm-2/ 26 | api-for-open-llm-chatglm/ 27 | api-for-open-llm-qwen 28 | LLM-Rob/ 29 | # Compiled source # 30 | ################### 31 | *.com 32 | *.class 33 | *.dll 34 | *.exe 35 | *.o 36 | *.so 37 | *.pyc 38 | *.sha512 39 | *.swp 40 | *.tsv 41 | # Packages # 42 | ############ 43 | # it's better to unpack these files and commit the raw source 44 | # git has its own built in compression methods 45 | *.7z 46 | *.dmg 47 | *.gz 48 | *.iso 49 | *.jar 50 | *.rar 51 | *.tar 52 | *.zip 53 | *.out 54 | 55 | # Logs and databases # 56 | ###################### 57 | *.log 58 | *.sql 59 | *.sqlite 60 | *.faiss 61 | *.pkl 62 | *.png 63 | !pipeline.png 64 | # OS generated files # 65 | ###################### 66 | .DS_Store 67 | .DS_Store? 68 | ._* 69 | .Spotlight-V100 70 | .Trashes 71 | ehthumbs.db 72 | Thumbs.db 73 | clean_requirements.txt 74 | user_requirements.txt -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "externals/UPR"] 2 | path = externals/UPR 3 | url = git@github.com:VerdureChen/unsupervised-passage-reranking.git 4 | [submodule "externals/MonoT5"] 5 | path = externals/MonoT5 6 | url = git@github.com:VerdureChen/pygaggle.git 7 | [submodule "externals/popular-fonts"] 8 | path = externals/popular-fonts 9 | url = https://github.com/chengda/popular-fonts.git 10 | -------------------------------------------------------------------------------- /data_v2/gen_dpr_hc_jsonl.py: -------------------------------------------------------------------------------- 1 | #convert psgs_w100.tsv to jsonl format 2 | #add generated docs meanwhile 3 | 4 | import json 5 | from tqdm import tqdm 6 | import ast 7 | import os 8 | 9 | def convert_psgs_to_jsonl(psgs_file, output_file): 10 | with open(psgs_file, 'r', encoding='utf-8') as f, open(output_file, 'w', encoding='utf-8') as out: 11 | for i, line in tqdm(enumerate(f), desc='Converting psgs to jsonl', total=21015325): 12 | if i == 0: 13 | continue 14 | line = line.strip().split('\t') 15 | doc_id = line[0] 16 | text = line[1][1:-1] 17 | title = line[2] 18 | new_text = title + '\n' + text 19 | doc = {'id': doc_id, 'contents': new_text} 20 | # force_ascii = False 21 | out.write(json.dumps(doc, ensure_ascii=False) + '\n') 22 | 23 | 24 | if __name__ == '__main__': 25 | psgs_file = 'raw_data/DPR/psgs_w100.tsv' 26 | output_file = 'input_data/DPR/psgs_w100.jsonl' 27 | if not os.path.exists('input_data/DPR'): 28 | os.makedirs('input_data/DPR') 29 | convert_psgs_to_jsonl(psgs_file, output_file) 30 | -------------------------------------------------------------------------------- /data_v2/process_raw_to_input_format.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | 3 | #process raw data to input format 4 | raw_data_paths = [ 5 | 'raw_data/DPR/nq-test.jsonl', 6 | 'raw_data/DPR/tqa-test.jsonl', 7 | 'raw_data/DPR/webq-test.jsonl', 8 | 'raw_data/DPR/pop-test.jsonl', 9 | ] 10 | 11 | def process_raw_data(raw_data_path): 12 | output_path = raw_data_path.replace('raw_data', 'input_data') 13 | raw_data = datasets.load_dataset('json', data_files=raw_data_path)['train'] 14 | def reassign_id_with_idx(example, idx): 15 | example['id'] = str(idx) 16 | if 'query' in example: 17 | # keep the special tokens like è 18 | example['question'] = example['query'] 19 | if 'possible_answers' in example: 20 | example['answer'] = example['possible_answers'] 21 | return example 22 | raw_data = raw_data.map(reassign_id_with_idx, with_indices=True) 23 | # keep id, question, answer columns 24 | raw_data = raw_data.select_columns(['id', 'question', 'answer']) 25 | # keep the special tokens like è 26 | raw_data.to_json(output_path, force_ascii=False) 27 | 28 | 29 | if __name__ == '__main__': 30 | for raw_data_path in raw_data_paths: 31 | process_raw_data(raw_data_path) 32 | -------------------------------------------------------------------------------- /get_response.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | openai.api_base = "xxx" 4 | openai.api_key = "xxx" 5 | # 6 | # os.environ["OPENAI_API_BASE"] = "http://124.16.138.144:7001/v1" 7 | # os.environ["OPENAI_API_KEY"] = "xxx" 8 | 9 | 10 | from langchain.chat_models import ChatOpenAI 11 | from langchain.schema import HumanMessage 12 | 13 | # chat = ChatOpenAI() 14 | 15 | 16 | questions = ["who starred in the movie summer of 42", 17 | "who was the first lady nominated member of the rajya sabha?", 18 | "how many episodes are there in dragon ball z?", 19 | "who designed the garden city of new earswick?", 20 | "who is the current director of the us mint?"] 21 | 22 | 23 | for question in questions: 24 | print(question) 25 | # print(chat([HumanMessage(content=f'Please tell me what you know from Wikipedia to answer the given question: {question}')])) 26 | completion = openai.ChatCompletion.create( 27 | model="llama", 28 | messages=[ 29 | {"role": "user", 30 | "content": f"Provide a background document in 100 words according to your knowledge from Wikipedia to answer the given question. Don't output any Chinese."\ 31 | f"\n\n Question:{question} \n\n Background Document:"}, 32 | ], 33 | stream=False, 34 | max_tokens=20, 35 | temperature=0.7, 36 | # stop=["Sure"], 37 | logprobs=True, 38 | top_logprobs=3, 39 | ) 40 | 41 | while len(completion.choices[0].message.content.strip().split(" ")) < 10: 42 | completion = openai.ChatCompletion.create( 43 | model="llama", 44 | messages=[ 45 | {"role": "user", 46 | "content": f"Provide a background document in 100 words according to your knowledge from Wikipedia to answer the given question. Don't output any Chinese."\ 47 | f"\n\n Question:{question} \n\n Background Document:"}, 48 | ], 49 | stream=False, 50 | max_tokens=20, 51 | temperature=0.7, 52 | # stop=["Sure"], 53 | logprobs=True, 54 | top_logprobs=3, 55 | ) 56 | 57 | print(completion) 58 | print('\n\n') 59 | 60 | # openai.Completion.create(prompt=f"Provide a background document in 100 words according to your knowledge to answer the given question." \ 61 | # f"\n\n Question:{question} \n\n Background Document:", 62 | # model="llama", 63 | # max_tokens=128, 64 | # temperature=0.7) 65 | # 66 | # print(completion.choices[0].text) 67 | # print('\n\n') 68 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VerdureChen/SOS-Retrieval-Loop/56a19847edfc30d1e6e59894d348e1c17ae15114/pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | accelerate==0.24.1 3 | aiohttp==3.8.6 4 | aiosignal==1.3.1 5 | anyio==3.7.1 6 | async-timeout==4.0.3 7 | attrs==23.1.0 8 | bitsandbytes==0.41.1 9 | certifi==2023.7.22 10 | charset-normalizer==3.3.2 11 | click==8.1.7 12 | cmake==3.27.7 13 | cpm-kernels==1.0.11 14 | dataclasses-json==0.6.1 15 | dill==0.3.7 16 | einops==0.7.0 17 | exceptiongroup==1.1.3 18 | fastapi==0.95.1 19 | filelock==3.13.1 20 | frozenlist==1.4.0 21 | fsspec==2023.10.0 22 | greenlet==3.0.1 23 | h11==0.14.0 24 | httptools==0.6.1 25 | idna==3.4 26 | Jinja2==3.1.2 27 | joblib==1.3.2 28 | jsonpatch==1.33 29 | jsonpointer==2.4 30 | jsonschema==4.19.2 31 | jsonschema-specifications==2023.7.1 32 | langchain==0.0.331 33 | langsmith==0.0.60 34 | lit==17.0.4 35 | loguru==0.7.2 36 | MarkupSafe==2.1.3 37 | marshmallow==3.20.1 38 | mpmath==1.3.0 39 | msgpack==1.0.7 40 | multidict==6.0.4 41 | multiprocess==0.70.15 42 | mypy-extensions==1.0.0 43 | networkx==3.2.1 44 | ninja==1.11.1.1 45 | nltk==3.8.1 46 | numpy==1.26.1 47 | nvidia-cublas-cu11==11.10.3.66 48 | nvidia-cublas-cu12==12.1.3.1 49 | nvidia-cuda-cupti-cu11==11.7.101 50 | nvidia-cuda-cupti-cu12==12.1.105 51 | nvidia-cuda-nvrtc-cu11==11.7.99 52 | nvidia-cuda-nvrtc-cu12==12.1.105 53 | nvidia-cuda-runtime-cu11==11.7.99 54 | nvidia-cuda-runtime-cu12==12.1.105 55 | nvidia-cudnn-cu11==8.5.0.96 56 | nvidia-cudnn-cu12==8.9.2.26 57 | nvidia-cufft-cu11==10.9.0.58 58 | nvidia-cufft-cu12==11.0.2.54 59 | nvidia-curand-cu11==10.2.10.91 60 | nvidia-curand-cu12==10.3.2.106 61 | nvidia-cusolver-cu11==11.4.0.1 62 | nvidia-cusolver-cu12==11.4.5.107 63 | nvidia-cusparse-cu11==11.7.4.91 64 | nvidia-cusparse-cu12==12.1.0.106 65 | nvidia-nccl-cu11==2.14.3 66 | nvidia-nccl-cu12==2.18.1 67 | nvidia-nvjitlink-cu12==12.3.52 68 | nvidia-nvtx-cu11==11.7.91 69 | nvidia-nvtx-cu12==12.1.105 70 | openai==0.28.1 71 | packaging==23.2 72 | pandas==2.1.2 73 | peft==0.6.0 74 | Pillow==10.1.0 75 | protobuf==4.25.0 76 | psutil==5.9.6 77 | pyarrow==14.0.0 78 | pydantic==1.10.13 79 | python-dateutil==2.8.2 80 | python-dotenv==1.0.0 81 | pytz==2023.3.post1 82 | PyYAML==6.0.1 83 | ray==2.8.0 84 | referencing==0.30.2 85 | regex==2023.10.3 86 | requests==2.31.0 87 | rpds-py==0.12.0 88 | safetensors==0.4.0 89 | scikit-learn==1.3.2 90 | scipy==1.11.3 91 | sentence-transformers==2.2.2 92 | sentencepiece==0.1.99 93 | six==1.16.0 94 | sniffio==1.3.0 95 | SQLAlchemy==2.0.23 96 | starlette==0.26.1 97 | sympy==1.12 98 | tenacity==8.2.3 99 | threadpoolctl==3.2.0 100 | tiktoken==0.5.1 101 | tokenizers==0.13.3 102 | torch==2.1.0+cu118 103 | torchaudio==2.1.0+cu118 104 | torchvision==0.16.0+cu118 105 | tqdm==4.66.1 106 | transformers==4.33.2 107 | transformers-stream-generator==0.0.4 108 | triton==2.1.0 109 | typing-inspect==0.9.0 110 | typing_extensions==4.4.0 111 | tzdata==2023.3 112 | urllib3==2.0.7 113 | uvicorn==0.24.0.post1 114 | uvloop==0.19.0 115 | vllm==0.2.0 116 | watchfiles==0.21.0 117 | websockets==12.0 118 | xformers==0.0.22.post7 119 | xxhash==3.4.1 120 | yarl==1.9.2 121 | coloredlogs==15.0.1 122 | datasets==2.16.1 123 | dill==0.3.7 124 | elasticsearch==8.12.0 125 | elastic-transport==8.12.0 126 | faiss-gpu==1.7.2 127 | fast-bleu==0.0.90 128 | flatbuffers==23.5.26 129 | huggingface-hub==0.20.3 130 | humanfriendly==10.0 131 | lightgbm==4.3.0 132 | multiprocess==0.70.15 133 | nmslib==2.1.1 134 | onnxruntime==1.16.3 135 | pyarrow-hotfix==0.6 136 | pybind11==2.6.1 137 | pyjnius==1.6.1 138 | pyserini==0.22.1 139 | xxhash==3.4.1 140 | -------------------------------------------------------------------------------- /src/evaluation/compute_mis_em.py: -------------------------------------------------------------------------------- 1 | #compute top5 doc self bleu 2 | import os 3 | import sys 4 | from fast_bleu import BLEU, SelfBLEU 5 | import json 6 | import re 7 | from collections import defaultdict 8 | import datasets 9 | import string 10 | 11 | 12 | def evaluate(predictions, mis_answer_dict): 13 | # evaluate the predictions with exact match 14 | def _normalize_answer(s): 15 | def remove_articles(text): 16 | return re.sub(r"\b(a|an|the)\b", " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | def exact_match_score(example, mis_answer_dict): 31 | question = example['question'] 32 | ground_truths = mis_answer_dict[question] 33 | assert type(ground_truths) == list, f'ground_truths is not a list, id:{example["id"]}, ground_truth:{ground_truths}' 34 | prediction = example['response'] 35 | example['exact_match'] = 0 36 | if not prediction: 37 | print(f'no prediction for qid {example["qid"]}, {example["query"]}') 38 | return example 39 | for ground_truth in ground_truths: 40 | if _normalize_answer(ground_truth) in _normalize_answer(prediction): 41 | example['exact_match'] = 1 42 | break 43 | return example 44 | 45 | fn_args = {'mis_answer_dict': mis_answer_dict} 46 | predictions = predictions.map(exact_match_score, fn_kwargs=fn_args) 47 | return predictions 48 | 49 | 50 | def get_mis_answer(retrieval_file_path): 51 | # mis_parent_path = os.path.dirname(os.path.dirname(os.path.dirname(retrieval_file_path))) 52 | # mis_answer_path = os.path.join(mis_parent_path, 'misinfo_data', 'mis_passage_processed', 'merged_file', 'merged.jsonl') 53 | mis_answer_dict = {} 54 | mis_answer_path = '../../data_v2/misinfo_data/DPR/mis_passage_processed/merged_file/merged.jsonl' 55 | with open(mis_answer_path, 'r') as f: 56 | for line in f: 57 | data = json.loads(line) 58 | mis_answer_dict[data['question']] = data['false_answer'] 59 | assert len(mis_answer_dict) == 800 60 | return mis_answer_dict 61 | 62 | 63 | 64 | def calculate_mis_em(retrieval_file, mis_answer_path): 65 | mis_answer_dict = get_mis_answer(mis_answer_path) 66 | dataset = datasets.load_dataset("json", data_files=retrieval_file)["train"] 67 | print('evaluate QA EM...') 68 | prediction = evaluate(dataset, mis_answer_dict) 69 | EM = sum(prediction['exact_match']) / len(prediction['exact_match']) 70 | return EM 71 | 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/evaluation/compute_pearson.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from scipy.stats import pearsonr 3 | import os 4 | from prettytable import PrettyTable 5 | 6 | # 读取TSV文件的函数 7 | def read_tsv(file_path): 8 | with open(file_path, 'r') as file: 9 | lines = file.readlines() 10 | 11 | data = {} 12 | for i in range(1, len(lines), 4): # 跳过标题行,每个模型3行 13 | # print(lines[i]) 14 | # Generate Model: gpt-3.5-turbo 15 | # print(lines[i].strip().split(':')[1]) 16 | model = lines[i].strip().split(':')[1].strip() 17 | # 使用列表推导式获取bm25值,注意跳过前两个元素 18 | bm25_values = [float(x) for x in lines[i + 2].strip().split('\t')[1:]] 19 | data[model] = bm25_values 20 | return data 21 | 22 | 23 | # 计算两组数据的皮尔逊相关系数的函数 24 | def calculate_correlation(data1, data2): 25 | # print(data1) 26 | # print(data2) 27 | 28 | correlations = {} 29 | for model in data1.keys(): 30 | correlation, _ = pearsonr(data1[model], data2[model]) 31 | correlations[model] = correlation 32 | return correlations 33 | 34 | 35 | # 主程序 36 | def main(): 37 | data_names = ['nq', 'webq', 'pop', 'tqa'] 38 | method_names = [ 39 | 'loop_output/DPR/mis_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240129064151', 40 | 'loop_output/DPR/mis_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240124142811', 41 | 'loop_output/DPR/mis_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240125140045', 42 | 'loop_output/DPR/mis_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240123121401' 43 | ] 44 | file_type = [ 45 | 'right', 46 | 'mis' 47 | ] 48 | model_avg_dict = {} 49 | # 假设你的数据已经保存在两个.tsv文件中 50 | for data_name in data_names: 51 | for method_name in method_names: 52 | for type in file_type: 53 | # ../../data_v2/loop_output/DPR/mis_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240123121401/nq/results/nq_QA.tsv 54 | file1_path = os.path.join('../../data_v2', method_name, data_name, 'results', f'{data_name}_QA.tsv') 55 | file2_path = os.path.join('../../data_v2', method_name, data_name, 'results', f'{data_name}_QA_llm_{type}.tsv') 56 | # print(file1_path) 57 | # print(file2_path) 58 | # 判断文件是否存在 59 | if not os.path.exists(file1_path) or not os.path.exists(file2_path): 60 | print(f"File {file1_path} or {file2_path} does not exist!") 61 | continue 62 | # 读取数据 63 | data1 = read_tsv(file1_path) 64 | data2 = read_tsv(file2_path) 65 | 66 | # 计算相关系数 67 | correlations = calculate_correlation(data1, data2) 68 | 69 | # 输出相关性结果 70 | for model, corr in correlations.items(): 71 | print(f"Model {model}: Pearson Correlation Coefficient = {corr:.3f}") 72 | if model not in model_avg_dict: 73 | model_avg_dict[model] = {} 74 | if data_name not in model_avg_dict[model]: 75 | model_avg_dict[model][data_name] = {} 76 | if type not in model_avg_dict[model][data_name]: 77 | model_avg_dict[model][data_name][type] = [] 78 | model_avg_dict[model][data_name][type].append(corr) 79 | 80 | # print a pretty table like this: 81 | # type = right 82 | # +------------------+---------+---------+---------+---------+ 83 | # | Model | nq | webq | pop | tqa | 84 | # +------------------+---------+---------+---------+---------+ 85 | # | gpt-3.5-turbo | 0.000 | 0.000 | 0.000 | 0.000 | 86 | # +------------------+---------+---------+---------+---------+ 87 | # | contriever | 0.000 | 0.000 | 0.000 | 0.000 | 88 | # +------------------+---------+---------+---------+---------+ 89 | # | bge-base | 0.000 | 0.000 | 0.000 | 0.000 | 90 | # +------------------+---------+---------+---------+---------+ 91 | # | llm-embedder | 0.000 | 0.000 | 0.000 | 0.000 | 92 | # +------------------+---------+---------+---------+---------+ 93 | # | Average | 0.000 | 0.000 | 0.000 | 0.000 | 94 | # +------------------+---------+---------+---------+---------+ 95 | # type = mis 96 | 97 | def print_pretty_table(model_avg_dict, file_type): 98 | for type in file_type: 99 | print(f"type = {type}") 100 | table = PrettyTable() 101 | table.field_names = ["Model"] + data_names # 表头 102 | 103 | for model in model_avg_dict.keys(): 104 | row = [model] # 开始构建行 105 | for data_name in data_names: 106 | # 计算每种数据类型的平均相关性 107 | avg_corr = sum(model_avg_dict[model][data_name][type]) / len(model_avg_dict[model][data_name][type]) 108 | row.append(f"{avg_corr:.3f}") 109 | table.add_row(row) # 添加行到表格 110 | 111 | # 添加平均行 112 | avg_row = ["Average"] 113 | for data_name in data_names: 114 | all_corrs = [model_avg_dict[model][data_name][type] for model in model_avg_dict if model not in ["gpt-3.5-turbo"]] 115 | avg_corr = sum(sum(corrs) for corrs in all_corrs) / sum(len(corrs) for corrs in all_corrs) 116 | avg_row.append(f"{avg_corr:.3f}") 117 | table.add_row(avg_row) 118 | 119 | print(table) # 打印表格 120 | 121 | print_pretty_table(model_avg_dict, file_type) 122 | 123 | 124 | 125 | 126 | 127 | # 执行主程序 128 | main() 129 | 130 | -------------------------------------------------------------------------------- /src/evaluation/compute_sig.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from scipy.stats import ttest_ind 4 | import os 5 | 6 | def calculate_acc_at_5(data): 7 | acc_scores = [] 8 | for query_id, query_data in data.items(): 9 | top_5_contexts = query_data['contexts'][:20] # 取前5个结果 10 | acc_scores.append(1 if any(ctx['has_answer'] for ctx in top_5_contexts) else 0) # 检查是否有正确答案 11 | return acc_scores 12 | 13 | 14 | def compare_significance(acc1_scores, acc2_scores): 15 | # 使用t-test比较两组数据的平均值 16 | t_stat, p_value = ttest_ind(acc1_scores, acc2_scores) 17 | return t_stat, p_value 18 | 19 | 20 | def read_json_file(file_path): 21 | with open(file_path, 'r', encoding='utf-8') as file: 22 | data = json.load(file) 23 | return data 24 | 25 | 26 | def print_results(file1_mean, file2_mean, t_stat, p_value): 27 | print(f"File 1 ACC@5 Average: {file1_mean}") 28 | print(f"File 2 ACC@5 Average: {file2_mean}") 29 | print(f"T-statistic: {t_stat}, P-value: {p_value}") 30 | if p_value < 0.05: 31 | print("There is a significant difference between the two files at the 5% significance level.") 32 | else: 33 | print("There is no significant difference between the two files at the 5% significance level.") 34 | 35 | if file1_mean > file2_mean: 36 | print("File 1 has a higher ACC@5 average value.") 37 | elif file1_mean < file2_mean: 38 | print("File 2 has a higher ACC@5 average value.") 39 | else: 40 | print("Both files have the same ACC@5 average value.") 41 | 42 | 43 | def main(file1_path, file2_path): 44 | # 读取文件 45 | file1_data = read_json_file(file1_path) 46 | file2_data = read_json_file(file2_path) 47 | 48 | # 计算acc@5 49 | acc1_scores = calculate_acc_at_5(file1_data) 50 | acc2_scores = calculate_acc_at_5(file2_data) 51 | 52 | # 计算平均值 53 | file1_mean = sum(acc1_scores) / len(acc1_scores) 54 | file2_mean = sum(acc2_scores) / len(acc2_scores) 55 | 56 | # 比较两个文件 57 | t_stat, p_value = compare_significance(acc1_scores, acc2_scores) 58 | 59 | # 打印结果 60 | print_results(file1_mean, file2_mean, t_stat, p_value) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser( 65 | description='Compare ACC@5 significance between two JSON files and determine which has the higher average.') 66 | parser.add_argument('ref', type=str, help='Path to the first JSON file') 67 | parser.add_argument('run', type=str, help='Path to the second JSON file') 68 | 69 | args = parser.parse_args() 70 | 71 | retrieve_model_names = [ 72 | 'bm25', 73 | 'contriever', 74 | 'bge-base', 75 | 'llm-embedder', 76 | ] 77 | 78 | rerank_model_names = [ 79 | 'upr', 80 | 'monot5', 81 | 'bge' 82 | ] 83 | data_names = [ 84 | 'nq', 85 | 'webq', 86 | 'pop', 87 | 'tqa' 88 | ] 89 | 90 | # for all files in two folders 91 | for data_name in data_names: 92 | for retrieve_model_name in retrieve_model_names: 93 | ref_path = os.path.join(args.ref, f'{data_name}/{data_name}-test-{retrieve_model_name}') 94 | run_path = os.path.join(args.run, f'{data_name}/{data_name}-test-{retrieve_model_name}') 95 | print(f'Comparing {ref_path} and {run_path}') 96 | main(ref_path, run_path) 97 | for rerank_model_name in rerank_model_names: 98 | if retrieve_model_name not in ['llm-embedder','contriever']: 99 | ref_path = os.path.join(args.ref, f'{data_name}/{data_name}-{rerank_model_name}_rerank_based_on_{retrieve_model_name}.json') 100 | run_path = os.path.join(args.run, f'{data_name}/{data_name}-{rerank_model_name}_rerank_based_on_{retrieve_model_name}.json') 101 | print(f'Comparing {ref_path} and {run_path}') 102 | main(ref_path, run_path) 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /src/evaluation/draw_bleu.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | from matplotlib.font_manager import FontProperties 5 | 6 | # 指定字体路径 7 | my_font = FontProperties(fname='/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/externals/popular-fonts/微软雅黑.ttf') 8 | datasets = ['nq', 'webq', 'pop', 'tqa'] 9 | 10 | for dataset in datasets: 11 | # 读取TSV文件 12 | file_path = f'png_tsvs/bleu/{dataset}_bleu_3.tsv' # 替换成你的TSV文件路径 13 | df = pd.read_csv(file_path, sep='\t') 14 | #converting columns names to string type 15 | df.columns = df.columns.astype(str) 16 | df = df.rename(columns={'0': 'Ori.'}) 17 | 18 | # 重命名方法 19 | #bm25 20 | # contriever 21 | # bge-base 22 | # llm-embedder 23 | # bm25+upr 24 | # bm25+monot5 25 | # bm25+bge 26 | # bge-base+upr 27 | # bge-base+monot5 28 | # bge-base+bge 29 | df['Method'] = df['Method'].replace('bm25', 'BM25') 30 | df['Method'] = df['Method'].replace('contriever', 'Contri') 31 | df['Method'] = df['Method'].replace('bge-base', 'BGE-B') 32 | df['Method'] = df['Method'].replace('llm-embedder', 'LLM-E') 33 | df['Method'] = df['Method'].replace('bm25+upr', 'BM25+U') 34 | df['Method'] = df['Method'].replace('bm25+monot5', 'BM25+M') 35 | df['Method'] = df['Method'].replace('bm25+bge', 'BM25+BR') 36 | df['Method'] = df['Method'].replace('bge-base+upr', 'BGE-B+U') 37 | df['Method'] = df['Method'].replace('bge-base+monot5', 'BGE-B+M') 38 | df['Method'] = df['Method'].replace('bge-base+bge', 'BGE-B+BR') 39 | 40 | # 转换DataFrame为长格式 41 | df_long = df.melt(id_vars=['Method'], var_name='loop', value_name='Self-BLEU') 42 | 43 | # 将loop列的数据类型转换为整数,以便在图表中正确排序 44 | df_long['loop'] = df_long['loop'].astype(str) 45 | sns.set_theme(style="ticks", rc={'axes.formatter.limits': (-4, 5)}, font_scale=2.6) 46 | # 画图 47 | plt.rcParams['font.weight'] = 'bold' 48 | plt.rcParams['axes.labelweight'] = 'bold' 49 | plt.rcParams['axes.titleweight'] = 'bold' 50 | plt.figure(figsize=(7.5, 7.5)) # 可以调整图片大小 51 | if dataset != 'pop':# and dataset != 'tqa': 52 | sns.lineplot(data=df_long, x='loop', y='Self-BLEU', hue='Method', palette="mako", legend=False) 53 | else: 54 | sns.lineplot(data=df_long, x='loop', y='Self-BLEU', hue='Method', palette="mako") 55 | # plt.title('Self-BLEU Values per Method Over Iterations') # 可以自定义标题 56 | sns.despine() 57 | plt.xlabel('迭代', fontproperties=my_font, fontsize=30, fontweight='bold',) # X轴标签 58 | plt.ylabel('Self-BLEU') # Y轴标签 59 | if dataset == 'pop':# or dataset == 'tqa': 60 | plt.legend(loc='upper right', bbox_to_anchor=(1.77, 0.93), fontsize=25) # 图例 61 | 62 | # plt.tight_layout() # 调整布局 63 | plt.show() # 显示图表 64 | plt.savefig(f'png_tsvs/bleu/{dataset}_bleu.png', bbox_inches='tight') # 保存图表到文件 65 | -------------------------------------------------------------------------------- /src/evaluation/draw_context.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | from matplotlib.font_manager import FontProperties 5 | 6 | # 指定字体路径 7 | my_font = FontProperties(fname='/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/externals/popular-fonts/微软雅黑.ttf', size=20) 8 | # 您提供的数据 9 | right_nums = ['right_0', 'right_1', 'right_2', 'right_3', 'right_4', 'right_5'] 10 | loops = list(range(1, 11)) 11 | 12 | # 创建所有的right和loop组合 13 | right_loop_combinations = [(right, loop) for right in right_nums for loop in loops] 14 | right_values = [int(combo[0].split('_')[1]) for combo in right_loop_combinations] 15 | loop_values = [combo[1] for combo in right_loop_combinations] 16 | 17 | # avg_query_counts数据 18 | avg_query_counts = [ 19 | # EM 0 data 20 | 58.88, 74.36, 77.88, 80.9, 83.26, 85, 86.08, 87.54, 88.06, 88.88, 21 | 23.58, 10.12, 7.5, 5.72, 4.74, 3.44, 3.38, 2.44, 2.34, 1.98, 22 | 9.38, 4.58, 3.72, 2.78, 2.14, 1.92, 1.26, 1.44, 0.94, 0.72, 23 | 4.4, 2.98, 2.18, 1.64, 1.4, 1.1, 0.94, 0.64, 0.6, 0.74, 24 | 1.54, 2.74, 2.06, 1.5, 1.08, 1.26, 1.12, 1.14, 1.12, 1.52, 25 | 0.58, 1.36, 2.12, 2.88, 3.56, 3.1, 3.52, 3.24, 3.44, 3.34, 26 | # EM 1 data 27 | 3.22, 2.44, 2.62, 2.7, 2.94, 2.7, 2.92, 2.86, 2.94, 3.02, 28 | 14.32, 4.08, 2.5, 1.78, 0.96, 1.16, 1.02, 0.86, 0.96, 0.62, 29 | 16.82, 6.12, 3.88, 3.02, 2.26, 1.78, 1.44, 1.66, 1.06, 1.28, 30 | 16.4, 11.92, 8.02, 5.06, 4.2, 3.5, 3.16, 2.26, 3, 1.96, 31 | 25.46, 22.76, 16.84, 14.1, 12.12, 11.64, 9.98, 9.86, 8.78, 8.78, 32 | 25.42, 56.54, 70.68, 77.92, 81.34, 83.4, 85.18, 86.06, 86.76, 87.16 33 | ] 34 | # 为了区分EM 0和EM 1的数据,我们创建一个EM列 35 | em_values = ['EM=0'] * (len(right_nums) * len(loops)) + ['EM=1'] * (len(right_nums) * len(loops)) 36 | 37 | # 创建DataFrame 38 | df = pd.DataFrame({ 39 | 'Right_Num': right_values * 2, 40 | 'Loop': loop_values * 2, 41 | 'Avg_Query_Count': avg_query_counts, 42 | 'EM': em_values 43 | }) 44 | # 自定义颜色字典 45 | palette = {"EM=1": "#6A3D9A", "EM=0": "#7F8FDD"} 46 | # 筛选出我们想要的loops 47 | df = df[df['Loop'].isin([1, 2, 5, 10])] 48 | 49 | # 设置Seaborn的风格 50 | sns.set(style="whitegrid", font_scale=1.5) 51 | 52 | # 创建一个Figure对象和6个Axes对象(subplot),所有子图在一行 53 | fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(18, 5), sharey=True) 54 | 55 | # 将循环 10 的子图放在最后一个位置 56 | desired_loops = [1, 2, 5, 10] 57 | # retrieval acc@5*100 的值 58 | retrieval_acc = [ 59 | 0.6895, 0.616, 0.5975, 0.582, 0.569, 0.5615, 0.555, 0.548, 0.545, 0.5405 60 | ] 61 | 62 | # 创建自定义的图例标签和线 63 | lines_for_legend = [] 64 | # 遍历每个desired_loop的子图并绘制条形图 65 | for i, loop in enumerate(desired_loops): 66 | ax = axes[i] 67 | loop_df = df[df['Loop'] == loop] 68 | 69 | # 计算EM=0和EM=1的总数 70 | em0_total = loop_df[loop_df['EM'] == 'EM=0']['Avg_Query_Count'].sum() 71 | em1_total = loop_df[loop_df['EM'] == 'EM=1']['Avg_Query_Count'].sum() 72 | 73 | # 绘制条形图 74 | sns.barplot(x='Right_Num', y='Avg_Query_Count', hue='EM', data=loop_df, ax=ax, palette=palette) 75 | 76 | # 在图上添加代表EM=0总数和EM=1总数的直线 77 | em0_line = ax.axhline(em0_total, color="#7F8FDD", linestyle='--', lw=2) 78 | em1_line = ax.axhline(em1_total, color="#6A3D9A", linestyle='--', lw=2) 79 | 80 | # 添加直线的标签 81 | ax.text(0.9, em0_total - 8, f'{em0_total:.2f}', color="#7F8FDD", ha='right', transform=ax.get_yaxis_transform()) 82 | ax.text(0.93, em1_total + 4, f'{em1_total:.2f}', color="#6A3D9A", ha='right', transform=ax.get_yaxis_transform()) 83 | 84 | # 添加代表retrieval acc@5*100的直线 85 | if loop == 10: # 调整索引以匹配列表长度 86 | acc_value = retrieval_acc[-1] 87 | else: 88 | acc_value = retrieval_acc[loop - 1] 89 | # 添加代表retrieval acc@5*100的直线 90 | acc_line = ax.axhline(acc_value * 100, color='#39375B', linestyle=':', lw=2) 91 | ax.text(0.9, acc_value * 100 + 4, f'{acc_value * 100:.2f}', color='#39375B', ha='right', 92 | transform=ax.get_yaxis_transform()) 93 | 94 | # ax.set_title(f'Iteration {loop}', pad=24) # 增加标题的垂直间距 95 | # ax.set_ylabel('Average Query Num' if i == 0 else '') # 增加 y 轴标签的垂直间距 96 | # ax.set_xlabel('Context Right Num', labelpad=10) # 增加 x 轴标签的垂直间距 97 | ax.set_title(f'迭代 {loop}', fontproperties=my_font, pad=24) # 增加标题的垂直间距 98 | ax.set_ylabel('平均查询数量' if i == 0 else '', fontproperties=my_font) # 增加 y 轴标签的垂直间距 99 | ax.set_xlabel('正确上下文数量', labelpad=10, fontproperties=my_font) # 增加 x 轴标签的垂直间距 100 | if i < len(axes): 101 | ax.legend().set_visible(False) 102 | # 如果是最后一个subplot,我们将直线添加到图例 103 | if i == len(desired_loops) - 1: 104 | lines_for_legend.append(em0_line) 105 | lines_for_legend.append(em1_line) 106 | lines_for_legend.append(acc_line) 107 | # 在整个图形的底部中心位置添加一个图例,现在包括我们的自定义线 108 | handles, labels = axes[-1].get_legend_handles_labels() 109 | custom_labels = ["EM=0 Total", "EM=1 Total", "Acc@5"] 110 | handles.extend(lines_for_legend) # 将自定义的线添加进图例句柄 111 | labels.extend(custom_labels) # 添加自定义标签 112 | 113 | fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.13, 0.6), ncol=1) 114 | 115 | # 调整子图之间的间距和边缘 116 | plt.tight_layout() 117 | 118 | # 调整子图的位置以适应图例 119 | fig.subplots_adjust(bottom=0.15) 120 | 121 | # 显示图表 122 | plt.show() 123 | plt.savefig('png_tsvs/draw_context.png', dpi=300, bbox_inches='tight') -------------------------------------------------------------------------------- /src/evaluation/draw_filter_retrieval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | 6 | # 由于表头是多级的,我们首先定义列名 7 | columns = [ 8 | 'nq', 'no_filter_BM25', 'no_filter_BGE-Base', 'no_filter_Contriever', 'no_filter_LLM-Embedder', 9 | 'filter_bleu_BM25', 'filter_bleu_BGE-Base', 'filter_bleu_Contriever', 'filter_bleu_LLM-Embedder', 10 | 'filter_source_BM25', 'filter_source_BGE-Base', 'filter_source_Contriever', 'filter_source_LLM-Embedder' 11 | ] 12 | 13 | # 读取数据,跳过前两行 14 | # 假设TSV文件是tab分隔,并且位于当前文件夹中 15 | df = pd.read_csv('filter_ret.tsv', sep='\t', skiprows=2, names=columns) 16 | 17 | # 设置绘图 18 | plt.figure(figsize=(14, 8)) 19 | plt.title('Accuracy over Loops') 20 | plt.xlabel('Loops') 21 | plt.ylabel('Accuracy @5') 22 | print(df) 23 | # 定义每种方法的颜色,确保相同方法使用相同颜色 24 | colors = { 25 | 'BM25': 'red', 26 | 'BGE-Base': 'blue', 27 | 'Contriever': 'green', 28 | 'LLM-Embedder': 'purple' 29 | } 30 | 31 | # 绘制线条,并确保图例中每个方法只出现一次 32 | for method, color in colors.items(): 33 | plt.plot(x, df[f'no_filter_{method}'], color=color, label=method) 34 | plt.plot(x, df[f'filter_bleu_{method}'], color=color, linestyle='--') 35 | plt.plot(x, df[f'filter_source_{method}'], color=color, linestyle='-.') 36 | 37 | # 为了在图例中只显示方法名称(而不是每条线),我们使用以下技巧: 38 | # 绘制一个不可见的数据点,但带有标签,用于图例 39 | for method, color in colors.items(): 40 | plt.plot([], [], color=color, label=method) 41 | 42 | # 显示图例 43 | plt.legend() 44 | 45 | # 显示网格线 46 | plt.grid(True) 47 | 48 | 49 | 50 | 51 | # 存储图像 52 | plt.savefig('png_tsvs/filter_retrieval/filter_ret.png') 53 | -------------------------------------------------------------------------------- /src/evaluation/draw_map.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | from collections import Counter 5 | from collections import defaultdict 6 | from itertools import cycle 7 | 8 | # 读取TSV文件 9 | file_path = 'tsvs/nq_tsv.tsv' # 请将'your_file_path.tsv'替换为你的TSV文件的实际路径 10 | #string 11 | df = pd.read_csv(file_path, sep='\t', dtype={'0->1': str, '1->0': str}) 12 | def process_id_list(id_list): 13 | if pd.isnull(id_list) or id_list == '': 14 | return [] 15 | return list(map(int, id_list.split(','))) 16 | 17 | # 应用转换函数到'0->1'和'1->0'列 18 | df['0->1'] = df['0->1'].apply(process_id_list) 19 | df['1->0'] = df['1->0'].apply(process_id_list) 20 | 21 | # 绘制图形 22 | plt.figure(figsize=(10, 6)) 23 | 24 | # 得到所有ID的唯一值列表 25 | all_ids = set() 26 | for ids in df['0->1']: 27 | all_ids.update(ids) 28 | for ids in df['1->0']: 29 | all_ids.update(ids) 30 | 31 | # 对于每个ID,在不同loop之间绘制连接线 32 | for id in sorted(all_ids): 33 | # 初始化两个列表,用于存储当前ID在不同loop的位置 34 | x_vals, y_vals = [], [] 35 | 36 | # 检查当前ID在每个loop中是否存在,并记录位置 37 | for index, row in df.iterrows(): 38 | # if id in row['0->1']: 39 | # x_vals.append(row['Loop Num']) 40 | # y_vals.append(id) 41 | # 如果需要连接1->0的变化,可以取消以下注释 42 | if id in row['1->0']: 43 | x_vals.append(row['Loop Num']) 44 | y_vals.append(id) 45 | 46 | # 绘制当前ID在不同loop之间的连接线 47 | if x_vals: # 确保列表不为空 48 | plt.plot(x_vals, y_vals, marker='o', linestyle='-', label=f'ID {id}') 49 | 50 | # 设置X轴与Y轴标题及图形标题 51 | plt.xlabel('Loop Number') 52 | plt.ylabel('ID') 53 | plt.title('Connection of the Same IDs Across Different Loops 1->0') 54 | 55 | # 由于ID可能很多,图例可能会非常拥挤,所以这里选择不显示图例 56 | # 如果确定要显示图例,可以取消以下注释 57 | # plt.legend(loc='best') 58 | 59 | plt.savefig('tsvs/nq_tsv_10.png') -------------------------------------------------------------------------------- /src/evaluation/draw_percentage.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | from matplotlib.font_manager import FontProperties 5 | 6 | # 指定字体路径 7 | my_font = FontProperties(fname='/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/externals/popular-fonts/微软雅黑.ttf') 8 | # 设置seaborn的主题 9 | sns.set_theme(style="ticks") 10 | # 读取数据,假设TSV文件的分隔符是制表符(默认) 11 | df = pd.read_csv('/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/src/evaluation/png_tsvs/percentage/plot_webq_tsv.tsv', delimiter='\t') 12 | #replace gpt-3.5-turbo to GPT-3.5-Turbo, baichuan2-13b-chat to Baichuan2-13B-Chat, qwen-14b-chat to Qwen-14B-Chat, chatglm3-6b to ChatGPT3-6B, llama2-13b-chat to Llama2-13B-Chat, human to Human 13 | df['generate_model_name'] = df['generate_model_name'].replace(['gpt-3.5-turbo', 'baichuan2-13b-chat', 'qwen-14b-chat', 'chatglm3-6b', 'llama2-13b-chat', 'human'], ['ChatGPT', 'Baichuan2', 'Qwen', 'ChatGLM3', 'Llama2', 'Human']) 14 | df = df.melt(id_vars='generate_model_name', var_name='Iteration', value_name='Percentage') 15 | #all value*100 and keep 1 decimal places 16 | df['Percentage'] = df['Percentage'].apply(lambda x: round(x*100, 1)) 17 | # 初始化一个空的列表用来存储每个条形的底部位置 18 | bottoms = [0] * len(df['Iteration'].unique()) 19 | 20 | 21 | # 颜色字典,使用十六进制颜色代码 22 | # gpt-3.5-turbo 23 | # baichuan2-13b-chat 24 | # qwen-14b-chat 25 | # chatglm3-6b 26 | # llama2-13b-chat 27 | # human 28 | # colors = { 29 | # 'gpt-3.5-turbo': '#CAB2D5', 30 | # 'baichuan2-13b-chat': '#FEBE6F', 31 | # 'qwen-14b-chat': '#FC9A98', 32 | # 'chatglm3-6b': '#B2DF8A', 33 | # 'llama2-13b-chat': '#6D759F', 34 | # 'human': '#A6CEE3' 35 | # } 36 | # colors = { 37 | # 'gpt-3.5-turbo': '#6A3D9A', 38 | # 'baichuan2-13b-chat': '#9977B7', 39 | # 'qwen-14b-chat': '#9D7BB9', 40 | # 'chatglm3-6b': '#B599C8', 41 | # 'llama2-13b-chat': '#CAB2D5', 42 | # 'human': '#7F8FDD' 43 | # } 44 | # colors = { 45 | # 'GPT-3.5-Turbo': '#6A3D9A', 46 | # 'Baichuan2-13B-Chat': '#9977B7', 47 | # 'Qwen-14B-Chat': '#9D7BB9', 48 | # 'ChatGPT3-6B': '#B599C8', 49 | # 'Llama2-13B-Chat': '#CAB2D5', 50 | # 'Human': '#7F8FDD' 51 | # } 52 | 53 | colors ={ 54 | 'ChatGPT': '#6A3D9A', 55 | 'Baichuan2': '#9977B7', 56 | 'Qwen': '#9D7BB9', 57 | 'ChatGLM3': '#B599C8', 58 | 'Llama2': '#CAB2D5', 59 | 'Human': '#7F8FDD' 60 | } 61 | 62 | # 设置图表大小 63 | plt.figure(figsize=(7.5, 7.5)) 64 | 65 | # 为每个模型绘制一个条形 66 | for i, model_name in enumerate(df['generate_model_name'].unique()): 67 | # 从DataFrame中抽取当前模型的数据 68 | model_data = df[df['generate_model_name'] == model_name] 69 | 70 | # 绘制条形 71 | plt.bar( 72 | model_data['Iteration'], # X轴坐标 73 | model_data['Percentage'], # 条形高度 74 | bottom=bottoms, # 设置条形的起始位置 75 | color=colors.get(model_name, '#808080'), 76 | label=model_name, 77 | ) 78 | 79 | # 更新下一组条形的起始位置 80 | bottoms = [bottoms[j] + model_data['Percentage'].values[j] for j in range(len(bottoms))] 81 | 82 | # 在最右侧的条形的右侧中间添加百分比文本 83 | last_iteration = df['Iteration'].unique()[-1] # 获取最后一次迭代 84 | last_iteration_data = df[df['Iteration'] == last_iteration] # 获取最后一次迭代的数据 85 | 86 | # 需要确定最后一个条形图的x轴位置 87 | last_bar_index = len(df['Iteration'].unique()) - 1 # 获取最后一个条形图的索引 88 | 89 | 90 | # 在最右侧的条形的右侧中间添加百分比文本 91 | for i, (model_name, percentage) in enumerate(zip(last_iteration_data['generate_model_name'], last_iteration_data['Percentage'])): 92 | # 如果bottoms是累积的,需要先找到这个模型的开始高度 93 | if i == 0: 94 | # 第一个模型的底部是0 95 | bottom_height = 0 96 | else: 97 | # 其他模型的底部是前一个模型的底部加上前一个模型的百分比 98 | bottom_height = bottom_height + last_iteration_data['Percentage'].values[i - 1] 99 | # 计算中点位置 100 | middle_height = bottom_height + (percentage / 2.0) 101 | print(f"{model_name}的底部位置是{bottom_height}") 102 | print(f"{model_name}的高度是{percentage}") 103 | print(f"{model_name}的中点位置是{middle_height}") 104 | 105 | # # 在条形的右侧中间添加文本 106 | # plt.text( 107 | # x=last_bar_index + 0.4, # X轴位置稍微偏右,用索引加上一个小的偏移量 108 | # y=middle_height, # Y轴位置为条形的中间位置 109 | # s=f"{percentage}%", # 显示的文本 110 | # va='center', # 垂直居中 111 | # ha='left', # 水平向左对齐 112 | # ) 113 | 114 | plt.rcParams['font.weight'] = 'bold' 115 | plt.rcParams['axes.labelweight'] = 'bold' 116 | plt.rcParams['axes.titleweight'] = 'bold' 117 | # plt.title('Percentage of Text', fontsize=20) # 标题字体大小 118 | plt.xlabel('迭代', fontsize=30, # X轴标签字体大小加粗 119 | fontweight='bold', fontproperties=my_font) 120 | plt.ylabel('百分比', fontsize=30, fontweight='bold', fontproperties=my_font) 121 | # Y轴标签字体大小 122 | # plt.legend(title='bleu', fontsize=14, # 图例字体大小 123 | # title_fontsize=14) # 图例标题字体大小,加粗 124 | # plt.legend(loc='upper right', bbox_to_anchor=(1.62, 0.7), fontsize=25) 125 | 126 | plt.xticks(fontsize=30, fontweight='bold') # X轴刻度字体大小 127 | plt.yticks(fontsize=30, fontweight='bold') # Y轴刻度字体大小 128 | # 图表标题 129 | # plt.title('Percentage of Text ') 130 | 131 | # X轴和Y轴标签 132 | # plt.xlabel('Iteration') 133 | # plt.ylabel('Percentage') 134 | 135 | # 显示图例 136 | # plt.legend(title='bleu') 137 | 138 | # 显示图表 139 | plt.show() 140 | 141 | #save 142 | plt.savefig('/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/src/evaluation/png_tsvs/percentage/chinese_plot_webq_tsv.png', dpi=300, bbox_inches='tight') -------------------------------------------------------------------------------- /src/evaluation/eva_configs/bleu_eva_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "../../data_v2/loop_output/DPR/mis_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240123051159/nq", 3 | "task": "bleu", 4 | "elasticsearch_url": "http://124.16.138.150:9978", 5 | "index_name": "bm25_psgs_index" 6 | } -------------------------------------------------------------------------------- /src/evaluation/eva_configs/context_answer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "../../data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_bm25_upr_total_loop_10_20231231125441/webq", 3 | "task": "context_answer" 4 | } -------------------------------------------------------------------------------- /src/evaluation/eva_configs/percentage_eva_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "../../data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_bm25_upr_total_loop_10_20231231125441/nq", 3 | "task": "percentage" 4 | } -------------------------------------------------------------------------------- /src/evaluation/eva_configs/qa_eva_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "../../data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240113075935/pop", 3 | "task": "QA" 4 | } -------------------------------------------------------------------------------- /src/evaluation/eva_configs/rank_eva_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "../../data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642/tqa", 3 | "task": "rank" 4 | } -------------------------------------------------------------------------------- /src/evaluation/eva_configs/retrieval_eva_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "directory": "../../data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642/tqa", 3 | "task": "retrieval" 4 | } -------------------------------------------------------------------------------- /src/evaluation/eva_context_answer.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | from eva_generate_em import evaluate as evaluate_em 4 | from datasets.utils.logging import disable_progress_bar 5 | from collections import defaultdict 6 | disable_progress_bar() 7 | 8 | def read_result_file(file_name): 9 | with open(file_name, "r", encoding='utf-8') as f: 10 | question_dataset = json.load(f) 11 | formatted_data = [ 12 | { 13 | "id": qid, 14 | "question": details["question"], 15 | "answers": details["answers"], 16 | "contexts": details["contexts"], 17 | } 18 | for qid, details in question_dataset.items() 19 | ] 20 | result_dataset = datasets.Dataset.from_list(formatted_data) 21 | 22 | return result_dataset 23 | 24 | 25 | def read_generated_file(file_name): 26 | generated_dataset = datasets.load_dataset('json', data_files=file_name)['train'] 27 | prediction = evaluate_em(generated_dataset) # format: for each sample, example['exact_match'] = 1 or 0 28 | 29 | return prediction 30 | 31 | 32 | def get_top_answer_number(example, cut_off=100): 33 | count = 0 34 | for i, context in enumerate(example['contexts']): 35 | if context['has_answer'] == True: 36 | count += 1 37 | if i == cut_off - 1: 38 | break 39 | 40 | example['top_answer_number'] = count 41 | return example 42 | 43 | 44 | def get_top_answer_number_dataset(dataset, cut_off=100): 45 | map_fn_kwargs = {'cut_off': cut_off} 46 | dataset = dataset.map(get_top_answer_number, fn_kwargs=map_fn_kwargs, num_proc=4) 47 | return dataset 48 | 49 | 50 | def compute_top_answer_number(result_file, generated_file, cut_off=5): 51 | result_dataset = read_result_file(result_file) 52 | generated_dataset = read_generated_file(generated_file) 53 | result_dataset = get_top_answer_number_dataset(result_dataset, cut_off) 54 | assert result_dataset['id'] == generated_dataset['id'] 55 | em = generated_dataset['exact_match'] 56 | top_answer_number = result_dataset['top_answer_number'] 57 | 58 | top_answer_number_dict = defaultdict(list) 59 | for i in range(len(em)): 60 | top_answer_number_dict[f'{cut_off}_{em[i]}'].append(top_answer_number[i]) 61 | 62 | # print(top_answer_number_dict) 63 | # print(top_answer_number_dict) 64 | return top_answer_number_dict 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/evaluation/eva_domi.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VerdureChen/SOS-Retrieval-Loop/56a19847edfc30d1e6e59894d348e1c17ae15114/src/evaluation/eva_domi.py -------------------------------------------------------------------------------- /src/evaluation/eva_generate_em.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import string 3 | import re 4 | import os 5 | 6 | 7 | def evaluate(predictions): 8 | # evaluate the predictions with exact match 9 | def _normalize_answer(s): 10 | def remove_articles(text): 11 | return re.sub(r"\b(a|an|the)\b", " ", text) 12 | 13 | def white_space_fix(text): 14 | return " ".join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return "".join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | def exact_match_score(example): 26 | ground_truths = example['answers'] 27 | assert type(ground_truths) == list, f'ground_truths is not a list, id:{example["id"]}, ground_truth:{ground_truths}' 28 | prediction = example['response'] 29 | example['exact_match'] = 0 30 | if not prediction: 31 | print(f'no prediction for qid {example["qid"]}, {example["query"]}') 32 | return example 33 | for ground_truth in ground_truths: 34 | if _normalize_answer(ground_truth) in _normalize_answer(prediction): 35 | example['exact_match'] = 1 36 | break 37 | return example 38 | 39 | predictions = predictions.map(exact_match_score) 40 | return predictions 41 | 42 | 43 | def main(): 44 | # load the predictions 45 | res_file_path = '../../data_v2/zero_gen_data/DPR' 46 | data_names = ['nq', 'webq', 'tqa', 'pop'] 47 | model_names = ['chatglm3-6b-chat', 'qwen-14b-chat', 'baichuan2-13b-chat', 'llama2-13b-chat', 'gpt-3.5-turbo'] 48 | for data_name in data_names: 49 | for model_name in model_names: 50 | res_path = f'{res_file_path}/{data_name}-test-gen-{model_name}.jsonl' 51 | # check if the file exists 52 | if not os.path.exists(res_path): 53 | continue 54 | predictions = datasets.load_dataset('json', data_files=res_path)['train'] 55 | predictions = evaluate(predictions) 56 | print('-' * 20) 57 | print(f'{data_name}-{model_name}') 58 | # compute the exact match score 59 | # 'list' object has no attribute 'sum' 60 | # predictions['exact_match'] is a list 61 | print(sum(predictions['exact_match']) / len(predictions['exact_match'])) 62 | print('-' * 20) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | 68 | -------------------------------------------------------------------------------- /src/evaluation/eva_percentage.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from collections import defaultdict 4 | 5 | # 使用正则表达式来匹配LLM名字 6 | def extract_llm_name(docid): 7 | match = re.match(r'([a-zA-Z0-9.-]+)_', str(docid)) 8 | return match.group(1) if match else None 9 | 10 | def calculate_llm_human_proportion(retrieval_file, topks, llm_names_preset=None): 11 | with open(retrieval_file, 'r') as file: 12 | data = json.load(file) 13 | 14 | # 初始化结果字典,包括每个LLM和人类的计数 15 | results = defaultdict(lambda: defaultdict(int)) 16 | if llm_names_preset: 17 | for topk in topks: 18 | for name in llm_names_preset: 19 | results[topk][name] = 0 20 | results[topk]['human'] = 0 21 | 22 | # 遍历每个查询 23 | for query_id, query_data in data.items(): 24 | contexts = query_data['contexts'] 25 | for topk in topks: 26 | # 获取当前topk的文档 27 | topk_contexts = contexts[:topk] 28 | llm_names = defaultdict(int) 29 | 30 | # 计算LLM和人类生成文本的数量 31 | for ctx in topk_contexts: 32 | llm_name = extract_llm_name(ctx['docid']) 33 | if llm_name: 34 | llm_names[llm_name] += 1 35 | else: 36 | # 假设数字id表示人类生成的文本 37 | llm_names['human'] += 1 38 | 39 | # 更新结果 40 | for name, count in llm_names.items(): 41 | results[topk][name] += count 42 | 43 | # 计算比例 44 | proportions = defaultdict(dict) 45 | # merge results[topk]['chatglm3-6b'] and results[topk]['chatglm3-6b-chat'] into results[topk]['chatglm3-6b'] 46 | for topk in topks: 47 | if 'chatglm3-6b-chat' in results[topk]: 48 | results[topk]['chatglm3-6b'] += results[topk]['chatglm3-6b-chat'] 49 | del results[topk]['chatglm3-6b-chat'] 50 | 51 | for topk in topks: 52 | total = sum(results[topk].values()) 53 | for name in results[topk]: 54 | proportions[topk][f'{name}_proportion'] = results[topk][name] / total 55 | print(f'Proportions for top-{topk}: {proportions[topk]}') 56 | print(f'sum: {sum(proportions[topk].values())}') 57 | # return dict like {5: {'human': 100, 'human_proportion': 0.5, 'gpt-3.5-turbo': 100, 'gpt-3.5-turbo_proportion': 0.5}} 58 | 59 | return proportions 60 | 61 | # 使用示例 62 | # retrieval_file = 'path_to_your_file.json' # 替换为实际文件路径 63 | # topks = [5, 20, 50] 64 | # proportions = calculate_llm_human_proportion(retrieval_file, topks) 65 | # print(proportions) 66 | -------------------------------------------------------------------------------- /src/evaluation/eva_rank.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | from eva_generate_em import evaluate as evaluate_em 4 | from datasets.utils.logging import disable_progress_bar 5 | disable_progress_bar() 6 | 7 | def read_result_file(file_name): 8 | with open(file_name, "r", encoding='utf-8') as f: 9 | question_dataset = json.load(f) 10 | formatted_data = [ 11 | { 12 | "id": qid, 13 | "question": details["question"], 14 | "answers": details["answers"], 15 | "contexts": details["contexts"], 16 | } 17 | for qid, details in question_dataset.items() 18 | ] 19 | result_dataset = datasets.Dataset.from_list(formatted_data) 20 | 21 | return result_dataset 22 | 23 | 24 | def read_generated_file(file_name): 25 | generated_dataset = datasets.load_dataset('json', data_files=file_name)['train'] 26 | prediction = evaluate_em(generated_dataset) # format: for each sample, example['exact_match'] = 1 or 0 27 | return prediction 28 | 29 | 30 | def get_first_answer_rank(example, cut_off=100): 31 | rank = cut_off 32 | for i, context in enumerate(example['contexts']): 33 | if context['has_answer'] == True: 34 | rank = min(rank, i+1) 35 | break 36 | 37 | example['first_answer_rank'] = rank 38 | return example 39 | 40 | 41 | def get_first_answer_rank_dataset(dataset, cut_off=100): 42 | map_fn_kwargs = {'cut_off': cut_off} 43 | dataset = dataset.map(get_first_answer_rank, fn_kwargs=map_fn_kwargs, num_proc=4) 44 | return dataset 45 | 46 | 47 | def get_first_human_answer_rank(example, cut_off=100): 48 | rank = cut_off 49 | for i, context in enumerate(example['contexts']): 50 | if context['has_answer'] == True and context['docid'].isdigit(): 51 | rank = min(rank, i+1) 52 | break 53 | 54 | example['first_human_answer_rank'] = rank 55 | return example 56 | 57 | 58 | def get_first_human_answer_rank_dataset(dataset, cut_off=100): 59 | map_fn_kwargs = {'cut_off': cut_off} 60 | dataset = dataset.map(get_first_human_answer_rank, fn_kwargs=map_fn_kwargs, num_proc=4) 61 | return dataset 62 | 63 | 64 | def compute_true_wrong_first_answer_rank(result_file, generated_file, cut_off=100): 65 | result_dataset = read_result_file(result_file) 66 | generated_dataset = read_generated_file(generated_file) 67 | result_dataset = get_first_answer_rank_dataset(result_dataset, cut_off) 68 | result_dataset = get_first_human_answer_rank_dataset(result_dataset, cut_off) 69 | 70 | assert result_dataset['id'] == generated_dataset['id'] 71 | em = generated_dataset['exact_match'] 72 | first_answer_rank = result_dataset['first_answer_rank'] 73 | first_human_answer_rank = result_dataset['first_human_answer_rank'] 74 | 75 | # compute avg answer rank and avg human answer rank, grouped by exact_match (0 or 1) 76 | avg_answer_rank = {} 77 | avg_human_answer_rank = {} 78 | for i in range(len(em)): 79 | if em[i] not in avg_answer_rank: 80 | avg_answer_rank[em[i]] = [] 81 | avg_human_answer_rank[em[i]] = [] 82 | avg_answer_rank[em[i]].append(first_answer_rank[i]) 83 | avg_human_answer_rank[em[i]].append(first_human_answer_rank[i]) 84 | 85 | 86 | # compute accuracy@5 87 | accuracy_at_5 = {} 88 | for i in avg_answer_rank: 89 | accuracy_at_5[i] = round(sum([1 for rank in avg_answer_rank[i] if rank <= 5]) / len(avg_answer_rank[i]),4) 90 | accuracy_at_3 = {} 91 | for i in avg_answer_rank: 92 | accuracy_at_3[i] = round(sum([1 for rank in avg_answer_rank[i] if rank <= 3]) / len(avg_answer_rank[i]),4) 93 | accuracy_at_1 = {} 94 | for i in avg_answer_rank: 95 | accuracy_at_1[i] = round(sum([1 for rank in avg_answer_rank[i] if rank <= 1]) / len(avg_answer_rank[i]),4) 96 | 97 | for i in avg_answer_rank: 98 | avg_answer_rank[i] = round(sum(avg_answer_rank[i]) / len(avg_answer_rank[i]),4) 99 | for i in avg_human_answer_rank: 100 | avg_human_answer_rank[i] = round(sum(avg_human_answer_rank[i]) / len(avg_human_answer_rank[i]),4) 101 | return avg_answer_rank, avg_human_answer_rank, accuracy_at_5, accuracy_at_3, accuracy_at_1, em 102 | 103 | 104 | 105 | def compute_change_rank(result_files, generated_files, cut_off=100): 106 | em = [] 107 | 108 | for result_file, generated_file in zip(result_files, generated_files): 109 | _, _, _, _, _, em_i = compute_true_wrong_first_answer_rank(result_file, generated_file, cut_off) 110 | em.append(em_i) 111 | 112 | # compute em: 0->1, 1->0, 0->0, 1->1 113 | # em format:[[0, 1, 1, 0, 0, 1, 1, 0, 0, 1], [1, 0, 0, 1, 1, 0, 0, 1, 1, 0]], compute the item change from 0 to 1 of each item 114 | 115 | change_em = {} 116 | change_index = {} 117 | for i in range(len(em[0])): 118 | if em[0][i] == 0 and em[1][i] == 1: 119 | change_em['0->1'] = change_em.get('0->1', 0) + 1 120 | change_index['0->1'] = change_index.get('0->1', []) + [i] 121 | elif em[0][i] == 1 and em[1][i] == 0: 122 | change_em['1->0'] = change_em.get('1->0', 0) + 1 123 | change_index['1->0'] = change_index.get('1->0', []) + [i] 124 | elif em[0][i] == 0 and em[1][i] == 0: 125 | change_em['0->0'] = change_em.get('0->0', 0) + 1 126 | elif em[0][i] == 1 and em[1][i] == 1: 127 | change_em['1->1'] = change_em.get('1->1', 0) + 1 128 | else: 129 | print('error') 130 | # check dict key exist 131 | for key in ['0->1', '1->0', '0->0', '1->1']: 132 | if key not in change_em: 133 | change_em[key] = 0 134 | for key in ['0->1', '1->0']: 135 | if key not in change_index: 136 | change_index[key] = [] 137 | return change_em, change_index 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /src/evaluation/eva_retrieval_trec.py: -------------------------------------------------------------------------------- 1 | from trec_eval import EvalFunction 2 | import argparse 3 | 4 | def evaluate_trec(qrel_path, run_path): 5 | EvalFunction.eval(['-c', '-l', '2', '-m', 'recall.100', qrel_path, run_path]) 6 | EvalFunction.eval(['-c', '-l', '2', '-M', '100', '-m', 'map', qrel_path, run_path]) 7 | EvalFunction.eval(['-c', '-l', '2', '-M', '100', '-m', 'recip_rank', qrel_path, run_path]) 8 | EvalFunction.eval(['-c', '-m', 'ndcg_cut.10', qrel_path, run_path]) -------------------------------------------------------------------------------- /src/evaluation/run_context_eva.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # if you are using "QA_llm_mis" or "QA_llm_right" task, you need to set the API_KEY and API_BASE 4 | API_KEY="none" 5 | API_BASE="none" 6 | QUERY_DATA_NAMES=(nq webq) 7 | #RESULT_NAMES=( 8 | # "nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20231227041949" 9 | # "nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240113075935" 10 | # "nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20231229042900" 11 | # "nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642" 12 | # "nq_webq_pop_tqa_loop_output_bm25_upr_total_loop_10_20231231125441" 13 | # "nq_webq_pop_tqa_loop_output_bm25_monot5_total_loop_10_20240101125941" 14 | # "nq_webq_pop_tqa_loop_output_bm25_bge_total_loop_10_20240103144945" 15 | # "nq_webq_pop_tqa_loop_output_bge-base_upr_total_loop_10_20240106093905" 16 | # "nq_webq_pop_tqa_loop_output_bge-base_monot5_total_loop_10_20240108014726" 17 | # "nq_webq_pop_tqa_loop_output_bge-base_bge_total_loop_10_20240109090024" 18 | #) 19 | # shellcheck disable=SC2054 20 | #RESULT_NAMES=( 21 | # "mis_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240124142811" 22 | # "mis_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240129064151" 23 | # "mis_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240125140045" 24 | # "mis_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240123121401" 25 | #) 26 | 27 | #RESULT_NAMES=( 28 | # "filter_bleu_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240131140843" 29 | # "filter_bleu_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240130134307" 30 | # "filter_bleu_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240131141029" 31 | # "filter_bleu_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240131141119" 32 | #) 33 | 34 | #RESULT_NAMES=( 35 | # "filter_source_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240204104046" 36 | # "filter_source_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240203141208" 37 | # "filter_source_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240204091108" 38 | # "filter_source_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240204103944" 39 | #) 40 | RESULT_NAMES=( 41 | "update_nq_webq_loop_output_bm25_None_total_loop_10_20240328134015" 42 | ) 43 | 44 | #RESULT_NAMES=( "mis_passage_processed" ) 45 | RESULT_DIR="../../data_v2/loop_output/DPR" 46 | #RESULT_DIR="../../data_v2/misinfo_data/DPR" 47 | TASK="context_answer" 48 | #TASK="retrieval" 49 | #TASK="bleu" 50 | #TASK="QA" 51 | #TASK="percentage" 52 | #TASK="misQA" 53 | #TASK="QA_llm_mis" 54 | #TASK="QA_llm_right" 55 | #TASK="filter_bleu_retrieval" 56 | #TASK="filter_bleu_percentage" 57 | #TASK="filter_bleu_context_answer" 58 | #TASK="filter_source_retrieval" 59 | #TASK="filter_source_percentage" 60 | #TASK="filter_source_context_answer" 61 | 62 | for ((i=0;i<${#QUERY_DATA_NAMES[@]};i++)) 63 | do 64 | for ((j=0;j<${#RESULT_NAMES[@]};j++)) 65 | do 66 | QUERY_DATA_NAME=${QUERY_DATA_NAMES[i]} 67 | RESULT_NAME=${RESULT_NAMES[j]} 68 | RESULT_PATH="${RESULT_DIR}/${RESULT_NAME}/${QUERY_DATA_NAME}" 69 | echo "QUERY_DATA_NAME: ${QUERY_DATA_NAME}" 70 | echo "RESULT_NAME: ${RESULT_NAME}" 71 | echo "RESULT_PATH: ${RESULT_PATH}" 72 | python3 eva_pipe.py --config_file_path none --directory ${RESULT_PATH} --task $TASK --api_key $API_KEY --api_base $API_BASE 73 | done 74 | done 75 | -------------------------------------------------------------------------------- /src/evaluation/sum_QA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from collections import defaultdict 4 | 5 | 6 | # Your provided path names and dataset names 7 | # path_names = [ 8 | # "mis_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240129064151", 9 | # "mis_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240124142811", 10 | # "mis_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240125140045", 11 | # "mis_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240123121401" 12 | # ] 13 | 14 | # path_names = [ 15 | # "filter_bleu_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240130134307", 16 | # "filter_bleu_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240131140843", 17 | # "filter_bleu_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240131141029", 18 | # "filter_bleu_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240131141119" 19 | # ] 20 | 21 | # path_names = [ 22 | # "filter_source_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240203141208", 23 | # "filter_source_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240204104046", 24 | # "filter_source_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240204091108", 25 | # "filter_source_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240204103944" 26 | # ] 27 | path_names = [ "low_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240327135107"] 28 | 29 | dataset_names = ["nq", "webq", "pop", "tqa"] 30 | base_dir = "/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/data_v2/loop_output/DPR" 31 | 32 | # 存储最终结果 33 | final_results = defaultdict(dict) 34 | 35 | # 遍历每个数据集 36 | 37 | for dataset in dataset_names: 38 | for path_name in path_names: 39 | file_path = os.path.join(base_dir, path_name, dataset, "results", f"{dataset}_QA.tsv") 40 | print(f"file_path: {file_path}") 41 | # 读取并解析文件 42 | if os.path.exists(file_path): 43 | with open(file_path, 'r') as file: 44 | lines = file.readlines() 45 | methods_scores = defaultdict(dict) 46 | ref_num = None 47 | model_name = None 48 | loop_order = [] 49 | method_name = None 50 | 51 | for line in lines: 52 | if line.startswith("Ref Num:"): 53 | ref_num = line.strip().split(": ")[1] 54 | elif line.startswith("Generate Model:"): 55 | model_name = line.strip().split(": ")[1] 56 | elif line.startswith(dataset): 57 | continue 58 | elif line == "\n": 59 | continue 60 | else: 61 | print(f"model_name: {model_name}") 62 | print(f"line: {line}") 63 | # llm-embedder 0.635 0.58 0.64 0.62 0.615 0.61 0.605 0.605 0.585 0.59 64 | method_name = line.strip().split("\t")[0] 65 | method_scores = line.strip().split("\t")[1:] 66 | # print(method_name, method_scores) 67 | # change the 2nd item in method_scores to last 68 | index_changed = method_scores.pop(1) 69 | method_scores.append(index_changed) 70 | print(f"method_name: {method_name}, method_scores: {method_scores}\n\n\n") 71 | methods_scores[model_name][method_name] = method_scores 72 | 73 | # 保存结果 74 | # print(f"ref_num: {ref_num}, methods_scores: {methods_scores}") 75 | for model_name in methods_scores.keys(): 76 | for method_name in methods_scores[model_name].keys(): 77 | if dataset not in final_results.keys(): 78 | final_results[dataset] = defaultdict(dict) 79 | if model_name not in final_results[dataset].keys(): 80 | final_results[dataset][model_name] = defaultdict(dict) 81 | final_results[dataset][model_name][method_name] = methods_scores[model_name][method_name].copy() 82 | 83 | # 保存结果 84 | for dataset in final_results.keys(): 85 | print(f"dataset: {dataset}") 86 | for model_name in final_results[dataset].keys(): 87 | print(f"model_name: {model_name}") 88 | for method_name in final_results[dataset][model_name].keys(): 89 | print(f"method_name: {method_name}, scores: {final_results[dataset][model_name][method_name]}") 90 | print("\n\n") 91 | 92 | 93 | output_path = os.path.join("sum_tsvs", "sum_QA_filter_source.tsv") 94 | ''' 95 | 96 | output_format: 97 | EM ref=5 98 | GPT 99 | nq 100 | BM25 BM25+UPR BM25+MonoT5 BM25+BGE_reranker BGE-Base 101 | loop1 0.53 0.55 0.53 0.565 0.56 102 | loop2 0.55 0.53 0.535 0.54 0.565 103 | loop3 0.55 0.515 0.53 0.555 0.57 104 | loop4 0.525 0.515 0.53 0.55 0.545 105 | loop5 0.515 0.52 0.535 0.545 0.54 106 | loop6 0.515 0.515 0.515 0.535 0.55 107 | 108 | ''' 109 | 110 | 111 | 112 | 113 | 114 | with open(output_path, 'w') as file: 115 | file.write("EM\tref=5\n") 116 | for dataset in dataset_names: 117 | for model_name in final_results[dataset].keys(): 118 | file.write(f"{model_name}\n") 119 | file.write(f'{dataset}\n') 120 | file.write(f"Method\t") 121 | for method_name in final_results[dataset][model_name].keys(): 122 | file.write(f"{method_name}\t") 123 | file.write("\n") 124 | for loop_order in range(1, 11): 125 | file.write(f"loop{loop_order}\t") 126 | for method_name in final_results[dataset][model_name].keys(): 127 | try: 128 | file.write(f"{final_results[dataset][model_name][method_name][loop_order - 1]}\t") 129 | except: 130 | print(f"dataset: {dataset}, model_name: {model_name}, method_name: {method_name}, loop_order: {loop_order}") 131 | file.write("\n") 132 | file.write("\n") 133 | file.write("\n") 134 | 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /src/evaluation/sum_bleu.py: -------------------------------------------------------------------------------- 1 | #summarze percentage tsvs 2 | 3 | import os 4 | import sys 5 | import json 6 | import csv 7 | 8 | 9 | path_names = [ 10 | "nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20231227041949", 11 | "nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240113075935", 12 | "nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20231229042900", 13 | "nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642", 14 | "nq_webq_pop_tqa_loop_output_bm25_upr_total_loop_10_20231231125441", 15 | "nq_webq_pop_tqa_loop_output_bm25_monot5_total_loop_10_20240101125941", 16 | "nq_webq_pop_tqa_loop_output_bm25_bge_total_loop_10_20240103144945", 17 | "nq_webq_pop_tqa_loop_output_bge-base_upr_total_loop_10_20240106093905", 18 | "nq_webq_pop_tqa_loop_output_bge-base_monot5_total_loop_10_20240108014726", 19 | "nq_webq_pop_tqa_loop_output_bge-base_bge_total_loop_10_20240109090024" 20 | ] 21 | 22 | dataset_names = [ 23 | "nq", 24 | "webq", 25 | "pop", 26 | "tqa" 27 | ] 28 | 29 | topk =[ 30 | 5, 31 | ] 32 | 33 | ref = 5 34 | 35 | # 要合并的文件类型 36 | file_types = ["bleu_top5_ref5_bigram.tsv", "bleu_top5_ref5_trigram.tsv"] 37 | 38 | # 输出文件 39 | output_file = "summary_bleu.tsv" 40 | 41 | # 创建或清空输出文件 42 | f = open(output_file, "w") 43 | f.close() 44 | 45 | # 准备写入数据 46 | with open(output_file, 'a') as outfile: 47 | 48 | total_path = "/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/data_v2/loop_output/DPR" 49 | # 遍历每个文件类型(topk_ref) 50 | for file_type in file_types: 51 | # 写入topk和ref作为新的部分的第一行 52 | topk_ref = file_type.replace("bleu_", "").replace(".tsv", "") 53 | outfile.write(topk_ref + "\n") 54 | # write a dataset name row, each dataset name has 12 blank columns after it 55 | for dataset_name in dataset_names: 56 | outfile.write(dataset_name + "\t" * 13) 57 | outfile.write("\n") 58 | 59 | # 写入表头 60 | headers = ['ref_num', "method", "generate_model_name"] + [str(i) for i in range(1, 11)] 61 | headers_with_space = "\t".join(headers) + "\t\t" 62 | outfile.write(headers_with_space * len(dataset_names) + "\n") 63 | 64 | 65 | 66 | 67 | # 遍历每个数据集 68 | # 初始化每个文件类型的数据行 69 | data_rows = [[] for _ in dataset_names] 70 | 71 | # 遍历每个文件夹进行数据集的读取 72 | for path in path_names: 73 | # 遍历每个数据集 74 | for dataset_index, dataset in enumerate(dataset_names): 75 | # 构建文件路径 76 | file_path = os.path.join(total_path, path, dataset, "results", file_type) 77 | # 检查文件是否存在 78 | if os.path.exists(file_path): 79 | # 读取文件内容 80 | with open(file_path, 'r') as infile: 81 | # 跳过表头 82 | next(infile) 83 | reader = csv.reader(infile, delimiter='\t') 84 | rows = list(reader) 85 | if len(rows) > 3: 86 | for row in rows[3:]: 87 | # 跳过空行 88 | if len(row) == 0: 89 | continue 90 | data_rows[dataset_index].append("\t".join(row)) 91 | else: 92 | for row in rows: 93 | # 跳过空行 94 | if len(row) == 0: 95 | continue 96 | data_rows[dataset_index].append("\t".join(row)) 97 | else: 98 | # 如果文件不存在,添加空行 99 | print("file not exist: ", file_path) 100 | # 将每个数据集的对应行作为一行写入文件,中间用\t\t 101 | for data_row in range(len(data_rows[0])): 102 | for dataset_index, dataset in enumerate(dataset_names): 103 | outfile.write(data_rows[dataset_index][data_row] + "\t\t") 104 | outfile.write("\n") 105 | -------------------------------------------------------------------------------- /src/evaluation/sum_context_tsv.py: -------------------------------------------------------------------------------- 1 | #summarze percentage tsvs 2 | 3 | import os 4 | import sys 5 | import json 6 | import csv 7 | 8 | 9 | # path_names = [ 10 | # "nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20231227041949", 11 | # "nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240113075935", 12 | # "nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20231229042900", 13 | # "nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642", 14 | # "nq_webq_pop_tqa_loop_output_bm25_upr_total_loop_10_20231231125441", 15 | # "nq_webq_pop_tqa_loop_output_bm25_monot5_total_loop_10_20240101125941", 16 | # "nq_webq_pop_tqa_loop_output_bm25_bge_total_loop_10_20240103144945", 17 | # "nq_webq_pop_tqa_loop_output_bge-base_upr_total_loop_10_20240106093905", 18 | # "nq_webq_pop_tqa_loop_output_bge-base_monot5_total_loop_10_20240108014726", 19 | # "nq_webq_pop_tqa_loop_output_bge-base_bge_total_loop_10_20240109090024" 20 | # ] 21 | 22 | # path_names = [ 23 | # "mis_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240129064151", 24 | # "mis_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240124142811", 25 | # "mis_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240125140045", 26 | # "mis_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240123121401" 27 | # ] 28 | 29 | path_names = [ 30 | "filter_bleu_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240130134307", 31 | "filter_bleu_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240131140843", 32 | "filter_bleu_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240131141029", 33 | "filter_bleu_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240131141119" 34 | ] 35 | # path_names = [ 36 | # "filter_source_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240203141208", 37 | # "filter_source_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240204104046", 38 | # "filter_source_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240204091108", 39 | # "filter_source_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240204103944" 40 | # ] 41 | 42 | dataset_names = [ 43 | "nq", 44 | "webq", 45 | "pop", 46 | "tqa" 47 | ] 48 | total_path = "/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/data_v2/loop_output/DPR" 49 | # 指定要合并的文件类型 50 | file_types = ["context_answer_ref5_cutoff5_em0.tsv", "context_answer_ref5_cutoff5_em1.tsv"] 51 | 52 | # 创建目标文件的表头 53 | header = ["ref_num", "model", "method", "right_num", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"] 54 | 55 | # 处理每个文件类型 56 | for file_type in file_types: 57 | # 初始化数据集字典,用于存储每个数据集的数据 58 | data_to_write = {dataset_name: [] for dataset_name in dataset_names} 59 | 60 | # 遍历每个路径和数据集名称,收集数据 61 | for path_name in path_names: 62 | for dataset_name in dataset_names: 63 | # 构建文件的完整路径 64 | file_path = os.path.join(total_path, path_name, dataset_name, "results", file_type) 65 | if os.path.exists(file_path): 66 | with open(file_path, 'r', newline='', encoding='utf-8') as tsvfile: 67 | reader = csv.reader(tsvfile, delimiter='\t') 68 | for row in reader: 69 | # 跳过空行和不正确的表头 70 | if not row or row[:4] == header[:4]: 71 | continue 72 | # 转换行为字典 73 | row_dict = dict(zip(header, row)) 74 | data_to_write[dataset_name].append(row_dict) 75 | 76 | # 写入合并后的文件 77 | output_file_name = f'sum_tsvs/filter_bleu_merged_{file_type}' 78 | with open(output_file_name, 'w', newline='', encoding='utf-8') as outfile: 79 | writer = csv.writer(outfile, delimiter='\t') 80 | 81 | # 写入新的表头 82 | extended_header = [[f'{dataset_name}_ref5_cutoff5'] + header for dataset_name in dataset_names] 83 | extended_header = [item for sublist in extended_header for item in sublist] 84 | writer.writerow(extended_header) 85 | 86 | # 写入合并后的数据 87 | max_rows = max(len(data_to_write[dataset_name]) for dataset_name in dataset_names) 88 | for i in range(max_rows): 89 | row_to_write = [] 90 | for dataset_name in dataset_names: 91 | data_list = data_to_write[dataset_name] 92 | if i < len(data_list): 93 | row_to_write.extend(['']+list(data_list[i].values())) 94 | else: 95 | # 如果数据集的行不够,用空字符串填充 96 | row_to_write.extend([''] * len(header)) 97 | writer.writerow(row_to_write) 98 | -------------------------------------------------------------------------------- /src/evaluation/sum_percentage_tsv.py: -------------------------------------------------------------------------------- 1 | #summarze percentage tsvs 2 | 3 | import os 4 | import sys 5 | import json 6 | import csv 7 | 8 | 9 | # path_names = [ 10 | # "nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20231227041949", 11 | # "nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240113075935", 12 | # "nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20231229042900", 13 | # "nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642", 14 | # "nq_webq_pop_tqa_loop_output_bm25_upr_total_loop_10_20231231125441", 15 | # "nq_webq_pop_tqa_loop_output_bm25_monot5_total_loop_10_20240101125941", 16 | # "nq_webq_pop_tqa_loop_output_bm25_bge_total_loop_10_20240103144945", 17 | # "nq_webq_pop_tqa_loop_output_bge-base_upr_total_loop_10_20240106093905", 18 | # "nq_webq_pop_tqa_loop_output_bge-base_monot5_total_loop_10_20240108014726", 19 | # "nq_webq_pop_tqa_loop_output_bge-base_bge_total_loop_10_20240109090024" 20 | # ] 21 | 22 | # path_names = [ 23 | # "mis_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240129064151", 24 | # "mis_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240124142811", 25 | # "mis_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240125140045", 26 | # "mis_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240123121401" 27 | # ] 28 | # path_names = [ 29 | # "filter_bleu_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240130134307", 30 | # "filter_bleu_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240131140843", 31 | # "filter_bleu_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240131141029", 32 | # "filter_bleu_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240131141119" 33 | # ] 34 | path_names = [ 35 | "filter_source_nq_webq_pop_tqa_loop_output_bm25_None_total_loop_10_20240203141208", 36 | "filter_source_nq_webq_pop_tqa_loop_output_bge-base_None_total_loop_10_20240204104046", 37 | "filter_source_nq_webq_pop_tqa_loop_output_contriever_None_total_loop_10_20240204091108", 38 | "filter_source_nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240204103944" 39 | ] 40 | 41 | dataset_names = [ 42 | "nq", 43 | "webq", 44 | "pop", 45 | "tqa" 46 | ] 47 | 48 | topk =[ 49 | 5, 50 | 20, 51 | 50 52 | ] 53 | 54 | ref = 5 55 | 56 | # 要合并的文件类型 57 | file_types = ["percentage_top5_ref5.tsv", "percentage_top20_ref5.tsv", "percentage_top50_ref5.tsv"] 58 | 59 | # 输出文件 60 | output_file = "sum_tsvs/filter_source_percentage_summary.tsv" 61 | 62 | # 创建或清空输出文件 63 | f = open(output_file, "w") 64 | f.close() 65 | 66 | # 准备写入数据 67 | with open(output_file, 'a') as outfile: 68 | 69 | total_path = "/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/data_v2/loop_output/DPR" 70 | # 遍历每个文件类型(topk_ref) 71 | for file_type in file_types: 72 | # 写入topk和ref作为新的部分的第一行 73 | topk_ref = file_type.replace("percentage_", "").replace(".tsv", "") 74 | outfile.write(topk_ref + "\n") 75 | # write a dataset name row, each dataset name has 12 blank columns after it 76 | for dataset_name in dataset_names: 77 | outfile.write(dataset_name + "\t" * 13) 78 | outfile.write("\n") 79 | 80 | # 写入表头 81 | headers = ["method", "generate_model_name"] + [str(i) for i in range(1, 11)] 82 | headers_with_space = "\t".join(headers) + "\t\t" 83 | outfile.write(headers_with_space * len(dataset_names) + "\n") 84 | 85 | 86 | 87 | 88 | # 遍历每个数据集 89 | # 初始化每个文件类型的数据行 90 | data_rows = [[] for _ in dataset_names] 91 | 92 | # 遍历每个文件夹进行数据集的读取 93 | for path in path_names: 94 | # 遍历每个数据集 95 | for dataset_index, dataset in enumerate(dataset_names): 96 | # 构建文件路径 97 | file_path = os.path.join(total_path, path, dataset, "results", file_type) 98 | # 检查文件是否存在 99 | if os.path.exists(file_path): 100 | # 读取文件内容 101 | with open(file_path, 'r') as infile: 102 | # 跳过表头 103 | next(infile) 104 | # 读取剩余行行数,如果大于6行,只读取第7行及以后的行将数据添加到对应数据集的列表中,否则全部添加 105 | reader = csv.reader(infile, delimiter='\t') 106 | rows = list(reader) 107 | if len(rows) > 6: 108 | for row in rows[6:]: 109 | data_rows[dataset_index].append("\t".join(row)) 110 | else: 111 | for row in rows: 112 | data_rows[dataset_index].append("\t".join(row)) 113 | else: 114 | # 如果文件不存在,添加空行 115 | print("file not exist: ", file_path) 116 | # 将每个数据集的对应行作为一行写入文件,中间用\t\t 117 | for data_row in range(len(data_rows[0])): 118 | for dataset_index, dataset in enumerate(dataset_names): 119 | outfile.write(data_rows[dataset_index][data_row] + "\t\t") 120 | outfile.write("\n") 121 | -------------------------------------------------------------------------------- /src/evaluation/trec_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | import sys 5 | import platform 6 | import pandas as pd 7 | import tempfile 8 | 9 | from pyserini.search import get_qrels_file 10 | from pyserini.util import download_evaluation_script 11 | 12 | 13 | class EvalFunction: 14 | @staticmethod 15 | def trunc(qrels, run): 16 | qrels = get_qrels_file(qrels) 17 | run = pd.read_csv(run, delim_whitespace=True, header=None) 18 | qrels = pd.read_csv(qrels, delim_whitespace=True, header=None) 19 | run[0] = run[0].astype(str) 20 | qrels[0] = qrels[0].astype(str) 21 | 22 | qrels = qrels[qrels[0].isin(run[0])] 23 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 24 | qrels.to_csv(temp_file, sep='\t', header=None, index=None) 25 | return temp_file 26 | 27 | @staticmethod 28 | def eval(args, trunc=True): 29 | script_path = download_evaluation_script('trec_eval') 30 | cmd_prefix = ['java', '-jar', script_path] 31 | # args = sys.argv 32 | 33 | # Option to discard non-judged hits in run file 34 | judged_docs_only = '' 35 | judged_result = [] 36 | cutoffs = [] 37 | 38 | if '-remove-unjudged' in args: 39 | judged_docs_only = args.pop(args.index('-remove-unjudged')) 40 | 41 | if any([i.startswith('judged.') for i in args]): 42 | # Find what position the arg is in. 43 | idx = [i.startswith('judged.') for i in args].index(True) 44 | cutoffs = args.pop(idx) 45 | cutoffs = list(map(int, cutoffs[7:].split(','))) 46 | # Get rid of the '-m' before the 'judged.xxx' option 47 | args.pop(idx - 1) 48 | 49 | temp_file = '' 50 | 51 | if len(args) > 1: 52 | if trunc: 53 | args[-2] = EvalFunction.trunc(args[-2], args[-1]) 54 | print('Trunc', args[-2]) 55 | 56 | if not os.path.exists(args[-2]): 57 | args[-2] = get_qrels_file(args[-2]) 58 | if os.path.exists(args[-1]): 59 | # Convert run to trec if it's on msmarco 60 | with open(args[-1]) as f: 61 | first_line = f.readline() 62 | if 'Q0' not in first_line: 63 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 64 | print('msmarco run detected. Converting to trec...') 65 | run = pd.read_csv(args[-1], delim_whitespace=True, header=None, 66 | names=['query_id', 'doc_id', 'rank']) 67 | run['score'] = 1 / run['rank'] 68 | run.insert(1, 'Q0', 'Q0') 69 | run['name'] = 'TEMPRUN' 70 | run.to_csv(temp_file, sep='\t', header=None, index=None) 71 | args[-1] = temp_file 72 | 73 | run = pd.read_csv(args[-1], delim_whitespace=True, header=None) 74 | qrels = pd.read_csv(args[-2], delim_whitespace=True, header=None) 75 | 76 | # cast doc_id column as string 77 | run[0] = run[0].astype(str) 78 | qrels[0] = qrels[0].astype(str) 79 | 80 | # Discard non-judged hits 81 | 82 | if judged_docs_only: 83 | if not temp_file: 84 | temp_file = tempfile.NamedTemporaryFile(delete=False).name 85 | judged_indexes = pd.merge(run[[0, 2]].reset_index(), qrels[[0, 2]], on=[0, 2])['index'] 86 | run = run.loc[judged_indexes] 87 | run.to_csv(temp_file, sep='\t', header=None, index=None) 88 | args[-1] = temp_file 89 | # Measure judged@cutoffs 90 | for cutoff in cutoffs: 91 | run_cutoff = run.groupby(0).head(cutoff) 92 | judged = len(pd.merge(run_cutoff[[0, 2]], qrels[[0, 2]], on=[0, 2])) / len(run_cutoff) 93 | metric_name = f'judged_{cutoff}' 94 | judged_result.append(f'{metric_name:22}\tall\t{judged:.4f}') 95 | cmd = cmd_prefix + args[1:] 96 | else: 97 | cmd = cmd_prefix 98 | 99 | print(f'Running command: {cmd}') 100 | shell = platform.system() == "Windows" 101 | process = subprocess.Popen(cmd, 102 | stdout=subprocess.PIPE, 103 | stderr=subprocess.PIPE, 104 | shell=shell) 105 | stdout, stderr = process.communicate() 106 | if stderr: 107 | print(stderr.decode("utf-8")) 108 | 109 | print('Results:') 110 | print(stdout.decode("utf-8").rstrip()) 111 | 112 | for judged in judged_result: 113 | print(judged) 114 | 115 | if temp_file: 116 | os.remove(temp_file) -------------------------------------------------------------------------------- /src/filtering/filter_configs/bleu_filter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_file": "/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642/nq/llm-embedder_nq_retrieval_loop_1", 3 | "output_file": "/home/xiaoyang2020/chenxiaoyang_11/Rob_LLM/src/filtering/llm-embedder_nq_retrieval_loop_1_filtered", 4 | "elasticsearch_url": "http://124.16.138.150:9978", 5 | "index_name": "bm25_psgs_index", 6 | "max_self_bleu": 0.4, 7 | "num_docs": 5 8 | } -------------------------------------------------------------------------------- /src/filtering/filter_configs/source_filter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_file": "../../data_v2/loop_output/DPR/nq_webq_pop_tqa_loop_output_llm-embedder_None_total_loop_10_20240116050642/nq/llm-embedder_nq_retrieval_loop_10", 3 | "output_file": "llm-embedder_nq_retrieval_loop_10_filtered", 4 | "elasticsearch_url": "http://124.16.138.150:9978", 5 | "index_name": "bm25_psgs_index", 6 | "max_self_bleu": 0.4, 7 | "num_docs": 5, 8 | "task": "filter_source" 9 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/eva_generate.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import string 3 | import re 4 | import os 5 | 6 | 7 | def evaluate(predictions): 8 | # evaluate the predictions with exact match 9 | def _normalize_answer(s): 10 | def remove_articles(text): 11 | return re.sub(r"\b(a|an|the)\b", " ", text) 12 | 13 | def white_space_fix(text): 14 | return " ".join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return "".join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | def exact_match_score(example): 26 | ground_truths = example['answers'] 27 | assert type(ground_truths) == list, f'ground_truths is not a list, id:{example["id"]}, ground_truth:{ground_truths}' 28 | prediction = example['response'] 29 | example['exact_match'] = 0 30 | if not prediction: 31 | print(f'no prediction for qid {example["qid"]}, {example["query"]}') 32 | return example 33 | for ground_truth in ground_truths: 34 | if _normalize_answer(ground_truth) in _normalize_answer(prediction): 35 | example['exact_match'] = 1 36 | break 37 | return example 38 | 39 | predictions = predictions.map(exact_match_score) 40 | return predictions 41 | 42 | 43 | def main(): 44 | # load the predictions 45 | res_file_path = '../../data_v2/zero_gen_data/DPR/sampled_data' 46 | data_names = ['nq', 'webq', 'tqa', 'pop'] 47 | model_names = ['chatglm3-6b-chat', 'qwen-14b-chat', 'baichuan2-13b-chat', 'llama2-13b-chat', 'gpt-3.5-turbo'] 48 | for data_name in data_names: 49 | for model_name in model_names: 50 | res_path = f'{res_file_path}/{data_name}-test-gen-{model_name}.jsonl' 51 | # check if the file exists 52 | if not os.path.exists(res_path): 53 | continue 54 | predictions = datasets.load_dataset('json', data_files=res_path)['train'] 55 | predictions = evaluate(predictions) 56 | print('-' * 20) 57 | print(f'{data_name}-{model_name}') 58 | # compute the exact match score 59 | # 'list' object has no attribute 'sum' 60 | # predictions['exact_match'] is a list 61 | print(sum(predictions['exact_match']) / len(predictions['exact_match'])) 62 | print('-' * 20) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | 68 | -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/baichuan2-13b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/baichuan2-13b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/baichuan2-13b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/baichuan2-13b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/chatglm3-6b-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8113/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/chatglm3-6b-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8113/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/chatglm3-6b-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8113/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/chatglm3-6b-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8113/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/gpt-3.5-turbo-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/gpt-3.5-turbo-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/gpt-3.5-turbo-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/gpt-3.5-turbo-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/llama2-13b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/llama2-13b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/llama2-13b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/llama2-13b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/qwen-14b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/qwen-14b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/qwen-14b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/rag_configs/qwen-14b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/run_generate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | MODEL_NAMES=(llama2-7b-chat) #chatglm3-6b qwen-14b-chat llama2-13b-chat baichuan2-13b-chat gpt-3.5-turbo 5 | GENERATE_BASE_AND_KEY=( 6 | "gpt-3.5-turbo xxx xxx" 7 | "chatglm3-6b xxx xxx" 8 | # "qwen-14b-chat xxx xxx" 9 | "llama2-7b-chat xxx xxx" 10 | "baichuan2-7b-chat xxx xxx" 11 | "llama2-13b-chat xxx xxx" 12 | "baichuan2-13b-chat xxx xxx" 13 | "qwen-0.5b-chat xxx xxx" 14 | "qwen-1.8b-chat xxx xxx" 15 | "qwen-4b-chat xxx xxx" 16 | "qwen-7b-chat xxx xxx" 17 | "qwen-14b-chat xxx xxx" 18 | ) 19 | 20 | DATA_NAMES=(tqa pop nq webq) 21 | CONTEXT_REF_NUM=0 22 | QUESTION_FILE_NAMES=( 23 | "-test-sample-200.jsonl" 24 | ) 25 | LOOP_CONFIG_PATH_NAME="../run_configs/test_zero_update_retrieval_config" 26 | 27 | TOTAL_LOG_DIR="../run_logs/test_zero_update_retrieval_log" 28 | QUESTION_FILE_PATH_TOTAL="../../data_v2/input_data/DPR/sampled_query" 29 | TOTAL_OUTPUT_DIR="../../data_v2/loop_output/DPR/test_zero_update_retrieval_result" 30 | mkdir -p "${TOTAL_LOG_DIR}" 31 | mkdir -p "${TOTAL_OUTPUT_DIR}" 32 | mkdir -p "${LOOP_CONFIG_PATH_NAME}" 33 | 34 | LOOP_NUM=1 35 | for ((i=0;i<${LOOP_NUM};i++)) 36 | do 37 | for MODEL_NAME in "${MODEL_NAMES[@]}" 38 | do 39 | # 遍历键值对数组 40 | for entry in "${GENERATE_BASE_AND_KEY[@]}"; do 41 | if [[ ${entry} == $MODEL_NAME* ]]; then 42 | # 读取URL和key 43 | read -ra ADDR <<< "$entry" 44 | API_BASE=${ADDR[1]} 45 | API_KEY=${ADDR[2]} 46 | break 47 | fi 48 | done 49 | for QUERY_DATA_NAME in "${DATA_NAMES[@]}" 50 | do 51 | for QUESTION_FILE_NAME in "${QUESTION_FILE_NAMES[@]}" 52 | do 53 | OUTPUT_DIR="${TOTAL_OUTPUT_DIR}/${QUERY_DATA_NAME}" 54 | mkdir -p "${OUTPUT_DIR}" 55 | QUESTION_FILE_PATH="${QUESTION_FILE_PATH_TOTAL}/${QUERY_DATA_NAME}${QUESTION_FILE_NAME}" 56 | echo "rewrite config file for ${MODEL_NAME} on ${QUERY_DATA_NAME}..." 57 | CONFIG_PATH="${LOOP_CONFIG_PATH_NAME}/${MODEL_NAME}_${QUERY_DATA_NAME}${QUESTION_FILE_NAME}_generate_context_ref_num_${CONTEXT_REF_NUM}.json" 58 | LOG_DIR="${TOTAL_LOG_DIR}/${MODEL_NAME}_${QUERY_DATA_NAME}${QUESTION_FILE_NAME}_generate_context_ref_num_${CONTEXT_REF_NUM}.log" 59 | GENERATE_OUTPUT_NAME="${OUTPUT_DIR}/${MODEL_NAME}_${QUERY_DATA_NAME}${QUESTION_FILE_NAME}_generate_context_ref_num_${CONTEXT_REF_NUM}.json" 60 | python ../rewrite_configs.py --method "${MODEL_NAME}" \ 61 | --data_name "${QUERY_DATA_NAME}" \ 62 | --stage "zero_update_generate" \ 63 | --output_dir "${CONFIG_PATH}" \ 64 | --overrides '{"question_file_path": "'"${QUESTION_FILE_PATH}"'", "output_file_path": "'"${GENERATE_OUTPUT_NAME}"'", "context_ref_num": "'"${CONTEXT_REF_NUM}"'", "api-base": "'"${API_BASE}"'", "api-key": "'"${API_KEY}"'"}' 65 | wait 66 | echo "Running generate for ${MODEL_NAME} on ${QUERY_DATA_NAME}..." 67 | python get_response_llm.py --config_file_path "${CONFIG_PATH}" > "${LOG_DIR}" 2>&1 & 68 | wait 69 | done 70 | 71 | done 72 | done 73 | done -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-13b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-13b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-13b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-13b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-7b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-7b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-7b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/baichuan2-7b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/gpt-3.5-turbo-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/gpt-3.5-turbo-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/gpt-3.5-turbo-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/gpt-3.5-turbo-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-13b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-13b", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-13b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-13b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-13b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-13b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-13b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-13b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-7b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-7b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-7b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/llama2-7b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-0.5b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-0.5b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-0.5b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-0.5b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-1.8b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-1.8b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-1.8b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-1.8b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-14b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-14b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-14b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-14b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-14b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-14b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-14b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-14b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-14b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-14b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-14b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-14b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-4b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-4b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-4b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-4b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-7b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-7b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-7b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-7b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-7b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-7b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/qwen-7b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-7b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/baichuan2-7b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/baichuan2-7b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/baichuan2-7b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/baichuan2-7b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-7b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8222/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/llama2-7b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/llama2-7b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/llama2-7b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/rag_configs/temp/llama2-7b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/llama2-7b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 4 | "output_file_path": "../../data_v2/incxt_gen_data/DPR/nq-test-h10-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/llama2-7b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/llama2-7b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-13b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/llama2-7b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama2-7b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.142:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8223/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-0.5b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": true, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 5, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-0.5b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-0.5b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-0.5b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-0.5b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-0.5b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-1.8b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-1.8b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-1.8b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-1.8b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-1.8b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-1.8b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-4b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/nq-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-4b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/pop-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-4b-chat-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/tqa-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/update_configs/zero-shot_configs/qwen-4b-chat-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen2-4b", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/update_zero_gen_data/DPR/webq-test-gen-qwen2-4b-chat.jsonl", 5 | "with_context": false, 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "index_name": "bm25_psgs_index", 8 | "context_ref_num": 0, 9 | "api-base": "http://124.16.138.150:8111/v1", 10 | "api-key": "xxx" 11 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/baichuan2-13b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/nq-test-gen-baichuan2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/baichuan2-13b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/trash/pop-test-gen-baichuan2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/baichuan2-13b-chat-config-trivia.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/trash/tqa-test-gen-baichuan2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/baichuan2-13b-chat-config-wq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-baichuan2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/chatglm3-6b-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/nq-test-gen-chatglm3-6b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/chatglm3-6b-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-chatglm3-6b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/chatglm3-6b-config-trivia.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-chatglm3-6b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/chatglm3-6b-config-wq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-chatglm3-6b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/gpt-3.5-turbo-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/nq-test-gen-gpt-3.5-turbo.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/gpt-3.5-turbo-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-gpt-3.5-turbo.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/gpt-3.5-turbo-config-trivia.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-gpt-3.5-turbo.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/gpt-3.5-turbo-config-wq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-gpt-3.5-turbo.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/llama2-13b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/nq-test-gen-llama2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/llama2-13b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/trash/pop-test-gen-llama2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/llama2-13b-chat-config-trivia.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/trash/tqa-test-gen-llama2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/llama2-13b-chat-config-wq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-llama2-13b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/qwen-14b-chat-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/nq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/nq-test-gen-qwen-14b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/qwen-14b-chat-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/pop-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/pop-test-gen-qwen-14b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/qwen-14b-chat-config-trivia.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/tqa-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/tqa-test-gen-qwen-14b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/llm_zero_generate/zero-shot_configs/qwen-14b-chat-config-wq.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/webq-test.jsonl", 4 | "output_file_path": "../../data_v2/zero_gen_data/DPR/webq-test-gen-qwen-14b-chat.jsonl" 5 | } -------------------------------------------------------------------------------- /src/misinfo/eva_mis.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import string 3 | import re 4 | import os 5 | 6 | 7 | def evaluate(predictions): 8 | # evaluate the predictions with exact match 9 | def _normalize_answer(s): 10 | def remove_articles(text): 11 | return re.sub(r"\b(a|an|the)\b", " ", text) 12 | 13 | def white_space_fix(text): 14 | return " ".join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return "".join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | def exact_match_score(example): 26 | ground_truths = example['answers'] 27 | false_answers = example['false_answer'] 28 | assert type(ground_truths) == list, f'ground_truths is not a list, id:{example["id"]}, ground_truth:{ground_truths}' 29 | prediction = example['response'] 30 | example['exact_match_true'] = 0 31 | example['exact_match_false'] = 0 32 | if not prediction: 33 | print(f'no prediction for qid {example["qid"]}, {example["query"]}') 34 | return example 35 | for ground_truth in ground_truths: 36 | if _normalize_answer(ground_truth) in _normalize_answer(prediction): 37 | example['exact_match_true'] = 1 38 | break 39 | for false_answer in false_answers: 40 | if _normalize_answer(false_answer) in _normalize_answer(prediction): 41 | example['exact_match_false'] = 1 42 | break 43 | return example 44 | 45 | predictions = predictions.map(exact_match_score) 46 | return predictions 47 | 48 | 49 | def main(): 50 | # load the predictions 51 | res_file_path = '../../data_v2/zero_gen_data/DPR/sampled_data' 52 | data_names = ['nq', 'webq', 'tqa', 'pop'] 53 | model_names = ['chatglm3-6b-chat', 'qwen-14b-chat', 'baichuan2-13b-chat', 'llama2-13b-chat', 'gpt-3.5-turbo'] 54 | for data_name in data_names: 55 | for model_name in model_names: 56 | res_path = f'{res_file_path}/{data_name}-test-gen-{model_name}.jsonl' 57 | # check if the file exists 58 | if not os.path.exists(res_path): 59 | continue 60 | predictions = datasets.load_dataset('json', data_files=res_path)['train'] 61 | predictions = evaluate(predictions) 62 | print('-' * 20) 63 | print(f'{data_name}-{model_name}') 64 | # compute the exact match score 65 | # 'list' object has no attribute 'sum' 66 | # predictions['exact_match'] is a list 67 | print(sum(predictions['exact_match']) / len(predictions['exact_match'])) 68 | print('-' * 20) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | 74 | -------------------------------------------------------------------------------- /src/misinfo/mis_config/mis_config_qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/input_data/DPR/sampled_query/nq-test-sample-10.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/nq-test-gen-qwen.jsonl", 5 | "misinfo_type": "mis_answer", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/nq_mis_config_answer_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/nq-test-gen-gpt-3.5-turbo-answer.jsonl", 5 | "misinfo_type": "mis_answer", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/nq_mis_config_passage_baichuan.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-baichuan-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/nq_mis_config_passage_chatglm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-chatglm-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/nq_mis_config_passage_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/nq-test-gen-gpt-3.5-turbo-answer.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-gpt-3.5-turbo-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/nq_mis_config_passage_llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-llama-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/nq_mis_config_passage_qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/nq-test-gen-qwen-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/pop_mis_config_answer_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/pop-test-gen-gpt-3.5-turbo-answer.jsonl", 5 | "misinfo_type": "mis_answer", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/pop_mis_config_passage_baichuan.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-baichuan-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/pop_mis_config_passage_chatglm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-chatglm-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/pop_mis_config_passage_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/pop-test-gen-gpt-3.5-turbo-answer.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-gpt-3.5-turbo-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/pop_mis_config_passage_llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-llama-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/pop_mis_config_passage_qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/pop-test-gen-qwen-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/tqa_mis_config_answer_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/tqa-test-gen-gpt-3.5-turbo-answer.jsonl", 5 | "misinfo_type": "mis_answer", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/tqa_mis_config_passage_baichuan.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-baichuan-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/tqa_mis_config_passage_chatglm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-chatglm-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/tqa_mis_config_passage_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/tqa-test-gen-gpt-3.5-turbo-answer.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-gpt-3.5-turbo-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/tqa_mis_config_passage_llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-llama-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/tqa_mis_config_passage_qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/tqa-test-gen-qwen-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/webq_mis_config_answer_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/webq-test-gen-gpt-3.5-turbo-answer.jsonl", 5 | "misinfo_type": "mis_answer", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/webq_mis_config_passage_baichuan.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "baichuan2-13b-chat", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-baichuan-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/webq_mis_config_passage_chatglm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "chatglm3", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-chatglm-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/webq_mis_config_passage_gpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "gpt-3.5-turbo", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_answer/webq-test-gen-gpt-3.5-turbo-answer.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-gpt-3.5-turbo-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/webq_mis_config_passage_llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Llama", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-llama-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/mis_config/webq_mis_config_passage_qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "Qwen", 3 | "question_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-gpt-3.5-turbo-passage.jsonl", 4 | "output_file_path": "../../data_v2/misinfo_data/DPR/mis_passage/webq-test-gen-qwen-passage.jsonl", 5 | "misinfo_type": "mis_passage", 6 | "api-base": "", 7 | "api-key": "" 8 | } -------------------------------------------------------------------------------- /src/misinfo/run_gen_misinfo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | datasets=(nq pop webq tqa) 4 | models=(qwen chatglm llama baichuan) 5 | # first generate false answer using gpt 6 | # shellcheck disable=SC2068 7 | for dataset in ${datasets[@]} 8 | do 9 | 10 | python gen_misinfo_llm.py --config_file_path mis_config/${dataset}_mis_config_answer_gpt.json 11 | 12 | done 13 | 14 | wait 15 | 16 | #generate misinformation passages using models 17 | # shellcheck disable=SC2068 18 | for dataset in ${datasets[@]} 19 | do 20 | # shellcheck disable=SC2068 21 | # shellcheck disable=SC2034 22 | for model in ${models[@]} 23 | do 24 | python gen_misinfo_llm.py --config_file_path mis_config/${dataset}_mis_config_passage_${model}.json > logs/${dataset}_misinfo_passage_${model}.log 2>&1 & 25 | done 26 | wait 27 | done -------------------------------------------------------------------------------- /src/post_process/delete_configs/delete_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "id_files": ["../run_logs/mis_filter_bleu_nq_webq_pop_tqa_loop_log_contriever_None_total_loop_10_20240206164013/index_add_logs"], 3 | "model_name": "faiss", 4 | "index_path": "../../data_v2/indexes", 5 | "index_name": "contriever_faiss_index", 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "delete_log_path": "../run_logs/mis_filter_bleu_nq_webq_pop_tqa_loop_log_contriever_None_total_loop_10_20240206164013/index_add_logs" 8 | } -------------------------------------------------------------------------------- /src/post_process/delete_configs/delete_config_bm25.json: -------------------------------------------------------------------------------- 1 | { 2 | "id_files": ["../run_logs/zero-shot_retrieval_log_low/index_add_logs"], 3 | "model_name": "bm25", 4 | "index_path": "../../data_v2/indexes", 5 | "index_name": "bm25_psgs_index", 6 | "elasticsearch_url": "http://124.16.138.150:9978", 7 | "delete_log_path": "../run_logs/zero-shot_retrieval_log_low/index_add_logs" 8 | } -------------------------------------------------------------------------------- /src/post_process/delete_doc_from_index.py: -------------------------------------------------------------------------------- 1 | #delete docs from index according to the doc id 2 | 3 | import sys 4 | import os 5 | import json 6 | import argparse 7 | import time 8 | 9 | sys.path.append('../retrieval_loop') 10 | 11 | from elastic_bm25_search_with_metadata import ElasticSearchBM25Retriever 12 | from faiss_search import Batch_FAISS 13 | 14 | 15 | def read_ids_from_file(file_path): 16 | ids = [] 17 | with open(file_path) as f: 18 | for line in f: 19 | ids.append(line.strip()) 20 | return ids 21 | 22 | 23 | 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | # a list of docid files 28 | parser.add_argument('--id_files', type=str, nargs='+', default=[]) 29 | # model name 30 | parser.add_argument('--model_name', type=str, default='bm25') 31 | # index path 32 | parser.add_argument('--index_path', type=str, default='') 33 | # index name 34 | parser.add_argument('--index_name', type=str, default='') 35 | # elasticsearch url 36 | parser.add_argument('--elasticsearch_url', type=str, default='http://localhost:9200') 37 | # config file path 38 | parser.add_argument('--config_file_path', type=str, default='../config/delete_doc_from_index.json') 39 | #delete log file path 40 | parser.add_argument('--delete_log_path', type=str, default='../logs') 41 | 42 | args = parser.parse_args() 43 | # 读取 JSON 配置文件 44 | # { 45 | # "id_files": ["../data/ids.txt"], 46 | # "model_name": "bm25", 47 | # "index_path": "", 48 | # "index_name": "index", 49 | # "elasticsearch_url": "http://localhost:9200" 50 | # } 51 | 52 | config = read_config_from_json(args.config_file_path) 53 | 54 | # 使用配置文件中的参数覆盖命令行参数 55 | args = override_args_by_config(args, config) 56 | 57 | print(f'args: {args}') 58 | 59 | return args 60 | 61 | 62 | # 函数用于读取 JSON 配置文件 63 | def read_config_from_json(json_file_path): 64 | try: 65 | with open(json_file_path, 'r') as json_file: 66 | args_dict = json.load(json_file) 67 | return args_dict 68 | except FileNotFoundError: 69 | print(f"Configuration file {json_file_path} not found.") 70 | return {} 71 | 72 | 73 | # 函数用于覆盖命令行参数 74 | def override_args_by_config(args, config): 75 | for key, value in config.items(): 76 | if hasattr(args, key): 77 | setattr(args, key, value) 78 | return args 79 | 80 | 81 | def delete_docs_from_index(id_files, model_name, index_path, index_name, elasticsearch_url, delete_log_path): 82 | delete_log_file = os.path.join(delete_log_path, f'delete_log_{model_name}_{index_name}_{time.strftime("%Y%m%d-%H%M%S")}.log') 83 | index_p = os.path.join(index_path, index_name) 84 | if model_name == 'bm25': 85 | index = ElasticSearchBM25Retriever.create(elasticsearch_url, index_name) 86 | index_size_before = index.get_document_count() 87 | elif model_name == 'faiss': 88 | index = Batch_FAISS.load_local(index_p, None) 89 | index_size_before = len(index.docstore._dict) 90 | 91 | else: 92 | raise ValueError(f'Invalid model name: {model_name}') 93 | 94 | with open(delete_log_file, 'w') as f: 95 | f.write(f'index size before: {index_size_before}\n') 96 | # wether id_files are directories or files 97 | delete_id_files = [] 98 | for id_file in id_files: 99 | if os.path.isdir(id_file): 100 | for file in os.listdir(id_file): 101 | if file.startswith(index_name): 102 | delete_id_files.append(os.path.join(id_file, file)) 103 | else: 104 | delete_id_files.append(id_file) 105 | 106 | for id_file in delete_id_files: 107 | if model_name not in id_file: 108 | continue 109 | ids = read_ids_from_file(id_file) 110 | 111 | if model_name == 'bm25': 112 | index.delete_documents_by_id(ids) 113 | index_size_after = index.get_document_count() 114 | elif model_name == 'faiss': 115 | try: 116 | index.delete(ids) 117 | except Exception as e: 118 | print(e) 119 | pass 120 | index_size_after = len(index.docstore._dict) 121 | else: 122 | raise ValueError(f'Invalid model name: {model_name}') 123 | # write ids to delete log file 124 | f.write(f'delete ids from {id_file}: {ids}\n') 125 | print(f'index size before: {index_size_before}, delete ids from {id_file}, index size after: {index_size_after}') 126 | f.write(f'index size after: {index_size_after}\n') 127 | if model_name == 'faiss': 128 | index.save_local(index_p) 129 | 130 | 131 | if __name__ == '__main__': 132 | args = get_args() 133 | delete_docs_from_index(args.id_files, args.model_name, args.index_path, args.index_name, args.elasticsearch_url, args.delete_log_path) 134 | 135 | 136 | -------------------------------------------------------------------------------- /src/post_process/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAMES=(qwen-0.5b-chat qwen-1.8b-chat qwen-4b-chat) 4 | QUERY_DATA_NAMES=(nq pop tqa webq) 5 | LOOP_NUM=0 6 | LOOP_CONFIG_PATH_NAME="../run_configs/update_configs" 7 | FROM_METHOD="update" 8 | 9 | TOTAL_LOG_DIR="../run_logs/update_log" 10 | TOTAL_OUTPUT_DIR="../../data_v2/update_data/DPR/update_passage_processed" 11 | mkdir -p "${TOTAL_LOG_DIR}" 12 | mkdir -p "${TOTAL_OUTPUT_DIR}" 13 | mkdir -p "${LOOP_CONFIG_PATH_NAME}" 14 | 15 | INPUT_FILE_PATH="../../data_v2/loop_output/DPR/zero_update_retrieval_result" 16 | 17 | for MODEL_NAME in "${MODEL_NAMES[@]}" 18 | do 19 | for QUERY_DATA_NAME in "${QUERY_DATA_NAMES[@]}" 20 | do 21 | # OUTPUT_DIR="${TOTAL_OUTPUT_DIR}/${QUERY_DATA_NAME}" 22 | # mkdir -p "${OUTPUT_DIR}" 23 | echo "rewrite config file for ${MODEL_NAME} on ${QUERY_DATA_NAME}..." 24 | INPUT_FILE_NAME="${INPUT_FILE_PATH}/${QUERY_DATA_NAME}/${MODEL_NAME}_${QUERY_DATA_NAME}-test-sample-200.jsonl_generate_context_ref_num_0.json" 25 | CONFIG_PATH="${LOOP_CONFIG_PATH_NAME}/${MODEL_NAME}_${QUERY_DATA_NAME}_postprocess_loop_${LOOP_NUM}.json" 26 | LOG_DIR="${TOTAL_LOG_DIR}/${MODEL_NAME}_${QUERY_DATA_NAME}_postprocess_loop_${LOOP_NUM}.log" 27 | POSTPROCESS_OUTPUT_NAME="${TOTAL_OUTPUT_DIR}/${QUERY_DATA_NAME}-test-gen-${MODEL_NAME}-postprocessed.jsonl" 28 | python ../rewrite_configs.py --method "${MODEL_NAME}" \ 29 | --data_name "${QUERY_DATA_NAME}" \ 30 | --loop "${LOOP_NUM}" \ 31 | --stage "post_process" \ 32 | --output_dir "${CONFIG_PATH}" \ 33 | --overrides '{"loop_num": "'"${LOOP_NUM}"'", "gen_model_name": "'"${MODEL_NAME}"'", "input_file": "'"${INPUT_FILE_NAME}"'", "output_dir": "'"${POSTPROCESS_OUTPUT_NAME}"'", "from_method": "'"${FROM_METHOD}"'"}' 34 | wait 35 | echo "Running postprocess for ${MODEL_NAME} on ${QUERY_DATA_NAME}..." 36 | python process_llm_text.py --config_file_path "$CONFIG_PATH" > "$LOG_DIR" 2>&1 & 37 | 38 | done 39 | done 40 | -------------------------------------------------------------------------------- /src/post_process/process_configs/template_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "loop_num": 0, 3 | "gen_model_name": "llama2", 4 | "query_set_name": "nq", 5 | "output_dir": "../../data_v2/post_processed_data/DPR/nq-test-gen-llama2-13b-chat.jsonl", 6 | "input_file": "../../data_v2/zero_gen_data/DPR/nq-test-gen-llama2-13b-chat.jsonl", 7 | "from_method": "zero-shot" 8 | } -------------------------------------------------------------------------------- /src/post_process/run_index_doc_delete.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Delete a document from an index 4 | 5 | python delete_doc_from_index.py --config_file_path delete_configs/delete_config_bm25.json 6 | -------------------------------------------------------------------------------- /src/rerank_loop/rankgpt_prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "LLMreranker" 18 | file_name = osp.join("templates", f"{template_name}.json") 19 | if not osp.exists(file_name): 20 | raise ValueError(f"Can't read {file_name}") 21 | with open(file_name) as fp: 22 | self.template = json.load(fp) 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | num: Union[None, str] = None, 34 | feedback_documents: Union[None, str] = None, 35 | ) -> str: 36 | # returns the full prompt from instruction and optional input 37 | # if a label (=response, =output) is provided, it's also appended. 38 | if feedback_documents: 39 | res = self.template["prompt_input"].format( 40 | num=num, 41 | instruction=instruction, 42 | input=input, 43 | feedback_documents=feedback_documents, 44 | ) 45 | elif input and num: 46 | res = self.template["prompt_input"].format( 47 | num=num, instruction=instruction, input=input 48 | ) 49 | elif input: 50 | res = self.template["prompt_input"].format( 51 | instruction=instruction, input=input 52 | ) 53 | else: 54 | res = self.template["prompt_no_input"].format( 55 | instruction=instruction 56 | ) 57 | if label: 58 | res = f"{res}{label}" 59 | if self._verbose: 60 | print(res) 61 | return res 62 | 63 | def get_response(self, output: str) -> str: 64 | return output.split(self.template["response_split"])[1].strip() 65 | -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/bge-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "bge", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/bge-reranker-large", 7 | "use_gpu": true, 8 | "report_topk_accuracies": [1, 5, 20, 100], 9 | "merge_shards_and_save": true, 10 | "special_suffix": "nq-bge-test-plm-BGE_reranker_large-topk-100-bm25", 11 | "index_name": "bm25_psgs_index", 12 | "elasticsearch_url": "http://124.16.138.142:9978", 13 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/nq-test-bm25-async.json", 14 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 15 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/bge-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "bge", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/bge-reranker-large", 7 | "use_gpu": true, 8 | "report_topk_accuracies": [1, 5, 20, 100], 9 | "merge_shards_and_save": true, 10 | "special_suffix": "pop-bge-test-plm-BGE_reranker_large-topk-100-bm25", 11 | "index_name": "bm25_psgs_index", 12 | "elasticsearch_url": "http://124.16.138.142:9978", 13 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/pop-test-bm25-async.json", 14 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 15 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/bge-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "bge", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/bge-reranker-large", 7 | "use_gpu": true, 8 | "report_topk_accuracies": [1, 5, 20, 100], 9 | "merge_shards_and_save": true, 10 | "special_suffix": "tqa-bge-test-plm-BGE_reranker_large-topk-100-bm25", 11 | "index_name": "bm25_psgs_index", 12 | "elasticsearch_url": "http://124.16.138.142:9978", 13 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/tqa-test-bm25-async.json", 14 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 15 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/bge-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "bge", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/bge-reranker-large", 7 | "use_gpu": true, 8 | "report_topk_accuracies": [1, 5, 20, 100], 9 | "merge_shards_and_save": true, 10 | "special_suffix": "webq-bge-test-plm-BGE_reranker_large-topk-100-bm25", 11 | "index_name": "bm25_psgs_index", 12 | "elasticsearch_url": "http://124.16.138.142:9978", 13 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/webq-test-bm25-async.json", 14 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 15 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/monot5-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "monot5", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/monot5-3b-msmarco-10k", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "nq-monot5-test-plm-MonoT5_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/nq-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/monot5-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "monot5", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/monot5-3b-msmarco-10k", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "pop-monot5-test-plm-MonoT5_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/pop-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/monot5-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "monot5", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/monot5-3b-msmarco-10k", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "tqa-monot5-test-plm-MonoT5_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/tqa-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/monot5-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "monot5", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "hf_model_name": "../../ret_model/monot5-3b-msmarco-10k", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "webq-monot5-test-plm-MonoT5_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/webq-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/rankgpt-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "rankgpt", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "rankgpt_llm_model_name": "gpt-3.5-turbo", 7 | "rankgpt_prompter_name": "LLMreranker", 8 | "rankgpt_rank_end": 30, 9 | "rankgpt_window_size": 20, 10 | "rankgpt_step_size": 10, 11 | "report_topk_accuracies": [1, 5, 20, 100], 12 | "merge_shards_and_save": true, 13 | "special_suffix": "nq-rankgpt-test-plm-gpt_3.5_turbo-topk-100-h10-bge-large", 14 | "index_name": "bm25_psgs_index", 15 | "elasticsearch_url": "http://124.16.138.142:9978", 16 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/nq-test-h10-bge-large.json", 17 | "reranker_output_dir": "../../data_v2/ret_output/DPR", 18 | "rankgpt_api_key": "", 19 | "rankgpt_api_base": "" 20 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/rankgpt-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "rankgpt", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "rankgpt_llm_model_name": "gpt-3.5-turbo", 7 | "rankgpt_prompter_name": "LLMreranker", 8 | "rankgpt_rank_end": 30, 9 | "rankgpt_window_size": 20, 10 | "rankgpt_step_size": 10, 11 | "report_topk_accuracies": [1, 5, 20, 100], 12 | "merge_shards_and_save": true, 13 | "special_suffix": "pop-rankgpt-test-plm-gpt_3.5_turbo-topk-100-h10-bge-large", 14 | "index_name": "bm25_psgs_index", 15 | "elasticsearch_url": "http://124.16.138.142:9978", 16 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/pop-test-h10-bge-large.json", 17 | "reranker_output_dir": "../../data_v2/ret_output/DPR", 18 | "rankgpt_api_key": "", 19 | "rankgpt_api_base": "" 20 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/rankgpt-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "rankgpt", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "rankgpt_llm_model_name": "gpt-3.5-turbo", 7 | "rankgpt_prompter_name": "LLMreranker", 8 | "rankgpt_rank_end": 30, 9 | "rankgpt_window_size": 20, 10 | "rankgpt_step_size": 10, 11 | "report_topk_accuracies": [1, 5, 20, 100], 12 | "merge_shards_and_save": true, 13 | "special_suffix": "tqa-rankgpt-test-plm-gpt_3.5_turbo-topk-100-h10-bge-large", 14 | "index_name": "bm25_psgs_index", 15 | "elasticsearch_url": "http://124.16.138.142:9978", 16 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/tqa-test-h10-bge-large.json", 17 | "reranker_output_dir": "../../data_v2/ret_output/DPR", 18 | "rankgpt_api_key": "", 19 | "rankgpt_api_base": "" 20 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/rankgpt-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "method": "rankgpt", 3 | "num_workers": 2, 4 | "log_interval": 1, 5 | "topk_passages": 100, 6 | "rankgpt_llm_model_name": "gpt-3.5-turbo", 7 | "rankgpt_prompter_name": "LLMreranker", 8 | "rankgpt_rank_end": 30, 9 | "rankgpt_window_size": 20, 10 | "rankgpt_step_size": 10, 11 | "report_topk_accuracies": [1, 5, 20, 100], 12 | "merge_shards_and_save": true, 13 | "special_suffix": "webq-rankgpt-test-plm-gpt_3.5_turbo-topk-100-h10-bge-large", 14 | "index_name": "bm25_psgs_index", 15 | "elasticsearch_url": "http://124.16.138.142:9978", 16 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/webq-test-h10-bge-large.json", 17 | "reranker_output_dir": "../../data_v2/ret_output/DPR", 18 | "rankgpt_api_key": "", 19 | "rankgpt_api_base": "" 20 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/upr-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_workers": 2, 3 | "log_interval": 1, 4 | "topk_passages": 100, 5 | "shard_size": 32, 6 | "hf_model_name": "../../ret_model/T0_3B", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "nq-upr-test-plm-T0_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/nq-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/upr-config-pop.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_workers": 2, 3 | "log_interval": 1, 4 | "topk_passages": 100, 5 | "shard_size": 32, 6 | "hf_model_name": "../../ret_model/T0_3B", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "pop-upr-test-plm-T0_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/pop-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/upr-config-tqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_workers": 2, 3 | "log_interval": 1, 4 | "topk_passages": 100, 5 | "shard_size": 32, 6 | "hf_model_name": "../../ret_model/T0_3B", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "tqa-upr-test-plm-T0_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/tqa-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/rerank_configs/upr-config-webq.json: -------------------------------------------------------------------------------- 1 | { 2 | "num_workers": 2, 3 | "log_interval": 1, 4 | "topk_passages": 100, 5 | "shard_size": 32, 6 | "hf_model_name": "../../ret_model/T0_3B", 7 | "use_gpu": true, 8 | "use_bf16": true, 9 | "report_topk_accuracies": [1, 5, 20, 100], 10 | "merge_shards_and_save": true, 11 | "special_suffix": "webq-upr-test-plm-T0_3B-topk-100-bm25", 12 | "index_name": "bm25_psgs_index", 13 | "elasticsearch_url": "http://124.16.138.142:9978", 14 | "retriever_topk_passages_path": "../../data_v2/ret_output/DPR/webq-test-bm25-async.json", 15 | "reranker_output_dir": "../../data_v2/ret_output/DPR" 16 | } -------------------------------------------------------------------------------- /src/rerank_loop/run_upr.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 指定GPU设备 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | # 定义世界大小(分布式训练的进程数) 7 | export WORLD_SIZE=1 8 | 9 | RERANK_MODEL_NAMES=(bge monot5) #monot5 upr bge rankgpt 10 | QUERY_DATA_NAMES=(nq) #nq webq tqa pop 11 | RETRIEVAL_MODEL_NAMES=(bm25 all-mpnet bge-base bge-large contriever dpr retromae llm-embedder) 12 | PORT=13423 13 | LOOP_CONFIG_PATH_NAME="../run_configs/original_retrieval_config" 14 | 15 | TOTAL_LOG_DIR="../run_logs/original_retrieval_log" 16 | TOTAL_OUTPUT_DIR="../../data_v2/loop_output/DPR/original_retrieval_result" 17 | mkdir -p "${TOTAL_LOG_DIR}" 18 | mkdir -p "${TOTAL_OUTPUT_DIR}" 19 | 20 | 21 | for RERANK_MODEL_NAME in "${RERANK_MODEL_NAMES[@]}"; do 22 | for QUERY_DATA_NAME in "${QUERY_DATA_NAMES[@]}"; do 23 | for RETRIEVAL_MODEL_NAME in "${RETRIEVAL_MODEL_NAMES[@]}"; do 24 | RETRIEVAL_OUTPUT_NAME="${QUERY_DATA_NAME}-test-${RETRIEVAL_MODEL_NAME}" 25 | RETRIEVAL_OUTPUT_PATH="${TOTAL_OUTPUT_DIR}/${QUERY_DATA_NAME}/${RETRIEVAL_OUTPUT_NAME}" 26 | OUTPUT_DIR="${TOTAL_OUTPUT_DIR}/${QUERY_DATA_NAME}" 27 | mkdir -p "${OUTPUT_DIR}" 28 | CONFIG_PATH="${LOOP_CONFIG_PATH_NAME}/${QUERY_DATA_NAME}_${RETRIEVAL_MODEL_NAME}_${RERANK_MODEL_NAME}_rerank.json" 29 | LOG_DIR="${TOTAL_LOG_DIR}/${QUERY_DATA_NAME}_${RETRIEVAL_MODEL_NAME}_${RERANK_MODEL_NAME}_rerank.log" 30 | RERANK_OUTPUT_NAME="${QUERY_DATA_NAME}-${RERANK_MODEL_NAME}_rerank_based_on_${RETRIEVAL_MODEL_NAME}" 31 | python ../rewrite_configs.py --method "${RERANK_MODEL_NAME}" \ 32 | --data_name "${QUERY_DATA_NAME}" \ 33 | --stage "rerank" \ 34 | --output_dir "${CONFIG_PATH}" \ 35 | --overrides '{"retriever_topk_passages_path": "'"${RETRIEVAL_OUTPUT_PATH}"'", "special_suffix": "'"${RERANK_OUTPUT_NAME}"'", "reranker_output_dir": "'"${OUTPUT_DIR}"'"}' 36 | 37 | wait 38 | echo "Running rerank for ${RERANK_MODEL_NAME} on ${QUERY_DATA_NAME}..." 39 | # 运行分布式Python脚本 40 | torchrun --nproc_per_node ${WORLD_SIZE} \ 41 | --nnodes 1 \ 42 | --node_rank 0 \ 43 | --master_addr localhost \ 44 | --master_port $PORT \ 45 | rerank_for_loop.py \ 46 | --config "$CONFIG_PATH" > "$LOG_DIR" 2>&1 & 47 | PORT=$((PORT+1)) 48 | # 等待所有进程结束 49 | wait 50 | 51 | done 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /src/rerank_loop/templates/LLMreranker.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by LLMreranker.", 3 | "prompt_input": "Rank the {num} passages based on their relevance to the search query. The passages will be listed in descending order using identifiers, and the most relevant passages should be listed first, and the output format should be [] > [] > etc\n\n### Query:\n{instruction}\n\n### Candidates:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } -------------------------------------------------------------------------------- /src/rerank_loop/templates/PRFreranker.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by LLMreranker.", 3 | "prompt_input": "Given the current search query and the first few documents from the input candidates acting as pseudo-relevant feedback, please rank the {num} candidates based on their relevance. The pseudo-relevant feedback passages may contain useful information about the query. The candidates should be listed in descending order of relevance using their identifiers, with the most relevant listed first. The output format should be [] > [] > etc, just output the identifiers of the candidates not the text.\n\n### Query:\n{instruction}\n\n### Candidates:\n{input}\n\n### Pseudo-Relevant Feedback:\n{feedback_documents}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/all-mpnet-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "all-mpnet-base-v2", 4 | "index_name": "all-mpnet_faiss_index_h100", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": true, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-all-mpnet"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/bge-base-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "bge-base", 4 | "index_name": "bge-base_faiss_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": true, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-bge-base"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/bge-large-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "bge-large", 4 | "index_name": "bge-large_faiss_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": true, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-bge-large"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/bm25-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "BM25", 4 | "index_name": "bm25_psgs_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": false, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-bm25"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/contriever-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "Contriever", 4 | "index_name": "contriever_faiss_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": false, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-contriever"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/dpr-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "DPR", 4 | "index_name": "DPR_faiss_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": false, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-dpr"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/llm-embedder-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "llm-embedder", 4 | "index_name": "llm-embedder_faiss_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": false, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-llm-embedder"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/retromae-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100.jsonl", 3 | "retrieval_model": "retromae", 4 | "index_name": "retromae_faiss_index", 5 | "index_path": "../../data_v2/indexes", 6 | "index_add_path": "../../data_v2/indexes", 7 | "page_content_column": "contents", 8 | "index_exists": false, 9 | "normalize_embeddings": false, 10 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-retromae"], 13 | "elasticsearch_url": "http://124.16.138.142:9978" 14 | } -------------------------------------------------------------------------------- /src/retrieval_loop/index_configs/tmp/bge-config-psgs_w100.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100_h100.jsonl", 3 | "retrieval_model": "Contriever", 4 | "index_name": "contriever_faiss_index_h100", 5 | "index_path": "../../data_v2/indexes", 6 | "page_content_column": "contents", 7 | "index_exists": false, 8 | "normalize_embeddings": false, 9 | "query_file": "../../data_v2/zero_gen_data/DPR/nq-test-h10.jsonl", 10 | "query_page_content_column": "question", 11 | "output_file": "../../data_v2/ret_output/DPR/nq-test-h10-contriever" 12 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/all-mpnet-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "all-mpnet-base-v2", 3 | "index_name": "all-mpnet_faiss_index_h100", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": true, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 7 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-all-mpnet", 12 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-all-mpnet", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-all-mpnet", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-all-mpnet"], 15 | "elasticsearch_url": "http://124.16.138.140:9978" 16 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/bge-base-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "bge-base", 3 | "index_name": "bge-base_faiss_index", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": true, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 7 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-bge-base", 12 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-bge-base", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-bge-base", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-bge-base"], 15 | "elasticsearch_url": "http://124.16.138.142:9978" 16 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/bge-large-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "bge-large", 3 | "index_name": "bge-large_faiss_index", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": true, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 7 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-bge-large", 12 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-bge-large", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-bge-large", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-bge-large" 15 | ] 16 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/bm25-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "BM25", 3 | "index_name": "bm25_psgs_index", 4 | "index_path": "../../data_v2/indexes", 5 | "page_content_column": "contents", 6 | "normalize_embeddings": false, 7 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 10 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 11 | "query_page_content_column": "question", 12 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-bm25", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-bm25", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-bm25", 15 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-bm25"], 16 | "elasticsearch_url": "http://124.16.138.142:9978" 17 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/contriever-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "Contriever", 3 | "index_name": "contriever_faiss_index", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": false, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl"], 7 | "query_page_content_column": "question", 8 | "output_files": ["../../data_v2/loop_output/DPR/nq-test-contriever"], 9 | "elasticsearch_url": "http://124.16.138.147:9978" 10 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/dpr-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "DPR", 3 | "index_name": "DPR_faiss_index", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": false, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 7 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-dpr", 12 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-dpr", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-dpr", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-dpr"], 15 | "elasticsearch_url": "http://124.16.138.142:9978" 16 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/llm-embedder-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "llm-embedder", 3 | "index_name": "llm-embedder_faiss_index", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": false, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 7 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-llm-embedder", 12 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-llm-embedder", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-llm-embedder", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-llm-embedder"], 15 | "elasticsearch_url": "http://124.16.138.142:9978" 16 | } -------------------------------------------------------------------------------- /src/retrieval_loop/retrieve_configs/retromae-config-nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "retrieval_model": "retromae", 3 | "index_name": "retromae_faiss_index", 4 | "index_path": "../../data_v2/indexes", 5 | "normalize_embeddings": false, 6 | "query_files": ["../../data_v2/input_data/DPR/sampled_query/nq-test-sample-200.jsonl", 7 | "../../data_v2/input_data/DPR/sampled_query/webq-test-sample-200.jsonl", 8 | "../../data_v2/input_data/DPR/sampled_query/pop-test-sample-200.jsonl", 9 | "../../data_v2/input_data/DPR/sampled_query/tqa-test-sample-200.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/loop_output/DPR/original_retrieval_result/nq/nq-test-retromae", 12 | "../../data_v2/loop_output/DPR/original_retrieval_result/webq/webq-test-retromae", 13 | "../../data_v2/loop_output/DPR/original_retrieval_result/pop/pop-test-retromae", 14 | "../../data_v2/loop_output/DPR/original_retrieval_result/tqa/tqa-test-retromae"], 15 | "elasticsearch_url": "http://124.16.138.142:9978" 16 | } -------------------------------------------------------------------------------- /src/retrieval_loop/run_index_builder.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODEL_NAMES=(bm25) # dpr contriever) # retromae all-mpnet bge-large llm-embedder bm25 contriever dpr) 4 | DATA_NAMES=(psgs_w100) 5 | 6 | for MODEL_NAME in "${MODEL_NAMES[@]}" 7 | do 8 | for DATA_NAME in "${DATA_NAMES[@]}" 9 | do 10 | echo "Running index bulider for ${MODEL_NAME} on ${DATA_NAME}..." 11 | CONFIG_PATH="index_configs/${MODEL_NAME}-config-${DATA_NAME}.json" 12 | LOG_DIR="logs/${MODEL_NAME}_${DATA_NAME}_indexing.log" 13 | 14 | python embedding_index_incremental_corpus.py --config_file_path "$CONFIG_PATH" > "$LOG_DIR" 2>&1 & 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /src/retrieval_loop/run_retrieval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | MODEL_NAMES=(contriever) #(bm25 dpr bge-base contriever retromae bge-large all-mpnet llm-ebedder) 5 | DATA_NAMES=(nq) 6 | 7 | for MODEL_NAME in "${MODEL_NAMES[@]}" 8 | do 9 | for DATA_NAME in "${DATA_NAMES[@]}" 10 | do 11 | echo "Running retrieval for ${MODEL_NAME}..." 12 | CONFIG_PATH="retrieve_configs/${MODEL_NAME}-config-${DATA_NAME}.json" 13 | LOG_DIR="../run_logs/test/${MODEL_NAME}_retrieval.log" 14 | 15 | python retrieve_methods.py --config_file_path "$CONFIG_PATH" > "$LOG_DIR" 2>&1 & 16 | done 17 | wait 18 | done 19 | -------------------------------------------------------------------------------- /src/rewrite_configs.py: -------------------------------------------------------------------------------- 1 | #rewrite config files before running a loop 2 | #read retrieval, rerank, generate, post_process and indexing config 3 | 4 | import os 5 | import sys 6 | import argparse 7 | import json 8 | 9 | 10 | #read config files 11 | 12 | def read_config(template_path): 13 | with open(template_path, 'r') as f: 14 | config = json.load(f) 15 | return config 16 | 17 | 18 | def get_args(): 19 | 20 | parser = argparse.ArgumentParser() 21 | 22 | group = parser.add_argument_group(title='argument-parser') 23 | 24 | group.add_argument('--stage', type=str, required=True, help='stage name: retrieval, rerank, generate, indexing') 25 | 26 | group.add_argument('--output_dir', type=str, required=True, help='config output dir') 27 | 28 | group.add_argument('--method', type=str, required=True, help='method name') 29 | 30 | group.add_argument('--data_name', type=str, required=True, help='data name') 31 | 32 | group.add_argument('--loop', type=int, default=0, help='loop number') 33 | 34 | group.add_argument('--total_config', type=str, default=None, help='total config file') 35 | 36 | # Accept a JSON string from the command line to override config settings 37 | group.add_argument('--overrides', type=str, help='JSON string to override config values') 38 | 39 | 40 | args = parser.parse_args() 41 | 42 | if args.overrides: 43 | print(f"Overriding config with: {args.overrides}") 44 | # {"index_name": "bm25_test_index", "index_exists": false} 45 | try: 46 | args.overrides = json.loads(args.overrides) 47 | except json.JSONDecodeError: 48 | print("Overrides must be a valid JSON string") 49 | sys.exit(1) 50 | 51 | return args 52 | 53 | def get_template_path(stage_name, method_name, data_name): 54 | if stage_name == 'retrieval': 55 | template_path = f'retrieve_configs/{method_name}-config-{data_name}.json' 56 | elif stage_name == 'rerank': 57 | template_path = f'rerank_configs/{method_name}-config-{data_name}.json' 58 | elif stage_name == 'generate': 59 | template_path = f'rag_configs/{method_name}-config-{data_name}.json' 60 | elif stage_name == 'zero_update_generate': 61 | template_path = f'update_configs/zero-shot_configs/{method_name}-config-{data_name}.json' 62 | elif stage_name == 'update_generate': 63 | template_path = f'update_configs/rag_configs/{method_name}-config-{data_name}.json' 64 | elif stage_name == 'indexing': 65 | template_path = f'index_configs/{method_name}-config-{data_name}.json' 66 | elif stage_name == 'post_process': 67 | template_path = f'process_configs/template_config.json' 68 | elif stage_name == 'filter_bleu': 69 | template_path = f'filter_configs/bleu_filter_config.json' 70 | elif stage_name == 'filter_source': 71 | template_path = f'filter_configs/source_filter_config.json' 72 | else: 73 | raise ValueError('stage name error') 74 | return template_path 75 | 76 | 77 | def rewrite_config(stage_name, method_name, data_name, loop, output_dir, total_config, overrides): 78 | template_path = get_template_path(stage_name, method_name, data_name) 79 | config = read_config(template_path) 80 | if total_config is not None: 81 | running_config = read_config(total_config)[stage_name] 82 | 83 | # for any key in config_template, if it is in running_config, use the value in running_config 84 | # otherwise, use the value in config_template 85 | for key in config: 86 | if key in running_config: 87 | config[key] = running_config[key] 88 | 89 | 90 | # if overrides is not None: 91 | # Apply overrides from the command line 92 | if overrides: 93 | for key, value in overrides.items(): 94 | config[key] = value 95 | 96 | with open(output_dir, 'w') as f: 97 | json.dump(config, f, indent=4) 98 | 99 | print(f'loop {loop} config file is saved in {output_dir}') 100 | print(f'stage: {stage_name}, method: {method_name}, data: {data_name}, loop: {loop}') 101 | print(f'config: {config}') 102 | 103 | 104 | def main(): 105 | args = get_args() 106 | rewrite_config(args.stage, args.method, args.data_name, args.loop, args.output_dir, args.total_config, args.overrides) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | 112 | -------------------------------------------------------------------------------- /src/test_function/test_configs/indexing_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100_h100.jsonl", 3 | "retrieval_model": "BM25", 4 | "index_name": "bm25_test_index", 5 | "index_path": "../../data_v2/indexes", 6 | "page_content_column": "contents", 7 | "index_exists": false, 8 | "normalize_embeddings": false, 9 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 10 | "query_page_content_column": "question", 11 | "output_files": ["../../data_v2/test_output/DPR/nq-test-h10-bm25"] 12 | } -------------------------------------------------------------------------------- /src/test_function/test_configs/post_process_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "loop_num": 0, 3 | "gen_model_name": "llama2", 4 | "query_set_name": "nq", 5 | "output_dir": "../../data_v2/test_output/DPR/nq-test-gen-llama2-13b-chat.jsonl", 6 | "input_file": "../../data_v2/zero_gen_data/DPR/nq-test-gen-llama2-13b-chat.jsonl" 7 | } -------------------------------------------------------------------------------- /src/test_function/test_configs/retrieval_config.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VerdureChen/SOS-Retrieval-Loop/56a19847edfc30d1e6e59894d348e1c17ae15114/src/test_function/test_configs/retrieval_config.json -------------------------------------------------------------------------------- /src/test_function/test_configs/template_total_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "indexing": { 3 | "new_text_file": "../../data_v2/input_data/DPR/psgs_w100_h100.jsonl", 4 | "page_content_column": "contents", 5 | "normalize_embeddings": true, 6 | "query_files": ["../../data_v2/input_data/DPR/nq-test-h10.jsonl"], 7 | "query_page_content_column": "question", 8 | "output_files": ["../../data_v2/ret_output/DPR/nq-test-h10-test"] 9 | }, 10 | "retrieval":{ 11 | "normalize_embeddings": true, 12 | "query_page_content_column": "question" 13 | }, 14 | "rerank": { 15 | "num_workers": 2, 16 | "log_interval": 1, 17 | "topk_passages": 100, 18 | "use_gpu": true, 19 | "report_topk_accuracies": [1, 5, 20, 100], 20 | "merge_shards_and_save": true, 21 | "elasticsearch_url": "http://124.16.138.142:9978" 22 | }, 23 | "generate": { 24 | "with_context": true, 25 | "elasticsearch_url": "http://124.16.138.142:9978", 26 | "context_ref_num": 5 27 | }, 28 | "post_process": { 29 | "query_set_name": "nq" 30 | }, 31 | "filter_bleu": { 32 | "max_self_bleu": 0.4, 33 | "num_docs": 5 34 | }, 35 | "filter_source": { 36 | "max_self_bleu": 0.4, 37 | "num_docs": 5 38 | }, 39 | "update_generate": { 40 | "with_context": true, 41 | "elasticsearch_url": "http://124.16.138.142:9978", 42 | "context_ref_num": 5 43 | } 44 | } --------------------------------------------------------------------------------