├── .gitignore
├── LICENSE
├── README.md
├── agent
├── __init__.py
├── base.py
├── expel.py
├── react.py
└── reflect.py
├── assets
├── eval_example.png
├── expel.png
└── header_baby.png
├── configs
├── agent
│ └── expel.yaml
├── benchmark
│ ├── alfworld.yaml
│ ├── fever.yaml
│ ├── hotpotqa.yaml
│ └── webshop.yaml
├── eval.yaml
├── insight_extraction.yaml
├── train.yaml
└── visualize_logs.yaml
├── data
├── alfworld
│ └── alfworld_tasks_suffix.json
├── fever
│ └── paper_dev.jsonl
├── hotpotqa
│ └── hotpot-qa-distractor-sample.joblib
└── webshop
│ └── webshop.fixed100.json
├── envs
├── __init__.py
├── alfworld
│ └── alfworld.py
├── base.py
├── fever
│ ├── fever.py
│ ├── wikienv.py
│ └── wrappers.py
├── hotpotqa
│ └── hotpotqa.py
└── webshop
│ └── webshop.py
├── eval.py
├── insight_extraction.py
├── memory
├── __init__.py
└── episode.py
├── models
├── __init__.py
└── llm.py
├── prompts
├── __init__.py
├── alfworld.py
├── fever.py
├── hotpotQA.py
├── templates
│ ├── human.py
│ └── system.py
└── webshop.py
├── requirements.txt
├── tools
└── visualize_logs.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
| ExpeL: LLM Agents are Experiential Learners
2 | :---:|:---:
3 |
4 |
5 |
6 | ⚡ [AAAI 2024 *(Oral)*] Official implementation of the ExpeL Agent ⚡
7 |
8 | ~ by Andrew Zhao, Daniel Huang, Quentin Xu, Matthieu Lin, Yong-Jin Liu, Gao Huang ~
9 |
10 |
11 | [](https://github.com/LeapLabTHU/ExpeL/releases)
12 | 
13 | [](https://star-history.com/#LeapLabTHU/ExpeL)
14 | [](https://github.com/LeapLabTHU/ExpeL/issues)
15 |
16 | ---
17 | ### 🌐 $\cdot$ [Project Page](https://andrewzh112.github.io/expel) 📄 $\cdot$ [Paper](https://arxiv.org/pdf/2308.10144.pdf)
18 |
19 | > "A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P, if its performance at tasks in T, as measured by P, improves with experience E." - Tom Mitchell
20 |
21 | # 📖 Table of Contents
22 |
23 |
24 |
25 |
26 |
27 |
28 | [👋 Introduction](#-introduction)
29 |
30 | [🛠️ Installation](#%EF%B8%8F-installation)
31 |
32 | - [🌳 Environments](#-environments)
33 | - [🏠 ALFWorld](#-alfworld)
34 | - [🛒 Webshop](#-webshop)
35 |
36 | [🚀 Quick start](#-quick-start)
37 | 1. [Experience Gathering](#1-for-the-experience-gathering-stage)
38 | 2. [Insights Extraction](#2-for-the-insights-extraction-stage)
39 | 3. [Evaluation](#3-for-evaluation)
40 |
41 | [🫡 Cite us !](#-cite-us-)
42 |
43 | [💌 Contact us !](#-contact-us-)
44 |
45 | [🏛️ License](#%EF%B8%8F-license)
46 |
47 | [⚠️ Issues](#%EF%B8%8F-issues)
48 |
49 |
50 |
51 |
52 |
53 |
54 |

55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 | ## 👋 Introduction
63 |
64 | This repo is the official implementation of [Expel: LLM Agents are Experiential Learners](https://arxiv.org/pdf/2308.10144.pdf).
65 |
66 | Our agent autonomously gathers experiences and extracts knowledge using natural language from a collection of training tasks. At inference, the agent recalls its extracted insights and past experiences to make informed decisions. Our empirical results highlight the robust learning efficacy of the ExpeL agent, indicating a consistent enhancement in its performance as it accumulates experiences.
67 |
68 | ## 🛠️ Installation
69 | Python version : 3.9.17
70 |
71 | 1. Create a virtual environment using [Anaconda](https://anaconda.org/anaconda/python) (or your favorite package manager), activate it, clone the repo and install the requirements.
72 |
73 | ```sh
74 | conda create -n expel python=3.9.17
75 | conda activate expel
76 |
77 | git clone https://github.com/LeapLabTHU/ExpeL.git expel
78 | cd expel
79 |
80 | pip install -r requirements.txt
81 | ```
82 |
83 | Next you need to setup the environments.
84 |
85 | ## 🌳 Environments
86 |
87 | Baby ExpeL has been playing around with the following environments:
88 |
89 | - ❓[HotpotQA](https://github.com/hotpotqa/hotpot)
90 | - 🏠 [ALFWorld](https://github.com/alfworld/alfworld)
91 | - 🛒 [WebShop](https://github.com/princeton-nlp/WebShop)
92 | - 🌡️ [FEVER](https://github.com/awslabs/fever)
93 |
94 | Among these, ALFWorld and WebShop require manual installation (+ loading a server (can be local) for WebShop). Details below:
95 |
96 | ### 🏠 ALFWorld
97 | The installation instructions are shown below. Use the previously created environment to install ALFWorld.
98 | You will also need to download the data at the specified location: ``data/alfworld``.
99 | ```Bash
100 | conda activate expel
101 | pip install alfworld[full]
102 |
103 | export ALFWORLD_DATA="data/alfworld"
104 | alfworld-download
105 | ```
106 | If you need more details, please refer to the [official repo](https://github.com/alfworld/alfworld#quickstart).
107 |
108 | ### 🛒 WebShop
109 |
110 | WebShop installation is different from the other environments. You will have to install it and **manually** run the server (can be local) in parallel of ExpeL to interact with the environment.
111 | The succinct installation instructions are shown below.
112 |
113 | ```bash
114 | git clone https://github.com/princeton-nlp/webshop.git webshop
115 | cd webshop
116 |
117 | # Create another env for the webshop server to avoid conflicts
118 | conda create -n webshop python=3.8.13
119 | conda activate webshop
120 |
121 | ./setup.sh -d all
122 | ```
123 |
124 | By default the WebShop only loads 1,000 products. But we need ALL OF THEM (🤯). So change ``web_agent_site/utils.py``:
125 |
126 | ```python
127 | # DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2_1000.json')
128 | # DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle_1000.json')
129 | DEFAULT_ATTR_PATH = join(BASE_DIR, '../data/items_ins_v2.json')
130 | DEFAULT_FILE_PATH = join(BASE_DIR, '../data/items_shuffle.json')
131 | ```
132 | To run the server, run the following command:
133 | ```bash
134 | ./run_dev.sh
135 | ```
136 | You will be given an URL (and port) once the website is on:
137 | - Go back to the cloned ExpeL repo
138 | - Modify the config file and add the given URL in ``envs/webshop/webshop.py``:
139 | ```python
140 | WEBSHOP_URL = "http://127.0.0.1:3000" # Example URL
141 | ```
142 |
143 | Note that you will have to run the WebShop server in the background to interact with the environment. We gathered some bugs we encountered during the WebShop Server setup [here](#issues).
144 |
145 | If you need more details, please refer to the [official repo](https://github.com/princeton-nlp/WebShop?tab=readme-ov-file#-setup).
146 |
147 |
148 | ## 🚀 Quick start
149 |
150 | Below are the commands to run the ExpeL Agent.
151 |
152 | **Either put your OpenAI API key in a ``.env`` file (OPENAI_API_KEY=XXX) or get prompted in command line**
153 |
154 | ### 1. For the **Experience Gathering** stage:
155 | ```bash
156 | python train.py benchmark= \
157 | run_name= \
158 | testing=false \
159 | resume=false
160 |
161 | # resume = true/false if you want to resume a previous run
162 | # benchmark = {hotpotqa, alfworld, webshop, fever}
163 | # agent.llm = {gpt-3.5-turbo (default), gpt-4}
164 | ```
165 | Below are the commands to run the experience gathering stage as in the paper:
166 | ```bash
167 | # 🏠 ALFWorld
168 | python train.py benchmark=alfworld run_name= testing=false resume=false
169 | # 🛒 WebShop
170 | python train.py benchmark=webshop run_name= testing=false resume=false
171 | # ❓ HotpotQA
172 | python train.py benchmark=hotpotqa run_name= testing=false resume=false
173 | ```
174 |
175 | By default, the result files (logs, dictionnaries) will be saved in ``logs//expel`` referenced by ````. You can change the log directory by adding ``log_dir=`` to the command line.
176 |
177 |
178 | ### 2. For the **Insights Extraction** stage:
179 | Use the collected experiences to extract insights.
180 |
181 | ```bash
182 | python insight_extraction.py \
183 | benchmark= \
184 | load_run_name= \
185 | run_name= \
186 | agent.llm= \
187 | agent.max_num_rules= \
188 | agent.success_critique_num= \
189 | testing=true \
190 | resume=false
191 |
192 | # agent.success_critique_num = number of experiences to give per iteration
193 | # agent.max_num_rules = target number of insights to extract
194 | ```
195 |
196 | To resume a run that stopped at a specific fold, remove ``load_run_name`` from the parameters and specify the fold ``resume_fold`` it stopped at and ``resume=true``.
197 |
198 |
199 | Below are the commands to run the insights extraction stage as in the paper:
200 | ```bash
201 | # 🏠 ALFWorld
202 | python insight_extraction.py benchmark=alfworld load_run_name= run_name= agent.llm=gpt-4 agent.max_num_rules=10 agent.success_critique_num=8 testing=false resume=false
203 | # 🛒 WebShop
204 | python insight_extraction.py benchmark=webshop load_run_name= run_name= agent.llm=gpt-4 agent.max_num_rules=8 agent.success_critique_num=4 testing=false resume=false
205 | # ❓ HotpotQA
206 | python insight_extraction.py benchmark=hotpotqa load_run_name= run_name= agent.llm=gpt-4 agent.max_num_rules=10 agent.success_critique_num=8 testing=false resume=false
207 | ```
208 |
209 | The final result files will be saved in ``logs//expel/extracted_insights`` referenced by ````.
210 |
211 |
212 | ### 3. For **Evaluation**:
213 | ```bash
214 | python eval.py benchmark= \
215 | load_run_name=extracted_insights/ \
216 | run_name= \
217 | benchmark.eval_configs.k_folds= \
218 | agent.fewshot_strategy=task_similarity \
219 | agent.retrieval_kwargs.max_fewshot_tokens= \
220 | agent.retrieval_kwargs.buffer_retrieve_ratio = \
221 | testing=false \
222 | resume=false
223 |
224 | # agent.fewshot_strategy = {task_similarity, thought_similarity,task_thought_similarity)
225 | # agent.llm = {gpt-3.5-turbo (default), gpt-4}
226 | # agent.retrieval_kwargs.max_fewshot_tokens=auto
227 | # benchmark.eval_configs.k_folds=2
228 | # agent.retrieval_kwargs.buffer_retrieve_ratio = safety measure to not retrieve 0 examples (bigger is safer)
229 | ```
230 |
231 | To resume a run that stopped, remove ``load_run_name`` from the parameters and add ``resume=true`` at the end of the command line.
232 |
233 | Below are the commands to evalute ExpeL as in the paper:
234 |
235 | ```bash
236 | # 🏠 ALFWorld
237 | python eval.py benchmark=alfworld load_run_name=extracted_insights/ run_name= agent.fewshot_strategy=task_similarity agent.retrieval_kwargs.max_fewshot_tokens=auto testing=false resume=false
238 | # 🛒 WebShop
239 | python eval.py benchmark=webshop load_run_name=extracted_insights/ run_name= agent.fewshot_strategy=task_similarity agent.retrieval_kwargs.max_fewshot_tokens=auto agent.retrieval_kwargs.buffer_retrieve_ratio=20 testing=false resume=false
240 | # ❓ HotpotQA
241 | python eval.py benchmark=hotpotqa load_run_name=extracted_insights/ run_name= agent.fewshot_strategy=task_similarity testing=false resume=false
242 | ```
243 | The result files will be saved in ``logs//expel/eval`` referenced by ````.
244 |
245 | ## 🫡 Cite us !
246 |
247 | This repository contains code for reproducing results. If you find this work useful in your research (and/or daily life), please cite:
248 |
249 | ```
250 | @inproceedings{zhao2024expel,
251 | author = {Andrew Zhao and Daniel Huang and Quentin Xu and Matthieu Lin and Yong-Jin Liu and Gao Huang},
252 | title = {ExpeL: LLM Agents Are Experiential Learners},
253 | booktitle = {Thirty-Eighth {AAAI} Conference on Artificial Intelligence, {AAAI}
254 | 2024, Thirty-Sixth Conference on Innovative Applications of Artificial
255 | Intelligence, {IAAI} 2024, Fourteenth Symposium on Educational Advances
256 | in Artificial Intelligence, {EAAI} 2024, February 20-27, 2024, Vancouver,
257 | Canada},
258 | editor = {Michael J. Wooldridge and Jennifer G. Dy and Sriraam Natarajan},
259 | year = {2024},
260 | pages = {19632--19642},
261 | publisher = {{AAAI} Press},
262 | url = {https://ojs.aaai.org/index.php/AAAI/article/view/29936},
263 | doi = {10.1609/aaai.v38i17.29936}
264 | }
265 | ```
266 |
267 | ## 💌 Contact us !
268 |
269 | If you have any questions, feel free to contact [Andrew Zhao](mailto:zqc21@mails.tsinghua.edu.cn), [Daniel Huang](mailto:huang-jy22@mails.tsinghua.edu.cn) or [Quentin Xu](mailto:quentinxu1@gmail.com).
270 |
271 |
272 | ## 🏛️ License
273 | Check `LICENSE.md`
274 |
275 | ## ⚠️ Issues
276 | We encountered some errors and gathered them here (note that at time of reading, they might have been fixed). If you don't encountered them, lucky you 😒.
277 |
278 | **🛒 WebShop**-server installation:
279 |
280 | ```Bash
281 | # install
282 | python -m spacy download en_core_web_lg #
283 | pip install lightgbm nmslib # need to have c compiler and stuff
284 | conda install mkl # if ImportError: libmkl_intel_lp64.so.1:
285 | pip install pysernini
286 | pip install pyserini --no-cache-dir # if low on ram
287 | pip install typing-inspect==0.8.0 typing_extensions==4.5.0 # if issubclass errors
288 | # if libjvm.so something, need to export JAVA_HOME
289 | ./setup.sh -d all
290 | ```
291 | On Mac, if you have problem with lightgbm or nmslib you might have to replace their pip to:
292 | ```Bash
293 | brew install cmake libomp
294 | pip install lightgbm
295 |
296 | CFLAGS="-mavx -DWARN(a)=(a)" pip install --use-pep517 nmslib
297 | ```
298 |
--------------------------------------------------------------------------------
/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseAgent
2 | from .react import ReactAgent
3 | from .reflect import ReflectAgent
4 | from .expel import ExpelAgent
5 |
6 |
7 | AGENT = dict(reflection=ReflectAgent, react=ReactAgent, expel=ExpelAgent)
8 |
--------------------------------------------------------------------------------
/agent/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | import re
3 |
4 | from prompts.templates.human import human_task_message_prompt
5 |
6 |
7 | class BaseAgent(ABC):
8 | """
9 | Base agent class.
10 | """
11 | @abstractmethod
12 | def run(self, *args, **kwargs) -> None:
13 | pass
14 |
15 | @abstractmethod
16 | def step(self) -> None:
17 | pass
18 |
19 | @abstractmethod
20 | def prompt_agent(self) -> str:
21 | pass
22 |
23 | @abstractmethod
24 | def _build_agent_prompt(self) -> str:
25 | pass
26 |
27 | @abstractmethod
28 | def after_step(self, *args, **kwargs) -> None:
29 | pass
30 |
31 | def is_terminated(self) -> bool:
32 | return self.env.is_terminated()
33 |
34 | def is_truncated(self) -> bool:
35 | return self.env.is_truncated() or (self.token_counter(self.log_history(include_all=True)) > 15800)
36 |
37 | @abstractmethod
38 | def reset(self, *args, **kwargs) -> None:
39 | pass
40 |
41 | def log_history(self, include_task: bool = True, include_all: bool = False) -> str:
42 | all_history = '\n'.join([prompt.content for prompt in self.prompt_history])
43 | if include_all:
44 | return all_history
45 |
46 | # only log the task prompt and the agent's response
47 | reflection_pattern = r'{}'.format(self.format_reflections(self.reflections, include_prefix=False))
48 | match = re.search(re.escape(reflection_pattern), all_history)
49 | if not match or match.group() == '' or not include_task:
50 | task_text_list = human_task_message_prompt.format_messages(task=self.remove_task_suffix(self.task))[0].content.split('\n')
51 | task_text = '\n'.join(task_text_list)
52 | pattern = r'{}'.format(re.escape(task_text.strip()) + '.*')
53 | match = re.search(pattern, all_history)
54 | if include_task:
55 | return match.group().lstrip("Now it's your turn!\n") + match.string[match.end():]
56 | return match.string[match.end():].strip()
57 |
58 | def remove_task_suffix(self, task: str) -> str:
59 | if self.benchmark_name == 'alfworld':
60 | return task.split('___')[0]
61 | return task
--------------------------------------------------------------------------------
/agent/react.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable, Tuple, Dict, Any, Union
2 | from functools import partial
3 | from copy import deepcopy
4 |
5 | from langchain.prompts import PromptTemplate
6 | from langchain.schema import ChatMessage
7 | from openai.error import InvalidRequestError
8 |
9 | from envs import BaseEnv
10 | from agent import BaseAgent
11 | from prompts.templates.human import (
12 | human_instruction_fewshot_message_prompt,
13 | human_task_message_prompt,
14 | )
15 | from utils import print_message, token_counter
16 |
17 | class ReactAgent(BaseAgent):
18 | """
19 | A Generic ReAct Agent.
20 | """
21 | def __init__(self,
22 | name: str,
23 | system_instruction: Union[str, Dict[str, str]],
24 | human_instruction: Callable,
25 | fewshots: Union[List[str], Dict[str, List[str]]],
26 | system_prompt: Callable,
27 | env: BaseEnv,
28 | llm: str,
29 | llm_builder: Callable,
30 | openai_api_key: str,
31 | tasks: List[Dict[str, Any]],
32 | max_steps: int,
33 | llm_parser: Callable,
34 | observation_formatter: Callable,
35 | testing: bool = False,
36 | task_idx: int = 0,
37 | benchmark_name = None,
38 | *args,
39 | **kwargs,
40 | ) -> None:
41 | self.benchmark_name = benchmark_name
42 | self.name = name
43 | self.tasks = tasks
44 | self.task_idx = task_idx
45 | self.all_system_instruction = system_instruction
46 | self.human_instruction = human_instruction
47 | self.human_instruction_kwargs = {'max_steps': max_steps}
48 | self.all_fewshots = fewshots
49 | self.system_prompt = system_prompt
50 | self.prompt_history = []
51 | self.testing = testing
52 | self.max_steps = max_steps
53 | self.llm_parser = llm_parser
54 | self.observation_formatter = observation_formatter
55 | self._last_observation_history = None
56 |
57 | self.env = env(**self.tasks[self.task_idx]['env_kwargs'], max_steps=self.max_steps)
58 | self.env.reset()
59 | self.task = self.tasks[self.task_idx]['task']
60 | self.reset()
61 | self.truncated, self.reward, self.terminated = False, False, False
62 | self.print_message = partial(print_message, testing=testing)
63 |
64 | self.success, self.fail, self.halted = 0, 0, 0
65 |
66 | self.llm = llm_builder(llm_name=llm, openai_api_key=openai_api_key, long_ver=False)
67 | self.long_context_llm = llm_builder(llm_name=llm, openai_api_key=openai_api_key, long_ver=True)
68 | del openai_api_key
69 | self.token_counter = partial(token_counter, llm=llm, tokenizer=getattr(self.llm, 'tokenizer', None))
70 |
71 | # build base prompt
72 | self._build_agent_prompt()
73 | self.update_dynamic_prompt_components()
74 |
75 | self.long_pass = None
76 |
77 | def is_success(self) -> bool:
78 | return self.env.success_fn()
79 |
80 | def set_env(self, task_kwargs: Dict[str, Any], max_steps: int):
81 | self.env.__init__(**task_kwargs, max_steps=max_steps)
82 |
83 | def run(self, reset: bool = True, *args, **kwargs) -> None:
84 | if reset:
85 | self.env.reset()
86 | self.reset()
87 |
88 | while not (self.is_truncated() or self.is_terminated()):
89 | self.step()
90 |
91 | def step(self) -> None:
92 | message, message_type, others = self.llm_parser(self.prompt_agent(), self.curr_step, False)
93 | self.prompt_history.append(message)
94 | self.print_message(message)
95 |
96 | thought_num = 1
97 | # loops while in thinking mode
98 | while message_type == 'thought':
99 | thought_num += 1
100 | message, message_type, others = self.llm_parser(self.prompt_agent(), self.curr_step, False)
101 | self.prompt_history.append(message)
102 | self.print_message(message)
103 |
104 | if thought_num > 2:
105 | if message_type == 'thought':
106 | others['action'] = 'N/A'
107 | break
108 |
109 | # Observe
110 | observation, self.reward, self.terminated, self.truncated, _ = self.env.step(others['action'])
111 | if others['action'] == 'N/A' and thought_num > 2:
112 | observation = "You are thinking too many times without taking action."
113 | observation_history, operation = self.observation_formatter(observation, step=self.curr_step)
114 | if operation == 'append':
115 | self.prompt_history.append(observation_history)
116 | elif operation == 'replace':
117 | for message in self.prompt_history:
118 | if self._last_observation_history.content in message.content:
119 | message.content = message.content.replace(self._last_observation_history.content, observation_history.content)
120 | break
121 | self._last_observation_history = deepcopy(observation_history)
122 | self.print_message(observation_history)
123 |
124 | BaseAgent.after_step(self)
125 |
126 | self.prompt_history = self.collapse_prompts(self.prompt_history)
127 |
128 | self.curr_step += 1
129 |
130 | def prompt_agent(self) -> str:
131 | self.prompt_history = self.collapse_prompts(self.prompt_history)
132 | self.update_dynamic_prompt_components()
133 | prompt_history = self.collapse_prompts(self.prompt_history)
134 | if self.testing:
135 | print('###################################')
136 | for prompt in prompt_history:
137 | self.print_message(prompt, self.token_counter)
138 | return input()
139 | try:
140 | return self.llm(prompt_history, stop=['\n', '\n\n'])
141 | except InvalidRequestError:
142 | while self.long_pass is None:
143 | res = input('Changing to long context LLM. Press Enter to continue.\n')
144 | if res == 'pass':
145 | self.long_pass = True
146 | elif res != '':
147 | continue
148 | break
149 |
150 | return self.long_context_llm(prompt_history, stop=['\n', '\n\n'])
151 |
152 | def _build_fewshot_prompt(
153 | self,
154 | fewshots: List[str],
155 | prompt_history: List[ChatMessage],
156 | instruction_prompt: PromptTemplate,
157 | instruction_prompt_kwargs: Dict[str, Any],
158 | prompt_type: str,
159 | ) -> str:
160 | if human_instruction_fewshot_message_prompt is not None and instruction_prompt is not None:
161 | prompt_history.append(
162 | human_instruction_fewshot_message_prompt('message_style_kwargs').format_messages(
163 | instruction=instruction_prompt.format_messages(
164 | **instruction_prompt_kwargs)[0].content,
165 | fewshots='\n\n'.join(fewshots)
166 | )[0]
167 | )
168 |
169 | def _build_agent_prompt(self) -> None:
170 | system_prompt = self.system_prompt.format_messages(
171 | instruction=self.system_instruction, ai_name=self.name
172 | )
173 | self.prompt_history.extend(system_prompt)
174 | self._build_fewshot_prompt(
175 | fewshots=self.fewshots, prompt_history=self.prompt_history,
176 | instruction_prompt=self.human_instruction,
177 | instruction_prompt_kwargs=self.human_instruction_kwargs,
178 | prompt_type='react_type',
179 | )
180 | self.prompt_history = self.collapse_prompts(self.prompt_history)
181 | self.log_idx = len(self.prompt_history)
182 | self.insert_before_task_prompt()
183 |
184 | self.prompt_history.append(human_task_message_prompt.format_messages(task=self.remove_task_suffix(self.task))[0])
185 | self.insert_after_task_prompt()
186 | self.prompt_history = self.collapse_prompts(self.prompt_history)
187 | self.pretask_idx = len(self.prompt_history)
188 | return self.prompt_history
189 |
190 | def reset(self, *args, **kwargs) -> None:
191 | self.prompt_history = []
192 | self.update_dynamic_prompt_components(reset=True)
193 | self.curr_step = 1
194 | self._build_agent_prompt()
195 |
196 | def insert_after_task_prompt(self) -> None:
197 | return
198 |
199 | def job_not_done(self) -> bool:
200 | return self.task_idx < len(self.tasks)
201 |
202 | def next_task(self):
203 | self.task_idx += 1
204 | # if there are more tasks, reset the env and the agent
205 | if self.job_not_done():
206 | self.task = self.tasks[self.task_idx]['task']
207 | self.set_env(self.tasks[self.task_idx]['env_kwargs'], max_steps=self.max_steps)
208 | self.env.reset()
209 | self.reset()
210 |
211 | def reset_stats(self) -> None:
212 | self.success = 0
213 | self.fail = 0
214 | self.halted = 0
215 |
216 | def update_stats(self) -> None:
217 | if not self.is_success() and self.is_truncated():
218 | self.halted += 1
219 | else:
220 | if self.reward:
221 | self.success += 1
222 | else:
223 | self.fail += 1
224 |
225 | def get_stats(self) -> Tuple[int, int, int]:
226 | return self.success, self.fail, self.halted
227 |
228 | def collapse_prompts(self, prompt_history: List[ChatMessage]) -> List[ChatMessage]:
229 | """Courtesy of GPT4"""
230 | if not prompt_history:
231 | return []
232 |
233 | new_prompt_history = []
234 | scratch_pad = prompt_history[0].content
235 | last_message_type = type(prompt_history[0])
236 |
237 | for message in prompt_history[1:]:
238 | current_message_type = type(message)
239 | if current_message_type == last_message_type:
240 | scratch_pad += '\n' + message.content
241 | else:
242 | new_prompt_history.append(last_message_type(content=scratch_pad))
243 | scratch_pad = message.content
244 | last_message_type = current_message_type
245 |
246 | # Handle the last accumulated message
247 | new_prompt_history.append(last_message_type(content=scratch_pad))
248 |
249 | return new_prompt_history
250 |
251 | def update_dynamic_prompt_components(self):
252 | #####################
253 | # Updating fewshots #
254 | #####################
255 | if isinstance(self.all_fewshots, dict):
256 | self.fewshots = self.all_fewshots[self.env.env_name]
257 | elif isinstance(self.all_fewshots, list):
258 | self.fewshots = self.all_fewshots
259 |
260 | #########################
261 | # Updating instructions #
262 | #########################
263 | if isinstance(self.all_system_instruction, str):
264 | self.system_instruction = self.all_system_instruction
265 | elif isinstance(self.all_system_instruction, dict):
266 | self.system_instruction = self.all_system_instruction[self.env.env_name]
267 | # if system gives instruction, then human instruction is empty
268 | self.human_instruction_kwargs['instruction'] = ''
269 | self.num_fewshots = len(self.fewshots)
270 |
271 | def load_checkpoint(self, loaded_dict: Dict[str, Any], no_load_list: List['str'] = []) -> None:
272 | for k, v in loaded_dict.items():
273 | if k in no_load_list:
274 | continue
275 | setattr(self, k, v)
276 | # following attributes are not saved in pickle but correctely initialized back: ['rule_template', 'truncate_strategy', 'embedder', 'retriever_cls', 'manual', 'reflection_task_prompt', 'message_splitter', 'identifier', 'message_step_splitter', 'format_reflections', 'formatted_reflection', 'human_instruction', 'system_prompt', 'llm_parser', 'observation_formatter', 'env', 'print_message', 'llm', 'long_context_llm', 'token_counter']
277 |
--------------------------------------------------------------------------------
/agent/reflect.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable
2 |
3 | from openai.error import InvalidRequestError
4 | from langchain.schema import HumanMessage
5 |
6 | from agent.react import ReactAgent
7 | from utils import Count
8 | import re
9 |
10 | class ReflectAgent(ReactAgent):
11 | """
12 | A Generic Reflection Agent.
13 | """
14 | def __init__(self,
15 | reflection_fewshots: List[str],
16 | reflection_task_prompt: Callable,
17 | reflection_system_instruction: str,
18 | max_relfection_depth: int,
19 | message_splitter: Callable,
20 | identifier: Callable,
21 | message_step_splitter: Callable,
22 | reflection_prefix: str,
23 | previous_trials_formatter: Callable,
24 | *args,
25 | **kwargs,
26 | ) -> None:
27 | self.reflection_counter = Count(max_relfection_depth)
28 | self.reflection_fewshots = reflection_fewshots
29 | self.reflection_task_prompt = reflection_task_prompt
30 | self.message_splitter = message_splitter
31 | self.identifier = identifier
32 | self.message_step_splitter = message_step_splitter
33 | self.reflection_prefix = reflection_prefix
34 | self.format_reflections = previous_trials_formatter
35 | self.reflection_prompt_history = []
36 | self.reflections = []
37 | self.previous_trial = []
38 | self.formatted_reflection = None
39 | self.perform_reflection = False
40 | self.increment_task = False
41 | ai_name = 'an advanced reasoning agent that can improve based on self refection'
42 | self.reflection_system_kwargs = dict(instruction=reflection_system_instruction, ai_name=ai_name)
43 | super().__init__(*args, **kwargs)
44 |
45 | def run(self, reset: bool = True, *args, **kwargs) -> None:
46 | if self.perform_reflection and not self.is_success():
47 | self.reflect()
48 | ReactAgent.run(self, reset)
49 | if self.reflection_counter.is_maximum() or self.is_success():
50 | self.increment_task = True
51 |
52 |
53 | def step(self) -> None:
54 | ReactAgent.step(self)
55 | trial = self.prompt_history[self.history_index].content.split(self.remove_task_suffix(self.task), 1)[-1].strip()
56 | steps = self.message_step_splitter(
57 | lines=trial,
58 | cycler=self.message_splitter,
59 | step_identifier=self.identifier)
60 | self.previous_trial.append(HumanMessage(content=steps[-1]))
61 |
62 |
63 | def reflect(self) -> None:
64 | self._format_reflection_scratchpad()
65 | self.reflection_prompt_history.append(HumanMessage(content=self.reflection_prefix))
66 | reflection = self.prompt_reflection()
67 | self.reflections.append(reflection)
68 | self.formatted_reflection = self.format_reflections(self.reflections)
69 | print(self.formatted_reflection)
70 | # wipe the history for a new round
71 | self.previous_trial = []
72 |
73 | def insert_before_task_prompt(self) -> None:
74 | if self.formatted_reflection is not None:
75 | self.prompt_history.append(HumanMessage(content=self.formatted_reflection))
76 |
77 | def prompt_reflection(self) -> str:
78 | self.reflection_prompt_history = self.collapse_prompts(self.reflection_prompt_history)
79 | if self.benchmark_name == 'webshop':
80 | # match the last "Observation:"
81 | pattern = r"\nObservation: (.*[\n]+)+Next plan:.*"
82 | matches = re.findall(pattern, self.reflection_prompt_history[-1].content)
83 | if 'Ran out of steps' in matches[-1]:
84 | add_text = "\nObservation: Ran out of steps! TASK FAILED\n\nNext plan:\n"
85 | elif 'Repeated action' in matches[-1]:
86 | add_text = "\nObservation: Repeated action! TASK FAILED\n\nNext plan:\n"
87 | else:
88 | add_text = "\nObservation: Wrong item! TASK FAILED\n\nNext plan:\n"
89 |
90 | new_history = self.reflection_prompt_history[-1].content.split(matches[-1])
91 | new_history = ''.join(new_history[:-1]) + add_text
92 |
93 | self.reflection_prompt_history[-1].content = new_history
94 |
95 | if self.testing:
96 | print('###################################')
97 | for prompt in self.reflection_prompt_history:
98 | self.print_message(prompt, self.token_counter)
99 | return input()
100 | try:
101 | return self.llm(self.reflection_prompt_history, stop=['\n', '\n\n'])
102 | except InvalidRequestError:
103 | return self.long_context_llm(self.reflection_prompt_history, stop=['\n', '\n\n'])
104 |
105 | def _build_reflection_prompt(self) -> None:
106 | # avoid building reflection prompt if it already exists
107 | if self.reflection_prompt_history != []:
108 | return
109 | system_prompt = self.system_prompt.format_messages(**self.reflection_system_kwargs)
110 | self.reflection_prompt_history.extend(system_prompt)
111 | self._build_fewshot_prompt(
112 | fewshots=self.reflection_fewshots,
113 | prompt_history=self.reflection_prompt_history,
114 | instruction_prompt=self.reflection_task_prompt,
115 | instruction_prompt_kwargs={},
116 | prompt_type='reflect_type',
117 | )
118 | self.reflection_prompt_history.append(HumanMessage(content=f'Previous trial:\n{self.remove_task_suffix(self.task)}'))
119 | self.reflect_interaction_idx = len(self.reflection_prompt_history)
120 | for message in self.previous_trial:
121 | self.reflection_prompt_history.append(message)
122 |
123 | def _format_reflection_scratchpad(self) -> str:
124 | lines = [ref.content for ref in self.reflection_prompt_history[self.reflect_interaction_idx:]]
125 | lines_by_tokens = sorted(lines, key=lambda x: self.token_counter(x))
126 | while self.token_counter(''.join(lines)) > 12000:
127 | ind = lines.index(lines_by_tokens.pop(-1))
128 | line = lines[ind]
129 | lines[ind] = line.split(':')[0] + ': ...'
130 | combined_message = HumanMessage(content='\n'.join(lines))
131 | self.reflection_prompt_history = self.reflection_prompt_history[:self.reflect_interaction_idx]
132 | self.reflection_prompt_history.append(combined_message)
133 |
134 | def reset(self, *args, **kwargs) -> None:
135 | ReactAgent.reset(self, *args, **kwargs)
136 | self.reflection_prompt_history = []
137 | self._build_reflection_prompt()
138 | if self.increment_task:
139 | self.reflections = []
140 | self.reflection_counter.reset()
141 | self.formatted_reflection = None
142 | self.previous_trial = []
143 |
144 | @property
145 | def history_index(self) -> int:
146 | return -1
147 |
148 | def next_task(self) -> bool:
149 | # increment task if reflection counter is at max OR if the agent is successful
150 | if self.increment_task:
151 | self.task_idx += 1
152 | if self.job_not_done():
153 | self.task = self.tasks[self.task_idx]['task']
154 | self.set_env(self.tasks[self.task_idx]['env_kwargs'], max_steps=self.max_steps)
155 | self.perform_reflection = False
156 | # wipe the history for a new task
157 | self.previous_trial = []
158 | # if there are more tasks, perform reflection
159 | if self.job_not_done() and not self.increment_task:
160 | self.perform_reflection = True
161 | self.reflection_counter.increment()
162 | self.reset()
163 | self.env.reset()
164 | if self.increment_task:
165 | self.increment_task = False
166 | return True
167 | return False
168 |
169 | def update_stats(self) -> None:
170 | # only count when finished trying for this task
171 | if self.increment_task:
172 | if not self.is_success() and self.is_truncated():
173 | self.halted += 1
174 | else:
175 | if self.reward:
176 | self.success += 1
177 | else:
178 | self.fail += 1
179 |
--------------------------------------------------------------------------------
/assets/eval_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/assets/eval_example.png
--------------------------------------------------------------------------------
/assets/expel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/assets/expel.png
--------------------------------------------------------------------------------
/assets/header_baby.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/assets/header_baby.png
--------------------------------------------------------------------------------
/configs/agent/expel.yaml:
--------------------------------------------------------------------------------
1 | name: expel
2 | llm: gpt-3.5-turbo # gpt-3.5-turbo-0301
3 | max_reflection_depth: 3
4 | max_num_rules: 20
5 | truncate_strategy: null
6 | fewshot_strategy: task_similarity # task_similarity, step_similarity, reflection_similarity, thought_similarity, action_similarity, rotation, task_thought_similarity
7 | critique_truncate_strategy: random
8 | success_critique_num: 8
9 | retrieval_kwargs:
10 | embedder_path: "all-mpnet-base-v2" # sentence-transformers/all-mpnet-base-v2
11 | embedder_type: huggingface
12 | retriever_type: knn # knn or svm
13 | buffer_retrieve_ratio: 4
14 | reranker: none
15 | max_fewshot_tokens: auto
--------------------------------------------------------------------------------
/configs/benchmark/alfworld.yaml:
--------------------------------------------------------------------------------
1 | name: alfworld
2 | task_prefix: ""
3 | task_file: data/alfworld/alfworld_tasks_suffix.json
4 | max_steps: 20
5 | num_fewshots: 2
6 | eval_configs:
7 | k_folds: 2
8 | k_folds_seed: 42
9 | split: eval_out_of_distribution
10 | ai_name: alfred
11 |
12 | dataset:
13 | data_path: 'data/alfworld/json_2.1.1/train'
14 | eval_id_data_path: 'data/alfworld/json_2.1.1/valid_seen' # null/None to disable
15 | eval_ood_data_path: 'data/alfworld/json_2.1.1/valid_unseen' # null/None to disable
16 | num_train_games: -1 # max training games (<=0 indicates full dataset)
17 | num_eval_games: -1 # max evaluation games (<=0 indicates full dataset)
18 |
19 | logic:
20 | domain: 'data/alfworld/logic/alfred.pddl' # PDDL domain file that defines the world dynamics
21 | grammar: 'data/alfworld/logic/alfred.twl2' # Grammar file that defines the text feedbacks
22 |
23 | env:
24 | type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid'
25 | regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file
26 | domain_randomization: False # shuffle Textworld print order and object id nums
27 | task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
28 | expert_timeout_steps: 150 # max steps before timeout for expert to solve the task
29 | expert_type: "handcoded" # 'handcoded' or 'downward'. Note: the downward planner is very slow for real-time use
30 | goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED)
31 |
32 | hybrid:
33 | start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point
34 | thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training
35 | eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training
36 |
37 | thor:
38 | screen_width: 300 # width of THOR window
39 | screen_height: 300 # height of THOR window
40 | smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow)
41 | save_frames_to_disk: False # save frame PNGs to disk (useful for making videos)
42 | save_frames_path: './videos/' # path to save frame PNGs
43 |
44 | controller:
45 | type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER)
46 | debug: False
47 | load_receps: True # load receptacle locations from precomputed dict (if available)
48 |
49 | mask_rcnn:
50 | pretrained_model_path: 'data/alfworld/detectors/mrcnn.pth'
51 |
52 | general:
53 | random_seed: 42
54 | use_cuda: True # disable this when running on machine without cuda
55 | visdom: False # plot training/eval curves, run with visdom server
56 | task: 'alfred'
57 | training_method: 'dagger' # 'dqn' or 'dagger'
58 | save_path: './training/' # path to save pytorch models
59 | observation_pool_capacity: 3 # k-size queue, 0 indicates no observation
60 | hide_init_receptacles: False # remove initial observation containing navigable receptacles
61 |
62 | training:
63 | batch_size: 10
64 | max_episode: 50000
65 | smoothing_eps: 0.1
66 | optimizer:
67 | learning_rate: 0.001
68 | clip_grad_norm: 5
69 |
70 | evaluate:
71 | run_eval: True
72 | batch_size: 10
73 | env:
74 | type: "AlfredTWEnv"
75 |
76 | checkpoint:
77 | report_frequency: 1000 # report every N episode
78 | experiment_tag: 'test' # name of experiment
79 | load_pretrained: False # during test, enable this so that the agent load your pretrained model
80 | load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path
81 |
82 | model:
83 | encoder_layers: 1
84 | decoder_layers: 1
85 | encoder_conv_num: 5
86 | block_hidden_dim: 64
87 | n_heads: 1
88 | dropout: 0.1
89 | block_dropout: 0.1
90 | recurrent: True
91 |
92 | rl:
93 | action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working)
94 | max_target_length: 20 # max token length for seq2seq generation
95 | beam_width: 10 # 1 means greedy
96 | generate_top_k: 3
97 |
98 | training:
99 | max_nb_steps_per_episode: 50 # terminate after this many steps
100 | learn_start_from_this_episode: 0 # delay updates until this epsiode
101 | target_net_update_frequency: 500 # sync target net with online net per this many epochs
102 |
103 | replay:
104 | accumulate_reward_from_final: True
105 | count_reward_lambda: 0.0 # 0 to disable
106 | novel_object_reward_lambda: 0.0 # 0 to disable
107 | discount_gamma_game_reward: 0.9
108 | discount_gamma_count_reward: 0.5
109 | discount_gamma_novel_object_reward: 0.5
110 | replay_memory_capacity: 500000 # adjust this depending on your RAM size
111 | replay_memory_priority_fraction: 0.5
112 | update_per_k_game_steps: 5
113 | replay_batch_size: 64
114 | multi_step: 3
115 | replay_sample_history_length: 4
116 | replay_sample_update_from: 2
117 |
118 | epsilon_greedy:
119 | noisy_net: False # if this is true, then epsilon greedy is disabled
120 | epsilon_anneal_episodes: 1000 # -1 if not annealing
121 | epsilon_anneal_from: 0.3
122 | epsilon_anneal_to: 0.1
123 |
124 | dagger:
125 | action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working)
126 | max_target_length: 20 # max token length for seq2seq generation
127 | beam_width: 10 # 1 means greedy
128 | generate_top_k: 5
129 | unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation
130 |
131 | training:
132 | max_nb_steps_per_episode: 50 # terminate after this many steps
133 |
134 | fraction_assist:
135 | fraction_assist_anneal_episodes: 50000
136 | fraction_assist_anneal_from: 1.0
137 | fraction_assist_anneal_to: 0.01
138 |
139 | fraction_random:
140 | fraction_random_anneal_episodes: 0
141 | fraction_random_anneal_from: 0.0
142 | fraction_random_anneal_to: 0.0
143 |
144 | replay:
145 | replay_memory_capacity: 500000
146 | update_per_k_game_steps: 5
147 | replay_batch_size: 64
148 | replay_sample_history_length: 4
149 | replay_sample_update_from: 2
150 |
151 | vision_dagger:
152 | model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input)
153 | resnet_fc_dim: 64
154 | maskrcnn_top_k_boxes: 10 # top k box features
155 | use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!)
156 | sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn'
--------------------------------------------------------------------------------
/configs/benchmark/fever.yaml:
--------------------------------------------------------------------------------
1 | name: fever
2 | task_prefix: "Claim: "
3 | task_file: null
4 | max_steps: 7
5 | ai_name: fact verifier
6 | num_fewshots: 5
7 | eval_configs:
8 | k_folds: 2
9 | k_folds_seed: 42
10 |
--------------------------------------------------------------------------------
/configs/benchmark/hotpotqa.yaml:
--------------------------------------------------------------------------------
1 | name: hotpotqa
2 | task_prefix: "Question: "
3 | task_file: data/hotpotqa/hotpot-qa-distractor-sample.joblib
4 | max_steps: 7
5 | num_fewshots: 6
6 | eval_configs:
7 | k_folds: 2
8 | k_folds_seed: 42
9 | ai_name: QA system
--------------------------------------------------------------------------------
/configs/benchmark/webshop.yaml:
--------------------------------------------------------------------------------
1 | name: webshop
2 | task_prefix: "Instruction:\n"
3 | task_file: data/webshop/webshop.fixed100.json
4 | max_steps: 15
5 | num_fewshots: 2
6 | eval_configs:
7 | k_folds: 2
8 | k_folds_seed: 42
9 | ai_name: webshop
10 | num_tasks: 100
--------------------------------------------------------------------------------
/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - benchmark: hotpotqa
4 | - agent: expel
5 | - override hydra/hydra_logging: disabled
6 | - override hydra/job_logging: disabled
7 |
8 | ai_name: ${benchmark.ai_name}
9 | agent_type: ${agent.name}
10 | log_dir: logs
11 | testing: true
12 | resume: false
13 | load_cache_rules: true
14 | no_rules: false
15 | run_name: run
16 | load_run_name: ${run_name}
17 |
18 | # remove hydra output files
19 | hydra:
20 | output_subdir: null
21 | run:
22 | dir: .
--------------------------------------------------------------------------------
/configs/insight_extraction.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - benchmark: hotpotqa
4 | - agent: expel
5 | - override hydra/hydra_logging: disabled
6 | - override hydra/job_logging: disabled
7 |
8 | ai_name: ${benchmark.ai_name}
9 | agent_type: ${agent.name}
10 | log_dir: logs
11 | run_name: run
12 | load_run_name: ${run_name}
13 | testing: true
14 | folded: true
15 | resume: false
16 | resume_fold: -1
17 | seed: 42
18 |
19 | # remove hydra output files
20 | hydra:
21 | output_subdir: null
22 | run:
23 | dir: .
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - benchmark: hotpotqa
4 | - agent: expel
5 | - override hydra/hydra_logging: disabled
6 | - override hydra/job_logging: disabled
7 |
8 | ai_name: ${benchmark.ai_name}
9 | agent_type: ${agent.name}
10 | log_dir: logs
11 | run_name: run
12 | testing: true
13 | resume: false
14 |
15 | # remove hydra output files
16 | hydra:
17 | output_subdir: null
18 | run:
19 | dir: .
--------------------------------------------------------------------------------
/configs/visualize_logs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - benchmark: hotpotqa
4 | - agent: expel
5 | - override hydra/hydra_logging: disabled
6 | - override hydra/job_logging: disabled
7 |
8 | hydra:
9 | output_subdir: null
10 | run:
11 | dir: .
12 |
13 | log_path: logs/${benchmark.name}/${agent.name}
14 | run_name: run
15 | num_tasks: 100
16 | save_fig: false
17 |
--------------------------------------------------------------------------------
/data/hotpotqa/hotpot-qa-distractor-sample.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/ExpeL/e41ec9a24823e7b560c561ab191441b56d9bcefc/data/hotpotqa/hotpot-qa-distractor-sample.joblib
--------------------------------------------------------------------------------
/envs/__init__.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | import json
3 | import random
4 | import json
5 |
6 | from .base import BaseEnv
7 | from .hotpotqa.hotpotqa import QAEnv
8 | from .fever.fever import FeverEnv
9 | from .alfworld.alfworld import AlfworldEnv
10 | from .webshop.webshop import WebshopEnv
11 | from utils import get_env_name_from_gamefile
12 |
13 | # Taken from ReAct Github
14 | idxs = list(range(7405))
15 | random.Random(233).shuffle(idxs)
16 |
17 | INIT_TASKS_FN = dict(
18 | hotpotqa=lambda cfg: [
19 | {
20 | 'task': f'{cfg.benchmark.task_prefix}{row["question"]}',
21 | 'env_kwargs': {
22 | 'question': row['question'],
23 | 'key': row['answer'],
24 | },
25 | 'env_name': 'hotpotqa',
26 | } for _, row in joblib.load(cfg.benchmark.task_file).reset_index(drop=True).iterrows()],
27 | # 100 tasks for fever
28 | fever=lambda cfg: [{
29 | 'task': cfg.benchmark.task_prefix + FeverEnv(idx).reset().replace('Claim: ', ''),
30 | 'env_kwargs': {
31 | 'idx': idx,
32 | },
33 | 'env_name': 'fever',
34 | } for idx in idxs[:100]],
35 | alfworld=lambda cfg: [
36 | {
37 | 'task': f'{cfg.benchmark.task_prefix}{row["goal"]}',
38 | 'env_kwargs': {
39 | 'config': cfg.benchmark,
40 | "gamefile": row["gamefile"],
41 | },
42 | 'env_name': get_env_name_from_gamefile(row['gamefile'])
43 | } for row in json.load(open(cfg.benchmark.task_file, "r"))
44 | ],
45 | webshop=lambda cfg: [
46 | {
47 | 'task': f'{cfg.benchmark.task_prefix}{row["task"]}',
48 | 'env_kwargs': {
49 | 'session_idx': row["session_idx"],
50 | },
51 | 'env_name': 'webshop'
52 | } for row in json.load(open(cfg.benchmark.task_file, "r"))
53 | ],
54 | )
55 |
56 | ENVS = dict(hotpotqa=QAEnv, fever=FeverEnv, alfworld=AlfworldEnv, webshop=WebshopEnv)
57 |
--------------------------------------------------------------------------------
/envs/alfworld/alfworld.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List, Dict, Any, Tuple
3 | from envs.base import BaseEnv
4 |
5 | import alfworld.agents.environment
6 | from utils import get_env_name_from_gamefile
7 |
8 | class AlfworldEnv(BaseEnv):
9 | def __init__(self,
10 | gamefile: str,
11 | config: Dict[str, Any],
12 | max_steps: int = 50,
13 | ):
14 | self.max_steps = max_steps
15 | self.gamefile = gamefile
16 | self.config = config
17 | self.main_env = getattr(alfworld.agents.environment, self.config.env.type)(self.config, train_eval=self.config.split)
18 | self.main_env.game_files = [self.gamefile]
19 | self.task = "housekeeper robot. The agent was placed in a household environment and a task to complete."
20 | self.env_name = get_env_name_from_gamefile(gamefile)
21 |
22 | self.reset()
23 |
24 | def reset(self):
25 | self.curr_step = 1
26 | self.answer = ''
27 | self.terminated = False
28 | self.reward = False
29 | self.is_exhausted = False
30 | self.env = self.main_env.init_env(batch_size=1)
31 | self.env.reset()
32 | self.last_action = None
33 |
34 | def step(self, action: str) -> Tuple[str, bool, bool, bool, int]:
35 | if action.startswith('put'):
36 | pattern = r'put (\w+\s*\d+) (?:in|on) (\w+\s*\d+)'
37 | match = re.match(pattern, action)
38 | if match is not None:
39 | action = 'put ' + match.group(1) + ' in/on ' + match.group(2)
40 |
41 | observation, reward, _ = self.alfworld_run(action)
42 | observation = observation.replace(' In it, you see nothing.', '').replace(', you see nothing', '')
43 | if self.last_action == action:
44 | self.truncated = True
45 | self.terminated = True
46 |
47 | self.last_action = action
48 |
49 | if reward:
50 | observation = 'Task is SOLVED.'
51 | self.terminated = True
52 | else:
53 | if self.is_truncated():
54 | observation = 'Max steps reached.'
55 | pass
56 |
57 | self.curr_step += 1
58 | self.terminated = self.is_terminated()
59 | self.truncated = self.is_truncated()
60 | self.reward = reward
61 |
62 | return observation, self.reward, self.terminated, self.truncated, self.curr_step
63 |
64 | def success_fn(self) -> bool:
65 | return self.reward
66 |
67 | def alfworld_run(self, action):
68 | observation, reward, done, info = self.env.step([action])
69 | observation, reward, done = process_observation(observation[0]), info['won'][0], done[0]
70 |
71 | return observation, reward, done
72 |
73 | def process_observation(obs):
74 | if obs.startswith('You arrive at loc '):
75 | obs = obs[obs.find('. ')+2:]
76 | return obs
--------------------------------------------------------------------------------
/envs/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Any
3 | import gym
4 |
5 | class BaseEnv(gym.Env):
6 | @abstractmethod
7 | def reset(self):
8 | pass
9 |
10 | @abstractmethod
11 | def step(self, action: str, *args, **kwargs) -> Dict[str, Any]:
12 | pass
13 |
14 | @abstractmethod
15 | def success_fn(self) -> bool:
16 | pass
17 |
18 | def is_terminated(self) -> bool:
19 | return self.terminated
20 |
21 | def is_truncated(self) -> bool:
22 | return self.curr_step > self.max_steps
23 |
--------------------------------------------------------------------------------
/envs/fever/fever.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | from .wikienv import WikiEnv
4 | from .wrappers import FeverWrapper
5 | from envs.base import BaseEnv
6 | from utils import parse_action
7 |
8 |
9 | class FeverEnv(BaseEnv):
10 | def __init__(self, idx, max_steps: int = 6) -> None:
11 | self.env = FeverWrapper(WikiEnv(), split="dev")
12 | self.task = """fact extract and verification. The agent was given access to a Docstore API environment and a fact to verify. The agent can search for pages related to the fact, lookup keywords in the pages, and finish with an answer."""
13 | self.max_steps = max_steps
14 | self.idx = idx
15 | self.env_name = 'fever'
16 |
17 | self.reset()
18 |
19 | def reset(self):
20 | self.curr_step = 1
21 | self.question = self.env.reset(idx=self.idx).replace('Claim: ', '')
22 | self.key = self.env.data[self.idx][1]
23 | self.terminated = False
24 | return self.question
25 |
26 | def step(self, action):
27 | action_type, argument = parse_action(action)
28 | if action_type == 'Finish':
29 | self.terminated = True
30 |
31 | attempts = 0
32 | while attempts < 10:
33 | try:
34 | obs, self.reward, done, info = self.env.step(action[0].lower() + action[1:])
35 | self.curr_step += 1
36 | terminated = self.is_terminated()
37 | truncated = self.is_truncated()
38 | return obs, self.reward, terminated, truncated, self.curr_step
39 | except requests.exceptions.Timeout:
40 | attempts += 1
41 |
42 | def success_fn(self) -> bool:
43 | return self.reward == 1
44 |
--------------------------------------------------------------------------------
/envs/fever/wikienv.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from ReAct Github
3 | """
4 |
5 | import ast
6 | import json
7 | import time
8 | import gym
9 | import requests
10 | from bs4 import BeautifulSoup
11 |
12 | # import wikipedia
13 |
14 | def clean_str(p):
15 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
16 |
17 |
18 | class textSpace(gym.spaces.Space):
19 | def contains(self, x) -> bool:
20 | """Return boolean specifying if x is a valid member of this space."""
21 | return isinstance(x, str)
22 |
23 |
24 | class WikiEnv(gym.Env):
25 |
26 | def __init__(self):
27 | """
28 | Initialize the environment.
29 | """
30 | super().__init__()
31 | self.page = None # current Wikipedia page
32 | self.obs = None # current observation
33 | self.lookup_keyword = None # current lookup keyword
34 | self.lookup_list = None # list of paragraphs containing current lookup keyword
35 | self.lookup_cnt = None # current lookup index
36 | self.steps = 0 # current number of steps
37 | self.answer = None # current answer from the agent
38 | self.observation_space = self.action_space = textSpace()
39 | self.search_time = 0
40 | self.num_searches = 0
41 |
42 | def _get_obs(self):
43 | return self.obs
44 |
45 | def _get_info(self):
46 | return {"steps": self.steps, "answer": self.answer}
47 |
48 | def reset(self, seed=None, return_info=False, options=None):
49 | # We need the following line to seed self.np_random
50 | # super().reset(seed=seed)
51 | self.obs = ("Interact with Wikipedia using search[], lookup[], and "
52 | "finish[].\n")
53 | self.page = None
54 | self.lookup_keyword = None
55 | self.lookup_list = None
56 | self.lookup_cnt = None
57 | self.steps = 0
58 | self.answer = None
59 | observation = self._get_obs()
60 | info = self._get_info()
61 | return (observation, info) if return_info else observation
62 |
63 | def construct_lookup_list(self, keyword):
64 | # find all paragraphs
65 | if self.page is None:
66 | return []
67 | paragraphs = self.page.split("\n")
68 | paragraphs = [p.strip() for p in paragraphs if p.strip()]
69 |
70 | # find all sentence
71 | sentences = []
72 | for p in paragraphs:
73 | sentences += p.split('. ')
74 | sentences = [s.strip() + '.' for s in sentences if s.strip()]
75 |
76 | parts = sentences
77 | parts = [p for p in parts if keyword.lower() in p.lower()]
78 | return parts
79 |
80 | @staticmethod
81 | def get_page_obs(page):
82 | # find all paragraphs
83 | paragraphs = page.split("\n")
84 | paragraphs = [p.strip() for p in paragraphs if p.strip()]
85 |
86 | # find all sentence
87 | sentences = []
88 | for p in paragraphs:
89 | sentences += p.split('. ')
90 | sentences = [s.strip() + '.' for s in sentences if s.strip()]
91 | return ' '.join(sentences[:5])
92 |
93 | # ps = page.split("\n")
94 | # ret = ps[0]
95 | # for i in range(1, len(ps)):
96 | # if len((ret + ps[i]).split(" ")) <= 50:
97 | # ret += ps[i]
98 | # else:
99 | # break
100 | # return ret
101 |
102 | def search_step(self, entity):
103 | entity_ = entity.replace(" ", "+")
104 | search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}"
105 | old_time = time.time()
106 | while True:
107 | try:
108 | response_text = requests.get(search_url).text
109 | break
110 | except:
111 | pass
112 | self.search_time += time.time() - old_time
113 | self.num_searches += 1
114 | soup = BeautifulSoup(response_text, features="html.parser")
115 | result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
116 | if result_divs: # mismatch
117 | self.result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
118 | self.obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}."
119 | else:
120 | page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")]
121 | if any("may refer to:" in p for p in page):
122 | self.search_step("[" + entity + "]")
123 | else:
124 | self.page = ""
125 | for p in page:
126 | if len(p.split(" ")) > 2:
127 | self.page += clean_str(p)
128 | if not p.endswith("\n"):
129 | self.page += "\n"
130 | self.obs = self.get_page_obs(self.page)
131 | self.lookup_keyword = self.lookup_list = self.lookup_cnt = None
132 |
133 | def step(self, action):
134 | reward = 0
135 | done = False
136 | action = action.strip()
137 | if self.answer is not None: # already finished
138 | done = True
139 | return self.obs, reward, done, self._get_info()
140 |
141 | if action.startswith("search[") and action.endswith("]"):
142 | entity = action[len("search["):-1]
143 | # entity_ = entity.replace(" ", "_")
144 | # search_url = f"https://en.wikipedia.org/wiki/{entity_}"
145 | self.search_step(entity)
146 | elif action.startswith("lookup[") and action.endswith("]"):
147 | keyword = action[len("lookup["):-1]
148 | if self.lookup_keyword != keyword: # reset lookup
149 | self.lookup_keyword = keyword
150 | self.lookup_list = self.construct_lookup_list(keyword)
151 | self.lookup_cnt = 0
152 | if self.lookup_cnt >= len(self.lookup_list):
153 | self.obs = "No more results.\n"
154 | else:
155 | self.obs = f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) " + self.lookup_list[self.lookup_cnt]
156 | self.lookup_cnt += 1
157 | elif action.startswith("finish[") and action.endswith("]"):
158 | answer = action[len("finish["):-1]
159 | self.answer = answer
160 | done = True
161 | self.obs = f"Episode finished, reward = {reward}\n"
162 | elif action.startswith("think[") and action.endswith("]"):
163 | self.obs = "Nice thought."
164 | else:
165 | self.obs = "Invalid action: {}".format(action)
166 |
167 | self.steps += 1
168 |
169 | return self.obs, reward, done, self._get_info()
170 |
171 | def get_time_info(self):
172 | speed = self.search_time / self.num_searches if self.num_searches else 0
173 | return {
174 | "call_speed": speed,
175 | "call_time": self.search_time,
176 | "num_calls": self.num_searches,
177 | }
--------------------------------------------------------------------------------
/envs/fever/wrappers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import gym
3 | import numpy as np
4 |
5 | from utils import normalize_answer
6 |
7 | DATA_DIR = "data/fever"
8 | HOTPOTQA_SPLIT_FILE = {
9 | "train": "hotpot_train_v1.1_simplified.json",
10 | "dev": "hotpot_dev_v1_simplified.json",
11 | "test": "hotpot_test_v1_simplified.json",
12 | }
13 |
14 | FEVER_SPLIT_FILE = {
15 | "train": "train.jsonl",
16 | "dev": "paper_dev.jsonl",
17 | }
18 |
19 | class FeverWrapper(gym.Wrapper):
20 | def __init__(self, env, split):
21 | super().__init__(env)
22 |
23 | data_path = f"{DATA_DIR}/{FEVER_SPLIT_FILE[split]}"
24 | with open(data_path, "r") as json_file:
25 | json_list = list(json_file)
26 |
27 | data = []
28 | for json_str in json_list:
29 | json_str = json.loads(json_str)
30 | label = json_str["label"]
31 | claim = json_str["claim"]
32 | data.append((claim, label))
33 |
34 | self.data = data
35 | self.data_idx = 0
36 | self.split = split
37 |
38 | def reset(self, seed=None, return_info=False, options=None, idx=None):
39 | self.env.reset(seed=seed, return_info=return_info, options=options)
40 | try:
41 | self.env.step('')
42 | except:
43 | pass
44 | self.env.reset(seed=seed, return_info=return_info, options=options)
45 | self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx
46 | observation = f"Claim: {self.data[self.data_idx][0]}"
47 | info = self._get_info()
48 | return (observation, info) if return_info else observation
49 |
50 | def _get_info(self):
51 | return {
52 | "steps": self.steps,
53 | "answer": self.answer,
54 | "question": self.data[self.data_idx][0],
55 | "fever_split": self.split
56 | }
57 |
58 | def get_reward(self, info):
59 | if info['answer'] is not None:
60 | label = normalize_answer(self.data[self.data_idx][1])
61 | pred = normalize_answer(info['answer'])
62 | if label == pred:
63 | return 1
64 | return 0
65 |
66 | def step(self, action):
67 | obs, _, done, info = self.env.step(action)
68 | reward = self.get_reward(info)
69 | if done:
70 | obs = f"Episode finished, reward = {reward}\n"
71 | info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx})
72 | info.update({'em': reward, 'reward': reward, 'f1': reward})
73 | return obs, reward, done, info
74 |
75 | def __len__(self):
76 | return len(self.data)
77 |
--------------------------------------------------------------------------------
/envs/hotpotqa/hotpotqa.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 | from typing import Tuple
4 | import time
5 |
6 | from langchain import Wikipedia
7 | from langchain.agents.react.base import DocstoreExplorer
8 |
9 | from envs.base import BaseEnv
10 | from utils import parse_action, EM
11 |
12 | class QAEnv(BaseEnv):
13 | def __init__(self,
14 | question: str,
15 | key: str,
16 | max_steps: int = 6,
17 | explorer: DocstoreExplorer = DocstoreExplorer(Wikipedia())):
18 |
19 | self.question = question
20 | self.key = key
21 | self.max_steps = max_steps
22 | self.explorer = explorer
23 | self.task = """multi-hop QA. The agent was given access to a Docstore API environment and a question to answer. The agent can search for pages related to the question, lookup keywords in the pages, and finish with an answer."""
24 | self.env_name = 'hotpotqa'
25 |
26 | self.reset()
27 |
28 | def reset(self):
29 | self.curr_step = 1
30 | self.answer = ''
31 | self.terminated = False
32 |
33 | def step(self, action: str) -> Tuple[str, bool, bool, bool, bool]:
34 | action_type, argument = parse_action(action)
35 |
36 | if action_type == 'Finish':
37 | self.answer = argument
38 | if self.success_fn():
39 | observation = 'Answer is CORRECT'
40 | else:
41 | observation = f'Answer is INCORRECT'
42 | self.terminated = True
43 | elif action_type == 'Search':
44 | while True:
45 | try:
46 | observation = self.explorer.search(argument).strip('\n').strip()
47 | break
48 | except Exception as e:
49 | print(e)
50 | time.sleep(5)
51 | elif action_type == 'Lookup':
52 | try:
53 | observation = self.explorer.lookup(argument).strip('\n').strip()
54 | except ValueError:
55 | observation = f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.'
56 | else:
57 | observation = 'Invalid Action. Valid Actions are Lookup[] Search[] and Finish[].'
58 |
59 | self.curr_step += 1
60 | self.reward = self.success_fn()
61 | self.terminated = self.is_terminated()
62 | self.truncated = self.is_truncated()
63 |
64 | return observation, self.reward, self.terminated, self.truncated, self.curr_step
65 |
66 | def success_fn(self) -> bool:
67 | return EM(self.answer, self.key)
68 |
--------------------------------------------------------------------------------
/envs/webshop/webshop.py:
--------------------------------------------------------------------------------
1 | import math
2 | import requests
3 | from bs4 import BeautifulSoup
4 | from bs4.element import Comment
5 | import re
6 | import numpy as np
7 |
8 | from typing import Any, Dict, List, Tuple
9 | from envs.base import BaseEnv
10 |
11 |
12 | # Type in the URL of the webshop server:
13 | # If local:
14 | WEBSHOP_URL = "http://127.0.0.1:3000"
15 |
16 | ACTION_TO_TEMPLATE = {
17 | 'Description': 'description_page.html',
18 | 'Features': 'features_page.html',
19 | 'Reviews': 'review_page.html',
20 | 'Attributes': 'attributes_page.html',
21 | }
22 |
23 | def clean_str(p):
24 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
25 |
26 | def tag_visible(element):
27 | ignore = {'style', 'script', 'head', 'title', 'meta', '[document]'}
28 | return (
29 | element.parent.name not in ignore and not isinstance(element, Comment)
30 | )
31 |
32 | def webshop_text(session, page_type, query_string='', page_num=1, asin='', options={}, subpage='', **kwargs):
33 | if page_type == 'init':
34 | url = (
35 | f'{WEBSHOP_URL}/{session}'
36 | )
37 | if page_type == 'search':
38 | url = (
39 | f'{WEBSHOP_URL}/search_results/{session}/'
40 | f'{query_string}/{page_num}'
41 | )
42 | elif page_type == 'item':
43 | url = (
44 | f'{WEBSHOP_URL}/item_page/{session}/'
45 | f'{asin}/{query_string}/{page_num}/{str(options).replace("#","%23")}' # FIXING '#' in url ISSUE
46 | )
47 | elif page_type == 'item_sub':
48 | url = (
49 | f'{WEBSHOP_URL}/item_sub_page/{session}/'
50 | f'{asin}/{query_string}/{page_num}/{subpage}/{str(options).replace("#","%23")}' # FIXING '#' in url ISSUE
51 | )
52 | elif page_type == 'end':
53 | url = (
54 | f'{WEBSHOP_URL}/done/{session}/'
55 | f'{asin}/{str(options).replace("#","%23")}'
56 | )
57 | html = requests.get(url).text # type: ignore
58 | html_obj = BeautifulSoup(html, 'html.parser')
59 | texts = html_obj.findAll(text=True)
60 | visible_texts = list(filter(tag_visible, texts))
61 | if False:
62 | # For `simple` mode, return just [SEP] separators
63 | return ' [SEP] '.join(t.strip() for t in visible_texts if t != '\n')
64 | else:
65 | # Otherwise, return an observation with tags mapped to specific, unique separators
66 | observation = ''
67 | option_type = ''
68 | options = {}
69 | asins = []
70 | cnt = 0
71 | prod_cnt = 0
72 | just_prod = 0
73 | for t in visible_texts:
74 | if t == '\n': continue
75 | if t.replace('\n', '').replace('\\n', '').replace(' ', '') == '': continue
76 | if t.parent.name == 'button': # button
77 | processed_t = f'\n[{t}] '
78 | elif t.parent.name == 'label': # options
79 | if f"'{t}'" in url: # type: ignore
80 | processed_t = f'[[{t}]]'
81 | else:
82 | processed_t = f'[{t}]'
83 | options[str(t)] = option_type
84 | elif t.parent.get('class') == ["product-link"]: # product asins
85 | processed_t = f'\n[{t}] '
86 | prod_cnt += 1
87 | asins.append(str(t))
88 | just_prod = 0
89 | else: # regular, unclickable text
90 | processed_t = '\n' + str(t) + ' '
91 | option_type = str(t)
92 | cnt += 1
93 | just_prod += 1
94 | observation += processed_t
95 | info = {}
96 | if options:
97 | info['option_types'] = options
98 | if asins:
99 | info['asins'] = asins
100 | if 'Your score (min 0.0, max 1.0)' in visible_texts:
101 | idx = visible_texts.index('Your score (min 0.0, max 1.0)')
102 | info['reward'] = round(float(visible_texts[idx + 1]), 2)
103 | observation = 'Your score (min 0.0, max 1.0): ' + (visible_texts[idx + 1])
104 |
105 | # Prompt harmonization
106 | observation = observation.replace('\nWebShop ', '')
107 | observation = observation.replace('\nInstruction: ', '')
108 | observation = observation.replace('[Search]\n', '[Search]')
109 |
110 | return clean_str(observation), info
111 |
112 | class WebshopEnv(BaseEnv):
113 | def __init__(self, session_idx: str, max_steps: int = 15):
114 | self.session_idx = session_idx
115 | self.max_steps = max_steps
116 | self.env_name = 'webshop'
117 | self.task = "instruction-following shopping task. The agent interacted with an online store website in order to buy the requested item based on an instruction."
118 |
119 | self.reset()
120 |
121 |
122 | def reset(self):
123 | self.curr_step = 1
124 | self.reward = False
125 | self.terminated = False
126 | self.last_action = None
127 |
128 | self.session = {'session': self.session_idx, 'page_type': 'init'}
129 | observation, info = webshop_text(**self.session)
130 | self.session.update(info)
131 |
132 | def success_fn(self) -> bool:
133 | return self.reward
134 |
135 | def step(self, action: str) -> Tuple[str, bool, bool, bool, int]:
136 | done = False
137 | observation_ = None
138 |
139 | try:
140 | if action == 'reset':
141 | self.session = {'session': self.session_idx, 'page_type': 'init'}
142 | elif action.startswith('think['):
143 | observation = 'OK.'
144 | elif action.startswith('search['):
145 | assert self.session['page_type'] == 'init'
146 | query = action[7:-1]
147 | self.session = {'session': self.session_idx, 'page_type': 'search',
148 | 'query_string': query, 'page_num': 1, 'fake_page_num' : 1}
149 | elif action.startswith('click['):
150 | button = action[6:-1]
151 | if button == 'Buy Now':
152 | assert self.session['page_type'] == 'item'
153 | self.session['page_type'] = 'end'
154 | self.terminated = True
155 | elif button == 'Back to Search':
156 | assert self.session['page_type'] in ['search', 'item_sub', 'item']
157 | self.session = {'session': self.session_idx, 'page_type': 'init'}
158 | elif button == 'Next >':
159 | # assert False # ad hoc page limitation
160 | assert self.session['page_type'] == 'search'
161 | assert self.session['page_num'] < math.ceil(self.session['max_products'] / 10) # Seemed already capped at 50
162 | self.session['page_type'] = 'search'
163 | self.session['page_num'] += 1
164 |
165 | elif button == '< Prev':
166 | assert self.session['page_type'] in ['search', 'item_sub', 'item']
167 | if self.session['page_type'] == 'search':
168 | assert self.session['page_num'] > 1
169 | self.session['page_num'] -= 1
170 | self.session['page_type'] = 'search'
171 | elif self.session['page_type'] == 'item_sub':
172 | self.session['page_type'] = 'item'
173 | elif self.session['page_type'] == 'item':
174 | self.session['page_type'] = 'search'
175 | self.session['options'] = {}
176 | elif button in ACTION_TO_TEMPLATE:
177 | assert self.session['page_type'] == 'item'
178 | self.session['page_type'] = 'item_sub'
179 | self.session['subpage'] = button
180 | else:
181 | if self.session['page_type'] == 'search':
182 | assert button in self.session.get('asins', []) # must be asins
183 | self.session['page_type'] = 'item'
184 | self.session['asin'] = button
185 | elif self.session['page_type'] == 'item':
186 | assert 'option_types' in self.session
187 | assert button in self.session['option_types'], (button, self.session['option_types']) # must be options
188 | option_type = self.session['option_types'][button]
189 | if not 'options' in self.session:
190 | self.session['options'] = {}
191 | self.session['options'][option_type] = button
192 | observation_ = f'You have clicked {button}.'
193 | else:
194 | assert False
195 | except AssertionError:
196 | observation_ = 'Invalid action!'
197 | if invalid_repeat(action=action, last_action=self.last_action):
198 | self.truncated = True
199 | self.terminated = True
200 | observation_ = 'Repeated action!'
201 |
202 | observation, info = webshop_text(**self.session)
203 | # update the max number of products of a query when we search
204 | if observation_ not in ['Invalid action!','Repeated action!'] and action.startswith('search['):
205 | pattern = r'\(Total results: (\d+)\)'
206 | max_products = int(re.findall(pattern, observation)[0])
207 | self.session.update({'max_products': max_products})
208 | if observation_:
209 | observation = observation_
210 |
211 | # OK from react code for observation:
212 | if action.startswith('think['):
213 | observation = 'OK.'
214 |
215 | self.session.update(info)
216 | reward = info.get('reward', 0.0)
217 |
218 | self.curr_step += 1
219 |
220 | if self.is_truncated() and not self.is_terminated():
221 | observation += ('\n\n' if observation != '' else '') + 'Ran out of steps! TASK FAILED'
222 |
223 | self.reward = (reward==1)
224 | self.last_action = action
225 |
226 | return observation, self.reward, self.is_terminated(), self.is_truncated(), self.curr_step
227 |
228 | def invalid_repeat(action: str, last_action: str) -> bool:
229 | if last_action is None:
230 | return False
231 | not_start_list = ['search[', 'think[i apologize', 'think[end', 'think[i\'m sorry', 'think[apolog']
232 | for word in not_start_list:
233 | if action.lower().startswith(word) and last_action.lower().startswith(word):
234 | return True
235 | return False
236 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import getpass
3 | import hydra
4 | from omegaconf import DictConfig
5 | from pathlib import Path
6 | import os
7 | from copy import deepcopy
8 | from functools import partial
9 | import dotenv
10 | dotenv.load_dotenv()
11 |
12 | from agent import AGENT
13 | from prompts.templates.system import system_message_prompt
14 | from prompts.templates.human import HUMAN_CRITIQUES
15 | from prompts import (
16 | SYSTEM_INSTRUCTION,
17 | HUMAN_INSTRUCTION,
18 | FEWSHOTS,
19 | REFLECTION_FEWSHOTS,
20 | HUMAN_REFLECTION_INSTRUCTION,
21 | SYSTEM_REFLECTION_INSTRUCTION,
22 | SYSTEM_CRITIQUE_INSTRUCTION,
23 | RULE_TEMPLATE,
24 | LLM_PARSER,
25 | OBSERVATION_FORMATTER,
26 | STEP_IDENTIFIER,
27 | CYCLER,
28 | STEP_CYCLER,
29 | REFLECTION_PREFIX,
30 | PREVIOUS_TRIALS_FORMATTER,
31 | STEP_STRIPPER,
32 | CRITIQUE_SUMMARY_SUFFIX,
33 | )
34 | from envs import ENVS, INIT_TASKS_FN
35 | from memory import (
36 | EMBEDDERS,
37 | RETRIEVERS,
38 | )
39 | from models import LLM_CLS
40 | from utils import get_fewshot_max_tokens, load_trajectories_log, save_trajectories_log, split_logs_by_task, plot_trial_stats, alfworld_results_per_env_name_log, get_webshop_mean_score, get_split_eval_idx_list
41 |
42 |
43 | def get_eval_num(eval_idx: int, eval_idx_list: List[List[int]]) -> int:
44 | eval_num = 0
45 | for eval_idxs in eval_idx_list:
46 | if eval_idx in eval_idxs:
47 | break
48 | eval_num += len(eval_idxs)
49 | return eval_num + eval_idxs.index(eval_idx)
50 |
51 |
52 | @hydra.main(version_base=None, config_path="configs", config_name="eval")
53 | def main(cfg : DictConfig) -> None:
54 | if cfg.testing:
55 | openai_api_key = 'NO_KEY_FOR_TESTING'
56 | else:
57 | openai_api_key = os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else getpass.getpass("Enter or paste your OpenAI API Key: ")
58 | LOG_PATH = Path('/'.join([cfg.log_dir, cfg.benchmark.name, cfg.agent_type]))
59 | SAVE_PATH = LOG_PATH / 'eval'
60 | SAVE_PATH.mkdir(exist_ok=True)
61 |
62 | print(f"{SAVE_PATH}/{cfg.run_name}.pkl")
63 |
64 | # Overwriting confirmation
65 | if not cfg.resume and os.path.exists(f"{SAVE_PATH}/{cfg.run_name}.pkl") and cfg.run_name != 'test':
66 | while True:
67 | res = input(f"Are you sure to overwrite '{cfg.run_name}'? (Y/N)\n").lower()
68 | if res == 'n':
69 | exit(0)
70 | elif res == 'y':
71 | break
72 |
73 | # Load trajectory checkpoint
74 | out = load_trajectories_log(
75 | SAVE_PATH if cfg.resume else LOG_PATH,
76 | run_name=cfg.load_run_name,
77 | load_log=cfg.resume,
78 | load_true_log=cfg.resume)
79 | dicts = out['dicts']
80 | eval_idx_list = dicts[-1].get(
81 | 'eval_idx_list',
82 | get_split_eval_idx_list(dicts[-1], cfg.benchmark.eval_configs.k_folds))
83 | log = out['log'] if cfg.resume else f'### EVAL ORDER ###\n{eval_idx_list}\n'
84 | true_log = out['true_log'] if cfg.resume else f'### EVAL ORDER ###\n{eval_idx_list}\n{str(cfg)}\n'
85 |
86 | num_training_tasks = len(INIT_TASKS_FN[cfg.benchmark.name](cfg))
87 |
88 | # we start at fold 0 if we are starting a new run
89 | starting_fold = dicts[-1].get('starting_fold', 0)
90 | # we start at the first task in the fold if we are starting a new run
91 | starting_idx = dicts[-1].get('starting_idx', eval_idx_list[0][0])
92 |
93 | react_agent = AGENT[cfg.agent_type](
94 | name=cfg.ai_name,
95 | system_instruction=SYSTEM_INSTRUCTION[cfg.benchmark.name],
96 | human_instruction=HUMAN_INSTRUCTION[cfg.benchmark.name],
97 | tasks=INIT_TASKS_FN[cfg.benchmark.name](cfg),
98 | fewshots=FEWSHOTS[cfg.benchmark.name],
99 | system_prompt=system_message_prompt,
100 | env=ENVS[cfg.benchmark.name],
101 | max_steps=cfg.benchmark.max_steps,
102 | openai_api_key=openai_api_key,
103 | llm=cfg.agent.llm,
104 | llm_builder=LLM_CLS,
105 | reflection_fewshots=REFLECTION_FEWSHOTS[cfg.benchmark.name],
106 | reflection_task_prompt=HUMAN_REFLECTION_INSTRUCTION[cfg.benchmark.name],
107 | reflection_system_instruction=SYSTEM_REFLECTION_INSTRUCTION[cfg.benchmark.name],
108 | reflection_system_prompt=SYSTEM_INSTRUCTION[cfg.benchmark.name],
109 | max_relfection_depth=cfg.agent.max_reflection_depth if 'max_reflection_depth' in cfg.agent.keys() else 0,
110 | system_critique_instructions=SYSTEM_CRITIQUE_INSTRUCTION[cfg.benchmark.name],
111 | human_critiques=HUMAN_CRITIQUES,
112 | max_num_rules=cfg.agent.max_num_rules if 'max_num_rules' in cfg.agent.keys() else 0,
113 | rule_template=RULE_TEMPLATE[cfg.benchmark.name],
114 | truncate_strategy=cfg.agent.truncate_strategy if 'truncate_strategy' in cfg.agent.keys() else None,
115 | llm_parser=LLM_PARSER[cfg.benchmark.name],
116 | observation_formatter=OBSERVATION_FORMATTER[cfg.benchmark.name],
117 | embedder=EMBEDDERS(cfg.agent.retrieval_kwargs.embedder_type),
118 | embedder_path=cfg.agent.retrieval_kwargs.embedder_path,
119 | step_stripper=STEP_STRIPPER[cfg.benchmark.name],
120 | retriever_cls=RETRIEVERS(cfg.agent.retrieval_kwargs.retriever_type),
121 | message_splitter=CYCLER[cfg.benchmark.name],
122 | identifier=STEP_IDENTIFIER[cfg.benchmark.name],
123 | message_step_splitter=partial(STEP_CYCLER, benchmark=cfg.benchmark.name),
124 | reflection_prefix=REFLECTION_PREFIX[cfg.benchmark.name],
125 | previous_trials_formatter=PREVIOUS_TRIALS_FORMATTER[cfg.benchmark.name],
126 | success_critique_num=cfg.agent.success_critique_num,
127 | fewshot_strategy=cfg.agent.fewshot_strategy,
128 | benchmark_name=cfg.benchmark.name,
129 | reranker=cfg.agent.retrieval_kwargs.reranker,
130 | buffer_retrieve_ratio=cfg.agent.retrieval_kwargs.buffer_retrieve_ratio,
131 | critique_truncate_strategy=cfg.agent.critique_truncate_strategy,
132 | critique_summary_suffix=CRITIQUE_SUMMARY_SUFFIX,
133 | testing=cfg.testing,
134 | task_idx=starting_idx,
135 | max_fewshot_tokens=get_fewshot_max_tokens(cfg.benchmark.name) if cfg.agent.retrieval_kwargs.max_fewshot_tokens == 'auto' else cfg.agent.retrieval_kwargs.max_fewshot_tokens,
136 | )
137 |
138 | if len(dicts) > 0:
139 | no_load_list = ['ai_message', 'message_type_format', 'max_num_rules', 'testing', 'human_critiques', 'system_critique_instructions', 'fewshot_strategy', 'success', 'halted', 'fail', 'task_idx', 'prompt_history', 'critique_truncate_strategy', 'success_critique_num', 'reflection_fewshots', 'reflection_system_prompt', 'reflection_prefix', 'reflection_prompt_history', 'reflections', 'previous_trial', 'perform_reflection', 'increment_task', 'reflection_system_kwargs', 'prepend_human_instruction', 'name', 'tasks', 'human_instruction_kwargs', 'all_system_instruction', 'all_fewshots', 'max_steps', 'ordered_summary', 'fewshots', 'system_instruction', 'num_fewshots', 'curr_step', 'log_idx', 'pretask_idx', 'reflect_interaction_idx', 'truncated', 'reward', 'terminated', 'autoregressive_model_instruction', 'failed_training_task_idx', '_train', 'task',
140 | 'eval_idx_list', 'starting_fold', 'starting_idx', 'rule_template', 'max_fewshot_tokens', 'buffer_retrieve_ratio']
141 | react_agent.load_checkpoint(dicts[-1], no_load_list=no_load_list)
142 | # resetting task_idx
143 | react_agent.task = react_agent.tasks[starting_idx]['task']
144 | react_agent.reset()
145 |
146 | react_agent.eval()
147 | start_processing = False # Flag for starting_fold
148 | start_eval_idx = False # Flag for starting_idx
149 | first_flag = True
150 | react_agent.no_rules = cfg.no_rules
151 |
152 | print(f'*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*\n\nWe are using the following model: {react_agent.llm.model_name}\n\n*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*')
153 | true_log += str(react_agent.llm.llm) + '\n'
154 | first_idxs = [eval_idxs[0] for eval_idxs in eval_idx_list]
155 |
156 | # start evaluating
157 | for k, eval_idxs in enumerate(eval_idx_list):
158 | # syncing fold
159 | if k == starting_fold or start_processing:
160 | start_processing = True
161 | if not cfg.no_rules:
162 | # make sure every start of fold does create_rules
163 | if not first_flag:
164 | starting_idx = eval_idxs[0]
165 | # create rules for each fold, if starting a new fold
166 | if starting_idx == eval_idxs[0]:
167 | training_ids = set(range(num_training_tasks)) - set(eval_idxs)
168 | react_agent.create_rules(
169 | list(training_ids),
170 | cache_fold=None,
171 | load_cache_fold=k if cfg.load_cache_rules else None,
172 | )
173 | first_flag = False
174 |
175 | # evaluate on each task in the fold
176 | for eval_idx in eval_idxs:
177 | # syncing idx
178 | if eval_idx == starting_idx or start_eval_idx:
179 | start_eval_idx = True
180 | # Skip the first matching eval_idx
181 | if eval_idx == starting_idx and starting_idx not in first_idxs:
182 | continue
183 | prefix = f"#######################################\nTASK {get_eval_num(eval_idx, eval_idx_list)} \nFOLD: {k}, EVAL_IDX: {eval_idx}\n" # the space after TASK \d+ is needed for log results parsing
184 | prefix += react_agent.remove_task_suffix(react_agent.tasks[eval_idx]['task']) + '\n'
185 | print(prefix)
186 |
187 | react_agent.run(mode='eval', eval_idx=eval_idx)
188 |
189 | # logging
190 | react_agent.update_stats()
191 | log += prefix + react_agent.log_history(include_task=False) + '\n\n'
192 | true_log += prefix + react_agent.log_history(include_all=True, include_task=False) + '\n\n'
193 | # not saving other complicated objects
194 | eval_dict = {k: deepcopy(v) for k, v in react_agent.__dict__.items() if type(v) in [list, set, str, bool, int, dict]}
195 | eval_dict.update({
196 | 'eval_idx_list': eval_idx_list,
197 | 'starting_fold': k,
198 | 'starting_idx': eval_idx,
199 | })
200 | dicts.append(eval_dict)
201 | save_trajectories_log(
202 | path=SAVE_PATH,
203 | log=log,
204 | dicts=dicts,
205 | true_log=true_log,
206 | run_name=f'{cfg.run_name}'
207 | )
208 |
209 | # logging to files
210 | success, fail, halted = react_agent.get_stats()
211 | log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}"
212 | true_log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}"
213 |
214 | print(f'Finished. Success: {success}, Fail: {fail}, Halted: {halted}')
215 |
216 | parsed_result = split_logs_by_task(text=log, num_tasks=len(react_agent.tasks))
217 | reflection_results = plot_trial_stats(parsed_result=parsed_result, benchmark=cfg.benchmark.name, max_trials=1, save_path=f"{LOG_PATH}/{cfg.run_name}_logs_stats.png")
218 |
219 | results = ', '.join([f"{k}: {v}" for k, v in reflection_results.items()]) + '\n'
220 | if cfg.benchmark.name == 'alfworld':
221 | results += str(alfworld_results_per_env_name_log(log, len(react_agent.tasks), 1))
222 | elif cfg.benchmark.name == 'webshop':
223 | results += str(get_webshop_mean_score(log, len(react_agent.tasks), 1))
224 | log += f'\n\n{results}\n########################################'
225 | true_log += f'\n\n{results}\n########################################'
226 | print(results)
227 |
228 | save_trajectories_log(
229 | path=SAVE_PATH,
230 | log=log,
231 | dicts=dicts,
232 | true_log=true_log,
233 | run_name=f'{cfg.run_name}'
234 | )
235 |
236 | if __name__ == "__main__":
237 | main()
238 |
--------------------------------------------------------------------------------
/insight_extraction.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import hydra
3 | from omegaconf import DictConfig
4 | from pathlib import Path
5 | from functools import partial
6 | import os
7 | import random
8 |
9 | from agent import AGENT
10 | from langchain.chat_models import ChatOpenAI
11 | from prompts.templates.system import system_message_prompt
12 | from prompts.templates.human import HUMAN_CRITIQUES
13 | from prompts import (
14 | SYSTEM_INSTRUCTION,
15 | HUMAN_INSTRUCTION,
16 | FEWSHOTS,
17 | REFLECTION_FEWSHOTS,
18 | HUMAN_REFLECTION_INSTRUCTION,
19 | SYSTEM_REFLECTION_INSTRUCTION,
20 | SYSTEM_CRITIQUE_INSTRUCTION,
21 | RULE_TEMPLATE,
22 | LLM_PARSER,
23 | OBSERVATION_FORMATTER,
24 | STEP_IDENTIFIER,
25 | CYCLER,
26 | STEP_CYCLER,
27 | REFLECTION_PREFIX,
28 | PREVIOUS_TRIALS_FORMATTER,
29 | STEP_STRIPPER,
30 | CRITIQUE_SUMMARY_SUFFIX,
31 | )
32 | from envs import ENVS, INIT_TASKS_FN
33 | from memory import (
34 | EMBEDDERS,
35 | RETRIEVERS,
36 | )
37 | from models import LLM_CLS
38 | from utils import save_trajectories_log, load_trajectories_log, shuffled_chunks, get_split_eval_idx_list
39 | from agent.reflect import Count
40 |
41 | from dotenv import load_dotenv
42 | load_dotenv()
43 |
44 |
45 | @hydra.main(version_base=None, config_path="configs", config_name="insight_extraction")
46 | def main(cfg : DictConfig) -> None:
47 | if cfg.testing:
48 | openai_api_key = 'NO_KEY_FOR_TESTING'
49 | else:
50 | openai_api_key = os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else getpass.getpass("Enter or paste your OpenAI API Key: ")
51 | LOG_PATH = Path('/'.join([cfg.log_dir, cfg.benchmark.name, cfg.agent_type]))
52 | SAVE_PATH = LOG_PATH / 'extracted_insights'
53 | SAVE_PATH.mkdir(exist_ok=True)
54 |
55 | # Overwriting confirmation
56 | if not cfg.resume and os.path.exists(f"{SAVE_PATH}/{cfg.run_name}.pkl") and cfg.run_name != 'test':
57 | while True:
58 | res = input(f"Are you sure to overwrite '{cfg.run_name}'? (Y/N)\n").lower()
59 | if res == 'n':
60 | exit(0)
61 | elif res == 'y':
62 | break
63 | if cfg.resume and cfg.resume_fold < 0:
64 | print('Specify a fold to resume when resuming a run! (resume_fold=X)')
65 | exit(1)
66 | out = load_trajectories_log(
67 | SAVE_PATH / f"fold_{cfg.resume_fold}" if cfg.resume_fold > -1 else LOG_PATH,
68 | run_name=cfg.load_run_name,
69 | load_log=cfg.resume,
70 | load_true_log=False
71 | )
72 | dicts = out['dicts']
73 | log = out['log'] if cfg.resume else ''
74 |
75 | cfg.folded = True
76 | react_agent = AGENT[cfg.agent_type](
77 | name=cfg.ai_name,
78 | system_instruction=SYSTEM_INSTRUCTION[cfg.benchmark.name],
79 | human_instruction=HUMAN_INSTRUCTION[cfg.benchmark.name],
80 | tasks=INIT_TASKS_FN[cfg.benchmark.name](cfg),
81 | fewshots=FEWSHOTS[cfg.benchmark.name],
82 | system_prompt=system_message_prompt,
83 | env=ENVS[cfg.benchmark.name],
84 | max_steps=cfg.benchmark.max_steps,
85 | openai_api_key=openai_api_key,
86 | llm=cfg.agent.llm,
87 | llm_builder=LLM_CLS,
88 | reflection_fewshots=REFLECTION_FEWSHOTS[cfg.benchmark.name],
89 | reflection_task_prompt=HUMAN_REFLECTION_INSTRUCTION[cfg.benchmark.name],
90 | reflection_system_instruction=SYSTEM_REFLECTION_INSTRUCTION[cfg.benchmark.name],
91 | reflection_system_prompt=SYSTEM_INSTRUCTION[cfg.benchmark.name],
92 | max_relfection_depth=cfg.agent.max_reflection_depth if 'max_reflection_depth' in cfg.agent.keys() else 0,
93 | system_critique_instructions=SYSTEM_CRITIQUE_INSTRUCTION[cfg.benchmark.name],
94 | human_critiques=HUMAN_CRITIQUES,
95 | max_num_rules=cfg.agent.max_num_rules if 'max_num_rules' in cfg.agent.keys() else 0,
96 | rule_template=RULE_TEMPLATE[cfg.benchmark.name],
97 | truncate_strategy=cfg.agent.truncate_strategy if 'truncate_strategy' in cfg.agent.keys() else None,
98 | llm_parser=LLM_PARSER[cfg.benchmark.name],
99 | observation_formatter=OBSERVATION_FORMATTER[cfg.benchmark.name],
100 | embedder=EMBEDDERS(cfg.agent.retrieval_kwargs.embedder_type),
101 | embedder_path=cfg.agent.retrieval_kwargs.embedder_path,
102 | step_stripper=STEP_STRIPPER[cfg.benchmark.name],
103 | retriever_cls=RETRIEVERS(cfg.agent.retrieval_kwargs.retriever_type),
104 | message_splitter=CYCLER[cfg.benchmark.name],
105 | identifier=STEP_IDENTIFIER[cfg.benchmark.name],
106 | message_step_splitter=partial(STEP_CYCLER, benchmark=cfg.benchmark.name),
107 | reflection_prefix=REFLECTION_PREFIX[cfg.benchmark.name],
108 | previous_trials_formatter=PREVIOUS_TRIALS_FORMATTER[cfg.benchmark.name],
109 | success_critique_num=cfg.agent.success_critique_num,
110 | fewshot_strategy=cfg.agent.fewshot_strategy,
111 | benchmark_name=cfg.benchmark.name,
112 | reranker=cfg.agent.retrieval_kwargs.reranker,
113 | buffer_retrieve_ratio=cfg.agent.retrieval_kwargs.buffer_retrieve_ratio,
114 | critique_truncate_strategy=cfg.agent.critique_truncate_strategy,
115 | critique_summary_suffix=CRITIQUE_SUMMARY_SUFFIX,
116 | testing=cfg.testing,
117 | max_fewshot_tokens = cfg.agent.retrieval_kwargs.max_fewshot_tokens,
118 | )
119 |
120 | print(f'Loading agent from {LOG_PATH}')
121 | no_load_list = ['ai_message', 'message_type_format', 'max_num_rules', 'testing', 'human_critiques', 'system_critique_instructions', 'fewshot_strategy', 'success', 'halted', 'fail', 'task_idx', 'prompt_history', 'critique_truncate_strategy', 'success_critique_num', 'reflection_fewshots', 'reflection_system_prompt', 'reflection_prefix', 'reflection_prompt_history', 'reflections', 'previous_trial', 'perform_reflection', 'increment_task', 'reflection_system_kwargs', 'prepend_human_instruction', 'name', 'tasks', 'human_instruction_kwargs', 'all_system_instruction', 'all_fewshots', 'max_steps', 'ordered_summary', 'fewshots', 'system_instruction', 'num_fewshots', 'curr_step', 'log_idx', 'pretask_idx', 'reflect_interaction_idx', 'truncated', 'reward', 'terminated', 'autoqregressive_model_instruction', 'failed_training_task_idx', '_train', 'task', 'eval_idx_list', 'starting_fold', 'starting_idx', 'critique_summary_suffix']
122 | react_agent.load_checkpoint(dicts[-1], no_load_list=no_load_list)
123 |
124 | random.seed(cfg.seed)
125 | num_training_tasks = len(INIT_TASKS_FN[cfg.benchmark.name](cfg))
126 | if not cfg.resume:
127 | resume = False
128 | else:
129 | resume = 'eval_idx_list' in dicts[-1]
130 | eval_idx_list = dicts[-1].get('eval_idx_list', get_split_eval_idx_list(dicts[-1], cfg.benchmark.eval_configs.k_folds))
131 |
132 | print(f'eval_idx_list: {eval_idx_list}')
133 | starting_fold = dicts[-1]['starting_fold'] = dicts[-1].get('critique_summary_fold', 0)
134 |
135 | resume_starting_fold = starting_fold
136 | critique_summary_log = dicts[-1].get('critique_summary_log', '')
137 |
138 | for k, eval_idxs in enumerate(eval_idx_list):
139 | if k < starting_fold:
140 | continue
141 | training_ids = set(range(num_training_tasks)) - set(eval_idxs)
142 | (SAVE_PATH / f"fold_{k}").mkdir(exist_ok=True)
143 | log += f'################## FOLD {k} ##################\n'
144 | log += react_agent.create_rules(
145 | list(training_ids),
146 | cache_fold=k,
147 | logging_dir=str(SAVE_PATH / f"fold_{k}"),
148 | run_name=cfg.run_name,
149 | loaded_dict=dicts[-1] if resume and resume_starting_fold == starting_fold else None,
150 | loaded_log=critique_summary_log if resume and resume_starting_fold == starting_fold else None,
151 | eval_idx_list=eval_idx_list,
152 | saving_dict=True,
153 | )
154 | starting_fold += 1
155 |
156 | save_dict = {k: v for k, v in react_agent.__dict__.items() if type(v) in [list, set, str, bool, int, dict, Count] and k not in ['openai_api_key', 'llm']}
157 | if cfg.folded:
158 | save_dict['eval_idx_list'] = eval_idx_list
159 | dicts.append(save_dict)
160 | save_trajectories_log(
161 | path=SAVE_PATH,
162 | log=log,
163 | dicts=dicts,
164 | run_name=f'{cfg.run_name}',
165 | save_true_log=False
166 | )
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/memory/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Any, Dict, List, Optional
2 |
3 | from pydantic import BaseModel, Extra, Field
4 | from langchain.retrievers import SVMRetriever, KNNRetriever
5 | from langchain.embeddings import OpenAIEmbeddings, LlamaCppEmbeddings, HuggingFaceEmbeddings
6 | from pydantic import BaseModel
7 | from langchain.embeddings.base import Embeddings
8 |
9 | from .episode import Trajectory
10 |
11 | # COPIED FROM HUGGINGFACE EMDDEDING CLASS
12 | class GPT4ALLEmbeddings(BaseModel, Embeddings):
13 | """Wrapper around sentence_transformers embedding models.
14 |
15 | To use, you should have the ``sentence_transformers`` python package installed.
16 |
17 | Example:
18 | .. code-block:: python
19 |
20 | from langchain.embeddings import HuggingFaceEmbeddings
21 |
22 | model_name = "sentence-transformers/all-mpnet-base-v2"
23 | model_kwargs = {'device': 'cpu'}
24 | encode_kwargs = {'normalize_embeddings': False}
25 | hf = HuggingFaceEmbeddings(
26 | model_name=model_name,
27 | model_kwargs=model_kwargs,
28 | encode_kwargs=encode_kwargs
29 | )
30 | """
31 |
32 | client: Any #: :meta private:
33 | model_name: str = 'gpt4all'
34 | """Model name to use."""
35 | cache_folder: Optional[str] = None
36 | """Path to store models.
37 | Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
38 | model_kwargs: Dict[str, Any] = Field(default_factory=dict)
39 | """Key word arguments to pass to the model."""
40 | encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
41 | """Key word arguments to pass when calling the `encode` method of the model."""
42 |
43 | def __init__(self, **kwargs: Any):
44 | """Initialize the sentence_transformer."""
45 | super().__init__(**kwargs)
46 | try:
47 | from gpt4all import Embed4All
48 |
49 | except ImportError as exc:
50 | raise ImportError(
51 | "Could not import sentence_transformers python package. "
52 | "Please install it with `pip install sentence_transformers`."
53 | ) from exc
54 |
55 | self.client = Embed4All()
56 |
57 | class Config:
58 | """Configuration for this pydantic object."""
59 |
60 | extra = Extra.forbid
61 |
62 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
63 | """Compute doc embeddings using a HuggingFace transformer model.
64 |
65 | Args:
66 | texts: The list of texts to embed.
67 |
68 | Returns:
69 | List of embeddings, one for each text.
70 | """
71 | texts = list(map(lambda x: x.replace("\n", " "), texts))
72 | embeddings = [self.client.embed(text) for text in texts]
73 | return embeddings
74 |
75 | def embed_query(self, text: str) -> List[float]:
76 | """Compute query embeddings using a HuggingFace transformer model.
77 |
78 | Args:
79 | text: The text to embed.
80 |
81 | Returns:
82 | Embeddings for the text.
83 | """
84 | text = text.replace("\n", " ")
85 | embedding = self.client.embed(text)
86 | return embedding
87 |
88 |
89 | def choose_embedder(key: str) -> Callable:
90 | if key == 'openai':
91 | return OpenAIEmbeddings
92 | if key == 'llama':
93 | return LlamaCppEmbeddings
94 | if key == 'gpt4all':
95 | return GPT4ALLEmbeddings
96 | return HuggingFaceEmbeddings
97 |
98 | def choose_retriever(key: str) -> Callable:
99 | if key == 'knn':
100 | return KNNRetriever
101 | if key == 'svm':
102 | return SVMRetriever
103 | return KNNRetriever
104 |
105 | EMBEDDERS = choose_embedder
106 | RETRIEVERS = choose_retriever
--------------------------------------------------------------------------------
/memory/episode.py:
--------------------------------------------------------------------------------
1 | from typing import List, Callable, Dict, Optional
2 | from copy import deepcopy
3 |
4 | class Trajectory:
5 | def __init__(
6 | self,
7 | task: str,
8 | trajectory: str,
9 | splitter: Callable,
10 | identifier: Callable,
11 | step_splitter: Callable,
12 | embedder: Optional[Callable] = None,
13 | reflections: List[str] = None
14 | ):
15 | self._task = task
16 | self._trajectory = trajectory
17 | self._reflections = deepcopy(reflections)
18 | self._observations, self._actions, self._thoughts = [], [], []
19 | for line in splitter(self._trajectory):
20 | setattr(self, f'_{identifier(line)}s', getattr(self, f'_{identifier(line)}s') + [line])
21 | self._steps = step_splitter(lines=trajectory, cycler=splitter, step_identifier=identifier)
22 | self._keys = {'thought': [], 'step': []}
23 | if embedder is not None:
24 | self._keys['task'] = [embedder(self.task)]
25 | for step in self.steps:
26 | self._keys['step'].append(embedder(step))
27 | for thought in self.thoughts:
28 | self._keys['thought'].append(embedder(thought))
29 |
30 | @property
31 | def task(self) -> str:
32 | return self._task
33 |
34 | @property
35 | def steps(self) -> List[str]:
36 | return self._steps
37 |
38 | @property
39 | def trajectory(self) -> str:
40 | return self._trajectory
41 |
42 | @property
43 | def num_steps(self) -> int:
44 | return max(len(self.thoughts), len(self.actions), len(self.observations))
45 |
46 | @property
47 | def observations(self) -> List[str]:
48 | return self._observations
49 |
50 | @property
51 | def actions(self) -> List[str]:
52 | return self._actions
53 |
54 | @property
55 | def thoughts(self) -> List[str]:
56 | return self._thoughts
57 |
58 | @property
59 | def reflections(self) -> List[str]:
60 | return self._reflections
61 |
62 | @property
63 | def keys(self) -> Dict[str, List[float]]:
64 | return self._keys
65 |
66 | def _replace(self, attr, value):
67 | self.__setattr__(self, attr, value)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .llm import LLM_CLS
--------------------------------------------------------------------------------
/models/llm.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List
2 | import time
3 |
4 | from langchain.chat_models import ChatOpenAI
5 | from langchain.schema import ChatMessage
6 | import openai
7 |
8 |
9 | class GPTWrapper:
10 | def __init__(self, llm_name: str, openai_api_key: str, long_ver: bool):
11 | self.model_name = llm_name
12 | if long_ver:
13 | llm_name = 'gpt-3.5-turbo-16k'
14 | self.llm = ChatOpenAI(
15 | model=llm_name,
16 | temperature=0.0,
17 | openai_api_key=openai_api_key,
18 | )
19 |
20 | def __call__(self, messages: List[ChatMessage], stop: List[str] = [], replace_newline: bool = True) -> str:
21 | kwargs = {}
22 | if stop != []:
23 | kwargs['stop'] = stop
24 | for i in range(6):
25 | try:
26 | output = self.llm(messages, **kwargs).content.strip('\n').strip()
27 | break
28 | except openai.error.RateLimitError:
29 | print(f'\nRetrying {i}...')
30 | time.sleep(1)
31 | else:
32 | raise RuntimeError('Failed to generate response')
33 |
34 | if replace_newline:
35 | output = output.replace('\n', '')
36 | return output
37 |
38 | def LLM_CLS(llm_name: str, openai_api_key: str, long_ver: bool) -> Callable:
39 | if 'gpt' in llm_name:
40 | return GPTWrapper(llm_name, openai_api_key, long_ver)
41 | else:
42 | raise ValueError(f"Unknown LLM model name: {llm_name}")
43 |
--------------------------------------------------------------------------------
/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List
2 |
3 | from . import alfworld
4 | from . import hotpotQA, fever, webshop
5 | from .templates.human import *
6 | from .templates.system import *
7 |
8 | FEWSHOTS = dict(
9 | hotpotqa=hotpotQA.FEWSHOTS,
10 | fever=fever.FEWSHOTS,
11 | alfworld=alfworld.FEWSHOTS,
12 | webshop=webshop.FEWSHOTS,
13 | )
14 | REFLECTION_FEWSHOTS = dict(
15 | hotpotqa=hotpotQA.REFLECTION_FEWSHOTS,
16 | fever=None,#fever.REFLECTION_FEWSHOTS,
17 | alfworld=alfworld.REFLECTION_FEWSHOTS,
18 | webshop=webshop.REFLECTION_FEWSHOTS,
19 | )
20 | SYSTEM_INSTRUCTION = dict(
21 | hotpotqa=hotpotQA.SYSTEM_INSTRUCTION,
22 | fever=fever.SYSTEM_INSTRUCTION,
23 | alfworld=alfworld.SYSTEM_INSTRUCTION,
24 | webshop=webshop.SYSTEM_INSTRUCTION,
25 | )
26 | SYSTEM_REFLECTION_INSTRUCTION = dict(
27 | hotpotqa=hotpotQA.SYSTEM_REFLECTION_INSTRUCTION,
28 | fever=None,#fever.SYSTEM_REFLECTION_INSTRUCTION,
29 | alfworld=None,#alfworld.SYSTEM_REFLECTION_INSTRUCTION,
30 | webshop=None,#webshop.SYSTEM_REFLECTION_INSTRUCTION,
31 | )
32 | HUMAN_INSTRUCTION = dict(
33 | hotpotqa=hotpotQA.HUMAN_INSTRUCTION,
34 | fever=fever.HUMAN_INSTRUCTION,
35 | alfworld=alfworld.HUMAN_INSTRUCTION,
36 | webshop=webshop.HUMAN_INSTRUCTION,
37 | )
38 | HUMAN_REFLECTION_INSTRUCTION = dict(
39 | hotpotqa=hotpotQA.HUMAN_REFLECTION_INSTRUCTION,
40 | fever=None,
41 | alfworld=alfworld.HUMAN_REFLECTION_INSTRUCTION,
42 | webshop=webshop.HUMAN_REFLECTION_INSTRUCTION,
43 | )
44 | SYSTEM_CRITIQUE_INSTRUCTION = dict(
45 | hotpotqa=dict(
46 | compare_existing_rules=hotpotQA.SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION,
47 | all_success_existing_rules=hotpotQA.SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION,
48 | ),
49 | fever=dict(
50 | compare_existing_rules=None,
51 | all_success_existing_riles=None
52 | ),
53 | alfworld=dict(
54 | compare_existing_rules=alfworld.SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION,
55 | all_success_existing_rules=alfworld.SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION,
56 | ),
57 | webshop=dict(
58 | compare_existing_rules=webshop.SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION,
59 | all_success_existing_rules=webshop.SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION,
60 | ),
61 | )
62 |
63 | LLM_PARSER = dict(
64 | hotpotqa=hotpotQA.LLM_PARSER,
65 | # fever and hotpotQA has same format
66 | fever=hotpotQA.LLM_PARSER,
67 | alfworld=alfworld.LLM_PARSER,
68 | webshop=webshop.LLM_PARSER,
69 | )
70 |
71 | OBSERVATION_FORMATTER = dict(
72 | hotpotqa=hotpotQA.OBSERVATION_FORMATTER,
73 | # fever and hotpotQA has same format
74 | fever=hotpotQA.OBSERVATION_FORMATTER,
75 | alfworld=alfworld.OBSERVATION_FORMATTER,
76 | webshop=webshop.OBSERVATION_FORMATTER,
77 | )
78 |
79 | STEP_IDENTIFIER = dict(
80 | hotpotqa=hotpotQA.STEP_IDENTIFIER,
81 | # fever and hotpotQA has same format
82 | fever=hotpotQA.STEP_IDENTIFIER,
83 | webshop=webshop.STEP_IDENTIFIER,
84 | alfworld=alfworld.STEP_IDENTIFIER,
85 | )
86 |
87 | CYCLER = dict(
88 | hotpotqa=hotpotQA.CYCLER,
89 | fever=hotpotQA.CYCLER,
90 | # fever and hotpotQA has same format
91 | webshop=webshop.CYCLER,
92 | alfworld=alfworld.CYCLER,
93 | )
94 | REFLECTION_PREFIX = dict(
95 | hotpotqa=hotpotQA.REFLECTION_PREFIX,
96 | fever=hotpotQA.REFLECTION_PREFIX,
97 | alfworld=alfworld.REFLECTION_PREFIX,
98 | # same format as alfworld
99 | webshop=webshop.REFLECTION_PREFIX,
100 | )
101 | PREVIOUS_TRIALS_FORMATTER=dict(
102 | hotpotqa=hotpotQA.PREVIOUS_TRIALS_FORMATTER,
103 | fever=hotpotQA.PREVIOUS_TRIALS_FORMATTER,
104 | alfworld=alfworld.PREVIOUS_TRIALS_FORMATTER,
105 | # same format as alfworld
106 | webshop=webshop.PREVIOUS_TRIALS_FORMATTER,
107 | )
108 |
109 | STEP_STRIPPER = dict(
110 | hotpotqa=hotpotQA.STEP_STRIPPER,
111 | fever=hotpotQA.STEP_STRIPPER,
112 | alfworld=alfworld.STEP_STRIPPER,
113 | # same format as alfworld
114 | webshop=webshop.STEP_STRIPPER,
115 | )
116 |
117 | def STEP_CYCLER(benchmark: str, lines: str, cycler: Callable, step_identifier: Callable, stripper: Callable = lambda x, y: x) -> List[str]:
118 | steps = []
119 | scratch_pad = ''
120 | for line in cycler(lines):
121 | step_type = step_identifier(line)
122 | stripped_line = stripper(line, step_type)
123 | scratch_pad += stripped_line + '\n'
124 | if step_type == 'observation':
125 | steps.append(scratch_pad.strip())
126 | scratch_pad = ''
127 | if scratch_pad != '':
128 | steps.append(scratch_pad.strip())
129 | return steps
130 |
--------------------------------------------------------------------------------
/prompts/fever.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts.chat import HumanMessagePromptTemplate
2 |
3 | FEWSHOTS = ["""Claim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.
4 | Thought 1: I need to search Nikolaj Coster-Waldau and find if he has worked with the Fox Broadcasting Company.
5 | Action 1: Search[Nikolaj Coster-Waldau]
6 | Observation 1: Nikolaj William Coster-Waldau (born 27 July 1970) is a Danish actor and producer. He graduated from the Danish National School of Performing Arts in Copenhagen in 1993,[1] and had his breakthrough role in Denmark with the film Nightwatch (1994). He played Jaime Lannister in the HBO fantasy drama series Game of Thrones, for which he received two Primetime Emmy Award nominations for Outstanding Supporting Actor in a Drama Series.. Coster-Waldau has appeared in numerous films in his native Denmark and Scandinavia, including Headhunters (2011) and A Thousand Times Good Night (2013). In the U.S, his debut film role was in the war film Black Hawk Down (2001), playing Medal of Honor recipient Gary Gordon.[2] He then played a detective in the short-lived Fox television series New Amsterdam (2008), and appeared in the 2009 Fox television film Virtuality, originally intended as a pilot.
7 | Thought 2: Because he "appeared in the 2009 Fox television film Virtuality", he should have worked with the Fox Broadcasting Company.
8 | Action 2: Finish[SUPPORTS]""",
9 | """Claim: Stranger Things is set in Bloomington, Indiana.
10 | Thought 1: I should search for Stranger Things, and see if it is set in Bloomington, Indiana.
11 | Action 1: Search[Stranger Things]
12 | Observation 1: Stranger Things is an American science fiction horror drama television series created by the Duffer Brothers. Set in the 1980s, primarily in the fictional town of Hawkins, Indiana, the series centers on a number of mysteries and supernatural events occurring around the town and their impact on an ensemble of child and adult characters.
13 | Thought 2: The observation says that it is set in a "fictional town of Hawkins, Indiana", so it is not set in Bloomington.
14 | Action 2: Finish[REFUTES]""",
15 | """Claim: Beautiful reached number two on the Billboard Hot 100 in 2003.?
16 | Thought 1: I need to search the song Beautiful and find if it reached number two on the Billboard Hot 100 in 2003.
17 | Action 1: Search[Beautiful]
18 | Observation 1: Could not find [Beautiful]. Similar: ['Beautiful', 'Beautiful, Beautiful', 'A Beautiful Mind (film)', 'Beautiful (Christina Aguilera song)', 'Life Is Beautiful'].
19 | Thought 2: From suggestions, I should search "Beautiful (Christina Aguilera song)" to find the song.
20 | Action 2: Search[Beautiful (Christina Aguilera song)]
21 | Observation 2: "Beautiful" is a song recorded by American singer Christina Aguilera for her fourth studio album, Stripped (2002).
22 | Thought 3: It does not mention Billboard, so I need to look up "Billboard Hot 100" to find if it reached number two on it in 2003.
23 | Action 3: Lookup[Billboard Hot 100]
24 | Observation 3: (Result 1 / 3) The song peaked at number two on the Billboard Hot 100 in the United States, where it was certified Gold for 500,000 units shipped.
25 | Thought 4: It only says the song peaked at number two on the Billboard Hot 100, but not if it was in 2003. I am not sure if this claim is true or not.
26 | Action 4: Finish[NOT ENOUGH INFO]"""]
27 |
28 | SYSTEM_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
29 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
30 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search.
31 | (3) Finish[answer], which returns the answer and finishes the task, answer should be one of REFUTES, SUPPORTS or NOT ENOUGH INFO.
32 | """
33 |
34 | human_instruction_template = """{instruction}You may take maximum of {max_steps} steps.
35 | Here are some examples:"""
36 |
37 | HUMAN_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_template)
38 |
--------------------------------------------------------------------------------
/prompts/hotpotQA.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Any, List
2 | import pickle
3 | import re
4 |
5 | from langchain.prompts.chat import HumanMessagePromptTemplate
6 | from langchain.schema import (
7 | AIMessage,
8 | HumanMessage,
9 | ChatMessage,
10 | )
11 |
12 | FEWSHOTS = ["""Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
13 | Thought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.
14 | Action 1: Search[Colorado orogeny]
15 | Observation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.
16 | Thought 2: It does not mention the eastern sector. So I need to look up eastern sector.
17 | Action 2: Lookup[eastern sector]
18 | Observation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.
19 | Thought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.
20 | Action 3: Search[High Plains]
21 | Observation 3: High Plains refers to one of two distinct land regions:
22 | Thought 4: I need to instead search High Plains (United States).
23 | Action 4: Search[High Plains (United States)]
24 | Observation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]
25 | Thought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
26 | Action 5: Finish[1,800 to 7,000 ft]""",
27 | """Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
28 | Thought 1: The question simplifies to "The Simpsons" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.
29 | Action 1: Search[Milhouse]
30 | Observation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.
31 | Thought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up "named after".
32 | Action 2: Lookup[named after]
33 | Observation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.
34 | Thought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
35 | Action 3: Finish[Richard Nixon]""",
36 | """Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
37 | Thought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.
38 | Action 1: Search[Adam Clayton Powell]
39 | Observation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].
40 | Thought 2: To find the documentary, I can search Adam Clayton Powell (film).
41 | Action 2: Search[Adam Clayton Powell (film)]
42 | Observation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.
43 | The film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.
44 | Thought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
45 | Action 3: Finish[The Saimaa Gesture]""",
46 | """Question: What profession does Nicholas Ray and Elia Kazan have in common?
47 | Thought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.
48 | Action 1: Search[Nicholas Ray]
49 | Observation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.
50 | Thought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.
51 | Action 2: Search[Elia Kazan]
52 | Observation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.
53 | Thought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
54 | Action 3: Finish[director, screenwriter, actor]""",
55 | """Question: Which magazine was started first Arthur's Magazine or First for Women?
56 | Thought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.
57 | Action 1: Search[Arthur's Magazine]
58 | Observation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.
59 | Thought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.
60 | Action 2: Search[First for Women]
61 | Observation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.
62 | Thought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
63 | Action 3: Finish[Arthur's Magazine]""",
64 | """Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
65 | Thought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.
66 | Action 1: Search[Pavel Urysohn]
67 | Observation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.
68 | Thought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.
69 | Action 2: Search[Leonid Levin]
70 | Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
71 | Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
72 | Action 3: Finish[yes]"""]
73 |
74 | REFLECTION_FEWSHOTS = [
75 | """
76 | Previous Trial:
77 | Question: The Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of what?
78 | Thought 1: I need to search Rome Protocols, find the three Prime Ministers, then find what they were assassinated as part of.
79 | Action 1: Search[Rome Protocols]
80 | Observation 1: The Rome Protocols were a series of three international agreements concluded in Rome on 17 March 1934 between the governments of Austria, Hungary and Italy. They were signed by Italian Prime Minister Benito Mussolini, Austrian Chancellor Engelbert Dollfuss and Hungarian Prime Minister Gyula Gömbös. All of the protocols went into effect on 12 July 1934 and were registered in League of Nations Treaty Series on 12 December 1934.
81 | Thought 2: The three Prime Ministers were Benito Mussolini, Engelbert Dollfuss, and Gyula Gömbös. I need to search Benito Mussolini and find what he was assassinated as part of.
82 | Action 2: Search[Benito Mussolini]
83 | Observation 2: Benito Amilcare Andrea Mussolini (UK: MU(U)SS-ə-LEE-nee, US: MOOSS-, Italian: [beˈniːto aˈmilkare anˈdrɛːa mussoˈliːni]; 29 July 1883 – 28 April 1945) was an Italian politician and journalist who founded and led the National Fascist Party (PNF). He was Prime Minister of Italy from the March on Rome in 1922 until his deposition in 1943, as well as "Duce" of Italian fascism from the establishment of the Italian Fasces of Combat in 1919 until his summary execution in 1945 by Italian partisans. As dictator of Italy and principal founder of fascism, Mussolini inspired and supported the international spread of fascist movements during the inter-war period.Mussolini was originally a socialist politician and a journalist at the Avanti! newspaper. In 1912, he became a member of the National Directorate of the Italian Socialist Party (PSI), but he was expelled from the PSI for advocating military intervention in World War I, in opposition to the party's stance on neutrality. In 1914, Mussolini founded a new journal, Il Popolo d'Italia, and served in the Royal Italian Army during the war until he was wounded and discharged in 1917. Mussolini denounced the PSI, his views now centering on Italian nationalism instead of socialism, and later founded the fascist movement which came to oppose egalitarianism and class conflict, instead advocating "revolutionary nationalism" transcending class lines. On 31 October 1922, following the March on Rome (28–30 October), Mussolini was appointed prime minister by King Victor Emmanuel III, becoming the youngest individual to hold the office up to that time. After removing all political opposition through his secret police and outlawing labor strikes, Mussolini and his followers consolidated power through a series of laws that transformed the nation into a one-party dictatorship. Within five years, Mussolini had established dictatorial authority by both legal and illegal means and aspired to create a totalitarian state. In 1929, Mussolini signed the Lateran Treaty with the Holy See to establish Vatican City.
84 | Mussolini's foreign policy aimed to restore the ancient grandeur of the Roman Empire by expanding Italian colonial possessions and the fascist sphere of influence. In the 1920s, he ordered the Pacification of Libya, instructed the bombing of Corfu over an incident with Greece, established a protectorate over Albania, and incorporated the city of Fiume into the Italian state via agreements with Yugoslavia. In 1936, Ethiopia was conquered following the Second Italo-Ethiopian War and merged into Italian East Africa (AOI) with Eritrea and Somalia. In 1939, Italian forces annexed Albania. Between 1936 and 1939, Mussolini ordered the successful Italian military intervention in Spain in favor of Francisco Franco during the Spanish Civil War. Mussolini's Italy initially tried to avoid the outbreak of a second global war, sending troops at the Brenner Pass to delay Anschluss and taking part in the Stresa Front, the Lytton Report, the Treaty of Lausanne, the Four-Power Pact and the Munich Agreement. However, Italy then alienated itself from Britain and France by aligning with Germany and Japan. Germany invaded Poland on 1 September 1939, resulting in declarations of war by France and the UK and the start of World War II.
85 | On 10 June 1940, Mussolini decided to enter the war on the Axis side. Despite initial success, the subsequent Axis collapse on multiple fronts and eventual Allied invasion of Sicily made Mussolini lose the support of the population and members of the Fascist Party. As a consequence, early on 25 July 1943, the Grand Council of Fascism passed a motion of no confidence in Mussolini; later that day King Victor Emmanuel III dismissed him as head of government and had him placed in custody, appointing Pietro Badoglio to succeed him as Prime Minister. After the king agreed to an armistice with the Allies, on 12 September 1943 Mussolini was rescued from captivity in the Gran Sasso raid by German paratroopers and Waffen-SS commandos led by Major Otto-Harald Mors. Adolf Hitler, after meeting with the rescued former dictator, then put Mussolini in charge of a puppet regime in northern Italy, the Italian Social Republic (Italian: Repubblica Sociale Italiana, RSI), informally known as the Salò Republic, causing a civil war. In late April 1945, in the wake of near total defeat, Mussolini and his mistress Clara Petacci attempted to flee to Switzerland, but both were captured by Italian communist partisans and summarily executed by firing squad on 28 April 1945 near Lake Como. The bodies of Mussolini and his mistress were then taken to Milan, where they were hung upside down at a service station to publicly confirm their demise.
86 | Thought 3: Benito Mussolini was assassinated as part of the collapse of the Axis powers during World War II. So the Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of World War II.
87 | Action 3: Finish[World War II]
88 |
89 | Reflection: I searched one of the prime ministers involved in the signing, then attemted to answer right away. I should have searched each of the prime ministers, then looked up 'death' on each of their pages in order to get more information before answering.
90 | """,
91 | """
92 | Previous Trial:
93 | Question: Kam Heskin plays Paige Morgan in a 2004 film directed by who?
94 | Thought 1: I need to search Kam Heskin and Paige Morgan, find the 2004 film they are in, then find the director of the film.
95 | Action 1: Search[Kam Heskin]
96 | Observation 1: Kam Erika Heskin (born May 8, 1973) is an American actress. She began her career playing Caitlin Deschanel on the NBC daytime soap opera Sunset Beach (1998–1999), before appearing in films Planet of the Apes (2001) and Catch Me If You Can (2002). Heskin went on in 2003 to play Elizabeth Bennet in an independent film Pride & Prejudice: A Latter-Day Comedy, Robin in Charmed (2004), and Paige Morgan in The Prince and Me film franchise (2006–2010).
97 | Thought 2: Kam Heskin plays Paige Morgan in The Prince and Me film franchise. I need to search The Prince and Me film franchise and find the director of the 2004 film.
98 | Action 2: Search[The Prince and Me film franchise]
99 | Observation 2: Could not find [The Prince and Me film franchise]. Similar: ['Kam Heskin', 'Prince of Persia', 'Harry Potter and the Half-Blood Prince (film)', 'Prince of Persia: The Sands of Time (film)', 'The Chronicles of Narnia (film series)', 'Majestic Prince (franchise)', 'Kingsman (franchise)', 'The Little Prince (2015 film)', 'Shrek (franchise)', 'Harry Potter (film series)']
100 | Thought 3: I need to search The Prince and Me (2004 film) instead.
101 | Action 3: Search[The Prince and Me (2004 film)]
102 | Observation 3: Could not find [The Prince and Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Kam Heskin', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince of Persia', 'Brooklynn Prince', 'Prince (musician)']
103 | Thought 4: I need to search The Prince & Me (2004 film).
104 | Action 4: Search[The Prince & Me (2004 film)]
105 | Observation 4: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince']
106 | Thought 5: I need to search The Prince & Me (2004 film) instead.
107 | Action 5: Search[The Prince & Me (2004 film)]
108 | Observation 5: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince']
109 | Thought 6: I need to search The Prince & Me (2004 film).
110 | Action 6: Search[The Prince & Me (2004 film)]
111 | Observation 6: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince']
112 |
113 | Reflection: I got stuck in a loop where I kept trying to search 'The Prince & Me (2004 film)' but the page could not be found. Instead I should have tried to search the similar results that had a similar name to see and they were made in 2004.
114 | """]
115 |
116 | SYSTEM_INSTRUCTION = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:
117 | (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.
118 | (2) Lookup[keyword], which returns the next sentence containing keyword in the last passage successfully found by Search.
119 | (3) Finish[answer], which returns the answer and finishes the task.
120 | """
121 |
122 | human_instruction_template = """{instruction}You may take maximum of {max_steps} steps.
123 | Here are some examples:"""
124 |
125 | HUMAN_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_template)
126 |
127 | human_instruction_reflection_template = """Here are some examples:"""
128 | HUMAN_REFLECTION_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_reflection_template)
129 |
130 | SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION = """You will be given two previous task trials in which you were given access to a Docstore API environment and a question to answer: one successful and one unsuccessful trial. You failed the trial either because you guessed the wrong answer with Finish[], or you used up your set number of reasoning steps."""
131 | SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION = """You will be given successful tasks trials in which you were given access to a Docstore API environment and a question to answer."""
132 | SYSTEM_REFLECTION_INSTRUCTION = """You will be given a previous reasoning trial in which you were given access to a Docstore API environment and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[], or you used up your set number of reasoning steps. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences."""
133 |
134 | def LLM_PARSER(llm_output, step: int, ai_message: bool) -> Tuple[ChatMessage, str, Dict[str, Any]]:
135 | pattern = r'(?i)action\s*(?:\d+|)\s*(?::|)\s*'
136 | action_pattern = r'(?i)\w+\[[^\]]+(?:\]|)'
137 |
138 | match = re.match(pattern, llm_output)
139 | if match:
140 | action = llm_output[match.end():]
141 | content = f"Action {step}: {action}"
142 |
143 | if len(re.findall(action_pattern, action)) > 1:
144 | return (
145 | AIMessage(content=content) if ai_message else HumanMessage(content=content),
146 | 'action',
147 | {'action': ''} # triggers invalid action
148 | )
149 |
150 | return (
151 | AIMessage(content=content) if ai_message else HumanMessage(content=content),
152 | 'action',
153 | {'action': action}
154 | )
155 |
156 | actions = re.findall(action_pattern, llm_output)
157 | if len(actions) == 1:
158 | action = actions[0]
159 | if action[-1] != ']':
160 | action += ']'
161 | content = f"Action {step}: {action}"
162 | return (
163 | AIMessage(content=content) if ai_message else HumanMessage(content=content),
164 | 'action',
165 | {'action': action}
166 | )
167 |
168 | if len(actions) > 1:
169 | content = re.sub(r"(?i)action\s*(?:\d*|)\s*(?::|)", "", llm_output)
170 | return (
171 | AIMessage(content=f"Action {step}: {content}"),
172 | 'action',
173 | {'action': ''} # triggers invalid action
174 | )
175 |
176 | # everthing else will be assumed to be a inner thought
177 | thought_pattern = r'(?i)thought\s*(?:\d+|)\s*(?::|)\s*(.*)'
178 | match = re.match(thought_pattern, llm_output)
179 | if match:
180 | # Extract the thought word and content
181 | thought_word = match.group(1)
182 | content = f"Thought {step}: {thought_word.rstrip(':')}"
183 | else:
184 | content = f"Thought {step}: {llm_output.rstrip(':')}"
185 | return (
186 | AIMessage(content=content) if ai_message else HumanMessage(content=content),
187 | 'thought',
188 | {}
189 | )
190 |
191 | def OBSERVATION_FORMATTER(observation: str, step: int, *args, **kwargs) -> Tuple[ChatMessage, str]:
192 | return HumanMessage(content=f"Observation {step}: " + observation.rstrip(':')), 'append'
193 |
194 | def STEP_IDENTIFIER(line: str) -> str:
195 | line = line.strip()
196 | pattern = re.compile(r'^(?i)action(?:\s+(\d+))?:')
197 | match = pattern.match(line)
198 | if match:
199 | return 'action'
200 | pattern = re.compile(r'^(?i)observation(?:\s+(\d+))?:')
201 | match = pattern.match(line)
202 | if match:
203 | return 'observation'
204 | return 'thought'
205 |
206 | def CYCLER(lines: str) -> List[str]:
207 | new_lines = []
208 | scratch_pad = ''
209 | for line in lines.split('\n'):
210 |
211 | # line is action
212 | pattern = re.compile(r'^(?i)action(?:\s+(\d+))?:')
213 | match = pattern.match(line)
214 | if match:
215 | if scratch_pad != '':
216 | new_lines.append(scratch_pad.strip())
217 | scratch_pad = ''
218 | new_lines.append(line)
219 | continue
220 |
221 | # line is thought
222 | pattern = re.compile(r'^(?i)thought(?:\s+(\d+))?:')
223 | match = pattern.match(line)
224 | if match:
225 | if scratch_pad != '':
226 | new_lines.append(scratch_pad.strip())
227 | scratch_pad = ''
228 | new_lines.append(line)
229 | continue
230 |
231 | # step is observation
232 | scratch_pad += line + '\n'
233 |
234 | # the rest of the scratch pad
235 | if scratch_pad != '':
236 | new_lines.append(scratch_pad.strip())
237 | return new_lines
238 |
239 | REFLECTION_PREFIX = '\nReflection:'
240 | def PREVIOUS_TRIALS_FORMATTER(reflections: List[str], include_prefix: bool = True) -> str:
241 | if reflections == []:
242 | return ''
243 | if include_prefix:
244 | memory_prefix = "You have attempted to solve the task before but failed. The following reflection(s) give a plan to avoid failing the task in the same way you did previously. Use them to improve your strategy of solving the task successfully."
245 | else:
246 | memory_prefix = ''
247 | memory_prefix += '\nReflections:'
248 | for reflection in reflections:
249 | memory_prefix += f"\n- {reflection.strip()}"
250 | return memory_prefix
251 |
252 | def STEP_STRIPPER(step: str, step_type: str):
253 | if step_type == 'observation':
254 | return re.sub(r'^(?i)observation(?:\s+(\d+))?:', 'Observation:', step)
255 | if step_type == 'action':
256 | return re.sub(r'^(?i)action(?:\s+(\d+))?:', 'Action:', step)
257 | if step_type == 'thought':
258 | return re.sub(r'^(?i)thought(?:\s+(\d+))?:', 'Thought:', step)
259 | return step
260 |
--------------------------------------------------------------------------------
/prompts/templates/human.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts.chat import HumanMessagePromptTemplate, SystemMessagePromptTemplate
2 |
3 | human_instruction_fewshots_template = """{instruction}
4 |
5 | {fewshots}
6 |
7 | (END OF EXAMPLES)
8 | """
9 | human_instruction_fewshot_message_prompt = lambda message_type: \
10 | SystemMessagePromptTemplate.from_template(
11 | human_instruction_fewshots_template,
12 | ) if message_type == 'all_system' else \
13 | HumanMessagePromptTemplate.from_template(human_instruction_fewshots_template,)
14 |
15 | human_task_template = """Now it's your turn!
16 | {task}"""
17 | human_task_message_prompt = HumanMessagePromptTemplate.from_template(
18 | human_task_template,
19 | )
20 |
21 | FORMAT_RULES_OPERATION_TEMPLATE = """ :
22 |
23 | The available operations are: AGREE (if the existing rule is strongly relevant for the task), REMOVE (if one existing rule is contradictory or similar/duplicated to other existing rules), EDIT (if any existing rule is not general enough or can be enhanced, rewrite and improve it), ADD (add new rules that are very different from existing rules and relevant for other tasks). Each needs to CLOSELY follow their corresponding formatting below (any existing rule not edited, not agreed, nor removed is considered copied):
24 |
25 | AGREE :
26 | REMOVE :
27 | EDIT :
28 | ADD :
29 |
30 | Do not mention the trials in the rules because all the rules should be GENERALLY APPLICABLE. Each rule should be concise and easy to follow. Any operation can be used MULTIPLE times. Do at most 4 operations and each existing rule can only get a maximum of 1 operation. """
31 |
32 | CRITIQUE_SUMMARY_SUFFIX = dict(full = """Focus on REMOVE rules first, and stop ADD rule unless the new rule is VERY insightful and different from EXISTING RULES. Below are the operations you do to the above list of EXISTING RULES:
33 | """, not_full = """Below are the operations you do to the above list of EXISTING RULES:
34 | """)
35 |
36 | human_critique_existing_rules_all_success_template = """{instruction}
37 | Here are the trials:
38 | {success_history}
39 |
40 | Here are the EXISTING RULES:
41 | {existing_rules}
42 |
43 | By examining the successful trials, and the list of existing rules, you can perform the following operations: add, edit, remove, or agree so that the new list of rules are general and high level insights of the successful trials or proposed way of Thought so they can be used as helpful tips to different tasks in the future. Have an emphasis on tips that help the agent perform better Thought and Action. Follow the below format:
44 |
45 | """ + FORMAT_RULES_OPERATION_TEMPLATE
46 |
47 | human_all_success_existing_rules_critique = HumanMessagePromptTemplate.from_template(human_critique_existing_rules_all_success_template)
48 |
49 | human_critique_existing_rules_template = """{instruction}
50 | Here are the two previous trials to compare and critique:
51 | TRIAL TASK:
52 | {task}
53 |
54 | SUCCESSFUL TRIAL:
55 | {success_history}
56 |
57 | FAILED TRIAL:
58 | {fail_history}
59 |
60 | Here are the EXISTING RULES:
61 | {existing_rules}
62 |
63 | By examining and contrasting to the successful trial, and the list of existing rules, you can perform the following operations: add, edit, remove, or agree so that the new list of rules is GENERAL and HIGH LEVEL critiques of the failed trial or proposed way of Thought so they can be used to avoid similar failures when encountered with different questions in the future. Have an emphasis on critiquing how to perform better Thought and Action. Follow the below format:
64 |
65 | """ + FORMAT_RULES_OPERATION_TEMPLATE
66 |
67 | human_existing_rules_critique = HumanMessagePromptTemplate.from_template(human_critique_existing_rules_template)
68 |
69 | HUMAN_CRITIQUES = dict(
70 | compare_existing_rules=human_existing_rules_critique,
71 | all_success_existing_rules=human_all_success_existing_rules_critique,
72 | )
73 |
74 | RULE_TEMPLATE = dict(
75 | hotpotqa=HumanMessagePromptTemplate.from_template("""The following are some experience you gather on a similar task of question answering using Wikipedia API. Use these as references to help you perform this task:
76 | {rules}
77 | """),
78 | webshop=HumanMessagePromptTemplate.from_template("""
79 | The following are some experiences (in decreasing order of importance) you gathered on tasks of purchasing items requested by an user by interacting with an online website. Use these experiences as useful references to help you perform better on this task:
80 | {rules}
81 | """),
82 | alfworld=HumanMessagePromptTemplate.from_template("""The following are some experience you gather on a similar task of completing a household task by interacting in a household environment. Use these as references to help you perform this task:
83 | {rules}
84 | """),
85 | fever=HumanMessagePromptTemplate.from_template("""The following paragraph are insights a teacher agent provided to you. It is MANDATORY for you to follow these insights as CLOSELY as possible as they will help you perform the fact verification tasks efficiently:
86 |
87 | In order to successfully complete factual verification tasks, begin by clearly understanding the claim. Then formulate a search query that is precise and directly related to the claim. Include the main subjects or context from the claim in your query. If the initial search doesn't yield desired results, consider refining the query, using synonyms, or breaking down the claim into smaller parts. Always verify the information you obtain against the claim before drawing a conclusion. If multiple searches fail, consider changing the search strategy or looking for related information that might indirectly provide the necessary information. If all else fails, consider that the answer might be found in the observations already made. When you're ready to draw a conclusion, double-check it against the information obtained and ensure its accuracy. Lastly, always be prepared to exhaust all possible search queries related to the task at hand before concluding. Remember, the claim can either be supported, refuted, or there might not be enough information to draw a conclusion.{rules}
88 | """),
89 | )
90 |
--------------------------------------------------------------------------------
/prompts/templates/system.py:
--------------------------------------------------------------------------------
1 | from langchain.prompts.chat import HumanMessagePromptTemplate
2 |
3 |
4 | system_template = """You are {ai_name}. {instruction}"""
5 |
6 | system_message_prompt = HumanMessagePromptTemplate.from_template(system_template)
7 |
--------------------------------------------------------------------------------
/prompts/webshop.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Any, List
2 | import re
3 |
4 | from langchain.prompts.chat import HumanMessagePromptTemplate, SystemMessagePromptTemplate
5 | from langchain.schema import (
6 | AIMessage,
7 | HumanMessage,
8 | ChatMessage,
9 | )
10 |
11 | FEWSHOTS = ["""Instruction:
12 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars
13 | [Search]
14 |
15 | Action: search[3 ounce bright citrus deodorant sensitive skin]
16 | Observation:
17 | [Back to Search]
18 | Page 1 (Total results: 50)
19 | [Next >]
20 | [B078GWRC1J]
21 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
22 | $10.99
23 | [B078GTKVXY]
24 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
25 | $10.99
26 | [B08KBVJ4XN]
27 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack)
28 | $15.95
29 |
30 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.]
31 | Observation: OK.
32 |
33 | Action: click[B078GWRC1J]
34 | Observation:
35 | [Back to Search]
36 | [< Prev]
37 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents]
38 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)]
39 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
40 | Price: $10.99
41 | Rating: N.A.
42 | [Description]
43 | [Features]
44 | [Reviews]
45 | [Buy Now]
46 |
47 | Action: think[For 3 ounce bottle of bright citrus deodorant for sensitive skin, the item has options 'bright citrus' and '3 ounce (pack of 1)' and seems good to buy.]
48 | Observation: OK.
49 |
50 | Action: click[bright citrus]
51 | Observation: You have clicked bright citrus.
52 |
53 | Action: click[3 ounce (pack of 1)]
54 | Observation: You have clicked 3 ounce (pack of 1).
55 |
56 | Action: think[All options are properly selected, I can now buy it.]
57 | Observation: OK.
58 |
59 | Action: click[Buy Now]
60 | Observation: Your score (min 0.0, max 1.0): 1.0""",
61 | """Instruction:
62 | i am looking for long sleeve men t-shirt.and please also choose the black one, and price lower than 30.00 dollars
63 | [Search]
64 | Action: search[shirt man]
65 | Observation:
66 | [Back to Search]
67 | Page 1 (Total results: 50)
68 | [Next >]
69 | [B09NNMV9LN]
70 | JSPOYOU Mens Short Sleeve Crewneck 3D Graphic Tunic Shirts Big & Tall Tie Dye Summer Top Basic Designed Classic Cotton Shirt
71 | $3.99 to $10.99
72 | [B09Q67H373]
73 | JSPOYOU 2022 Newly T-Shirt for Mens Funny 3D Graphics Pattern Crewneck Short Sleeve Tees Big and Tall Summer Casual Comfy Top
74 | $1.99 to $8.99
75 | [B00I3O3BSS]
76 | Blue by Betsey Johnson Women's SB Ever Ballet Flat
77 | $89.99
78 |
79 | Action: think[It seems the first page doesn't mention long sleeves. I should check the next page]
80 | Observation: OK.
81 |
82 | Action: click[Next >]
83 | Observation:
84 | [Back to Search]
85 | Page 2 (Total results: 50)
86 | [< Prev]
87 | [Next >]
88 | [B09Q8RD8YN]
89 | Bungo Stray Anime Dogs Anime Character, Long Sleeve, Sweatshirt, Hoodie, T shirt
90 | $19.99
91 | [B09KLQLLT2]
92 | Long Sleeve Superhero T Shirt Tank Top Mens Compression Shirt Men Workout Fitness Gym Shirt
93 | $19.99
94 | [B09R9YCM6R]
95 | XXBR Summer T-shirts for Mens, Soldier Short Sleeve 3D Street Vintage Printed Shirt Slim Fit Muscle Casual Tee Tops
96 | $8.98 to $11.99
97 |
98 | Action: think[It looks like we have long sleeves tshirt but not black. I will not bother looking at the others pages and will focus on rephrasing the query. I first need to click 'Back to Search' button.]
99 | Observation: OK.
100 |
101 | Action: click[Back to Search]
102 | Observation:
103 | i am looking for long sleeve men t-shirt and please also choose the black one, and price lower than 40.00 dollars
104 | [Search]
105 |
106 | Action: search[long sleeves t-shirt man black]
107 | Observation:
108 | [Back to Search]
109 | Page 1 (Total results: 50)
110 | [Next >]
111 | [B09QGK5XHZ]
112 | WENKOMG1 Men's Long Sleeve Undershirt with Mask Turtleneck Hooded T-Shirt Solid Color Workout Tops Zipper Side Slit Shirts Slim Fit Sweatshirt Spring/Summer Tee Shirts(Gray,)
113 | $8.39
114 | [B09QQP3356]
115 | HAUKLIE Men's Sports Waffle Ribbed Polo Shirts Summer Long Sleeve Cotton Muscle Quarter-Zip Henley T-Shirt Tunics Tops
116 | $10.99
117 | [B09Q67H373]
118 | JSPOYOU 2022 Newly T-Shirt for Mens Funny 3D Graphics Pattern Crewneck Short Sleeve Tees Big and Tall Summer Casual Comfy Top
119 | $1.99 to $8.99
120 |
121 | Action: think[B09QQP3356 seems to be a good match now. It is under 40.00 dollars and has long sleeves. I will check it first.]
122 | Observation: OK.
123 |
124 | Action: click[B09QQP3356]
125 | Observation:
126 | [Back to Search]
127 | [< Prev]
128 | color [black][gray][light blue]
129 | size [small][medium][large][x-large][xx-large][3x-large]
130 | HAUKLIE Men's Sports Waffle Ribbed Polo Shirts Summer Long Sleeve Cotton Muscle Quarter-Zip Henley T-Shirt Tunics Tops
131 | Price: $10.99
132 | Rating: N.A.
133 | [Description]
134 | [Features]
135 | [Reviews]
136 | [Buy Now]
137 |
138 | Action: think[There is 2 options, color and size. I need to select them one by one according to the instruction. First, I will choose the black color.]
139 | Observation: OK.
140 |
141 | Action: click[black]
142 | Observation: You have clicked black.
143 |
144 | Action: think[The item has also a size option, I must select one. Since no information is given about the size, I will the first one.]
145 | Observation: OK.
146 |
147 | Action: click[small]
148 | Observation: You have clicked small.
149 |
150 | Action: think[All options are proprely selected, I can now buy it.]
151 | Observation: OK.
152 |
153 | Action: click[Buy Now]
154 | Observation: Your score (min 0.0, max 1.0): 1.0"""
155 | ]
156 |
157 | NEW_REFLECTION_FEWSHOTS = [
158 | """Previous Trial:
159 | Instruction:
160 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars
161 | [Search]
162 |
163 | Action: search[dairy free and apple variety pack of chips]
164 | Observation:
165 | [Back to Search]
166 | Page 1 (Total results: 50)
167 | [Next >]
168 | [B07HRFJWP8]
169 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total)
170 | $100.0
171 | [B01KMHY5PG]
172 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars)
173 | $21.49
174 | [B008D2X8C4]
175 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
176 | $100.0
177 |
178 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.]
179 | Observation: OK.
180 |
181 | Action: click[B008D2X8C4]
182 | Observation:
183 | [Back to Search]
184 | [< Prev]
185 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack]
186 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)]
187 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
188 | Price: $100.0
189 | Rating: N.A.
190 | [Description]
191 | [Features]
192 | [Reviews]
193 | [Buy Now]
194 |
195 | Action: think[The item has 2 options, flavor and size. I need to select them one by one according to the instruction. First, I will choose the variety pack.]
196 | Observation: OK.
197 |
198 | Action: click[variety pack]
199 | Observation: You have clicked variety pack.
200 |
201 | Action: think[Then, since there is no required size, I will choose the first one.]
202 | Observation: OK.
203 |
204 | Action: click[0.8 ounce (pack of 24)]
205 | Observation: You have clicked 0.8 ounce (pack of 24).
206 |
207 | Action: click[Buy Now]
208 | Observation: Wrong item! TASK FAILED
209 |
210 | Next plan: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.""",
211 | """Previous Trial:
212 | Instruction:
213 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars
214 | [Search]
215 |
216 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2]
217 | Observation:
218 | [Back to Search]
219 | Page 1 (Total results: 50)
220 | [Next >]
221 | [B06Y96MXJV]
222 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
223 | $29.99
224 | [B06Y96N1KG]
225 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
226 | $29.99
227 | [B07GJTKYJQ]
228 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
229 | $42.99
230 |
231 | Action: think[The first two items are not what I want. The third item is not what I want, but it is the closest one.]
232 | Observation: OK.
233 |
234 | Action: think[The third item is not what I want, but it is the closest one.]
235 | Observation: OK.
236 |
237 | Action: click[B07GJTKYJQ]
238 | Observation:
239 | [Back to Search]
240 | [< Prev]
241 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
242 | Price: $42.99
243 | Rating: N.A.
244 | [Description]
245 | [Features]
246 | [Reviews]
247 | [Buy Now]
248 |
249 | Action: think[The price is too high and it is not what I want.]
250 | Observation: OK.
251 |
252 | Action: click[< Prev]
253 | Observation:
254 | [Back to Search]
255 | Page 1 (Total results: 50)
256 | [Next >]
257 | [B06Y96MXJV]
258 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
259 | $29.99
260 | [B06Y96N1KG]
261 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
262 | $29.99
263 | [B07GJTKYJQ]
264 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
265 | $42.99
266 |
267 | Action: think[The first two items are not what I want.]
268 | Observation: OK.
269 |
270 | Action: think[I need to refine my search.]
271 | Observation: OK.
272 |
273 | Action: click[< Back to Search]
274 | Observation: Invalid action!
275 |
276 | Action: click[Back to Search]
277 | Observation: Ran out of steps! TASK FAILED
278 |
279 | Next plan: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches and think less so that I can find more products."""
280 | ]
281 |
282 | ACT_REFLECTION_FEWSHOTS = [
283 | """Previous Trial:
284 | Instruction:
285 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars
286 | [Search]
287 |
288 | Action: search[dairy free and apple variety pack of chips]
289 | Observation:
290 | [Back to Search]
291 | Page 1 (Total results: 50)
292 | [Next >]
293 | [B07HRFJWP8]
294 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total)
295 | $100.0
296 | [B01KMHY5PG]
297 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars)
298 | $21.49
299 | [B008D2X8C4]
300 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
301 | $100.0
302 |
303 | Action: click[B008D2X8C4]
304 | Observation:
305 | [Back to Search]
306 | [< Prev]
307 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack]
308 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)]
309 | Price: $100.0
310 | Rating: N.A.
311 | [Description]
312 | [Features]
313 | [Reviews]
314 | [Buy Now]
315 |
316 | Action: click[variety pack]
317 | Observation: You have clicked variety pack.
318 |
319 | Action: click[0.8 ounce (pack of 24)]
320 | Observation: You have clicked 0.8 ounce (pack of 24).
321 |
322 | Action: click[Buy Now]
323 | Observation: Wrong item! TASK FAILED
324 |
325 | Next plan: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.""",
326 | """Previous Trial:
327 | Instruction:
328 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars
329 | [Search]
330 |
331 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2]
332 | Observation:
333 | [Back to Search]
334 | Page 1 (Total results: 50)
335 | [Next >]
336 | [B06Y96MXJV]
337 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
338 | $29.99
339 | [B06Y96N1KG]
340 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
341 | $29.99
342 | [B07GJTKYJQ]
343 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
344 | $42.99
345 |
346 | Action: click[B07GJTKYJQ]
347 | Observation:
348 | [Back to Search]
349 | [< Prev]
350 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
351 | Price: $42.99
352 | Rating: N.A.
353 | [Description]
354 | [Features]
355 | [Reviews]
356 | [Buy Now]
357 |
358 | Action: click[< Prev]
359 | Observation:
360 | [Back to Search]
361 | Page 1 (Total results: 50)
362 | [Next >]
363 | [B06Y96MXJV]
364 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
365 | $29.99
366 | [B06Y96N1KG]
367 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
368 | $29.99
369 | [B07GJTKYJQ]
370 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
371 | $42.99
372 |
373 | Action: click[< Back to Search]
374 | Observation: Invalid action!
375 |
376 | Action: click[Back to Search]
377 | Observation: Ran out of steps! TASK FAILED
378 |
379 | Next plan: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches and think less so that I can find more products."""
380 | ]
381 |
382 | REFLECTION_FEWSHOTS = [
383 | """Previous Trial:
384 | Instruction:
385 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars
386 | [Search]
387 |
388 | Action: search[dairy free and apple variety pack of chips]
389 | Observation:
390 | [Back to Search]
391 | Page 1 (Total results: 50)
392 | [Next >]
393 | [B07HRFJWP8]
394 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total)
395 | $100.0
396 | [B01KMHY5PG]
397 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars)
398 | $21.49
399 | [B008D2X8C4]
400 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
401 | $100.0
402 |
403 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.]
404 | Observation: OK.
405 |
406 | Action: click[B008D2X8C4]
407 | Observation:
408 | [Back to Search]
409 | [< Prev]
410 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack]
411 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)]
412 | Price: $100.0
413 | Rating: N.A.
414 | [Description]
415 | [Features]
416 | [Reviews]
417 | [Buy Now]
418 |
419 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.]
420 | Observation: OK.
421 |
422 | Action: click[variety pack]
423 | Observation: You have clicked variety pack.
424 |
425 | Action: click[0.8 ounce (pack of 24)]
426 | Observation: You have clicked 0.8 ounce (pack of 24).
427 |
428 | Action: click[Buy Now]
429 | Observation: Wrong item! TASK FAILED
430 |
431 | Next plan: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.""",
432 | """Previous Trial:
433 | Instruction:
434 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars
435 | [Search]
436 |
437 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2]
438 | Observation:
439 | [Back to Search]
440 | Page 1 (Total results: 50)
441 | [Next >]
442 | [B06Y96MXJV]
443 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
444 | $29.99
445 | [B06Y96N1KG]
446 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
447 | $29.99
448 | [B07GJTKYJQ]
449 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
450 | $42.99
451 |
452 | Action: think[The first two items are not what I want. The third item is not what I want, but it is the closest one.]
453 | Observation: OK.
454 |
455 | Action: think[The third item is not what I want, but it is the closest one.]
456 | Observation: OK.
457 |
458 | Action: click[B07GJTKYJQ]
459 | Observation:
460 | [Back to Search]
461 | [< Prev]
462 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
463 | Price: $42.99
464 | Rating: N.A.
465 | [Description]
466 | [Features]
467 | [Reviews]
468 | [Buy Now]
469 |
470 | Action: think[The price is too high and it is not what I want.]
471 | Observation: OK.
472 |
473 | Action: click[< Prev]
474 | Observation:
475 | [Back to Search]
476 | Page 1 (Total results: 50)
477 | [Next >]
478 | [B06Y96MXJV]
479 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
480 | $29.99
481 | [B06Y96N1KG]
482 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
483 | $29.99
484 | [B07GJTKYJQ]
485 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
486 | $42.99
487 |
488 | Action: think[The first two items are not what I want.]
489 | Observation: OK.
490 |
491 | Action: think[I need to refine my search.]
492 | Observation: OK.
493 |
494 | Action: click[< Back to Search]
495 | Observation: Invalid action!
496 |
497 | Action: click[Back to Search]
498 | Observation: Ran out of steps! TASK FAILED
499 |
500 | Next plan: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches and think less so that I can find more products."""
501 | ]
502 |
503 | SYSTEM_INSTRUCTION = """Follow the syntax of the examples closely when taking actions."""
504 |
505 | human_instruction_template = """{instruction}You may take maximum of {max_steps} steps.
506 | Here are two examples:"""
507 |
508 | HUMAN_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_template)
509 |
510 | human_instruction_reflection_template = """Here are two examples:"""
511 | HUMAN_REFLECTION_INSTRUCTION = HumanMessagePromptTemplate.from_template(human_instruction_reflection_template)
512 |
513 | SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION = """You will be given two previous task trials in which you could interact with an online store website in order to buy the requested item: one successful and one unsuccessful trial. You failed the trial either because you reached the maximum allowed number of steps or you bought a wrong item."""
514 | SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION = """You will be given successful task trials in which you could interact with an online store website in order to buy the requested items."""
515 |
516 | def LLM_PARSER(llm_output, step: int, ai_message: bool) -> Tuple[ChatMessage, str, Dict[str, Any]]:
517 | pattern = r'(?i)action\s*\d*\s*:\s*'
518 | formatted_output = re.sub(pattern, '', llm_output)
519 | if '[' not in formatted_output:
520 | formatted_output = f'think[{formatted_output.lstrip("Observation:").strip()}]'
521 | elif '[' in formatted_output and not formatted_output.endswith(']'):
522 | formatted_output += ']'
523 | return (
524 | AIMessage(content=f'Action: {formatted_output}') if ai_message else HumanMessage(content=f'Action: {formatted_output}'),
525 | 'action',
526 | {'action': formatted_output}
527 | )
528 |
529 | def OBSERVATION_FORMATTER(observation: str, *args, **kwargs) -> Tuple[ChatMessage, str]:
530 | return HumanMessage(content=f"Observation: " + observation + '\n'), 'append'
531 |
532 | def STEP_IDENTIFIER(line: str) -> str:
533 | line = line.strip()
534 | if line.lower().startswith('action: think['):
535 | return 'thought'
536 | if line.lower().startswith('action'):
537 | return 'action'
538 | return 'observation'
539 |
540 | # Should return [action1, observation1, action2, observation2, ...]
541 | def CYCLER(lines: str) -> List[str]:
542 | listed_lines = []
543 | listed_steps = lines.split('\n\n')
544 | for step in listed_steps:
545 | # assuming action only takes one line
546 | listed_lines.extend(step.split('\n', 1))
547 | return listed_lines
548 |
549 | REFLECTION_PREFIX = '\nNext plan:'
550 | def PREVIOUS_TRIALS_FORMATTER(reflections: List[str], *args, **kwargs) -> str:
551 | if reflections == []:
552 | return ''
553 | memory_prefix = 'Your memory for the task below:'
554 | for i, reflection in enumerate(reflections):
555 | memory_prefix += f"\nTrial {i}:\n{reflection.strip()}"
556 | return memory_prefix
557 |
558 | def STEP_STRIPPER(step: str, step_type: str):
559 | if step_type != 'task':
560 | return step
561 | return step.strip('Instruction:\n').strip()
562 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gym==0.26.2
2 | hydra-core==1.3.2
3 | ipykernel==6.19.2
4 | joblib==1.2.0
5 | langchain==0.0.181
6 | omegaconf==2.3.0
7 | openai==0.27.7
8 | pandas==1.5.3
9 | python-dotenv==1.0.0
10 | tiktoken==0.4.0
11 | wikipedia==1.4.0
12 | torch # ==1.9.0+cu111
13 | transformers==4.30.2
14 | einops==0.6.1
15 | sentence_transformers
16 | matplotlib
17 | # gdown # for downloading files from google drive (webshop)
--------------------------------------------------------------------------------
/tools/visualize_logs.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 | import hydra
3 |
4 | import sys
5 | import os
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 |
8 | from utils import load_trajectories_log, split_logs_by_task, plot_trial_stats, alfworld_results_per_env_name_log, get_webshop_mean_score, get_webshop_mean_scores, mode_results
9 |
10 | @hydra.main(version_base=None, config_path="../configs", config_name="visualize_logs")
11 | def main(cfg: DictConfig) -> None:
12 | out = load_trajectories_log(path=cfg.log_path, run_name=cfg.run_name, load_dict=False)
13 | log = out['log']
14 |
15 | parsed_result = split_logs_by_task(text=log, num_tasks=cfg.num_tasks)
16 | assert len(parsed_result) == cfg.num_tasks
17 |
18 | res = plot_trial_stats(parsed_result=parsed_result, benchmark=cfg.benchmark.name, max_trials=4, save_path=f"{cfg.log_path}/{cfg.run_name}_logs_stats.png" if cfg.save_fig else None)
19 | if 'eval' in cfg.run_name:
20 | if cfg.benchmark.name == 'alfworld':
21 | print(alfworld_results_per_env_name_log(log, cfg.num_tasks, 1))
22 | elif cfg.benchmark.name == 'webshop':
23 | print(get_webshop_mean_score(log, cfg.num_tasks, 1))
24 | res = {k: v[-1] for k, v in res.items()}
25 | else:
26 | if cfg.benchmark.name == 'alfworld':
27 | print(alfworld_results_per_env_name_log(log, cfg.num_tasks, cfg.agent.max_reflection_depth+1))
28 | elif cfg.benchmark.name == 'webshop':
29 | print(get_webshop_mean_scores(log, cfg.num_tasks, cfg.agent.max_reflection_depth+1))
30 |
31 | print(res)
32 |
33 | ############################################
34 | ### MODE OPTIONS SELECTION ###
35 | ############################################
36 | # [react, reflection]
37 | # _[sum, mean, list]
38 | # _[token, count]
39 | # _[thought, action, observation, invalid]
40 | # _[traj, step]
41 | ############################################
42 | modes = [
43 | 'react_mean_count_of_thought_per_traj',
44 | 'react_mean_count_of_action_per_traj',
45 | 'react_mean_count_of_observation_per_traj',
46 | 'react_mean_tokens_per_traj',
47 | 'react_mean_count_of_invalid_per_traj',
48 | ]
49 | if 'eval' not in cfg.run_name:
50 | modes += [
51 | 'reflection1_mean_tokens_per_traj',
52 | 'reflection2_mean_tokens_per_traj',
53 | 'reflection3_mean_tokens_per_traj',
54 | 'reflection1_mean_count_of_invalid_per_traj',
55 | 'reflection2_mean_count_of_invalid_per_traj',
56 | 'reflection3_mean_count_of_invalid_per_traj',]
57 | for mode in modes:
58 | print(f"{mode}: {mode_results(cfg.benchmark.name, log, cfg.num_tasks, mode)}")
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import getpass
2 | import hydra
3 | from omegaconf import DictConfig
4 | from pathlib import Path
5 | import os
6 | from copy import deepcopy
7 | from functools import partial
8 | import dotenv
9 | dotenv.load_dotenv()
10 |
11 | from agent import AGENT
12 | from prompts.templates.system import system_message_prompt
13 | from prompts.templates.human import HUMAN_CRITIQUES
14 | from prompts import (
15 | SYSTEM_INSTRUCTION,
16 | HUMAN_INSTRUCTION,
17 | FEWSHOTS,
18 | REFLECTION_FEWSHOTS,
19 | HUMAN_REFLECTION_INSTRUCTION,
20 | SYSTEM_REFLECTION_INSTRUCTION,
21 | SYSTEM_CRITIQUE_INSTRUCTION,
22 | RULE_TEMPLATE,
23 | LLM_PARSER,
24 | OBSERVATION_FORMATTER,
25 | STEP_IDENTIFIER,
26 | CYCLER,
27 | STEP_CYCLER,
28 | REFLECTION_PREFIX,
29 | PREVIOUS_TRIALS_FORMATTER,
30 | STEP_STRIPPER,
31 | CRITIQUE_SUMMARY_SUFFIX,
32 | )
33 | from envs import ENVS, INIT_TASKS_FN
34 | from memory import (
35 | EMBEDDERS,
36 | RETRIEVERS,
37 | )
38 | from models import LLM_CLS
39 | from utils import save_trajectories_log, load_trajectories_log, plot_trial_stats, split_logs_by_task, alfworld_results_per_env_name, get_webshop_mean_scores, get_fewshot_max_tokens
40 | from agent.reflect import Count
41 |
42 | @hydra.main(version_base=None, config_path="configs", config_name="train")
43 | def main(cfg : DictConfig) -> None:
44 | if cfg.testing:
45 | openai_api_key = 'NO_KEY_FOR_TESTING'
46 | else:
47 | openai_api_key = os.environ['OPENAI_API_KEY'] if 'OPENAI_API_KEY' in os.environ else getpass.getpass("Enter or paste your OpenAI API Key: ")
48 | LOG_PATH = Path('/'.join([cfg.log_dir, cfg.benchmark.name, cfg.agent_type]))
49 | LOG_PATH.mkdir(parents=True, exist_ok=True)
50 |
51 | # Load trajectory checkpoint, init as empty if not exist
52 | if cfg.resume:
53 | out = load_trajectories_log(
54 | LOG_PATH,
55 | run_name=cfg.run_name,
56 | load_true_log=True)
57 | else:
58 | # Overwriting confirmation
59 | if os.path.exists(f"{LOG_PATH}/{cfg.run_name}.pkl") and cfg.run_name != 'test':
60 | while True:
61 | res = input(f"Are you sure to overwrite '{cfg.run_name}'? (Y/N)\n").lower()
62 | if res == 'n':
63 | exit(0)
64 | elif res == 'y':
65 | break
66 | out = {'log': '', 'dicts': [], 'true_log': f'{str(cfg)}'}
67 | log, dicts, true_log = out['log'], out['dicts'], out['true_log']
68 |
69 | react_agent = AGENT[cfg.agent_type](
70 | name=cfg.ai_name,
71 | system_instruction=SYSTEM_INSTRUCTION[cfg.benchmark.name],
72 | human_instruction=HUMAN_INSTRUCTION[cfg.benchmark.name],
73 | tasks=INIT_TASKS_FN[cfg.benchmark.name](cfg),
74 | fewshots=FEWSHOTS[cfg.benchmark.name],
75 | system_prompt=system_message_prompt,
76 | env=ENVS[cfg.benchmark.name],
77 | max_steps=cfg.benchmark.max_steps,
78 | openai_api_key=openai_api_key,
79 | llm=cfg.agent.llm,
80 | llm_builder=LLM_CLS,
81 | reflection_fewshots=REFLECTION_FEWSHOTS[cfg.benchmark.name],
82 | reflection_task_prompt=HUMAN_REFLECTION_INSTRUCTION[cfg.benchmark.name],
83 | reflection_system_instruction=SYSTEM_REFLECTION_INSTRUCTION[cfg.benchmark.name],
84 | reflection_system_prompt=SYSTEM_INSTRUCTION[cfg.benchmark.name],
85 | max_relfection_depth=cfg.agent.max_reflection_depth if 'max_reflection_depth' in cfg.agent.keys() else 0,
86 | system_critique_instructions=SYSTEM_CRITIQUE_INSTRUCTION[cfg.benchmark.name],
87 | human_critiques=HUMAN_CRITIQUES,
88 | max_num_rules=cfg.agent.max_num_rules if 'max_num_rules' in cfg.agent.keys() else 0,
89 | rule_template=RULE_TEMPLATE[cfg.benchmark.name],
90 | truncate_strategy=cfg.agent.truncate_strategy if 'truncate_strategy' in cfg.agent.keys() else None,
91 | llm_parser=LLM_PARSER[cfg.benchmark.name],
92 | observation_formatter=OBSERVATION_FORMATTER[cfg.benchmark.name],
93 | embedder=EMBEDDERS(cfg.agent.retrieval_kwargs.embedder_type),
94 | embedder_path=cfg.agent.retrieval_kwargs.embedder_path,
95 | step_stripper=STEP_STRIPPER[cfg.benchmark.name],
96 | retriever_cls=RETRIEVERS(cfg.agent.retrieval_kwargs.retriever_type),
97 | message_splitter=CYCLER[cfg.benchmark.name],
98 | identifier=STEP_IDENTIFIER[cfg.benchmark.name],
99 | message_step_splitter=partial(STEP_CYCLER, benchmark=cfg.benchmark.name),
100 | reflection_prefix=REFLECTION_PREFIX[cfg.benchmark.name],
101 | previous_trials_formatter=PREVIOUS_TRIALS_FORMATTER[cfg.benchmark.name],
102 | success_critique_num=cfg.agent.success_critique_num,
103 | fewshot_strategy=cfg.agent.fewshot_strategy,
104 | critique_truncate_strategy=cfg.agent.critique_truncate_strategy,
105 | critique_summary_suffix=CRITIQUE_SUMMARY_SUFFIX,
106 | testing=cfg.testing,
107 | task_idx=dicts[-1]['task_idx'] if len(dicts) > 0 else 0,
108 | benchmark_name=cfg.benchmark.name,
109 | reranker=cfg.agent.retrieval_kwargs.reranker,
110 | buffer_retrieve_ratio=cfg.agent.retrieval_kwargs.buffer_retrieve_ratio,
111 | max_fewshot_tokens=get_fewshot_max_tokens(cfg.benchmark.name) if cfg.agent.retrieval_kwargs.max_fewshot_tokens == 'auto' else cfg.agent.retrieval_kwargs.max_fewshot_tokens,
112 | )
113 | if len(dicts) > 0:
114 | react_agent.load_checkpoint(loaded_dict=dicts[-1], no_load_list=['testing', 'max_relfection_depth', 'fewshot_strategy', 'max_fewshot_tokens'])
115 | if 'eval_idx_list' in dicts[-1]:
116 | react_agent.eval_idx_list = dicts[-1]['eval_idx_list']
117 |
118 | print(f"""*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
119 |
120 | You are using the following language model: {react_agent.llm.model_name}
121 |
122 | *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*""")
123 |
124 | while react_agent.job_not_done():
125 | prefix = f"#######################################\nTASK {react_agent.task_idx}"
126 | if cfg.agent_type in ['reflection', 'expel']:
127 | prefix += f' Reflection {react_agent.reflection_counter.count}\n\n'
128 | else:
129 | prefix += '\n\n'
130 | print(prefix + react_agent.remove_task_suffix(react_agent.task)) # remove_task_suffix used for alfworld
131 |
132 | react_agent.run(mode='train')
133 |
134 | #############################################
135 | ### Update & Save trajectory logs + dicts ###
136 | #############################################
137 | react_agent.update_stats()
138 | log += prefix + react_agent.log_history() + '\n\n'
139 | true_log += prefix + react_agent.log_history(include_all=True) + '\n\n'
140 |
141 | # next task
142 | react_agent.next_task()
143 |
144 | dicts.append({k: deepcopy(v) for k, v in react_agent.__dict__.items() if type(v) in [list, set, str, bool, int, dict, Count] and k not in ['openai_api_key', 'llm']}) # not saving complicated objects
145 |
146 | save_trajectories_log(
147 | LOG_PATH, log, dicts, true_log,
148 | run_name=cfg.run_name
149 | )
150 | #############################################
151 |
152 | ######################################
153 | ### Final Log & Save stats + PRINT ###
154 | ######################################
155 | success, fail, halted = react_agent.get_stats()
156 | log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}"
157 | true_log += f"########################################\nEND TRIAL\nTrial summary: Success: {success}/{success + fail + halted}, Fail: {fail}/{success + fail + halted}, Halted: {halted}/{success + fail + halted}"
158 | print(f'Finished. Success: {success}, Fail: {fail}, Halted: {halted}')
159 |
160 | parsed_result = split_logs_by_task(text=log, num_tasks=len(react_agent.tasks))
161 | reflection_results = plot_trial_stats(parsed_result=parsed_result, benchmark=cfg.benchmark.name, max_trials=cfg.agent.max_reflection_depth + 1, save_path=f"{LOG_PATH}/{cfg.run_name}_logs_stats.png")
162 |
163 | results = ', '.join([f"{k}: {v}" for k, v in reflection_results.items()]) + '\n'
164 | if cfg.benchmark.name == 'alfworld':
165 | results += str(alfworld_results_per_env_name(dicts[-1]))
166 | elif cfg.benchmark.name == 'webshop':
167 | results += str(get_webshop_mean_scores(log, len(react_agent.tasks), cfg.agent.max_reflection_depth + 1))
168 | log += f'\n\n{results}\n#######################################'
169 | true_log += f'\n\n{results}\n#######################################'
170 | print(results)
171 |
172 | save_trajectories_log(
173 | LOG_PATH, log, dicts, true_log,
174 | run_name=cfg.run_name
175 | )
176 |
177 | log, dicts, true_log = '', [], ''
178 | react_agent.reset_stats()
179 | ################################
180 |
181 | if __name__ == "__main__":
182 | main()
183 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 | from typing import List, Dict, Callable, Tuple, Any, Union
4 | from matplotlib import pyplot as plt
5 |
6 | import tiktoken
7 | from langchain.schema import (
8 | ChatMessage
9 | )
10 |
11 | from prompts import FEWSHOTS
12 |
13 | import math
14 | import pickle
15 | import re
16 |
17 |
18 | ENV_NAMES = [
19 | 'pick_and_place',
20 | 'pick_clean_then_place',
21 | 'pick_heat_then_place',
22 | 'pick_cool_then_place',
23 | 'look_at_obj',
24 | 'pick_two_obj'
25 | ]
26 |
27 | TASK_ENV_NAMES = [
28 | ('clean', 'pick_clean_then_place'),
29 | ('hot', 'pick_heat_then_place'),
30 | ('heat', 'pick_heat_then_place'),
31 | ('cool', 'pick_cool_then_place'),
32 | ('look', 'look_at_obj'),
33 | ('examine', 'look_at_obj'),
34 | ('two', 'pick_two_obj'),
35 | ('put', 'pick_and_place') # last one must be at last position
36 | ]
37 |
38 | class Count:
39 | """
40 | Class for reflection counting.
41 | """
42 | def __init__(self, maximum):
43 | self.count = 0
44 | self.maximum = maximum
45 |
46 | def increment(self):
47 | self.count += 1
48 | if self.count > self.maximum:
49 | self.count = self.maximum
50 |
51 | def is_beginning(self):
52 | return self.count == 1
53 |
54 | def reset(self):
55 | self.count = 0
56 |
57 | def is_maximum(self):
58 | return self.count == self.maximum
59 |
60 | def random_divide_list(lst: List[Any], k: int):
61 | """
62 | Divides the list into chunks, each with maximum length k.
63 |
64 | Args:
65 | lst: The list to be divided.
66 | k: The maximum length of each chunk.
67 |
68 | Returns:
69 | A list of chunks.
70 | """
71 | random.shuffle(lst)
72 | if len(lst) <= k:
73 | return [lst]
74 | else:
75 | num_chunks = math.ceil(len(lst) / k)
76 | chunk_size = math.ceil(len(lst) / num_chunks)
77 | return [lst[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
78 |
79 | def shuffled_chunks(lst: List[Any], num_chunks: int):
80 | """
81 | Divides the list into chunks as equally as possible.
82 |
83 | Args:
84 | lst: The list to be divided.
85 | num_chunks: The number of chunks.
86 |
87 | Returns:
88 | A list of chunks.
89 | """
90 | random.shuffle(lst)
91 | chunk_size = len(lst) // num_chunks
92 | remainder = len(lst) % num_chunks
93 | chunks = [lst[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
94 |
95 | # Distribute the remainder elements across the chunks
96 | if remainder > 0:
97 | for i in range(remainder):
98 | chunks[i].append(lst[num_chunks * chunk_size + i])
99 |
100 | random.shuffle(chunks)
101 | return chunks
102 |
103 | def token_counter(text: str, llm: str = 'gpt-3.5-turbo', tokenizer: Callable = None) -> int:
104 | """
105 | Counts the number of tokens in the text.
106 |
107 | Args:
108 | text: The text to be counted.
109 | llm: The language model name.
110 | tokenizer: The tokenizer to be used.
111 | """
112 | if 'gpt' in llm:
113 | return len(tiktoken.encoding_for_model(llm).encode(text))
114 |
115 | raise NotImplementedError
116 |
117 | def print_message(message: ChatMessage, token_counter: Callable = None, testing: bool = True, extra_text: str = '') -> None:
118 | """
119 | Prints the formatted message.
120 |
121 | Args:
122 | message: The message to be printed.
123 | token_counter: A function that takes in a string and returns the number of tokens in the string.
124 | testing: Add message type and token count in testing mode.
125 | extra_text: Extra text to be printed after the message in testing mode.
126 | """
127 | if testing:
128 | message = f"$$${message.type}$$$\t{message.content}\t$$${message.type}$$$"
129 | if token_counter is not None:
130 | message += f"\t***{token_counter(message)} tokens***"
131 | message += extra_text
132 | print(message)
133 | else:
134 | print(message.content)
135 |
136 | def parse_action(string: str):
137 | """
138 | Parse action string into action type and argument for HotpotQA and Fever.
139 |
140 | Args:
141 | string: action string
142 |
143 | Returns:
144 | action_type: action type
145 | argument: argument
146 | """
147 | pattern = r'^(\w+)\[(.+)\]$'
148 | match = re.match(pattern, string)
149 |
150 | if match:
151 | action_type = match.group(1)
152 | argument = match.group(2)
153 | return action_type, argument
154 |
155 | else:
156 | return None, None
157 |
158 | def normalize_answer(s: str):
159 | """
160 | Lower text and remove punctuation, articles and extra whitespace.
161 |
162 | Args:
163 | s: string to normalize
164 |
165 | Returns:
166 | normalized string
167 | """
168 | def remove_articles(text):
169 | return re.sub(r"\b(a|an|the)\b", " ", text)
170 |
171 | def white_space_fix(text):
172 | return " ".join(text.split())
173 |
174 | def remove_punc(text):
175 | exclude = set(string.punctuation)
176 | return "".join(ch for ch in text if ch not in exclude)
177 |
178 | def lower(text):
179 | return text.lower()
180 |
181 | return white_space_fix(remove_articles(remove_punc(lower(s))))
182 |
183 | def EM(answer, key) -> bool:
184 | """
185 | Exact match between answer and key.
186 |
187 | Args:
188 | answer: answer
189 | key: key
190 |
191 | Returns:
192 | True if exact match, else False
193 | """
194 | return normalize_answer(answer) == normalize_answer(key)
195 |
196 |
197 | def save_trajectories_log(path: str, log: str = None, dicts: list = None, true_log: str = None, save_log: bool = True, save_dict: bool = True, save_true_log: bool = True, run_name: str = 'run') -> None:
198 | """
199 | Saves the log and the dict to the path.
200 |
201 | Args:
202 | path: The path to save the log and the dictionaries.
203 | log: The log to be saved.
204 | dicts: The dict to be saved.
205 | true_log: The true log to be saved.
206 | save_log: Whether to save the log.
207 | save_dict: Whether to save the dictionaries.
208 | save_true_log: Whether to save the true log.
209 | run_name: The name of the run.
210 | """
211 | if save_log:
212 | with open(f'{path}/{run_name}.txt', 'w') as f:
213 | f.write(log)
214 | if save_dict:
215 | with open(f'{path}/{run_name}.pkl', 'wb') as f:
216 | pickle.dump(dicts, f)
217 | if save_true_log:
218 | with open(f'{path}/{run_name}_true.txt', 'w') as f:
219 | f.write(true_log)
220 |
221 | def load_trajectories_log(path: str, load_log: bool = True, load_dict: bool = True, load_true_log: bool = False, run_name: str = 'run') -> Dict[str, Any]:
222 | """
223 | Loads the log and the dict from the path.
224 |
225 | Args:
226 | path: The path to load the logs and the dictionaries.
227 | load_log: Whether to load the log.
228 | load_dict: Whether to load the dictionaries.
229 | load_true_log: Whether to load the true log.
230 | run_name: The name of the run.
231 |
232 | Returns:
233 | A dictionary containing the logs and the dict.
234 | """
235 | out = dict()
236 | if load_log:
237 | with open(f'{path}/{run_name}.txt', 'r') as f:
238 | out['log'] = f.read()
239 | if load_dict:
240 | with open(f'{path}/{run_name}.pkl', 'rb') as f:
241 | out['dicts'] = pickle.load(f)
242 | if load_true_log:
243 | with open(f'{path}/{run_name}_true.txt', 'r') as f:
244 | out['true_log'] = f.read()
245 |
246 | return out
247 |
248 | def split_logs_by_task(text: str, num_tasks: int) -> List[List[str]]:
249 | """
250 | Splits the log text by task.
251 |
252 | Args:
253 | text: The log text.
254 | num_tasks: The number of tasks.
255 |
256 | Returns:
257 | A list of lists of log texts, each list corresponding to a task.
258 | """
259 | remaining_text = text
260 | parsed_result = []
261 | for task_i in range(num_tasks+1):
262 | if task_i == num_tasks:
263 | pattern_i = r'########################################\nEND TRIAL'
264 | else:
265 | pattern_i = rf'#######################################\n.*TASK {str(task_i)} '
266 | matches_i = re.split(pattern_i, remaining_text)
267 | remaining_text = matches_i[-1]
268 | parsed_result.append(matches_i[1:-1])
269 | if task_i != 0:
270 | parsed_result[task_i-1].append(matches_i[0])
271 |
272 | # remove the last empty list
273 | parsed_result.pop()
274 |
275 | return parsed_result
276 |
277 | def recompute_stats(parsed_result: List[List[str]], benchmark: str, trial: int = -1) -> Dict[str, int]:
278 | """
279 | Recomputes the stats from the parsed log text.
280 |
281 | Args:
282 | parsed_result: The parsed log text.
283 | benchmark: The benchmark name.
284 | trial: The number of trials.
285 |
286 | Returns:
287 | The stats for the given benchmark.
288 | """
289 | stats = {"success": 0, "fail": 0, "halted": 0} if benchmark != 'alfworld' else {"success": 0, "fail": 0}
290 |
291 | for task_i in range(len(parsed_result)):
292 | trajectories = parsed_result[task_i]
293 | last_trajectory = trajectories[min(trial, len(trajectories) - 1)].strip()
294 | last_step = last_trajectory.split('\n')[-1]
295 |
296 | if benchmark == 'hotpotqa':
297 | if ' CORRECT' in last_step:
298 | stats["success"] += 1
299 | elif 'INCORRECT' in last_step:
300 | stats["fail"] += 1
301 | else:
302 | stats["halted"] += 1
303 | elif benchmark == 'alfworld':
304 | if 'SOLVED' in last_step:
305 | stats["success"] += 1
306 | else:
307 | stats["fail"] += 1
308 | elif benchmark == 'webshop':
309 | if ': 1.0' in last_step or 'Your score' in last_step:
310 | stats["success"] += 1
311 | else:
312 | stats["halted"] += 1
313 | elif benchmark == 'fever':
314 | if 'reward = 1' in last_step:
315 | stats["success"] += 1
316 | elif 'reward = 0' in last_step:
317 | stats["fail"] += 1
318 | else:
319 | stats["halted"] += 1
320 | else:
321 | raise NotImplementedError(f'recompute_stats for {benchmark} not implemented')
322 |
323 | return stats
324 |
325 | def plot_trial_stats(parsed_result: List[List[str]], benchmark: str, max_trials: int = 4, save_path: str = None) -> Dict[str, List[int]]:
326 | """
327 | Plots the stats from the parsed log text.
328 |
329 | Args:
330 | parsed_result: The parsed log text.
331 | benchmark: The benchmark name.
332 | max_trials: The number of trials.
333 | save_path: The path to save the figure.
334 |
335 | Returns:
336 | The stats for the given benchmark.
337 | """
338 | results = dict()
339 | colors = {'success': 'green', 'fail': 'red', 'halted': 'orange'}
340 |
341 | for i in range(max_trials):
342 | stats = recompute_stats(parsed_result, benchmark, i)
343 | for key, value in stats.items():
344 | results[key] = results.get(key, []) + [value]
345 |
346 | if benchmark == 'alfworld':
347 | assert len(parsed_result) == 134
348 | results = {k: [round(x / 134 * 100, 2) for x in v] for k, v in results.items()}
349 | else:
350 | assert len(parsed_result) == 100
351 |
352 | for i, (key, value) in enumerate(results.items()):
353 | plt.plot(value, label=key, marker='o', color=colors[key])
354 |
355 | # annotate all points with their values next to them
356 | for i in range(max_trials):
357 | for key, value in results.items():
358 | plt.annotate(value[i], (i, value[i]), textcoords="offset points", xytext=(0,10), ha='center')
359 |
360 | plt.legend(loc='best')
361 | plt.xlabel("Reflection numbers")
362 | plt.ylabel("Task SR %")
363 | plt.xticks(range(max_trials))
364 | plt.show()
365 | if save_path:
366 | plt.savefig(save_path)
367 |
368 | return results
369 |
370 | def get_env_name_from_gamefile(gamefile: str) -> Union[str, None]:
371 | """
372 | Gets the environment name from the gamefile for ALFWorld.
373 |
374 | Args:
375 | gamefile: The gamefile.
376 |
377 | Returns:
378 | The environment name.
379 | """
380 | for k in ENV_NAMES:
381 | if k in gamefile:
382 | return k
383 |
384 | def get_env_name_from_task(task: str, benchmark: str) -> Union[str, None]:
385 | """
386 | Gets the environment name from the task instruction for ALFWorld.
387 |
388 | Args:
389 | task: The task.
390 | benchmark: The benchmark name.
391 |
392 | Returns:
393 | The environment name.
394 | """
395 | if benchmark == 'alfworld':
396 | for k, v in TASK_ENV_NAMES:
397 | if k in task:
398 | return v
399 | else:
400 | return benchmark
401 |
402 | def alfworld_results_per_env_name(agent: Dict[str, Any]) -> Dict[str, int]:
403 | """
404 | Computes the results per environment name for ALFWorld from agent dict.
405 |
406 | Args:
407 | agent: The agent dictionary.
408 |
409 | Returns:
410 | The results per environment name.
411 | """
412 | tasks = agent['tasks']
413 | succeeded_trial_history = agent['succeeded_trial_history']
414 | failed_trial_history = agent['failed_trial_history']
415 |
416 | results = {k: 0 for k in ENV_NAMES}
417 | totals = {k: 0 for k in ENV_NAMES}
418 | seen = {t['task']: 0 for t in tasks}
419 | for task in tasks:
420 | t = task['task']
421 | env_name = get_env_name_from_gamefile(task['env_kwargs']['gamefile'])
422 | if len(succeeded_trial_history[t]) > seen[t]:
423 | results[env_name] += 1
424 | totals[env_name] += 1
425 | seen[t] += 1
426 | elif len(failed_trial_history[t]) > 0:
427 | totals[env_name] += 1
428 |
429 | return {k: round(results[k] / totals[k], 2) if totals[k] != 0 else 0 for k in results.keys()}
430 |
431 | def alfworld_results_per_env_name_log(log: str, num_tasks: int, num_trials: int) -> Dict[str, int]:
432 | """
433 | Computes the results per environment name for ALFWorld from log text.
434 |
435 | Args:
436 | log: The log text.
437 | num_tasks: The number of tasks.
438 | num_trials: The number of trials.
439 |
440 | Returns:
441 | The results per environment name.
442 | """
443 | results = {k: [0, 0, 0, 0] for k in ENV_NAMES}
444 | totals = {k: [0, 0, 0, 0] for k in ENV_NAMES}
445 |
446 | parsed_results = split_logs_by_task(log, num_tasks)
447 | for i in range(num_trials):
448 | for task in parsed_results:
449 | if 'You are in the middle of a room' not in task[i if i < len(task) else -1]:
450 | continue
451 | task_desc = re.findall(r'Your task is to: (.*)', task[i if i < len(task) else -1])[0]
452 | env_name = get_env_name_from_task(task_desc, 'alfworld')
453 | if env_name is None:
454 | raise ValueError(f'env_name is None for task {task_desc}')
455 | if 'SOLVED' in task[i if i < len(task) else -1].strip().split('\n')[-1]:
456 | results[env_name][i] += 1
457 | totals[env_name][i] += 1
458 |
459 | assert all([sum([totals[k][i] for k in ENV_NAMES]) == num_tasks for i in range(num_trials)])
460 | print(totals)
461 |
462 | return {k: [round(results[k][i] / totals[k][i], 2) if totals[k][i] != 0 else 0 for i in range(num_trials)] for k in ENV_NAMES}
463 |
464 | def get_webshop_mean_score(log: str, num_tasks: int, num_trials: int) -> float:
465 | """
466 | Computes the mean score for WebShop from log text.
467 |
468 | Args:
469 | log: The log text.
470 | num_tasks: The number of tasks.
471 | num_trials: The number of trials.
472 |
473 | Returns:
474 | The mean score.
475 | """
476 | parsed_result = split_logs_by_task(text=log, num_tasks=num_tasks)
477 |
478 | assert len(parsed_result) == num_tasks
479 | return sum([
480 | max([
481 | float(parsed_result[k][i].strip().split('\n')[-1].split()[-1]
482 | if 'Your score' in parsed_result[k][i].strip().split('\n')[-1] else 0)
483 | for i in range(min(len(parsed_result[k]), num_trials))
484 | ]
485 | ) for k in range(len(parsed_result))]
486 | ) / len(parsed_result)
487 |
488 | def get_webshop_mean_scores(log: str, num_tasks: int, num_trials: int) -> float:
489 | """
490 | Computes the mean scores for WebShop from log text.
491 |
492 | Args:
493 | log: The log text.
494 | num_tasks: The number of tasks.
495 | num_trials: The number of trials.
496 |
497 | Returns:
498 | The mean scores.
499 | """
500 | return [get_webshop_mean_score(log, num_tasks, i) for i in range(1, num_trials + 1)]
501 |
502 | def get_fewshot_max_tokens(benchmark: str) -> int:
503 | """
504 | Gets the maximum number of tokens in the fewshot tasks for the given benchmark.
505 |
506 | Args:
507 | benchmark: The benchmark name.
508 |
509 | Returns:
510 | The maximum number of tokens.
511 | """
512 | fewshots = FEWSHOTS[benchmark]
513 | if isinstance(fewshots, dict):
514 | return max([max([token_counter(f) for f in fs]) for fs in fewshots.values()])
515 | elif isinstance(fewshots, list):
516 | return max([token_counter(f) for f in fewshots])
517 |
518 | def get_split_eval_idx_list(agent_dict: Dict[str, Any], n_folds: int) -> List[List[int]]:
519 | """
520 | Gets the split evaluation index list.
521 |
522 | Args:
523 | agent_dict: The agent dictionary.
524 | n_folds: The number of folds.
525 |
526 | Returns:
527 | The split evaluation index list.
528 | """
529 | eval_idx_list = [[] for _ in range(n_folds)]
530 | env_names = set(x['env_name'] for x in agent_dict['tasks'])
531 | task2idx = agent_dict['task2idx']
532 |
533 | # compare success/failure
534 | compare_dict = {env_name: [] for env_name in env_names}
535 | success_dict = {env_name: [] for env_name in env_names}
536 | fail_dict = {env_name: [] for env_name in env_names}
537 | for task, trials in agent_dict['succeeded_trial_history'].items():
538 | if len(trials) > 0:
539 | if len(agent_dict['failed_trial_history'][task]) > 0:
540 | compare_dict[get_env_name_from_task(task, agent_dict['benchmark_name'])].append(task2idx[task])
541 | else:
542 | success_dict[get_env_name_from_task(task, agent_dict['benchmark_name'])].append(task2idx[task])
543 | else:
544 | assert len(agent_dict['failed_trial_history'][task]) > 0
545 | fail_dict[get_env_name_from_task(task, agent_dict['benchmark_name'])].append(task2idx[task])
546 |
547 | # split into n_folds
548 | j = 0
549 | for idx_list in list(compare_dict.values()) + list(success_dict.values()) + list(fail_dict.values()):
550 | random.shuffle(idx_list)
551 | for idx in idx_list:
552 | eval_idx_list[j % n_folds].append(idx)
553 | j += 1
554 |
555 | assert set.intersection(*[set(x) for x in eval_idx_list]) == set()
556 |
557 | return eval_idx_list
558 |
559 | def mode_results(benchmark: str, log: str, num_tasks: int, mode: str) -> Any:
560 | """
561 | Computes the statistic results for the given mode.
562 |
563 | Args:
564 | benchmark: The benchmark name.
565 | log: The log text.
566 | num_tasks: The number of tasks.
567 | mode: The mode.
568 |
569 | Returns:
570 | The results.
571 | """
572 | parsed_result = split_logs_by_task(text=log, num_tasks=num_tasks)
573 | if 'react' in mode:
574 | parsed_result = [x[0] for x in parsed_result]
575 | elif 'reflection' in mode:
576 | pattern = r'reflection(\d+)'
577 | res = re.findall(pattern, mode)
578 | if len(res) > 0:
579 | i = int(res[0])
580 | else:
581 | pattern = r'(\d+)reflection'
582 | res = re.findall(pattern, mode)
583 | if len(res) > 0:
584 | i = int(res[0])
585 | else:
586 | i = 0
587 | parsed_result = [x[i if i < len(x) else -1] for x in parsed_result]
588 |
589 | if 'token' in mode:
590 | f_trial = token_counter
591 | elif 'count' in mode:
592 | f_trial = lambda x: 1
593 | else:
594 | raise NotImplementedError(f'mode must contain token or count')
595 |
596 | invalid = None
597 | if 'invalid' in mode:
598 | if benchmark == 'alfworld':
599 | invalid = 'nothing happens'
600 | elif benchmark == 'webshop':
601 | invalid = 'invalid action'
602 | elif benchmark in ['hotpotqa', 'fever']:
603 | # invalid = 'similar: ['
604 | invalid = 'invalid action'
605 | mode += 'observation'
606 |
607 | if 'thought' in mode:
608 | if benchmark == 'webshop':
609 | lambda_filter = lambda y: y.strip().startswith('Action: think[')
610 | elif benchmark == 'alfworld':
611 | lambda_filter = lambda y: y.strip().startswith('> think:')
612 | elif benchmark in ['hotpotqa', 'fever']:
613 | lambda_filter = lambda y: y.strip().startswith('Thought')
614 | else:
615 | raise NotImplementedError(f'benchmark {benchmark} not implemented')
616 | elif 'action' in mode:
617 | if benchmark == 'webshop':
618 | lambda_filter = lambda y: y.strip().startswith('Action: click[') or y.strip().startswith('Action: search[') # valid actions
619 | elif benchmark == 'alfworld':
620 | lambda_filter = lambda y: y.strip().startswith('> ') and not y.strip().startswith('> think:') # valid and invalid actions
621 | elif benchmark in ['hotpotqa', 'fever']:
622 | lambda_filter = lambda y: y.strip().startswith('Action')
623 | else:
624 | raise NotImplementedError(f'benchmark {benchmark} not implemented')
625 | elif 'observation' in mode:
626 | if benchmark == 'webshop':
627 | lambda_filter = lambda y: y.strip().startswith('Observation:')
628 | elif benchmark == 'alfworld':
629 | lambda_filter = lambda y: not y.strip().startswith('> ')
630 | elif benchmark in ['hotpotqa', 'fever']:
631 | lambda_filter = lambda y: y.strip().startswith('Observation')
632 | else:
633 | raise NotImplementedError(f'benchmark {benchmark} not implemented')
634 | else:
635 | lambda_filter = lambda y: True
636 |
637 | if 'step' in mode:
638 | if invalid is not None:
639 | parsed_result = [y for x in parsed_result for y in x.strip().split('\n') if lambda_filter(y) and invalid in y.lower()]
640 | else:
641 | parsed_result = [y for x in parsed_result for y in x.strip().split('\n') if lambda_filter(y)]
642 |
643 | elif 'traj' in mode:
644 | if invalid is not None:
645 | parsed_result = [[y for y in x.strip().split('\n') if lambda_filter(y) and invalid in y.lower()] for x in parsed_result]
646 | else:
647 | parsed_result = [[y for y in x.strip().split('\n') if lambda_filter(y)] for x in parsed_result]
648 | assert len(parsed_result) == num_tasks
649 | else:
650 | raise NotImplementedError('mode must contain traj or step')
651 |
652 | if 'sum' in mode:
653 | f_return = sum
654 | elif 'mean' in mode:
655 | f_return = lambda x: sum(x) / len(x) if len(x) > 0 else 0
656 | elif 'list' in mode:
657 | f_return = lambda x: x
658 | else:
659 | raise NotImplementedError('mode must contain sum, mean or list')
660 |
661 | if 'print' in mode:
662 | print(parsed_result)
663 |
664 | if len(parsed_result) > 0 and isinstance(parsed_result[0], List):
665 | return f_return([sum([f_trial(trial) for trial in result]) for result in parsed_result])
666 |
667 | return f_return([f_trial(trial) for trial in parsed_result])
668 |
--------------------------------------------------------------------------------