├── .DS_Store ├── method.gif ├── Self_Plan ├── .DS_Store ├── Train │ ├── .DS_Store │ ├── deepspeed_config_s3.json │ ├── train.py │ └── train_lora.py ├── Tool_Selection │ ├── __pycache__ │ │ └── selector.cpython-39.pyc │ ├── HotpotQA_Tools.json │ ├── ScienceQA_Tools.json │ ├── tool_selected.py │ ├── llms.py │ ├── pre_prompt.py │ └── selector.py ├── Group_Planning │ ├── benchmark_run │ │ ├── data │ │ │ └── hotpotqa │ │ │ │ ├── easy.joblib │ │ │ │ ├── hard.joblib │ │ │ │ └── medium.joblib │ │ ├── __pycache__ │ │ │ ├── llms.cpython-39.pyc │ │ │ ├── utils.cpython-39.pyc │ │ │ ├── config.cpython-39.pyc │ │ │ ├── fewshots.cpython-39.pyc │ │ │ ├── utils.cpython-311.pyc │ │ │ ├── agent_arch.cpython-39.pyc │ │ │ ├── pre_prompt.cpython-39.pyc │ │ │ ├── agent_arch.cpython-311.pyc │ │ │ └── Meta_agent_arch.cpython-39.pyc │ │ ├── config.py │ │ ├── evaluate.py │ │ ├── hotpotqa_env.py │ │ ├── llms.py │ │ ├── pre_prompt.py │ │ ├── wikienv.py │ │ ├── utils.py │ │ └── wrappers.py │ └── run_eval.py └── Traj_Syn │ ├── benchmark_run │ ├── __pycache__ │ │ ├── llms.cpython-39.pyc │ │ ├── config.cpython-39.pyc │ │ ├── utils.cpython-311.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── fewshots.cpython-39.pyc │ │ ├── agent_arch.cpython-311.pyc │ │ ├── agent_arch.cpython-39.pyc │ │ └── pre_prompt.cpython-39.pyc │ ├── config.py │ ├── evaluate.py │ ├── hotpotqa_env.py │ ├── llms.py │ ├── wikienv.py │ ├── pre_prompt.py │ ├── utils.py │ └── wrappers.py │ └── run_task.py ├── requirements.txt ├── Scripts ├── filter.sh ├── tool_select.sh ├── run_task.sh ├── model_bash │ ├── single_model.sh │ └── multi_model.sh ├── self_instruct_build.sh ├── run_eval.sh ├── fastchat_lora.sh ├── analysis.py └── filter_data.py ├── Self_Instruct ├── Meta_sample │ ├── Meta_Hotpotqa.json │ └── Meta_Scienceqa.json ├── pre_prompt.py ├── llms.py └── data_generation.py ├── Prompts ├── hotpotqa_cot.txt ├── scienceqa_cot.txt ├── hotpotqa_react.txt └── scienceqa_react.txt ├── README.md └── LICENSE /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/.DS_Store -------------------------------------------------------------------------------- /method.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/method.gif -------------------------------------------------------------------------------- /Self_Plan/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/.DS_Store -------------------------------------------------------------------------------- /Self_Plan/Train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Train/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain==0.0.299 2 | fschat==0.2.35 3 | transformers==4.36.2 4 | peft==0.5.0 5 | accelerate==0.23.0 6 | sentencepiece 7 | openai==0.28.0 8 | -------------------------------------------------------------------------------- /Self_Plan/Tool_Selection/__pycache__/selector.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Tool_Selection/__pycache__/selector.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/data/hotpotqa/easy.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa/easy.joblib -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/data/hotpotqa/hard.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa/hard.joblib -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/llms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/llms.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/data/hotpotqa/medium.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa/medium.joblib -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/fewshots.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/fewshots.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/llms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/llms.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/agent_arch.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/agent_arch.cpython-311.pyc -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/agent_arch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/agent_arch.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/__pycache__/pre_prompt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Traj_Syn/benchmark_run/__pycache__/pre_prompt.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/fewshots.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/fewshots.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/agent_arch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/agent_arch.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/pre_prompt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/pre_prompt.cpython-39.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/agent_arch.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/agent_arch.cpython-311.pyc -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/__pycache__/Meta_agent_arch.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjunlp/AutoAct/HEAD/Self_Plan/Group_Planning/benchmark_run/__pycache__/Meta_agent_arch.cpython-39.pyc -------------------------------------------------------------------------------- /Scripts/filter.sh: -------------------------------------------------------------------------------- 1 | python Scripts/filter_data.py \ 2 | --source_path Self_Planning/Traj_Syn/output/hotpotqa_train_data.jsonl \ 3 | --save_path Self_Planning/Traj_Syn/output \ 4 | --task_name HotpotQA \ 5 | --filter_num 200 -------------------------------------------------------------------------------- /Scripts/tool_select.sh: -------------------------------------------------------------------------------- 1 | python ../Self_Plan/Tool_Selection/tool_selected.py\ 2 | --model_name llama-2-13b-chat \ 3 | --task_name ScienceQA \ 4 | --top_k 40 \ 5 | --top_p 0.75 \ 6 | --max_tokens 1024 \ 7 | --tool_save_path ../Self_Plan/Tool_Selection/Tools.json \ 8 | -------------------------------------------------------------------------------- /Scripts/run_task.sh: -------------------------------------------------------------------------------- 1 | python Self_Plan/Traj_Syn/run_task.py \ 2 | --agent_name ZeroshotThink_HotPotQA_run_Agent \ 3 | --llm_name llama-2-13b-chat \ 4 | --max_context_len 4096 \ 5 | --task Hotpotqa \ 6 | --task_path Self_Instruct/hotpotqa_metaqa.json \ 7 | --save_path Self_Plan/Traj_Syn/output/hotpotqa_train_data.jsonl -------------------------------------------------------------------------------- /Scripts/model_bash/single_model.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m fastchat.serve.model_worker \ 2 | --port 31021 --worker http://localhost:31021 \ 3 | --host localhost \ 4 | --model-names your-model-name \ 5 | --model-path /model/path \ 6 | --max-gpu-memory 31Gib \ 7 | --dtype float16 \ 8 | --num-gpus 8 -------------------------------------------------------------------------------- /Scripts/self_instruct_build.sh: -------------------------------------------------------------------------------- 1 | python Self_Instruct/data_generation.py \ 2 | --source_data Self_Instruct/Meta_sample/Meta_Hotpotqa.json \ 3 | --target_data Self_Instruct/hotpotqa_metaqa.json \ 4 | --dataset_name hotpotqa \ 5 | --generate_all_num 800 \ 6 | --generate_per_round_num 10 \ 7 | --model_name llama-2-13b-chat \ -------------------------------------------------------------------------------- /Self_Instruct/Meta_sample/Meta_Hotpotqa.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Question": "from 1969 to 1979, Arno Schmidt was the executive chef of a hotel located in which neighborhood in New York ?", 4 | "Answer": "Manhattan" 5 | }, 6 | { 7 | "Question": "Are both Shangri-La City and Ma'anshan cities in China?", 8 | "Answer": "yes" 9 | } 10 | ] -------------------------------------------------------------------------------- /Scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | python Self_Plan/Group_Planning/run_eval.py \ 2 | --agent_name ZeroshotThink_HotPotQA_run_Agent \ 3 | --plan_agent plan \ 4 | --tool_agent tool \ 5 | --reflect_agent reflect \ 6 | --max_context_len 4096 \ 7 | --task HotpotQA \ 8 | --task_path Self_Plan/Group_Planning/benchmark_run/data/hotpotqa \ 9 | --save_path Self_Plan/Group_Planning/output/13b -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | available_agent_names = ["ZeroshotThink_HotPotQA_run_Agent","ZeroshotThink_ScienceQA_run_Agent"] 9 | OPENAI_API_KEY = "" -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | available_agent_names = ["ZeroshotThink_HotPotQA_run_Agent","ZeroshotThink_ScienceQA_run_Agent"] 9 | OPENAI_API_KEY = "" -------------------------------------------------------------------------------- /Scripts/model_bash/multi_model.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PEFT_SHARE_BASE_WEIGHTS=true python3 -m fastchat.serve.multi_model_worker \ 2 | --port 31022 --worker http://localhost:31022 \ 3 | --host localhost \ 4 | --model-path /model/path \ 5 | --model-names "plan" \ 6 | --model-path /model/path\ 7 | --model-names "action" \ 8 | --model-path /model/path \ 9 | --model-names "reflect" \ 10 | --max-gpu-memory 31Gib \ 11 | --dtype float16 \ 12 | --num-gpus 8 -------------------------------------------------------------------------------- /Prompts/hotpotqa_cot.txt: -------------------------------------------------------------------------------- 1 | Solve a question answering task with Thought, Action step. You should think step-by-step. 2 | Here are some examples. 3 | 4 | Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who? 5 | Thought: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon. 6 | Action: finish[Richard Nixon] 7 | 8 | Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? 9 | Thought: The eastern sector of Colorado orogeny extends into the High Plains. High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft. 10 | Action: finish[1,800 to 7,000 ft] -------------------------------------------------------------------------------- /Self_Plan/Train/deepspeed_config_s3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "zero_optimization": { 11 | "stage": 3, 12 | "offload_optimizer": { 13 | "device": "cpu", 14 | "pin_memory": true 15 | }, 16 | "offload_param": { 17 | "device": "cpu", 18 | "pin_memory": true 19 | }, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "stage3_max_live_parameters" : 1e9, 23 | "stage3_max_reuse_distance" : 1e9, 24 | "stage3_prefetch_bucket_size" : 5e8, 25 | "stage3_param_persistence_threshold" : 1e6, 26 | "sub_group_size" : 1e12, 27 | "stage3_gather_16bit_weights_on_model_save": true 28 | }, 29 | "train_batch_size": "auto", 30 | "train_micro_batch_size_per_gpu": "auto", 31 | "gradient_accumulation_steps": "auto" 32 | } -------------------------------------------------------------------------------- /Scripts/fastchat_lora.sh: -------------------------------------------------------------------------------- 1 | for agent in plan tool reflect 2 | do 3 | echo "####################" 4 | echo $agent 5 | echo "####################" 6 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 deepspeed Self_Plan/Train/train_lora.py \ 7 | --model_name_or_path llama-2-13b-chat \ 8 | --lora_r 8 \ 9 | --lora_alpha 16 \ 10 | --lora_dropout 0.05 \ 11 | --data_path Self_Plan/Traj_Syn/output/data_$agent.json \ 12 | --output_dir Self_Plan/Train/lora/HotpotQA/13b-$agent-5-epoch \ 13 | --num_train_epochs 5 \ 14 | --per_device_train_batch_size 2 \ 15 | --per_device_eval_batch_size 1 \ 16 | --gradient_accumulation_steps 1 \ 17 | --evaluation_strategy "no" \ 18 | --save_strategy "steps" \ 19 | --save_steps 10000 \ 20 | --save_total_limit 1 \ 21 | --learning_rate 1e-4 \ 22 | --weight_decay 0. \ 23 | --warmup_ratio 0.03 \ 24 | --lr_scheduler_type "cosine" \ 25 | --logging_steps 1 \ 26 | --fp16 True \ 27 | --model_max_length 4096 \ 28 | --gradient_checkpointing True \ 29 | --q_lora False \ 30 | --deepspeed Self_Plan/Train/deepspeed_config_s3.json \ 31 | --resume_from_checkpoint False 32 | done -------------------------------------------------------------------------------- /Self_Plan/Tool_Selection/HotpotQA_Tools.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "BingSearch", 4 | "definition": "BingSearch engine can search for rich external knowledge on the Internet based on keywords, which can compensate for knowledge fallacy and knowledge outdated.", 5 | "usage": "BingSearch[query], which searches the exact detailed query on the Internet and returns the relevant information to the query. Be specific and precise with your query to increase the chances of getting relevant results. For example, Bingsearch[popular dog breeds in the United States]" 6 | }, 7 | { 8 | "name": "Retrieve", 9 | "definition": "Retrieve additional background knowledge crucial for tackling complex problems. It is especially beneficial for specialized domains like science and mathematics, providing context for the task", 10 | "usage": "Retrieve[entity], which retrieves the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to retrieve. For example, Retrieve[Milhouse]" 11 | }, 12 | { 13 | "name": "Lookup", 14 | "definition": "A Lookup Tool returns the next sentence containing the target string in the page from the search tool (like BingSearch or Retrieve),so it is recommended to use with Bingsearch and Retrieve, simulating Ctrl+F functionality on the browser to find target answer.", 15 | "usage": "Lookup[keyword], which returns the next sentence containing the keyword in the last passage successfully found by Retrieve or BingSearch. For example, Lookup[river]." 16 | } 17 | ] -------------------------------------------------------------------------------- /Scripts/analysis.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import argparse 3 | 4 | def analyze_results(file_path): 5 | all_correct = 0 6 | all_wrong = 0 7 | right_reflect_wrong = 0 8 | wrong_reflect_wrong = 0 9 | reward = 0 10 | all_row = 0 11 | with open(file_path, "r") as f: 12 | for item in jsonlines.Reader(f): 13 | all_row += 1 14 | reward += item["reward"] 15 | if item["correct"]: 16 | all_correct += 1 17 | if "Reflect[wrong]" in item["prompt"]: 18 | right_reflect_wrong += 1 19 | else: 20 | all_wrong += 1 21 | if "Reflect[wrong]" in item["prompt"]: 22 | wrong_reflect_wrong += 1 23 | print(f'1. Accuracy: {all_correct/all_row}') 24 | print(f'2. Reflect wrong: {(right_reflect_wrong+wrong_reflect_wrong)/all_row} Reflect right: {1-(right_reflect_wrong+wrong_reflect_wrong)/all_row}') 25 | print(f'3. Correct answers - Reflect right: {1-(right_reflect_wrong/all_correct)} Wrong: {right_reflect_wrong/all_correct}') 26 | print(f'4. Wrong answers - Reflect right: {1-(wrong_reflect_wrong/all_wrong)} Wrong: {wrong_reflect_wrong/all_wrong}') 27 | print(f'5. Reward: {reward}') 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser(description="Analyze results from JSONL file.") 32 | parser.add_argument("--file_path", help="Path to the directory containing JSONL files.") 33 | args = parser.parse_args() 34 | analyze_results(args.file_path) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /Self_Plan/Tool_Selection/ScienceQA_Tools.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "BingSearch", 4 | "definition": "BingSearch engine can search for rich external knowledge on the Internet based on keywords, which can compensate for knowledge fallacy and knowledge outdated.", 5 | "usage": "BingSearch[query], which searches the exact detailed query on the Internet and returns the relevant information to the query. Be specific and precise with your query to increase the chances of getting relevant results. For example, Bingsearch[popular dog breeds in the United States]" 6 | }, 7 | { 8 | "name": "Retrieve", 9 | "definition": "Retrieve additional background knowledge crucial for tackling complex problems. It is especially beneficial for specialized domains like science and mathematics, providing context for the task", 10 | "usage": "Retrieve[entity], which retrieves the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to retrieve. For example, Retrieve[Milhouse]" 11 | }, 12 | { 13 | "name": "Image2Text", 14 | "definition": "Image2Text is used to detect words in images convert them into text by OCR and generate captions for images. It is particularly valuable when understanding an image semantically, like identifying objects and interactions in a scene.", 15 | "usage": "Image2Text[image], which generates captions for the image and detects words in the image. You are recommended to use it first to get more information about the image to the question. If the question contains an image, it will return the caption and OCR text, else, it will return None. For example, Image2Text[image]." 16 | } 17 | ] -------------------------------------------------------------------------------- /Prompts/scienceqa_cot.txt: -------------------------------------------------------------------------------- 1 | Solve a question answering task with Thought, Action step. You should think step-by-step. 2 | Here are some examples. 3 | 4 | Question: Complete the sentence so that it uses personification. 5 | I felt like my fate was () when I broke my arm right before the biggest game of the season. 6 | Options: (A) scheming against me (B) disastrous 7 | Metadata: {'has_image': False, 'grade': 10, 'subject': 'language science', 'topic': 'writing-strategies', 'category': 'Creative techniques', 'skill': 'Use personification'} 8 | Thought: According to the information provided by Metadata, I need to use personification to complete this sentence. "scheme against" is a personification phrase, so the answer is A. 9 | Action: finish[A] 10 | 11 | 12 | Question: Which statement describes the Sahara Desert ecosystem? 13 | Image: A desert with a bunch of animals on it. 14 | Options: (A) It has warm, wet summers. (B) It has dry, thin soil. (C) It has thick, moist soil 15 | Metadata: {'has_image': True, 'grade': 6, 'subject': 'natural science', 'topic': 'biology', 'category': 'Ecosystems', 'skill': 'Describe ecosystems'} 16 | Thought: A hot desert is a type of ecosystem. Hot deserts have the following features: a small amount of rain, dry, thin soil, many different types of organisms, and It has thick, moist soil. So, the following statement describes the Sahara Desert ecosystem: a small amount of rain, dry, thin soil, many different types of organisms, and It has thick, moist soil. It has dry, thin soil. The following statements do not describe the Sahara Desert: a small amount of rain, dry, thin soil, many different types of organisms, and It has thick, moist soil. It has warm, wet summers. Hence, the answer is B. 17 | Action: finish[B] 18 | 19 | -------------------------------------------------------------------------------- /Self_Instruct/pre_prompt.py: -------------------------------------------------------------------------------- 1 | HOTPOTQA_TASK_NAME = "HotpotQA" 2 | HOTPOTQA_TASK_DESCRIPTION = "This is a question-answering task that includes high-quality multi-hop questions and do not contain images. It tests language modeling abilities for multi-step reasoning and covers a wide range of topics. Some questions are challenging, while others are easier, requiring multiple steps of reasoning to arrive at the final answer." 3 | 4 | SCIENCEQA_TASK_NAME = "ScienceQA" 5 | SCIENCEQA_TASK_DESCRIPTION = " This is a multimodal question-answering task that necessitates a model to utilizetools for transforming image information intotextual data. Simultaneously, this task incorporates substantial background knowledge, requiring the language model to acquire external information to enhance its comprehension of the task" 6 | 7 | 8 | DATA_GEN_SYSTEM_PROMPT = """I want you to be a QA pair generator to generate high-quality questions for use in Task 9 | described as follows : 10 | Task Name: {task_name} 11 | Task Description: {task_description} 12 | """ 13 | 14 | HOTPOTQA_DATA_GEN_HUMAN_PROMPT = """{QA_pairs}\nModelled on all examples above,I want you to generate new different {Gen_num} Question-Answer pairs. The format like below: 15 | Question: The Treaty of Versailles, signed in 1919, officially ended which war? 16 | Answer: World War I 17 | """ 18 | 19 | SCIENCEQA_DATA_GEN_HUMAN_PROMPT = """{QA_pairs}\nModelled on all examples above,I want you to generate {Gen_num} new different multimodal multiple-choice science questions.The question format like below: 20 | Question: Which of these states is farthest north? 21 | Options: Options: (A) West Virginia (B) Louisiana (C) Arizona (D) Oklahoma (hint:ensure all choices in one line ,not one choice one line) 22 | Ocr: Oklahoma,West Virginia,Arizona,Louisiana 23 | Caption: An aerial view of a painting of a forest 24 | Answer: A. West Virginia 25 | """ 26 | -------------------------------------------------------------------------------- /Self_Instruct/Meta_sample/Meta_Scienceqa.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "Question": "Which of these states is farthest north?\nOptions: (A) West Virginia (B) Louisiana (C) Arizona (D) Oklahoma", 4 | "caption": "An aerial view of a painting of a forest.", 5 | "orc": "[([[307, 253], [361, 253], [361, 265], [307, 265]], 'Oklahoma', 0.7105310102042172), ([[623, 261], [691, 261], [691, 275], [623, 275]], 'West Virginia', 0.9675798036138213), ([[153, 281], [193, 281], [193, 293], [153, 293]], 'Arizona', 0.9426240560831896), ([[365, 319], [419, 319], [419, 333], [365, 333]], 'Louisiana', 0.9964687228131679)]", 6 | "Answer": "A. West Virginia", 7 | "metadata": { 8 | "has_image": true, 9 | "grade": 2, 10 | "subject": "social science", 11 | "topic": "geography", 12 | "category": "Geography", 13 | "skill": "Read a map: cardinal directions" 14 | } 15 | }, 16 | { 17 | "Question": "Identify the question that Tom and Justin's experiment can best answer.\n\nContext: The passage below describes an experiment. Read the passage and then follow the instructions below.\n\nTom placed a ping pong ball in a catapult, pulled the catapult's arm back to a 45\u00b0 angle, and launched the ball. Then, Tom launched another ping pong ball, this time pulling the catapult's arm back to a 30\u00b0 angle. With each launch, his friend Justin measured the distance between the catapult and the place where the ball hit the ground. Tom and Justin repeated the launches with ping pong balls in four more identical catapults. They compared the distances the balls traveled when launched from a 45\u00b0 angle to the distances the balls traveled when launched from a 30\u00b0 angle.\nFigure: a catapult for launching ping pong balls.\n\nOptions: (A) Do ping pong balls stop rolling along the ground sooner after being launched from a 30\u00b0 angle or a 45\u00b0 angle? (B) Do ping pong balls travel farther when launched from a 30\u00b0 angle compared to a 45\u00b0 angle?", 18 | "caption": "A wooden board with a wooden head on top of it.", 19 | "orc": "[]", 20 | "Answer": "B. Do ping pong balls travel farther when launched from a 30\u00b0 angle compared to a 45\u00b0 angle?", 21 | "metadata": { 22 | "has_image": true, 23 | "grade": 8, 24 | "subject": "natural science", 25 | "topic": "science-and-engineering-practices", 26 | "category": "Designing experiments", 27 | "skill": "Identify the experimental question" 28 | } 29 | } 30 | ] -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import json 9 | import os 10 | import re 11 | from collections import defaultdict 12 | import pandas as pd 13 | import hotpotqa_run.utils as utils 14 | 15 | def eval_success(result_file) -> list: 16 | df = pd.read_csv(result_file) 17 | return df['success'].tolist() 18 | 19 | def eval_reward(result_file) -> list: 20 | df = pd.read_csv(result_file) 21 | return df['reward'].tolist() 22 | 23 | def eval_llm_agent(llm_name, agent_name): 24 | levels = ['easy','medium','hard'] 25 | all_reward = [] 26 | all_success = [] 27 | for l in levels: 28 | file_name = f"execution_data/hotpotqa/{l}_{agent_name}_{llm_name}.csv" 29 | all_reward += eval_reward(file_name) 30 | all_success += eval_success(file_name) 31 | avg_reward = sum(all_reward)/len(all_reward) 32 | avg_success = sum(all_success)/len(all_success) 33 | return avg_reward, avg_success 34 | 35 | def eval_llm_agent_level(llm_name, agent_name, level): 36 | file_name = f"execution_data/hotpotqa/{level}_{agent_name}_{llm_name}.csv" 37 | all_reward = eval_reward(file_name) 38 | all_success = eval_success(file_name) 39 | avg_reward = sum(all_reward)/len(all_reward) 40 | avg_success = sum(all_success)/len(all_success) 41 | return avg_reward, avg_success 42 | 43 | def eval_sessions(llm_name, agent_name): 44 | levels = ['easy','medium','hard'] 45 | all_reward = [] 46 | all_success = [] 47 | for l in levels: 48 | reward, success = eval_sessions_level((llm_name, agent_name,l)) 49 | all_reward += reward 50 | all_success += success 51 | avg_reward = sum(all_reward)/len(all_reward) 52 | avg_success = sum(all_success)/len(all_success) 53 | return avg_reward, avg_success 54 | 55 | def eval_sessions_level(llm_name, agent_name,level): 56 | file_name = f"execution_data/hotpotqa/{level}_{agent_name}_{llm_name}.jsonl" 57 | sessions = utils.get_all_agent_sessions(file_name) 58 | all_reward = [sess["reward"] for sess in sessions] 59 | all_success = [sess["correct"] for sess in sessions] 60 | avg_reward = sum(all_reward)/len(all_reward) 61 | avg_success = sum(all_success)/len(all_success) 62 | return avg_reward, avg_success 63 | 64 | def get_reward_w_level(llm_name, agent_name): 65 | levels = ['easy','medium','hard'] 66 | ret = [] 67 | for l in levels: 68 | reward, _ = eval_sessions_level(llm_name, agent_name, l) 69 | ret.append(reward) 70 | return ret 71 | 72 | -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import json 9 | import os 10 | import re 11 | from collections import defaultdict 12 | import pandas as pd 13 | import hotpotqa_run.utils as utils 14 | 15 | def eval_success(result_file) -> list: 16 | df = pd.read_csv(result_file) 17 | return df['success'].tolist() 18 | 19 | def eval_reward(result_file) -> list: 20 | df = pd.read_csv(result_file) 21 | return df['reward'].tolist() 22 | 23 | def eval_llm_agent(llm_name, agent_name): 24 | levels = ['easy','medium','hard'] 25 | all_reward = [] 26 | all_success = [] 27 | for l in levels: 28 | file_name = f"execution_data/hotpotqa/{l}_{agent_name}_{llm_name}.csv" 29 | all_reward += eval_reward(file_name) 30 | all_success += eval_success(file_name) 31 | avg_reward = sum(all_reward)/len(all_reward) 32 | avg_success = sum(all_success)/len(all_success) 33 | return avg_reward, avg_success 34 | 35 | def eval_llm_agent_level(llm_name, agent_name, level): 36 | file_name = f"execution_data/hotpotqa/{level}_{agent_name}_{llm_name}.csv" 37 | all_reward = eval_reward(file_name) 38 | all_success = eval_success(file_name) 39 | avg_reward = sum(all_reward)/len(all_reward) 40 | avg_success = sum(all_success)/len(all_success) 41 | return avg_reward, avg_success 42 | 43 | def eval_sessions(llm_name, agent_name): 44 | levels = ['easy','medium','hard'] 45 | all_reward = [] 46 | all_success = [] 47 | for l in levels: 48 | reward, success = eval_sessions_level((llm_name, agent_name,l)) 49 | all_reward += reward 50 | all_success += success 51 | avg_reward = sum(all_reward)/len(all_reward) 52 | avg_success = sum(all_success)/len(all_success) 53 | return avg_reward, avg_success 54 | 55 | def eval_sessions_level(llm_name, agent_name,level): 56 | file_name = f"execution_data/hotpotqa/{level}_{agent_name}_{llm_name}.jsonl" 57 | sessions = utils.get_all_agent_sessions(file_name) 58 | all_reward = [sess["reward"] for sess in sessions] 59 | all_success = [sess["correct"] for sess in sessions] 60 | avg_reward = sum(all_reward)/len(all_reward) 61 | avg_success = sum(all_success)/len(all_success) 62 | return avg_reward, avg_success 63 | 64 | def get_reward_w_level(llm_name, agent_name): 65 | levels = ['easy','medium','hard'] 66 | ret = [] 67 | for l in levels: 68 | reward, _ = eval_sessions_level(llm_name, agent_name, l) 69 | ret.append(reward) 70 | return ret 71 | 72 | -------------------------------------------------------------------------------- /Self_Plan/Tool_Selection/tool_selected.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import json 4 | from langchain.vectorstores import Chroma 5 | from langchain.embeddings import OpenAIEmbeddings 6 | import sys 7 | sys.path.append('/data/rolnan/') 8 | from pre_prompt import ( 9 | ACTION_SYSTEM_PROMPT, 10 | BENCHMARK_DESCRIPTION, 11 | TOOL_POOL, 12 | TASK_PROMPT_TEMPLATE 13 | ) 14 | from llms import MetaAgent 15 | 16 | 17 | def action_parse(text): 18 | text = text.strip() 19 | lines = text.split("\n") 20 | actions = [line.strip() for line in lines if line != "" and line[0].isdigit()] 21 | return actions 22 | 23 | def main(args): 24 | task_name = args.task_name 25 | task_description = BENCHMARK_DESCRIPTION[args.task_name] 26 | 27 | meta_agent = MetaAgent( 28 | model_name=args.model_name, 29 | openai_key=args.openai_key, 30 | url=args.openai_base, 31 | system_prompt=ACTION_SYSTEM_PROMPT 32 | ) 33 | tool_pool = [{"name": tool["name"], "definition": tool["definition"]} for tool in TOOL_POOL] 34 | human_prompt_args = {"task_name": task_name, "task_description": task_description, "tool_pool": tool_pool} 35 | output = meta_agent.generate( 36 | human_prompt_template=TASK_PROMPT_TEMPLATE, 37 | human_prompt_args=human_prompt_args, 38 | temprature=args.temperature, 39 | top_k=args.top_k, 40 | top_p=args.top_p, 41 | max_tokens=args.max_tokens, 42 | update_prompt=False 43 | ) 44 | print(output) 45 | 46 | generated_actions = action_parse(output) 47 | tool_selected = [] 48 | for action in generated_actions: 49 | for tool in TOOL_POOL: 50 | if tool["name"] in action: 51 | tool_selected.append(tool) 52 | break 53 | with open(args.tool_save_path, 'w') as json_file: 54 | json.dump(tool_selected, json_file, indent=2) 55 | 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--model_name", type=str, default="/data/PLMs/llama-2-converted/7b-chat") 61 | parser.add_argument("--task_name", type=str, default="ScienceQA") 62 | parser.add_argument("--openai_key", type=str, default="EMPTY") 63 | parser.add_argument("--openai_base", type=str, default="http://localhost:8000/v1") 64 | parser.add_argument("--temperature", type=float, default=0.2) 65 | parser.add_argument("--top_k", type=int, default=40) 66 | parser.add_argument("--top_p", type=float, default=0.75) 67 | parser.add_argument("--max_tokens", type=int, default=1024) 68 | parser.add_argument("--retrieve_k", type=int, default=3) 69 | parser.add_argument("--retrieve_p", type=float, default=0.6) 70 | parser.add_argument("--tool_save_path", type=str, default="/data/rolnan/ScienceQA/tool_selected.json") 71 | args = parser.parse_args() 72 | 73 | main(args) -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/run_task.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | import concurrent 6 | import joblib 7 | from benchmark_run.utils import summarize_trial_detailed, log_trial 8 | import benchmark_run.utils as utils 9 | from benchmark_run.agent_arch import get_agent 10 | from benchmark_run.llms import get_llm_backend 11 | from benchmark_run.config import available_agent_names 12 | import json 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Parsing the input of agents, llms and llm context length.') 16 | parser.add_argument("--agent_name", type=str, help="Name of the agent.", default="ZeroshotThink_HotPotQA_run_Agent") 17 | parser.add_argument("--llm_name", type=str, help="Name of the llm", default="llama-2-13b") 18 | parser.add_argument("--max_context_len", type=int, help="Maximum context length", default=4096) 19 | parser.add_argument("--task",type=str ,help="task name",default="Hotpotqa") 20 | parser.add_argument("--task_path",type=str,help="task path") 21 | parser.add_argument("--save_path",type=str,help="save path") 22 | args = parser.parse_args() 23 | 24 | agent_name = args.agent_name 25 | 26 | llm_name=args.llm_name 27 | task_path=args.task_path 28 | save_path=args.save_path 29 | max_context_len = args.max_context_len 30 | 31 | assert agent_name in available_agent_names 32 | 33 | def process_agent_run_step(agent): 34 | agent.run() 35 | 36 | def run_one_complex_level_hotpotqa(): 37 | hotpot = json.load(open(task_path)) 38 | agent_save_file = save_path 39 | task_instructions = [(row['Question'], row['Answer']) for row in hotpot] 40 | if os.path.exists(agent_save_file): 41 | sessions = utils.get_all_agent_sessions(agent_save_file) 42 | completed_tasks = utils.get_non_error_tasks(sessions) 43 | task_instructions = [task for task in task_instructions if task not in completed_tasks] 44 | utils.delete_error(agent_save_file) 45 | llm = get_llm_backend(llm_name).run 46 | 47 | agent_cls = get_agent(agent_name) 48 | agents = [agent_cls(ques, ans, llm, max_context_len) for ques, ans in task_instructions] 49 | for agent in agents: 50 | process_agent_run_step(agent) 51 | utils.log_agent(agent, agent_save_file) 52 | print(f'Finished Trial. Total: {len(agents)}') 53 | 54 | def run_one_complex_level_scienceqa(): 55 | scienceqa = json.load(open(task_path)) 56 | agent_save_file = save_path 57 | task_instructions = [(row['Question'], row["choices"],row['Answer'], row["caption"], row["orc"]) for row in scienceqa] 58 | if os.path.exists(agent_save_file): 59 | sessions = utils.get_all_agent_sessions(agent_save_file) 60 | completed_tasks = utils.get_non_error_tasks(sessions) 61 | # task_instructions = [task for task in task_instructions if task not in completed_tasks] 62 | task_instructions = task_instructions[len(completed_tasks):] 63 | utils.delete_error(agent_save_file) 64 | llm = get_llm_backend(llm_name).run 65 | agent_cls = get_agent(agent_name) 66 | agents = [agent_cls(ques, ans, llm, choices, cap, ocr, max_context_len) for ques, choices, ans, cap, ocr in task_instructions] 67 | for agent in agents: 68 | process_agent_run_step(agent) 69 | utils.log_agent(agent, agent_save_file) 70 | print(f'Finished Trial. Total: {len(agents)}') 71 | 72 | def main(): 73 | if args.task == "Hotpotqa": 74 | run_one_complex_level_hotpotqa() 75 | elif args.task == "Scienceqa": 76 | run_one_complex_level_scienceqa() 77 | if __name__ == '__main__': 78 | main() -------------------------------------------------------------------------------- /Prompts/hotpotqa_react.txt: -------------------------------------------------------------------------------- 1 | Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be two types: 2 | (1) search[question], which searches a question on Google and returns a short snippet containing the answer. Note that sometimes the snippet does not contain the answer, and some alternative search might be needed. 3 | (2) finish[answer], which returns the answer and finishes the task. 4 | Here are some examples. 5 | 6 | 7 | What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? 8 | Thought: I need to first find the eastern sector of the Colorado orogeny extends into what, then find its elevation range. 9 | Action: search[the eastern sector of the Colorado orogeny extends into what?] 10 | Observation: the High Plains 11 | Thought: I need to find the elevation range for the High Plains. 12 | Action: search[elevation range of the High Plains?] 13 | Observation: around 1,800 to 7,000 ft 14 | Thought: I have the answer. 15 | Action: finish[1,800 to 7,000 ft] 16 | Observation: Episode finished, reward = 1 17 | 18 | 19 | Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who? 20 | Thought: I need to search Milhouse is named after who. 21 | Action: search[Milhouse is named after who?] 22 | Observation: U.S. president Richard Nixon 23 | Thought: I find the answer. 24 | Action: finish[Richard Nixon] 25 | Observation: Episode finished, reward = 1 26 | 27 | 28 | Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture? 29 | Thought: I need to search documentary Adam Clayton Powell and documentary The Saimaa Gesture to find which is about Finnish rock groups. 30 | Action: search[documentary Adam Clayton Powell] 31 | Observation: Adam Clayton Powell (1989). Documentary. The Academy Award¨-nominated Adam Clayton Powell delves into the gripping life and career of the most influential ... 32 | Thought: I do not get whether it is about Finnish rock groups. I need to search Adam Clayton Powell to make sure. 33 | Action: search[Adam Clayton Powell] 34 | Observation: Re-elected for nearly three decades, Powell became a powerful national politician of the Democratic Party, and served as a national spokesman on civil rights ... 35 | Thought: Adam Clayton Powell is a politican, not Finnish rock groups. I need to search The Saimaa Gesture to make sure. 36 | Action: search[The Saimaa Gesture documentary] 37 | Observation: It is a documentary about three Finnish rock groups aboard the steamboat SS Heinävesi on their tour around Lake Saimaa. The Saimaa Gesture. Directed by, Aki ... 38 | Thought: The Saimaa Gesture is about three Finnish rock groups, so the answer is The Saimaa Gesture. 39 | Action: finish[The Saimaa Gesture] 40 | Observation: Episode finished, reward = 1 41 | 42 | 43 | What profession does Nicholas Ray and Elia Kazan have in common? 44 | Thought: I need to search the profession of Nicholas Ray and Elia Kazan, then find what is common. 45 | Action: search[Nicholas Ray profession] 46 | Observation: New York City, U.S.. Occupation(s), Film director, screenwriter, actor. Years active, 1946–1979. Spouses. 47 | Thought: Nicholas Ray is film director, screenwriter, actor. I need to search Elia Kazan next. 48 | Action: search[Elia Kazan profession] 49 | Observation: Occupations. Actor; director; producer; screenwriter. Years active, 1934–1976. Spouses. Molly Day Thacher Kazan. ​. ​. ( m. 1932, until her death in 1963)​. 50 | Thought: Elia Kazan is actor, director, producer, screenwriter. So the common profession is actor, director, screenwriter 51 | Action: finish[actor, director, screenwriter] 52 | Observation: Episode finished, reward = 1 53 | -------------------------------------------------------------------------------- /Self_Instruct/llms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from langchain import OpenAI, PromptTemplate 4 | from langchain.chat_models import ChatOpenAI 5 | from langchain.prompts.chat import ( 6 | ChatPromptTemplate, 7 | SystemMessagePromptTemplate, 8 | HumanMessagePromptTemplate, 9 | AIMessagePromptTemplate, 10 | AIMessage 11 | ) 12 | 13 | 14 | class MetaAgent: 15 | def __init__( 16 | self, 17 | model_name: str, 18 | openai_key: str = "EMPTY", 19 | url: str = "http://localhost:8000/v1", 20 | system_prompt: str = None, 21 | ): 22 | self.key = openai_key 23 | self.url = url 24 | self.model_name = model_name 25 | self.system_prompt = system_prompt 26 | self.llm = self._get_llm() 27 | self.prompt = self._init_prompt() 28 | self.prompt_args = dict() 29 | 30 | def _get_llm(self): 31 | os.environ['OPENAI_API_KEY'] = self.key 32 | if "text" in self.model_name: 33 | llm = OpenAI(model=self.model_name) 34 | elif "gpt" in self.model_name: 35 | llm = ChatOpenAI(model=self.model_name) 36 | else: 37 | os.environ['OPENAI_API_BASE'] = self.url 38 | llm = ChatOpenAI(model=self.model_name) 39 | return llm 40 | 41 | def _init_prompt(self): 42 | if not self.system_prompt: 43 | return [] 44 | 45 | if isinstance(self.llm, ChatOpenAI): 46 | system_prompt_template = SystemMessagePromptTemplate.from_template(self.system_prompt) 47 | return [system_prompt_template] 48 | else: 49 | system_prompt_template = "system: " + self.system_prompt 50 | return [system_prompt_template] 51 | 52 | def generate( 53 | self, 54 | human_prompt_template, 55 | human_prompt_args, 56 | temprature=0.2, 57 | top_k=40, 58 | top_p=0.75, 59 | max_tokens=512, 60 | stop=None, 61 | update_prompt=False, 62 | reset_prompt=False 63 | ): 64 | _old_prompt = copy.deepcopy(self.prompt) 65 | _old_prompt_args = copy.deepcopy(self.prompt_args) 66 | self.prompt_args.update(human_prompt_args) 67 | 68 | if isinstance(self.llm, ChatOpenAI): 69 | human_prompt_template = HumanMessagePromptTemplate.from_template(human_prompt_template) 70 | self.prompt.append(human_prompt_template) 71 | prompt_template = ChatPromptTemplate.from_messages(self.prompt) 72 | prompt = prompt_template.format_messages(**self.prompt_args) 73 | else: 74 | self.prompt.append("human: " + human_prompt_template) 75 | prompt_template = PromptTemplate.from_template("\n\n".join(self.prompt)) 76 | prompt = prompt_template.format_prompt(**self.prompt_args) 77 | 78 | response = self.llm( 79 | prompt, 80 | temprature=temprature, 81 | top_k=top_k, 82 | top_p=top_p, 83 | max_tokens=max_tokens, 84 | stop=stop 85 | ) 86 | 87 | output = response.content if isinstance(response, AIMessage) else response 88 | 89 | if update_prompt: 90 | if isinstance(self.llm, ChatOpenAI): 91 | ai_prompt_template = AIMessagePromptTemplate.from_template(output) 92 | self.prompt.append(ai_prompt_template) 93 | else: 94 | self.prompt.append(output) 95 | else: 96 | self.prompt = _old_prompt 97 | self.prompt_args = _old_prompt_args 98 | 99 | if reset_prompt: 100 | self._init_prompt() 101 | 102 | return output -------------------------------------------------------------------------------- /Self_Plan/Tool_Selection/llms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from langchain import OpenAI, PromptTemplate 4 | from langchain.chat_models import ChatOpenAI 5 | from langchain.prompts.chat import ( 6 | ChatPromptTemplate, 7 | SystemMessagePromptTemplate, 8 | HumanMessagePromptTemplate, 9 | AIMessagePromptTemplate, 10 | AIMessage 11 | ) 12 | 13 | 14 | class MetaAgent: 15 | 16 | def __init__( 17 | self, 18 | model_name: str, 19 | openai_key: str = "EMPTY", 20 | url: str = "http://localhost:8000/v1", 21 | system_prompt: str = None, 22 | ): 23 | self.key = openai_key 24 | self.url = url 25 | self.model_name = model_name 26 | self.system_prompt = system_prompt 27 | self.llm = self._get_llm() 28 | self.prompt = self._init_prompt() 29 | self.prompt_args = dict() 30 | 31 | def _get_llm(self): 32 | os.environ['OPENAI_API_KEY'] = self.key 33 | if "text" in self.model_name: 34 | llm = OpenAI(model=self.model_name) 35 | elif "gpt" in self.model_name: 36 | llm = ChatOpenAI(model=self.model_name) 37 | else: 38 | os.environ['OPENAI_API_BASE'] = self.url 39 | llm = ChatOpenAI(model=self.model_name) 40 | return llm 41 | 42 | def _init_prompt(self): 43 | if not self.system_prompt: 44 | return [] 45 | 46 | if isinstance(self.llm, ChatOpenAI): 47 | system_prompt_template = SystemMessagePromptTemplate.from_template(self.system_prompt) 48 | return [system_prompt_template] 49 | else: 50 | system_prompt_template = "system: " + self.system_prompt 51 | return [system_prompt_template] 52 | 53 | def generate( 54 | self, 55 | human_prompt_template, 56 | human_prompt_args, 57 | temprature=0.2, 58 | top_k=40, 59 | top_p=0.75, 60 | max_tokens=512, 61 | stop=None, 62 | update_prompt=False, 63 | reset_prompt=False 64 | ): 65 | _old_prompt = copy.deepcopy(self.prompt) 66 | _old_prompt_args = copy.deepcopy(self.prompt_args) 67 | self.prompt_args.update(human_prompt_args) 68 | 69 | if isinstance(self.llm, ChatOpenAI): 70 | human_prompt_template = HumanMessagePromptTemplate.from_template(human_prompt_template) 71 | self.prompt.append(human_prompt_template) 72 | prompt_template = ChatPromptTemplate.from_messages(self.prompt) 73 | prompt = prompt_template.format_messages(**self.prompt_args) 74 | else: 75 | self.prompt.append("human: " + human_prompt_template) 76 | prompt_template = PromptTemplate.from_template("\n\n".join(self.prompt)) 77 | prompt = prompt_template.format_prompt(**self.prompt_args) 78 | 79 | response = self.llm( 80 | prompt, 81 | temprature=temprature, 82 | top_k=top_k, 83 | top_p=top_p, 84 | max_tokens=max_tokens, 85 | stop=stop 86 | ) 87 | 88 | output = response.content if isinstance(response, AIMessage) else response 89 | 90 | if update_prompt: 91 | if isinstance(self.llm, ChatOpenAI): 92 | ai_prompt_template = AIMessagePromptTemplate.from_template(output) 93 | self.prompt.append(ai_prompt_template) 94 | else: 95 | self.prompt.append(output) 96 | else: 97 | self.prompt = _old_prompt 98 | self.prompt_args = _old_prompt_args 99 | 100 | if reset_prompt: 101 | self._init_prompt() 102 | 103 | return output -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/hotpotqa_env.py: -------------------------------------------------------------------------------- 1 | import wikienv, wrappers 2 | 3 | def step(env, action): 4 | attempts = 0 5 | while attempts < 10: 6 | try: 7 | return env.step(action) 8 | except requests.exceptions.Timeout: 9 | attempts += 1 10 | 11 | class HotPotQAEnv: 12 | def __init__(self, n_tasks=100): 13 | self.n_tasks = n_tasks 14 | self.sessions = {} 15 | env = wikienv.WikiEnv() 16 | env = wrappers.HotPotQAWrapper(env, split="dev") 17 | env = wrappers.LoggingWrapper(env) 18 | for idx in range(self.n_tasks): 19 | self.sessions[idx] = {'session': idx, 20 | 'question': question = env.reset(idx=idx) 21 | } 22 | 23 | def step(self, session, action): 24 | done = False 25 | observation_ = None 26 | if action == 'reset': 27 | self.sessions[session] = {'session': session, 'page_type': 'init'} 28 | elif action.startswith('think['): 29 | observation = 'OK.' 30 | elif action.startswith('search['): 31 | assert self.sessions[session]['page_type'] == 'init' 32 | query = action[7:-1] 33 | self.sessions[session] = {'session': session, 'page_type': 'search', 34 | 'query_string': query, 'page_num': 1} 35 | elif action.startswith('click['): 36 | button = action[6:-1] 37 | if button == 'Buy Now': 38 | assert self.sessions[session]['page_type'] == 'item' 39 | self.sessions[session]['page_type'] = 'end' 40 | done = True 41 | elif button == 'Back to Search': 42 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item'] 43 | self.sessions[session] = {'session': session, 'page_type': 'init'} 44 | elif button == 'Next >': 45 | assert False # ad hoc page limitation 46 | assert self.sessions[session]['page_type'] == 'search' 47 | self.sessions[session]['page_num'] += 1 48 | elif button == '< Prev': 49 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item'] 50 | if self.sessions[session]['page_type'] == 'search': 51 | assert False 52 | self.sessions[session]['page_num'] -= 1 53 | elif self.sessions[session]['page_type'] == 'item_sub': 54 | self.sessions[session]['page_type'] = 'item' 55 | elif self.sessions[session]['page_type'] == 'item': 56 | self.sessions[session]['page_type'] = 'search' 57 | self.sessions[session]['options'] = {} 58 | elif button in ACTION_TO_TEMPLATE: 59 | assert self.sessions[session]['page_type'] == 'item' 60 | self.sessions[session]['page_type'] = 'item_sub' 61 | self.sessions[session]['subpage'] = button 62 | else: 63 | if self.sessions[session]['page_type'] == 'search': 64 | assert button in self.sessions[session].get('asins', []) # must be asins 65 | self.sessions[session]['page_type'] = 'item' 66 | self.sessions[session]['asin'] = button 67 | elif self.sessions[session]['page_type'] == 'item': 68 | assert 'option_types' in self.sessions[session] 69 | assert button in self.sessions[session]['option_types'], (button, self.sessions[session]['option_types']) # must be options 70 | option_type = self.sessions[session]['option_types'][button] 71 | if not 'options' in self.sessions[session]: 72 | self.sessions[session]['options'] = {} 73 | self.sessions[session]['options'][option_type] = button 74 | observation_ = f'You have clicked {button}.' 75 | else: 76 | assert False 77 | observation, info, clickable = webshop_text(**self.sessions[session]) 78 | if observation_: 79 | observation = observation_ 80 | self.sessions[session].update(info) 81 | reward = info.get('reward', 0.0) 82 | asins = info.get('asins', []) 83 | return observation, reward, done, asins, clickable 84 | -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/hotpotqa_env.py: -------------------------------------------------------------------------------- 1 | import wikienv, wrappers 2 | 3 | def step(env, action): 4 | attempts = 0 5 | while attempts < 10: 6 | try: 7 | return env.step(action) 8 | except requests.exceptions.Timeout: 9 | attempts += 1 10 | 11 | class HotPotQAEnv: 12 | def __init__(self, n_tasks=100): 13 | self.n_tasks = n_tasks 14 | self.sessions = {} 15 | env = wikienv.WikiEnv() 16 | env = wrappers.HotPotQAWrapper(env, split="dev") 17 | env = wrappers.LoggingWrapper(env) 18 | for idx in range(self.n_tasks): 19 | self.sessions[idx] = {'session': idx, 20 | 'question': question = env.reset(idx=idx) 21 | } 22 | 23 | def step(self, session, action): 24 | done = False 25 | observation_ = None 26 | if action == 'reset': 27 | self.sessions[session] = {'session': session, 'page_type': 'init'} 28 | elif action.startswith('think['): 29 | observation = 'OK.' 30 | elif action.startswith('search['): 31 | assert self.sessions[session]['page_type'] == 'init' 32 | query = action[7:-1] 33 | self.sessions[session] = {'session': session, 'page_type': 'search', 34 | 'query_string': query, 'page_num': 1} 35 | elif action.startswith('click['): 36 | button = action[6:-1] 37 | if button == 'Buy Now': 38 | assert self.sessions[session]['page_type'] == 'item' 39 | self.sessions[session]['page_type'] = 'end' 40 | done = True 41 | elif button == 'Back to Search': 42 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item'] 43 | self.sessions[session] = {'session': session, 'page_type': 'init'} 44 | elif button == 'Next >': 45 | assert False # ad hoc page limitation 46 | assert self.sessions[session]['page_type'] == 'search' 47 | self.sessions[session]['page_num'] += 1 48 | elif button == '< Prev': 49 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item'] 50 | if self.sessions[session]['page_type'] == 'search': 51 | assert False 52 | self.sessions[session]['page_num'] -= 1 53 | elif self.sessions[session]['page_type'] == 'item_sub': 54 | self.sessions[session]['page_type'] = 'item' 55 | elif self.sessions[session]['page_type'] == 'item': 56 | self.sessions[session]['page_type'] = 'search' 57 | self.sessions[session]['options'] = {} 58 | elif button in ACTION_TO_TEMPLATE: 59 | assert self.sessions[session]['page_type'] == 'item' 60 | self.sessions[session]['page_type'] = 'item_sub' 61 | self.sessions[session]['subpage'] = button 62 | else: 63 | if self.sessions[session]['page_type'] == 'search': 64 | assert button in self.sessions[session].get('asins', []) # must be asins 65 | self.sessions[session]['page_type'] = 'item' 66 | self.sessions[session]['asin'] = button 67 | elif self.sessions[session]['page_type'] == 'item': 68 | assert 'option_types' in self.sessions[session] 69 | assert button in self.sessions[session]['option_types'], (button, self.sessions[session]['option_types']) # must be options 70 | option_type = self.sessions[session]['option_types'][button] 71 | if not 'options' in self.sessions[session]: 72 | self.sessions[session]['options'] = {} 73 | self.sessions[session]['options'][option_type] = button 74 | observation_ = f'You have clicked {button}.' 75 | else: 76 | assert False 77 | observation, info, clickable = webshop_text(**self.sessions[session]) 78 | if observation_: 79 | observation = observation_ 80 | self.sessions[session].update(info) 81 | reward = info.get('reward', 0.0) 82 | asins = info.get('asins', []) 83 | return observation, reward, done, asins, clickable 84 | -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/llms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import os 9 | import sys 10 | import json 11 | import random 12 | import tiktoken 13 | token_enc = tiktoken.get_encoding("cl100k_base") 14 | import openai 15 | from langchain import PromptTemplate, OpenAI, LLMChain 16 | from transformers import AutoTokenizer, AutoModelForCausalLM 17 | from langchain.chat_models import ChatOpenAI 18 | from langchain.chains import LLMChain 19 | from langchain.prompts.chat import ( 20 | ChatPromptTemplate, 21 | SystemMessagePromptTemplate, 22 | AIMessagePromptTemplate, 23 | HumanMessagePromptTemplate, 24 | ) 25 | OPENAI_API_KEY="" 26 | 27 | OPENAI_CHAT_MODELS = ["gpt-3.5-turbo","gpt-3.5-turbo-16k-0613","gpt-3.5-turbo-16k","gpt-4-0613","gpt-4-32k-0613"] 28 | OPENAI_LLM_MODELS = ["text-davinci-003","text-ada-001"] 29 | FASTCHAT_CHAT_MODELS = ["llama-2-70b-chat"] 30 | FASTCHAT_LLM_MODELS = ["vicuna-7b"] 31 | 32 | 33 | class langchain_openai_chatllm: 34 | def __init__(self, llm_name): 35 | openai.api_key = OPENAI_API_KEY 36 | self.llm_name = llm_name 37 | human_template="{prompt}" 38 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 39 | self.chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt]) 40 | 41 | def run(self, prompt, temperature=0, stop=['\n'], max_tokens=128): 42 | chat = ChatOpenAI(model=self.llm_name, temperature=temperature, stop=stop, max_tokens=max_tokens) 43 | self.chain = LLMChain(llm=chat, prompt=self.chat_prompt) 44 | return self.chain.run(prompt) 45 | 46 | class langchain_openai_llm: 47 | def __init__(self, llm_name): 48 | openai.api_key = OPENAI_API_KEY 49 | self.prompt_temp = PromptTemplate( 50 | input_variables=["prompt"], template="{prompt}" 51 | ) 52 | self.llm_name = llm_name 53 | 54 | def run(self, prompt, temperature=0.9, stop=['\n'], max_tokens=128): 55 | llm = OpenAI(model=self.llm_name, temperature=temperature, stop=stop, max_tokens=max_tokens) 56 | chain = LLMChain(llm=llm, prompt=self.prompt_temp) 57 | return chain.run(prompt) 58 | 59 | 60 | class langchain_fastchat_chatllm: 61 | def __init__(self, llm_name): 62 | os.environ['OPENAI_API_KEY'] = "EMPTY" 63 | os.environ['OPENAI_API_BASE'] = "http://localhost:8000/v1" 64 | self.llm_name = llm_name 65 | human_template="{prompt}" 66 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 67 | self.chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt]) 68 | 69 | def run(self, prompt, temperature=1, stop=['\n'], max_tokens=128): 70 | chat = ChatOpenAI(model=self.llm_name, temperature=temperature, stop=stop, max_tokens=max_tokens) 71 | self.chain = LLMChain(llm=chat, prompt=self.chat_prompt) 72 | return self.chain.run(prompt) 73 | 74 | 75 | class langchain_fastchat_llm: 76 | def __init__(self, llm_name): 77 | os.environ['OPENAI_API_KEY'] = "EMPTY" 78 | os.environ['OPENAI_API_BASE'] = "http://localhost:8000/v1" 79 | self.prompt_temp = PromptTemplate( 80 | input_variables=["prompt"], template="{prompt}" 81 | ) 82 | self.llm_name = llm_name 83 | 84 | def run(self, prompt, temperature=0.9, stop=['\n'], max_tokens=128): 85 | llm = OpenAI(model=self.llm_name, temperature=temperature, stop=['\n'], max_tokens=max_tokens) 86 | chain = LLMChain(llm=llm, prompt=self.prompt_temp) 87 | return chain.run(prompt) 88 | 89 | def get_llm_backend(llm_name): 90 | if llm_name in OPENAI_CHAT_MODELS: 91 | return langchain_openai_chatllm(llm_name) 92 | elif llm_name in OPENAI_LLM_MODELS: 93 | return langchain_openai_llm(llm_name) 94 | elif llm_name in FASTCHAT_CHAT_MODELS: 95 | return langchain_fastchat_llm(llm_name) 96 | else: 97 | return langchain_fastchat_llm(llm_name) 98 | -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/run_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | import concurrent 6 | import joblib 7 | from benchmark_run.utils import summarize_trial_detailed, log_trial 8 | import benchmark_run.utils as utils 9 | from benchmark_run.Meta_agent_arch import get_agent 10 | from benchmark_run.llms import get_llm_backend 11 | from benchmark_run.config import available_agent_names 12 | import json 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Parsing the input of agents, llms and llm context length.') 16 | parser.add_argument("--agent_name", type=str, help="Name of the agent.", default="ZeroshotThink_HotPotQA_run_Agent") 17 | parser.add_argument("--plan_agent", type=str, help="Name of the plan", default="plan_peft") 18 | parser.add_argument("--action_agent", type=str, help="Name of the action", default="action_peft") 19 | parser.add_argument("--reflect_agent", type=str, help="Name of the reflect_agent", default="reflect_peft") 20 | parser.add_argument("--max_context_len", type=int, help="Maximum context length", default=4096) 21 | parser.add_argument("--task",type=str ,help="task name",default="Hotpotqa") 22 | parser.add_argument("--task_path",type=str,help="task path") 23 | parser.add_argument("--save_path",type=str,help="save path") 24 | args = parser.parse_args() 25 | 26 | agent_name = args.agent_name 27 | 28 | plan_agent = args.plan_agent 29 | action_agent = args.action_agent 30 | reflect_agent = args.reflect_agent 31 | max_context_len = args.max_context_len 32 | save_path = args.save_path 33 | if save_path[-1] != "/": 34 | save_path += "/" 35 | task_path = args.task_path 36 | assert agent_name in available_agent_names 37 | 38 | def process_agent_run_step(agent): 39 | agent.run() 40 | 41 | def run_one_complex_level_hotpotqa(level="easy"): 42 | hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True) 43 | agent_save_file = f"{save_path}{level}.jsonl" 44 | task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()] 45 | if os.path.exists(agent_save_file): 46 | sessions = utils.get_all_agent_sessions(agent_save_file) 47 | completed_tasks = utils.get_non_error_tasks(sessions) 48 | print(f"{level}:{len(completed_tasks)}") 49 | task_instructions = [task for task in task_instructions if task not in completed_tasks] 50 | utils.delete_error(agent_save_file) 51 | 52 | llm_plan = get_llm_backend(plan_agent).run 53 | llm_action = get_llm_backend(action_agent).run 54 | llm_reflect = get_llm_backend(reflect_agent).run 55 | 56 | agent_cls = get_agent(agent_name) 57 | agents = [agent_cls(ques, ans, llm_plan, llm_action, llm_reflect, max_context_len) for ques, ans in task_instructions] 58 | for agent in agents: 59 | process_agent_run_step(agent) 60 | utils.log_agent(agent, agent_save_file) 61 | print(f'Finished Trial. Total: {len(agents)}') 62 | def run_one_complex_level_scienceqa(level="1-4"): 63 | f = open(f'{task_path}/format_scienceqa_grade{level}.json') 64 | scienceqa = json.load(f) 65 | agent_save_file = f"{save_path}{level}.jsonl" 66 | task_instructions = [(row['Question'],row['choices'],row['Answer'],row['orc'],row['caption']) for row in scienceqa] 67 | if os.path.exists(agent_save_file): 68 | sessions = utils.get_all_agent_sessions(agent_save_file) 69 | completed_tasks = utils.get_non_error_tasks(sessions) 70 | print(f"{level}:{len(completed_tasks)}") 71 | task_instructions = [task for task in task_instructions if task[0] not in completed_tasks] 72 | utils.delete_error(agent_save_file) 73 | 74 | llm_plan = get_llm_backend(plan_agent).run 75 | llm_action = get_llm_backend(action_agent).run 76 | llm_reflect = get_llm_backend(reflect_agent).run 77 | 78 | agent_cls = get_agent(agent_name) 79 | agents = [agent_cls(ques, choices, ans, caption, orc, llm_plan, llm_action, llm_reflect, max_context_len) for ques,choices,ans,orc,caption in task_instructions] 80 | for agent in agents: 81 | process_agent_run_step(agent) 82 | utils.log_agent(agent, agent_save_file) 83 | print(f'Finished Trial. Total: {len(agents)}') 84 | 85 | def main(): 86 | if args.task == "Hotpotqa": 87 | levels = ['easy', 'medium', 'hard'] 88 | for level in levels: 89 | run_one_complex_level_hotpotqa(level) 90 | elif args.task == "Scienceqa": 91 | levels = ['1-4', '5-8', '9-12'] 92 | for level in levels: 93 | run_one_complex_level_scienceqa(level) 94 | if __name__ == '__main__': 95 | main() -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/llms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import os 9 | import sys 10 | import json 11 | import random 12 | import tiktoken 13 | token_enc = tiktoken.get_encoding("cl100k_base") 14 | import openai 15 | from langchain import PromptTemplate, OpenAI, LLMChain 16 | from transformers import AutoTokenizer, AutoModelForCausalLM 17 | from langchain.chat_models import ChatOpenAI 18 | from langchain.chains import LLMChain 19 | from langchain import PromptTemplate 20 | from langchain.prompts.chat import ( 21 | ChatPromptTemplate, 22 | SystemMessagePromptTemplate, 23 | AIMessagePromptTemplate, 24 | HumanMessagePromptTemplate, 25 | ) 26 | OPENAI_API_KEY = "" 27 | 28 | OPENAI_CHAT_MODELS = ["gpt-3.5-turbo","gpt-3.5-turbo-16k-0613","gpt-3.5-turbo-16k","gpt-4-0613","gpt-4-32k-0613"] 29 | OPENAI_LLM_MODELS = ["text-davinci-003","text-ada-001"] 30 | FASTCHAT_CHAT_MODELS = ["llama-2-70b-chat"] 31 | FASTCHAT_LLM_MODELS = ["vicuna-7b"] 32 | 33 | 34 | class langchain_openai_chatllm: 35 | def __init__(self, llm_name): 36 | openai.api_key = OPENAI_API_KEY 37 | self.llm_name = llm_name 38 | human_template="{prompt}" 39 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 40 | self.chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt]) 41 | 42 | def run(self, prompt, temperature=1, stop=['\n'], max_tokens=128): 43 | chat = ChatOpenAI(model=self.llm_name, temperature=temperature, stop=stop, max_tokens=max_tokens) 44 | self.chain = LLMChain(llm=chat, prompt=self.chat_prompt) 45 | return self.chain.run(prompt) 46 | 47 | class langchain_openai_llm: 48 | def __init__(self, llm_name): 49 | openai.api_key = OPENAI_API_KEY 50 | self.prompt_temp = PromptTemplate( 51 | input_variables=["prompt"], template="{prompt}" 52 | ) 53 | self.llm_name = llm_name 54 | 55 | def run(self, prompt, temperature=0.9, stop=['\n'], max_tokens=128): 56 | llm = OpenAI(model=self.llm_name, temperature=temperature, stop=stop, max_tokens=max_tokens) 57 | chain = LLMChain(llm=llm, prompt=self.prompt_temp) 58 | return chain.run(prompt) 59 | 60 | 61 | class langchain_fastchat_chatllm: 62 | def __init__(self, llm_name): 63 | os.environ['OPENAI_API_KEY'] = "EMPTY" 64 | os.environ['OPENAI_API_BASE'] = "http://localhost:8000/v1" 65 | self.llm_name = llm_name 66 | human_template="{prompt}" 67 | human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 68 | self.chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt]) 69 | 70 | def run(self, prompt, temperature=1, stop=['\n'], max_tokens=128): 71 | chat = ChatOpenAI(model=self.llm_name, temperature=temperature, stop=stop, max_tokens=max_tokens) 72 | self.chain = LLMChain(llm=chat, prompt=self.chat_prompt) 73 | return self.chain.run(prompt) 74 | 75 | 76 | class langchain_fastchat_llm: 77 | def __init__(self, llm_name): 78 | os.environ['OPENAI_API_KEY'] = "EMPTY" 79 | os.environ['OPENAI_API_BASE'] = "http://localhost:8000/v1" 80 | self.prompt_temp = PromptTemplate( 81 | input_variables=["prompt"], template="{prompt}\n### Response:\n" 82 | ) 83 | self.llm_name = llm_name 84 | 85 | def run(self, prompt, temperature=0.5, stop=['\n'], max_tokens=256): 86 | print("*********temperature0.5***********************") 87 | llm = OpenAI( 88 | model=self.llm_name, 89 | temperature=temperature, 90 | top_p=0.75, 91 | top_k=40, 92 | num_beams=4, 93 | stop=stop, 94 | max_tokens=max_tokens 95 | ) 96 | chain = LLMChain(llm=llm, prompt=self.prompt_temp) 97 | try: 98 | output = chain.run(prompt) 99 | except Exception as e: 100 | output = "" 101 | print(e) 102 | return output 103 | 104 | def get_llm_backend(llm_name): 105 | if llm_name in OPENAI_CHAT_MODELS: 106 | return langchain_openai_chatllm(llm_name) 107 | elif llm_name in OPENAI_LLM_MODELS: 108 | return langchain_openai_llm(llm_name) 109 | elif llm_name in FASTCHAT_CHAT_MODELS: 110 | return langchain_fastchat_llm(llm_name) 111 | else: 112 | return langchain_fastchat_llm(llm_name) 113 | -------------------------------------------------------------------------------- /Prompts/scienceqa_react.txt: -------------------------------------------------------------------------------- 1 | Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be the following types: 2 | (1) search[question], which searches a question on Bing and returns a short snippet containing the answer. Note that sometimes the snippet does not contain the answer, and some alternative search might be needed. 3 | (2) image2text[image], which generates captions for the image and detects words in the image.You are recommand to use it first to get more information about the imgage to the question. If the questions contains image, it will return catption and ocr text, else, it will return None. 4 | (3) retrieve[background], which retrieves background knowledge as the hint for the given question. Normally, we consider using this action when the llm generated background knowledge is helpful to guide the solution. 5 | (2) finish[answer], which returns the answer and finishes the task. 6 | The observation closely follows the action and is the result of the action. The trajectory is composed of a cycle of thought, action, and observation. 7 | Here are some examples. 8 | 9 | Question: Which type of relationship is formed when a feather mite lives on a barn swallow's feathers? 10 | Options: (A) parasitic (B) mutualistic (C) commensal 11 | Metadata: {'has_image': True, 'grade': 7, 'subject': 'natural science', 'topic': 'biology', 'category': 'Ecological interactions', 'skill': 'Classify symbiotic relationships'} 12 | Thought: I need to learn the background knowledge. 13 | Action: retrieve[background] 14 | Observation: A parasitic relationship is a type of symbiotic interaction where one organism benefits at the expense of another, often causing harm or detriment to the host organism. A mutualistic relationship is a type of symbiotic interaction where both organisms involved benefit from each other's presence, often contributing to each other's survival and well-being. A commensal relationship is a type of symbiotic interaction where one organism benefits from the association with another, while the other is neither helped nor harmed. 15 | Thought: I need to learn the relationship between feather mites and barn swallows 16 | Action: search[the relationship between feather mites and barn swallows] 17 | Observation: the relationship between feather mites and barn swallows is a commensal one 18 | Thought: I find the answer. 19 | Action: finish[C] 20 | 21 | Question: Which material is this screw driver made of? 22 | Options: (A) cardboard (B) plastic 23 | Metadata: {'pid': 264, 'has_image': True, 'grade': 2, 'subject': 'natural science', 'topic': 'physics', 'category': 'Materials', 'skill': 'Identify multiple materials in objects'} 24 | Thought: I need to get the content of the image to determine the answer. 25 | Action: image2text[image] 26 | Observation: Image: A yellow and black fork sticking out of a yellow frisbee. 27 | Thought: It seems there is a mismatch between the question and the provided image. I need to get more background knowledge. 28 | Action: retrieve[screwdriver] 29 | Observation: Screwdrivers are commonly made of steel for the shaft and tip, with plastic or rubber handles, and are used for various applications including assembling electronics, woodworking, automotive repair, construction, and DIY home projects. Cardboard is a lightweight, durable, and versatile material often used for packaging, crafts, and construction due to its affordability and recyclability. Plastic is a synthetic material known for its durability, versatility, and resistance to corrosion, making it widely used in manufacturing, packaging, construction, and various consumer products. 30 | Thought: Based on the above information, it is harder plastic that is used to manufacture the screw driver. I have my answer. 31 | Action: finish[B] 32 | 33 | Question: Which animal's skin is better adapted as a warning sign to ward off predators? 34 | Image: A yellow and black fire hydrant sitting on top of a rock. 35 | Options: (A) blue poison dart frog (B) gray tree frog 36 | Metadata: {'has_image': True, 'grade': 4, 'subject': 'natural science', 'topic': 'biology', 'category': 'Adaptations', 'skill': 'Animal adaptations: skins and body coverings'} 37 | Thought: I need to search for the function of the skin of tree frogs and the function of the skin of blue poison dart frogs, and analyze which animal's skin has more of a warning effect. 38 | Action: search[the function of the skin of tree frogs] 39 | Observation: The skin of tree frogs serves as a respiratory organ, helps maintain water balance, and contains toxins that act as a defense mechanism against predators. 40 | Thought: Tree frog's skin doesn't serve as a warning to the enemy. I need to search for the function of the skin of blue poison dart frogs. 41 | Action: search[the function of the skin of blue poison dart frogs] 42 | Observation: The bright blue color of this frog's skin warns predators not to eat it 43 | Thought: I find the answer. 44 | Action: finish[B] -------------------------------------------------------------------------------- /Self_Plan/Tool_Selection/pre_prompt.py: -------------------------------------------------------------------------------- 1 | ACTION_SYSTEM_PROMPT = """In order to complete a complex benchmark, we usually need the collaborative work of the following four types of agents: 2 | 1. Plan Agent. This agent is used to plan the specific execution process of the benchmark, solving a given task by determining the order in which other expert language models are invoked; 3 | 2. Tool Agent. This type of agents is used to determine how to execute a specific action when solving a task. Tools can include interactive actions with the benchmark environment and actions to call external tools or models outside of the benchmark; 4 | 3. Answer Agent. This agent is used to generate the final answer for a given task based on historical information; 5 | 4. Reflection Agent. This agent reflects on historical information and answers to determine whether the answer matches the given query. 6 | Above all, the Tool Agent includes many sub-agents that can be flexibly selected. Now your task is to generate 3-5 Tool Agents for solving a given benchmark. Note that all agents are based on language models, and their inputs and outputs must be text. You only need to provide the names and descriptions of the tool agents in order, without any addtional output.""" 7 | 8 | TOOL_POOL = [ 9 | {"name": "BingSearch","definition":"BingSearch engine can search for rich external knowledge on the Internet based on keywords, which can compensate for knowledge fallacy and knowledge outdated.","usage":"BingSearch[query], which searches the exact detailed query on the Internet and returns the relevant information to the query. Be specific and precise with your query to increase the chances of getting relevant results. For example, Bingsearch[popular dog breeds in the United States]"}, 10 | {"name": "Retrieve", "definition": "Retrieve additional background knowledge crucial for tackling complex problems. It is especially beneficial for specialized domains like science and mathematics, providing context for the task","usage":"Retrieve[entity], which retrieves the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to retrieve. For example, Retrieve[Milhouse]"}, 11 | {"name": "Lookup", "definition": "A Lookup Tool returns the next sentence containing the target string in the page from the search tool (like BingSearch or Retrieve),so it is recommended to use with Bingsearch and Retrieve, simulating Ctrl+F functionality on the browser to find target answer.","usage":"Lookup[keyword], which returns the next sentence containing the keyword in the last passage successfully found by Retrieve or BingSearch. For example, Lookup[river]."}, 12 | {"name": "Image2Text", "definition":"Image2Text is used to detect words in images convert them into text by OCR and generate captions for images. It is particularly valuable when understanding an image semantically, like identifying objects and interactions in a scene.","usage":"Image2Text[image], which generates captions for the image and detects words in the image. You are recommended to use it first to get more information about the image to the question. If the question contains an image, it will return the caption and OCR text, else, it will return None. For example, Image2Text[image]."}, 13 | {"name": "Text2Image","definition":"Text2Image Specializes in converting textual information into visual representations, facilitating the incorporation of textual data into image-based formats within the task.","usage":"Text2Image[text], which generates an image for the text provided by using multimodal models. For example, Text2Image[blue sky]"}, 14 | {"name": "KnowledgeGraph","definition":"KnowledgeGraph is used to query knowledge graph and get the query results as output.","usage":"KnowledgeGraph[query], which queries the knowledge graph and returns the relevant information to the query. For example, KnowledgeGraph[What is the capital of China?]"}, 15 | {"name": "Database", "definition": "A Database tool can output any valid SQL commmands to finish databse query, update task.","usage":"Database[query], which outputs any valid SQL commmands to finish databse query, update task. For example, Database[SELECT * FROM Customers WHERE Country='Mexico';]"}, 16 | {"name": "Calculator", "definition": "Calculator is used to calculate the result of the given mathematical expression.","usage":"Calculator[query], which calculates the result of the given mathematical expression. For example, Calculator[2+2]"}, 17 | {"name": "Table Verbalizer", "definition": "Table Verbalizer is used to convert structured tables into text to enhance the comprehension of tabular information.","usage":"Table Verbalizer[table], which converts structured tables into text to enhance the comprehension of tabular information. For example, Table Verbalizer[table]"}, 18 | {"name": "Code Interpreter", "definition": "Code Interpreter is a tool or software that interprets and executes code written in Python. It analyzes the source code line by line and translates it into machine-readable instructions or directly executes the code and returns Execution results","usage":"Code[python], which interprets and executes Python code, providing a line-by-line analysis of the source code and translating it into machine-readable instructions. For instance, Code[print(\"hello world!\")]"} 19 | ] 20 | 21 | TASK_PROMPT_TEMPLATE = """The following is the given task name and description , and you need to select three most corresponding tool agents according to the above rules in the format of one line one tool. 22 | Here are tools to be selected from: {tool_pool} 23 | Task Name: {task_name} 24 | Task Description: {task_description} 25 | Task Tool Agents: 26 | """ 27 | 28 | HOTPOTQA_TASK_DESCRIPTION = "This is a question-answering task that includes high-quality multi-hop questions and do not contain images. It tests language modeling abilities for multi-step reasoning and covers a wide range of topics. Some questions are challenging, while others are easier, requiring multiple steps of reasoning to arrive at the final answer." 29 | 30 | SCIENCEQA_TASK_DESCRIPTION = " This is a multimodal question-answering task that necessitates a model to utilizetools for transforming image information intotextual data. Simultaneously, this task incorporates substantial background knowledge, requiring the language model to acquire external information to enhance its comprehension of the task" 31 | 32 | BENCHMARK_DESCRIPTION = { 33 | "ScienceQA": SCIENCEQA_TASK_DESCRIPTION, 34 | "HotpotQA": HOTPOTQA_TASK_DESCRIPTION, 35 | } -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/pre_prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | from langchain.prompts import PromptTemplate 9 | 10 | # ZEROSHOT_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 11 | # (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. For example, Search[Milhouse] 12 | # (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. For example, Lookup[named after] 13 | # (3) Finish[answer], which returns the answer and finishes the task. For example, Finish[Richard Nixon] 14 | # You may take as many steps as necessary. 15 | # Question: {question}{scratchpad}""" 16 | 17 | ZEROSHOT_INSTRUCTION_HOTPOTQA = """I want you to be a good multi-hop question answerer ,solving a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be five types : 18 | {tools} 19 | (4) Finish[answer], which returns a definite answer. For example, Finish[Richard Nixon] (If it is a judgement question, please Finish[yes] or Finish[no]) 20 | (5) Reflect[right/wrong], which reflects the answer right or wrong based on the context history. For example, Reflect[right] 21 | Note that Reflect must be the next Action after Finish. You may take as many steps as necessary. 22 | Question: {question}\n\n{scratchpad}""" 23 | 24 | zeroshot_agent_prompt_hotpotqa = PromptTemplate( 25 | input_variables=["tools","question", "scratchpad"], 26 | template = ZEROSHOT_INSTRUCTION_HOTPOTQA, 27 | ) 28 | 29 | ZEROSHOT_INSTRUCTION_SCIENCEQA = """I want you to be a good multimodal multiple-choice science questions answerer. Select a correct option to a multi-choice multi-modal question with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be five types: 30 | {tools}(4) Finish[option], which returns the answer option and finishes the task. For example, Finish[A] 31 | (5) Reflect[right/wrong], which reflects the answer right or wrong based on the context history. For example, Reflect[right] 32 | Note that to determine the answer, it's needed to consider both the Question and the available Options. 33 | Note that Reflect must be the next Action after Finish. 34 | BingSearch and Retrieve can be used multi-times. 35 | Question: {question}\n{scratchpad}""" 36 | 37 | zeroshot_agent_prompt_scienceqa = PromptTemplate( 38 | input_variables=["tools","question", "scratchpad"], 39 | template = ZEROSHOT_INSTRUCTION_SCIENCEQA, 40 | ) 41 | 42 | REACT_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 43 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 44 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 45 | (3) Finish[answer], which returns the answer and finishes the task. 46 | You may take as many steps as necessary. 47 | Here are some examples: 48 | {examples} 49 | (END OF EXAMPLES) 50 | Question: {question}{scratchpad}""" 51 | 52 | react_agent_prompt = PromptTemplate( 53 | input_variables=["examples", "question", "scratchpad"], 54 | template = REACT_INSTRUCTION, 55 | ) 56 | 57 | PLAN_INSTRUCTION = """Setup a plan for answering question with Actions. Action can be three types: 58 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 59 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 60 | (3) Finish[answer], which returns the answer and finishes the task. 61 | {examples} 62 | (END OF EXAMPLES) 63 | Question: {question} 64 | Plan:""" 65 | 66 | plan_prompt = PromptTemplate( 67 | input_variables=["examples", "question"], 68 | template = PLAN_INSTRUCTION, 69 | ) 70 | 71 | PLANNER_INSTRUCTION = """Solve a question answering task with Plan, interleaving Action, Observation steps. Plan is decided ahead of Actions. Action can be three types: 72 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 73 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 74 | (3) Finish[answer], which returns the answer and finishes the task. 75 | You may take as many steps as necessary. 76 | Here are some examples: 77 | {examples} 78 | (END OF EXAMPLES) 79 | Question: {question} 80 | Plan: {plan}{scratchpad}""" 81 | 82 | planner_agent_prompt = PromptTemplate( 83 | input_variables=["examples", "question", "plan", "scratchpad"], 84 | template = PLANNER_INSTRUCTION, 85 | ) 86 | 87 | PLANNERREACT_INSTRUCTION = """Solve a question answering task with Plan, interleaving Thought, Action, Observation steps. Plan is decided ahead of Actions. Thought can reason about the current situation. Action can be three types: 88 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 89 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 90 | (3) Finish[answer], which returns the answer and finishes the task. 91 | You may take as many steps as necessary. 92 | Here are some examples: 93 | {examples} 94 | (END OF EXAMPLES) 95 | Question: {question} 96 | Plan: {plan}{scratchpad}""" 97 | 98 | plannerreact_agent_prompt = PromptTemplate( 99 | input_variables=["examples", "question", "plan", "scratchpad"], 100 | template = PLANNERREACT_INSTRUCTION, 101 | ) 102 | 103 | 104 | -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/wikienv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import ast 9 | import json 10 | import time 11 | import gym 12 | import requests 13 | from bs4 import BeautifulSoup 14 | 15 | # import wikipedia 16 | 17 | def clean_str(p): 18 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") 19 | 20 | 21 | class textSpace(gym.spaces.Space): 22 | def contains(self, x) -> bool: 23 | """Return boolean specifying if x is a valid member of this space.""" 24 | return isinstance(x, str) 25 | 26 | 27 | class WikiEnv(gym.Env): 28 | 29 | def __init__(self): 30 | """ 31 | Initialize the environment. 32 | """ 33 | super().__init__() 34 | self.page = None # current Wikipedia page 35 | self.obs = None # current observation 36 | self.lookup_keyword = None # current lookup keyword 37 | self.lookup_list = None # list of paragraphs containing current lookup keyword 38 | self.lookup_cnt = None # current lookup index 39 | self.steps = 0 # current number of steps 40 | self.answer = None # current answer from the agent 41 | self.observation_space = self.action_space = textSpace() 42 | self.search_time = 0 43 | self.num_searches = 0 44 | 45 | def _get_obs(self): 46 | return self.obs 47 | 48 | def _get_info(self): 49 | return {"steps": self.steps, "answer": self.answer} 50 | 51 | def reset(self, seed=None, return_info=False, options=None): 52 | # We need the following line to seed self.np_random 53 | # super().reset(seed=seed) 54 | self.obs = ("Interact with Wikipedia using search[], lookup[], and " 55 | "finish[].\n") 56 | self.page = None 57 | self.lookup_keyword = None 58 | self.lookup_list = None 59 | self.lookup_cnt = None 60 | self.steps = 0 61 | self.answer = None 62 | observation = self._get_obs() 63 | info = self._get_info() 64 | return (observation, info) if return_info else observation 65 | 66 | def construct_lookup_list(self, keyword): 67 | # find all paragraphs 68 | if self.page is None: 69 | return [] 70 | paragraphs = self.page.split("\n") 71 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 72 | 73 | # find all sentence 74 | sentences = [] 75 | for p in paragraphs: 76 | sentences += p.split('. ') 77 | sentences = [s.strip() + '.' for s in sentences if s.strip()] 78 | 79 | parts = sentences 80 | parts = [p for p in parts if keyword.lower() in p.lower()] 81 | return parts 82 | 83 | @staticmethod 84 | def get_page_obs(page): 85 | # find all paragraphs 86 | paragraphs = page.split("\n") 87 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 88 | 89 | # find all sentence 90 | sentences = [] 91 | for p in paragraphs: 92 | sentences += p.split('. ') 93 | sentences = [s.strip() + '.' for s in sentences if s.strip()] 94 | return ' '.join(sentences[:5]) 95 | 96 | # ps = page.split("\n") 97 | # ret = ps[0] 98 | # for i in range(1, len(ps)): 99 | # if len((ret + ps[i]).split(" ")) <= 50: 100 | # ret += ps[i] 101 | # else: 102 | # break 103 | # return ret 104 | 105 | def search_step(self, entity): 106 | entity_ = entity.replace(" ", "+") 107 | search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}" 108 | old_time = time.time() 109 | response_text = requests.get(search_url).text 110 | self.search_time += time.time() - old_time 111 | self.num_searches += 1 112 | soup = BeautifulSoup(response_text, features="html.parser") 113 | result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) 114 | if result_divs: # mismatch 115 | self.result_titles = [clean_str(div.get_text().strip()) for div in result_divs] 116 | self.obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}." 117 | else: 118 | page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] 119 | if any("may refer to:" in p for p in page): 120 | self.search_step("[" + entity + "]") 121 | else: 122 | self.page = "" 123 | for p in page: 124 | if len(p.split(" ")) > 2: 125 | self.page += clean_str(p) 126 | if not p.endswith("\n"): 127 | self.page += "\n" 128 | self.obs = self.get_page_obs(self.page) 129 | self.lookup_keyword = self.lookup_list = self.lookup_cnt = None 130 | 131 | def step(self, action): 132 | reward = 0 133 | done = False 134 | action = action.strip() 135 | if self.answer is not None: # already finished 136 | done = True 137 | return self.obs, reward, done, self._get_info() 138 | 139 | if action.startswith("search[") and action.endswith("]"): 140 | entity = action[len("search["):-1] 141 | # entity_ = entity.replace(" ", "_") 142 | # search_url = f"https://en.wikipedia.org/wiki/{entity_}" 143 | self.search_step(entity) 144 | elif action.startswith("lookup[") and action.endswith("]"): 145 | keyword = action[len("lookup["):-1] 146 | if self.lookup_keyword != keyword: # reset lookup 147 | self.lookup_keyword = keyword 148 | self.lookup_list = self.construct_lookup_list(keyword) 149 | self.lookup_cnt = 0 150 | if self.lookup_cnt >= len(self.lookup_list): 151 | self.obs = "No more results.\n" 152 | else: 153 | self.obs = f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) " + self.lookup_list[self.lookup_cnt] 154 | self.lookup_cnt += 1 155 | elif action.startswith("finish[") and action.endswith("]"): 156 | answer = action[len("finish["):-1] 157 | self.answer = answer 158 | done = True 159 | self.obs = f"Episode finished, reward = {reward}\n" 160 | elif action.startswith("think[") and action.endswith("]"): 161 | self.obs = "Nice thought." 162 | else: 163 | self.obs = "Invalid action: {}".format(action) 164 | 165 | self.steps += 1 166 | 167 | return self.obs, reward, done, self._get_info() 168 | 169 | def get_time_info(self): 170 | speed = self.search_time / self.num_searches if self.num_searches else 0 171 | return { 172 | "call_speed": speed, 173 | "call_time": self.search_time, 174 | "num_calls": self.num_searches, 175 | } -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/wikienv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import ast 9 | import json 10 | import time 11 | import gym 12 | import requests 13 | from bs4 import BeautifulSoup 14 | 15 | # import wikipedia 16 | 17 | def clean_str(p): 18 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") 19 | 20 | 21 | class textSpace(gym.spaces.Space): 22 | def contains(self, x) -> bool: 23 | """Return boolean specifying if x is a valid member of this space.""" 24 | return isinstance(x, str) 25 | 26 | 27 | class WikiEnv(gym.Env): 28 | 29 | def __init__(self): 30 | """ 31 | Initialize the environment. 32 | """ 33 | super().__init__() 34 | self.page = None # current Wikipedia page 35 | self.obs = None # current observation 36 | self.lookup_keyword = None # current lookup keyword 37 | self.lookup_list = None # list of paragraphs containing current lookup keyword 38 | self.lookup_cnt = None # current lookup index 39 | self.steps = 0 # current number of steps 40 | self.answer = None # current answer from the agent 41 | self.observation_space = self.action_space = textSpace() 42 | self.search_time = 0 43 | self.num_searches = 0 44 | 45 | def _get_obs(self): 46 | return self.obs 47 | 48 | def _get_info(self): 49 | return {"steps": self.steps, "answer": self.answer} 50 | 51 | def reset(self, seed=None, return_info=False, options=None): 52 | # We need the following line to seed self.np_random 53 | # super().reset(seed=seed) 54 | self.obs = ("Interact with Wikipedia using search[], lookup[], and " 55 | "finish[].\n") 56 | self.page = None 57 | self.lookup_keyword = None 58 | self.lookup_list = None 59 | self.lookup_cnt = None 60 | self.steps = 0 61 | self.answer = None 62 | observation = self._get_obs() 63 | info = self._get_info() 64 | return (observation, info) if return_info else observation 65 | 66 | def construct_lookup_list(self, keyword): 67 | # find all paragraphs 68 | if self.page is None: 69 | return [] 70 | paragraphs = self.page.split("\n") 71 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 72 | 73 | # find all sentence 74 | sentences = [] 75 | for p in paragraphs: 76 | sentences += p.split('. ') 77 | sentences = [s.strip() + '.' for s in sentences if s.strip()] 78 | 79 | parts = sentences 80 | parts = [p for p in parts if keyword.lower() in p.lower()] 81 | return parts 82 | 83 | @staticmethod 84 | def get_page_obs(page): 85 | # find all paragraphs 86 | paragraphs = page.split("\n") 87 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 88 | 89 | # find all sentence 90 | sentences = [] 91 | for p in paragraphs: 92 | sentences += p.split('. ') 93 | sentences = [s.strip() + '.' for s in sentences if s.strip()] 94 | return ' '.join(sentences[:5]) 95 | 96 | # ps = page.split("\n") 97 | # ret = ps[0] 98 | # for i in range(1, len(ps)): 99 | # if len((ret + ps[i]).split(" ")) <= 50: 100 | # ret += ps[i] 101 | # else: 102 | # break 103 | # return ret 104 | 105 | def search_step(self, entity): 106 | entity_ = entity.replace(" ", "+") 107 | search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}" 108 | old_time = time.time() 109 | response_text = requests.get(search_url).text 110 | self.search_time += time.time() - old_time 111 | self.num_searches += 1 112 | soup = BeautifulSoup(response_text, features="html.parser") 113 | result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) 114 | if result_divs: # mismatch 115 | self.result_titles = [clean_str(div.get_text().strip()) for div in result_divs] 116 | self.obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}." 117 | else: 118 | page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] 119 | if any("may refer to:" in p for p in page): 120 | self.search_step("[" + entity + "]") 121 | else: 122 | self.page = "" 123 | for p in page: 124 | if len(p.split(" ")) > 2: 125 | self.page += clean_str(p) 126 | if not p.endswith("\n"): 127 | self.page += "\n" 128 | self.obs = self.get_page_obs(self.page) 129 | self.lookup_keyword = self.lookup_list = self.lookup_cnt = None 130 | 131 | def step(self, action): 132 | reward = 0 133 | done = False 134 | action = action.strip() 135 | if self.answer is not None: # already finished 136 | done = True 137 | return self.obs, reward, done, self._get_info() 138 | 139 | if action.startswith("search[") and action.endswith("]"): 140 | entity = action[len("search["):-1] 141 | # entity_ = entity.replace(" ", "_") 142 | # search_url = f"https://en.wikipedia.org/wiki/{entity_}" 143 | self.search_step(entity) 144 | elif action.startswith("lookup[") and action.endswith("]"): 145 | keyword = action[len("lookup["):-1] 146 | if self.lookup_keyword != keyword: # reset lookup 147 | self.lookup_keyword = keyword 148 | self.lookup_list = self.construct_lookup_list(keyword) 149 | self.lookup_cnt = 0 150 | if self.lookup_cnt >= len(self.lookup_list): 151 | self.obs = "No more results.\n" 152 | else: 153 | self.obs = f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) " + self.lookup_list[self.lookup_cnt] 154 | self.lookup_cnt += 1 155 | elif action.startswith("finish[") and action.endswith("]"): 156 | answer = action[len("finish["):-1] 157 | self.answer = answer 158 | done = True 159 | self.obs = f"Episode finished, reward = {reward}\n" 160 | elif action.startswith("think[") and action.endswith("]"): 161 | self.obs = "Nice thought." 162 | else: 163 | self.obs = "Invalid action: {}".format(action) 164 | 165 | self.steps += 1 166 | 167 | return self.obs, reward, done, self._get_info() 168 | 169 | def get_time_info(self): 170 | speed = self.search_time / self.num_searches if self.num_searches else 0 171 | return { 172 | "call_speed": speed, 173 | "call_time": self.search_time, 174 | "num_calls": self.num_searches, 175 | } -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/pre_prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | from langchain.prompts import PromptTemplate 9 | 10 | # ZEROSHOT_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 11 | # (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. For example, Search[Milhouse] 12 | # (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. For example, Lookup[named after] 13 | # (3) Finish[answer], which returns the answer and finishes the task. For example, Finish[Richard Nixon] 14 | # You may take as many steps as necessary. 15 | # Question: {question}{scratchpad}""" 16 | 17 | ZEROSHOT_INSTRUCTION_HOTPOTQA = """I want you to be a good multi-hop question answerer ,solving a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be five types : 18 | {tools} 19 | (4) Finish[answer], which returns a definite answer. For example, Finish[Richard Nixon] (If it is a judgement question, please Finish[yes] or Finish[no]) 20 | (5) Reflect[right/wrong], which reflects the answer right or wrong based on the context history. For example, Reflect[right] 21 | Note that Reflect must be the next Action after Finish. You may take as many steps as necessary. 22 | Question: {question}\n\n{scratchpad}""" 23 | 24 | zeroshot_agent_prompt_hotpotqa = PromptTemplate( 25 | input_variables=["tools","question", "scratchpad"], 26 | template = ZEROSHOT_INSTRUCTION_HOTPOTQA, 27 | ) 28 | 29 | ZEROSHOT_INSTRUCTION_SCIENCEQA = """I want you to be a good multimodal multiple-choice science questions answerer. Select a correct option to a multi-choice multi-modal question with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be five types: 30 | {tools} 31 | (4) Finish[option], which returns the answer option and finishes the task. For example, Finish[A] 32 | (5) Reflect[right/wrong], which reflects the answer right or wrong based on the context history. For example, Reflect[right] 33 | Note that to determine the answer, it's needed to consider both the Question and the available Options. 34 | Note that Reflect must be the next Action after Finish. 35 | BingSearch and Retrieve can be used multi-times. 36 | Question: {question}\n{scratchpad}""" 37 | 38 | zeroshot_agent_prompt_scienceqa = PromptTemplate( 39 | input_variables=["tools","question", "scratchpad"], 40 | template = ZEROSHOT_INSTRUCTION_SCIENCEQA, 41 | ) 42 | 43 | REACT_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 44 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 45 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 46 | (3) Finish[answer], which returns the answer and finishes the task. 47 | You may take as many steps as necessary. 48 | Here are some examples: 49 | {examples} 50 | (END OF EXAMPLES) 51 | Question: {question}{scratchpad}""" 52 | 53 | react_agent_prompt = PromptTemplate( 54 | input_variables=["examples", "question", "scratchpad"], 55 | template = REACT_INSTRUCTION, 56 | ) 57 | 58 | PLAN_INSTRUCTION = """Setup a plan for answering question with Actions. Action can be three types: 59 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 60 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 61 | (3) Finish[answer], which returns the answer and finishes the task. 62 | {examples} 63 | (END OF EXAMPLES) 64 | Question: {question} 65 | Plan:""" 66 | 67 | plan_prompt = PromptTemplate( 68 | input_variables=["examples", "question"], 69 | template = PLAN_INSTRUCTION, 70 | ) 71 | 72 | PLANNER_INSTRUCTION = """Solve a question answering task with Plan, interleaving Action, Observation steps. Plan is decided ahead of Actions. Action can be three types: 73 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 74 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 75 | (3) Finish[answer], which returns the answer and finishes the task. 76 | You may take as many steps as necessary. 77 | Here are some examples: 78 | {examples} 79 | (END OF EXAMPLES) 80 | Question: {question} 81 | Plan: {plan}{scratchpad}""" 82 | 83 | planner_agent_prompt = PromptTemplate( 84 | input_variables=["examples", "question", "plan", "scratchpad"], 85 | template = PLANNER_INSTRUCTION, 86 | ) 87 | 88 | PLANNERREACT_INSTRUCTION = """Solve a question answering task with Plan, interleaving Thought, Action, Observation steps. Plan is decided ahead of Actions. Thought can reason about the current situation. Action can be three types: 89 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 90 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 91 | (3) Finish[answer], which returns the answer and finishes the task. 92 | You may take as many steps as necessary. 93 | Here are some examples: 94 | {examples} 95 | (END OF EXAMPLES) 96 | Question: {question} 97 | Plan: {plan}{scratchpad}""" 98 | 99 | plannerreact_agent_prompt = PromptTemplate( 100 | input_variables=["examples", "question", "plan", "scratchpad"], 101 | template = PLANNERREACT_INSTRUCTION, 102 | ) 103 | 104 | 105 | -------------------------------------------------------------------------------- /Self_Instruct/data_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pre_prompt import ( 3 | DATA_GEN_SYSTEM_PROMPT, 4 | HOTPOTQA_TASK_NAME, 5 | HOTPOTQA_TASK_DESCRIPTION, 6 | SCIENCEQA_TASK_NAME, 7 | SCIENCEQA_TASK_DESCRIPTION, 8 | SCIENCEQA_DATA_GEN_HUMAN_PROMPT, 9 | HOTPOTQA_DATA_GEN_HUMAN_PROMPT 10 | ) 11 | from llms import MetaAgent 12 | import json 13 | import random 14 | import re 15 | import os 16 | def get_data_hotpotqa(source_data): 17 | data=json.load(open(source_data)) 18 | data=[{"Question":d['Question'],"Answer":d['Answer']} for d in data] 19 | return data 20 | 21 | def get_data_scienceqa(source_data): 22 | data=json.load(open(source_data)) 23 | data=[{"Question":d['Question'],"Answer":d['Answer'],"Caption":d["caption"]} for d in data] 24 | return data 25 | 26 | def get_random_data(data, num_samples=5): 27 | random_data = random.sample(data, num_samples) 28 | return random_data 29 | 30 | def parse_ouput_scienceqa(output): 31 | output = output.split("\n") 32 | non_empty_lines = [line for line in output if line.strip()] 33 | output="\n".join(non_empty_lines) 34 | # Define a regular expression pattern to extract relevant information 35 | pattern = re.compile(r"Question:(.*?)Options:(.*?)Ocr:(.*?)Caption:(.*?)Answer:(.*?)$", re.DOTALL | re.MULTILINE) 36 | 37 | # Find all matches in the text 38 | matches = pattern.findall(output) 39 | 40 | # Create a list to store extracted data 41 | questions_data = [] 42 | 43 | # Process each match 44 | for match in matches: 45 | question = match[0].strip() 46 | options = match[1].strip() 47 | ocr = match[2].strip() 48 | caption = match[3].strip() 49 | answer = match[4].strip() 50 | # Create a dictionary for each question 51 | question_data = { 52 | "Question": question, 53 | "Options": options, 54 | "Ocr": ocr, 55 | "Caption": caption, 56 | "Answer": answer 57 | } 58 | # Add the dictionary to the list 59 | print(len(questions_data),'\n') 60 | questions_data.append(question_data) 61 | 62 | 63 | return questions_data 64 | 65 | def parse_ouput_hotpotqa(output): 66 | pattern = r"Question: (.+)\nAnswer: (.+)" 67 | matches = re.findall(pattern, output) 68 | print(matches) 69 | new_qa_pairs=[] 70 | for match in matches: 71 | question = match[0] 72 | answer = match[1] 73 | new_qa_pairs.append({ 74 | 'Question': question, 75 | 'Answer': answer, 76 | }) 77 | return new_qa_pairs 78 | def save_to_json(data,path): 79 | if os.path.exists(path) and os.path.getsize(path) > 0: 80 | with open(path, 'r') as file: 81 | ori_data = json.load(file) 82 | else: 83 | ori_data = [] 84 | with open(path, 'w') as file: 85 | data = data+ori_data 86 | json.dump(data, file,indent=4) 87 | 88 | def main(args): 89 | data_system_prmpt = DATA_GEN_SYSTEM_PROMPT 90 | if args.dataset_name == "hotpotqa": 91 | dataset_system_prompt = data_system_prmpt.format(task_name = HOTPOTQA_TASK_NAME, task_description = HOTPOTQA_TASK_DESCRIPTION) 92 | elif args.dataset_name == "scienceqa": 93 | dataset_system_prompt = data_system_prmpt.format(task_name = SCIENCEQA_TASK_NAME, task_description = SCIENCEQA_TASK_DESCRIPTION) 94 | meta_agent = MetaAgent( 95 | model_name=args.model_name, 96 | openai_key=args.openai_key, 97 | url=args.openai_base, 98 | system_prompt= dataset_system_prompt 99 | ) 100 | qa_pairs=[] 101 | if args.dataset_name=="hotpotqa": 102 | qa_pairs=get_data_hotpotqa(args.source_data) 103 | elif args.dataset_name == "scienceqa": 104 | qa_pairs = get_data_scienceqa(args.source_data) 105 | #else dataset_name 106 | answer_set=set() 107 | unique_qa=[] 108 | ori_qa=get_random_data(qa_pairs,num_samples=2) 109 | unique_qa = ori_qa 110 | for u in unique_qa: 111 | if args.dataset_name == "hotpotqa": 112 | answer_set.add(u["Answer"]) 113 | elif args.dataset_name == "scienceqa": 114 | answer_set.add(u["Answer"].split(' ')[1]) 115 | while(len(answer_set) List[Any]: 13 | """Return a list of values in dict sorted by key.""" 14 | return [values[val] for val in sorted(values)] 15 | 16 | 17 | class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): 18 | """Example selector that selects examples based on SemanticSimilarity.""" 19 | 20 | vectorstore: VectorStore 21 | """VectorStore than contains information about examples.""" 22 | k: int = 4 23 | """Number of examples to select.""" 24 | example_keys: Optional[List[str]] = None 25 | """Optional keys to filter examples to.""" 26 | input_keys: Optional[List[str]] = None 27 | """Optional keys to filter input to. If provided, the search is based on 28 | the input variables instead of all variables.""" 29 | 30 | class Config: 31 | """Configuration for this pydantic object.""" 32 | 33 | extra = Extra.forbid 34 | arbitrary_types_allowed = True 35 | 36 | def add_example(self, example: Dict[str, str]) -> str: 37 | """Add new example to vectorstore.""" 38 | if self.input_keys: 39 | string_example = " ".join( 40 | sorted_values({key: example[key] for key in self.input_keys}) 41 | ) 42 | else: 43 | string_example = " ".join(sorted_values(example)) 44 | ids = self.vectorstore.add_texts([string_example], metadatas=[example]) 45 | return ids[0] 46 | 47 | def select_examples(self, input_variables: Dict[str, str]) -> List[dict,float]: 48 | """Select which examples to use based on semantic similarity.""" 49 | # Get the docs with the highest similarity. 50 | if self.input_keys: 51 | input_variables = {key: input_variables[key] for key in self.input_keys} 52 | query = " ".join(sorted_values(input_variables)) 53 | example_docs = self.vectorstore.similarity_search(query, k=self.k) 54 | #print(example_docs) 55 | # Get the examples from the metadata. 56 | # This assumes that examples are stored in metadata. 57 | examples = [dict(e[0].metadata) for e in example_docs] 58 | 59 | # If example keys are provided, filter examples to those keys. 60 | if self.example_keys: 61 | examples = [{k: eg[k] for k in self.example_keys} for eg in examples] 62 | 63 | return example_docs 64 | 65 | @classmethod 66 | def from_examples( 67 | cls, 68 | examples: List[dict], 69 | embeddings: Embeddings, 70 | vectorstore_cls: Type[VectorStore], 71 | k: int = 4, 72 | input_keys: Optional[List[str]] = None, 73 | **vectorstore_cls_kwargs: Any, 74 | ) -> SemanticSimilarityExampleSelector: 75 | """Create k-shot example selector using example list and embeddings. 76 | 77 | Reshuffles examples dynamically based on query similarity. 78 | 79 | Args: 80 | examples: List of examples to use in the prompt. 81 | embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings(). 82 | vectorstore_cls: A vector store DB interface class, e.g. FAISS. 83 | k: Number of examples to select 84 | input_keys: If provided, the search is based on the input variables 85 | instead of all variables. 86 | vectorstore_cls_kwargs: optional kwargs containing url for vector store 87 | 88 | Returns: 89 | The ExampleSelector instantiated, backed by a vector store. 90 | """ 91 | if input_keys: 92 | string_examples = [ 93 | " ".join(sorted_values({k: eg[k] for k in input_keys})) 94 | for eg in examples 95 | ] 96 | else: 97 | string_examples = [" ".join(sorted_values(eg)) for eg in examples] 98 | vectorstore = vectorstore_cls.from_texts( 99 | string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs 100 | ) 101 | return cls(vectorstore=vectorstore, k=k, input_keys=input_keys) 102 | 103 | 104 | class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): 105 | """ExampleSelector that selects examples based on Max Marginal Relevance. 106 | 107 | This was shown to improve performance in this paper: 108 | https://arxiv.org/pdf/2211.13892.pdf 109 | """ 110 | 111 | fetch_k: int = 20 112 | """Number of examples to fetch to rerank.""" 113 | 114 | def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: 115 | """Select which examples to use based on semantic similarity.""" 116 | # Get the docs with the highest similarity. 117 | if self.input_keys: 118 | input_variables = {key: input_variables[key] for key in self.input_keys} 119 | query = " ".join(sorted_values(input_variables)) 120 | example_docs = self.vectorstore.max_marginal_relevance_search( 121 | query, k=self.k, fetch_k=self.fetch_k 122 | ) 123 | # Get the examples from the metadata. 124 | # This assumes that examples are stored in metadata. 125 | examples = [dict(e.metadata) for e in example_docs] 126 | # If example keys are provided, filter examples to those keys. 127 | if self.example_keys: 128 | examples = [{k: eg[k] for k in self.example_keys} for eg in examples] 129 | return examples 130 | 131 | @classmethod 132 | def from_examples( 133 | cls, 134 | examples: List[dict], 135 | embeddings: Embeddings, 136 | vectorstore_cls: Type[VectorStore], 137 | k: int = 4, 138 | input_keys: Optional[List[str]] = None, 139 | fetch_k: int = 20, 140 | **vectorstore_cls_kwargs: Any, 141 | ) -> MaxMarginalRelevanceExampleSelector: 142 | """Create k-shot example selector using example list and embeddings. 143 | 144 | Reshuffles examples dynamically based on query similarity. 145 | 146 | Args: 147 | examples: List of examples to use in the prompt. 148 | embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). 149 | vectorstore_cls: A vector store DB interface class, e.g. FAISS. 150 | k: Number of examples to select 151 | input_keys: If provided, the search is based on the input variables 152 | instead of all variables. 153 | vectorstore_cls_kwargs: optional kwargs containing url for vector store 154 | 155 | Returns: 156 | The ExampleSelector instantiated, backed by a vector store. 157 | """ 158 | if input_keys: 159 | string_examples = [ 160 | " ".join(sorted_values({k: eg[k] for k in input_keys})) 161 | for eg in examples 162 | ] 163 | else: 164 | string_examples = [" ".join(sorted_values(eg)) for eg in examples] 165 | vectorstore = vectorstore_cls.from_texts( 166 | string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs 167 | ) 168 | return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys) 169 | -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import os 9 | import joblib 10 | import json 11 | import requests 12 | 13 | def summarize_trial(agents): 14 | correct = [a for a in agents if a.is_correct()] 15 | incorrect = [a for a in agents if a.is_finished() and not a.is_correct()] 16 | not_finish = [a for a in agents if not a.is_finished()] 17 | return correct, incorrect, not_finish 18 | 19 | def remove_fewshot(prompt: str) -> str: 20 | prefix = prompt.split('Here are some examples:')[0] 21 | suffix = prompt.split('(END OF EXAMPLES)')[1] 22 | return prefix.strip('\n').strip() + '\n' + suffix.strip('\n').strip() 23 | 24 | def log_trial(agents, trial_n): 25 | correct, incorrect, not_finish = summarize_trial(agents) 26 | 27 | log = f""" 28 | ######################################## 29 | BEGIN TRIAL {trial_n} 30 | Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)} , Not Finished: {len(not_finish)} 31 | ####################################### 32 | """ 33 | 34 | log += '------------- BEGIN CORRECT AGENTS -------------\n\n' 35 | for agent in correct: 36 | # log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 37 | log += agent._build_agent_prompt() + f'\nCorrect answer: {agent.key}\n\n' 38 | 39 | log += '------------- BEGIN INCORRECT AGENTS -----------\n\n' 40 | for agent in incorrect: 41 | # log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 42 | log += agent._build_agent_prompt() + f'\nCorrect answer: {agent.key}\n\n' 43 | 44 | log += '------------- BEGIN NOT_FINISH AGENTS -----------\n\n' 45 | for agent in not_finish: 46 | log += agent._build_agent_prompt() + f'\nCorrect answer: {agent.key}\n\n' 47 | 48 | return log 49 | 50 | def summarize_trial_detailed(agents): 51 | correct = [a.is_correct() for a in agents] 52 | reward = [a.reward()[0] for a in agents] 53 | halted = [a for a in agents if a.is_halted()] 54 | incorrect = [a for a in agents if a.is_finished() and not a.is_correct()] 55 | error = [a.run_error for a in agents] 56 | return correct, reward, error, halted, incorrect 57 | 58 | def log_agent(agent, file_path): 59 | question = agent.question 60 | g_truth = agent.key 61 | correct = agent.is_correct() 62 | reward = agent.reward()[0] 63 | halted = agent.is_halted() 64 | error = agent.run_error 65 | prompt = agent._build_agent_prompt() 66 | save_dict = {"question":question, "answer":g_truth, "correct":correct, "reward":reward, 67 | "halted":halted, "error":error,"prompt":prompt} 68 | with open(file_path, 'a') as f: 69 | json.dump(save_dict, f) 70 | f.write("\n") 71 | 72 | 73 | def get_all_agent_sessions(file_name): 74 | sessions = [] 75 | with open(file_name) as f: 76 | for line in f: 77 | session = json.loads(line) 78 | sessions.append(session) 79 | return sessions 80 | 81 | def get_error_tasks(sessions): 82 | error_tasks = [] 83 | for sess in sessions: 84 | if sess["error"]: 85 | task = (sess["question"], sess["answer"]) 86 | error_tasks.append(task) 87 | error_tasks = list(set(error_tasks)) 88 | return error_tasks 89 | 90 | def get_non_error_tasks(sessions): 91 | tasks = [] 92 | for sess in sessions: 93 | if not sess["error"]: 94 | task = (sess["question"], sess["answer"]) 95 | tasks.append(task) 96 | tasks = list(set(tasks)) 97 | return tasks 98 | 99 | def delete_error(file_name): 100 | sessions = get_all_agent_sessions(file_name) 101 | non_error_sessions = [sess for sess in sessions if not sess["error"]] 102 | with open(file_name+'.back', 'a') as b_f: 103 | for sess in sessions: 104 | json.dump(sess, b_f) 105 | b_f.write('\n') 106 | with open(file_name, 'w') as f: 107 | for sess in non_error_sessions: 108 | json.dump(sess, f) 109 | f.write('\n') 110 | 111 | def summarize_react_trial(agents): 112 | correct = [a for a in agents if a.is_correct()] 113 | halted = [a for a in agents if a.is_halted()] 114 | incorrect = [a for a in agents if a.is_finished() and not a.is_correct()] 115 | return correct, incorrect, halted 116 | 117 | def summarize_react_trial_detailed(agents): 118 | correct = [a.is_correct() for a in agents] 119 | reward = [a.reward()[0] for a in agents] 120 | return correct, reward 121 | 122 | def log_react_trial(agents, trial_n): 123 | correct, incorrect, halted = summarize_react_trial(agents) 124 | 125 | log = f""" 126 | ######################################## 127 | BEGIN TRIAL {trial_n} 128 | Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)} 129 | ####################################### 130 | """ 131 | 132 | log += '------------- BEGIN CORRECT AGENTS -------------\n\n' 133 | for agent in correct: 134 | log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 135 | 136 | log += '------------- BEGIN INCORRECT AGENTS -----------\n\n' 137 | for agent in incorrect: 138 | log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 139 | 140 | log += '------------- BEGIN HALTED AGENTS -----------\n\n' 141 | for agent in halted: 142 | log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 143 | 144 | return log 145 | 146 | def save_agents(agents, dir: str): 147 | os.makedirs(dir, exist_ok=True) 148 | for i, agent in enumerate(agents): 149 | agent.enc = None 150 | joblib.dump(agent, os.path.join(dir, f'{i}.joblib')) 151 | 152 | def load_agents(dir:str): 153 | import tiktoken 154 | agents = [] 155 | for f in os.listdir(dir): 156 | agent = joblib.load(os.path.join(dir, f)) 157 | agent.enc = tiktoken.encoding_for_model("text-davinci-003") 158 | agents.append(agent) 159 | return agents 160 | 161 | 162 | 163 | def _validate_server(address): 164 | if not address: 165 | raise ValueError('Must provide a valid server for search') 166 | if address.startswith('http://') or address.startswith('https://'): 167 | return address 168 | PROTOCOL = 'http://' 169 | print(f'No protocol provided, using "{PROTOCOL}"') 170 | return f'{PROTOCOL}{address}' 171 | 172 | 173 | 174 | def call_bing_search(query, count, endpoint="https://api.bing.microsoft.com/v7.0/search", bing_api_key=""): 175 | headers = {'Ocp-Apim-Subscription-Key': bing_api_key} 176 | params = {"q": query, "textDecorations": True, 177 | "textFormat": "HTML", "count": count, "mkt": "en-GB"} 178 | try: 179 | server = _validate_server(endpoint) # server address 180 | server_response = requests.get(server, headers=headers, params=params) 181 | resp_status = server_response.status_code 182 | print(server_response) 183 | if resp_status == 200: 184 | result = server_response.json() 185 | return parse_bing_result(result) 186 | except: 187 | pass 188 | 189 | return None 190 | 191 | def parse_bing_result(result): 192 | responses = [] 193 | try: 194 | value = result["webPages"]["value"] 195 | except: 196 | return responses 197 | 198 | for i in range(len(value)): 199 | snippet = value[i]['snippet'] if 'snippet' in value[i] else "" 200 | snippet = snippet.replace("", "").replace("", "").strip() 201 | if snippet != "": 202 | responses.append(snippet) 203 | 204 | return responses -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import os 9 | import joblib 10 | import json 11 | import requests 12 | 13 | def summarize_trial(agents): 14 | correct = [a for a in agents if a.is_correct()] 15 | incorrect = [a for a in agents if a.is_finished() and not a.is_correct()] 16 | not_finish = [a for a in agents if not a.is_finished()] 17 | return correct, incorrect, not_finish 18 | 19 | def remove_fewshot(prompt: str) -> str: 20 | prefix = prompt.split('Here are some examples:')[0] 21 | suffix = prompt.split('(END OF EXAMPLES)')[1] 22 | return prefix.strip('\n').strip() + '\n' + suffix.strip('\n').strip() 23 | 24 | def log_trial(agents, trial_n): 25 | correct, incorrect, not_finish = summarize_trial(agents) 26 | 27 | log = f""" 28 | ######################################## 29 | BEGIN TRIAL {trial_n} 30 | Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)} , Not Finished: {len(not_finish)} 31 | ####################################### 32 | """ 33 | 34 | log += '------------- BEGIN CORRECT AGENTS -------------\n\n' 35 | for agent in correct: 36 | # log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 37 | log += agent._build_agent_prompt() + f'\nCorrect answer: {agent.key}\n\n' 38 | 39 | log += '------------- BEGIN INCORRECT AGENTS -----------\n\n' 40 | for agent in incorrect: 41 | # log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 42 | log += agent._build_agent_prompt() + f'\nCorrect answer: {agent.key}\n\n' 43 | 44 | log += '------------- BEGIN NOT_FINISH AGENTS -----------\n\n' 45 | for agent in not_finish: 46 | log += agent._build_agent_prompt() + f'\nCorrect answer: {agent.key}\n\n' 47 | 48 | return log 49 | 50 | def summarize_trial_detailed(agents): 51 | correct = [a.is_correct() for a in agents] 52 | reward = [a.reward()[0] for a in agents] 53 | halted = [a for a in agents if a.is_halted()] 54 | incorrect = [a for a in agents if a.is_finished() and not a.is_correct()] 55 | error = [a.run_error for a in agents] 56 | return correct, reward, error, halted, incorrect 57 | 58 | def log_agent(agent, file_path): 59 | question = agent.question 60 | g_truth = agent.key 61 | correct = agent.is_correct() 62 | reward = agent.reward()[0] 63 | halted = agent.is_halted() 64 | error = agent.run_error 65 | prompt = agent._build_agent_prompt() 66 | save_dict = {"question":question, "answer":g_truth, "correct":correct, "reward":reward, 67 | "halted":halted, "error":error,"prompt":prompt} 68 | with open(file_path, 'a') as f: 69 | json.dump(save_dict, f) 70 | f.write("\n") 71 | 72 | 73 | def get_all_agent_sessions(file_name): 74 | sessions = [] 75 | with open(file_name) as f: 76 | for line in f: 77 | session = json.loads(line) 78 | sessions.append(session) 79 | return sessions 80 | 81 | def get_error_tasks(sessions): 82 | error_tasks = [] 83 | for sess in sessions: 84 | if sess["error"]: 85 | task = (sess["question"], sess["answer"]) 86 | error_tasks.append(task) 87 | error_tasks = list(set(error_tasks)) 88 | return error_tasks 89 | 90 | def get_non_error_tasks(sessions): 91 | tasks = [] 92 | for sess in sessions: 93 | if not sess["error"]: 94 | task = (sess["question"], sess["answer"]) 95 | tasks.append(task) 96 | tasks = list(set(tasks)) 97 | return tasks 98 | 99 | def delete_error(file_name): 100 | sessions = get_all_agent_sessions(file_name) 101 | non_error_sessions = [sess for sess in sessions if not sess["error"]] 102 | with open(file_name+'.back', 'a') as b_f: 103 | for sess in sessions: 104 | json.dump(sess, b_f) 105 | b_f.write('\n') 106 | with open(file_name, 'w') as f: 107 | for sess in non_error_sessions: 108 | json.dump(sess, f) 109 | f.write('\n') 110 | 111 | def summarize_react_trial(agents): 112 | correct = [a for a in agents if a.is_correct()] 113 | halted = [a for a in agents if a.is_halted()] 114 | incorrect = [a for a in agents if a.is_finished() and not a.is_correct()] 115 | return correct, incorrect, halted 116 | 117 | def summarize_react_trial_detailed(agents): 118 | correct = [a.is_correct() for a in agents] 119 | reward = [a.reward()[0] for a in agents] 120 | return correct, reward 121 | 122 | def log_react_trial(agents, trial_n): 123 | correct, incorrect, halted = summarize_react_trial(agents) 124 | 125 | log = f""" 126 | ######################################## 127 | BEGIN TRIAL {trial_n} 128 | Trial summary: Correct: {len(correct)}, Incorrect: {len(incorrect)}, Halted: {len(halted)} 129 | ####################################### 130 | """ 131 | 132 | log += '------------- BEGIN CORRECT AGENTS -------------\n\n' 133 | for agent in correct: 134 | log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 135 | 136 | log += '------------- BEGIN INCORRECT AGENTS -----------\n\n' 137 | for agent in incorrect: 138 | log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 139 | 140 | log += '------------- BEGIN HALTED AGENTS -----------\n\n' 141 | for agent in halted: 142 | log += remove_fewshot(agent._build_agent_prompt()) + f'\nCorrect answer: {agent.key}\n\n' 143 | 144 | return log 145 | 146 | def save_agents(agents, dir: str): 147 | os.makedirs(dir, exist_ok=True) 148 | for i, agent in enumerate(agents): 149 | agent.enc = None 150 | joblib.dump(agent, os.path.join(dir, f'{i}.joblib')) 151 | 152 | def load_agents(dir:str): 153 | import tiktoken 154 | agents = [] 155 | for f in os.listdir(dir): 156 | agent = joblib.load(os.path.join(dir, f)) 157 | agent.enc = tiktoken.encoding_for_model("text-davinci-003") 158 | agents.append(agent) 159 | return agents 160 | 161 | 162 | 163 | def _validate_server(address): 164 | if not address: 165 | raise ValueError('Must provide a valid server for search') 166 | if address.startswith('http://') or address.startswith('https://'): 167 | return address 168 | PROTOCOL = 'http://' 169 | print(f'No protocol provided, using "{PROTOCOL}"') 170 | return f'{PROTOCOL}{address}' 171 | 172 | def parse_bing_result(result): 173 | responses = [] 174 | try: 175 | value = result["webPages"]["value"] 176 | except: 177 | return responses 178 | 179 | for i in range(len(value)): 180 | snippet = value[i]['snippet'] if 'snippet' in value[i] else "" 181 | snippet = snippet.replace("", "").replace("", "").strip() 182 | if snippet != "": 183 | responses.append(snippet) 184 | 185 | return responses 186 | 187 | def call_bing_search(query, count, endpoint="https://api.bing.microsoft.com/v7.0/search", bing_api_key=""): 188 | headers = {'Ocp-Apim-Subscription-Key': bing_api_key} 189 | params = {"q": query, "textDecorations": True, 190 | "textFormat": "HTML", "count": count, "mkt": "en-GB"} 191 | # print() 192 | try: 193 | server = _validate_server(endpoint) # server address 194 | server_response = requests.get(server, headers=headers, params=params) 195 | resp_status = server_response.status_code 196 | print(server_response) 197 | if resp_status == 200: 198 | result = server_response.json() 199 | return parse_bing_result(result) 200 | except: 201 | pass 202 | 203 | return None 204 | 205 | -------------------------------------------------------------------------------- /Scripts/filter_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import jsonlines 3 | import copy 4 | import random 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='Parsing the file_path to filter and to save the data.') 8 | parser.add_argument("--source_path", type=str, help="source data path") 9 | parser.add_argument("--save_path", type=str, help="path to save data") 10 | parser.add_argument("--task_name", type=str, help="task name") 11 | parser.add_argument("--filter_num", type=int, help="filter num") 12 | args = parser.parse_args() 13 | systemprompt_hotpotqa = """I want you to be a good multi-hop question answerer ,solving a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be five types : 14 | (1) BingSearch[query], which search the exact detailed query on the Internet and returns the relevant information to the query. Be specific and precise with your query to increase the chances of getting relevant results. For example, instead of searching for "dogs," you can search for "popular dog breeds in the United States."For example, BingSearch[Which type of computer networking technology, developed in the 1970s, allows devices to communicate over a shared network] 15 | (2) Retrieve[entity], which retrieve the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to retrieve. For example, Retrieve[Milhouse] 16 | (3) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Retrieve or BingSearch. For example, Lookup[river] 17 | (4) Finish[answer], which returns a definite answer. For example, Finish[Richard Nixon] (If it is a judgement question, please Finish[yes] or Finish[no]) 18 | (5) Reflect[right/wrong], which reflects the answer right or wrong based on the context history. For example, Reflect[right] 19 | Note that Reflect must be the next Action after Finish. You may take as many steps as necessary.""" 20 | 21 | systemprompt_scienceqa = """I want you to be a good multimodal multiple-choice science questions answerer. Select a correct option to a multi-choice multi-modal question with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be five types: 22 | (1) Image2Text[image], which generates captions for the image and detects words in the image.You are recommand to use it first to get more information about the imgage to the question. If the questions contains image, it will return catption and ocr text, else, it will return None. For example, ImageCaptioner[image] 23 | (2) BingSearch[question], which searches the exact detailed question on the Internet and returns the relevant information to the query. Be specific and precise with your query to increase the chances of getting relevant results. For example, instead of searching for "dogs," you can search for "popular dog breeds in the United States." For example, BingSearch[Which type of computer networking technology, developed in the 1970s, allows devices to communicate over a shared network] 24 | (3) Retrieve[entity], which retrieves the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to retrieve. For example, Retrieve[Milhouse] 25 | (4) Finish[option], which returns the answer option and finishes the task. For example, Finish[A] 26 | (5) Reflect[right/wrong], which reflects the answer right or wrong based on the context history. For example, Reflect[right] 27 | Note that to determine the answer, it's needed to consider both the Question and the available Options. 28 | Note that Reflect must be the next Action after Finish. 29 | BingSearch and Retrieve can be used multi-times.""" 30 | 31 | def prompt_retrive(prompt:str)->dict: 32 | action=[] 33 | thought=[] 34 | obversation=[] 35 | for line in prompt.split('\n'): 36 | if line.startswith("Action"): 37 | action.append(line[line.find(':')+1:].strip()) 38 | elif line.startswith("Thought"): 39 | thought.append(line[line.find(':')+1:].strip()) 40 | elif line.startswith("Observation"): 41 | obversation.append(line[line.find(':')+1:].strip()) 42 | return {"actions":action,"thoughts":thought,"observations":obversation} 43 | 44 | train_folder=args.save_path 45 | if train_folder[-1]!='/': 46 | train_folder+='/' 47 | f_plan = open(f"{train_folder}data_plan.json","w") 48 | f_action=open(f"{train_folder}data_action.json","w") 49 | f_reflect=open(f"{train_folder}data_reflect.json","w") 50 | 51 | 52 | with open(args.source_path,"r") as f: 53 | data_plan_action=[] 54 | data_plan_thought=[] 55 | data_action=[] 56 | data_reflect_thought=[] 57 | data_reflect_bool=[] 58 | lines = f.readlines() 59 | random.shuffle(lines) 60 | num = 0 61 | for item in lines: 62 | if num == args.filter_num: 63 | break 64 | item = json.loads(item) 65 | #只读取正确的数据 66 | if item['correct'] == False: 67 | continue 68 | else: 69 | num += 1 70 | question=item['question'] 71 | prompt=item["prompt"] 72 | data=prompt_retrive(prompt) 73 | 74 | #plan_data_thought and action 75 | systemprompt = systemprompt_hotpotqa if args.task_name == "HotpotQA" else systemprompt_scienceqa 76 | plan_data={"input":systemprompt+f"\nQuestion:{question}\nThought: ","output":""} 77 | for index,(a,t,o) in enumerate(zip(data["actions"],data["thoughts"],data["observations"])): 78 | if len(a)==0 or len(t)==0 or len(o)==0: 79 | continue 80 | plan_data["output"]=t 81 | if len(t) >0: 82 | data_plan_thought.append(copy.copy(plan_data)) 83 | plan_data["input"]+=t+"\n"+f"Action: " 84 | if '[' in a and ']' in a: 85 | action_type=a[:a.find('[')] 86 | keyword=a[a.find('[')+1:a.find(']')] 87 | if action_type=="Reflect": 88 | break 89 | plan_data["output"]=action_type 90 | if len(action_type) : 91 | data_plan_action.append(copy.copy(plan_data)) 92 | plan_data["input"]+=a+"\n" 93 | plan_data["input"]+=f"Obversation: "+o+"\n"+"Thought: " 94 | action_data={"input":systemprompt+f"\nQuestion:{question}\n","output":""} 95 | for index,(a,t,o) in enumerate(zip(data["actions"],data["thoughts"],data["observations"])): 96 | if len(a)==0 or len(t)==0 or len(o)==0: 97 | continue 98 | if '[' in a and ']' in a: 99 | action_type=a[:a.find('[')] 100 | keyword=a[a.find('[')+1:a.find(']')] 101 | if action_type == 'Reflect' : 102 | action_data["input"]+=f"Thought: " 103 | action_data["output"]=t 104 | if len(t)>0: 105 | data_reflect_thought.append(copy.copy(action_data)) 106 | action_data["input"]+=t + '\n' 107 | action_data["output"]=f"Reflect[{keyword}]" 108 | data_reflect_bool.append(copy.copy(action_data)) 109 | action_data["input"]+=f"Action: "+a+"\n" 110 | action_data["input"]+=f"Obversation: "+o+"\n" 111 | else : 112 | action_data["input"]+=f"Thought: "+t+"\n" 113 | action_data["input"]+=f"Aciton: "+action_type 114 | action_data["output"]=keyword 115 | if len(keyword)>0: 116 | data_action.append(copy.copy(action_data)) 117 | action_data["input"]+=f"[{keyword}]"+"\n" 118 | action_data["input"]+=f"Obversation: "+o+"\n" 119 | else: 120 | action_data["input"]+=f"Thought: "+t+"\n" 121 | action_data["input"]+=f"Action: "+a+"\n" 122 | action_data["input"]+=f"Obversation: "+o+"\n" 123 | 124 | print(num) 125 | data_plan = data_plan_action+data_plan_thought 126 | random.shuffle(data_plan) 127 | print(len(data_plan)) 128 | json.dump(data_plan, f_plan, ensure_ascii=False) 129 | 130 | print(len(data_action)) 131 | json.dump(data_action,f_action,ensure_ascii=False) 132 | 133 | data_reflect = data_reflect_bool+data_reflect_thought 134 | random.shuffle(data_reflect) 135 | print(len(data_reflect)) 136 | json.dump(data_reflect,f_reflect,ensure_ascii=False) 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /Self_Plan/Traj_Syn/benchmark_run/wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import json 9 | import os 10 | import gym 11 | import numpy as np 12 | import re 13 | import string 14 | from collections import Counter 15 | 16 | 17 | DATA_DIR = "data" 18 | HOTPOTQA_SPLIT_FILE = { 19 | "train": "hotpot_train_v1.1_simplified.json", 20 | "dev": "hotpot_dev_v1_simplified.json", 21 | "test": "hotpot_test_v1_simplified.json", 22 | } 23 | 24 | FEVER_SPLIT_FILE = { 25 | "train": "train.jsonl", 26 | "dev": "paper_dev.jsonl", 27 | } 28 | 29 | 30 | class HistoryWrapper(gym.ObservationWrapper): 31 | def __init__(self, env, obs_format, prompt=None): 32 | super().__init__(env) 33 | assert obs_format in ["obs", "history"] 34 | if obs_format == "history": 35 | assert hasattr(self.env, "traj") 36 | self.obs_format = obs_format 37 | self.prompt = prompt if prompt is not None else "" 38 | 39 | def observation(self, obs): 40 | if self.obs_format == "obs": 41 | return obs 42 | elif self.obs_format == "history": 43 | observation = self.env.traj["observations"][0] + "\n" 44 | for i, (o, a) in enumerate(zip(self.env.traj["observations"][1:], self.env.traj["actions"]), 1): 45 | observation += f"Action {i}: {a}\nObservation {i}: {o}\n\n" 46 | return self.prompt + observation 47 | 48 | 49 | def normalize_answer(s): 50 | def remove_articles(text): 51 | return re.sub(r"\b(a|an|the)\b", " ", text) 52 | 53 | def white_space_fix(text): 54 | return " ".join(text.split()) 55 | 56 | def remove_punc(text): 57 | exclude = set(string.punctuation) 58 | return "".join(ch for ch in text if ch not in exclude) 59 | 60 | def lower(text): 61 | return text.lower() 62 | 63 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 64 | 65 | def f1_score(prediction, ground_truth): 66 | normalized_prediction = normalize_answer(prediction) 67 | normalized_ground_truth = normalize_answer(ground_truth) 68 | 69 | ZERO_METRIC = (0, 0, 0) 70 | 71 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 72 | return ZERO_METRIC 73 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 74 | return ZERO_METRIC 75 | 76 | prediction_tokens = normalized_prediction.split() 77 | ground_truth_tokens = normalized_ground_truth.split() 78 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 79 | num_same = sum(common.values()) 80 | if num_same == 0: 81 | return ZERO_METRIC 82 | precision = 1.0 * num_same / len(prediction_tokens) 83 | recall = 1.0 * num_same / len(ground_truth_tokens) 84 | f1 = (2 * precision * recall) / (precision + recall) 85 | return f1, precision, recall 86 | 87 | class HotPotQAWrapper(gym.Wrapper): 88 | def __init__(self, env, split): 89 | super().__init__(env) 90 | data_file = f"{DATA_DIR}/{HOTPOTQA_SPLIT_FILE[split]}" 91 | self.data = json.load(open(data_file)) 92 | self.data = [(d['question'], d['answer']) for d in self.data] 93 | self.data_idx = 0 94 | self.split = split 95 | 96 | def reset(self, seed=None, return_info=False, options=None, idx=None): 97 | self.env.reset(seed=seed, return_info=return_info, options=options) 98 | try: 99 | self.env.step('') 100 | except: 101 | pass 102 | self.env.reset(seed=seed, return_info=return_info, options=options) 103 | self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx 104 | observation = f"Question: {self.data[self.data_idx][0]}" 105 | info = self._get_info() 106 | return (observation, info) if return_info else observation 107 | 108 | def _get_info(self): 109 | return { 110 | "steps": self.steps, 111 | "answer": self.answer, 112 | "question": self.data[self.data_idx][0], 113 | "hotpot_split": self.split 114 | } 115 | 116 | def get_reward(self, info): 117 | if info['answer'] is not None: 118 | pred = normalize_answer(self.data[self.data_idx][1]) 119 | gt = normalize_answer(info['answer']) 120 | score = (pred == gt) 121 | return int(score) 122 | return 0 123 | 124 | def get_metrics(self, info): 125 | if info['answer'] is not None: 126 | pred = normalize_answer(self.data[self.data_idx][1]) 127 | gt = normalize_answer(info['answer']) 128 | em = (pred == gt) 129 | f1 = f1_score(pred, gt)[0] 130 | return {'reward': em, 'em': em, 'f1': f1} 131 | return {'reward': 0, 'em': 0, 'f1': 0} 132 | 133 | def step(self, action): 134 | # TODO: first step obs does not have question. 135 | obs, _, done, info = self.env.step(action) 136 | reward = self.get_reward(info) 137 | if done: 138 | obs = f"Episode finished, reward = {reward}\n" 139 | info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx}) 140 | info.update(self.get_metrics(info)) 141 | return obs, reward, done, info 142 | 143 | def __len__(self): 144 | return len(self.data) 145 | 146 | class FeverWrapper(gym.Wrapper): 147 | def __init__(self, env, split): 148 | super().__init__(env) 149 | 150 | data_path = f"./data/{FEVER_SPLIT_FILE[split]}" 151 | with open(data_path, "r") as json_file: 152 | json_list = list(json_file) 153 | 154 | data = [] 155 | for json_str in json_list: 156 | json_str = json.loads(json_str) 157 | label = json_str["label"] 158 | claim = json_str["claim"] 159 | data.append((claim, label)) 160 | 161 | self.data = data 162 | self.data_idx = 0 163 | self.split = split 164 | 165 | def reset(self, seed=None, return_info=False, options=None, idx=None): 166 | self.env.reset(seed=seed, return_info=return_info, options=options) 167 | try: 168 | self.env.step('') 169 | except: 170 | pass 171 | self.env.reset(seed=seed, return_info=return_info, options=options) 172 | self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx 173 | observation = f"Claim: {self.data[self.data_idx][0]}" 174 | info = self._get_info() 175 | return (observation, info) if return_info else observation 176 | 177 | def _get_info(self): 178 | return { 179 | "steps": self.steps, 180 | "answer": self.answer, 181 | "question": self.data[self.data_idx][0], 182 | "fever_split": self.split 183 | } 184 | 185 | def get_reward(self, info): 186 | if info['answer'] is not None: 187 | label = normalize_answer(self.data[self.data_idx][1]) 188 | pred = normalize_answer(info['answer']) 189 | if label == pred: 190 | return 1 191 | return 0 192 | 193 | def step(self, action): 194 | # TODO: first step obs does not have question. 195 | obs, _, done, info = self.env.step(action) 196 | reward = self.get_reward(info) 197 | if done: 198 | obs = f"Episode finished, reward = {reward}\n" 199 | info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx}) 200 | info.update({'em': reward, 'reward': reward, 'f1': reward}) 201 | return obs, reward, done, info 202 | 203 | def __len__(self): 204 | return len(self.data) 205 | 206 | 207 | class LoggingWrapper(gym.Wrapper): 208 | def __init__(self, env, folder="trajs", file_id=None): 209 | super().__init__(env) 210 | self.trajs = [] 211 | self.traj = {"observations": [], "actions": []} 212 | self.folder = folder 213 | self.file_id = np.random.randint(0, 10000000) if file_id is None else file_id 214 | self.file_path = f"{self.folder}/{self.file_id}.json" 215 | os.makedirs("trajs", exist_ok=True) 216 | 217 | def __len__(self): 218 | return len(self.env.data) 219 | 220 | 221 | def reset(self, seed=None, return_info=False, options=None, idx=None): 222 | output = self.env.reset(seed=seed, return_info=return_info, options=options, idx=idx) 223 | observation = output[0] if return_info else output 224 | self.traj = {"observations": [observation], "actions": []} 225 | return output 226 | 227 | def step(self, action): 228 | obs, reward, done, info = self.env.step(action) 229 | self.traj["observations"].append(obs) 230 | self.traj["actions"].append(action) 231 | if done: 232 | self.traj.update(info) 233 | return obs, reward, done, info 234 | 235 | def update_record(self): 236 | if len(self.traj) > 0: 237 | self.trajs.append(self.traj) 238 | self.traj = {"observations": [], "actions": []} 239 | 240 | def write(self): 241 | self.update_record() 242 | with open(self.file_path, "w") as f: 243 | json.dump(self.trajs, f) 244 | print(f"Saved trajs to trajs/{self.file_id}.json") 245 | 246 | def close(self): 247 | self.write() -------------------------------------------------------------------------------- /Self_Plan/Group_Planning/benchmark_run/wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2023, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: Apache License 2.0 5 | For full license text, see the LICENSE file in the repo root or https://www.apache.org/licenses/LICENSE-2.0 6 | """ 7 | 8 | import json 9 | import os 10 | import gym 11 | import numpy as np 12 | import re 13 | import string 14 | from collections import Counter 15 | 16 | 17 | DATA_DIR = "data" 18 | HOTPOTQA_SPLIT_FILE = { 19 | "train": "hotpot_train_v1.1_simplified.json", 20 | "dev": "hotpot_dev_v1_simplified.json", 21 | "test": "hotpot_test_v1_simplified.json", 22 | } 23 | 24 | FEVER_SPLIT_FILE = { 25 | "train": "train.jsonl", 26 | "dev": "paper_dev.jsonl", 27 | } 28 | 29 | 30 | class HistoryWrapper(gym.ObservationWrapper): 31 | def __init__(self, env, obs_format, prompt=None): 32 | super().__init__(env) 33 | assert obs_format in ["obs", "history"] 34 | if obs_format == "history": 35 | assert hasattr(self.env, "traj") 36 | self.obs_format = obs_format 37 | self.prompt = prompt if prompt is not None else "" 38 | 39 | def observation(self, obs): 40 | if self.obs_format == "obs": 41 | return obs 42 | elif self.obs_format == "history": 43 | observation = self.env.traj["observations"][0] + "\n" 44 | for i, (o, a) in enumerate(zip(self.env.traj["observations"][1:], self.env.traj["actions"]), 1): 45 | observation += f"Action {i}: {a}\nObservation {i}: {o}\n\n" 46 | return self.prompt + observation 47 | 48 | 49 | def normalize_answer(s): 50 | def remove_articles(text): 51 | return re.sub(r"\b(a|an|the)\b", " ", text) 52 | 53 | def white_space_fix(text): 54 | return " ".join(text.split()) 55 | 56 | def remove_punc(text): 57 | exclude = set(string.punctuation) 58 | return "".join(ch for ch in text if ch not in exclude) 59 | 60 | def lower(text): 61 | return text.lower() 62 | 63 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 64 | 65 | def f1_score(prediction, ground_truth): 66 | normalized_prediction = normalize_answer(prediction) 67 | normalized_ground_truth = normalize_answer(ground_truth) 68 | 69 | ZERO_METRIC = (0, 0, 0) 70 | 71 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 72 | return ZERO_METRIC 73 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 74 | return ZERO_METRIC 75 | 76 | prediction_tokens = normalized_prediction.split() 77 | ground_truth_tokens = normalized_ground_truth.split() 78 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 79 | num_same = sum(common.values()) 80 | if num_same == 0: 81 | return ZERO_METRIC 82 | precision = 1.0 * num_same / len(prediction_tokens) 83 | recall = 1.0 * num_same / len(ground_truth_tokens) 84 | f1 = (2 * precision * recall) / (precision + recall) 85 | return f1, precision, recall 86 | 87 | class HotPotQAWrapper(gym.Wrapper): 88 | def __init__(self, env, split): 89 | super().__init__(env) 90 | data_file = f"{DATA_DIR}/{HOTPOTQA_SPLIT_FILE[split]}" 91 | self.data = json.load(open(data_file)) 92 | self.data = [(d['question'], d['answer']) for d in self.data] 93 | self.data_idx = 0 94 | self.split = split 95 | 96 | def reset(self, seed=None, return_info=False, options=None, idx=None): 97 | self.env.reset(seed=seed, return_info=return_info, options=options) 98 | try: 99 | self.env.step('') 100 | except: 101 | pass 102 | self.env.reset(seed=seed, return_info=return_info, options=options) 103 | self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx 104 | observation = f"Question: {self.data[self.data_idx][0]}" 105 | info = self._get_info() 106 | return (observation, info) if return_info else observation 107 | 108 | def _get_info(self): 109 | return { 110 | "steps": self.steps, 111 | "answer": self.answer, 112 | "question": self.data[self.data_idx][0], 113 | "hotpot_split": self.split 114 | } 115 | 116 | def get_reward(self, info): 117 | if info['answer'] is not None: 118 | pred = normalize_answer(self.data[self.data_idx][1]) 119 | gt = normalize_answer(info['answer']) 120 | score = (pred == gt) 121 | return int(score) 122 | return 0 123 | 124 | def get_metrics(self, info): 125 | if info['answer'] is not None: 126 | pred = normalize_answer(self.data[self.data_idx][1]) 127 | gt = normalize_answer(info['answer']) 128 | em = (pred == gt) 129 | f1 = f1_score(pred, gt)[0] 130 | return {'reward': em, 'em': em, 'f1': f1} 131 | return {'reward': 0, 'em': 0, 'f1': 0} 132 | 133 | def step(self, action): 134 | # TODO: first step obs does not have question. 135 | obs, _, done, info = self.env.step(action) 136 | reward = self.get_reward(info) 137 | if done: 138 | obs = f"Episode finished, reward = {reward}\n" 139 | info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx}) 140 | info.update(self.get_metrics(info)) 141 | return obs, reward, done, info 142 | 143 | def __len__(self): 144 | return len(self.data) 145 | 146 | class FeverWrapper(gym.Wrapper): 147 | def __init__(self, env, split): 148 | super().__init__(env) 149 | 150 | data_path = f"./data/{FEVER_SPLIT_FILE[split]}" 151 | with open(data_path, "r") as json_file: 152 | json_list = list(json_file) 153 | 154 | data = [] 155 | for json_str in json_list: 156 | json_str = json.loads(json_str) 157 | label = json_str["label"] 158 | claim = json_str["claim"] 159 | data.append((claim, label)) 160 | 161 | self.data = data 162 | self.data_idx = 0 163 | self.split = split 164 | 165 | def reset(self, seed=None, return_info=False, options=None, idx=None): 166 | self.env.reset(seed=seed, return_info=return_info, options=options) 167 | try: 168 | self.env.step('') 169 | except: 170 | pass 171 | self.env.reset(seed=seed, return_info=return_info, options=options) 172 | self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx 173 | observation = f"Claim: {self.data[self.data_idx][0]}" 174 | info = self._get_info() 175 | return (observation, info) if return_info else observation 176 | 177 | def _get_info(self): 178 | return { 179 | "steps": self.steps, 180 | "answer": self.answer, 181 | "question": self.data[self.data_idx][0], 182 | "fever_split": self.split 183 | } 184 | 185 | def get_reward(self, info): 186 | if info['answer'] is not None: 187 | label = normalize_answer(self.data[self.data_idx][1]) 188 | pred = normalize_answer(info['answer']) 189 | if label == pred: 190 | return 1 191 | return 0 192 | 193 | def step(self, action): 194 | # TODO: first step obs does not have question. 195 | obs, _, done, info = self.env.step(action) 196 | reward = self.get_reward(info) 197 | if done: 198 | obs = f"Episode finished, reward = {reward}\n" 199 | info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx}) 200 | info.update({'em': reward, 'reward': reward, 'f1': reward}) 201 | return obs, reward, done, info 202 | 203 | def __len__(self): 204 | return len(self.data) 205 | 206 | 207 | class LoggingWrapper(gym.Wrapper): 208 | def __init__(self, env, folder="trajs", file_id=None): 209 | super().__init__(env) 210 | self.trajs = [] 211 | self.traj = {"observations": [], "actions": []} 212 | self.folder = folder 213 | self.file_id = np.random.randint(0, 10000000) if file_id is None else file_id 214 | self.file_path = f"{self.folder}/{self.file_id}.json" 215 | os.makedirs("trajs", exist_ok=True) 216 | 217 | def __len__(self): 218 | return len(self.env.data) 219 | 220 | 221 | def reset(self, seed=None, return_info=False, options=None, idx=None): 222 | output = self.env.reset(seed=seed, return_info=return_info, options=options, idx=idx) 223 | observation = output[0] if return_info else output 224 | self.traj = {"observations": [observation], "actions": []} 225 | return output 226 | 227 | def step(self, action): 228 | obs, reward, done, info = self.env.step(action) 229 | self.traj["observations"].append(obs) 230 | self.traj["actions"].append(action) 231 | if done: 232 | self.traj.update(info) 233 | return obs, reward, done, info 234 | 235 | def update_record(self): 236 | if len(self.traj) > 0: 237 | self.trajs.append(self.traj) 238 | self.traj = {"observations": [], "actions": []} 239 | 240 | def write(self): 241 | self.update_record() 242 | with open(self.file_path, "w") as f: 243 | json.dump(self.trajs, f) 244 | print(f"Saved trajs to trajs/{self.file_id}.json") 245 | 246 | def close(self): 247 | self.write() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

AutoAct

2 |

Automatic Agent Learning from Scratch for QA via Self-Planning

3 | 4 |

5 | 📄arXiv • 6 | 🤗HFPaper • 7 | 🌐Web 8 |

9 | 10 | [![Awesome](https://awesome.re/badge.svg)](https://github.com/zjunlp/AutoAct) 11 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 12 | ![](https://img.shields.io/github/last-commit/zjunlp/AutoAct?color=green) 13 | 14 | ## Table of Contents 15 | 16 | - 🌻[Acknowledgement](#🌻acknowledgement) 17 | - 🌟[Overview](#🌟overview) 18 | - 🔧[Installation](#🔧installation) 19 | - ✏️[Self-Instruct](#✏️Self-Instruct) 20 | - 📝[Self-Planning](#📝Self-Planning) 21 | - [Automatic Tool Selection](#Automatic-Tool-Selection) 22 | - [Trajectories Synthesis](#Trajectories-Synthesis) 23 | - [Self-Differentiation](#Self-Differentiation) 24 | - [Group Planning](#Group-Planning) 25 | - 🚩[Citation](#🚩Citation) 26 | 27 | --- 28 | 29 | 30 | 31 | ## 🌻Acknowledgement 32 | 33 | Our code of training module is referenced and adapted from [FastChat](https://github.com/lm-sys/FastChat), while the code of inference module is implemented based on [BOLAA](https://github.com/salesforce/BOLAA). Various baseline codes use [ReAct](https://github.com/ysymyth/ReAct), [Reflexion](https://github.com/noahshinn/reflexion), [BOLAA](https://github.com/salesforce/BOLAA), [Chameleon](https://github.com/lupantech/chameleon-llm), [ReWOO](https://github.com/billxbf/ReWOO), [FireAct](https://github.com/anchen1011/FireAct) respectively. We use LangChain with open models via [Fastchat](https://github.com/lm-sys/FastChat/blob/main/docs/langchain_integration.md). Thanks for their great contributions! 34 | 35 | 36 | 37 | ## 🌟Overview 38 | 39 | Language agents have achieved considerable performance on various complex tasks. Despite the incessant exploration in this field, existing language agent systems still struggle with costly, non-reproducible data reliance and face the challenge of compelling a single model for multiple functions. To this end, we introduce **AutoAct**, an automatic agent learning framework that does not rely on large-scale annotated data and synthetic trajectories from closed-source models (e.g., GPT-4). Given limited data with a tool library, **AutoAct** first automatically synthesizes planning trajectories without any assistance from humans or strong closed-source models. Then, **AutoAct** leverages a *division-of-labor* strategy to automatically differentiate based on the target task information and synthesized trajectories, producing a sub-agent group to complete the task. We conduct comprehensive experiments with different LLMs, which demonstrates that **AutoAct** yields better or parallel performance compared to various strong baselines. 40 | 41 | method 42 | 43 | 44 | 45 | ## 🔧Installation 46 | 47 | ```bash 48 | git clone https://github.com/zjunlp/AutoAct 49 | cd AutoAct 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | Before the experiments, you need to apply for a Bing Search key [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api) (not free). 54 | 55 | ## ✏️Self-Instruct 56 | 57 | We conduct self-instruct on Meta-Agent to acquire a sufficient amount of task data and provide an ample training resource. 58 | 59 | ```bash 60 | python Self_Instruct/data_generation.py \ 61 | --source_data Self_Instruct/Meta_sample/Meta_Hotpotqa.json \ 62 | --target_data Self_Instruct/hotpotqa_metaqa.json \ 63 | --dataset_name hotpotqa \ 64 | --generate_all_num 800 \ 65 | --generate_per_round_num 10 \ 66 | --model_name llama-2-13b-chat \ 67 | ``` 68 | 69 | The `source_data` contains data examples from the target task information. The `target_data` consists of data generated through self-instruct. The variable `generate_all_num` represents the total number of generated data instances. In order to improve generation efficiency and avoid duplication, we generate `generate_per_round_num` data instances per round. 70 | 71 | 72 | 73 | ## 📝Self-Planning 74 | 75 | ### Automatic Tool Selection 76 | 77 | With the tool library at hand, we ask the Meta-Agent to select applicable tools for each task automatically. 78 | 79 | ```bash 80 | python Self_Planning/Tool_Selection/tool_selected.py \ 81 | --model_name llama-2-13b-chat \ 82 | --task_name ScienceQA \ 83 | --top_k 40 \ 84 | --top_p 0.75 \ 85 | --max_tokens 1024 \ 86 | --tool_save_path Self_Planning/Tool_Selection/{task_name}_Tools.json 87 | ``` 88 | 89 | The information of the selected tools will be stored in `tool_save_path`. 90 | 91 | 92 | 93 | ### Trajectories Synthesis 94 | 95 | ```bash 96 | python Self_Plan/Traj_Syn/run_task.py \ 97 | --agent_name ZeroshotThink_HotPotQA_run_Agent \ 98 | --llm_name llama-2-13b-chat \ 99 | --max_context_len 4096 \ 100 | --task Hotpotqa \ 101 | --task_path Self_Instruct/hotpotqa_metaqa.json \ 102 | --save_path Self_Plan/Traj_Syn/output/hotpotqa_train_data.jsonl 103 | ``` 104 | 105 | In order to obtain high-quality synthesized trajectories, we filter out all the trajectories with $\texttt{reward}<1$ and collect trajectories with exactly correct answers ($\texttt{reward}=1$) as the training source for self-differentiation. We release the trajectories synthesized by Llama-{13,70}b-chat after filtering in [Google Drive](https://drive.google.com/drive/folders/1Sh6Ksj8T0fT23ePWRf_dDcOTmpZlulr2?usp=sharing) (but you should also run `filter_data.py` for trajectory differentiation). 106 | 107 | ```bash 108 | python Scripts/filter_data.py \ 109 | --source_path Self_Plan/Traj_Syn/output/hotpotqa_train_data.jsonl \ 110 | --save_path Self_Plan/Traj_Syn/output \ 111 | --task_name HotpotQA \ 112 | --filter_num 200 113 | ``` 114 | 115 | 116 | 117 | ### Self-Differentiation 118 | 119 | In order to establish a clear *division-of-labor*, we leverage synthesized planning trajectories to differentiate the Meta-Agent into three sub-agents with distinct functionalities: 120 | 121 | - **Plan-Agent** undertakes task decomposition and determines which tool to invoke in each planning loop. 122 | - **Tool-Agent** is responsible for how to invoke the tool by deciding the parameters for the tool invocation. 123 | - **Reflect-Agent** engages in reflection by considering all the historical trajectories and providing a reflection result. 124 | 125 | Agent training: 126 | 127 | ```bash 128 | for agent in plan tool reflect 129 | do 130 | echo "####################" 131 | echo $agent 132 | echo "####################" 133 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 deepspeed Self_Plan/Train/train_lora.py \ 134 | --model_name_or_path llama-2-13b-chat \ 135 | --lora_r 8 \ 136 | --lora_alpha 16 \ 137 | --lora_dropout 0.05 \ 138 | --data_path Self_Plan/Traj_Syn/output/data_$agent.json \ 139 | --output_dir Self_Plan/Train/lora/HotpotQA/13b-$agent-5-epoch \ 140 | --num_train_epochs 5 \ 141 | --per_device_train_batch_size 2 \ 142 | --per_device_eval_batch_size 1 \ 143 | --gradient_accumulation_steps 1 \ 144 | --evaluation_strategy "no" \ 145 | --save_strategy "steps" \ 146 | --save_steps 10000 \ 147 | --save_total_limit 1 \ 148 | --learning_rate 1e-4 \ 149 | --weight_decay 0. \ 150 | --warmup_ratio 0.03 \ 151 | --lr_scheduler_type "cosine" \ 152 | --logging_steps 1 \ 153 | --fp16 True \ 154 | --model_max_length 4096 \ 155 | --gradient_checkpointing True \ 156 | --q_lora False \ 157 | --deepspeed Self_Plan/Train/deepspeed_config_s3.json \ 158 | --resume_from_checkpoint False 159 | done 160 | ``` 161 | 162 | 163 | 164 | ### Group Planning 165 | 166 | After obtaining the task-specific sub-agents, any new question is processed through group planning among the sub-agents to achieve the desired outcome. 167 | 168 | ```bash 169 | python Self_Planning/Group_Planning/run_eval.py \ 170 | --agent_name ZeroshotThink_HotPotQA_run_Agent \ 171 | --plan_agent plan \ 172 | --tool_agent tool \ 173 | --reflect_agent reflect \ 174 | --max_context_len 4096 \ 175 | --task HotpotQA \ 176 | --task_path Self_Planning/Group_Planning/benchmark_run/data/hotpotqa \ 177 | --save_path Self_Planning/Group_Planning/output/13b 178 | ``` 179 | 180 | We release the trajectories of text sets generated by Llama-{7,13,70}b-chat in [Google Drive](https://drive.google.com/drive/folders/1Sh6Ksj8T0fT23ePWRf_dDcOTmpZlulr2?usp=sharing). 181 | 182 | The prompts used in our experiments are in directory [Prompts]https://github.com/zjunlp/AutoAct/tree/main/Prompts. 183 | 184 | ## 🚩Citation 185 | 186 | Please cite our repository if you use AutoAct in your work. Thanks! 187 | 188 | ```bibtex 189 | @article{DBLP:journals/corr/abs-2401-05268, 190 | author = {Shuofei Qiao and 191 | Ningyu Zhang and 192 | Runnan Fang and 193 | Yujie Luo and 194 | Wangchunshu Zhou and 195 | Yuchen Eleanor Jiang and 196 | Chengfei Lv and 197 | Huajun Chen}, 198 | title = {{AUTOACT:} Automatic Agent Learning from Scratch via Self-Planning}, 199 | journal = {CoRR}, 200 | volume = {abs/2401.05268}, 201 | year = {2024}, 202 | url = {https://doi.org/10.48550/arXiv.2401.05268}, 203 | doi = {10.48550/ARXIV.2401.05268}, 204 | eprinttype = {arXiv}, 205 | eprint = {2401.05268}, 206 | timestamp = {Thu, 25 Jan 2024 15:41:08 +0100}, 207 | biburl = {https://dblp.org/rec/journals/corr/abs-2401-05268.bib}, 208 | bibsource = {dblp computer science bibliography, https://dblp.org} 209 | } 210 | ``` 211 | 212 | 213 | 214 | ## 🎉Contributors 215 | 216 | 217 | 218 | 219 | We will offer long-term maintenance to fix bugs and solve issues. So if you have any problems, please put issues to us. 220 | -------------------------------------------------------------------------------- /Self_Plan/Train/train.py: -------------------------------------------------------------------------------- 1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: 2 | # 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass, field 18 | import json 19 | import math 20 | import pathlib 21 | from typing import Dict, Optional, Sequence 22 | 23 | import numpy as np 24 | import torch 25 | from torch.utils.data import Dataset 26 | import transformers 27 | from transformers import Trainer 28 | from transformers.trainer_pt_utils import LabelSmoother 29 | 30 | from fastchat.conversation import SeparatorStyle 31 | from fastchat.model.model_adapter import get_conversation_template 32 | 33 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 34 | 35 | 36 | @dataclass 37 | class ModelArguments: 38 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 39 | trust_remote_code: bool = field( 40 | default=False, 41 | metadata={ 42 | "help": "Whether or not to allow for custom models defined on the Hub in their own modeling files" 43 | }, 44 | ) 45 | padding_side: str = field( 46 | default="right", metadata={"help": "The padding side in tokenizer"} 47 | ) 48 | 49 | 50 | @dataclass 51 | class DataArguments: 52 | data_path: str = field( 53 | default=None, metadata={"help": "Path to the training data."} 54 | ) 55 | eval_data_path: str = field( 56 | default=None, metadata={"help": "Path to the evaluation data."} 57 | ) 58 | lazy_preprocess: bool = False 59 | 60 | 61 | @dataclass 62 | class TrainingArguments(transformers.TrainingArguments): 63 | cache_dir: Optional[str] = field(default=None) 64 | optim: str = field(default="adamw_torch") 65 | model_max_length: int = field( 66 | default=512, 67 | metadata={ 68 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 69 | }, 70 | ) 71 | 72 | 73 | local_rank = None 74 | 75 | 76 | def rank0_print(*args): 77 | if local_rank == 0: 78 | print(*args) 79 | 80 | 81 | def trainer_save_model_safe(trainer: transformers.Trainer): 82 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 83 | from torch.distributed.fsdp import StateDictType, FullStateDictConfig 84 | 85 | save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 86 | with FSDP.state_dict_type( 87 | trainer.model, StateDictType.FULL_STATE_DICT, save_policy 88 | ): 89 | trainer.save_model() 90 | 91 | 92 | def preprocess( 93 | sources, 94 | tokenizer: transformers.PreTrainedTokenizer, 95 | ) -> Dict: 96 | conv = get_conversation_template("vicuna") 97 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 98 | 99 | # Apply prompt templates 100 | conversations = [] 101 | for i, source in enumerate(sources): 102 | if roles[source[0]["from"]] != conv.roles[0]: 103 | # Skip the first one if it is not from human 104 | source = source[1:] 105 | 106 | conv.messages = [] 107 | for j, sentence in enumerate(source): 108 | role = roles[sentence["from"]] 109 | assert role == conv.roles[j % 2], f"{i}" 110 | conv.append_message(role, sentence["value"]) 111 | conversations.append(conv.get_prompt()) 112 | 113 | # Tokenize conversations 114 | input_ids = tokenizer( 115 | conversations, 116 | return_tensors="pt", 117 | padding="max_length", 118 | max_length=tokenizer.model_max_length, 119 | truncation=True, 120 | ).input_ids 121 | targets = input_ids.clone() 122 | 123 | assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO 124 | 125 | # Mask targets. Only compute loss on the assistant outputs. 126 | sep = conv.sep + conv.roles[1] + ": " 127 | for conversation, target in zip(conversations, targets): 128 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 129 | 130 | turns = conversation.split(conv.sep2) 131 | cur_len = 1 132 | target[:cur_len] = IGNORE_TOKEN_ID 133 | for i, turn in enumerate(turns): 134 | if turn == "": 135 | break 136 | turn_len = len(tokenizer(turn).input_ids) 137 | 138 | parts = turn.split(sep) 139 | if len(parts) != 2: 140 | break 141 | parts[0] += sep 142 | # "-2" is hardcoded for the Llama tokenizer to make the offset correct. 143 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 144 | 145 | if i != 0 and not tokenizer.legacy: 146 | # The legacy and non-legacy modes handle special tokens differently 147 | instruction_len -= 1 148 | 149 | # Ignore the user instructions 150 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 151 | cur_len += turn_len 152 | 153 | if i != 0 and not tokenizer.legacy: 154 | # The legacy and non-legacy modes handle special tokens differently 155 | cur_len -= 1 156 | 157 | target[cur_len:] = IGNORE_TOKEN_ID 158 | 159 | if False: # Inspect and check the correctness of masking 160 | z = target.clone() 161 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 162 | rank0_print(tokenizer.decode(z)) 163 | exit() 164 | 165 | if cur_len < tokenizer.model_max_length: 166 | if cur_len != total_len: 167 | target[:] = IGNORE_TOKEN_ID 168 | rank0_print( 169 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 170 | f" #turn = {len(turns) - 1}. (ignored)" 171 | ) 172 | 173 | return dict( 174 | input_ids=input_ids, 175 | labels=targets, 176 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 177 | ) 178 | 179 | 180 | class SupervisedDataset(Dataset): 181 | """Dataset for supervised fine-tuning.""" 182 | 183 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 184 | super(SupervisedDataset, self).__init__() 185 | 186 | rank0_print("Formatting inputs...") 187 | sources = [example["conversations"] for example in raw_data] 188 | data_dict = preprocess(sources, tokenizer) 189 | 190 | self.input_ids = data_dict["input_ids"] 191 | self.labels = data_dict["labels"] 192 | self.attention_mask = data_dict["attention_mask"] 193 | 194 | def __len__(self): 195 | return len(self.input_ids) 196 | 197 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 198 | return dict( 199 | input_ids=self.input_ids[i], 200 | labels=self.labels[i], 201 | attention_mask=self.attention_mask[i], 202 | ) 203 | 204 | 205 | class LazySupervisedDataset(Dataset): 206 | """Dataset for supervised fine-tuning.""" 207 | 208 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 209 | super(LazySupervisedDataset, self).__init__() 210 | self.tokenizer = tokenizer 211 | 212 | rank0_print("Formatting inputs...Skip in lazy mode") 213 | self.tokenizer = tokenizer 214 | self.raw_data = raw_data 215 | self.cached_data_dict = {} 216 | 217 | def __len__(self): 218 | return len(self.raw_data) 219 | 220 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 221 | if i in self.cached_data_dict: 222 | return self.cached_data_dict[i] 223 | 224 | ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer) 225 | ret = dict( 226 | input_ids=ret["input_ids"][0], 227 | labels=ret["labels"][0], 228 | attention_mask=ret["attention_mask"][0], 229 | ) 230 | self.cached_data_dict[i] = ret 231 | 232 | return ret 233 | 234 | 235 | def make_supervised_data_module( 236 | tokenizer: transformers.PreTrainedTokenizer, data_args 237 | ) -> Dict: 238 | """Make dataset and collator for supervised fine-tuning.""" 239 | dataset_cls = ( 240 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset 241 | ) 242 | rank0_print("Loading data...") 243 | 244 | train_json = json.load(open(data_args.data_path, "r")) 245 | train_dataset = dataset_cls(train_json, tokenizer=tokenizer) 246 | 247 | if data_args.eval_data_path: 248 | eval_json = json.load(open(data_args.eval_data_path, "r")) 249 | eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer) 250 | else: 251 | eval_dataset = None 252 | 253 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 254 | 255 | 256 | def train(): 257 | global local_rank 258 | 259 | parser = transformers.HfArgumentParser( 260 | (ModelArguments, DataArguments, TrainingArguments) 261 | ) 262 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 263 | local_rank = training_args.local_rank 264 | 265 | # Set RoPE scaling factor 266 | config = transformers.AutoConfig.from_pretrained( 267 | model_args.model_name_or_path, 268 | cache_dir=training_args.cache_dir, 269 | trust_remote_code=model_args.trust_remote_code, 270 | ) 271 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 272 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 273 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 274 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 275 | config.use_cache = False 276 | 277 | # Load model and tokenizer 278 | model = transformers.AutoModelForCausalLM.from_pretrained( 279 | model_args.model_name_or_path, 280 | config=config, 281 | cache_dir=training_args.cache_dir, 282 | trust_remote_code=model_args.trust_remote_code, 283 | ) 284 | tokenizer = transformers.AutoTokenizer.from_pretrained( 285 | model_args.model_name_or_path, 286 | cache_dir=training_args.cache_dir, 287 | model_max_length=training_args.model_max_length, 288 | padding_side=model_args.padding_side, 289 | use_fast=False, 290 | trust_remote_code=model_args.trust_remote_code, 291 | ) 292 | 293 | if tokenizer.pad_token != tokenizer.unk_token: 294 | tokenizer.pad_token = tokenizer.unk_token 295 | 296 | # Load data 297 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 298 | 299 | # Start trainner 300 | trainer = Trainer( 301 | model=model, tokenizer=tokenizer, args=training_args, **data_module 302 | ) 303 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 304 | trainer.train(resume_from_checkpoint=True) 305 | else: 306 | trainer.train() 307 | 308 | # Save model 309 | model.config.use_cache = True 310 | trainer.save_state() 311 | if trainer.is_deepspeed_enabled: 312 | trainer.save_model() 313 | else: 314 | trainer_save_model_safe(trainer) 315 | 316 | 317 | if __name__ == "__main__": 318 | train() 319 | -------------------------------------------------------------------------------- /Self_Plan/Train/train_lora.py: -------------------------------------------------------------------------------- 1 | # Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> 2 | 3 | # Adapted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import json 19 | from dataclasses import dataclass, field 20 | import logging 21 | import pathlib 22 | import typing 23 | import os 24 | from datasets import load_dataset 25 | 26 | from deepspeed import zero 27 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 28 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 29 | import transformers 30 | from transformers import Trainer, BitsAndBytesConfig, deepspeed 31 | import torch 32 | 33 | from .train import ( 34 | DataArguments, 35 | ModelArguments, 36 | rank0_print 37 | ) 38 | 39 | from fastchat.train.llama2_flash_attn_monkey_patch import ( 40 | replace_llama_attn_with_flash_attn, 41 | ) 42 | 43 | # os.environ["CUDA_VISIBLE_DEVICES"] = 0 44 | 45 | FULL_PROMPT_TEMPLATE = "{input}\n### Response:\n{output}" 46 | NO_OUTPUT_PROMPT_TEMPLATE = "{input}\n### Response:\n" 47 | 48 | @dataclass 49 | class TrainingArguments(transformers.TrainingArguments): 50 | cache_dir: typing.Optional[str] = field(default=None) 51 | optim: str = field(default="adamw_torch") 52 | model_max_length: int = field( 53 | default=512, 54 | metadata={ 55 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 56 | }, 57 | ) 58 | flash_attn: bool = False 59 | resume_from_checkpoint: bool = False 60 | 61 | 62 | @dataclass 63 | class LoraArguments: 64 | lora_r: int = 8 65 | lora_alpha: int = 16 66 | lora_dropout: float = 0.05 67 | lora_target_modules: typing.List[str] = field( 68 | default_factory=lambda: ["q_proj", "v_proj"] 69 | ) 70 | lora_weight_path: str = "" 71 | lora_bias: str = "none" 72 | q_lora: bool = False 73 | 74 | 75 | def maybe_zero_3(param): 76 | if hasattr(param, "ds_id"): 77 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 78 | with zero.GatheredParameters([param]): 79 | param = param.data.detach().cpu().clone() 80 | else: 81 | param = param.detach().cpu().clone() 82 | return param 83 | 84 | 85 | # Borrowed from peft.utils.get_peft_model_state_dict 86 | def get_peft_state_maybe_zero_3(named_params, bias): 87 | if bias == "none": 88 | to_return = {k: t for k, t in named_params if "lora_" in k} 89 | elif bias == "all": 90 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 91 | elif bias == "lora_only": 92 | to_return = {} 93 | maybe_lora_bias = {} 94 | lora_bias_names = set() 95 | for k, t in named_params: 96 | if "lora_" in k: 97 | to_return[k] = t 98 | bias_name = k.split("lora_")[0] + "bias" 99 | lora_bias_names.add(bias_name) 100 | elif "bias" in k: 101 | maybe_lora_bias[k] = t 102 | for k, t in maybe_lora_bias: 103 | if bias_name in lora_bias_names: 104 | to_return[bias_name] = t 105 | else: 106 | raise NotImplementedError 107 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 108 | return to_return 109 | 110 | 111 | def load_and_process_dataset(tokenizer, data_args): 112 | rank0_print("Loading data...") 113 | 114 | def tokenize(prompt, add_eos_token=True): 115 | result = tokenizer( 116 | prompt, 117 | truncation=True, 118 | max_length=tokenizer.model_max_length, 119 | padding=False, 120 | return_tensors=None, 121 | ) 122 | if ( 123 | result["input_ids"][-1] != tokenizer.eos_token_id 124 | and len(result["input_ids"]) < tokenizer.model_max_length 125 | and add_eos_token 126 | ): 127 | result["input_ids"].append(tokenizer.eos_token_id) 128 | result["attention_mask"].append(1) 129 | 130 | result["labels"] = result["input_ids"].copy() 131 | 132 | return result 133 | 134 | def generate_and_tokenize_prompt(data_point): 135 | full_prompt = FULL_PROMPT_TEMPLATE.format( 136 | input=data_point["input"], 137 | output=data_point["output"] 138 | ) 139 | tokenized_full_prompt = tokenize(full_prompt) 140 | user_prompt = NO_OUTPUT_PROMPT_TEMPLATE.format(input=data_point["input"]) 141 | tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) 142 | user_prompt_len = len(tokenized_user_prompt["input_ids"]) 143 | tokenized_full_prompt["labels"] = [ 144 | -100 145 | ] * user_prompt_len + tokenized_full_prompt["labels"][ 146 | user_prompt_len: 147 | ] 148 | return tokenized_full_prompt 149 | 150 | train_data = load_dataset("json", data_files=data_args.data_path) 151 | if data_args.eval_data_path: 152 | eval_data = load_dataset("json", data_files=data_args.data_path) 153 | eval_data = eval_data["train"].shuffle().map(generate_and_tokenize_prompt) 154 | else: 155 | eval_data = None 156 | train_data = train_data["train"].shuffle().map(generate_and_tokenize_prompt) 157 | return train_data, eval_data 158 | 159 | 160 | 161 | 162 | def train(): 163 | parser = transformers.HfArgumentParser( 164 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 165 | ) 166 | ( 167 | model_args, 168 | data_args, 169 | training_args, 170 | lora_args, 171 | ) = parser.parse_args_into_dataclasses() 172 | 173 | if training_args.flash_attn: 174 | replace_llama_attn_with_flash_attn() 175 | 176 | device_map = None 177 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 178 | ddp = world_size != 1 179 | if lora_args.q_lora: 180 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None 181 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): 182 | logging.warning( 183 | "FSDP and ZeRO3 are both currently incompatible with QLoRA." 184 | ) 185 | 186 | compute_dtype = ( 187 | torch.float16 188 | if training_args.fp16 189 | else (torch.bfloat16 if training_args.bf16 else torch.float32) 190 | ) 191 | print("load_begin") 192 | model = transformers.AutoModelForCausalLM.from_pretrained( 193 | model_args.model_name_or_path, 194 | cache_dir=training_args.cache_dir, 195 | device_map=device_map, 196 | quantization_config=BitsAndBytesConfig( 197 | load_in_4bit=True, 198 | bnb_4bit_use_double_quant=True, 199 | bnb_4bit_quant_type="nf4", 200 | bnb_4bit_compute_dtype=compute_dtype, 201 | ) 202 | if lora_args.q_lora 203 | else None, 204 | torch_dtype=compute_dtype 205 | ) 206 | print("load_success") 207 | lora_config = LoraConfig( 208 | r=lora_args.lora_r, 209 | lora_alpha=lora_args.lora_alpha, 210 | target_modules=lora_args.lora_target_modules, 211 | lora_dropout=lora_args.lora_dropout, 212 | bias=lora_args.lora_bias, 213 | task_type="CAUSAL_LM", 214 | ) 215 | 216 | if lora_args.q_lora: 217 | model = prepare_model_for_kbit_training( 218 | model, use_gradient_checkpointing=training_args.gradient_checkpointing 219 | ) 220 | if not ddp and torch.cuda.device_count() > 1: 221 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 222 | model.is_parallelizable = True 223 | model.model_parallel = True 224 | 225 | model = get_peft_model(model, lora_config) 226 | if training_args.flash_attn: 227 | for name, module in model.named_modules(): 228 | if "norm" in name: 229 | module = module.to(compute_dtype) 230 | if "lm_head" in name or "embed_tokens" in name: 231 | if hasattr(module, "weight"): 232 | module = module.to(compute_dtype) 233 | if training_args.deepspeed is not None and training_args.local_rank == 0: 234 | model.print_trainable_parameters() 235 | 236 | if training_args.gradient_checkpointing: 237 | model.enable_input_require_grads() 238 | 239 | tokenizer = transformers.AutoTokenizer.from_pretrained( 240 | model_args.model_name_or_path, 241 | cache_dir=training_args.cache_dir, 242 | model_max_length=training_args.model_max_length, 243 | padding_side="left", 244 | use_fast=False, 245 | ) 246 | tokenizer.pad_token = tokenizer.unk_token 247 | 248 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # same as unk token id 249 | model.config.bos_token_id = tokenizer.bos_token_id = 1 250 | model.config.eos_token_id = tokenizer.eos_token_id = 2 251 | 252 | # data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 253 | train_data, eval_data = load_and_process_dataset(tokenizer=tokenizer, data_args=data_args) 254 | 255 | trainer = Trainer( 256 | model=model, 257 | tokenizer=tokenizer, 258 | args=training_args, 259 | # **data_module 260 | train_dataset=train_data, 261 | eval_dataset=eval_data, 262 | data_collator=transformers.DataCollatorForSeq2Seq( 263 | tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True 264 | ), 265 | ) 266 | 267 | model.config.use_cache = False 268 | 269 | print("all_right") 270 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")) and training_args.resume_from_checkpoint : 271 | trainer.train(resume_from_checkpoint=True) 272 | else: 273 | trainer.train() 274 | trainer.save_state() 275 | 276 | # check if zero3 mode enabled 277 | if deepspeed.is_deepspeed_zero3_enabled(): 278 | # use deepspeed engine internal function to gather state dict 279 | # state_dict_zero3 contains whole parameters of base and lora adapters 280 | # we will not extract lora parameters since peft save_pretrained will do that 281 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 282 | # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 283 | state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 284 | if training_args.local_rank == 0: 285 | state_dict = state_dict_zero3 286 | else: 287 | # in other mode we use original code from fastchat team, to make sure our change is minimum 288 | state_dict = get_peft_state_maybe_zero_3( 289 | model.named_parameters(), lora_args.lora_bias 290 | ) 291 | 292 | if training_args.local_rank == 0: 293 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 294 | 295 | 296 | if __name__ == "__main__": 297 | train() 298 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------