├── .gitignore ├── LICENSE ├── README.md ├── agent ├── __init__.py ├── base.py ├── expel.py ├── react.py └── reflect.py ├── assets ├── eval_example.png ├── expel.png └── header_baby.png ├── configs ├── agent │ └── expel.yaml ├── benchmark │ ├── alfworld.yaml │ ├── fever.yaml │ ├── hotpotqa.yaml │ └── webshop.yaml ├── eval.yaml ├── insight_extraction.yaml ├── train.yaml └── visualize_logs.yaml ├── data ├── alfworld │ └── alfworld_tasks_suffix.json ├── fever │ └── paper_dev.jsonl ├── hotpotqa │ └── hotpot-qa-distractor-sample.joblib └── webshop │ └── webshop.fixed100.json ├── envs ├── __init__.py ├── alfworld │ └── alfworld.py ├── base.py ├── fever │ ├── fever.py │ ├── wikienv.py │ └── wrappers.py ├── hotpotqa │ └── hotpotqa.py └── webshop │ └── webshop.py ├── eval.py ├── insight_extraction.py ├── memory ├── __init__.py └── episode.py ├── models ├── __init__.py └── llm.py ├── prompts ├── __init__.py ├── alfworld.py ├── fever.py ├── hotpotQA.py ├── templates │ ├── human.py │ └── system.py └── webshop.py ├── requirements.txt ├── tools └── visualize_logs.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Smoll baby robot | ExpeL: LLM Agents are Experiential Learners 2 | :---:|:---: 3 | 4 | 5 | 6 | ⚡ [AAAI 2024 *(Oral)*] Official implementation of the ExpeL Agent ⚡ 7 | 8 | ~ by Andrew Zhao, Daniel Huang, Quentin Xu, Matthieu Lin, Yong-Jin Liu, Gao Huang ~ 9 | 10 | 11 | [![Release Notes](https://img.shields.io/github/release/LeapLabTHU/ExpeL)](https://github.com/LeapLabTHU/ExpeL/releases) 12 | ![License: Apache 2.0](https://img.shields.io/github/license/LeapLabTHU/ExpeL) 13 | [![GitHub star chart](https://img.shields.io/github/stars/LeapLabTHU/ExpeL?style=social)](https://star-history.com/#LeapLabTHU/ExpeL) 14 | [![Open Issues](https://img.shields.io/github/issues-raw/LeapLabTHU/ExpeL)](https://github.com/LeapLabTHU/ExpeL/issues) 15 | 16 | --- 17 | ### 🌐 $\cdot$ [Project Page](https://andrewzh112.github.io/expel)   📄 $\cdot$ [Paper](https://arxiv.org/pdf/2308.10144.pdf) 18 | 19 | > "A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P, if its performance at tasks in T, as measured by P, improves with experience E." - Tom Mitchell 20 | 21 | # 📖 Table of Contents 22 | 23 |
24 | 25 | 26 |
27 | 28 | [👋 Introduction](#-introduction) 29 | 30 | [🛠️ Installation](#%EF%B8%8F-installation) 31 | 32 | - [🌳 Environments](#-environments) 33 | - [🏠 ALFWorld](#-alfworld) 34 | - [🛒 Webshop](#-webshop) 35 | 36 | [🚀 Quick start](#-quick-start) 37 | 1. [Experience Gathering](#1-for-the-experience-gathering-stage) 38 | 2. [Insights Extraction](#2-for-the-insights-extraction-stage) 39 | 3. [Evaluation](#3-for-evaluation) 40 | 41 | [🫡 Cite us !](#-cite-us-) 42 | 43 | [💌 Contact us !](#-contact-us-) 44 | 45 | [🏛️ License](#%EF%B8%8F-license) 46 | 47 | [⚠️ Issues](#%EF%B8%8F-issues) 48 | 49 |
50 | 51 | 52 |
53 | 54 | Smoll baby robot 55 | 56 |
57 | 58 |
59 | 60 | 61 | 62 | ## 👋 Introduction 63 | 64 | This repo is the official implementation of [Expel: LLM Agents are Experiential Learners](https://arxiv.org/pdf/2308.10144.pdf). 65 | 66 | Our agent autonomously gathers experiences and extracts knowledge using natural language from a collection of training tasks. At inference, the agent recalls its extracted insights and past experiences to make informed decisions. Our empirical results highlight the robust learning efficacy of the ExpeL agent, indicating a consistent enhancement in its performance as it accumulates experiences. 67 | 68 | ## 🛠️ Installation 69 | Python version : 3.9.17 70 | 71 | 1. Create a virtual environment using [Anaconda](https://anaconda.org/anaconda/python) (or your favorite package manager), activate it, clone the repo and install the requirements. 72 | 73 | ```sh 74 | conda create -n expel python=3.9.17 75 | conda activate expel 76 | 77 | git clone https://github.com/LeapLabTHU/ExpeL.git expel 78 | cd expel 79 | 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | Next you need to setup the environments. 84 | 85 | ## 🌳 Environments 86 | 87 | Baby ExpeL has been playing around with the following environments: 88 | 89 | - ❓[HotpotQA](https://github.com/hotpotqa/hotpot) 90 | - 🏠 [ALFWorld](https://github.com/alfworld/alfworld) 91 | - 🛒 [WebShop](https://github.com/princeton-nlp/WebShop) 92 | - 🌡️ [FEVER](https://github.com/awslabs/fever) 93 | 94 | Among these, ALFWorld and WebShop require manual installation (+ loading a server (can be local) for WebShop). Details below: 95 | 96 | ### 🏠 ALFWorld 97 | The installation instructions are shown below. Use the previously created environment to install ALFWorld. 98 | You will also need to download the data at the specified location: ``data/alfworld``. 99 | ```Bash 100 | conda activate expel 101 | pip install alfworld[full] 102 | 103 | export ALFWORLD_DATA="data/alfworld" 104 | alfworld-download 105 | ``` 106 | If you need more details, please refer to the [official repo](https://github.com/alfworld/alfworld#quickstart). 107 | 108 | ### 🛒 WebShop 109 | 110 | WebShop installation is different from the other environments. You will have to install it and **manually** run the server (can be local) in parallel of ExpeL to interact with the environment. 111 | The succinct installation instructions are shown below. 112 | 113 | ```bash 114 | git clone https://github.com/princeton-nlp/webshop.git webshop 115 | cd webshop 116 | 117 | # Create another env for the webshop server to avoid conflicts 118 | conda create -n webshop python=3.8.13 119 | conda activate webshop 120 | 121 | ./setup.sh -d all 122 | ``` 123 | 124 | By default the WebShop only loads 1,000 products. But we need ALL OF THEM (🤯). So change ``web_agent_site/utils.py``: 125 | 126 | ```python 127 | # DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2_1000.json') 128 | # DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle_1000.json') 129 | DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2.json') 130 | DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle.json') 131 | ``` 132 | To run the server, run the following command: 133 | ```bash 134 | ./run_dev.sh 135 | ``` 136 | You will be given an URL (and port) once the website is on: 137 | - Go back to the cloned ExpeL repo 138 | - Modify the config file and add the given URL in ``envs/webshop/webshop.py``: 139 | ```python 140 | WEBSHOP_URL = "http://127.0.0.1:3000" # Example URL 141 | ``` 142 | 143 | Note that you will have to run the WebShop server in the background to interact with the environment. We gathered some bugs we encountered during the WebShop Server setup [here](#issues). 144 | 145 | If you need more details, please refer to the [official repo](https://github.com/princeton-nlp/WebShop?tab=readme-ov-file#-setup). 146 | 147 | 148 | ## 🚀 Quick start 149 | 150 | Below are the commands to run the ExpeL Agent. 151 | 152 | **Either put your OpenAI API key in a ``.env`` file (OPENAI_API_KEY=XXX) or get prompted in command line** 153 | 154 | ### 1. For the **Experience Gathering** stage: 155 | ```bash 156 | python train.py benchmark= \ 157 | run_name= \ 158 | testing=false \ 159 | resume=false 160 | 161 | # resume = true/false if you want to resume a previous run 162 | # benchmark = {hotpotqa, alfworld, webshop, fever} 163 | # agent.llm = {gpt-3.5-turbo (default), gpt-4} 164 | ``` 165 | Below are the commands to run the experience gathering stage as in the paper: 166 | ```bash 167 | # 🏠 ALFWorld 168 | python train.py benchmark=alfworld run_name= testing=false resume=false 169 | # 🛒 WebShop 170 | python train.py benchmark=webshop run_name= testing=false resume=false 171 | # ❓ HotpotQA 172 | python train.py benchmark=hotpotqa run_name= testing=false resume=false 173 | ``` 174 | 175 | By default, the result files (logs, dictionnaries) will be saved in ``logs//expel`` referenced by ````. You can change the log directory by adding ``log_dir=`` to the command line. 176 | 177 | 178 | ### 2. For the **Insights Extraction** stage: 179 | Use the collected experiences to extract insights. 180 | 181 | ```bash 182 | python insight_extraction.py \ 183 | benchmark= \ 184 | load_run_name= \ 185 | run_name= \ 186 | agent.llm= \ 187 | agent.max_num_rules= \ 188 | agent.success_critique_num= \ 189 | testing=true \ 190 | resume=false 191 | 192 | # agent.success_critique_num = number of experiences to give per iteration 193 | # agent.max_num_rules = target number of insights to extract 194 | ``` 195 | 196 | To resume a run that stopped at a specific fold, remove ``load_run_name`` from the parameters and specify the fold ``resume_fold`` it stopped at and ``resume=true``. 197 | 198 | 199 | Below are the commands to run the insights extraction stage as in the paper: 200 | ```bash 201 | # 🏠 ALFWorld 202 | python insight_extraction.py benchmark=alfworld load_run_name= run_name= agent.llm=gpt-4 agent.max_num_rules=10 agent.success_critique_num=8 testing=false resume=false 203 | # 🛒 WebShop 204 | python insight_extraction.py benchmark=webshop load_run_name= run_name= agent.llm=gpt-4 agent.max_num_rules=8 agent.success_critique_num=4 testing=false resume=false 205 | # ❓ HotpotQA 206 | python insight_extraction.py benchmark=hotpotqa load_run_name= run_name= agent.llm=gpt-4 agent.max_num_rules=10 agent.success_critique_num=8 testing=false resume=false 207 | ``` 208 | 209 | The final result files will be saved in ``logs//expel/extracted_insights`` referenced by ````. 210 | 211 | 212 | ### 3. For **Evaluation**: 213 | ```bash 214 | python eval.py benchmark= \ 215 | load_run_name=extracted_insights/ \ 216 | run_name= \ 217 | benchmark.eval_configs.k_folds= \ 218 | agent.fewshot_strategy=task_similarity \ 219 | agent.retrieval_kwargs.max_fewshot_tokens= \ 220 | agent.retrieval_kwargs.buffer_retrieve_ratio = \ 221 | testing=false \ 222 | resume=false 223 | 224 | # agent.fewshot_strategy = {task_similarity, thought_similarity,task_thought_similarity) 225 | # agent.llm = {gpt-3.5-turbo (default), gpt-4} 226 | # agent.retrieval_kwargs.max_fewshot_tokens=auto 227 | # benchmark.eval_configs.k_folds=2 228 | # agent.retrieval_kwargs.buffer_retrieve_ratio = safety measure to not retrieve 0 examples (bigger is safer) 229 | ``` 230 | 231 | To resume a run that stopped, remove ``load_run_name`` from the parameters and add ``resume=true`` at the end of the command line. 232 | 233 | Below are the commands to evalute ExpeL as in the paper: 234 | 235 | ```bash 236 | # 🏠 ALFWorld 237 | python eval.py benchmark=alfworld load_run_name=extracted_insights/ run_name= agent.fewshot_strategy=task_similarity agent.retrieval_kwargs.max_fewshot_tokens=auto testing=false resume=false 238 | # 🛒 WebShop 239 | python eval.py benchmark=webshop load_run_name=extracted_insights/ run_name= agent.fewshot_strategy=task_similarity agent.retrieval_kwargs.max_fewshot_tokens=auto agent.retrieval_kwargs.buffer_retrieve_ratio=20 testing=false resume=false 240 | # ❓ HotpotQA 241 | python eval.py benchmark=hotpotqa load_run_name=extracted_insights/ run_name= agent.fewshot_strategy=task_similarity testing=false resume=false 242 | ``` 243 | The result files will be saved in ``logs//expel/eval`` referenced by ````. 244 | 245 | ## 🫡 Cite us ! 246 | 247 | This repository contains code for reproducing results. If you find this work useful in your research (and/or daily life), please cite: 248 | 249 | ``` 250 | @inproceedings{zhao2024expel, 251 | author = {Andrew Zhao and Daniel Huang and Quentin Xu and Matthieu Lin and Yong-Jin Liu and Gao Huang}, 252 | title = {ExpeL: LLM Agents Are Experiential Learners}, 253 | booktitle = {Thirty-Eighth {AAAI} Conference on Artificial Intelligence, {AAAI} 254 | 2024, Thirty-Sixth Conference on Innovative Applications of Artificial 255 | Intelligence, {IAAI} 2024, Fourteenth Symposium on Educational Advances 256 | in Artificial Intelligence, {EAAI} 2024, February 20-27, 2024, Vancouver, 257 | Canada}, 258 | editor = {Michael J. Wooldridge and Jennifer G. Dy and Sriraam Natarajan}, 259 | year = {2024}, 260 | pages = {19632--19642}, 261 | publisher = {{AAAI} Press}, 262 | url = {https://ojs.aaai.org/index.php/AAAI/article/view/29936}, 263 | doi = {10.1609/aaai.v38i17.29936} 264 | } 265 | ``` 266 | 267 | ## 💌 Contact us ! 268 | 269 | If you have any questions, feel free to contact [Andrew Zhao](mailto:zqc21@mails.tsinghua.edu.cn), [Daniel Huang](mailto:huang-jy22@mails.tsinghua.edu.cn) or [Quentin Xu](mailto:quentinxu1@gmail.com). 270 | 271 | 272 | ## 🏛️ License 273 | Check `LICENSE.md` 274 | 275 | ## ⚠️ Issues 276 | We encountered some errors and gathered them here (note that at time of reading, they might have been fixed). If you don't encountered them, lucky you 😒. 277 | 278 | **🛒 WebShop**-server installation: 279 | 280 | ```Bash 281 | # install 282 | python -m spacy download en_core_web_lg # 283 | pip install lightgbm nmslib # need to have c compiler and stuff 284 | conda install mkl # if ImportError: libmkl_intel_lp64.so.1: 285 | pip install pysernini 286 | pip install pyserini --no-cache-dir # if low on ram 287 | pip install typing-inspect==0.8.0 typing_extensions==4.5.0 # if issubclass errors 288 | # if libjvm.so something, need to export JAVA_HOME 289 | ./setup.sh -d all 290 | ``` 291 | On Mac, if you have problem with lightgbm or nmslib you might have to replace their pip to: 292 | ```Bash 293 | brew install cmake libomp 294 | pip install lightgbm 295 | 296 | CFLAGS="-mavx -DWARN(a)=(a)" pip install --use-pep517 nmslib 297 | ``` 298 | -------------------------------------------------------------------------------- /agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAgent 2 | from .react import ReactAgent 3 | from .reflect import ReflectAgent 4 | from .expel import ExpelAgent 5 | 6 | 7 | AGENT = dict(reflection=ReflectAgent, react=ReactAgent, expel=ExpelAgent) 8 | -------------------------------------------------------------------------------- /agent/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import re 3 | 4 | from prompts.templates.human import human_task_message_prompt 5 | 6 | 7 | class BaseAgent(ABC): 8 | """ 9 | Base agent class. 10 | """ 11 | @abstractmethod 12 | def run(self, *args, **kwargs) -> None: 13 | pass 14 | 15 | @abstractmethod 16 | def step(self) -> None: 17 | pass 18 | 19 | @abstractmethod 20 | def prompt_agent(self) -> str: 21 | pass 22 | 23 | @abstractmethod 24 | def _build_agent_prompt(self) -> str: 25 | pass 26 | 27 | @abstractmethod 28 | def after_step(self, *args, **kwargs) -> None: 29 | pass 30 | 31 | def is_terminated(self) -> bool: 32 | return self.env.is_terminated() 33 | 34 | def is_truncated(self) -> bool: 35 | return self.env.is_truncated() or (self.token_counter(self.log_history(include_all=True)) > 15800) 36 | 37 | @abstractmethod 38 | def reset(self, *args, **kwargs) -> None: 39 | pass 40 | 41 | def log_history(self, include_task: bool = True, include_all: bool = False) -> str: 42 | all_history = '\n'.join([prompt.content for prompt in self.prompt_history]) 43 | if include_all: 44 | return all_history 45 | 46 | # only log the task prompt and the agent's response 47 | reflection_pattern = r'{}'.format(self.format_reflections(self.reflections, include_prefix=False)) 48 | match = re.search(re.escape(reflection_pattern), all_history) 49 | if not match or match.group() == '' or not include_task: 50 | task_text_list = human_task_message_prompt.format_messages(task=self.remove_task_suffix(self.task))[0].content.split('\n') 51 | task_text = '\n'.join(task_text_list) 52 | pattern = r'{}'.format(re.escape(task_text.strip()) + '.*') 53 | match = re.search(pattern, all_history) 54 | if include_task: 55 | return match.group().lstrip("Now it's your turn!\n") + match.string[match.end():] 56 | return match.string[match.end():].strip() 57 | 58 | def remove_task_suffix(self, task: str) -> str: 59 | if self.benchmark_name == 'alfworld': 60 | return task.split('___')[0] 61 | return task -------------------------------------------------------------------------------- /agent/react.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Tuple, Dict, Any, Union 2 | from functools import partial 3 | from copy import deepcopy 4 | 5 | from langchain.prompts import PromptTemplate 6 | from langchain.schema import ChatMessage 7 | from openai.error import InvalidRequestError 8 | 9 | from envs import BaseEnv 10 | from agent import BaseAgent 11 | from prompts.templates.human import ( 12 | human_instruction_fewshot_message_prompt, 13 | human_task_message_prompt, 14 | ) 15 | from utils import print_message, token_counter 16 | 17 | class ReactAgent(BaseAgent): 18 | """ 19 | A Generic ReAct Agent. 20 | """ 21 | def __init__(self, 22 | name: str, 23 | system_instruction: Union[str, Dict[str, str]], 24 | human_instruction: Callable, 25 | fewshots: Union[List[str], Dict[str, List[str]]], 26 | system_prompt: Callable, 27 | env: BaseEnv, 28 | llm: str, 29 | llm_builder: Callable, 30 | openai_api_key: str, 31 | tasks: List[Dict[str, Any]], 32 | max_steps: int, 33 | llm_parser: Callable, 34 | observation_formatter: Callable, 35 | testing: bool = False, 36 | task_idx: int = 0, 37 | benchmark_name = None, 38 | *args, 39 | **kwargs, 40 | ) -> None: 41 | self.benchmark_name = benchmark_name 42 | self.name = name 43 | self.tasks = tasks 44 | self.task_idx = task_idx 45 | self.all_system_instruction = system_instruction 46 | self.human_instruction = human_instruction 47 | self.human_instruction_kwargs = {'max_steps': max_steps} 48 | self.all_fewshots = fewshots 49 | self.system_prompt = system_prompt 50 | self.prompt_history = [] 51 | self.testing = testing 52 | self.max_steps = max_steps 53 | self.llm_parser = llm_parser 54 | self.observation_formatter = observation_formatter 55 | self._last_observation_history = None 56 | 57 | self.env = env(**self.tasks[self.task_idx]['env_kwargs'], max_steps=self.max_steps) 58 | self.env.reset() 59 | self.task = self.tasks[self.task_idx]['task'] 60 | self.reset() 61 | self.truncated, self.reward, self.terminated = False, False, False 62 | self.print_message = partial(print_message, testing=testing) 63 | 64 | self.success, self.fail, self.halted = 0, 0, 0 65 | 66 | self.llm = llm_builder(llm_name=llm, openai_api_key=openai_api_key, long_ver=False) 67 | self.long_context_llm = llm_builder(llm_name=llm, openai_api_key=openai_api_key, long_ver=True) 68 | del openai_api_key 69 | self.token_counter = partial(token_counter, llm=llm, tokenizer=getattr(self.llm, 'tokenizer', None)) 70 | 71 | # build base prompt 72 | self._build_agent_prompt() 73 | self.update_dynamic_prompt_components() 74 | 75 | self.long_pass = None 76 | 77 | def is_success(self) -> bool: 78 | return self.env.success_fn() 79 | 80 | def set_env(self, task_kwargs: Dict[str, Any], max_steps: int): 81 | self.env.__init__(**task_kwargs, max_steps=max_steps) 82 | 83 | def run(self, reset: bool = True, *args, **kwargs) -> None: 84 | if reset: 85 | self.env.reset() 86 | self.reset() 87 | 88 | while not (self.is_truncated() or self.is_terminated()): 89 | self.step() 90 | 91 | def step(self) -> None: 92 | message, message_type, others = self.llm_parser(self.prompt_agent(), self.curr_step, False) 93 | self.prompt_history.append(message) 94 | self.print_message(message) 95 | 96 | thought_num = 1 97 | # loops while in thinking mode 98 | while message_type == 'thought': 99 | thought_num += 1 100 | message, message_type, others = self.llm_parser(self.prompt_agent(), self.curr_step, False) 101 | self.prompt_history.append(message) 102 | self.print_message(message) 103 | 104 | if thought_num > 2: 105 | if message_type == 'thought': 106 | others['action'] = 'N/A' 107 | break 108 | 109 | # Observe 110 | observation, self.reward, self.terminated, self.truncated, _ = self.env.step(others['action']) 111 | if others['action'] == 'N/A' and thought_num > 2: 112 | observation = "You are thinking too many times without taking action." 113 | observation_history, operation = self.observation_formatter(observation, step=self.curr_step) 114 | if operation == 'append': 115 | self.prompt_history.append(observation_history) 116 | elif operation == 'replace': 117 | for message in self.prompt_history: 118 | if self._last_observation_history.content in message.content: 119 | message.content = message.content.replace(self._last_observation_history.content, observation_history.content) 120 | break 121 | self._last_observation_history = deepcopy(observation_history) 122 | self.print_message(observation_history) 123 | 124 | BaseAgent.after_step(self) 125 | 126 | self.prompt_history = self.collapse_prompts(self.prompt_history) 127 | 128 | self.curr_step += 1 129 | 130 | def prompt_agent(self) -> str: 131 | self.prompt_history = self.collapse_prompts(self.prompt_history) 132 | self.update_dynamic_prompt_components() 133 | prompt_history = self.collapse_prompts(self.prompt_history) 134 | if self.testing: 135 | print('###################################') 136 | for prompt in prompt_history: 137 | self.print_message(prompt, self.token_counter) 138 | return input() 139 | try: 140 | return self.llm(prompt_history, stop=['\n', '\n\n']) 141 | except InvalidRequestError: 142 | while self.long_pass is None: 143 | res = input('Changing to long context LLM. Press Enter to continue.\n') 144 | if res == 'pass': 145 | self.long_pass = True 146 | elif res != '': 147 | continue 148 | break 149 | 150 | return self.long_context_llm(prompt_history, stop=['\n', '\n\n']) 151 | 152 | def _build_fewshot_prompt( 153 | self, 154 | fewshots: List[str], 155 | prompt_history: List[ChatMessage], 156 | instruction_prompt: PromptTemplate, 157 | instruction_prompt_kwargs: Dict[str, Any], 158 | prompt_type: str, 159 | ) -> str: 160 | if human_instruction_fewshot_message_prompt is not None and instruction_prompt is not None: 161 | prompt_history.append( 162 | human_instruction_fewshot_message_prompt('message_style_kwargs').format_messages( 163 | instruction=instruction_prompt.format_messages( 164 | **instruction_prompt_kwargs)[0].content, 165 | fewshots='\n\n'.join(fewshots) 166 | )[0] 167 | ) 168 | 169 | def _build_agent_prompt(self) -> None: 170 | system_prompt = self.system_prompt.format_messages( 171 | instruction=self.system_instruction, ai_name=self.name 172 | ) 173 | self.prompt_history.extend(system_prompt) 174 | self._build_fewshot_prompt( 175 | fewshots=self.fewshots, prompt_history=self.prompt_history, 176 | instruction_prompt=self.human_instruction, 177 | instruction_prompt_kwargs=self.human_instruction_kwargs, 178 | prompt_type='react_type', 179 | ) 180 | self.prompt_history = self.collapse_prompts(self.prompt_history) 181 | self.log_idx = len(self.prompt_history) 182 | self.insert_before_task_prompt() 183 | 184 | self.prompt_history.append(human_task_message_prompt.format_messages(task=self.remove_task_suffix(self.task))[0]) 185 | self.insert_after_task_prompt() 186 | self.prompt_history = self.collapse_prompts(self.prompt_history) 187 | self.pretask_idx = len(self.prompt_history) 188 | return self.prompt_history 189 | 190 | def reset(self, *args, **kwargs) -> None: 191 | self.prompt_history = [] 192 | self.update_dynamic_prompt_components(reset=True) 193 | self.curr_step = 1 194 | self._build_agent_prompt() 195 | 196 | def insert_after_task_prompt(self) -> None: 197 | return 198 | 199 | def job_not_done(self) -> bool: 200 | return self.task_idx < len(self.tasks) 201 | 202 | def next_task(self): 203 | self.task_idx += 1 204 | # if there are more tasks, reset the env and the agent 205 | if self.job_not_done(): 206 | self.task = self.tasks[self.task_idx]['task'] 207 | self.set_env(self.tasks[self.task_idx]['env_kwargs'], max_steps=self.max_steps) 208 | self.env.reset() 209 | self.reset() 210 | 211 | def reset_stats(self) -> None: 212 | self.success = 0 213 | self.fail = 0 214 | self.halted = 0 215 | 216 | def update_stats(self) -> None: 217 | if not self.is_success() and self.is_truncated(): 218 | self.halted += 1 219 | else: 220 | if self.reward: 221 | self.success += 1 222 | else: 223 | self.fail += 1 224 | 225 | def get_stats(self) -> Tuple[int, int, int]: 226 | return self.success, self.fail, self.halted 227 | 228 | def collapse_prompts(self, prompt_history: List[ChatMessage]) -> List[ChatMessage]: 229 | """Courtesy of GPT4""" 230 | if not prompt_history: 231 | return [] 232 | 233 | new_prompt_history = [] 234 | scratch_pad = prompt_history[0].content 235 | last_message_type = type(prompt_history[0]) 236 | 237 | for message in prompt_history[1:]: 238 | current_message_type = type(message) 239 | if current_message_type == last_message_type: 240 | scratch_pad += '\n' + message.content 241 | else: 242 | new_prompt_history.append(last_message_type(content=scratch_pad)) 243 | scratch_pad = message.content 244 | last_message_type = current_message_type 245 | 246 | # Handle the last accumulated message 247 | new_prompt_history.append(last_message_type(content=scratch_pad)) 248 | 249 | return new_prompt_history 250 | 251 | def update_dynamic_prompt_components(self): 252 | ##################### 253 | # Updating fewshots # 254 | ##################### 255 | if isinstance(self.all_fewshots, dict): 256 | self.fewshots = self.all_fewshots[self.env.env_name] 257 | elif isinstance(self.all_fewshots, list): 258 | self.fewshots = self.all_fewshots 259 | 260 | ######################### 261 | # Updating instructions # 262 | ######################### 263 | if isinstance(self.all_system_instruction, str): 264 | self.system_instruction = self.all_system_instruction 265 | elif isinstance(self.all_system_instruction, dict): 266 | self.system_instruction = self.all_system_instruction[self.env.env_name] 267 | # if system gives instruction, then human instruction is empty 268 | self.human_instruction_kwargs['instruction'] = '' 269 | self.num_fewshots = len(self.fewshots) 270 | 271 | def load_checkpoint(self, loaded_dict: Dict[str, Any], no_load_list: List['str'] = []) -> None: 272 | for k, v in loaded_dict.items(): 273 | if k in no_load_list: 274 | continue 275 | setattr(self, k, v) 276 | # following attributes are not saved in pickle but correctely initialized back: ['rule_template', 'truncate_strategy', 'embedder', 'retriever_cls', 'manual', 'reflection_task_prompt', 'message_splitter', 'identifier', 'message_step_splitter', 'format_reflections', 'formatted_reflection', 'human_instruction', 'system_prompt', 'llm_parser', 'observation_formatter', 'env', 'print_message', 'llm', 'long_context_llm', 'token_counter'] 277 | -------------------------------------------------------------------------------- /agent/reflect.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable 2 | 3 | from openai.error import InvalidRequestError 4 | from langchain.schema import HumanMessage 5 | 6 | from agent.react import ReactAgent 7 | from utils import Count 8 | import re 9 | 10 | class ReflectAgent(ReactAgent): 11 | """ 12 | A Generic Reflection Agent. 13 | """ 14 | def __init__(self, 15 | reflection_fewshots: List[str], 16 | reflection_task_prompt: Callable, 17 | reflection_system_instruction: str, 18 | max_relfection_depth: int, 19 | message_splitter: Callable, 20 | identifier: Callable, 21 | message_step_splitter: Callable, 22 | reflection_prefix: str, 23 | previous_trials_formatter: Callable, 24 | *args, 25 | **kwargs, 26 | ) -> None: 27 | self.reflection_counter = Count(max_relfection_depth) 28 | self.reflection_fewshots = reflection_fewshots 29 | self.reflection_task_prompt = reflection_task_prompt 30 | self.message_splitter = message_splitter 31 | self.identifier = identifier 32 | self.message_step_splitter = message_step_splitter 33 | self.reflection_prefix = reflection_prefix 34 | self.format_reflections = previous_trials_formatter 35 | self.reflection_prompt_history = [] 36 | self.reflections = [] 37 | self.previous_trial = [] 38 | self.formatted_reflection = None 39 | self.perform_reflection = False 40 | self.increment_task = False 41 | ai_name = 'an advanced reasoning agent that can improve based on self refection' 42 | self.reflection_system_kwargs = dict(instruction=reflection_system_instruction, ai_name=ai_name) 43 | super().__init__(*args, **kwargs) 44 | 45 | def run(self, reset: bool = True, *args, **kwargs) -> None: 46 | if self.perform_reflection and not self.is_success(): 47 | self.reflect() 48 | ReactAgent.run(self, reset) 49 | if self.reflection_counter.is_maximum() or self.is_success(): 50 | self.increment_task = True 51 | 52 | 53 | def step(self) -> None: 54 | ReactAgent.step(self) 55 | trial = self.prompt_history[self.history_index].content.split(self.remove_task_suffix(self.task), 1)[-1].strip() 56 | steps = self.message_step_splitter( 57 | lines=trial, 58 | cycler=self.message_splitter, 59 | step_identifier=self.identifier) 60 | self.previous_trial.append(HumanMessage(content=steps[-1])) 61 | 62 | 63 | def reflect(self) -> None: 64 | self._format_reflection_scratchpad() 65 | self.reflection_prompt_history.append(HumanMessage(content=self.reflection_prefix)) 66 | reflection = self.prompt_reflection() 67 | self.reflections.append(reflection) 68 | self.formatted_reflection = self.format_reflections(self.reflections) 69 | print(self.formatted_reflection) 70 | # wipe the history for a new round 71 | self.previous_trial = [] 72 | 73 | def insert_before_task_prompt(self) -> None: 74 | if self.formatted_reflection is not None: 75 | self.prompt_history.append(HumanMessage(content=self.formatted_reflection)) 76 | 77 | def prompt_reflection(self) -> str: 78 | self.reflection_prompt_history = self.collapse_prompts(self.reflection_prompt_history) 79 | if self.benchmark_name == 'webshop': 80 | # match the last "Observation:" 81 | pattern = r"\nObservation: (.*[\n]+)+Next plan:.*" 82 | matches = re.findall(pattern, self.reflection_prompt_history[-1].content) 83 | if 'Ran out of steps' in matches[-1]: 84 | add_text = "\nObservation: Ran out of steps! TASK FAILED\n\nNext plan:\n" 85 | elif 'Repeated action' in matches[-1]: 86 | add_text = "\nObservation: Repeated action! TASK FAILED\n\nNext plan:\n" 87 | else: 88 | add_text = "\nObservation: Wrong item! TASK FAILED\n\nNext plan:\n" 89 | 90 | new_history = self.reflection_prompt_history[-1].content.split(matches[-1]) 91 | new_history = ''.join(new_history[:-1]) + add_text 92 | 93 | self.reflection_prompt_history[-1].content = new_history 94 | 95 | if self.testing: 96 | print('###################################') 97 | for prompt in self.reflection_prompt_history: 98 | self.print_message(prompt, self.token_counter) 99 | return input() 100 | try: 101 | return self.llm(self.reflection_prompt_history, stop=['\n', '\n\n']) 102 | except InvalidRequestError: 103 | return self.long_context_llm(self.reflection_prompt_history, stop=['\n', '\n\n']) 104 | 105 | def _build_reflection_prompt(self) -> None: 106 | # avoid building reflection prompt if it already exists 107 | if self.reflection_prompt_history != []: 108 | return 109 | system_prompt = self.system_prompt.format_messages(**self.reflection_system_kwargs) 110 | self.reflection_prompt_history.extend(system_prompt) 111 | self._build_fewshot_prompt( 112 | fewshots=self.reflection_fewshots, 113 | prompt_history=self.reflection_prompt_history, 114 | instruction_prompt=self.reflection_task_prompt, 115 | instruction_prompt_kwargs={}, 116 | prompt_type='reflect_type', 117 | ) 118 | self.reflection_prompt_history.append(HumanMessage(content=f'Previous trial:\n{self.remove_task_suffix(self.task)}')) 119 | self.reflect_interaction_idx = len(self.reflection_prompt_history) 120 | for message in self.previous_trial: 121 | self.reflection_prompt_history.append(message) 122 | 123 | def _format_reflection_scratchpad(self) -> str: 124 | lines = [ref.content for ref in self.reflection_prompt_history[self.reflect_interaction_idx:]] 125 | lines_by_tokens = sorted(lines, key=lambda x: self.token_counter(x)) 126 | while self.token_counter(''.join(lines)) > 12000: 127 | ind = lines.index(lines_by_tokens.pop(-1)) 128 | line = lines[ind] 129 | lines[ind] = line.split(':')[0] + ': ...' 130 | combined_message = HumanMessage(content='\n'.join(lines)) 131 | self.reflection_prompt_history = self.reflection_prompt_history[:self.reflect_interaction_idx] 132 | self.reflection_prompt_history.append(combined_message) 133 | 134 | def reset(self, *args, **kwargs) -> None: 135 | ReactAgent.reset(self, *args, **kwargs) 136 | self.reflection_prompt_history = [] 137 | self._build_reflection_prompt() 138 | if self.increment_task: 139 | self.reflections = [] 140 | self.reflection_counter.reset() 141 | self.formatted_reflection = None 142 | self.previous_trial = [] 143 | 144 | @property 145 | def history_index(self) -> int: 146 | return -1 147 | 148 | def next_task(self) -> bool: 149 | # increment task if reflection counter is at max OR if the agent is successful 150 | if self.increment_task: 151 | self.task_idx += 1 152 | if self.job_not_done(): 153 | self.task = self.tasks[self.task_idx]['task'] 154 | self.set_env(self.tasks[self.task_idx]['env_kwargs'], max_steps=self.max_steps) 155 | self.perform_reflection = False 156 | # wipe the history for a new task 157 | self.previous_trial = [] 158 | # if there are more tasks, perform reflection 159 | if self.job_not_done() and not self.increment_task: 160 | self.perform_reflection = True 161 | self.reflection_counter.increment() 162 | self.reset() 163 | self.env.reset() 164 | if self.increment_task: 165 | self.increment_task = False 166 | return True 167 | return False 168 | 169 | def update_stats(self) -> None: 170 | # only count when finished trying for this task 171 | if self.increment_task: 172 | if not self.is_success() and self.is_truncated(): 173 | self.halted += 1 174 | else: 175 | if self.reward: 176 | self.success += 1 177 | else: 178 | self.fail += 1 179 | -------------------------------------------------------------------------------- /assets/eval_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/assets/eval_example.png -------------------------------------------------------------------------------- /assets/expel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/assets/expel.png -------------------------------------------------------------------------------- /assets/header_baby.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/assets/header_baby.png -------------------------------------------------------------------------------- /configs/agent/expel.yaml: -------------------------------------------------------------------------------- 1 | name: expel 2 | llm: gpt-3.5-turbo # gpt-3.5-turbo-0301 3 | max_reflection_depth: 3 4 | max_num_rules: 20 5 | truncate_strategy: null 6 | fewshot_strategy: task_similarity # task_similarity, step_similarity, reflection_similarity, thought_similarity, action_similarity, rotation, task_thought_similarity 7 | critique_truncate_strategy: random 8 | success_critique_num: 8 9 | retrieval_kwargs: 10 | embedder_path: "all-mpnet-base-v2" # sentence-transformers/all-mpnet-base-v2 11 | embedder_type: huggingface 12 | retriever_type: knn # knn or svm 13 | buffer_retrieve_ratio: 4 14 | reranker: none 15 | max_fewshot_tokens: auto -------------------------------------------------------------------------------- /configs/benchmark/alfworld.yaml: -------------------------------------------------------------------------------- 1 | name: alfworld 2 | task_prefix: "" 3 | task_file: data/alfworld/alfworld_tasks_suffix.json 4 | max_steps: 20 5 | num_fewshots: 2 6 | eval_configs: 7 | k_folds: 2 8 | k_folds_seed: 42 9 | split: eval_out_of_distribution 10 | ai_name: alfred 11 | 12 | dataset: 13 | data_path: 'data/alfworld/json_2.1.1/train' 14 | eval_id_data_path: 'data/alfworld/json_2.1.1/valid_seen' # null/None to disable 15 | eval_ood_data_path: 'data/alfworld/json_2.1.1/valid_unseen' # null/None to disable 16 | num_train_games: -1 # max training games (<=0 indicates full dataset) 17 | num_eval_games: -1 # max evaluation games (<=0 indicates full dataset) 18 | 19 | logic: 20 | domain: 'data/alfworld/logic/alfred.pddl' # PDDL domain file that defines the world dynamics 21 | grammar: 'data/alfworld/logic/alfred.twl2' # Grammar file that defines the text feedbacks 22 | 23 | env: 24 | type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid' 25 | regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file 26 | domain_randomization: False # shuffle Textworld print order and object id nums 27 | task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place 28 | expert_timeout_steps: 150 # max steps before timeout for expert to solve the task 29 | expert_type: "handcoded" # 'handcoded' or 'downward'. Note: the downward planner is very slow for real-time use 30 | goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED) 31 | 32 | hybrid: 33 | start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point 34 | thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training 35 | eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training 36 | 37 | thor: 38 | screen_width: 300 # width of THOR window 39 | screen_height: 300 # height of THOR window 40 | smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow) 41 | save_frames_to_disk: False # save frame PNGs to disk (useful for making videos) 42 | save_frames_path: './videos/' # path to save frame PNGs 43 | 44 | controller: 45 | type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER) 46 | debug: False 47 | load_receps: True # load receptacle locations from precomputed dict (if available) 48 | 49 | mask_rcnn: 50 | pretrained_model_path: 'data/alfworld/detectors/mrcnn.pth' 51 | 52 | general: 53 | random_seed: 42 54 | use_cuda: True # disable this when running on machine without cuda 55 | visdom: False # plot training/eval curves, run with visdom server 56 | task: 'alfred' 57 | training_method: 'dagger' # 'dqn' or 'dagger' 58 | save_path: './training/' # path to save pytorch models 59 | observation_pool_capacity: 3 # k-size queue, 0 indicates no observation 60 | hide_init_receptacles: False # remove initial observation containing navigable receptacles 61 | 62 | training: 63 | batch_size: 10 64 | max_episode: 50000 65 | smoothing_eps: 0.1 66 | optimizer: 67 | learning_rate: 0.001 68 | clip_grad_norm: 5 69 | 70 | evaluate: 71 | run_eval: True 72 | batch_size: 10 73 | env: 74 | type: "AlfredTWEnv" 75 | 76 | checkpoint: 77 | report_frequency: 1000 # report every N episode 78 | experiment_tag: 'test' # name of experiment 79 | load_pretrained: False # during test, enable this so that the agent load your pretrained model 80 | load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path 81 | 82 | model: 83 | encoder_layers: 1 84 | decoder_layers: 1 85 | encoder_conv_num: 5 86 | block_hidden_dim: 64 87 | n_heads: 1 88 | dropout: 0.1 89 | block_dropout: 0.1 90 | recurrent: True 91 | 92 | rl: 93 | action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working) 94 | max_target_length: 20 # max token length for seq2seq generation 95 | beam_width: 10 # 1 means greedy 96 | generate_top_k: 3 97 | 98 | training: 99 | max_nb_steps_per_episode: 50 # terminate after this many steps 100 | learn_start_from_this_episode: 0 # delay updates until this epsiode 101 | target_net_update_frequency: 500 # sync target net with online net per this many epochs 102 | 103 | replay: 104 | accumulate_reward_from_final: True 105 | count_reward_lambda: 0.0 # 0 to disable 106 | novel_object_reward_lambda: 0.0 # 0 to disable 107 | discount_gamma_game_reward: 0.9 108 | discount_gamma_count_reward: 0.5 109 | discount_gamma_novel_object_reward: 0.5 110 | replay_memory_capacity: 500000 # adjust this depending on your RAM size 111 | replay_memory_priority_fraction: 0.5 112 | update_per_k_game_steps: 5 113 | replay_batch_size: 64 114 | multi_step: 3 115 | replay_sample_history_length: 4 116 | replay_sample_update_from: 2 117 | 118 | epsilon_greedy: 119 | noisy_net: False # if this is true, then epsilon greedy is disabled 120 | epsilon_anneal_episodes: 1000 # -1 if not annealing 121 | epsilon_anneal_from: 0.3 122 | epsilon_anneal_to: 0.1 123 | 124 | dagger: 125 | action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working) 126 | max_target_length: 20 # max token length for seq2seq generation 127 | beam_width: 10 # 1 means greedy 128 | generate_top_k: 5 129 | unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation 130 | 131 | training: 132 | max_nb_steps_per_episode: 50 # terminate after this many steps 133 | 134 | fraction_assist: 135 | fraction_assist_anneal_episodes: 50000 136 | fraction_assist_anneal_from: 1.0 137 | fraction_assist_anneal_to: 0.01 138 | 139 | fraction_random: 140 | fraction_random_anneal_episodes: 0 141 | fraction_random_anneal_from: 0.0 142 | fraction_random_anneal_to: 0.0 143 | 144 | replay: 145 | replay_memory_capacity: 500000 146 | update_per_k_game_steps: 5 147 | replay_batch_size: 64 148 | replay_sample_history_length: 4 149 | replay_sample_update_from: 2 150 | 151 | vision_dagger: 152 | model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input) 153 | resnet_fc_dim: 64 154 | maskrcnn_top_k_boxes: 10 # top k box features 155 | use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!) 156 | sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn' -------------------------------------------------------------------------------- /configs/benchmark/fever.yaml: -------------------------------------------------------------------------------- 1 | name: fever 2 | task_prefix: "Claim: " 3 | task_file: null 4 | max_steps: 7 5 | ai_name: fact verifier 6 | num_fewshots: 5 7 | eval_configs: 8 | k_folds: 2 9 | k_folds_seed: 42 10 | -------------------------------------------------------------------------------- /configs/benchmark/hotpotqa.yaml: -------------------------------------------------------------------------------- 1 | name: hotpotqa 2 | task_prefix: "Question: " 3 | task_file: data/hotpotqa/hotpot-qa-distractor-sample.joblib 4 | max_steps: 7 5 | num_fewshots: 6 6 | eval_configs: 7 | k_folds: 2 8 | k_folds_seed: 42 9 | ai_name: QA system -------------------------------------------------------------------------------- /configs/benchmark/webshop.yaml: -------------------------------------------------------------------------------- 1 | name: webshop 2 | task_prefix: "Instruction:\n" 3 | task_file: data/webshop/webshop.fixed100.json 4 | max_steps: 15 5 | num_fewshots: 2 6 | eval_configs: 7 | k_folds: 2 8 | k_folds_seed: 42 9 | ai_name: webshop 10 | num_tasks: 100 -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - benchmark: hotpotqa 4 | - agent: expel 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | 8 | ai_name: ${benchmark.ai_name} 9 | agent_type: ${agent.name} 10 | log_dir: logs 11 | testing: true 12 | resume: false 13 | load_cache_rules: true 14 | no_rules: false 15 | run_name: run 16 | load_run_name: ${run_name} 17 | 18 | # remove hydra output files 19 | hydra: 20 | output_subdir: null 21 | run: 22 | dir: . -------------------------------------------------------------------------------- /configs/insight_extraction.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - benchmark: hotpotqa 4 | - agent: expel 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | 8 | ai_name: ${benchmark.ai_name} 9 | agent_type: ${agent.name} 10 | log_dir: logs 11 | run_name: run 12 | load_run_name: ${run_name} 13 | testing: true 14 | folded: true 15 | resume: false 16 | resume_fold: -1 17 | seed: 42 18 | 19 | # remove hydra output files 20 | hydra: 21 | output_subdir: null 22 | run: 23 | dir: . -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - benchmark: hotpotqa 4 | - agent: expel 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | 8 | ai_name: ${benchmark.ai_name} 9 | agent_type: ${agent.name} 10 | log_dir: logs 11 | run_name: run 12 | testing: true 13 | resume: false 14 | 15 | # remove hydra output files 16 | hydra: 17 | output_subdir: null 18 | run: 19 | dir: . -------------------------------------------------------------------------------- /configs/visualize_logs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - benchmark: hotpotqa 4 | - agent: expel 5 | - override hydra/hydra_logging: disabled 6 | - override hydra/job_logging: disabled 7 | 8 | hydra: 9 | output_subdir: null 10 | run: 11 | dir: . 12 | 13 | log_path: logs/${benchmark.name}/${agent.name} 14 | run_name: run 15 | num_tasks: 100 16 | save_fig: false 17 | -------------------------------------------------------------------------------- /data/hotpotqa/hotpot-qa-distractor-sample.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/data/hotpotqa/hotpot-qa-distractor-sample.joblib -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import json 3 | import random 4 | import json 5 | 6 | from .base import BaseEnv 7 | from .hotpotqa.hotpotqa import QAEnv 8 | from .fever.fever import FeverEnv 9 | from .alfworld.alfworld import AlfworldEnv 10 | from .webshop.webshop import WebshopEnv 11 | from utils import get_env_name_from_gamefile 12 | 13 | # Taken from ReAct Github 14 | idxs = list(range(7405)) 15 | random.Random(233).shuffle(idxs) 16 | 17 | INIT_TASKS_FN = dict( 18 | hotpotqa=lambda cfg: [ 19 | { 20 | 'task': f'{cfg.benchmark.task_prefix}{row["question"]}', 21 | 'env_kwargs': { 22 | 'question': row['question'], 23 | 'key': row['answer'], 24 | }, 25 | 'env_name': 'hotpotqa', 26 | } for _, row in joblib.load(cfg.benchmark.task_file).reset_index(drop=True).iterrows()], 27 | # 100 tasks for fever 28 | fever=lambda cfg: [{ 29 | 'task': cfg.benchmark.task_prefix + FeverEnv(idx).reset().replace('Claim: ', ''), 30 | 'env_kwargs': { 31 | 'idx': idx, 32 | }, 33 | 'env_name': 'fever', 34 | } for idx in idxs[:100]], 35 | alfworld=lambda cfg: [ 36 | { 37 | 'task': f'{cfg.benchmark.task_prefix}{row["goal"]}', 38 | 'env_kwargs': { 39 | 'config': cfg.benchmark, 40 | "gamefile": row["gamefile"], 41 | }, 42 | 'env_name': get_env_name_from_gamefile(row['gamefile']) 43 | } for row in json.load(open(cfg.benchmark.task_file, "r")) 44 | ], 45 | webshop=lambda cfg: [ 46 | { 47 | 'task': f'{cfg.benchmark.task_prefix}{row["task"]}', 48 | 'env_kwargs': { 49 | 'session_idx': row["session_idx"], 50 | }, 51 | 'env_name': 'webshop' 52 | } for row in json.load(open(cfg.benchmark.task_file, "r")) 53 | ], 54 | ) 55 | 56 | ENVS = dict(hotpotqa=QAEnv, fever=FeverEnv, alfworld=AlfworldEnv, webshop=WebshopEnv) 57 | -------------------------------------------------------------------------------- /envs/alfworld/alfworld.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Dict, Any, Tuple 3 | from envs.base import BaseEnv 4 | 5 | import alfworld.agents.environment 6 | from utils import get_env_name_from_gamefile 7 | 8 | class AlfworldEnv(BaseEnv): 9 | def __init__(self, 10 | gamefile: str, 11 | config: Dict[str, Any], 12 | max_steps: int = 50, 13 | ): 14 | self.max_steps = max_steps 15 | self.gamefile = gamefile 16 | self.config = config 17 | self.main_env = getattr(alfworld.agents.environment, self.config.env.type)(self.config, train_eval=self.config.split) 18 | self.main_env.game_files = [self.gamefile] 19 | self.task = "housekeeper robot. The agent was placed in a household environment and a task to complete." 20 | self.env_name = get_env_name_from_gamefile(gamefile) 21 | 22 | self.reset() 23 | 24 | def reset(self): 25 | self.curr_step = 1 26 | self.answer = '' 27 | self.terminated = False 28 | self.reward = False 29 | self.is_exhausted = False 30 | self.env = self.main_env.init_env(batch_size=1) 31 | self.env.reset() 32 | self.last_action = None 33 | 34 | def step(self, action: str) -> Tuple[str, bool, bool, bool, int]: 35 | if action.startswith('put'): 36 | pattern = r'put (\w+\s*\d+) (?:in|on) (\w+\s*\d+)' 37 | match = re.match(pattern, action) 38 | if match is not None: 39 | action = 'put ' + match.group(1) + ' in/on ' + match.group(2) 40 | 41 | observation, reward, _ = self.alfworld_run(action) 42 | observation = observation.replace(' In it, you see nothing.', '').replace(', you see nothing', '') 43 | if self.last_action == action: 44 | self.truncated = True 45 | self.terminated = True 46 | 47 | self.last_action = action 48 | 49 | if reward: 50 | observation = 'Task is SOLVED.' 51 | self.terminated = True 52 | else: 53 | if self.is_truncated(): 54 | observation = 'Max steps reached.' 55 | pass 56 | 57 | self.curr_step += 1 58 | self.terminated = self.is_terminated() 59 | self.truncated = self.is_truncated() 60 | self.reward = reward 61 | 62 | return observation, self.reward, self.terminated, self.truncated, self.curr_step 63 | 64 | def success_fn(self) -> bool: 65 | return self.reward 66 | 67 | def alfworld_run(self, action): 68 | observation, reward, done, info = self.env.step([action]) 69 | observation, reward, done = process_observation(observation[0]), info['won'][0], done[0] 70 | 71 | return observation, reward, done 72 | 73 | def process_observation(obs): 74 | if obs.startswith('You arrive at loc '): 75 | obs = obs[obs.find('. ')+2:] 76 | return obs -------------------------------------------------------------------------------- /envs/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any 3 | import gym 4 | 5 | class BaseEnv(gym.Env): 6 | @abstractmethod 7 | def reset(self): 8 | pass 9 | 10 | @abstractmethod 11 | def step(self, action: str, *args, **kwargs) -> Dict[str, Any]: 12 | pass 13 | 14 | @abstractmethod 15 | def success_fn(self) -> bool: 16 | pass 17 | 18 | def is_terminated(self) -> bool: 19 | return self.terminated 20 | 21 | def is_truncated(self) -> bool: 22 | return self.curr_step > self.max_steps 23 | -------------------------------------------------------------------------------- /envs/fever/fever.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | from .wikienv import WikiEnv 4 | from .wrappers import FeverWrapper 5 | from envs.base import BaseEnv 6 | from utils import parse_action 7 | 8 | 9 | class FeverEnv(BaseEnv): 10 | def __init__(self, idx, max_steps: int = 6) -> None: 11 | self.env = FeverWrapper(WikiEnv(), split="dev") 12 | self.task = """fact extract and verification. The agent was given access to a Docstore API environment and a fact to verify. The agent can search for pages related to the fact, lookup keywords in the pages, and finish with an answer.""" 13 | self.max_steps = max_steps 14 | self.idx = idx 15 | self.env_name = 'fever' 16 | 17 | self.reset() 18 | 19 | def reset(self): 20 | self.curr_step = 1 21 | self.question = self.env.reset(idx=self.idx).replace('Claim: ', '') 22 | self.key = self.env.data[self.idx][1] 23 | self.terminated = False 24 | return self.question 25 | 26 | def step(self, action): 27 | action_type, argument = parse_action(action) 28 | if action_type == 'Finish': 29 | self.terminated = True 30 | 31 | attempts = 0 32 | while attempts < 10: 33 | try: 34 | obs, self.reward, done, info = self.env.step(action[0].lower() + action[1:]) 35 | self.curr_step += 1 36 | terminated = self.is_terminated() 37 | truncated = self.is_truncated() 38 | return obs, self.reward, terminated, truncated, self.curr_step 39 | except requests.exceptions.Timeout: 40 | attempts += 1 41 | 42 | def success_fn(self) -> bool: 43 | return self.reward == 1 44 | -------------------------------------------------------------------------------- /envs/fever/wikienv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from ReAct Github 3 | """ 4 | 5 | import ast 6 | import json 7 | import time 8 | import gym 9 | import requests 10 | from bs4 import BeautifulSoup 11 | 12 | # import wikipedia 13 | 14 | def clean_str(p): 15 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") 16 | 17 | 18 | class textSpace(gym.spaces.Space): 19 | def contains(self, x) -> bool: 20 | """Return boolean specifying if x is a valid member of this space.""" 21 | return isinstance(x, str) 22 | 23 | 24 | class WikiEnv(gym.Env): 25 | 26 | def __init__(self): 27 | """ 28 | Initialize the environment. 29 | """ 30 | super().__init__() 31 | self.page = None # current Wikipedia page 32 | self.obs = None # current observation 33 | self.lookup_keyword = None # current lookup keyword 34 | self.lookup_list = None # list of paragraphs containing current lookup keyword 35 | self.lookup_cnt = None # current lookup index 36 | self.steps = 0 # current number of steps 37 | self.answer = None # current answer from the agent 38 | self.observation_space = self.action_space = textSpace() 39 | self.search_time = 0 40 | self.num_searches = 0 41 | 42 | def _get_obs(self): 43 | return self.obs 44 | 45 | def _get_info(self): 46 | return {"steps": self.steps, "answer": self.answer} 47 | 48 | def reset(self, seed=None, return_info=False, options=None): 49 | # We need the following line to seed self.np_random 50 | # super().reset(seed=seed) 51 | self.obs = ("Interact with Wikipedia using search[], lookup[], and " 52 | "finish[].\n") 53 | self.page = None 54 | self.lookup_keyword = None 55 | self.lookup_list = None 56 | self.lookup_cnt = None 57 | self.steps = 0 58 | self.answer = None 59 | observation = self._get_obs() 60 | info = self._get_info() 61 | return (observation, info) if return_info else observation 62 | 63 | def construct_lookup_list(self, keyword): 64 | # find all paragraphs 65 | if self.page is None: 66 | return [] 67 | paragraphs = self.page.split("\n") 68 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 69 | 70 | # find all sentence 71 | sentences = [] 72 | for p in paragraphs: 73 | sentences += p.split('. ') 74 | sentences = [s.strip() + '.' for s in sentences if s.strip()] 75 | 76 | parts = sentences 77 | parts = [p for p in parts if keyword.lower() in p.lower()] 78 | return parts 79 | 80 | @staticmethod 81 | def get_page_obs(page): 82 | # find all paragraphs 83 | paragraphs = page.split("\n") 84 | paragraphs = [p.strip() for p in paragraphs if p.strip()] 85 | 86 | # find all sentence 87 | sentences = [] 88 | for p in paragraphs: 89 | sentences += p.split('. ') 90 | sentences = [s.strip() + '.' for s in sentences if s.strip()] 91 | return ' '.join(sentences[:5]) 92 | 93 | # ps = page.split("\n") 94 | # ret = ps[0] 95 | # for i in range(1, len(ps)): 96 | # if len((ret + ps[i]).split(" ")) <= 50: 97 | # ret += ps[i] 98 | # else: 99 | # break 100 | # return ret 101 | 102 | def search_step(self, entity): 103 | entity_ = entity.replace(" ", "+") 104 | search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}" 105 | old_time = time.time() 106 | while True: 107 | try: 108 | response_text = requests.get(search_url).text 109 | break 110 | except: 111 | pass 112 | self.search_time += time.time() - old_time 113 | self.num_searches += 1 114 | soup = BeautifulSoup(response_text, features="html.parser") 115 | result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) 116 | if result_divs: # mismatch 117 | self.result_titles = [clean_str(div.get_text().strip()) for div in result_divs] 118 | self.obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}." 119 | else: 120 | page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] 121 | if any("may refer to:" in p for p in page): 122 | self.search_step("[" + entity + "]") 123 | else: 124 | self.page = "" 125 | for p in page: 126 | if len(p.split(" ")) > 2: 127 | self.page += clean_str(p) 128 | if not p.endswith("\n"): 129 | self.page += "\n" 130 | self.obs = self.get_page_obs(self.page) 131 | self.lookup_keyword = self.lookup_list = self.lookup_cnt = None 132 | 133 | def step(self, action): 134 | reward = 0 135 | done = False 136 | action = action.strip() 137 | if self.answer is not None: # already finished 138 | done = True 139 | return self.obs, reward, done, self._get_info() 140 | 141 | if action.startswith("search[") and action.endswith("]"): 142 | entity = action[len("search["):-1] 143 | # entity_ = entity.replace(" ", "_") 144 | # search_url = f"https://en.wikipedia.org/wiki/{entity_}" 145 | self.search_step(entity) 146 | elif action.startswith("lookup[") and action.endswith("]"): 147 | keyword = action[len("lookup["):-1] 148 | if self.lookup_keyword != keyword: # reset lookup 149 | self.lookup_keyword = keyword 150 | self.lookup_list = self.construct_lookup_list(keyword) 151 | self.lookup_cnt = 0 152 | if self.lookup_cnt >= len(self.lookup_list): 153 | self.obs = "No more results.\n" 154 | else: 155 | self.obs = f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) " + self.lookup_list[self.lookup_cnt] 156 | self.lookup_cnt += 1 157 | elif action.startswith("finish[") and action.endswith("]"): 158 | answer = action[len("finish["):-1] 159 | self.answer = answer 160 | done = True 161 | self.obs = f"Episode finished, reward = {reward}\n" 162 | elif action.startswith("think[") and action.endswith("]"): 163 | self.obs = "Nice thought." 164 | else: 165 | self.obs = "Invalid action: {}".format(action) 166 | 167 | self.steps += 1 168 | 169 | return self.obs, reward, done, self._get_info() 170 | 171 | def get_time_info(self): 172 | speed = self.search_time / self.num_searches if self.num_searches else 0 173 | return { 174 | "call_speed": speed, 175 | "call_time": self.search_time, 176 | "num_calls": self.num_searches, 177 | } -------------------------------------------------------------------------------- /envs/fever/wrappers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gym 3 | import numpy as np 4 | 5 | from utils import normalize_answer 6 | 7 | DATA_DIR = "data/fever" 8 | HOTPOTQA_SPLIT_FILE = { 9 | "train": "hotpot_train_v1.1_simplified.json", 10 | "dev": "hotpot_dev_v1_simplified.json", 11 | "test": "hotpot_test_v1_simplified.json", 12 | } 13 | 14 | FEVER_SPLIT_FILE = { 15 | "train": "train.jsonl", 16 | "dev": "paper_dev.jsonl", 17 | } 18 | 19 | class FeverWrapper(gym.Wrapper): 20 | def __init__(self, env, split): 21 | super().__init__(env) 22 | 23 | data_path = f"{DATA_DIR}/{FEVER_SPLIT_FILE[split]}" 24 | with open(data_path, "r") as json_file: 25 | json_list = list(json_file) 26 | 27 | data = [] 28 | for json_str in json_list: 29 | json_str = json.loads(json_str) 30 | label = json_str["label"] 31 | claim = json_str["claim"] 32 | data.append((claim, label)) 33 | 34 | self.data = data 35 | self.data_idx = 0 36 | self.split = split 37 | 38 | def reset(self, seed=None, return_info=False, options=None, idx=None): 39 | self.env.reset(seed=seed, return_info=return_info, options=options) 40 | try: 41 | self.env.step('') 42 | except: 43 | pass 44 | self.env.reset(seed=seed, return_info=return_info, options=options) 45 | self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx 46 | observation = f"Claim: {self.data[self.data_idx][0]}" 47 | info = self._get_info() 48 | return (observation, info) if return_info else observation 49 | 50 | def _get_info(self): 51 | return { 52 | "steps": self.steps, 53 | "answer": self.answer, 54 | "question": self.data[self.data_idx][0], 55 | "fever_split": self.split 56 | } 57 | 58 | def get_reward(self, info): 59 | if info['answer'] is not None: 60 | label = normalize_answer(self.data[self.data_idx][1]) 61 | pred = normalize_answer(info['answer']) 62 | if label == pred: 63 | return 1 64 | return 0 65 | 66 | def step(self, action): 67 | obs, _, done, info = self.env.step(action) 68 | reward = self.get_reward(info) 69 | if done: 70 | obs = f"Episode finished, reward = {reward}\n" 71 | info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx}) 72 | info.update({'em': reward, 'reward': reward, 'f1': reward}) 73 | return obs, reward, done, info 74 | 75 | def __len__(self): 76 | return len(self.data) 77 | -------------------------------------------------------------------------------- /envs/hotpotqa/hotpotqa.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from typing import Tuple 4 | import time 5 | 6 | from langchain import Wikipedia 7 | from langchain.agents.react.base import DocstoreExplorer 8 | 9 | from envs.base import BaseEnv 10 | from utils import parse_action, EM 11 | 12 | class QAEnv(BaseEnv): 13 | def __init__(self, 14 | question: str, 15 | key: str, 16 | max_steps: int = 6, 17 | explorer: DocstoreExplorer = DocstoreExplorer(Wikipedia())): 18 | 19 | self.question = question 20 | self.key = key 21 | self.max_steps = max_steps 22 | self.explorer = explorer 23 | self.task = """multi-hop QA. The agent was given access to a Docstore API environment and a question to answer. The agent can search for pages related to the question, lookup keywords in the pages, and finish with an answer.""" 24 | self.env_name = 'hotpotqa' 25 | 26 | self.reset() 27 | 28 | def reset(self): 29 | self.curr_step = 1 30 | self.answer = '' 31 | self.terminated = False 32 | 33 | def step(self, action: str) -> Tuple[str, bool, bool, bool, bool]: 34 | action_type, argument = parse_action(action) 35 | 36 | if action_type == 'Finish': 37 | self.answer = argument 38 | if self.success_fn(): 39 | observation = 'Answer is CORRECT' 40 | else: 41 | observation = f'Answer is INCORRECT' 42 | self.terminated = True 43 | elif action_type == 'Search': 44 | while True: 45 | try: 46 | observation = self.explorer.search(argument).strip('\n').strip() 47 | break 48 | except Exception as e: 49 | print(e) 50 | time.sleep(5) 51 | elif action_type == 'Lookup': 52 | try: 53 | observation = self.explorer.lookup(argument).strip('\n').strip() 54 | except ValueError: 55 | observation = f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.' 56 | else: 57 | observation = 'Invalid Action. Valid Actions are Lookup[] Search[] and Finish[].' 58 | 59 | self.curr_step += 1 60 | self.reward = self.success_fn() 61 | self.terminated = self.is_terminated() 62 | self.truncated = self.is_truncated() 63 | 64 | return observation, self.reward, self.terminated, self.truncated, self.curr_step 65 | 66 | def success_fn(self) -> bool: 67 | return EM(self.answer, self.key) 68 | -------------------------------------------------------------------------------- /envs/webshop/webshop.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from bs4 import BeautifulSoup 4 | from bs4.element import Comment 5 | import re 6 | import numpy as np 7 | 8 | from typing import Any, Dict, List, Tuple 9 | from envs.base import BaseEnv 10 | 11 | 12 | # Type in the URL of the webshop server: 13 | # If local: 14 | WEBSHOP_URL = "http://127.0.0.1:3000" 15 | 16 | ACTION_TO_TEMPLATE = { 17 | 'Description': 'description_page.html', 18 | 'Features': 'features_page.html', 19 | 'Reviews': 'review_page.html', 20 | 'Attributes': 'attributes_page.html', 21 | } 22 | 23 | def clean_str(p): 24 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") 25 | 26 | def tag_visible(element): 27 | ignore = {'style', 'script', 'head', 'title', 'meta', '[document]'} 28 | return ( 29 | element.parent.name not in ignore and not isinstance(element, Comment) 30 | ) 31 | 32 | def webshop_text(session, page_type, query_string='', page_num=1, asin='', options={}, subpage='', **kwargs): 33 | if page_type == 'init': 34 | url = ( 35 | f'{WEBSHOP_URL}/{session}' 36 | ) 37 | if page_type == 'search': 38 | url = ( 39 | f'{WEBSHOP_URL}/search_results/{session}/' 40 | f'{query_string}/{page_num}' 41 | ) 42 | elif page_type == 'item': 43 | url = ( 44 | f'{WEBSHOP_URL}/item_page/{session}/' 45 | f'{asin}/{query_string}/{page_num}/{str(options).replace("#","%23")}' # FIXING '#' in url ISSUE 46 | ) 47 | elif page_type == 'item_sub': 48 | url = ( 49 | f'{WEBSHOP_URL}/item_sub_page/{session}/' 50 | f'{asin}/{query_string}/{page_num}/{subpage}/{str(options).replace("#","%23")}' # FIXING '#' in url ISSUE 51 | ) 52 | elif page_type == 'end': 53 | url = ( 54 | f'{WEBSHOP_URL}/done/{session}/' 55 | f'{asin}/{str(options).replace("#","%23")}' 56 | ) 57 | html = requests.get(url).text # type: ignore 58 | html_obj = BeautifulSoup(html, 'html.parser') 59 | texts = html_obj.findAll(text=True) 60 | visible_texts = list(filter(tag_visible, texts)) 61 | if False: 62 | # For `simple` mode, return just [SEP] separators 63 | return ' [SEP] '.join(t.strip() for t in visible_texts if t != '\n') 64 | else: 65 | # Otherwise, return an observation with tags mapped to specific, unique separators 66 | observation = '' 67 | option_type = '' 68 | options = {} 69 | asins = [] 70 | cnt = 0 71 | prod_cnt = 0 72 | just_prod = 0 73 | for t in visible_texts: 74 | if t == '\n': continue 75 | if t.replace('\n', '').replace('\\n', '').replace(' ', '') == '': continue 76 | if t.parent.name == 'button': # button 77 | processed_t = f'\n[{t}] ' 78 | elif t.parent.name == 'label': # options 79 | if f"'{t}'" in url: # type: ignore 80 | processed_t = f'[[{t}]]' 81 | else: 82 | processed_t = f'[{t}]' 83 | options[str(t)] = option_type 84 | elif t.parent.get('class') == ["product-link"]: # product asins 85 | processed_t = f'\n[{t}] ' 86 | prod_cnt += 1 87 | asins.append(str(t)) 88 | just_prod = 0 89 | else: # regular, unclickable text 90 | processed_t = '\n' + str(t) + ' ' 91 | option_type = str(t) 92 | cnt += 1 93 | just_prod += 1 94 | observation += processed_t 95 | info = {} 96 | if options: 97 | info['option_types'] = options 98 | if asins: 99 | info['asins'] = asins 100 | if 'Your score (min 0.0, max 1.0)' in visible_texts: 101 | idx = visible_texts.index('Your score (min 0.0, max 1.0)') 102 | info['reward'] = round(float(visible_texts[idx + 1]), 2) 103 | observation = 'Your score (min 0.0, max 1.0): ' + (visible_texts[idx + 1]) 104 | 105 | # Prompt harmonization 106 | observation = observation.replace('\nWebShop ', '') 107 | observation = observation.replace('\nInstruction: ', '') 108 | observation = observation.replace('[Search]\n', '[Search]') 109 | 110 | return clean_str(observation), info 111 | 112 | class WebshopEnv(BaseEnv): 113 | def __init__(self, session_idx: str, max_steps: int = 15): 114 | self.session_idx = session_idx 115 | self.max_steps = max_steps 116 | self.env_name = 'webshop' 117 | self.task = "instruction-following shopping task. The agent interacted with an online store website in order to buy the requested item based on an instruction." 118 | 119 | self.reset() 120 | 121 | 122 | def reset(self): 123 | self.curr_step = 1 124 | self.reward = False 125 | self.terminated = False 126 | self.last_action = None 127 | 128 | self.session = {'session': self.session_idx, 'page_type': 'init'} 129 | observation, info = webshop_text(**self.session) 130 | self.session.update(info) 131 | 132 | def success_fn(self) -> bool: 133 | return self.reward 134 | 135 | def step(self, action: str) -> Tuple[str, bool, bool, bool, int]: 136 | done = False 137 | observation_ = None 138 | 139 | try: 140 | if action == 'reset': 141 | self.session = {'session': self.session_idx, 'page_type': 'init'} 142 | elif action.startswith('think['): 143 | observation = 'OK.' 144 | elif action.startswith('search['): 145 | assert self.session['page_type'] == 'init' 146 | query = action[7:-1] 147 | self.session = {'session': self.session_idx, 'page_type': 'search', 148 | 'query_string': query, 'page_num': 1, 'fake_page_num' : 1} 149 | elif action.startswith('click['): 150 | button = action[6:-1] 151 | if button == 'Buy Now': 152 | assert self.session['page_type'] == 'item' 153 | self.session['page_type'] = 'end' 154 | self.terminated = True 155 | elif button == 'Back to Search': 156 | assert self.session['page_type'] in ['search', 'item_sub', 'item'] 157 | self.session = {'session': self.session_idx, 'page_type': 'init'} 158 | elif button == 'Next >': 159 | # assert False # ad hoc page limitation 160 | assert self.session['page_type'] == 'search' 161 | assert self.session['page_num'] < math.ceil(self.session['max_products'] / 10) # Seemed already capped at 50 162 | self.session['page_type'] = 'search' 163 | self.session['page_num'] += 1 164 | 165 | elif button == '< Prev': 166 | assert self.session['page_type'] in ['search', 'item_sub', 'item'] 167 | if self.session['page_type'] == 'search': 168 | assert self.session['page_num'] > 1 169 | self.session['page_num'] -= 1 170 | self.session['page_type'] = 'search' 171 | elif self.session['page_type'] == 'item_sub': 172 | self.session['page_type'] = 'item' 173 | elif self.session['page_type'] == 'item': 174 | self.session['page_type'] = 'search' 175 | self.session['options'] = {} 176 | elif button in ACTION_TO_TEMPLATE: 177 | assert self.session['page_type'] == 'item' 178 | self.session['page_type'] = 'item_sub' 179 | self.session['subpage'] = button 180 | else: 181 | if self.session['page_type'] == 'search': 182 | assert button in self.session.get('asins', []) # must be asins 183 | self.session['page_type'] = 'item' 184 | self.session['asin'] = button 185 | elif self.session['page_type'] == 'item': 186 | assert 'option_types' in self.session 187 | assert button in self.session['option_types'], (button, self.session['option_types']) # must be options 188 | option_type = self.session['option_types'][button] 189 | if not 'options' in self.session: 190 | self.session['options'] = {} 191 | self.session['options'][option_type] = button 192 | observation_ = f'You have clicked {button}.' 193 | else: 194 | assert False 195 | except AssertionError: 196 | observation_ = 'Invalid action!' 197 | if invalid_repeat(action=action, last_action=self.last_action): 198 | self.truncated = True 199 | self.terminated = True 200 | observation_ = 'Repeated action!' 201 | 202 | observation, info = webshop_text(**self.session) 203 | # update the max number of products of a query when we search 204 | if observation_ not in ['Invalid action!','Repeated action!'] and action.startswith('search['): 205 | pattern = r'\(Total results: (\d+)\)' 206 | max_products = int(re.findall(pattern, observation)[0]) 207 | self.session.update({'max_products': max_products}) 208 | if observation_: 209 | observation = observation_ 210 | 211 | # OK from react code for observation: 212 | if action.startswith('think['): 213 | observation = 'OK.' 214 | 215 | self.session.update(info) 216 | reward = info.get('reward', 0.0) 217 | 218 | self.curr_step += 1 219 | 220 | if self.is_truncated() and not self.is_terminated(): 221 | observation += ('\n\n' if observation != '' else '') + 'Ran out of steps! TASK FAILED' 222 | 223 | self.reward = (reward==1) 224 | self.last_action = action 225 | 226 | return observation, self.reward, self.is_terminated(), self.is_truncated(), self.curr_step 227 | 228 | def invalid_repeat(action: str, last_action: str) -> bool: 229 | if last_action is None: 230 | return False 231 | not_start_list = ['search[', 'think[i apologize', 'think[end', 'think[i\'m sorry', 'think[apolog'] 232 | for word in not_start_list: 233 | if action.lower().startswith(word) and last_action.lower().startswith(word): 234 | return True 235 | return False 236 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import getpass 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pathlib import Path 6 | import os 7 | from copy import deepcopy 8 | from functools import partial 9 | import dotenv 10 | dotenv.load_dotenv() 11 | 12 | from agent import AGENT 13 | from prompts.templates.system import system_message_prompt 14 | from prompts.templates.human import HUMAN_CRITIQUES 15 | from prompts import ( 16 | SYSTEM_INSTRUCTION, 17 | HUMAN_INSTRUCTION, 18 | FEWSHOTS, 19 | REFLECTION_FEWSHOTS, 20 | HUMAN_REFLECTION_INSTRUCTION, 21 | SYSTEM_REFLECTION_INSTRUCTION, 22 | SYSTEM_CRITIQUE_INSTRUCTION, 23 | RULE_TEMPLATE, 24 | LLM_PARSER, 25 | OBSERVATION_FORMATTER, 26 | STEP_IDENTIFIER, 27 | CYCLER, 28 | STEP_CYCLER, 29 | REFLECTION_PREFIX, 30 | PREVIOUS_TRIALS_FORMATTER, 31 | STEP_STRIPPER, 32 | CRITIQUE_SUMMARY_SUFFIX, 33 | ) 34 | from envs import ENVS, INIT_TASKS_FN 35 | from memory import ( 36 | EMBEDDERS, 37 | RETRIEVERS, 38 | ) 39 | from models import LLM_CLS 40 | from utils import get_fewshot_max_tokens, load_trajectories_log, save_trajectories_log, split_logs_by_task, plot_trial_stats, alfworld_results_per_env_name_log, get_webshop_mean_score, get_split_eval_idx_list 41 | 42 | 43 | def get_eval_num(eval_idx: int, eval_idx_list: List[List[int]]) -> int: 44 | eval_num = 0 45 | for eval_idxs in eval_idx_list: 46 | if eval_idx in eval_idxs: 47 | break 48 | eval_num += len(eval_idxs) 49 | return eval_num + eval_idxs.index(eval_idx) 50 | 51 | 52 | @hydra.main(version_base=None, config_path="configs", config_name="eval") 53 | def main(cfg : DictConfig) -> None: 54 | if cfg.testing: 55 | openai_api_key = 'NO_KEY_FOR_TESTING' 56 | else: 57 | openai_api_key = os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else getpass.getpass("Enter or paste your OpenAI API Key: ") 58 | LOG_PATH = Path('/'.join([cfg.log_dir, cfg.benchmark.name, cfg.agent_type])) 59 | SAVE_PATH = LOG_PATH / 'eval' 60 | SAVE_PATH.mkdir(exist_ok=True) 61 | 62 | print(f"{SAVE_PATH}/{cfg.run_name}.pkl") 63 | 64 | # Overwriting confirmation 65 | if not cfg.resume and os.path.exists(f"{SAVE_PATH}/{cfg.run_name}.pkl") and cfg.run_name != 'test': 66 | while True: 67 | res = input(f"Are you sure to overwrite '{cfg.run_name}'? (Y/N)\n").lower() 68 | if res == 'n': 69 | exit(0) 70 | elif res == 'y': 71 | break 72 | 73 | # Load trajectory checkpoint 74 | out = load_trajectories_log( 75 | SAVE_PATH if cfg.resume else LOG_PATH, 76 | run_name=cfg.load_run_name, 77 | load_log=cfg.resume, 78 | load_true_log=cfg.resume) 79 | dicts = out['dicts'] 80 | eval_idx_list = dicts[-1].get( 81 | 'eval_idx_list', 82 | get_split_eval_idx_list(dicts[-1], cfg.benchmark.eval_configs.k_folds)) 83 | log = out['log'] if cfg.resume else f'### EVAL ORDER ###\n{eval_idx_list}\n' 84 | true_log = out['true_log'] if cfg.resume else f'### EVAL ORDER ###\n{eval_idx_list}\n{str(cfg)}\n' 85 | 86 | num_training_tasks = len(INIT_TASKS_FN[cfg.benchmark.name](cfg)) 87 | 88 | # we start at fold 0 if we are starting a new run 89 | starting_fold = dicts[-1].get('starting_fold', 0) 90 | # we start at the first task in the fold if we are starting a new run 91 | starting_idx = dicts[-1].get('starting_idx', eval_idx_list[0][0]) 92 | 93 | react_agent = AGENT[cfg.agent_type]( 94 | name=cfg.ai_name, 95 | system_instruction=SYSTEM_INSTRUCTION[cfg.benchmark.name], 96 | human_instruction=HUMAN_INSTRUCTION[cfg.benchmark.name], 97 | tasks=INIT_TASKS_FN[cfg.benchmark.name](cfg), 98 | fewshots=FEWSHOTS[cfg.benchmark.name], 99 | system_prompt=system_message_prompt, 100 | env=ENVS[cfg.benchmark.name], 101 | max_steps=cfg.benchmark.max_steps, 102 | openai_api_key=openai_api_key, 103 | llm=cfg.agent.llm, 104 | llm_builder=LLM_CLS, 105 | reflection_fewshots=REFLECTION_FEWSHOTS[cfg.benchmark.name], 106 | reflection_task_prompt=HUMAN_REFLECTION_INSTRUCTION[cfg.benchmark.name], 107 | reflection_system_instruction=SYSTEM_REFLECTION_INSTRUCTION[cfg.benchmark.name], 108 | reflection_system_prompt=SYSTEM_INSTRUCTION[cfg.benchmark.name], 109 | max_relfection_depth=cfg.agent.max_reflection_depth if 'max_reflection_depth' in cfg.agent.keys() else 0, 110 | system_critique_instructions=SYSTEM_CRITIQUE_INSTRUCTION[cfg.benchmark.name], 111 | human_critiques=HUMAN_CRITIQUES, 112 | max_num_rules=cfg.agent.max_num_rules if 'max_num_rules' in cfg.agent.keys() else 0, 113 | rule_template=RULE_TEMPLATE[cfg.benchmark.name], 114 | truncate_strategy=cfg.agent.truncate_strategy if 'truncate_strategy' in cfg.agent.keys() else None, 115 | llm_parser=LLM_PARSER[cfg.benchmark.name], 116 | observation_formatter=OBSERVATION_FORMATTER[cfg.benchmark.name], 117 | embedder=EMBEDDERS(cfg.agent.retrieval_kwargs.embedder_type), 118 | embedder_path=cfg.agent.retrieval_kwargs.embedder_path, 119 | step_stripper=STEP_STRIPPER[cfg.benchmark.name], 120 | retriever_cls=RETRIEVERS(cfg.agent.retrieval_kwargs.retriever_type), 121 | message_splitter=CYCLER[cfg.benchmark.name], 122 | identifier=STEP_IDENTIFIER[cfg.benchmark.name], 123 | message_step_splitter=partial(STEP_CYCLER, benchmark=cfg.benchmark.name), 124 | reflection_prefix=REFLECTION_PREFIX[cfg.benchmark.name], 125 | previous_trials_formatter=PREVIOUS_TRIALS_FORMATTER[cfg.benchmark.name], 126 | success_critique_num=cfg.agent.success_critique_num, 127 | fewshot_strategy=cfg.agent.fewshot_strategy, 128 | benchmark_name=cfg.benchmark.name, 129 | reranker=cfg.agent.retrieval_kwargs.reranker, 130 | buffer_retrieve_ratio=cfg.agent.retrieval_kwargs.buffer_retrieve_ratio, 131 | critique_truncate_strategy=cfg.agent.critique_truncate_strategy, 132 | critique_summary_suffix=CRITIQUE_SUMMARY_SUFFIX, 133 | testing=cfg.testing, 134 | task_idx=starting_idx, 135 | max_fewshot_tokens=get_fewshot_max_tokens(cfg.benchmark.name) if cfg.agent.retrieval_kwargs.max_fewshot_tokens == 'auto' else cfg.agent.retrieval_kwargs.max_fewshot_tokens, 136 | ) 137 | 138 | if len(dicts) > 0: 139 | no_load_list = ['ai_message', 'message_type_format', 'max_num_rules', 'testing', 'human_critiques', 'system_critique_instructions', 'fewshot_strategy', 'success', 'halted', 'fail', 'task_idx', 'prompt_history', 'critique_truncate_strategy', 'success_critique_num', 'reflection_fewshots', 'reflection_system_prompt', 'reflection_prefix', 'reflection_prompt_history', 'reflections', 'previous_trial', 'perform_reflection', 'increment_task', 'reflection_system_kwargs', 'prepend_human_instruction', 'name', 'tasks', 'human_instruction_kwargs', 'all_system_instruction', 'all_fewshots', 'max_steps', 'ordered_summary', 'fewshots', 'system_instruction', 'num_fewshots', 'curr_step', 'log_idx', 'pretask_idx', 'reflect_interaction_idx', 'truncated', 'reward', 'terminated', 'autoregressive_model_instruction', 'failed_training_task_idx', '_train', 'task', 140 | 'eval_idx_list', 'starting_fold', 'starting_idx', 'rule_template', 'max_fewshot_tokens', 'buffer_retrieve_ratio'] 141 | react_agent.load_checkpoint(dicts[-1], no_load_list=no_load_list) 142 | # resetting task_idx 143 | react_agent.task = react_agent.tasks[starting_idx]['task'] 144 | react_agent.reset() 145 | 146 | react_agent.eval() 147 | start_processing = False # Flag for starting_fold 148 | start_eval_idx = False # Flag for starting_idx 149 | first_flag = True 150 | react_agent.no_rules = cfg.no_rules 151 | 152 | print(f'*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*\n\nWe are using the following model: {react_agent.llm.model_name}\n\n*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*') 153 | true_log += str(react_agent.llm.llm) + '\n' 154 | first_idxs = [eval_idxs[0] for eval_idxs in eval_idx_list] 155 | 156 | # start evaluating 157 | for k, eval_idxs in enumerate(eval_idx_list): 158 | # syncing fold 159 | if k == starting_fold or start_processing: 160 | start_processing = True 161 | if not cfg.no_rules: 162 | # make sure every start of fold does create_rules 163 | if not first_flag: 164 | starting_idx = eval_idxs[0] 165 | # create rules for each fold, if starting a new fold 166 | if starting_idx == eval_idxs[0]: 167 | training_ids = set(range(num_training_tasks)) - set(eval_idxs) 168 | react_agent.create_rules( 169 | list(training_ids), 170 | cache_fold=None, 171 | load_cache_fold=k if cfg.load_cache_rules else None, 172 | ) 173 | first_flag = False 174 | 175 | # evaluate on each task in the fold 176 | for eval_idx in eval_idxs: 177 | # syncing idx 178 | if eval_idx == starting_idx or start_eval_idx: 179 | start_eval_idx = True 180 | # Skip the first matching eval_idx 181 | if eval_idx == starting_idx and starting_idx not in first_idxs: 182 | continue 183 | prefix = f"#######################################\nTASK {get_eval_num(eval_idx, eval_idx_list)} \nFOLD: {k}, EVAL_IDX: {eval_idx}\n" # the space after TASK \d+ is needed for log results parsing 184 | prefix += react_agent.remove_task_suffix(react_agent.tasks[eval_idx]['task']) + '\n' 185 | print(prefix) 186 | 187 | react_agent.run(mode='eval', eval_idx=eval_idx) 188 | 189 | # logging 190 | react_agent.update_stats() 191 | log += prefix + react_agent.log_history(include_task=False) + '\n\n' 192 | true_log += prefix + react_agent.log_history(include_all=True, include_task=False) + '\n\n' 193 | # not saving other complicated objects 194 | eval_dict = {k: deepcopy(v) for k, v in react_agent.__dict__.items() if type(v) in [list, set, str, bool, int, dict]} 195 | eval_dict.update({ 196 | 'eval_idx_list': eval_idx_list, 197 | 'starting_fold': k, 198 | 'starting_idx': eval_idx, 199 | }) 200 | dicts.append(eval_dict) 201 | save_trajectories_log( 202 | path=SAVE_PATH, 203 | log=log, 204 | dicts=dicts, 205 | true_log=true_log, 206 | run_name=f'{cfg.run_name}' 207 | ) 208 | 209 | # logging to files 210 | success, fail, halted = react_agent.get_stats() 211 | log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}" 212 | true_log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}" 213 | 214 | print(f'Finished. Success: {success}, Fail: {fail}, Halted: {halted}') 215 | 216 | parsed_result = split_logs_by_task(text=log, num_tasks=len(react_agent.tasks)) 217 | reflection_results = plot_trial_stats(parsed_result=parsed_result, benchmark=cfg.benchmark.name, max_trials=1, save_path=f"{LOG_PATH}/{cfg.run_name}_logs_stats.png") 218 | 219 | results = ', '.join([f"{k}: {v}" for k, v in reflection_results.items()]) + '\n' 220 | if cfg.benchmark.name == 'alfworld': 221 | results += str(alfworld_results_per_env_name_log(log, len(react_agent.tasks), 1)) 222 | elif cfg.benchmark.name == 'webshop': 223 | results += str(get_webshop_mean_score(log, len(react_agent.tasks), 1)) 224 | log += f'\n\n{results}\n########################################' 225 | true_log += f'\n\n{results}\n########################################' 226 | print(results) 227 | 228 | save_trajectories_log( 229 | path=SAVE_PATH, 230 | log=log, 231 | dicts=dicts, 232 | true_log=true_log, 233 | run_name=f'{cfg.run_name}' 234 | ) 235 | 236 | if __name__ == "__main__": 237 | main() 238 | -------------------------------------------------------------------------------- /insight_extraction.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import hydra 3 | from omegaconf import DictConfig 4 | from pathlib import Path 5 | from functools import partial 6 | import os 7 | import random 8 | 9 | from agent import AGENT 10 | from langchain.chat_models import ChatOpenAI 11 | from prompts.templates.system import system_message_prompt 12 | from prompts.templates.human import HUMAN_CRITIQUES 13 | from prompts import ( 14 | SYSTEM_INSTRUCTION, 15 | HUMAN_INSTRUCTION, 16 | FEWSHOTS, 17 | REFLECTION_FEWSHOTS, 18 | HUMAN_REFLECTION_INSTRUCTION, 19 | SYSTEM_REFLECTION_INSTRUCTION, 20 | SYSTEM_CRITIQUE_INSTRUCTION, 21 | RULE_TEMPLATE, 22 | LLM_PARSER, 23 | OBSERVATION_FORMATTER, 24 | STEP_IDENTIFIER, 25 | CYCLER, 26 | STEP_CYCLER, 27 | REFLECTION_PREFIX, 28 | PREVIOUS_TRIALS_FORMATTER, 29 | STEP_STRIPPER, 30 | CRITIQUE_SUMMARY_SUFFIX, 31 | ) 32 | from envs import ENVS, INIT_TASKS_FN 33 | from memory import ( 34 | EMBEDDERS, 35 | RETRIEVERS, 36 | ) 37 | from models import LLM_CLS 38 | from utils import save_trajectories_log, load_trajectories_log, shuffled_chunks, get_split_eval_idx_list 39 | from agent.reflect import Count 40 | 41 | from dotenv import load_dotenv 42 | load_dotenv() 43 | 44 | 45 | @hydra.main(version_base=None, config_path="configs", config_name="insight_extraction") 46 | def main(cfg : DictConfig) -> None: 47 | if cfg.testing: 48 | openai_api_key = 'NO_KEY_FOR_TESTING' 49 | else: 50 | openai_api_key = os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else getpass.getpass("Enter or paste your OpenAI API Key: ") 51 | LOG_PATH = Path('/'.join([cfg.log_dir, cfg.benchmark.name, cfg.agent_type])) 52 | SAVE_PATH = LOG_PATH / 'extracted_insights' 53 | SAVE_PATH.mkdir(exist_ok=True) 54 | 55 | # Overwriting confirmation 56 | if not cfg.resume and os.path.exists(f"{SAVE_PATH}/{cfg.run_name}.pkl") and cfg.run_name != 'test': 57 | while True: 58 | res = input(f"Are you sure to overwrite '{cfg.run_name}'? (Y/N)\n").lower() 59 | if res == 'n': 60 | exit(0) 61 | elif res == 'y': 62 | break 63 | if cfg.resume and cfg.resume_fold < 0: 64 | print('Specify a fold to resume when resuming a run! (resume_fold=X)') 65 | exit(1) 66 | out = load_trajectories_log( 67 | SAVE_PATH / f"fold_{cfg.resume_fold}" if cfg.resume_fold > -1 else LOG_PATH, 68 | run_name=cfg.load_run_name, 69 | load_log=cfg.resume, 70 | load_true_log=False 71 | ) 72 | dicts = out['dicts'] 73 | log = out['log'] if cfg.resume else '' 74 | 75 | cfg.folded = True 76 | react_agent = AGENT[cfg.agent_type]( 77 | name=cfg.ai_name, 78 | system_instruction=SYSTEM_INSTRUCTION[cfg.benchmark.name], 79 | human_instruction=HUMAN_INSTRUCTION[cfg.benchmark.name], 80 | tasks=INIT_TASKS_FN[cfg.benchmark.name](cfg), 81 | fewshots=FEWSHOTS[cfg.benchmark.name], 82 | system_prompt=system_message_prompt, 83 | env=ENVS[cfg.benchmark.name], 84 | max_steps=cfg.benchmark.max_steps, 85 | openai_api_key=openai_api_key, 86 | llm=cfg.agent.llm, 87 | llm_builder=LLM_CLS, 88 | reflection_fewshots=REFLECTION_FEWSHOTS[cfg.benchmark.name], 89 | reflection_task_prompt=HUMAN_REFLECTION_INSTRUCTION[cfg.benchmark.name], 90 | reflection_system_instruction=SYSTEM_REFLECTION_INSTRUCTION[cfg.benchmark.name], 91 | reflection_system_prompt=SYSTEM_INSTRUCTION[cfg.benchmark.name], 92 | max_relfection_depth=cfg.agent.max_reflection_depth if 'max_reflection_depth' in cfg.agent.keys() else 0, 93 | system_critique_instructions=SYSTEM_CRITIQUE_INSTRUCTION[cfg.benchmark.name], 94 | human_critiques=HUMAN_CRITIQUES, 95 | max_num_rules=cfg.agent.max_num_rules if 'max_num_rules' in cfg.agent.keys() else 0, 96 | rule_template=RULE_TEMPLATE[cfg.benchmark.name], 97 | truncate_strategy=cfg.agent.truncate_strategy if 'truncate_strategy' in cfg.agent.keys() else None, 98 | llm_parser=LLM_PARSER[cfg.benchmark.name], 99 | observation_formatter=OBSERVATION_FORMATTER[cfg.benchmark.name], 100 | embedder=EMBEDDERS(cfg.agent.retrieval_kwargs.embedder_type), 101 | embedder_path=cfg.agent.retrieval_kwargs.embedder_path, 102 | step_stripper=STEP_STRIPPER[cfg.benchmark.name], 103 | retriever_cls=RETRIEVERS(cfg.agent.retrieval_kwargs.retriever_type), 104 | message_splitter=CYCLER[cfg.benchmark.name], 105 | identifier=STEP_IDENTIFIER[cfg.benchmark.name], 106 | message_step_splitter=partial(STEP_CYCLER, benchmark=cfg.benchmark.name), 107 | reflection_prefix=REFLECTION_PREFIX[cfg.benchmark.name], 108 | previous_trials_formatter=PREVIOUS_TRIALS_FORMATTER[cfg.benchmark.name], 109 | success_critique_num=cfg.agent.success_critique_num, 110 | fewshot_strategy=cfg.agent.fewshot_strategy, 111 | benchmark_name=cfg.benchmark.name, 112 | reranker=cfg.agent.retrieval_kwargs.reranker, 113 | buffer_retrieve_ratio=cfg.agent.retrieval_kwargs.buffer_retrieve_ratio, 114 | critique_truncate_strategy=cfg.agent.critique_truncate_strategy, 115 | critique_summary_suffix=CRITIQUE_SUMMARY_SUFFIX, 116 | testing=cfg.testing, 117 | max_fewshot_tokens = cfg.agent.retrieval_kwargs.max_fewshot_tokens, 118 | ) 119 | 120 | print(f'Loading agent from {LOG_PATH}') 121 | no_load_list = ['ai_message', 'message_type_format', 'max_num_rules', 'testing', 'human_critiques', 'system_critique_instructions', 'fewshot_strategy', 'success', 'halted', 'fail', 'task_idx', 'prompt_history', 'critique_truncate_strategy', 'success_critique_num', 'reflection_fewshots', 'reflection_system_prompt', 'reflection_prefix', 'reflection_prompt_history', 'reflections', 'previous_trial', 'perform_reflection', 'increment_task', 'reflection_system_kwargs', 'prepend_human_instruction', 'name', 'tasks', 'human_instruction_kwargs', 'all_system_instruction', 'all_fewshots', 'max_steps', 'ordered_summary', 'fewshots', 'system_instruction', 'num_fewshots', 'curr_step', 'log_idx', 'pretask_idx', 'reflect_interaction_idx', 'truncated', 'reward', 'terminated', 'autoqregressive_model_instruction', 'failed_training_task_idx', '_train', 'task', 'eval_idx_list', 'starting_fold', 'starting_idx', 'critique_summary_suffix'] 122 | react_agent.load_checkpoint(dicts[-1], no_load_list=no_load_list) 123 | 124 | random.seed(cfg.seed) 125 | num_training_tasks = len(INIT_TASKS_FN[cfg.benchmark.name](cfg)) 126 | if not cfg.resume: 127 | resume = False 128 | else: 129 | resume = 'eval_idx_list' in dicts[-1] 130 | eval_idx_list = dicts[-1].get('eval_idx_list', get_split_eval_idx_list(dicts[-1], cfg.benchmark.eval_configs.k_folds)) 131 | 132 | print(f'eval_idx_list: {eval_idx_list}') 133 | starting_fold = dicts[-1]['starting_fold'] = dicts[-1].get('critique_summary_fold', 0) 134 | 135 | resume_starting_fold = starting_fold 136 | critique_summary_log = dicts[-1].get('critique_summary_log', '') 137 | 138 | for k, eval_idxs in enumerate(eval_idx_list): 139 | if k < starting_fold: 140 | continue 141 | training_ids = set(range(num_training_tasks)) - set(eval_idxs) 142 | (SAVE_PATH / f"fold_{k}").mkdir(exist_ok=True) 143 | log += f'################## FOLD {k} ##################\n' 144 | log += react_agent.create_rules( 145 | list(training_ids), 146 | cache_fold=k, 147 | logging_dir=str(SAVE_PATH / f"fold_{k}"), 148 | run_name=cfg.run_name, 149 | loaded_dict=dicts[-1] if resume and resume_starting_fold == starting_fold else None, 150 | loaded_log=critique_summary_log if resume and resume_starting_fold == starting_fold else None, 151 | eval_idx_list=eval_idx_list, 152 | saving_dict=True, 153 | ) 154 | starting_fold += 1 155 | 156 | save_dict = {k: v for k, v in react_agent.__dict__.items() if type(v) in [list, set, str, bool, int, dict, Count] and k not in ['openai_api_key', 'llm']} 157 | if cfg.folded: 158 | save_dict['eval_idx_list'] = eval_idx_list 159 | dicts.append(save_dict) 160 | save_trajectories_log( 161 | path=SAVE_PATH, 162 | log=log, 163 | dicts=dicts, 164 | run_name=f'{cfg.run_name}', 165 | save_true_log=False 166 | ) 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /memory/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, Dict, List, Optional 2 | 3 | from pydantic import BaseModel, Extra, Field 4 | from langchain.retrievers import SVMRetriever, KNNRetriever 5 | from langchain.embeddings import OpenAIEmbeddings, LlamaCppEmbeddings, HuggingFaceEmbeddings 6 | from pydantic import BaseModel 7 | from langchain.embeddings.base import Embeddings 8 | 9 | from .episode import Trajectory 10 | 11 | # COPIED FROM HUGGINGFACE EMDDEDING CLASS 12 | class GPT4ALLEmbeddings(BaseModel, Embeddings): 13 | """Wrapper around sentence_transformers embedding models. 14 | 15 | To use, you should have the ``sentence_transformers`` python package installed. 16 | 17 | Example: 18 | .. code-block:: python 19 | 20 | from langchain.embeddings import HuggingFaceEmbeddings 21 | 22 | model_name = "sentence-transformers/all-mpnet-base-v2" 23 | model_kwargs = {'device': 'cpu'} 24 | encode_kwargs = {'normalize_embeddings': False} 25 | hf = HuggingFaceEmbeddings( 26 | model_name=model_name, 27 | model_kwargs=model_kwargs, 28 | encode_kwargs=encode_kwargs 29 | ) 30 | """ 31 | 32 | client: Any #: :meta private: 33 | model_name: str = 'gpt4all' 34 | """Model name to use.""" 35 | cache_folder: Optional[str] = None 36 | """Path to store models. 37 | Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" 38 | model_kwargs: Dict[str, Any] = Field(default_factory=dict) 39 | """Key word arguments to pass to the model.""" 40 | encode_kwargs: Dict[str, Any] = Field(default_factory=dict) 41 | """Key word arguments to pass when calling the `encode` method of the model.""" 42 | 43 | def __init__(self, **kwargs: Any): 44 | """Initialize the sentence_transformer.""" 45 | super().__init__(**kwargs) 46 | try: 47 | from gpt4all import Embed4All 48 | 49 | except ImportError as exc: 50 | raise ImportError( 51 | "Could not import sentence_transformers python package. " 52 | "Please install it with `pip install sentence_transformers`." 53 | ) from exc 54 | 55 | self.client = Embed4All() 56 | 57 | class Config: 58 | """Configuration for this pydantic object.""" 59 | 60 | extra = Extra.forbid 61 | 62 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 63 | """Compute doc embeddings using a HuggingFace transformer model. 64 | 65 | Args: 66 | texts: The list of texts to embed. 67 | 68 | Returns: 69 | List of embeddings, one for each text. 70 | """ 71 | texts = list(map(lambda x: x.replace("\n", " "), texts)) 72 | embeddings = [self.client.embed(text) for text in texts] 73 | return embeddings 74 | 75 | def embed_query(self, text: str) -> List[float]: 76 | """Compute query embeddings using a HuggingFace transformer model. 77 | 78 | Args: 79 | text: The text to embed. 80 | 81 | Returns: 82 | Embeddings for the text. 83 | """ 84 | text = text.replace("\n", " ") 85 | embedding = self.client.embed(text) 86 | return embedding 87 | 88 | 89 | def choose_embedder(key: str) -> Callable: 90 | if key == 'openai': 91 | return OpenAIEmbeddings 92 | if key == 'llama': 93 | return LlamaCppEmbeddings 94 | if key == 'gpt4all': 95 | return GPT4ALLEmbeddings 96 | return HuggingFaceEmbeddings 97 | 98 | def choose_retriever(key: str) -> Callable: 99 | if key == 'knn': 100 | return KNNRetriever 101 | if key == 'svm': 102 | return SVMRetriever 103 | return KNNRetriever 104 | 105 | EMBEDDERS = choose_embedder 106 | RETRIEVERS = choose_retriever -------------------------------------------------------------------------------- /memory/episode.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Dict, Optional 2 | from copy import deepcopy 3 | 4 | class Trajectory: 5 | def __init__( 6 | self, 7 | task: str, 8 | trajectory: str, 9 | splitter: Callable, 10 | identifier: Callable, 11 | step_splitter: Callable, 12 | embedder: Optional[Callable] = None, 13 | reflections: List[str] = None 14 | ): 15 | self._task = task 16 | self._trajectory = trajectory 17 | self._reflections = deepcopy(reflections) 18 | self._observations, self._actions, self._thoughts = [], [], [] 19 | for line in splitter(self._trajectory): 20 | setattr(self, f'_{identifier(line)}s', getattr(self, f'_{identifier(line)}s') + [line]) 21 | self._steps = step_splitter(lines=trajectory, cycler=splitter, step_identifier=identifier) 22 | self._keys = {'thought': [], 'step': []} 23 | if embedder is not None: 24 | self._keys['task'] = [embedder(self.task)] 25 | for step in self.steps: 26 | self._keys['step'].append(embedder(step)) 27 | for thought in self.thoughts: 28 | self._keys['thought'].append(embedder(thought)) 29 | 30 | @property 31 | def task(self) -> str: 32 | return self._task 33 | 34 | @property 35 | def steps(self) -> List[str]: 36 | return self._steps 37 | 38 | @property 39 | def trajectory(self) -> str: 40 | return self._trajectory 41 | 42 | @property 43 | def num_steps(self) -> int: 44 | return max(len(self.thoughts), len(self.actions), len(self.observations)) 45 | 46 | @property 47 | def observations(self) -> List[str]: 48 | return self._observations 49 | 50 | @property 51 | def actions(self) -> List[str]: 52 | return self._actions 53 | 54 | @property 55 | def thoughts(self) -> List[str]: 56 | return self._thoughts 57 | 58 | @property 59 | def reflections(self) -> List[str]: 60 | return self._reflections 61 | 62 | @property 63 | def keys(self) -> Dict[str, List[float]]: 64 | return self._keys 65 | 66 | def _replace(self, attr, value): 67 | self.__setattr__(self, attr, value) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm import LLM_CLS -------------------------------------------------------------------------------- /models/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | import time 3 | 4 | from langchain.chat_models import ChatOpenAI 5 | from langchain.schema import ChatMessage 6 | import openai 7 | 8 | 9 | class GPTWrapper: 10 | def __init__(self, llm_name: str, openai_api_key: str, long_ver: bool): 11 | self.model_name = llm_name 12 | if long_ver: 13 | llm_name = 'gpt-3.5-turbo-16k' 14 | self.llm = ChatOpenAI( 15 | model=llm_name, 16 | temperature=0.0, 17 | openai_api_key=openai_api_key, 18 | ) 19 | 20 | def __call__(self, messages: List[ChatMessage], stop: List[str] = [], replace_newline: bool = True) -> str: 21 | kwargs = {} 22 | if stop != []: 23 | kwargs['stop'] = stop 24 | for i in range(6): 25 | try: 26 | output = self.llm(messages, **kwargs).content.strip('\n').strip() 27 | break 28 | except openai.error.RateLimitError: 29 | print(f'\nRetrying {i}...') 30 | time.sleep(1) 31 | else: 32 | raise RuntimeError('Failed to generate response') 33 | 34 | if replace_newline: 35 | output = output.replace('\n', '') 36 | return output 37 | 38 | def LLM_CLS(llm_name: str, openai_api_key: str, long_ver: bool) -> Callable: 39 | if 'gpt' in llm_name: 40 | return GPTWrapper(llm_name, openai_api_key, long_ver) 41 | else: 42 | raise ValueError(f"Unknown LLM model name: {llm_name}") 43 | -------------------------------------------------------------------------------- /prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | from . import alfworld 4 | from . import hotpotQA, fever, webshop 5 | from .templates.human import * 6 | from .templates.system import * 7 | 8 | FEWSHOTS = dict( 9 | hotpotqa=hotpotQA.FEWSHOTS, 10 | fever=fever.FEWSHOTS, 11 | alfworld=alfworld.FEWSHOTS, 12 | webshop=webshop.FEWSHOTS, 13 | ) 14 | REFLECTION_FEWSHOTS = dict( 15 | hotpotqa=hotpotQA.REFLECTION_FEWSHOTS, 16 | fever=None,#fever.REFLECTION_FEWSHOTS, 17 | alfworld=alfworld.REFLECTION_FEWSHOTS, 18 | webshop=webshop.REFLECTION_FEWSHOTS, 19 | ) 20 | SYSTEM_INSTRUCTION = dict( 21 | hotpotqa=hotpotQA.SYSTEM_INSTRUCTION, 22 | fever=fever.SYSTEM_INSTRUCTION, 23 | alfworld=alfworld.SYSTEM_INSTRUCTION, 24 | webshop=webshop.SYSTEM_INSTRUCTION, 25 | ) 26 | SYSTEM_REFLECTION_INSTRUCTION = dict( 27 | hotpotqa=hotpotQA.SYSTEM_REFLECTION_INSTRUCTION, 28 | fever=None,#fever.SYSTEM_REFLECTION_INSTRUCTION, 29 | alfworld=None,#alfworld.SYSTEM_REFLECTION_INSTRUCTION, 30 | webshop=None,#webshop.SYSTEM_REFLECTION_INSTRUCTION, 31 | ) 32 | HUMAN_INSTRUCTION = dict( 33 | hotpotqa=hotpotQA.HUMAN_INSTRUCTION, 34 | fever=fever.HUMAN_INSTRUCTION, 35 | alfworld=alfworld.HUMAN_INSTRUCTION, 36 | webshop=webshop.HUMAN_INSTRUCTION, 37 | ) 38 | HUMAN_REFLECTION_INSTRUCTION = dict( 39 | hotpotqa=hotpotQA.HUMAN_REFLECTION_INSTRUCTION, 40 | fever=None, 41 | alfworld=alfworld.HUMAN_REFLECTION_INSTRUCTION, 42 | webshop=webshop.HUMAN_REFLECTION_INSTRUCTION, 43 | ) 44 | SYSTEM_CRITIQUE_INSTRUCTION = dict( 45 | hotpotqa=dict( 46 | compare_existing_rules=hotpotQA.SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION, 47 | all_success_existing_rules=hotpotQA.SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION, 48 | ), 49 | fever=dict( 50 | compare_existing_rules=None, 51 | all_success_existing_riles=None 52 | ), 53 | alfworld=dict( 54 | compare_existing_rules=alfworld.SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION, 55 | all_success_existing_rules=alfworld.SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION, 56 | ), 57 | webshop=dict( 58 | compare_existing_rules=webshop.SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION, 59 | all_success_existing_rules=webshop.SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION, 60 | ), 61 | ) 62 | 63 | LLM_PARSER = dict( 64 | hotpotqa=hotpotQA.LLM_PARSER, 65 | # fever and hotpotQA has same format 66 | fever=hotpotQA.LLM_PARSER, 67 | alfworld=alfworld.LLM_PARSER, 68 | webshop=webshop.LLM_PARSER, 69 | ) 70 | 71 | OBSERVATION_FORMATTER = dict( 72 | hotpotqa=hotpotQA.OBSERVATION_FORMATTER, 73 | # fever and hotpotQA has same format 74 | fever=hotpotQA.OBSERVATION_FORMATTER, 75 | alfworld=alfworld.OBSERVATION_FORMATTER, 76 | webshop=webshop.OBSERVATION_FORMATTER, 77 | ) 78 | 79 | STEP_IDENTIFIER = dict( 80 | hotpotqa=hotpotQA.STEP_IDENTIFIER, 81 | # fever and hotpotQA has same format 82 | fever=hotpotQA.STEP_IDENTIFIER, 83 | webshop=webshop.STEP_IDENTIFIER, 84 | alfworld=alfworld.STEP_IDENTIFIER, 85 | ) 86 | 87 | CYCLER = dict( 88 | hotpotqa=hotpotQA.CYCLER, 89 | fever=hotpotQA.CYCLER, 90 | # fever and hotpotQA has same format 91 | webshop=webshop.CYCLER, 92 | alfworld=alfworld.CYCLER, 93 | ) 94 | REFLECTION_PREFIX = dict( 95 | hotpotqa=hotpotQA.REFLECTION_PREFIX, 96 | fever=hotpotQA.REFLECTION_PREFIX, 97 | alfworld=alfworld.REFLECTION_PREFIX, 98 | # same format as alfworld 99 | webshop=webshop.REFLECTION_PREFIX, 100 | ) 101 | PREVIOUS_TRIALS_FORMATTER=dict( 102 | hotpotqa=hotpotQA.PREVIOUS_TRIALS_FORMATTER, 103 | fever=hotpotQA.PREVIOUS_TRIALS_FORMATTER, 104 | alfworld=alfworld.PREVIOUS_TRIALS_FORMATTER, 105 | # same format as alfworld 106 | webshop=webshop.PREVIOUS_TRIALS_FORMATTER, 107 | ) 108 | 109 | STEP_STRIPPER = dict( 110 | hotpotqa=hotpotQA.STEP_STRIPPER, 111 | fever=hotpotQA.STEP_STRIPPER, 112 | alfworld=alfworld.STEP_STRIPPER, 113 | # same format as alfworld 114 | webshop=webshop.STEP_STRIPPER, 115 | ) 116 | 117 | def STEP_CYCLER(benchmark: str, lines: str, cycler: Callable, step_identifier: Callable, stripper: Callable = lambda x, y: x) -> List[str]: 118 | steps = [] 119 | scratch_pad = '' 120 | for line in cycler(lines): 121 | step_type = step_identifier(line) 122 | stripped_line = stripper(line, step_type) 123 | scratch_pad += stripped_line + '\n' 124 | if step_type == 'observation': 125 | steps.append(scratch_pad.strip()) 126 | scratch_pad = '' 127 | if scratch_pad != '': 128 | steps.append(scratch_pad.strip()) 129 | return steps 130 | -------------------------------------------------------------------------------- /prompts/fever.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts.chat import HumanMessagePromptTemplate 2 | 3 | FEWSHOTS = ["""Claim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company. 4 | Thought 1: I need to search Nikolaj Coster-Waldau and find if he has worked with the Fox Broadcasting Company. 5 | Action 1: Search[Nikolaj Coster-Waldau] 6 | Observation 1: Nikolaj William Coster-Waldau (born 27 July 1970) is a Danish actor and producer. He graduated from the Danish National School of Performing Arts in Copenhagen in 1993,[1] and had his breakthrough role in Denmark with the film Nightwatch (1994). He played Jaime Lannister in the HBO fantasy drama series Game of Thrones, for which he received two Primetime Emmy Award nominations for Outstanding Supporting Actor in a Drama Series.. Coster-Waldau has appeared in numerous films in his native Denmark and Scandinavia, including Headhunters (2011) and A Thousand Times Good Night (2013). In the U.S, his debut film role was in the war film Black Hawk Down (2001), playing Medal of Honor recipient Gary Gordon.[2] He then played a detective in the short-lived Fox television series New Amsterdam (2008), and appeared in the 2009 Fox television film Virtuality, originally intended as a pilot. 7 | Thought 2: Because he "appeared in the 2009 Fox television film Virtuality", he should have worked with the Fox Broadcasting Company. 8 | Action 2: Finish[SUPPORTS]""", 9 | """Claim: Stranger Things is set in Bloomington, Indiana. 10 | Thought 1: I should search for Stranger Things, and see if it is set in Bloomington, Indiana. 11 | Action 1: Search[Stranger Things] 12 | Observation 1: Stranger Things is an American science fiction horror drama television series created by the Duffer Brothers. Set in the 1980s, primarily in the fictional town of Hawkins, Indiana, the series centers on a number of mysteries and supernatural events occurring around the town and their impact on an ensemble of child and adult characters. 13 | Thought 2: The observation says that it is set in a "fictional town of Hawkins, Indiana", so it is not set in Bloomington. 14 | Action 2: Finish[REFUTES]""", 15 | """Claim: Beautiful reached number two on the Billboard Hot 100 in 2003.? 16 | Thought 1: I need to search the song Beautiful and find if it reached number two on the Billboard Hot 100 in 2003. 17 | Action 1: Search[Beautiful] 18 | Observation 1: Could not find [Beautiful]. Similar: ['Beautiful', 'Beautiful, Beautiful', 'A Beautiful Mind (film)', 'Beautiful (Christina Aguilera song)', 'Life Is Beautiful']. 19 | Thought 2: From suggestions, I should search "Beautiful (Christina Aguilera song)" to find the song. 20 | Action 2: Search[Beautiful (Christina Aguilera song)] 21 | Observation 2: "Beautiful" is a song recorded by American singer Christina Aguilera for her fourth studio album, Stripped (2002). 22 | Thought 3: It does not mention Billboard, so I need to look up "Billboard Hot 100" to find if it reached number two on it in 2003. 23 | Action 3: Lookup[Billboard Hot 100] 24 | Observation 3: (Result 1 / 3) The song peaked at number two on the Billboard Hot 100 in the United States, where it was certified Gold for 500,000 units shipped. 25 | Thought 4: It only says the song peaked at number two on the Billboard Hot 100, but not if it was in 2003. I am not sure if this claim is true or not. 26 | Action 4: Finish[NOT ENOUGH INFO]"""] 27 | 28 | SYSTEM_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 29 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 30 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 31 | (3) Finish[answer], which returns the answer and finishes the task, answer should be one of REFUTES, SUPPORTS or NOT ENOUGH INFO. 32 | """ 33 | 34 | human_instruction_template = """{instruction}You may take maximum of {max_steps} steps. 35 | Here are some examples:""" 36 | 37 | HUMAN_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_template) 38 | -------------------------------------------------------------------------------- /prompts/hotpotQA.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any, List 2 | import pickle 3 | import re 4 | 5 | from langchain.prompts.chat import HumanMessagePromptTemplate 6 | from langchain.schema import ( 7 | AIMessage, 8 | HumanMessage, 9 | ChatMessage, 10 | ) 11 | 12 | FEWSHOTS = ["""Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into? 13 | Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area. 14 | Action 1: Search[Colorado orogeny] 15 | Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas. 16 | Thought 2: It does not mention the eastern sector. So I need to look up eastern sector. 17 | Action 2: Lookup[eastern sector] 18 | Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny. 19 | Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range. 20 | Action 3: Search[High Plains] 21 | Observation 3: High Plains refers to one of two distinct land regions: 22 | Thought 4: I need to instead search High Plains (United States). 23 | Action 4: Search[High Plains (United States)] 24 | Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3] 25 | Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft. 26 | Action 5: Finish[1,800 to 7,000 ft]""", 27 | """Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who? 28 | Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after. 29 | Action 1: Search[Milhouse] 30 | Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening. 31 | Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after". 32 | Action 2: Lookup[named after] 33 | Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous. 34 | Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon. 35 | Action 3: Finish[Richard Nixon]""", 36 | """Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture? 37 | Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups. 38 | Action 1: Search[Adam Clayton Powell] 39 | Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito']. 40 | Thought 2: To find the documentary, I can search Adam Clayton Powell (film). 41 | Action 2: Search[Adam Clayton Powell (film)] 42 | Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg. 43 | The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience. 44 | Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture. 45 | Action 3: Finish[The Saimaa Gesture]""", 46 | """Question: What profession does Nicholas Ray and Elia Kazan have in common? 47 | Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common. 48 | Action 1: Search[Nicholas Ray] 49 | Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause. 50 | Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions. 51 | Action 2: Search[Elia Kazan] 52 | Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor. 53 | Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor. 54 | Action 3: Finish[director, screenwriter, actor]""", 55 | """Question: Which magazine was started first Arthur's Magazine or First for Women? 56 | Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first. 57 | Action 1: Search[Arthur's Magazine] 58 | Observation 1: Arthur's Magazine (1844-€“1846) was an American literary periodical published in Philadelphia in the 19th century. 59 | Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next. 60 | Action 2: Search[First for Women] 61 | Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989. 62 | Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first. 63 | Action 3: Finish[Arthur's Magazine]""", 64 | """Question: Were Pavel Urysohn and Leonid Levin known for the same type of work? 65 | Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same. 66 | Action 1: Search[Pavel Urysohn] 67 | Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 – August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory. 68 | Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work. 69 | Action 2: Search[Leonid Levin] 70 | Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist. 71 | Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work. 72 | Action 3: Finish[yes]"""] 73 | 74 | REFLECTION_FEWSHOTS = [ 75 | """ 76 | Previous Trial: 77 | Question: The Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of what? 78 | Thought 1: I need to search Rome Protocols, find the three Prime Ministers, then find what they were assassinated as part of. 79 | Action 1: Search[Rome Protocols] 80 | Observation 1: The Rome Protocols were a series of three international agreements concluded in Rome on 17 March 1934 between the governments of Austria, Hungary and Italy. They were signed by Italian Prime Minister Benito Mussolini, Austrian Chancellor Engelbert Dollfuss and Hungarian Prime Minister Gyula Gömbös. All of the protocols went into effect on 12 July 1934 and were registered in League of Nations Treaty Series on 12 December 1934. 81 | Thought 2: The three Prime Ministers were Benito Mussolini, Engelbert Dollfuss, and Gyula Gömbös. I need to search Benito Mussolini and find what he was assassinated as part of. 82 | Action 2: Search[Benito Mussolini] 83 | Observation 2: Benito Amilcare Andrea Mussolini (UK: MU(U)SS-ə-LEE-nee, US: MOOSS-, Italian: [beˈniːto aˈmilkare anˈdrɛːa mussoˈliːni]; 29 July 1883 – 28 April 1945) was an Italian politician and journalist who founded and led the National Fascist Party (PNF). He was Prime Minister of Italy from the March on Rome in 1922 until his deposition in 1943, as well as "Duce" of Italian fascism from the establishment of the Italian Fasces of Combat in 1919 until his summary execution in 1945 by Italian partisans. As dictator of Italy and principal founder of fascism, Mussolini inspired and supported the international spread of fascist movements during the inter-war period.Mussolini was originally a socialist politician and a journalist at the Avanti! newspaper. In 1912, he became a member of the National Directorate of the Italian Socialist Party (PSI), but he was expelled from the PSI for advocating military intervention in World War I, in opposition to the party's stance on neutrality. In 1914, Mussolini founded a new journal, Il Popolo d'Italia, and served in the Royal Italian Army during the war until he was wounded and discharged in 1917. Mussolini denounced the PSI, his views now centering on Italian nationalism instead of socialism, and later founded the fascist movement which came to oppose egalitarianism and class conflict, instead advocating "revolutionary nationalism" transcending class lines. On 31 October 1922, following the March on Rome (28–30 October), Mussolini was appointed prime minister by King Victor Emmanuel III, becoming the youngest individual to hold the office up to that time. After removing all political opposition through his secret police and outlawing labor strikes, Mussolini and his followers consolidated power through a series of laws that transformed the nation into a one-party dictatorship. Within five years, Mussolini had established dictatorial authority by both legal and illegal means and aspired to create a totalitarian state. In 1929, Mussolini signed the Lateran Treaty with the Holy See to establish Vatican City. 84 | Mussolini's foreign policy aimed to restore the ancient grandeur of the Roman Empire by expanding Italian colonial possessions and the fascist sphere of influence. In the 1920s, he ordered the Pacification of Libya, instructed the bombing of Corfu over an incident with Greece, established a protectorate over Albania, and incorporated the city of Fiume into the Italian state via agreements with Yugoslavia. In 1936, Ethiopia was conquered following the Second Italo-Ethiopian War and merged into Italian East Africa (AOI) with Eritrea and Somalia. In 1939, Italian forces annexed Albania. Between 1936 and 1939, Mussolini ordered the successful Italian military intervention in Spain in favor of Francisco Franco during the Spanish Civil War. Mussolini's Italy initially tried to avoid the outbreak of a second global war, sending troops at the Brenner Pass to delay Anschluss and taking part in the Stresa Front, the Lytton Report, the Treaty of Lausanne, the Four-Power Pact and the Munich Agreement. However, Italy then alienated itself from Britain and France by aligning with Germany and Japan. Germany invaded Poland on 1 September 1939, resulting in declarations of war by France and the UK and the start of World War II. 85 | On 10 June 1940, Mussolini decided to enter the war on the Axis side. Despite initial success, the subsequent Axis collapse on multiple fronts and eventual Allied invasion of Sicily made Mussolini lose the support of the population and members of the Fascist Party. As a consequence, early on 25 July 1943, the Grand Council of Fascism passed a motion of no confidence in Mussolini; later that day King Victor Emmanuel III dismissed him as head of government and had him placed in custody, appointing Pietro Badoglio to succeed him as Prime Minister. After the king agreed to an armistice with the Allies, on 12 September 1943 Mussolini was rescued from captivity in the Gran Sasso raid by German paratroopers and Waffen-SS commandos led by Major Otto-Harald Mors. Adolf Hitler, after meeting with the rescued former dictator, then put Mussolini in charge of a puppet regime in northern Italy, the Italian Social Republic (Italian: Repubblica Sociale Italiana, RSI), informally known as the Salò Republic, causing a civil war. In late April 1945, in the wake of near total defeat, Mussolini and his mistress Clara Petacci attempted to flee to Switzerland, but both were captured by Italian communist partisans and summarily executed by firing squad on 28 April 1945 near Lake Como. The bodies of Mussolini and his mistress were then taken to Milan, where they were hung upside down at a service station to publicly confirm their demise. 86 | Thought 3: Benito Mussolini was assassinated as part of the collapse of the Axis powers during World War II. So the Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of World War II. 87 | Action 3: Finish[World War II] 88 | 89 | Reflection: I searched one of the prime ministers involved in the signing, then attemted to answer right away. I should have searched each of the prime ministers, then looked up 'death' on each of their pages in order to get more information before answering. 90 | """, 91 | """ 92 | Previous Trial: 93 | Question: Kam Heskin plays Paige Morgan in a 2004 film directed by who? 94 | Thought 1: I need to search Kam Heskin and Paige Morgan, find the 2004 film they are in, then find the director of the film. 95 | Action 1: Search[Kam Heskin] 96 | Observation 1: Kam Erika Heskin (born May 8, 1973) is an American actress. She began her career playing Caitlin Deschanel on the NBC daytime soap opera Sunset Beach (1998–1999), before appearing in films Planet of the Apes (2001) and Catch Me If You Can (2002). Heskin went on in 2003 to play Elizabeth Bennet in an independent film Pride & Prejudice: A Latter-Day Comedy, Robin in Charmed (2004), and Paige Morgan in The Prince and Me film franchise (2006–2010). 97 | Thought 2: Kam Heskin plays Paige Morgan in The Prince and Me film franchise. I need to search The Prince and Me film franchise and find the director of the 2004 film. 98 | Action 2: Search[The Prince and Me film franchise] 99 | Observation 2: Could not find [The Prince and Me film franchise]. Similar: ['Kam Heskin', 'Prince of Persia', 'Harry Potter and the Half-Blood Prince (film)', 'Prince of Persia: The Sands of Time (film)', 'The Chronicles of Narnia (film series)', 'Majestic Prince (franchise)', 'Kingsman (franchise)', 'The Little Prince (2015 film)', 'Shrek (franchise)', 'Harry Potter (film series)'] 100 | Thought 3: I need to search The Prince and Me (2004 film) instead. 101 | Action 3: Search[The Prince and Me (2004 film)] 102 | Observation 3: Could not find [The Prince and Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Kam Heskin', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince of Persia', 'Brooklynn Prince', 'Prince (musician)'] 103 | Thought 4: I need to search The Prince & Me (2004 film). 104 | Action 4: Search[The Prince & Me (2004 film)] 105 | Observation 4: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince'] 106 | Thought 5: I need to search The Prince & Me (2004 film) instead. 107 | Action 5: Search[The Prince & Me (2004 film)] 108 | Observation 5: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince'] 109 | Thought 6: I need to search The Prince & Me (2004 film). 110 | Action 6: Search[The Prince & Me (2004 film)] 111 | Observation 6: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince'] 112 | 113 | Reflection: I got stuck in a loop where I kept trying to search 'The Prince & Me (2004 film)' but the page could not be found. Instead I should have tried to search the similar results that had a similar name to see and they were made in 2004. 114 | """] 115 | 116 | SYSTEM_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: 117 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. 118 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search. 119 | (3) Finish[answer], which returns the answer and finishes the task. 120 | """ 121 | 122 | human_instruction_template = """{instruction}You may take maximum of {max_steps} steps. 123 | Here are some examples:""" 124 | 125 | HUMAN_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_template) 126 | 127 | human_instruction_reflection_template = """Here are some examples:""" 128 | HUMAN_REFLECTION_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_reflection_template) 129 | 130 | SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION = """You will be given two previous task trials in which you were given access to a Docstore API environment and a question to answer: one successful and one unsuccessful trial. You failed the trial either because you guessed the wrong answer with Finish[], or you used up your set number of reasoning steps.""" 131 | SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION = """You will be given successful tasks trials in which you were given access to a Docstore API environment and a question to answer.""" 132 | SYSTEM_REFLECTION_INSTRUCTION = """You will be given a previous reasoning trial in which you were given access to a Docstore API environment and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[], or you used up your set number of reasoning steps. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.""" 133 | 134 | def LLM_PARSER(llm_output, step: int, ai_message: bool) -> Tuple[ChatMessage, str, Dict[str, Any]]: 135 | pattern = r'(?i)action\s*(?:\d+|)\s*(?::|)\s*' 136 | action_pattern = r'(?i)\w+\[[^\]]+(?:\]|)' 137 | 138 | match = re.match(pattern, llm_output) 139 | if match: 140 | action = llm_output[match.end():] 141 | content = f"Action {step}: {action}" 142 | 143 | if len(re.findall(action_pattern, action)) > 1: 144 | return ( 145 | AIMessage(content=content) if ai_message else HumanMessage(content=content), 146 | 'action', 147 | {'action': ''} # triggers invalid action 148 | ) 149 | 150 | return ( 151 | AIMessage(content=content) if ai_message else HumanMessage(content=content), 152 | 'action', 153 | {'action': action} 154 | ) 155 | 156 | actions = re.findall(action_pattern, llm_output) 157 | if len(actions) == 1: 158 | action = actions[0] 159 | if action[-1] != ']': 160 | action += ']' 161 | content = f"Action {step}: {action}" 162 | return ( 163 | AIMessage(content=content) if ai_message else HumanMessage(content=content), 164 | 'action', 165 | {'action': action} 166 | ) 167 | 168 | if len(actions) > 1: 169 | content = re.sub(r"(?i)action\s*(?:\d*|)\s*(?::|)", "", llm_output) 170 | return ( 171 | AIMessage(content=f"Action {step}: {content}"), 172 | 'action', 173 | {'action': ''} # triggers invalid action 174 | ) 175 | 176 | # everthing else will be assumed to be a inner thought 177 | thought_pattern = r'(?i)thought\s*(?:\d+|)\s*(?::|)\s*(.*)' 178 | match = re.match(thought_pattern, llm_output) 179 | if match: 180 | # Extract the thought word and content 181 | thought_word = match.group(1) 182 | content = f"Thought {step}: {thought_word.rstrip(':')}" 183 | else: 184 | content = f"Thought {step}: {llm_output.rstrip(':')}" 185 | return ( 186 | AIMessage(content=content) if ai_message else HumanMessage(content=content), 187 | 'thought', 188 | {} 189 | ) 190 | 191 | def OBSERVATION_FORMATTER(observation: str, step: int, *args, **kwargs) -> Tuple[ChatMessage, str]: 192 | return HumanMessage(content=f"Observation {step}: " + observation.rstrip(':')), 'append' 193 | 194 | def STEP_IDENTIFIER(line: str) -> str: 195 | line = line.strip() 196 | pattern = re.compile(r'^(?i)action(?:\s+(\d+))?:') 197 | match = pattern.match(line) 198 | if match: 199 | return 'action' 200 | pattern = re.compile(r'^(?i)observation(?:\s+(\d+))?:') 201 | match = pattern.match(line) 202 | if match: 203 | return 'observation' 204 | return 'thought' 205 | 206 | def CYCLER(lines: str) -> List[str]: 207 | new_lines = [] 208 | scratch_pad = '' 209 | for line in lines.split('\n'): 210 | 211 | # line is action 212 | pattern = re.compile(r'^(?i)action(?:\s+(\d+))?:') 213 | match = pattern.match(line) 214 | if match: 215 | if scratch_pad != '': 216 | new_lines.append(scratch_pad.strip()) 217 | scratch_pad = '' 218 | new_lines.append(line) 219 | continue 220 | 221 | # line is thought 222 | pattern = re.compile(r'^(?i)thought(?:\s+(\d+))?:') 223 | match = pattern.match(line) 224 | if match: 225 | if scratch_pad != '': 226 | new_lines.append(scratch_pad.strip()) 227 | scratch_pad = '' 228 | new_lines.append(line) 229 | continue 230 | 231 | # step is observation 232 | scratch_pad += line + '\n' 233 | 234 | # the rest of the scratch pad 235 | if scratch_pad != '': 236 | new_lines.append(scratch_pad.strip()) 237 | return new_lines 238 | 239 | REFLECTION_PREFIX = '\nReflection:' 240 | def PREVIOUS_TRIALS_FORMATTER(reflections: List[str], include_prefix: bool = True) -> str: 241 | if reflections == []: 242 | return '' 243 | if include_prefix: 244 | memory_prefix = "You have attempted to solve the task before but failed. The following reflection(s) give a plan to avoid failing the task in the same way you did previously. Use them to improve your strategy of solving the task successfully." 245 | else: 246 | memory_prefix = '' 247 | memory_prefix += '\nReflections:' 248 | for reflection in reflections: 249 | memory_prefix += f"\n- {reflection.strip()}" 250 | return memory_prefix 251 | 252 | def STEP_STRIPPER(step: str, step_type: str): 253 | if step_type == 'observation': 254 | return re.sub(r'^(?i)observation(?:\s+(\d+))?:', 'Observation:', step) 255 | if step_type == 'action': 256 | return re.sub(r'^(?i)action(?:\s+(\d+))?:', 'Action:', step) 257 | if step_type == 'thought': 258 | return re.sub(r'^(?i)thought(?:\s+(\d+))?:', 'Thought:', step) 259 | return step 260 | -------------------------------------------------------------------------------- /prompts/templates/human.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts.chat import HumanMessagePromptTemplate, SystemMessagePromptTemplate 2 | 3 | human_instruction_fewshots_template = """{instruction} 4 | 5 | {fewshots} 6 | 7 | (END OF EXAMPLES) 8 | """ 9 | human_instruction_fewshot_message_prompt = lambda message_type: \ 10 | SystemMessagePromptTemplate.from_template( 11 | human_instruction_fewshots_template, 12 | ) if message_type == 'all_system' else \ 13 | HumanMessagePromptTemplate.from_template(human_instruction_fewshots_template,) 14 | 15 | human_task_template = """Now it's your turn! 16 | {task}""" 17 | human_task_message_prompt = HumanMessagePromptTemplate.from_template( 18 | human_task_template, 19 | ) 20 | 21 | FORMAT_RULES_OPERATION_TEMPLATE = """ : 22 | 23 | The available operations are: AGREE (if the existing rule is strongly relevant for the task), REMOVE (if one existing rule is contradictory or similar/duplicated to other existing rules), EDIT (if any existing rule is not general enough or can be enhanced, rewrite and improve it), ADD (add new rules that are very different from existing rules and relevant for other tasks). Each needs to CLOSELY follow their corresponding formatting below (any existing rule not edited, not agreed, nor removed is considered copied): 24 | 25 | AGREE : 26 | REMOVE : 27 | EDIT : 28 | ADD : 29 | 30 | Do not mention the trials in the rules because all the rules should be GENERALLY APPLICABLE. Each rule should be concise and easy to follow. Any operation can be used MULTIPLE times. Do at most 4 operations and each existing rule can only get a maximum of 1 operation. """ 31 | 32 | CRITIQUE_SUMMARY_SUFFIX = dict(full = """Focus on REMOVE rules first, and stop ADD rule unless the new rule is VERY insightful and different from EXISTING RULES. Below are the operations you do to the above list of EXISTING RULES: 33 | """, not_full = """Below are the operations you do to the above list of EXISTING RULES: 34 | """) 35 | 36 | human_critique_existing_rules_all_success_template = """{instruction} 37 | Here are the trials: 38 | {success_history} 39 | 40 | Here are the EXISTING RULES: 41 | {existing_rules} 42 | 43 | By examining the successful trials, and the list of existing rules, you can perform the following operations: add, edit, remove, or agree so that the new list of rules are general and high level insights of the successful trials or proposed way of Thought so they can be used as helpful tips to different tasks in the future. Have an emphasis on tips that help the agent perform better Thought and Action. Follow the below format: 44 | 45 | """ + FORMAT_RULES_OPERATION_TEMPLATE 46 | 47 | human_all_success_existing_rules_critique = HumanMessagePromptTemplate.from_template(human_critique_existing_rules_all_success_template) 48 | 49 | human_critique_existing_rules_template = """{instruction} 50 | Here are the two previous trials to compare and critique: 51 | TRIAL TASK: 52 | {task} 53 | 54 | SUCCESSFUL TRIAL: 55 | {success_history} 56 | 57 | FAILED TRIAL: 58 | {fail_history} 59 | 60 | Here are the EXISTING RULES: 61 | {existing_rules} 62 | 63 | By examining and contrasting to the successful trial, and the list of existing rules, you can perform the following operations: add, edit, remove, or agree so that the new list of rules is GENERAL and HIGH LEVEL critiques of the failed trial or proposed way of Thought so they can be used to avoid similar failures when encountered with different questions in the future. Have an emphasis on critiquing how to perform better Thought and Action. Follow the below format: 64 | 65 | """ + FORMAT_RULES_OPERATION_TEMPLATE 66 | 67 | human_existing_rules_critique = HumanMessagePromptTemplate.from_template(human_critique_existing_rules_template) 68 | 69 | HUMAN_CRITIQUES = dict( 70 | compare_existing_rules=human_existing_rules_critique, 71 | all_success_existing_rules=human_all_success_existing_rules_critique, 72 | ) 73 | 74 | RULE_TEMPLATE = dict( 75 | hotpotqa=HumanMessagePromptTemplate.from_template("""The following are some experience you gather on a similar task of question answering using Wikipedia API. Use these as references to help you perform this task: 76 | {rules} 77 | """), 78 | webshop=HumanMessagePromptTemplate.from_template(""" 79 | The following are some experiences (in decreasing order of importance) you gathered on tasks of purchasing items requested by an user by interacting with an online website. Use these experiences as useful references to help you perform better on this task: 80 | {rules} 81 | """), 82 | alfworld=HumanMessagePromptTemplate.from_template("""The following are some experience you gather on a similar task of completing a household task by interacting in a household environment. Use these as references to help you perform this task: 83 | {rules} 84 | """), 85 | fever=HumanMessagePromptTemplate.from_template("""The following paragraph are insights a teacher agent provided to you. It is MANDATORY for you to follow these insights as CLOSELY as possible as they will help you perform the fact verification tasks efficiently: 86 | 87 | In order to successfully complete factual verification tasks, begin by clearly understanding the claim. Then formulate a search query that is precise and directly related to the claim. Include the main subjects or context from the claim in your query. If the initial search doesn't yield desired results, consider refining the query, using synonyms, or breaking down the claim into smaller parts. Always verify the information you obtain against the claim before drawing a conclusion. If multiple searches fail, consider changing the search strategy or looking for related information that might indirectly provide the necessary information. If all else fails, consider that the answer might be found in the observations already made. When you're ready to draw a conclusion, double-check it against the information obtained and ensure its accuracy. Lastly, always be prepared to exhaust all possible search queries related to the task at hand before concluding. Remember, the claim can either be supported, refuted, or there might not be enough information to draw a conclusion.{rules} 88 | """), 89 | ) 90 | -------------------------------------------------------------------------------- /prompts/templates/system.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts.chat import HumanMessagePromptTemplate 2 | 3 | 4 | system_template = """You are {ai_name}. {instruction}""" 5 | 6 | system_message_prompt = HumanMessagePromptTemplate.from_template(system_template) 7 | -------------------------------------------------------------------------------- /prompts/webshop.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Any, List 2 | import re 3 | 4 | from langchain.prompts.chat import HumanMessagePromptTemplate, SystemMessagePromptTemplate 5 | from langchain.schema import ( 6 | AIMessage, 7 | HumanMessage, 8 | ChatMessage, 9 | ) 10 | 11 | FEWSHOTS = ["""Instruction: 12 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars 13 | [Search] 14 | 15 | Action: search[3 ounce bright citrus deodorant sensitive skin] 16 | Observation: 17 | [Back to Search] 18 | Page 1 (Total results: 50) 19 | [Next >] 20 | [B078GWRC1J] 21 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 22 | $10.99 23 | [B078GTKVXY] 24 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 25 | $10.99 26 | [B08KBVJ4XN] 27 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack) 28 | $15.95 29 | 30 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.] 31 | Observation: OK. 32 | 33 | Action: click[B078GWRC1J] 34 | Observation: 35 | [Back to Search] 36 | [< Prev] 37 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents] 38 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)] 39 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 40 | Price: $10.99 41 | Rating: N.A. 42 | [Description] 43 | [Features] 44 | [Reviews] 45 | [Buy Now] 46 | 47 | Action: think[For 3 ounce bottle of bright citrus deodorant for sensitive skin, the item has options 'bright citrus' and '3 ounce (pack of 1)' and seems good to buy.] 48 | Observation: OK. 49 | 50 | Action: click[bright citrus] 51 | Observation: You have clicked bright citrus. 52 | 53 | Action: click[3 ounce (pack of 1)] 54 | Observation: You have clicked 3 ounce (pack of 1). 55 | 56 | Action: think[All options are properly selected, I can now buy it.] 57 | Observation: OK. 58 | 59 | Action: click[Buy Now] 60 | Observation: Your score (min 0.0, max 1.0): 1.0""", 61 | """Instruction: 62 | i am looking for long sleeve men t-shirt.and please also choose the black one, and price lower than 30.00 dollars 63 | [Search] 64 | Action: search[shirt man] 65 | Observation: 66 | [Back to Search] 67 | Page 1 (Total results: 50) 68 | [Next >] 69 | [B09NNMV9LN] 70 | JSPOYOU Mens Short Sleeve Crewneck 3D Graphic Tunic Shirts Big & Tall Tie Dye Summer Top Basic Designed Classic Cotton Shirt 71 | $3.99 to $10.99 72 | [B09Q67H373] 73 | JSPOYOU 2022 Newly T-Shirt for Mens Funny 3D Graphics Pattern Crewneck Short Sleeve Tees Big and Tall Summer Casual Comfy Top 74 | $1.99 to $8.99 75 | [B00I3O3BSS] 76 | Blue by Betsey Johnson Women's SB Ever Ballet Flat 77 | $89.99 78 | 79 | Action: think[It seems the first page doesn't mention long sleeves. I should check the next page] 80 | Observation: OK. 81 | 82 | Action: click[Next >] 83 | Observation: 84 | [Back to Search] 85 | Page 2 (Total results: 50) 86 | [< Prev] 87 | [Next >] 88 | [B09Q8RD8YN] 89 | Bungo Stray Anime Dogs Anime Character, Long Sleeve, Sweatshirt, Hoodie, T shirt 90 | $19.99 91 | [B09KLQLLT2] 92 | Long Sleeve Superhero T Shirt Tank Top Mens Compression Shirt Men Workout Fitness Gym Shirt 93 | $19.99 94 | [B09R9YCM6R] 95 | XXBR Summer T-shirts for Mens, Soldier Short Sleeve 3D Street Vintage Printed Shirt Slim Fit Muscle Casual Tee Tops 96 | $8.98 to $11.99 97 | 98 | Action: think[It looks like we have long sleeves tshirt but not black. I will not bother looking at the others pages and will focus on rephrasing the query. I first need to click 'Back to Search' button.] 99 | Observation: OK. 100 | 101 | Action: click[Back to Search] 102 | Observation: 103 | i am looking for long sleeve men t-shirt and please also choose the black one, and price lower than 40.00 dollars 104 | [Search] 105 | 106 | Action: search[long sleeves t-shirt man black] 107 | Observation: 108 | [Back to Search] 109 | Page 1 (Total results: 50) 110 | [Next >] 111 | [B09QGK5XHZ] 112 | WENKOMG1 Men's Long Sleeve Undershirt with Mask Turtleneck Hooded T-Shirt Solid Color Workout Tops Zipper Side Slit Shirts Slim Fit Sweatshirt Spring/Summer Tee Shirts(Gray,) 113 | $8.39 114 | [B09QQP3356] 115 | HAUKLIE Men's Sports Waffle Ribbed Polo Shirts Summer Long Sleeve Cotton Muscle Quarter-Zip Henley T-Shirt Tunics Tops 116 | $10.99 117 | [B09Q67H373] 118 | JSPOYOU 2022 Newly T-Shirt for Mens Funny 3D Graphics Pattern Crewneck Short Sleeve Tees Big and Tall Summer Casual Comfy Top 119 | $1.99 to $8.99 120 | 121 | Action: think[B09QQP3356 seems to be a good match now. It is under 40.00 dollars and has long sleeves. I will check it first.] 122 | Observation: OK. 123 | 124 | Action: click[B09QQP3356] 125 | Observation: 126 | [Back to Search] 127 | [< Prev] 128 | color [black][gray][light blue] 129 | size [small][medium][large][x-large][xx-large][3x-large] 130 | HAUKLIE Men's Sports Waffle Ribbed Polo Shirts Summer Long Sleeve Cotton Muscle Quarter-Zip Henley T-Shirt Tunics Tops 131 | Price: $10.99 132 | Rating: N.A. 133 | [Description] 134 | [Features] 135 | [Reviews] 136 | [Buy Now] 137 | 138 | Action: think[There is 2 options, color and size. I need to select them one by one according to the instruction. First, I will choose the black color.] 139 | Observation: OK. 140 | 141 | Action: click[black] 142 | Observation: You have clicked black. 143 | 144 | Action: think[The item has also a size option, I must select one. Since no information is given about the size, I will the first one.] 145 | Observation: OK. 146 | 147 | Action: click[small] 148 | Observation: You have clicked small. 149 | 150 | Action: think[All options are proprely selected, I can now buy it.] 151 | Observation: OK. 152 | 153 | Action: click[Buy Now] 154 | Observation: Your score (min 0.0, max 1.0): 1.0""" 155 | ] 156 | 157 | NEW_REFLECTION_FEWSHOTS = [ 158 | """Previous Trial: 159 | Instruction: 160 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars 161 | [Search] 162 | 163 | Action: search[dairy free and apple variety pack of chips] 164 | Observation: 165 | [Back to Search] 166 | Page 1 (Total results: 50) 167 | [Next >] 168 | [B07HRFJWP8] 169 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total) 170 | $100.0 171 | [B01KMHY5PG] 172 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars) 173 | $21.49 174 | [B008D2X8C4] 175 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 176 | $100.0 177 | 178 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.] 179 | Observation: OK. 180 | 181 | Action: click[B008D2X8C4] 182 | Observation: 183 | [Back to Search] 184 | [< Prev] 185 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack] 186 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)] 187 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 188 | Price: $100.0 189 | Rating: N.A. 190 | [Description] 191 | [Features] 192 | [Reviews] 193 | [Buy Now] 194 | 195 | Action: think[The item has 2 options, flavor and size. I need to select them one by one according to the instruction. First, I will choose the variety pack.] 196 | Observation: OK. 197 | 198 | Action: click[variety pack] 199 | Observation: You have clicked variety pack. 200 | 201 | Action: think[Then, since there is no required size, I will choose the first one.] 202 | Observation: OK. 203 | 204 | Action: click[0.8 ounce (pack of 24)] 205 | Observation: You have clicked 0.8 ounce (pack of 24). 206 | 207 | Action: click[Buy Now] 208 | Observation: Wrong item! TASK FAILED 209 | 210 | Next plan: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.""", 211 | """Previous Trial: 212 | Instruction: 213 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars 214 | [Search] 215 | 216 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2] 217 | Observation: 218 | [Back to Search] 219 | Page 1 (Total results: 50) 220 | [Next >] 221 | [B06Y96MXJV] 222 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 223 | $29.99 224 | [B06Y96N1KG] 225 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 226 | $29.99 227 | [B07GJTKYJQ] 228 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 229 | $42.99 230 | 231 | Action: think[The first two items are not what I want. The third item is not what I want, but it is the closest one.] 232 | Observation: OK. 233 | 234 | Action: think[The third item is not what I want, but it is the closest one.] 235 | Observation: OK. 236 | 237 | Action: click[B07GJTKYJQ] 238 | Observation: 239 | [Back to Search] 240 | [< Prev] 241 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 242 | Price: $42.99 243 | Rating: N.A. 244 | [Description] 245 | [Features] 246 | [Reviews] 247 | [Buy Now] 248 | 249 | Action: think[The price is too high and it is not what I want.] 250 | Observation: OK. 251 | 252 | Action: click[< Prev] 253 | Observation: 254 | [Back to Search] 255 | Page 1 (Total results: 50) 256 | [Next >] 257 | [B06Y96MXJV] 258 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 259 | $29.99 260 | [B06Y96N1KG] 261 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 262 | $29.99 263 | [B07GJTKYJQ] 264 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 265 | $42.99 266 | 267 | Action: think[The first two items are not what I want.] 268 | Observation: OK. 269 | 270 | Action: think[I need to refine my search.] 271 | Observation: OK. 272 | 273 | Action: click[< Back to Search] 274 | Observation: Invalid action! 275 | 276 | Action: click[Back to Search] 277 | Observation: Ran out of steps! TASK FAILED 278 | 279 | Next plan: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches and think less so that I can find more products.""" 280 | ] 281 | 282 | ACT_REFLECTION_FEWSHOTS = [ 283 | """Previous Trial: 284 | Instruction: 285 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars 286 | [Search] 287 | 288 | Action: search[dairy free and apple variety pack of chips] 289 | Observation: 290 | [Back to Search] 291 | Page 1 (Total results: 50) 292 | [Next >] 293 | [B07HRFJWP8] 294 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total) 295 | $100.0 296 | [B01KMHY5PG] 297 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars) 298 | $21.49 299 | [B008D2X8C4] 300 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 301 | $100.0 302 | 303 | Action: click[B008D2X8C4] 304 | Observation: 305 | [Back to Search] 306 | [< Prev] 307 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack] 308 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)] 309 | Price: $100.0 310 | Rating: N.A. 311 | [Description] 312 | [Features] 313 | [Reviews] 314 | [Buy Now] 315 | 316 | Action: click[variety pack] 317 | Observation: You have clicked variety pack. 318 | 319 | Action: click[0.8 ounce (pack of 24)] 320 | Observation: You have clicked 0.8 ounce (pack of 24). 321 | 322 | Action: click[Buy Now] 323 | Observation: Wrong item! TASK FAILED 324 | 325 | Next plan: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.""", 326 | """Previous Trial: 327 | Instruction: 328 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars 329 | [Search] 330 | 331 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2] 332 | Observation: 333 | [Back to Search] 334 | Page 1 (Total results: 50) 335 | [Next >] 336 | [B06Y96MXJV] 337 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 338 | $29.99 339 | [B06Y96N1KG] 340 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 341 | $29.99 342 | [B07GJTKYJQ] 343 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 344 | $42.99 345 | 346 | Action: click[B07GJTKYJQ] 347 | Observation: 348 | [Back to Search] 349 | [< Prev] 350 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 351 | Price: $42.99 352 | Rating: N.A. 353 | [Description] 354 | [Features] 355 | [Reviews] 356 | [Buy Now] 357 | 358 | Action: click[< Prev] 359 | Observation: 360 | [Back to Search] 361 | Page 1 (Total results: 50) 362 | [Next >] 363 | [B06Y96MXJV] 364 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 365 | $29.99 366 | [B06Y96N1KG] 367 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 368 | $29.99 369 | [B07GJTKYJQ] 370 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 371 | $42.99 372 | 373 | Action: click[< Back to Search] 374 | Observation: Invalid action! 375 | 376 | Action: click[Back to Search] 377 | Observation: Ran out of steps! TASK FAILED 378 | 379 | Next plan: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches and think less so that I can find more products.""" 380 | ] 381 | 382 | REFLECTION_FEWSHOTS = [ 383 | """Previous Trial: 384 | Instruction: 385 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars 386 | [Search] 387 | 388 | Action: search[dairy free and apple variety pack of chips] 389 | Observation: 390 | [Back to Search] 391 | Page 1 (Total results: 50) 392 | [Next >] 393 | [B07HRFJWP8] 394 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total) 395 | $100.0 396 | [B01KMHY5PG] 397 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars) 398 | $21.49 399 | [B008D2X8C4] 400 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 401 | $100.0 402 | 403 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.] 404 | Observation: OK. 405 | 406 | Action: click[B008D2X8C4] 407 | Observation: 408 | [Back to Search] 409 | [< Prev] 410 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack] 411 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)] 412 | Price: $100.0 413 | Rating: N.A. 414 | [Description] 415 | [Features] 416 | [Reviews] 417 | [Buy Now] 418 | 419 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.] 420 | Observation: OK. 421 | 422 | Action: click[variety pack] 423 | Observation: You have clicked variety pack. 424 | 425 | Action: click[0.8 ounce (pack of 24)] 426 | Observation: You have clicked 0.8 ounce (pack of 24). 427 | 428 | Action: click[Buy Now] 429 | Observation: Wrong item! TASK FAILED 430 | 431 | Next plan: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.""", 432 | """Previous Trial: 433 | Instruction: 434 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars 435 | [Search] 436 | 437 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2] 438 | Observation: 439 | [Back to Search] 440 | Page 1 (Total results: 50) 441 | [Next >] 442 | [B06Y96MXJV] 443 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 444 | $29.99 445 | [B06Y96N1KG] 446 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 447 | $29.99 448 | [B07GJTKYJQ] 449 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 450 | $42.99 451 | 452 | Action: think[The first two items are not what I want. The third item is not what I want, but it is the closest one.] 453 | Observation: OK. 454 | 455 | Action: think[The third item is not what I want, but it is the closest one.] 456 | Observation: OK. 457 | 458 | Action: click[B07GJTKYJQ] 459 | Observation: 460 | [Back to Search] 461 | [< Prev] 462 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 463 | Price: $42.99 464 | Rating: N.A. 465 | [Description] 466 | [Features] 467 | [Reviews] 468 | [Buy Now] 469 | 470 | Action: think[The price is too high and it is not what I want.] 471 | Observation: OK. 472 | 473 | Action: click[< Prev] 474 | Observation: 475 | [Back to Search] 476 | Page 1 (Total results: 50) 477 | [Next >] 478 | [B06Y96MXJV] 479 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 480 | $29.99 481 | [B06Y96N1KG] 482 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 483 | $29.99 484 | [B07GJTKYJQ] 485 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 486 | $42.99 487 | 488 | Action: think[The first two items are not what I want.] 489 | Observation: OK. 490 | 491 | Action: think[I need to refine my search.] 492 | Observation: OK. 493 | 494 | Action: click[< Back to Search] 495 | Observation: Invalid action! 496 | 497 | Action: click[Back to Search] 498 | Observation: Ran out of steps! TASK FAILED 499 | 500 | Next plan: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches and think less so that I can find more products.""" 501 | ] 502 | 503 | SYSTEM_INSTRUCTION = """Follow the syntax of the examples closely when taking actions.""" 504 | 505 | human_instruction_template = """{instruction}You may take maximum of {max_steps} steps. 506 | Here are two examples:""" 507 | 508 | HUMAN_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_template) 509 | 510 | human_instruction_reflection_template = """Here are two examples:""" 511 | HUMAN_REFLECTION_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_reflection_template) 512 | 513 | SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION = """You will be given two previous task trials in which you could interact with an online store website in order to buy the requested item: one successful and one unsuccessful trial. You failed the trial either because you reached the maximum allowed number of steps or you bought a wrong item.""" 514 | SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION = """You will be given successful task trials in which you could interact with an online store website in order to buy the requested items.""" 515 | 516 | def LLM_PARSER(llm_output, step: int, ai_message: bool) -> Tuple[ChatMessage, str, Dict[str, Any]]: 517 | pattern = r'(?i)action\s*\d*\s*:\s*' 518 | formatted_output = re.sub(pattern, '', llm_output) 519 | if '[' not in formatted_output: 520 | formatted_output = f'think[{formatted_output.lstrip("Observation:").strip()}]' 521 | elif '[' in formatted_output and not formatted_output.endswith(']'): 522 | formatted_output += ']' 523 | return ( 524 | AIMessage(content=f'Action: {formatted_output}') if ai_message else HumanMessage(content=f'Action: {formatted_output}'), 525 | 'action', 526 | {'action': formatted_output} 527 | ) 528 | 529 | def OBSERVATION_FORMATTER(observation: str, *args, **kwargs) -> Tuple[ChatMessage, str]: 530 | return HumanMessage(content=f"Observation: " + observation + '\n'), 'append' 531 | 532 | def STEP_IDENTIFIER(line: str) -> str: 533 | line = line.strip() 534 | if line.lower().startswith('action: think['): 535 | return 'thought' 536 | if line.lower().startswith('action'): 537 | return 'action' 538 | return 'observation' 539 | 540 | # Should return [action1, observation1, action2, observation2, ...] 541 | def CYCLER(lines: str) -> List[str]: 542 | listed_lines = [] 543 | listed_steps = lines.split('\n\n') 544 | for step in listed_steps: 545 | # assuming action only takes one line 546 | listed_lines.extend(step.split('\n', 1)) 547 | return listed_lines 548 | 549 | REFLECTION_PREFIX = '\nNext plan:' 550 | def PREVIOUS_TRIALS_FORMATTER(reflections: List[str], *args, **kwargs) -> str: 551 | if reflections == []: 552 | return '' 553 | memory_prefix = 'Your memory for the task below:' 554 | for i, reflection in enumerate(reflections): 555 | memory_prefix += f"\nTrial {i}:\n{reflection.strip()}" 556 | return memory_prefix 557 | 558 | def STEP_STRIPPER(step: str, step_type: str): 559 | if step_type != 'task': 560 | return step 561 | return step.strip('Instruction:\n').strip() 562 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.26.2 2 | hydra-core==1.3.2 3 | ipykernel==6.19.2 4 | joblib==1.2.0 5 | langchain==0.0.181 6 | omegaconf==2.3.0 7 | openai==0.27.7 8 | pandas==1.5.3 9 | python-dotenv==1.0.0 10 | tiktoken==0.4.0 11 | wikipedia==1.4.0 12 | torch # ==1.9.0+cu111 13 | transformers==4.30.2 14 | einops==0.6.1 15 | sentence_transformers 16 | matplotlib 17 | # gdown # for downloading files from google drive (webshop) -------------------------------------------------------------------------------- /tools/visualize_logs.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | import hydra 3 | 4 | import sys 5 | import os 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | from utils import load_trajectories_log, split_logs_by_task, plot_trial_stats, alfworld_results_per_env_name_log, get_webshop_mean_score, get_webshop_mean_scores, mode_results 9 | 10 | @hydra.main(version_base=None, config_path="../configs", config_name="visualize_logs") 11 | def main(cfg: DictConfig) -> None: 12 | out = load_trajectories_log(path=cfg.log_path, run_name=cfg.run_name, load_dict=False) 13 | log = out['log'] 14 | 15 | parsed_result = split_logs_by_task(text=log, num_tasks=cfg.num_tasks) 16 | assert len(parsed_result) == cfg.num_tasks 17 | 18 | res = plot_trial_stats(parsed_result=parsed_result, benchmark=cfg.benchmark.name, max_trials=4, save_path=f"{cfg.log_path}/{cfg.run_name}_logs_stats.png" if cfg.save_fig else None) 19 | if 'eval' in cfg.run_name: 20 | if cfg.benchmark.name == 'alfworld': 21 | print(alfworld_results_per_env_name_log(log, cfg.num_tasks, 1)) 22 | elif cfg.benchmark.name == 'webshop': 23 | print(get_webshop_mean_score(log, cfg.num_tasks, 1)) 24 | res = {k: v[-1] for k, v in res.items()} 25 | else: 26 | if cfg.benchmark.name == 'alfworld': 27 | print(alfworld_results_per_env_name_log(log, cfg.num_tasks, cfg.agent.max_reflection_depth+1)) 28 | elif cfg.benchmark.name == 'webshop': 29 | print(get_webshop_mean_scores(log, cfg.num_tasks, cfg.agent.max_reflection_depth+1)) 30 | 31 | print(res) 32 | 33 | ############################################ 34 | ### MODE OPTIONS SELECTION ### 35 | ############################################ 36 | # [react, reflection] 37 | # _[sum, mean, list] 38 | # _[token, count] 39 | # _[thought, action, observation, invalid] 40 | # _[traj, step] 41 | ############################################ 42 | modes = [ 43 | 'react_mean_count_of_thought_per_traj', 44 | 'react_mean_count_of_action_per_traj', 45 | 'react_mean_count_of_observation_per_traj', 46 | 'react_mean_tokens_per_traj', 47 | 'react_mean_count_of_invalid_per_traj', 48 | ] 49 | if 'eval' not in cfg.run_name: 50 | modes += [ 51 | 'reflection1_mean_tokens_per_traj', 52 | 'reflection2_mean_tokens_per_traj', 53 | 'reflection3_mean_tokens_per_traj', 54 | 'reflection1_mean_count_of_invalid_per_traj', 55 | 'reflection2_mean_count_of_invalid_per_traj', 56 | 'reflection3_mean_count_of_invalid_per_traj',] 57 | for mode in modes: 58 | print(f"{mode}: {mode_results(cfg.benchmark.name, log, cfg.num_tasks, mode)}") 59 | 60 | 61 | if __name__ == '__main__': 62 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import getpass 2 | import hydra 3 | from omegaconf import DictConfig 4 | from pathlib import Path 5 | import os 6 | from copy import deepcopy 7 | from functools import partial 8 | import dotenv 9 | dotenv.load_dotenv() 10 | 11 | from agent import AGENT 12 | from prompts.templates.system import system_message_prompt 13 | from prompts.templates.human import HUMAN_CRITIQUES 14 | from prompts import ( 15 | SYSTEM_INSTRUCTION, 16 | HUMAN_INSTRUCTION, 17 | FEWSHOTS, 18 | REFLECTION_FEWSHOTS, 19 | HUMAN_REFLECTION_INSTRUCTION, 20 | SYSTEM_REFLECTION_INSTRUCTION, 21 | SYSTEM_CRITIQUE_INSTRUCTION, 22 | RULE_TEMPLATE, 23 | LLM_PARSER, 24 | OBSERVATION_FORMATTER, 25 | STEP_IDENTIFIER, 26 | CYCLER, 27 | STEP_CYCLER, 28 | REFLECTION_PREFIX, 29 | PREVIOUS_TRIALS_FORMATTER, 30 | STEP_STRIPPER, 31 | CRITIQUE_SUMMARY_SUFFIX, 32 | ) 33 | from envs import ENVS, INIT_TASKS_FN 34 | from memory import ( 35 | EMBEDDERS, 36 | RETRIEVERS, 37 | ) 38 | from models import LLM_CLS 39 | from utils import save_trajectories_log, load_trajectories_log, plot_trial_stats, split_logs_by_task, alfworld_results_per_env_name, get_webshop_mean_scores, get_fewshot_max_tokens 40 | from agent.reflect import Count 41 | 42 | @hydra.main(version_base=None, config_path="configs", config_name="train") 43 | def main(cfg : DictConfig) -> None: 44 | if cfg.testing: 45 | openai_api_key = 'NO_KEY_FOR_TESTING' 46 | else: 47 | openai_api_key = os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else getpass.getpass("Enter or paste your OpenAI API Key: ") 48 | LOG_PATH = Path('/'.join([cfg.log_dir, cfg.benchmark.name, cfg.agent_type])) 49 | LOG_PATH.mkdir(parents=True, exist_ok=True) 50 | 51 | # Load trajectory checkpoint, init as empty if not exist 52 | if cfg.resume: 53 | out = load_trajectories_log( 54 | LOG_PATH, 55 | run_name=cfg.run_name, 56 | load_true_log=True) 57 | else: 58 | # Overwriting confirmation 59 | if os.path.exists(f"{LOG_PATH}/{cfg.run_name}.pkl") and cfg.run_name != 'test': 60 | while True: 61 | res = input(f"Are you sure to overwrite '{cfg.run_name}'? (Y/N)\n").lower() 62 | if res == 'n': 63 | exit(0) 64 | elif res == 'y': 65 | break 66 | out = {'log': '', 'dicts': [], 'true_log': f'{str(cfg)}'} 67 | log, dicts, true_log = out['log'], out['dicts'], out['true_log'] 68 | 69 | react_agent = AGENT[cfg.agent_type]( 70 | name=cfg.ai_name, 71 | system_instruction=SYSTEM_INSTRUCTION[cfg.benchmark.name], 72 | human_instruction=HUMAN_INSTRUCTION[cfg.benchmark.name], 73 | tasks=INIT_TASKS_FN[cfg.benchmark.name](cfg), 74 | fewshots=FEWSHOTS[cfg.benchmark.name], 75 | system_prompt=system_message_prompt, 76 | env=ENVS[cfg.benchmark.name], 77 | max_steps=cfg.benchmark.max_steps, 78 | openai_api_key=openai_api_key, 79 | llm=cfg.agent.llm, 80 | llm_builder=LLM_CLS, 81 | reflection_fewshots=REFLECTION_FEWSHOTS[cfg.benchmark.name], 82 | reflection_task_prompt=HUMAN_REFLECTION_INSTRUCTION[cfg.benchmark.name], 83 | reflection_system_instruction=SYSTEM_REFLECTION_INSTRUCTION[cfg.benchmark.name], 84 | reflection_system_prompt=SYSTEM_INSTRUCTION[cfg.benchmark.name], 85 | max_relfection_depth=cfg.agent.max_reflection_depth if 'max_reflection_depth' in cfg.agent.keys() else 0, 86 | system_critique_instructions=SYSTEM_CRITIQUE_INSTRUCTION[cfg.benchmark.name], 87 | human_critiques=HUMAN_CRITIQUES, 88 | max_num_rules=cfg.agent.max_num_rules if 'max_num_rules' in cfg.agent.keys() else 0, 89 | rule_template=RULE_TEMPLATE[cfg.benchmark.name], 90 | truncate_strategy=cfg.agent.truncate_strategy if 'truncate_strategy' in cfg.agent.keys() else None, 91 | llm_parser=LLM_PARSER[cfg.benchmark.name], 92 | observation_formatter=OBSERVATION_FORMATTER[cfg.benchmark.name], 93 | embedder=EMBEDDERS(cfg.agent.retrieval_kwargs.embedder_type), 94 | embedder_path=cfg.agent.retrieval_kwargs.embedder_path, 95 | step_stripper=STEP_STRIPPER[cfg.benchmark.name], 96 | retriever_cls=RETRIEVERS(cfg.agent.retrieval_kwargs.retriever_type), 97 | message_splitter=CYCLER[cfg.benchmark.name], 98 | identifier=STEP_IDENTIFIER[cfg.benchmark.name], 99 | message_step_splitter=partial(STEP_CYCLER, benchmark=cfg.benchmark.name), 100 | reflection_prefix=REFLECTION_PREFIX[cfg.benchmark.name], 101 | previous_trials_formatter=PREVIOUS_TRIALS_FORMATTER[cfg.benchmark.name], 102 | success_critique_num=cfg.agent.success_critique_num, 103 | fewshot_strategy=cfg.agent.fewshot_strategy, 104 | critique_truncate_strategy=cfg.agent.critique_truncate_strategy, 105 | critique_summary_suffix=CRITIQUE_SUMMARY_SUFFIX, 106 | testing=cfg.testing, 107 | task_idx=dicts[-1]['task_idx'] if len(dicts) > 0 else 0, 108 | benchmark_name=cfg.benchmark.name, 109 | reranker=cfg.agent.retrieval_kwargs.reranker, 110 | buffer_retrieve_ratio=cfg.agent.retrieval_kwargs.buffer_retrieve_ratio, 111 | max_fewshot_tokens=get_fewshot_max_tokens(cfg.benchmark.name) if cfg.agent.retrieval_kwargs.max_fewshot_tokens == 'auto' else cfg.agent.retrieval_kwargs.max_fewshot_tokens, 112 | ) 113 | if len(dicts) > 0: 114 | react_agent.load_checkpoint(loaded_dict=dicts[-1], no_load_list=['testing', 'max_relfection_depth', 'fewshot_strategy', 'max_fewshot_tokens']) 115 | if 'eval_idx_list' in dicts[-1]: 116 | react_agent.eval_idx_list = dicts[-1]['eval_idx_list'] 117 | 118 | print(f"""*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* 119 | 120 | You are using the following language model: {react_agent.llm.model_name} 121 | 122 | *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*""") 123 | 124 | while react_agent.job_not_done(): 125 | prefix = f"#######################################\nTASK {react_agent.task_idx}" 126 | if cfg.agent_type in ['reflection', 'expel']: 127 | prefix += f' Reflection {react_agent.reflection_counter.count}\n\n' 128 | else: 129 | prefix += '\n\n' 130 | print(prefix + react_agent.remove_task_suffix(react_agent.task)) # remove_task_suffix used for alfworld 131 | 132 | react_agent.run(mode='train') 133 | 134 | ############################################# 135 | ### Update & Save trajectory logs + dicts ### 136 | ############################################# 137 | react_agent.update_stats() 138 | log += prefix + react_agent.log_history() + '\n\n' 139 | true_log += prefix + react_agent.log_history(include_all=True) + '\n\n' 140 | 141 | # next task 142 | react_agent.next_task() 143 | 144 | dicts.append({k: deepcopy(v) for k, v in react_agent.__dict__.items() if type(v) in [list, set, str, bool, int, dict, Count] and k not in ['openai_api_key', 'llm']}) # not saving complicated objects 145 | 146 | save_trajectories_log( 147 | LOG_PATH, log, dicts, true_log, 148 | run_name=cfg.run_name 149 | ) 150 | ############################################# 151 | 152 | ###################################### 153 | ### Final Log & Save stats + PRINT ### 154 | ###################################### 155 | success, fail, halted = react_agent.get_stats() 156 | log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}" 157 | true_log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}" 158 | print(f'Finished. Success: {success}, Fail: {fail}, Halted: {halted}') 159 | 160 | parsed_result = split_logs_by_task(text=log, num_tasks=len(react_agent.tasks)) 161 | reflection_results = plot_trial_stats(parsed_result=parsed_result, benchmark=cfg.benchmark.name, max_trials=cfg.agent.max_reflection_depth + 1, save_path=f"{LOG_PATH}/{cfg.run_name}_logs_stats.png") 162 | 163 | results = ', '.join([f"{k}: {v}" for k, v in reflection_results.items()]) + '\n' 164 | if cfg.benchmark.name == 'alfworld': 165 | results += str(alfworld_results_per_env_name(dicts[-1])) 166 | elif cfg.benchmark.name == 'webshop': 167 | results += str(get_webshop_mean_scores(log, len(react_agent.tasks), cfg.agent.max_reflection_depth + 1)) 168 | log += f'\n\n{results}\n#######################################' 169 | true_log += f'\n\n{results}\n#######################################' 170 | print(results) 171 | 172 | save_trajectories_log( 173 | LOG_PATH, log, dicts, true_log, 174 | run_name=cfg.run_name 175 | ) 176 | 177 | log, dicts, true_log = '', [], '' 178 | react_agent.reset_stats() 179 | ################################ 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from typing import List, Dict, Callable, Tuple, Any, Union 4 | from matplotlib import pyplot as plt 5 | 6 | import tiktoken 7 | from langchain.schema import ( 8 | ChatMessage 9 | ) 10 | 11 | from prompts import FEWSHOTS 12 | 13 | import math 14 | import pickle 15 | import re 16 | 17 | 18 | ENV_NAMES = [ 19 | 'pick_and_place', 20 | 'pick_clean_then_place', 21 | 'pick_heat_then_place', 22 | 'pick_cool_then_place', 23 | 'look_at_obj', 24 | 'pick_two_obj' 25 | ] 26 | 27 | TASK_ENV_NAMES = [ 28 | ('clean', 'pick_clean_then_place'), 29 | ('hot', 'pick_heat_then_place'), 30 | ('heat', 'pick_heat_then_place'), 31 | ('cool', 'pick_cool_then_place'), 32 | ('look', 'look_at_obj'), 33 | ('examine', 'look_at_obj'), 34 | ('two', 'pick_two_obj'), 35 | ('put', 'pick_and_place') # last one must be at last position 36 | ] 37 | 38 | class Count: 39 | """ 40 | Class for reflection counting. 41 | """ 42 | def __init__(self, maximum): 43 | self.count = 0 44 | self.maximum = maximum 45 | 46 | def increment(self): 47 | self.count += 1 48 | if self.count > self.maximum: 49 | self.count = self.maximum 50 | 51 | def is_beginning(self): 52 | return self.count == 1 53 | 54 | def reset(self): 55 | self.count = 0 56 | 57 | def is_maximum(self): 58 | return self.count == self.maximum 59 | 60 | def random_divide_list(lst: List[Any], k: int): 61 | """ 62 | Divides the list into chunks, each with maximum length k. 63 | 64 | Args: 65 | lst: The list to be divided. 66 | k: The maximum length of each chunk. 67 | 68 | Returns: 69 | A list of chunks. 70 | """ 71 | random.shuffle(lst) 72 | if len(lst) <= k: 73 | return [lst] 74 | else: 75 | num_chunks = math.ceil(len(lst) / k) 76 | chunk_size = math.ceil(len(lst) / num_chunks) 77 | return [lst[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)] 78 | 79 | def shuffled_chunks(lst: List[Any], num_chunks: int): 80 | """ 81 | Divides the list into chunks as equally as possible. 82 | 83 | Args: 84 | lst: The list to be divided. 85 | num_chunks: The number of chunks. 86 | 87 | Returns: 88 | A list of chunks. 89 | """ 90 | random.shuffle(lst) 91 | chunk_size = len(lst) // num_chunks 92 | remainder = len(lst) % num_chunks 93 | chunks = [lst[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)] 94 | 95 | # Distribute the remainder elements across the chunks 96 | if remainder > 0: 97 | for i in range(remainder): 98 | chunks[i].append(lst[num_chunks * chunk_size + i]) 99 | 100 | random.shuffle(chunks) 101 | return chunks 102 | 103 | def token_counter(text: str, llm: str = 'gpt-3.5-turbo', tokenizer: Callable = None) -> int: 104 | """ 105 | Counts the number of tokens in the text. 106 | 107 | Args: 108 | text: The text to be counted. 109 | llm: The language model name. 110 | tokenizer: The tokenizer to be used. 111 | """ 112 | if 'gpt' in llm: 113 | return len(tiktoken.encoding_for_model(llm).encode(text)) 114 | 115 | raise NotImplementedError 116 | 117 | def print_message(message: ChatMessage, token_counter: Callable = None, testing: bool = True, extra_text: str = '') -> None: 118 | """ 119 | Prints the formatted message. 120 | 121 | Args: 122 | message: The message to be printed. 123 | token_counter: A function that takes in a string and returns the number of tokens in the string. 124 | testing: Add message type and token count in testing mode. 125 | extra_text: Extra text to be printed after the message in testing mode. 126 | """ 127 | if testing: 128 | message = f"$$${message.type}$$$\t{message.content}\t$$${message.type}$$$" 129 | if token_counter is not None: 130 | message += f"\t***{token_counter(message)} tokens***" 131 | message += extra_text 132 | print(message) 133 | else: 134 | print(message.content) 135 | 136 | def parse_action(string: str): 137 | """ 138 | Parse action string into action type and argument for HotpotQA and Fever. 139 | 140 | Args: 141 | string: action string 142 | 143 | Returns: 144 | action_type: action type 145 | argument: argument 146 | """ 147 | pattern = r'^(\w+)\[(.+)\]$' 148 | match = re.match(pattern, string) 149 | 150 | if match: 151 | action_type = match.group(1) 152 | argument = match.group(2) 153 | return action_type, argument 154 | 155 | else: 156 | return None, None 157 | 158 | def normalize_answer(s: str): 159 | """ 160 | Lower text and remove punctuation, articles and extra whitespace. 161 | 162 | Args: 163 | s: string to normalize 164 | 165 | Returns: 166 | normalized string 167 | """ 168 | def remove_articles(text): 169 | return re.sub(r"\b(a|an|the)\b", " ", text) 170 | 171 | def white_space_fix(text): 172 | return " ".join(text.split()) 173 | 174 | def remove_punc(text): 175 | exclude = set(string.punctuation) 176 | return "".join(ch for ch in text if ch not in exclude) 177 | 178 | def lower(text): 179 | return text.lower() 180 | 181 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 182 | 183 | def EM(answer, key) -> bool: 184 | """ 185 | Exact match between answer and key. 186 | 187 | Args: 188 | answer: answer 189 | key: key 190 | 191 | Returns: 192 | True if exact match, else False 193 | """ 194 | return normalize_answer(answer) == normalize_answer(key) 195 | 196 | 197 | def save_trajectories_log(path: str, log: str = None, dicts: list = None, true_log: str = None, save_log: bool = True, save_dict: bool = True, save_true_log: bool = True, run_name: str = 'run') -> None: 198 | """ 199 | Saves the log and the dict to the path. 200 | 201 | Args: 202 | path: The path to save the log and the dictionaries. 203 | log: The log to be saved. 204 | dicts: The dict to be saved. 205 | true_log: The true log to be saved. 206 | save_log: Whether to save the log. 207 | save_dict: Whether to save the dictionaries. 208 | save_true_log: Whether to save the true log. 209 | run_name: The name of the run. 210 | """ 211 | if save_log: 212 | with open(f'{path}/{run_name}.txt', 'w') as f: 213 | f.write(log) 214 | if save_dict: 215 | with open(f'{path}/{run_name}.pkl', 'wb') as f: 216 | pickle.dump(dicts, f) 217 | if save_true_log: 218 | with open(f'{path}/{run_name}_true.txt', 'w') as f: 219 | f.write(true_log) 220 | 221 | def load_trajectories_log(path: str, load_log: bool = True, load_dict: bool = True, load_true_log: bool = False, run_name: str = 'run') -> Dict[str, Any]: 222 | """ 223 | Loads the log and the dict from the path. 224 | 225 | Args: 226 | path: The path to load the logs and the dictionaries. 227 | load_log: Whether to load the log. 228 | load_dict: Whether to load the dictionaries. 229 | load_true_log: Whether to load the true log. 230 | run_name: The name of the run. 231 | 232 | Returns: 233 | A dictionary containing the logs and the dict. 234 | """ 235 | out = dict() 236 | if load_log: 237 | with open(f'{path}/{run_name}.txt', 'r') as f: 238 | out['log'] = f.read() 239 | if load_dict: 240 | with open(f'{path}/{run_name}.pkl', 'rb') as f: 241 | out['dicts'] = pickle.load(f) 242 | if load_true_log: 243 | with open(f'{path}/{run_name}_true.txt', 'r') as f: 244 | out['true_log'] = f.read() 245 | 246 | return out 247 | 248 | def split_logs_by_task(text: str, num_tasks: int) -> List[List[str]]: 249 | """ 250 | Splits the log text by task. 251 | 252 | Args: 253 | text: The log text. 254 | num_tasks: The number of tasks. 255 | 256 | Returns: 257 | A list of lists of log texts, each list corresponding to a task. 258 | """ 259 | remaining_text = text 260 | parsed_result = [] 261 | for task_i in range(num_tasks+1): 262 | if task_i == num_tasks: 263 | pattern_i = r'########################################\nEND TRIAL' 264 | else: 265 | pattern_i = rf'#######################################\n.*TASK {str(task_i)} ' 266 | matches_i = re.split(pattern_i, remaining_text) 267 | remaining_text = matches_i[-1] 268 | parsed_result.append(matches_i[1:-1]) 269 | if task_i != 0: 270 | parsed_result[task_i-1].append(matches_i[0]) 271 | 272 | # remove the last empty list 273 | parsed_result.pop() 274 | 275 | return parsed_result 276 | 277 | def recompute_stats(parsed_result: List[List[str]], benchmark: str, trial: int = -1) -> Dict[str, int]: 278 | """ 279 | Recomputes the stats from the parsed log text. 280 | 281 | Args: 282 | parsed_result: The parsed log text. 283 | benchmark: The benchmark name. 284 | trial: The number of trials. 285 | 286 | Returns: 287 | The stats for the given benchmark. 288 | """ 289 | stats = {"success": 0, "fail": 0, "halted": 0} if benchmark != 'alfworld' else {"success": 0, "fail": 0} 290 | 291 | for task_i in range(len(parsed_result)): 292 | trajectories = parsed_result[task_i] 293 | last_trajectory = trajectories[min(trial, len(trajectories) - 1)].strip() 294 | last_step = last_trajectory.split('\n')[-1] 295 | 296 | if benchmark == 'hotpotqa': 297 | if ' CORRECT' in last_step: 298 | stats["success"] += 1 299 | elif 'INCORRECT' in last_step: 300 | stats["fail"] += 1 301 | else: 302 | stats["halted"] += 1 303 | elif benchmark == 'alfworld': 304 | if 'SOLVED' in last_step: 305 | stats["success"] += 1 306 | else: 307 | stats["fail"] += 1 308 | elif benchmark == 'webshop': 309 | if ': 1.0' in last_step or 'Your score' in last_step: 310 | stats["success"] += 1 311 | else: 312 | stats["halted"] += 1 313 | elif benchmark == 'fever': 314 | if 'reward = 1' in last_step: 315 | stats["success"] += 1 316 | elif 'reward = 0' in last_step: 317 | stats["fail"] += 1 318 | else: 319 | stats["halted"] += 1 320 | else: 321 | raise NotImplementedError(f'recompute_stats for {benchmark} not implemented') 322 | 323 | return stats 324 | 325 | def plot_trial_stats(parsed_result: List[List[str]], benchmark: str, max_trials: int = 4, save_path: str = None) -> Dict[str, List[int]]: 326 | """ 327 | Plots the stats from the parsed log text. 328 | 329 | Args: 330 | parsed_result: The parsed log text. 331 | benchmark: The benchmark name. 332 | max_trials: The number of trials. 333 | save_path: The path to save the figure. 334 | 335 | Returns: 336 | The stats for the given benchmark. 337 | """ 338 | results = dict() 339 | colors = {'success': 'green', 'fail': 'red', 'halted': 'orange'} 340 | 341 | for i in range(max_trials): 342 | stats = recompute_stats(parsed_result, benchmark, i) 343 | for key, value in stats.items(): 344 | results[key] = results.get(key, []) + [value] 345 | 346 | if benchmark == 'alfworld': 347 | assert len(parsed_result) == 134 348 | results = {k: [round(x / 134 * 100, 2) for x in v] for k, v in results.items()} 349 | else: 350 | assert len(parsed_result) == 100 351 | 352 | for i, (key, value) in enumerate(results.items()): 353 | plt.plot(value, label=key, marker='o', color=colors[key]) 354 | 355 | # annotate all points with their values next to them 356 | for i in range(max_trials): 357 | for key, value in results.items(): 358 | plt.annotate(value[i], (i, value[i]), textcoords="offset points", xytext=(0,10), ha='center') 359 | 360 | plt.legend(loc='best') 361 | plt.xlabel("Reflection numbers") 362 | plt.ylabel("Task SR %") 363 | plt.xticks(range(max_trials)) 364 | plt.show() 365 | if save_path: 366 | plt.savefig(save_path) 367 | 368 | return results 369 | 370 | def get_env_name_from_gamefile(gamefile: str) -> Union[str, None]: 371 | """ 372 | Gets the environment name from the gamefile for ALFWorld. 373 | 374 | Args: 375 | gamefile: The gamefile. 376 | 377 | Returns: 378 | The environment name. 379 | """ 380 | for k in ENV_NAMES: 381 | if k in gamefile: 382 | return k 383 | 384 | def get_env_name_from_task(task: str, benchmark: str) -> Union[str, None]: 385 | """ 386 | Gets the environment name from the task instruction for ALFWorld. 387 | 388 | Args: 389 | task: The task. 390 | benchmark: The benchmark name. 391 | 392 | Returns: 393 | The environment name. 394 | """ 395 | if benchmark == 'alfworld': 396 | for k, v in TASK_ENV_NAMES: 397 | if k in task: 398 | return v 399 | else: 400 | return benchmark 401 | 402 | def alfworld_results_per_env_name(agent: Dict[str, Any]) -> Dict[str, int]: 403 | """ 404 | Computes the results per environment name for ALFWorld from agent dict. 405 | 406 | Args: 407 | agent: The agent dictionary. 408 | 409 | Returns: 410 | The results per environment name. 411 | """ 412 | tasks = agent['tasks'] 413 | succeeded_trial_history = agent['succeeded_trial_history'] 414 | failed_trial_history = agent['failed_trial_history'] 415 | 416 | results = {k: 0 for k in ENV_NAMES} 417 | totals = {k: 0 for k in ENV_NAMES} 418 | seen = {t['task']: 0 for t in tasks} 419 | for task in tasks: 420 | t = task['task'] 421 | env_name = get_env_name_from_gamefile(task['env_kwargs']['gamefile']) 422 | if len(succeeded_trial_history[t]) > seen[t]: 423 | results[env_name] += 1 424 | totals[env_name] += 1 425 | seen[t] += 1 426 | elif len(failed_trial_history[t]) > 0: 427 | totals[env_name] += 1 428 | 429 | return {k: round(results[k] / totals[k], 2) if totals[k] != 0 else 0 for k in results.keys()} 430 | 431 | def alfworld_results_per_env_name_log(log: str, num_tasks: int, num_trials: int) -> Dict[str, int]: 432 | """ 433 | Computes the results per environment name for ALFWorld from log text. 434 | 435 | Args: 436 | log: The log text. 437 | num_tasks: The number of tasks. 438 | num_trials: The number of trials. 439 | 440 | Returns: 441 | The results per environment name. 442 | """ 443 | results = {k: [0, 0, 0, 0] for k in ENV_NAMES} 444 | totals = {k: [0, 0, 0, 0] for k in ENV_NAMES} 445 | 446 | parsed_results = split_logs_by_task(log, num_tasks) 447 | for i in range(num_trials): 448 | for task in parsed_results: 449 | if 'You are in the middle of a room' not in task[i if i < len(task) else -1]: 450 | continue 451 | task_desc = re.findall(r'Your task is to: (.*)', task[i if i < len(task) else -1])[0] 452 | env_name = get_env_name_from_task(task_desc, 'alfworld') 453 | if env_name is None: 454 | raise ValueError(f'env_name is None for task {task_desc}') 455 | if 'SOLVED' in task[i if i < len(task) else -1].strip().split('\n')[-1]: 456 | results[env_name][i] += 1 457 | totals[env_name][i] += 1 458 | 459 | assert all([sum([totals[k][i] for k in ENV_NAMES]) == num_tasks for i in range(num_trials)]) 460 | print(totals) 461 | 462 | return {k: [round(results[k][i] / totals[k][i], 2) if totals[k][i] != 0 else 0 for i in range(num_trials)] for k in ENV_NAMES} 463 | 464 | def get_webshop_mean_score(log: str, num_tasks: int, num_trials: int) -> float: 465 | """ 466 | Computes the mean score for WebShop from log text. 467 | 468 | Args: 469 | log: The log text. 470 | num_tasks: The number of tasks. 471 | num_trials: The number of trials. 472 | 473 | Returns: 474 | The mean score. 475 | """ 476 | parsed_result = split_logs_by_task(text=log, num_tasks=num_tasks) 477 | 478 | assert len(parsed_result) == num_tasks 479 | return sum([ 480 | max([ 481 | float(parsed_result[k][i].strip().split('\n')[-1].split()[-1] 482 | if 'Your score' in parsed_result[k][i].strip().split('\n')[-1] else 0) 483 | for i in range(min(len(parsed_result[k]), num_trials)) 484 | ] 485 | ) for k in range(len(parsed_result))] 486 | ) / len(parsed_result) 487 | 488 | def get_webshop_mean_scores(log: str, num_tasks: int, num_trials: int) -> float: 489 | """ 490 | Computes the mean scores for WebShop from log text. 491 | 492 | Args: 493 | log: The log text. 494 | num_tasks: The number of tasks. 495 | num_trials: The number of trials. 496 | 497 | Returns: 498 | The mean scores. 499 | """ 500 | return [get_webshop_mean_score(log, num_tasks, i) for i in range(1, num_trials + 1)] 501 | 502 | def get_fewshot_max_tokens(benchmark: str) -> int: 503 | """ 504 | Gets the maximum number of tokens in the fewshot tasks for the given benchmark. 505 | 506 | Args: 507 | benchmark: The benchmark name. 508 | 509 | Returns: 510 | The maximum number of tokens. 511 | """ 512 | fewshots = FEWSHOTS[benchmark] 513 | if isinstance(fewshots, dict): 514 | return max([max([token_counter(f) for f in fs]) for fs in fewshots.values()]) 515 | elif isinstance(fewshots, list): 516 | return max([token_counter(f) for f in fewshots]) 517 | 518 | def get_split_eval_idx_list(agent_dict: Dict[str, Any], n_folds: int) -> List[List[int]]: 519 | """ 520 | Gets the split evaluation index list. 521 | 522 | Args: 523 | agent_dict: The agent dictionary. 524 | n_folds: The number of folds. 525 | 526 | Returns: 527 | The split evaluation index list. 528 | """ 529 | eval_idx_list = [[] for _ in range(n_folds)] 530 | env_names = set(x['env_name'] for x in agent_dict['tasks']) 531 | task2idx = agent_dict['task2idx'] 532 | 533 | # compare success/failure 534 | compare_dict = {env_name: [] for env_name in env_names} 535 | success_dict = {env_name: [] for env_name in env_names} 536 | fail_dict = {env_name: [] for env_name in env_names} 537 | for task, trials in agent_dict['succeeded_trial_history'].items(): 538 | if len(trials) > 0: 539 | if len(agent_dict['failed_trial_history'][task]) > 0: 540 | compare_dict[get_env_name_from_task(task, agent_dict['benchmark_name'])].append(task2idx[task]) 541 | else: 542 | success_dict[get_env_name_from_task(task, agent_dict['benchmark_name'])].append(task2idx[task]) 543 | else: 544 | assert len(agent_dict['failed_trial_history'][task]) > 0 545 | fail_dict[get_env_name_from_task(task, agent_dict['benchmark_name'])].append(task2idx[task]) 546 | 547 | # split into n_folds 548 | j = 0 549 | for idx_list in list(compare_dict.values()) + list(success_dict.values()) + list(fail_dict.values()): 550 | random.shuffle(idx_list) 551 | for idx in idx_list: 552 | eval_idx_list[j % n_folds].append(idx) 553 | j += 1 554 | 555 | assert set.intersection(*[set(x) for x in eval_idx_list]) == set() 556 | 557 | return eval_idx_list 558 | 559 | def mode_results(benchmark: str, log: str, num_tasks: int, mode: str) -> Any: 560 | """ 561 | Computes the statistic results for the given mode. 562 | 563 | Args: 564 | benchmark: The benchmark name. 565 | log: The log text. 566 | num_tasks: The number of tasks. 567 | mode: The mode. 568 | 569 | Returns: 570 | The results. 571 | """ 572 | parsed_result = split_logs_by_task(text=log, num_tasks=num_tasks) 573 | if 'react' in mode: 574 | parsed_result = [x[0] for x in parsed_result] 575 | elif 'reflection' in mode: 576 | pattern = r'reflection(\d+)' 577 | res = re.findall(pattern, mode) 578 | if len(res) > 0: 579 | i = int(res[0]) 580 | else: 581 | pattern = r'(\d+)reflection' 582 | res = re.findall(pattern, mode) 583 | if len(res) > 0: 584 | i = int(res[0]) 585 | else: 586 | i = 0 587 | parsed_result = [x[i if i < len(x) else -1] for x in parsed_result] 588 | 589 | if 'token' in mode: 590 | f_trial = token_counter 591 | elif 'count' in mode: 592 | f_trial = lambda x: 1 593 | else: 594 | raise NotImplementedError(f'mode must contain token or count') 595 | 596 | invalid = None 597 | if 'invalid' in mode: 598 | if benchmark == 'alfworld': 599 | invalid = 'nothing happens' 600 | elif benchmark == 'webshop': 601 | invalid = 'invalid action' 602 | elif benchmark in ['hotpotqa', 'fever']: 603 | # invalid = 'similar: [' 604 | invalid = 'invalid action' 605 | mode += 'observation' 606 | 607 | if 'thought' in mode: 608 | if benchmark == 'webshop': 609 | lambda_filter = lambda y: y.strip().startswith('Action: think[') 610 | elif benchmark == 'alfworld': 611 | lambda_filter = lambda y: y.strip().startswith('> think:') 612 | elif benchmark in ['hotpotqa', 'fever']: 613 | lambda_filter = lambda y: y.strip().startswith('Thought') 614 | else: 615 | raise NotImplementedError(f'benchmark {benchmark} not implemented') 616 | elif 'action' in mode: 617 | if benchmark == 'webshop': 618 | lambda_filter = lambda y: y.strip().startswith('Action: click[') or y.strip().startswith('Action: search[') # valid actions 619 | elif benchmark == 'alfworld': 620 | lambda_filter = lambda y: y.strip().startswith('> ') and not y.strip().startswith('> think:') # valid and invalid actions 621 | elif benchmark in ['hotpotqa', 'fever']: 622 | lambda_filter = lambda y: y.strip().startswith('Action') 623 | else: 624 | raise NotImplementedError(f'benchmark {benchmark} not implemented') 625 | elif 'observation' in mode: 626 | if benchmark == 'webshop': 627 | lambda_filter = lambda y: y.strip().startswith('Observation:') 628 | elif benchmark == 'alfworld': 629 | lambda_filter = lambda y: not y.strip().startswith('> ') 630 | elif benchmark in ['hotpotqa', 'fever']: 631 | lambda_filter = lambda y: y.strip().startswith('Observation') 632 | else: 633 | raise NotImplementedError(f'benchmark {benchmark} not implemented') 634 | else: 635 | lambda_filter = lambda y: True 636 | 637 | if 'step' in mode: 638 | if invalid is not None: 639 | parsed_result = [y for x in parsed_result for y in x.strip().split('\n') if lambda_filter(y) and invalid in y.lower()] 640 | else: 641 | parsed_result = [y for x in parsed_result for y in x.strip().split('\n') if lambda_filter(y)] 642 | 643 | elif 'traj' in mode: 644 | if invalid is not None: 645 | parsed_result = [[y for y in x.strip().split('\n') if lambda_filter(y) and invalid in y.lower()] for x in parsed_result] 646 | else: 647 | parsed_result = [[y for y in x.strip().split('\n') if lambda_filter(y)] for x in parsed_result] 648 | assert len(parsed_result) == num_tasks 649 | else: 650 | raise NotImplementedError('mode must contain traj or step') 651 | 652 | if 'sum' in mode: 653 | f_return = sum 654 | elif 'mean' in mode: 655 | f_return = lambda x: sum(x) / len(x) if len(x) > 0 else 0 656 | elif 'list' in mode: 657 | f_return = lambda x: x 658 | else: 659 | raise NotImplementedError('mode must contain sum, mean or list') 660 | 661 | if 'print' in mode: 662 | print(parsed_result) 663 | 664 | if len(parsed_result) > 0 and isinstance(parsed_result[0], List): 665 | return f_return([sum([f_trial(trial) for trial in result]) for result in parsed_result]) 666 | 667 | return f_return([f_trial(trial) for trial in parsed_result]) 668 | --------------------------------------------------------------------------------