├── .gitignore ├── Eval ├── game24 │ ├── README.md │ ├── reflection_few_shot_examples.txt │ ├── run.py │ ├── run_mcts.py │ ├── run_re.py │ ├── scripts │ │ ├── ARMAP-B │ │ │ └── rm_best_of_n_llama8b.sh │ │ ├── ARMAP-M │ │ │ └── mcts_rm.sh │ │ └── ARMAP-R │ │ │ └── re_rm_llama8b.sh │ ├── setup.py │ └── src │ │ ├── tot │ │ ├── __init__.py │ │ ├── argparser.py │ │ ├── argparser_bk.py │ │ ├── data │ │ │ ├── 24 │ │ │ │ └── 24.csv │ │ │ ├── crosswords │ │ │ │ ├── mini0505.json │ │ │ │ └── mini0505_0_100_5.json │ │ │ └── text │ │ │ │ └── data_100_random_text.txt │ │ ├── methods │ │ │ ├── bfs.py │ │ │ ├── bfs_re.py │ │ │ ├── data_generation │ │ │ │ ├── data_analysis.py │ │ │ │ └── data_generation_pipeline.py │ │ │ └── mcts_agents │ │ │ │ ├── debug_prompt.py │ │ │ │ ├── env_game24_wrapper.py │ │ │ │ ├── env_game24_wrapper_bk.py │ │ │ │ ├── env_scienceworld_wrapper.py │ │ │ │ ├── fact_rlhf_reward_prompt_wj2.txt │ │ │ │ ├── mcts.py │ │ │ │ ├── not_used.png │ │ │ │ ├── not_used_0.png │ │ │ │ ├── not_used_1.png │ │ │ │ ├── not_used_2.png │ │ │ │ ├── not_used_3.png │ │ │ │ ├── not_used_4.png │ │ │ │ ├── not_used_5.png │ │ │ │ ├── not_used_6.png │ │ │ │ ├── not_used_7.png │ │ │ │ ├── not_used_8.png │ │ │ │ ├── not_used_factual.json │ │ │ │ ├── performance_analysis.py │ │ │ │ ├── performance_analysis_bk.py │ │ │ │ ├── performance_analysis_bk3.py │ │ │ │ ├── performance_analysis_bk_st.py │ │ │ │ ├── performance_analysis_bk_st2.py │ │ │ │ ├── performance_analysis_v2.py │ │ │ │ ├── test_average.py │ │ │ │ ├── uct.py │ │ │ │ ├── utils_mcts.py │ │ │ │ └── utils_mcts_bk.py │ │ ├── models.py │ │ ├── models_new.py │ │ ├── prompts │ │ │ ├── crosswords.py │ │ │ ├── game24.py │ │ │ ├── game24_data.py │ │ │ ├── game24_mcts.py │ │ │ └── text.py │ │ └── tasks │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── crosswords.py │ │ │ ├── game24.py │ │ │ ├── game24_mcts.py │ │ │ └── text.py │ │ └── tree_of_thoughts_llm.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ ├── requires.txt │ │ └── top_level.txt ├── sciworld │ ├── README.md │ ├── agents │ │ ├── __init__.py │ │ ├── base.py │ │ ├── fastchat_agent.py │ │ ├── local_agent.py │ │ ├── openai_lm_agent.py │ │ ├── openllm_lm_agent.py │ │ └── reflection_few_shot_examples.txt │ ├── configs │ │ ├── model │ │ │ ├── fastchat.json │ │ │ ├── fastchat_explore.json │ │ │ ├── local.json │ │ │ ├── local_0.5.json │ │ │ └── local_base.json │ │ └── task │ │ │ ├── alfworld.json │ │ │ └── sciworld.json │ ├── envs │ │ ├── __init__.py │ │ ├── alfworld_env.py │ │ ├── base.py │ │ └── sciworld_env.py │ ├── main.py │ ├── main_mcts.py │ ├── main_re.py │ ├── mcts_agents │ │ ├── env_scienceworld_wrapper.py │ │ ├── mcts.py │ │ ├── my_random_agent.py │ │ ├── sci_world_sample.py │ │ ├── uct.py │ │ └── utils_mcts.py │ ├── prompt │ │ ├── __init__.py │ │ ├── batch_server_fix_sciworld.json │ │ ├── icl_examples │ │ │ ├── alfworld_icl.json │ │ │ └── sciworld_icl.json │ │ ├── instructions │ │ │ ├── alfworld_inst.txt │ │ │ ├── sciworld_inst.txt │ │ │ ├── sciworld_inst_ibm.txt │ │ │ ├── sciworld_inst_llama.txt │ │ │ └── webshop_inst.txt │ │ ├── scienceworld_reflection.txt │ │ ├── sciworld.py │ │ └── templates.py │ ├── scripts │ │ ├── ARMAP-B │ │ │ └── calc_sample.py │ │ ├── ARMAP-M │ │ │ ├── run_mcts.sh │ │ │ ├── test_sampling.py │ │ │ ├── trans_local_llm.sh │ │ │ └── trans_rm.sh │ │ ├── ARMAP-R │ │ │ └── llama8b_re_seen.sh │ │ ├── Greedy │ │ │ └── llama8b_base_seen.sh │ │ └── Sampling │ │ │ └── llama8b_sample_seen.sh │ ├── tasks │ │ ├── __init__.py │ │ ├── alfworld.py │ │ ├── base.py │ │ └── sciworld.py │ └── utils │ │ ├── __init__.py │ │ ├── datatypes.py │ │ └── replace_sciworld_score.py └── webshop │ ├── README.md │ ├── configs │ ├── agents │ │ ├── api_agents.yaml │ │ ├── fs_agent.yaml │ │ ├── gemini.yaml │ │ ├── new_local.yaml │ │ ├── openai-chat.yaml │ │ └── openai-text.yaml │ ├── assignments │ │ ├── new_llama8b_base.yaml │ │ ├── new_llama8b_ref.yaml │ │ └── new_llama8b_sample.yaml │ ├── start_task.yaml │ └── tasks │ │ ├── alfworld.yaml │ │ ├── avalon.yaml │ │ ├── card_game.yaml │ │ ├── dbbench.yaml │ │ ├── kg.yaml │ │ ├── ltp.yaml │ │ ├── mind2web.yaml │ │ ├── os.yaml │ │ ├── task_assembly.yaml │ │ └── webshop.yaml │ ├── data_generation │ ├── data_modules.py │ ├── data_pipeline.py │ ├── llm_models.py │ └── utils.py │ ├── llms │ ├── __init__.py │ ├── lm_config.py │ ├── providers │ │ ├── cogagent_utils.py │ │ ├── gemini_utils.py │ │ ├── hf_utils.py │ │ ├── local_utils.py │ │ ├── openai_azure_utils.py │ │ └── openai_utils.py │ ├── tokenizers.py │ └── utils.py │ ├── prompt │ └── reflection_few_shot_examples.txt │ ├── scripts │ ├── ARMAP-B │ │ └── calc_sample.py │ ├── ARMAP-M │ │ ├── run_mcts.sh │ │ ├── trans_local_llm.sh │ │ └── trans_rm.sh │ ├── ARMAP-R │ │ └── new_llama8b_ref.sh │ ├── Greedy │ │ └── new_llama8b_base.sh │ └── Sampling │ │ └── new_llama8b_sample.sh │ └── src │ ├── __init__.py │ ├── analysis.py │ ├── assigner.py │ ├── client │ ├── __init__.py │ ├── __pycache__bk │ │ ├── __init__.cpython-312.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── agent.cpython-312.pyc │ │ ├── agent.cpython-38.pyc │ │ ├── agent.cpython-39.pyc │ │ ├── task.cpython-38.pyc │ │ └── task.cpython-39.pyc │ ├── agent.py │ ├── agent_test.py │ ├── agents │ │ ├── __init__.py │ │ ├── claude_agent.py │ │ ├── fastchat_client.py │ │ ├── gemini_agent.py │ │ ├── http_agent.py │ │ ├── key_pool.py │ │ ├── new_local_agent.py │ │ └── test_agent.py │ └── task.py │ ├── configs.py │ ├── mcts_agents │ ├── env_webshop_wrapper.py │ ├── mcts.py │ ├── my_random_agent.py │ ├── performance_analysis.py │ ├── uct.py │ └── utils.py │ ├── server │ ├── __init__.py │ ├── task.py │ ├── task_controller.py │ ├── task_worker.py │ └── tasks │ │ ├── __init__.py │ │ ├── alfworld │ │ ├── __init__.py │ │ ├── configs │ │ │ └── base_config.yaml │ │ ├── environment.py │ │ ├── prompts │ │ │ ├── alfworld_multiturn.json │ │ │ ├── alfworld_multiturn_cot.json │ │ │ ├── alfworld_multiturn_new.json │ │ │ ├── alfworld_multiturn_plan_first.json │ │ │ ├── alfworld_multiturn_pure.json │ │ │ ├── alfworld_multiturn_react.json │ │ │ └── alfworld_singleturn.json │ │ ├── task.py │ │ └── utils.py │ │ ├── avalon │ │ ├── README.md │ │ ├── __init__.py │ │ ├── agents │ │ │ ├── agent.py │ │ │ ├── baseline_agents.py │ │ │ └── llm_with_discussion.py │ │ ├── avalon_exception.py │ │ ├── engine.py │ │ ├── prompts.py │ │ ├── task.py │ │ ├── task_scoring.py │ │ ├── utils.py │ │ └── wrapper.py │ │ ├── card_game │ │ ├── AI │ │ │ ├── AI_En.py │ │ │ ├── __init__.py │ │ │ ├── action1.py │ │ │ ├── action2.py │ │ │ ├── basline1.py │ │ │ ├── basline2.py │ │ │ ├── client.py │ │ │ ├── main.py │ │ │ ├── prompt │ │ │ │ ├── __init__.py │ │ │ │ ├── cn.py │ │ │ │ ├── cn_en.py │ │ │ │ └── en.py │ │ │ └── sdk │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── ai_client.cpython-310-x86_64-linux-gnu.so │ │ │ │ ├── ai_client.cpython-38-x86_64-linux-gnu.so │ │ │ │ ├── ai_client.hpp │ │ │ │ ├── ai_client.pyi │ │ │ │ ├── build │ │ │ │ ├── CMakeCache.txt │ │ │ │ ├── CMakeFiles │ │ │ │ │ ├── 3.25.0 │ │ │ │ │ │ ├── CMakeCXXCompiler.cmake │ │ │ │ │ │ ├── CMakeDetermineCompilerABI_CXX.bin │ │ │ │ │ │ ├── CMakeSystem.cmake │ │ │ │ │ │ └── CompilerIdCXX │ │ │ │ │ │ │ ├── CMakeCXXCompilerId.cpp │ │ │ │ │ │ │ └── a.out │ │ │ │ │ ├── CMakeDirectoryInformation.cmake │ │ │ │ │ ├── CMakeOutput.log │ │ │ │ │ ├── Makefile.cmake │ │ │ │ │ ├── Makefile2 │ │ │ │ │ ├── TargetDirectories.txt │ │ │ │ │ ├── ai_client.dir │ │ │ │ │ │ ├── DependInfo.cmake │ │ │ │ │ │ ├── build.make │ │ │ │ │ │ ├── cmake_clean.cmake │ │ │ │ │ │ ├── compiler_depend.make │ │ │ │ │ │ ├── compiler_depend.ts │ │ │ │ │ │ ├── depend.make │ │ │ │ │ │ ├── flags.make │ │ │ │ │ │ ├── jsoncpp │ │ │ │ │ │ │ ├── jsoncpp.cpp.o │ │ │ │ │ │ │ └── jsoncpp.cpp.o.d │ │ │ │ │ │ ├── link.txt │ │ │ │ │ │ ├── progress.make │ │ │ │ │ │ ├── py_ai_sdk.cpp.o │ │ │ │ │ │ └── py_ai_sdk.cpp.o.d │ │ │ │ │ ├── cmake.check_cache │ │ │ │ │ └── progress.marks │ │ │ │ ├── Makefile │ │ │ │ ├── bin │ │ │ │ │ └── ai_client.cpython-310-x86_64-linux-gnu.so │ │ │ │ └── cmake_install.cmake │ │ │ │ ├── jsoncpp │ │ │ │ ├── json │ │ │ │ │ ├── json-forwards.h │ │ │ │ │ └── json.h │ │ │ │ └── jsoncpp.cpp │ │ │ │ ├── py_ai_sdk.cpp │ │ │ │ ├── py_json_cast.hpp │ │ │ │ └── pyd │ │ │ │ ├── ai_client.cp38-win_amd64.pyd │ │ │ │ ├── ai_client.cpython-310-x86_64-linux-gnu.so │ │ │ │ ├── ai_client.cpython-36m-x86_64-linux-gnu.so │ │ │ │ ├── ai_client.cpython-38-darwin.so │ │ │ │ └── ai_client.cpython-38-x86_64-linux-gnu.so │ │ ├── __init__.py │ │ ├── judger │ │ │ ├── __init__.py │ │ │ ├── cal_metric.py │ │ │ ├── judger.py │ │ │ ├── rserver.py │ │ │ └── run_all.py │ │ ├── logic │ │ │ ├── .gitignore │ │ │ ├── Makefile │ │ │ ├── bin │ │ │ │ └── main │ │ │ └── src │ │ │ │ ├── aqua_sdk.cpp │ │ │ │ ├── fish.cpp │ │ │ │ ├── fish.h │ │ │ │ ├── fishset.cpp │ │ │ │ ├── fishset.h │ │ │ │ ├── game.cpp │ │ │ │ ├── game.h │ │ │ │ ├── jsoncpp │ │ │ │ ├── json │ │ │ │ │ ├── json-forwards.h │ │ │ │ │ └── json.h │ │ │ │ └── jsoncpp.cpp │ │ │ │ ├── main │ │ │ │ ├── main.cpp │ │ │ │ ├── player.cpp │ │ │ │ ├── player.h │ │ │ │ └── timer.h │ │ ├── server.py │ │ ├── task.py │ │ └── utils.py │ │ ├── dbbench │ │ ├── Interaction.py │ │ ├── __init__.py │ │ └── requirements.txt │ │ ├── knowledgegraph │ │ ├── __init__.py │ │ ├── api.py │ │ ├── ontology │ │ │ ├── fb_roles │ │ │ ├── fb_types │ │ │ ├── reverse_properties │ │ │ └── vocab.json │ │ ├── requirements.txt │ │ ├── task.py │ │ └── utils │ │ │ ├── logic_form_util.py │ │ │ ├── semparse_util.py │ │ │ └── sparql_executer.py │ │ ├── ltp │ │ ├── __init__.py │ │ └── task.py │ │ ├── mind2web │ │ ├── __init__.py │ │ ├── data_utils │ │ │ ├── __init__.py │ │ │ └── dom_utils.py │ │ ├── dataloader.py │ │ └── task.py │ │ ├── os_interaction │ │ ├── __init__.py │ │ ├── images.py │ │ ├── requirements.txt │ │ └── task.py │ │ ├── webshop │ │ ├── .github │ │ │ ├── ISSUE_TEMPLATE.md │ │ │ ├── PULL_REQUEST_TEMPLATE.md │ │ │ └── workflows │ │ │ │ └── pytest.yml │ │ ├── .gitignore │ │ ├── LICENSE.md │ │ ├── README.md │ │ ├── __init__.py │ │ ├── __init__re.py │ │ ├── assets │ │ │ ├── diagram.gif │ │ │ ├── model_ckpts.png │ │ │ └── transfer-logic.png │ │ ├── baseline_models │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── agent.py │ │ │ ├── env.py │ │ │ ├── generate_search.py │ │ │ ├── logger.py │ │ │ ├── models │ │ │ │ ├── bert.py │ │ │ │ ├── modules.py │ │ │ │ └── rnn.py │ │ │ ├── requirements.txt │ │ │ ├── test.py │ │ │ ├── train_choice_il.py │ │ │ ├── train_rl.py │ │ │ └── train_search_il.py │ │ ├── conftest.py │ │ ├── requirements.txt │ │ ├── run_dev.sh │ │ ├── run_envs │ │ │ ├── run_web_agent_site_env.py │ │ │ └── run_web_agent_text_env.py │ │ ├── run_prod.sh │ │ ├── run_web_agent_site_env.sh │ │ ├── run_web_agent_text_env.sh │ │ ├── search_engine │ │ │ ├── convert_product_file_format.py │ │ │ ├── lucene_searcher.py │ │ │ └── run_indexing.sh │ │ ├── setup.sh │ │ ├── tests │ │ │ ├── transfer │ │ │ │ ├── mocks │ │ │ │ │ ├── mock_parse_item_page_amz │ │ │ │ │ ├── mock_parse_item_page_ebay │ │ │ │ │ ├── mock_parse_item_page_ws │ │ │ │ │ ├── mock_parse_item_page_ws_desc │ │ │ │ │ ├── mock_parse_item_page_ws_feat │ │ │ │ │ ├── mock_parse_results_amz │ │ │ │ │ ├── mock_parse_results_ebay │ │ │ │ │ └── mock_parse_results_ws │ │ │ │ └── test_predict_help.py │ │ │ └── web-agent-site │ │ │ │ ├── engine │ │ │ │ ├── test_goal.py │ │ │ │ └── test_normalize.py │ │ │ │ └── test_utils.py │ │ ├── transfer │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── app.py │ │ │ ├── predict_help.py │ │ │ └── webshop_lite.py │ │ └── web_agent_site │ │ │ ├── __init__.py │ │ │ ├── app.py │ │ │ ├── attributes │ │ │ ├── annotate.py │ │ │ └── generate_attrs.py │ │ │ ├── engine │ │ │ ├── __init__.py │ │ │ ├── engine.py │ │ │ ├── goal.py │ │ │ └── normalize.py │ │ │ ├── envs │ │ │ ├── __init__.py │ │ │ ├── web_agent_site_env.py │ │ │ └── web_agent_text_env.py │ │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── models.py │ │ │ ├── static │ │ │ ├── images │ │ │ │ └── no-image-available.png │ │ │ └── style.css │ │ │ ├── templates │ │ │ ├── attributes_page.html │ │ │ ├── description_page.html │ │ │ ├── done_page.html │ │ │ ├── features_page.html │ │ │ ├── item_page.html │ │ │ ├── results_page.html │ │ │ ├── review_page.html │ │ │ └── search_page.html │ │ │ └── utils.py │ │ └── webshop_docker │ ├── start_task.py │ ├── typings │ ├── __init__.py │ ├── config.py │ ├── exception.py │ ├── general.py │ ├── output.py │ ├── request.py │ └── status.py │ └── utils │ ├── __init__.py │ ├── max_flow.py │ ├── others.py │ └── rules.py ├── README.md ├── RM ├── README.md ├── data │ └── put_training_data_here ├── data_utils │ ├── common_utils.py │ ├── constants.py │ ├── data_utils_ppo.py │ ├── data_utils_rm.py │ └── data_utils_sft.py ├── finetune_lora_rm.py ├── lora_utils.py ├── models │ ├── distributed_utils.py │ ├── ppo_trainer.py │ ├── qlora_model.py │ ├── reward_model.py │ ├── rl_models.py │ ├── rl_trainer.py │ └── trainer_utils.py ├── modifications │ ├── config.json │ └── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── siglip │ │ ├── __init__.py │ │ ├── configuration_siglip.py │ │ ├── convert_siglip_to_hf.py │ │ ├── image_processing_siglip.py │ │ ├── modeling_siglip.py │ │ ├── processing_siglip.py │ │ └── tokenization_siglip.py │ │ ├── siglip_encoder.py │ │ └── vision_encoder.py ├── prompts │ └── multi_frame_prompt.txt ├── requirements.txt ├── rm │ └── __init__.py ├── scripts │ ├── server_reward_sciworld.sh │ ├── train_reward_alfworld.sh │ ├── train_reward_game24.sh │ ├── train_reward_sciworld.sh │ └── train_reward_webshop.sh └── server_lora_rm.py └── assets ├── armap.gif └── teaser.svg /.gitignore: -------------------------------------------------------------------------------- 1 | info.txt 2 | __pycache__ 3 | TODO.txt 4 | data 5 | rm -------------------------------------------------------------------------------- /Eval/game24/README.md: -------------------------------------------------------------------------------- 1 | # Game24 Evaluation 2 | ## Install Dependencies 3 | Install dependencies of [Tree of Thoughts](https://github.com/princeton-nlp/tree-of-thought-llm/tree/ab400345c5ea39d28ea6d7d3be0e417b11113c87). 4 | 5 | 6 | ## Inference 7 | 8 | 9 | ### Greedy 10 | 11 | TBD 12 | 13 | ### Sampling 14 | 15 | TBD 16 | 17 | ### ARMAP-R 18 | 19 | ``` 20 | bash scripts/ARMAP-R/re_rm_llama8b.sh 21 | ``` 22 | 23 | ### ARMAP-B 24 | 25 | TBD 26 | 27 | ### ARMAP-M 28 | 29 | TBD -------------------------------------------------------------------------------- /Eval/game24/reflection_few_shot_examples.txt: -------------------------------------------------------------------------------- 1 | You will be given the history of a past experience in which you were placed in an environment and given a task to complete. You were unsuccessful in completing the task. Do not summarize your environment, but rather think about the strategy and path you took to attempt to complete the task. Devise a concise, new plan of action that accounts for your mistake with reference to specific actions that you should have taken. For example, if you tried A and B but forgot C, then devise a plan to achieve C with environment-specific actions. You will need this later when you are solving the same task. Give your plan after "Plan". Here are two examples: 2 | 3 | Task Description: 4 | Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number. 5 | 6 | Input: 7 1 2 4 7 | Steps: 8 | 9 | 7 - 1 = 6 (left: 2 4 6) 10 | 4 * 6 = 24 (left: 2 24) 11 | 24 + 2 = 26 (left: 26) 12 | 26 - 2 = 24 (left: 24) 13 | Answer: ((7 - 1) * 4) + 2 - 2 = 24 14 | 15 | STATUS: FAIL 16 | 17 | New plan: I realized that I mistakenly used the number 2 twice in my previous attempt, which violates the rule of using each number only once. To correct this, I need to ensure that each number is used exactly once and not reused. I will carefully track the remaining numbers after each operation to avoid duplication. I will explore alternative combinations, such as using 4 * ((7 - 1) + 2) = 32, then adjusting operations to reach 24 while adhering to the rules. 18 | 19 | Task Description: 20 | Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number. 21 | 22 | Input: 2 5 6 8 23 | Steps: 24 | 25 | 5 - 2 = 3 (left: 3 6 8) 26 | 3 * 8 = 24 (left: 6 24) 27 | Answer: (5 - 2) * 3 = 24 28 | 29 | STATUS: FAIL 30 | 31 | New plan: In my previous attempt, I failed because I did not use all the numbers. To correct this, I will ensure that I use all four numbers to reach 24. 32 | -------------------------------------------------------------------------------- /Eval/game24/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | from tot.tasks import get_task 6 | from tot.methods.bfs import solve, naive_solve 7 | from tot.models import gpt_usage 8 | import pdb 9 | from tot.argparser import parse_args 10 | 11 | def set_debugger(): 12 | from IPython.core import ultratb 13 | import sys 14 | sys.excepthook = ultratb.FormattedTB(call_pdb=True) 15 | set_debugger() 16 | 17 | def run(args): 18 | task = get_task(args.task) 19 | logs, cnt_avg, cnt_any = [], 0, 0 20 | if args.naive_run: 21 | file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json' 22 | else: 23 | file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json' 24 | os.makedirs(os.path.dirname(file), exist_ok=True) 25 | 26 | for i in range(args.task_start_index, args.task_end_index): 27 | 28 | # solve 29 | if args.naive_run: 30 | ys, info = naive_solve(args, task, i) 31 | else: 32 | ys, info = solve(args, task, i, to_print=True) 33 | 34 | # log 35 | infos = [task.test_output(i, y) for y in ys] 36 | info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': gpt_usage(args.backend)}) 37 | logs.append(info) 38 | with open(file, 'w') as f: 39 | json.dump(logs, f, indent=4) 40 | 41 | # log main metric 42 | accs = [info['r'] for info in infos] 43 | cnt_avg += sum(accs) / len(accs) 44 | cnt_any += any(accs) 45 | print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n') 46 | 47 | n = args.task_end_index - args.task_start_index 48 | print(cnt_avg / n, cnt_any / n) 49 | print('usage_so_far', gpt_usage(args.backend)) 50 | 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | args = parse_args() 56 | print(args) 57 | run(args) 58 | -------------------------------------------------------------------------------- /Eval/game24/run_mcts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | from tot.tasks import get_task 6 | from tot.methods.bfs import solve, naive_solve 7 | from tot.models import gpt_usage 8 | from tot.methods.mcts_agents.env_game24_wrapper import mcts_solve 9 | import pdb 10 | from tot.argparser import parse_args 11 | 12 | def run(args): 13 | task = get_task(args.task) 14 | logs, cnt_avg, cnt_any = [], 0, 0 15 | if args.naive_run: 16 | file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json' 17 | else: 18 | file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json' 19 | os.makedirs(os.path.dirname(file), exist_ok=True) 20 | 21 | for i in range(args.task_start_index, args.task_end_index): 22 | # solve 23 | mcts_solve(args, task, i) 24 | 25 | if __name__ == '__main__': 26 | args = parse_args() 27 | print(args) 28 | run(args) 29 | -------------------------------------------------------------------------------- /Eval/game24/scripts/ARMAP-B/rm_best_of_n_llama8b.sh: -------------------------------------------------------------------------------- 1 | LM_API="" 2 | MD_ID="meta-llama/Meta-Llama-3.1-8B-Instruct" 3 | echo ${LM_API} 4 | python run.py \ 5 | --task game24 \ 6 | --task_start_index 900 \ 7 | --task_end_index 1000 \ 8 | --naive_run \ 9 | --prompt_sample cot \ 10 | --n_generate_sample 100 \ 11 | --temperature 0.7 \ 12 | --llm_api ${LM_API} \ 13 | --backend ${MD_ID} \ 14 | -------------------------------------------------------------------------------- /Eval/game24/scripts/ARMAP-M/mcts_rm.sh: -------------------------------------------------------------------------------- 1 | LM_API="" 2 | RM_API="" 3 | echo ${LM_API} 4 | echo ${RM_API} 5 | python run_mcts.py \ 6 | --task game24MCTS \ 7 | --task_start_index 900 \ 8 | --task_end_index 1000 \ 9 | --method_generate propose \ 10 | --method_evaluate value \ 11 | --method_select greedy \ 12 | --n_evaluate_sample 3 \ 13 | --n_select_sample 5 \ 14 | --mcts_run \ 15 | --prompt_sample cot \ 16 | --output_dir output/llm380b4bit_rm07_debug_v3 \ 17 | --rollouts 100 \ 18 | --llm_api ${LM_API} \ 19 | --reward_func ${RM_API} \ 20 | --temperature 0.7 \ 21 | --max_token 999 \ 22 | ${@} 23 | 24 | #--method_generate sample \ 25 | #--task_start_index 900 \ 26 | -------------------------------------------------------------------------------- /Eval/game24/scripts/ARMAP-R/re_rm_llama8b.sh: -------------------------------------------------------------------------------- 1 | #LM_API="http://9.33.168.174:7778/v1" 2 | LM_API="http://localhost:17777/v1" 3 | MD_ID="meta-llama/Meta-Llama-3.1-8B-Instruct" 4 | echo ${LM_API} 5 | ulimit -n 10000 6 | python run_re.py \ 7 | --task game24 \ 8 | --task_start_index 900 \ 9 | --task_end_index 1000 \ 10 | --naive_run \ 11 | --prompt_sample cot \ 12 | --n_generate_sample 1 \ 13 | --temperature 0.0 \ 14 | --llm_api ${LM_API} \ 15 | --backend ${MD_ID} \ 16 | --reward http://localhost:15601/api/generate \ 17 | --re_iterations 10 \ 18 | -------------------------------------------------------------------------------- /Eval/game24/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('README.md', 'r', encoding='utf-8') as fh: 4 | long_description = fh.read() 5 | 6 | 7 | setuptools.setup( 8 | name='tree-of-thoughts-llm', 9 | author='Shunyu Yao', 10 | author_email='shunyuyao.cs@gmail.com', 11 | description='Official Implementation of "Tree of Thoughts: Deliberate Problem Solving with Large Language Models"', 12 | keywords='tree-search, large-language-models, llm, prompting, tree-of-thoughts', 13 | long_description=long_description, 14 | long_description_content_type='text/markdown', 15 | url='https://github.com/princeton-nlp/tree-of-thought-llm', 16 | project_urls={ 17 | 'Homepage': 'https://github.com/princeton-nlp/tree-of-thought-llm', 18 | }, 19 | package_dir={'': 'src'}, 20 | packages=setuptools.find_packages(where='src'), 21 | classifiers=[ 22 | "License :: OSI Approved :: MIT License", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.7", 25 | "Programming Language :: Python :: 3.8", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | 'Intended Audience :: Science/Research', 30 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 31 | ], 32 | python_requires='>=3.7', 33 | install_requires=[ 34 | 'setuptools', 35 | ], 36 | include_package_data=True, 37 | ) 38 | -------------------------------------------------------------------------------- /Eval/game24/src/tot/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /Eval/game24/src/tot/argparser_bk.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pdb 3 | 4 | def parse_args(): 5 | # task runing 6 | args = argparse.ArgumentParser() 7 | args.add_argument('--backend', type=str, default="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4") 8 | args.add_argument('--temperature', type=float, default=0.7) 9 | 10 | args.add_argument('--task', type=str, default='game24', choices=['game24', 'text', 'crosswords']) 11 | args.add_argument('--task_start_index', type=int, default=900) 12 | args.add_argument('--task_end_index', type=int, default=1000) 13 | 14 | args.add_argument('--naive_run', action='store_true') 15 | args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot']) # only used when method_generate = sample, or naive_run 16 | 17 | args.add_argument('--method_generate', type=str, choices=['sample', 'propose']) 18 | args.add_argument('--method_evaluate', type=str, choices=['value', 'vote']) 19 | args.add_argument('--method_select', type=str, choices=['sample', 'greedy'], default='greedy') 20 | args.add_argument('--n_generate_sample', type=int, default=1) # only thing needed if naive_run 21 | args.add_argument('--n_evaluate_sample', type=int, default=1) 22 | args.add_argument('--n_select_sample', type=int, default=1) 23 | args.add_argument('--llm_api', type=str) 24 | # data generation 25 | args.add_argument('--stage', type=str, 26 | choices=['stage1', 'stage2', 'stage3'], default='stage2') 27 | args.add_argument('--output_dir', type=str, 28 | default="output/data_generation/stage1_v1") 29 | args.add_argument('--start_idx', type=int, default=0) 30 | args.add_argument('--end_idx', type=int, default=1000) 31 | args.add_argument('--batch_size', type=int, default=10) 32 | args.add_argument('--stage2_output_dir', type=str, 33 | default="output/data_generation/stage2_v1") 34 | args.add_argument('--stage3_output_dir', type=str, 35 | default="output/data_generation/stage3_v1") 36 | args.add_argument('--max_tokens', type=int, default=100) 37 | args.add_argument('--stage3_previous_output_dir', type=str, 38 | default="output/data_generation/stage3_v1") 39 | args = args.parse_args() 40 | return args 41 | -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/debug_prompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tot.prompts.game24_mcts import * 3 | from tot.models_new import gpt 4 | import pdb 5 | prompt_llama = '''<|begin_of_text|><|start_header_id|>system<|end_header_id|> 6 | 7 | You are a helpful assistant to complete steps for the Game24.<|eot_id|><|start_header_id|>user<|end_header_id|> 8 | 9 | {USER_INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|> 10 | ''' 11 | 12 | def debug_prompt(): 13 | x = '3 4 4 13' 14 | y = '13 + 3 = 16 (left: 4 4 16)\n4 + 4 = 8 (left: 8 16)\n8 + 16 = 24 (left: 24)\n' 15 | prompt = cot_prompt.format(input=x) + 'Steps:\n' + y + "Answer: " 16 | #prompt = prompt_llama.format(USER_INPUT=prompt) 17 | raw_output = gpt(prompt, n=1, stop=None) 18 | proposals = raw_output[0].split('\n') 19 | print(prompt) 20 | #print(proposals) 21 | print(raw_output[0]) 22 | pdb.set_trace() 23 | 24 | def debug_prompt2(): 25 | x = '3 4 4 13' 26 | #x = '4 5 6 10' 27 | #y = '4 + 5 = 9 (left: 6 9 10)\n' 28 | #y = '4 + 5 = 9 (left: 6 9 10)\n6 + 9 = 15 (left: 10 15)\n' 29 | #x = '3 4 4 13' 30 | #y = '13 + 3 = 16 (left: 4 4 16)\n4 + 4 = 8 (left: 8 16)\n8 + 16 = 24 (left: 24)\n' 31 | #y = '13 + 3 = 16 (left: 4 4 16)\n4 + 4 = 8 (left: 8 16)\n' 32 | y = '13 + 3 = 16 (left: 4 4 16)\n' 33 | #prompt = prompt_llama.format(USER_INPUT=prompt) 34 | prompt = cot_prompt_rollout_v3.format(input=x) + 'Steps:\n' + y 35 | #prompt = cot_prompt_rollout.format(input=x) + 'Steps:\n' + y + "Answer: " 36 | #prompt = prompt_llama.format(USER_INPUT=prompt) 37 | raw_output = gpt(prompt, n=1, stop=None, temperature=0.0) 38 | proposals = raw_output[0].split('\n') 39 | print(prompt) 40 | #print(proposals) 41 | print(raw_output[0]) 42 | pdb.set_trace() 43 | 44 | 45 | if __name__=="__main__": 46 | #debug_prompt() 47 | debug_prompt2() -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/fact_rlhf_reward_prompt_wj2.txt: -------------------------------------------------------------------------------- 1 | USER: Please evaluate whether your last response achieves the "Navigation Intent" or not. 2 | 3 | Here's the information you'll have: 4 | The current web page screenshot: This is a screenshot of the webpage, with each interactable element assigned a unique numerical id. Each bounding box and its respective id shares the same color. 5 | The observation, which lists the IDs of all interactable elements on the current web page with their text content if any, in the format [id] [tagType] [text content]. tagType is the type of the element, such as button, link, or textbox. text content is the text content of the element. 6 | {factual_prompt} 7 | 8 | ASSISTANT: Following your definitions, the score of my last response is -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_0.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_1.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_2.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_3.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_4.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_5.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_6.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_7.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/game24/src/tot/methods/mcts_agents/not_used_8.png -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/not_used_factual.json: -------------------------------------------------------------------------------- 1 | { 2 | "not_used.png" : [ 3 | "not_used" 4 | ] 5 | } -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/performance_analysis_bk.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pdb 4 | 5 | 6 | def analyze_performance(): 7 | result_dir = "output/llm380b4bit_GT_Debug" 8 | fn_list = os.listdir(result_dir) 9 | best_reward_list = [] 10 | for fn in fn_list: 11 | full_fn = os.path.join(result_dir, fn) 12 | with open(full_fn) as fh: 13 | fdict = json.load(fh) 14 | max_re = max(fdict["rewards"]) 15 | best_reward_list.append(max_re) 16 | if max_re !=1 or len(fdict["rewards"])==1: 17 | continue 18 | visualize_example(fdict) 19 | best_reward = sum(best_reward_list) / len(best_reward_list) 20 | print("Evaluating number: %d\n"%(len(best_reward_list))) 21 | print("Best rewards: %f\n"%(best_reward)) 22 | pdb.set_trace() 23 | 24 | def visualize_example(fdict): 25 | print(fdict['trajectories'][-1][2][-1]) 26 | pdb.set_trace() 27 | 28 | def analyze_performance(): 29 | 30 | if __name__ == "__main__": 31 | analyze_performance() 32 | -------------------------------------------------------------------------------- /Eval/game24/src/tot/methods/mcts_agents/test_average.py: -------------------------------------------------------------------------------- 1 | Average = (53.9 + 50.6 + 52.0 + 9.6) / 4 2 | Greedy = (57.2 + 55.1 + 50.4 + 6.0) / 4 3 | Sampling = (57.3 + 57.0 + 61.3 + 19.0) / 4 4 | MCTS = (58.2 + 55.9 + 66.8 + 24.0) / 4 5 | print("llama70b") 6 | print(Average) 7 | print(Greedy) 8 | print(Sampling) 9 | print(MCTS) 10 | 11 | Average = (24.5 + 20.6 + 56.4 + 2.0) / 4 12 | Greedy = (30.8 + 23.8 + 57.7 + 2.0) / 4 13 | Sampling = (35.7 + 28.1 + 59.3 + 11.0) / 4 14 | MCTS = (31.2 + 24.9 + 59.1 + 9.0) / 4 15 | print("llama8b") 16 | print(Average) 17 | print(Greedy) 18 | print(Sampling) 19 | print(MCTS) 20 | 21 | Average = (18.4 + 17.1 + 52.0 + 1.0) / 4 22 | Greedy = (21.1 + 19.6 + 50.4 + 1.0) / 4 23 | Sampling = (24.5 + 21.2 + 61.2 + 2.0) / 4 24 | MCTS = (30.0 + 23.4 + 66.8 + 4.0 ) / 4 25 | print("mixtral") 26 | print(Average) 27 | print(Greedy) 28 | print(Sampling) 29 | print(MCTS) 30 | 31 | Average = (10.0 + 7.6 + 34.7 + 2.0) / 4 32 | Greedy = (9.6 + 6.5 + 42.4 + 2.1) / 4 33 | Sampling = (20.0 + 17.0 + 51.1 + 9.0) / 4 34 | MCTS = (27.0 + 26.2 + 53.7 + 10.0) / 4 35 | print("phi") 36 | print(Average) 37 | print(Greedy) 38 | print(Sampling) 39 | print(MCTS) -------------------------------------------------------------------------------- /Eval/game24/src/tot/prompts/text.py: -------------------------------------------------------------------------------- 1 | standard_prompt = ''' 2 | Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input} 3 | ''' 4 | 5 | cot_prompt = ''' 6 | Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input} 7 | 8 | Make a plan then write. Your output should be of the following format: 9 | 10 | Plan: 11 | Your plan here. 12 | 13 | Passage: 14 | Your passage here. 15 | ''' 16 | 17 | 18 | vote_prompt = '''Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice. 19 | ''' 20 | 21 | compare_prompt = '''Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent". 22 | ''' 23 | 24 | score_prompt = '''Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10. 25 | ''' -------------------------------------------------------------------------------- /Eval/game24/src/tot/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | def get_task(name): 2 | if name == 'game24': 3 | from tot.tasks.game24 import Game24Task 4 | return Game24Task() 5 | elif name == 'game24MCTS': 6 | from tot.tasks.game24_mcts import Game24Task 7 | return Game24Task() 8 | elif name == 'text': 9 | from tot.tasks.text import TextTask 10 | return TextTask() 11 | elif name == 'crosswords': 12 | from tot.tasks.crosswords import MiniCrosswordsTask 13 | return MiniCrosswordsTask() 14 | else: 15 | raise NotImplementedError -------------------------------------------------------------------------------- /Eval/game24/src/tot/tasks/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data') 3 | 4 | class Task: 5 | def __init__(self): 6 | pass 7 | 8 | def __len__(self) -> int: 9 | pass 10 | 11 | def get_input(self, idx: int) -> str: 12 | pass 13 | 14 | def test_output(self, idx: int, output: str): 15 | pass -------------------------------------------------------------------------------- /Eval/game24/src/tree_of_thoughts_llm.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | MANIFEST.in 3 | README.md 4 | pyproject.toml 5 | requirements.txt 6 | setup.py 7 | src/tot/__init__.py 8 | src/tot/models.py 9 | src/tot/data/24/24.csv 10 | src/tot/data/crosswords/mini0505.json 11 | src/tot/data/crosswords/mini0505_0_100_5.json 12 | src/tot/data/text/data_100_random_text.txt 13 | src/tot/methods/bfs.py 14 | src/tot/prompts/crosswords.py 15 | src/tot/prompts/game24.py 16 | src/tot/prompts/text.py 17 | src/tot/tasks/__init__.py 18 | src/tot/tasks/base.py 19 | src/tot/tasks/crosswords.py 20 | src/tot/tasks/game24.py 21 | src/tot/tasks/text.py 22 | src/tree_of_thoughts_llm.egg-info/PKG-INFO 23 | src/tree_of_thoughts_llm.egg-info/SOURCES.txt 24 | src/tree_of_thoughts_llm.egg-info/dependency_links.txt 25 | src/tree_of_thoughts_llm.egg-info/requires.txt 26 | src/tree_of_thoughts_llm.egg-info/top_level.txt -------------------------------------------------------------------------------- /Eval/game24/src/tree_of_thoughts_llm.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Eval/game24/src/tree_of_thoughts_llm.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 2 | aiosignal==1.3.1 3 | async-timeout==4.0.2 4 | attrs==23.1.0 5 | backoff==2.2.1 6 | certifi==2023.5.7 7 | charset-normalizer==3.1.0 8 | frozenlist==1.3.3 9 | idna==3.4 10 | mpmath==1.3.0 11 | multidict==6.0.4 12 | numpy==1.24.3 13 | openai==0.27.7 14 | requests==2.31.0 15 | sympy==1.12 16 | tqdm==4.65.0 17 | urllib3==2.0.2 18 | yarl==1.9.2 19 | pandas==2.0.3 20 | -------------------------------------------------------------------------------- /Eval/game24/src/tree_of_thoughts_llm.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | tot 2 | -------------------------------------------------------------------------------- /Eval/sciworld/README.md: -------------------------------------------------------------------------------- 1 | # Sciworld Evaluation 2 | ## Install Dependencies 3 | Install dependencies of [Exploration-Based Trajectory Optimization for LLM Agents](https://github.com/Yifan-Song793/ETO/tree/a2fc5da38f8d00cfaf3f9b6370d586eebaf72904). 4 | 5 | 6 | ## Inference 7 | 8 | 9 | ### Greedy 10 | 11 | ``` 12 | bash scripts/Greedy/llama8b_base_seen.sh 13 | ``` 14 | 15 | ### Sampling 16 | 17 | ``` 18 | bash scripts/Sampling/llama8b_sample_seen.sh 19 | ``` 20 | 21 | ### ARMAP-R 22 | 23 | ``` 24 | bash scripts/ARMAP-R/llama8b_re_seen.sh 25 | ``` 26 | 27 | ### ARMAP-B 28 | 29 | First, run the process to obtain the Sampling results. 30 | ``` 31 | bash scripts/Sampling/llama8b_sample_seen.sh 32 | ``` 33 | 34 | Then, use the reward model to obtain the best-of-n results. 35 | ``` 36 | python scripts/ARMAP-B/calc_sample.py outputs/llama8b_sample_seen 12345 37 | ``` 38 | 39 | ### ARMAP-M 40 | 41 | 42 | #### port transfer 43 | ``` 44 | bash script/ARMAP-M/trans_local_llm.sh 45 | bash script/ARMAP-M/trans_rm.sh 46 | ``` 47 | 48 | #### mcts 49 | ``` 50 | bash script/ARMAP-M/run_mcts.sh 51 | ``` 52 | #### performance analysis 53 | ``` 54 | python scripts/ARMAP-M/test_sampling.py 55 | ``` 56 | -------------------------------------------------------------------------------- /Eval/sciworld/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | from .openai_lm_agent import OpenAILMAgent 3 | from .fastchat_agent import FastChatAgent 4 | from .openllm_lm_agent import OPENLLMLMAgent 5 | from .local_agent import LocalAgent 6 | -------------------------------------------------------------------------------- /Eval/sciworld/agents/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict, Any, Mapping 3 | 4 | logger = logging.getLogger("agent_frame") 5 | 6 | 7 | class LMAgent: 8 | """Base class for an agent.""" 9 | 10 | def __init__(self, config: Mapping[str, Any]): 11 | self.config = config 12 | logger.debug(f"Initialized {self.__class__.__name__} with config: {config}") 13 | # The agent should not generate observations or expert feedback 14 | self.stop_words = ["\nObservation:", "\nTask:", "\n---"] 15 | 16 | def __call__(self) -> str: 17 | pass 18 | 19 | def add_system_message( 20 | self, messages: List[Dict[str, str]] 21 | ) -> List[Dict[str, str]]: 22 | # Prepend the prompt with the system message 23 | first_msg = messages[0] 24 | assert first_msg["role"] == "user" 25 | system, examples, task = first_msg["content"].split("\n---\n") 26 | messages = [ 27 | {"role": "system", "content": system}, 28 | {"role": "user", "content": examples + "\n---\n" + task}, 29 | ] + messages[1:] 30 | return messages 31 | -------------------------------------------------------------------------------- /Eval/sciworld/agents/openai_lm_agent.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import logging 3 | import backoff 4 | 5 | from .base import LMAgent 6 | 7 | logger = logging.getLogger("agent_frame") 8 | 9 | 10 | class OpenAILMAgent(LMAgent): 11 | def __init__(self, config): 12 | super().__init__(config) 13 | assert "model_name" in config.keys() 14 | if "api_base" in config: 15 | openai.api_base = config['api_base'] 16 | if "api_key" in config: 17 | openai.api_key = config['api_key'] 18 | 19 | def __call__(self, messages) -> str: 20 | # Prepend the prompt with the system message 21 | response = openai.ChatCompletion.create( 22 | model=self.config["model_name"], 23 | messages=messages, 24 | max_tokens=self.config.get("max_tokens", 512), 25 | temperature=self.config.get("temperature", 0), 26 | stop=self.stop_words, 27 | top_p=self.config.get("top_p", 1), 28 | ) 29 | return response.choices[0].message["content"] 30 | -------------------------------------------------------------------------------- /Eval/sciworld/configs/model/fastchat.json: -------------------------------------------------------------------------------- 1 | { 2 | "agent_class": "FastChatAgent", 3 | "config": { 4 | "controller_address": "http://localhost:21001", 5 | "model_name": "vicuna-7b-v1.5", 6 | "max_new_tokens": 512, 7 | "temperature": 0.0 8 | } 9 | } -------------------------------------------------------------------------------- /Eval/sciworld/configs/model/fastchat_explore.json: -------------------------------------------------------------------------------- 1 | { 2 | "agent_class": "FastChatAgent", 3 | "config": { 4 | "controller_address": "http://localhost:21001", 5 | "model_name": "vicuna-7b-v1.5", 6 | "max_new_tokens": 512, 7 | "temperature": 0.7 8 | } 9 | } -------------------------------------------------------------------------------- /Eval/sciworld/configs/model/local.json: -------------------------------------------------------------------------------- 1 | { 2 | "agent_class": "LocalAgent", 3 | "config": { 4 | "model_name": "llama-3-70b-instruct", 5 | "temperature": 1, 6 | "top_p": 1, 7 | "top_k": 40, 8 | "max_new_tokens": 150, 9 | "min_new_tokens": 20, 10 | "data_gen": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Eval/sciworld/configs/model/local_0.5.json: -------------------------------------------------------------------------------- 1 | { 2 | "agent_class": "LocalAgent", 3 | "config": { 4 | "model_name": "llama-3-70b-instruct", 5 | "temperature": 0.5, 6 | "top_p": 1, 7 | "top_k": 40, 8 | "max_new_tokens": 150, 9 | "min_new_tokens": 20, 10 | "data_gen": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Eval/sciworld/configs/model/local_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "agent_class": "LocalAgent", 3 | "config": { 4 | "model_name": "llama-3-70b-instruct", 5 | "temperature": 0, 6 | "top_p": 1, 7 | "top_k": 40, 8 | "max_new_tokens": 150, 9 | "min_new_tokens": 20, 10 | "data_gen": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /Eval/sciworld/configs/task/alfworld.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "task_class": "AlfWorldTask", 4 | "filepath": "eval_agent/data" 5 | }, 6 | "env_config": { 7 | "env_class": "AlfWorldEnv", 8 | "instruction_path": "eval_agent/prompt/instructions/alfworld_inst.txt", 9 | "icl_path": "eval_agent/prompt/icl_examples/alfworld_icl.json", 10 | "icl_format": "conversation", 11 | "max_steps": 40 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Eval/sciworld/configs/task/sciworld.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "task_class": "SciWorldTask" 4 | }, 5 | "env_config": { 6 | "env_class": "SciWorldEnv", 7 | "instruction_path": "prompt/instructions/sciworld_inst_ibm.txt", 8 | "env_jar_path": "/nobackup/users/zfchen/cdl/ScienceWorld_Planning/envs/scienceworld/scienceworld.jar", 9 | "icl_path": "prompt/icl_examples/sciworld_icl.json", 10 | "icl_format": "conversation", 11 | "max_steps": 10 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Eval/sciworld/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseEnv 2 | from .sciworld_env import SciWorldEnv 3 | # from .alfworld_env import AlfWorldEnv -------------------------------------------------------------------------------- /Eval/sciworld/mcts_agents/my_random_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Random Agent given as an example 3 | """ 4 | 5 | from gym import spaces 6 | import dyna_gym.utils.utils as utils 7 | 8 | class MyRandomAgent(object): 9 | def __init__(self, action_space): 10 | self.action_space = action_space 11 | 12 | def reset(self, p=None): 13 | """ 14 | Reset the attributes. 15 | Expect to receive them in the same order as init. 16 | p : list of parameters 17 | """ 18 | if p is not None: 19 | utils.assert_types(p,[spaces.discrete.Discrete]) 20 | self.__init__(p[0]) 21 | 22 | def display(self): 23 | """ 24 | Display infos about the attributes. 25 | """ 26 | print('Displaying Random agent:') 27 | print('Action space :', self.action_space) 28 | 29 | def act(self, observation=None, reward=None, done=None): 30 | return self.action_space.sample() 31 | -------------------------------------------------------------------------------- /Eval/sciworld/prompt/__init__.py: -------------------------------------------------------------------------------- 1 | from .templates import prompt_with_icl -------------------------------------------------------------------------------- /Eval/sciworld/prompt/icl_examples/alfworld_icl.json: -------------------------------------------------------------------------------- 1 | [ 2 | [ 3 | { 4 | "role": "user", 5 | "content": "You are in the middle of a room. Looking quickly around you, you see a bathtubbasin 1, a cabinet 2, a cabinet 1, a countertop 1, a garbagecan 1, a handtowelholder 1, a sinkbasin 1, a toilet 1, a toiletpaperhanger 1, and a towelholder 1.\n\nYour task is to: put a toiletpaper in toiletpaperhanger." 6 | }, 7 | { 8 | "from": "gpt", 9 | "content": "Thought: I need to place a toilet paper roll on the toilet paper hanger. First, I should check if there's already a roll on the hanger or if it's empty.\nAction: go to toiletpaperhanger 1" 10 | }, 11 | { 12 | "role": "user", 13 | "content": "Observation: On the toiletpaperhanger 1, you see nothing." 14 | }, 15 | { 16 | "from": "gpt", 17 | "content": "Thought: The toilet paper hanger is empty, so I need to find a new roll of toilet paper to put on it.\nAction: go to toilet 1" 18 | }, 19 | { 20 | "role": "user", 21 | "content": "Observation: On the toilet 1, you see a soapbottle 1, and a toiletpaper 1." 22 | }, 23 | { 24 | "from": "gpt", 25 | "content": "Thought: I've found a toilet paper roll on the toilet. I should pick it up so I can place it on the hanger.\nAction: take toiletpaper 1 from toilet 1" 26 | }, 27 | { 28 | "role": "user", 29 | "content": "Observation: You pick up the toiletpaper 1 from the toilet 1." 30 | }, 31 | { 32 | "from": "gpt", 33 | "content": "Thought: Now that I have the toilet paper roll, I need to return to the toilet paper hanger to install it.\nAction: go to toiletpaperhanger 1" 34 | }, 35 | { 36 | "role": "user", 37 | "content": "Observation: On the toiletpaperhanger 1, you see nothing." 38 | }, 39 | { 40 | "from": "gpt", 41 | "content": "Thought: I'm back at the toilet paper hanger with the roll in hand. It's time to put the toilet paper on the hanger.\nAction: put toiletpaper 1 in/on toiletpaperhanger 1" 42 | } 43 | ] 44 | ] -------------------------------------------------------------------------------- /Eval/sciworld/prompt/instructions/alfworld_inst.txt: -------------------------------------------------------------------------------- 1 | Interact with a household to solve a task. Imagine you are an intelligent agent in a household environment and your target is to perform actions to complete the task goal. At the beginning of your interactions, you will be given the detailed description of the current environment and your goal to accomplish. 2 | For each of your turn, you will be given the observation of the last turn. You should first think about the current condition and plan for your future actions, and then output your action in this turn. Your output must strictly follow this format:"Thought: your thoughts.\nAction: your next action". 3 | 4 | The available actions are: 5 | 1. go to {recep} 6 | 2. task {obj} from {recep} 7 | 3. put {obj} in/on {recep} 8 | 4. open {recep} 9 | 5. close {recep} 10 | 6. toggle {obj} {recep} 11 | 7. clean {obj} with {recep} 12 | 8. heat {obj} with {recep} 13 | 9. cool {obj} with {recep} 14 | where {obj} and {recep} correspond to objects and receptacles. 15 | After your each turn, the environment will give you immediate feedback based on which you plan your next few steps. if the envrionment output "Nothing happened", that means the previous action is invalid and you should try more options. 16 | 17 | Your response should use the following format: 18 | 19 | Thought: 20 | Action: -------------------------------------------------------------------------------- /Eval/sciworld/prompt/instructions/sciworld_inst.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant to do some scientific experiment in an environment. 2 | In the environment, there are several rooms: kitchen, foundry, workshop, bathroom, outside, living room, bedroom, greenhouse, art studio, hallway 3 | You should explore the environment and find the items you need to complete the experiment. 4 | You can teleport to any room in one step. 5 | All containers in the environment have already been opened, you can directly get items from the containers. 6 | 7 | The available actions are: 8 | open OBJ: open a container 9 | close OBJ: close a container 10 | activate OBJ: activate a device 11 | deactivate OBJ: deactivate a device 12 | connect OBJ to OBJ: connect electrical components 13 | disconnect OBJ: disconnect electrical components 14 | use OBJ [on OBJ]: use a device/item 15 | look around: describe the current room 16 | examine OBJ: describe an object in detail 17 | look at OBJ: describe a container's contents 18 | read OBJ: read a note or book 19 | move OBJ to OBJ: move an object to a container 20 | pick up OBJ: move an object to the inventory 21 | pour OBJ into OBJ: pour a liquid into a container 22 | mix OBJ: chemically mix a container 23 | teleport to LOC: teleport to a specific room 24 | focus on OBJ: signal intent on a task object 25 | wait: task no action for 10 steps 26 | wait1: task no action for a step -------------------------------------------------------------------------------- /Eval/sciworld/prompt/instructions/sciworld_inst_ibm.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant to do some scientific experiment in an environment. 2 | In the environment, there are several rooms: kitchen, foundry, workshop, bathroom, outside, living room, bedroom, greenhouse, art studio, hallway 3 | You should explore the environment and find the items you need to complete the experiment. 4 | You can teleport to any room in one step. 5 | All containers in the environment have already been opened, you can directly get items from the containers. 6 | 7 | The available actions are: 8 | open OBJ: open a container 9 | close OBJ: close a container 10 | activate OBJ: activate a device 11 | deactivate OBJ: deactivate a device 12 | connect OBJ to OBJ: connect electrical components 13 | disconnect OBJ: disconnect electrical components 14 | use OBJ [on OBJ]: use a device/item 15 | look around: describe the current room 16 | examine OBJ: describe an object in detail 17 | look at OBJ: describe a container's contents 18 | read OBJ: read a note or book 19 | move OBJ to OBJ: move an object to a container 20 | pick up OBJ: move an object to the inventory 21 | pour OBJ into OBJ: pour a liquid into a container 22 | mix OBJ: chemically mix a container 23 | teleport to LOC: teleport to a specific room 24 | focus on OBJ: signal intent on a task object 25 | wait: task no action for 10 steps 26 | wait1: task no action for a step 27 | 28 | You will be given one example and one task below. The example is to showcase how you should do when you are given the task. 29 | 30 | In this example, there are several key elements. 31 | First is the task you need to complete. You need to break this task down into several steps, each corresponding to a thought, action, and observation. 32 | The thought is your reasoning for this step, and the action is the move you make based on this reasoning. 33 | The action must be selected from the aforementioned available actions. 34 | You need to generate the thought and action without any other words, while the observation is based on the environment you see after performing your action. 35 | You do not need to generate the observation; it will be provided to you after you generate the thought and action. 36 | 37 | Here is the example: 38 | -------------------------------------------------------------------------------- /Eval/sciworld/prompt/instructions/sciworld_inst_llama.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant to do some scientific experiment in an environment. 2 | In the environment, there are several rooms: kitchen, foundry, workshop, bathroom, outside, living room, bedroom, greenhouse, art studio, hallway 3 | You should explore the environment and find the items you need to complete the experiment. 4 | You can teleport to any room in one step. 5 | All containers in the environment have already been opened, you can directly get items from the containers. 6 | 7 | The available actions are: 8 | open OBJ: open a container 9 | close OBJ: close a container 10 | activate OBJ: activate a device 11 | deactivate OBJ: deactivate a device 12 | connect OBJ to OBJ: connect electrical components 13 | disconnect OBJ: disconnect electrical components 14 | use OBJ [on OBJ]: use a device/item 15 | look around: describe the current room 16 | examine OBJ: describe an object in detail 17 | look at OBJ: describe a container's contents 18 | read OBJ: read a note or book 19 | move OBJ to OBJ: move an object to a container 20 | pick up OBJ: move an object to the inventory 21 | pour OBJ into OBJ: pour a liquid into a container 22 | mix OBJ: chemically mix a container 23 | teleport to LOC: teleport to a specific room 24 | focus on OBJ: signal intent on a task object 25 | wait: task no action for 10 steps 26 | wait1: task no action for a step 27 | 28 | You will be given one example and one task below. The example is to showcase how you should do when you are given the task. 29 | 30 | In this example, there are several key elements. 31 | First is the task you need to complete. You need to break this task down into several steps, each corresponding to a thought, action, and observation. 32 | The thought is your reasoning for this step, and the action is the move you make based on this reasoning. 33 | The action must be selected from the aforementioned available actions. 34 | You need to generate the thought and action without any other words, while the observation is based on the environment you see after performing your action. 35 | You do not need to generate the observation; it will be provided to you after you generate the thought and action. 36 | 37 | Here is the example: 38 | -------------------------------------------------------------------------------- /Eval/sciworld/prompt/instructions/webshop_inst.txt: -------------------------------------------------------------------------------- 1 | You are web shopping. 2 | I will give you instructions about what to do. 3 | You have to follow the instructions. 4 | Every round I will give you an observation and a list of available actions, you have to respond an action based on the state and instruction. 5 | You can use search action if search is available. 6 | You can click one of the buttons in clickables. 7 | An action should be of the following structure: 8 | search[keywords] 9 | click[value] 10 | If the action is not valid, perform nothing. 11 | Keywords in search are up to you, but the value in click must be a value in the list of available actions. 12 | Remember that your keywords in search should be carefully designed. 13 | Your response should use the following format: 14 | 15 | Thought: I think ... 16 | Action: click[something] -------------------------------------------------------------------------------- /Eval/sciworld/scripts/ARMAP-M/run_mcts.sh: -------------------------------------------------------------------------------- 1 | reward_url=15678 2 | model_port=7780 3 | reward_func=http://172.30.150.31:${reward_url}/api/generate 4 | model_url="http://172.30.150.31:${model_port}/v1" 5 | start_id=0 6 | echo $job_id 7 | echo $reward_func 8 | echo $model_url 9 | echo $start_id 10 | 11 | python main_mcts.py --agent_config openllm_llama31_70b_mcts \ 12 | --model_name llama-3.1-70b \ 13 | --exp_config sciworld \ 14 | --split dev \ 15 | --verbose \ 16 | --exp_name _mcts_vllm_llama31_70b \ 17 | --seq_num 10 \ 18 | --horizon 10 \ 19 | --rollouts 10 \ 20 | --start_id ${start_id} \ 21 | --minimal_sample_num 10 \ 22 | --model_url_add ${model_url} \ 23 | --reward_func ${reward_func} -------------------------------------------------------------------------------- /Eval/sciworld/scripts/ARMAP-M/test_sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(".") 4 | from mcts_agents.sci_world_sample import score_sampling 5 | 6 | 7 | def test_sampling(): 8 | data_dir = "PATH_TO_OUTPUTS_DIR" 9 | url = "http://172.30.150.31:15678/api/generate" 10 | out_fn = "sampling_mcts.txt" 11 | fh = open(out_fn, "w+") 12 | print(url+"\n") 13 | fh.flush() 14 | output_file = os.path.basename( 15 | data_dir) + "mcts_result.json" 16 | if os.path.isfile(output_file): 17 | continue 18 | score_sampling(data_dir, output_file, url, fh) 19 | fh.flush() 20 | 21 | 22 | if __name__ == "__main__": 23 | test_sampling() 24 | -------------------------------------------------------------------------------- /Eval/sciworld/scripts/ARMAP-M/trans_local_llm.sh: -------------------------------------------------------------------------------- 1 | llm_url=LLM_URL 2 | llm_port=LLM_PORT 3 | ssh -L 0.0.0.0:${llm_port}:${llm_url}:${llm_port} REMOTE_SERVER_ADDRESS -------------------------------------------------------------------------------- /Eval/sciworld/scripts/ARMAP-M/trans_rm.sh: -------------------------------------------------------------------------------- 1 | rm_url=RM_URL 2 | rm_port=RM_PORT 3 | ssh -L 0.0.0.0:${rm_port}:${rm_url}:${rm_port} REMOTE_SERVER_ADDRESS -------------------------------------------------------------------------------- /Eval/sciworld/scripts/ARMAP-R/llama8b_re_seen.sh: -------------------------------------------------------------------------------- 1 | RESULT_DIR="outputs/llama8b_re_seen" 2 | python main_re.py --agent_config local_base \ 3 | --model_name meta-llama/Meta-Llama-3.1-8B-Instruct \ 4 | --exp_config sciworld \ 5 | --split dev \ 6 | --verbose \ 7 | --output_path ${RESULT_DIR} \ 8 | --task sample-test \ 9 | --rm http://0.0.0.0:15678/api/generate \ 10 | --threshold 7.25 \ 11 | --re_iters 10 \ 12 | -------------------------------------------------------------------------------- /Eval/sciworld/scripts/Greedy/llama8b_base_seen.sh: -------------------------------------------------------------------------------- 1 | RESULT_DIR="outputs/llama8b_base_seen" 2 | python main.py --agent_config local_base \ 3 | --model_name meta-llama/Meta-Llama-3.1-8B-Instruct \ 4 | --exp_config sciworld \ 5 | --split dev \ 6 | --verbose \ 7 | --output_path ${RESULT_DIR} \ 8 | --task sample-test \ 9 | -------------------------------------------------------------------------------- /Eval/sciworld/scripts/Sampling/llama8b_sample_seen.sh: -------------------------------------------------------------------------------- 1 | for idx in {0..9} 2 | do 3 | echo "Loop Iteration: $idx" 4 | 5 | RESULT_DIR="outputs/llama8b_sample_seen/llama8b_sample_seen_${idx}" 6 | python main.py --agent_config local_0.5 \ 7 | --model_name meta-llama/Meta-Llama-3.1-8B-Instruct \ 8 | --exp_config sciworld \ 9 | --split dev \ 10 | --output_path ${RESULT_DIR} \ 11 | --task sample-test \ 12 | --verbose & 13 | done 14 | 15 | 16 | -------------------------------------------------------------------------------- /Eval/sciworld/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Task 2 | from .sciworld import SciWorldTask 3 | # from .alfworld import AlfWorldTask -------------------------------------------------------------------------------- /Eval/sciworld/tasks/alfworld.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | import logging 5 | from typing import Iterable, Tuple 6 | 7 | import alfworld 8 | import alfworld.agents.environment as envs 9 | 10 | from eval_agent.tasks.base import Task 11 | 12 | 13 | logger = logging.getLogger("agent_frame") 14 | 15 | PREFIXES = { 16 | "pick_and_place": "put", 17 | "pick_clean_then_place": "clean", 18 | "pick_heat_then_place": "heat", 19 | "pick_cool_then_place": "cool", 20 | "look_at_obj": "examine", 21 | "pick_two_obj": "puttwo", 22 | } 23 | 24 | 25 | class AlfWorldTask(Task): 26 | """Alfworld task instance.""" 27 | 28 | task_name = "alfworld" 29 | 30 | def __init__( 31 | self, 32 | game_file: str, 33 | env: envs.AlfredTWEnv, 34 | obs: str, 35 | **kwargs, 36 | ): 37 | super().__init__(**kwargs) 38 | self.game_file = game_file 39 | self.observation = obs 40 | 41 | self.env = env 42 | 43 | @classmethod 44 | def load_tasks(cls, split: str, part_num: int, part_idx: int = -1) -> Tuple[Iterable[Task], int]: 45 | os.environ["ALFWORLD_DATA"] = "eval_agent/data/alfworld" 46 | alfworld_data_path = os.environ.get("ALFWORLD_DATA") 47 | 48 | with open(os.path.join(alfworld_data_path, "base_config.yaml")) as f: 49 | config = yaml.safe_load(f) 50 | 51 | if split == 'train': 52 | split = "train" 53 | N_TASKS = 3321 54 | elif split == 'dev': 55 | split = "eval_in_distribution" 56 | N_TASKS = 140 57 | elif split == 'test': 58 | split = "eval_out_of_distribution" 59 | N_TASKS = 134 60 | 61 | env = getattr(alfworld.agents.environment, config["env"]["type"])( 62 | config, train_eval=split 63 | ) 64 | assert isinstance(env, alfworld.agents.environment.AlfredTWEnv) 65 | env = env.init_env(batch_size=1) 66 | 67 | if part_num > 1: 68 | assert part_idx != -1 69 | part_inst_num = [N_TASKS // part_num] * part_num 70 | part_inst_num[-1] += N_TASKS % part_num 71 | # jump to the start of the part 72 | env.skip(sum(part_inst_num[:part_idx])) 73 | N_TASKS = part_inst_num[part_idx] 74 | 75 | def generator(): 76 | for idx in range(N_TASKS): 77 | obs, info = env.reset() 78 | obs = "\n".join(obs[0].split("\n\n")[1:]) 79 | game_file = info["extra.gamefile"][0] 80 | 81 | yield cls( 82 | task_id=idx, 83 | game_file=game_file, 84 | env=env, 85 | obs=obs, 86 | ) 87 | 88 | return generator(), N_TASKS 89 | -------------------------------------------------------------------------------- /Eval/sciworld/tasks/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import functools 5 | from abc import ABC, abstractmethod 6 | from typing import List, Optional, Tuple, Any 7 | 8 | 9 | logger = logging.getLogger("agent_frame") 10 | 11 | 12 | class Task(ABC): 13 | """Base class for a task instance.""" 14 | 15 | task_name: str = "base" 16 | 17 | def __init__(self, **kwargs) -> None: 18 | self.task_id: Any = kwargs.get("task_id", None) 19 | 20 | @classmethod 21 | @abstractmethod 22 | def load_tasks(cls, split: str, part_num: int, part_idx: int) -> Tuple[List["Task"], int]: 23 | pass 24 | -------------------------------------------------------------------------------- /Eval/sciworld/tasks/sciworld.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | import logging 5 | from typing import Iterable, List, Tuple 6 | 7 | from scienceworld import ScienceWorldEnv 8 | 9 | from eval_agent.tasks.base import Task 10 | 11 | logger = logging.getLogger("agent_frame") 12 | 13 | 14 | class SciWorldTask(Task): 15 | """ScienceWorld task instance.""" 16 | 17 | task_name = "sciworld" 18 | 19 | def __init__( 20 | self, 21 | sub_task_name: str, 22 | variation_idx: int, 23 | **kwargs, 24 | ): 25 | super().__init__(**kwargs) 26 | self.sub_task_name = sub_task_name 27 | self.variation_idx = variation_idx 28 | 29 | @classmethod 30 | def load_tasks(cls, split: str, part_num: int, part_idx: int = -1): 31 | if split == 'train': 32 | task_idxs = json.load(open("data/sciworld/train_indices.json")) 33 | elif split == 'dev': 34 | task_idxs = json.load(open("data/sciworld/dev_indices.json")) 35 | elif split == 'test': 36 | task_idxs = json.load(open("data/sciworld/test_indices.json")) 37 | elif split == 'data_fake': 38 | task_idxs = json.load(open("data/sciworld/train_indices_sampled.json")) 39 | else: 40 | raise ValueError 41 | taskname2id = json.load(open("data/sciworld/taskname2id.json")) 42 | if part_num == 1: 43 | task_idxs = task_idxs 44 | else: 45 | assert part_idx != -1 46 | part_len = len(task_idxs) // part_num + 1 47 | task_idxs = task_idxs[part_len * part_idx: part_len * (part_idx + 1)] 48 | N_TASKS = len(task_idxs) 49 | 50 | def generator(): 51 | for item in task_idxs: 52 | task_name = item[0] 53 | variation_idx = item[1] 54 | yield cls( 55 | task_id=f"{taskname2id[task_name]}_{variation_idx}", 56 | sub_task_name=task_name, 57 | variation_idx=variation_idx, 58 | ) 59 | 60 | return generator(), N_TASKS 61 | -------------------------------------------------------------------------------- /Eval/sciworld/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | # use cache to avoid loading the same file multiple times 4 | # which can leads to too many open files error 5 | @functools.lru_cache(maxsize=128) 6 | def load_file(filepath: str) -> str: 7 | with open(filepath, "r") as f: 8 | content = f.read() 9 | return content 10 | -------------------------------------------------------------------------------- /Eval/sciworld/utils/replace_sciworld_score.py: -------------------------------------------------------------------------------- 1 | from scienceworld import ScienceWorldEnv 2 | 3 | 4 | def step(self, inputStr:str): 5 | observation = self.server.step(inputStr) 6 | raw_score = self.server.getScore() 7 | score = int(round(100 * raw_score)) # Convert from 0-1 to 0-100 8 | isCompleted = self.server.getCompleted() 9 | numMoves = self.getNumMoves() 10 | 11 | # Calculate reward 12 | reward = score - self.lastStepScore # Calculate reward (delta score) for this step 13 | self.lastStepScore = score # Store current score for reward calculation on the next step 14 | 15 | 16 | # If the number of moves exceeds the environment step limit, then set isCompleted to be true 17 | if (numMoves > self.envStepLimit): 18 | isCompleted = True 19 | 20 | # New: Handle this in the API rather than the agent -- if the score is less than zero, then set the isCompleted flag to true. 21 | if (score < 0): 22 | isCompleted = True 23 | 24 | # Mirror of Jericho API 25 | infos = { 26 | 'moves': numMoves, 27 | 'raw_score': raw_score, 28 | 'score': score, 29 | 'reward': reward, 30 | 'look': self.look(), 31 | 'inv': self.inventory(), 32 | 'taskDesc': self.taskdescription(), 33 | 'valid': self.getValidActionObjectCombinations(), 34 | 'variationIdx': self.variationIdx, 35 | 'taskName': self.taskName, 36 | 'simplificationStr': self.simplificationStr, 37 | } 38 | 39 | return observation, reward, isCompleted, infos 40 | 41 | 42 | def sciworld_monkey_patch(): 43 | ScienceWorldEnv.step = step 44 | print("Monkey Patched ScienceWorldEnv.step") 45 | -------------------------------------------------------------------------------- /Eval/webshop/README.md: -------------------------------------------------------------------------------- 1 | # Webshop Evaluation 2 | ## Install Dependencies 3 | Install dependencies of [AgentBench](https://github.com/THUDM/AgentBench/tree/ac57ad0fd30dc2ec6c7e8991fd21eb643d864783). 4 | 5 | ``` 6 | conda activate agent-bench 7 | ``` 8 | ## Inference 9 | 10 | Start the task server. 11 | ``` 12 | python -m src.start_task -a 13 | ``` 14 | 15 | ### Greedy 16 | 17 | ``` 18 | bash scripts/Greedy/new_llama8b_base.sh 19 | ``` 20 | 21 | ### Sampling 22 | 23 | ``` 24 | bash scripts/Sampling/new_llama8b_sample.sh 25 | ``` 26 | 27 | ### ARMAP-R 28 | 29 | First, replace Eval/webshop/src/server/tasks/webshop/_\_init__.py with __init__re.py, restart the task server. 30 | ``` 31 | python -m src.start_task -a 32 | ``` 33 | 34 | Then, run the process to obtain the ARMAP-R results. 35 | ``` 36 | bash scripts/ARMAP-R/new_llama8b_ref.sh 37 | ``` 38 | 39 | ### ARMAP-B 40 | 41 | First, run the process to obtain the Sampling results. 42 | ``` 43 | bash scripts/Sampling/new_llama8b_sample.sh 44 | ``` 45 | 46 | Then, use the reward model to obtain the best-of-n results. 47 | ``` 48 | python scripts/ARMAP-B/calc_sample.py outputs/new_llama8b_sample webshop-std 12345 49 | ``` 50 | 51 | ### ARMAP-M 52 | 53 | #### port transfer 54 | ``` 55 | bash script/ARMAP-M/trans_local_llm.sh 56 | bash script/ARMAP-M/trans_rm.sh 57 | ``` 58 | 59 | #### docker setup 60 | ``` 61 | docker ps 62 | docker exec -it {docker_id} /bin/bash 63 | ``` 64 | 65 | #### mcts 66 | ``` 67 | bash script/ARMAP-M/run_mcts.sh 68 | ``` 69 | #### performance analysis 70 | ``` 71 | python src/mcts_agents/performance_analysis.py 72 | ``` -------------------------------------------------------------------------------- /Eval/webshop/configs/agents/fs_agent.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: "src.client.agents.FastChatAgent" 3 | parameters: 4 | name: "FastChat" 5 | controller_address: "http://localhost:55555" 6 | max_new_tokens: 512 7 | temperature: 0 8 | 9 | vicuna-33b: 10 | parameters: 11 | model_name: "vicuna-33b-v1.3" 12 | 13 | wizard-30b: 14 | parameters: 15 | model_name: "WizardLM-30B-V1.0-merged" 16 | 17 | vicuna-13b: 18 | parameters: 19 | model_name: "vicuna-13b-v1.5" 20 | 21 | vicuna-7b: 22 | parameters: 23 | model_name: "vicuna-7b-v1.5" 24 | -------------------------------------------------------------------------------- /Eval/webshop/configs/agents/gemini.yaml: -------------------------------------------------------------------------------- 1 | module: src.client.agents.Gemini 2 | parameters: 3 | name: <% NAME %> -------------------------------------------------------------------------------- /Eval/webshop/configs/agents/new_local.yaml: -------------------------------------------------------------------------------- 1 | module: src.client.agents.NewLocal 2 | parameters: 3 | name: mixtral-8x7b-instruct-v01 4 | temperature: 0.5 -------------------------------------------------------------------------------- /Eval/webshop/configs/agents/openai-chat.yaml: -------------------------------------------------------------------------------- 1 | module: src.client.agents.HTTPAgent 2 | parameters: 3 | url: https://api.openai.com/v1/chat/completions 4 | headers: 5 | Content-Type: application/json 6 | Authorization: Bearer <% PUT-YOUR-OPENAI-KEY-HERE %> 7 | body: 8 | temperature: 0 9 | prompter: 10 | name: role_content_dict 11 | args: 12 | agent_role: assistant 13 | return_format: "{response[choices][0][message][content]}" 14 | -------------------------------------------------------------------------------- /Eval/webshop/configs/agents/openai-text.yaml: -------------------------------------------------------------------------------- 1 | module: src.client.agents.HTTPAgent 2 | parameters: 3 | name: <% NAME %> 4 | url: https://api.openai.com/v1/completions 5 | headers: 6 | Content-Type: application/json 7 | Authorization: Bearer <% PUT-YOUR-OPENAI-KEY-HERE %> 8 | body: 9 | model: <% NAME %> 10 | temperature: 0 11 | prompter: 12 | name: prompt_string 13 | return_format: "{response[choices][0][text]}" 14 | 15 | -------------------------------------------------------------------------------- /Eval/webshop/configs/assignments/new_llama8b_base.yaml: -------------------------------------------------------------------------------- 1 | import: definition.yaml 2 | 3 | concurrency: 4 | task: 5 | webshop-std: 1 6 | agent: 7 | new_llama8b_base: 3 8 | 9 | assignments: 10 | - agent: 11 | - new_llama8b_base 12 | task: 13 | - webshop-std 14 | 15 | output: "outputs/new_llama8b_base" 16 | -------------------------------------------------------------------------------- /Eval/webshop/configs/assignments/new_llama8b_ref.yaml: -------------------------------------------------------------------------------- 1 | import: definition.yaml 2 | 3 | concurrency: 4 | task: 5 | webshop-std: 1 6 | agent: 7 | new_llama8b_base: 3 8 | 9 | assignments: 10 | - agent: 11 | - new_llama8b_base 12 | task: 13 | - webshop-std 14 | 15 | output: "outputs/new_llama8b_ref" 16 | -------------------------------------------------------------------------------- /Eval/webshop/configs/assignments/new_llama8b_sample.yaml: -------------------------------------------------------------------------------- 1 | import: definition.yaml 2 | 3 | concurrency: 4 | task: 5 | webshop-std: 1 6 | agent: 7 | new_llama8b_sample0: 1 8 | new_llama8b_sample1: 1 9 | new_llama8b_sample2: 1 10 | new_llama8b_sample3: 1 11 | new_llama8b_sample4: 1 12 | new_llama8b_sample5: 1 13 | new_llama8b_sample6: 1 14 | new_llama8b_sample7: 1 15 | new_llama8b_sample8: 1 16 | new_llama8b_sample9: 1 17 | 18 | assignments: 19 | - agent: 20 | - new_llama8b_sample0 21 | - new_llama8b_sample1 22 | - new_llama8b_sample2 23 | - new_llama8b_sample3 24 | - new_llama8b_sample4 25 | - new_llama8b_sample5 26 | - new_llama8b_sample6 27 | - new_llama8b_sample7 28 | - new_llama8b_sample8 29 | - new_llama8b_sample9 30 | task: 31 | - webshop-std 32 | 33 | output: "outputs/new_llama8b_sample" 34 | -------------------------------------------------------------------------------- /Eval/webshop/configs/start_task.yaml: -------------------------------------------------------------------------------- 1 | definition: 2 | import: tasks/task_assembly.yaml 3 | 4 | start: 5 | webshop-std: 2 6 | # webshop-std: 2 7 | # webshop-train: 4 8 | # webshop-train: 3 9 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/alfworld.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: src.server.tasks.alfworld.ALFWorld 3 | docker: 4 | image: longinyu/agentbench-alfworld 5 | command: umask 0; [ -f /root/.setup.sh ] && bash /root/.setup.sh; 6 | parameters: 7 | name: alfworld-std 8 | data_path: "/AgentBench/data/alfworld" 9 | config_path: "src/server/tasks/alfworld/configs/base_config.yaml" 10 | prompts_path: "src/server/tasks/alfworld/prompts/alfworld_multiturn_plan_first.json" 11 | split: "standard" 12 | max_step: 35 13 | 14 | alfworld-dev: 15 | parameters: 16 | name: alfworld-dev 17 | split: "dev" 18 | 19 | alfworld-std: 20 | parameters: 21 | name: alfworld-std 22 | split: "standard" 23 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/avalon.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: "src.server.tasks.avalon.AvalonBench" 3 | parameters: 4 | concurrency: 5 5 | num_players: 5 6 | discussion: False 7 | 8 | avalon-dev-naive: 9 | parameters: 10 | name: "AvalonBench-dev-naive" 11 | data_file: "data/avalon/dev.json" 12 | agent_list: ["naive", "naive", "naive", "naive", "naive"] 13 | 14 | avalon-dev-single: 15 | parameters: 16 | name: "AvalonBench-dev-single" 17 | data_file: "data/avalon/dev.json" 18 | agent_list: ["llm", "naive", "naive", "naive", "naive"] -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/card_game.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: src.server.tasks.card_game.CardGame 3 | docker: 4 | image: longinyu/agentbench-card_game 5 | parameters: 6 | port: 12347 7 | workers: 10 8 | 9 | cg-dev: 10 | parameters: 11 | name: "CardGame-dev" 12 | test_time: 3 13 | 14 | cg-std: 15 | parameters: 16 | name: "CardGame-std" 17 | test_time: 5 18 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/dbbench.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: src.server.tasks.dbbench.DBBench 3 | parameters: 4 | concurrency: 1 5 | max_round: 15 6 | 7 | dbbench-dev: 8 | parameters: 9 | name: dbbench-dev 10 | data_file: "data/dbbench/dev.jsonl" 11 | 12 | dbbench-std: 13 | parameters: 14 | name: dbbench-std 15 | data_file: "data/dbbench/standard.jsonl" 16 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/kg.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: "src.server.tasks.knowledgegraph.KnowledgeGraph" 3 | parameters: 4 | round: 15 5 | sparql_url: "http://164.107.116.56:3093/sparql" 6 | 7 | kg-dev: 8 | parameters: 9 | name: "KnowledgeGraph-dev" 10 | data_file: "data/knowledgegraph/dev.json" 11 | 12 | kg-std: 13 | parameters: 14 | name: "KnowledgeGraph-std" 15 | data_file: "data/knowledgegraph/std.json" 16 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/ltp.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: src.server.tasks.ltp.LateralThinkingPuzzle 3 | docker: 4 | image: longinyu/agentbench-ltp 5 | parameters: 6 | round: 25 7 | eval_yaml: "configs/agents/api_agents.yaml" 8 | 9 | ltp-dev: 10 | parameters: 11 | name: ltp-dev 12 | filepath: "data/lateralthinkingpuzzle/dev.xlsx" 13 | 14 | ltp-std: 15 | parameters: 16 | name: ltp-std 17 | filepath: "data/lateralthinkingpuzzle/standard.xlsx" -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/mind2web.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: src.server.tasks.mind2web.Mind2Web 3 | docker: 4 | image: longinyu/agentbench-mind2web 5 | parameters: 6 | concurrency: 10 7 | data: 8 | data_path: "." 9 | cache_path: "./data/mind2web/.cache/data" 10 | score_file: /root/work/data/scores_all_data.pkl 11 | train: 12 | neg_ratio: 0.2 13 | num_candidates: 5 14 | max_context_len: 512 15 | model: 16 | mode: "multichoice" 17 | name: flan-t5-base 18 | model_name_or_path: "google/flan-t5-base" 19 | max_seq_length: 2048 20 | eval: 21 | topk: 10 22 | seed: 123 23 | llm_prompt: data/mind2web/prompt/llm_prompt_cot.json 24 | 25 | m2w-dev: 26 | parameters: 27 | name: "Mind2Web-dev" 28 | data: 29 | test_split_files: 30 | test_domain: /root/work/data/data_dev/*.json 31 | 32 | m2w-std: 33 | parameters: 34 | name: "Mind2Web-std" 35 | data: 36 | test_split_files: 37 | test_domain: /root/work/data/data_std/*.json 38 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/os.yaml: -------------------------------------------------------------------------------- 1 | os-dev: 2 | module: "src.server.tasks.os_interaction.OSInteraction" 3 | parameters: 4 | name: "os-dev" 5 | concurrency: 24 6 | round_limit: 8 7 | 8 | docker_config: 9 | localhost: local-os 10 | directory: data/os_interaction/res/dockerfiles 11 | 12 | scripts: 13 | directory: data/os_interaction/res/scripts 14 | 15 | data_config: 16 | files: 17 | - problem_file: data/os_interaction/data/dev.json 18 | script_dir: data/os_interaction/scripts/dev/ 19 | index_prefix: "dev-001-" 20 | 21 | bk: [ ] 22 | ignore: [ ] 23 | 24 | os-std: 25 | module: "src.server.tasks.os_interaction.OSInteraction" 26 | parameters: 27 | name: "os-std" 28 | concurrency: 24 29 | round_limit: 8 30 | 31 | docker_config: 32 | localhost: local-os 33 | directory: data/os_interaction/res/dockerfiles 34 | 35 | scripts: 36 | directory: data/os_interaction/res/scripts 37 | 38 | data_config: 39 | files: 40 | - problem_file: data/os_interaction/data/1/*.json 41 | script_dir: data/os_interaction/scripts/1/ 42 | index_prefix: "std-001-" 43 | - problem_file: data/os_interaction/data/2/*.json 44 | script_dir: data/os_interaction/scripts/2/ 45 | index_prefix: "std-002-" 46 | - problem_file: data/os_interaction/data/3/*.json 47 | script_dir: data/os_interaction/scripts/3/ 48 | index_prefix: "std-003-" 49 | - problem_file: data/os_interaction/data/4/*.json 50 | script_dir: data/os_interaction/scripts/4/ 51 | index_prefix: "std-004-" 52 | - problem_file: data/os_interaction/data/5/*.json 53 | script_dir: data/os_interaction/scripts/5/ 54 | index_prefix: "std-005-" 55 | - problem_file: data/os_interaction/data/6/*.json 56 | script_dir: data/os_interaction/scripts/6/ 57 | index_prefix: "std-006-" 58 | - problem_file: data/os_interaction/data/7/*.json 59 | script_dir: data/os_interaction/scripts/7/ 60 | index_prefix: "std-007-" 61 | 62 | bk: [ ] 63 | ignore: [ ] 64 | -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/task_assembly.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | docker: 3 | command: umask 0; [ -f /root/.setup.sh ] && bash /root/.setup.sh; 4 | 5 | import: 6 | - webshop.yaml 7 | - dbbench.yaml 8 | - mind2web.yaml 9 | - card_game.yaml 10 | - kg.yaml 11 | - os.yaml 12 | - ltp.yaml 13 | - alfworld.yaml 14 | - avalon.yaml -------------------------------------------------------------------------------- /Eval/webshop/configs/tasks/webshop.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | module: src.server.tasks.webshop_docker.WebShop 3 | parameters: 4 | concurrency: 1 5 | docker: 6 | #image: longinyu/agentbench-webshop3 7 | #image: zfchen2021/agentbench-webshop 8 | image: zfchen2021/agentbench-webshop 9 | command: ln -s /root/webshop /root/workspace/src/server/tasks/webshop_docker;cp /root/workspace/src/server/tasks/webshop/__init__.py /root/webshop/__init__.py; 10 | 11 | webshop-test: 12 | parameters: 13 | name: webshop-dev 14 | start: 200 15 | end: 210 16 | 17 | webshop-dev: 18 | parameters: 19 | name: webshop-dev 20 | start: 200 21 | end: 280 22 | 23 | webshop-std: 24 | parameters: 25 | name: webshop-std 26 | start: 0 27 | end: 200 28 | 29 | webshop-train: 30 | parameters: 31 | name: webshop-train 32 | start: 0 33 | end: 8300 34 | #end: 500 35 | -------------------------------------------------------------------------------- /Eval/webshop/data_generation/data_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 5 | LIB_DIR= os.path.dirname(SCRIPT_DIR) 6 | sys.path.append(LIB_DIR) 7 | from utils import config, build_prompt_v1, set_debugger 8 | from data_modules import step_1_sample_data, step_2_generate_goal, get_anno_statistics, sample_goal_data, step_3_pack_data, step_1_sample_data_v2 9 | from data_modules import step_4_construct_triplet, step_4_construct_triplet_v2 10 | import json 11 | import pdb 12 | 13 | set_debugger() 14 | 15 | def run_data_generation(args): 16 | # step 1: sample products from the database 17 | if args.step_id == 1: 18 | step_1_sample_data(args) 19 | #step_1_sample_data_v2(args) 20 | # step 2: generate intent samples based on attributes 21 | elif args.step_id == 2: 22 | step_2_generate_goal(args) 23 | # step 3: pack data into webshop formats 24 | elif args.step_id == 3: 25 | step_3_pack_data(args) 26 | elif args.step_id == 4: 27 | #step_4_construct_triplet(args) 28 | step_4_construct_triplet_v2(args) 29 | else: 30 | raise NotImplementedError 31 | 32 | 33 | if __name__ == "__main__": 34 | args = config() 35 | #build_prompt_v1() 36 | run_data_generation(args) 37 | #sample_goal_data(args) 38 | #sample_goal_data(args) 39 | # get_anno_statistics() 40 | -------------------------------------------------------------------------------- /Eval/webshop/llms/__init__.py: -------------------------------------------------------------------------------- 1 | """This module is adapt from https://github.com/zeno-ml/zeno-build""" 2 | from .providers.gemini_utils import generate_from_gemini_completion 3 | #from .providers.hf_utils import generate_from_huggingface_completion 4 | """ 5 | from .providers.openai_utils import ( 6 | generate_from_openai_chat_completion, 7 | generate_from_openai_completion, 8 | ) 9 | from .providers.cogagent_utils import generate_from_cogagent_completion 10 | """ 11 | from .providers.openai_azure_utils import generate_from_openai_azure_chat_completion 12 | from .providers.local_utils import generate_from_local_completion 13 | from .utils import call_llm 14 | 15 | __all__ = [ 16 | "generate_from_openai_completion", 17 | "generate_from_openai_chat_completion", 18 | "generate_from_huggingface_completion", 19 | "generate_from_gemini_completion", 20 | "generate_from_cogagent_completion", 21 | "generate_from_openai_azure_chat_completion", 22 | "call_llm", 23 | ] 24 | -------------------------------------------------------------------------------- /Eval/webshop/llms/providers/hf_utils.py: -------------------------------------------------------------------------------- 1 | from text_generation import Client # type: ignore 2 | 3 | 4 | def generate_from_huggingface_completion( 5 | prompt: str, 6 | model_endpoint: str, 7 | temperature: float, 8 | top_p: float, 9 | max_new_tokens: int, 10 | stop_sequences: list[str] | None = None, 11 | ) -> str: 12 | client = Client(model_endpoint, timeout=60) 13 | generation: str = client.generate( 14 | prompt=prompt, 15 | temperature=temperature, 16 | top_p=top_p, 17 | max_new_tokens=max_new_tokens, 18 | stop_sequences=stop_sequences, 19 | ).generated_text 20 | 21 | return generation 22 | -------------------------------------------------------------------------------- /Eval/webshop/llms/tokenizers.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import tiktoken 4 | from transformers import LlamaTokenizer # type: ignore 5 | 6 | 7 | class Tokenizer(object): 8 | def __init__(self, provider: str, model_name: str) -> None: 9 | if "openai" in provider: 10 | self.tokenizer = tiktoken.encoding_for_model(model_name) 11 | elif provider == "huggingface": 12 | self.tokenizer = LlamaTokenizer.from_pretrained(model_name) 13 | # turn off adding special tokens automatically 14 | self.tokenizer.add_special_tokens = False # type: ignore[attr-defined] 15 | self.tokenizer.add_bos_token = False # type: ignore[attr-defined] 16 | self.tokenizer.add_eos_token = False # type: ignore[attr-defined] 17 | elif provider == "google" or provider =="local": 18 | self.tokenizer = None # Not used for input length computation, as Gemini is based on characters 19 | else: 20 | raise NotImplementedError 21 | 22 | def encode(self, text: str) -> list[int]: 23 | return self.tokenizer.encode(text) 24 | 25 | def decode(self, ids: list[int]) -> str: 26 | return self.tokenizer.decode(ids) 27 | 28 | def __call__(self, text: str) -> list[int]: 29 | return self.tokenizer.encode(text) 30 | -------------------------------------------------------------------------------- /Eval/webshop/scripts/ARMAP-M/run_mcts.sh: -------------------------------------------------------------------------------- 1 | llm_port=7780 2 | rm_port=15678 3 | python src/mcts_agents/env_webshop_wrapper.py \ 4 | --llm_api http://172.17.0.1:${llm_port}/v1 \ 5 | --rm_api http://172.17.0.1:${rm_port}/api/generate \ 6 | --seq_num 10 \ 7 | --horizon 10 \ 8 | --rollouts 10 \ 9 | --with_vllm_api \ 10 | --temperature 1.5 \ 11 | --output_dir ./outputs/mcts_sampling_temp_15 \ -------------------------------------------------------------------------------- /Eval/webshop/scripts/ARMAP-M/trans_local_llm.sh: -------------------------------------------------------------------------------- 1 | llm_url=LLM_URL 2 | llm_port=LLM_PORT 3 | ssh -L 0.0.0.0:${llm_port}:${llm_url}:${llm_port} REMOTE_SERVER_ADDRESS -------------------------------------------------------------------------------- /Eval/webshop/scripts/ARMAP-M/trans_rm.sh: -------------------------------------------------------------------------------- 1 | rm_url=RM_URL 2 | rm_port=RM_PORT 3 | ssh -L 0.0.0.0:${rm_port}:${rm_url}:${rm_port} REMOTE_SERVER_ADDRESS -------------------------------------------------------------------------------- /Eval/webshop/scripts/ARMAP-R/new_llama8b_ref.sh: -------------------------------------------------------------------------------- 1 | python -m src.assigner -c configs/assignments/new_llama8b_ref.yaml -------------------------------------------------------------------------------- /Eval/webshop/scripts/Greedy/new_llama8b_base.sh: -------------------------------------------------------------------------------- 1 | python -m src.assigner -c configs/assignments/new_llama8b_base.yaml -------------------------------------------------------------------------------- /Eval/webshop/scripts/Sampling/new_llama8b_sample.sh: -------------------------------------------------------------------------------- 1 | python -m src.assigner -c configs/assignments/new_llama8b_sample.yaml -------------------------------------------------------------------------------- /Eval/webshop/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/client/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import AgentClient 2 | from .agents import * 3 | from .task import TaskClient 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/agent.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/agent.cpython-312.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/agent.cpython-38.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/agent.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/agent.cpython-39.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/task.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/task.cpython-38.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/__pycache__bk/task.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/client/__pycache__bk/task.cpython-39.pyc -------------------------------------------------------------------------------- /Eval/webshop/src/client/agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class AgentClient: 5 | def __init__(self, *args, **kwargs): 6 | pass 7 | 8 | def inference(self, history: List[dict]) -> str: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /Eval/webshop/src/client/agent_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from src.configs import ConfigLoader 4 | from src.typings import InstanceFactory 5 | from .agent import AgentClient 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--config', type=str, default='configs/agents/api_agents.yaml') 11 | parser.add_argument('--agent', type=str, default='gpt-3.5-turbo-0613') 12 | return parser.parse_args() 13 | 14 | 15 | def interaction(agent: AgentClient): 16 | try: 17 | history = [] 18 | while True: 19 | print("================= USER ===================") 20 | user = input(">>> ") 21 | history.append({"role": "user", "content": user,"image":"/home/haoyu/zfchen/dl/AgentBench/assets/agentbench.png"}) 22 | try: 23 | agent_response = agent.inference(history) 24 | print("================ AGENT ====================") 25 | print(agent_response) 26 | history.append({"role": "agent", "content": agent_response}) 27 | except Exception as e: 28 | print(e) 29 | exit(0) 30 | except KeyboardInterrupt: 31 | print("\n[Exit] KeyboardInterrupt") 32 | exit(0) 33 | 34 | 35 | if __name__ == '__main__': 36 | args = parse_args() 37 | loader = ConfigLoader() 38 | config = loader.load_from(args.config) 39 | assert args.agent in config, f"Agent {args.agent} not found in {args.config}" 40 | agent_config = config[args.agent] 41 | factory = InstanceFactory(**agent_config) 42 | agent_client: AgentClient = factory.create() 43 | interaction(agent_client) 44 | -------------------------------------------------------------------------------- /Eval/webshop/src/client/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastchat_client import FastChatAgent 2 | from .http_agent import HTTPAgent 3 | from .gemini_agent import Gemini,GeminiText 4 | from .new_local_agent import NewLocal -------------------------------------------------------------------------------- /Eval/webshop/src/client/agents/claude_agent.py: -------------------------------------------------------------------------------- 1 | import anthropic 2 | import os 3 | from copy import deepcopy 4 | from typing import List 5 | 6 | from ..agent import AgentClient 7 | 8 | 9 | class Claude(AgentClient): 10 | def __init__(self, api_args=None, *args, **config): 11 | super().__init__(*args, **config) 12 | if not api_args: 13 | api_args = {} 14 | api_args = deepcopy(api_args) 15 | self.key = api_args.pop("key", None) or os.getenv('Claude_API_KEY') 16 | api_args["model"] = api_args.pop("model", None) 17 | if not self.key: 18 | raise ValueError("Claude API KEY is required, please assign api_args.key or set OPENAI_API_KEY " 19 | "environment variable.") 20 | if not api_args["model"]: 21 | raise ValueError("Claude model is required, please assign api_args.model.") 22 | self.api_args = api_args 23 | if not self.api_args.get("stop_sequences"): 24 | self.api_args["stop_sequences"] = [anthropic.HUMAN_PROMPT] 25 | 26 | def inference(self, history: List[dict]) -> str: 27 | prompt = "" 28 | for message in history: 29 | if message["role"] == "user": 30 | prompt += anthropic.HUMAN_PROMPT + message["content"] 31 | else: 32 | prompt += anthropic.AI_PROMPT + message["content"] 33 | prompt += anthropic.AI_PROMPT 34 | c = anthropic.Client(api_key=self.key) 35 | resp = c.completions.create(prompt=prompt, **self.api_args) 36 | return str(resp.completion) 37 | -------------------------------------------------------------------------------- /Eval/webshop/src/client/agents/key_pool.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import openai 4 | import random 5 | 6 | 7 | class KeyPool: 8 | MAX_REQUESTS_PER_MINUTE = 3 9 | keys = [] 10 | origin_key = "" 11 | index = -1 12 | 13 | @classmethod 14 | def set_all_keys(cls, keys): 15 | cls.keys = keys 16 | 17 | @classmethod 18 | def set_key(cls): 19 | random.seed(int(time.time())) 20 | index = random.randint(0, len(cls.keys)-1) 21 | wait_seconds = 0.1 22 | # print(f"\nWaiting {wait_seconds}s for next request.\n") 23 | time.sleep(wait_seconds) 24 | return cls.keys[index] 25 | 26 | @classmethod 27 | def reset_key(cls): 28 | openai.api_key = cls.origin_key 29 | 30 | def __enter__(self): 31 | return KeyPool.set_key() 32 | 33 | def __exit__(self, type, val, tb): 34 | # KeyPool.reset_key() 35 | pass -------------------------------------------------------------------------------- /Eval/webshop/src/client/agents/test_agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from src.client import AgentClient 4 | 5 | 6 | class CountHistoryAgent(AgentClient): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | 10 | def inference(self, history: List[dict]) -> str: 11 | return "I received {} items in history.".format(len(history)) 12 | -------------------------------------------------------------------------------- /Eval/webshop/src/mcts_agents/my_random_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Random Agent given as an example 3 | """ 4 | 5 | from gym import spaces 6 | import dyna_gym.utils.utils as utils 7 | 8 | class MyRandomAgent(object): 9 | def __init__(self, action_space): 10 | self.action_space = action_space 11 | 12 | def reset(self, p=None): 13 | """ 14 | Reset the attributes. 15 | Expect to receive them in the same order as init. 16 | p : list of parameters 17 | """ 18 | if p is not None: 19 | utils.assert_types(p,[spaces.discrete.Discrete]) 20 | self.__init__(p[0]) 21 | 22 | def display(self): 23 | """ 24 | Display infos about the attributes. 25 | """ 26 | print('Displaying Random agent:') 27 | print('Action space :', self.action_space) 28 | 29 | def act(self, observation=None, reward=None, done=None): 30 | return self.action_space.sample() 31 | -------------------------------------------------------------------------------- /Eval/webshop/src/mcts_agents/performance_analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pdb 4 | 5 | 6 | def parse_performance_mcts(): 7 | result_dir = "PATH_TO_RESULT_DIR" 8 | fn_list = os.listdir(result_dir) 9 | mcts_perf, avg_perf, max_perf = [], [], [] 10 | for fn in fn_list: 11 | full_fn = os.path.join(result_dir, fn) 12 | with open(full_fn, "r") as fh: 13 | fdict = json.load(fh) 14 | reward_pds = fdict["rewards_pred"] 15 | reward_gts = fdict["rewards_gt"] 16 | max_val = max(reward_pds) 17 | max_idx_list = [idx for idx, ele in enumerate(reward_pds) if ele==max_val] 18 | max_gt_list = [reward_gts[idx] for idx in max_idx_list] 19 | mcts_rewards = sum(max_gt_list) / len(max_gt_list) 20 | mcts_perf.append(mcts_rewards) 21 | 22 | mean_reward = sum(reward_gts) / len(reward_gts) 23 | max_perf.append(max(reward_gts)) 24 | avg_perf.append(mean_reward) 25 | print(reward_gts) 26 | if max(reward_gts)==0: 27 | print(fn) 28 | 29 | mean_p = sum(avg_perf)/len(avg_perf) 30 | mcts_p = sum(mcts_perf)/len(mcts_perf) 31 | max_p = sum(max_perf)/len(max_perf) 32 | print("Testing sample number: %d\n"%(len(fn_list))) 33 | print("Average MEAN performance: %f\n"%(mean_p)) 34 | print("Average MCTS performance: %f\n"%(mcts_p)) 35 | print("Average MAX performance: %f\n"%(max_p)) 36 | pdb.set_trace() 37 | 38 | if __name__=="__main__": 39 | parse_performance_mcts() -------------------------------------------------------------------------------- /Eval/webshop/src/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import ALFWorld -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import glob 5 | import random 6 | import numpy as np 7 | 8 | import textworld 9 | import textworld.agents 10 | import textworld.gym 11 | import gym 12 | 13 | from alfworld.agents.utils.misc import Demangler, get_templated_task_desc, add_task_to_grammar 14 | import alfworld.agents.modules.generic as generic 15 | from alfworld.agents.environment.alfred_tw_env import AlfredTWEnv 16 | 17 | class SingleAlfredTWEnv(AlfredTWEnv): 18 | ''' 19 | Interface for Textworld Env 20 | Contains only one game_file per environment 21 | ''' 22 | 23 | def __init__(self, config, name, train_eval="eval_out_of_distribution"): 24 | print("Initializing AlfredTWEnv...") 25 | self.config = config 26 | self.train_eval = train_eval 27 | 28 | self.goal_desc_human_anns_prob = self.config['env']['goal_desc_human_anns_prob'] 29 | self.get_game_logic() 30 | # self.gen_game_files(regen_game_files=self.config['env']['regen_game_files']) 31 | 32 | self.random_seed = 42 33 | 34 | self.game_files = [name] 35 | self.num_games = 1 36 | 37 | def get_all_game_files(config, split="eval_out_of_distribution"): 38 | env = AlfredTWEnv(config, train_eval=split) 39 | game_files = env.game_files 40 | del env 41 | return game_files 42 | 43 | # if __name__=="__main__": 44 | # os.environ["ALFWORLD_DATA"] = "/data/share/leixy/ReAct/alfworld/data" 45 | # config=load_config("/data/share/leixy/AgentBench/src/tasks/alfworld/configs/base_config.yaml") 46 | # game_files = get_all_game_files(config, "train") 47 | # game_files = [game.split("data/")[-1] for game in game_files] 48 | # with open("train.json", "w") as f: 49 | # f.write(json.dumps(game_files, indent=2)) 50 | # f.close() 51 | # print(len(game_files)) 52 | # print(game_files[0]) 53 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_multiturn.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e5f20e9e713309c9a61dd940b49456d8654f59ab1785832928afdf8442a8e359 3 | size 57778 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_multiturn_cot.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:03880d41baa1308088d1c2cee127632d1218cc2897aa11085ee970dd31604169 3 | size 1616 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_multiturn_new.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:43274ec5b83e586f673cc257e35034cc505e550596978005e9d2307c37beca77 3 | size 11546 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_multiturn_plan_first.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1ed562e55f0f7b6ca3588b2fefb7f5d1d233ba981d456fb79afbc081e1ac38b7 3 | size 7832 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_multiturn_pure.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5a44202878ad73b98a6a11c78dad905b0db685b71e36f43b37685dff8ed8f244 3 | size 6732 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_multiturn_react.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:79a92edddd49691cde9d7a05beda354658355882911fcf5c5a7b38a4777bc3b2 3 | size 8932 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/prompts/alfworld_singleturn.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d9036576672710b76c242d393ea1441a6cb02b97554e056543576170e4573433 3 | size 51028 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/alfworld/utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from typing import List 3 | import re 4 | import threading 5 | import jsonlines 6 | import yaml 7 | import json 8 | import numpy as np 9 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 10 | 11 | def bleu_score(reference, candidate): 12 | reference_tokens = reference.split() 13 | candidate_tokens = candidate.split() 14 | 15 | smoothie = SmoothingFunction().method4 16 | score = sentence_bleu([reference_tokens], candidate_tokens, smoothing_function=smoothie) 17 | return score 18 | 19 | def process_ob(ob): 20 | if ob.startswith('You arrive at loc '): 21 | ob = ob[ob.find('. ')+2:] 22 | return ob 23 | 24 | def process_action(action, choices, limit=0.01, to_print=False): 25 | if to_print: 26 | print("preprocess action: ", action) 27 | match = re.search("ACTION:(.*)", action) 28 | if match: 29 | action = match.group(1) 30 | else: 31 | return False 32 | 33 | action = action.strip().lower().split("\n")[0] 34 | if not choices: 35 | return action 36 | if action in choices: 37 | return action 38 | try: 39 | bleus = [bleu_score(choice, action) for choice in choices] 40 | max_index = np.argmax(np.array(bleus)) 41 | max_score = bleus[max_index] 42 | if max_score > limit: 43 | if to_print: 44 | print("processed action: ", choices[max_index], " score: ", max_score) 45 | return choices[max_index] 46 | except Exception as e: 47 | print("encounter exception: ", e) 48 | print("choices: ", choices) 49 | print("action: ", action) 50 | return action 51 | 52 | def load_prompts(prompts_file): 53 | with open(prompts_file, 'r') as f: 54 | d = json.load(f) 55 | f.close() 56 | return d 57 | 58 | def load_config(config_file): 59 | with open(config_file) as reader: 60 | config = yaml.safe_load(reader) 61 | return config -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/avalon/README.md: -------------------------------------------------------------------------------- 1 | # AvalonBench 2 | 3 | ## Quick Start 4 | 5 | ### Start the task server and the assigner 6 | 7 | Start the game (3 is the number of workers) 8 | ```bash 9 | python -m src.start_task -a --start avalon-dev-single 3 10 | ``` 11 | Start the assigner 12 | ```bash 13 | python -m src.assigner --config ./configs/assignments/test_avalon.yaml 14 | ``` 15 | 16 | ### Customize configurations and data 17 | 18 | 1. You can modify the file `configs/tasks/avalon.yaml` to configure the agent list. A config file looks like this: 19 | ```yaml 20 | default: 21 | module: "src.server.tasks.avalon.AvalonBench" 22 | parameters: 23 | num_players: 5 24 | discussion: False 25 | 26 | avalon-dev-naive: 27 | parameters: 28 | name: "AvalonBench-dev-naive" 29 | data_file: "data/avalon/dev.json" 30 | agent_list: ["naive", "naive", "naive", "naive", "naive"] 31 | 32 | avalon-dev-single: 33 | parameters: 34 | name: "AvalonBench-dev-single" 35 | data_file: "data/avalon/dev.json" 36 | agent_list: ["llm", "naive", "naive", "naive", "naive"] 37 | ``` 38 | where `naive` stands for the naive bots. Agents will play the roles with the same index in the data file (see following). 39 | ```plaintext 40 | Note: There should only be one "llm" in the `agent_list` 41 | ``` 42 | 43 | 2. You can also add data in `data/avalon/dev.json` (Note: Currently we only support the 5-player game setting, which includes 1 Merlin, 2 Servants, 1 Minion and 1 Assassin). A data item looks like this: 44 | 45 | ```json 46 | { 47 | "num_players": 5, 48 | "quest_leader": 0, 49 | "role_names": ["Assassin", "Servant", "Servant", "Merlin", "Minion"] 50 | } 51 | ``` 52 | where `quest_leader` is the id of the initial quest leader in this game. You can change the game setup by altering `quest_leader` with number from 0 to 4, and by permuting `role_names`. 53 | 54 | ### Naive experiment 55 | 56 | You can also start a naive experiment using: 57 | ```bash 58 | python -m src.start_task -a --start avalon-dev-naive 3 59 | ``` 60 | where all the agents are naive bots. For details of the naive strategies, please refer to the [paper](https://arxiv.org/pdf/2310.05036.pdf). 61 | 62 | ## Prompts 63 | 64 | All the prompts are maintained in `src/server/tasks/avalon/prompt.py`. You can find the respective prompts used in `src/server/tasks/avalon/agents/llm_with_discussion.py` and `src/server/tasks/avalon/wrapper.py`. -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/avalon/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import AvalonBench -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/avalon/avalon_exception.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from src.typings import AgentBenchException 3 | 4 | class AvalonException(AgentBenchException): 5 | def __init__(self, reason: str, detail: Union[str, None] = None) -> None: 6 | super().__init__() 7 | self.reason = reason 8 | self.detail = detail 9 | 10 | def __str__(self) -> str: 11 | if not self.detail: 12 | return "{CLASS_NAME}[{REASON}]".format( 13 | CLASS_NAME=self.__class__.__name__, REASON=self.reason 14 | ) 15 | else: 16 | return "{CLASS_NAME}[{REASON}]: {DETAIL}".format( 17 | CLASS_NAME=self.__class__.__name__, 18 | REASON=self.reason, 19 | DETAIL=self.detail, 20 | ) 21 | 22 | class AvalonEnvException(AvalonException): 23 | def __init__(self, detail: Union[str, None] = None) -> None: 24 | super().__init__("Avalon Environment Exception", detail) 25 | 26 | class AvalonAgentActionException(AvalonException): 27 | def __init__(self, detail: Union[str, None] = None) -> None: 28 | super().__init__("Invalid action (result) with retry", detail) -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/avalon/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def get_vote_result(answer: str): 4 | match_vote = "Yes|No" 5 | vote_result = [] 6 | 7 | vote_result = re.findall(match_vote, answer) 8 | 9 | result = '' if len(vote_result) == 0 else vote_result[-1] 10 | 11 | return result 12 | 13 | def get_team_result(answer: str): 14 | match_num = r"\d+" 15 | player_list = [] 16 | 17 | player_list = re.findall(match_num, answer) 18 | 19 | player_list = [int(id) for id in player_list] 20 | 21 | return player_list 22 | 23 | def get_assassination_result(message: str, answer: str): 24 | match_num = r"\d+" 25 | player_id = [] 26 | 27 | player_id = re.findall(match_num, str(message)+str(answer)) 28 | 29 | player_id = int(player_id[-1]) 30 | 31 | return player_id 32 | 33 | def get_believed_player_sides(answer): 34 | scores = eval(answer.split("Answer: ")[-1]) 35 | 36 | return scores 37 | 38 | def verbalize_team_result(team: frozenset, votes, outcome: bool): 39 | verbal_vote = { 40 | 0: "reject", 41 | 1: "approve" 42 | } 43 | verbalized_result = "" 44 | if outcome == True: 45 | verbalized_result = f"The team {str(list(team))} is approved." 46 | elif outcome == False: 47 | verbalized_result = f"The team {str(list(team))} is rejected." 48 | else: 49 | raise ValueError("Invalid outcome %s" % outcome) 50 | 51 | for idx, vote in enumerate(votes): 52 | verbalized_result += " Player %d voted %s." % (idx, verbal_vote[vote]) 53 | 54 | return verbalized_result 55 | 56 | def verbalize_mission_result(team: frozenset, outcome: bool): 57 | verbalized_result = "" 58 | if outcome == True: 59 | verbalized_result = "The mission succeeded." 60 | elif outcome == False: 61 | verbalized_result = "The mission failed." 62 | else: 63 | raise ValueError("Invalid outcome %s" % outcome) 64 | 65 | verbalized_result += " The team is %s, which contains" % str(list(team)) 66 | for member in team: 67 | verbalized_result += " Player %s" % str(member) 68 | 69 | return verbalized_result -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/action1.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import random 3 | 4 | from sdk.ai_client import Action, AIClient, Game 5 | from prompt.en import enemy_fish, your_fish 6 | 7 | class AI(AIClient): 8 | def __init__(self, stage) -> None: 9 | super().__init__() 10 | self.stage = stage 11 | self.name_to_id = {"spray": 1, "flame": 2, "eel": 3, "sunfish": 4, "barracuda": 5, "mobula": 6, "octopus": 8, "whiteshark": 9, "hammerhead": 10} 12 | self.id_to_name = {} 13 | for name, id in self.name_to_id.items(): 14 | self.id_to_name[id] = name 15 | self.id_to_name[-1] = "unknown" 16 | self.ai_fish = [] 17 | random.seed(42) 18 | 19 | def Pick(self, game: Game) -> List[int]: 20 | pick_list = [] 21 | self.pos_to_name = enemy_fish 22 | 23 | for i in range(4): 24 | pick_list.append(self.name_to_id[self.pos_to_name[i]]) 25 | self.ai_fish.append(self.name_to_id[your_fish[i]]) 26 | 27 | #random.shuffle(pick_list) 28 | return pick_list 29 | 30 | def Assert(self, game: Game) -> Tuple[int, int]: 31 | if self.stage == 1: 32 | return (-1, -1) 33 | else: 34 | return (self.get_enemy_living_fishes()[0], random.choice(self.ai_fish)) 35 | 36 | def Act(self, game: Game) -> Action: 37 | action = Action(game) 38 | my_pos: int = self.get_my_living_fishes()[0] 39 | action.set_action_fish(my_pos) 40 | return self.auto_valid_action(my_pos, action) 41 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/basline1.py: -------------------------------------------------------------------------------- 1 | from action1 import AI 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | stage = int(sys.argv[1]) 6 | 7 | myAI = AI(stage) 8 | myAI.run() 9 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/basline2.py: -------------------------------------------------------------------------------- 1 | from action2 import AI 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | stage = int(sys.argv[1]) 6 | 7 | myAI = AI(stage) 8 | myAI.run() 9 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/client.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import threading 3 | import json 4 | import time 5 | 6 | class Client: 7 | def __init__(self, port, host='localhost'): 8 | print(333) 9 | self.host = host 10 | self.port = port 11 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 12 | self.socket.connect((self.host, self.port)) 13 | self.stop_flag = threading.Event() 14 | 15 | def quit(self): 16 | self.send_message("") 17 | self.stop() 18 | 19 | def llm_call(self, history, prompt, system): 20 | message = [] 21 | 22 | if system: 23 | history.append((system, "Okay, I will play the game with you according to the rules.")) 24 | 25 | for ix, chat in enumerate(history): 26 | message.append({ 27 | "role": "user", 28 | "content": chat[0] 29 | }) 30 | message.append({ 31 | "role": "agent", 32 | "content": chat[1] 33 | }) 34 | 35 | message.append({ 36 | "role": "user", 37 | "content": prompt 38 | }) 39 | #with open("client.txt", "a") as f: 40 | # f.write(json.dumps(message) + "\n") 41 | self.send_message(json.dumps(message)) 42 | output = self.receive_messages() 43 | 44 | #with open("client.txt", "a") as f: 45 | # f.write(json.dumps(output) + "\n######################\n") 46 | return output 47 | 48 | def receive_messages(self): 49 | while not self.stop_flag.is_set(): 50 | data = self.socket.recv(1000000).decode() 51 | if data: 52 | return data 53 | 54 | def send_message(self, message): 55 | self.socket.sendall(message.encode()) 56 | 57 | def stop(self): 58 | self.stop_flag.set() 59 | self.socket.close() -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/main.py: -------------------------------------------------------------------------------- 1 | print(777) 2 | import sys 3 | from client import Client 4 | 5 | print(666) 6 | if __name__ == "__main__": 7 | print(555) 8 | language = sys.argv[1] 9 | stage = int(sys.argv[2]) 10 | order = int(sys.argv[3]) 11 | save_dir = sys.argv[4] 12 | port = int(sys.argv[5]) 13 | client = Client(port=port) 14 | if language == 'en': 15 | from AI_En import Agent 16 | myAI = Agent(client, stage, order, save_dir) 17 | else: 18 | from AI_Cn import Agent 19 | myAI = Agent(client, stage, order, save_dir) 20 | 21 | myAI.run() 22 | client.quit() -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/prompt/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/prompt/cn.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | action_format = { 4 | "选择的鱼": "选择我方仍活着的一条鱼", 5 | "动作": "选择主动技能或普通攻击", 6 | "目标位置": "目标位置,必须在 [0, 3] 范围内" 7 | } 8 | 9 | guess_format = { 10 | "猜测种类": "选择你想要猜测的种类", 11 | "目标位置": "目标位置,必须在 [0, 3] 范围内" 12 | } 13 | 14 | fish_description = { 15 | "射水鱼": { 16 | "被动技能": r"反弹:队友被直接攻击后,若其生命值少于30%,对来源造成60的伤害", 17 | "主动技能": r"范围伤害:对敌方所有目标发动攻击,每个造成35%攻击力的伤害" 18 | }, 19 | "喷火鱼": { 20 | "被动技能": r"反弹:队友被直接攻击后,若其生命值少于30%,对来源造成60的伤害", 21 | "主动技能": r"内斗:对一名队友发动攻击造成75伤害,并使自身攻击力永久提升140" 22 | }, 23 | "电鳗": { 24 | "被动技能": r"伤害转移:自身受到敌方直接攻击而受到伤害时,如果还有队友存活,会让队友帮忙承受伤害(自身承受70%,其余队友平摊30%);每累计受到的实际伤害值达到200,自身攻击值永久提升40", 25 | "主动技能": r"范围伤害:对敌方所有目标发动攻击,每个造成35%攻击力的伤害" 26 | }, 27 | "翻车鱼": { 28 | "被动技能": r"伤害转移:自身受到敌方直接攻击而受到伤害时,如果还有队友存活,会让队友帮忙承受伤害(自身承受70%,其余队友平摊30%);每累计受到的实际伤害值达到200,自身攻击值永久提升40", 29 | "主动技能": r"内斗:对一名队友发动攻击造成75伤害,并使自身攻击力永久提升140" 30 | } 31 | } 32 | 33 | ACTION_PROMPT = """这是一个双人对战游戏,每队拥有四条宠物鱼。 34 | 每条鱼都有400初始生命值、200攻击力,以及自己的主动技能和被动技能。 35 | 每个回合,你可以选择一条活着的鱼,并且使用主动技能,或用普通攻击敌方鱼。 36 | 鱼的被动技能都会自动触发,只要被动技能条件满足。 37 | %s 38 | 获胜的条件是在比赛结束时有更多的鱼活着。 39 | 40 | 以下是四种宠物鱼: 41 | %s 42 | 43 | 与我一起进行对战。在每一轮中,你都应该输出你的思考过程,并用以下JSON格式返回你的动作: 44 | %s 45 | 46 | 注意!你必须在每一轮中返回你的动作。否则,你将被直接判负。""" 47 | 48 | GUESS_RULE = """你拥有的鱼的初始信息是隐藏的,敌方可以在每回合内猜测你的一条鱼的种类,如果敌方猜对了,你的那条鱼的种类将被揭示,并且你的所有鱼会受到50的伤害""" 49 | 50 | GUESS_PROMPT = """这是一个双人对战游戏,每队拥有四条宠物鱼。 51 | 每条鱼都有初始生命值、攻击力,以及自己的主动技能和被动技能。 52 | 所有的鱼的初始信息是隐藏的,你可以在每回合内猜测敌方的一条鱼的种类,如果你猜对了,敌方那条鱼的种类将被揭示,并且敌方的所有鱼会受到50的伤害。 53 | 获胜的条件是在比赛结束时有更多的鱼活着。 54 | 55 | 以下是四种宠物鱼: 56 | %s 57 | 58 | 与我一起进行对战。在每一轮中,你都应该输出你的思考过程,并用以下JSON格式返回你的动作: 59 | %s 60 | 61 | 注意!你必须在每一轮中返回你的动作。否则,你将被直接判负。""" 62 | 63 | ACTION_DESCRIPTION = { 64 | 1: ACTION_PROMPT % ('', fish_description, action_format), 65 | 2: ACTION_PROMPT % (GUESS_RULE, fish_description, action_format) 66 | } 67 | 68 | GUESS_DESCRIPTION = { 69 | 2: GUESS_PROMPT % (fish_description, guess_format), 70 | } -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/prompt/cn_en.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | action_format = { 4 | "pick_fish": "pick an alive fish", 5 | "action": "choose from [normal, active]", 6 | "target_position": "target's position, you must choose from [0,3]" 7 | } 8 | 9 | guess_format = { 10 | "guess_type": "the enemy's fish type you may guess", 11 | "target_position": "guess target's position, you must choose from [0,3]" 12 | } 13 | 14 | fish_description = { 15 | "射水鱼": { 16 | "被动技能": r"反弹:队友被直接攻击后,若其生命值少于30%,对来源造成60的伤害", 17 | "主动技能": r"范围伤害:对敌方所有目标发动攻击,每个造成35%攻击力的伤害" 18 | }, 19 | "喷火鱼": { 20 | "被动技能": r"反弹:队友被直接攻击后,若其生命值少于30%,对来源造成60的伤害", 21 | "主动技能": r"内斗:对一名队友发动攻击造成100伤害,并使自身攻击力永久提升140" 22 | }, 23 | "电鳗": { 24 | "被动技能": r"伤害转移:自身受到敌方直接攻击而受到伤害时,如果还有队友存活,会让队友帮忙承受伤害(自身承受70%,其余队友平摊30%);每累计受到的实际伤害值达到200,自身攻击值永久提升40", 25 | "主动技能": r"范围伤害:对敌方所有目标发动攻击,每个造成35%攻击力的伤害" 26 | }, 27 | "翻车鱼": { 28 | "被动技能": r"伤害转移:自身受到敌方直接攻击而受到伤害时,如果还有队友存活,会让队友帮忙承受伤害(自身承受70%,其余队友平摊30%);每累计受到的实际伤害值达到200,自身攻击值永久提升40", 29 | "主动技能": r"内斗:对一名队友发动攻击造成100伤害,并使自身攻击力永久提升140" 30 | } 31 | } 32 | 33 | ACTION_PROMPT = """这是一个双人对战游戏,每队拥有四条宠物鱼。 34 | 每条鱼都有400初始生命值、200攻击力,以及自己的主动技能和被动技能。 35 | 每个回合,你可以选择一条活着的鱼,并且使用主动技能,或用普通攻击敌方鱼。 36 | 鱼的被动技能都会自动触发,只要被动技能条件满足。 37 | %s 38 | 获胜的条件是在比赛结束时有更多的鱼活着。 39 | 40 | 以下是四种宠物鱼: 41 | %s 42 | 43 | 与我一起进行对战。在每一轮中,你都应该输出你的思考过程,并用以下JSON格式返回你的动作: 44 | %s 45 | 46 | 注意!你必须在每一轮中返回你的动作。否则,你将被直接判负。""" 47 | 48 | GUESS_RULE = """你拥有的鱼的初始信息是隐藏的,敌方可以在每回合内猜测你的一条鱼的种类,如果敌方猜对了,你的那条鱼的种类将被揭示,并且你的所有鱼会受到50的伤害""" 49 | 50 | GUESS_PROMPT = """这是一个双人对战游戏,每队拥有四条宠物鱼。 51 | 每条鱼都有初始生命值、攻击力,以及自己的主动技能和被动技能。 52 | 所有的鱼的初始信息是隐藏的,你可以在每回合内猜测敌方的一条鱼的种类,如果你猜对了,敌方那条鱼的种类将被揭示,并且敌方的所有鱼会受到50的伤害。 53 | 获胜的条件是在比赛结束时有更多的鱼活着。 54 | 55 | 以下是四种宠物鱼: 56 | %s 57 | 58 | 与我一起进行对战。在每一轮中,你都应该输出你的思考过程,并用以下JSON格式返回你的动作: 59 | %s 60 | 61 | 注意!你必须在每一轮中返回你的动作。否则,你将被直接判负。""" 62 | 63 | ACTION_DESCRIPTION = { 64 | 1: ACTION_PROMPT % ('', fish_description, action_format), 65 | 2: ACTION_PROMPT % (GUESS_RULE, fish_description, action_format) 66 | } 67 | 68 | GUESS_DESCRIPTION = { 69 | 2: GUESS_PROMPT % (fish_description, guess_format), 70 | } -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | 3 | project(py_ai_sdk CXX) 4 | 5 | set(CMAKE_BUILD_TYPE Release) 6 | if(MSVC) 7 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8 /std:c++17") 8 | else(GCC or Clang) 9 | set(CMAKE_CXX_STANDARD 17) 10 | endif(MSVC) 11 | set(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin) 12 | 13 | # set(CMAKE_PREFIX_PATH "") 14 | # find_package(Python EXACT COMPONENTS Interpreter Development REQUIRED) 15 | set(pybind11_DIR 16 | "/home/piano/mambaforge/lib/python3.10/site-packages/pybind11/share/cmake/pybind11") 17 | find_package(pybind11 REQUIRED) 18 | 19 | file(GLOB_RECURSE jsoncpp "jsoncpp/*.cpp") 20 | pybind11_add_module(ai_client py_ai_sdk.cpp ${jsoncpp}) 21 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/ai_client.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/ai_client.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/ai_client.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/ai_client.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/3.25.0/CMakeDetermineCompilerABI_CXX.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/3.25.0/CMakeDetermineCompilerABI_CXX.bin -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/3.25.0/CMakeSystem.cmake: -------------------------------------------------------------------------------- 1 | set(CMAKE_HOST_SYSTEM "Linux-5.10.16.3-microsoft-standard-WSL2") 2 | set(CMAKE_HOST_SYSTEM_NAME "Linux") 3 | set(CMAKE_HOST_SYSTEM_VERSION "5.10.16.3-microsoft-standard-WSL2") 4 | set(CMAKE_HOST_SYSTEM_PROCESSOR "x86_64") 5 | 6 | 7 | 8 | set(CMAKE_SYSTEM "Linux-5.10.16.3-microsoft-standard-WSL2") 9 | set(CMAKE_SYSTEM_NAME "Linux") 10 | set(CMAKE_SYSTEM_VERSION "5.10.16.3-microsoft-standard-WSL2") 11 | set(CMAKE_SYSTEM_PROCESSOR "x86_64") 12 | 13 | set(CMAKE_CROSSCOMPILING "FALSE") 14 | 15 | set(CMAKE_SYSTEM_LOADED 1) 16 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/3.25.0/CompilerIdCXX/a.out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/3.25.0/CompilerIdCXX/a.out -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/CMakeDirectoryInformation.cmake: -------------------------------------------------------------------------------- 1 | # CMAKE generated file: DO NOT EDIT! 2 | # Generated by "Unix Makefiles" Generator, CMake Version 3.25 3 | 4 | # Relative path conversion top directories. 5 | set(CMAKE_RELATIVE_PATH_TOP_SOURCE "/mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk") 6 | set(CMAKE_RELATIVE_PATH_TOP_BINARY "/mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/build") 7 | 8 | # Force unix paths in dependencies. 9 | set(CMAKE_FORCE_UNIX_PATHS 1) 10 | 11 | 12 | # The C and CXX include file regular expressions for this directory. 13 | set(CMAKE_C_INCLUDE_REGEX_SCAN "^.*$") 14 | set(CMAKE_C_INCLUDE_REGEX_COMPLAIN "^$") 15 | set(CMAKE_CXX_INCLUDE_REGEX_SCAN ${CMAKE_C_INCLUDE_REGEX_SCAN}) 16 | set(CMAKE_CXX_INCLUDE_REGEX_COMPLAIN ${CMAKE_C_INCLUDE_REGEX_COMPLAIN}) 17 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/TargetDirectories.txt: -------------------------------------------------------------------------------- 1 | /mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/build/CMakeFiles/ai_client.dir 2 | /mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/build/CMakeFiles/edit_cache.dir 3 | /mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/build/CMakeFiles/rebuild_cache.dir 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/DependInfo.cmake: -------------------------------------------------------------------------------- 1 | 2 | # Consider dependencies only in project. 3 | set(CMAKE_DEPENDS_IN_PROJECT_ONLY OFF) 4 | 5 | # The set of languages for which implicit dependencies are needed: 6 | set(CMAKE_DEPENDS_LANGUAGES 7 | ) 8 | 9 | # The set of dependency files which are needed: 10 | set(CMAKE_DEPENDS_DEPENDENCY_FILES 11 | "/mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/jsoncpp/jsoncpp.cpp" "CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o" "gcc" "CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o.d" 12 | "/mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/py_ai_sdk.cpp" "CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o" "gcc" "CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o.d" 13 | ) 14 | 15 | # Targets to which this target links. 16 | set(CMAKE_TARGET_LINKED_INFO_FILES 17 | ) 18 | 19 | # Fortran module output directory. 20 | set(CMAKE_Fortran_TARGET_MODULE_DIR "") 21 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/cmake_clean.cmake: -------------------------------------------------------------------------------- 1 | file(REMOVE_RECURSE 2 | "CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o" 3 | "CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o.d" 4 | "CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o" 5 | "CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o.d" 6 | "bin/ai_client.cpython-310-x86_64-linux-gnu.so" 7 | "bin/ai_client.pdb" 8 | ) 9 | 10 | # Per-language clean rules from dependency scanning. 11 | foreach(lang CXX) 12 | include(CMakeFiles/ai_client.dir/cmake_clean_${lang}.cmake OPTIONAL) 13 | endforeach() 14 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/compiler_depend.make: -------------------------------------------------------------------------------- 1 | # Empty compiler generated dependencies file for ai_client. 2 | # This may be replaced when dependencies are built. 3 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/compiler_depend.ts: -------------------------------------------------------------------------------- 1 | # CMAKE generated file: DO NOT EDIT! 2 | # Timestamp file for compiler generated dependencies management for ai_client. 3 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/depend.make: -------------------------------------------------------------------------------- 1 | # Empty dependencies file for ai_client. 2 | # This may be replaced when dependencies are built. 3 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/flags.make: -------------------------------------------------------------------------------- 1 | # CMAKE generated file: DO NOT EDIT! 2 | # Generated by "Unix Makefiles" Generator, CMake Version 3.25 3 | 4 | # compile CXX with /usr/bin/c++ 5 | CXX_DEFINES = -Dai_client_EXPORTS 6 | 7 | CXX_INCLUDES = -isystem /home/piano/mambaforge/lib/python3.10/site-packages/pybind11/include -isystem /mnt/e/conda/envs/ml/include/python3.10 8 | 9 | CXX_FLAGS = -O3 -DNDEBUG -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -std=gnu++17 10 | 11 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/link.txt: -------------------------------------------------------------------------------- 1 | /usr/bin/c++ -fPIC -O3 -DNDEBUG -flto -shared -o bin/ai_client.cpython-310-x86_64-linux-gnu.so CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o CMakeFiles/ai_client.dir/jsoncpp/jsoncpp.cpp.o 2 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/progress.make: -------------------------------------------------------------------------------- 1 | CMAKE_PROGRESS_1 = 1 2 | CMAKE_PROGRESS_2 = 2 3 | CMAKE_PROGRESS_3 = 3 4 | 5 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/ai_client.dir/py_ai_sdk.cpp.o -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/cmake.check_cache: -------------------------------------------------------------------------------- 1 | # This file is generated by cmake for dependency checking of the CMakeCache.txt file 2 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/CMakeFiles/progress.marks: -------------------------------------------------------------------------------- 1 | 3 2 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/bin/ai_client.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/build/bin/ai_client.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/build/cmake_install.cmake: -------------------------------------------------------------------------------- 1 | # Install script for directory: /mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk 2 | 3 | # Set the install prefix 4 | if(NOT DEFINED CMAKE_INSTALL_PREFIX) 5 | set(CMAKE_INSTALL_PREFIX "/usr/local") 6 | endif() 7 | string(REGEX REPLACE "/$" "" CMAKE_INSTALL_PREFIX "${CMAKE_INSTALL_PREFIX}") 8 | 9 | # Set the install configuration name. 10 | if(NOT DEFINED CMAKE_INSTALL_CONFIG_NAME) 11 | if(BUILD_TYPE) 12 | string(REGEX REPLACE "^[^A-Za-z0-9_]+" "" 13 | CMAKE_INSTALL_CONFIG_NAME "${BUILD_TYPE}") 14 | else() 15 | set(CMAKE_INSTALL_CONFIG_NAME "Release") 16 | endif() 17 | message(STATUS "Install configuration: \"${CMAKE_INSTALL_CONFIG_NAME}\"") 18 | endif() 19 | 20 | # Set the component getting installed. 21 | if(NOT CMAKE_INSTALL_COMPONENT) 22 | if(COMPONENT) 23 | message(STATUS "Install component: \"${COMPONENT}\"") 24 | set(CMAKE_INSTALL_COMPONENT "${COMPONENT}") 25 | else() 26 | set(CMAKE_INSTALL_COMPONENT) 27 | endif() 28 | endif() 29 | 30 | # Install shared libraries without execute permission? 31 | if(NOT DEFINED CMAKE_INSTALL_SO_NO_EXE) 32 | set(CMAKE_INSTALL_SO_NO_EXE "1") 33 | endif() 34 | 35 | # Is this installation the result of a crosscompile? 36 | if(NOT DEFINED CMAKE_CROSSCOMPILING) 37 | set(CMAKE_CROSSCOMPILING "FALSE") 38 | endif() 39 | 40 | # Set default install directory permissions. 41 | if(NOT DEFINED CMAKE_OBJDUMP) 42 | set(CMAKE_OBJDUMP "/usr/bin/objdump") 43 | endif() 44 | 45 | if(CMAKE_INSTALL_COMPONENT) 46 | set(CMAKE_INSTALL_MANIFEST "install_manifest_${CMAKE_INSTALL_COMPONENT}.txt") 47 | else() 48 | set(CMAKE_INSTALL_MANIFEST "install_manifest.txt") 49 | endif() 50 | 51 | string(REPLACE ";" "\n" CMAKE_INSTALL_MANIFEST_CONTENT 52 | "${CMAKE_INSTALL_MANIFEST_FILES}") 53 | file(WRITE "/mnt/c/Users/piano/Downloads/AquaWarAI/AI_SDK/Python/sdk/build/${CMAKE_INSTALL_MANIFEST}" 54 | "${CMAKE_INSTALL_MANIFEST_CONTENT}") 55 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/py_json_cast.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef PY_JSON_CAST_HPP_INCLUDED 4 | #define PY_JSON_CAST_HPP_INCLUDED 5 | 6 | #include 7 | 8 | #include "jsoncpp/json/json.h" 9 | 10 | namespace pybind11::detail { 11 | template <> 12 | class type_caster { 13 | public: 14 | PYBIND11_TYPE_CASTER(Json::Value, _("Json")); 15 | 16 | bool load(handle src, bool) try { 17 | value = handle_to_json(src); 18 | return true; 19 | } catch (type_error) { 20 | return false; 21 | } 22 | 23 | static handle cast(Json::Value src, return_value_policy, handle) { 24 | return json_to_handle(src).release(); 25 | } 26 | 27 | private: 28 | static Json::Value handle_to_json(const handle &hdl) { 29 | if (hdl.ptr() == nullptr || hdl.is_none()) return Json::nullValue; 30 | if (isinstance(hdl)) return hdl.cast(); 31 | if (isinstance(hdl)) return hdl.cast(); 32 | if (isinstance(hdl)) return hdl.cast(); 33 | if (isinstance(hdl)) return hdl.cast(); 34 | if (isinstance(hdl) || isinstance(hdl) || 35 | isinstance(hdl)) { 36 | Json::Value ret = Json::arrayValue; 37 | for (const handle &h : hdl) ret.append(handle_to_json(h)); 38 | return ret; 39 | } 40 | if (isinstance(hdl)) { 41 | Json::Value ret = Json::objectValue; 42 | for (const handle &key : hdl) 43 | ret[str(key)] = handle_to_json(hdl[key]); 44 | return ret; 45 | } 46 | throw type_error("Bad cast from Python to C++: " + 47 | repr(hdl).cast()); 48 | } 49 | static object json_to_handle(const Json::Value &json) { 50 | if (json.isNull()) return none(); 51 | if (json.isBool()) return bool_(json.asBool()); 52 | if (json.isIntegral()) return int_(json.asLargestInt()); 53 | if (json.isNumeric()) return float_(json.asDouble()); 54 | if (json.isString()) return str(json.asCString()); 55 | if (json.isArray()) { 56 | list ret; 57 | for (const Json::Value &j : json) ret.append(json_to_handle(j)); 58 | return ret; 59 | } 60 | if (json.isObject()) { 61 | dict ret; 62 | for (Json::ValueConstIterator iter = json.begin(); 63 | iter != json.end(); ++iter) 64 | ret[str(iter.key().asCString())] = json_to_handle(*iter); 65 | return ret; 66 | } 67 | throw type_error("Bad cast from C++ to Python: " + json.asString()); 68 | } 69 | }; 70 | } // namespace pybind11::detail 71 | 72 | #endif // PY_JSON_CAST_HPP_INCLUDED -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cp38-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cp38-win_amd64.pyd -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-38-darwin.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-38-darwin.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/AI/sdk/pyd/ai_client.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import CardGame 2 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/judger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/judger/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | cmake-build-debug -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | g++ src/aqua_sdk.cpp src/fish.cpp src/fishset.cpp src/game.cpp src/main.cpp src/player.cpp src/jsoncpp/jsoncpp.cpp -o bin/main -std=c++17 -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/bin/main: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/logic/bin/main -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/fishset.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "fish.h" 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | extern std::ofstream debuggerfs; 10 | void debugmsg(std::string str); 11 | 12 | class FishSet 13 | { 14 | private: 15 | std::vector fishs; 16 | 17 | public: 18 | int player_id; 19 | int timestamp; 20 | FishSet* another = nullptr; 21 | FishSet(int flag = 0); 22 | FishSet(const FishSet& base); // 根据一个已知鱼集合创建,得到一个鱼编号集合和给定集合一致的鱼集合 23 | Json::Value to_json() const; 24 | std::vector get_fishs() const; 25 | int get_size(); 26 | bool empty(); 27 | void clear(); 28 | void add(Fish* fish); 29 | void add(int id); // 添加对应id的鱼 30 | void remove(int id); // 移除对应id的鱼 31 | void to_fight(); // 所有鱼变为战斗状态 32 | void to_dead(); // 所有鱼变为阵亡状态 33 | void update_state(); // 将所有鱼从战斗状态更新到最新状态(可能是阵亡状态) 34 | // void hp_debuff(double rate); 35 | void hp_debuff(int dec); 36 | int count_live_fish(); 37 | int living_fish_count() const; // 存活的鱼的个数 38 | int hp_sum() const; // 鱼的 hp 之和 39 | int hp_max() const; // 鱼的 hp 最大值 40 | bool is_all_dead(); 41 | bool count(Fish* tar); 42 | int update_timestamp(); 43 | 44 | std::optional on_damaged(Fish *src, Fish *target, int dmg = -1); 45 | void set_fishset(FishSet *_ally, FishSet *_hostile); 46 | void start_turn(); 47 | // void debug_msg(std::string str); 48 | }; -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/game.cpp: -------------------------------------------------------------------------------- 1 | #include "game.h" 2 | 3 | // 将 Game::State 转为对应阶段名称 4 | std::string Game::state_info(Game::State state){ 5 | switch(state){ 6 | case READY: return "准备阶段"; 7 | case PICK: return "选择阶段"; 8 | case ASSERT: return "断言阶段"; 9 | case ACTION: return "行动阶段"; 10 | case END: return "游戏结束"; 11 | } 12 | return ""; 13 | } 14 | 15 | Game::Game(){ 16 | players.clear(); 17 | players.push_back(Player(0)); 18 | players.push_back(Player(1)); 19 | } 20 | 21 | /* 22 | 按照成员变量定义顺序转成 JSON 23 | 样例: 24 | { 25 | "winner": winner, 26 | "players": [ 27 | player1, 28 | player2, 29 | ... 30 | ], 31 | "state": state, 32 | "gamestate": gamestate, 33 | "cur_turn": cur_turn, 34 | "over": over, 35 | "cnt": cnt, 36 | "score": score, 37 | "rounds": rounds 38 | } 39 | */ 40 | Json::Value Game::to_json() const{ 41 | Json::Value json; 42 | json["winner"] = winner; 43 | json["players"].resize(0); 44 | for(const auto& player : players) 45 | json["players"].append(player.to_json()); 46 | json["state"] = state; 47 | json["gamestate"] = gamestate; 48 | json["cur_turn"] = cur_turn; 49 | json["over"] = over; 50 | json["cnt"] = cnt; 51 | json["score"] = score; 52 | json["rounds"] = rounds; 53 | return json; 54 | } -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/game.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "player.h" 4 | 5 | class Game { 6 | public: 7 | enum State { 8 | READY = 0, 9 | PICK = 2, 10 | ASSERT = 3, 11 | ACTION = 4, 12 | END = 5 13 | }; 14 | static constexpr int STATE_LIMIT = 128; 15 | static std::string state_info(State state); 16 | int winner = -1; // 获胜玩家id 17 | int last_winner = -1; // 上一轮获胜玩家id 18 | std::vector players; 19 | int state = 1; // 当前回合(传递给Judger的) 20 | int last_round_state = 0; 21 | int gamestate = READY; // 当前阶段 22 | int cur_turn = 0; // 当前操作者 23 | int first_mover; // 当前轮的先手 24 | bool over = false; // 游戏是否结束 25 | int cnt = 0; 26 | int imiid[2] = {-1, -1}; 27 | int errorai = 0; 28 | int score = 0; // 记录比分 29 | int rounds = 0; // 记录战斗轮数 30 | Game(); 31 | Json::Value to_json() const; 32 | }; -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/main: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/card_game/logic/src/main -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "aqua_sdk.cpp" 2 | 3 | int main(int argc, char* argv[]){ 4 | AquaWarSDK aw; 5 | aw.start(); 6 | aw.run(); 7 | 8 | return 0; 9 | } -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/player.cpp: -------------------------------------------------------------------------------- 1 | #include "player.h" 2 | 3 | Player::Player(int _id):id(_id){ 4 | my_fish = FishSet(1); 5 | my_fish.player_id = _id; 6 | fight_fish = FishSet(); 7 | fight_fish.player_id = _id; 8 | } 9 | 10 | Player::Player(int _id, FishSet _my_fish):id(_id),my_fish(_my_fish){ 11 | my_fish = FishSet(1); 12 | fight_fish = FishSet(); 13 | } 14 | 15 | Player::Player(const Player &p){ 16 | id = p.id; 17 | type = p.type; 18 | my_fish = p.my_fish; 19 | fight_fish = p.fight_fish; 20 | } 21 | 22 | /* 23 | 按照成员变量定义顺序转成 JSON 24 | 样例: 25 | { 26 | "id": id, 27 | "type": type, 28 | "my_fish": my_fish, 29 | "fight_fish": fight_fish 30 | } 31 | */ 32 | Json::Value Player::to_json() const{ 33 | Json::Value json; 34 | json["id"] = id; 35 | json["type"] = type; 36 | json["my_fish"] = my_fish.to_json(); 37 | json["fight_fish"] = fight_fish.to_json(); 38 | return json; 39 | } 40 | 41 | int Player::get_id(){ 42 | return id; 43 | } 44 | 45 | std::vector Player::get_fishs() const{ 46 | return my_fish.get_fishs(); 47 | } 48 | 49 | std::vector Player::get_fight_fishs() const{ 50 | return fight_fish.get_fishs(); 51 | } 52 | 53 | int Player::get_size(){ 54 | return my_fish.get_size(); 55 | } 56 | 57 | bool Player::empty(){ 58 | return my_fish.empty(); 59 | } 60 | 61 | void Player::clear(){ 62 | my_fish.clear(); 63 | } 64 | 65 | void Player::add(Fish* fish){ 66 | my_fish.add(fish); 67 | } 68 | 69 | void Player::add(int id){ 70 | my_fish.add(id); 71 | } 72 | 73 | void Player::remove(int id){ 74 | my_fish.remove(id); 75 | } 76 | 77 | void Player::to_war(int id){ 78 | my_fish.remove(id); 79 | fight_fish.add(id); 80 | } -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/player.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "fishset.h" 3 | #include 4 | 5 | class Player { 6 | private: 7 | int id; // 玩家id 8 | 9 | public: 10 | int type = 1; // 1 => AI, 2 => 播放器 11 | FishSet my_fish; // 队伍里的所有剩余鱼 12 | FishSet fight_fish; // 当前上场战斗的鱼 13 | Player(int _id); 14 | Player(int _id, FishSet _my_fish); 15 | Player(const Player &p); 16 | Json::Value to_json() const; 17 | int get_id(); 18 | std::vector get_fishs() const; 19 | std::vector get_fight_fishs() const; 20 | int get_size(); 21 | bool empty(); 22 | void clear(); 23 | void add(Fish* fish); 24 | void add(int id); // 添加对应id的鱼 25 | void remove(int id); // 移除对应id的鱼 26 | void to_war(int id); // 编号为id的一只鱼上场战斗 27 | }; -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/logic/src/timer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | class Timer { 6 | public: 7 | using BaseClock = std::chrono::steady_clock; 8 | using TimePoint = BaseClock::time_point; 9 | using Duration = BaseClock::duration; 10 | 11 | Timer() : m_start(Timer::now()) {} 12 | int runtime() const { 13 | Duration time = Timer::now() - m_start; 14 | return std::chrono::duration_cast(time).count(); 15 | } 16 | static TimePoint now() { return BaseClock::now(); } 17 | 18 | private: 19 | TimePoint m_start; 20 | }; -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import socket 3 | 4 | from src.typings.general import ChatHistoryItem 5 | 6 | 7 | class Server: 8 | def __init__(self, port, workers, host="localhost"): 9 | self.host = host 10 | self.port = port 11 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 12 | 13 | while True: 14 | try: 15 | self.socket.bind((self.host, self.port)) 16 | break 17 | except OSError as e: 18 | self.port += 1 19 | self.socket.listen(workers + 2) 20 | self.log = {} 21 | self.status = {} 22 | 23 | async def start(self, folder, session): 24 | log_file = [] 25 | print(111) 26 | client_socket, client_address = self.socket.accept() 27 | print(222) 28 | while True: 29 | data = client_socket.recv(1000000).decode() 30 | if data == "": 31 | self.stop(client_socket) 32 | break 33 | elif data.startswith("#[ERROR]"): 34 | status = int(data[-1]) 35 | self.status[folder] = status 36 | else: 37 | try: 38 | session.history = json.loads(data) 39 | session.history = [ChatHistoryItem(**item) for item in session.history] 40 | log_file.append({"role": "user", "content": data}) 41 | ret = await session.action() 42 | if ret.content is None: 43 | self.status[folder] = 3 44 | self.send_message(client_socket, "### LLM ERROR EXIT ###") 45 | break 46 | else: 47 | ret = ret.content 48 | print("\n######\n") 49 | print(ret) 50 | log_file.append({"role": "agent", "content": ret}) 51 | print("sending message") 52 | self.send_message(client_socket, ret) 53 | print("message sent") 54 | except json.decoder.JSONDecodeError: 55 | log_file.append({"role": "agent", "content": ""}) 56 | print("except sending") 57 | self.send_message(client_socket, "") 58 | print("except message sent") 59 | self.log[folder] = log_file 60 | 61 | def send_message(self, client_socket, message): 62 | client_socket.sendall(message.encode()) 63 | 64 | def stop(self, client_socket): 65 | client_socket.close() 66 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/card_game/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import signal 4 | import subprocess 5 | 6 | 7 | def run_cmd(cmd_string, timeout=600): 8 | print("命令为:" + cmd_string) 9 | p = subprocess.Popen( 10 | cmd_string, 11 | stderr=subprocess.PIPE, 12 | stdout=subprocess.PIPE, 13 | shell=True, 14 | close_fds=True, 15 | start_new_session=True, 16 | ) 17 | print("created") 18 | encoding = "utf-8" 19 | if platform.system() == "Windows": 20 | encoding = "gbk" 21 | 22 | try: 23 | print("trying") 24 | (msg, errs) = p.communicate(timeout=timeout) 25 | print("comed") 26 | ret_code = p.poll() 27 | print("polled") 28 | if ret_code: 29 | code = 1 30 | msg = "[Error]Called Error : " + str(msg.decode(encoding)) 31 | else: 32 | code = 0 33 | msg = str(msg.decode(encoding)) 34 | print(ret_code) 35 | except subprocess.TimeoutExpired: 36 | p.kill() 37 | p.terminate() 38 | os.killpg(p.pid, signal.SIGTERM) 39 | 40 | code = 1 41 | msg = ( 42 | "[ERROR]Timeout Error : Command '" 43 | + cmd_string 44 | + "' timed out after " 45 | + str(timeout) 46 | + " seconds" 47 | ) 48 | except Exception as e: 49 | code = 1 50 | msg = "[ERROR]Unknown Error : " + str(e) 51 | 52 | print("returning") 53 | 54 | return code, msg 55 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/dbbench/requirements.txt: -------------------------------------------------------------------------------- 1 | mysql-connector-python==8.0.33 2 | docker==6.1.2 -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/knowledgegraph/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import KnowledgeGraph -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/knowledgegraph/ontology/vocab.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6a178fa7f348631d816e37a384a89039b6ef1903c8797c4a20fbe97074ba6453 3 | size 1137427 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/knowledgegraph/requirements.txt: -------------------------------------------------------------------------------- 1 | SPARQLWrapper 2 | networkx -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/knowledgegraph/utils/semparse_util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def lisp_to_nested_expression(lisp_string: str) -> List: 5 | """ 6 | Takes a logical form as a lisp string and returns a nested list representation of the lisp. 7 | For example, "(count (division first))" would get mapped to ['count', ['division', 'first']]. 8 | """ 9 | stack: List = [] 10 | current_expression: List = [] 11 | tokens = lisp_string.split() 12 | for token in tokens: 13 | while token[0] == '(': 14 | nested_expression: List = [] 15 | current_expression.append(nested_expression) 16 | stack.append(current_expression) 17 | current_expression = nested_expression 18 | token = token[1:] 19 | current_expression.append(token.replace(')', '')) 20 | while token[-1] == ')': 21 | current_expression = stack.pop() 22 | token = token[:-1] 23 | return current_expression[0] 24 | 25 | def expression_to_lisp(expression) -> str: 26 | rtn = '(' 27 | for i, e in enumerate(expression): 28 | if isinstance(e, list): 29 | rtn += expression_to_lisp(e) 30 | else: 31 | rtn += e 32 | if i != len(expression) - 1: 33 | rtn += ' ' 34 | 35 | rtn += ')' 36 | return rtn 37 | 38 | 39 | def get_nesting_level(expression) -> int: 40 | max_sub = 0 41 | for item in expression: 42 | if isinstance(item, list): 43 | level = get_nesting_level(item) 44 | if level > max_sub: 45 | max_sub = level 46 | 47 | return 1 + max_sub 48 | 49 | 50 | 51 | if __name__ == '__main__': 52 | lisp = '(AND common.topic (AND (JOIN common.topic.notable_types Comic Strip) (JOIN common.topic.notable_types Comic Strip)))' 53 | print(get_nesting_level(lisp_to_nested_expression(lisp))) 54 | 55 | print(expression_to_lisp(lisp_to_nested_expression(lisp))) 56 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/ltp/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import LateralThinkingPuzzle -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/mind2web/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import Mind2Web 2 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/mind2web/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/mind2web/data_utils/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/os_interaction/__init__.py: -------------------------------------------------------------------------------- 1 | from .task import OSInteraction -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/os_interaction/requirements.txt: -------------------------------------------------------------------------------- 1 | docker 2 | PyYAML 3 | PyYAML 4 | Requests 5 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## I'm submitting a ... 2 | 3 | - [ ] bug report 4 | - [ ] feature request 5 | 6 | ## What is the current behavior? 7 | 8 | Please describe the current behavior of the WebShop app or agent 9 | 10 | ### Steps to Reproduce 11 | 12 | If the current behavior is a bug, please provide the steps to reproduce and if possible a minimal demo of the problem 13 | 14 | ## What is the expected behavior? 15 | 16 | Please describe the desired behavior of the WebShop app or agent 17 | 18 | ## Motivation for Change 19 | 20 | What is the motivation / use case for changing the behavior? 21 | 22 | ## Please tell us about your environment: 23 | 24 | * Version: 2.0.0-beta.X 25 | * Browser: 26 | * Language: 27 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | Provide a general summary of your changes 4 | 5 | ## Description of Changes 6 | 7 | Describe your changes + testing (if appropriate) in technical detail 8 | 9 | ## Screenshots 10 | 11 | Include visuals such as screenshots or recordings if necessary to show changes in effect 12 | 13 | ## Checklist 14 | - [ ] My code follows the style guidelines of this project 15 | - [ ] I have performed a self-review of my code 16 | - [ ] I have commented my code + updated documentation (if necessary) 17 | - [ ] I have added tests to define the behavior of the feature(s) and verify it is working 18 | - [ ] New + existing unit tests pass -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/pytest.yml 2 | name: PyTest 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 10 16 | 17 | steps: 18 | - name: Check out repository code 19 | uses: actions/checkout@v3 20 | 21 | # Setup Python (faster than using Python container) 22 | - name: Setup Python 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.8" 26 | 27 | # Install pip dependencies + setup for testing 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -r requirements.txt 32 | python -m spacy download en_core_web_lg 33 | 34 | # Run testing suite 35 | - name: Run test suite 36 | run: | 37 | pytest -v -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb* 2 | *.pyc 3 | *.swp 4 | 5 | .DS_Store 6 | .idea/ 7 | .pytest_cache/ 8 | .vscode/ 9 | 10 | __pycache__/ 11 | data/ 12 | search_engine/indexes* 13 | search_engine/resources* 14 | transfer/flagged 15 | user_session_logs/ 16 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/assets/diagram.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/assets/diagram.gif -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/assets/model_ckpts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/assets/model_ckpts.png -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/assets/transfer-logic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/assets/transfer-logic.png -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/baseline_models/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | wandb/ 3 | ckpts/ 4 | __pycache__/ 5 | data/il_trajs_finalized_images.jsonl 6 | *.ipynb 7 | 8 | 9 | *.txt 10 | !requirements.txt 11 | scripts/ 12 | *.out 13 | wandb/ 14 | *.swp 15 | logs/ 16 | .DS_Store 17 | .idea/ 18 | nbs/ 19 | 20 | crawl-* 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # PyBuilder 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # celery beat schedule file 100 | celerybeat-schedule 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/baseline_models/generate_search.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import torch 5 | from tqdm import tqdm 6 | from transformers import BartForConditionalGeneration 7 | 8 | from train_search import get_data, get_dataset, tokenizer 9 | 10 | if __name__ == "__main__": 11 | model = BartForConditionalGeneration.from_pretrained( 12 | './ckpts/web_search/checkpoint-800') 13 | model.eval() 14 | model = model.to('cuda') 15 | dataset = get_dataset("web_search") 16 | dataloader = torch.utils.data.DataLoader(dataset["all"], batch_size=32) 17 | _, all_goals = get_data("all") 18 | all_dec = [] 19 | for batch in tqdm(dataloader): 20 | output = model.generate( 21 | input_ids=batch["input_ids"].to('cuda'), 22 | attention_mask=batch["attention_mask"].to('cuda'), 23 | num_beams=10, num_return_sequences=10, 24 | max_length=512, early_stopping=True 25 | ) 26 | dec = tokenizer.batch_decode( 27 | output, skip_special_tokens=True, clean_up_tokenization_spaces=False) 28 | assert len(dec) % 10 == 0 29 | for i in range(len(dec) // 10): 30 | all_dec.append(dec[i*10:(i+1)*10]) 31 | assert len(all_goals) == len(all_dec) 32 | d = {goal: dec for goal, dec in zip(all_goals, all_dec)} 33 | with open('./data/goal_query_predict.json', 'w') as f: 34 | json.dump(d, f) 35 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/baseline_models/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | faiss-gpu 4 | transformers 5 | wandb -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/conftest.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.11.1 2 | cleantext==1.1.4 3 | env==0.1.0 4 | Flask==2.1.2 5 | gdown 6 | gradio 7 | gym==0.24.0 8 | numpy==1.22.4 9 | pandas==1.4.2 10 | pyserini==0.17.0 11 | pytest 12 | PyYAML==6.0 13 | rank_bm25==0.2.2 14 | requests==2.27.1 15 | requests_mock 16 | rich==12.4.4 17 | scikit_learn==1.1.1 18 | selenium==4.2.0 19 | spacy 20 | thefuzz==0.19.0 21 | torch==1.11.0 22 | tqdm==4.64.0 23 | train==0.0.5 24 | transformers==4.19.2 25 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/run_dev.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export FLASK_ENV=development 3 | python -m web_agent_site.app --log --attrs 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/run_envs/run_web_agent_site_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the site gym environment. 3 | 4 | TODO: move to testing dir for more rigorous tests 5 | """ 6 | import gym 7 | from rich import print 8 | from rich.markup import escape 9 | 10 | from web_agent_site.envs import WebAgentSiteEnv 11 | from web_agent_site.models import * 12 | from web_agent_site.utils import DEBUG_PROD_SIZE 13 | 14 | 15 | if __name__ == '__main__': 16 | #env = gym.make('WebAgentSite-v0') 17 | #env = WebAgentSiteEnv(render=True, pause=2.0) 18 | #env = WebAgentSiteEnv(observation_mode='html', render=False) 19 | env = WebAgentSiteEnv(observation_mode='text', render=False, num_products=DEBUG_PROD_SIZE) 20 | global_step = 0 21 | 22 | try: 23 | #policy = HumanPolicy() 24 | policy = GPTPolicy() 25 | 26 | observation = env.observation 27 | while True: 28 | print(observation) 29 | available_actions = env.get_available_actions() 30 | print('Available actions:', available_actions) 31 | action = policy.forward(observation, available_actions) 32 | observation, reward, done, info = env.step(action) 33 | print(f'Taking action "{escape(action)}" -> Reward = {reward}') 34 | if done: 35 | break 36 | global_step += 1 37 | finally: 38 | env.close() -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/run_prod.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m web_agent_site.app --log 3 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/run_web_agent_site_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m run_envs.run_web_agent_site_env 3 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/run_web_agent_text_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m run_envs.run_web_agent_text_env 3 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/search_engine/convert_product_file_format.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from tqdm import tqdm 4 | sys.path.insert(0, '../') 5 | 6 | from web_agent_site.utils import DEFAULT_FILE_PATH 7 | from web_agent_site.engine.engine import load_products 8 | 9 | all_products, *_ = load_products(filepath=DEFAULT_FILE_PATH) 10 | 11 | 12 | docs = [] 13 | for p in tqdm(all_products, total=len(all_products)): 14 | option_texts = [] 15 | options = p.get('options', {}) 16 | for option_name, option_contents in options.items(): 17 | option_contents_text = ', '.join(option_contents) 18 | option_texts.append(f'{option_name}: {option_contents_text}') 19 | option_text = ', and '.join(option_texts) 20 | 21 | doc = dict() 22 | doc['id'] = p['asin'] 23 | doc['contents'] = ' '.join([ 24 | p['Title'], 25 | p['Description'], 26 | p['BulletPoints'][0], 27 | option_text, 28 | ]).lower() 29 | doc['product'] = p 30 | docs.append(doc) 31 | 32 | 33 | with open('./resources_100/documents.jsonl', 'w+') as f: 34 | for doc in docs[:100]: 35 | f.write(json.dumps(doc) + '\n') 36 | 37 | with open('./resources/documents.jsonl', 'w+') as f: 38 | for doc in docs: 39 | f.write(json.dumps(doc) + '\n') 40 | 41 | with open('./resources_1k/documents.jsonl', 'w+') as f: 42 | for doc in docs[:1000]: 43 | f.write(json.dumps(doc) + '\n') 44 | 45 | with open('./resources_100k/documents.jsonl', 'w+') as f: 46 | for doc in docs[:100000]: 47 | f.write(json.dumps(doc) + '\n') 48 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/search_engine/lucene_searcher.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pyserini.search.lucene import LuceneSearcher 3 | from rich import print 4 | 5 | 6 | searcher = LuceneSearcher('indexes') 7 | hits = searcher.search('rubber sole shoes', k=20) 8 | 9 | for hit in hits: 10 | doc = searcher.doc(hit.docid) 11 | print(doc) 12 | obj = json.loads(doc.raw())['product']['Title'] 13 | print(obj) 14 | 15 | print(len(hits)) 16 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/search_engine/run_indexing.sh: -------------------------------------------------------------------------------- 1 | python -m pyserini.index.lucene \ 2 | --collection JsonCollection \ 3 | --input resources_100 \ 4 | --index indexes_100 \ 5 | --generator DefaultLuceneDocumentGenerator \ 6 | --threads 1 \ 7 | --storePositions --storeDocvectors --storeRaw 8 | 9 | python -m pyserini.index.lucene \ 10 | --collection JsonCollection \ 11 | --input resources \ 12 | --index indexes \ 13 | --generator DefaultLuceneDocumentGenerator \ 14 | --threads 1 \ 15 | --storePositions --storeDocvectors --storeRaw 16 | 17 | python -m pyserini.index.lucene \ 18 | --collection JsonCollection \ 19 | --input resources_1k \ 20 | --index indexes_1k \ 21 | --generator DefaultLuceneDocumentGenerator \ 22 | --threads 1 \ 23 | --storePositions --storeDocvectors --storeRaw 24 | 25 | python -m pyserini.index.lucene \ 26 | --collection JsonCollection \ 27 | --input resources_100k \ 28 | --index indexes_100k \ 29 | --generator DefaultLuceneDocumentGenerator \ 30 | --threads 1 \ 31 | --storePositions --storeDocvectors --storeRaw 32 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Displays information on how to use script 4 | helpFunction() 5 | { 6 | echo "Usage: $0 [-d small|all]" 7 | echo -e "\t-d small|all - Specify whether to download entire dataset (all) or just 1000 (small)" 8 | exit 1 # Exit script after printing help 9 | } 10 | 11 | # Get values of command line flags 12 | while getopts d: flag 13 | do 14 | case "${flag}" in 15 | d) data=${OPTARG};; 16 | esac 17 | done 18 | 19 | if [ -z "$data" ]; then 20 | echo "[ERROR]: Missing -d flag" 21 | helpFunction 22 | fi 23 | 24 | # Install Python Dependencies 25 | pip install -r requirements.txt; 26 | 27 | # Install Environment Dependencies via `conda` 28 | conda install -c pytorch faiss-cpu==1.7.2; 29 | conda install -c conda-forge openjdk=11; 30 | 31 | # Download dataset into `data` folder via `gdown` command 32 | mkdir -p data; 33 | cd data; 34 | if [ "$data" == "small" ]; then 35 | gdown https://drive.google.com/uc?id=1EgHdxQ_YxqIQlvvq5iKlCrkEKR6-j0Ib; # items_shuffle_1000 - product scraped info 36 | gdown https://drive.google.com/uc?id=1IduG0xl544V_A_jv3tHXC0kyFi7PnyBu; # items_ins_v2_1000 - product attributes 37 | elif [ "$data" == "all" ]; then 38 | gdown https://drive.google.com/uc?id=1A2whVgOO0euk5O13n2iYDM0bQRkkRduB; # items_shuffle 39 | gdown https://drive.google.com/uc?id=1s2j6NgHljiZzQNL3veZaAiyW_qDEgBNi; # items_ins_v2 40 | else 41 | echo "[ERROR]: argument for `-d` flag not recognized" 42 | helpFunction 43 | fi 44 | gdown https://drive.google.com/uc?id=14Kb5SPBk_jfdLZ_CDBNitW98QLDlKR5O # items_human_ins 45 | cd .. 46 | 47 | # Download spaCy large NLP model 48 | python -m spacy download en_core_web_lg 49 | 50 | # Build search engine index 51 | cd search_engine 52 | mkdir -p resources resources_100 resources_1k resources_100k 53 | python convert_product_file_format.py # convert items.json => required doc format 54 | mkdir -p indexes 55 | ./run_indexing.sh 56 | cd .. 57 | 58 | # Create logging folder + samples of log data 59 | get_human_trajs () { 60 | PYCMD=$(cat < 13 | 14 | The contents of this directory each serve the following purposes: 15 | * `app.py`: Run to launch interactive [Gradio](https://gradio.app/) demo of app 16 | * `predict_help.py`: Amazon, eBay web scraping code 17 | * `webshop_lite.py`: A condensed version of WebShop's templating engine 18 | 19 | If you are interested in *transferring an agent's functionality to an new website or platform*, you will need to... 20 | 1. implement two new functions: `parse_results_.py` and `parse_item_page_.py`. The corresponding interfaces and working examples for Amazon can be found [here](https://github.com/princeton-nlp/webshop/tree/master/transfer/predict_help.py#L262) and [here](https://github.com/princeton-nlp/webshop/tree/master/transfer/predict_help.py#L296). 21 | 2. Invoke these functions in the [`run_episode`](https://github.com/princeton-nlp/webshop/tree/master/transfer/app.py#L105) function in the `app.py` file. Specifically, you should add a single call to... 22 | * `parse_results...` in the [conditional]((https://github.com/princeton-nlp/webshop/tree/master/transfer/predict_help.py#L220)) handling `Page.RESULTS` page types 23 | * `parse_item_page...` in the [conditional]((https://github.com/princeton-nlp/webshop/tree/master/transfer/predict_help.py#L240)) handling `Page.ITEMS` page types -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/transfer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/transfer/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/web_agent_site/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/attributes/annotate.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | from rich import print 4 | 5 | ATTR_DIR = './data/attributes' 6 | 7 | ATTR_PATHS = [ 8 | 'narrow_2-gram.yaml', 9 | 'narrow_1-gram.yaml', 10 | 'broad_2-gram.yaml', 11 | 'broad_1-gram.yaml', 12 | ] 13 | ATTR_PATHS = [Path(ATTR_DIR) / af for af in ATTR_PATHS] 14 | 15 | 16 | def annotate(attr_path): 17 | with open(attr_path) as f: 18 | attrs_by_cat = yaml.safe_load(f) 19 | 20 | unique_attrs = set() 21 | all_attrs = [] 22 | for _, attrs in attrs_by_cat.items(): 23 | attrs = [a.split('|')[0].strip() for a in attrs] 24 | unique_attrs.update(attrs) 25 | all_attrs += attrs 26 | print(f'Total unique attributes: {len(unique_attrs)}') 27 | total = len(all_attrs) 28 | num_left = len(all_attrs) 29 | 30 | annotated_attrs_by_cat = dict() 31 | for category, attrs in attrs_by_cat.items(): 32 | print( 33 | f'Category: [ {category} ] | ' 34 | f'Number of attributes: {len(attrs)}\n' 35 | ) 36 | annotated_attrs = [] 37 | for i, attr in enumerate(attrs): 38 | attr, score = attr.split(' | ') 39 | print( 40 | f'{"[" + str(i) + "]":<5} ' 41 | f'[bold green]{attr:<30}[/bold green] | ' 42 | f'[red]{category}[/red] | ' 43 | f'{score}' 44 | ) 45 | tags = input( 46 | 'Annotate [1: ITEM, 2: PROP, 3: USE, ' 47 | '⎵: next example, q: next category] > ' 48 | ) 49 | print('\n') 50 | tags = tags.strip() 51 | annotated_attrs.append(f'{attr} | {score} | {tags}') 52 | if 'q' in tags: 53 | break 54 | 55 | num_left -= len(attrs) 56 | print(f'{num_left} / {total} total attributes left.') 57 | 58 | ans = input('Starting the next category... [y/n] > ') 59 | if ans == 'n': 60 | break 61 | 62 | def main(): 63 | for attr_path in ATTR_PATHS: 64 | annotate(attr_path) 65 | 66 | if __name__ == '__main__': 67 | """ 68 | python -m web_agent_site.attributes.annotate 69 | """ 70 | main() 71 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/web_agent_site/engine/__init__.py -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | from web_agent_site.envs.web_agent_site_env import WebAgentSiteEnv 4 | from web_agent_site.envs.web_agent_text_env import WebAgentTextEnv 5 | 6 | register( 7 | id='WebAgentSiteEnv-v0', 8 | entry_point='web_agent_site.envs:WebAgentSiteEnv', 9 | ) 10 | 11 | register( 12 | id='WebAgentTextEnv-v0', 13 | entry_point='web_agent_site.envs:WebAgentTextEnv', 14 | ) -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/models/__init__.py: -------------------------------------------------------------------------------- 1 | from web_agent_site.models.models import * 2 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/models/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model implementations. The model interface should be suitable for both 3 | the ``site env'' and the ``text env''. 4 | """ 5 | import json 6 | import random 7 | import re 8 | 9 | import requests 10 | 11 | random.seed(4) 12 | 13 | 14 | class BasePolicy: 15 | def __init__(self): 16 | pass 17 | 18 | def forward(self, observation, available_actions): 19 | """ 20 | Args: 21 | observation (`str`): 22 | HTML string 23 | 24 | available_actions (): 25 | ... 26 | Returns: 27 | action (`str`): 28 | Return string of the format ``action_name[action_arg]''. 29 | Examples: 30 | - search[white shoes] 31 | - click[button=Reviews] 32 | - click[button=Buy Now] 33 | """ 34 | raise NotImplementedError 35 | 36 | 37 | class HumanPolicy(BasePolicy): 38 | def __init__(self): 39 | super().__init__() 40 | 41 | def forward(self, observation, available_actions): 42 | action = input('> ') 43 | return action 44 | 45 | 46 | class RandomPolicy(BasePolicy): 47 | def __init__(self): 48 | super().__init__() 49 | 50 | def forward(self, observation, available_actions): 51 | if available_actions['has_search_bar']: 52 | action = 'search[shoes]' 53 | else: 54 | action_arg = random.choice(available_actions['clickables']) 55 | action = f'click[{action_arg}]' 56 | return action 57 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/static/images/no-image-available.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/Eval/webshop/src/server/tasks/webshop/web_agent_site/static/images/no-image-available.png -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/templates/description_page.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 |
16 |

Instruction:
{{ instruction_text }}

17 |
18 |
19 |
20 |
21 |
22 | 23 |
24 |
25 |
26 |
27 | 28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |

{{product_info.Description}}

36 |
37 |
38 |
39 |
40 |
41 |
42 | 43 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/templates/features_page.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 |
16 |

Instruction:
{{ instruction_text }}

17 |
18 |
19 |
20 |
21 |
22 | 23 |
24 |
25 |
26 |
27 | 28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
    36 | {% for bulletpoint in product_info.BulletPoints %} 37 |
  • {{bulletpoint}}

  • 38 | {% endfor %} 39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | 47 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/templates/review_page.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 |
16 |

Instruction:
{{ instruction_text }}

17 |
18 |
19 |
20 |
21 |
22 | 23 |
24 |
25 |
26 |
27 | 28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | {% for review in product_info.Reviews %} 36 |
37 |
38 |

"{{review.title}}"

39 |

40 | {{review.score}} 41 | {% for i in range(review.score | int) %} 42 | 43 | {% endfor %} 44 | {% for i in range(5 - review.score | int) %} 45 | 46 | {% endfor %} 47 |

48 |

{{review.body}}

49 |
50 |
51 | {% endfor %} 52 |
53 |
54 |
55 |
56 |
57 |
58 | 59 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/templates/search_page.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 |
13 |
14 | 17 |
18 |

Instruction:
{{ instruction_text }}

19 |
20 |
21 |
22 |
23 | 24 | 25 | 26 | 27 |
28 |
29 |
30 |
31 |
32 |
33 | 34 | -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop/web_agent_site/utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import hashlib 3 | import logging 4 | import random 5 | from os.path import dirname, abspath, join 6 | 7 | BASE_DIR = dirname(abspath(__file__)) 8 | DEBUG_PROD_SIZE = None # set to `None` to disable 9 | 10 | DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2_1000.json') 11 | # DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2.json') 12 | DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle_1000.json') 13 | # DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle.json') 14 | DEFAULT_REVIEW_PATH = join(BASE_DIR, '../data/reviews.json') 15 | 16 | FEAT_CONV = join(BASE_DIR, '../data/feat_conv.pt') 17 | FEAT_IDS = join(BASE_DIR, '../data/feat_ids.pt') 18 | 19 | HUMAN_ATTR_PATH = join(BASE_DIR, '../data/items_human_ins.json') 20 | HUMAN_ATTR_PATH = join(BASE_DIR, '../data/items_human_ins.json') 21 | 22 | def random_idx(cum_weights): 23 | """Generate random index by sampling uniformly from sum of all weights, then 24 | selecting the `min` between the position to keep the list sorted (via bisect) 25 | and the value of the second to last index 26 | """ 27 | pos = random.uniform(0, cum_weights[-1]) 28 | idx = bisect.bisect(cum_weights, pos) 29 | idx = min(idx, len(cum_weights) - 2) 30 | return idx 31 | 32 | def setup_logger(session_id, user_log_dir): 33 | """Creates a log file and logging object for the corresponding session ID""" 34 | logger = logging.getLogger(session_id) 35 | formatter = logging.Formatter('%(message)s') 36 | file_handler = logging.FileHandler( 37 | user_log_dir / f'{session_id}.jsonl', 38 | mode='w' 39 | ) 40 | file_handler.setFormatter(formatter) 41 | logger.setLevel(logging.INFO) 42 | logger.addHandler(file_handler) 43 | return logger 44 | 45 | def generate_mturk_code(session_id: str) -> str: 46 | """Generates a redeem code corresponding to the session ID for an MTurk 47 | worker once the session is completed 48 | """ 49 | sha = hashlib.sha1(session_id.encode()) 50 | return sha.hexdigest()[:10].upper() -------------------------------------------------------------------------------- /Eval/webshop/src/server/tasks/webshop_docker: -------------------------------------------------------------------------------- 1 | /root/webshop -------------------------------------------------------------------------------- /Eval/webshop/src/typings/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .exception import * 3 | from .general import * 4 | from .output import * 5 | from .request import * 6 | from .status import * 7 | -------------------------------------------------------------------------------- /Eval/webshop/src/typings/exception.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | 4 | class AgentBenchException(Exception): 5 | pass 6 | 7 | 8 | class ClientException(AgentBenchException): 9 | def __init__(self, reason: str, detail: Union[str, None] = None) -> None: 10 | super().__init__() 11 | self.reason = reason 12 | self.detail = detail 13 | 14 | def __str__(self) -> str: 15 | if not self.detail: 16 | return "{CLASS_NAME}[{REASON}]".format( 17 | CLASS_NAME=self.__class__.__name__, REASON=self.reason 18 | ) 19 | else: 20 | return "{CLASS_NAME}[{REASON}]: {DETAIL}".format( 21 | CLASS_NAME=self.__class__.__name__, 22 | REASON=self.reason, 23 | DETAIL=self.detail, 24 | ) 25 | 26 | 27 | class ServerException(AgentBenchException): 28 | pass 29 | 30 | 31 | class AgentClientException(ClientException): 32 | pass 33 | 34 | 35 | class TaskClientException(ClientException): 36 | pass 37 | 38 | 39 | class AgentContextLimitException(AgentClientException): 40 | def __init__(self, detail: Union[str, None] = None) -> None: 41 | super().__init__("agent_context_limit", detail) 42 | 43 | 44 | class AgentTimeoutException(AgentClientException): 45 | def __init__(self, detail: Union[str, None] = None) -> None: 46 | super().__init__("agent_timeout", detail) 47 | 48 | 49 | class AgentNetworkException(AgentClientException): 50 | def __init__(self, detail: Union[str, None] = None) -> None: 51 | super().__init__("agent_network", detail) 52 | 53 | 54 | class TaskTimeoutException(TaskClientException): 55 | def __init__(self, detail: Union[str, None] = None) -> None: 56 | super().__init__("task_timeout", detail) 57 | 58 | 59 | class TaskNetworkException(TaskClientException): 60 | def __init__(self, detail: Union[str, None] = None) -> None: 61 | super().__init__("task_network", detail) 62 | -------------------------------------------------------------------------------- /Eval/webshop/src/typings/general.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | from typing import List, Dict, Union, Any, Literal 3 | 4 | from pydantic import BaseModel, validator 5 | 6 | JSONSerializable = Union[None, bool, int, float, str, List[Any], Dict[str, Any]] 7 | SampleIndex = Union[int, str] 8 | 9 | 10 | class InstanceFactory(BaseModel): 11 | module: str 12 | parameters: Dict[str, Any] = {} 13 | 14 | @validator("parameters", pre=True) 15 | def _ensure_dict(cls, v): 16 | if v is None: 17 | return {} 18 | return v 19 | 20 | def create(self): 21 | # print('>>>>>>>> ', self.module, self.parameters) 22 | splits = self.module.split(".") 23 | if len(splits) == 0: 24 | raise Exception("Invalid module name: {}".format(self.module)) 25 | if len(splits) == 1: 26 | g = globals() 27 | if self.module in g: 28 | class_type = g[self.module] 29 | else: 30 | class_type = getattr(builtins, self.module) 31 | return class_type(**self.parameters) 32 | else: 33 | path = ".".join(self.module.split(".")[:-1]) 34 | mod = __import__(path, fromlist=[self.module.split(".")[-1]]) 35 | #import pdb 36 | #pdb.set_trace() 37 | return getattr(mod, self.module.split(".")[-1])(**self.parameters) 38 | 39 | 40 | class Assignment(BaseModel): 41 | agent: str 42 | task: str 43 | 44 | 45 | class ChatHistoryItem(BaseModel): 46 | role: Literal["user", "agent"] 47 | content: str 48 | -------------------------------------------------------------------------------- /Eval/webshop/src/typings/output.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | from pydantic import BaseModel, root_validator 4 | 5 | from . import ChatHistoryItem 6 | from .general import JSONSerializable, SampleIndex 7 | from .status import SampleStatus, AgentOutputStatus 8 | 9 | 10 | class TaskOutput(BaseModel): 11 | index: Union[None, SampleIndex] = None 12 | status: SampleStatus = SampleStatus.RUNNING 13 | result: JSONSerializable = None 14 | history: Union[None, List[ChatHistoryItem]] = None 15 | 16 | 17 | class TaskSampleExecutionResult(BaseModel): 18 | status: SampleStatus = SampleStatus.COMPLETED 19 | result: JSONSerializable = None 20 | 21 | 22 | class AgentOutput(BaseModel): 23 | status: AgentOutputStatus = AgentOutputStatus.NORMAL 24 | content: Union[str, None] = None 25 | 26 | # at least one of them should be not None 27 | @root_validator(pre=False, skip_on_failure=True) 28 | def post_validate(cls, instance: dict): 29 | assert ( 30 | instance.get("status") is not AgentOutputStatus.NORMAL 31 | or instance.get("content") is not None 32 | ), "If status is NORMAL, content should not be None" 33 | return instance 34 | 35 | 36 | class TaskClientOutput(BaseModel): 37 | error: Union[str, None] = None 38 | info: Union[str, None] = None 39 | output: Union[TaskOutput, None] = None 40 | -------------------------------------------------------------------------------- /Eval/webshop/src/typings/request.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | from .general import SampleIndex 6 | from .output import AgentOutput, TaskOutput 7 | 8 | 9 | class RegisterRequest(BaseModel): 10 | name: str 11 | address: str 12 | concurrency: int 13 | indices: list 14 | 15 | 16 | class StartSampleRequest(BaseModel): 17 | name: str 18 | index: SampleIndex 19 | 20 | 21 | class InteractRequest(BaseModel): 22 | session_id: int 23 | agent_response: AgentOutput 24 | 25 | 26 | class CancelRequest(BaseModel): 27 | session_id: int 28 | 29 | 30 | class HeartbeatRequest(BaseModel): 31 | name: str 32 | address: str 33 | 34 | 35 | class CalculateOverallRequest(BaseModel): 36 | name: str 37 | results: List[TaskOutput] 38 | 39 | 40 | class WorkerStartSampleRequest(BaseModel): 41 | index: SampleIndex 42 | session_id: int 43 | 44 | 45 | class SampleStatusRequest(BaseModel): 46 | session_id: int 47 | -------------------------------------------------------------------------------- /Eval/webshop/src/typings/status.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum, Enum 2 | 3 | 4 | class SampleStatus(str, Enum): 5 | RUNNING = "running" 6 | COMPLETED = "completed" 7 | AGENT_CONTEXT_LIMIT = "agent context limit" 8 | AGENT_VALIDATION_FAILED = "agent validation failed" 9 | AGENT_INVALID_ACTION = "agent invalid action" 10 | TASK_LIMIT_REACHED = "task limit reached" 11 | UNKNOWN = "unknown" 12 | TASK_ERROR = "task error" 13 | 14 | 15 | class WorkerStatus(IntEnum): 16 | ALIVE = 0 17 | COMA = 1 18 | DEAD = 2 19 | 20 | 21 | class AgentOutputStatus(str, Enum): 22 | NORMAL = "normal" 23 | CANCELLED = "cancelled" 24 | AGENT_CONTEXT_LIMIT = "agent context limit" 25 | -------------------------------------------------------------------------------- /Eval/webshop/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .max_flow import Graph, MaxFlow 2 | from .others import * 3 | from .rules import * 4 | -------------------------------------------------------------------------------- /Eval/webshop/src/utils/others.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | 6 | class JsonEncoder(json.JSONEncoder): 7 | """Convert numpy classes to JSON serializable objects.""" 8 | 9 | def default(self, obj): 10 | if isinstance(obj, (np.integer, np.floating, np.bool_)): 11 | return obj.item() 12 | elif isinstance(obj, np.ndarray): 13 | return obj.tolist() 14 | else: 15 | return super(JsonEncoder, self).default(obj) 16 | 17 | 18 | def serialize(obj, max_depth=5, compress=False): 19 | """ 20 | dump into json, including only basic types, list types and dict types. 21 | If other types are included, they will be converted into string. 22 | """ 23 | if max_depth <= 0: 24 | return "..." 25 | if isinstance(obj, (int, float, str, bool, type(None))): 26 | return obj 27 | elif isinstance(obj, list) or isinstance(obj, tuple): 28 | if not compress or len(obj) <= 5: 29 | return [serialize(item, max_depth - 1, compress) for item in obj] 30 | else: 31 | return [serialize(item, max_depth - 1, True) for item in obj[:5]] + [ 32 | "...(total: %d)" % len(obj) 33 | ] 34 | elif isinstance(obj, dict): 35 | if not compress or len(obj) <= 5: 36 | return { 37 | str(key): serialize(obj[key], max_depth - 1, compress) for key in obj 38 | } 39 | else: 40 | ret = { 41 | str(key): serialize(obj[key], max_depth - 1, True) 42 | for key in list(obj.keys())[:5] 43 | } 44 | ret["...total..."] = len(obj) 45 | return ret 46 | elif hasattr(obj, "__dict__"): 47 | return serialize(obj.__dict__, max_depth, True) 48 | else: 49 | ret = str(obj) 50 | if len(ret) > 100: 51 | ret = ret[:45] + " ... " + ret[-45:] 52 | return ret 53 | 54 | 55 | class ColorMessage: 56 | @staticmethod 57 | def red(msg): 58 | return "\033[91m" + msg + "\033[0m" 59 | 60 | @staticmethod 61 | def green(msg): 62 | return "\033[92m" + msg + "\033[0m" 63 | 64 | @staticmethod 65 | def cyan(msg): 66 | return "\033[96m" + msg + "\033[0m" 67 | 68 | @staticmethod 69 | def yellow(msg): 70 | return "\033[93m" + msg + "\033[0m" 71 | -------------------------------------------------------------------------------- /Eval/webshop/src/utils/rules.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class RuleBase: 5 | def check(self, obj) -> bool: 6 | raise NotImplementedError() 7 | 8 | 9 | class ContainRule(RuleBase): 10 | def __init__(self, target, reverse=False) -> None: 11 | """ 12 | Check if target is in obj. 13 | `reverse`: if True, check if obj is in target 14 | """ 15 | self.target = target 16 | self.reverse = reverse 17 | 18 | def check(self, obj) -> bool: 19 | if self.reverse: 20 | return obj in self.target 21 | else: 22 | return self.target in obj 23 | 24 | 25 | class NotRule(RuleBase): 26 | def __init__(self, rule: RuleBase) -> None: 27 | self.rule = rule 28 | 29 | def check(self, obj) -> bool: 30 | return not self.rule.check(obj) 31 | 32 | 33 | class AndRule(RuleBase): 34 | def __init__(self, rules: List[RuleBase]) -> None: 35 | self.rules = rules 36 | 37 | def check(self, obj) -> bool: 38 | return all(rule.check(obj) for rule in self.rules) 39 | 40 | 41 | class OrRule(RuleBase): 42 | def __init__(self, rules: List[RuleBase]) -> None: 43 | self.rules = rules 44 | 45 | def check(self, obj) -> bool: 46 | return any(rule.check(obj) for rule in self.rules) 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARMAP: Scaling Autonomous Agents via Automatic Reward Modeling And Planning 2 | 3 | [ [Project Page](https://armap-agent.github.io) | [Paper](https://arxiv.org/abs/2502.12130) | [Model Weights](https://huggingface.co/Heaplax/ARMAP-RM-LoRA)] 4 | 5 | ![teaser](assets/teaser.svg) 6 | 7 | This repository contains the official code for ARMAP: Scaling Autonomous Agents via Automatic Reward Modeling And Planning 8 | 9 | ## Reward Model 10 | Reward model training pipline is provided in [the reward model directory](./RM). 11 | 12 | ## Inference 13 | After starting the reward model server, please refer to the eval folder for each task: 14 | 15 | [webshop](./Eval/webshop) 16 | 17 | [game24](./Eval/game24) 18 | 19 | [sciworld](./Eval/sciworld) 20 | 21 | 22 | ## Citation 23 | If you use this work or find it helpful, please consider citing: (bibtex) 24 | ``` 25 | @misc{chen2025scalingautonomousagentsautomatic, 26 | title={Scaling Autonomous Agents via Automatic Reward Modeling And Planning}, 27 | author={Zhenfang Chen and Delin Chen and Rui Sun and Wenjun Liu and Chuang Gan}, 28 | year={2025}, 29 | eprint={2502.12130}, 30 | archivePrefix={arXiv}, 31 | primaryClass={cs.AI}, 32 | url={https://arxiv.org/abs/2502.12130}, 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /RM/data/put_training_data_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/RM/data/put_training_data_here -------------------------------------------------------------------------------- /RM/data_utils/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The LLaVA-RLHF Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # PPO Constants 17 | from enum import Enum 18 | 19 | # FACTUAL_PROMPT = "Specifically, the AI's response should be fully supported by the combination of the following captions:\n" 20 | FACTUAL_PROMPT = "\n" 21 | 22 | class AnswerType(Enum): 23 | GENERAL = 1 24 | A_IN_ABCD = 2 25 | B_IN_ABCD = 3 26 | C_IN_ABCD = 4 27 | D_IN_ABCD = 5 28 | NO_IN_YESNO = 6 29 | YES_IN_YESNO = 7 30 | -------------------------------------------------------------------------------- /RM/modifications/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "./llm", 3 | "architectures": [ 4 | "LlavaLlamaForCausalLM" 5 | ], 6 | "attention_bias": false, 7 | "attention_dropout": 0.0, 8 | "bos_token_id": 1, 9 | "eos_token_id": 2, 10 | "hidden_act": "silu", 11 | "hidden_size": 2560, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 6912, 14 | "max_position_embeddings": 4096, 15 | "model_max_length": 4096, 16 | "model_type": "llama", 17 | "num_attention_heads": 20, 18 | "num_hidden_layers": 32, 19 | "num_key_value_heads": 20, 20 | "pad_token_id": 0, 21 | "pretraining_tp": 1, 22 | "rms_norm_eps": 1e-05, 23 | "rope_scaling": null, 24 | "rope_theta": 10000.0, 25 | "tie_word_embeddings": false, 26 | "tokenizer_model_max_length": 4096, 27 | "tokenizer_padding_side": "right", 28 | "torch_dtype": "bfloat16", 29 | "transformers_version": "4.36.2", 30 | "use_cache": true, 31 | "vocab_size": 32000, 32 | "mm_vision_select_feature": "patch", 33 | "mm_vision_select_layer": -2, 34 | "mm_hidden_size": 1152, 35 | "mm_vision_tower": "./rm/VILA1.5-3b/vision_tower/" 36 | } 37 | -------------------------------------------------------------------------------- /RM/modifications/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | from .siglip_encoder import SiglipVisionTower 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 6 | if vision_tower.startswith("openai") or vision_tower.startswith("laion"): 7 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 8 | elif "vila" in vision_tower.lower(): 9 | vision_tower = SiglipVisionTower(vision_tower, vision_tower_cfg) 10 | return vision_tower 11 | 12 | raise ValueError(f'Unknown vision tower: {vision_tower}') 13 | -------------------------------------------------------------------------------- /RM/modifications/multimodal_encoder/siglip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2 3 | 4 | from transformers import AutoConfig, PretrainedConfig, AutoModel 5 | from .siglip import ( 6 | SiglipVisionConfig, 7 | SiglipVisionModel, 8 | SiglipImageProcessor, 9 | ) 10 | 11 | 12 | class SiglipVisionTower(VisionTower): 13 | def __init__(self, model_name_or_path: str, config: PretrainedConfig, state_dict=None): 14 | super().__init__(model_name_or_path, config) 15 | current_device = torch.cuda.current_device() 16 | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) 17 | #self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path,device_map={"": current_device},) 18 | # import pdb 19 | # pdb.set_trace() 20 | self.vision_tower = SiglipVisionModel.from_pretrained( 21 | # TODO(ligeng): why pass config here leading to errors? 22 | model_name_or_path,state_dict=state_dict 23 | #model_name_or_path,device_map={"": current_device}, state_dict=state_dict 24 | ) 25 | self.is_loaded = True 26 | 27 | 28 | class SiglipVisionTowerS2(VisionTowerS2): 29 | def __init__(self, model_name_or_path: str, config: PretrainedConfig): 30 | super().__init__(model_name_or_path, config) 31 | self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) 32 | self.vision_tower = SiglipVisionModel.from_pretrained( 33 | model_name_or_path, torch_dtype=eval(config.model_dtype) 34 | ) 35 | 36 | # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information 37 | self.image_processor.size['height'] = self.image_processor.size['width'] = self.scales[-1] 38 | 39 | self.is_loaded = True 40 | 41 | 42 | AutoConfig.register("siglip_vision_model", SiglipVisionConfig) 43 | AutoModel.register(SiglipVisionConfig, SiglipVisionModel) 44 | 45 | -------------------------------------------------------------------------------- /RM/prompts/multi_frame_prompt.txt: -------------------------------------------------------------------------------- 1 | USER: Current screenshot: . Observation: {observation0}. Navigation Intent: {Intent} 2 | 3 | ASSISTANT: {response0} 4 | 5 | USER: Current screenshot: . Observation: {observation1}. 6 | 7 | ASSISTANT: {response1} 8 | 9 | USER: Please evaluate whether your last response achieves the "Navigation Intent" or not. 10 | 11 | ASSISTANT: Following your definitions, the score of my last response is -------------------------------------------------------------------------------- /RM/rm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/RM/rm/__init__.py -------------------------------------------------------------------------------- /RM/scripts/server_reward_sciworld.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | set -x 3 | 4 | export FLASK_PORT=15678 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | export MODEL_DIR="rm" 8 | export PYTHONPATH="$PWD:$PYTHONPATH" 9 | export GPUS_PER_NODE=1 10 | export OMP_NUM_THREADS=1 11 | 12 | # MODEL CONFIG 13 | VISION_TOWER=VILA1.5-3b/vision_tower 14 | LM_MODEL_NAME=VILA1.5-3b 15 | MM_PROJECTOR=VILA1.5-3b/mm_projector 16 | 17 | # SAVE CONFIG 18 | MODEL_NAME=RM-sciworld 19 | LORA_CKPT=checkpoint-120 20 | 21 | # TRAINING CONFIG 22 | NUM_EPOCHS=3 23 | LEARNING_RATE=5e-4 24 | BATCH_SIZE=1 25 | GRAD_ACCUMULATION=1 26 | 27 | python server_lora_rm.py \ 28 | --do_eval \ 29 | --seed 42 \ 30 | --per_device_train_batch_size $BATCH_SIZE \ 31 | --per_device_eval_batch_size $BATCH_SIZE \ 32 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 33 | --model_name_or_path $MODEL_DIR/$LM_MODEL_NAME/llm \ 34 | --vision_tower $MODEL_DIR/$VISION_TOWER \ 35 | --learning_rate $LEARNING_RATE \ 36 | --mm_vision_select_layer -2 \ 37 | --mm_use_im_start_end False \ 38 | --mm_use_im_patch_token False \ 39 | --freeze_mm_mlp_adapter True \ 40 | --model_max_length 4096 \ 41 | --query_len 4000 \ 42 | --response_len 2000 \ 43 | --dataset_name "none" \ 44 | --eval_dataset_name "none" \ 45 | --eval_size 1 \ 46 | --bits 16 \ 47 | --lora_r 64 \ 48 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 49 | --output_dir "$MODEL_DIR/$MODEL_NAME" \ 50 | --lora_dir "$MODEL_DIR/$MODEL_NAME/$LORA_CKPT" \ 51 | --num_train_epochs $NUM_EPOCHS \ 52 | --group_by_length False \ 53 | --evaluation_strategy "steps" \ 54 | --eval_steps 50 \ 55 | --save_strategy "steps" \ 56 | --save_steps 50 \ 57 | --save_total_limit 10 \ 58 | --weight_decay 0.0 \ 59 | --warmup_ratio 0.03 \ 60 | --lr_scheduler_type "constant_with_warmup" \ 61 | --logging_steps 5 \ 62 | --report_to "tensorboard" \ 63 | --ddp_backend "nccl" \ 64 | --bf16 True \ 65 | --ddp_find_unused_parameters False \ 66 | --resume_from_training True \ 67 | --image_aspect_ratio 'resize' 68 | -------------------------------------------------------------------------------- /RM/scripts/train_reward_alfworld.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | export DATA_DIR="data" 8 | export MODEL_DIR="rm" 9 | export PYTHONPATH="$PWD:$PYTHONPATH" 10 | export GPUS_PER_NODE=1 11 | export OMP_NUM_THREADS=1 12 | export CUDA_LAUNCH_BLOCKING=1 13 | 14 | # MODEL CONFIG 15 | VISION_TOWER=VILA1.5-3b/vision_tower 16 | LM_MODEL_NAME=VILA1.5-3b 17 | MM_PROJECTOR=VILA1.5-3b/mm_projector 18 | 19 | # DATA CONFIG 20 | PREFERENCE_DATA=alfworld_preference.json 21 | CAPTION_DATA=alfworld_factual.json 22 | IMAGE_FOLDER=alfworld_image 23 | 24 | 25 | # SAVE CONFIG 26 | MODEL_NAME=RM-alfworld 27 | 28 | # TRAINING CONFIG 29 | NUM_EPOCHS=10 30 | LEARNING_RATE=1e-5 31 | BATCH_SIZE=1 32 | GRAD_ACCUMULATION=8 33 | 34 | python finetune_lora_rm.py \ 35 | --do_train \ 36 | --do_eval \ 37 | --seed 42 \ 38 | --per_device_train_batch_size $BATCH_SIZE \ 39 | --per_device_eval_batch_size $BATCH_SIZE \ 40 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 41 | --model_name_or_path $MODEL_DIR/$LM_MODEL_NAME/llm \ 42 | --tokenizer $MODEL_DIR/$LM_MODEL_NAME/llm \ 43 | --image_folder $DATA_DIR/$IMAGE_FOLDER \ 44 | --vision_tower $MODEL_DIR/$VISION_TOWER \ 45 | --mm_projector_path $MODEL_DIR/$MM_PROJECTOR \ 46 | --learning_rate $LEARNING_RATE \ 47 | --mm_vision_select_layer -2 \ 48 | --mm_vision_select_feature cls_patch \ 49 | --mm_use_im_start_end False \ 50 | --mm_use_im_patch_token False \ 51 | --freeze_mm_mlp_adapter False \ 52 | --model_max_length 4096 \ 53 | --query_len 8000 \ 54 | --response_len 2000 \ 55 | --dataset_path $DATA_DIR/$PREFERENCE_DATA \ 56 | --eval_dataset_path $DATA_DIR/$PREFERENCE_DATA \ 57 | --dataset_name "none" \ 58 | --eval_dataset_name "none" \ 59 | --eval_size 50 \ 60 | --bits 16 \ 61 | --lora_r 64 \ 62 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 63 | --output_dir "$MODEL_DIR/$MODEL_NAME" \ 64 | --num_train_epochs $NUM_EPOCHS \ 65 | --group_by_length False \ 66 | --evaluation_strategy "steps" \ 67 | --eval_steps 50 \ 68 | --save_strategy "steps" \ 69 | --save_steps 100 \ 70 | --save_total_limit 10 \ 71 | --weight_decay 0.0 \ 72 | --warmup_ratio 0.1 \ 73 | --lr_scheduler_type "constant_with_warmup" \ 74 | --logging_steps 5 \ 75 | --report_to "tensorboard" \ 76 | --ddp_backend "nccl" \ 77 | --bf16 True \ 78 | --ddp_find_unused_parameters False \ 79 | --resume_from_training True \ 80 | --reward_prompt_file "./prompts/multi_frame_prompt.txt" \ 81 | --image_to_caption_file $DATA_DIR/$CAPTION_DATA \ 82 | --image_aspect_ratio 'resize' 83 | -------------------------------------------------------------------------------- /RM/scripts/train_reward_game24.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | 7 | export CUDA_VISIBLE_DEVICES=0,1,2,3 8 | export DATA_DIR="data" 9 | export MODEL_DIR="rm" 10 | export PYTHONPATH="$PWD:$PYTHONPATH" 11 | export GPUS_PER_NODE=4 12 | export OMP_NUM_THREADS=4 13 | 14 | # MODEL CONFIG 15 | VISION_TOWER=VILA1.5-3b/vision_tower 16 | LM_MODEL_NAME=VILA1.5-3b 17 | MM_PROJECTOR=VILA1.5-3b/mm_projector 18 | 19 | # DATA CONFIG 20 | PREFERENCE_DATA=game24_preference.json 21 | CAPTION_DATA=game24_factual.json 22 | IMAGE_FOLDER=game24_image 23 | 24 | # SAVE CONFIG 25 | MODEL_NAME=RM-game24 26 | 27 | # TRAINING CONFIG 28 | NUM_EPOCHS=10 29 | LEARNING_RATE=1e-5 30 | BATCH_SIZE=1 31 | GRAD_ACCUMULATION=16 32 | 33 | torchrun \ 34 | --standalone \ 35 | --nnodes=1 \ 36 | --nproc-per-node=$GPUS_PER_NODE \ 37 | finetune_lora_rm.py \ 38 | --do_train \ 39 | --do_eval \ 40 | --seed 42 \ 41 | --per_device_train_batch_size $BATCH_SIZE \ 42 | --per_device_eval_batch_size $BATCH_SIZE \ 43 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 44 | --model_name_or_path $MODEL_DIR/$LM_MODEL_NAME/llm \ 45 | --tokenizer $MODEL_DIR/$LM_MODEL_NAME/llm \ 46 | --image_folder $DATA_DIR/$IMAGE_FOLDER \ 47 | --vision_tower $MODEL_DIR/$VISION_TOWER \ 48 | --mm_projector_path $MODEL_DIR/$MM_PROJECTOR \ 49 | --learning_rate $LEARNING_RATE \ 50 | --mm_vision_select_layer -2 \ 51 | --mm_vision_select_feature cls_patch \ 52 | --mm_use_im_start_end False \ 53 | --mm_use_im_patch_token False \ 54 | --freeze_mm_mlp_adapter False \ 55 | --model_max_length 4000 \ 56 | --query_len 4000 \ 57 | --response_len 450 \ 58 | --dataset_path $DATA_DIR/$PREFERENCE_DATA \ 59 | --eval_dataset_path $DATA_DIR/$PREFERENCE_DATA \ 60 | --dataset_name "none" \ 61 | --eval_dataset_name "none" \ 62 | --eval_size 50 \ 63 | --bits 16 \ 64 | --lora_r 64 \ 65 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 66 | --output_dir "$MODEL_DIR/$MODEL_NAME" \ 67 | --num_train_epochs $NUM_EPOCHS \ 68 | --group_by_length False \ 69 | --evaluation_strategy "steps" \ 70 | --eval_steps 25 \ 71 | --save_strategy "steps" \ 72 | --save_steps 500 \ 73 | --save_total_limit 30 \ 74 | --weight_decay 0.0 \ 75 | --warmup_ratio 0.1 \ 76 | --lr_scheduler_type "constant_with_warmup" \ 77 | --logging_steps 5 \ 78 | --report_to "tensorboard" \ 79 | --ddp_backend "nccl" \ 80 | --bf16 True \ 81 | --ddp_find_unused_parameters False \ 82 | --resume_from_training True \ 83 | --reward_prompt_file "./prompts/multi_frame_prompt.txt" \ 84 | --image_to_caption_file $DATA_DIR/$CAPTION_DATA \ 85 | --image_aspect_ratio 'resize' 86 | -------------------------------------------------------------------------------- /RM/scripts/train_reward_sciworld.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | 7 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 8 | export DATA_DIR="data" 9 | export MODEL_DIR="rm" 10 | export PYTHONPATH="$PWD:$PYTHONPATH" 11 | export GPUS_PER_NODE=8 12 | export OMP_NUM_THREADS=8 13 | 14 | # MODEL CONFIG 15 | VISION_TOWER=VILA1.5-3b/vision_tower 16 | LM_MODEL_NAME=VILA1.5-3b 17 | MM_PROJECTOR=VILA1.5-3b/mm_projector 18 | 19 | # DATA CONFIG 20 | PREFERENCE_DATA=sciworld_preference.json 21 | CAPTION_DATA=sciworld_factual.json 22 | IMAGE_FOLDER=sciworld_image 23 | 24 | # SAVE CONFIG 25 | MODEL_NAME=RM-sciworld 26 | 27 | # TRAINING CONFIG 28 | NUM_EPOCHS=10 29 | LEARNING_RATE=1e-5 30 | BATCH_SIZE=1 31 | GRAD_ACCUMULATION=16 32 | 33 | torchrun \ 34 | --standalone \ 35 | --nnodes=1 \ 36 | --nproc-per-node=$GPUS_PER_NODE \ 37 | finetune_lora_rm.py \ 38 | --do_train \ 39 | --do_eval \ 40 | --seed 42 \ 41 | --per_device_train_batch_size $BATCH_SIZE \ 42 | --per_device_eval_batch_size $BATCH_SIZE \ 43 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 44 | --model_name_or_path $MODEL_DIR/$LM_MODEL_NAME/llm \ 45 | --tokenizer $MODEL_DIR/$LM_MODEL_NAME/llm \ 46 | --image_folder $DATA_DIR/$IMAGE_FOLDER \ 47 | --vision_tower $MODEL_DIR/$VISION_TOWER \ 48 | --mm_projector_path $MODEL_DIR/$MM_PROJECTOR \ 49 | --learning_rate $LEARNING_RATE \ 50 | --mm_vision_select_layer -2 \ 51 | --mm_vision_select_feature cls_patch \ 52 | --mm_use_im_start_end False \ 53 | --mm_use_im_patch_token False \ 54 | --freeze_mm_mlp_adapter False \ 55 | --model_max_length 4096 \ 56 | --query_len 4000 \ 57 | --response_len 2000 \ 58 | --dataset_path $DATA_DIR/$PREFERENCE_DATA \ 59 | --eval_dataset_path $DATA_DIR/$PREFERENCE_DATA \ 60 | --dataset_name "none" \ 61 | --eval_dataset_name "none" \ 62 | --eval_size 500 \ 63 | --bits 16 \ 64 | --lora_r 64 \ 65 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 66 | --output_dir "$MODEL_DIR/$MODEL_NAME" \ 67 | --num_train_epochs $NUM_EPOCHS \ 68 | --group_by_length False \ 69 | --evaluation_strategy "steps" \ 70 | --eval_steps 20 \ 71 | --save_strategy "steps" \ 72 | --save_steps 40 \ 73 | --save_total_limit 10 \ 74 | --weight_decay 0.0 \ 75 | --warmup_ratio 0.1 \ 76 | --lr_scheduler_type "constant_with_warmup" \ 77 | --logging_steps 5 \ 78 | --report_to "tensorboard" \ 79 | --ddp_backend "nccl" \ 80 | --bf16 True \ 81 | --ddp_find_unused_parameters False \ 82 | --resume_from_training True \ 83 | --reward_prompt_file "./prompts/multi_frame_prompt.txt" \ 84 | --image_to_caption_file $DATA_DIR/$CAPTION_DATA \ 85 | --image_aspect_ratio 'resize' 86 | -------------------------------------------------------------------------------- /RM/scripts/train_reward_webshop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0 7 | export DATA_DIR="data" 8 | export MODEL_DIR="rm" 9 | export PYTHONPATH="$PWD:$PYTHONPATH" 10 | export GPUS_PER_NODE=1 11 | export OMP_NUM_THREADS=1 12 | export CUDA_LAUNCH_BLOCKING=1 13 | 14 | # MODEL CONFIG 15 | VISION_TOWER=VILA1.5-3b/vision_tower 16 | LM_MODEL_NAME=VILA1.5-3b 17 | MM_PROJECTOR=VILA1.5-3b/mm_projector 18 | 19 | # DATA CONFIG 20 | PREFERENCE_DATA=webshop_preference.json 21 | CAPTION_DATA=webshop_factual.json 22 | IMAGE_FOLDER=webshop_image 23 | 24 | # SAVE CONFIG 25 | MODEL_NAME=RM-webshop 26 | 27 | # TRAINING CONFIG 28 | NUM_EPOCHS=3 29 | LEARNING_RATE=2e-5 30 | BATCH_SIZE=1 31 | GRAD_ACCUMULATION=4 32 | 33 | python finetune_lora_rm.py \ 34 | --do_train \ 35 | --do_eval \ 36 | --seed 42 \ 37 | --per_device_train_batch_size $BATCH_SIZE \ 38 | --per_device_eval_batch_size $BATCH_SIZE \ 39 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 40 | --model_name_or_path $MODEL_DIR/$LM_MODEL_NAME/llm \ 41 | --tokenizer $MODEL_DIR/$LM_MODEL_NAME/llm \ 42 | --image_folder $DATA_DIR/$IMAGE_FOLDER \ 43 | --vision_tower $MODEL_DIR/$VISION_TOWER \ 44 | --mm_projector_path $MODEL_DIR/$MM_PROJECTOR \ 45 | --learning_rate $LEARNING_RATE \ 46 | --mm_vision_select_layer -2 \ 47 | --mm_vision_select_feature cls_patch \ 48 | --mm_use_im_start_end False \ 49 | --mm_use_im_patch_token False \ 50 | --freeze_mm_mlp_adapter False \ 51 | --model_max_length 8192 \ 52 | --query_len 8000 \ 53 | --response_len 450 \ 54 | --dataset_path $DATA_DIR/$PREFERENCE_DATA \ 55 | --eval_dataset_path $DATA_DIR/$PREFERENCE_DATA \ 56 | --dataset_name "none" \ 57 | --eval_dataset_name "none" \ 58 | --eval_size 50 \ 59 | --bits 16 \ 60 | --lora_r 64 \ 61 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 62 | --output_dir "$MODEL_DIR/$MODEL_NAME" \ 63 | --num_train_epochs $NUM_EPOCHS \ 64 | --group_by_length False \ 65 | --evaluation_strategy "steps" \ 66 | --eval_steps 50 \ 67 | --save_strategy "steps" \ 68 | --save_steps 100 \ 69 | --save_total_limit 10 \ 70 | --weight_decay 0.0 \ 71 | --warmup_ratio 0.1 \ 72 | --lr_scheduler_type "constant_with_warmup" \ 73 | --logging_steps 5 \ 74 | --report_to "tensorboard" \ 75 | --ddp_backend "nccl" \ 76 | --bf16 True \ 77 | --ddp_find_unused_parameters False \ 78 | --resume_from_training True \ 79 | --reward_prompt_file "./prompts/multi_frame_prompt.txt" \ 80 | --image_to_caption_file $DATA_DIR/$CAPTION_DATA \ 81 | --image_aspect_ratio 'resize' 82 | -------------------------------------------------------------------------------- /assets/armap.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heaplax/ARMAP/9a240bbb718b05ad72145dfa75be9a77aa942f14/assets/armap.gif --------------------------------------------------------------------------------