├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── blocksworld_control.py ├── blocksworld_rot.sh ├── examples ├── __init__.py ├── blocksworld │ ├── README.md │ ├── bfs.py │ ├── cot.py │ ├── data │ │ ├── bw_basic.json │ │ ├── bw_basic_sup.json │ │ ├── bw_config.yaml │ │ ├── dataset.ipynb │ │ ├── deceptive_dataset.ipynb │ │ ├── full.json │ │ ├── full_data │ │ │ ├── step_10.json │ │ │ ├── step_12.json │ │ │ ├── step_14.json │ │ │ ├── step_16.json │ │ │ ├── step_2.json │ │ │ ├── step_4.json │ │ │ ├── step_6.json │ │ │ └── step_8.json │ │ ├── full_mystory.json │ │ ├── generated_domain.pddl │ │ ├── process_pddl.ipynb │ │ ├── split_v1 │ │ │ ├── split_v1_step_10_data.json │ │ │ ├── split_v1_step_12_data.json │ │ │ ├── split_v1_step_2_data.json │ │ │ ├── split_v1_step_4_data.json │ │ │ ├── split_v1_step_6_data.json │ │ │ └── split_v1_step_8_data.json │ │ ├── split_v2 │ │ │ ├── split_v2_step_10_data.json │ │ │ ├── split_v2_step_12_data.json │ │ │ ├── split_v2_step_2_data.json │ │ │ ├── split_v2_step_4_data.json │ │ │ ├── split_v2_step_6_data.json │ │ │ └── split_v2_step_8_data.json │ │ ├── step_2.json │ │ ├── step_4.json │ │ ├── step_6.json │ │ ├── step_8.json │ │ ├── under_6.json │ │ └── under_8.json │ ├── rap.py │ ├── search_config.py │ └── world_model.py └── gsm8k │ ├── __init__.py │ ├── aggregate.py │ ├── bfs.py │ ├── cot.py │ ├── rap.py │ ├── search_config.py │ ├── utils.py │ └── world_model.py ├── gsm8k_control.py ├── gsm8k_rot.sh ├── gsm8k_summarization.json ├── prompts ├── bw │ ├── mistral │ │ ├── pool_prompt_v2_step_10_rot.json │ │ ├── pool_prompt_v2_step_2_rot.json │ │ ├── pool_prompt_v2_step_4_rot.json │ │ ├── pool_prompt_v2_step_6_rot.json │ │ └── pool_prompt_v2_step_8_rot.json │ ├── mixtral │ │ ├── pool_prompt_v2_step_10_rot.json │ │ ├── pool_prompt_v2_step_2_rot.json │ │ ├── pool_prompt_v2_step_4_rot.json │ │ ├── pool_prompt_v2_step_6_rot.json │ │ └── pool_prompt_v2_step_8_rot.json │ ├── phi-2 │ │ ├── pool_prompt_v2_step_10_rot.json │ │ ├── pool_prompt_v2_step_2_rot.json │ │ ├── pool_prompt_v2_step_4_rot.json │ │ ├── pool_prompt_v2_step_6_rot.json │ │ └── pool_prompt_v2_step_8_rot.json │ ├── pool_prompt.json │ ├── pool_prompt_v0.json │ ├── pool_prompt_v1.json │ ├── pool_prompt_v2_step_10.json │ ├── pool_prompt_v2_step_10_rot.json │ ├── pool_prompt_v2_step_10_template.json │ ├── pool_prompt_v2_step_2.json │ ├── pool_prompt_v2_step_2_rot.json │ ├── pool_prompt_v2_step_2_template.json │ ├── pool_prompt_v2_step_4.json │ ├── pool_prompt_v2_step_4_rot.json │ ├── pool_prompt_v2_step_4_template.json │ ├── pool_prompt_v2_step_6.json │ ├── pool_prompt_v2_step_6_rot.json │ ├── pool_prompt_v2_step_6_rot_iter_1.json │ ├── pool_prompt_v2_step_6_rot_iter_2.json │ ├── pool_prompt_v2_step_6_rot_iter_3.json │ ├── pool_prompt_v2_step_6_rot_iter_4.json │ ├── pool_prompt_v2_step_6_rot_iter_5.json │ ├── pool_prompt_v2_step_6_rot_phi_2.json │ ├── pool_prompt_v2_step_6_template.json │ ├── pool_prompt_v2_step_8.json │ ├── pool_prompt_v2_step_8_rot.json │ ├── pool_prompt_v2_step_8_template.json │ └── prompt.json └── gsm8k │ ├── cot_default.json │ ├── cot_rot.json │ ├── mistral │ ├── cot_rot_mistral.json │ └── prompt_pool_improved_mistral.json │ ├── mixtral │ ├── cot_rot_mixtral.json │ └── prompt_pool_improved_mixtral.json │ ├── phi-2 │ ├── cot_rot_mixtral.json │ └── prompt_pool_improved_phi-2.json │ ├── prompt_pool.json │ ├── prompt_pool_rot.json │ ├── prompt_pool_rot_2_iter.json │ ├── prompt_pool_rot_3_iter.json │ ├── prompt_pool_rot_4_iter.json │ ├── prompt_pool_rot_5_iter.json │ ├── prompt_pool_rot_all.json │ ├── prompt_pool_rot_promsing_0.5.json │ ├── prompt_pool_rot_random.json │ ├── prompt_pool_rot_task.json │ ├── prompt_pool_template.json │ └── useful_examples.json ├── reasoners ├── __init__.py ├── algorithm │ ├── __init__.py │ ├── beam_search.py │ ├── bfs.py │ └── mcts.py ├── base.py ├── benchmark │ ├── __init__.py │ ├── blocksworld.py │ ├── bw_utils.py │ └── gsm8k.py ├── visualization │ ├── __init__.py │ ├── __main__.py │ ├── analyze.py │ ├── tree_log.py │ ├── tree_snapshot.py │ └── visualizer_client.py └── vllm_model.py ├── requirements.txt ├── rot_bargain ├── .gitignore ├── README.md ├── bargain_control.py ├── core.py ├── mcts_reflect.py ├── prompts │ ├── rot_w_panelty.json │ └── rot_wo_panelty.json ├── run │ ├── run_cot.py │ └── run_mcts.py ├── run_bargain.sh ├── user_simulators │ ├── __init__.py │ ├── mcts_simulator.py │ ├── naive_simulator.py │ └── strategy_simulator.py └── utils.py ├── rot_scripts ├── blocksworld_analysis.py ├── blocksworld_generate_rot_prompt.py ├── gpt4_utils.py ├── gsm8k_analysis.py └── gsm8k_generate_rot_prompt.py └── vllm-server ├── mixtral.sh ├── phi-2.sh └── vllm_api.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | logs 3 | *.out 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "LLMs-Planning"] 2 | path = LLMs-Planning 3 | url = https://github.com/karthikv792/LLMs-Planning -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RoT: Enhancing Large Language Models with Reflection on Search Trees 2 | 3 | RoT is a reflection framework designed to improve the performance of tree-search-based prompting methods and non-tree-search-based prompting methods such as [RAP](https://arxiv.org/abs/2305.14992), [ToT](https://arxiv.org/abs/2305.10601), and CoT based on previous valuable search experiences. 4 | 5 | --- 6 | This repo contains the implementation and experiment code of Blocksworld and GSM8K. For the implementation of CraigslistBargain, see [RoT dialogue](rot-bargain), as the tree search process with a stochastic environment is much different from the deterministic ones. 7 | 8 | ## Quick Start 9 | Install the required libraries. 10 | ```bash 11 | conda create -n rot python=3.10 12 | conda activate rot 13 | 14 | git clone https://github.com/huiwy/reflection-on-trees --recursive 15 | cd reflection-on-trees 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | RoT uses [vllm](https://github.com/vllm-project/vllm) to support efficient text generation, so you need to first launch a vllm service. 20 | ```bash 21 | cd vllm-server 22 | sh phi-2.sh 23 | ``` 24 | 25 | Then you can run RoT to generate the new prompts with guidelines based on the served model. 26 | ```bash 27 | export OPENAI_API_BASE=xxx 28 | export OPENAI_API_KEY=xxx 29 | 30 | sh blocksworld_rot.sh prompts/bw/pool_prompt_rot.json # the prompt with RoT are generated at prompts/bw/pool_prompt_rot.json 31 | sh gsm8k_rot.sh prompts/gsm8k/prompt_pool_rot.json # the prompt with RoT are generated at prompts/gsm8k/prompt_pool_rot.json 32 | ``` 33 | 34 | Finally add the genereted prompt to prompt dict in `blocksword_control.py` or `gsm8k_control.py`: 35 | ```python 36 | prompt_path = { 37 | 'default': 'prompts/gsm8k/prompt_pool.json', 38 | 'rot': 'prompts/gsm8k/prompt_pool_rot.json', 39 | ... 40 | + 'rot-new': 'prompts/gsm8k/prompt_pool_rot.json' 41 | } 42 | ``` 43 | 44 | and run with the new prompt with guidelines: 45 | 46 | ```bash 47 | python gsm8k_control.py --mode mcts --n_iter 10 --split train --prompt rot-new 48 | ``` 49 | 50 | ## Acknowledgement 51 | This repo is built on [llm-reasoner](https://github.com/Ber666/llm-reasoners). -------------------------------------------------------------------------------- /blocksworld_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from typing import Literal, Optional 5 | from dataclasses import dataclass 6 | from datetime import datetime 7 | 8 | os.environ['VAL'] = 'LLMs-Planning/planner_tools/VAL' 9 | 10 | model_ip = 'http://10.234.38.2:23100/v1' 11 | 12 | prompt_path = { 13 | 'default': 'prompts/bw/pool_prompt_v2_step_{step}.json', 14 | 'rot': 'prompts/bw/pool_prompt_v2_step_{step}_rot.json', 15 | 'rot-iter-2': 'prompts/bw/pool_prompt_v2_step_{step}_rot_iter_2.json', 16 | 'rot-iter-3': 'prompts/bw/pool_prompt_v2_step_{step}_rot_iter_3.json', 17 | 'rot-iter-4': 'prompts/bw/pool_prompt_v2_step_{step}_rot_iter_4.json', 18 | 'rot-iter-5': 'prompts/bw/pool_prompt_v2_step_{step}_rot_iter_5.json', 19 | } 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--prompt', type=str, default='default') 24 | parser.add_argument('--model', type=str, default='phi-2') 25 | parser.add_argument('--mode', type=str, default='cot') 26 | parser.add_argument('--n_iters', type=int, default=-1) 27 | parser.add_argument('--width', type=int, default=-1) 28 | parser.add_argument('--step', type=int, default=4) 29 | parser.add_argument('--split', type=str, default='test') 30 | args = parser.parse_args() 31 | 32 | os.environ['VLLM_API_BASE'] = model_ip 33 | 34 | if args.mode == 'cot': 35 | command = f'python examples/blocksworld/cot.py --hf_path {args.model}' 36 | elif args.mode == 'bfs': 37 | command = f'python examples/blocksworld/bfs.py --hf_path {args.model} --depth_limit {args.step}' 38 | if args.width != -1: 39 | command += f' --width {args.width}' 40 | else: 41 | command = f'python examples/blocksworld/rap.py --hf_path {args.model} --depth_limit {args.step} --output_trace_in_each_iter' 42 | if args.n_iters != -1: 43 | command += f' --n_iters {args.n_iters}' 44 | 45 | command += f' --prompt_path {prompt_path[args.prompt].format(step=args.step)}' 46 | 47 | log_path = f'logs/bw/step_{args.step}/{args.model}_{args.prompt}_{args.mode}_{args.n_iters}_{datetime.now().strftime("%Y%m%d-%H%M")}' 48 | 49 | command += f' --data_path examples/blocksworld/data/split_v2/split_v2_step_{args.step}_data.json' 50 | command += f' --batch_size 1' 51 | command += f' --log_dir {log_path}' 52 | command += ' --config_file examples/blocksworld/data/bw_config.yaml' 53 | print(command) 54 | 55 | import time 56 | t = time.time() 57 | os.system(command) 58 | time_consumed = time.time() - t 59 | 60 | with open(f'{log_path}/time_consumed.txt', 'w') as f: 61 | f.write(str(time_consumed)) -------------------------------------------------------------------------------- /blocksworld_rot.sh: -------------------------------------------------------------------------------- 1 | output_name=$1 2 | 3 | outputs=$(python blocksworld_control.py --mode mcts --n_iter 10) 4 | log_dir=$(echo $outputs | grep -oP 'log_dir [^ ]*' | cut -d' ' -f2) 5 | 6 | python rot_scripts/blocksworld_analysis.py --path $log_dir/algo_output --steps 4 --output_name $log_dir/rot_analysis.json 7 | python rot_scripts/blocksworld_generate_rot_prompt.py --path $log_dir/rot_analysis.json --output_name $output_name --steps 4 -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiwy/reflection-on-trees/dffe87451e98396bc4cc55fc3fbd49a8e2ed557f/examples/__init__.py -------------------------------------------------------------------------------- /examples/blocksworld/README.md: -------------------------------------------------------------------------------- 1 | # Blocksworld 2 | 3 | ## Data 4 | The full [Blocksworld](https://arxiv.org/abs/2305.15771) datasets contain 602 samples. 5 | 6 | Our experiments are conducted in two distinct settings: Hard (v1) and Easy (v2). In Easy setting, we assume prior knowledge of the minimum number of actions for each case. Leveraging this information, we use demonstration cases that share the same minimum number of actions as the test case. For each group of cases, we randomly select 10 cases to create a pool of demonstration cases, leaving the remaining cases as the test set. During inference, we randomly sample 4-shot demonstration cases from this pool and utilize them to formulate prompts. In the Hard setting, we randomly select 10 cases from the full dataset to form a demonstration pool and subsequently exclude these cases from the test set. During inference, we randomly sample 4-shot demonstration cases from this global pool, irrespective of the minimum number of actions required for the test case. 7 | 8 | We provide the script to reproduce the results of [CoT](test_cot.sh) (for both easy and hard) and RAP ([easy](test_rap_v2.sh) and [hard](test_rap_v1.sh)). 9 | 10 | If you want to modify the experiment settings or develop your own method, you may look at `cot_inference.py` or `rap_inference.py`. 11 | 12 | The [data files](data) may contain absolute paths, please replace them with your own paths before running. -------------------------------------------------------------------------------- /examples/blocksworld/bfs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | sys.path.insert(1, path) 6 | 7 | from typing import Type, Callable, Optional 8 | 9 | import numpy as np 10 | 11 | from reasoners import Reasoner, SearchAlgorithm 12 | from reasoners.benchmark import BWEvaluator, blocksworld 13 | from reasoners.algorithm import BFSNode, BFS 14 | 15 | from world_model import BlocksWorldModel 16 | from search_config import BWConfig 17 | import json 18 | def rap_bw(hf_path: str, 19 | prompt_path: str, 20 | search_algo: Type[SearchAlgorithm] = BFS, 21 | data_path: str = 'data', 22 | resume: int = 0, 23 | depth_limit: int = 6, 24 | reward_alpha: float = 0.5, 25 | batch_size = 1, 26 | goal_reached_reward = 100, 27 | goal_reward_default = 0., 28 | log_dir: Optional[str] = None, 29 | disable_log: bool = False, 30 | domain_file: str = "examples/blocksworld/data/generated_domain.pddl", 31 | config_file: str = "", 32 | lm_plan_file: str = 'lm_plan.tmp', 33 | **search_algo_params): 34 | 35 | 36 | from reasoners import VLLMModel 37 | base_model = VLLMModel(model=hf_path) 38 | 39 | prompt = json.load(open(prompt_path, 'r')) 40 | 41 | search_algo_params |= {"depth_limit": depth_limit, "disable_tqdm": False} 42 | world_model = BlocksWorldModel(base_model=base_model, prompt=prompt, batch_size=batch_size, max_steps=depth_limit) 43 | config = BWConfig(base_model=base_model, prompt=prompt, batch_size=batch_size, 44 | reward_alpha=reward_alpha, goal_reached_reward=goal_reached_reward, 45 | goal_reward_default=goal_reward_default) 46 | search_algo = BFS(**search_algo_params) 47 | reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo) 48 | evaluator = BWEvaluator(config_file=config_file, domain_file=domain_file, data_path=data_path, init_prompt=prompt, disable_log=disable_log) 49 | accuracy = evaluator.evaluate(reasoner, shuffle_prompt=True, num_shot=4, resume=resume, log_dir=log_dir) 50 | print(accuracy) 51 | 52 | if __name__ == '__main__': 53 | import fire 54 | fire.Fire(rap_bw) # user will need to switch the model in the code 55 | -------------------------------------------------------------------------------- /examples/blocksworld/cot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | sys.path.insert(1, path) 6 | 7 | import json 8 | from reasoners.benchmark import BWEvaluator 9 | import fire 10 | 11 | class CoTReasoner(): 12 | def __init__(self, base_model, temperature=0.8): 13 | self.base_model = base_model 14 | self.temperature = temperature 15 | def __call__(self, example, prompt=None): 16 | inputs = prompt["icl"].replace("", example["init"])\ 17 | .replace("", example["goal"]).replace("", "") 18 | output = self.base_model.generate([inputs], 19 | do_sample=True, 20 | temperature=self.temperature, 21 | stop='\n[').text[0].strip() 22 | return output 23 | 24 | def main(hf_path, data_path, prompt_path, disable_log=False, batch_size=1, config_file: str = "examples/blocksworld/data/bw_config.yaml", domain_file: str = "examples/blocksworld/data/generated_domain.pddl", resume=0, log_dir=None, temperature=0.8, exllama_mem_map: str = None): 25 | 26 | from reasoners import VLLMModel 27 | base_model = VLLMModel(model=hf_path) 28 | 29 | 30 | with open(prompt_path) as f: 31 | prompt = json.load(f) 32 | 33 | reasoner = CoTReasoner(base_model, temperature=temperature) 34 | evaluator = BWEvaluator(config_file=config_file, domain_file=domain_file, data_path=data_path, init_prompt=prompt, disable_log=disable_log, output_extractor=lambda x:x, sample_prompt_type="rap") # rap prompt includes cot 35 | accuracy = evaluator.evaluate(reasoner, shuffle_prompt=True, num_shot=4, resume=resume, log_dir=log_dir) 36 | print(f'accuracy: {accuracy:.4f}') 37 | return 0 38 | 39 | if __name__ == '__main__': 40 | fire.Fire(main) -------------------------------------------------------------------------------- /examples/blocksworld/data/bw_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | domain_intro: | 3 | I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do 4 | 5 | Pick up a block 6 | Unstack a block from on top of another block 7 | Put down a block 8 | Stack a block on top of another block 9 | 10 | I have the following restrictions on my actions: 11 | I can only pick up or unstack one block at a time. 12 | I can only pick up or unstack a block if my hand is empty. 13 | I can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up. 14 | I can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block. 15 | I can only unstack a block from on top of another block if the block I am unstacking is clear. 16 | Once I pick up or unstack a block, I am holding the block. 17 | I can only put down a block that I am holding. 18 | I can only stack a block on top of another block if I am holding the block being stacked. 19 | I can only stack a block on top of another block if the block onto which I am stacking the block is clear. 20 | Once I put down or stack a block, my hand becomes empty. 21 | 22 | domain_intro_cost: | 23 | I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do: 24 | 25 | Pick up a block. It takes 1 minute to pick up a block. 26 | Unstack a block from on top of another block. It takes 1 minute to unstack a block from on top of another block. 27 | Put down a block. It takes 1 minute to put down a block. 28 | Stack a block on top of another block. It takes 1 minute to stack a block on top of another block. 29 | 30 | I have the following restrictions on my actions: 31 | I can only pick up or unstack one block at a time. 32 | I can only pick up or unstack a block if my hand is empty. 33 | I can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up. 34 | I can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block. 35 | I can only unstack a block from on top of another block if the block I am unstacking is clear. 36 | Once I pick up or unstack a block, I am holding the block. 37 | I can only put down a block that I am holding. 38 | I can only stack a block on top of another block if I am holding the block being stacked. 39 | I can only stack a block on top of another block if the block onto which I am stacking the block is clear. 40 | Once I put down or stack a block, my hand becomes empty. 41 | 42 | actions: 43 | {pick-up: "pick up the {}", 44 | put-down: "put down the {}", 45 | stack: "stack the {} on top of the {}", 46 | unstack: "unstack the {} from on top of the {}"} 47 | 48 | objects: 49 | - blue block 50 | - orange block 51 | - red block 52 | - yellow block 53 | 54 | predicates: 55 | {ontable: "the {} is on the table", 56 | clear: "the {} is clear", 57 | handempty: "the hand is empty", 58 | "on": "the {} is on top of the {}"} 59 | 60 | # encoded_objects: dictionary of object names, have to be alphabetical 61 | encoded_objects: 62 | {"a": "red block", "b": "blue block", "c": "orange block", "d": "yellow block", 63 | "e": "white block", "f": "magenta block", "g": "black block", "h": "cyan block", 64 | "i": "green block", "j": "violet block", "k": "silver block", "l": "gold block" } 65 | 66 | callbacks: 67 | - t1_gen_goal_directed_instances -------------------------------------------------------------------------------- /examples/blocksworld/data/full_data/step_14.json: -------------------------------------------------------------------------------- 1 | [["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-452.pddl", "(unstack d a)\n(put-down d)\n(unstack a e)\n(stack a d)\n(unstack e c)\n(stack e a)\n(unstack c b)\n(put-down c)\n(pick-up b)\n(stack b c)\n(unstack e a)\n(stack e b)\n(unstack a d)\n(stack a e)\n", 14]] -------------------------------------------------------------------------------- /examples/blocksworld/data/full_data/step_16.json: -------------------------------------------------------------------------------- 1 | [["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-464.pddl", "(unstack c a)\n(put-down c)\n(unstack a b)\n(put-down a)\n(unstack b e)\n(put-down b)\n(unstack e d)\n(put-down e)\n(pick-up d)\n(stack d a)\n(pick-up c)\n(stack c d)\n(pick-up b)\n(stack b c)\n(pick-up e)\n(stack e b)\n", 16]] -------------------------------------------------------------------------------- /examples/blocksworld/data/full_data/step_2.json: -------------------------------------------------------------------------------- 1 | [["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-192.pddl", "(unstack d c)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-149.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-396.pddl", "(unstack c d)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-161.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-378.pddl", "(unstack a c)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-273.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-79.pddl", "(pick-up b)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-124.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-223.pddl", "(pick-up a)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-34.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-71.pddl", "(pick-up d)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-224.pddl", "(pick-up d)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-428.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-426.pddl", "(unstack c b)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-495.pddl", "(pick-up a)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-430.pddl", "(unstack c d)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-46.pddl", "(unstack b d)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-41.pddl", "(unstack b c)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-5.pddl", "(pick-up d)\n(stack d c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-123.pddl", "(unstack a d)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-252.pddl", "(unstack c d)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-362.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-164.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-312.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-31.pddl", "(unstack b d)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-21.pddl", "(pick-up b)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-142.pddl", "(pick-up a)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-170.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-70.pddl", "(unstack c b)\n(stack c d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-233.pddl", "(pick-up c)\n(stack c d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-20.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-73.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-100.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-36.pddl", "(unstack c a)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-8.pddl", "(pick-up a)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-93.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-45.pddl", "(pick-up a)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-86.pddl", "(pick-up c)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-41.pddl", "(pick-up c)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-96.pddl", "(unstack b a)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-92.pddl", "(unstack b c)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-48.pddl", "(pick-up a)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-33.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-47.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-77.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-89.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-83.pddl", "(pick-up b)\n(stack b a)\n", 2]] -------------------------------------------------------------------------------- /examples/blocksworld/data/generated_domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain blocksworld-4ops) 2 | (:requirements :strips) 3 | (:predicates (clear ?x) 4 | (ontable ?x) 5 | (handempty) 6 | (holding ?x) 7 | (on ?x ?y)) 8 | 9 | (:action pick-up 10 | :parameters (?ob) 11 | :precondition (and (clear ?ob) (ontable ?ob) (handempty)) 12 | :effect (and (holding ?ob) (not (clear ?ob)) (not (ontable ?ob)) 13 | (not (handempty)))) 14 | 15 | (:action put-down 16 | :parameters (?ob) 17 | :precondition (holding ?ob) 18 | :effect (and (clear ?ob) (handempty) (ontable ?ob) 19 | (not (holding ?ob)))) 20 | 21 | (:action stack 22 | :parameters (?ob ?underob) 23 | :precondition (and (clear ?underob) (holding ?ob)) 24 | :effect (and (handempty) (clear ?ob) (on ?ob ?underob) 25 | (not (clear ?underob)) (not (holding ?ob)))) 26 | 27 | (:action unstack 28 | :parameters (?ob ?underob) 29 | :precondition (and (on ?ob ?underob) (clear ?ob) (handempty)) 30 | :effect (and (holding ?ob) (clear ?underob) 31 | (not (on ?ob ?underob)) (not (clear ?ob)) (not (handempty))))) 32 | -------------------------------------------------------------------------------- /examples/blocksworld/data/split_v1/split_v1_step_2_data.json: -------------------------------------------------------------------------------- 1 | [["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-41.pddl", "(unstack b c)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-123.pddl", "(unstack a d)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-86.pddl", "(pick-up c)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-93.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-252.pddl", "(unstack c d)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-142.pddl", "(pick-up a)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-20.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-79.pddl", "(pick-up b)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-430.pddl", "(unstack c d)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-233.pddl", "(pick-up c)\n(stack c d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-83.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-71.pddl", "(pick-up d)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-73.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-41.pddl", "(pick-up c)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-273.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-34.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-495.pddl", "(pick-up a)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-192.pddl", "(unstack d c)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-164.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-46.pddl", "(unstack b d)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-45.pddl", "(pick-up a)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-396.pddl", "(unstack c d)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-170.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-149.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-77.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-5.pddl", "(pick-up d)\n(stack d c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-312.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-48.pddl", "(pick-up a)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-36.pddl", "(unstack c a)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-124.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-378.pddl", "(unstack a c)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-89.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-33.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-70.pddl", "(unstack c b)\n(stack c d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-223.pddl", "(pick-up a)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-92.pddl", "(unstack b c)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-8.pddl", "(pick-up a)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-96.pddl", "(unstack b a)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-21.pddl", "(pick-up b)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-47.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-362.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-161.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-100.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-31.pddl", "(unstack b d)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-426.pddl", "(unstack c b)\n(stack c a)\n", 2]] -------------------------------------------------------------------------------- /examples/blocksworld/data/split_v2/split_v2_step_2_data.json: -------------------------------------------------------------------------------- 1 | [["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-96.pddl", "(unstack b a)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-430.pddl", "(unstack c d)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-45.pddl", "(pick-up a)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-31.pddl", "(unstack b d)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-46.pddl", "(unstack b d)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-123.pddl", "(unstack a d)\n(stack a c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-224.pddl", "(pick-up d)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-86.pddl", "(pick-up c)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-93.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-92.pddl", "(unstack b c)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-79.pddl", "(pick-up b)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-252.pddl", "(unstack c d)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-170.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-428.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-73.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-223.pddl", "(pick-up a)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-362.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-142.pddl", "(pick-up a)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-426.pddl", "(unstack c b)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-192.pddl", "(unstack d c)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-378.pddl", "(unstack a c)\n(stack a d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-164.pddl", "(pick-up c)\n(stack c a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-41.pddl", "(pick-up c)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-89.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-312.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-70.pddl", "(unstack c b)\n(stack c d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-71.pddl", "(pick-up d)\n(stack d a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-100.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-161.pddl", "(pick-up d)\n(stack d b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-5.pddl", "(pick-up d)\n(stack d c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-21.pddl", "(pick-up b)\n(stack b d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-273.pddl", "(unstack a c)\n(stack a b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-396.pddl", "(unstack c d)\n(stack c b)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-83.pddl", "(pick-up b)\n(stack b a)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-77.pddl", "(pick-up b)\n(stack b c)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic/instance-233.pddl", "(pick-up c)\n(stack c d)\n", 2], ["LLMs-Planning/llm_planning_analysis/instances/blocksworld/generated_basic_3/instance-36.pddl", "(unstack c a)\n(stack c b)\n", 2]] -------------------------------------------------------------------------------- /examples/blocksworld/data/step_2.json: -------------------------------------------------------------------------------- 1 | [["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-5.pddl", "(pick-up yellow)\n(stack yellow orange)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-21.pddl", "(pick-up blue)\n(stack blue yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-31.pddl", "(unstack blue yellow)\n(stack blue red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-34.pddl", "(pick-up yellow)\n(stack yellow blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-41.pddl", "(unstack blue orange)\n(stack blue yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-46.pddl", "(unstack blue yellow)\n(stack blue orange)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-70.pddl", "(unstack orange blue)\n(stack orange yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-71.pddl", "(pick-up yellow)\n(stack yellow red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-79.pddl", "(pick-up blue)\n(stack blue yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-123.pddl", "(unstack red yellow)\n(stack red orange)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-124.pddl", "(pick-up orange)\n(stack orange red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-142.pddl", "(pick-up red)\n(stack red yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-149.pddl", "(pick-up blue)\n(stack blue orange)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-161.pddl", "(pick-up yellow)\n(stack yellow blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-164.pddl", "(pick-up orange)\n(stack orange red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-170.pddl", "(pick-up yellow)\n(stack yellow blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-192.pddl", "(unstack yellow orange)\n(stack yellow red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-223.pddl", "(pick-up red)\n(stack red yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-224.pddl", "(pick-up yellow)\n(stack yellow red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-233.pddl", "(pick-up orange)\n(stack orange yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-252.pddl", "(unstack orange yellow)\n(stack orange red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-273.pddl", "(unstack red orange)\n(stack red blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-312.pddl", "(pick-up yellow)\n(stack yellow blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-362.pddl", "(unstack red orange)\n(stack red blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-378.pddl", "(unstack red orange)\n(stack red yellow)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-396.pddl", "(unstack orange yellow)\n(stack orange blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-426.pddl", "(unstack orange blue)\n(stack orange red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-428.pddl", "(pick-up yellow)\n(stack yellow blue)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-430.pddl", "(unstack orange yellow)\n(stack orange red)\n", 2], ["gpt-plan-benchmark/gpt_plan_test/instances/generated_basic/instance-495.pddl", "(pick-up red)\n(stack red blue)\n", 2]] -------------------------------------------------------------------------------- /examples/blocksworld/rap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | sys.path.insert(1, path) 6 | 7 | from typing import Type, Callable, Optional 8 | 9 | import numpy as np 10 | 11 | from reasoners import Reasoner, SearchAlgorithm 12 | from reasoners.benchmark import BWEvaluator, blocksworld 13 | from reasoners.algorithm import MCTS, MCTSAggregation 14 | 15 | from world_model import BlocksWorldModel 16 | from search_config import BWConfig 17 | import json 18 | def rap_bw(hf_path: str, 19 | prompt_path: str, 20 | search_algo: Type[SearchAlgorithm] = MCTS, 21 | data_path: str = 'data', 22 | resume: int = 0, 23 | depth_limit: int = 6, 24 | reward_alpha: float = 0.5, 25 | batch_size = 1, 26 | goal_reached_reward = 100, 27 | goal_reward_default = 0., 28 | cum_reward: Callable[[list[float]], float] = sum, 29 | calc_q: Callable[[list[float]], float] = np.mean, 30 | log_dir: Optional[str] = None, 31 | disable_log: bool = False, 32 | domain_file: str = "examples/blocksworld/data/generated_domain.pddl", 33 | config_file: str = "", 34 | lm_plan_file: str = 'lm_plan.tmp', 35 | **search_algo_params): 36 | 37 | 38 | from reasoners import VLLMModel 39 | base_model = VLLMModel(model=hf_path) 40 | 41 | aggregator = MCTSAggregation(lambda x: x.history_actions, weight_policy='edge') 42 | 43 | prompt = json.load(open(prompt_path, 'r')) 44 | print(search_algo) 45 | 46 | search_algo_params |= {'cum_reward': cum_reward, 'calc_q': calc_q, "depth_limit": depth_limit, "disable_tqdm": False, 'aggregator': aggregator} 47 | world_model = BlocksWorldModel(base_model=base_model, prompt=prompt, batch_size=batch_size, max_steps=depth_limit) 48 | config = BWConfig(base_model=base_model, prompt=prompt, batch_size=batch_size, 49 | reward_alpha=reward_alpha, goal_reached_reward=goal_reached_reward, 50 | goal_reward_default=goal_reward_default) 51 | search_algo = MCTS(**search_algo_params) 52 | reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo) 53 | evaluator = BWEvaluator(config_file=config_file, domain_file=domain_file, data_path=data_path, init_prompt=prompt, disable_log=disable_log) 54 | accuracy = evaluator.evaluate(reasoner, shuffle_prompt=True, num_shot=4, resume=resume, log_dir=log_dir) 55 | print(accuracy) 56 | 57 | if __name__ == '__main__': 58 | import fire 59 | 60 | fire.Fire(rap_bw) # user will need to switch the model in the code 61 | -------------------------------------------------------------------------------- /examples/blocksworld/search_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import reasoners.benchmark.bw_utils as utils 4 | from world_model import BWState, BWAction 5 | from reasoners import SearchConfig 6 | 7 | class BWConfig(SearchConfig): 8 | def __init__(self, 9 | base_model, 10 | prompt: dict, 11 | batch_size=2, 12 | reward_alpha=0.5, 13 | goal_reward_default=0., 14 | goal_reached_reward=100) -> None: 15 | super().__init__() 16 | self.base_model = base_model 17 | self.example = None 18 | self.prompt = prompt 19 | self.batch_size = batch_size 20 | self.reward_alpha = reward_alpha 21 | self.goal_reward_default = goal_reward_default 22 | self.goal_reached_reward = goal_reached_reward 23 | 24 | def get_actions(self, state: BWState) -> list[BWAction]: 25 | blocks_state = state.blocks_state 26 | return utils.generate_all_actions(blocks_state) 27 | 28 | def fast_reward(self, state: BWState, action: BWAction) -> tuple[float, dict]: 29 | if state.buffered_action == "": 30 | # if no action buffered 31 | current_blocks_state = state.blocks_state 32 | else: 33 | # if action buffered 34 | current_blocks_state = state.last_blocks_state 35 | previous_action = state.buffered_action + "\n" if state.buffered_action != "" else "" 36 | 37 | icl_template = self.prompt["icl_list"][state.step_idx // 2] 38 | # every two step, we will deduct the icl prompt 39 | # so that the distribution of step length is more reasonable 40 | 41 | inputs = icl_template.replace("", current_blocks_state)\ 42 | .replace("", utils.extract_goals(self.example, return_raw=True)).replace("", previous_action) 43 | 44 | intuition = self.base_model.get_loglikelihood([inputs], [inputs + action])[0] 45 | 46 | self_eval_prompt = self.prompt["self-eval"].replace("", current_blocks_state)\ 47 | .replace("", utils.extract_goals(self.example, return_raw=True)).replace("", action) 48 | self_eval = self.base_model.get_loglikelihood([self_eval_prompt], 49 | [self_eval_prompt + "good"])[0] 50 | 51 | return self.calculate_reward(intuition, self_eval), {'intuition': intuition, "self_eval": self_eval} 52 | 53 | def calculate_reward(self, intuition, self_eval, goal_reached=None): 54 | # to provide a unified interface for reward and fast_reward 55 | if goal_reached is None: 56 | goal_reward = self.goal_reward_default 57 | elif goal_reached[0]: 58 | goal_reward = self.goal_reached_reward 59 | else: 60 | goal_reward = goal_reached[1] 61 | return (intuition + self_eval) * self.reward_alpha + goal_reward * (1 - self.reward_alpha) 62 | 63 | def reward(self, state: BWState, action: BWAction, 64 | intuition: float = None, 65 | self_eval: float = None, 66 | goal_reached: tuple[bool, float] = None) -> float: 67 | assert intuition is not None, "intuition is required to calculate reward in this search config, consider passing it in fast_reward" 68 | assert self_eval is not None, "self_eval is required to calculate reward in this search config, consider passing it in fast_reward" 69 | assert goal_reached is not None, "goal_reached is required to calculate reward in this search config, consider passing it in world model's step" 70 | return (self.calculate_reward(intuition, self_eval, goal_reached), 71 | {'intuition': intuition, 'goal_reached': goal_reached}) 72 | 73 | def update_example(self, example, prompt=None) -> None: 74 | super().update_example(example, prompt=prompt) 75 | -------------------------------------------------------------------------------- /examples/blocksworld/world_model.py: -------------------------------------------------------------------------------- 1 | """The world model for the Blocksworld.""" 2 | 3 | from typing import NamedTuple 4 | import reasoners.benchmark.bw_utils as utils 5 | from reasoners import WorldModel 6 | import copy 7 | 8 | BWAction = str 9 | class BWState(NamedTuple): 10 | """The state of the Blocksworld. 11 | 12 | See the docstring of BlocksWorldModel for more details. 13 | """ 14 | step_idx: int 15 | last_blocks_state: str 16 | blocks_state: str 17 | buffered_action: BWAction 18 | history_actions: BWAction = '' 19 | 20 | 21 | class BlocksWorldModel(WorldModel): 22 | """Blocks World World Model 23 | State: (step_idx, last_blocks_state, blocks_state, buffered_action) 24 | Action: e.g. "put the red block on the green block" 25 | Additional notes about the state: 26 | the block state is updated every two actions. When there is a block in hand, 27 | the block state is not updated, but the action is buffered. With the next action, 28 | the block state is updated and the buffer is cleared. 29 | """ 30 | 31 | def __init__(self, 32 | base_model, 33 | prompt: dict, 34 | max_steps: int = 6, 35 | batch_size=2) -> None: 36 | super().__init__() 37 | self.max_steps = max_steps 38 | self.base_model = base_model 39 | self.prompt = prompt 40 | self.batch_size = batch_size 41 | 42 | def init_state(self) -> BWState: 43 | """Initialize the world model. 44 | 45 | :return: the initial state 46 | """ 47 | return BWState(step_idx=0, last_blocks_state="", blocks_state=utils. 48 | extract_init_state(self.example), buffered_action="") 49 | 50 | def step(self, state: BWState, action: BWAction) -> tuple[BWState, dict]: 51 | """Take a step in the world model. 52 | 53 | :param state: the current state (see the docstring of BlocksWorldModel) 54 | :param action: the action to take 55 | :return: the next state and additional information cached for reward calculation 56 | """ 57 | state = copy.deepcopy(state) 58 | buffered_action = state.buffered_action 59 | blocks_state = state.blocks_state 60 | step_idx = state.step_idx 61 | blocks_state = self.update_blocks(blocks_state, action) 62 | if state.buffered_action == "": 63 | # if no action buffered, buffer the action 64 | new_buffered_action = action 65 | else: 66 | # if action buffered, clear the buffer 67 | new_buffered_action = "" 68 | 69 | state = BWState(step_idx=step_idx+1, last_blocks_state=state.blocks_state, 70 | blocks_state=blocks_state, buffered_action=new_buffered_action, history_actions=state.history_actions + ',' + action) 71 | return state, {"goal_reached": utils.goal_check(utils.extract_goals(self.example), blocks_state)} 72 | 73 | def update_blocks(self, block_states: str, action: BWAction) -> str: 74 | """Update the block states with the action. 75 | 76 | :param block_states: the current block states. Note that this argument is a string, 77 | and it's only a part of 'BWState' 78 | :param action: the action to take 79 | :return: the updated block states 80 | """ 81 | if "pick" in action: 82 | key = "world_update_pickup" 83 | elif "unstack" in action: 84 | key = "world_update_unstack" 85 | elif "put" in action: 86 | key = "world_update_putdown" 87 | elif "stack" in action: 88 | key = "world_update_stack" 89 | else: 90 | raise ValueError("Invalid action") 91 | world_update_prompt = self.prompt[key].format(block_states, action.capitalize() + ".") 92 | world_output = self.base_model.generate([world_update_prompt], 93 | stop="\n", hide_input=True, temperature=0).text[0].strip() 94 | new_state = utils.apply_change(world_output, block_states) 95 | return new_state 96 | 97 | def is_terminal(self, state: BWState) -> bool: 98 | if utils.goal_check(utils.extract_goals(self.example), state.blocks_state)[0]: 99 | return True 100 | elif state.step_idx == self.max_steps: 101 | return True 102 | return False 103 | -------------------------------------------------------------------------------- /examples/gsm8k/__init__.py: -------------------------------------------------------------------------------- 1 | from world_model import GSM8kState, GSM8kAction, GSM8kWorldModel 2 | from search_config import GSM8kConfig 3 | -------------------------------------------------------------------------------- /examples/gsm8k/aggregate.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Optional 3 | import glob 4 | import os 5 | 6 | from tqdm import tqdm 7 | from datasets import load_dataset 8 | 9 | from reasoners.algorithm import MCTSAggregation, MCTSResult 10 | 11 | import utils 12 | 13 | 14 | def aggregate_rap_gsm8k(log_dir: str, 15 | start: int = 0): 16 | aggregator = MCTSAggregation(utils.retrieve_answer, weight_policy='edge') 17 | files = glob.glob(f'{log_dir}/algo_output/*.pkl') 18 | indices = sorted(filter(lambda index: index >= start, (int(os.path.basename(f)[:-4]) for f in files))) 19 | dataset = load_dataset("gsm8k", "main", split=f'test') 20 | correct_count = 0 21 | for i, index in enumerate(tqdm(indices)): 22 | with open(f'{log_dir}/algo_output/{index}.pkl', 'rb') as f: 23 | result: MCTSResult = pickle.load(f) 24 | output = aggregator(result.tree_state) 25 | # output = utils.retrieve_answer(result.terminal_state) 26 | answer = utils.retrieve_answer_from_dataset(dataset[index - 1]['answer']) 27 | correct = utils.judge_answer(output, answer) 28 | 29 | correct_count += correct 30 | accuracy = correct_count / (i + 1) 31 | log_str = f'Case #{i + 1}({index}): {correct=}, {output=}, {answer=} ; {accuracy=:.3f} ({correct_count}/{i+1})' 32 | tqdm.write(log_str) 33 | 34 | 35 | if __name__ == '__main__': 36 | import fire 37 | fire.Fire(aggregate_rap_gsm8k) 38 | -------------------------------------------------------------------------------- /examples/gsm8k/bfs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | sys.path.insert(1, path) 6 | 7 | from typing import Type, Callable, Optional, Literal 8 | 9 | from reasoners.benchmark import GSM8KEvaluator 10 | 11 | from reasoners import Reasoner, SearchAlgorithm 12 | from reasoners.algorithm import BFSNode, BFS 13 | from reasoners import VLLMModel 14 | from world_model import GSM8kWorldModel, GSM8kState, GSM8kAction, GSM8kPromptDict 15 | from search_config import GSM8kConfig, GSM8kUsefulPrompt 16 | import utils 17 | 18 | 19 | def node_visualizer(x: BFSNode[GSM8kState, GSM8kAction]): 20 | if not x.state: 21 | return {} 22 | return {"question": x.state[-1].sub_question, "answer": x.state[-1].sub_answer} 23 | 24 | def rap_gsm8k(base_model, 25 | prompt: GSM8kPromptDict, 26 | useful_prompt: GSM8kUsefulPrompt, 27 | search_algo: Type[SearchAlgorithm] = BFS, 28 | resume: int = 0, 29 | n_action: int = 4, 30 | n_confidence: int = 8, 31 | depth_limit: int = 5, 32 | force_terminating_on_depth_limit: bool = True, 33 | batch_size: int = 4, 34 | temperature: float = 0.8, 35 | early_stop_base: int = 2, 36 | early_stop_threshold: float = 0.5, 37 | reward_alpha: float = 0.5, 38 | reward_confidence_default: float = 0.8, 39 | log_dir: Optional[str] = None, 40 | disable_log: bool = False, 41 | disable_tqdm: bool = False, 42 | samples: int = -1, 43 | split: str = 'test', 44 | **search_algo_params): 45 | 46 | search_algo_params |= {'disable_tqdm': disable_tqdm, 47 | 'node_visualizer': node_visualizer} 48 | 49 | world_model = GSM8kWorldModel(base_model=base_model, 50 | n_confidence=n_confidence, batch_size=batch_size, temperature=temperature, 51 | early_stop_base=early_stop_base, early_stop_threshold=early_stop_threshold) 52 | config = GSM8kConfig(base_model=base_model, useful_prompt=useful_prompt, 53 | n_actions=n_action, batch_size=batch_size, temperature=temperature, 54 | reward_alpha=reward_alpha, reward_confidence_default=reward_confidence_default, 55 | force_terminating_on_depth_limit=force_terminating_on_depth_limit, depth_limit=depth_limit) 56 | 57 | search_algo = search_algo(**search_algo_params) 58 | reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo) 59 | 60 | evaluator = GSM8KEvaluator(output_extractor=utils.retrieve_answer, 61 | answer_extractor=utils.retrieve_answer_from_dataset, 62 | init_prompt=prompt, 63 | sample_prompt_type="rap", 64 | disable_log=disable_log, 65 | disable_tqdm=disable_tqdm, samples=samples) 66 | 67 | accuracy = evaluator.evaluate(reasoner, num_shot=2, resume=resume, log_dir=log_dir, split=split) 68 | print(accuracy) 69 | 70 | 71 | if __name__ == '__main__': 72 | import os 73 | import sys 74 | import json 75 | import warnings 76 | import fire 77 | 78 | def main(prompt: str, 79 | hf_path: str = 'microsoft/phi-2', 80 | batch_size: int = 1, 81 | useful_prompt: str = 'prompts/gsm8k/useful_examples.json', 82 | disable_log: bool = False, 83 | disable_tqdm: bool = False, 84 | split: str = 'test', 85 | **kwargs): 86 | 87 | with open(useful_prompt) as f: 88 | useful_prompt = json.load(f) 89 | with open(prompt) as f: 90 | prompt = json.load(f) 91 | base_model = VLLMModel(model=hf_path) 92 | rap_gsm8k(base_model=base_model, 93 | useful_prompt=useful_prompt, 94 | prompt=prompt, 95 | batch_size=batch_size, 96 | disable_log=disable_log, 97 | disable_tqdm=disable_tqdm, 98 | split=split, 99 | **kwargs) 100 | 101 | 102 | fire.Fire(main) 103 | -------------------------------------------------------------------------------- /examples/gsm8k/cot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | sys.path.insert(1, path) 6 | 7 | import json 8 | from reasoners.benchmark import GSM8KEvaluator 9 | from reasoners import VLLMModel 10 | 11 | import utils 12 | import fire 13 | 14 | class CoTReasoner(): 15 | def __init__(self, base_model, n_sc=1, temperature=0.8, bs=10): 16 | assert n_sc == 1 or temperature > 0, \ 17 | "Temperature = 0 indicates greedy decoding. There is no point running multiple chains (n_sc > 1)" 18 | self.base_model = base_model 19 | self.temperature = temperature 20 | self.n_sc = n_sc 21 | self.bs = bs 22 | def __call__(self, example, prompt=None): 23 | inputs = prompt["cot"].replace("{QUESTION}", example) 24 | outputs = [] 25 | for i in range((self.n_sc - 1) // self.bs + 1): 26 | local_bs = min(self.bs, self.n_sc - i * self.bs) 27 | outputs += self.base_model.generate([inputs] * local_bs, 28 | do_sample=True, 29 | temperature=self.temperature, 30 | stop='\n').text 31 | return [o.strip() for o in outputs] 32 | 33 | def main(hf_path, prompt, batch_size=1, resume=0, log_dir=None, temperature=0.8, n_sc=1, split='test'): 34 | 35 | 36 | base_model = VLLMModel(model=hf_path) 37 | 38 | with open(prompt) as f: 39 | prompt = json.load(f) 40 | 41 | reasoner = CoTReasoner(base_model, temperature=temperature, n_sc=n_sc, bs=batch_size) 42 | evaluator = GSM8KEvaluator( 43 | output_extractor=utils.cot_sc_extractor, 44 | answer_extractor=lambda x: utils.retrieve_answer_from_dataset(x["answer"]), 45 | init_prompt=prompt, # will update dynamically 46 | disable_log=False, 47 | disable_tqdm=False, 48 | sample_prompt_type="cot", 49 | split=split 50 | ) 51 | 52 | accuracy = evaluator.evaluate(reasoner, shuffle_prompt=True, num_shot=4, resume=resume, log_dir=log_dir) 53 | print(f'accuracy: {accuracy:.4f}') 54 | return 0 55 | 56 | if __name__ == '__main__': 57 | fire.Fire(main) -------------------------------------------------------------------------------- /examples/gsm8k/rap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) 5 | sys.path.insert(1, path) 6 | 7 | from typing import Type, Callable, Optional, Literal 8 | 9 | import numpy as np 10 | 11 | from reasoners.benchmark import GSM8KEvaluator 12 | 13 | from reasoners import Reasoner, SearchAlgorithm 14 | from reasoners.algorithm import MCTS, MCTSNode, MCTSAggregation 15 | 16 | from world_model import GSM8kWorldModel, GSM8kState, GSM8kAction, GSM8kPromptDict 17 | from search_config import GSM8kConfig, GSM8kUsefulPrompt 18 | import utils 19 | 20 | 21 | def node_visualizer(x: MCTSNode[GSM8kState, GSM8kAction]): 22 | if not x.state: 23 | return {} 24 | return {"question": x.state[-1].sub_question, "answer": x.state[-1].sub_answer} 25 | 26 | def rap_gsm8k(base_model, 27 | prompt: GSM8kPromptDict, 28 | useful_prompt: GSM8kUsefulPrompt, 29 | search_algo: Type[SearchAlgorithm] = MCTS, 30 | resume: int = 0, 31 | n_action: int = 4, 32 | n_confidence: int = 8, 33 | depth_limit: int = 5, 34 | force_terminating_on_depth_limit: bool = True, 35 | batch_size: int = 4, 36 | temperature: float = 0.8, 37 | early_stop_base: int = 2, 38 | early_stop_threshold: float = 0.5, 39 | reward_alpha: float = 0.5, 40 | reward_confidence_default: float = 0.8, 41 | cum_reward: Callable[[list[float]], float] = np.mean, 42 | calc_q: Callable[[list[float]], float] = max, 43 | log_dir: Optional[str] = None, 44 | disable_log: bool = False, 45 | disable_tqdm: bool = False, 46 | output_trace_in_each_iter: bool = True, 47 | aggregate: bool = True, 48 | samples: int = -1, 49 | split: str = 'test', 50 | **search_algo_params): 51 | 52 | if aggregate: 53 | aggregator = MCTSAggregation(utils.retrieve_answer, weight_policy='edge') 54 | else: 55 | aggregator = None 56 | 57 | search_algo_params |= {'cum_reward': cum_reward, 'calc_q': calc_q, 'disable_tqdm': disable_tqdm, 58 | 'output_trace_in_each_iter': output_trace_in_each_iter, 59 | 'node_visualizer': node_visualizer, 'aggregator': aggregator} 60 | world_model = GSM8kWorldModel(base_model=base_model, 61 | n_confidence=n_confidence, batch_size=batch_size, temperature=temperature, 62 | early_stop_base=early_stop_base, early_stop_threshold=early_stop_threshold) 63 | config = GSM8kConfig(base_model=base_model, useful_prompt=useful_prompt, 64 | n_actions=n_action, batch_size=batch_size, temperature=temperature, 65 | reward_alpha=reward_alpha, reward_confidence_default=reward_confidence_default, 66 | force_terminating_on_depth_limit=force_terminating_on_depth_limit, depth_limit=depth_limit) 67 | search_algo = search_algo(**search_algo_params) 68 | reasoner = Reasoner(world_model=world_model, search_config=config, search_algo=search_algo) 69 | 70 | evaluator = GSM8KEvaluator(output_extractor=utils.retrieve_answer, 71 | answer_extractor=utils.retrieve_answer_from_dataset, 72 | init_prompt=prompt, 73 | sample_prompt_type="rap", 74 | disable_log=disable_log, 75 | disable_tqdm=disable_tqdm, samples=samples, split=split) 76 | 77 | accuracy = evaluator.evaluate(reasoner, num_shot=2, resume=resume, log_dir=log_dir) 78 | print(accuracy) 79 | 80 | 81 | if __name__ == '__main__': 82 | import json 83 | import fire 84 | from reasoners import VLLMModel 85 | 86 | 87 | def main(prompt: str, 88 | hf_path: str = 'microsoft/phi-2', 89 | batch_size: int = 1, 90 | useful_prompt: str = 'prompts/gsm8k/useful_examples.json', 91 | disable_log: bool = False, 92 | disable_tqdm: bool = False, 93 | split: str = 'test', 94 | **kwargs): 95 | 96 | with open(useful_prompt) as f: 97 | useful_prompt = json.load(f) 98 | with open(prompt) as f: 99 | prompt = json.load(f) 100 | 101 | base_model = VLLMModel(model=hf_path) 102 | 103 | rap_gsm8k(base_model=base_model, 104 | useful_prompt=useful_prompt, 105 | prompt=prompt, 106 | batch_size=batch_size, 107 | disable_log=disable_log, 108 | disable_tqdm=disable_tqdm, 109 | split=split, 110 | **kwargs) 111 | 112 | 113 | fire.Fire(main) 114 | -------------------------------------------------------------------------------- /examples/gsm8k/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional, Union 3 | 4 | from reasoners.base import AlgorithmOutput 5 | from collections import Counter 6 | 7 | def retrieve_answer(output: Union[list, str, AlgorithmOutput]) -> Optional[str]: 8 | ''' 9 | output should be a world_model.GSM8kState if being a list 10 | ''' 11 | if isinstance(output, AlgorithmOutput): 12 | if (result := getattr(output, 'aggregated_result', None)) is not None: 13 | return result 14 | output = output.terminal_state 15 | if isinstance(output, list): 16 | output = output[-1].sub_answer 17 | match = re.match(r'.*The answer is .*?([ $.0-9,\-=]+).*\..*', output) 18 | if match is None: 19 | return None 20 | answer = match[1].replace(',', '').replace('$', '').replace(' ', '') 21 | if '=' in answer: 22 | answer = answer[answer.rindex('=') + 1:] 23 | return answer 24 | 25 | 26 | def retrieve_answer_from_dataset(answer: Union[str, dict]) -> str: 27 | if isinstance(answer, dict): 28 | answer = answer['answer'] 29 | return re.match(r'[\S\s]*#### (.*)$', answer)[1].replace(',', '').replace(' ', '') 30 | 31 | 32 | def judge_answer(output: Optional[str], answer: str) -> bool: 33 | if output is None: 34 | return False 35 | try: 36 | output = int(output) 37 | answer = int(answer) 38 | return output == answer 39 | except ValueError: 40 | pass 41 | try: 42 | output = float(output) 43 | answer = float(answer) 44 | return output == answer 45 | except ValueError: 46 | pass 47 | return output == answer 48 | 49 | def retrieve_answer(output: Union[list, str]) -> Optional[str]: 50 | ''' 51 | output should be a world_model.GSM8kState if being a list 52 | ''' 53 | if isinstance(output, AlgorithmOutput): 54 | if (result := getattr(output, 'aggregated_result', None)) is not None: 55 | return result 56 | output = output.terminal_state 57 | if isinstance(output, list): 58 | output = output[-1].sub_answer 59 | match = re.match(r'.*[Tt]he answer is .*?([ $.0-9,\-]+).*\..*', output) 60 | if match is None: 61 | return None 62 | answer = match[1].replace(',', '').replace('$', '').replace(' ', '') 63 | if '=' in answer: 64 | answer = answer[answer.rindex('=') + 1:] 65 | return answer 66 | 67 | 68 | 69 | def rap_extractor(algo_output, aggregate=True): 70 | 71 | from reasoners.algorithm import MCTSAggregation 72 | if aggregate: 73 | aggregator = MCTSAggregation(retrieve_answer, weight_policy='edge_inverse_depth') 74 | output = aggregator(algo_output.tree_state) 75 | else: 76 | if algo_output.terminal_state is None: 77 | output = None 78 | else: 79 | output = retrieve_answer(algo_output.terminal_state) 80 | return output 81 | 82 | def cot_sc_extractor(algo_output, sc=True): 83 | # aggregate the results from multiple reasoning chains with majority vote 84 | answers = [retrieve_answer(x) for x in algo_output] 85 | answers = [x for x in answers if x is not None] 86 | if len(answers) == 0: 87 | return '' 88 | counter = Counter(answers) 89 | return counter.most_common(1)[0][0] -------------------------------------------------------------------------------- /examples/gsm8k/world_model.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import NamedTuple, TypedDict 3 | from collections import defaultdict 4 | from reasoners import WorldModel 5 | import utils 6 | from reasoners.base import Example 7 | 8 | 9 | class SubResult(NamedTuple): 10 | sub_question: str 11 | sub_answer: str 12 | confidence: float 13 | 14 | 15 | GSM8kState = list[SubResult] 16 | GSM8kAction = str 17 | GSM8kExample = str 18 | 19 | 20 | class GSM8kPromptDict(TypedDict): 21 | instruction: str 22 | interactive_examples: list[str] 23 | useful_examples: list[str] 24 | question_prefix: str 25 | subquestion_prefix: str 26 | overall_question_prefix: str 27 | answer_prefix: str 28 | 29 | 30 | class GSM8kWorldModel(WorldModel[GSM8kState, GSM8kAction, GSM8kExample]): 31 | """ 32 | GSM8k World Model 33 | State: [[sub_question_1, sub_answer_1, confidence_1], [sub_question_2, sub_answer_2, confidence_2], ...] 34 | Action: sub_question 35 | """ 36 | 37 | def __init__(self, 38 | base_model, 39 | n_confidence=8, 40 | batch_size=2, 41 | temperature=0.8, 42 | top_k=50, 43 | top_p=0.95, 44 | early_stop_base=None, 45 | early_stop_threshold=1.) -> None: 46 | super().__init__() 47 | self.base_model = base_model 48 | self.batch_size = batch_size 49 | self.n_confidence = n_confidence 50 | self.temperature = temperature 51 | self.early_stop_base = early_stop_base if early_stop_base is not None else n_confidence 52 | self.early_stop_threshold = early_stop_threshold 53 | self.prompt_examples = "" 54 | self.n_shots = 0 55 | self.top_k = top_k 56 | self.top_p = top_p 57 | 58 | def update_example(self, example: Example, prompt: GSM8kPromptDict = None) -> None: 59 | super().update_example(example, prompt) 60 | assert prompt is not None 61 | self.prompt = prompt 62 | with io.StringIO() as f: 63 | f.write(self.prompt['instruction'] + '\n\n') 64 | for idx, example in enumerate(self.prompt['interactive_examples']): 65 | f.write(example.format(idx=idx + 1) + '\n\n') 66 | self.n_shots = len(self.prompt['interactive_examples']) 67 | self.prompt_examples = f.getvalue() 68 | 69 | def init_state(self) -> list: 70 | return [] 71 | 72 | def step(self, state: GSM8kState, action: GSM8kAction) -> tuple[GSM8kState, dict]: 73 | state = state.copy() 74 | with io.StringIO() as f: 75 | f.write(self.prompt_examples) 76 | f.write(self.prompt["question_prefix"].format(idx=self.n_shots + 1, question=self.example) + "\n") 77 | for idx, (q, a, _) in enumerate(state): 78 | f.write( 79 | self.prompt["subquestion_prefix"].format(idx=self.n_shots + 1, sub_idx=idx + 1) + " " + q + "\n") 80 | f.write(self.prompt["answer_prefix"].format(idx=self.n_shots + 1, sub_idx=idx + 1) + " " + a + "\n") 81 | f.write(self.prompt["subquestion_prefix"].format(idx=self.n_shots + 1, 82 | sub_idx=len(state) + 1) + " " + action + "\n") 83 | f.write(self.prompt["answer_prefix"].format(idx=self.n_shots + 1, sub_idx=len(state) + 1)) 84 | model_input = f.getvalue() 85 | 86 | answer_dict = defaultdict(list) # map from answer to list of thoughts 87 | result = "" 88 | for start1 in range(0, self.n_confidence, self.early_stop_base): 89 | stop1 = min(start1 + self.early_stop_base, self.n_confidence) 90 | 91 | for start in range(start1, stop1, self.batch_size): 92 | stop = min(start + self.batch_size, stop1) 93 | num = stop - start 94 | 95 | outputs = self.base_model.generate(model_input, 96 | temperature=self.temperature, 97 | num_return_sequences=num, 98 | stop='\n').text 99 | for output in outputs: 100 | result = output.strip() 101 | answer = utils.retrieve_answer(result) 102 | if answer is not None: 103 | answer_dict[answer].append(result) 104 | 105 | # Early stop if confidence is high enough 106 | if len(answer_dict) == 0: # no answer yet 107 | continue 108 | 109 | sorted_answer_dict = sorted(answer_dict.items(), key=lambda p: len(p[1]), reverse=True) 110 | max_len = len(sorted_answer_dict[0][1]) 111 | if max_len / stop1 >= self.early_stop_threshold: 112 | if len(sorted_answer_dict) >= 2 and max_len == len(sorted_answer_dict[1][1]): 113 | pass # Tie with the second best answer 114 | else: 115 | break 116 | 117 | if len(answer_dict) == 0: 118 | print("Warning: no answer found") 119 | confidence, answer = 0, result # No reasonable answer found. Fall back to choose the last response 120 | else: 121 | sorted_answer_dict = sorted(answer_dict.items(), key=lambda p: len(p[1]), reverse=True) 122 | max_answer = sorted_answer_dict[0] 123 | max_answer_output_list = max_answer[1] 124 | max_len = len(max_answer_output_list) 125 | answer = max_answer_output_list[0] # Here we simply choose the first appearance of the answer 126 | confidence = max_len / sum(len(v) for v in answer_dict.values()) 127 | 128 | state.append(SubResult(action, answer, confidence)) 129 | aux = {'confidence': confidence} 130 | return state, aux 131 | 132 | def is_terminal(self, state: GSM8kState) -> bool: 133 | if len(state) > 0 and "Now we can answer" in state[-1].sub_question: 134 | return True 135 | else: 136 | return False 137 | -------------------------------------------------------------------------------- /gsm8k_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from typing import Literal, Optional 5 | from dataclasses import dataclass 6 | from datetime import datetime 7 | 8 | model_ip = 'http://10.234.38.2:23100/v1' 9 | 10 | prompt_path = { 11 | 'default': 'prompts/gsm8k/prompt_pool.json', 12 | 'rot': 'prompts/gsm8k/prompt_pool_rot.json', 13 | 'rot-2-iter': 'prompts/gsm8k/prompt_pool_rot_2_iter.json', 14 | 'rot-3-iter': 'prompts/gsm8k/prompt_pool_rot_3_iter.json', 15 | 'rot-4-iter': 'prompts/gsm8k/prompt_pool_rot_4_iter.json', 16 | 'rot-5-iter': 'prompts/gsm8k/prompt_pool_rot_5_iter.json', 17 | 'rot-6-iter': 'prompts/gsm8k/prompt_pool_rot_6_iter.json', 18 | '0.5': 'prompts/gsm8k/prompt_pool_rot_promising_0.5.json', 19 | 'all': 'prompts/gsm8k/prompt_pool_rot_all.json', 20 | 'random': 'prompts/gsm8k/prompt_pool_rot_random.json', 21 | 'cot-default': 'prompts/gsm8k/cot_default.json', 22 | 'cot-rot': 'prompts/gsm8k/cot_rot.json', 23 | } 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--prompt', type=str, default='default') 28 | parser.add_argument('--model', type=str, default='phi-2') 29 | parser.add_argument('--mode', type=str, default='mcts') 30 | parser.add_argument('--n_iters', type=int, default=-1) 31 | parser.add_argument('--split', type=str, default='test') 32 | args = parser.parse_args() 33 | 34 | os.environ['VLLM_API_BASE'] = model_ip 35 | 36 | if args.mode == 'cot': 37 | command = f'python examples/gsm8k/cot.py --hf_path {args.model}' 38 | elif args.mode == 'bfs': 39 | command = f'python examples/gsm8k/bfs.py --hf_path {args.model} --width {args.n_iters}' 40 | elif args.mode == 'mcts': 41 | command = f'python examples/gsm8k/rap.py --hf_path {args.model}' 42 | if args.n_iters != -1: 43 | command += f' --n_iters {args.n_iters}' 44 | 45 | command += f' --prompt {prompt_path[args.prompt]} --split {args.split}' 46 | 47 | log_path = f'logs/gsm8k/{args.model}_{args.prompt}_{args.mode}_{datetime.now().strftime("%Y%m%d-%H%M%S")}' 48 | 49 | command += f' --log_dir {log_path}' 50 | print(command) 51 | 52 | import time 53 | t = time.time() 54 | os.system(command) 55 | time_consumed = time.time() - t 56 | with open(f'{log_path}/time_consumed.txt', 'w') as f: 57 | f.write(str(time_consumed)) -------------------------------------------------------------------------------- /gsm8k_rot.sh: -------------------------------------------------------------------------------- 1 | output_name=$1 2 | 3 | outputs=$(python gsm8k_control.py --mode mcts --n_iter 3 --split train) 4 | log_dir=$(echo $outputs | grep -oP 'log_dir [^ ]*' | cut -d' ' -f2) 5 | 6 | python rot_scripts/gsm8k_analysis.py --path $log_dir/algo_output --output_name $log_dir/rot_analysis.json 7 | python rot_scripts/gsm8k_generate_rot_prompt.py --path $log_dir/rot_analysis.json --output_name $output_name --mode mcts -------------------------------------------------------------------------------- /gsm8k_summarization.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /prompts/gsm8k/cot_default.json: -------------------------------------------------------------------------------- 1 | { 2 | "cot_pool": 3 | [ 4 | "Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\nA: Natalia sold 48 clips in April and half as many clips in May, so she sold 48 / 2 = 24 clips in May. Altogether, she sold 48 + 24 = 72 clips. The answer is 72.\n\n", 5 | "Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\nA: Since Weng earns $12 an hour for babysitting, she earns $12 / 60 = $0.2 per minute. Working 50 minutes, she earned $0.2 x 50 = $10. The answer is 10.\n\n", 6 | "Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nA: In the beginning, Betty has only half of the money she needs, which is 100 / 2 = $50. Her grandparents gave her twice as much as her parents, so they gave her 15 * 2 = $30. Now that she got $15 from her parents and $30 from her grandparents, she will need $100 - $15 - $30 = $55. Since she already has $50, she needs $55 - $50 = $5 more. The answer is 5.\n\n", 7 | "Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?\nA: Julie read twice as many pages as yesterday, so she read 12 * 2 = 24 pages today. Since yesterday, Julie read 12 + 24 = 36 pages. So, there are 120 - 36 = 84 pages left to be read. Since she wants to read half of the remaining pages, she should read 84 / 2 = 42 pages. The answer is 42.\n\n", 8 | "Q: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\nA: James writes a 3-page letter to 2 different friends twice a week, so he writes 3 * 2 * 2 = 12 pages every week. There are 52 weeks in a year, so he writes 12 * 52 = 624 pages a year. The answer is 624.\n\n", 9 | "Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\nA: There are 80% more purple flowers than yellow flowers, so there are 10 * 1.8 = 18 purple flowers. Because there are 10 yellow flowers, so there are 10 + 18 = 28 yellow and purple flowers. There are 25% as many green flowers as there are yellow and purple flowers, so there are 28 * 0.25 = 7 green flowers. In summary, he has 10 + 18 + 7 = 35 flowers in his garden. The answer is 35.\n\n", 10 | "Q: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nA: He buys 2 large pizzas, so he has 2 * 16 = 32 slices. He buys 2 small pizzas, so he has 2 * 8 = 16 slices. There are 32 slices from the large pizzas and 16 slices from the small pizzas, so he eats 32 + 16 = 48 pieces that day. The answer is 48.\n\n", 11 | "Q: Ken created a care package to send to his brother, who was away at boarding school. Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds. Then, he added enough brownies to cause the weight to triple. Next, he added another 2 pounds of jelly beans. And finally, he added enough gummy worms to double the weight once again. What was the final weight of the box of goodies, in pounds?\nA: Ken poured jelly beans into the box until the weight was 2 pounds, so the weight of the box was 2 pounds at first. Then Ken added enough brownies to cause the weight to triple, so the weight of the box was 2 * 3 = 6 pounds. After Ken added another 2 pounds of jelly beans, the weight of the box was 6 + 2 = 8 pounds. Finally, he added enough gummy worms to double the weight once again, so the weight of the box was 8 * 2 = 16 pounds. The answer is 16.\n\n", 12 | "Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?\nA: Alexis spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt, so she spent 30 + 46 + 38 + 11 + 18 = $143 on everything else. Alexis had a budget of $200 and finally there was $16 left, so she spent 200 - 16 = $184 in total. Since Alexis has spent $143 on everything else, she spent 184 - 143 = $41 on the shoes. The answer is 41.\n\n", 13 | "Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?\nA: Tina makes $18.00 an hour, so she makes 18 * 8 = $144.00 in an 8-hour shift. Tina works 10 hours every day for 5 days, so she works 10 * 5 = 50 hours. Since she works 8 hours every day, she gets 50 - 8 * 5 = 10 hours of overtime. Her hourly overtime wage is 18 + 18 / 2 = $27.00. Tina works 10 hours a day, and 8 hours of that is paid at her regular hourly wage, so she makes 10 - 8 = 2 hours of overtime every day. Since her hourly overtime wage is $27.00, she makes 27 * 2 = $54.00 in overtime each day. Every day, tina makes $144.00 in an 8-hour shift and $54.00 in overtime, so she makes 144 + 54 = $198.00 each day. Tina works 5 days a week, so she makes 198 * 5 = $990.00. The answer is 990.\n\n" 14 | ], 15 | "prefix": "Q: {QUESTION}\nA:" 16 | } -------------------------------------------------------------------------------- /prompts/gsm8k/cot_rot.json: -------------------------------------------------------------------------------- 1 | { 2 | "guidelines": "To solve math word problems effectively and minimize errors, integrate the various suggested policies into the following comprehensive approach:\n\n1. Clarify the quantities involved, distinguishing between total amounts, unit costs, and multiplicative factors. Use correct mathematical operations based on these relationships.\n\n2. When necessary, accurately convert between units (e.g., minutes to hours) using proper mathematical operations, and maintain consistency in units throughout the problem.\n\n3. Directly translate the word problem’s conditions into mathematical equations or expressions, and apply correct mathematical operations.\n\n4. Address each sub-question in a logical order, one at a time, methodically performing operations and maintaining at least three decimal places of precision for operations with hourly rates or similar calculations.\n\n5. When dealing with ratios or relational costs, ensure the correct base amounts are used for calculations.\n\n6. Sequentially solve the sub-questions, constantly cross-referencing with the problem's conditions, and checking that each sub-answer is logical and consistent with the overall scenario.\n\n7. Refrain from rounding intermediate results to maintain accuracy, and only round the final result if necessary, ensuring the level of precision matches that of the given variables.\n\n8. Regularly ensure that all units and logical consistencies are maintained throughout the solution process, avoiding impossible situations (such as fractions of indivisible items).\n\n9. Double-check each step against the provided information and common sense narrative to avoid repetition of errors or misinterpretation.\n\n10. Before finalizing the answer, confirm that each part of the problem has been addressed and that the math operations have been applied correctly. Validate each sub-answer and the final answer, ensuring that they make sense and address the question accurately and fully.\n\n11. After combining sub-answers for the final solution, question its reasonableness in the context of the problem, verifying if quantities add up correctly and whether the overall solution is plausible.", 3 | "cot_pool": 4 | [ 5 | "Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\nA: Natalia sold 48 clips in April and half as many clips in May, so she sold 48 / 2 = 24 clips in May. Altogether, she sold 48 + 24 = 72 clips. The answer is 72.\n\n", 6 | "Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\nA: Since Weng earns $12 an hour for babysitting, she earns $12 / 60 = $0.2 per minute. Working 50 minutes, she earned $0.2 x 50 = $10. The answer is 10.\n\n", 7 | "Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nA: In the beginning, Betty has only half of the money she needs, which is 100 / 2 = $50. Her grandparents gave her twice as much as her parents, so they gave her 15 * 2 = $30. Now that she got $15 from her parents and $30 from her grandparents, she will need $100 - $15 - $30 = $55. Since she already has $50, she needs $55 - $50 = $5 more. The answer is 5.\n\n", 8 | "Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?\nA: Julie read twice as many pages as yesterday, so she read 12 * 2 = 24 pages today. Since yesterday, Julie read 12 + 24 = 36 pages. So, there are 120 - 36 = 84 pages left to be read. Since she wants to read half of the remaining pages, she should read 84 / 2 = 42 pages. The answer is 42.\n\n", 9 | "Q: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\nA: James writes a 3-page letter to 2 different friends twice a week, so he writes 3 * 2 * 2 = 12 pages every week. There are 52 weeks in a year, so he writes 12 * 52 = 624 pages a year. The answer is 624.\n\n", 10 | "Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\nA: There are 80% more purple flowers than yellow flowers, so there are 10 * 1.8 = 18 purple flowers. Because there are 10 yellow flowers, so there are 10 + 18 = 28 yellow and purple flowers. There are 25% as many green flowers as there are yellow and purple flowers, so there are 28 * 0.25 = 7 green flowers. In summary, he has 10 + 18 + 7 = 35 flowers in his garden. The answer is 35.\n\n", 11 | "Q: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nA: He buys 2 large pizzas, so he has 2 * 16 = 32 slices. He buys 2 small pizzas, so he has 2 * 8 = 16 slices. There are 32 slices from the large pizzas and 16 slices from the small pizzas, so he eats 32 + 16 = 48 pieces that day. The answer is 48.\n\n", 12 | "Q: Ken created a care package to send to his brother, who was away at boarding school. Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds. Then, he added enough brownies to cause the weight to triple. Next, he added another 2 pounds of jelly beans. And finally, he added enough gummy worms to double the weight once again. What was the final weight of the box of goodies, in pounds?\nA: Ken poured jelly beans into the box until the weight was 2 pounds, so the weight of the box was 2 pounds at first. Then Ken added enough brownies to cause the weight to triple, so the weight of the box was 2 * 3 = 6 pounds. After Ken added another 2 pounds of jelly beans, the weight of the box was 6 + 2 = 8 pounds. Finally, he added enough gummy worms to double the weight once again, so the weight of the box was 8 * 2 = 16 pounds. The answer is 16.\n\n", 13 | "Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?\nA: Alexis spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt, so she spent 30 + 46 + 38 + 11 + 18 = $143 on everything else. Alexis had a budget of $200 and finally there was $16 left, so she spent 200 - 16 = $184 in total. Since Alexis has spent $143 on everything else, she spent 184 - 143 = $41 on the shoes. The answer is 41.\n\n", 14 | "Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?\nA: Tina makes $18.00 an hour, so she makes 18 * 8 = $144.00 in an 8-hour shift. Tina works 10 hours every day for 5 days, so she works 10 * 5 = 50 hours. Since she works 8 hours every day, she gets 50 - 8 * 5 = 10 hours of overtime. Her hourly overtime wage is 18 + 18 / 2 = $27.00. Tina works 10 hours a day, and 8 hours of that is paid at her regular hourly wage, so she makes 10 - 8 = 2 hours of overtime every day. Since her hourly overtime wage is $27.00, she makes 27 * 2 = $54.00 in overtime each day. Every day, tina makes $144.00 in an 8-hour shift and $54.00 in overtime, so she makes 144 + 54 = $198.00 each day. Tina works 5 days a week, so she makes 198 * 5 = $990.00. The answer is 990.\n\n" 15 | ], 16 | "prefix": "Q: {QUESTION}\nA:" 17 | } -------------------------------------------------------------------------------- /prompts/gsm8k/mistral/cot_rot_mistral.json: -------------------------------------------------------------------------------- 1 | { 2 | "guidelines": "To effectively address mathematical word problems while minimizing errors, follow this comprehensive approach that integrates the recommended guidelines:\n\n1. When necessary, accurately convert between units (e.g., minutes to hours) using precise mathematical procedures and maintain consistency in units throughout the problem.\n2. Directly translate the conditions of the word problem into mathematical equations or expressions and apply accurate mathematical operations.\n3. Address each sub-question in a logical sequence, one at a time, executing operations methodically and ensuring at least three decimal places of precision for tasks involving hourly rates or similar calculations.\n4. When dealing with ratios or relational costs, ensure the correct base amounts are used for computations.\n5. Progress through the sub-questions systematically, continually referencing the problem's conditions, and verifying that each sub-answer is logical and aligns with the overall scenario.\n6. Avoid rounding intermediate results to maintain precision, and only round the final outcome if necessary, ensuring the accuracy aligns with the given variables.\n7. Regularly verify that all units and logical consistencies are maintained throughout the solution process, avoiding implausible scenarios (e.g., fractions of indivisible items).\n8. Validate each step against the provided information and common-sense narrative to prevent the recurrence of errors or misinterpretation.\n9. Before finalizing the solution, confirm that each aspect of the problem has been addressed and that mathematical operations have been applied accurately. Validate each sub-answer and the final solution to ensure they are sensible and fully address the question.\n", 3 | "cot_pool": 4 | [ 5 | "Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\nA: Natalia sold 48 clips in April and half as many clips in May, so she sold 48 / 2 = 24 clips in May. Altogether, she sold 48 + 24 = 72 clips. The answer is 72.\n\n", 6 | "Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\nA: Since Weng earns $12 an hour for babysitting, she earns $12 / 60 = $0.2 per minute. Working 50 minutes, she earned $0.2 x 50 = $10. The answer is 10.\n\n", 7 | "Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nA: In the beginning, Betty has only half of the money she needs, which is 100 / 2 = $50. Her grandparents gave her twice as much as her parents, so they gave her 15 * 2 = $30. Now that she got $15 from her parents and $30 from her grandparents, she will need $100 - $15 - $30 = $55. Since she already has $50, she needs $55 - $50 = $5 more. The answer is 5.\n\n", 8 | "Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?\nA: Julie read twice as many pages as yesterday, so she read 12 * 2 = 24 pages today. Since yesterday, Julie read 12 + 24 = 36 pages. So, there are 120 - 36 = 84 pages left to be read. Since she wants to read half of the remaining pages, she should read 84 / 2 = 42 pages. The answer is 42.\n\n", 9 | "Q: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\nA: James writes a 3-page letter to 2 different friends twice a week, so he writes 3 * 2 * 2 = 12 pages every week. There are 52 weeks in a year, so he writes 12 * 52 = 624 pages a year. The answer is 624.\n\n", 10 | "Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\nA: There are 80% more purple flowers than yellow flowers, so there are 10 * 1.8 = 18 purple flowers. Because there are 10 yellow flowers, so there are 10 + 18 = 28 yellow and purple flowers. There are 25% as many green flowers as there are yellow and purple flowers, so there are 28 * 0.25 = 7 green flowers. In summary, he has 10 + 18 + 7 = 35 flowers in his garden. The answer is 35.\n\n", 11 | "Q: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nA: He buys 2 large pizzas, so he has 2 * 16 = 32 slices. He buys 2 small pizzas, so he has 2 * 8 = 16 slices. There are 32 slices from the large pizzas and 16 slices from the small pizzas, so he eats 32 + 16 = 48 pieces that day. The answer is 48.\n\n", 12 | "Q: Ken created a care package to send to his brother, who was away at boarding school. Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds. Then, he added enough brownies to cause the weight to triple. Next, he added another 2 pounds of jelly beans. And finally, he added enough gummy worms to double the weight once again. What was the final weight of the box of goodies, in pounds?\nA: Ken poured jelly beans into the box until the weight was 2 pounds, so the weight of the box was 2 pounds at first. Then Ken added enough brownies to cause the weight to triple, so the weight of the box was 2 * 3 = 6 pounds. After Ken added another 2 pounds of jelly beans, the weight of the box was 6 + 2 = 8 pounds. Finally, he added enough gummy worms to double the weight once again, so the weight of the box was 8 * 2 = 16 pounds. The answer is 16.\n\n", 13 | "Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?\nA: Alexis spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt, so she spent 30 + 46 + 38 + 11 + 18 = $143 on everything else. Alexis had a budget of $200 and finally there was $16 left, so she spent 200 - 16 = $184 in total. Since Alexis has spent $143 on everything else, she spent 184 - 143 = $41 on the shoes. The answer is 41.\n\n", 14 | "Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?\nA: Tina makes $18.00 an hour, so she makes 18 * 8 = $144.00 in an 8-hour shift. Tina works 10 hours every day for 5 days, so she works 10 * 5 = 50 hours. Since she works 8 hours every day, she gets 50 - 8 * 5 = 10 hours of overtime. Her hourly overtime wage is 18 + 18 / 2 = $27.00. Tina works 10 hours a day, and 8 hours of that is paid at her regular hourly wage, so she makes 10 - 8 = 2 hours of overtime every day. Since her hourly overtime wage is $27.00, she makes 27 * 2 = $54.00 in overtime each day. Every day, tina makes $144.00 in an 8-hour shift and $54.00 in overtime, so she makes 144 + 54 = $198.00 each day. Tina works 5 days a week, so she makes 198 * 5 = $990.00. The answer is 990.\n\n" 15 | ], 16 | "prefix": "Q: {QUESTION}\nA:" 17 | } -------------------------------------------------------------------------------- /prompts/gsm8k/mistral/prompt_pool_improved_mistral.json: -------------------------------------------------------------------------------- 1 | { 2 | "instruction": "To effectively address mathematical word problems while minimizing errors, follow this comprehensive approach that integrates the recommended guidelines:\n\n1. When necessary, accurately convert between units (e.g., minutes to hours) using precise mathematical procedures and maintain consistency in units throughout the problem.\n2. Directly translate the conditions of the word problem into mathematical equations or expressions and apply accurate mathematical operations.\n3. Address each sub-question in a logical sequence, one at a time, executing operations methodically and ensuring at least three decimal places of precision for tasks involving hourly rates or similar calculations.\n4. When dealing with ratios or relational costs, ensure the correct base amounts are used for computations.\n5. Progress through the sub-questions systematically, continually referencing the problem's conditions, and verifying that each sub-answer is logical and aligns with the overall scenario.\n6. Avoid rounding intermediate results to maintain precision, and only round the final outcome if necessary, ensuring the accuracy aligns with the given variables.\n7. Regularly verify that all units and logical consistencies are maintained throughout the solution process, avoiding implausible scenarios (e.g., fractions of indivisible items).\n8. Validate each step against the provided information and common-sense narrative to prevent the recurrence of errors or misinterpretation.\n9. Before finalizing the solution, confirm that each aspect of the problem has been addressed and that mathematical operations have been applied accurately. Validate each sub-answer and the final solution to ensure they are sensible and fully address the question.\n", 3 | "interactive_examples": [ 4 | "Question {idx}: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\nQuestion {idx}.1: How many clips did Natalia sell in May?\nAnswer 1.1: Natalia sold 48 clips in April and half as many clips in May, so she sold 48 / 2 = 24 clips in May. The answer is 24.\nQuestion {idx}.2: Now we can answer the question: How many clips did Natalia sell altogether in April and May?\nAnswer 1.2: Natalia sold 48 clips in April and 24 clips in May, so altogether she sold 48 + 24 = 72 clips. The answer is 72.", 5 | "Question {idx}: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\nQuestion {idx}.1: How much does Weng earn per minute?\nAnswer 1.1: Since Weng earns $12 an hour for babysitting, she earns $12 / 60 = $0.2 per minute. The answer is 0.2.\nQuestion {idx}.2: Now we can answer the question: How much did she earn?\nAnswer 1.2: Working 50 minutes, she earned $0.2 x 50 = $10. The answer is 10.", 6 | "Question {idx}: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nQuestion {idx}.1: How much money does Betty have in the beginning?\nAnswer 1.1: In the beginning, Betty has only half of the money she needs, which is 100 / 2 = $50. The answer is 50.\nQuestion {idx}.2: How much money did Betty's grandparents give her?\nAnswer 1.2: Her grandparents gave her twice as much as her parents, so they gave her 15 * 2 = $30. The answer is 30.\nQuestion {idx}.3: Now we can answer the question: How much more money does Betty need to buy the wallet?\nAnswer 1.3: Now that she got $15 from her parents and $30 from her grandparents, she will need $100 - $15 - $30 = $55. Since she already has $50, she needs $55 - $50 = $5 more. The answer is 5.", 7 | "Question {idx}: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?\nQuestion {idx}.1: How many pages did Julie read today?\nAnswer 1.1: Julie read twice as many pages as yesterday, so she read 12 * 2 = 24 pages The answer is 24.\nQuestion {idx}.2: How many pages did Julie read since yesterday?\nAnswer 1.2: Since yesterday, Julie read 12 + 24 = 36 pages. The answer is 36.\nQuestion {idx}.3: How many pages are left to be read?\nAnswer 1.3: There are 120 - 36 = 84 pages left to be read. The answer is 84.\nQuestion {idx}.4: Now we can answer the question: How many pages should she read?\nAnswer 1.4: She wants to read half of the remaining pages, so she should read 84 / 2 = 42 pages. The answer is 42.", 8 | "Question {idx}: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\nQuestion {idx}.1: How many pages does he write every week?\nAnswer 1.1: James writes a 3-page letter to 2 different friends twice a week, so he writes 3 * 2 * 2 = 12 pages every week. The answer is 12.\nQuestion {idx}.2: How many weeks are there in a year?\nAnswer 1.2: There are 52 weeks in a year. The answer is 52.\nQuestion {idx}.3: Now we can answer the question: How many pages does he write a year?\nAnswer 1.3: James writes 12 pages every week, so he writes 12 * 52 = 624 pages a year. The answer is 624.", 9 | "Question {idx}: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\nQuestion {idx}.1: How many purple flowers are there?\nAnswer 1.1: There are 80% more purple flowers than yellow flowers, so there are 10 * 1.8 = 18 purple flowers. The answer is 18.\nQuestion {idx}.2: How many yellow and purple flowers are there in total?\nAnswer 1.2: There are 10 yellow flowers and 18 purple flowers, so there are 10 + 18 = 28 yellow and purple flowers. The answer is 28.\nQuestion {idx}.3: How many green flowers are there?\nAnswer 1.3: There are 25% as many green flowers as there are yellow and purple flowers, so there are 28 * 0.25 = 7 green flowers. The answer is 7.\nQuestion {idx}.4: Now we can answer the question: How many flowers does Mark have in his garden?\nAnswer 1.4: Mark has 10 yellow flowers, 18 purple flowers, and 7 green flowers, so he has 10 + 18 + 7 = 35 flowers in his garden. The answer is 35.", 10 | "Question {idx}: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nQuestion {idx}.1: How many slices do the large pizzas have?\nAnswer 1.1: He buys 2 large pizzas, so he has 2 * 16 = 32 slices. The answer is 32.\nQuestion {idx}.2: How many slices do the small pizzas have?\nAnswer 1.2: He buys 2 small pizzas, so he has 2 * 8 = 16 slices. The answer is 16.\nQuestion {idx}.3: How many pieces does he eat that day?\nAnswer 1.3: Now we can answer the question: There are 32 slices from the large pizzas and 16 slices from the small pizzas, so he eats 32 + 16 = 48 pieces that day. The answer is 48.", 11 | "Question {idx}: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?\nQuestion {idx}.1: How much did Alexis pay for everything else?\nAnswer 1.1: Alexis spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt, so she spent 30 + 46 + 38 + 11 + 18 = $143 on everything else. The answer is 143.\nQuestion {idx}.2: How much money did Alexis spend in total?\nAnswer 1.2: Alexis had a budget of $200 and finally there was $16 left, so she spent 200 - 16 = $184 in total. The answer is 184.\nQuestion {idx}.3: Now we can answer the question: How much did Alexis pay for the shoes?\nAnswer 1.3: Alexis spent $143 on everything else, so she spent 184 - 143 = $41 on the shoes. The answer is 41." 12 | ], 13 | "useful_examples": [ 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "" 24 | ], 25 | "question_prefix": "Question {idx}: {question}", 26 | "subquestion_prefix": "Question {idx}.{sub_idx}:", 27 | "overall_question_prefix": "Now we can answer the question:", 28 | "answer_prefix": "Answer {idx}.{sub_idx}:" 29 | } -------------------------------------------------------------------------------- /prompts/gsm8k/mixtral/cot_rot_mixtral.json: -------------------------------------------------------------------------------- 1 | { 2 | "guidelines": "To solve math word problems effectively and minimize errors, integrate the various suggested policies into the following comprehensive approach:\n\n1. Clarify the quantities involved, distinguishing between total amounts, unit costs, and multiplicative factors. Use correct mathematical operations based on these relationships.\n\n2. When necessary, accurately convert between units (e.g., minutes to hours) using proper mathematical operations, and maintain consistency in units throughout the problem.\n\n3. Directly translate the word problem’s conditions into mathematical equations or expressions, and apply correct mathematical operations.\n\n4. Address each sub-question in a logical order, one at a time, methodically performing operations and maintaining at least three decimal places of precision for operations with hourly rates or similar calculations.\n\n5. When dealing with ratios or relational costs, ensure the correct base amounts are used for calculations.\n\n6. Sequentially solve the sub-questions, constantly cross-referencing with the problem's conditions, and checking that each sub-answer is logical and consistent with the overall scenario.\n\n7. Refrain from rounding intermediate results to maintain accuracy, and only round the final result if necessary, ensuring the level of precision matches that of the given variables.\n\n8. Regularly ensure that all units and logical consistencies are maintained throughout the solution process, avoiding impossible situations (such as fractions of indivisible items).\n\n9. Double-check each step against the provided information and common sense narrative to avoid repetition of errors or misinterpretation.\n\n10. Before finalizing the answer, confirm that each part of the problem has been addressed and that the math operations have been applied correctly. Validate each sub-answer and the final answer, ensuring that they make sense and address the question accurately and fully.\n\n11. After combining sub-answers for the final solution, question its reasonableness in the context of the problem, verifying if quantities add up correctly and whether the overall solution is plausible.", 3 | "cot_pool": 4 | [ 5 | "Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\nA: Natalia sold 48 clips in April and half as many clips in May, so she sold 48 / 2 = 24 clips in May. Altogether, she sold 48 + 24 = 72 clips. The answer is 72.\n\n", 6 | "Q: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\nA: Since Weng earns $12 an hour for babysitting, she earns $12 / 60 = $0.2 per minute. Working 50 minutes, she earned $0.2 x 50 = $10. The answer is 10.\n\n", 7 | "Q: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nA: In the beginning, Betty has only half of the money she needs, which is 100 / 2 = $50. Her grandparents gave her twice as much as her parents, so they gave her 15 * 2 = $30. Now that she got $15 from her parents and $30 from her grandparents, she will need $100 - $15 - $30 = $55. Since she already has $50, she needs $55 - $50 = $5 more. The answer is 5.\n\n", 8 | "Q: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?\nA: Julie read twice as many pages as yesterday, so she read 12 * 2 = 24 pages today. Since yesterday, Julie read 12 + 24 = 36 pages. So, there are 120 - 36 = 84 pages left to be read. Since she wants to read half of the remaining pages, she should read 84 / 2 = 42 pages. The answer is 42.\n\n", 9 | "Q: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\nA: James writes a 3-page letter to 2 different friends twice a week, so he writes 3 * 2 * 2 = 12 pages every week. There are 52 weeks in a year, so he writes 12 * 52 = 624 pages a year. The answer is 624.\n\n", 10 | "Q: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\nA: There are 80% more purple flowers than yellow flowers, so there are 10 * 1.8 = 18 purple flowers. Because there are 10 yellow flowers, so there are 10 + 18 = 28 yellow and purple flowers. There are 25% as many green flowers as there are yellow and purple flowers, so there are 28 * 0.25 = 7 green flowers. In summary, he has 10 + 18 + 7 = 35 flowers in his garden. The answer is 35.\n\n", 11 | "Q: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nA: He buys 2 large pizzas, so he has 2 * 16 = 32 slices. He buys 2 small pizzas, so he has 2 * 8 = 16 slices. There are 32 slices from the large pizzas and 16 slices from the small pizzas, so he eats 32 + 16 = 48 pieces that day. The answer is 48.\n\n", 12 | "Q: Ken created a care package to send to his brother, who was away at boarding school. Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds. Then, he added enough brownies to cause the weight to triple. Next, he added another 2 pounds of jelly beans. And finally, he added enough gummy worms to double the weight once again. What was the final weight of the box of goodies, in pounds?\nA: Ken poured jelly beans into the box until the weight was 2 pounds, so the weight of the box was 2 pounds at first. Then Ken added enough brownies to cause the weight to triple, so the weight of the box was 2 * 3 = 6 pounds. After Ken added another 2 pounds of jelly beans, the weight of the box was 6 + 2 = 8 pounds. Finally, he added enough gummy worms to double the weight once again, so the weight of the box was 8 * 2 = 16 pounds. The answer is 16.\n\n", 13 | "Q: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?\nA: Alexis spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt, so she spent 30 + 46 + 38 + 11 + 18 = $143 on everything else. Alexis had a budget of $200 and finally there was $16 left, so she spent 200 - 16 = $184 in total. Since Alexis has spent $143 on everything else, she spent 184 - 143 = $41 on the shoes. The answer is 41.\n\n", 14 | "Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?\nA: Tina makes $18.00 an hour, so she makes 18 * 8 = $144.00 in an 8-hour shift. Tina works 10 hours every day for 5 days, so she works 10 * 5 = 50 hours. Since she works 8 hours every day, she gets 50 - 8 * 5 = 10 hours of overtime. Her hourly overtime wage is 18 + 18 / 2 = $27.00. Tina works 10 hours a day, and 8 hours of that is paid at her regular hourly wage, so she makes 10 - 8 = 2 hours of overtime every day. Since her hourly overtime wage is $27.00, she makes 27 * 2 = $54.00 in overtime each day. Every day, tina makes $144.00 in an 8-hour shift and $54.00 in overtime, so she makes 144 + 54 = $198.00 each day. Tina works 5 days a week, so she makes 198 * 5 = $990.00. The answer is 990.\n\n" 15 | ], 16 | "prefix": "Q: {QUESTION}\nA:" 17 | } -------------------------------------------------------------------------------- /prompts/gsm8k/prompt_pool_rot_task.json: -------------------------------------------------------------------------------- 1 | { 2 | "instruction": "Given a question, please decompose it into sub-questions. For each sub-question, please answer it in a complete sentence, ending with \"The answer is\". When the original question is answerable, please start the subquestion with \"Now we can answer the question: \". To effectively solve math word problems, it's important to follow a structured approach. Here is a general guideline that can be applied to many types of math word problems:\n\n1. Read the Problem Carefully: Understand what the problem is asking. Identify the key information and what the problem is asking you to find.\n\n2. Identify the Variables: Determine what quantities the problem is dealing with and assign symbols if necessary.\n\n3. Translate Words into Math: Convert the words into mathematical expressions or equations using the identified variables.\n\n4. Develop a Plan: Decide on the steps you need to take to solve the problem using the information given.\n\n5. Execute the Plan: Carry out the steps you have decided on to find the solution.\n\n6. Check Your Work: Verify that your solution makes sense in the context of the problem and that you have answered what was asked.", 3 | "interactive_examples": [ 4 | "Question {idx}: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\nQuestion {idx}.1: How many clips did Natalia sell in May?\nAnswer 1.1: Natalia sold 48 clips in April and half as many clips in May, so she sold 48 / 2 = 24 clips in May. The answer is 24.\nQuestion {idx}.2: Now we can answer the question: How many clips did Natalia sell altogether in April and May?\nAnswer 1.2: Natalia sold 48 clips in April and 24 clips in May, so altogether she sold 48 + 24 = 72 clips. The answer is 72.", 5 | "Question {idx}: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?\nQuestion {idx}.1: How much does Weng earn per minute?\nAnswer 1.1: Since Weng earns $12 an hour for babysitting, she earns $12 / 60 = $0.2 per minute. The answer is 0.2.\nQuestion {idx}.2: Now we can answer the question: How much did she earn?\nAnswer 1.2: Working 50 minutes, she earned $0.2 x 50 = $10. The answer is 10.", 6 | "Question {idx}: Betty is saving money for a new wallet which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet?\nQuestion {idx}.1: How much money does Betty have in the beginning?\nAnswer 1.1: In the beginning, Betty has only half of the money she needs, which is 100 / 2 = $50. The answer is 50.\nQuestion {idx}.2: How much money did Betty's grandparents give her?\nAnswer 1.2: Her grandparents gave her twice as much as her parents, so they gave her 15 * 2 = $30. The answer is 30.\nQuestion {idx}.3: Now we can answer the question: How much more money does Betty need to buy the wallet?\nAnswer 1.3: Now that she got $15 from her parents and $30 from her grandparents, she will need $100 - $15 - $30 = $55. Since she already has $50, she needs $55 - $50 = $5 more. The answer is 5.", 7 | "Question {idx}: Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?\nQuestion {idx}.1: How many pages did Julie read today?\nAnswer 1.1: Julie read twice as many pages as yesterday, so she read 12 * 2 = 24 pages The answer is 24.\nQuestion {idx}.2: How many pages did Julie read since yesterday?\nAnswer 1.2: Since yesterday, Julie read 12 + 24 = 36 pages. The answer is 36.\nQuestion {idx}.3: How many pages are left to be read?\nAnswer 1.3: There are 120 - 36 = 84 pages left to be read. The answer is 84.\nQuestion {idx}.4: Now we can answer the question: How many pages should she read?\nAnswer 1.4: She wants to read half of the remaining pages, so she should read 84 / 2 = 42 pages. The answer is 42.", 8 | "Question {idx}: James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?\nQuestion {idx}.1: How many pages does he write every week?\nAnswer 1.1: James writes a 3-page letter to 2 different friends twice a week, so he writes 3 * 2 * 2 = 12 pages every week. The answer is 12.\nQuestion {idx}.2: How many weeks are there in a year?\nAnswer 1.2: There are 52 weeks in a year. The answer is 52.\nQuestion {idx}.3: Now we can answer the question: How many pages does he write a year?\nAnswer 1.3: James writes 12 pages every week, so he writes 12 * 52 = 624 pages a year. The answer is 624.", 9 | "Question {idx}: Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, and there are 80% more of those in purple. There are only 25% as many green flowers as there are yellow and purple flowers. How many flowers does Mark have in his garden?\nQuestion {idx}.1: How many purple flowers are there?\nAnswer 1.1: There are 80% more purple flowers than yellow flowers, so there are 10 * 1.8 = 18 purple flowers. The answer is 18.\nQuestion {idx}.2: How many yellow and purple flowers are there in total?\nAnswer 1.2: There are 10 yellow flowers and 18 purple flowers, so there are 10 + 18 = 28 yellow and purple flowers. The answer is 28.\nQuestion {idx}.3: How many green flowers are there?\nAnswer 1.3: There are 25% as many green flowers as there are yellow and purple flowers, so there are 28 * 0.25 = 7 green flowers. The answer is 7.\nQuestion {idx}.4: Now we can answer the question: How many flowers does Mark have in his garden?\nAnswer 1.4: Mark has 10 yellow flowers, 18 purple flowers, and 7 green flowers, so he has 10 + 18 + 7 = 35 flowers in his garden. The answer is 35.", 10 | "Question {idx}: Albert is wondering how much pizza he can eat in one day. He buys 2 large pizzas and 2 small pizzas. A large pizza has 16 slices and a small pizza has 8 slices. If he eats it all, how many pieces does he eat that day?\nQuestion {idx}.1: How many slices do the large pizzas have?\nAnswer 1.1: He buys 2 large pizzas, so he has 2 * 16 = 32 slices. The answer is 32.\nQuestion {idx}.2: How many slices do the small pizzas have?\nAnswer 1.2: He buys 2 small pizzas, so he has 2 * 8 = 16 slices. The answer is 16.\nQuestion {idx}.3: How many pieces does he eat that day?\nAnswer 1.3: Now we can answer the question: There are 32 slices from the large pizzas and 16 slices from the small pizzas, so he eats 32 + 16 = 48 pieces that day. The answer is 48.", 11 | "Question {idx}: Alexis is applying for a new job and bought a new set of business clothes to wear to the interview. She went to a department store with a budget of $200 and spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt. She also purchased a pair of shoes, but lost the receipt for them. She has $16 left from her budget. How much did Alexis pay for the shoes?\nQuestion {idx}.1: How much did Alexis pay for everything else?\nAnswer 1.1: Alexis spent $30 on a button-up shirt, $46 on suit pants, $38 on a suit coat, $11 on socks, and $18 on a belt, so she spent 30 + 46 + 38 + 11 + 18 = $143 on everything else. The answer is 143.\nQuestion {idx}.2: How much money did Alexis spend in total?\nAnswer 1.2: Alexis had a budget of $200 and finally there was $16 left, so she spent 200 - 16 = $184 in total. The answer is 184.\nQuestion {idx}.3: Now we can answer the question: How much did Alexis pay for the shoes?\nAnswer 1.3: Alexis spent $143 on everything else, so she spent 184 - 143 = $41 on the shoes. The answer is 41." 12 | ], 13 | "useful_examples": [ 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "" 24 | ], 25 | "question_prefix": "Question {idx}: {question}", 26 | "subquestion_prefix": "Question {idx}.{sub_idx}:", 27 | "overall_question_prefix": "Now we can answer the question:", 28 | "answer_prefix": "Answer {idx}.{sub_idx}:" 29 | } -------------------------------------------------------------------------------- /prompts/gsm8k/useful_examples.json: -------------------------------------------------------------------------------- 1 | { 2 | "input": "Given a question and some sub-questions, determine whether the last sub-question is useful to answer the question. Output 'Yes' or 'No', and a reason.\n\nQuestion 1: Four years ago, Kody was only half as old as Mohamed. If Mohamed is currently twice as 30 years old, how old is Kody?\nQuestion 1.1: How old is Mohamed?\nQuestion 1.2: How old was Mohamed four years ago?\nNew question 1.3: How old was Kody four years ago?\nIs the new question useful? Yes. We need the answer to calculate how old is Kody now.\n\nQuestion 2: Traci and Harris are baking cakes together. Traci has brought flour from her own house and Harris has 400g of flour in his house. Each cake needs 100g of flour and Traci and Harris have created 9 cakes each. How much flour, in grams, did Traci bring from her own house?\nNew question 2.1: How many cakes did Traci bring from her own house?\nIs the new question useful? No. The new question is not related to the original question.\n\nQuestion 3: A quantity surveyor is figuring the construction costs for a couple that wishes to build a house. The costs are as follows: land costs $50 per square meter, bricks cost $100 per 1000 bricks and roof tiles cost $10 per roof tile. If the house they wish to build requires 2000 square meters, 10000 bricks, and 500 roof tiles, how much construction costs are required for this project?\nQuestion 3.1: How much does the land cost?\nQuestion 3.2: How much do the bricks cost?\nNew question 3.3: How much do the roof tiles cost?\nIs the new question useful? Yes. We need the answer to calculate the total construction costs.\n\nQuestion 4: Wallace's water heater is twice the size of Catherine's water heater. If the capacity of Wallace's water heater is 40 gallons and it's 3/4 full, calculate the total number of gallons of water they both have if Catherine's water heater is also full with water to 3/4 of its capacity.\nQuestion 4.1: How much water is in Wallace's water heater?\nNew question 4.2: How much water do they have in total?\nIs the new question useful? No. It is too hard to answer the new question based on the current information.\n\n", 3 | "question_prefix": "Question 5: ", 4 | "subquestion_prefix": "Question 5.{}:", 5 | "new_subquestion_prefix": "New question 5.{}:", 6 | "useful_prefix": "Is the new question useful? Answer:" 7 | } 8 | -------------------------------------------------------------------------------- /reasoners/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import WorldModel, SearchConfig, Reasoner, SearchAlgorithm, GenerateOutput, State, Action, Example, Trace, Evaluator 2 | from .benchmark import * 3 | from .vllm_model import VLLMModel -------------------------------------------------------------------------------- /reasoners/algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch, BeamSearchNode, BeamSearchResult 2 | from .mcts import MCTS, MCTSNode, MCTSResult, MCTSAggregation 3 | from .bfs import BFS, BFSNode, BFSResult -------------------------------------------------------------------------------- /reasoners/algorithm/bfs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from os import PathLike 3 | import math 4 | from copy import deepcopy 5 | from typing import Generic, Optional, NamedTuple, Callable, Hashable, List, Tuple 6 | import itertools 7 | from abc import ABC 8 | from collections import defaultdict 9 | 10 | import numpy as np 11 | from tqdm import trange 12 | 13 | from .. import SearchAlgorithm, WorldModel, SearchConfig, State, Action, Example, Trace 14 | 15 | 16 | class BFSNode(Generic[State, Action]): 17 | id_iter = itertools.count() 18 | 19 | @classmethod 20 | def reset_id(cls): 21 | cls.id_iter = itertools.count() 22 | 23 | def __init__(self, state: Optional[State], action: Optional[Action], parent: "Optional[BFSNode]" = None, 24 | fast_reward: float = 0., fast_reward_details=None, 25 | is_terminal: bool = False): 26 | """ 27 | A node in the MCTS search tree 28 | 29 | :param state: the current state 30 | :param action: the action of the last step, i.e., the action from parent node to current node 31 | :param parent: the parent node, None if root of the tree 32 | :param fast_reward: an estimation of the reward of the last step 33 | :param is_terminal: whether the current state is a terminal state 34 | :param calc_q: the way to calculate the Q value from histories. Defaults: np.mean 35 | """ 36 | self.id = next(BFSNode.id_iter) 37 | if fast_reward_details is None: 38 | fast_reward_details = {} 39 | self.cum_rewards: List[float] = [] 40 | self.fast_reward = self.reward = fast_reward 41 | self.fast_reward_details = fast_reward_details 42 | self.is_terminal = is_terminal 43 | self.action = action 44 | self.state = state 45 | self.parent = parent 46 | self.children: 'Optional[List[BFSNode]]' = None 47 | if parent is None: 48 | self.depth = 0 49 | else: 50 | self.depth = parent.depth + 1 51 | 52 | class BFSResult(NamedTuple): 53 | terminal_state: State 54 | trace: Trace 55 | trace_of_nodes: List[BFSNode] 56 | tree_state: BFSNode 57 | 58 | class BFS(SearchAlgorithm, Generic[State, Action, Example]): 59 | def __init__(self, 60 | depth_limit: int = 5, 61 | width: int = 5, 62 | disable_tqdm: bool = True, 63 | node_visualizer: Callable[[BFSNode], dict] = lambda x: x.__dict__): 64 | super().__init__() 65 | self.world_model = None 66 | self.search_config = None 67 | self.terminals = [] 68 | self.width = width 69 | self.depth_limit = depth_limit 70 | self._output_iter: List[BFSNode] = None 71 | self.trace_in_each_iter: List[List[BFSNode]] = None 72 | self.root: Optional[BFSNode] = None 73 | self.disable_tqdm = disable_tqdm 74 | self.node_visualizer = node_visualizer 75 | 76 | def _is_terminal_with_depth_limit(self, node: BFSNode): 77 | return node.is_terminal or node.depth >= self.depth_limit 78 | 79 | def _expand(self, node: BFSNode): 80 | if node.state is None: 81 | node.state, aux = self.world_model.step(node.parent.state, node.action) 82 | # reward is calculated after the state is updated, so that the 83 | # information can be cached and passed from the world model 84 | # to the reward function with **aux without repetitive computation 85 | node.reward, node.reward_details = self.search_config. \ 86 | reward(node.parent.state, node.action, **node.fast_reward_details, **aux) 87 | node.is_terminal = self.world_model.is_terminal(node.state) 88 | 89 | if node.is_terminal: 90 | return 91 | 92 | children = [] 93 | actions = self.search_config.get_actions(node.state) 94 | 95 | if getattr(self.search_config, 'fast_rewards', None) is None: 96 | for action in actions: 97 | fast_reward, fast_reward_details = self.search_config.fast_reward(node.state, action) 98 | child = BFSNode(state=None, action=action, parent=node, 99 | fast_reward=fast_reward, fast_reward_details=fast_reward_details) 100 | children.append(child) 101 | else: 102 | fast_rewards, fast_reward_detailss = self.search_config.fast_rewards(node.state, actions) 103 | for r, d, a in zip(fast_rewards, fast_reward_detailss, actions): 104 | children.append(BFSNode(state=None, action=a, parent=node, 105 | fast_reward=r, fast_reward_details=d)) 106 | 107 | node.children = children 108 | 109 | return children 110 | 111 | def search(self): 112 | self._output_cum_reward = -math.inf 113 | self._output_iter = None 114 | self.root = BFSNode(state=self.world_model.init_state(), action=None, parent=None) 115 | terminals = [] 116 | frontier = [self.root] 117 | for _ in trange(self.depth_limit+1, disable=self.disable_tqdm): 118 | next_frontier = [] 119 | for node in frontier: 120 | children = self._expand(node) 121 | if children is None: 122 | continue 123 | 124 | non_terminal_children = [child for child in children if not child.is_terminal] 125 | next_frontier.extend(non_terminal_children) 126 | 127 | terminals.extend([node for node in frontier if node.is_terminal]) 128 | frontier = sorted(next_frontier, key=lambda x: x.fast_reward, reverse=True)[:self.width] 129 | if len(terminals) == 0: 130 | breakpoint() 131 | best_state = sorted(terminals, key=lambda x: x.reward, reverse=True)[0] 132 | 133 | self._output_iter = [] 134 | 135 | while best_state.parent is not None: 136 | self._output_iter.append(best_state) 137 | best_state = best_state.parent 138 | self._output_iter.append(best_state) 139 | self._output_iter = self._output_iter[::-1] 140 | 141 | def __call__(self, 142 | world_model: WorldModel[State, Action, Example], 143 | search_config: SearchConfig[State, Action, Example], 144 | **kwargs) -> BFSResult: 145 | BFSNode.reset_id() 146 | self.world_model = world_model 147 | self.search_config = search_config 148 | self.search() 149 | 150 | if self._output_iter is None: 151 | terminal_state = trace = None 152 | else: 153 | terminal_state = self._output_iter[-1].state 154 | trace = [node.state for node in self._output_iter], [node.action for node in self._output_iter[1:]] 155 | 156 | result = BFSResult( 157 | terminal_state=terminal_state, 158 | trace=trace, 159 | trace_of_nodes=self._output_iter, 160 | tree_state=self.root 161 | ) 162 | 163 | l = [self.root] 164 | while l: 165 | node = l.pop() 166 | if node.children is not None: 167 | l.extend(node.children) 168 | 169 | return result 170 | -------------------------------------------------------------------------------- /reasoners/base.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar, Union, NamedTuple, Protocol, Optional, runtime_checkable, Tuple 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | from transformers import StoppingCriteriaList 6 | from datetime import datetime 7 | import os, sys, pickle 8 | from tqdm import tqdm 9 | import torch 10 | 11 | State = TypeVar("State") 12 | Action = TypeVar("Action") 13 | Example = TypeVar("Example") 14 | Trace = tuple[list[State], list[Action]] 15 | 16 | 17 | class GenerateOutput(NamedTuple): 18 | text: list[str] 19 | log_prob: list[Union[np.ndarray, float]] = None 20 | 21 | class WorldModel(ABC, Generic[State, Action, Example]): 22 | def __init__(self) -> None: 23 | self.example = None 24 | self.prompt = None 25 | 26 | @abstractmethod 27 | def init_state(self) -> State: ... 28 | 29 | @abstractmethod 30 | def step(self, state: State, action: Action) -> Union[State, Tuple[State, dict]]: 31 | """ Returns the next state and optionally an auxiliary data dict 32 | 33 | :param state: The current state 34 | :param action: The action to take 35 | :return: The next state and optionally an auxiliary data dict 36 | """ 37 | ... 38 | 39 | @abstractmethod 40 | def is_terminal(self, state: State) -> bool: ... 41 | 42 | def update_example(self, example: Example, prompt = None) -> None: 43 | if prompt is not None: 44 | self.prompt = prompt 45 | self.example = example 46 | 47 | 48 | class SearchConfig(ABC, Generic[State, Action, Example]): 49 | def __init__(self) -> None: 50 | self.example = None 51 | self.prompt = None 52 | 53 | @abstractmethod 54 | def get_actions(self, state: State) -> list[Action]: ... 55 | 56 | @abstractmethod 57 | def fast_reward(self, state: State, action: Action) -> tuple[float, dict]: 58 | return 0, {} 59 | 60 | @abstractmethod 61 | def reward(self, state, action, **kwargs) -> tuple[float, dict]: ... 62 | 63 | def update_example(self, example: Example, prompt = None) -> None: 64 | if prompt is not None: 65 | self.prompt = prompt 66 | self.example = example 67 | 68 | 69 | @runtime_checkable 70 | class AlgorithmOutput(Protocol[State]): 71 | terminal_state: State 72 | trace: Trace 73 | 74 | 75 | class SearchAlgorithm(ABC): 76 | def __init__(self, **kwargs): ... 77 | 78 | @abstractmethod 79 | def __call__(self, world_model: WorldModel, search_config: SearchConfig, **kwargs) -> AlgorithmOutput: ... 80 | 81 | 82 | class Reasoner(ABC, Generic[State, Action, Example]): 83 | def __init__(self, 84 | world_model: WorldModel[State, Action, Example], 85 | search_config: SearchConfig[State, Action, Example], 86 | search_algo: SearchAlgorithm) -> None: 87 | self.world_model = world_model 88 | self.search_config = search_config 89 | self.search_algo = search_algo 90 | 91 | def __call__(self, example: Example, prompt = None, **kwargs) -> AlgorithmOutput[State]: 92 | self.world_model.update_example(example, prompt=prompt) 93 | self.search_config.update_example(example, prompt=prompt) 94 | return self.search_algo(self.world_model, self.search_config, **kwargs) 95 | 96 | class Evaluator(): 97 | @abstractmethod 98 | def __init__(self) -> None: 99 | pass 100 | 101 | @abstractmethod 102 | def sample_prompt( 103 | self, 104 | shuffle_prompt, 105 | num_shot, 106 | sample_prompt_type 107 | ): 108 | pass 109 | 110 | def evaluate(self, 111 | reasoner, 112 | shuffle_prompt=True, 113 | num_shot=4, 114 | resume=0, 115 | log_dir=None): 116 | 117 | self.dataset = list(self.full_dataset)[resume:] 118 | try: 119 | algo_name = reasoner.search_algo.__class__.__name__ 120 | except: 121 | algo_name = "unknown" 122 | 123 | 124 | if log_dir is None: 125 | log_dir = f'logs/{self._dataset_name}_'\ 126 | f'{algo_name}/'\ 127 | f'{datetime.now().strftime("%m%d%Y-%H%M%S")}' 128 | os.makedirs(log_dir, exist_ok=True) 129 | os.makedirs(os.path.join(log_dir, 'algo_output'), exist_ok=True) 130 | 131 | with open(os.path.join(log_dir, 'args.txt'), 'w') as f: 132 | print(sys.argv, file=f) 133 | 134 | correct_count = 0 135 | 136 | for i, example in enumerate(tqdm(self.dataset, 137 | total=resume + len(self.dataset), 138 | initial=resume, 139 | desc=self._dataset_name, 140 | disable=self.disable_tqdm)): 141 | 142 | algo_output = reasoner(self.input_processor(example), 143 | prompt=self.sample_prompt( 144 | shuffle_prompt=shuffle_prompt, 145 | num_shot=num_shot, 146 | sample_prompt_type=self.sample_prompt_type)) 147 | 148 | output = self.output_extractor(algo_output) 149 | answer = self.answer_extractor(example) 150 | 151 | correct = self.eval_output(answer, output) 152 | correct_count += correct 153 | accuracy = correct_count / (i + 1) 154 | log_str = f'Case #{resume + i + 1}: {correct=}, {output=}, {answer=};'\ 155 | f'{accuracy=:.3f} ({correct_count}/{i + 1})' 156 | tqdm.write(log_str) 157 | 158 | if not self.disable_log: 159 | with open(os.path.join(log_dir, 'result.log'), 'a') as f: 160 | print(log_str, file=f) 161 | 162 | with open(os.path.join(log_dir, 'algo_output', f'{resume + i + 1}.pkl'), 'wb') as f: 163 | pickle.dump(algo_output, f) 164 | 165 | return accuracy 166 | 167 | @abstractmethod 168 | def eval_output(self, answer, output): 169 | pass -------------------------------------------------------------------------------- /reasoners/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from .gsm8k import GSM8KEvaluator 2 | from .blocksworld import BWEvaluator -------------------------------------------------------------------------------- /reasoners/benchmark/blocksworld.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | from tqdm import tqdm 4 | import torch 5 | import os, pickle 6 | from datetime import datetime 7 | import sys 8 | import random 9 | from reasoners import Evaluator 10 | import copy 11 | 12 | import reasoners.benchmark.bw_utils as bw_utils 13 | 14 | def rap_bw_extractor(algo_output): 15 | if torch.distributed.is_initialized(): 16 | torch.distributed.barrier() 17 | # to make sure the plan is saved before evaluation in multi-process setting 18 | if algo_output.trace is None: 19 | print("No plan found") 20 | return "" 21 | else: 22 | return "\n".join(algo_output.trace[1]) 23 | 24 | def get_icl(init_prompt, examples): 25 | icl = init_prompt["intro"] + \ 26 | "\n".join([ 27 | "[STATEMENT]\nAs initial conditions I have that, " + \ 28 | example["init"] + \ 29 | ".\nMy goal is to have that " +\ 30 | example["goal"] + \ 31 | ".\n\nMy plan is as follows:\n\n[PLAN]" + \ 32 | example["plan"] 33 | for example in examples 34 | ]) 35 | icl += "\n[STATEMENT]\nAs initial conditions I have that, \nMy goal is to \n\nMy plan is as follows:\n\n[PLAN]\n" 36 | return icl 37 | 38 | class BWEvaluator(Evaluator): 39 | def __init__(self, 40 | config_file, 41 | domain_file, 42 | data_path, 43 | init_prompt, 44 | disable_log=False, 45 | disable_tqdm=False, 46 | output_extractor=rap_bw_extractor, 47 | answer_extractor=lambda x:x, 48 | sample_prompt_type="rap") -> None: 49 | 50 | self.init_prompt = init_prompt 51 | self.output_extractor = output_extractor 52 | self.answer_extractor = answer_extractor 53 | self.input_processor = lambda x: x 54 | self.full_dataset = bw_utils.load_blocksworld(config_file, domain_file, data_path) # [{"goal": str, "init": str}] 55 | self._dataset_name = 'blocksworld' 56 | self.disable_log = disable_log 57 | self.disable_tqdm = disable_tqdm 58 | self.sample_prompt_type = sample_prompt_type 59 | 60 | self.lm_plan_file = "tmp_plan.txt" 61 | self.config_file = config_file 62 | self.domain_file = domain_file 63 | 64 | def sample_prompt(self, 65 | shuffle_prompt=True, 66 | num_shot=4, 67 | sample_prompt_type="rap"): 68 | 69 | if sample_prompt_type == "rap": 70 | if shuffle_prompt: 71 | examples = random.sample(self.init_prompt["example_pool"], num_shot) 72 | else: 73 | examples = self.init_prompt["example_pool"][:num_shot] 74 | 75 | icl = get_icl(self.init_prompt, examples) 76 | 77 | prompt = copy.deepcopy(self.init_prompt) 78 | prompt["icl"] = icl 79 | prompt["icl_list"] = [icl] 80 | examples = copy.deepcopy(examples) 81 | for i in range(5): 82 | new_examples = [] 83 | for example in examples: 84 | if len(example["states"]) > 1: 85 | new_examples.append({ 86 | "init": example["states"][0], 87 | "goal": example["goal"], 88 | "plan": "\n" + "\n".join(example["plan"].split("\n")[3:]), 89 | "states": example["states"][1:] 90 | }) 91 | else: 92 | new_examples.append(example) 93 | examples = copy.deepcopy(new_examples) 94 | icl = get_icl(self.init_prompt, examples) 95 | prompt["icl_list"].append(icl) 96 | else: 97 | raise NotImplementedError 98 | # print("prompt:", prompt) 99 | return prompt 100 | 101 | def eval_output(self, answer, output): 102 | bw_utils.text_to_plan_blocksworld(output, answer["instance_file"], self.config_file, self.domain_file, self.lm_plan_file) 103 | correct = bw_utils.validate_plan(self.domain_file, answer["instance_file"], self.lm_plan_file)[0] 104 | return correct -------------------------------------------------------------------------------- /reasoners/benchmark/gsm8k.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import json 3 | from tqdm import tqdm 4 | import torch 5 | import os, pickle 6 | from datetime import datetime 7 | import sys 8 | import random 9 | import copy 10 | from reasoners import Evaluator 11 | 12 | class GSM8KEvaluator(Evaluator): 13 | def __init__(self, 14 | output_extractor, 15 | answer_extractor, 16 | init_prompt=None, 17 | disable_log=False, 18 | disable_tqdm=False, 19 | sample_prompt_type="l2m", 20 | samples=100, 21 | split='test') -> None: 22 | 23 | self.init_prompt = init_prompt 24 | self.output_extractor = output_extractor 25 | self.answer_extractor = answer_extractor 26 | self.input_processor = lambda x: x["question"] 27 | self.full_dataset = datasets.load_dataset('gsm8k', 'main', split=split,) 28 | 29 | self._dataset_name = 'gsm8k' 30 | self.disable_log = disable_log 31 | self.disable_tqdm = disable_tqdm 32 | self.sample_prompt_type = sample_prompt_type 33 | 34 | def sample_prompt(self, 35 | shuffle_prompt=False, 36 | num_shot=4, 37 | sample_prompt_type="l2m"): 38 | 39 | if sample_prompt_type == "l2m": 40 | prompt = {} 41 | if shuffle_prompt: 42 | decomp_examples = random.sample(self.init_prompt["decomposition_pool"], num_shot) 43 | solv_examples = random.sample(self.init_prompt["solving_pool"], num_shot) 44 | else: 45 | decomp_examples = self.init_prompt["decomposition_pool"][:num_shot] 46 | solv_examples = self.init_prompt["solving_pool"][:num_shot] 47 | prompt["decomposition"] = "".join(decomp_examples) + self.init_prompt["composition_prefix"] 48 | prompt["overall"] = "".join(decomp_examples) + self.init_prompt["overall_prefix"] 49 | prompt["solving"] = "".join(solv_examples) + self.init_prompt["solving_prefix"] 50 | 51 | elif sample_prompt_type == "cot": 52 | prompt = {} 53 | if shuffle_prompt: 54 | examples = random.sample(self.init_prompt["cot_pool"], num_shot) 55 | else: 56 | examples = self.init_prompt["cot_pool"][:num_shot] 57 | 58 | if 'guidelines' in self.init_prompt: 59 | prompt["cot"] = self.init_prompt['guidelines']+ '\n' + "".join(examples) + self.init_prompt["prefix"] 60 | else: 61 | prompt["cot"] = "".join(examples) + self.init_prompt["prefix"] 62 | 63 | elif sample_prompt_type == "rap": 64 | 65 | ret = copy.deepcopy(self.init_prompt) 66 | ret['interactive_examples'], ret['useful_examples'] = zip(*random.sample( 67 | list(zip(ret['interactive_examples'], ret['useful_examples'])), k=num_shot)) 68 | if 'answer_examples' in ret: 69 | ret['answer_examples'] = ret['answer_examples'][:num_shot] 70 | ret['subquestion_examples'] = ret['subquestion_examples'][:num_shot] 71 | 72 | return ret 73 | 74 | else: 75 | raise NotImplementedError 76 | return prompt 77 | 78 | def eval_output(self, answer, output): 79 | if output is None: 80 | return False 81 | try: 82 | output = int(output) 83 | answer = int(answer) 84 | return output == answer 85 | except ValueError: 86 | pass 87 | try: 88 | output = float(output) 89 | answer = float(answer) 90 | return output == answer 91 | except ValueError: 92 | pass 93 | return output == answer 94 | -------------------------------------------------------------------------------- /reasoners/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .tree_log import TreeLog, TreeLogEncoder 2 | from .tree_snapshot import TreeSnapshot 3 | from .visualizer_client import VisualizerClient, visualize 4 | from .__main__ import main 5 | -------------------------------------------------------------------------------- /reasoners/visualization/__main__.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | import argparse 3 | 4 | from reasoners.visualization import VisualizerClient 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('tree_log', type=str) 8 | parser.add_argument('--base_url', type=str) 9 | args = parser.parse_args() 10 | 11 | if args.base_url is None: 12 | client = VisualizerClient() 13 | else: 14 | client = VisualizerClient(args.base_url) 15 | 16 | with open(args.tree_log) as f: 17 | data = f.read() 18 | result = client.post_log(data) 19 | print(result.access_url) 20 | -------------------------------------------------------------------------------- /reasoners/visualization/tree_log.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Sequence, Union 3 | 4 | from reasoners.algorithm import MCTSNode, MCTSResult, BeamSearchNode, BeamSearchResult, BFSNode 5 | from reasoners.visualization.tree_snapshot import NodeId, EdgeId, TreeSnapshot, NodeData, EdgeData 6 | 7 | 8 | class TreeLogEncoder(json.JSONEncoder): 9 | def default(self, o): 10 | from numpy import float32 11 | 12 | if isinstance(o, TreeSnapshot.Node): 13 | return o.__dict__ 14 | elif isinstance(o, TreeSnapshot.Edge): 15 | return o.__dict__ 16 | elif isinstance(o, TreeSnapshot): 17 | return o.__dict__() 18 | elif isinstance(o, float32): 19 | return float(o) 20 | if isinstance(o, TreeLog): 21 | return {"logs": list(o)} 22 | else: 23 | return super().default(o) 24 | 25 | 26 | class TreeLog: 27 | def __init__(self, tree_snapshots: Sequence[TreeSnapshot]) -> None: 28 | self._tree_snapshots = tree_snapshots 29 | 30 | def __getitem__(self, item): 31 | return self._tree_snapshots[item] 32 | 33 | def __iter__(self): 34 | return iter(self._tree_snapshots) 35 | 36 | def __len__(self): 37 | return len(self._tree_snapshots) 38 | 39 | def __str__(self): 40 | return json.dumps(self, cls=TreeLogEncoder, indent=2) 41 | 42 | @classmethod 43 | def from_mcts_results(cls, mcts_results: MCTSResult, node_data_factory: callable = None, 44 | edge_data_factory: callable = None) -> 'TreeLog': 45 | 46 | def get_reward_details(n: MCTSNode) -> Union[dict, None]: 47 | if hasattr(n, "reward_details"): 48 | return n.reward_details 49 | return n.fast_reward_details if hasattr(n, "fast_reward_details") else None 50 | 51 | def default_node_data_factory(n: MCTSNode) -> NodeData: 52 | if not n.state: 53 | return NodeData({}) 54 | # transform any object to dict 55 | if hasattr(n.state, "_asdict"): 56 | # if the state is a NamedTuple 57 | state_dict = n.state._asdict() 58 | elif isinstance(n.state, list): 59 | state_dict = {idx: value for idx, value in enumerate(n.state)} 60 | else: 61 | try: 62 | state_dict = dict(n.state) 63 | except TypeError: 64 | raise TypeError("The type of the state is not supported. " 65 | "Please provide a node_data_factory function to transform the state to a dict.") 66 | return NodeData(state_dict) 67 | 68 | def default_edge_data_factory(n: MCTSNode) -> EdgeData: 69 | if isinstance(n, BFSNode): 70 | return EdgeData({"Q": n.reward, "action": n.action, **get_reward_details(n)}) 71 | else: 72 | return EdgeData({"Q": n.Q, "reward": n.reward, **get_reward_details(n)}) 73 | 74 | node_data_factory = node_data_factory or default_node_data_factory 75 | edge_data_factory = edge_data_factory or default_edge_data_factory 76 | 77 | snapshots = [] 78 | 79 | def all_nodes(node: MCTSNode): 80 | node_id = NodeId(node.id) 81 | 82 | nodes[node_id] = TreeSnapshot.Node(node_id, node_data_factory(node)) 83 | if node.children is None: 84 | return 85 | for child in node.children: 86 | edge_id = EdgeId(len(edges)) 87 | edges.append(TreeSnapshot.Edge(edge_id, node.id, child.id, edge_data_factory(child))) 88 | all_nodes(child) 89 | 90 | 91 | if getattr(mcts_results, 'tree_state_after_each_iter', None) is None: 92 | tree_states = [mcts_results.tree_state] 93 | else: 94 | tree_states = mcts_results.tree_state_after_each_iter 95 | for step in range(len(tree_states)): 96 | edges = [] 97 | nodes = {} 98 | 99 | root = tree_states[step] 100 | all_nodes(root) 101 | tree = TreeSnapshot(list(nodes.values()), edges) 102 | 103 | # select edges following the MCTS trace 104 | if getattr(mcts_results, 'trace_in_each_iter', None) is not None: 105 | trace = mcts_results.trace_in_each_iter[step] 106 | for step_idx in range(len(trace) - 1): 107 | in_node_id = trace[step_idx].id 108 | out_node_id = trace[step_idx + 1].id 109 | for edges in tree.out_edges(in_node_id): 110 | if edges.target == out_node_id: 111 | nodes[in_node_id].selected_edge = edges.id 112 | break 113 | 114 | # for all other nodes, select edges with highest Q 115 | for node in tree.nodes.values(): 116 | if node.selected_edge is None and tree.children(node.id): 117 | node.selected_edge = max( 118 | tree.out_edges(node.id), 119 | key=lambda edge: edge.data.get("Q", -float("inf")) 120 | ).id 121 | 122 | snapshots.append(tree) 123 | 124 | return cls(snapshots) 125 | 126 | @classmethod 127 | def from_beam_search_results(cls, bs_results: Union[BeamSearchResult, Sequence[BeamSearchResult]], 128 | node_data_factory: callable = None, edge_data_factory: callable = None) -> 'TreeLog': 129 | 130 | if isinstance(bs_results, BeamSearchResult): 131 | bs_results = [bs_results] 132 | bs_results = bs_results[0] 133 | 134 | def default_node_data_factory(n: BeamSearchNode) -> NodeData: 135 | if not n.state: 136 | return NodeData({}) 137 | # transform any object to dict 138 | if hasattr(n.state, "_asdict"): 139 | # if the state is a NamedTuple 140 | state_dict = n.state._asdict() 141 | elif isinstance(n.state, list): 142 | state_dict = {idx: value for idx, value in enumerate(n.state)} 143 | else: 144 | try: 145 | state_dict = dict(n.state) 146 | except TypeError: 147 | raise TypeError("The type of the state is not supported. " 148 | "Please provide a node_data_factory function to transform the state to a dict.") 149 | return NodeData(state_dict) 150 | 151 | def default_edge_data_factory(n: BeamSearchNode) -> EdgeData: 152 | return EdgeData({"reward": n.reward, "action": n.action}) 153 | 154 | node_data_factory = node_data_factory or default_node_data_factory 155 | edge_data_factory = edge_data_factory or default_edge_data_factory 156 | 157 | snapshots = [] 158 | 159 | def all_nodes(node: BeamSearchNode): 160 | node_id = NodeId(node.id) 161 | 162 | nodes[node_id] = TreeSnapshot.Node(node_id, node_data_factory(node)) 163 | for child in node.children: 164 | edge_id = EdgeId(len(edges)) 165 | edges.append(TreeSnapshot.Edge(edge_id, node.id, child.id, edge_data_factory(child))) 166 | all_nodes(child) 167 | 168 | root = bs_results.tree 169 | edges = [] 170 | nodes = {} 171 | all_nodes(root) 172 | tree = TreeSnapshot(list(nodes.values()), edges) 173 | 174 | # select edges with highest reward 175 | for node in tree.nodes.values(): 176 | if node.selected_edge is None and tree.children(node.id): 177 | node.selected_edge = max( 178 | tree.out_edges(node.id), 179 | key=lambda edge: edge.data.get("reward", -float("inf")) 180 | ).id 181 | 182 | snapshots.append(tree) 183 | 184 | return cls(snapshots) 185 | -------------------------------------------------------------------------------- /reasoners/visualization/tree_snapshot.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import NewType, Optional, Collection 4 | 5 | NodeId = NewType("NodeId", int) 6 | EdgeId = NewType("EdgeId", int) 7 | NodeData = NewType("NodeData", dict) 8 | EdgeData = NewType("EdgeData", dict) 9 | 10 | 11 | class TreeSnapshot: 12 | @dataclass 13 | class Node: 14 | id: NodeId 15 | data: NodeData 16 | selected_edge: Optional[EdgeId] = None 17 | 18 | @dataclass 19 | class Edge: 20 | id: EdgeId 21 | source: NodeId 22 | target: NodeId 23 | data: EdgeData 24 | 25 | def __init__(self, nodes: Collection[Node], edges: Collection[Edge]) -> None: 26 | self.nodes: dict[NodeId, TreeSnapshot.Node] = {node.id: node for node in nodes} 27 | self.edges: dict[EdgeId, TreeSnapshot.Edge] = {edge.id: edge for edge in edges} 28 | self._parent = {} 29 | self._children: dict[NodeId, set[NodeId]] = defaultdict(set) 30 | 31 | for edge in edges: 32 | self._parent[edge.target] = edge.source 33 | self._children[edge.source].add(edge.target) 34 | 35 | assert len(self._parent) == len(self.nodes) - 1 36 | assert self._connected() 37 | 38 | def _connected(self) -> bool: 39 | visited = set() 40 | queue = [next(iter(self.nodes))] 41 | while queue: 42 | node = queue.pop() 43 | visited.add(node) 44 | queue.extend(self._children[node] - visited) 45 | return len(visited) == len(self.nodes) 46 | 47 | def node(self, node_id: NodeId) -> Node: 48 | return self.nodes[node_id] 49 | 50 | def edge(self, edge_id: EdgeId) -> Edge: 51 | return self.edges[edge_id] 52 | 53 | def out_edges(self, node_id: NodeId) -> Collection[Edge]: 54 | return [self.edge(edge_id) for edge_id in self.edges if self.edge(edge_id).source == node_id] 55 | 56 | def in_edges(self, node_id: NodeId) -> Collection[Edge]: 57 | return [self.edge(edge_id) for edge_id in self.edges if self.edge(edge_id).target == node_id] 58 | 59 | def parent(self, node_id: NodeId) -> NodeId: 60 | return self._parent[node_id] 61 | 62 | def children(self, node_id: NodeId) -> Collection[NodeId]: 63 | return self._children[node_id] 64 | 65 | def trace(self, node_id: NodeId) -> [NodeId]: 66 | trace_ = [] 67 | while node_id != 0: 68 | trace_.append(node_id) 69 | node_id = self.parent(node_id) 70 | 71 | return trace_[::-1] 72 | 73 | def __dict__(self): 74 | return { 75 | "nodes": self.nodes, 76 | "edges": self.edges, 77 | } 78 | -------------------------------------------------------------------------------- /reasoners/visualization/visualizer_client.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | from typing import Optional, Union 4 | 5 | import requests 6 | 7 | from reasoners.algorithm import MCTSResult, BeamSearchResult 8 | from reasoners.visualization import TreeLog, TreeLogEncoder 9 | 10 | _API_DEFAULT_BASE_URL = "https://2wz3t0av30.execute-api.us-west-1.amazonaws.com/staging" 11 | _VISUALIZER_DEFAULT_BASE_URL = "https://www.llm-reasoners.net" 12 | 13 | 14 | class VisualizerClient: 15 | def __init__(self, base_url: str = _API_DEFAULT_BASE_URL) -> None: 16 | self.base_url = base_url 17 | 18 | @dataclasses.dataclass 19 | class TreeLogReceipt: 20 | id: str 21 | access_key: str 22 | 23 | @property 24 | def access_url(self) -> str: 25 | return f"{_VISUALIZER_DEFAULT_BASE_URL}/visualizer/{self.id}?accessKey={self.access_key}" 26 | 27 | def post_log(self, data: Union[TreeLog, str, dict]) -> Optional[TreeLogReceipt]: 28 | if isinstance(data, TreeLog): 29 | data = json.dumps(data, cls=TreeLogEncoder) 30 | if isinstance(data, dict): 31 | data = json.dumps(data) 32 | 33 | url = f"{self.base_url}/logs" 34 | headers = {'Content-Type': 'application/json'} 35 | response = requests.post(url, headers=headers, data=data) 36 | 37 | if response.status_code != 200: 38 | print(f"POST Log failed with status code: {response.status_code}, message: {response.text}") 39 | return None 40 | 41 | return self.TreeLogReceipt(**response.json()) 42 | 43 | 44 | def present_visualizer(receipt: VisualizerClient.TreeLogReceipt): 45 | import webbrowser 46 | print(f"Visualizer URL: {receipt.access_url}") 47 | webbrowser.open(receipt.access_url) 48 | 49 | 50 | def visualize(result: Union[TreeLog, MCTSResult, BeamSearchResult], **kwargs): 51 | tree_log: TreeLog 52 | 53 | if isinstance(result, TreeLog): 54 | tree_log = result 55 | elif isinstance(result, MCTSResult): 56 | tree_log = TreeLog.from_mcts_results(result, **kwargs) 57 | elif isinstance(result, BeamSearchResult): 58 | tree_log = TreeLog.from_beam_search_results(result, **kwargs) 59 | elif isinstance(result, ...): 60 | raise NotImplementedError() 61 | else: 62 | raise TypeError(f"Unsupported result type: {type(result)}") 63 | 64 | receipt = VisualizerClient().post_log(tree_log) 65 | 66 | if receipt is not None: 67 | present_visualizer(receipt) 68 | -------------------------------------------------------------------------------- /reasoners/vllm_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import Optional, Union 4 | import time 5 | import httpx 6 | from . import GenerateOutput 7 | 8 | from requests_futures.sessions import FuturesSession 9 | 10 | API_BASE = os.getenv("VLLM_API_BASE", None) 11 | 12 | class VLLMModel: 13 | def __init__(self, model: str, max_tokens:int = 200, temperature=0.7): 14 | self.model = model 15 | self.max_tokens = max_tokens 16 | self.temperature = temperature 17 | 18 | 19 | def generate(self, 20 | prompt: str, 21 | max_tokens: int = None, 22 | num_return_sequences: int = 1, 23 | stop: Optional[str] = None, 24 | temperature = None, 25 | logprobs=False, 26 | **kwargs) -> GenerateOutput: 27 | if type(prompt) == str: 28 | prompt = [prompt] 29 | 30 | gpt_temperature = self.temperature if temperature is None else temperature 31 | 32 | if max_tokens is None: 33 | max_tokens = self.max_tokens 34 | 35 | try: 36 | session = FuturesSession() 37 | url = API_BASE + "/completions" 38 | 39 | rs = [] 40 | for p in prompt: 41 | data = { 42 | 'model': self.model, 43 | 'prompt': p, 44 | 'temperature': gpt_temperature, 45 | 'n': num_return_sequences, 46 | 'max_tokens': max_tokens, 47 | 'stop': stop, 48 | 'logprobs': logprobs 49 | } 50 | rs.append(session.post(url, json=data, timeout=60)) 51 | 52 | rs = [r.result() for r in rs] 53 | responses = [r.json() for r in rs] 54 | session.close() 55 | 56 | log_prob = None 57 | if logprobs: 58 | log_prob = [ 59 | sum(choice["logprobs"]["token_logprobs"][:-1])/(len(choice['logprobs']['token_logprobs'])-1) 60 | for r in responses for choice in r["choices"] 61 | ] 62 | 63 | return GenerateOutput( 64 | text=[choice["text"] for r in responses for choice in r["choices"]], 65 | log_prob=log_prob 66 | ) 67 | 68 | except Exception as e: 69 | print(rs) 70 | raise e 71 | 72 | def get_next_token_logits(self, 73 | prompt: Union[str, list[str]], 74 | candidates: Union[list[str], list[list[str]]], 75 | **kwargs) -> list[np.ndarray]: 76 | if isinstance(prompt, str): 77 | prompt = [prompt] 78 | 79 | session = FuturesSession() 80 | 81 | res = [] 82 | for p in prompt: 83 | data = { 84 | 'model': self.model, 85 | 'prompt': p, 86 | 'n': 1, 87 | 'max_tokens': 1, 88 | 'logprobs': 1000, 89 | } 90 | 91 | r = session.post(API_BASE + "/completions", json=data, timeout=60) 92 | res.append(r) 93 | 94 | responses = [r.result().json() for r in res] 95 | results = [] 96 | for r in responses: 97 | cand_log_probs = [] 98 | top_log_probs = r["choices"][0]["logprobs"]["top_logprobs"][0] 99 | for cand in candidates: 100 | r = -1000 101 | for key in top_log_probs: 102 | if cand in key and top_log_probs[key] > r: 103 | r = top_log_probs[key] 104 | 105 | cand_log_probs.append(r) 106 | 107 | results.append(np.array(cand_log_probs)) 108 | 109 | session.close() 110 | 111 | return results 112 | 113 | def get_loglikelihood(self, 114 | prefixs: list[str], 115 | prompts: list[str], 116 | **kwargs) -> list[np.ndarray]: 117 | 118 | session = FuturesSession() 119 | res = [] 120 | 121 | for prefix, prompt in zip(prefixs, prompts): 122 | 123 | data = { 124 | 'prompt': prompt, 125 | 'prefix': prefix, 126 | } 127 | 128 | r = session.post(API_BASE + "/logprobs", json=data, timeout=60) 129 | res.append(r) 130 | 131 | responses = [r.result().json() for r in res] 132 | results = [] 133 | 134 | for r in responses: 135 | results.append(sum(r["logprobs"])) 136 | 137 | session.close() 138 | 139 | return np.array(results) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | uvicorn 2 | datasets 3 | fastapi 4 | fire 5 | httpx 6 | numpy 7 | scipy 8 | openai==0.28.0 9 | requests~=2.31.0 10 | PyYAML~=6.0 11 | tarski~=0.8.2 12 | pddl==0.4.0 13 | requests_futures 14 | torch 15 | tqdm 16 | transformers==4.38.2 17 | vllm==0.3.0 18 | pyvis -------------------------------------------------------------------------------- /rot_bargain/.gitignore: -------------------------------------------------------------------------------- 1 | mixtral-vllm 2 | logs 3 | data 4 | __pycache__ 5 | analyze_mcts_log.py 6 | mcts_vis.ipynb 7 | mcts.pkl 8 | lib 9 | *.json 10 | apis 11 | *.html 12 | *.pkl 13 | 14 | !prompts/*.json -------------------------------------------------------------------------------- /rot_bargain/README.md: -------------------------------------------------------------------------------- 1 | # RoT of CraigslistBargain 2 | 3 | ## Quick Start 4 | Similar to RoT with deterministic environment, we use [vllm](https://github.com/vllm-project/vllm) to support efficient text generation, so you need to launch a vllm when using a deployed model. 5 | ```bash 6 | cd ../vllm-server 7 | sh mixtral.sh 8 | ``` 9 | 10 | Then you can run RoT to generate the new prompts with guidelines based on the served model. 11 | ```bash 12 | python bargain_control.py \ 13 | --mode mcts --n_iter 8 \ 14 | --seller mixtral \ # the model name of the seller 15 | --stage reflect \ 16 | --rot_output_path prompts/rot.json # path to the file to store the guidelines. 17 | ``` 18 | 19 | Finally run LLM with the guidelines: 20 | ```bash 21 | python bargain_control.py \ 22 | --mode mcts --n_iter 8 \ 23 | --seller mixtral --stage evaluate \ 24 | --prompt_path prompts/rot.json 25 | ``` -------------------------------------------------------------------------------- /rot_bargain/bargain_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from datetime import datetime 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--seller', type=str, default='mixtral') 9 | parser.add_argument('--buyer', type=str, default='gpt4') 10 | parser.add_argument('--n_iters', type=int, default=-1) 11 | parser.add_argument('--panelty', type=float, default=0.0) 12 | parser.add_argument('--mode', type=str, default='mcts') 13 | parser.add_argument('--stage', type=str, default='evaluate') 14 | parser.add_argument('--prompt_path', type=str, default='none') 15 | parser.add_argument('--rot_output_path', type=str, default='prompts/rot.json') 16 | args = parser.parse_args() 17 | 18 | if args.mode == 'mcts': 19 | command = f'python run/run_mcts.py --stage {args.stage}' 20 | else: 21 | command = f'python run/run_cot.py --stage {args.stage}' 22 | 23 | command += f' --buyer {args.buyer} --seller {args.seller}' 24 | 25 | if args.n_iters != -1: 26 | command += f' --n_iters {args.n_iters}' 27 | 28 | command += f' --prompt_path {args.prompt_path} --panelty {args.panelty}' 29 | 30 | log_path = f'logs/{args.mode}_{args.buyer}_{args.seller}_{args.prompt_path.split("/")[-1]}_{args.n_iters}_{args.panelty}_{datetime.now().strftime("%Y%m%d-%H%M")}' 31 | 32 | command += f' --log_path {log_path}' 33 | print(command) 34 | 35 | import time 36 | t = time.time() 37 | os.system(command) 38 | time_consumed = time.time() - t 39 | with open(f'{log_path}/time_consumed.txt', 'w') as f: 40 | f.write(str(time_consumed)) 41 | 42 | if args.stage == 'reflect': 43 | os.system(f'python mcts_reflect.py --path {log_path} --output-path {args.rot_output_path}') -------------------------------------------------------------------------------- /rot_bargain/core.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import json 3 | import random 4 | import re 5 | import copy 6 | 7 | class DialogueGame: 8 | def __init__(self, moderator, agents, max_turn, last_message, history) -> None: 9 | self.moderator = moderator 10 | self.agents = agents 11 | self.max_turn = max_turn 12 | self.current_turn = 0 13 | self.history = history 14 | self.last_message = last_message 15 | 16 | def step(self, **kwargs): 17 | agent = self.agents[self.current_turn % 2] 18 | self.last_message = agent.chat(**kwargs) 19 | self.history.append(self.last_message) 20 | 21 | if self.current_turn % 2 == 1: 22 | finished, res = self.moderator.generate_result(self.history) 23 | else: 24 | finished, res = False, 0 25 | 26 | self.current_turn += 1 27 | if self.current_turn >= self.max_turn: 28 | return True, res 29 | 30 | if finished: 31 | return True, res 32 | 33 | return False, None 34 | 35 | def step_and_state(self, **kwargs): 36 | finished, res = self.step(**kwargs) 37 | if finished: 38 | reward = self.compute_reward(res) 39 | else: 40 | reward = 0 41 | return finished, reward, self.gamestate() 42 | 43 | def compute_reward(self, res): 44 | return self.moderator.compute_result(res)['utility'] 45 | 46 | def get_valid_actions(self): 47 | agent = self.agents[self.current_turn % 2] 48 | return agent.get_valid_actions() 49 | 50 | def predict(self): 51 | agent = self.agents[self.current_turn % 2] 52 | return agent.predict() 53 | 54 | def run(self): 55 | while True: 56 | finished, res = self.step() 57 | if finished: 58 | break 59 | return res 60 | 61 | @classmethod 62 | def resume_gamestate(cls, gamestate): 63 | game = cls(gamestate.moderator, gamestate.agents, gamestate.max_turn, gamestate.history[-1], gamestate.history) 64 | game.current_turn = gamestate.current_turn 65 | 66 | for agent in game.agents: 67 | agent.game = game 68 | 69 | new_game = copy.copy(game) 70 | new_game.history = game.history[:] 71 | new_game.agents = gamestate.agents 72 | 73 | for a in new_game.agents: 74 | a.game = new_game 75 | 76 | return new_game 77 | 78 | def gamestate(self): 79 | return DialogueGameState(self.history[:], self.moderator, self.agents, self.max_turn, self.current_turn) 80 | 81 | class DialogueGameState: 82 | def __init__(self, history, moderator, agents, max_turn, current_turn) -> None: 83 | self.history = history 84 | self.agents = agents 85 | self.max_turn = max_turn 86 | self.current_turn = current_turn 87 | self.moderator = moderator 88 | 89 | def __str__(self): 90 | return json.dumps(self.history) 91 | 92 | def __hash__(self): 93 | return hash(str(self)) 94 | 95 | def __eq__(self, o: object) -> bool: 96 | return str(self) == str(o) 97 | 98 | def copy(self): 99 | gamestates = copy.copy(self) 100 | gamestates.history = self.history[:] 101 | return gamestates 102 | 103 | class BarginRollout: 104 | def __init__(self): 105 | self.agent = None 106 | 107 | def simulate(self, state): 108 | history = state.history 109 | history = '\n'.join(['{}: {}'.format('Buyer' if i % 2 == 0 else 'Seller', m) for i, m in enumerate(history)]) 110 | 111 | messages = [ 112 | { 113 | "role": "system", 114 | "content": "You are GPT4 developed by OpenAI. Now please act as a bargining dialogue simulator to simulate a bargining where the seller and buyer are maximizing their profit." 115 | }, 116 | { 117 | "role": "user", 118 | "content": f"Dialogue History:\n{history}\n\nPlease complete the dialogue. Note both the seller and buyer are not require to reach an agreement. After the dialogue is completed, show whether this deal is made in [YES] or [NO]. If the answer is [YES], please also show the price in <$xx>. Your completion should start with Buyer. I will tip you $100 if you complete this task successfully." 119 | } 120 | ] 121 | 122 | for i in range(10): 123 | try: 124 | dialogue_history = None 125 | price = 0 126 | if self.agent.model == 'gpt4': 127 | res = query_gpt4(messages) 128 | elif self.agent.model == 'chatgpt': 129 | res = query_chatgpt(messages) 130 | else: 131 | res = query_mixtral(messages) 132 | 133 | if res.startswith('Seller'): 134 | continue 135 | 136 | dialogue_history = res.split('[YES]')[0].split('[NO]')[0].strip() 137 | if '[YES]' in res.upper(): 138 | price = re.findall('<\$[0-9,\.]+>', res)[0].replace(',', '') 139 | # find the number in the price and change it into float 140 | price = float(re.findall('[0-9\.]+', price)[0]) 141 | break 142 | elif '[NO]' in res.upper(): 143 | break 144 | except: 145 | pass 146 | if i == 10: 147 | return '', 0 148 | 149 | r = self.agent.game.moderator.compute_result(price)['utility'] 150 | 151 | return dialogue_history, r 152 | 153 | 154 | class BarginingModerator(): 155 | def __init__(self, data, mode, panelty) -> None: 156 | self.data = data 157 | self.mode = mode 158 | self.panelty = panelty 159 | 160 | def generate_result(self, history): 161 | history = history 162 | history = '\n'.join(['{}: {}'.format('Buyer' if i % 2 == 0 else 'Seller', m) for i, m in enumerate(history)]) 163 | 164 | messages_2 = [ 165 | { 166 | "role": "system", 167 | "content": "You are GPT4, a large language model trained by OpenAI. Now please act as a bargining system to determine whether the seller and buyer has reached an agreement." 168 | }, 169 | { 170 | "role": "user", 171 | "content": f"Dialogue History:\n{history}\n\n\nPlease first summarize the buyer and seller's opinion and determine whether the buyer and seller has reach an agreement in [YES] or [NO]. If an agreement is reached, please also provide the price of the item in <$XX>. You should follow this format:\nReview: ...\nResult: [YES]/[NO] <$...>" 172 | } 173 | ] 174 | 175 | res = query_gpt4(messages_2) 176 | 177 | finished = 'YES' in res 178 | 179 | if finished: 180 | try: 181 | price = re.findall('<\$[0-9,\.]+>', res)[0].replace(',', '') 182 | # find the number in the price and change it into float 183 | price = float(re.findall('[0-9\.]+', price)[0]) 184 | return finished, price 185 | except: 186 | return False, 0 187 | return finished, 0 188 | 189 | def compute_result(self, price): 190 | seller_price = self.data['seller-price'] 191 | buyer_price = self.data['buyer-price'] 192 | 193 | return { 194 | 'utility': (2*price - seller_price - buyer_price) / (seller_price - buyer_price) if price != 0 else self.panelty, 195 | 'agreement': 1 if price != 0 else 0, 196 | } -------------------------------------------------------------------------------- /rot_bargain/mcts_reflect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | import numpy as np 5 | from utils import * 6 | 7 | 8 | def covnert_graph(relations, nodes): 9 | old_relations = relations 10 | 11 | actions_nodes = defaultdict(dict) 12 | for key, value in old_relations.items(): 13 | if '__' in key: 14 | node = key.split('__')[0] 15 | action = key.split('__')[1] 16 | 17 | if action not in actions_nodes[node]: 18 | actions_nodes[node][action] = { 19 | 'value': 0, 20 | 'children': [] 21 | } 22 | actions_nodes[node][action]['children'].extend(value) 23 | 24 | return actions_nodes 25 | 26 | def compute_action_value(relations, nodes): 27 | for node, actions in relations.items(): 28 | for action, info in actions.items(): 29 | info['value'] = sum([nodes[child]['value'] for child in info['children']]) / len(info['children']) 30 | 31 | return relations 32 | 33 | def find_good_actions_nodes(action_nodes): 34 | node_value = { 35 | node: [action_nodes[node][a]['value'] for a in action_nodes[node]] for node in action_nodes 36 | } 37 | 38 | node_avg = { 39 | node: np.mean(node_value[node]) for node in node_value 40 | } 41 | 42 | node_var = { 43 | node: np.var(node_value[node]) for node in node_value 44 | } 45 | 46 | node_max = { 47 | node: np.max(node_value[node]) for node in node_value 48 | } 49 | 50 | node_gain = { 51 | node: node_max[node] - node_avg[node] for node in node_value 52 | } 53 | 54 | return node_avg, node_var, node_max, node_gain 55 | 56 | def merge_guidelines(guidelines): 57 | 58 | messages = [{ 59 | "role": "user", 60 | "content": "Please merge the following policies into one:\n" + "\n".join(guidelines) 61 | }] 62 | 63 | res = query_gpt4(messages) 64 | return res 65 | 66 | def process(args): 67 | gpt4_analysis = [] 68 | 69 | path = args.path 70 | 71 | bargains = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))] 72 | 73 | 74 | for i in bargains: 75 | p = os.path.join(path, i) 76 | 77 | relations = json.load(open(f'{p}/relationship.json')) 78 | nodes = json.load(open(f'{p}/node_info.json')) 79 | background = json.load(open(f'{p}/game.json')) 80 | 81 | recover_dialogue(relations, nodes, 'node_0', background['history'][:2]) 82 | action_nodes = covnert_graph(relations, nodes) 83 | 84 | action_nodes = compute_action_value(action_nodes, nodes) 85 | good_nodes = find_good_actions_nodes(action_nodes) 86 | 87 | selected_nodes = [k for k in good_nodes[3] if good_nodes[3][k] > 0.1] 88 | 89 | for n in selected_nodes: 90 | res = query_gpt4_for_advise_node(action_nodes, n, nodes) 91 | 92 | gpt4_analysis.append({ 93 | 'node': n, 94 | 'res': res.split('Summary:')[1].strip() 95 | }) 96 | 97 | guidelines = merge_guidelines([g['res'] for g in gpt4_analysis]) 98 | 99 | with open(args.output_path, 'w') as f: 100 | json.dump(guidelines, f, indent=4) 101 | 102 | 103 | def query_gpt4_for_advise_node(action_nodes, node, nodes): 104 | history = nodes[node]['history'] 105 | history = '\n\n'.join([f'Buyer: {h}' if i % 2 == 0 else f'Seller: {h}' for i, h in enumerate(history)]) + '\n\n' 106 | 107 | strategies = action_nodes[node].keys() 108 | strategies = '\n'.join([f'Strategy {i}: {s} Value: {action_nodes[node][s]["value"]}' for i, s in enumerate(strategies) if s != 'history']) 109 | 110 | prompt = f"Dialogue history:\n{history}\nThe strategies of the seller used to reply in the current response and their rewards are listed below:\n{strategies}\n\n" + \ 111 | f"Can you analyze the reason and then summarize the findings into a policy in one sentence in the format of Analysis: ...\nSummary:\n..." 112 | 113 | res = query_gpt4([ 114 | { 115 | 'role': 'system', 116 | 'content': 'You are GPT4 trained by OpenAI. Now please act as a dialogue analyzer who can evaluate the behaviour of the seller in a bargining dialogue.', 117 | }, 118 | { 119 | 'role': 'user', 120 | 'content': prompt 121 | } 122 | ]) 123 | return res 124 | 125 | 126 | import argparse 127 | 128 | def recover_dialogue(relation, node, current_node, history): 129 | node[current_node]['history'] = history + [node[current_node]['response']] 130 | 131 | for r in relation: 132 | if r.split('__')[0] == current_node: 133 | for child in relation[r]: 134 | recover_dialogue(relation, node, child, history) 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--path', type=str, default='') 139 | parser.add_argument('--output-path', type=str, default='guidelines.json') 140 | args = parser.parse_args() 141 | 142 | process(args) -------------------------------------------------------------------------------- /rot_bargain/prompts/rot_w_panelty.json: -------------------------------------------------------------------------------- 1 | "To create a negotiation strategy that places a higher value on the agreement of a bargain, consider the following approach:\n\n1. Establish the Value Proposition: Begin by clearly communicating the unique benefits and features of your product. Ensure that the buyer understands the long-term advantages and potential cost savings associated with the purchase. This sets the stage for a value-driven conversation rather than one focused solely on price.\n\n2. Build Strong Relationships: Prioritize the development of a positive, collaborative relationship with the buyer. Show genuine interest in their needs and how your product can meet them. A strong rapport can create a foundation of trust and make the buyer more receptive to your value proposition.\n\n3. Leverage Social Proof: Use testimonials and case studies from satisfied customers to underscore the value of your product. This can help to validate your asking price and demonstrate the satisfaction of past clients, making the perceived value more tangible.\n\n4. Highlight Urgency and Exclusivity: Create a sense of urgency by indicating that the availability of the product at this level of value is limited. Additionally, make the buyer feel exclusive by suggesting that this offer is unique to them or to a select group of customers.\n\n5. Offer Non-Price Concessions: If price becomes a sticking point, shift the focus to non-price concessions that can add value to the deal without lowering the price. This might include extended warranties, additional services, or flexible delivery options.\n\n6. Anchor Your Price: Start negotiations with a price that is higher than what you expect to receive. This sets a psychological benchmark for the negotiation, making your actual desired price seem more reasonable by comparison.\n\n7. Emphasize Confidence in the Product: Throughout the negotiation, maintain a strong belief in the worth of your product. This confidence can be contagious and can influence the buyer's perception of value.\n\n8. Maintain Flexibility: While you should negotiate from a position of strength, be prepared to make strategic concessions that can lead to an agreement. Flexibility demonstrates a willingness to work with the buyer and can help close the deal.\n\n9. Focus on the Win-Win Outcome: Frame the negotiation as a partnership aimed at achieving a mutually beneficial outcome. When the buyer sees the agreement as a win for both sides, they are more likely to value the deal and less likely to fixate on price alone.\n\n10. Close with Confidence: When you sense that the buyer acknowledges the value of your product and is close to making a decision, confidently guide them towards finalizing the deal. Reiterate the key benefits and the unique opportunity the agreement represents.\n\nBy adopting this strategy, you shift the focus from price to the overall value of the agreement, fostering a negotiation environment where both parties feel they are getting a fair deal." -------------------------------------------------------------------------------- /rot_bargain/prompts/rot_wo_panelty.json: -------------------------------------------------------------------------------- 1 | "You should emphasize the demand for their product, thus preserving the value perception without substantial price compromises. It's important to maintain this perception of value by offering non-price concessions and using strategies that highlight urgency, exclusivity, and added value, thereby negotiating from a position of strength. Confidence in the product’s inherent value should be the cornerstone of a seller’s negotiation policy, which should also prioritize building a strong rapport with the buyer, facilitating a transaction that feels mutually beneficial. Additionally, sellers should focus on articulating the long-term value and potential cost savings of their product, leveraging testimonials and the anchoring effect to support their pricing structure. All of these strategies should be employed with a degree of negotiation flexibility, aiming to secure a final price that approaches the original asking price as closely as possible." -------------------------------------------------------------------------------- /rot_bargain/run/run_cot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir)) 4 | from core import * 5 | from user_simulators.naive_simulator import * 6 | from user_simulators.strategy_simulator import * 7 | from user_simulators.mcts_simulator import * 8 | import json 9 | import argparse 10 | 11 | data = json.load(open('data/CraigslistBargain/dev-luis-post.json')) 12 | 13 | stat = {'utility': 0, 'agreement': 0} 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--log_path', type=str) 18 | parser.add_argument('--buyer', type=str) 19 | parser.add_argument('--seller', type=str) 20 | parser.add_argument('--prompt_path', type=str, default='none') 21 | parser.add_argument('--stage', type=str, default='evaluate') 22 | 23 | args = parser.parse_args() 24 | 25 | if args.stage == 'reflect': 26 | data = data[:2] 27 | else: 28 | data = data[2:20] 29 | 30 | for i, d in enumerate(data): 31 | buyer = UserSimulator('You are GPT4, trained by OpenAI. Now enter role-playing mode. In the following conversation, you will play as a buyer in a price bargaining game. Please make your response short and succinct.', [{ 32 | 'role': 'user', 33 | 'content': f'You are the buyer who is trying to buy a product at the price of ${d["buyer-price"]}. The product information are described as follows: \n\nTitle:{d["title"]}\nDescription:\n{d["description"]}\n\n\nNow start the game.' 34 | }], model=args.buyer) 35 | 36 | if args.prompt_path == 'none': 37 | seller = UserSimulator('You are GPT4, trained by OpenAI. Now enter role-playing mode. In the following conversation, you will play as a seller in a price bargaining game. Please make your response short and succinct.', [{ 38 | 'role': 'user', 39 | 'content': f'You are the seller who is trying to sell a product at the price of ${d["seller-price"]}. The product information are described as follows: \n\nTitle:{d["title"]}\nDescription:\n{d["description"]}\n\nNow start the game.\nHow much is this product?' 40 | },], history_start_idx=1, model=args.seller) 41 | 42 | else: 43 | strategy = json.load(open(f'prompt/{args.prompt}')) 44 | 45 | seller = FixedStrategyUserSimulator('You are GPT4, trained by OpenAI. Now enter role-playing mode. In the following conversation, you will play as a seller in a price bargaining game. Please make your response short and succinct.', [{ 46 | 'role': 'user', 47 | 'content': f'You are the seller who is trying to sell a product at the price of ${d["seller-price"]}. The product information are described as follows: \n\nTitle:{d["title"]}\nDescription:\n{d["description"]}\n\nNow start the game.\nHow much is this product?' 48 | }], history_start_idx=1, strategy=strategy, model=args.seller) 49 | 50 | 51 | res = f'Hi, its price is ${d["seller-price"]}.' 52 | 53 | history = [ 54 | 'How much is this product?', 55 | f'Hi, its price is ${d["seller-price"]}.' 56 | ] 57 | 58 | moderator = BarginingModerator(d, mode='seller', panelty=0.0) 59 | 60 | game = DialogueGame(moderator, [buyer, seller], max_turn=8, last_message=res, history=history) 61 | 62 | seller.game = game 63 | buyer.game = game 64 | 65 | res = game.run() 66 | res = moderator.compute_result(res) 67 | 68 | import os 69 | os.makedirs(f'{args.log_path}/{i}', exist_ok=True) 70 | 71 | with open(f'{args.log_path}/{i}/game.json', 'w') as f: 72 | json.dump({ 73 | 'item': d['title'], 74 | 'description': d['description'], 75 | 'buyer-price': d['buyer-price'], 76 | 'seller-price': d['seller-price'], 77 | 'history': game.history, 78 | 'utility': res['utility'], 79 | 'agreement': res['agreement'] 80 | }, f) 81 | 82 | stat['utility'] += res['utility'] 83 | stat['agreement'] += res['agreement'] 84 | 85 | 86 | print(res) 87 | 88 | stat['utility'] = stat['utility'] / stat['agreement'] if stat['agreement'] != 0 else 0 89 | stat['utility'] = stat['utility'] / len(data) 90 | stat['agreement'] = stat['agreement'] / len(data) 91 | 92 | with open(f'{args.log_path}/stat.json', 'w') as f: 93 | json.dump(stat, f) -------------------------------------------------------------------------------- /rot_bargain/run/run_mcts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir)) 4 | 5 | from core import * 6 | from user_simulators import * 7 | import json 8 | import argparse 9 | 10 | mcts_config = MCTSConfig() 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--log_path', type=str) 15 | parser.add_argument('--buyer', type=str) 16 | parser.add_argument('--seller', type=str) 17 | parser.add_argument('--n_iters', type=int, default=8) 18 | parser.add_argument('--panelty', type=float, default=0.0) 19 | parser.add_argument('--stage', type=str, default='evaluate') 20 | parser.add_argument('--prompt_path', type=str, default='none') 21 | 22 | args = parser.parse_args() 23 | 24 | if args.prompt_path == 'none': 25 | prompt = '' 26 | else: 27 | prompt = json.load(open(args.prompt_path)) 28 | 29 | mcts_config.threshold_offset = 0.2 30 | mcts_config.reuse_mcts = True 31 | mcts_config.num_simulations = args.n_iters 32 | mcts_config.max_realizations = 1 33 | 34 | data = json.load(open('data/CraigslistBargain/dev-luis-post.json')) 35 | 36 | if args.stage == 'reflect': 37 | data = data[:2] 38 | else: 39 | data = data[2:20] 40 | 41 | stat = {'utility': 0, 'agreement': 0} 42 | 43 | for i, d in enumerate(data): 44 | buyer = UserSimulator('Now enter the role-playing mode. In the following conversation, you will play as a buyer in a price bargaining game.', [{ 45 | 'role': 'user', 46 | 'content': f'You are the buyer who is trying to buy a product at the price of ${d["buyer-price"]}. The product information are described as follows: \n\nTitle:{d["title"]}\nDescription:\n{d["description"]}\nPlease make your response short and succinct.\n\nNow start the game.' 47 | }], model=args.buyer) 48 | 49 | seller = MCTSUserSimulator('Now enter the role-playing mode. In the following conversation, you will play as a seller in a price bargaining game.', [{ 50 | 'role': 'user', 51 | 'content': f'You are the seller who is trying to sell a product at the price of ${d["seller-price"]}. The product information are described as follows: \n\nTitle:{d["title"]}\nDescription:\n{d["description"]}\nPlease make your response short and succinct.{prompt}\n\nNow start the game.\nHow much is this product?' 52 | },], history_start_idx=1, actions=[ 53 | "Strategy 1: Emphasize Exclusivity\nHighlight the uniqueness and limited availability of the product to create a sense of urgency and justify its value.", 54 | "Strategy 2: Payment Plans\nOffer flexible payment options or installment plans that enable the customer to purchase at the asking price but spread out the payment over time.", 55 | "Strategy 3: Customer Loyalty\nConsider a small discount or added benefit for repeat customers, to reinforce loyalty and encourage future full-price purchases.", 56 | "Strategy 4: Price Anchoring\nMention higher-priced comparable items first, making the asking price seem more reasonable by comparison.", 57 | "Strategy 5: Create a Win-Win Situation\nFind out what's most valuable to the buyer that doesn't significantly affect your bottom line and leverage that in the negotiation.", 58 | ], configs=mcts_config, simulator=BarginRollout(), model=args.seller) 59 | 60 | res = f'Hi, its price is ${d["seller-price"]}.' 61 | 62 | history = [ 63 | 'How much is this product?', 64 | f'Hi, its price is ${d["seller-price"]}.' 65 | ] 66 | 67 | moderator = BarginingModerator(d, mode='seller', panelty=args.panelty) 68 | 69 | game = DialogueGame(moderator, [buyer, seller], max_turn=8, last_message=res, history=history) 70 | 71 | seller.game = game 72 | buyer.game = game 73 | 74 | res = game.run() 75 | res = moderator.compute_result(res) 76 | 77 | import os 78 | 79 | os.makedirs(f'{args.log_path}/{i}', exist_ok=True) 80 | with open(f'{args.log_path}/{i}/game.json', 'w') as f: 81 | json.dump({ 82 | 'item': d['title'], 83 | 'description': d['description'], 84 | 'buyer-price': d['buyer-price'], 85 | 'seller-price': d['seller-price'], 86 | 'history': game.history, 87 | 'utility': res['utility'], 88 | 'agreement': res['agreement'] 89 | }, f, indent=4) 90 | 91 | seller.mcts.dump(f'{args.log_path}/{i}/mcts.pkl') 92 | 93 | os.system(f'python analyze_mcts_log.py --path {args.log_path}/{i}') 94 | stat['utility'] += res['utility'] 95 | stat['agreement'] += res['agreement'] 96 | 97 | print(res) 98 | 99 | stat['utility'] = stat['utility'] / stat['agreement'] if stat['agreement'] != 0 else 0 100 | stat['utility'] = stat['utility'] / len(data) 101 | stat['agreement'] = stat['agreement'] / len(data) 102 | 103 | with open(f'{args.log_path}/stat.json', 'w') as f: 104 | json.dump(stat, f, indent=4) 105 | 106 | print(stat) -------------------------------------------------------------------------------- /rot_bargain/run_bargain.sh: -------------------------------------------------------------------------------- 1 | export OPENAI_API_KEY='XXX' 2 | 3 | # generate the rot guidelines 4 | python bargain_control.py \ 5 | --mode mcts --n_iter 8 \ 6 | --seller mixtral --stage reflect \ 7 | --rot_output_path prompts/rot.json 8 | 9 | # evaluate the rot guidelines 10 | python bargain_control.py \ 11 | --mode mcts --n_iter 8 \ 12 | --seller mixtral --stage evaluate \ 13 | --prompt_path prompts/rot.json -------------------------------------------------------------------------------- /rot_bargain/user_simulators/__init__.py: -------------------------------------------------------------------------------- 1 | from .mcts_simulator import * 2 | from .naive_simulator import * 3 | from .strategy_simulator import * -------------------------------------------------------------------------------- /rot_bargain/user_simulators/mcts_simulator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import math 4 | 5 | from core import * 6 | from user_simulators.naive_simulator import UserSimulator 7 | from user_simulators.strategy_simulator import StrategyUserSimulator 8 | 9 | from collections import defaultdict 10 | from functools import partial 11 | 12 | import time 13 | 14 | class MCTSConfig: 15 | def __init__(self) -> None: 16 | self.cpuct = 1 17 | self.Q_0 = -0.5 18 | self.num_simulations = 5 19 | self.max_realizations = 2 20 | self.smoothing = 1 21 | 22 | class DialogueNode: 23 | def __init__(self, data, parent) -> None: 24 | self.data = data 25 | self.parent = parent 26 | self.children = [] 27 | self.visits = 0 28 | self.value = 0 29 | 30 | parent.children.append(self) 31 | 32 | class OpenLoopMCTS: 33 | def __init__(self, configs, simulator=None) -> None: 34 | self.configs = configs 35 | 36 | self.Ns: dict = {} 37 | self.Nsa: dict = {} 38 | self.Q: dict = {} 39 | # utility 40 | self.valid_moves: dict = {} 41 | self.terminals: dict = {} 42 | self.simulator = simulator 43 | # debugging / more information 44 | self.Vs: dict = {} 45 | self.realizations: dict = defaultdict(list) 46 | self.realizations_Vs: dict = defaultdict(dict) 47 | self.realizations_Ns: dict = defaultdict(dict) 48 | self.max_realizations = configs.max_realizations 49 | 50 | self.simulation_results = defaultdict(list) 51 | 52 | def dump(self, name): 53 | import pickle 54 | with open(f'{name}', 'wb') as f: 55 | pickle.dump(self, f) 56 | 57 | def _init_node(self, state: DialogueGameState): 58 | game = DialogueGame.resume_gamestate(state) 59 | allowed_actions = game.get_valid_actions() 60 | 61 | self.valid_moves[state] = allowed_actions 62 | 63 | self.Ns[state] = 0 64 | self.Nsa[state] = {action: 0 for action in self.valid_moves[state]} 65 | self.Q[state] = {action: self.configs.Q_0 for action in self.valid_moves[state]} 66 | 67 | v = [1/len(allowed_actions) for _ in allowed_actions] 68 | self.Vs[state] = v 69 | 70 | return v 71 | 72 | def search(self, state): 73 | if not state in self.valid_moves: 74 | self._init_node(state) 75 | 76 | best_action = self.select_action(state) 77 | game = DialogueGame.resume_gamestate(state) 78 | 79 | realization_key = f"{state}__{best_action}" 80 | 81 | if not realization_key in self.realizations or len(self.realizations[realization_key]) < self.max_realizations: 82 | # mcts agent action 83 | finished, r, realization_state = game.step_and_state(action=best_action) 84 | rs = realization_state.copy() 85 | self.realizations[realization_key].append(rs) 86 | 87 | if finished: 88 | self.update(state, best_action, r) 89 | self.update_realizations(realization_key, realization_state, r) 90 | 91 | return r 92 | 93 | history, r = self.simulator.simulate(realization_state) 94 | self.simulation_results[rs].append((history, r)) 95 | self.update(state, best_action, r) 96 | self.update_realizations(realization_key, realization_state, r) 97 | 98 | return r 99 | else: 100 | realization_state = random.choice(self.realizations[realization_key]) 101 | game = DialogueGame.resume_gamestate(realization_state) 102 | 103 | # the other agent action 104 | finished, r, next_state = game.step_and_state() 105 | 106 | if finished: 107 | self.update(state, best_action, r) 108 | self.update_realizations(realization_key, realization_state, r) 109 | 110 | self.Vs[next_state] = r 111 | self.Q[next_state] = {'': r} 112 | 113 | return r 114 | 115 | v = self.search(next_state) 116 | 117 | self.update(state, best_action, v) 118 | self.update_realizations(realization_key, realization_state, v) 119 | 120 | return v 121 | 122 | def get_best_realization(self, state, action): 123 | best_v = -10000 124 | best_realization = None 125 | for realization in self.realizations[f"{state}__{action}"]: 126 | if realization in self.realizations_Vs[f"{state}__{action}"]: 127 | v = self.realizations_Vs[f"{state}__{action}"][realization] 128 | if v > best_v: 129 | best_v = v 130 | best_realization = realization 131 | return best_realization 132 | 133 | def select_action(self, state): 134 | best_uct = -100 135 | best_actions = [] 136 | for i, a in enumerate(self.valid_moves[state]): 137 | Ns = self.Ns[state] 138 | uct = self.Q[state][a] + self.configs.cpuct * math.sqrt(Ns) / (1 + self.Nsa[state][a]) 139 | if abs(uct - best_uct) < 1e-5: 140 | best_actions.append(a) 141 | elif uct > best_uct: 142 | best_uct = uct 143 | best_actions = [a] 144 | return random.choice(best_actions) 145 | 146 | def update(self, state, action, v): 147 | self.Q[state][action] = (self.Nsa[state][action] * self.Q[state][action] + v) / (self.Nsa[state][action] + 1) 148 | self.Ns[state] += 1 149 | self.Nsa[state][action] += 1 150 | 151 | def update_realizations(self, realization_key, realization_state, v): 152 | if realization_state in self.realizations_Vs[realization_key]: 153 | vv = self.realizations_Vs[realization_key][realization_state] 154 | n = self.realizations_Ns[realization_key][realization_state] 155 | else: 156 | vv = 0 157 | n = 0 158 | 159 | self.realizations_Vs[realization_key][realization_state] = (n * vv + v) / (n + 1) 160 | self.realizations_Ns[realization_key][realization_state] = n + 1 161 | 162 | def get_best_action(self, state): 163 | best_action = None 164 | best_v = -10000 165 | for a in self.valid_moves[state]: 166 | v = self.Q[state][a] 167 | if v > best_v: 168 | best_v = v 169 | best_action = a 170 | return best_action 171 | 172 | class MCTSUserSimulator(UserSimulator): 173 | def __init__(self, system, messages, actions, configs, history_start_idx=0, simulator=None, model='gpt4') -> None: 174 | self.actions = actions 175 | self.configs = configs 176 | self.simulator = simulator 177 | self.mcts = None 178 | if self.simulator is not None: 179 | self.simulator.agent = self 180 | super().__init__(system, messages, history_start_idx, model=model) 181 | 182 | def chat(self, **kwargs): 183 | best_action, utterance = self.get_action_utterance() 184 | return utterance 185 | 186 | def get_valid_actions(self): 187 | return self.actions 188 | 189 | def get_action_utterance(self): 190 | if self.configs.reuse_mcts and self.mcts is not None: 191 | mcts = self.mcts 192 | else: 193 | mcts = OpenLoopMCTS(self.configs, self.simulator) 194 | self.mcts = mcts 195 | 196 | game = self.game 197 | 198 | gamestate = self.game.gamestate() 199 | # replace all MCTSSimulator with StrategySimulator 200 | for i in range(len(gamestate.agents)): 201 | if isinstance(gamestate.agents[i], MCTSUserSimulator): 202 | gamestate.agents[i].chat = partial(StrategyUserSimulator.chat, gamestate.agents[i]) 203 | 204 | t = time.time() 205 | for i in range(self.configs.num_simulations): 206 | mcts.search(gamestate) 207 | print(f'{i+1}/{self.configs.num_simulations} {time.time()-t}') 208 | 209 | for a in game.agents: 210 | a.game = game 211 | 212 | best_action = mcts.get_best_action(gamestate) 213 | best_realization = mcts.get_best_realization(gamestate, best_action) 214 | 215 | if best_realization is None: 216 | utterance = self.chat(action=best_action) 217 | else: 218 | utterance = best_realization.history[-1] 219 | 220 | for i in range(len(gamestate.agents)): 221 | if isinstance(gamestate.agents[i], MCTSUserSimulator): 222 | gamestate.agents[i].chat = partial(type(gamestate.agents[i]).chat, gamestate.agents[i]) 223 | 224 | return best_action, utterance -------------------------------------------------------------------------------- /rot_bargain/user_simulators/naive_simulator.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import numpy as np 3 | class UserSimulator: 4 | def __init__(self, system, bg_messages, history_start_idx=0, model='gpt4') -> None: 5 | self.system = system 6 | self.bg_messages = bg_messages 7 | self.history_start_idx = history_start_idx 8 | self.model=model 9 | self.game = None 10 | 11 | def chat(self, **kwargs): 12 | messages = self.prepare_messages() 13 | 14 | if self.model == 'gpt4': 15 | res = query_gpt4(messages) 16 | elif self.model == 'chatgpt': 17 | res = query_chatgpt(messages) 18 | elif self.model == 'mixtral': 19 | res = query_mixtral(messages) 20 | 21 | return res 22 | 23 | def get_valid_actions(self): 24 | return ['None'] 25 | 26 | def predict(self): 27 | return np.array([1]), 0 28 | 29 | def prepare_messages(self): 30 | guidelines = getattr(self, 'guidelines', '') 31 | 32 | system_message = [{ 33 | 'role': 'system', 34 | 'content': self.system + guidelines 35 | }] 36 | 37 | history = self.game.history[self.history_start_idx:] 38 | if len(self.bg_messages) == 0: 39 | user = 'assistant' 40 | else: 41 | user = self.bg_messages[-1]['role'] 42 | 43 | if user == 'assistant': 44 | history_messages = [{ 45 | 'role': 'user' if i % 2 == 0 else 'assistant', 46 | 'content': m 47 | } for i, m in enumerate(history)] 48 | else: 49 | history_messages = [{ 50 | 'role': 'user' if i % 2 == 1 else 'assistant', 51 | 'content': m 52 | } for i, m in enumerate(history)] 53 | 54 | messages = system_message + self.bg_messages + history_messages 55 | 56 | return messages 57 | -------------------------------------------------------------------------------- /rot_bargain/user_simulators/strategy_simulator.py: -------------------------------------------------------------------------------- 1 | from .naive_simulator import UserSimulator 2 | from utils import * 3 | import random 4 | 5 | class StrategyUserSimulator(UserSimulator): 6 | def __init__(self, system, messages, actions, configs, history_start_idx=0, model='gpt4') -> None: 7 | self.actions = actions 8 | self.configs = configs 9 | 10 | super().__init__(system, messages, history_start_idx, model=model) 11 | 12 | def chat(self, **kwargs): 13 | messages = self.prepare_messages() 14 | 15 | if 'action' in kwargs: 16 | action = kwargs['action'] 17 | else: 18 | action = self.get_action() 19 | 20 | if not action == '': 21 | messages = messages[:-1] + [{ 22 | 'role': 'user', 23 | 'content': f'Please reply with the following strategy: {action}. ' + messages[-1]['content'] 24 | }] 25 | 26 | if self.model == 'gpt4': 27 | res = query_gpt4(messages) 28 | elif self.model == 'chatgpt': 29 | res = query_chatgpt(messages) 30 | else: 31 | res = query_mixtral(messages) 32 | 33 | return res 34 | 35 | def get_action(self): 36 | return random.choice(self.actions) 37 | 38 | def get_valid_actions(self): 39 | return self.actions 40 | 41 | 42 | class FixedStrategyUserSimulator(UserSimulator): 43 | def __init__(self, system, messages, history_start_idx=0, strategy='', model='gpt4') -> None: 44 | self.strategy = strategy 45 | 46 | super().__init__(system, messages, history_start_idx, model=model) 47 | 48 | def chat(self, **kwargs): 49 | messages = self.prepare_messages() 50 | 51 | new_messages = messages[:-1] + [{ 52 | 'role': 'user', 53 | 'content': f'Please reply with the following strategy: {self.strategy}. ' + messages[-1]['content'] 54 | }] 55 | 56 | if self.model == 'gpt4': 57 | res = query_gpt4(new_messages) 58 | elif self.model == 'chatgpt': 59 | res = query_chatgpt(new_messages) 60 | elif self.model == 'mistral': 61 | res = query_mixtral(new_messages) 62 | 63 | return res 64 | 65 | def get_action(self): 66 | return random.choice(self.actions) 67 | 68 | def get_valid_actions(self): 69 | return self.actions -------------------------------------------------------------------------------- /rot_bargain/utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | import httpx 4 | from requests_futures.sessions import FuturesSession 5 | 6 | 7 | def query_gpt4(messages): 8 | for i in range(5): 9 | try: 10 | response = openai.ChatCompletion.create( 11 | messages=messages, temperature=0.7, n=1, model="gpt-4-turbo") 12 | 13 | return response['choices'][0]['message']['content'] 14 | except Exception as e: 15 | print(e, response) 16 | time.sleep(5) 17 | 18 | raise e 19 | 20 | def query_chatgpt(messages): 21 | for i in range(5): 22 | try: 23 | response = openai.ChatCompletion.create( 24 | messages=messages, 25 | n=1, 26 | temperature=0.7, 27 | model="gpt-3.5-turbo-1106" 28 | ) 29 | return response['choices'][0]['message']['content'] 30 | except Exception as e: 31 | print(e, response) 32 | time.sleep(10) 33 | raise e 34 | 35 | def query_mixtral(messages, n=1): 36 | if n == 1: 37 | return httpx.post('http://0.0.0.0:23100/v1/chat/completions', json={ 38 | 'messages': messages, 39 | "model": 'mixtral', 40 | 'n': n, 41 | 'temperature': 0.7 42 | }, timeout=60).json()['choices'][0]['message']['content'].strip() 43 | else: 44 | return [c['message']['content'].strip() for c in httpx.post('http://0.0.0.0:23100/v1/chat/completions', json={ 45 | 'messages': messages, 'temperature': 0.7, "model": 'mixtral', 'n': n}, timeout=60).json()['choices']] 46 | 47 | def query_multiple_mixtral(messages): 48 | session = FuturesSession() 49 | futures = [] 50 | for m in messages: 51 | futures.append(session.post('http://0.0.0.0:23100/v1/chat/completions', json={'messages': m, 'temperature': 0.7, "model": 'mixtral'}, timeout=60)) 52 | 53 | res = [] 54 | for f in futures: 55 | res.append(f.result().json()['choices'][0]['message']['content'].strip()) 56 | 57 | return res 58 | 59 | def join_bargining_history(history): 60 | history = '\n'.join(['{}: {}'.format('Buyer' if i % 2 == 0 else 'Seller', m) for i, m in enumerate(history)]) 61 | return history -------------------------------------------------------------------------------- /rot_scripts/blocksworld_analysis.py: -------------------------------------------------------------------------------- 1 | from gpt4_utils import query_gpt4 2 | 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'examples', 'blocksworld'))) 7 | 8 | import os 9 | import json 10 | 11 | from tqdm import trange 12 | from reasoners.visualization import visualize, analyze, TreeLog 13 | import reasoners.benchmark.bw_utils as bw_utils 14 | 15 | os.environ['VAL'] = 'LLMs-Planning/planner_tools/VAL' 16 | 17 | 18 | class Evaluator: 19 | def __init__(self, data, depth, config_file, domain_file) -> None: 20 | self.data = data 21 | self.depth = depth 22 | self.config_file = config_file 23 | self.domain_file = domain_file 24 | 25 | def is_terminal(self, node): 26 | if 'step_idx' not in node.data: 27 | return False 28 | return node.data['step_idx'] == self.depth 29 | 30 | def get_score(self, node): 31 | output = node.data['history_actions'].replace(',', '\n') 32 | bw_utils.text_to_plan_blocksworld(output, self.data["instance_file"], self.config_file, self.domain_file, 'tmp.plan') 33 | correct = bw_utils.validate_plan(self.domain_file, self.data["instance_file"], 'tmp.plan')[0] 34 | 35 | return correct 36 | 37 | config_file = 'examples/blocksworld/data/bw_config.yaml' 38 | domain_file = 'examples/blocksworld/data/generated_domain.pddl' 39 | 40 | if __name__ == '__main__': 41 | import argparse 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--path', type=str) 44 | parser.add_argument('--output_name', type=str, default='bw_summarization.json') 45 | parser.add_argument('--thres', type=float, default=0.1) 46 | parser.add_argument('--steps', type=int) 47 | args = parser.parse_args() 48 | 49 | data_file = f'examples/blocksworld/data/split_v2/split_v2_step_{args.steps}_data.json' 50 | path = args.path 51 | data = bw_utils.load_blocksworld(config_file, domain_file, data_file) 52 | 53 | descriptions = [] 54 | for i, d in enumerate(data): 55 | aa = analyze.Analysis.from_file(f"{path}/{i+1}.pkl", Evaluator(d, args.steps, config_file, domain_file)) 56 | 57 | aa.backprop_rewards() 58 | nodes = aa.get_improving_nodes(args.thres) 59 | 60 | for n in nodes: 61 | descriptions.append({ 62 | "goal": d['goal'], 63 | "details": aa.get_node_details(n) 64 | }) 65 | 66 | def get_gpt4_summarizations(descriptions): 67 | summarizations = [] 68 | NL= '\n' 69 | batchsize = 5 70 | for i in trange(0, len(descriptions), batchsize): 71 | batch = descriptions[i:i+batchsize] 72 | prompt = '' 73 | for i, b in enumerate(batch): 74 | prompt += f"Goal {i}: {b['goal']}\n\n" 75 | prompt += f"State {i}: {b['details'][0]}\nActions {i}:\n{NL.join(b['details'][1])}\n\n" 76 | 77 | messages = [ 78 | { 79 | "role": "system", 80 | "content": "BlocksWorld is game that requires the agent to apply a sequence of actions to make configurations of blocks match a goal configuration. Please summarize the following action and corresponding rewards given a state into a policy to achieve higher reward." 81 | }, 82 | { 83 | "role": "user", 84 | "content": prompt + "Note: Since the reward is not given when playing the game, your policy should avoid directly use the reward as information. Your policy should be specific to help avoid making the same mistake. Please follow this format: Summarization: ...\nPolicy: ..." 85 | } 86 | ] 87 | 88 | summarizations.append(query_gpt4(messages)) 89 | 90 | if len(summarizations) > 8: 91 | break 92 | 93 | with open('summarizations.json', 'r') as f: 94 | summarizations = json.load(f)[0] 95 | 96 | policies = [ 97 | s.split("Policy:")[1].strip() for s in summarizations 98 | ] 99 | 100 | messages = [{ 101 | "role": "user", 102 | "content": "Please merge the following policies into one:\n" + NL.join(policies) 103 | }] 104 | 105 | merged_policy = query_gpt4(messages) 106 | 107 | return summarizations, merged_policy 108 | 109 | summarizations = get_gpt4_summarizations(descriptions) 110 | 111 | with open('summarizations.json', 'w') as f: 112 | json.dump(summarizations, f, indent=4) 113 | 114 | -------------------------------------------------------------------------------- /rot_scripts/blocksworld_generate_rot_prompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--path', type=str) 6 | parser.add_argument('--output_name', type=str, default='bw_summarization.json') 7 | parser.add_argument('--steps', type=str, default='4') 8 | 9 | args = parser.parse_args() 10 | 11 | prompts = json.load(open(f'prompts/bw/pool_prompt_v2_step_{args.steps}_template.json')) 12 | rot = json.load(open(f'{args.path}')) 13 | 14 | prompts['intro'] = prompts['intro'].format(template=rot) 15 | prompts['self-eval'] = prompts['self-eval'].format(template=rot) 16 | 17 | with open(args.output_name, 'w') as f: 18 | json.dump(prompts, f, indent=2) -------------------------------------------------------------------------------- /rot_scripts/gpt4_utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | 4 | def query_gpt4(messages): 5 | for _ in range(5): 6 | try: 7 | response = openai.ChatCompletion.create(messages=messages, n=1, model="gpt-4-turbo-1106") 8 | return response['choices'][0]['message']['content'] 9 | except Exception as e: 10 | print(e, response) 11 | time.sleep(15) -------------------------------------------------------------------------------- /rot_scripts/gsm8k_analysis.py: -------------------------------------------------------------------------------- 1 | from gpt4_utils import query_gpt4 2 | 3 | import sys 4 | import os 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, 'examples', 'gsm8k'))) 7 | 8 | import datasets 9 | import json 10 | 11 | from reasoners.visualization import analyze 12 | from reasoners.algorithm.mcts import MCTSNode 13 | from utils import retrieve_answer, retrieve_answer_from_dataset, judge_answer 14 | 15 | class Evaluator: 16 | def __init__(self, data) -> None: 17 | self.data = data 18 | def is_terminal(self, node): 19 | if node.data['question'] is None: 20 | return False 21 | return 'Now we can answer the question:' in node.data['question'] 22 | 23 | def get_score(self, node): 24 | pred = retrieve_answer(node.data['answer']) 25 | gold = retrieve_answer_from_dataset(self.data) 26 | return int(judge_answer(pred, gold)) 27 | 28 | def compose_solution(solution): 29 | text = '' 30 | for i, s in enumerate(solution): 31 | text += (f"sub-question {i}: {s.data['question']}\nsub-answer {i}: {s.data['answer']}\n") 32 | return text 33 | 34 | def compose_summarization_prompt(case): 35 | question = case[0]['question'] 36 | answer = case[0]['answer'].split('####')[1].strip() 37 | correct = compose_solution(case[1][0]) 38 | wrong = compose_solution(case[1][1]) 39 | 40 | return f"Below is a math word problem:\n{question}\nIts answer is: {answer}\nThe following are two solutions.\n\nCorrect solution:\n{correct}\nWrong solution:\n{wrong}\nPlease compare the above 2 solutions, briefly analyze how the mistake is made and give a practical policy to avoid this mistake.\nPlease write in the following format: Analysis: ...\nPolicy: ..." 41 | 42 | def gsm_node_data_factory(x: MCTSNode): 43 | if not x.state: 44 | return {"question": x.action, "answer": "Not finished"} 45 | return {"question": x.action, "answer": x.state[-1].sub_answer} 46 | 47 | if __name__ == '__main__': 48 | import argparse 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--path', type=str) 51 | parser.add_argument('--output_name', type=str, default='gsm8k_summarization.json') 52 | parser.add_argument('--split', type=str, default='train') 53 | parser.add_argument('--thres', type=float, default=0.1) 54 | parser.add_argument('--method', type=str, default='improving') 55 | 56 | args = parser.parse_args() 57 | 58 | full_dataset = datasets.load_dataset('gsm8k', 'main', split=args.split) 59 | max_idx = max([int(p.replace('.pkl', '')) for p in os.listdir(args.path)]) 60 | 61 | res = [] 62 | for i in range(max_idx): 63 | try: 64 | path = os.path.join(args.path, f'{i+1}.pkl') 65 | 66 | analysis = analyze.Analysis.from_file(path, evaluator=Evaluator(full_dataset[i]), node_data_factory=gsm_node_data_factory) 67 | analysis.backprop_rewards() 68 | 69 | if args.method == 'improving': 70 | nodes = analysis.get_improving_nodes(threshold=args.thres) 71 | else: 72 | nodes = analysis.get_improving_nodes(method=args.method) 73 | for node in nodes: 74 | res.append((full_dataset[i], analysis.get_node_details_gsm8k(node))) 75 | except Exception as e: 76 | pass 77 | 78 | with open('gsm8k_summarization.json', 'w') as f: 79 | json.dump(res, f, indent=4) 80 | 81 | 82 | cases = [x for x in res if x is not None][:30] 83 | 84 | batchsize = 5 85 | analysises = [] 86 | for i in range(0, len(cases), batchsize): 87 | batch = cases[i:i+batchsize] 88 | NL = '\n' 89 | prompt = '' 90 | for i, bb in enumerate(batch): 91 | prompt += f"State {i}:\nQuestion {i}: {bb[0]['question']}\nPartial solution {i}: {bb[1][0]}" 92 | prompt += f"Possible Action & Response & Rewards:\n\n{NL.join([d.format(idx=ii) for ii, d in enumerate(bb[1][1])])}" 93 | 94 | messages = [ 95 | { 96 | "role": "system", 97 | "content": "You are a math tutor. Now you are teaching a student who is solving a math word problem. Please summarize the following action, response and corresponding rewards given a state into a policy to achieve higher reward." 98 | }, 99 | { 100 | "role": "user", 101 | "content": "Below are some examples of the process.\n" + prompt + "\n\nPlease give suggestion on how to ask subquestions (take action) and answer the subquestions (response) to achieve higher reward. Please first analyze the mistakes and then summarize a policy. Your policy should be specific to help avoid making the same mistake. Please follow this format: Summarization: ...\nPolicy: ..." 102 | } 103 | ] 104 | 105 | res = query_gpt4(messages) 106 | 107 | analysises.append({ 108 | 'sample': str(bb), 109 | 'analysis': res.split('Summarization:')[1].split('Policy:')[0].strip(), 110 | 'policy': res.split('Policy:')[1].strip(), 111 | }) 112 | 113 | with open('gsm8k_summarization.json', 'w') as f: 114 | json.dump(analysises, f) 115 | 116 | analysises = json.load(open('gsm8k_summarization.json')) 117 | 118 | policies = [a['policy'] for a in analysises] 119 | 120 | messages = [ 121 | { 122 | 'content': f"The following are some policies. Please merge them as a comprehensive policy. Note you should keep the details in the merged policy and make sure that the merged policy is not too general.\n" + '\n\n'.join([f"Policy {i}:\n"+ s for i, s in enumerate(policies)]), 123 | 'role': 'user' 124 | } 125 | ] 126 | 127 | res = query_gpt4(messages) 128 | 129 | with open(args.output_name, 'w') as f: 130 | json.dump(res, f) -------------------------------------------------------------------------------- /rot_scripts/gsm8k_generate_rot_prompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--path', type=str) 6 | parser.add_argument('--output_name', type=str) 7 | parser.add_argument('--mode', type=str, default='cot') 8 | 9 | args = parser.parse_args() 10 | 11 | rot = json.load(open(f'{args.path}')) 12 | 13 | if args.mode == 'cot': 14 | prompts = json.load(open(f'prompts/gsm8k/cot_default.json')) 15 | prompts['guidelines'] = rot 16 | else: 17 | prompts = json.load(open(f'/home/wenyanghui/projects/reflection-on-trees/prompts/gsm8k/prompt_pool_template.json')) 18 | prompts['instruction'] = prompts['instruction'].format(rot=rot) 19 | 20 | with open(args.output_name, 'w') as f: 21 | json.dump(prompts, f, indent=2) -------------------------------------------------------------------------------- /vllm-server/mixtral.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python vllm_api.py \ 2 | --model mistralai/Mixtral-8x7B-Instruct-v0.1 \ 3 | --host 0.0.0.0 \ 4 | --served-model-name mixtral \ 5 | --dtype float16 \ 6 | --port 23100 \ 7 | --disable-log-requests \ 8 | --trust-remote-code \ 9 | --tp 2 -------------------------------------------------------------------------------- /vllm-server/phi-2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python vllm_api.py \ 2 | --model microsoft/phi-2 \ 3 | --host 0.0.0.0 \ 4 | --served-model-name phi-2 \ 5 | --dtype float16 \ 6 | --port 23100 \ 7 | --disable-log-requests \ 8 | --trust-remote-code --------------------------------------------------------------------------------