├── C-3PO ├── .gitignore ├── __init__.py ├── arguments.py ├── config │ ├── freshqa_server.yaml │ ├── multihoprag_server.yaml │ └── search_config.yaml ├── construct_cache.py ├── file_query.py ├── llm_server.py ├── log_utils.py ├── main.py ├── metrics.py ├── prompt │ ├── __init__.py │ ├── decide_next_step.py │ ├── decide_prompt.py │ ├── decide_prompt_identity.py │ ├── decide_prompt_identity_v2.py │ ├── decision.py │ ├── evaluation.py │ ├── few_shots │ │ ├── __init__.py │ │ ├── answer │ │ │ └── public.py │ │ ├── decide_next_step │ │ │ ├── 2WikiMultiHopQA.py │ │ │ ├── Musique.py │ │ │ ├── NaturalQuestions.py │ │ │ ├── PopQA.py │ │ │ ├── TriviaQA.py │ │ │ ├── __init__.py │ │ │ ├── hotpotqa.py │ │ │ └── public.py │ │ ├── decision │ │ │ ├── 2WikiMultiHopQA.py │ │ │ ├── Musique.py │ │ │ ├── NaturalQuestions.py │ │ │ ├── PopQA.py │ │ │ ├── TriviaQA.py │ │ │ ├── __init__.py │ │ │ └── hotpotqa.py │ │ ├── filter │ │ │ ├── 2WikiMultiHopQA.py │ │ │ ├── Musique.py │ │ │ ├── NaturalQuestions.py │ │ │ ├── PopQA.py │ │ │ ├── TriviaQA.py │ │ │ ├── __init__.py │ │ │ ├── hotpotqa.py │ │ │ └── template.py │ │ └── planning │ │ │ ├── 2WikiMultiHopQA.py │ │ │ ├── ASQA.py │ │ │ ├── Musique.py │ │ │ ├── NaturalQuestions.py │ │ │ ├── PopQA.py │ │ │ ├── TriviaQA.py │ │ │ ├── __init__.py │ │ │ ├── hotpotqa.py │ │ │ └── public.py │ ├── filter.py │ └── planning.py ├── proxy.py ├── retrieve │ ├── retrieve_engine.py │ ├── retriever.py │ └── search_engine.py ├── search │ ├── LLM_planning_role.py │ ├── LLM_query_role.py │ ├── decide_next_step_role.py │ ├── decide_prompt_role.py │ ├── evaluation_role.py │ ├── make_decision_role.py │ ├── retrieve_filter_role.py │ └── search.py ├── solver.py ├── tree │ ├── node.py │ └── tree.py └── utils.py ├── LICENSE ├── README.md ├── deploy_servers ├── llm_server │ ├── base_sgl.sh │ └── qwen_72b_serve.sh └── retrieve_server │ └── retrieve_code │ ├── passage_retrieval.py │ ├── src │ ├── __init__.py │ ├── beir_utils.py │ ├── contriever.py │ ├── data.py │ ├── dist_utils.py │ ├── evaluation.py │ ├── finetuning_data.py │ ├── inbatch.py │ ├── index.py │ ├── moco.py │ ├── normalize_text.py │ ├── options.py │ ├── slurm.py │ └── utils.py │ └── wiki18 │ ├── query_serve.py │ ├── start_wiki18.sh │ ├── wiki18_config.yaml │ └── wiki18_serve.py ├── images └── C-3PO.png ├── inference └── single_model.sh ├── instruct_sampling_scripts ├── offline_base_instruct.sh └── run_72b.sh ├── requirements.txt ├── retrieval_requirements.txt └── train └── sft_scripts └── run_base_packing.sh /C-3PO/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # custom .gitignore 163 | user.config 164 | saves/ 165 | cache/ 166 | 167 | -------------------------------------------------------------------------------- /C-3PO/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/__init__.py -------------------------------------------------------------------------------- /C-3PO/config/freshqa_server.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_dir: ./output_dir/models/Qwen2-1.5B/ppo_106_344_ckpt160 2 | model_base_name: ppo_106_344_ckpt160 3 | tp: 1 4 | port: 20000 5 | proxy_concurrency: 256 6 | temperature: 0.6 7 | top_k: -1 8 | top_p: 1.0 9 | max_tokens: 2048 10 | model_type: proxy 11 | gpt_proxy_concurrency: 8 12 | openai_api_key: None 13 | proxy_url: None 14 | proxy_model_name: None 15 | force_decision: False 16 | force_action: Planning # "Planning", "Retrieval" No Retrieval 17 | 18 | ## params of search 19 | n_decision_sample: 1 20 | n_generate_sample: 1 21 | n_plan_sample: 1 22 | n_answer_sample: 1 23 | max_iter: 20 24 | search_config: ./config/search_config.yaml 25 | few_shot: False 26 | few_num: 2 27 | dict_few_num: 1 28 | decide_prompt: "identity" 29 | max_depth: 13 30 | max_documents: 10 31 | answer_eval: "none" 32 | retriever_type: "search_engine" # search_engine dense 33 | retrieve_server_url: http://10.32.25.199:35004/search 34 | musique_server_url: http://10.32.25.199:35002/search 35 | retrieve_top_k: 10 36 | max_query_length: 100 37 | search_engine_url: your url 38 | search_scene: your scene 39 | search_engine_cache: True 40 | search_engine_cache_file: ./cache_search 41 | 42 | # params of llm 43 | llm_server_url: http://10.32.4.13:10080/v1 44 | llm_api_key: EMPTY 45 | llm_name: qwen2-72b-instruct 46 | llm_server_type: online 47 | llm_query_few_shot: True 48 | wo_llm: False 49 | online_concurrency: 64 50 | answer_temperature: 0. 51 | plan_temperature: 0.3 52 | answer_top_p: 1. 53 | plan_top_p: 1. 54 | 55 | 56 | # other params 57 | only_eval_answer: False 58 | test: True 59 | backend: sglang 60 | use_planning_cache: False 61 | cache_dir: "none" 62 | data_path: ./freshqa/FreshQA_v12182024.jsonl 63 | dataname: 2WikiMultiHopQA # for few-shot 64 | output_dir: ./output_dir/run/ood/freshqa_12182024 -------------------------------------------------------------------------------- /C-3PO/config/multihoprag_server.yaml: -------------------------------------------------------------------------------- 1 | checkpoint_dir: ./ppo_106_344_ckpt160 2 | tp: 1 3 | port: 20000 4 | proxy_concurrency: 256 5 | temperature: 0.6 6 | top_k: -1 7 | top_p: 1.0 8 | max_tokens: 2048 9 | model_type: proxy 10 | gpt_proxy_concurrency: 8 11 | openai_api_key: None 12 | proxy_url: None 13 | proxy_model_name: None 14 | force_decision: False 15 | force_action: Planning # "Planning", "Retrieval" No Retrieval 16 | 17 | ## params of search 18 | n_decision_sample: 1 19 | n_generate_sample: 1 20 | n_plan_sample: 1 21 | n_answer_sample: 1 22 | max_iter: 20 23 | search_config: ./config/search_config.yaml 24 | few_shot: False 25 | few_num: 2 26 | dict_few_num: 1 27 | decide_prompt: "identity" 28 | max_depth: 13 29 | max_documents: 10 30 | answer_eval: "none" 31 | retriever_type: "search_engine" # search_engine dense 32 | retrieve_server_url: http://10.32.25.199:35004/search 33 | musique_server_url: http://10.32.25.199:35002/search 34 | retrieve_top_k: 10 35 | max_query_length: 100 36 | search_engine_url: xxx 37 | search_scene: xx 38 | search_engine_cache: True 39 | search_engine_cache_file: ./cache_search 40 | 41 | # params of llm 42 | llm_server_url: http://10.32.4.13:10080/v1 43 | llm_api_key: EMPTY 44 | llm_name: qwen2-72b-instruct 45 | llm_server_type: online 46 | llm_query_few_shot: True 47 | wo_llm: False 48 | online_concurrency: 64 49 | answer_temperature: 0. 50 | plan_temperature: 0.3 51 | answer_top_p: 1. 52 | plan_top_p: 1. 53 | 54 | 55 | # other params 56 | only_eval_answer: False 57 | test: True 58 | backend: sglang 59 | use_planning_cache: False 60 | cache_dir: "none" 61 | data_path: ./MultiHopRAG/multihoprag_test.jsonl 62 | dataname: 2WikiMultiHopQA 63 | output_dir: ./run/ood/multihoprag -------------------------------------------------------------------------------- /C-3PO/config/search_config.yaml: -------------------------------------------------------------------------------- 1 | role: 2 | MAKE_DECISION: 3 | name: "MAKE_DECISION" 4 | actions: 5 | PLANNING_ACTION: "Planning" # -> planning_lst 6 | RETRIEVAL_ACTION: "Retrieval" # -> retrieval_lst 7 | NO_RETRIEVAL_ACTION: "No Retrieval" # -> decide_prompt_lst 8 | 9 | LLM_PLANNING: 10 | name: "LLM_PLANNING" 11 | actions: 12 | LLM_PLANNING_ACTION: "llm_planning" # -> after_planning_lst 13 | 14 | DECIDE_NEXT_STEP: 15 | name: "DECIDE_NEXT_STEP" 16 | actions: 17 | RETRIEVAL_ACTION: "Retrieval" # -> retrieval_lst 18 | LLM_ACTION: "LLM" # -> decide_prompt_lst 19 | 20 | RETRIEVE_AND_FILTER: 21 | name: "RETRIEVE_AND_FILTER" 22 | actions: 23 | RETRIEVE_AND_FILTER_ACTION: "retrieve_and_filter" # -> decide_next_step_lst 24 | DIRECT_RETRIEVE_AND_FILTER_ACTION: "direct_retrieve_and_filter" # -> decide_prompt_lst 25 | 26 | 27 | DECIDE_PROMPT: 28 | name: "DECIDE_PROMPT" 29 | actions: 30 | DECIDE_PROMPT_ACTION: "decide_prompt" # -> decide_prompt_lst 31 | 32 | QUERY_LLM: 33 | name: "QUERY_LLM" 34 | actions: 35 | QUERY_LLM_ACTION: "query_llm" # -> query_llm_lst 36 | 37 | # EXTRACT_ANSWER: 38 | # name: "EXTRACT_ANSWER" 39 | # actions: 40 | # QUERY_LLM_ACTION: "query_llm" # -> query_llm_lst 41 | 42 | 43 | few_shot_file_name: 44 | - decision # MAKE_DECISION 45 | - planning # LLM_PLANNING 46 | - decide_next_step # DECIDE_NEXT_STEP 47 | - filter # RETRIEVE_AND_FILTER 48 | - extract_answer 49 | - answer 50 | 51 | test_few_shot_file_name: 52 | - planning 53 | - extract_answer 54 | - answer 55 | 56 | USE_PUBLIC: 57 | planning: 58 | - NaturalQuestions 59 | - TriviaQA 60 | decide_next_step: 61 | - NaturalQuestions 62 | - TriviaQA 63 | answer: 64 | - 2WikiMultiHopQA 65 | - hotpotqa 66 | - Musique 67 | - NaturalQuestions 68 | - PopQA 69 | - TriviaQA 70 | -------------------------------------------------------------------------------- /C-3PO/construct_cache.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.chdir(sys.path[0]) 3 | import os.path as osp 4 | import random 5 | import yaml 6 | import json 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | from tqdm import tqdm 11 | from transformers import AutoTokenizer 12 | from pebble import ProcessPool, ThreadPool 13 | 14 | # from solver import Solver 15 | # from retrieve.retriever import BasicRAG 16 | # from search.search import Search 17 | from arguments import get_args, set_seed 18 | from utils import load_data, create_batches, load_agent, load_json 19 | from utils import few_shot_random_select, load_few_shot 20 | from prompt.planning import prompt_llm_planning, prompt_llm_planning_few_shot, instruct_for_each_dataset 21 | 22 | 23 | class LLM_AGENT(): 24 | def __init__(self, args, llm, sampling_params): 25 | self.args = args 26 | self.llm = llm 27 | self.sampling_params = sampling_params 28 | 29 | def generate(self, query_list): 30 | index = query_list.get("index", 0) 31 | messages = query_list["messages"] 32 | sampling_params = query_list["sampling_params"] 33 | 34 | outputs = self.llm.chat.completions.create( 35 | model="default", 36 | messages=messages, 37 | n=sampling_params['n'], 38 | temperature=sampling_params['temperature'], 39 | top_p=sampling_params['top_p'], 40 | max_tokens=sampling_params['max_new_tokens'], 41 | ) 42 | return {"index": index, "outputs": outputs} 43 | 44 | if __name__=="__main__": 45 | args = get_args(write_to_file=False) 46 | set_seed(args.seed) 47 | 48 | if args.test: 49 | args.cache_dir = osp.join(args.cache_dir, args.dataname, "test.json") 50 | else: 51 | args.cache_dir = osp.join(args.cache_dir, args.dataname, "ppo_train.json") 52 | 53 | args.max_tokens = 2048 54 | args.temperature = 0.3 55 | args.top_k = -1 56 | args.top_p = 1.0 57 | args.model_type = "proxy" 58 | args.n_plan_sample = 5 59 | args.filter = False 60 | args.few_shot = True 61 | args.focus_qid = "" 62 | 63 | args.data_file_name = "ppo_train_6000.jsonl" 64 | 65 | # tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_dir) 66 | 67 | if osp.exists(args.cache_dir): 68 | cache_data = load_json(args.cache_dir) 69 | else: 70 | os.makedirs(osp.dirname(args.cache_dir), exist_ok=True) 71 | cache_data = {} 72 | 73 | data = load_data(args) 74 | logger.info(f"{args.dataname} data loaded, length: {len(data)}") 75 | 76 | available_gpus = os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',') 77 | print(f"available_gpus: {available_gpus}") 78 | llm, sampling_params, server_process = load_agent(args, num_gpus=len(available_gpus)) 79 | llm_agent = LLM_AGENT(args, llm, sampling_params) 80 | 81 | need_generate_data = [] 82 | for item in data: 83 | key = f"{item['id']}_{item['question']}" 84 | if key not in cache_data: 85 | need_generate_data.append(item) 86 | 87 | with open(args.search_config, 'r') as f: 88 | config = yaml.safe_load(f) 89 | few_shot_examples = load_few_shot(args, config=config) 90 | 91 | batch_size = 1000 92 | 93 | for i in tqdm(range(0, len(need_generate_data), batch_size)): 94 | batch_data = need_generate_data[i:i + batch_size] 95 | key_list = [] 96 | prompt_list = [] 97 | 98 | # 对每个batch内的数据生成prompt 99 | for item in batch_data: 100 | key = f"{item['id']}_{item['question']}" 101 | for i in range(args.n_plan_sample): 102 | key_list.append(key) 103 | info = { 104 | "examples": few_shot_random_select(few_shot_examples, 'planning', num=args.few_num, dict_num=args.dict_few_num), 105 | "question": item['question'], 106 | "dataset_instructions": instruct_for_each_dataset.get(args.dataname, ""), 107 | } 108 | input_text = prompt_llm_planning_few_shot.format_map(info) 109 | message = [ 110 | {"role": "system", "content": "You are a helpful assistant."}, 111 | {"role": "user", "content": input_text}, 112 | ] 113 | # input_template_text = tokenizer.apply_chat_template( 114 | # message, 115 | # tokenize=False, 116 | # add_generation_prompt=True 117 | # ) 118 | prompt_list.append(message) 119 | 120 | # 处理当前batch 121 | sampling_params["n"] = 2 122 | query_list = [{"index": i, "messages": p, "sampling_params": sampling_params} for i, p in enumerate(prompt_list)] 123 | with ThreadPool(max_workers=args.proxy_concurrency) as pool: 124 | future = pool.map(llm_agent.generate, query_list, timeout=180) 125 | outputs = list(future.result()) 126 | # outputs = llm.generate(prompt_list, sampling_params) 127 | # outputs = [outputs[i:i+sampling_params["n"]] for i in range(0, len(outputs), sampling_params["n"])] 128 | 129 | # 保存结果 130 | for key, output in zip(key_list, outputs): 131 | if key not in cache_data: 132 | cache_data[key] = [] 133 | for o in output['outputs'].choices: 134 | if o.message.content not in cache_data[key]: 135 | cache_data[key].append(o.message.content) 136 | 137 | 138 | print(f"cache_data length: {len(cache_data)}") 139 | print(f"save at {args.cache_dir}") 140 | print(f"per question generate {args.n_plan_sample * sampling_params['n']} samples") 141 | with open(args.cache_dir, 'w') as f: 142 | json.dump(cache_data, f, indent=2) 143 | 144 | from sglang.utils import terminate_process 145 | terminate_process(server_process) -------------------------------------------------------------------------------- /C-3PO/file_query.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.chdir(sys.path[0]) 3 | import os.path as osp 4 | import random 5 | import yaml 6 | import time 7 | 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer 10 | 11 | from types import SimpleNamespace 12 | 13 | from solver import Solver 14 | from search.search import Search 15 | from utils import print_tree, load_jsonl 16 | 17 | 18 | if __name__=="__main__": 19 | try: 20 | # load yaml CUDA_VISIBLE_DEVICES=0 21 | # with open("./config/math500_server.yaml", "r") as f: 22 | with open("./config/freshqa_server.yaml", "r") as f: 23 | # with open("./config/multihoprag_server.yaml", "r") as f: 24 | args = SimpleNamespace(**yaml.load(f, Loader=yaml.FullLoader)) 25 | 26 | if "dashscope" in args.llm_server_url: 27 | args.online_concurrency = 4 28 | 29 | args.llm_server_url = args.llm_server_url.split(',') 30 | args.timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 31 | solver = Solver(args) 32 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_dir) 33 | ## refine save path 34 | epoch = 0 35 | if args.force_decision: 36 | args.output_dir = osp.join(args.output_dir, f"{args.model_base_name}_{args.max_depth}", args.force_action, f"{args.llm_name}", args.timestamp) 37 | else: 38 | args.output_dir = osp.join(args.output_dir, f"{args.model_base_name}_{args.max_depth}", f"{args.llm_name}_{args.retriever_type}", args.timestamp) 39 | file_path = osp.join(args.output_dir, "collectted_solutions") 40 | os.makedirs(file_path, exist_ok=True) 41 | epoch_file_path = osp.join(file_path, f'collectted_solutions_{epoch}.jsonl') 42 | solver.epoch_file_path = epoch_file_path 43 | solver.final_file_path = osp.join(file_path, f"{epoch}_final_result.jsonl") 44 | 45 | # check cache 46 | if args.use_planning_cache: 47 | solver.llm_server.check_cache(data) 48 | 49 | # load file 50 | data = load_jsonl(args.data_path) 51 | 52 | 53 | trees = [Search(args=args, data_item=item, tree_tag=str(idx), tokenizer=tokenizer, server=False) for idx, item in enumerate(data)] 54 | all_save_tree_info = solver.solve(trees) 55 | 56 | solver.retrieve.retriever.save_cache() 57 | # print_tree(all_save_tree_info) 58 | 59 | # 把args保存到文件 60 | with open(osp.join(args.output_dir, "args.yaml"), "w") as f: 61 | yaml.dump(vars(args), f) 62 | 63 | 64 | if args.model_type == "proxy" and args.backend == "sglang": 65 | from sglang.utils import terminate_process 66 | terminate_process(solver.proxy.server_process) 67 | 68 | except Exception as e: 69 | print(e) 70 | if args.model_type == "proxy" and args.backend == "sglang": 71 | from sglang.utils import terminate_process 72 | terminate_process(solver.proxy.server_process) 73 | raise e -------------------------------------------------------------------------------- /C-3PO/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import os 4 | import json 5 | import warnings 6 | import os.path as osp 7 | warnings.filterwarnings("ignore", category=DeprecationWarning) 8 | logger = logging.getLogger(__name__) 9 | 10 | def config_logging(file_name, write_to_file=True): 11 | fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s' 12 | if write_to_file: 13 | file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") 14 | # %(asctime)s - [%(filename)s:%(funcName)s:%(lineno)s:%(levelname)s] - %(message)s 15 | formatter = logging.Formatter(fmt, datefmt="%Y/%m/%d %H:%M:%S") 16 | file_handler.setFormatter(formatter) 17 | file_handler.setLevel(logging.DEBUG) 18 | 19 | console_handler = logging.StreamHandler(sys.stdout) 20 | console_handler.setFormatter(logging.Formatter(fmt, datefmt="%Y/%m/%d %H:%M:%S")) 21 | console_handler.setLevel(logging.INFO) 22 | 23 | logging.basicConfig( 24 | level=logging.INFO, 25 | handlers=[file_handler, console_handler] if write_to_file else [console_handler], 26 | ) 27 | 28 | 29 | def log_params(FLAGS, write_to_file=True): 30 | # 配置logging 31 | if write_to_file: 32 | os.makedirs(FLAGS.output_dir, exist_ok=True) 33 | 34 | config_logging(osp.join(FLAGS.output_dir, 'logfile.log'), write_to_file) 35 | 36 | if write_to_file: 37 | for k, v in FLAGS.__dict__.items(): 38 | logger.info(k + ":" + str(v)) 39 | 40 | with open(osp.join(FLAGS.output_dir, 'commandline_args.json'), 'w') as f: 41 | json.dump(FLAGS.__dict__, f, indent=2) 42 | 43 | 44 | -------------------------------------------------------------------------------- /C-3PO/main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.chdir(sys.path[0]) 3 | import os.path as osp 4 | import random 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer 11 | 12 | from solver import Solver 13 | from search.search import Search 14 | from arguments import get_args, set_seed 15 | from utils import load_data, create_batches 16 | 17 | if __name__=="__main__": 18 | try: 19 | args = get_args(write_to_file=True) 20 | set_seed(args.seed) 21 | data = load_data(args) 22 | logger.info(f"{args.dataname} data loaded, length: {len(data)}") 23 | 24 | file_path = osp.join(args.output_dir, "collectted_solutions") 25 | os.makedirs(file_path, exist_ok=True) 26 | 27 | solver = Solver(args) 28 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_dir) 29 | 30 | # check cache 31 | if args.use_planning_cache: 32 | solver.llm_server.check_cache(data) 33 | 34 | for epoch in range(args.num_epoch): 35 | if not args.test: 36 | random.shuffle(data) 37 | epoch_file_path = osp.join(file_path, f'collectted_solutions_{epoch}.jsonl') 38 | solver.epoch_file_path = epoch_file_path 39 | solver.final_file_path = osp.join(file_path, f"{epoch}_final_result.jsonl") 40 | logger.info(f"********** EPOCH {epoch} ***********") 41 | batch_num = 0 42 | for batch_data in tqdm(create_batches(data, args.batch_size), desc="batch data"): 43 | batch_num += 1 44 | logger.info(f"begin {batch_num}, {len(batch_data)}") 45 | sys.stdout.flush() 46 | trees = [Search(args=args, data_item=item, tree_tag=str(idx), tokenizer=tokenizer) for idx, item in enumerate(batch_data)] 47 | solver.solve(trees) 48 | 49 | if args.model_type == "proxy" and args.backend == "sglang": 50 | from sglang.utils import terminate_process 51 | terminate_process(solver.proxy.server_process) 52 | except Exception as e: 53 | logger.exception(e) 54 | if args.model_type == "proxy" and args.backend == "sglang": 55 | from sglang.utils import terminate_process 56 | terminate_process(solver.proxy.server_process) 57 | raise e -------------------------------------------------------------------------------- /C-3PO/metrics.py: -------------------------------------------------------------------------------- 1 | import re, json, string 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | def normalize_answer(s): 6 | def remove_articles(text): 7 | return re.sub(r"\b(a|an|the)\b", " ", text) 8 | 9 | def white_space_fix(text): 10 | return " ".join(text.split()) 11 | 12 | def remove_punc(text): 13 | exclude = set(string.punctuation) 14 | return "".join(ch for ch in text if ch not in exclude) 15 | 16 | def lower(text): 17 | return text.lower() 18 | 19 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 20 | 21 | 22 | def exact_presence(answers, context): 23 | """Verify if any of the answers is present in the given context.""" 24 | 25 | answers = [normalize_answer(ans) for ans in answers] 26 | context = " ".join(context.split("#@")) 27 | context = normalize_answer(context) 28 | 29 | for ans in answers: 30 | if ans in context: 31 | return 1 32 | 33 | return 0 34 | 35 | def compute_str_em(item, _key): 36 | """Compute STR-EM metric (only for ASQA item) 37 | Args: 38 | data: requires field `qa_pairs/short_answers` and `output` 39 | Returns: 40 | STR-EM and STR-EM-HIT () 41 | """ 42 | 43 | if 'qa_pairs' not in item or item['qa_pairs'] is None: 44 | return 0, 0 45 | 46 | 47 | loc_acc = [] 48 | for qa_pair in item['qa_pairs']: 49 | loc_acc.append(exact_presence(qa_pair['answers'], item[_key])) 50 | 51 | acc = np.mean(loc_acc) 52 | hit = int(np.mean(loc_acc) == 1) 53 | 54 | return acc, hit 55 | 56 | 57 | def get_item_metrics(item, _key='response', is_asqa=False): 58 | 59 | if is_asqa: 60 | acc, _ = compute_str_em(item, _key) 61 | else: 62 | acc = exact_presence(item['answers'], item[_key]) 63 | 64 | return acc 65 | 66 | def directly_get_metrics(answers, response): 67 | acc = exact_presence(answers, response) 68 | return acc 69 | 70 | 71 | if __name__ == "__main__": 72 | item = { 73 | "answers": ["december – 7th"], 74 | # "response": "March 19, 1848", 75 | "response": "December #@ December 7th #@ December – 7th", # 'december – 7th' 76 | } 77 | print(get_item_metrics(item, is_asqa=False)) -------------------------------------------------------------------------------- /C-3PO/prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/__init__.py -------------------------------------------------------------------------------- /C-3PO/prompt/decide_next_step.py: -------------------------------------------------------------------------------- 1 | decide_next_step_prefix = """You are an intelligent assistant tasked with determining the next appropriate action based on the provided existing documents, plan, and question. You have access to a large language model (LLM) for answering question and a retrieval system for gathering additional documents. Your objective is to decide whether to write a query for retrieving relevant documents or to generate a comprehensive answer using the LLM based on the existing documents and plan. 2 | 3 | Instructions: 4 | 1. **Evaluate Existing Documents**: Assess the existing documents to determine if it is sufficient to answer the question. 5 | 2. **Follow the Plan**: Understand the next steps outlined in the plan. 6 | 3. **Decision Categories:** 7 | - If the existing documents is insufficient and requires additional retrieval, respond with: 8 | [Retrieval] `YOUR QUERY HERE` 9 | - If the existing documents is adequate to answer the question, respond with: 10 | [LLM] 11 | 4. **Focus on Action**: Do not answer the question directly; concentrate on identifying the next appropriate action based on the existing documents, plan, and question. 12 | 5. **Output Format**: 13 | Thought: [Your analysis for current situation (need retrieval for additional informations or use LLM to answer)] 14 | Action: [Your decision based on the analysis (Retrieval or LLM)]""" 15 | 16 | decide_next_step_suffix = """Now, process the following question:\n\nExisting Documents: {existing_documents} 17 | 18 | Plan: {planning} 19 | 20 | Question: {question}\n""" 21 | 22 | decide_next_step_human_input = """Existing Documents: {existing_documents} 23 | 24 | Plan: {planning} 25 | 26 | Question: {question}\n""" 27 | 28 | few_shot_template = """Here are some examples: 29 | {examples}""" 30 | 31 | decide_next_step = decide_next_step_prefix + "\n\n" + decide_next_step_suffix 32 | decide_next_step_few_shot = decide_next_step_prefix + "\n\n" + few_shot_template + "\n\n" + decide_next_step_suffix 33 | 34 | 35 | # 单纯LLM 36 | decide_next_step_LLM_prefix = """You are an intelligent assistant assigned to analyze useful information from the existing documents and plan for responding to a question. 37 | 38 | Instructions: 39 | 1. **Evaluate Existing Documents**: Thoroughly review the provided documents to extract useful information relevant to the question. 40 | 2. **Decision Categories:** 41 | - After your analysis, you should respond with: 42 | [LLM] 43 | 3. **Focus on Action**: Focus solely on identifying and analyzing relevant information rather than answering the question directly. 44 | 4. **Output Format**: 45 | Thought: [Provide a detailed analysis outlining the useful information.] 46 | Action: [LLM]""" 47 | 48 | 49 | decide_next_step_LLM = decide_next_step_LLM_prefix + "\n\n" + decide_next_step_suffix 50 | decide_next_step_LLM_few_shot = decide_next_step_LLM_prefix + "\n\n" + few_shot_template + "\n\n" + decide_next_step_suffix 51 | -------------------------------------------------------------------------------- /C-3PO/prompt/decide_prompt.py: -------------------------------------------------------------------------------- 1 | 2 | DECIDE_PROMPT_QP = """You are tasked with selecting the most appropriate prompt from a list of candidates based on a given question and passages. Your role is to evaluate the suitability of each prompt in relation to the provided context. 3 | 4 | Instructions: 5 | 1. You will receive a question, a set of passages, and a list of candidate prompts. 6 | 2. Assess each prompt for its relevance and suitability in addressing the given question, taking into account the provided passages if they are available. 7 | 3. **Do not** attempt to answer the question or modify the prompts; your responsibility is solely to evaluate and select. 8 | 4. Output the index of the prompt that you believe is the best fit for the provided question and passages. The output should be in the form of a single integer. 9 | 10 | Context: 11 | - Question: {questions} 12 | - Passages: 13 | {passages} 14 | - Candidate Prompts: 15 | {prompts} 16 | 17 | Considerations: 18 | - Evaluate the prompts based on clarity, relevance, and potential effectiveness in guiding an LLM to generate a high-quality response to the question. 19 | - Consider how well each prompt aligns with the main themes or concepts in the provided question and passages. 20 | 21 | Please provide your output below:""" 22 | 23 | 24 | DECIDE_PROMPT_Q = """Your task is to select the most appropriate prompt based on a given question. You need to evaluate the suitability of each prompt to effectively address the provided question. 25 | 26 | Instructions: 27 | 1. You will receive a question and a list of candidate prompts. 28 | 2. Assess each prompt for its relevance and effectiveness in addressing the given question. 29 | 3. **Do not** attempt to answer the question or modify the prompts; your responsibility is solely to evaluate and select. 30 | 4. Output the index of the prompt that you believe is the best fit for the provided question. The output should be in the form of a single integer enclosed in square brackets. 31 | 32 | Context: 33 | - Question: {questions} 34 | - Candidate Prompts: 35 | {prompts} 36 | 37 | Considerations: 38 | - Evaluate the prompts on the basis of clarity, relevance, and their potential effectiveness in guiding an LLM to generate a high-quality response to the question. 39 | 40 | Please provide your output below:""" 41 | -------------------------------------------------------------------------------- /C-3PO/prompt/decide_prompt_identity.py: -------------------------------------------------------------------------------- 1 | NO_DOCUMENT_PROMPT = """{dataset_instructions}\nBased on your knowledge, answer the question:\n{question}""" 2 | 3 | DOCUMENT_PROMPT = """Existing documents: {documents}\n\n{dataset_instructions}\nBased on your knowledge and the provided information, answer the question:\n{question}""" 4 | 5 | POPQA_DECISION = """Note that the question mainly asks about the object entity that holds a certain relationship with the given subject entity. There may be multiple correct answers. Make sure your response includes all correct answers and provides clear reasoning details followed by a concise conclusion.""" 6 | OTHERS_DECISION = """Note that the question may be compositional and require intermediate analysis to deduce the final answer. Make sure your response is grounded and provides clear reasoning details followed by a concise conclusion.""" 7 | 8 | instruct_for_each_dataset = { 9 | "PopQA": POPQA_DECISION, 10 | "2WikiMultiHopQA": OTHERS_DECISION, 11 | "NaturalQuestions": OTHERS_DECISION, 12 | "TriviaQA": OTHERS_DECISION, 13 | "hotpotqa": OTHERS_DECISION, 14 | "Musique": OTHERS_DECISION, 15 | } 16 | 17 | 18 | -------------------------------------------------------------------------------- /C-3PO/prompt/decide_prompt_identity_v2.py: -------------------------------------------------------------------------------- 1 | NO_DOCUMENT_PROMPT = """{dataset_instructions}\nBased on your knowledge, answer the question:\n{question}""" 2 | 3 | DOCUMENT_PROMPT_PREFIX = """You are a knowledgeable assistant. Please answer the following question by: 4 | 1. First reviewing the provided documents. Please extract the relevant information and ignore irrelevant information about the question. 5 | 2. Combining relevant information from the documents (if any) with your own knowledge to generate a response. 6 | 3. {dataset_instructions}""" 7 | 8 | few_shot_template = """Here are some examples: 9 | {examples}""" 10 | 11 | DOCUMENT_PROMPT_SUFFIX = """Now, process the following question: 12 | Existing documents: {documents}\n\nQuestion: {question}""" 13 | 14 | 15 | DOCUMENT_PROMPT = DOCUMENT_PROMPT_PREFIX + "\n\n" + DOCUMENT_PROMPT_SUFFIX 16 | DOCUMENT_PROMPT_FEW_SHOT = DOCUMENT_PROMPT_PREFIX + "\n\n" + few_shot_template + "\n\n" + DOCUMENT_PROMPT_SUFFIX 17 | 18 | POPQA_DECISION = """Note that the question mainly asks about the object entity that holds a certain relationship with the given subject entity. There may be multiple correct answers. Make sure your response includes all correct answers and provides clear reasoning details followed by a concise conclusion.""" 19 | OTHERS_DECISION = """Note that the question may be compositional and require intermediate analysis to deduce the final answer. Make sure your response is grounded and provides clear reasoning details followed by a concise conclusion.""" 20 | 21 | instruct_for_each_dataset = { 22 | "PopQA": POPQA_DECISION, 23 | "2WikiMultiHopQA": OTHERS_DECISION, 24 | "NaturalQuestions": OTHERS_DECISION, 25 | "TriviaQA": OTHERS_DECISION, 26 | "hotpotqa": OTHERS_DECISION, 27 | "Musique": OTHERS_DECISION, 28 | } 29 | 30 | -------------------------------------------------------------------------------- /C-3PO/prompt/decision.py: -------------------------------------------------------------------------------- 1 | prompt_decision_making_prefix = """You are an intelligent assistant tasked with evaluating whether a given question requires further information through retrieval or needs planning to arrive at an accurate answer. You will have access to a large language model (LLM) for planning or answering the question and a retrieval system to provide relevant information about the query. 2 | 3 | Instructions: 4 | 1. **Evaluate the Question**: Assess whether a precise answer can be provided based on the existing knowledge of LLM. Consider the specificity, complexity, and clarity of the question. 5 | 2. **Decision Categories:** 6 | - If the question is complex and requires a planning phase before retrieval, your response should be: 7 | [Planning] 8 | - If the question requests specific information that you believe the LLM does not possess or pertains to recent events or niche topics outside LLM's knowledge scope, format your response as follows: 9 | [Retrieval] `YOUR QUERY HERE` 10 | - If you think the LLM can answer the question without additional information, respond with: 11 | [No Retrieval] 12 | 3. **Focus on Assessment**: Avoid providing direct answers to the questions. Concentrate solely on determining the necessity for retrieval or planning.{dataset_instructions}""" # {dataset_instructions} 13 | 14 | prompt_decision_making_suffix_few_shot = """Now, process the following question:\n\nQuestion: {question}\n""" 15 | 16 | decision_human_input = "Question: {question}\n" 17 | 18 | few_shot_template = """Here are some examples: 19 | {examples}""" 20 | 21 | prompt_decision_making = prompt_decision_making_prefix + "\n\n" + prompt_decision_making_suffix_few_shot 22 | 23 | prompt_decision_making_few_shot = prompt_decision_making_prefix + "\n\n" + few_shot_template + "\n\n" + prompt_decision_making_suffix_few_shot 24 | 25 | POPQA_DECISION = """\n4. Keep in mind that the question mainly asks about the object entity that holds a certain relationship with the given subject entity. There may be multiple correct answers.""" 26 | OTHERS_DECISION = """\n4. Keep in mind that the question may be compositional.""" 27 | 28 | instruct_for_each_dataset = { 29 | "PopQA": POPQA_DECISION, 30 | "2WikiMultiHopQA": OTHERS_DECISION, 31 | "NaturalQuestions": OTHERS_DECISION, 32 | "TriviaQA": OTHERS_DECISION, 33 | "hotpotqa": OTHERS_DECISION, 34 | "Musique": OTHERS_DECISION, 35 | } 36 | 37 | 38 | 39 | 40 | if __name__ == "__main__": 41 | print(prompt_decision_making_few_shot) 42 | -------------------------------------------------------------------------------- /C-3PO/prompt/evaluation.py: -------------------------------------------------------------------------------- 1 | evaluation_prompt_prefix = """You are a precise answer validator. Your task is to compare the predicted answer with a set of acceptable correct answers and determine if the prediction matches any of them. 2 | 3 | Input format: 4 | Question: [The question text] 5 | Correct Answers: [Array or list of acceptable correct answers] 6 | Predicted Answer: [The answer to be evaluated] 7 | 8 | Rules: 9 | 1. Consider semantic equivalence, not just exact string matching 10 | 2. Ignore minor differences in formatting, spacing, or capitalization 11 | 3. For numerical answers, consider acceptable margin of error if applicable 12 | 4. For text answers, focus on the core meaning rather than exact wording 13 | 5. The predicted answer is considered correct if it matches ANY ONE of the provided correct answers 14 | 6. The matching can be exact or semantically equivalent to any of the correct answers 15 | 7. Return only "True" if the predicted answer is correct, or "False" if it is incorrect.""" 16 | 17 | few_shot_template = """Here are some examples: 18 | {examples}""" 19 | 20 | evaluation_prompt_suffix = """Now, process the following question: 21 | Question: {question} 22 | Correct Answer: {true_answer} 23 | Predicted Answer: {long_answer}\n""" 24 | 25 | 26 | evaluation_prompt = evaluation_prompt_prefix + "\n\n" + evaluation_prompt_suffix 27 | evaluation_prompt_few_shot = evaluation_prompt_prefix + "\n\n" + few_shot_template + "\n\n" + evaluation_prompt_suffix 28 | 29 | 30 | 31 | if __name__ == "__main__": 32 | query = evaluation_prompt.format( 33 | question="When was Philip, Count Of Egmont's father born?", 34 | true_answer="['18 November 1522']", 35 | long_answer="The documents provided do not contain information about Philip, Count of Egmont. However, they do provide information about Lamoral, Count of Egmont, who was born on November 18, 1522. Lamoral's father, John IV, is mentioned, but his birthdate is not provided. Therefore, it is not possible to determine when Philip, Count of Egmont's father was born based on the given documents. It should be noted that there seems to be a confusion in the question, as the documents only provide information about Lamoral, Count of Egmont, and not Philip, Count of Egmont.", 36 | ) 37 | print() 38 | 39 | "Philip, Count of Egmont's father was Lamoral, Count of Egmont. According to Document 1 and Document 3, Lamoral, Count of Egmont was born on November 18, 1522. Therefore, Philip, Count of Egmont's father was born on November 18, 1522." -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/few_shots/__init__.py -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decide_next_step/TriviaQA.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | "no_existing_documents": [ 3 | # first step 4 | 5 | 6 | 7 | 8 | ], 9 | 10 | "has_existing_documents": [ 11 | 12 | 13 | ], 14 | 15 | "LLM": [ 16 | 17 | 18 | ] 19 | } 20 | 21 | 22 | """Existing Documents: \n\nPlan: \n\nQuestion: \n\nThought: \nAction: [Retrieval]""", # Russell Rouse -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decide_next_step/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/few_shots/decide_next_step/__init__.py -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decision/2WikiMultiHopQA.py: -------------------------------------------------------------------------------- 1 | # decision_prompt = { 2 | # "no_retrieve": [ 3 | # # [No Retrieval] 4 | # """Question: Are the directors of films Esterina and Fence Riders both from the same country?\n[No Retrieval]""", 5 | # """Question: Which award the performer of song Proud Mary earned?\n[No Retrieval]""", 6 | # """Question: Where was the director of film Lost In Siberia born?\n[No Retrieval]""", 7 | # ], 8 | # "retrieve": [ 9 | # # [Retrieval] 10 | # """Question: Are both bands, Chris Robinson Brotherhood and Outbreakband, from the same country?\n[Retrieval] Are Chris Robinson Brotherhood and Outbreakband from the same country?""", 11 | # """Question: Are both The Willowz and Deepfield from the same country?\n[Retrieval] Are both The Willowz and Deepfield from the same country?""", 12 | # # """Question: when was the first general election held in india?\n[Retrieval] when was the first general election held in india?""", 13 | # # """Question: Qbasic is the extension of which programming language?\n[Retrieval] Qbasic is the extension of which programming language?""", 14 | # ], 15 | # "planning": [ 16 | # # [Planning] 17 | # """Question: What nationality is the director of film The Caper Of The Golden Bulls?\n[Planning] To answer the question, we need to find information about the director of the film "The Caper of the Golden Bulls." Then we should determine which nationality is the director born using the retrieval.\nStep 1: Find the relevant documents that mention the film `The Caper of the Golden Bulls.`\nStep 2: Identify the director of the film from the retrieved documents.\nStep 3: Find the relevant information about `Which nationality is the director born?`.\nStep 4: Provide the answer based on the retrieved information.""", 18 | # """Question: When is Dominic Roco's father's birthday?\n[Planning] I don't have access to personal information about who is the father of Dominic Roco, including his birthday. However, we can retrieve information to answer the question about Dominic Roco's father's birthday.\nSTEP 1: We should find the relevant documents about `Dominic Roco` to identify his father's name.\nSTEP 2: Identify the father's name from the retrieved documents.\nSTEP 3: Find the relevant information about the father's birthday.\nSTEP 4: Provide the answer based on the retrieved information.""", 19 | # """Question: Why did the director of film The Notorious Landlady die?\n[Planning] To answer the question, we need to find information about `the director of film The Notorious Landlady`. Then, we should retrieve the cause of the director's death.\nSTEP 1: Retrieve the relevant documents about `the director of film The Notorious Landlady`.\nSTEP 2: Identify the director of the film from the retrieved documents.\nSTEP 3: Find the relevant information about `Why did the director die?`.\nSTEP 4: Finish the answer based on the retrieved information.""", 20 | # """Question: Where was the place of death of the director of film Magic Mirror (Film)?\n[Planning] We know that the film "Magic Mirror" was directed by Manoel de Oliveira. To answer the question, we need to find information about the place of death of Manoel de Oliveira.\nSTEP 1: Retrieve the relevant documents about `Where was the place of death of Manoel de Oliveira?`.\nSTEP 2: Find the place of death of Manoel de Oliveira from the retrieved documents and provide the answer.""", 21 | # ] 22 | # } 23 | 24 | 25 | 26 | EXAMPLES = { 27 | "no_retrieve": [ 28 | # [No Retrieval] 29 | """Question: Do both films The Falcon (Film) and Valentin The Good have the directors from the same country?\n[No Retrieval]""", 30 | """Question: Where did the director of film Ride The Man Down die?\n[No Retrieval]""", 31 | """Question: Which film has the director born later, The Countess Of Parma or Prem Bandhan?\n[No Retrieval]""", 32 | """Question: Did the movies Karılar Koğuşu and A Pizza In Jordbro, originate from the same country?\n[No Retrieval]""", 33 | """Question: What is the place of birth of the composer of song Gretchen Am Spinnrade?\n[No Retrieval]""", 34 | """Question: Were Mary Schiavo and Faisal Al-Dakhil from the same country?\n[No Retrieval]""", 35 | """Question: Are the movies The Market Of Vain Desire and Asokamala, from the same country?\n[No Retrieval]""", 36 | ], 37 | "retrieve": [ 38 | # [Retrieval] 39 | """Question: Who is the stepchild of Lysicles (5Th Century Bc)?\n[Retrieval] Who is the stepchild of Lysicles (5Th Century Bc)?""", 40 | """Question: Who is the father-in-law of Infanta Blanca Of Spain?\n[Retrieval] Who is the father-in-law of Infanta Blanca Of Spain?""", 41 | """Question: Which film was released first, Sweet And Twenty or Caravan Of Death (Film)?\n[Retrieval] Which film was released first, Sweet And Twenty or Caravan Of Death (Film)?""", 42 | """Question: What is the date of birth of E. C. Spykman's husband?\n[Retrieval] What is the date of birth of E. C. Spykman's husband?""", 43 | """Question: Why did the performer of song Someday My Day Will Come die?\n[Retrieval] Why did the performer of song Someday My Day Will Come die?""", 44 | """Question: Do director of film The Nines and director of film The Sea Wolf (1920 Film) share the same nationality?\n[Retrieval] Do director of film The Nines and director of film The Sea Wolf (1920 Film) share the same nationality?""", 45 | """Question: Are the directors of films Joe (2013 Film) and Boynton Beach Club both from the same country?\n[Retrieval] Are the directors of films Joe (2013 Film) and Boynton Beach Club both from the same country?""", 46 | """Question: Which film has the director born earlier, Raja Kumarudu or Into The Abyss (Film)?\n[Retrieval] Which film has the director born earlier, Raja Kumarudu or Into The Abyss (Film)?""", 47 | """Question: What is the place of birth of the director of film The Road To Denver?\n[Retrieval] What is the place of birth of the director of film The Road To Denver?""", 48 | """Question: What is the date of death of Charles-René D'Hozier's father?\n[Retrieval] What is the date of death of Charles-René D'Hozier's father?""", 49 | 50 | ], 51 | "planning": [ 52 | # [Planning] 53 | """Question: When was the director of film My Official Wife (1914 Film) born?\n[Planning]""", 54 | """Question: Which film whose director is younger, Men Without Law or Headlines (1925 Film)?\n[Planning]""", 55 | """Question: What is the place of birth of the director of film Martha, Meet Frank, Daniel And Laurence?\n[Planning]""", 56 | """Question: Where was the father of Mirjam Finkelstein born?\n[Planning]""", 57 | """Question: What is the date of death of Joan Of Dampierre's mother?\n[Planning]""", 58 | """Question: Which film whose director is younger, The Vein or Judgment Deferred?\n[Planning]""", 59 | """Question: What is the date of death of the director of film Little Man, What Now? (1933 Film)?\n[Planning]""", 60 | """Question: Which film was released first, Welcome To Home Gori or Good Sam?\n[Planning]""", 61 | """Question: Where did the director of film Happy Ghost Iv study?\n[Planning]""", 62 | """Question: What is the date of death of the director of film The Sporting Lover?\n[Planning]""", 63 | """Question: Who is Prince Moulay Rachid Of Morocco's paternal grandmother?\n[Planning]""", 64 | """Question: Why did the director of film Being Respectable die?\n[Planning]""", 65 | """Question: Are both Scatterwood Lake and Hazeltine Lake located in the same country?\n[Planning]""", 66 | """Question: When is the director of film Mickey One 's birthday?\n[Planning]""", 67 | """Question: Are both Liege/Cnrl Aerodrome and Deer Lake Airport located in the same country?\n[Planning]""", 68 | """Question: When did the director of film Mutthu Ondu Mutthu die?\n[Planning]""", 69 | """Question: Where did the director of film Two Tickets To Broadway graduate from?\n[Planning]""", 70 | """Question: Who is the paternal grandfather of John Iv, Count Of Soissons?\n[Planning]""", 71 | ] 72 | } -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decision/NaturalQuestions.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | "no_retrieve": [ 3 | # [No Retrieval] 4 | """Question: who said a plague on both your houses\n[No Retrieval]""", 5 | """Question: when did the song mr. sandman come out\n[No Retrieval]""", 6 | """Question: when did ringling brothers merge with barnum and bailey\n[No Retrieval]""", 7 | """Question: what were the two opposing sides in china 's civil war\n[No Retrieval]""", 8 | """Question: when did pawn stars first air on tv\n[No Retrieval]""", 9 | """Question: who is the original singer of how sweet it is to be loved by you\n[No Retrieval]""", 10 | """Question: where did they film transformers age of extinction\n[No Retrieval]""", 11 | """Question: the portion of the uterine endometrium that is shed every month is the\n[No Retrieval]""", 12 | """Question: once upon a time in mumbaai based on whose story\n[No Retrieval]""", 13 | 14 | ], 15 | 16 | "retrieve": [ 17 | # [Retrieval] 18 | """Question: who died in the beginning of fast and furious\n[Retrieval] who died in the beginning of fast and furious""", 19 | """Question: who appeared on saturday night live when adele was the musical guest in 2008\n[Retrieval] who appeared on saturday night live when adele was the musical guest in 2008""", 20 | """Question: who has the most nba championships in the nba\n[Retrieval] who has the most nba championships in the nba""", 21 | """Question: the american academician who taught the first sociology courses in the united states was\n[Retrieval] the american academician who taught the first sociology courses in the united states was""", # ['William Graham Sumner'] 22 | """Question: where does cape towns water supply come from\n[Retrieval] where does cape towns water supply come from""", 23 | """Question: when does 47 meters down come out in uk\n[Retrieval] when does 47 meters down come out in uk""", # ['26 July 2017'] 24 | """Question: who sang i want to shake you down\n[Retrieval] who sang i want to shake you down""", # Gregory Abbott 25 | """Question: when was sound captured for the first time\n[Retrieval] when was sound captured for the first time""", # 必须做filter 26 | """Question: where does the last name barnes come from\n[Retrieval] where does the last name barnes come from""", # 必须做filter 27 | 28 | 29 | ], 30 | 31 | "planning": [ 32 | # [Planning] 33 | """Question: what is your dragon from how to train your dragon\n[Planning]""", 34 | 35 | # from other dataset 36 | """Question: When was the director of film My Official Wife (1914 Film) born?\n[Planning]""", 37 | """Question: Which film whose director is younger, Men Without Law or Headlines (1925 Film)?\n[Planning]""", 38 | """Question: What is the place of birth of the director of film Martha, Meet Frank, Daniel And Laurence?\n[Planning]""", 39 | """Question: Who was the lead singer of the manhattans?\n[Planning]""", 40 | """Question: What is the name of the matchmaker in fiddler?\n[Planning]""", 41 | """Question: Who was the composer of Two?\n[Planning]""", 42 | """Question: What genre is Foxy Brown?\n[Planning]""", # Pete Townshend 43 | 44 | ] 45 | } 46 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decision/PopQA.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | "no_retrieve": [ 3 | # [No Retrieval] 4 | """Question: Who was the producer of Roy?\n[No Retrieval]""", 5 | """Question: What is the capital of Curaçao?\n[No Retrieval]""", 6 | """Question: What is the capital of County Kilkenny?\n[No Retrieval]""", 7 | """Question: Who was the screenwriter for These Three?\n[No Retrieval]""", 8 | """Question: Who is the author of Fihi Ma Fihi?\n[No Retrieval]""", 9 | """Question: In what country is Port Authority Trans-Hudson?\n[No Retrieval]""", 10 | """Question: What is the religion of Rama?\n[No Retrieval]""", 11 | """Question: What sport does Rugby Africa play?\n[No Retrieval]""", 12 | """Question: Who is the father of Anil Kapoor?\n[No Retrieval]""", 13 | """Question: Who was the director of Sugar?\n[No Retrieval]""", 14 | 15 | ], 16 | 17 | "retrieve": [ 18 | # [Retrieval] 19 | """Question: Who was the director of Skull Heads?\n[Retrieval] Who was the director of Skull Heads?""", 20 | """Question: What genre is 454 Big Block?\n[Retrieval] What genre is 454 Big Block?""", 21 | """Question: Who was the screenwriter for The Sea Inside?\n[Retrieval] Who was the screenwriter for The Sea Inside?""", 22 | """Question: Who was the screenwriter for Greed?\n[Retrieval] Who was the screenwriter for Greed?""", # ['William Graham Sumner'] 23 | """Question: In what city was Jack Kachkar born?\n[Retrieval] In what city was Jack Kachkar born?""", 24 | """Question: What genre is World Trade?\n[Retrieval] What genre is World Trade?""", # ['26 July 2017'] 25 | """Question: What is Sarai the capital of?\n[Retrieval] What is Sarai the capital of?""", # Gregory Abbott 26 | """Question: Who is the author of Voyage d'Egypte et de Nubie?\n[Retrieval] Voyage d'Egypte et de Nubie""", # 必须做filter 27 | # """Question: Who was the producer of La Mission?\n[Retrieval] Who was the producer of La Mission?""", # 必须做filter 28 | # """Question: What sport does Paulo Grilo play?\n[Retrieval]""", # ['association football', 'football', 'soccer'] 29 | """Question: In what city was Barbara Harris born?\n[Retrieval] Barbara Harris""", # ['Philadelphia', 'Philly', 'City of Brotherly Love', 'Cradle of Liberty', 'Philadelphia, Pennsylvania', 'City of Philadelphia', 'Philadelphia, PA'] 30 | """Question: What is Jacopo Melani's occupation?\n[Retrieval] Jacopo Melani""", 31 | """Question: What is the religion of Juan Soldevilla y Romero?\n[Retrieval] What is the religion of Juan Soldevilla y Romero?""", 32 | """Question: What is Laishevo the capital of?\n[Retrieval] What is Laishevo the capital of?""", 33 | 34 | 35 | ], 36 | 37 | "planning": [ 38 | # [Planning] 39 | """Question: Who was the composer of Two?\n[Planning]""", 40 | """Question: What genre is Foxy Brown?\n[Planning]""", # Pete Townshend 41 | """Question: Who was the producer of The Piano?\n[Planning]""", # Robert Southey 42 | # """Question: What is the religion of Francis?\n[Planning]""", 43 | # """Question: Who is the author of The Search?\n[Planning]""", 44 | # """Question: Who was the screenwriter for The Bench?\n[Planning]""", # 检索不出来 45 | """Question: Who was the director of The Wedding?\n[Planning]""", 46 | ] 47 | } 48 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decision/TriviaQA.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | "no_retrieve": [ 3 | # [No Retrieval] 4 | """Question: Who was the first head of the Church of England?\n[No Retrieval]""", 5 | """Question: 'Roadhouse' is the ironic US Secret Service codename for what famous standard-setting luxury Art Deco hotel at 301 Park Avenue in Manhattan, New York City?\n[No Retrieval]""", 6 | """Question: Meaning literally 'make like' what is the full Latin word from which 'fax' derives (as in fax machine)?\n[No Retrieval]""", 7 | """Question: What does the Latin phrase ‘ab initio’ translate to in English?\n[No Retrieval]""", 8 | """Question: In which country is the town of Gorgonzola?\n[No Retrieval]""", 9 | """Question: Most of the Ozark Plateau is in which two US states?\n[No Retrieval]""", 10 | """Question: What was the name of the dog who accompanied the Three Men In A Boat?\n[No Retrieval]""", 11 | """Question: In poetry, a quatrain is a stanza or complete poem consisting of how many lines of verse?\n[No Retrieval]""", 12 | """Question: June 21, 1973, saw the US Supreme Court establish the Miller Test, which determines whether something is, or isnt, what?\n[No Retrieval]""", 13 | """Question: Who had a hit with 'What becomes of the broken hearted' in 1966 and again in 1974?\n[No Retrieval]""", 14 | 15 | ], 16 | 17 | "retrieve": [ 18 | # [Retrieval] 19 | """Question: Who is the Chief Constable of the Greater Manchester Police Force?\n[Retrieval] Who is the Chief Constable of the Greater Manchester Police Force?""", 20 | """Question: "You get nothing for a pair" was a Bruce Forsyth catchphrase in which programme?\n[Retrieval] "You get nothing for a pair" was a Bruce Forsyth catchphrase in which programme?""", 21 | """Question: Who wrote How to Cheat at Cooking published in 1971?\n[Retrieval] Who wrote How to Cheat at Cooking published in 1971?""", 22 | """Question: Glassed-eyed member of the 'Rat Pack'?\n[Retrieval] Glassed-eyed member of the 'Rat Pack'?""", 23 | """Question: South African rugby winger Bryan Habana was named after which famous English sportsman?\n[Retrieval] South African rugby winger Bryan Habana was named after which famous English sportsman?""", 24 | """Question: Who went ‘Beyond Breaking Point’ in a Sport Relief challenge in March?\n[Retrieval] Who went ‘Beyond Breaking Point’ in a Sport Relief challenge in March?""", 25 | """Question: What is the surname of the character played by Joanna Lumley in 'Absolutely Fabulous'?\n[Retrieval] What is the surname of the character played by Joanna Lumley in 'Absolutely Fabulous'?""", 26 | """Question: What is the longest running show staged at London's Royal Drury Lane theatre?\n[Retrieval] What is the longest running show staged at London's Royal Drury Lane theatre?""", 27 | """Question: Bushido, developed between the 9th and 20th centuries relates to which culture?\n[Retrieval] Bushido, developed between the 9th and 20th centuries relates to which culture?""", 28 | 29 | """Question: 'Bloody Mary' refers to which queen of England?\n[Retrieval] 'Bloody Mary' refers to which queen of England?""", # ['mary 1', 'Mary 1'] 30 | """Question: Which service station is on the M6 toll motorway?\n[Retrieval] Which service station is on the M6 toll motorway?""", # ['norton caines', 'Norton Caines'] 31 | 32 | """Question: In the USA it's the Oscars what is it in France?\n[Retrieval] In the USA it's the Oscars what is it in France?""", # ['caesars', 'caesar disambiguation', 'Caesar (disambiguation)', 'Casear', 'casear', 'Cesear', 'Caesare', 'Caesaros', 'Ceaser', 'caeser', 'caesaros', 'cæsar', 'Caesars', 'caesare', ...] 33 | 34 | """Question: Arnova, made by the French corporation Archos, is an 'entry level' brand of what?\n[Retrieval] Arnova, made by the French corporation Archos""", 35 | 36 | 37 | ], 38 | 39 | "planning": [ 40 | # [Planning] 41 | # """Question: On the London underground only one station contains a single vowel. Which station?\n[Planning]""", # ['Banking system', '🏦', 'Banking business', 'Banking industry', '⛻', 'Credit institutions', 'Bank', 'Money-lenders', 'Banking establishment', 'banker', 'Credit Institutions', 'Monetary intermediation', 'credit institution', 'Banks and Banking', ...] 42 | # """Question: "Mr Tom Piperson" features in which of Beatrix Potter\'s stories?\n[Planning]""", 43 | # """Question: Zebra, Panda, Pelican, and Puffin are types of UK what?\n[Planning]""", # ['pedestrian road crossings', 'Pedestrian road crossings'] 44 | 45 | 46 | # """Question: Who had Wings like a shield of steel?\n[Planning]""", # ['batfink', 'The Short Circuit Case', 'Batfink: This Is Your Life', 'short circuit case', 'Pink Pearl of Persia', 'batfink this is your life', 'Batfink', 'pink pearl of persia'] 47 | # from other dataset 48 | """Question: When was the director of film My Official Wife (1914 Film) born?\n[Planning]""", 49 | """Question: Which film whose director is younger, Men Without Law or Headlines (1925 Film)?\n[Planning]""", 50 | """Question: What is the place of birth of the director of film Martha, Meet Frank, Daniel And Laurence?\n[Planning]""", 51 | """Question: Who was the lead singer of the manhattans?\n[Planning]""", 52 | """Question: What is the name of the matchmaker in fiddler?\n[Planning]""", 53 | """Question: what is your dragon from how to train your dragon\n[Planning]""", 54 | """Question: Who was the composer of Two?\n[Planning]""", 55 | """Question: What genre is Foxy Brown?\n[Planning]""", # Pete Townshend 56 | 57 | ] 58 | } 59 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/few_shots/decision/__init__.py -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/decision/hotpotqa.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | "no_retrieve": [ 3 | # [No Retrieval] 4 | # """Question: Ugo Falena wrote the libretto for the opera "L\'ultimo Lord which is by an Italian composer and pianist who is known for having completed Puccini\'s opera Turandot in what year?\n[No Retrieval]""", 5 | # """Question: Katie Green, is an English model from Chichester, West Sussex, she initiated a campaign against size zero models with which Liberal Democrat, a British politician, a member of the Liberal Democrats, he served as a member of parliament (MP) representing the constituency of Montgomeryshire in Wales from 1997 until he lost his seat at the 2010 general election?\n[No Retrieval]""", 6 | # """Question: What is the name of this major American airline fron 1924 until 2001 that was owned by the originally named Trans World Corporation?\n[No Retrieval]""", 7 | # """Question: Of these two chief executives of the National Football League, Milton P. "Milt" Woodard and Allen Davis, who was born first?\n[No Retrieval]""", 8 | # """Question: The museum for which Mike Massimino is the senior advisor of space programs for is located in which neighborhood on the West Side of Manhattan?\n[No Retrieval]""", 9 | # """Question: Are Monster Beverage and Gilead Sciences both American companies?\n[No Retrieval]""", 10 | 11 | 12 | ], 13 | "retrieve": [ 14 | # [Retrieval] 15 | # """Question: What is the name of this Indian singer, who appeared in Comedy Circus Ke Taansen and recorded a dance song Main Tera Boyfriend with Meet Bros and Arijit Singh?\n[Retrieval] What is the name of this Indian singer, who appeared in Comedy Circus Ke Taansen and recorded a dance song Main Tera Boyfriend with Meet Bros and Arijit Singh?""", 16 | # """Question: What nationality is novelist Zadie Smith's husband?\n[Retrieval] What nationality is novelist Zadie Smith's husband?""", 17 | # """Question: Rock & Roll is an EP by a singer born in which year ?\n[Retrieval] Rock & Roll is an EP by a singer born in which year ?""", 18 | # """Question: What art center can be found near the ghost town in Bullfrog Hills and on the same property as a museum?\n[Retrieval] What art center can be found near the ghost town in Bullfrog Hills and on the same property as a museum?""", 19 | # """Question: Where were the teams home games played who represented the University of Kentucky in the Southeastern Conference and who's coach was a retired American football coach and former player ?\n[Retrieval] Where were the teams home games played who represented the University of Kentucky in the Southeastern Conference and who's coach was a retired American football coach and former player ?""", 20 | # """Question: Reveal dealt with the departure of a former drummer best known for which band?\n[Retrieval] Reveal dealt with the departure of a former drummer best known for which band?""", 21 | ], 22 | 23 | "planning": [ 24 | # [Planning] 25 | """Question: Who wrote the obituary for the man who created the \"Watch Mr. Wizard\" television programming?\n[Planning]""", 26 | """Question: Where both Games Magazine and The General published by Games Publications?\n[Planning]""", 27 | """Question: Which of the Lake Poets become Poet Laureate for 30 years from 1913 through 1943?\n[Planning]""", # Robert Southey 28 | """Question: What ABC series is the actress who played Mary in the 2017 American drama film about a 7 year old who ecomes the subject of a custody battle between her uncle and grandmother in?\n[Planning]""", 29 | """Question: what was the American singer that sing "We\'re Getting Stronger" known for \n[Planning]""", 30 | """Question: Who was the seventh president of the oldest higher education institution in suburban Long Island?\n[Planning]""", 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/few_shots/filter/__init__.py -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/filter/template.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = { 2 | "retrieve": [ 3 | 4 | 5 | ], 6 | 7 | "planning": [ 8 | 9 | 10 | 11 | ] 12 | } 13 | 14 | """Question: \n\nDocuments: \n\nThought: \n\nAction: [1]""", 15 | 16 | 17 | EXAMPLES = { 18 | "retrieve": [ 19 | 20 | 21 | ], 22 | 23 | "planning": [ 24 | 25 | 26 | ] 27 | } 28 | 29 | 30 | 31 | 32 | 33 | 34 | """Current step's objectives: \n\nQuestion: \n\nDocuments: \n\nThought: \n\nAction: [1]""", -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/2WikiMultiHopQA.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = [ 2 | """Question: What nationality is the director of film The Caper Of The Golden Bulls?\nTo answer the question, we need to find information about the director of the film "The Caper of the Golden Bulls." Then we should determine which nationality is the director born using the retrieval.\nStep 1: Retrieve the relevant documents that mention the film `The Caper of the Golden Bulls.`\nStep 2: Identify the director of the film from the retrieved documents.\nStep 3: Retrieve the relevant information about `Which nationality is the director born?`.\nStep 4: Provide the answer based on the retrieved information.""", 3 | """Question: When is Dominic Roco's father's birthday?\nI don't have access to personal information about who is the father of Dominic Roco, including his birthday. However, we can retrieve information to answer the question about Dominic Roco's father's birthday.\nSTEP 1: We should retrieve the relevant documents about `Dominic Roco` to identify his father's name.\nSTEP 2: Retrieve the relevant information about the father's birthday.\nSTEP 3: Provide the answer based on the retrieved information.""", 4 | """Question: Why did the director of film The Notorious Landlady die?\nTo answer the question, we need to find information about `the director of film The Notorious Landlady`. Then, we should retrieve the cause of the director's death.\nSTEP 1: Retrieve the relevant documents about `the director of film The Notorious Landlady`.\nSTEP 2: Identify the director of the film from the retrieved documents.\nSTEP 3: Retrieve the relevant information about `Why did the director die?`.\nSTEP 4: Finish the answer based on the retrieved information.""", 5 | """Question: Where was the place of death of the director of film Magic Mirror (Film)?\nWe know that the film "Magic Mirror" was directed by Manoel de Oliveira. To answer the question, we need to find information about the place of death of Manoel de Oliveira.\nSTEP 1: Retrieve the relevant documents about `Where was the place of death of Manoel de Oliveira?`.\nSTEP 2: Retrieve the place of death of Manoel de Oliveira from the retrieved documents and provide the answer.""", 6 | 7 | """Question: When was the director of film My Official Wife (1914 Film) born?\nTo address the question, we'll break it down into the following core components:\nIdentify the Film: We need to confirm that "My Official Wife" is indeed the film in question, as there may be multiple films with similar titles.\nIdentify the Director: We need to determine who directed the 1914 film "My Official Wife."\nFind Birth Date: The final goal is to ascertain the birth date of the identified director.\nSTEP 1: Search for detailed information about the film "My Official Wife" (1914). This includes its cast, crew, and particularly the director.\nSTEP 2: From the retrieved information about the film, extract the name of the director.\nSTEP 3: Once the director is identified, we need to search for biographical information on that person to find their birth date.\nSTEP 4: Finalize the response by compiling the birth date and ensuring it is attributed to the correct director of "My Official Wife" (1914).""", 8 | 9 | """Question: Which film whose director is younger, Men Without Law or Headlines (1925 Film)?\nThe question asks us to compare two films: "Men Without Law" and "Headlines (1925 Film)." The key focus is to determine which film's director is younger. We need to know the directors of both films and the birth years of the directors to compare their ages effectively.\nSTEP 1: Conduct a search for "Men Without Law" to find out its director.\nSTEP 2: Search infomation about the director and identify the birth year of the director of "Men Without Law"\nSTEP 3: Conduct a separate search for "Headlines (1925 Film)" to find out its director.\nSTEP 4: Search infomation about the director and identify the birth year of the director of "Headlines (1925 Film)"\nSTEP 5: Compare the birth years of the two directors to determine which director is younger. Provide the answer based on the comparison.""", 10 | 11 | """Question: What is the place of birth of the director of film Martha, Meet Frank, Daniel And Laurence?\nIdentify Core Components: Film Title: "Martha, Meet Frank, Daniel And Laurence". Key Subject: Director of the film. Information Requested: Place of birth of the director. We need to identify who the director of the film is and identify who the director of the film is.\nSTEP 1: Retrieve the relevant documents that mention the film "Martha, Meet Frank, Daniel And Laurence".\nSTEP 2: Identify the director of the film from the retrieved documents.\nSTEP 3: Retrieve the relevant information about the place of birth of the director.\nSTEP 4: Provide a clear and comprehensive answer to the original question.""", 12 | 13 | """Question: Where was the father of Mirjam Finkelstein born?\nQuestion Analysis:\n- Core Components:\nSubject: The father of Mirjam Finkelstein\nAction: Determine the birthplace of this individual\n- Additional Information Needed:\nDetails regarding Mirjam Finkelstein's family, particularly her father's background, including his name and place of birth.\nSTEP 1: Conduct a search to gather background information about Mirjam Finkelstein. This may include finding details about her family, specifically focusing on her father.\nSTEP 2: Once sufficient information about Mirjam is gathered, specifically locate the father's name. Conduct a separate search to find detailed information about her father's birthplace.\nSTEP 3: Compile the gathered data into a comprehensive format, ensuring that the birthplace of Mirjam Finkelstein's father is highlighted as the answer to the question.""", 14 | 15 | """Question: What is the date of death of Joan Of Dampierre's mother?\nQuestion Analysis: 16 | - Key components of the question: 17 | We need to find the date of death. 18 | The person whose death date we're looking for is the mother of Joan of Dampierre. 19 | - What we need to find out through retrieval: 20 | The identity of Joan of Dampierre's mother 21 | The date of death of Joan of Dampierre's mother 22 | STEP 1: Use the retrieval system to gather information about Joan of Dampierre and look for specific mentions of her mother's name. 23 | STEP 2: Once the mother's identity is established, use the retrieval system to gather information about her. Focus on finding biographical details, particularly the date of her death. 24 | STEP 3: Summarize the retrieved information gathered about the date of death of Joan of Dampierre's mother, along with any relevant context or additional details""", 25 | 26 | """Question: Which film was released first, Welcome To Home Gori or Good Sam?\nQuestion Analysis: 27 | a. Core components: 28 | Film 1: "Welcome To Home Gori" 29 | Film 2: "Good Sam" 30 | Comparison: Release dates 31 | b. Additional information needed through retrieval: 32 | Release date of "Welcome To Home Gori" 33 | Release date of "Good Sam" 34 | STEP 1: Use retrieval system to search for "Welcome To Home Gori" as a film title, and find its release date from the retrieved documents. 35 | STEP 2: Use retrieval system to search for "Good Sam" as a film title, and identify the release date of "Welcome To Home Gori" from the retrieved documents. 36 | STEP 3: Summarize the release dates of both films and state which film was released first. Provide the answer based on the retrieved information.""", 37 | 38 | """Question: Where did the director of film Happy Ghost Iv study?\nThe question is asking about the educational background of the director of the film Happy Ghost IV. The primary piece of information needed is the director's education history, specifically where they studied.\nSTEP 1: Retrieve the name of the director who directed Happy Ghost IV.\nSTEP 2: Once the director is identified, focus on retrieving information related to their academic history.\nSTEP 3: Identify the educational institution where the director studied and provide the answer based on the retrieved information.""", 39 | 40 | 41 | ] 42 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/ASQA.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = [ 2 | """Question: Who was the lead singer of the manhattans?\nKey information we know: 3 | a) The Manhattans were a musical group 4 | b) We're looking for information about the lead singer, which may have changed over time. 5 | STEP 1: Retrieve information about The Manhattans. 6 | STEP 2: Check if there were multiple lead singers or changes over time. Identify who was designated as the lead singer.""", 7 | 8 | """Question: What is the name of the matchmaker in fiddler?\nThe question is about a character in Fiddler (potentially referring to multiple productions of "Fiddler on the Roof"). The character's role is a matchmaker. Consideration of both character names and actor names. 9 | STEP 1: Use the retrieval system to search for "Fiddler" in relation to matchmakers. 10 | STEP 2: Identify the character who is a matchmaker in the context of "Fiddler on the Roof." 11 | STEP 3: Provide the name of the character (may the actor's name if needed) who is a matchmaker in the production of "Fiddler on the Roof".""", 12 | 13 | """Question: When was the imf and world bank created?\nThe IMF and World Bank are key players in global economic stability and development. We need to identify the specific dates when these institutions were established.\nSTEP 1: Retrieve the founding dates of the International Monetary Fund (IMF) and identify the year it was established.\nSTEP 2: Retrieve the founding dates of the World Bank and identify the year it was established.\nSTEP 3: Provide the specific years when the IMF and World Bank were created based on the retrieved information.""", 14 | 15 | """Question: Who landed the first quad in figure skating?\nIt's an ambiguous question, due to the fact that the term "quad" can refer to different types of jumps in figure skating. The first quad jump in figure skating was landed by Kurt Browning in 1988. However, it was not in competition. Therefore, we should retrieve the information about `who landed the first quad in figure skating?` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve some information about `who landed the first quad in figure skating?`.\nSTEP 2: Analysis the retrieved information to determine other types of quad jumps in figure skating.\nSTEP 3: Provide all the answer about the first quad jumps in figure skating.""", 16 | 17 | """Question: When is the movie thor ragnarok coming out?\n'Thor: Ragnarok' was released on November 3, 2017, in the United States. However, if you are in a different country, the release date might have varied slightly. Therefore, it's a ambiguous question, due to lack of specific information about the place. We need to retrieve more information about `when is the movie thor ragnarok coming out?` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve more information about `when is the movie thor ragnarok coming out?`.\nSTEP 2: Analysis different release dates in different regions and provide a comprehensive answer.""", 18 | 19 | """Question: Who has the most points per game in nhl history?\nIt's an ambiguous question. Clarify what is meant by 'the most points per game' - whether it refers to regular season only or includes playoffs, whether it refers to a player or a team, and something else. Therefore, we should retrieve the information about `the most points per game in NHL history` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve some information about `the most points per game in NHL history`.\nSTEP 2: Analysis the retrieved information to identify key information to eliminate ambiguity.\nSTEP 3: Combine all sub-questions to provide a comprehensive answer.""", 20 | 21 | ] 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/NaturalQuestions.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = [ 2 | """Question: what is your dragon from how to train your dragon\nQuestion Analysis: The question is about a dragon from the "How to Train Your Dragon". It's asking about a specific dragon that belongs to or is associated with the person being asked. We first consider the main characters and their associated dragons. 3 | STEP 1: Retrieve the information about the "what is your dragon from how to train your dragon". 4 | STEP 2: Identify the dragon mentioned in the retrieved documents. 5 | STEP 3: Provide the answer based on the retrieved information.""", 6 | 7 | # """Question: Which British rock musician moved to Polydor Records in 1966\nThe question is about a character in Fiddler (potentially referring to multiple productions of "Fiddler on the Roof"). The character's role is a matchmaker. Consideration of both character names and actor names. 8 | # STEP 1: Use the retrieval system to search for "Fiddler" in relation to matchmakers. 9 | # STEP 2: Identify the character who is a matchmaker in the context of "Fiddler on the Roof." 10 | # STEP 3: Provide the name of the character (may the actor's name if needed) who is a matchmaker in the production of "Fiddler on the Roof".""", 11 | 12 | # """Question: When was the imf and world bank created?\nThe IMF and World Bank are key players in global economic stability and development. We need to identify the specific dates when these institutions were established.\nSTEP 1: Retrieve the founding dates of the International Monetary Fund (IMF) and identify the year it was established.\nSTEP 2: Retrieve the founding dates of the World Bank and identify the year it was established.\nSTEP 3: Provide the specific years when the IMF and World Bank were created based on the retrieved information.""", 13 | 14 | # """Question: Who landed the first quad in figure skating?\nIt's an ambiguous question, due to the fact that the term "quad" can refer to different types of jumps in figure skating. The first quad jump in figure skating was landed by Kurt Browning in 1988. However, it was not in competition. Therefore, we should retrieve the information about `who landed the first quad in figure skating?` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve some information about `who landed the first quad in figure skating?`.\nSTEP 2: Analysis the retrieved information to determine other types of quad jumps in figure skating.\nSTEP 3: Provide all the answer about the first quad jumps in figure skating.""", 15 | 16 | # """Question: When is the movie thor ragnarok coming out?\n'Thor: Ragnarok' was released on November 3, 2017, in the United States. However, if you are in a different country, the release date might have varied slightly. Therefore, it's a ambiguous question, due to lack of specific information about the place. We need to retrieve more information about `when is the movie thor ragnarok coming out?` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve more information about `when is the movie thor ragnarok coming out?`.\nSTEP 2: Analysis different release dates in different regions and provide a comprehensive answer.""", 17 | 18 | # """Question: Who has the most points per game in nhl history?\nIt's an ambiguous question. Clarify what is meant by 'the most points per game' - whether it refers to regular season only or includes playoffs, whether it refers to a player or a team, and something else. Therefore, we should retrieve the information about `the most points per game in NHL history` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve some information about `the most points per game in NHL history`.\nSTEP 2: Analysis the retrieved information to identify key information to eliminate ambiguity.\nSTEP 3: Combine all sub-questions to provide a comprehensive answer.""", 19 | 20 | ] 21 | 22 | 23 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/PopQA.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = [ 2 | """Question: Who was the composer of Two?\nQuestion Analysis: 3 | Main Subject: Composer 4 | Title: Two 5 | Ambiguity: The term "Two" could refer to a film, a musical, or another medium involving music. 6 | The term "Two" is not specifically defined in the question, meaning that multiple works could exist under this title. The context of the composer is essential for pinpointing the correct individual if "Two" refers to a specific work. 7 | Additional Information Needed: 8 | Determining if "Two" refers to a film, a piece of music, or another form of artistic expression. 9 | STEP 1: Conduct a retrieval search `the film or music Two` to determine the context of the question. 10 | STEP 2: Identify the composer of the work titled "Two".""", 11 | 12 | """Question: What genre is Foxy Brown?\nQuestion Analysis: Foxy Brown refers to both a character (from the 1974 film directed by Jack Hill) and an individual (the rapper/singer). The term "genre" could pertain to various contexts, such as film genre, musical genre, or even genre as it applies to specific works related to the character or singer. 13 | Additional Information Needed: If the question pertains to the film, genre classifications such as action, blaxploitation, or drama may need exploration. If it refers to the rapper Foxy Brown, investigating hip-hop, rap, and sub-genres of music may be necessary. Clarification on whether the question is regarding the film or the singer would be helpful, but this may require retrieval of contextual information or assumptions based on common usage of the name. 14 | STEP 1: Conduct a retrieval search to gather the information about Foxy Brown. Identify the context of the question (film or singer). 15 | STEP 2: Retrieve the information of the genre associated with the identified context (film or rapper/singer). 16 | STEP 3: Provide the genre of Foxy Brown based on the retrieved information.""", 17 | 18 | """Question: Who was the producer of The Piano?\nQuestion Analysis: "The Piano" likely refers to a well-known film or music piece. We need to identify the individual or entity responsible for its production. If there are multiple interpretations of "The Piano," include all relevant information for a comprehensive response. 19 | Step 1: Conduct a retrieval search for "film The Piano" to determine if the context is film-related. 20 | Step 2: Conduct a retrieval search for "music The Piano" to determine if the context is music-related. 21 | Step 3: Based on the retrieved information, provide the name of the producer. If both film and music contexts are identified, provide producers for both; otherwise, focus on the relevant context specified by the search.""", 22 | 23 | """Question: Who was the director of The Wedding?\nQuestion Analysis: The title "The Wedding" is a common name for films and other media, which might refer to different productions across years. We need to retrieve relevant information about the film or media work titled "The Wedding" to identify the director. If there are multiple works with the same title, we should summarize all relevant works for a comprehensive response. 24 | STEP 1: Conduct a retrieval search for "The Wedding" to determine the context of the question. 25 | STEP 2: Summarize all works titled "The Wedding" and identify the director for each work to provide a comprehensive response.""", 26 | 27 | # """Question: When is the movie thor ragnarok coming out?\n'Thor: Ragnarok' was released on November 3, 2017, in the United States. However, if you are in a different country, the release date might have varied slightly. Therefore, it's a ambiguous question, due to lack of specific information about the place. We need to retrieve more information about `when is the movie thor ragnarok coming out?` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve more information about `when is the movie thor ragnarok coming out?`.\nSTEP 2: Analysis different release dates in different regions and provide a comprehensive answer.""", 28 | 29 | # """Question: Who has the most points per game in nhl history?\nIt's an ambiguous question. Clarify what is meant by 'the most points per game' - whether it refers to regular season only or includes playoffs, whether it refers to a player or a team, and something else. Therefore, we should retrieve the information about `the most points per game in NHL history` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve some information about `the most points per game in NHL history`.\nSTEP 2: Analysis the retrieved information to identify key information to eliminate ambiguity.\nSTEP 3: Combine all sub-questions to provide a comprehensive answer.""", 30 | 31 | ] 32 | 33 | 34 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/TriviaQA.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/few_shots/planning/TriviaQA.py -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/C-3PO/prompt/few_shots/planning/__init__.py -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/hotpotqa.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = [ 2 | """Question: What nationality is the director of film The Caper Of The Golden Bulls?\nTo answer the question, we need to find information about the director of the film "The Caper of the Golden Bulls." Then we should determine which nationality is the director born using the retrieval.\nStep 1: Retrieve the relevant documents that mention the film `The Caper of the Golden Bulls.`\nStep 2: Identify the director of the film from the retrieved documents.\nStep 3: Retrieve the relevant information about `Which nationality is the director born?`.\nStep 4: Provide the answer based on the retrieved information.""", 3 | """Question: When is Dominic Roco's father's birthday?\nI don't have access to personal information about who is the father of Dominic Roco, including his birthday. However, we can retrieve information to answer the question about Dominic Roco's father's birthday.\nSTEP 1: We should retrieve the relevant documents about `Dominic Roco` to identify his father's name.\nSTEP 2: Retrieve the relevant information about the father's birthday.\nSTEP 3: Provide the answer based on the retrieved information.""", 4 | """Question: Why did the director of film The Notorious Landlady die?\nTo answer the question, we need to find information about `the director of film The Notorious Landlady`. Then, we should retrieve the cause of the director's death.\nSTEP 1: Retrieve the relevant documents about `the director of film The Notorious Landlady`.\nSTEP 2: Identify the director of the film from the retrieved documents.\nSTEP 3: Retrieve the relevant information about `Why did the director die?`.\nSTEP 4: Finish the answer based on the retrieved information.""", 5 | 6 | """Question: Where was the place of death of the director of film Magic Mirror (Film)?\nWe know that the film "Magic Mirror" was directed by Manoel de Oliveira. To answer the question, we need to find information about the place of death of Manoel de Oliveira.\nSTEP 1: Retrieve the relevant documents about `Where was the place of death of Manoel de Oliveira?`.\nSTEP 2: Retrieve the place of death of Manoel de Oliveira from the retrieved documents and provide the answer.""", 7 | 8 | """Question: When was the director of film My Official Wife (1914 Film) born?\nTo address the question, we'll break it down into the following core components:\nIdentify the Film: We need to confirm that "My Official Wife" is indeed the film in question, as there may be multiple films with similar titles.\nIdentify the Director: We need to determine who directed the 1914 film "My Official Wife."\nFind Birth Date: The final goal is to ascertain the birth date of the identified director.\nSTEP 1: Search for detailed information about the film "My Official Wife" (1914). This includes its cast, crew, and particularly the director.\nSTEP 2: From the retrieved information about the film, extract the name of the director.\nSTEP 3: Once the director is identified, we need to search for biographical information on that person to find their birth date.\nSTEP 4: Finalize the response by compiling the birth date and ensuring it is attributed to the correct director of "My Official Wife" (1914).""", 9 | 10 | # hotpotqa 11 | """Question: Who wrote the obituary for the man who created the \"Watch Mr. Wizard\" television programming?\nQuestion Analysis: 12 | Identify the Creator: The core component of the question is to identify the individual who created the "Watch Mr. Wizard" television program. 13 | Obtain Obituary Details: Once the creator is identified, find and retrieve the obituary that was written for this individual. 14 | Determine the Author of the Obituary: Finally, determine who authored the obituary for the creator of the program. 15 | Step 1: Retrieve information about `who created the \"Watch Mr. Wizard\" television programming?`. 16 | Step 2: Identify the creator of the program from the retrieved documents. 17 | Step 3: Retrieve the information about who wrote the obituary for for the creator. 18 | Step 4: Determine the author of the obituary.""", 19 | 20 | """Question: Where both Games Magazine and The General published by Games Publications?\nKey Information Needed: 21 | Verification if "Games Magazine" was published by "Games Publications." 22 | Verification if "The General" was published by "Games Publications." 23 | Step 1: Retrieve Information about `Is Games Magazine published by Games Publications?`, and identify the publisher of Games Magazine. 24 | Step 2: Retrieve Information about `Is The General published by Games Publications?`, and identify the publisher of The General. 25 | Step 3: Compare the publishers of both magazines to determine if they were published by Games Publications.""", 26 | 27 | # """Question: Where was the father of Mirjam Finkelstein born?\nQuestion Analysis:\n- Core Components:\nSubject: The father of Mirjam Finkelstein\nAction: Determine the birthplace of this individual\n- Additional Information Needed:\nDetails regarding Mirjam Finkelstein's family, particularly her father's background, including his name and place of birth.\nSTEP 1: Conduct a search to gather background information about Mirjam Finkelstein. This may include finding details about her family, specifically focusing on her father.\nSTEP 2: Once sufficient information about Mirjam is gathered, specifically locate the father's name. Conduct a separate search to find detailed information about her father's birthplace.\nSTEP 3: Compile the gathered data into a comprehensive format, ensuring that the birthplace of Mirjam Finkelstein's father is highlighted as the answer to the question.""", 28 | 29 | # """Question: What is the date of death of Joan Of Dampierre's mother?\nQuestion Analysis: 30 | # - Key components of the question: 31 | # We need to find the date of death. 32 | # The person whose death date we're looking for is the mother of Joan of Dampierre. 33 | # - What we need to find out through retrieval: 34 | # The identity of Joan of Dampierre's mother 35 | # The date of death of Joan of Dampierre's mother 36 | # STEP 1: Use the retrieval system to gather information about Joan of Dampierre and look for specific mentions of her mother's name. 37 | # STEP 2: Once the mother's identity is established, use the retrieval system to gather information about her. Focus on finding biographical details, particularly the date of her death. 38 | # STEP 3: Summarize the retrieved information gathered about the date of death of Joan of Dampierre's mother, along with any relevant context or additional details""", 39 | 40 | # """Question: Which film was released first, Welcome To Home Gori or Good Sam?\nQuestion Analysis: 41 | # a. Core components: 42 | # Film 1: "Welcome To Home Gori" 43 | # Film 2: "Good Sam" 44 | # Comparison: Release dates 45 | # b. Additional information needed through retrieval: 46 | # Release date of "Welcome To Home Gori" 47 | # Release date of "Good Sam" 48 | # STEP 1: Use retrieval system to search for "Welcome To Home Gori" as a film title, and find its release date from the retrieved documents. 49 | # STEP 2: Use retrieval system to search for "Good Sam" as a film title, and identify the release date of "Welcome To Home Gori" from the retrieved documents. 50 | # STEP 3: Summarize the release dates of both films and state which film was released first. Provide the answer based on the retrieved information.""", 51 | 52 | # """Question: Where did the director of film Happy Ghost Iv study?\nThe question is asking about the educational background of the director of the film Happy Ghost IV. The primary piece of information needed is the director's education history, specifically where they studied.\nSTEP 1: Retrieve the name of the director who directed Happy Ghost IV.\nSTEP 2: Once the director is identified, focus on retrieving information related to their academic history.\nSTEP 3: Identify the educational institution where the director studied and provide the answer based on the retrieved information.""", 53 | 54 | 55 | ] 56 | -------------------------------------------------------------------------------- /C-3PO/prompt/few_shots/planning/public.py: -------------------------------------------------------------------------------- 1 | EXAMPLES = [ 2 | """Question: When was the director of film My Official Wife (1914 Film) born?\nTo address the question, we'll break it down into the following core components:\nIdentify the Film: We need to confirm that "My Official Wife" is indeed the film in question, as there may be multiple films with similar titles.\nIdentify the Director: We need to determine who directed the 1914 film "My Official Wife."\nFind Birth Date: The final goal is to ascertain the birth date of the identified director.\nSTEP 1: Search for detailed information about the film "My Official Wife" (1914). This includes its cast, crew, and particularly the director.\nSTEP 2: From the retrieved information about the film, extract the name of the director.\nSTEP 3: Once the director is identified, we need to search for biographical information on that person to find their birth date.\nSTEP 4: Finalize the response by compiling the birth date and ensuring it is attributed to the correct director of "My Official Wife" (1914).""", 3 | 4 | """Question: Which film whose director is younger, Men Without Law or Headlines (1925 Film)?\nThe question asks us to compare two films: "Men Without Law" and "Headlines (1925 Film)." The key focus is to determine which film's director is younger. We need to know the directors of both films and the birth years of the directors to compare their ages effectively.\nSTEP 1: Conduct a search for "Men Without Law" to find out its director.\nSTEP 2: Search infomation about the director and identify the birth year of the director of "Men Without Law"\nSTEP 3: Conduct a separate search for "Headlines (1925 Film)" to find out its director.\nSTEP 4: Search infomation about the director and identify the birth year of the director of "Headlines (1925 Film)"\nSTEP 5: Compare the birth years of the two directors to determine which director is younger. Provide the answer based on the comparison.""", 5 | 6 | """Question: What is the place of birth of the director of film Martha, Meet Frank, Daniel And Laurence?\nIdentify Core Components: Film Title: "Martha, Meet Frank, Daniel And Laurence". Key Subject: Director of the film. Information Requested: Place of birth of the director. We need to identify who the director of the film is and identify who the director of the film is.\nSTEP 1: Retrieve the relevant documents that mention the film "Martha, Meet Frank, Daniel And Laurence".\nSTEP 2: Identify the director of the film from the retrieved documents.\nSTEP 3: Retrieve the relevant information about the place of birth of the director.\nSTEP 4: Provide a clear and comprehensive answer to the original question.""", 7 | 8 | """Question: When was the imf and world bank created?\nThe IMF and World Bank are key players in global economic stability and development. We need to identify the specific dates when these institutions were established.\nSTEP 1: Retrieve the founding dates of the International Monetary Fund (IMF) and identify the year it was established.\nSTEP 2: Retrieve the founding dates of the World Bank and identify the year it was established.\nSTEP 3: Provide the specific years when the IMF and World Bank were created based on the retrieved information.""", 9 | 10 | """Question: Who landed the first quad in figure skating?\nIt's an ambiguous question, due to the fact that the term "quad" can refer to different types of jumps in figure skating. The first quad jump in figure skating was landed by Kurt Browning in 1988. However, it was not in competition. Therefore, we should retrieve the information about `who landed the first quad in figure skating?` to provide a comprehensive and accurate response.\nSTEP 1: Retrieve some information about `who landed the first quad in figure skating?`.\nSTEP 2: Analysis the retrieved information to determine other types of quad jumps in figure skating.\nSTEP 3: Provide all the answer about the first quad jumps in figure skating.""", 11 | 12 | """Question: what is your dragon from how to train your dragon\nQuestion Analysis: The question is about a dragon from the "How to Train Your Dragon". It's asking about a specific dragon that belongs to or is associated with the person being asked. We first consider the main characters and their associated dragons. 13 | STEP 1: Retrieve the information about the "what is your dragon from how to train your dragon". 14 | STEP 2: Identify the dragon mentioned in the retrieved documents. 15 | STEP 3: Provide the answer based on the retrieved information.""", 16 | 17 | """Question: Who was the producer of The Piano?\nQuestion Analysis: "The Piano" likely refers to a well-known film or music piece. We need to identify the individual or entity responsible for its production. If there are multiple interpretations of "The Piano," include all relevant information for a comprehensive response. 18 | Step 1: Conduct a retrieval search for "film The Piano" to determine if the context is film-related. 19 | Step 2: Conduct a retrieval search for "music The Piano" to determine if the context is music-related. 20 | Step 3: Based on the retrieved information, provide the name of the producer. If both film and music contexts are identified, provide producers for both; otherwise, focus on the relevant context specified by the search.""", 21 | 22 | """Question: Who was the director of The Wedding?\nQuestion Analysis: The title "The Wedding" is a common name for films and other media, which might refer to different productions across years. We need to retrieve relevant information about the film or media work titled "The Wedding" to identify the director. If there are multiple works with the same title, we should summarize all relevant works for a comprehensive response. 23 | STEP 1: Conduct a retrieval search for "The Wedding" to determine the context of the question. 24 | STEP 2: Summarize all works titled "The Wedding" and identify the director for each work to provide a comprehensive response.""", 25 | 26 | 27 | 28 | ] 29 | 30 | 31 | -------------------------------------------------------------------------------- /C-3PO/prompt/filter.py: -------------------------------------------------------------------------------- 1 | 2 | FILTER_PROMPT_PREFIX = """You are an intelligent assistant tasked with analyzing the retrieved documents based on a given question and the current step's objectives. Your role is to determine the relevance of each document in relation to the question and the specified objectives. 3 | 4 | Instructions: 5 | 1. **Analyze Relevance**: Evaluate each document whether it aligns with the objectives of the current retrieval step and contains a direct answer to the question. 6 | 2. **Thought Process**: Provide a brief analysis for each document, considering both the answer content and the retrieval objectives. 7 | 3. **Filter Documents**: After your thought process, generate a list of document indices indicating which documents to retain. 8 | 4. **Output Format**: 9 | Thought: [Your analysis for each document] 10 | Action: [List of document indices to retain, separated by commas]""" 11 | 12 | FEW_SHOT_PROMPT = """Here are some examples: 13 | {examples}""" 14 | 15 | FILTER_PROMPT_SUFFIX = """Now, process the following question:\n\nCurrent step's objectives: {objective} 16 | 17 | Question: {question} 18 | 19 | Documents: 20 | {documents}""" 21 | 22 | FILTER_human_input = """Current step's objectives: {objective} 23 | 24 | Question: {question} 25 | 26 | Documents: 27 | {documents}""" 28 | 29 | FILTER_PROMPT = FILTER_PROMPT_PREFIX + "\n\n" + FILTER_PROMPT_SUFFIX 30 | FILTER_PROMPT_FEW_SHOT = FILTER_PROMPT_PREFIX + "\n\n" + FEW_SHOT_PROMPT + "\n\n" + FILTER_PROMPT_SUFFIX 31 | 32 | 33 | 34 | # Direct Retrieve 35 | 36 | FILTER_PROMPT_DIRECT_RETRIEVE_PREFIX = """You are an intelligent assistant tasked with analyzing the retrieved documents based on a given question. Your role is to determine the relevance of each document in relation to the question. 37 | 38 | Instructions: 39 | 1. **Analyze Relevance**: Evaluate each document whether it provides helpful and relevant information or contains a direct answer to the question. 40 | 2. **Thought Process**: Provide a brief analysis for each document. 41 | 3. **Filter Documents**: After your thought process, generate a list of document indices indicating which documents to retain. 42 | 4. **Output Format**: 43 | Thought: [Your analysis for each document] 44 | Action: [List of document indices to retain, separated by commas]""" 45 | 46 | FILTER_PROMPT_DIRECT_RETRIEVE_SUFFIX = """Now, process the following question:\n\nQuestion: {question} 47 | 48 | Documents: 49 | {documents}""" 50 | 51 | FILTER_direct_human_input = """Question: {question} 52 | 53 | Documents: 54 | {documents}""" 55 | 56 | FILTER_PROMPT_DIRECT_RETRIEVE = FILTER_PROMPT_DIRECT_RETRIEVE_PREFIX + "\n\n" + FILTER_PROMPT_DIRECT_RETRIEVE_SUFFIX 57 | FILTER_PROMPT_DIRECT_RETRIEVE_FEW_SHOT = FILTER_PROMPT_DIRECT_RETRIEVE_PREFIX + "\n\n" + FEW_SHOT_PROMPT + "\n\n" + FILTER_PROMPT_DIRECT_RETRIEVE_SUFFIX -------------------------------------------------------------------------------- /C-3PO/prompt/planning.py: -------------------------------------------------------------------------------- 1 | prompt_planning_prefix = """You are an expert assistant tasked with analyzing the following question and formulating a detailed plan. You will utilize a retrieval system to gather relevant information in your planning. Your goal is to analysis the question and provide a structured sequence of actions to address it effectively. 2 | 3 | Instructions: 4 | 1. **Question Analysis**: Identifying the core components of the question. Determine what key information we currently know and what additional information is needed through retrieval. 5 | 2. **Step By Step Planning**: Develop a detailed plan step by step. Focus on the planning process rather than providing direct answers. 6 | 3. **Focus on Planning**: Keep your response clear and structured, concentrating solely on the analysis and planning aspects.{dataset_instructions}""" 7 | 8 | 9 | prompt_planning_suffix_few_shot = """Now, process the following question: 10 | Question: {question}\n""" 11 | 12 | few_shot_template = """Here are some examples: 13 | {examples}""" 14 | 15 | prompt_llm_planning = prompt_planning_prefix + "\n\n" + prompt_planning_suffix_few_shot 16 | 17 | prompt_llm_planning_few_shot = prompt_planning_prefix + "\n\n" + few_shot_template + "\n\n" + prompt_planning_suffix_few_shot 18 | 19 | 20 | ASQA_DECISION = """\n4. Keep in mind that the question may be ambiguous and may have multiple correct answers. Ensure that your planning outlines are clear, especially for ambiguous questions.""" 21 | POPQA_DECISION = """\n4. Keep in mind that the question mainly asks about the object entity that holds a certain relationship with the given subject entity. There may be multiple correct answers. Ensure that your planning outlines are clear, especially for ambiguous questions.""" 22 | OTHERS_DECISION = """\n4. Keep in mind that the question may be compositional and require intermediate analysis to deduce the final answer. Ensure that your planning outlines are clear.""" 23 | 24 | instruct_for_each_dataset = { 25 | "ASQA": ASQA_DECISION, 26 | "PopQA": POPQA_DECISION, 27 | "2WikiMultiHopQA": OTHERS_DECISION, 28 | "NaturalQuestions": OTHERS_DECISION, 29 | "TriviaQA": OTHERS_DECISION, 30 | } -------------------------------------------------------------------------------- /C-3PO/proxy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import sleep 3 | from openai import OpenAI 4 | from typing import List 5 | from pebble import ProcessPool, ThreadPool 6 | 7 | from utils import load_agent 8 | 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | def load_proxy(args): 14 | available_gpus = os.environ.get('CUDA_VISIBLE_DEVICES', "0").split(',') 15 | logger.info(f"available_gpus: {available_gpus}, NUM_GPUS: {len(available_gpus)}") 16 | proxy = Proxy(args, num_gpus=len(available_gpus)) 17 | return proxy 18 | 19 | class Proxy(): 20 | 21 | def __init__( 22 | self, 23 | args, 24 | num_gpus: int = 1, 25 | ): 26 | self.backend = args.backend 27 | self.model_type = args.model_type 28 | 29 | if self.model_type == "proxy": 30 | self.llm, self.sampling_params, self.server_process = load_agent(args, num_gpus) 31 | self.concurrency = args.proxy_concurrency 32 | self.model_name = "default" 33 | 34 | elif self.model_type == "gpt": 35 | self.model_name = args.proxy_model_name 36 | self.temperature = args.temperature 37 | self.top_p = args.top_p 38 | self.max_tokens = args.max_tokens 39 | self.concurrency = args.gpt_proxy_concurrency 40 | self.llm = OpenAI( 41 | api_key=args.openai_api_key, 42 | base_url=args.proxy_url, 43 | ) 44 | self.sampling_params = sampling_params = { 45 | "temperature": args.temperature, 46 | "top_p": args.top_p, 47 | "top_k": args.top_k, 48 | "max_new_tokens": args.max_tokens, 49 | } 50 | 51 | else: 52 | raise NotImplementedError 53 | 54 | def sglang_generate(self, query_list): 55 | index = query_list.get("index", 0) 56 | messages = query_list["messages"] 57 | sampling_params = query_list["sampling_params"] 58 | outputs = self.llm.chat.completions.create( 59 | model=self.model_name, 60 | messages=messages, 61 | n=sampling_params['n'], 62 | temperature=sampling_params['temperature'], 63 | top_p=sampling_params['top_p'], 64 | max_tokens=sampling_params['max_new_tokens'], 65 | ) 66 | return {"index": index, "outputs": outputs} 67 | 68 | def generate(self, prompt_dict: List[dict], n_generate_sample: int): 69 | need_generate_prompt = [item for item in prompt_dict if item['need_generate']] 70 | if len(need_generate_prompt) == 0: 71 | return prompt_dict 72 | 73 | prompt = [item['template_text'] for item in need_generate_prompt] 74 | 75 | if self.model_type == "proxy" and self.backend == "vllm": 76 | self.sampling_params.n = n_generate_sample 77 | self.sampling_params.best_of = n_generate_sample 78 | outputs = self.llm.generate(prompt, self.sampling_params, use_tqdm=True) 79 | 80 | for item, output in zip(need_generate_prompt, outputs): 81 | response = [item.text for item in output.outputs] 82 | item['outputs'] = response 83 | return prompt_dict 84 | 85 | else: 86 | self.sampling_params["n"] = n_generate_sample 87 | query_list = [{"index": i, "messages": p, "sampling_params": self.sampling_params} for i, p in enumerate(prompt)] 88 | with ThreadPool(max_workers=self.concurrency) as pool: 89 | future = pool.map(self.sglang_generate, query_list) 90 | outputs = list(future.result()) 91 | sorted_outputs = sorted(outputs, key=lambda x: x["index"]) 92 | 93 | assert len(sorted_outputs) == len(need_generate_prompt) 94 | for item, output in zip(need_generate_prompt, sorted_outputs): 95 | response = [o.message.content for o in output['outputs'].choices] 96 | item['outputs'] = response 97 | return prompt_dict -------------------------------------------------------------------------------- /C-3PO/retrieve/retrieve_engine.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from typing import List 3 | 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | timeout_duration = 360 9 | 10 | class Dense_Retrieve(): 11 | def __init__(self, server_url, top_k): 12 | self.server_url = server_url # TODO: 或许能够通过多个url进行负载均衡 13 | self.headers = {"Content-Type": "application/json"} 14 | self.top_k = top_k 15 | 16 | def retrieve(self, messages: List[str], retrieve_times: List[int]) -> List[List[str]]: 17 | n_docs = (max(retrieve_times) + 1) * self.top_k 18 | data = { 19 | "queries": messages, 20 | "n_docs": n_docs, 21 | } 22 | try: 23 | response = requests.post(self.server_url, json=data, headers=self.headers, timeout=timeout_duration) 24 | except requests.ConnectionError as e: 25 | logger.info(f"Retrieve Connection Error: {e}") 26 | raise RuntimeError("Retrieve Connection Error") from e 27 | except requests.Timeout as e: 28 | logger.info(f"Retrieve Timeout: {e}") 29 | raise RuntimeError("Retrieve Timeout") from e 30 | except requests.HTTPError as e: 31 | logger.info(f"Request Error: {e}") 32 | raise RuntimeError(f"HTTP Error: {e.response.status_code}") from e 33 | except requests.RequestException as e: 34 | logger.info(f"Request Error: {e}") 35 | raise RuntimeError("Request Error") from e 36 | else: 37 | selected_passages = [] 38 | for r_time, passages in zip(retrieve_times, response.json()): 39 | selected_passages.append(passages[r_time * self.top_k:self.top_k * (r_time + 1)]) 40 | return selected_passages -------------------------------------------------------------------------------- /C-3PO/retrieve/retriever.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | try: 4 | from retrieve.retrieve_engine import Dense_Retrieve 5 | from retrieve.search_engine import Search_Engine 6 | except: 7 | from retrieve_engine import Dense_Retrieve 8 | from search_engine import Search_Engine 9 | 10 | class BasicRAG: 11 | def __init__(self, args): 12 | self.args = args 13 | self.retriever_type = args.retriever_type 14 | 15 | if self.retriever_type == "dense": 16 | self.retriever = Dense_Retrieve(self.args.retrieve_server_url, self.args.retrieve_top_k) 17 | elif self.retriever_type == "search_engine": 18 | search_engine_cache_file = self.args.search_engine_cache_file if self.args.search_engine_cache else None 19 | self.retriever = Search_Engine(self.args.search_engine_url, self.args.search_scene, self.args.retrieve_top_k, search_engine_cache_file) 20 | else: 21 | raise NotImplementedError 22 | 23 | 24 | def retrieve(self, queries: List[str], retrieve_times: List[int]) -> List[List[str]]: 25 | # length of each query less than args.max_query_length 26 | if self.retriever_type == "dense": 27 | passages = self.retriever.retrieve(messages=queries, retrieve_times=retrieve_times) 28 | return passages 29 | elif self.retriever_type == "search_engine": 30 | passages = self.retriever.retrieve(messages=queries, retrieve_times=retrieve_times) 31 | return passages 32 | else: 33 | raise NotImplementedError 34 | 35 | 36 | if __name__=="__main__": 37 | import argparse 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--retriever_type', type=str, default="dense", choices=["dense", "BM25"]) 40 | parser.add_argument('--retrieve_server_url', type=str, default="http://10.32.18.155:35004/search") 41 | parser.add_argument('--top_k', type=int, default=10) 42 | parser.add_argument('--max_query_length', type=int, default=100) 43 | 44 | args = parser.parse_args() 45 | rag = BasicRAG(args) 46 | rag.retrieve(["What is the capital of China?"]) 47 | print() -------------------------------------------------------------------------------- /C-3PO/retrieve/search_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import requests 5 | from time import sleep 6 | from typing import List 7 | from pebble import ThreadPool 8 | from tqdm import tqdm 9 | 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | timeout_duration = 360 15 | 16 | class Search_Engine(): 17 | def __init__(self, server_url, search_scene, top_k, search_engine_cache_file=None): 18 | self.server_url = server_url # TODO: 或许能够通过多个url进行负载均衡 19 | self.search_scene = search_scene 20 | self.headers = { 21 | 'Content-Type': 'application/json', 22 | '__d_head_qto': '5000', 23 | '__d_head_app': 'Test', 24 | "Authorization": "Bearer {your_key}", 25 | } 26 | self.top_k = top_k 27 | self.search_engine_cache_file = search_engine_cache_file 28 | self.retry_times = 10 29 | if search_engine_cache_file is not None: 30 | self.search_engine_cache_file = osp.join(search_engine_cache_file, search_scene, "cache_search.json") 31 | if osp.exists(self.search_engine_cache_file): 32 | with open(self.search_engine_cache_file, "r") as f: 33 | self.cache = json.load(f) 34 | else: 35 | self.cache = {} 36 | 37 | def save_cache(self): 38 | if self.search_engine_cache_file is not None: 39 | os.makedirs(osp.dirname(self.search_engine_cache_file), exist_ok=True) 40 | new_dict = {} 41 | for key, value in self.cache.items(): 42 | if value['success'] == True: 43 | new_dict[key] = value 44 | with open(self.search_engine_cache_file, "w") as f: 45 | json.dump(self.cache, f, indent=2) 46 | 47 | def single_query(self, payload): 48 | for _ in range(self.retry_times): 49 | response = requests.post(self.server_url, headers=self.headers, json=payload).json() 50 | sleep(0.5) 51 | if response['status'] == 0: 52 | return {json.dumps(payload): response} 53 | 54 | raise ValueError(f"{response['message']}") 55 | 56 | def retrieve(self, messages: List[str], retrieve_times: List[int]) -> List[List[str]]: 57 | query_list, key_list = [], [] 58 | for message, r_time in zip(messages, retrieve_times): 59 | n_docs = (r_time + 1) * self.top_k 60 | payload = { 61 | "rid": "", 62 | "scene": self.search_scene, 63 | "uq": message, 64 | "debug": False, 65 | "fields": [], 66 | "page": 1, 67 | "rows": n_docs, 68 | "customConfigInfo": { 69 | "readpage": False, 70 | }, 71 | } 72 | key = json.dumps(payload) 73 | key_list.append(key) 74 | if key not in self.cache: 75 | query_list.append(payload) 76 | 77 | with ThreadPool(max_workers=1) as pool: 78 | future = pool.map(self.single_query, query_list) 79 | outputs = list(tqdm( 80 | future.result(), 81 | total=len(query_list), 82 | desc="Searching queries" 83 | )) 84 | 85 | outputs = {key: value for item in outputs for key, value in item.items()} 86 | selected_passages = [] 87 | for r_time, key in zip(retrieve_times, key_list): 88 | if key in self.cache: 89 | response = self.cache[key] 90 | else: 91 | response = outputs[key] 92 | if response['success'] == True: 93 | self.cache[key] = response 94 | passages = [{"id": item['_id'], "text": item['snippet'], "title": item['title']} for item in response['data']['docs']] 95 | selected_passages.append(passages[r_time * self.top_k:self.top_k * (r_time + 1)]) 96 | return selected_passages 97 | 98 | if __name__ == "__main__": 99 | url = "" 100 | search_scene = "" 101 | top_k = 10 102 | search_engine_cache_file = "path to cache_search" 103 | search_engine = Search_Engine(url, search_scene, top_k, search_engine_cache_file) 104 | queries = ["who is the director of Mera Naam Joker"] 105 | retrieve_times = [0] 106 | passages = search_engine.retrieve(queries, retrieve_times) 107 | print(passages) 108 | search_engine.save_cache() -------------------------------------------------------------------------------- /C-3PO/search/LLM_planning_role.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | 4 | from typing import List 5 | 6 | from tree.node import BaseNode, State 7 | 8 | from prompt.planning import prompt_llm_planning, prompt_llm_planning_few_shot, instruct_for_each_dataset 9 | 10 | from utils import few_shot_random_select 11 | 12 | class LLM_PLANNING_ROLE(): 13 | def __init__(self): 14 | pass 15 | 16 | def llm_planning_role(self, nodes: List[BaseNode]): 17 | if len(nodes) == 0: 18 | return 19 | self.llm_planning(nodes) 20 | 21 | async def async_generate(self, message_lst: List[str], model_url_lst: List[str], n_generate_sample: int): 22 | async with aiohttp.ClientSession() as session: 23 | tasks = [self.llm.async_query(msg, model_url, n_generate_sample) for msg, model_url in zip(message_lst, model_url_lst)] 24 | outputs = await asyncio.gather(*tasks) 25 | return outputs 26 | 27 | def llm_planning_prepare_input(self): 28 | nodes = self.planning_lst 29 | need_llm_input = [] 30 | if len(nodes) == 0: 31 | return need_llm_input 32 | 33 | for i, node in enumerate(nodes): 34 | info = { 35 | "examples": few_shot_random_select(self.few_shot_examples, 'planning', num=self.args.few_num, dict_num=self.args.dict_few_num), 36 | "question": self.tree.question, 37 | "dataset_instructions": instruct_for_each_dataset.get(self.args.dataname, ""), 38 | } 39 | input_text = prompt_llm_planning_few_shot.format_map(info) 40 | # model_url = self.args.llm_server_url[i % len(self.args.llm_server_url)] 41 | 42 | message =[ 43 | {"role": "system", "content": "You are a helpful assistant."}, 44 | {"role": "user", "content": input_text}, 45 | ] 46 | need_llm_input.append({ 47 | "id": self.get_tree_node_tag(node), 48 | "role": self.config['role']['LLM_PLANNING']['name'], 49 | "need_generate": True, 50 | "text": input_text, 51 | "message": message, 52 | "cache_key": f"{self.tree.question_id}_{self.tree.question}", 53 | "generate_config": { 54 | "n": self.n_plan_sample, 55 | "temperature": self.args.plan_temperature, 56 | "top_p": self.args.plan_top_p, 57 | } 58 | }) 59 | 60 | return need_llm_input 61 | 62 | def llm_planning_postprocess_output(self, outputs: dict): 63 | nodes = self.planning_lst 64 | if len(nodes) == 0: 65 | return 66 | 67 | # global dudeplicate due to plan step only in second layer and the nodes on first layer only [Planning] 68 | output_text = set() 69 | for node in nodes: 70 | tag = self.get_tree_node_tag(node) 71 | 72 | output = outputs[tag] 73 | for response in output['outputs']: 74 | if response in output_text: 75 | continue 76 | output_text.add(response) 77 | new_state = self.parse_planning(output['text'], response) 78 | new_node = BaseNode(node_id=node.node_id + f".{len(node.children)}", parent=node, state=new_state, depth=node.depth+1, role=self.config['role']['LLM_PLANNING']['name'], plan=node.plan, documents=node.documents, query_history=node.query_history) 79 | new_node.update_plan() 80 | node.add_child(new_node) 81 | 82 | def parse_planning(self, input_text:str, output_text: str) -> State: 83 | if len(output_text) == 0: 84 | return State(input_text=input_text, output_text=output_text, is_terminal=True, terminal_reason="Empty Output") 85 | else: 86 | return State(input_text=input_text, output_text=output_text, action=self.config['role']['LLM_PLANNING']['actions']['LLM_PLANNING_ACTION']) 87 | 88 | -------------------------------------------------------------------------------- /C-3PO/search/LLM_query_role.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from typing import List 4 | 5 | from tree.node import BaseNode, State 6 | 7 | class LLM_QUERY_ROLE(): 8 | def __init__(self): 9 | pass 10 | 11 | def llm_query_role(self, nodes: List[BaseNode]): 12 | if len(nodes) == 0: 13 | return 14 | self.llm_query(nodes) 15 | 16 | def llm_query_prepare_input(self): 17 | nodes = self.llm_query_lst 18 | need_llm_input = [] 19 | if len(nodes) == 0: 20 | return need_llm_input 21 | 22 | for node in nodes: 23 | need_llm_input.append({ 24 | "id": self.get_tree_node_tag(node), 25 | "role": self.config['role']['QUERY_LLM']['name'], 26 | "need_generate": True, 27 | "text": node.state.output_text, 28 | "message": node.state.action_input, 29 | "generate_config": node.state.thought, 30 | }) 31 | 32 | return need_llm_input 33 | 34 | def llm_query_postprocess_output(self, outputs: dict): 35 | nodes = self.llm_query_lst 36 | if len(nodes) == 0: 37 | return 38 | 39 | for node in nodes: 40 | tag = self.get_tree_node_tag(node) 41 | output = outputs[tag] 42 | output_text = set() 43 | for response in output['outputs']: 44 | if response in output_text: 45 | continue 46 | output_text.add(response) 47 | new_state = self.parse_query_LLM(output['text'], response) 48 | new_node = BaseNode(node_id=node.node_id + f".{len(node.children)}", parent=node, state=new_state, depth=node.depth+1, role=self.config['role']['QUERY_LLM']['name'], plan=node.plan, documents=node.documents, query_history=node.query_history) 49 | 50 | new_node.check_answer(qa_pairs=self.tree.qa_pairs, answers=self.tree.answers, is_asqa=self.tree.is_asqa) 51 | self.tree.add_solution_nodes(new_node.return_solution_state(node_id=new_node.node_id)) 52 | node.add_child(new_node) 53 | 54 | def parse_query_LLM(self, input_text:str, output_text: str) -> State: 55 | if len(output_text) == 0: 56 | return State(input_text=input_text, output_text=output_text, is_terminal=True, terminal_reason="Empty Output") 57 | else: 58 | return State(input_text=input_text, output_text=output_text, is_terminal=True, action=self.config['role']['QUERY_LLM']['actions']['QUERY_LLM_ACTION'], terminal_reason="End Reason") -------------------------------------------------------------------------------- /C-3PO/search/decide_prompt_role.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import aiohttp 3 | 4 | from typing import List 5 | 6 | from prompt.decide_prompt_identity import NO_DOCUMENT_PROMPT, DOCUMENT_PROMPT, instruct_for_each_dataset 7 | from prompt.decide_prompt_identity_v2 import DOCUMENT_PROMPT_FEW_SHOT 8 | 9 | from tree.node import BaseNode, State 10 | 11 | from utils import few_shot_random_select 12 | 13 | 14 | class DECIDE_PROMPT_ROLE(): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def decide_prompt_prepare_input(self): 19 | nodes = self.decide_prompt_lst 20 | if len(nodes) == 0: 21 | return [] 22 | 23 | if self.args.decide_prompt == "identity": 24 | return self.identity_prompt(nodes) 25 | else: 26 | raise ValueError(f"Invalid decide_prompt method: {self.args.decide_prompt}") 27 | 28 | 29 | def identity_prompt(self, nodes: List[BaseNode]): 30 | need_proxy_input = [] 31 | # input_text_lst, message_lst, model_url_lst = [], [], [] 32 | for i, node in enumerate(nodes): 33 | if node.document_is_empty(): 34 | infos = { 35 | "question": self.tree.question, 36 | "dataset_instructions": instruct_for_each_dataset.get(self.args.dataname, "") 37 | } 38 | input_text = NO_DOCUMENT_PROMPT.format_map(infos) 39 | else: 40 | if self.args.llm_query_few_shot: 41 | infos = { 42 | "documents": self.str_documents(node), 43 | "dataset_instructions": instruct_for_each_dataset.get(self.args.dataname, ""), 44 | "question": self.tree.question, 45 | "examples": few_shot_random_select(self.few_shot_examples, 'answer', num=self.args.few_num, dict_num=1), 46 | } 47 | input_text = DOCUMENT_PROMPT_FEW_SHOT.format_map(infos) 48 | else: 49 | infos = { 50 | "documents": self.str_documents(node), 51 | "dataset_instructions": instruct_for_each_dataset.get(self.args.dataname, ""), 52 | "question": self.tree.question, 53 | } 54 | input_text = DOCUMENT_PROMPT.format_map(infos) 55 | 56 | message =[ 57 | {"role": "system", "content": "You are a helpful assistant."}, 58 | {"role": "user", "content": input_text}, 59 | ] 60 | 61 | need_proxy_input.append({ 62 | "id": self.get_tree_node_tag(node), 63 | "role": self.config['role']['DECIDE_PROMPT']['name'], 64 | "need_generate": False, 65 | "text": input_text, 66 | "message": message, 67 | "generate_config": { 68 | "n": self.n_answer_sample, 69 | "temperature": self.args.answer_temperature, 70 | "top_p": self.args.answer_top_p, 71 | } 72 | }) 73 | return need_proxy_input 74 | 75 | def decide_prompt_postprocess_output(self, outputs: dict): 76 | nodes = self.decide_prompt_lst 77 | if len(nodes) == 0: 78 | return 79 | 80 | if self.args.decide_prompt in ["identity", "identity_few_shot"]: 81 | self.identity_prompt_postprocess_output(outputs) 82 | else: 83 | raise ValueError(f"Invalid decide_prompt method: {self.args.decide_prompt}") 84 | 85 | 86 | def identity_prompt_postprocess_output(self, outputs: dict): 87 | nodes = self.decide_prompt_lst 88 | for node in nodes: 89 | tag = self.get_tree_node_tag(node) 90 | output = outputs[tag] 91 | new_state = State(output_text=output['text'], action=self.config['role']['DECIDE_PROMPT']['actions']['DECIDE_PROMPT_ACTION'], action_input=output['message'], thought=output['generate_config']) # 92 | new_node = BaseNode(node_id=node.node_id + f".{len(node.children)}", parent=node, state=new_state, depth=node.depth+1, role=self.config['role']['DECIDE_PROMPT']['name'], plan=node.plan, documents=node.documents, query_history=node.query_history) 93 | node.add_child(new_node) 94 | 95 | -------------------------------------------------------------------------------- /C-3PO/search/evaluation_role.py: -------------------------------------------------------------------------------- 1 | 2 | import yaml 3 | import asyncio 4 | 5 | from typing import List, Union 6 | 7 | from prompt.evaluation import evaluation_prompt, evaluation_prompt_few_shot 8 | 9 | from metrics import get_item_metrics 10 | 11 | from utils import few_shot_random_select 12 | 13 | from utils import load_few_shot 14 | 15 | 16 | class EVALUATION_ROLE(): 17 | def __init__(self, args, llm, sampling_params=None): 18 | if args.only_eval_answer: 19 | self.args = args 20 | self.llm = llm 21 | self.sampling_params = sampling_params 22 | self.invalid_message =[ 23 | {"role": "system", "content": "You are a helpful assistant."}, 24 | {"role": "user", "content": "hi"}, 25 | ] 26 | 27 | def evluation_prepare_input(self): 28 | need_proxy_input = [] 29 | for idx, solution in enumerate(self.save_tree_info['solution_nodes']): 30 | infos = { 31 | "question": self.save_tree_info['question'], 32 | "true_answer": str(self.save_tree_info['answers']), 33 | "long_answer": self.save_tree_info['tree'][solution['node_id']]['state']['output_text'], 34 | } 35 | input_text = evaluation_prompt.format_map(infos) 36 | message =[ 37 | {"role": "system", "content": "You are a helpful assistant."}, 38 | {"role": "user", "content": input_text}, 39 | ] 40 | need_proxy_input.append({ 41 | "id": self.get_extract_answer_tag(idx), 42 | "role": "evaluate_answer", 43 | "need_generate": True, 44 | "text": input_text, 45 | "message": message, 46 | "generate_config": { 47 | "n": self.n_answer_sample, 48 | "temperature": self.args.answer_temperature, 49 | "top_p": self.args.answer_top_p, 50 | } 51 | }) 52 | 53 | return need_proxy_input 54 | 55 | def evluation_postprocess_output(self, outputs: dict): 56 | for idx, solution in enumerate(self.save_tree_info['solution_nodes']): 57 | tag = self.get_extract_answer_tag(idx) 58 | output = outputs[tag] 59 | output_text = set() 60 | for response in output['outputs']: 61 | if response in output_text: 62 | continue 63 | 64 | answer_status = 1 if "true" in response.lower() else 0 65 | self.save_tree_info['tree'][solution['node_id']]['state']['eval_response'] = response 66 | self.save_tree_info['tree'][solution['node_id']]['state']['eval_status'] = answer_status 67 | 68 | for item in self.save_tree_info['solution_nodes']: 69 | item['eval_status'] = self.save_tree_info['tree'][item['node_id']]['state']['eval_status'] 70 | 71 | def evaluation_off_input(self, save_tree_info): 72 | # data_item 是 self.tree.save_tree() 73 | message_lst = [] 74 | assert len(save_tree_info['solution_nodes']) <=1, "only support one solution node" 75 | for i, solution in enumerate(save_tree_info['solution_nodes']): 76 | infos = { 77 | "question": save_tree_info['question'], 78 | "true_answer": str(save_tree_info['answers']), 79 | "long_answer": save_tree_info['tree'][solution['node_id']]['state']['output_text'], 80 | } 81 | input_text = evaluation_prompt.format_map(infos) 82 | message =[ 83 | {"role": "system", "content": "You are a helpful assistant."}, 84 | {"role": "user", "content": input_text}, 85 | ] 86 | message_lst.append(message) 87 | 88 | return message_lst if len(message_lst) > 0 else [self.invalid_message] 89 | 90 | def evaluation_off_output(self, outputs, save_tree_info): 91 | for solution, output in zip(save_tree_info['solution_nodes'], outputs): 92 | if self.args.only_eval_answer and self.args.model_type == "proxy": 93 | if self.args.backend == "vllm": 94 | response = output.outputs[0].text 95 | elif self.args.backend == "sglang": 96 | response = output['outputs'].choices[0].message.content 97 | else: 98 | response = output.choices[0].message.content 99 | 100 | answer_status = 1 if "true" in response.lower() else 0 101 | save_tree_info['tree'][solution['node_id']]['state']['eval_response'] = response 102 | save_tree_info['tree'][solution['node_id']]['state']['eval_status'] = answer_status 103 | 104 | for item in save_tree_info['solution_nodes']: 105 | item['eval_status'] = save_tree_info['tree'][item['node_id']]['state']['eval_status'] 106 | 107 | return save_tree_info 108 | 109 | def evaluation_off_generate(self, query_list): 110 | index = query_list.get("index", 0) 111 | messages = query_list["messages"] 112 | sampling_params = query_list["sampling_params"] 113 | 114 | outputs = self.llm.chat.completions.create( 115 | model="default", 116 | messages=messages, 117 | n=sampling_params['n'], 118 | temperature=sampling_params['temperature'], 119 | top_p=sampling_params['top_p'], 120 | max_tokens=sampling_params['max_new_tokens'], 121 | ) 122 | return {"index": index, "outputs": outputs} 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /C-3PO/search/make_decision_role.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import List 4 | 5 | from prompt.decision import prompt_decision_making_few_shot, instruct_for_each_dataset, prompt_decision_making 6 | from tree.node import BaseNode, State 7 | from utils import few_shot_random_select 8 | 9 | class MAKING_DECISION_ROLE(): 10 | def __init__(self): 11 | self.legal_decision_action = [self.config['role']['MAKE_DECISION']['actions']['RETRIEVAL_ACTION'], self.config['role']['MAKE_DECISION']['actions']['NO_RETRIEVAL_ACTION'], self.config['role']['MAKE_DECISION']['actions']['PLANNING_ACTION']] 12 | 13 | def make_decision_role(self, nodes: List[BaseNode]): 14 | if len(nodes) == 0: 15 | return 16 | self.making_decision(nodes) 17 | 18 | def making_decision_prepare_input(self): # 输入一定不会溢出 19 | nodes = self.make_decision_lst 20 | need_proxy_input = [] 21 | if len(nodes) == 0: 22 | return need_proxy_input 23 | for node in nodes: 24 | if self.args.few_shot: 25 | info = { 26 | "examples": few_shot_random_select(self.few_shot_examples, 'decision', num=self.args.few_num, dict_num=self.args.dict_few_num), 27 | "question": self.tree.question, 28 | "dataset_instructions": instruct_for_each_dataset.get(self.args.dataname, ""), 29 | } 30 | input_text = prompt_decision_making_few_shot.format_map(info) 31 | else: 32 | info = { 33 | "question": self.tree.question, 34 | "dataset_instructions": instruct_for_each_dataset.get(self.args.dataname, ""), 35 | } 36 | input_text = prompt_decision_making.format_map(info) 37 | 38 | if not self.args.force_decision: 39 | # input_template_text = self.organize_prompt(prompt=input_text) 40 | need_proxy_input.append({ 41 | "id": self.get_tree_node_tag(node), 42 | "role": self.config['role']['MAKE_DECISION']['name'], 43 | "need_generate": True, 44 | "text": input_text, 45 | "template_text": self.organize_message(prompt=input_text) if self.args.backend == "sglang" else self.organize_prompt(prompt=input_text) 46 | }) 47 | else: 48 | input_template_text = None 49 | if self.args.force_action == "Planning": 50 | outputs = [f"[Planning]"] 51 | elif self.args.force_action == "Retrieval": 52 | outputs = [f"[Retrieval] {self.tree.question}"] 53 | elif self.args.force_action == "No Retrieval": 54 | outputs = [f"[No Retrieval]"] 55 | else: 56 | raise ValueError(f"Invalid force action: {self.args.force_action}") 57 | need_proxy_input.append({ 58 | "id": self.get_tree_node_tag(node), 59 | "role": self.config['role']['MAKE_DECISION']['name'], 60 | "need_generate": False, 61 | "text": input_text, 62 | "template_text": input_template_text, 63 | "outputs": outputs 64 | }) 65 | 66 | return need_proxy_input 67 | 68 | def making_decision_postprocess_output(self, outputs: dict): 69 | nodes = self.make_decision_lst 70 | if len(nodes) == 0: 71 | return 72 | 73 | # parse output 74 | output_text = set() # do dudeplicate local dudeplicate == global dudeplicate 75 | for node in nodes: # len(nodes) must = 1 76 | tag = self.get_tree_node_tag(node) 77 | 78 | output = outputs[tag] 79 | for response in output['outputs']: 80 | if response in output_text: 81 | continue 82 | output_text.add(response) 83 | new_state = self.parse_decision(output['text'], response) 84 | new_node = BaseNode(node_id=node.node_id + f".{len(node.children)}", parent=node, state=new_state, depth=node.depth+1, role=self.config['role']['MAKE_DECISION']['name'], plan=node.plan, documents=node.documents, query_history=node.query_history) 85 | node.add_child(new_node) 86 | 87 | def parse_decision(self, input_text:str, output_text: str) -> State: 88 | regex = r'\[(.*?)\](.*)' 89 | # state_list = [] 90 | # for text in texts: 91 | match = re.search(regex, output_text, re.DOTALL) 92 | if match: 93 | action = match.group(1).strip() # []中的文本 94 | content = match.group(2).strip() # []后面的文本 95 | content = " ".join(content.split()[:self.args.max_query_length]) 96 | 97 | state = State(input_text=input_text, output_text=output_text, action=action, action_input=content) 98 | if (action not in self.legal_decision_action) or (action==self.config['role']['MAKE_DECISION']['actions']['RETRIEVAL_ACTION'] and len(content)==0): 99 | state.is_terminal = True 100 | state.terminal_reason = "Invalid Action" 101 | 102 | else: 103 | state = State(input_text=input_text, output_text=output_text, is_terminal=True, terminal_reason="Invalid Format") 104 | 105 | # state_list.append(state) 106 | return state -------------------------------------------------------------------------------- /C-3PO/solver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | import time 5 | import json 6 | 7 | from typing import List 8 | from tqdm import tqdm 9 | from pebble import ThreadPool 10 | 11 | from proxy import load_proxy 12 | from llm_server import LLM_SERVER 13 | from search.search import Search 14 | from retrieve.retriever import BasicRAG 15 | 16 | def write_solutions_to_file(all_solutions, file_path): 17 | if file_path is None: 18 | return 19 | with open(file_path, 'a') as writer: 20 | for solution_item in all_solutions: 21 | writer.write(json.dumps(solution_item) + '\n') 22 | writer.flush() 23 | 24 | class Solver(): 25 | def __init__(self, args): 26 | self.args = args 27 | self.proxy = load_proxy(args) 28 | self.llm_server = LLM_SERVER(args) if self.args.llm_server_type == "online" else LLM_SERVER(args, llm=self.proxy.llm, sampling_params=self.proxy.sampling_params) 29 | self.n_generate_sample = args.n_generate_sample 30 | 31 | self.retrieve = BasicRAG(args) 32 | 33 | self.epoch_file_path = None 34 | self.final_file_path = None 35 | 36 | def solve(self, searches: List[Search]): 37 | 38 | search_length = len(searches) 39 | begin_time = time.time() 40 | all_save_tree_info = [] 41 | node2documents = {} 42 | for i in tqdm(range(self.args.max_iter), desc="Step"): 43 | # if i in [3, 4, 6]: 44 | if i in [4, 6]: 45 | self.n_generate_sample = max(self.n_generate_sample - 1, 1) 46 | 47 | # prepare input 48 | all_proxy_input, all_llm_input = [], [] 49 | for search in searches: 50 | proxy_input, llm_input = search.prepare_input(node2documents) 51 | all_proxy_input.extend(proxy_input) 52 | all_llm_input.extend(llm_input) 53 | 54 | # generate 55 | # all_proxy_input = self.proxy.generate(all_proxy_input, self.n_generate_sample) 56 | # all_proxy_input = {item["id"]: item for item in all_proxy_input} 57 | 58 | # all_llm_input = self.llm_server.generate(all_llm_input) 59 | # all_llm_input = {item["id"]: item for item in all_llm_input} 60 | 61 | with ThreadPool(max_workers=1) as pool: 62 | # llm_server 在线程中运行 63 | llm_future = pool.schedule( 64 | self.llm_server.generate, 65 | args=(all_llm_input,) 66 | ) 67 | 68 | # proxy.generate (包含 Ray 调用)在主线程运行 69 | proxy_result = self.proxy.generate(all_proxy_input, self.n_generate_sample) 70 | # 获取 llm_server 结果 71 | llm_result = llm_future.result() 72 | all_proxy_input = {item["id"]: item for item in proxy_result} 73 | all_llm_input = {item["id"]: item for item in llm_result} 74 | 75 | # postprocess output 76 | finish_searches = [] 77 | for search in searches: 78 | search.postprocess_output(all_proxy_input, all_llm_input) 79 | if search.search_finish(): 80 | finish_searches.append(search) 81 | 82 | for search in finish_searches: 83 | searches.remove(search) 84 | # [v for v in all_proxy_input.values() if '[Retrieval]' in v['outputs'][0]] # [v for v in all_proxy_input.values() if '[Planning]' in v['outputs'][0]] 85 | # do retrieve 86 | node2documents = self.batch_retrieve(searches) 87 | 88 | # save_solution 89 | all_extract_answer_input = [] 90 | if len(finish_searches) > 0: 91 | for search in finish_searches: 92 | extract_answer_input = search.save_solution_prepare_input() 93 | all_extract_answer_input.extend(extract_answer_input) 94 | 95 | all_extract_answer_input = self.llm_server.generate(all_extract_answer_input) 96 | all_extract_answer_input = {item["id"]: item for item in all_extract_answer_input} 97 | 98 | save_tree_info_list = [] 99 | for search in finish_searches: 100 | save_tree_info = search.save_solution_postprocess_output(all_extract_answer_input) 101 | save_tree_info_list.append(save_tree_info) 102 | 103 | write_solutions_to_file(save_tree_info_list, self.epoch_file_path) 104 | all_save_tree_info.extend(save_tree_info_list) 105 | 106 | logger.info(f"Epoch finished. Save all solutions") 107 | write_solutions_to_file(all_save_tree_info, self.final_file_path) 108 | logger.info(f"Epoch finished. Time: {time.time() - begin_time}. Len: {search_length}") 109 | return all_save_tree_info 110 | 111 | def batch_retrieve(self, searches, batch_size=4096): 112 | retrieve_input = {"node_id": [], "queries": [], "retrieve_times": []} 113 | for search in searches: 114 | node_id, queries, retrieve_times = search.do_retrieve_prepare_input() 115 | retrieve_input["node_id"].extend(node_id) 116 | retrieve_input["queries"].extend(queries) 117 | retrieve_input["retrieve_times"].extend(retrieve_times) 118 | 119 | node_id_list, query_list, documents_list = [], [], [] 120 | if len(retrieve_input["node_id"]) > 0: 121 | # 批量处理 122 | total_queries = len(retrieve_input["queries"]) 123 | for i in range(0, total_queries, batch_size): # TODO: 多线程 多url 124 | # 获取当前批次的查询和检索时间 125 | batch_node_id = retrieve_input["node_id"][i:i + batch_size] 126 | batch_queries = retrieve_input["queries"][i:i + batch_size] 127 | batch_retrieve_times = retrieve_input["retrieve_times"][i:i + batch_size] 128 | 129 | # 执行检索 130 | batch_documents = self.retrieve.retrieve(batch_queries, batch_retrieve_times) 131 | 132 | # 将当前批次的结果添加到总的文档列表中 133 | node_id_list.extend(batch_node_id) 134 | query_list.extend(batch_queries) 135 | documents_list.extend(batch_documents) 136 | 137 | # 处理searches 138 | assert len(node_id_list) == len(documents_list), f"{len(node_id_list)} != {len(documents_list)}" 139 | node2documents = {node_id: {"query": query, "documents": documents} for node_id, query, documents in zip(node_id_list, query_list, documents_list)} 140 | else: 141 | node2documents = {} 142 | 143 | return node2documents 144 | 145 | -------------------------------------------------------------------------------- /C-3PO/tree/node.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from typing import List, Dict 4 | from metrics import get_item_metrics 5 | 6 | class State(): 7 | """ 8 | This class represents the state of a node 9 | param text: new generation text in this node 10 | param is_terminal: whether stopping 11 | 12 | """ 13 | def __init__( 14 | self, 15 | input_text="", 16 | output_text="", 17 | thought=None, 18 | action=None, 19 | action_input=None, 20 | observation=None, 21 | is_terminal=False, 22 | terminal_reason=None, 23 | retrieve_query=None, 24 | retrieved_all_documents=None, 25 | ): 26 | self.input_text = input_text 27 | self.output_text = output_text 28 | self.is_terminal = is_terminal 29 | self.terminal_reason = terminal_reason 30 | 31 | self.thought = thought 32 | self.action = action 33 | self.action_input = action_input 34 | self.observation = observation 35 | 36 | self.retrieve_query = retrieve_query 37 | self.retrieved_all_documents = retrieved_all_documents 38 | 39 | self.final_answer = None # extract key word from LLM 40 | self.final_status = None 41 | 42 | self.key_answer = None 43 | self.key_status = None 44 | 45 | def to_dict(self): 46 | return self.__dict__ 47 | 48 | 49 | class BaseNode(): 50 | """ 51 | This class defines a node of the tree 52 | param parent: parent node 53 | param state: state of current node 54 | param P: prior probability 55 | param total_value: the expected probability of solving this problem 56 | """ 57 | def __init__(self, node_id=None, parent=None, state=None, P=None, depth=0, role=None, plan=None, documents=[], query_history={}): 58 | self.node_id = node_id 59 | self.parent = parent 60 | self.state = state 61 | self.depth = depth 62 | self.P = P 63 | self.role = role 64 | 65 | self.children = [] 66 | 67 | self.documents = copy.deepcopy(documents) # document {"title": str, "text": str, ...} 68 | self.plan = plan 69 | self.query_history = copy.deepcopy(query_history) 70 | 71 | def is_leaf(self) -> bool: 72 | return len(self.children) == 0 73 | 74 | def add_child(self, child): 75 | self.children.append(child) 76 | 77 | def extend_documents(self, documents: List[Dict]): 78 | for document in documents: 79 | if document not in self.documents: 80 | self.documents.append(copy.deepcopy(document)) 81 | 82 | def remove_noinfo_documents(self, query_text): 83 | new_documents = [doc for doc in self.documents if query_text != doc["text"]] 84 | self.documents = new_documents 85 | 86 | def update_plan(self): 87 | self.plan = self.state.output_text 88 | 89 | def document_is_empty(self) -> bool: 90 | return len(self.documents) == 0 91 | 92 | def check_answer(self, qa_pairs: List[Dict], answers: List[str], is_asqa=False): 93 | self.state.final_answer = self.state.output_text 94 | item = {"qa_pairs": qa_pairs, "answers": answers, "response": self.state.final_answer} 95 | self.state.final_status = get_item_metrics(item, is_asqa=is_asqa) 96 | 97 | def return_solution_state(self, node_id: str): 98 | return { 99 | "node_id": node_id, 100 | "final_status": self.state.final_status, 101 | "key_status": self.state.key_status, 102 | } 103 | 104 | def return_state(self): 105 | # return self.state 106 | infos = { 107 | "role": self.role, 108 | "plan": self.plan, 109 | "documents": self.documents, 110 | "query_history": self.query_history, 111 | "state": self.state.to_dict(), 112 | } 113 | return infos 114 | -------------------------------------------------------------------------------- /C-3PO/tree/tree.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from tree.node import BaseNode, State 4 | 5 | 6 | class Tree(): 7 | def __init__(self, args, data_item, root=None, server=False): 8 | self.args = args 9 | self.server = server 10 | if not server: 11 | self.question_id = data_item['id'] 12 | self.question = data_item['question'] 13 | self.qa_pairs = data_item['qa_pairs'] 14 | self.answers = data_item['answers'] # list 15 | self.is_asqa = True if self.args.dataname == "ASQA" else False 16 | assert isinstance(self.answers, list), "answers should be a list" 17 | else: 18 | self.question_id = None 19 | self.question = data_item 20 | self.qa_pairs = None 21 | self.answers = [""] 22 | self.is_asqa = False 23 | 24 | if root is None: 25 | self.root = BaseNode(node_id='0', parent=None, state=State(), P=1) 26 | else: 27 | self.root = root 28 | 29 | self.solution_nodes = [] 30 | 31 | def add_solution_nodes(self, node_id: str): 32 | self.solution_nodes.append(node_id) 33 | 34 | def save_tree(self): 35 | 36 | candidates = [self.root] 37 | states = {} 38 | while candidates: 39 | node = candidates.pop(0) 40 | states[node.node_id] = node.return_state() 41 | if not node.is_leaf(): 42 | candidates.extend(node.children) 43 | 44 | return_infos = { 45 | "question_id": self.question_id, 46 | "question": self.question, 47 | "qa_pairs": self.qa_pairs, 48 | "answers": self.answers, 49 | "tree": states, 50 | "solution_nodes": self.solution_nodes, 51 | } 52 | return return_infos 53 | 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # C-3PO: Compact Plug-and-Play Proxy Optimization to Achieve Human-like Retrieval-Augmented Generation 4 | 5 | 6 | 7 | 8 |
9 | 10 | This repo contains a proxy-centric alignment framework (C-3PO) that bridges the gap between retrievers and LLMs. Instead of modifying existing components in RAG, C-3PO introduces multi-agent system within a lightweight proxy model to simulate humen-like behaviors that optimizes the entire RAG pipeline while maintaining plug-and-play compatibility. 11 | 12 |
13 | c-3po_framework 14 |
15 | 16 | ## :boom: News 17 | - **[2025.05.01]** Our C-3PO is accepted by ICML 2025. 18 | - **[2025.03.02]** Release our Code. 19 | - **[2025.02.12]** Release our [Demo](https://www.modelscope.cn/studios/Decaderan/C-3PO) on the ModelScope. 20 | - **[2025.02.10]** Release our paper [C-3PO](https://arxiv.org/abs/2502.06205) on the Arxiv. 21 | 22 | 25 | 26 | 27 | ## :honeybee: Deploy LLM and Retrieval services 28 | 29 | ### Step1: Python Environment 30 | For C-3PO (also works for LLM server) 31 | ```bash 32 | conda create -n c3po python=3.11 33 | conda activate c3po 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | For Retrieval (dense model) 38 | ```bash 39 | conda create -n faiss python=3.11 40 | conda activate faiss 41 | pip install -r retrieval_requirements.txt 42 | ``` 43 | 44 | ### Step2: Download LLM from Huggingface 45 | Please download the following models from huggingface: 46 | ``` 47 | Qwen2-0.5B 48 | Qwen2-1.5B 49 | Qwen2-72B-Instruct 50 | contriever-msmarco 51 | ``` 52 | 53 | ### Step3: Download the wikipedia 2018 dump 54 | Download preprocessed passage data of the wikipedia 2018 dump. 55 | ```bash 56 | cd ./C-3PO/deploy_servers/retrieve_server/wiki18 57 | wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz 58 | ``` 59 | Then, download the embedded passages. We use Contriever-MSMARCO. 60 | ```bash 61 | cd ./C-3PO/deploy_servers/retrieve_server/wikipedia_embeddings 62 | wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar 63 | ``` 64 | 65 | ### Step4: Deploy the LLM server 66 | 67 | ```bash 68 | cd ./deploy_servers/llm_server 69 | bash qwen_72b_serve.sh 70 | ``` 71 | 72 | ### Step5: Deploy the Retrieval server 73 | 74 | ```bash 75 | cd ./C-3PO/deploy_servers/retrieve_server/retrieve_code/wiki18 76 | bash start_wiki18.sh 77 | ``` 78 | 79 | ## :dart: Inference 80 | 81 | ### Download our released ckpt (Optional) 82 | You can download our ckpt of [C-3PO-1.5B](https://www.modelscope.cn/models/Decaderan/C-3PO-1.5B) or [C-3PO-0.5B](https://www.modelscope.cn/models/Decaderan/C-3PO-0.5B) on the ModelScope. 83 | 84 | 85 | ### Scripts 86 | Our implementation supports two high-performance inference engines: SGLang and vLLM, allowing users to optimize for different deployment scenarios and hardware configurations. 87 | ```bash 88 | cd ./C-3PO/inference 89 | bash single_model.sh 90 | ``` 91 | 92 | ## Tree-structured Rollout for Seed Data (Supervised Warm-up) 93 | We collect seed data through tree-structured rollout using Qwen-2-72B-Instruct. 94 | 95 | ### Step1: tree-structured rollout 96 | ```bash 97 | cd ./C-3PO/instruct_sampling_scripts 98 | bash run_72b.sh 99 | ``` 100 | 101 | ### Step2: supervised fine-tuning 102 | We use Llama-Factory as our training framework for sft. 103 | ```bash 104 | git clone (from llama-Factory) 105 | # we release our training hyper-parameters for easy reproduction. 106 | cd ./C-3PO/train/sft_scripts 107 | bash run_base_packing.sh 108 | ``` 109 | 110 | ## :heart: Acknowledgements 111 | 112 | This work is built upon several excellent open-source projects. We sincerely thank: 113 | 114 | - [Llama Factory](https://github.com/hiyouga/LLaMA-Factory) for providing the supervised fine-tuning framework 115 | - [vLLM](https://github.com/vllm-project/vllm) for the efficient inference engine with high throughput 116 | - [SGLang](https://github.com/sgl-project/sglang) for the efficient inference engine with high throughput 117 | - [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) for the comprehensive reinforcement learning framework 118 | 119 | We express our gratitude to all these projects for their outstanding contributions to the open-source community. 120 | 121 | 122 | ## Citation 123 | If you find our work useful in your research, please consider citing our paper: 124 | ```bibtex 125 | @article{chen2025c, 126 | title={C-3PO: Compact Plug-and-Play Proxy Optimization to Achieve Human-like Retrieval-Augmented Generation}, 127 | author={Chen, Guoxin and Liao, Minpeng and Yu, Peiying and Wang, Dingmin and Qiao, Zile and Yang, Chao and Zhao, Xin and Fan, Kai}, 128 | journal={arXiv preprint arXiv:2502.06205}, 129 | year={2025} 130 | } 131 | ``` 132 | Your support by starring ⭐ this repository would be greatly appreciated! -------------------------------------------------------------------------------- /deploy_servers/llm_server/base_sgl.sh: -------------------------------------------------------------------------------- 1 | model_name_or_path=$1 2 | TP=$2 3 | 4 | export VLLM_USE_MODELSCOPE="False" 5 | 6 | python=path_to/c3po/bin/python 7 | 8 | # 自动检测gpu数量 9 | GPU_NUM=$(nvidia-smi -L | wc -l) 10 | 11 | ${python} -m sglang.launch_server --model-path $model_name_or_path --host 0.0.0.0 --tp $TP --dp $((${GPU_NUM} / ${TP})) --port 10080 --mem-fraction-static 0.9 12 | 13 | -------------------------------------------------------------------------------- /deploy_servers/llm_server/qwen_72b_serve.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir -p ./logs 3 | 4 | hostname -I 5 | 6 | model_name_or_path=path_to_model_cache/Qwen2-72B-Instruct 7 | 8 | TP=4 9 | basename=$(basename $model_name_or_path) 10 | 11 | bash base_sgl.sh $model_name_or_path $TP 12 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/deploy_servers/retrieve_server/retrieve_code/src/__init__.py -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/contriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | import transformers 6 | from transformers import BertModel, XLMRobertaModel 7 | 8 | from src import utils 9 | 10 | 11 | class Contriever(BertModel): 12 | def __init__(self, config, pooling="average", **kwargs): 13 | super().__init__(config, add_pooling_layer=False) 14 | if not hasattr(config, "pooling"): 15 | self.config.pooling = pooling 16 | 17 | def forward( 18 | self, 19 | input_ids=None, 20 | attention_mask=None, 21 | token_type_ids=None, 22 | position_ids=None, 23 | head_mask=None, 24 | inputs_embeds=None, 25 | encoder_hidden_states=None, 26 | encoder_attention_mask=None, 27 | output_attentions=None, 28 | output_hidden_states=None, 29 | normalize=False, 30 | ): 31 | 32 | model_output = super().forward( 33 | input_ids=input_ids, 34 | attention_mask=attention_mask, 35 | token_type_ids=token_type_ids, 36 | position_ids=position_ids, 37 | head_mask=head_mask, 38 | inputs_embeds=inputs_embeds, 39 | encoder_hidden_states=encoder_hidden_states, 40 | encoder_attention_mask=encoder_attention_mask, 41 | output_attentions=output_attentions, 42 | output_hidden_states=output_hidden_states, 43 | ) 44 | 45 | last_hidden = model_output["last_hidden_state"] 46 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 47 | 48 | if self.config.pooling == "average": 49 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 50 | elif self.config.pooling == "cls": 51 | emb = last_hidden[:, 0] 52 | 53 | if normalize: 54 | emb = torch.nn.functional.normalize(emb, dim=-1) 55 | return emb 56 | 57 | 58 | class XLMRetriever(XLMRobertaModel): 59 | def __init__(self, config, pooling="average", **kwargs): 60 | super().__init__(config, add_pooling_layer=False) 61 | if not hasattr(config, "pooling"): 62 | self.config.pooling = pooling 63 | 64 | def forward( 65 | self, 66 | input_ids=None, 67 | attention_mask=None, 68 | token_type_ids=None, 69 | position_ids=None, 70 | head_mask=None, 71 | inputs_embeds=None, 72 | encoder_hidden_states=None, 73 | encoder_attention_mask=None, 74 | output_attentions=None, 75 | output_hidden_states=None, 76 | normalize=False, 77 | ): 78 | 79 | model_output = super().forward( 80 | input_ids=input_ids, 81 | attention_mask=attention_mask, 82 | token_type_ids=token_type_ids, 83 | position_ids=position_ids, 84 | head_mask=head_mask, 85 | inputs_embeds=inputs_embeds, 86 | encoder_hidden_states=encoder_hidden_states, 87 | encoder_attention_mask=encoder_attention_mask, 88 | output_attentions=output_attentions, 89 | output_hidden_states=output_hidden_states, 90 | ) 91 | 92 | last_hidden = model_output["last_hidden_state"] 93 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 94 | if self.config.pooling == "average": 95 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 96 | elif self.config.pooling == "cls": 97 | emb = last_hidden[:, 0] 98 | if normalize: 99 | emb = torch.nn.functional.normalize(emb, dim=-1) 100 | return emb 101 | 102 | 103 | def load_retriever(model_path, pooling="average", random_init=False): 104 | # try: check if model exists locally 105 | path = os.path.join(model_path, "checkpoint.pth") 106 | if os.path.exists(path): 107 | pretrained_dict = torch.load(path, map_location="cpu") 108 | opt = pretrained_dict["opt"] 109 | if hasattr(opt, "retriever_model_id"): 110 | retriever_model_id = opt.retriever_model_id 111 | else: 112 | # retriever_model_id = "bert-base-uncased" 113 | retriever_model_id = "bert-base-multilingual-cased" 114 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id) 115 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id) 116 | if "xlm" in retriever_model_id: 117 | model_class = XLMRetriever 118 | else: 119 | model_class = Contriever 120 | retriever = model_class(cfg) 121 | pretrained_dict = pretrained_dict["model"] 122 | 123 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class 124 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k} 125 | elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class 126 | pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k} 127 | retriever.load_state_dict(pretrained_dict, strict=False) 128 | else: 129 | retriever_model_id = model_path 130 | if "xlm" in retriever_model_id: 131 | model_class = XLMRetriever 132 | else: 133 | model_class = Contriever 134 | cfg = utils.load_hf(transformers.AutoConfig, model_path) 135 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path) 136 | retriever = utils.load_hf(model_class, model_path) 137 | 138 | return retriever, tokenizer, retriever_model_id 139 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class Gather(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x: torch.tensor): 10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 11 | dist.all_gather(output, x) 12 | return tuple(output) 13 | 14 | @staticmethod 15 | def backward(ctx, *grads): 16 | all_gradients = torch.stack(grads) 17 | dist.all_reduce(all_gradients) 18 | return all_gradients[dist.get_rank()] 19 | 20 | 21 | def gather(x: torch.tensor): 22 | if not dist.is_initialized(): 23 | return x 24 | x_gather = Gather.apply(x) 25 | x_gather = torch.cat(x_gather, dim=0) 26 | return x_gather 27 | 28 | 29 | @torch.no_grad() 30 | def gather_nograd(x: torch.tensor): 31 | if not dist.is_initialized(): 32 | return x 33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())] 34 | dist.all_gather(x_gather, x, async_op=False) 35 | 36 | x_gather = torch.cat(x_gather, dim=0) 37 | return x_gather 38 | 39 | 40 | @torch.no_grad() 41 | def varsize_gather_nograd(x: torch.Tensor): 42 | """gather tensors of different sizes along the first dimension""" 43 | if not dist.is_initialized(): 44 | return x 45 | 46 | # determine max size 47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 49 | dist.all_gather(allsizes, size) 50 | max_size = max([size.cpu().max() for size in allsizes]) 51 | 52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) 53 | padded[: x.shape[0]] = x 54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] 55 | dist.all_gather(output, padded) 56 | 57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] 58 | output = torch.cat(output, dim=0) 59 | 60 | return output 61 | 62 | 63 | @torch.no_grad() 64 | def get_varsize(x: torch.Tensor): 65 | """gather tensors of different sizes along the first dimension""" 66 | if not dist.is_initialized(): 67 | return [x.shape[0]] 68 | 69 | # determine max size 70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 72 | dist.all_gather(allsizes, size) 73 | allsizes = torch.cat(allsizes) 74 | return allsizes 75 | 76 | 77 | def get_rank(): 78 | if not dist.is_available(): 79 | return 0 80 | if not dist.is_initialized(): 81 | return 0 82 | return dist.get_rank() 83 | 84 | 85 | def is_main(): 86 | return get_rank() == 0 87 | 88 | 89 | def get_world_size(): 90 | if not dist.is_initialized(): 91 | return 1 92 | else: 93 | return dist.get_world_size() 94 | 95 | 96 | def barrier(): 97 | if dist.is_initialized(): 98 | dist.barrier() 99 | 100 | 101 | def average_main(x): 102 | if not dist.is_initialized(): 103 | return x 104 | if dist.is_initialized() and dist.get_world_size() > 1: 105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 106 | if is_main(): 107 | x = x / dist.get_world_size() 108 | return x 109 | 110 | 111 | def sum_main(x): 112 | if not dist.is_initialized(): 113 | return x 114 | if dist.is_initialized() and dist.get_world_size() > 1: 115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 116 | return x 117 | 118 | 119 | def weighted_average(x, count): 120 | if not dist.is_initialized(): 121 | if isinstance(x, torch.Tensor): 122 | x = x.item() 123 | return x, count 124 | t_loss = torch.tensor([x * count]).cuda() 125 | t_total = torch.tensor([count]).cuda() 126 | t_loss = sum_main(t_loss) 127 | t_total = sum_main(t_total) 128 | return (t_loss / t_total).item(), t_total.item() 129 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import logging 10 | import regex 11 | import string 12 | import unicodedata 13 | from functools import partial 14 | from multiprocessing import Pool as ProcessPool 15 | from typing import Tuple, List, Dict 16 | import numpy as np 17 | 18 | """ 19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 20 | """ 21 | 22 | class SimpleTokenizer(object): 23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 24 | NON_WS = r'[^\p{Z}\p{C}]' 25 | 26 | def __init__(self): 27 | """ 28 | Args: 29 | annotators: None or empty set (only tokenizes). 30 | """ 31 | self._regexp = regex.compile( 32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 34 | ) 35 | 36 | def tokenize(self, text, uncased=False): 37 | matches = [m for m in self._regexp.finditer(text)] 38 | if uncased: 39 | tokens = [m.group().lower() for m in matches] 40 | else: 41 | tokens = [m.group() for m in matches] 42 | return tokens 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 47 | 48 | def calculate_matches(data: List, workers_num: int): 49 | """ 50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 53 | :param answers: list of answers's list. One list per question 54 | :param closest_docs: document ids of the top results along with their scores 55 | :param workers_num: amount of parallel threads to process data 56 | :param match_type: type of answer matching. Refer to has_answer code for available options 57 | :return: matching information tuple. 58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 59 | valid matches across an entire dataset. 60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 61 | """ 62 | 63 | logger.info('Matching answers in top docs...') 64 | 65 | tokenizer = SimpleTokenizer() 66 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 67 | 68 | processes = ProcessPool(processes=workers_num) 69 | scores = processes.map(get_score_partial, data) 70 | 71 | logger.info('Per question validation results len=%d', len(scores)) 72 | 73 | n_docs = len(data[0]['ctxs']) 74 | top_k_hits = [0] * n_docs 75 | for question_hits in scores: 76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 77 | if best_hit is not None: 78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 79 | 80 | return QAMatchStats(top_k_hits, scores) 81 | 82 | def check_answer(example, tokenizer) -> List[bool]: 83 | """Search through all the top docs to see if they have any of the answers.""" 84 | answers = example['answers'] 85 | ctxs = example['ctxs'] 86 | 87 | hits = [] 88 | 89 | for i, doc in enumerate(ctxs): 90 | text = doc['text'] 91 | 92 | if text is None: # cannot find the document for some reason 93 | logger.warning("no doc in db") 94 | hits.append(False) 95 | continue 96 | 97 | hits.append(has_answer(answers, text, tokenizer)) 98 | 99 | return hits 100 | 101 | def has_answer(answers, text, tokenizer) -> bool: 102 | """Check if a document contains an answer string.""" 103 | text = _normalize(text) 104 | text = tokenizer.tokenize(text, uncased=True) 105 | 106 | for answer in answers: 107 | answer = _normalize(answer) 108 | answer = tokenizer.tokenize(answer, uncased=True) 109 | for i in range(0, len(text) - len(answer) + 1): 110 | if answer == text[i: i + len(answer)]: 111 | return True 112 | return False 113 | 114 | ################################################# 115 | ######## READER EVALUATION ######## 116 | ################################################# 117 | 118 | def _normalize(text): 119 | return unicodedata.normalize('NFD', text) 120 | 121 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 122 | def normalize_answer(s): 123 | def remove_articles(text): 124 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 125 | 126 | def white_space_fix(text): 127 | return ' '.join(text.split()) 128 | 129 | def remove_punc(text): 130 | exclude = set(string.punctuation) 131 | return ''.join(ch for ch in text if ch not in exclude) 132 | 133 | def lower(text): 134 | return text.lower() 135 | 136 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 137 | 138 | def em(prediction, ground_truth): 139 | return normalize_answer(prediction) == normalize_answer(ground_truth) 140 | 141 | def f1(prediction, ground_truth): 142 | prediction_tokens = normalize_answer(prediction).split() 143 | ground_truth_tokens = normalize_answer(ground_truth).split() 144 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 145 | num_same = sum(common.values()) 146 | if num_same == 0: 147 | return 0 148 | precision = 1.0 * num_same / len(prediction_tokens) 149 | recall = 1.0 * num_same / len(ground_truth_tokens) 150 | f1 = (2 * precision * recall) / (precision + recall) 151 | return f1 152 | 153 | def f1_score(prediction, ground_truths): 154 | return max([f1(prediction, gt) for gt in ground_truths]) 155 | 156 | def exact_match_score(prediction, ground_truths): 157 | return max([em(prediction, gt) for gt in ground_truths]) 158 | 159 | #################################################### 160 | ######## RETRIEVER EVALUATION ######## 161 | #################################################### 162 | 163 | def eval_batch(scores, inversions, avg_topk, idx_topk): 164 | for k, s in enumerate(scores): 165 | s = s.cpu().numpy() 166 | sorted_idx = np.argsort(-s) 167 | score(sorted_idx, inversions, avg_topk, idx_topk) 168 | 169 | def count_inversions(arr): 170 | inv_count = 0 171 | lenarr = len(arr) 172 | for i in range(lenarr): 173 | for j in range(i + 1, lenarr): 174 | if (arr[i] > arr[j]): 175 | inv_count += 1 176 | return inv_count 177 | 178 | def score(x, inversions, avg_topk, idx_topk): 179 | x = np.array(x) 180 | inversions.append(count_inversions(x)) 181 | for k in avg_topk: 182 | # ratio of passages in the predicted top-k that are 183 | # also in the topk given by gold score 184 | avg_pred_topk = (x[:k] 0: 43 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives) 44 | negatives += random_negatives 45 | if n_hard_negatives > 0: 46 | hard_negatives = random.sample( 47 | example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives 48 | ) 49 | negatives += hard_negatives 50 | else: 51 | gold = example["positive_ctxs"][0] 52 | nidx = 0 53 | if "negative_ctxs" in example: 54 | negatives = [example["negative_ctxs"][nidx]] 55 | else: 56 | negatives = [] 57 | 58 | gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"] 59 | 60 | negatives = [ 61 | n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives 62 | ] 63 | 64 | example = { 65 | "query": self.normalize_fn(question), 66 | "gold": self.normalize_fn(gold), 67 | "negatives": [self.normalize_fn(n) for n in negatives], 68 | } 69 | return example 70 | 71 | def _load_data(self, datapaths, global_rank, world_size, maxload): 72 | counter = 0 73 | self.data = [] 74 | for path in datapaths: 75 | path = str(path) 76 | if path.endswith(".jsonl"): 77 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload) 78 | elif path.endswith(".json"): 79 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload) 80 | self.data.extend(file_data) 81 | if maxload is not None and maxload > 0 and counter >= maxload: 82 | break 83 | 84 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None): 85 | examples = [] 86 | with open(path, "r") as fin: 87 | data = json.load(fin) 88 | for example in data: 89 | counter += 1 90 | if global_rank > -1 and not counter % world_size == global_rank: 91 | continue 92 | examples.append(example) 93 | if maxload is not None and maxload > 0 and counter == maxload: 94 | break 95 | 96 | return examples, counter 97 | 98 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None): 99 | examples = [] 100 | with open(path, "r") as fin: 101 | for line in fin: 102 | counter += 1 103 | if global_rank > -1 and not counter % world_size == global_rank: 104 | continue 105 | example = json.loads(line) 106 | examples.append(example) 107 | if maxload is not None and maxload > 0 and counter == maxload: 108 | break 109 | 110 | return examples, counter 111 | 112 | def sample_n_hard_negatives(self, ex): 113 | 114 | if "hard_negative_ctxs" in ex: 115 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)]) 116 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :])) 117 | else: 118 | n_hard_negatives = 0 119 | n_random_negatives = self.negative_ctxs - n_hard_negatives 120 | if "negative_ctxs" in ex: 121 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"])) 122 | else: 123 | n_random_negatives = 0 124 | return n_hard_negatives, n_random_negatives 125 | 126 | 127 | class Collator(object): 128 | def __init__(self, tokenizer, passage_maxlength=200): 129 | self.tokenizer = tokenizer 130 | self.passage_maxlength = passage_maxlength 131 | 132 | def __call__(self, batch): 133 | queries = [ex["query"] for ex in batch] 134 | golds = [ex["gold"] for ex in batch] 135 | negs = [item for ex in batch for item in ex["negatives"]] 136 | allpassages = golds + negs 137 | 138 | qout = self.tokenizer.batch_encode_plus( 139 | queries, 140 | max_length=self.passage_maxlength, 141 | truncation=True, 142 | padding=True, 143 | add_special_tokens=True, 144 | return_tensors="pt", 145 | ) 146 | kout = self.tokenizer.batch_encode_plus( 147 | allpassages, 148 | max_length=self.passage_maxlength, 149 | truncation=True, 150 | padding=True, 151 | add_special_tokens=True, 152 | return_tensors="pt", 153 | ) 154 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool() 155 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool() 156 | 157 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)] 158 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :] 159 | 160 | batch = { 161 | "q_tokens": q_tokens, 162 | "q_mask": q_mask, 163 | "k_tokens": k_tokens, 164 | "k_mask": k_mask, 165 | "g_tokens": g_tokens, 166 | "g_mask": g_mask, 167 | "n_tokens": n_tokens, 168 | "n_mask": n_mask, 169 | } 170 | 171 | return batch 172 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/inbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import math 7 | import random 8 | import transformers 9 | import logging 10 | import torch.distributed as dist 11 | 12 | from src import contriever, dist_utils, utils 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class InBatch(nn.Module): 18 | def __init__(self, opt, retriever=None, tokenizer=None): 19 | super(InBatch, self).__init__() 20 | 21 | self.opt = opt 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.label_smoothing = opt.label_smoothing 25 | if retriever is None or tokenizer is None: 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | self.tokenizer = tokenizer 30 | self.encoder = retriever 31 | 32 | def _load_retriever(self, model_id, pooling, random_init): 33 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 34 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 35 | 36 | if "xlm" in model_id: 37 | model_class = contriever.XLMRetriever 38 | else: 39 | model_class = contriever.Contriever 40 | 41 | if random_init: 42 | retriever = model_class(cfg) 43 | else: 44 | retriever = utils.load_hf(model_class, model_id) 45 | 46 | if "bert-" in model_id: 47 | if tokenizer.bos_token_id is None: 48 | tokenizer.bos_token = "[CLS]" 49 | if tokenizer.eos_token_id is None: 50 | tokenizer.eos_token = "[SEP]" 51 | 52 | retriever.config.pooling = pooling 53 | 54 | return retriever, tokenizer 55 | 56 | def get_encoder(self): 57 | return self.encoder 58 | 59 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 60 | 61 | bsz = len(q_tokens) 62 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device) 63 | 64 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 65 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 66 | 67 | gather_fn = dist_utils.gather 68 | 69 | gather_kemb = gather_fn(kemb) 70 | 71 | labels = labels + dist_utils.get_rank() * len(kemb) 72 | 73 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb) 74 | 75 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing) 76 | 77 | # log stats 78 | if len(stats_prefix) > 0: 79 | stats_prefix = stats_prefix + "/" 80 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 81 | 82 | predicted_idx = torch.argmax(scores, dim=-1) 83 | accuracy = 100 * (predicted_idx == labels).float().mean() 84 | stdq = torch.std(qemb, dim=0).mean().item() 85 | stdk = torch.std(kemb, dim=0).mean().item() 86 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 87 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 88 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 89 | 90 | return loss, iter_stats 91 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import pickle 9 | import time 10 | from typing import List, Tuple 11 | 12 | import faiss 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | class Indexer(object): 17 | 18 | def __init__(self, vector_sz, index_type='FlatIP', nlist=200000, n_subquantizers=64, n_bits=8): 19 | # vector_sz: embedding dimension 20 | self.index_type = index_type 21 | if index_type == 'IVFPQ': 22 | quantizer = faiss.IndexFlatIP(vector_sz) # 使用内积作为量化器 23 | self.index = faiss.IndexIVFPQ(quantizer, vector_sz, nlist, n_subquantizers, n_bits) 24 | elif index_type == 'PQ': 25 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) 26 | elif index_type == "FlatIP": 27 | self.index = faiss.IndexFlatIP(vector_sz) # 默认使用IndexFlatIP 28 | else: 29 | raise NotImplementedError(f"unknown index_type: {index_type}") 30 | 31 | print(f'index_info:\nvector_sz:{vector_sz}\nindex_type:{index_type}\nnlist:{nlist}\nn_subquantizers:{n_subquantizers}\nn_bits:{n_bits}') 32 | 33 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64) 34 | self.index_id_to_db_id = [] 35 | 36 | def index_data(self, ids, embeddings): 37 | self._update_id_mapping(ids) 38 | # embeddings = embeddings.astype('float32') 39 | embeddings = np.array(embeddings, dtype=np.float32) 40 | if not self.index.is_trained: 41 | print(f'start training...') 42 | begin_time = time.time() 43 | self.index.train(embeddings) 44 | print(f'training time: {time.time() - begin_time}') 45 | self.index.add(embeddings) 46 | 47 | print(f'Total data indexed {len(self.index_id_to_db_id)}') 48 | 49 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 4096) -> List[Tuple[List[object], List[float]]]: 50 | # import pdb; pdb.set_trace() 51 | # query_vectors = query_vectors.astype('float32') 52 | query_vectors = np.array(query_vectors, dtype=np.float32) 53 | result = [] 54 | nbatch = (len(query_vectors)-1) // index_batch_size + 1 55 | for k in tqdm(range(nbatch), desc=f"searching bz: {index_batch_size}"): 56 | start_idx = k*index_batch_size 57 | end_idx = min((k+1)*index_batch_size, len(query_vectors)) 58 | q = query_vectors[start_idx: end_idx] 59 | # scores, indexes = self.index.search(n=q.shape[0], x=q, k=top_docs) 60 | scores, indexes = self.index.search(q, top_docs) 61 | # convert to external ids 62 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] 63 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 64 | return result 65 | 66 | def serialize(self, dir_path): 67 | index_file = os.path.join(dir_path, f'{self.index_type}_index.faiss') 68 | meta_file = os.path.join(dir_path, f'{self.index_type}_index_meta.faiss') 69 | print(f'Serializing index to {index_file}, meta data to {meta_file}') 70 | 71 | faiss.write_index(self.index, index_file) 72 | with open(meta_file, mode='wb') as f: 73 | pickle.dump(self.index_id_to_db_id, f) 74 | 75 | def deserialize_from(self, dir_path): 76 | index_file = os.path.join(dir_path, f'{self.index_type}_index.faiss') 77 | meta_file = os.path.join(dir_path, f'{self.index_type}_index_meta.faiss') 78 | print(f'Loading index from {index_file}, meta data from {meta_file}') 79 | 80 | self.index = faiss.read_index(index_file) 81 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 82 | 83 | with open(meta_file, "rb") as reader: 84 | self.index_id_to_db_id = pickle.load(reader) 85 | assert len( 86 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 87 | 88 | def _update_id_mapping(self, db_ids: List): 89 | #new_ids = np.array(db_ids, dtype=np.int64) 90 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0) 91 | self.index_id_to_db_id.extend(db_ids) -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import copy 7 | import transformers 8 | 9 | from src import contriever, dist_utils, utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MoCo(nn.Module): 15 | def __init__(self, opt): 16 | super(MoCo, self).__init__() 17 | 18 | self.queue_size = opt.queue_size 19 | self.momentum = opt.momentum 20 | self.temperature = opt.temperature 21 | self.label_smoothing = opt.label_smoothing 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode 25 | 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | 30 | self.tokenizer = tokenizer 31 | self.encoder_q = retriever 32 | self.encoder_k = copy.deepcopy(retriever) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) 36 | param_k.requires_grad = False 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | def _load_retriever(self, model_id, pooling, random_init): 45 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 47 | 48 | if "xlm" in model_id: 49 | model_class = contriever.XLMRetriever 50 | else: 51 | model_class = contriever.Contriever 52 | 53 | if random_init: 54 | retriever = model_class(cfg) 55 | else: 56 | retriever = utils.load_hf(model_class, model_id) 57 | 58 | if "bert-" in model_id: 59 | if tokenizer.bos_token_id is None: 60 | tokenizer.bos_token = "[CLS]" 61 | if tokenizer.eos_token_id is None: 62 | tokenizer.eos_token = "[SEP]" 63 | 64 | retriever.config.pooling = pooling 65 | 66 | return retriever, tokenizer 67 | 68 | def get_encoder(self, return_encoder_k=False): 69 | if return_encoder_k: 70 | return self.encoder_k 71 | else: 72 | return self.encoder_q 73 | 74 | def _momentum_update_key_encoder(self): 75 | """ 76 | Update of the key encoder 77 | """ 78 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 79 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum) 80 | 81 | @torch.no_grad() 82 | def _dequeue_and_enqueue(self, keys): 83 | # gather keys before updating queue 84 | keys = dist_utils.gather_nograd(keys.contiguous()) 85 | 86 | batch_size = keys.shape[0] 87 | 88 | ptr = int(self.queue_ptr) 89 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity 90 | 91 | # replace the keys at ptr (dequeue and enqueue) 92 | self.queue[:, ptr : ptr + batch_size] = keys.T 93 | ptr = (ptr + batch_size) % self.queue_size # move pointer 94 | 95 | self.queue_ptr[0] = ptr 96 | 97 | def _compute_logits(self, q, k): 98 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 99 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 100 | 101 | logits = torch.cat([l_pos, l_neg], dim=1) 102 | return logits 103 | 104 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 105 | bsz = q_tokens.size(0) 106 | 107 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 108 | 109 | # compute key features 110 | with torch.no_grad(): # no gradient to keys 111 | self._momentum_update_key_encoder() # update the key encoder 112 | 113 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k: 114 | self.encoder_k.eval() 115 | 116 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 117 | 118 | logits = self._compute_logits(q, k) / self.temperature 119 | 120 | # labels: positive key indicators 121 | labels = torch.zeros(bsz, dtype=torch.long).cuda() 122 | 123 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing) 124 | 125 | self._dequeue_and_enqueue(k) 126 | 127 | # log stats 128 | if len(stats_prefix) > 0: 129 | stats_prefix = stats_prefix + "/" 130 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 131 | 132 | predicted_idx = torch.argmax(logits, dim=-1) 133 | accuracy = 100 * (predicted_idx == labels).float().mean() 134 | stdq = torch.std(q, dim=0).mean().item() 135 | stdk = torch.std(k, dim=0).mean().item() 136 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 137 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 138 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 139 | 140 | return loss, iter_stats 141 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/normalize_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from chemdataextractor.text.normalize 3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | Tools for normalizing text. 5 | https://github.com/mcs07/ChemDataExtractor 6 | :copyright: Copyright 2016 by Matt Swain. 7 | :license: MIT 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining 10 | a copy of this software and associated documentation files (the 11 | 'Software'), to deal in the Software without restriction, including 12 | without limitation the rights to use, copy, modify, merge, publish, 13 | distribute, sublicense, and/or sell copies of the Software, and to 14 | permit persons to whom the Software is furnished to do so, subject to 15 | the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be 18 | included in all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 21 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 22 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 25 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 26 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | """ 28 | 29 | #: Control characters. 30 | CONTROLS = { 31 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011', 32 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b', 33 | } 34 | # There are further control characters, but they are instead replaced with a space by unicode normalization 35 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 36 | 37 | 38 | #: Hyphen and dash characters. 39 | HYPHENS = { 40 | '-', # \u002d Hyphen-minus 41 | '‐', # \u2010 Hyphen 42 | '‑', # \u2011 Non-breaking hyphen 43 | '⁃', # \u2043 Hyphen bullet 44 | '‒', # \u2012 figure dash 45 | '–', # \u2013 en dash 46 | '—', # \u2014 em dash 47 | '―', # \u2015 horizontal bar 48 | } 49 | 50 | #: Minus characters. 51 | MINUSES = { 52 | '-', # \u002d Hyphen-minus 53 | '−', # \u2212 Minus 54 | '-', # \uff0d Full-width Hyphen-minus 55 | '⁻', # \u207b Superscript minus 56 | } 57 | 58 | #: Plus characters. 59 | PLUSES = { 60 | '+', # \u002b Plus 61 | '+', # \uff0b Full-width Plus 62 | '⁺', # \u207a Superscript plus 63 | } 64 | 65 | #: Slash characters. 66 | SLASHES = { 67 | '/', # \u002f Solidus 68 | '⁄', # \u2044 Fraction slash 69 | '∕', # \u2215 Division slash 70 | } 71 | 72 | #: Tilde characters. 73 | TILDES = { 74 | '~', # \u007e Tilde 75 | '˜', # \u02dc Small tilde 76 | '⁓', # \u2053 Swung dash 77 | '∼', # \u223c Tilde operator #in mbert vocab 78 | '∽', # \u223d Reversed tilde 79 | '∿', # \u223f Sine wave 80 | '〜', # \u301c Wave dash #in mbert vocab 81 | '~', # \uff5e Full-width tilde #in mbert vocab 82 | } 83 | 84 | #: Apostrophe characters. 85 | APOSTROPHES = { 86 | "'", # \u0027 87 | '’', # \u2019 88 | '՚', # \u055a 89 | 'Ꞌ', # \ua78b 90 | 'ꞌ', # \ua78c 91 | ''', # \uff07 92 | } 93 | 94 | #: Single quote characters. 95 | SINGLE_QUOTES = { 96 | "'", # \u0027 97 | '‘', # \u2018 98 | '’', # \u2019 99 | '‚', # \u201a 100 | '‛', # \u201b 101 | 102 | } 103 | 104 | #: Double quote characters. 105 | DOUBLE_QUOTES = { 106 | '"', # \u0022 107 | '“', # \u201c 108 | '”', # \u201d 109 | '„', # \u201e 110 | '‟', # \u201f 111 | } 112 | 113 | #: Accent characters. 114 | ACCENTS = { 115 | '`', # \u0060 116 | '´', # \u00b4 117 | } 118 | 119 | #: Prime characters. 120 | PRIMES = { 121 | '′', # \u2032 122 | '″', # \u2033 123 | '‴', # \u2034 124 | '‵', # \u2035 125 | '‶', # \u2036 126 | '‷', # \u2037 127 | '⁗', # \u2057 128 | } 129 | 130 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 131 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 132 | 133 | def normalize(text): 134 | for control in CONTROLS: 135 | text = text.replace(control, '') 136 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ') 137 | 138 | for hyphen in HYPHENS | MINUSES: 139 | text = text.replace(hyphen, '-') 140 | text = text.replace('\u00ad', '') 141 | 142 | for double_quote in DOUBLE_QUOTES: 143 | text = text.replace(double_quote, '"') # \u0022 144 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS): 145 | text = text.replace(single_quote, "'") # \u0027 146 | text = text.replace('′', "'") # \u2032 prime 147 | text = text.replace('‵', "'") # \u2035 reversed prime 148 | text = text.replace('″', "''") # \u2033 double prime 149 | text = text.replace('‶', "''") # \u2036 reversed double prime 150 | text = text.replace('‴', "'''") # \u2034 triple prime 151 | text = text.replace('‷', "'''") # \u2037 reversed triple prime 152 | text = text.replace('⁗', "''''") # \u2057 quadruple prime 153 | 154 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026 155 | 156 | for slash in SLASHES: 157 | text = text.replace(slash, '/') 158 | 159 | #for tilde in TILDES: 160 | # text = text.replace(tilde, '~') 161 | 162 | return text 163 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | 6 | 7 | class Options: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # basic parameters 14 | self.parser.add_argument( 15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here" 16 | ) 17 | self.parser.add_argument( 18 | "--train_data", 19 | nargs="+", 20 | default=[], 21 | help="Data used for training, passed as a list of directories splitted into tensor files.", 22 | ) 23 | self.parser.add_argument( 24 | "--eval_data", 25 | nargs="+", 26 | default=[], 27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.", 28 | ) 29 | self.parser.add_argument( 30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format" 31 | ) 32 | self.parser.add_argument( 33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored" 34 | ) 35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining") 36 | self.parser.add_argument("--continue_training", action="store_true") 37 | self.parser.add_argument("--num_workers", type=int, default=5) 38 | 39 | self.parser.add_argument("--chunk_length", type=int, default=256) 40 | self.parser.add_argument("--loading_mode", type=str, default="split") 41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing") 42 | self.parser.add_argument( 43 | "--sampling_coefficient", 44 | type=float, 45 | default=0.0, 46 | help="coefficient used for sampling between different datasets during training, \ 47 | by default sampling is uniform over datasets", 48 | ) 49 | self.parser.add_argument("--augmentation", type=str, default="none") 50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0) 51 | 52 | self.parser.add_argument("--dropout", type=float, default=0.1) 53 | self.parser.add_argument("--rho", type=float, default=0.05) 54 | 55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco") 56 | self.parser.add_argument("--queue_size", type=int, default=65536) 57 | self.parser.add_argument("--temperature", type=float, default=1.0) 58 | self.parser.add_argument("--momentum", type=float, default=0.999) 59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true") 60 | self.parser.add_argument("--eval_normalize_text", action="store_true") 61 | self.parser.add_argument("--norm_query", action="store_true") 62 | self.parser.add_argument("--norm_doc", action="store_true") 63 | self.parser.add_argument("--projection_size", type=int, default=768) 64 | 65 | self.parser.add_argument("--ratio_min", type=float, default=0.1) 66 | self.parser.add_argument("--ratio_max", type=float, default=0.5) 67 | self.parser.add_argument("--score_function", type=str, default="dot") 68 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased") 69 | self.parser.add_argument("--pooling", type=str, default="average") 70 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights") 71 | 72 | # dataset parameters 73 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.") 74 | self.parser.add_argument( 75 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation." 76 | ) 77 | self.parser.add_argument("--total_steps", type=int, default=1000) 78 | self.parser.add_argument("--warmup_steps", type=int, default=-1) 79 | 80 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 81 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)") 82 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization") 83 | # training parameters 84 | self.parser.add_argument("--optim", type=str, default="adamw") 85 | self.parser.add_argument("--scheduler", type=str, default="linear") 86 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 87 | self.parser.add_argument( 88 | "--lr_min_ratio", 89 | type=float, 90 | default=0.0, 91 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate", 92 | ) 93 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate") 94 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1") 95 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2") 96 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps") 97 | self.parser.add_argument( 98 | "--log_freq", type=int, default=100, help="log train stats every steps during training" 99 | ) 100 | self.parser.add_argument( 101 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training" 102 | ) 103 | self.parser.add_argument("--save_freq", type=int, default=50000) 104 | self.parser.add_argument("--maxload", type=int, default=None) 105 | self.parser.add_argument("--label_smoothing", type=float, default=0.0) 106 | 107 | # finetuning options 108 | self.parser.add_argument("--negative_ctxs", type=int, default=1) 109 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0) 110 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0) 111 | 112 | def print_options(self, opt): 113 | message = "" 114 | for k, v in sorted(vars(opt).items()): 115 | comment = "" 116 | default = self.parser.get_default(k) 117 | if v != default: 118 | comment = f"\t[default: %s]" % str(default) 119 | message += f"{str(k):>40}: {str(v):<40}{comment}\n" 120 | print(message, flush=True) 121 | model_dir = os.path.join(opt.output_dir, "models") 122 | if not os.path.exists(model_dir): 123 | os.makedirs(os.path.join(opt.output_dir, "models")) 124 | file_name = os.path.join(opt.output_dir, "opt.txt") 125 | with open(file_name, "wt") as opt_file: 126 | opt_file.write(message) 127 | opt_file.write("\n") 128 | 129 | def parse(self): 130 | opt, _ = self.parser.parse_known_args() 131 | # opt = self.parser.parse_args() 132 | return opt 133 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import getLogger 8 | import os 9 | import sys 10 | import torch 11 | import socket 12 | import signal 13 | import subprocess 14 | 15 | 16 | logger = getLogger() 17 | 18 | def sig_handler(signum, frame): 19 | logger.warning("Signal handler called with signal " + str(signum)) 20 | prod_id = int(os.environ['SLURM_PROCID']) 21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 22 | if prod_id == 0: 23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 25 | else: 26 | logger.warning("Not the main process, no need to requeue.") 27 | sys.exit(-1) 28 | 29 | 30 | def term_handler(signum, frame): 31 | logger.warning("Signal handler called with signal " + str(signum)) 32 | logger.warning("Bypassing SIGTERM.") 33 | 34 | 35 | def init_signal_handler(): 36 | """ 37 | Handle signals sent by SLURM for time limit / pre-emption. 38 | """ 39 | signal.signal(signal.SIGUSR1, sig_handler) 40 | signal.signal(signal.SIGTERM, term_handler) 41 | 42 | 43 | def init_distributed_mode(params): 44 | """ 45 | Handle single and multi-GPU / multi-node / SLURM jobs. 46 | Initialize the following variables: 47 | - local_rank 48 | - global_rank 49 | - world_size 50 | """ 51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ 52 | has_local_rank = hasattr(params, 'local_rank') 53 | 54 | # SLURM job without torch.distributed.launch 55 | if is_slurm_job and has_local_rank: 56 | 57 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 58 | 59 | # local rank on the current node / global rank 60 | params.local_rank = int(os.environ['SLURM_LOCALID']) 61 | params.global_rank = int(os.environ['SLURM_PROCID']) 62 | params.world_size = int(os.environ['SLURM_NTASKS']) 63 | 64 | # define master address and master port 65 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 66 | params.main_addr = hostnames.split()[0].decode('utf-8') 67 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1 68 | 69 | # set environment variables for 'env://' 70 | os.environ['MASTER_ADDR'] = params.main_addr 71 | os.environ['MASTER_PORT'] = str(params.main_port) 72 | os.environ['WORLD_SIZE'] = str(params.world_size) 73 | os.environ['RANK'] = str(params.global_rank) 74 | is_distributed = True 75 | 76 | 77 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 78 | elif has_local_rank and params.local_rank != -1: 79 | 80 | assert params.main_port == -1 81 | 82 | # read environment variables 83 | params.global_rank = int(os.environ['RANK']) 84 | params.world_size = int(os.environ['WORLD_SIZE']) 85 | 86 | is_distributed = True 87 | 88 | # local job (single GPU) 89 | else: 90 | params.local_rank = 0 91 | params.global_rank = 0 92 | params.world_size = 1 93 | is_distributed = False 94 | 95 | # set GPU device 96 | torch.cuda.set_device(params.local_rank) 97 | 98 | # initialize multi-GPU 99 | if is_distributed: 100 | 101 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 102 | # 'env://' will read these environment variables: 103 | # MASTER_PORT - required; has to be a free port on machine with rank 0 104 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 105 | # WORLD_SIZE - required; can be set either here, or in a call to init function 106 | # RANK - required; can be set either here, or in a call to init function 107 | 108 | #print("Initializing PyTorch distributed ...") 109 | torch.distributed.init_process_group( 110 | init_method='env://', 111 | backend='nccl', 112 | #world_size=params.world_size, 113 | #rank=params.global_rank, 114 | ) -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/wiki18/query_serve.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import time 4 | 5 | def use_flask(): # 传入待分析的文本,与token 6 | # url = "http://10.140.0.136:35000/search" 7 | # url = "http://10.32.25.199:35004/search" 8 | url = "http://10.32.25.199:35002/search" 9 | headers = {'Content-Type': 'application/json'} # 设置请求头 10 | data = json.dumps({ 11 | # 'queries': ["The director of the film Swarg Narak"], # , "what's the weather today?" * 128 12 | # 'queries': ["Director of Detective Chinatown 2 birthplace"], 13 | 'queries': ["""the Brooklyn Dodgers and the Boston Braves on 11 August 11 1951"""], # Warren Beatty Who is the director of Bulworth? 14 | "n_docs": 10 15 | }) 16 | begin_time = time.time() 17 | response = requests.post(url, headers=headers, data=data) # 发送POST请求 18 | print(time.time() - begin_time) 19 | if response.status_code == 200: 20 | return response.status_code, response.json() # 状态码,返回JSON对象(这里是结果列表) 21 | else: 22 | return response.status_code, response.raise_for_status() # 如果响应状态码不是200, 抛出异常 23 | 24 | 25 | 26 | if __name__ == '__main__': 27 | data = use_flask() 28 | print(data) 29 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/wiki18/start_wiki18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gunicorn_run=/opt/conda/envs/faiss/bin/gunicorn 4 | 5 | APP_MODULE="wiki18_serve:app" # 指向您的 Flask 应用对象的导入路径 6 | WORKERS=1 # 设置工作进程数量 7 | BIND="0.0.0.0:35004" # 设置 Gunicorn 监听的服务器地址和端口号 8 | LOG_LEVEL="info" # 日志级别(debug, info, warning, error, critical) 9 | 10 | # 运行 Gunicorn 服务器 11 | $gunicorn_run -w $WORKERS -b $BIND --timeout 360 --log-level=$LOG_LEVEL $APP_MODULE 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/wiki18/wiki18_config.yaml: -------------------------------------------------------------------------------- 1 | corpus_name: wiki18 2 | retrieve_model: ./model_cache/contriever-msmarco 3 | corpus: ./C-3PO/deploy_servers/retrieve_server/wiki18/psgs_w100.tsv 4 | corpus_embeddings: ./C-3PO/deploy_servers/retrieve_server/wikipedia_embeddings 5 | n_subquantizers: 32 6 | n_bits: 8 7 | nlist: 30000 8 | save_or_load_index: True 9 | indexing_batch_size: 5000000 # 所有数据都读进去来 10 | validation_workers: 32 11 | per_gpu_batch_size: 512 12 | question_maxlength: 256 13 | projection_size: 768 # embedding 维度 14 | lowercase: True 15 | normalize_text: True 16 | no_fp16: False 17 | use_gpu: True 18 | index_type: FlatIP 19 | index_batch_size: 4096 -------------------------------------------------------------------------------- /deploy_servers/retrieve_server/retrieve_code/wiki18/wiki18_serve.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.chdir(sys.path[0]) 3 | sys.path.append("..") 4 | 5 | import yaml 6 | from types import SimpleNamespace 7 | 8 | from passage_retrieval import Retriever 9 | from flask import Flask, request, jsonify 10 | 11 | app = Flask(__name__) 12 | 13 | # 将配置字典转换为一个args对象 14 | def dict_to_simplenamespace(d: dict) -> SimpleNamespace: 15 | namespace = SimpleNamespace() 16 | for key, value in d.items(): 17 | if isinstance(value, dict): 18 | value = dict_to_simplenamespace(value) 19 | setattr(namespace, key, value) 20 | return namespace 21 | 22 | 23 | # 加载参数 24 | with open("wiki18_config.yaml", 'r', encoding='utf-8') as fin: 25 | configs_dict = yaml.load(fin, Loader=yaml.FullLoader) 26 | 27 | configs = dict_to_simplenamespace(configs_dict) 28 | 29 | # 预加载的模型 30 | retriever = Retriever(configs) 31 | retriever.setup_retriever() 32 | print("Retriever setup finished.") 33 | 34 | @app.route('/search', methods=['POST']) 35 | def search(): 36 | # 获取请求中的查询参数 37 | args = request.get_json() 38 | queries = args.get('queries', []) 39 | top_n = args.get('n_docs', 10) 40 | retrieved_documents = retriever.search_document(queries, top_n) 41 | 42 | # 返回JSON格式的响应 43 | return jsonify(retrieved_documents) 44 | 45 | if __name__ == '__main__': 46 | # 启动Flask服务器,host='0.0.0.0'使服务器对外可见 47 | app.run(host='0.0.0.0', port=35004, debug=False) 48 | -------------------------------------------------------------------------------- /images/C-3PO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chen-GX/C-3PO/e9a8021521a9eb006b7581b8c672dc3838cfc66c/images/C-3PO.png -------------------------------------------------------------------------------- /inference/single_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p ./logs 4 | 5 | dataname_list=2WikiMultiHopQA,hotpotqa,Musique,NaturalQuestions,PopQA,TriviaQA 6 | # 按逗号分开 7 | dataname_list=(${dataname_list//,/ }) 8 | 9 | 10 | # environment 11 | export VLLM_USE_MODELSCOPE="False" 12 | export TOKENIZERS_PARALLELISM=false 13 | export CUDA_VISIBLE_DEVICES=0 14 | 15 | debug_num=-1 16 | tp=1 17 | proxy_concurrency=256 18 | model_type=proxy 19 | retriever_type=dense 20 | retrieve_server_url=http://10.32.25.199:35004/search 21 | musique_server_url=http://10.32.25.199:35002/search 22 | 23 | llm_server_url=http://10.32.4.13:10080/v1,http://10.32.17.208:10080/v1 # your llm server 24 | 25 | llm_name=qwen2-72b-instruct 26 | llm_server_type=online 27 | 28 | max_depth=13 29 | test=True 30 | temperature=0 31 | online_concurrency=32 32 | backend=sglang 33 | output_dir=output_path 34 | use_planning_cache=True 35 | wo_llm=False 36 | llm_query_few_shot=True 37 | 38 | port=20010 39 | 40 | folders=( 41 | path_to_your_ckpt 42 | ) 43 | 44 | python=/opt/conda/envs/c3po/bin/python 45 | 46 | for dataname in "${dataname_list[@]}" 47 | do 48 | for folder in "${folders[@]}" 49 | do 50 | echo ${dataname} 51 | $python ../C-3PO/main.py \ 52 | --model_type ${model_type} \ 53 | --tp ${tp} \ 54 | --proxy_concurrency ${proxy_concurrency} \ 55 | --dataname ${dataname} \ 56 | --output_dir ${output_dir} \ 57 | --checkpoint_dir ${folder} \ 58 | --retriever_type ${retriever_type} \ 59 | --retrieve_server_url ${retrieve_server_url} \ 60 | --musique_server_url ${musique_server_url} \ 61 | --llm_name ${llm_name} \ 62 | --llm_server_type ${llm_server_type} \ 63 | --llm_server_url ${llm_server_url} \ 64 | --debug_num ${debug_num} \ 65 | --max_depth ${max_depth} \ 66 | --test ${test} \ 67 | --online_concurrency ${online_concurrency} \ 68 | --temperature ${temperature} \ 69 | --backend ${backend} \ 70 | --use_planning_cache ${use_planning_cache} \ 71 | --wo_llm ${wo_llm} \ 72 | --llm_query_few_shot ${llm_query_few_shot} \ 73 | --port ${port} > ./logs/${dataname}_ppo_215_105_ckpt150.log 2>&1 74 | done 75 | done -------------------------------------------------------------------------------- /instruct_sampling_scripts/offline_base_instruct.sh: -------------------------------------------------------------------------------- 1 | 2 | dataname=$1 3 | output_dir=$2 4 | 5 | export VLLM_USE_MODELSCOPE="False" 6 | python=/opt/conda/envs/c3po/bin/python 7 | 8 | model_type=proxy 9 | 10 | # parameters 11 | n_decision_sample=3 12 | n_generate_sample=2 13 | n_plan_sample=$3 14 | checkpoint_dir=$4 15 | 16 | retrieve_server_url=http://10.32.25.199:35004/search 17 | musique_server_url=http://10.32.25.199:35002/search 18 | llm_server_type=offline 19 | 20 | filter=False 21 | filter_path=path/sampling_filter.json 22 | 23 | force_decision=${5:-False} 24 | force_action=${6:-Planning} 25 | filter_key=${7:-existing} 26 | focus_qid=${8:-""} 27 | 28 | max_depth=13 29 | 30 | debug_num=-1 31 | seed=0 32 | test=False 33 | backend=sglang 34 | 35 | ${python} ../C-3PO/main.py \ 36 | --model_type ${model_type} \ 37 | --dataname ${dataname} \ 38 | --output_dir ${output_dir} \ 39 | --n_decision_sample ${n_decision_sample} \ 40 | --n_plan_sample ${n_plan_sample} \ 41 | --n_generate_sample ${n_generate_sample} \ 42 | --checkpoint_dir ${checkpoint_dir} \ 43 | --retrieve_server_url ${retrieve_server_url} \ 44 | --musique_server_url ${musique_server_url} \ 45 | --llm_server_type ${llm_server_type} \ 46 | --debug_num ${debug_num} \ 47 | --filter ${filter} \ 48 | --filter_path ${filter_path} \ 49 | --force_decision ${force_decision} \ 50 | --force_action ${force_action} \ 51 | --filter_key ${filter_key} \ 52 | --max_depth ${max_depth} \ 53 | --focus_qid ${focus_qid} \ 54 | --seed ${seed} \ 55 | --test ${test} \ 56 | --backend ${backend} -------------------------------------------------------------------------------- /instruct_sampling_scripts/run_72b.sh: -------------------------------------------------------------------------------- 1 | timestamp=$( date +"%Y%m%d_%H%M%S") 2 | 3 | echo $timestamp 4 | 5 | mkdir -p ./logs 6 | 7 | n_plan_sample=2 8 | checkpoint_dir=./model_cache/Qwen2-72B-Instruct 9 | 10 | force_decision=False 11 | force_action=Planning 12 | filter_key='None' 13 | focus_qid='None' 14 | 15 | output_dir=./output_dir/run/batch_tree_search/${force_action} 16 | 17 | for dataname in 2WikiMultiHopQA Musique hotpotqa NaturalQuestions PopQA TriviaQA 18 | do 19 | echo "Running ${dataname}" 20 | # 获取checkpoint_dir的basename 21 | basename=$(basename ${checkpoint_dir}) 22 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash offline_base_instruct.sh ${dataname} ${output_dir} ${n_plan_sample} ${checkpoint_dir} ${force_decision} ${force_action} ${filter_key} ${focus_qid} > ./logs/${force_action}_${dataname}_${basename}_${timestamp}.log 2>&1 23 | pkill -f /opt/conda/envs/c3po/bin/python 24 | sleep 300 25 | done 26 | 27 | wait -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.10.10 3 | aiosignal==1.3.1 4 | annotated-types==0.7.0 5 | anthropic==0.36.1 6 | anyio==4.6.2.post1 7 | asttokens==2.4.1 8 | attrs==24.2.0 9 | audioread==3.0.1 10 | certifi==2024.8.30 11 | cffi==1.17.1 12 | charset-normalizer==3.4.0 13 | click==8.1.7 14 | cloudpickle==3.1.0 15 | compressed-tensors==0.6.0 16 | contourpy==1.3.1 17 | cycler==0.12.1 18 | datasets==3.0.1 19 | decorator==5.1.1 20 | decord==0.6.0 21 | dill==0.3.8 22 | diskcache==5.6.3 23 | distro==1.9.0 24 | einops==0.8.0 25 | executing==2.1.0 26 | fastapi==0.115.2 27 | filelock==3.16.1 28 | flashinfer==0.1.6+cu121torch2.4 29 | fonttools==4.55.6 30 | frozenlist==1.4.1 31 | fsspec==2024.6.1 32 | gguf==0.10.0 33 | h11==0.14.0 34 | hf_transfer==0.1.8 35 | httpcore==1.0.6 36 | httptools==0.6.2 37 | httpx==0.27.2 38 | huggingface-hub==0.25.2 39 | idna==3.10 40 | importlib_metadata==8.5.0 41 | interegular==0.3.3 42 | ipython==8.29.0 43 | jedi==0.19.1 44 | Jinja2==3.1.4 45 | jiter==0.6.1 46 | joblib==1.4.2 47 | jsonschema==4.23.0 48 | jsonschema-specifications==2024.10.1 49 | kiwisolver==1.4.8 50 | lark==1.2.2 51 | lazy_loader==0.4 52 | librosa==0.10.2.post1 53 | litellm==1.49.4 54 | llvmlite==0.43.0 55 | lm-format-enforcer==0.10.6 56 | MarkupSafe==3.0.1 57 | matplotlib==3.10.0 58 | matplotlib-inline==0.1.7 59 | mistral_common==1.4.4 60 | modelscope==1.19.0 61 | mpmath==1.3.0 62 | msgpack==1.1.0 63 | msgspec==0.18.6 64 | multidict==6.1.0 65 | multiprocess==0.70.16 66 | nest-asyncio==1.6.0 67 | networkx==3.4.1 68 | numba==0.60.0 69 | numpy==1.26.4 70 | nvidia-cublas-cu12==12.1.3.1 71 | nvidia-cuda-cupti-cu12==12.1.105 72 | nvidia-cuda-nvrtc-cu12==12.1.105 73 | nvidia-cuda-runtime-cu12==12.1.105 74 | nvidia-cudnn-cu12==9.1.0.70 75 | nvidia-cufft-cu12==11.0.2.54 76 | nvidia-curand-cu12==10.3.2.106 77 | nvidia-cusolver-cu12==11.4.5.107 78 | nvidia-cusparse-cu12==12.1.0.106 79 | nvidia-ml-py==12.560.30 80 | nvidia-nccl-cu12==2.20.5 81 | nvidia-nvjitlink-cu12==12.6.77 82 | nvidia-nvtx-cu12==12.1.105 83 | openai==1.51.2 84 | opencv-python-headless==4.10.0.84 85 | orjson==3.10.11 86 | outlines==0.0.46 87 | packaging==24.1 88 | pandas==2.2.3 89 | parso==0.8.4 90 | partial-json-parser==0.2.1.1.post4 91 | Pebble==5.0.7 92 | pexpect==4.9.0 93 | pillow==10.4.0 94 | pip==24.2 95 | platformdirs==4.3.6 96 | pooch==1.8.2 97 | prometheus_client==0.21.0 98 | prometheus-fastapi-instrumentator==7.0.0 99 | prompt_toolkit==3.0.48 100 | propcache==0.2.0 101 | protobuf==5.28.2 102 | psutil==6.0.0 103 | ptyprocess==0.7.0 104 | pure_eval==0.2.3 105 | py-cpuinfo==9.0.0 106 | pyairports==2.1.1 107 | pyarrow==17.0.0 108 | pycountry==24.6.1 109 | pycparser==2.22 110 | pydantic==2.9.2 111 | pydantic_core==2.23.4 112 | Pygments==2.18.0 113 | pyparsing==3.2.1 114 | python-dateutil==2.9.0.post0 115 | python-dotenv==1.0.1 116 | python-multipart==0.0.12 117 | pytz==2024.2 118 | PyYAML==6.0.2 119 | pyzmq==26.2.0 120 | ray==2.37.0 121 | referencing==0.35.1 122 | regex==2024.9.11 123 | requests==2.32.3 124 | rpds-py==0.20.0 125 | safetensors==0.4.5 126 | scikit-learn==1.5.2 127 | scipy==1.14.1 128 | seaborn==0.13.2 129 | sentence-transformers==3.3.1 130 | sentencepiece==0.2.0 131 | setuptools==75.1.0 132 | sglang==0.3.6 133 | six==1.16.0 134 | sniffio==1.3.1 135 | soundfile==0.12.1 136 | soxr==0.5.0.post1 137 | stack-data==0.6.3 138 | starlette==0.40.0 139 | sympy==1.13.3 140 | threadpoolctl==3.5.0 141 | tiktoken==0.7.0 142 | tokenizers==0.20.1 143 | torch==2.4.0 144 | torchao==0.5.0 145 | torchvision==0.19.0 146 | tqdm==4.66.5 147 | traitlets==5.14.3 148 | transformers==4.45.2 149 | triton==3.0.0 150 | typing_extensions==4.12.2 151 | tzdata==2024.2 152 | urllib3==2.2.3 153 | uvicorn==0.32.0 154 | uvloop==0.21.0 155 | vllm==0.6.3.post1 156 | vllm-flash-attn==2.6.1 157 | watchfiles==0.24.0 158 | wcwidth==0.2.13 159 | websockets==13.1 160 | wheel==0.44.0 161 | xformers==0.0.27.post2 162 | xxhash==3.5.0 163 | yarl==1.15.3 164 | zipp==3.20.2 165 | zmq==0.0.0 166 | -------------------------------------------------------------------------------- /retrieval_requirements.txt: -------------------------------------------------------------------------------- 1 | blinker==1.8.2 2 | certifi==2024.7.4 3 | charset-normalizer==3.3.2 4 | click==8.1.7 5 | faiss==1.8.0 6 | filelock==3.15.4 7 | Flask==3.0.3 8 | fsspec==2024.6.1 9 | gunicorn==23.0.0 10 | huggingface-hub==0.24.6 11 | idna==3.7 12 | itsdangerous==2.2.0 13 | Jinja2==3.1.4 14 | MarkupSafe==2.1.5 15 | mkl-fft==1.3.8 16 | mkl-random==1.2.4 17 | mkl-service==2.4.0 18 | mpmath==1.3.0 19 | networkx==3.3 20 | numpy==1.26.4 21 | nvidia-cublas-cu12==12.1.3.1 22 | nvidia-cuda-cupti-cu12==12.1.105 23 | nvidia-cuda-nvrtc-cu12==12.1.105 24 | nvidia-cuda-runtime-cu12==12.1.105 25 | nvidia-cudnn-cu12==8.9.2.26 26 | nvidia-cufft-cu12==11.0.2.54 27 | nvidia-curand-cu12==10.3.2.106 28 | nvidia-cusolver-cu12==11.4.5.107 29 | nvidia-cusparse-cu12==12.1.0.106 30 | nvidia-nccl-cu12==2.20.5 31 | nvidia-nvjitlink-cu12==12.6.20 32 | nvidia-nvtx-cu12==12.1.105 33 | packaging==24.1 34 | pillow==10.4.0 35 | pip==24.2 36 | PyYAML==6.0.2 37 | regex==2024.7.24 38 | requests==2.32.3 39 | safetensors==0.4.4 40 | setuptools==72.1.0 41 | sympy==1.13.2 42 | tokenizers==0.19.1 43 | torch==2.3.1 44 | torchaudio==2.3.1 45 | torchvision==0.18.1 46 | tqdm==4.66.5 47 | transformers==4.44.2 48 | triton==2.3.1 49 | typing_extensions==4.12.2 50 | urllib3==2.2.2 51 | Werkzeug==3.0.4 52 | wheel==0.43.0 53 | -------------------------------------------------------------------------------- /train/sft_scripts/run_base_packing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_NUM=$(nvidia-smi -L | wc -l) 4 | 5 | seed=42 6 | 7 | export TOKENIZERS_PARALLELISM=false 8 | export VLLM_USE_MODELSCOPE="False" 9 | export WANDB_DISABLED="True" 10 | 11 | echo "Prepare the conda environment" 12 | 13 | timestamp=$( date +"%Y%m%d_%H%M%S") 14 | echo $timestamp 15 | 16 | root_path=xxx 17 | model_name_or_path=${root_path}/model_cache/Qwen2-0.5B 18 | model_base_name=$(basename ${model_name_or_path}) 19 | dataset=policy_11.25_6data_v2 20 | dataset_dir=${root_path}/proxy_train_data/sft 21 | 22 | finetuning_type=full 23 | learning_rate=4e-5 24 | 25 | if [ $GPU_NUM -eq 8 ]; then 26 | gradient_accumulation_steps=6 27 | per_device_train_batch_size=8 28 | elif [ $GPU_NUM -eq 4 ]; then 29 | gradient_accumulation_steps=12 30 | per_device_train_batch_size=8 31 | else 32 | echo "GPU_NUM must be 4 or 8" 33 | exit 1 34 | fi 35 | 36 | 37 | output_dir=${root_path}/workspace/output_dir/run/proxy_sft/${model_base_name}/${dataset}/${learning_rate}_${GPU_NUM}GPU_${timestamp} 38 | deepspeed_config_file=${root_path}/workspace/LLaMA-Factory/examples/deepspeed/ds_z2_config.json 39 | deepspeed_env=/opt/conda/envs/c3po/bin/deepspeed 40 | 41 | ${deepspeed_env} --num_gpus ${GPU_NUM} ../src/train.py \ 42 | --deepspeed ${deepspeed_config_file} \ 43 | --stage sft \ 44 | --do_train \ 45 | --flash_attn fa2 \ 46 | --packing True \ 47 | --neat_packing True \ 48 | --model_name_or_path ${model_name_or_path} \ 49 | --dataset_dir ${dataset_dir}\ 50 | --dataset ${dataset} \ 51 | --template qwen \ 52 | --finetuning_type ${finetuning_type} \ 53 | --save_safetensors \ 54 | --output_dir ${output_dir} \ 55 | --overwrite_cache \ 56 | --max_new_tokens 2048 \ 57 | --cutoff_len 4096 \ 58 | --per_device_train_batch_size ${per_device_train_batch_size} \ 59 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 60 | --warmup_ratio 0.03 \ 61 | --weight_decay 0. \ 62 | --lr_scheduler_type cosine \ 63 | --logging_steps 10 \ 64 | --save_steps 20 \ 65 | --learning_rate ${learning_rate} \ 66 | --num_train_epochs 8.0 \ 67 | --dataloader_num_workers 16 \ 68 | --preprocessing_num_workers 128 \ 69 | --ddp_timeout 180000000 \ 70 | --seed $seed \ 71 | --plot_loss \ 72 | --save_only_model \ 73 | --bf16 74 | --------------------------------------------------------------------------------