├── .gitignore
├── LICENSE
├── README.md
├── alfworld
├── README
├── alfworld_trial.py
├── base_config.yaml
├── generate_reflections.py
├── main.py
├── prompts
│ ├── alfworld.json
│ ├── alfworld_3prompts.json
│ ├── fever.json
│ └── prompts_naive.json
├── requirements.txt
└── utils.py
├── figures
└── illusion.png
├── hotpotqa
├── README
├── base.py
├── hotpot.py
├── hotpotqa.py
├── models.py
├── requirements.txt
├── run.py
├── tot.py
└── utils.py
└── webshop
├── README
├── base.py
├── lats.py
├── models.py
├── prompt.py
├── requirements.txt
├── run.py
├── utils.py
└── webshop.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | .pybuilder/
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | # For a library or package, you might want to ignore these files since the code is
86 | # intended to run in multiple environments; otherwise, check them in:
87 | # .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # poetry
97 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
98 | # This is especially recommended for binary packages to ensure reproducibility, and is more
99 | # commonly ignored for libraries.
100 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
101 | #poetry.lock
102 |
103 | # pdm
104 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
105 | #pdm.lock
106 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
107 | # in version control.
108 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
109 | .pdm.toml
110 | .pdm-python
111 | .pdm-build/
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
163 | /alfworld_runs/70bins/
164 | /alfworld_runs/8bins/
165 | /alfworld_runs/7embed/
166 | /alfworld_runs/2embed/
167 | /alfworld_runs/reranker/
168 | /alfworld_runs/rerankers/
169 | /alfworld_runs/logs/
170 | /alfworld_runs/.infinity_cache/
171 | /alfworld_runs/e5_7b/
172 | /8bbase/
173 | /70bins/
174 | /8bins/
175 | /7embed/
176 | /2embed/
177 | /reranker/
178 | /rerankers/
179 | /logs/
180 | /.infinity_cache/
181 | /e5_7b/
182 | /8bbase/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoPS: Empowering LLM Agents with Provable Cross-Task Experience Sharing
2 |
3 | This repo holds the code for [CoPS: Empowering LLM Agents with Provable Cross-Task Experience Sharing](https://arxiv.org/abs/2410.16670) by Chen Yang, Chenyang Zhao, Quanquan Gu and Dongruo Zhou.
4 |
5 | 
6 |
7 | CoPS enhances sequential reasoning in LLM-powered agents by enabling effective sharing and selection of cross-task experiences through a pessimism-based strategy.
8 |
9 | ## Installation
10 |
11 | Our repo is mainly build on SGLang version 0.2.15.
12 |
13 | ```shell
14 | git clone -b v0.2.15 https://github.com/sgl-project/sglang.git
15 | cd sglang
16 | pip install -e "python[all]"
17 |
18 | # Modify the version of flashinfer to your own device
19 | pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
20 | ```
21 |
22 | For the installation of Alfworld and Webshop benchmarks locally, please refer to the following links:
23 |
24 | [https://github.com/alfworld/alfworld](https://github.com/alfworld/alfworld)
25 |
26 | [https://github.com/princeton-nlp/WebShop](https://github.com/princeton-nlp/WebShop)
27 |
28 | ## Reproduce
29 |
30 | To reproduce results on each benchmarks, refer to their subdirectory.
31 |
32 | ## Model Sever
33 |
34 | To accelerate the running efficiency of COPS, we implement a data-paralysis serving router and engine with SGLang backend, i.e., the [ModelServer](https://github.com/zhaochenyang20/ModelServer). After configuration, you can get a response or embedding from these kind of commands:
35 |
36 |
37 | ModelSever Command
38 |
39 | ```python
40 | from model_sever import ModelServer
41 | message = [
42 | {
43 | "role": "user",
44 | "content": "Interact with a household to solve a task. Here are a few examples.\nYou are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 1, a diningtable 3, a diningtable 2, a diningtable 1, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sidetable 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.\nYour task is to: find some apple and put it in sidetable.\n> think: To solve the task, I need to find and take an apple, then put it in sidetable.\nOK.\n> think: First I need to find an apple. An apple is more likely to appear in fridges (1), diningtables (1-3), sidetables (1), countertops (1), sinkbasins (1), garbagecan (1). I can check one by one, starting with fridge 1.\nOK.\n> go to fridge 1\nThe fridge 1 is closed.\n> open fridge 1\nYou open the fridge 1. The fridge 1 is open. In it, you see a lettuce 2, a mug 2, a potato 2, and a tomato 1.\n> go to diningtable 1\nOn the diningtable 1, you see nothing.\n> go to diningtable 2\nOn the diningtable 2, you see a pineapple 1.\n> go to diningtable 3\nOn the diningtable 3, you see a bowl 2, a bowl 1, a bread 2, and a bread 1.\n> go to sidetable 1\nOn the sidetable 1, you see a cup 1, a lettuce 1, a peppershaker 3, a potato 1, and a saltshaker 1.\n> go to countertop 1\nOn the countertop 1, you see a butterknife 1, a dishsponge 1, and a soapbottle 1.\n> go to sinkbasin 1\nOn the sinkbasin 1, you see a cup 2, a egg 1, a mug 1, and a potato 3.\n> go to garbagecan 1\nOn the garbagecan 1, you see a apple 3, and a egg 3.\n> think: Now I find an apple (3). Next, I need to take it.\n> take apple 3 from garbagecan 1\nYou pick up the apple 3 from the garbagecan 1.\n> think: Now I find an apple (3). Next, I need to put it in/on sidetable 1.\nOK.\n> go to sidetable 1\nOn the sidetable 1, you see a cup 1, a lettuce 1, a peppershaker 3, a potato 1, and a saltshaker 1.\n> put apple 3 in/on sidetable 1\nYou put the apple 3 in/on the sidetable 1.\nYou are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a countertop 1, a garbagecan 1, a handtowelholder 2, a handtowelholder 1, a sinkbasin 2, a sinkbasin 1, a toilet 1, a toiletpaperhanger 1, and a towelholder 1.\nYour task is to: put some spraybottle on toilet.\n> think: To solve the task, I need to find and take a sparybottle, then put it on toilet.\nOK.\n> think: First I need to find a spraybottle. A spraybottle is more likely to appear in cabinet (1-4), countertop (1), toilet (1), sinkbasin (1-2), garbagecan (1). I can check one by one, starting with cabinet 1.\nOK.\n> go to cabinet 1\nOn the cabinet 1, you see a cloth 1, a soapbar 1, a soapbottle 1.\n> go to cabinet 2\nThe cabinet 2 is closed.\n> open cabinet 2\nYou open the cabinet 2. The cabinet 2 is open. In it, you see a candle 1, and a spraybottle 2.\n> think: Now I find a spraybottle (2). Next, I need to take it.\n> take spraybottle 2 from cabinet 2\nYou pick up the spraybottle 2 from the cabinet 2.\n> think: Now I take a spraybottle (2). Next, I need to put it in/on toilet 1.\nOK.\n> go to toilet 1\nOn the toilet 1, you see a soapbottle 2.\n> put spraybottle 2 in/on toilet 1\nYou put the spraybottle 2 in/on the toilet 1.\n\n\nHere is the task:\nYou are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.\nYour task is to: put a mug in desk.\n>",
45 | },
46 | {
47 | "role": "assistant",
48 | "content": "think: To solve the task, I need to find and take a mug, then put it in/on desk.",
49 | },
50 | {"role": "user", "content": "OK."},
51 | {
52 | "role": "assistant",
53 | "content": "think: First I need to find a mug. A mug is more likely to appear in desk (1-2), drawer (1-6), shelf (1-6), garbagecan (1), laundryhamper (1), safe (1). I can check one by one, starting with desk 1.",
54 | },
55 | {"role": "user", "content": "OK."},
56 | {"role": "assistant", "content": "go to desk 1"},
57 | {
58 | "role": "user",
59 | "content": "On the desk 1, you see a bowl 1, a cd 2, a creditcard 2, a desklamp 1, a keychain 1, a laptop 2, a pen 2, and a pencil 2.",
60 | },
61 | {
62 | "role": "assistant",
63 | "content": "think: Now I don't find a mug in desk 1. Next, I need to check desk 2.",
64 | },
65 | {"role": "user", "content": "OK."},
66 | {"role": "assistant", "content": "go to desk 2"},
67 | {
68 | "role": "user",
69 | "content": "On the desk 2, you see a alarmclock 1, a book 3, a bowl 2, a creditcard 3, and a keychain 3.",
70 | },
71 | ]
72 |
73 | server = ModelServer()
74 | # `8` here is the completion model size configured in the `client_configs`
75 | complition = server.get_completion_or_embedding("8", message)
76 | # get the completion from a 8b instruct model, i.e. Llama3.1 8B
77 | print(complition)
78 |
79 | # `7` here is the embedding model size configured in the `client_configs`
80 | embedding = server.get_completion_or_embedding(
81 | "7",
82 | message="As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
83 | get_embedding=True,
84 | )
85 | # get the embedding of a 7b embedding model, i.e. `Alibaba-NLP/gte-Qwen1.5-7B-instruct`
86 | print(embedding[:10])
87 | ```
88 |
89 |
90 |
91 | ## Other Notes
92 |
93 | For all questions, contact [cya2@iu.edu](cya2@iu.edu)
94 |
95 | ## Cite
96 |
97 | ```bibtex
98 | @misc{yang2024copsempoweringllmagents,
99 | title={CoPS: Empowering LLM Agents with Provable Cross-Task Experience Sharing},
100 | author={Chen Yang and Chenyang Zhao and Quanquan Gu and Dongruo Zhou},
101 | year={2024},
102 | eprint={2410.16670},
103 | archivePrefix={arXiv},
104 | primaryClass={cs.LG},
105 | url={https://arxiv.org/abs/2410.16670},
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/alfworld/README:
--------------------------------------------------------------------------------
1 | # Alfworld experiment code
2 |
3 | Code modified from [Reflexion](https://github.com/noahshinn/reflexion)
4 |
5 | Setup your local LLM using an OpenAI API compatible endpoint, then set it up in utils.py. We recommend [SGLang](https://github.com/sgl-project/sglang)
6 |
7 | Example command:
8 |
9 | python main.py --num_trials 10 --num_envs 134 --plan_model_size 8 \
10 | --reflect_model_size 8 --cluster_size 134 --mem_size 0 \
11 | --run_name 'alfworld_new' \
12 | --mem_selection_method 'fifo' \
13 | --in_context_trajactory_size 5 \
14 | --online_embedding_model_size 7 \
15 | --trajactory_search_method 'knn' \
16 | --use_success_trajectory
--------------------------------------------------------------------------------
/alfworld/alfworld_trial.py:
--------------------------------------------------------------------------------
1 | """Adapted from https://github.com/ysymyth/ReAct/blob/master/alfworld.ipynb"""
2 |
3 | import os
4 |
5 | os.environ["PYTHONUTF8"] = "1"
6 | import sys
7 | import json
8 | import yaml
9 | import importlib
10 | import alfworld
11 | import re
12 | import numpy as np
13 | import faiss
14 | import random
15 | import math
16 | import alfworld.agents.environment
17 | from utils import EnvironmentHistory, ModelServer
18 | from IPython import embed
19 | from typing import List, Dict, Any, Tuple
20 |
21 |
22 | EMBEDDING_DIM = {"7": 3584, "2": 1536}
23 |
24 | FOLDER = "./prompts"
25 | PROMPT_FILE = "alfworld_3prompts.json"
26 |
27 | with open(os.path.join(FOLDER, PROMPT_FILE), "r", encoding="utf-8") as f:
28 | d = json.load(f)
29 |
30 |
31 | PREFIXES = {
32 | "pick_and_place": "put",
33 | "pick_clean_then_place": "clean",
34 | "pick_heat_then_place": "heat",
35 | "pick_cool_then_place": "cool",
36 | "look_at_obj": "examine",
37 | "pick_two_obj": "puttwo",
38 | }
39 |
40 |
41 | def process_ob(ob):
42 | if ob.startswith("You arrive at loc "):
43 | ob = ob[ob.find(". ") + 2 :]
44 | return ob
45 |
46 |
47 | def alfworld_run(
48 | model_server: ModelServer,
49 | plan_model_size: str,
50 | env,
51 | base_prompt,
52 | memory: List[str],
53 | to_print=True,
54 | ob="",
55 | mem_size: int = 0,
56 | mem_selection_method: str = "fifo",
57 | config_path: str = None,
58 | ) -> Tuple[EnvironmentHistory, bool]:
59 |
60 | assert mem_selection_method in [
61 | "fifo",
62 | "fix",
63 | "random",
64 | ], "Invalid memory selection method"
65 |
66 | if mem_size == 0:
67 | env_history = EnvironmentHistory(base_prompt, ob, [], [])
68 | elif mem_size != -1 and len(memory) >= mem_size:
69 | if mem_selection_method == "fifo":
70 | env_history = EnvironmentHistory(base_prompt, ob, memory[-mem_size:], [])
71 | elif mem_selection_method == "fix":
72 | env_history = EnvironmentHistory(base_prompt, ob, memory[:mem_size], [])
73 | elif mem_selection_method == "random":
74 | env_history = EnvironmentHistory(
75 | base_prompt, ob, random.sample(memory, mem_size), []
76 | )
77 | else:
78 | env_history = EnvironmentHistory(base_prompt, ob, memory, [])
79 |
80 | env_history.reset()
81 | if to_print:
82 | print(ob)
83 | sys.stdout.flush()
84 | cur_step = 0
85 | preva = ""
86 | message = [{"role": "user", "content": str(env_history)}]
87 | while cur_step < 49:
88 | tmpr = 0.0
89 | action = preva
90 | while action == preva:
91 | print(f"cur_step: {cur_step}")
92 | print(f"temperature: {tmpr}")
93 | completion = model_server.get_completion_or_embedding(
94 | plan_model_size,
95 | message=message,
96 | temperature=tmpr,
97 | )
98 | lines = completion.split("\n")
99 | action = ""
100 | for line in lines:
101 | stripped_line = line.strip()
102 | if stripped_line != "":
103 | action = stripped_line
104 | break
105 |
106 | if ">" in action:
107 | action = action.replace(">", "").strip()
108 | action_words = action.split(" ")
109 | if "put" in action_words:
110 | for i in range(len(action_words)):
111 | if (
112 | action_words[i].strip().lower() == "in"
113 | or action_words[i].strip().lower() == "on"
114 | ):
115 | action_words[i] = "in/on"
116 | action = " ".join(action_words)
117 | tmpr += 0.1
118 | env_history.add("action", action)
119 | message.append({"role": "assistant", "content": action})
120 | preva = action
121 | observation, _, done, info = env.step([action])
122 | observation, _, done = process_ob(observation[0]), info["won"][0], done[0]
123 | if action.startswith("think:"):
124 | observation = "OK."
125 | env_history.add("observation", observation)
126 | message.append({"role": "user", "content": observation})
127 | if to_print:
128 | print(f"> {action}\n{observation}")
129 | sys.stdout.flush()
130 | if done:
131 | return env_history, True
132 | elif env_history.check_is_exhausted():
133 | return env_history, False
134 | cur_step += 1
135 | return env_history, False
136 |
137 |
138 | def adjust_trial_number(trial_log_path: str, is_fail: bool = False) -> str:
139 | pattern = r"trial_(\d+)\.log"
140 |
141 | def replace(match):
142 | return "trial_fail.json" if is_fail else "trial_inf.json"
143 |
144 | result = re.sub(pattern, replace, trial_log_path)
145 | return result
146 |
147 |
148 | def adjust_trial_number2(trial_log_path: str) -> str:
149 | pattern = r"trial_(\d+)\.log"
150 |
151 | def replace(match):
152 | return "trial_cache.json"
153 |
154 | result = re.sub(pattern, replace, trial_log_path)
155 | return result
156 |
157 |
158 | def get_offline_embedding(des):
159 | with open("embedding.json", "r", encoding="utf-8") as file:
160 | embdata = json.load(file)
161 | for a, b in embdata:
162 | if des == a:
163 | return b
164 | else:
165 | return []
166 |
167 |
168 | def enumerate_splits(string):
169 | lines = [line for line in string.split("\n") if line]
170 | result = []
171 |
172 | for i in range(len(lines) - 1):
173 | first_part = "\n".join(lines[: i + 1])
174 | second_part = "\n".join(lines[i + 1 :])
175 | result.append((first_part, second_part))
176 |
177 | return result
178 |
179 |
180 | def replace_lines(text, prefix, suffix):
181 | replaced_text = re.sub(
182 | r"^>(.*)$", lambda m: f"{prefix}{m.group(1)}{suffix}", text, flags=re.MULTILINE
183 | )
184 | return replaced_text
185 |
186 |
187 | def print_first_n_lines(text, n=5):
188 | lines = text.split('\n')
189 | ret=""
190 | for line in lines[:n]:
191 | ret+=line+"\n"
192 | return ret
193 |
194 | def run_trial(
195 | cluster_size,
196 | cluster_counter,
197 | model_server: ModelServer,
198 | plan_model_size: str,
199 | trial_log_path: str,
200 | world_log_path: str,
201 | trial_idx: int,
202 | env_configs: List[Dict[str, Any]],
203 | mem_size: int = 0,
204 | specific_cluster_name: str = None,
205 | mem_selection_method: str = "fifo",
206 | use_success_trajectory: bool = False,
207 | trajactory_search_method: str = "knn",
208 | in_context_trajactory_size: int = 3,
209 | online_embedding_model_size: str = "2",
210 | ) -> List[Dict[str, Any]]:
211 | importlib.reload(alfworld)
212 | importlib.reload(alfworld.agents.environment)
213 | success_in_each_cluster = {
214 | "clean": 0,
215 | "put": 0,
216 | "cool": 0,
217 | "puttwo": 0,
218 | "examine": 0,
219 | "heat": 0,
220 | }
221 | with open("base_config.yaml") as reader:
222 | config = yaml.safe_load(reader)
223 | split = "eval_out_of_distribution"
224 |
225 | assert (
226 | not use_success_trajectory
227 | and trajactory_search_method is None
228 | and in_context_trajactory_size == 0
229 | ) or (
230 | use_success_trajectory
231 | and trajactory_search_method in ["knn", "random"]
232 | and in_context_trajactory_size > 0
233 | ), "If use_success_trajectory is False, then in_context_trajactory_size should be 0."
234 |
235 | env = getattr(alfworld.agents.environment, config["env"]["type"])(
236 | config, train_eval=split
237 | )
238 | env = env.init_env(batch_size=1)
239 |
240 | num_successes: int = 0
241 | num_success_increase: int = 0
242 | actcnt = 0
243 | last_trial_log_path: str = adjust_trial_number(trial_log_path)
244 | fail_db_path = adjust_trial_number(trial_log_path, is_fail=True)
245 | emb_cache_path = adjust_trial_number2(trial_log_path)
246 | success_data = {
247 | "clean": [],
248 | "put": [],
249 | "cool": [],
250 | "puttwo": [],
251 | "examine": [],
252 | "heat": [],
253 | }
254 |
255 | increase_success = {
256 | "clean": [],
257 | "put": [],
258 | "cool": [],
259 | "puttwo": [],
260 | "examine": [],
261 | "heat": [],
262 | }
263 |
264 | if not os.path.exists(last_trial_log_path):
265 | with open(last_trial_log_path, "w", encoding="utf-8") as file:
266 | print("build success log")
267 | json.dump(
268 | {
269 | "clean": [],
270 | "put": [],
271 | "cool": [],
272 | "puttwo": [],
273 | "examine": [],
274 | "heat": [],
275 | },
276 | file,
277 | indent=4,
278 | )
279 |
280 | with open(last_trial_log_path, "r", encoding="utf-8") as file:
281 | success_data = json.load(file)
282 |
283 | fail_data = {}
284 | emb_cache = {}
285 |
286 | if not os.path.exists(emb_cache_path):
287 | with open(emb_cache_path, "w", encoding="utf-8") as file:
288 | print("build embed cache")
289 | json.dump({}, file)
290 |
291 | with open(emb_cache_path, "r", encoding="utf-8") as file:
292 | emb_cache = json.load(file)
293 |
294 | if not os.path.exists(fail_db_path):
295 | with open(fail_db_path, "w", encoding="utf-8") as file:
296 | print("build fail log")
297 | json.dump({}, file)
298 |
299 | with open(fail_db_path, "r", encoding="utf-8") as file:
300 | fail_data = json.load(file)
301 |
302 | trajectories = []
303 | embedding_array = np.zeros((0, EMBEDDING_DIM[online_embedding_model_size]))
304 | huge_trajectories = []
305 | huge_ary = np.zeros((0, EMBEDDING_DIM[online_embedding_model_size]))
306 | if trajactory_search_method == "knn":
307 | for key in success_data:
308 | for des, trj in success_data[key]:
309 | # vec = get_offline_embedding(des)
310 | trj_cut = (trj + "\n").split("Here is the task:")[-1].strip()
311 | if trj_cut in emb_cache:
312 | vec = emb_cache[trj_cut]
313 | else:
314 | vec = model_server.get_completion_or_embedding(
315 | online_embedding_model_size,
316 | message=trj_cut,
317 | get_embedding=True,
318 | )
319 | emb_cache[trj_cut] = vec
320 | trajectories.append(trj)
321 | embedding_array = np.vstack((embedding_array, np.array(vec)))
322 | elif trajactory_search_method == "random":
323 | for key in success_data:
324 | for des, trj in success_data[key]:
325 | trajectories.append(trj)
326 |
327 | emb_db = faiss.IndexFlatL2(EMBEDDING_DIM[online_embedding_model_size])
328 | emb_db.add(embedding_array.astype("float32"))
329 | huge_db = faiss.IndexFlatL2(EMBEDDING_DIM[online_embedding_model_size])
330 | huge_db.add(huge_ary.astype("float32"))
331 | final_fail_db = {}
332 |
333 | for z, env_config in enumerate(env_configs):
334 |
335 | ob, info = env.reset()
336 | ob = "\n".join(ob[0].split("\n\n")[1:])
337 | env_description = ob.strip()
338 | env_vec = (
339 | np.array(get_offline_embedding(env_description))
340 | .reshape(1, -1)
341 | .astype("float32")
342 | )
343 | name = "/".join(info["extra.gamefile"][0].split("/")[-3:-1])
344 |
345 | print(f"using {name}")
346 |
347 | if env_config["is_success"]:
348 | for i, (k, v) in enumerate(PREFIXES.items()):
349 | if name.startswith(k):
350 | success_in_each_cluster[v] += 1
351 | num_successes += 1
352 | actcnt += 1
353 | with open(world_log_path, "a", encoding="utf-8") as wf:
354 | wf.write(
355 | f"""
356 |
357 | ==============================================================
358 | Environment #{z}
359 | Trial #{trial_idx}
360 | Game: {name}
361 | SUCCESS
362 | ==============================================================
363 |
364 | """
365 | )
366 | with open(trial_log_path, "a", encoding="utf-8") as wf:
367 | wf.write(
368 | f"""
369 |
370 | ==============================================================
371 | Environment #{z}
372 | Trial #{trial_idx}
373 | Game: {name}
374 | SUCCESS
375 | ==============================================================
376 |
377 | """
378 | )
379 | continue
380 |
381 | for i, (k, v) in enumerate(PREFIXES.items()):
382 | if name.startswith(k):
383 | env_configs[z]["cluster"] = v
384 | if v in cluster_counter:
385 | if len(cluster_counter[v]) < cluster_size and (
386 | specific_cluster_name is None or v == specific_cluster_name
387 | ):
388 | cluster_counter[v].append(name)
389 | else:
390 | if specific_cluster_name is None or v == specific_cluster_name:
391 | cluster_counter[v] = [name]
392 | else:
393 | cluster_counter[v] = []
394 | base_prompt = (
395 | "Interact with a household to solve a task. Here are a few examples.\n"
396 | + d[f"react_{v}_1"]
397 | + d[f"react_{v}_0"]
398 | + "\n"
399 | )
400 | if (
401 | trajectories
402 | and use_success_trajectory
403 | and in_context_trajactory_size > 0
404 | ):
405 | assert trajactory_search_method in ["knn", "random"]
406 | if trajactory_search_method == "knn":
407 | if name in fail_data:
408 | fail_key=fail_data[name].strip()
409 | else:
410 | fail_key=env_description.strip()
411 | fail_vec = model_server.get_completion_or_embedding(
412 | online_embedding_model_size,
413 | message=fail_key,
414 | get_embedding=True,
415 | )
416 | indices = [[]]
417 | dist=[]
418 | for index, row in np.ndenumerate(embedding_array):
419 | if index[1] == 0:
420 | realN = (1.0 / np.linalg.norm(np.array(fail_vec) - embedding_array[index[0]]))
421 | dist.append((realN,index[0]))
422 | dist.sort(key=lambda x: x[0], reverse=True)
423 | sz_now=min(len(trajectories),in_context_trajactory_size)
424 | dist=dist[:sz_now]
425 | dist=[(math.exp(5.0*x),y) for (x,y) in dist]
426 | original_sum=sum([x for (x,_) in dist])
427 | dist=[(x *float(sz_now) / original_sum,y) for (x,y) in dist]
428 | tot=0
429 | cntD=[]
430 | realD=[]
431 | for (x,y) in dist:
432 | cntD.append(math.floor(x))
433 | realD.append(x)
434 | tot+=math.floor(x)
435 | while totmaxx:
440 | maxx=realD[ig]-float(cntD[ig])
441 | maxi=ig
442 | cntD[maxi]+=1
443 | tot+=1
444 | weights=[]
445 | for (x,y) in dist:
446 | weights.append(x)
447 | normalized_weights = [w / sum(weights) for w in weights]
448 | for ig in range(sz_now):
449 | s_ind = np.random.choice(len(weights), p=normalized_weights)
450 | (_,y)=dist[s_ind]
451 | indices[0].append(y)
452 | elif trajactory_search_method == "random":
453 | indices[0] = random.sample(
454 | range(len(trajectories)), min(len(trajectories),in_context_trajactory_size)
455 | )
456 | for i in indices[0]:
457 | base_prompt += (
458 | (trajectories[i] + "\n")
459 | .split("Here is the task:")[-1]
460 | .strip()
461 | ) + "\n"
462 | base_prompt = replace_lines(
463 | base_prompt,
464 | "<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n",
465 | "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n",
466 | )
467 | base_prompt += "\n\nExamples finished.\n\n"
468 | if env_config["principle"]:
469 | base_prompt += "Here are the core principles you should follow as much as possible in your planning:\n[Principles start]\n"
470 | for pi in env_config["principle"]:
471 | base_prompt += pi + "\n"
472 | base_prompt += "\n[Principles end]\n\n"
473 | inclus = False
474 | final_env_history = ""
475 | is_success = False
476 | if name in cluster_counter[v]:
477 | actcnt += 1
478 | inclus = True
479 | final_env_history, is_success = alfworld_run(
480 | model_server,
481 | plan_model_size,
482 | env,
483 | base_prompt,
484 | env_config["memory"],
485 | to_print=True,
486 | ob=ob,
487 | mem_size=mem_size,
488 | mem_selection_method=mem_selection_method,
489 | )
490 | else:
491 | env_configs[z]["skip"] = True
492 |
493 | # update env config
494 | if is_success:
495 | status_str: str = f"""
496 |
497 | ==============================================================
498 | Environment #{z}
499 | Trial #{trial_idx}
500 | Game: {name}
501 | SUCCESS
502 | ==============================================================
503 |
504 | """
505 | env_configs[z]["is_success"] = True
506 | num_successes += 1
507 | num_success_increase += 1
508 | success_in_each_cluster[v] += 1
509 | elif inclus:
510 | status_str: str = f"""
511 |
512 | ==============================================================
513 | Environment #{z}
514 | Trial #{trial_idx}
515 | Game: {name}
516 | FAIL
517 | ==============================================================
518 |
519 | """
520 | else:
521 | status_str: str = f"""
522 |
523 | ==============================================================
524 | Environment #{z}
525 | Trial #{trial_idx}
526 | SKIP
527 | ==============================================================
528 |
529 | """
530 | # log to world log
531 | assert world_log_path.endswith(".log")
532 | if not os.path.exists(world_log_path):
533 | os.system(f"touch {world_log_path}")
534 | with open(world_log_path, "a", encoding="utf-8") as f:
535 | f.write(status_str + "\n")
536 |
537 | if is_success:
538 | increase_success[v].append(
539 | (env_description, str(final_env_history))
540 | )
541 | trj_cut = (str(final_env_history) + "\n").split("Here is the task:")[-1].strip()
542 | vec = model_server.get_completion_or_embedding(
543 | online_embedding_model_size,
544 | message=trj_cut,
545 | get_embedding=True,
546 | )
547 | emb_cache[trj_cut] = vec
548 | trajectories.append(str(final_env_history))
549 | embedding_array = np.vstack((embedding_array, np.array(vec)))
550 | else:
551 | final_fail_db[name] = (
552 | (str(final_env_history) + "\n")
553 | .split("Here is the task:")[-1]
554 | .strip()
555 | )
556 |
557 | with open(trial_log_path, "a", encoding="utf-8") as wf:
558 | wf.write(
559 | f"""
560 |
561 | ==============================================================
562 | Environment #{z}
563 | {str(final_env_history)}
564 | STATUS: {"OK" if is_success else "FAIL"}
565 | ==============================================================
566 |
567 | """
568 | )
569 | env.close()
570 |
571 | log_str: str = f"""
572 |
573 | ==============================================================
574 | SUCCESS: {num_successes}
575 | INCREASE SUCCESS: {num_success_increase}
576 | FAIL: {actcnt - num_successes}
577 | TOTAL: {actcnt}
578 | ACCURACY: {round(num_successes / actcnt, 2)}
579 | CLUSTER SUCCEESS: {str(success_in_each_cluster)}
580 | ==============================================================
581 |
582 | """
583 |
584 | with open(last_trial_log_path, "w", encoding="utf-8") as file:
585 | json.dump(
586 | {
587 | key: success_data.get(key, []) + increase_success.get(key, [])
588 | for key in success_data.keys() | increase_success.keys()
589 | },
590 | file,
591 | indent=4,
592 | )
593 | with open(fail_db_path, "w", encoding="utf-8") as file:
594 | json.dump(
595 | final_fail_db,
596 | file,
597 | indent=4,
598 | )
599 |
600 | with open(emb_cache_path, "w", encoding="utf-8") as file:
601 | json.dump(
602 | emb_cache,
603 | file,
604 | indent=4,
605 | )
606 |
607 | with open(trial_log_path, "a", encoding="utf-8") as wf:
608 | wf.write(log_str)
609 | with open(world_log_path, "a", encoding="utf-8") as wf:
610 | wf.write(log_str + "\n")
611 | return env_configs, cluster_counter
612 |
--------------------------------------------------------------------------------
/alfworld/base_config.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | data_path: '$ALFWORLD_DATA/json_2.1.1/train'
3 | eval_id_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_seen' # null/None to disable
4 | eval_ood_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_unseen' # null/None to disable
5 | num_train_games: -1 # max training games (<=0 indicates full dataset)
6 | num_eval_games: -1 # max evaluation games (<=0 indicates full dataset)
7 |
8 | logic:
9 | domain: '$ALFWORLD_DATA/logic/alfred.pddl' # PDDL domain file that defines the world dynamics
10 | grammar: '$ALFWORLD_DATA/logic/alfred.twl2' # Grammar file that defines the text feedbacks
11 |
12 | env:
13 | type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid'
14 | regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file
15 | domain_randomization: False # shuffle Textworld print order and object id nums
16 | task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
17 | expert_timeout_steps: 150 # max steps before timeout for expert to solve the task
18 | expert_type: "handcoded" # 'handcoded' or 'downward'. Note: the downward planner is very slow for real-time use
19 | goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED)
20 |
21 | hybrid:
22 | start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point
23 | thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training
24 | eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training
25 |
26 | thor:
27 | screen_width: 300 # width of THOR window
28 | screen_height: 300 # height of THOR window
29 | smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow)
30 | save_frames_to_disk: False # save frame PNGs to disk (useful for making videos)
31 | save_frames_path: './videos/' # path to save frame PNGs
32 |
33 | controller:
34 | type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER)
35 | debug: False
36 | load_receps: True # load receptacle locations from precomputed dict (if available)
37 |
38 | mask_rcnn:
39 | pretrained_model_path: '$ALFWORLD_DATA/detectors/mrcnn.pth'
40 |
41 | general:
42 | random_seed: 42
43 | use_cuda: True # disable this when running on machine without cuda
44 | visdom: False # plot training/eval curves, run with visdom server
45 | task: 'alfred'
46 | training_method: 'dagger' # 'dqn' or 'dagger'
47 | save_path: './training/' # path to save pytorch models
48 | observation_pool_capacity: 3 # k-size queue, 0 indicates no observation
49 | hide_init_receptacles: False # remove initial observation containing navigable receptacles
50 |
51 | training:
52 | batch_size: 10
53 | max_episode: 50000
54 | smoothing_eps: 0.1
55 | optimizer:
56 | learning_rate: 0.001
57 | clip_grad_norm: 5
58 |
59 | evaluate:
60 | run_eval: True
61 | batch_size: 10
62 | env:
63 | type: "AlfredTWEnv"
64 |
65 | checkpoint:
66 | report_frequency: 1000 # report every N episode
67 | experiment_tag: 'test' # name of experiment
68 | load_pretrained: False # during test, enable this so that the agent load your pretrained model
69 | load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path
70 |
71 | model:
72 | encoder_layers: 1
73 | decoder_layers: 1
74 | encoder_conv_num: 5
75 | block_hidden_dim: 64
76 | n_heads: 1
77 | dropout: 0.1
78 | block_dropout: 0.1
79 | recurrent: True
80 |
81 | rl:
82 | action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working)
83 | max_target_length: 20 # max token length for seq2seq generation
84 | beam_width: 10 # 1 means greedy
85 | generate_top_k: 3
86 |
87 | training:
88 | max_nb_steps_per_episode: 50 # terminate after this many steps
89 | learn_start_from_this_episode: 0 # delay updates until this epsiode
90 | target_net_update_frequency: 500 # sync target net with online net per this many epochs
91 |
92 | replay:
93 | accumulate_reward_from_final: True
94 | count_reward_lambda: 0.0 # 0 to disable
95 | novel_object_reward_lambda: 0.0 # 0 to disable
96 | discount_gamma_game_reward: 0.9
97 | discount_gamma_count_reward: 0.5
98 | discount_gamma_novel_object_reward: 0.5
99 | replay_memory_capacity: 500000 # adjust this depending on your RAM size
100 | replay_memory_priority_fraction: 0.5
101 | update_per_k_game_steps: 5
102 | replay_batch_size: 64
103 | multi_step: 3
104 | replay_sample_history_length: 4
105 | replay_sample_update_from: 2
106 |
107 | epsilon_greedy:
108 | noisy_net: False # if this is true, then epsilon greedy is disabled
109 | epsilon_anneal_episodes: 1000 # -1 if not annealing
110 | epsilon_anneal_from: 0.3
111 | epsilon_anneal_to: 0.1
112 |
113 | dagger:
114 | action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working)
115 | max_target_length: 20 # max token length for seq2seq generation
116 | beam_width: 10 # 1 means greedy
117 | generate_top_k: 5
118 | unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation
119 |
120 | training:
121 | max_nb_steps_per_episode: 50 # terminate after this many steps
122 |
123 | fraction_assist:
124 | fraction_assist_anneal_episodes: 50000
125 | fraction_assist_anneal_from: 1.0
126 | fraction_assist_anneal_to: 0.01
127 |
128 | fraction_random:
129 | fraction_random_anneal_episodes: 0
130 | fraction_random_anneal_from: 0.0
131 | fraction_random_anneal_to: 0.0
132 |
133 | replay:
134 | replay_memory_capacity: 500000
135 | update_per_k_game_steps: 5
136 | replay_batch_size: 64
137 | replay_sample_history_length: 4
138 | replay_sample_update_from: 2
139 |
140 | vision_dagger:
141 | model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input)
142 | resnet_fc_dim: 64
143 | maskrcnn_top_k_boxes: 10 # top k box features
144 | use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!)
145 | sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn'
146 |
--------------------------------------------------------------------------------
/alfworld/generate_reflections.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["PYTHONUTF8"] = "1"
4 | import random
5 | from utils import ModelServer
6 |
7 | from typing import List, Dict, Any
8 |
9 | with open("./reflexion_few_shot_examples.txt", "r", encoding="utf-8") as f:
10 | FEW_SHOT_EXAMPLES = f.read()
11 |
12 |
13 | def _get_scenario(s: str) -> str:
14 | """Parses the relevant scenario from the experience log."""
15 | return s.split("Here is the task:")[-1].strip()
16 |
17 |
18 | def _generate_reflection_query(log_str: str, memory: List[str]) -> str:
19 | """Allows the Agent to reflect upon a past experience."""
20 | scenario: str = _get_scenario(log_str)
21 | query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. You were unsuccessful in completing the task. Do not summarize your environment, but rather think about the strategy and path you took to attempt to complete the task. Devise a concise, new plan of action that accounts for your mistake with reference to specific actions that you should have taken. For example, if you tried A and B but forgot C, then devise a plan to achieve C with environment-specific actions. You will need this later when you are solving the same task. Remember that your plan must be very concise and your entire output must only limited to a single line of words. Here are two examples:
22 |
23 | {FEW_SHOT_EXAMPLES}
24 |
25 | Here is your failed experience:
26 | {scenario}"""
27 |
28 | query += "\n\nNew plan:"
29 | return query
30 |
31 |
32 | def update_memory(
33 | model_server: ModelServer,
34 | reflection_model_size: str,
35 | trial_log_path: str,
36 | env_configs: List[Dict[str, Any]],
37 | mem_size: int = -1,
38 | ) -> List[Dict[str, Any]]:
39 | """Updates the given env_config with the appropriate reflections."""
40 | with open(trial_log_path, "r", encoding="utf-8") as f:
41 | full_log: str = f.read()
42 |
43 | env_logs = full_log.split(
44 | "==============================================================\n\n\n\n=============================================================="
45 | )
46 |
47 | faile = []
48 | env_logs = env_logs[:-1]
49 | env_logs[0] = env_logs[0].strip(
50 | "\n\n==============================================================\n"
51 | )
52 | env_logs = [each.strip() for each in env_logs]
53 | assert all(
54 | ["Environment #" in env_logs[i] for i in range(len(env_logs))]
55 | ), "Parsing error"
56 | #! 不要修改 print 的这串字符串,因为这串字符串是用来分割日志的
57 | #! env_logs[0] 需要特殊处理,而最原始的 env_logs[-1] 是胜率总结
58 | for i, env in enumerate(env_configs):
59 |
60 | # if unsolved, get reflection and update env config
61 | if not env["is_success"] and not env["skip"]:
62 | faile.append(i)
63 | assert mem_size != 0, "Memory size should not be 0."
64 | #! Reflection 的 mem selection 都是 FIFO
65 | if mem_size != -1 and len(env["memory"]) >= mem_size:
66 | memory: List[str] = env["memory"][-mem_size:]
67 | else:
68 | memory: List[str] = env["memory"]
69 | reflection_query: str = _generate_reflection_query(
70 | env_logs[i].strip(), memory
71 | )
72 | print("start reflection env:")
73 | print(i)
74 | messages = [{"role": "user", "content": reflection_query}]
75 | reflection: str = model_server.get_completion_or_embedding(reflection_model_size, messages) # type: ignore
76 | env_configs[i]["memory"] += [reflection]
77 |
78 | return env_configs #! 删去 general reflection 的逻辑
79 |
--------------------------------------------------------------------------------
/alfworld/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from IPython import embed
3 |
4 | os.environ["PYTHONUTF8"] = "1"
5 | import json
6 | import argparse
7 | from utils import ModelServer
8 |
9 | from alfworld_trial import run_trial
10 |
11 | from generate_reflections import update_memory
12 | from typing import Any, List, Dict
13 | import json
14 |
15 |
16 | def get_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--num_trials", type=int, help="The number of trials to run")
19 | parser.add_argument(
20 | "--num_envs", type=int, help="The number of environments per trial"
21 | )
22 | parser.add_argument(
23 | "--mem_size",
24 | type=int,
25 | help="The size of the memory that will be used in the memory. -1 means unlimited mem, 0 means 0 mem.",
26 | default=0,
27 | )
28 | parser.add_argument("--run_name", type=str, help="The name of the run")
29 | parser.add_argument("--is_resume", action="store_true", help="To resume run")
30 | parser.add_argument(
31 | "--resume_dir", type=str, help="If resume, the logging directory", default=""
32 | )
33 | parser.add_argument(
34 | "--start_trial_num",
35 | type=int,
36 | help="The start trial num, if not resume, should be 0.",
37 | default=0,
38 | )
39 | parser.add_argument(
40 | "--plan_model_size", type=str, help="only support 8, 70", default="70"
41 | )
42 | parser.add_argument(
43 | "--reflect_model_size", type=str, help="only support 8, 70", default="70"
44 | )
45 | parser.add_argument(
46 | "--online_embedding_model_size", type=str, help="only support 2, 7", default="2"
47 | )
48 | parser.add_argument(
49 | "--cluster_size",
50 | type=int,
51 | help="How many instances to run in each cluster",
52 | default=5,
53 | )
54 | parser.add_argument(
55 | "--specific_cluster_name",
56 | type=str,
57 | default=None,
58 | help="Whether to only run on a specific cluster of envs",
59 | )
60 | parser.add_argument(
61 | "--mem_selection_method",
62 | type=str,
63 | default="fifo",
64 | help="The method to select memory for planning, only fifo/fix/random.",
65 | )
66 | parser.add_argument(
67 | "--log_file_path",
68 | type=str,
69 | help="The path to the log file",
70 | default="",
71 | required=True,
72 | )
73 | parser.add_argument(
74 | "--use_success_trajectory",
75 | action="store_true",
76 | help="Whether to use success trajectories in planning base prompt.",
77 | default=False,
78 | )
79 |
80 | parser.add_argument(
81 | "--trajactory_search_method",
82 | type=str,
83 | default=None,
84 | help="The method to select memory for planning, only knn/random.",
85 | )
86 | parser.add_argument(
87 | "--in_context_trajactory_size",
88 | type=int,
89 | default=3,
90 | help="The size of in context trajactory size.",
91 | )
92 | args = parser.parse_args()
93 |
94 | assert args.log_file_path != "", "Log file path should be provided"
95 | assert args.num_trials > 0, "Number of trials should be positive"
96 | assert args.num_envs > 0, "Number of environments should be positive"
97 | assert (
98 | args.resume_dir == args.run_name or args.resume_dir == ""
99 | ), "Should resume from previous directory"
100 | assert args.mem_selection_method in [
101 | "fifo",
102 | "fix",
103 | "random",
104 | ], "Invalid memory selection method"
105 | assert args.specific_cluster_name in [
106 | None,
107 | "clean",
108 | "put",
109 | "heat",
110 | "cool",
111 | "examine",
112 | "puttwo",
113 | ], "Invalid cluster name used"
114 | assert (
115 | not args.use_success_trajectory
116 | and args.trajactory_search_method is None
117 | and args.in_context_trajactory_size == 0
118 | ) or (
119 | args.use_success_trajectory
120 | and args.trajactory_search_method in ["knn", "random"]
121 | and args.in_context_trajactory_size > 0
122 | ), "If use_success_trajectory is False, then in_context_trajactory_size should be 0."
123 | assert args.trajactory_search_method in [
124 | "knn",
125 | "random",
126 | None,
127 | ], "Invalid search method"
128 | assert args.plan_model_size in ["8", "70"], "Invalid plan model size"
129 | assert args.reflect_model_size in ["8", "70"], "Invalid reflect model size"
130 | assert args.online_embedding_model_size in [
131 | "2",
132 | "7",
133 | ], "Invalid online embedding size"
134 | return args
135 |
136 |
137 | def main(args) -> None:
138 | if args.is_resume:
139 | if not os.path.exists(args.resume_dir):
140 | raise ValueError(f"Resume directory `{args.resume_dir}` does not exist")
141 | logging_dir = args.resume_dir
142 |
143 | # load previous environment configs
144 | env_config_path: str = os.path.join(
145 | args.resume_dir, f"env_results_trial_{args.start_trial_num - 1}.json"
146 | )
147 | if not os.path.exists(env_config_path):
148 | raise ValueError(
149 | f"Environment config file `{env_config_path}` does not exist"
150 | )
151 | with open(env_config_path, "r", encoding="utf-8") as rf:
152 | env_configs: List[Dict[str, Any]] = json.load(rf)
153 | else:
154 | # Create new run directory
155 | if not os.path.exists(args.run_name):
156 | os.makedirs(args.run_name)
157 | logging_dir = args.run_name
158 |
159 | # initialize environment configs
160 | env_configs: List[Dict[str, Any]] = []
161 |
162 | env_configs = [
163 | {
164 | "name": f"env_{i}",
165 | "memory": [],
166 | "principle": [],
167 | "is_success": False,
168 | "skip": False,
169 | "cluster": "",
170 | }
171 | for i in range(args.num_envs)
172 | ]
173 | config_path = os.path.join(logging_dir, "config.json")
174 | model_server = ModelServer(config_path=config_path)
175 | world_log_path: str = os.path.join(logging_dir, "world.log")
176 | os.system(f"touch {world_log_path}")
177 | with open(world_log_path, "a", encoding="utf-8") as wf:
178 | if args.is_resume:
179 | info_string = f"""
180 |
181 | ==============================================================
182 | RESUME
183 | Run name: {args.run_name}
184 | Number of trials: {args.num_trials}
185 | Number of max environments: {args.num_envs}
186 | Number of instances per cluster: {args.cluster_size}
187 | Resume trial number: {args.start_trial_num}
188 | Sending all logs to: {args.run_name}
189 | Setting memory size to: {args.mem_size}
190 | Plan model size: {args.plan_model_size}
191 | Reflect model size: {args.reflect_model_size}
192 | Online embedding size: {args.online_embedding_model_size}
193 | Specific cluster name: {args.specific_cluster_name}
194 | Log file path: {args.log_file_path}
195 | Use success trajectory: {args.use_success_trajectory}
196 | Trajactory search method: {args.trajactory_search_method}
197 | In context trajactory size: {args.in_context_trajactory_size}
198 | ==============================================================
199 |
200 | """
201 | wf.write(info_string)
202 | else:
203 | info_string = f"""
204 |
205 | ==============================================================
206 | START
207 | Run name: {args.run_name}
208 | Number of trials: {args.num_trials}
209 | Number of max environments: {args.num_envs}
210 | Number of instances per cluster: {args.cluster_size}
211 | Start trial number: {args.start_trial_num}
212 | Sending all logs to: {args.run_name}
213 | Setting memory size to: {args.mem_size}
214 | Plan model size: {args.plan_model_size}
215 | Reflect model size: {args.reflect_model_size}
216 | Online embedding size: {args.online_embedding_model_size}
217 | Specific cluster name: {args.specific_cluster_name}
218 | Log file path: {args.log_file_path}
219 | Use success trajectory: {args.use_success_trajectory}
220 | Trajactory search method: {args.trajactory_search_method}
221 | In context trajactory size: {args.in_context_trajactory_size}
222 | ==============================================================
223 |
224 | """
225 | wf.write(info_string)
226 |
227 | with open(config_path, "w", encoding="utf-8") as wf:
228 | info_dict = vars(args)
229 | info_dict["is_running"] = True
230 | json.dump(info_dict, wf, indent=4)
231 |
232 | trial_idx = args.start_trial_num
233 | cluster_counter = {}
234 |
235 | while trial_idx < args.num_trials:
236 | with open(world_log_path, "a", encoding="utf-8") as wf:
237 | wf.write(
238 | f"""
239 |
240 | ==============================================================
241 | Start Trial #{trial_idx}
242 | ==============================================================
243 |
244 | """
245 | )
246 |
247 | # set paths to log files
248 | trial_log_path: str = os.path.join(args.run_name, f"trial_{trial_idx}.log")
249 | os.system(f"touch {trial_log_path}")
250 | trial_env_configs_log_path: str = os.path.join(
251 | args.run_name, f"env_results_trial_{trial_idx}.json"
252 | )
253 | if os.path.exists(trial_log_path):
254 | open(trial_log_path, "w").close()
255 | if os.path.exists(trial_env_configs_log_path):
256 | open(trial_env_configs_log_path, "w").close()
257 |
258 | env_configs, cluster_counter = run_trial(
259 | args.cluster_size,
260 | cluster_counter,
261 | model_server,
262 | args.plan_model_size,
263 | trial_log_path,
264 | world_log_path,
265 | trial_idx,
266 | env_configs,
267 | args.mem_size,
268 | specific_cluster_name=args.specific_cluster_name,
269 | mem_selection_method=args.mem_selection_method,
270 | use_success_trajectory=args.use_success_trajectory,
271 | trajactory_search_method=args.trajactory_search_method,
272 | in_context_trajactory_size=args.in_context_trajactory_size,
273 | online_embedding_model_size=args.online_embedding_model_size,
274 | )
275 |
276 |
277 | if args.mem_size != 0:
278 |
279 | env_configs: List[Dict[str, Any]] = update_memory(
280 | model_server,
281 | args.reflect_model_size,
282 | trial_log_path,
283 | env_configs,
284 | mem_size=args.mem_size,
285 | )
286 |
287 | # log env configs for trial
288 | with open(trial_env_configs_log_path, "w", encoding="utf-8") as wf:
289 | json.dump(env_configs, wf, indent=4)
290 |
291 | # log world for trial
292 | with open(world_log_path, "a", encoding="utf-8") as wf:
293 | wf.write(
294 | f"""
295 |
296 | ==============================================================
297 | End Trial #{trial_idx}
298 | ==============================================================
299 |
300 | """
301 | )
302 | trial_idx += 1
303 |
304 |
305 | if __name__ == "__main__":
306 | args = get_args()
307 | main(args)
308 |
--------------------------------------------------------------------------------
/alfworld/prompts/fever.json:
--------------------------------------------------------------------------------
1 | {"webact_simple3": "\nDetermine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nAction 1: Search[Nikolaj Coster-Waldau]\nObservation 1: Nikolaj William Coster-Waldau (born 27 July 1970) is a Danish actor and producer. He graduated from the Danish National School of Performing Arts in Copenhagen in 1993,[1] and had his breakthrough role in Denmark with the film Nightwatch (1994). He played Jaime Lannister in the HBO fantasy drama series Game of Thrones, for which he received two Primetime Emmy Award nominations for Outstanding Supporting Actor in a Drama Series.. Coster-Waldau has appeared in numerous films in his native Denmark and Scandinavia, including Headhunters (2011) and A Thousand Times Good Night (2013). In the U.S, his debut film role was in the war film Black Hawk Down (2001), playing Medal of Honor recipient Gary Gordon.[2] He then played a detective in the short-lived Fox television series New Amsterdam (2008), and appeared in the 2009 Fox television film Virtuality, originally intended as a pilot.\nAction 2: Finish[SUPPORTS]\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nAction 1: Search[Stranger Things]\nObservation 1: Stranger Things is an American science fiction horror drama television series created by the Duffer Brothers. Set in the 1980s, primarily in the fictional town of Hawkins, Indiana, the series centers on a number of mysteries and supernatural events occurring around the town and their impact on an ensemble of child and adult characters. \nAction 2: Finish[REFUTES]\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nAction 1: Search[Beautiful]\nObservation 1: Could not find [Beautiful]. Similar: ['Beautiful', 'Beautiful, Beautiful', 'A Beautiful Mind (film)', 'Beautiful (Christina Aguilera song)', 'Life Is Beautiful'].\nAction 2: Search[Beautiful (Christina Aguilera song)]\nObservation 2: \"Beautiful\" is a song recorded by American singer Christina Aguilera for her fourth studio album, Stripped (2002).\nAction 3: Lookup[Billboard Hot 100]\nObservation 3: (Result 1 / 3) The song peaked at number two on the Billboard Hot 100 in the United States, where it was certified Gold for 500,000 units shipped.\nAction 4: Finish[NOT ENOUGH INFO]\n\n", "cotqa_simple3": "Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nThought: Nikolaj William Coster-Waldau appeared in the 2009 Fox television film Virtuality, so he has worked with the Fox Broadcasting Company.\nAnswer: SUPPORTS\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nThought: Stranger Things is in the fictional town of Hawkins, Indiana, not in Bloomington, Indiana.\nAnswer:REFUTES\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nThought: The song peaked at number two on the Billboard Hot 100 in the United States, but not sure if it was in 2003.\nAnswer: NOT ENOUGH INFO\n", "webqa_simple3": "Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nAnswer: SUPPORTS\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nAnswer:REFUTES\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nAnswer: NOT ENOUGH INFO\n", "webthink_simple3": "\nDetermine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nThought 1: I need to search Nikolaj Coster-Waldau and find if he has worked with the Fox Broadcasting Company.\nAction 1: Search[Nikolaj Coster-Waldau]\nObservation 1: Nikolaj William Coster-Waldau (born 27 July 1970) is a Danish actor and producer. He graduated from the Danish National School of Performing Arts in Copenhagen in 1993,[1] and had his breakthrough role in Denmark with the film Nightwatch (1994). He played Jaime Lannister in the HBO fantasy drama series Game of Thrones, for which he received two Primetime Emmy Award nominations for Outstanding Supporting Actor in a Drama Series.. Coster-Waldau has appeared in numerous films in his native Denmark and Scandinavia, including Headhunters (2011) and A Thousand Times Good Night (2013). In the U.S, his debut film role was in the war film Black Hawk Down (2001), playing Medal of Honor recipient Gary Gordon.[2] He then played a detective in the short-lived Fox television series New Amsterdam (2008), and appeared in the 2009 Fox television film Virtuality, originally intended as a pilot.\nThought 2: Because he \"appeared in the 2009 Fox television film Virtuality\", he should have worked with the Fox Broadcasting Company.\nAction 2: Finish[SUPPORTS]\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nThought 1: I should search for Stranger Things, and see if it is set in Bloomington, Indiana.\nAction 1: Search[Stranger Things]\nObservation 1: Stranger Things is an American science fiction horror drama television series created by the Duffer Brothers. Set in the 1980s, primarily in the fictional town of Hawkins, Indiana, the series centers on a number of mysteries and supernatural events occurring around the town and their impact on an ensemble of child and adult characters. \nThought 2: The observation says that it is set in a \"fictional town of Hawkins, Indiana\", so it is not set in Bloomington.\nAction 2: Finish[REFUTES]\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nThought 1: I need to search the song Beautiful and find if it reached number two on the Billboard Hot 100 in 2003.\nAction 1: Search[Beautiful]\nObservation 1: Could not find [Beautiful]. Similar: ['Beautiful', 'Beautiful, Beautiful', 'A Beautiful Mind (film)', 'Beautiful (Christina Aguilera song)', 'Life Is Beautiful'].\nThought 2: From suggestions, I should search \"Beautiful (Christina Aguilera song)\" to find the song.\nAction 2: Search[Beautiful (Christina Aguilera song)]\nObservation 2: \"Beautiful\" is a song recorded by American singer Christina Aguilera for her fourth studio album, Stripped (2002).\nThought 3: It does not mention Billboard, so I need to look up \"Billboard Hot 100\" to find if it reached number two on it in 2003.\nAction 3: Lookup[Billboard Hot 100]\nObservation 3: (Result 1 / 3) The song peaked at number two on the Billboard Hot 100 in the United States, where it was certified Gold for 500,000 units shipped.\nThought 4: It only says the song peaked at number two on the Billboard Hot 100, but not if it was in 2003. I am not sure if this claim is true or not.\nAction 4: Finish[NOT ENOUGH INFO]\n\n"}
--------------------------------------------------------------------------------
/alfworld/requirements.txt:
--------------------------------------------------------------------------------
1 | faiss
2 | openai
3 | alfworld
--------------------------------------------------------------------------------
/alfworld/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
4 | os.environ["PYTHONUTF8"] = "1"
5 | import time
6 | from typing import Dict, List
7 | import openai
8 | import random, json
9 | from IPython import embed
10 | import random
11 | import math
12 |
13 |
14 | class EnvironmentHistory:
15 | def __init__(
16 | self,
17 | base_query: str,
18 | start_info,
19 | memory: List[str],
20 | history: List[Dict[str, str]] = [],
21 | ) -> None:
22 |
23 | def _get_base_query(base_query: str, start_info: str, memory: List[str]) -> str:
24 | query = base_query
25 |
26 | # add memory if it exists
27 |
28 | query += f"\nHere is the task:\n{start_info}"
29 | if len(memory) > 0:
30 | query += "\n\nBelow are your reflection memory for the task, you should apply them wisely in your planning:\n[memory start]\n"
31 | for i, m in enumerate(memory):
32 | query += f"\nReflection from Trial {i}:\n{m.strip()}"
33 | query += "\n[memory end]\n"
34 | return query
35 |
36 | self._cur_query: str = f"{_get_base_query(base_query, start_info, memory)}"
37 | self._history: List[Dict[str, str]] = history
38 | self._last_action: str = ""
39 | self._is_exhausted: bool = False
40 |
41 | def add(self, label: str, value: str) -> None:
42 | assert label in ["action", "observation", "human_edit"]
43 | self._history += [
44 | {
45 | "label": label,
46 | "value": value,
47 | }
48 | ]
49 | if label == "action":
50 | if value == self._last_action:
51 | self._is_exhausted = True
52 | else:
53 | self._last_action = value
54 |
55 | def check_is_exhausted(self) -> bool:
56 | return self._is_exhausted
57 |
58 | def reset(self) -> None:
59 | self._history = []
60 |
61 | def __str__(self) -> str:
62 | s: str = self._cur_query + "\n"
63 | for i, item in enumerate(self._history):
64 | if item["label"] == "action":
65 | s += f'> {item["value"]}'
66 | elif item["label"] == "observation":
67 | s += item["value"]
68 | # NOT CURRENTLY SUPPORTED
69 | elif item["label"] == "human_edit":
70 | s += f'[human edit]: {item["value"]}'
71 | if i != len(self._history) - 1:
72 | s += "\n"
73 | return s
74 |
75 |
76 | class ModelServer:
77 |
78 | def get_completion_or_embedding(
79 | self,
80 | model_size: str,
81 | message,
82 | temperature: float = 0.0,
83 | max_tokens: int = 256,
84 | get_embedding: bool = False,
85 | ) -> str:
86 | assert model_size in ["70", "8", "7"]
87 |
88 | if not get_embedding:
89 | assert type(message) == list, "Message should be a list."
90 | response = client.chat.completions.create(
91 | model=model_name,
92 | messages=message,
93 | max_tokens=max_tokens,
94 | temperature=temperature,
95 | stop=["<|eot_id|>"],
96 | )
97 | else:
98 | assert type(message) == str, "Message should be a string."
99 | response = client.embeddings.create(
100 | model=model_name,
101 | input=message,
102 | )
103 | return (
104 | str(response.choices[0].message.content)
105 | if not get_embedding
106 | else response.data[0].embedding
107 | )
108 |
109 |
--------------------------------------------------------------------------------
/figures/illusion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uclaml/COPS/989b29882a4ee6cdca828c402268216e4c5f5d85/figures/illusion.png
--------------------------------------------------------------------------------
/hotpotqa/README:
--------------------------------------------------------------------------------
1 | # HotPotQA experiment code
2 |
3 | Code modified from [LATS](https://github.com/lapisrocks/LanguageAgentTreeSearch)
4 |
5 | Setup your local LLM using an OpenAI API compatible endpoint, then set it up in utils.py. We recommend [SGLang](https://github.com/sgl-project/sglang)
6 |
7 | Example command:
8 |
9 | python run.py --model_size 8 --temperature 0.0 --task_start_index 0 --task_end_index 100 --cot_size 5 \
10 | --prompt_sample 'cot' --n_generate_sample 1 --n_evaluate_sample 1 \
11 | --iteration 10 --algorithm 'tot' --cot_method 'knn' --run_name 'hotpotqa_new'
--------------------------------------------------------------------------------
/hotpotqa/base.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class Task:
5 | def __init__(self):
6 | pass
7 |
8 | def __len__(self) -> int:
9 | pass
10 |
11 | def get_input(self, idx: int) -> str:
12 | pass
13 |
14 | def test_output(self, idx: int, output: str):
15 | pass
16 |
--------------------------------------------------------------------------------
/hotpotqa/hotpotqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from base import Task
4 | from hotpot import *
5 | from models import gpt
6 | import logging
7 | from transformers import GPT2Tokenizer
8 | import random
9 |
10 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
11 |
12 |
13 | def get_token_length(text):
14 | return len(tokenizer.encode(text))
15 |
16 |
17 | max_token_length = 4000
18 |
19 | # Configuring the logging
20 |
21 |
22 | class HotPotQATask(Task):
23 | """
24 | Input (x) : a text instruction
25 | Output (y) : a text generation
26 | Reward (r) : # TODO
27 | Input Example:
28 | Output Example:
29 | """
30 |
31 | def __init__(self):
32 | """
33 | file: a text file, each line is some sentences
34 | """
35 | super().__init__()
36 | self.steps = 7
37 | self.stops = ["\nObservation:\n", None]
38 | self.value_cache = {}
39 |
40 | def __len__(self) -> int:
41 | return len(self.data)
42 |
43 | def get_input(self, idx: int) -> str:
44 | return self.data[idx]
45 |
46 | def test_output(self, idx: int, output: str):
47 | output = output.split("Action:\n")[-1]
48 | prompt = score_prompt + output
49 | score_outputs = gpt(prompt, n=5, model="gpt-4")
50 | scores = []
51 | for score_output in score_outputs:
52 | # print(score_output)
53 | pattern = r".*correctness score is (\d+).*"
54 | match = re.match(pattern, score_output, re.DOTALL)
55 | if match:
56 | score = int(match.groups()[0])
57 | scores.append(score)
58 | else:
59 | print(f"------------------score no match: {[score_output]}")
60 | print(scores)
61 | # print('------------')
62 | info = {"rs": scores, "r": sum(scores) / len(scores) if scores else 0}
63 | return info
64 |
65 | @staticmethod
66 | def standard_prompt_wrap(x: str, y: str = "") -> str:
67 | return standard_prompt.format(input=x) + y
68 |
69 | @staticmethod
70 | def generate_self_reflection(z, question):
71 | reflection_mapping = []
72 | trajectories = ""
73 |
74 | sampled_items = random.sample(z, min(3, len(z)))
75 | failed_trajectories = "\n".join([f"{question}\n{traj}\n" for traj in z])
76 | failed_trajectories = [
77 | f"Question: {traj}" for traj in failed_trajectories.split("Question: ")[1:]
78 | ]
79 |
80 | for traj in failed_trajectories:
81 | trajectories += traj
82 |
83 | reflect_prompt = reflection_prompt.format(trajectory=traj)
84 |
85 | reflection = gpt(reflect_prompt)
86 |
87 | trajectories += "Reflection: " + reflection[0] + "\n"
88 |
89 | reflection_mapping.append(
90 | {"question": question, "trajectory": traj, "reflection": reflection[0]}
91 | )
92 |
93 | return reflection_mapping
94 |
95 | @staticmethod
96 | def cot_prompt_wrap(x: str, y: str = "", reflection_mapping_list=[], knn=None):
97 | question = x
98 | input = x + y
99 | trajectories = ""
100 | # z = []
101 | if reflection_mapping_list:
102 | for reflection_mapping in reflection_mapping_list:
103 | traj_with_reflection = (
104 | reflection_mapping["trajectory"]
105 | + "FAILED TRAJECTORY\nReflection: "
106 | + reflection_mapping["reflection"]
107 | + "\n\n"
108 | )
109 | trajectories += traj_with_reflection
110 |
111 | prompt = cot_prompt_feedback.format(trajectories=trajectories, input=input)
112 | if get_token_length(prompt) > max_token_length:
113 | print("Too long")
114 | trajectories = ""
115 | for reflection_mapping in reflection_mapping_list[:3]:
116 | traj_with_reflection = (
117 | reflection_mapping["trajectory"]
118 | + "FAILED TRAJECTORY\nReflection: "
119 | + reflection_mapping["reflection"]
120 | + "\n\n"
121 | )
122 | trajectories += traj_with_reflection
123 | prompt = cot_prompt_feedback_short.format(
124 | trajectories=trajectories, input=input
125 | )
126 |
127 | return prompt
128 | else:
129 | prompt = cot_prompt
130 | if knn:
131 | for traj in knn:
132 | prompt += traj + "\n"
133 | prompt += "\n" + input + "\n"
134 | # if get_token_length(prompt) > max_token_length:
135 | # prompt = cot_prompt_short.format(input=input)
136 | return prompt
137 |
138 | @staticmethod
139 | def vote_prompt_wrap(x: str, ys: list) -> str:
140 | prompt = vote_prompt + "\n" + x + "\n\n"
141 | for i, y in enumerate(ys, 1):
142 | # y = y.replace('Plan:\n', '')
143 | # TODO: truncate the plan part?
144 | prompt += f"Choice {i}:\n{y}\n"
145 | return prompt
146 |
147 | @staticmethod
148 | def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
149 | vote_results = [0] * n_candidates
150 | for vote_output in vote_outputs:
151 | pattern = r".*best trajectory is .*(\d+).*"
152 | match = re.match(pattern, vote_output, re.DOTALL)
153 | if match:
154 | vote = int(match.groups()[0]) - 1
155 | if vote in range(n_candidates):
156 | vote_results[vote] += 1
157 | else:
158 | print(f"vote no match: {[vote_output]}")
159 | return vote_results
160 |
161 | @staticmethod
162 | def compare_prompt_wrap(x: str, ys: list) -> str:
163 | assert len(ys) == 2, "compare prompt only supports 2 candidates"
164 |
165 | # Extract the last Action for each trajectory
166 | last_actions = []
167 | for y in ys:
168 | # Split by line and reverse to start from the end
169 | lines = y.split("\n")[::-1]
170 | for line in lines:
171 | # Check for an Action line and get its content
172 | if "Action" in line:
173 | last_actions.append(line.split("Action")[-1].strip(": "))
174 | break
175 |
176 | assert len(last_actions) == 2, "Expected to find 2 Actions"
177 |
178 | # Construct the prompt with the extracted Actions
179 | prompt = (
180 | compare_prompt
181 | + f"Action 1:{last_actions[0]}\n\nAction 2:{last_actions[1]}\n"
182 | )
183 | return prompt
184 |
185 | @staticmethod
186 | def compare_output_unwrap(compare_output: str):
187 | if "more correct trajectory is 1" in compare_output:
188 | return 0
189 | elif "more correct trajectory is 2" in compare_output:
190 | return 1
191 | elif "two trajectories are similarly correct" in compare_output:
192 | return 0.5
193 | else:
194 | print(f"-----------------compare no match: {[compare_output]}")
195 | return -1
196 |
197 | @staticmethod
198 | def value_prompt_wrap(x: str, y: str, z: list = [], reflections: list = []) -> str:
199 | question = x.split("\n")[0]
200 | # z = []
201 | if len(z) != 0:
202 | failed_trajectories = ""
203 |
204 | # Combine the trajectories with their corresponding reflections
205 | for traj, ref in zip(z, reflections):
206 | failed_trajectories += f"{question}\n{traj}\nThis trajectory is incorrect as {ref['reflection']}\nThus the correctness score is 1\n"
207 |
208 | inp = x + y + "\nThis trajectory is "
209 |
210 | prompt = value_prompt_reasoning_feedback.format(
211 | s="", trajectories=failed_trajectories, input=inp
212 | )
213 |
214 | if get_token_length(prompt) > max_token_length:
215 | prompt = value_prompt_reasoning_feedback_short.format(
216 | s="", trajectories=failed_trajectories, input=inp
217 | )
218 | if len(z) != 0 and False:
219 | failed_trajectories = "\n".join(
220 | [f"{question}\n{traj}\nThus the correctness score is 1\n" for traj in z]
221 | )
222 | inp = x + y + "\nThus the correctness score is "
223 | prompt = value_prompt_feedback.format(
224 | s="", trajectories=failed_trajectories, input=inp
225 | )
226 | if get_token_length(prompt) > max_token_length:
227 | print("Token length exceeded", get_token_length(prompt))
228 | failed_trajectories = "\n".join(
229 | [
230 | f"{question}\n{traj}\nThus the correctness score is 1\n"
231 | for traj in z[:2]
232 | ]
233 | )
234 | inp = y + "\nThus the correctness score is "
235 | prompt = value_prompt_feedback.format(
236 | s="", trajectories=failed_trajectories, input=inp
237 | )
238 | print("New length", get_token_length(prompt))
239 | else:
240 | inp = y + "\nThis trajectory is "
241 | # inp = y + "\nThus the correctess score is "
242 | # prompt = value_prompt.format(s="", input=inp)
243 | prompt = value_prompt_reasoning.format(s="", input=inp)
244 |
245 | return prompt
246 |
247 | @staticmethod
248 | def value_outputs_unwrap(evaluate_prompt: str):
249 | evaluate_prompt = evaluate_prompt[0]
250 | if "10" in evaluate_prompt:
251 | return 1.0
252 | elif "9" in evaluate_prompt:
253 | return 0.9
254 | elif "8" in evaluate_prompt:
255 | return 0.8
256 | elif "7" in evaluate_prompt:
257 | return 0.7
258 | elif "6" in evaluate_prompt:
259 | return 0.6
260 | elif "5" in evaluate_prompt:
261 | return 0.5
262 | elif "4" in evaluate_prompt:
263 | return 0.4
264 | elif "3" in evaluate_prompt:
265 | return 0.3
266 | elif "2" in evaluate_prompt:
267 | return 0.2
268 | elif "1" in evaluate_prompt:
269 | return 0.1
270 | else:
271 | return -1
272 |
--------------------------------------------------------------------------------
/hotpotqa/models.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from utils import ModelServer
3 | import time
4 | import re
5 |
6 | server = None
7 |
8 |
9 | def build_server(config_path):
10 | global server
11 | server = ModelServer(config_path=config_path)
12 |
13 |
14 | def split_and_keep_prefixes(s, delimiters):
15 | regex_pattern = f"({'|'.join(map(re.escape, delimiters))})"
16 | parts = re.split(regex_pattern, s)
17 | result = [parts[0]]
18 | for i in range(1, len(parts), 2):
19 | result.append(parts[i] + (parts[i + 1] if i + 1 < len(parts) else ""))
20 | return result
21 |
22 |
23 | def online_embed(traj):
24 | return server.get_completion_or_embedding(
25 | "7",
26 | message=traj,
27 | get_embedding=True,
28 | )
29 |
30 |
31 | def gpt(prompt, model_size="8", temperature=1.0, max_tokens=100, n=1) -> list:
32 | def call_openai_api(messages, model_size, temperature, max_tokens, n):
33 | outputs = []
34 | while n > 0:
35 | cnt = min(n, 20)
36 | n -= cnt
37 | res = server.get_completion_or_embedding(
38 | model_size, messages, temperature, max_tokens
39 | )
40 | outputs.extend(
41 | [
42 | re.sub(r"^Thought \d+: ", "", choice.message.content)
43 | for choice in res.choices
44 | ]
45 | )
46 | return outputs
47 |
48 | messages = []
49 | parts = re.split(r"(Thought \d+:|Observation \d+:|Question:)", prompt)
50 |
51 | result = [parts[0].strip()]
52 | for i in range(1, len(parts), 2):
53 | if i + 1 < len(parts):
54 | result.append(parts[i] + " " + parts[i + 1].strip())
55 | result.pop()
56 |
57 | for msg in result:
58 | if msg.startswith("Solve"):
59 | messages.append({"role": "user", "content": msg})
60 | if msg.startswith("Thought"):
61 | messages.append({"role": "assistant", "content": msg})
62 | if msg.startswith("Observation"):
63 | messages.append({"role": "user", "content": msg})
64 | if msg.startswith("Question"):
65 | messages.append({"role": "user", "content": msg})
66 |
67 | return call_openai_api(
68 | messages, model_size=model_size, temperature=temperature, max_tokens=max_tokens, n=n
69 | )
70 |
71 |
72 | reflection_prompt = """You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to an Docstore API environment and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[], or you used up your set number of reasoning steps. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.
73 | Here are some examples:
74 | Previous Trial:
75 | Question: The Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of what?
76 | Thought 1: I need to search Rome Protocols, find the three Prime Ministers, then find what they were assassinated as part of.
77 | Action 1: Search[Rome Protocols]
78 | Observation 1: The Rome Protocols were a series of three international agreements concluded in Rome on 17 March 1934 between the governments of Austria, Hungary and Italy. They were signed by Italian Prime Minister Benito Mussolini, Austrian Chancellor Engelbert Dollfuss and Hungarian Prime Minister Gyula Gömbös. All of the protocols went into effect on 12 July 1934 and were registered in League of Nations Treaty Series on 12 December 1934.
79 | Thought 2: The three Prime Ministers were Benito Mussolini, Engelbert Dollfuss, and Gyula Gömbös. I need to search Benito Mussolini and find what he was assassinated as part of.
80 | Action 2: Search[Benito Mussolini]
81 | Observation 2: Benito Amilcare Andrea Mussolini (UK: MU(U)SS-ə-LEE-nee, US: MOOSS-, Italian: [beˈniːto aˈmilkare anˈdrɛːa mussoˈliːni]; 29 July 1883 – 28 April 1945) was an Italian politician and journalist who founded and led the National Fascist Party (PNF). He was Prime Minister of Italy from the March on Rome in 1922 until his deposition in 1943, as well as "Duce" of Italian fascism from the establishment of the Italian Fasces of Combat in 1919 until his summary execution in 1945 by Italian partisans. As dictator of Italy and principal founder of fascism, Mussolini inspired and supported the international spread of fascist movements during the inter-war period.Mussolini was originally a socialist politician and a journalist at the Avanti! newspaper. In 1912, he became a member of the National Directorate of the Italian Socialist Party (PSI), but he was expelled from the PSI for advocating military intervention in World War I, in opposition to the party's stance on neutrality. In 1914, Mussolini founded a new journal, Il Popolo d'Italia, and served in the Royal Italian Army during the war until he was wounded and discharged in 1917. Mussolini denounced the PSI, his views now centering on Italian nationalism instead of socialism, and later founded the fascist movement which came to oppose egalitarianism and class conflict, instead advocating "revolutionary nationalism" transcending class lines. On 31 October 1922, following the March on Rome (28–30 October), Mussolini was appointed prime minister by King Victor Emmanuel III, becoming the youngest individual to hold the office up to that time. After removing all political opposition through his secret police and outlawing labor strikes, Mussolini and his followers consolidated power through a series of laws that transformed the nation into a one-party dictatorship. Within five years, Mussolini had established dictatorial authority by both legal and illegal means and aspired to create a totalitarian state. In 1929, Mussolini signed the Lateran Treaty with the Holy See to establish Vatican City.
82 | Mussolini's foreign policy aimed to restore the ancient grandeur of the Roman Empire by expanding Italian colonial possessions and the fascist sphere of influence. In the 1920s, he ordered the Pacification of Libya, instructed the bombing of Corfu over an incident with Greece, established a protectorate over Albania, and incorporated the city of Fiume into the Italian state via agreements with Yugoslavia. In 1936, Ethiopia was conquered following the Second Italo-Ethiopian War and merged into Italian East Africa (AOI) with Eritrea and Somalia. In 1939, Italian forces annexed Albania. Between 1936 and 1939, Mussolini ordered the successful Italian military intervention in Spain in favor of Francisco Franco during the Spanish Civil War. Mussolini's Italy initially tried to avoid the outbreak of a second global war, sending troops at the Brenner Pass to delay Anschluss and taking part in the Stresa Front, the Lytton Report, the Treaty of Lausanne, the Four-Power Pact and the Munich Agreement. However, Italy then alienated itself from Britain and France by aligning with Germany and Japan. Germany invaded Poland on 1 September 1939, resulting in declarations of war by France and the UK and the start of World War II.
83 | On 10 June 1940, Mussolini decided to enter the war on the Axis side. Despite initial success, the subsequent Axis collapse on multiple fronts and eventual Allied invasion of Sicily made Mussolini lose the support of the population and members of the Fascist Party. As a consequence, early on 25 July 1943, the Grand Council of Fascism passed a motion of no confidence in Mussolini; later that day King Victor Emmanuel III dismissed him as head of government and had him placed in custody, appointing Pietro Badoglio to succeed him as Prime Minister. After the king agreed to an armistice with the Allies, on 12 September 1943 Mussolini was rescued from captivity in the Gran Sasso raid by German paratroopers and Waffen-SS commandos led by Major Otto-Harald Mors. Adolf Hitler, after meeting with the rescued former dictator, then put Mussolini in charge of a puppet regime in northern Italy, the Italian Social Republic (Italian: Repubblica Sociale Italiana, RSI), informally known as the Salò Republic, causing a civil war. In late April 1945, in the wake of near total defeat, Mussolini and his mistress Clara Petacci attempted to flee to Switzerland, but both were captured by Italian communist partisans and summarily executed by firing squad on 28 April 1945 near Lake Como. The bodies of Mussolini and his mistress were then taken to Milan, where they were hung upside down at a service station to publicly confirm their demise.
84 | Thought 3: Benito Mussolini was assassinated as part of the collapse of the Axis powers during World War II. So the Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of World War II.
85 | Action 3: Finish[World War II]
86 |
87 | Reflection: I searched one of the prime ministers involved in the signing, then attemted to answer right away. I should have searched each of the prime ministers, then looked up 'death' on each of their pages in order to get more information before answering.
88 |
89 | Previous Trial:
90 | Question: Kam Heskin plays Paige Morgan in a 2004 film directed by who?
91 | Thought 1: I need to search Kam Heskin and Paige Morgan, find the 2004 film they are in, then find the director of the film.
92 | Action 1: Search[Kam Heskin]
93 | Observation 1: Kam Erika Heskin (born May 8, 1973) is an American actress. She began her career playing Caitlin Deschanel on the NBC daytime soap opera Sunset Beach (1998–1999), before appearing in films Planet of the Apes (2001) and Catch Me If You Can (2002). Heskin went on in 2003 to play Elizabeth Bennet in an independent film Pride & Prejudice: A Latter-Day Comedy, Robin in Charmed (2004), and Paige Morgan in The Prince and Me film franchise (2006–2010).
94 | Thought 2: Kam Heskin plays Paige Morgan in The Prince and Me film franchise. I need to search The Prince and Me film franchise and find the director of the 2004 film.
95 | Action 2: Search[The Prince and Me film franchise]
96 | Observation 2: Could not find [The Prince and Me film franchise]. Similar: ['Kam Heskin', 'Prince of Persia', 'Harry Potter and the Half-Blood Prince (film)', 'Prince of Persia: The Sands of Time (film)', 'The Chronicles of Narnia (film series)', 'Majestic Prince (franchise)', 'Kingsman (franchise)', 'The Little Prince (2015 film)', 'Shrek (franchise)', 'Harry Potter (film series)']
97 | Thought 3: I need to search The Prince and Me (2004 film) instead.
98 | Action 3: Search[The Prince and Me (2004 film)]
99 | Observation 3: Could not find [The Prince and Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Kam Heskin', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince of Persia', 'Brooklynn Prince', 'Prince (musician)']
100 | Thought 4: I need to search The Prince & Me (2004 film).
101 | Action 4: Search[The Prince & Me (2004 film)]
102 | Observation 4: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince']
103 | Thought 5: I need to search The Prince & Me (2004 film) instead.
104 | Action 5: Search[The Prince & Me (2004 film)]
105 | Observation 5: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince']
106 | Thought 6: I need to search The Prince & Me (2004 film).
107 | Action 6: Search[The Prince & Me (2004 film)]
108 | Observation 6: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince']
109 |
110 | Reflection: I got stuck in a loop where I kept trying to search 'The Prince & Me (2004 film)' but the page could not be found. Instead I should have tried to search the similar results that had a similar name to see and they were made in 2004.
111 |
112 | Previous Trial:
113 | Question: The Oberoi family is part of a hotel company that has a head office in what city?
114 | Thought 1: I need to search the Oberoi family and find the city where the hotel company's head office is located.
115 | Action 1: Search[Oberoi family]
116 | Observation 1: The Oberoi Group is a luxury hotel group with its head office in New Delhi, India.[2] Founded in 1934, the company owns and operates 32 luxury hotels and two river cruise ships in 7 countries, primarily under its Oberoi Hotels & Resorts and Trident brands.[3] The group also operates The Oberoi Centre for Learning and Development, which is regarded as one of Asia's top institutions for hospitality education.[4]. The foundations of the Oberoi Group date back to 1934 when The Rai Bahadur Mohan Singh Oberoi, the founder of the group, bought two properties: the Maidens in Delhi and the Clarke's in Shimla.[5] In the following years Oberoi, assisted by his two sons, Tilak Raj Singh Oberoi and Prithvi Raj Singh Oberoi (P.R.S. Oberoi
117 | Thought 2: The Oberoi Group has its head office in New Delhi, India. I just need to finish with New Delhi.
118 | Action 2: Finish[New Delhi]
119 | Observation 2: Episode finished, reward = 0
120 |
121 | Reflection: I appeared to have retrieved the correct information about The Oberoi Family and the location of it's head office, and provided a corresponding answer. However this answer does not exactly match the ground truth answer so I should try a different wording, such as Delhi.
122 |
123 | Previous trial:"""
124 |
125 | def reflgpt(prompt, model_size="8", temperature=1.0, max_tokens=100, n=1) -> list:
126 | def call_openai_api(messages, model_size, temperature, max_tokens, n):
127 | outputs = []
128 | while n > 0:
129 | cnt = min(n, 20)
130 | n -= cnt
131 | res = server.get_completion_or_embedding(
132 | model_size, messages, temperature, max_tokens
133 | )
134 | outputs.extend(
135 | [
136 | re.sub(r"^Thought \d+: ", "", choice.message.content)
137 | for choice in res.choices
138 | ]
139 | )
140 | return outputs
141 |
142 | # messages = [{"role": "user", "content": prompt}]
143 | messages = []
144 | messages.append({"role": "user", "content": reflection_prompt+"\n"+prompt+"\nReflection:"})
145 | # print(messages)
146 | # time.sleep(30)
147 | return call_openai_api(
148 | messages, model_size=model_size, temperature=temperature, max_tokens=max_tokens, n=n
149 | )
--------------------------------------------------------------------------------
/hotpotqa/requirements.txt:
--------------------------------------------------------------------------------
1 | faiss
2 | openai
3 | transformers
4 | logging
5 | re
6 | typing
7 | itertools
8 | functools
9 | wikienv
10 | wrappers
11 | requests
12 |
13 |
--------------------------------------------------------------------------------
/hotpotqa/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | import math
6 | from models import online_embed, build_server
7 | import faiss
8 | from hotpotqa import HotPotQATask
9 | from tot import dfs_search, quote_env
10 | import logging
11 | import random
12 |
13 |
14 | def random_selection(lst, n=5):
15 | if len(lst) <= n:
16 | return lst
17 | else:
18 | return random.sample(lst, n)
19 |
20 |
21 | def run(args):
22 | task = HotPotQATask()
23 | print(task)
24 | logs, cnt_avg, cnt_any = [], 0, 0
25 |
26 | config_path = os.path.join(args.run_name, "config.json")
27 |
28 | with open(config_path, "w", encoding="utf-8") as wf:
29 | info_dict = vars(args)
30 | info_dict["is_running"] = True
31 | json.dump(info_dict, wf, indent=4)
32 |
33 | build_server(config_path=config_path)
34 |
35 | os.makedirs(os.path.dirname(args.log_dir), exist_ok=True)
36 |
37 | logging.basicConfig(
38 | filename=args.log_dir,
39 | level=logging.INFO,
40 | format="%(asctime)s - %(levelname)s - %(message)s",
41 | filemode="a",
42 | )
43 |
44 | wins = {}
45 | lose = {}
46 | trajectories = []
47 | embedding_array = np.zeros((0, 3584))
48 | tongji=[]
49 | cnt=0
50 | for trial in range(10):
51 | print("Trial")
52 | print(trial)
53 | count = 0
54 | task_accs = []
55 | info = []
56 | emb_db = faiss.IndexFlatL2(3584)
57 | emb_db.add(embedding_array.astype("float32"))
58 |
59 | for i in range(args.task_start_index, args.task_end_index):
60 | # solve
61 | if i in wins:
62 | continue
63 | prev = None
64 | knnret = []
65 | if trajectories and args.cot_size>0:
66 | if i in lose:
67 | prev = lose[i]
68 | else:
69 | prev=quote_env(i)
70 | if args.cot_method=="knn":
71 | fail_vec = online_embed(str(prev))
72 | dist=[]
73 | for index, row in np.ndenumerate(embedding_array):
74 | if index[1] == 0:
75 | realN = (1.0 / np.linalg.norm(np.array(fail_vec) - embedding_array[index[0]]))
76 | dist.append((realN,index[0]))
77 | dist.sort(key=lambda x: x[0], reverse=True)
78 | sz_now=min(len(trajectories),args.cot_size)
79 | dist=dist[:sz_now]
80 | dist=[(math.exp(0.0*x),y) for (x,y) in dist]
81 | original_sum=sum([x for (x,_) in dist])
82 | dist=[(x *float(sz_now) / original_sum,y) for (x,y) in dist]
83 | tot=0
84 | cntD=[]
85 | realD=[]
86 | for (x,y) in dist:
87 | cntD.append(math.floor(x))
88 | realD.append(x)
89 | tot+=math.floor(x)
90 | while totmaxx:
95 | maxx=realD[ig]-float(cntD[ig])
96 | maxi=ig
97 | cntD[maxi]+=1
98 | tot+=1
99 | weights=[]
100 | for (x,y) in dist:
101 | weights.append(x)
102 | normalized_weights = [w / sum(weights) for w in weights]
103 | for ig in range(sz_now):
104 | s_ind = np.random.choice(len(weights), p=normalized_weights)
105 | (_,y)=dist[s_ind]
106 | knnret.append(trajectories[y])
107 | else:
108 | knnret=random.sample(trajectories, min(args.cot_size,len(trajectories)))
109 | state, value, all_nodes, reward, em, failt, succt = dfs_search(
110 | args, task, i, args.iteration, knnret
111 | )
112 | if failt:
113 | print("Fail")
114 | print(i)
115 | lose[i] = failt[0]
116 | if succt:
117 | cnt=cnt+1
118 | print("Success")
119 | print(i)
120 | wins[i] = 1
121 | vec = online_embed(str(succt[0]))
122 | trajectories.append(succt[0])
123 | embedding_array = np.vstack((embedding_array, np.array(vec)))
124 | # log main metric
125 | if em is None:
126 | em = 0
127 | task_accs.append(em)
128 | cnt_avg = sum(task_accs) / len(task_accs)
129 | print(i, "len(task_accs)", len(task_accs), "cnt_avg", cnt_avg, "\n")
130 | tongji.append(cnt)
131 | print(tongji)
132 | n = args.task_end_index - args.task_start_index
133 |
134 |
135 | def parse_args():
136 | args = argparse.ArgumentParser()
137 | args.add_argument("--model_size", type=str, default="8")
138 | args.add_argument("--temperature", type=float, default=1.0)
139 | args.add_argument("--task_start_index", type=int, default=900)
140 | args.add_argument("--task_end_index", type=int, default=1000)
141 | args.add_argument("--prompt_sample", type=str, choices=["standard", "cot"])
142 | args.add_argument("--n_generate_sample", type=int, default=1)
143 | args.add_argument("--n_evaluate_sample", type=int, default=1)
144 | args.add_argument("--iteration", type=int, default=50)
145 | args.add_argument("--algorithm", type=str, choices=["lats", "rap", "tot"])
146 | args.add_argument("--cot_method", type=str, choices=["knn", "random", "None"])
147 | args.add_argument("--run_name", type=str)
148 | args.add_argument("--log_file_path", type=str)
149 | args.add_argument("--log_dir", type=str)
150 | args.add_argument("--cot_size", type=int, default=0)
151 | args = args.parse_args()
152 | return args
153 |
154 |
155 | if __name__ == "__main__":
156 | args = parse_args()
157 | print(args)
158 | run(args)
159 |
--------------------------------------------------------------------------------
/hotpotqa/tot.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import numpy as np
3 | from functools import partial
4 | from models import gpt
5 | import wikienv, wrappers
6 | import requests
7 | import logging
8 | import re
9 |
10 | # Configuring the logging
11 | logging.basicConfig(
12 | filename="tot_150it.log",
13 | level=logging.INFO,
14 | format="%(asctime)s - %(levelname)s - %(message)s",
15 | filemode="w",
16 | )
17 |
18 | # Test logging
19 | logging.info("This is a test log entryyyyyy.")
20 |
21 | env = wikienv.WikiEnv()
22 | env = wrappers.HotPotQAWrapper(env, split="train")
23 | env = wrappers.LoggingWrapper(env)
24 |
25 | logging.info("Logging has been configured.")
26 |
27 | global reflection_map
28 | reflection_map = []
29 |
30 |
31 | def step(env, action):
32 | attempts = 0
33 | while attempts < 10:
34 | try:
35 | return env.step(action)
36 | except requests.exceptions.Timeout:
37 | attempts += 1
38 |
39 |
40 | def get_value(task, x, y, n_evaluate_sample, cache_value=True):
41 | global reflection_map
42 | unique_trajectories = get_unique_trajectories(failed_trajectories)
43 | value_prompt = task.value_prompt_wrap(x, y, unique_trajectories, reflection_map)
44 | logging.info(f"Current: {x}")
45 | logging.info(f"Current: {y}")
46 | if cache_value and value_prompt in task.value_cache:
47 | return task.value_cache[value_prompt]
48 | logging.info(f"VALUE PROMPT: {value_prompt}")
49 | value_outputs = gpt(value_prompt, n=n_evaluate_sample)
50 | logging.info(f"VALUE OUTPUTS: {value_outputs}")
51 | value = task.value_outputs_unwrap(value_outputs)
52 | logging.info(f"VALUES: {value}")
53 | if cache_value:
54 | task.value_cache[value_prompt] = value
55 | return value
56 |
57 |
58 | def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
59 | values = []
60 | local_value_cache = {}
61 | for y in ys: # each partial output
62 | if y in local_value_cache: # avoid duplicate candidates
63 | value = 0
64 | else:
65 | value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
66 | local_value_cache[y] = value
67 | values.append(value)
68 | return values
69 |
70 |
71 | def get_samples(task, x, y, n_generate_sample, prompt_sample, knn=None):
72 | unique_trajectories = get_unique_trajectories(failed_trajectories)
73 | global reflection_map
74 | reflection_map = []
75 | if prompt_sample == "standard":
76 | prompt = task.standard_prompt_wrap(x, y, [])
77 | elif prompt_sample == "cot":
78 | prompt = task.cot_prompt_wrap(x, y, [], knn)
79 | else:
80 | raise ValueError(f"prompt_sample {prompt_sample} not recognized")
81 | logging.info(f"PROMPT: {prompt}")
82 | samples = gpt(prompt, n=n_generate_sample)
83 | return [y + _ for _ in samples]
84 |
85 |
86 | def get_unique_trajectories(failed_trajectories, num=2):
87 | unique_trajectories = []
88 | seen_final_answers = set()
89 | for traj in failed_trajectories:
90 | final_answer = traj.get("final_answer")
91 | if final_answer not in seen_final_answers:
92 | unique_trajectories.append(node_trajectory_to_text(traj["trajectory"]))
93 | seen_final_answers.add(final_answer)
94 | if len(unique_trajectories) >= num:
95 | break
96 | return unique_trajectories
97 |
98 |
99 | class Node:
100 | def __init__(self, state, question, parent=None, knn=None):
101 | self.state = (
102 | {"thought": "", "action": "", "observation": ""} if state is None else state
103 | )
104 | self.parent = parent
105 | self.question = question
106 | self.children = []
107 | self.visits = 0
108 | self.value = 0
109 | self.depth = 0 if parent is None else parent.depth + 1
110 | self.is_terminal = False
111 | self.reward = 0
112 | self.exhausted = False # If all children are terminal
113 | self.em = 0 # Exact match, evaluation metric
114 | self.knn = knn
115 |
116 | def ques(self):
117 | return self.question
118 |
119 | def uct(self):
120 | if self.visits == 0:
121 | # return float('inf')
122 | return self.value * 2
123 | return self.value / self.visits + np.sqrt(
124 | 2 * np.log(self.parent.visits) / self.visits
125 | )
126 |
127 | def uct_with_depth(self, C1=1, C2=1):
128 | if self.visits == 0:
129 | return self.value
130 | exploitation_term = self.value / self.visits
131 | exploration_term = np.sqrt(2 * np.log(self.parent.visits) / self.visits)
132 | depth_term = self.depth
133 | return exploitation_term + C1 * exploration_term + C2 * depth_term
134 |
135 | def __str__(self):
136 | return f"Node(depth={self.depth}, value={self.value:.2f}, visits={self.visits}, thought={self.state['thought']}, action={self.state['action']}, observation={self.state['observation']})"
137 |
138 | def to_dict(self):
139 | return {
140 | "state": self.state,
141 | "question": self.question,
142 | "parent": self.parent.to_dict() if self.parent else None,
143 | "children": [child.to_dict() for child in self.children],
144 | "visits": self.visits,
145 | "value": self.value,
146 | "depth": self.depth,
147 | "is_terminal": self.is_terminal,
148 | "reward": self.reward,
149 | "em": self.em,
150 | "knn": self.knn,
151 | }
152 |
153 |
154 | def node_trajectory_to_text(node_string):
155 | lines = node_string.split("\n")
156 | formatted_lines = []
157 | for line in lines:
158 | if line.startswith("Question"):
159 | formatted_lines.append(line)
160 | continue
161 | try:
162 | depth = int(line.split(",")[0].split("=")[1].strip())
163 | thought = line.split(", thought=")[1].split(", action=")[0].strip()
164 | action = line.split(", action=")[1].split(", observation=")[0].strip()
165 | observation = line.split(", observation=")[1].split(")")[0].strip()
166 | except IndexError:
167 | continue
168 |
169 | if depth != 0:
170 | if thought:
171 | formatted_lines.append(f"Thought {depth}: {thought}")
172 | if action:
173 | formatted_lines.append(f"Action {depth}: {action}")
174 | if observation:
175 | formatted_lines.append(f"Observation {depth}: {observation}")
176 | formatted_lines.pop()
177 | return "\n".join(formatted_lines)
178 |
179 |
180 | def traj_depth(node_string):
181 | lines = node_string.split("\n")
182 | formatted_lines = []
183 | ret = 0
184 | for line in lines:
185 | try:
186 | depth = int(line.split(",")[0].split("=")[1].strip())
187 | thought = line.split(", thought=")[1].split(", action=")[0].strip()
188 | action = line.split(", action=")[1].split(", observation=")[0].strip()
189 | observation = line.split(", observation=")[1].split(")")[0].strip()
190 | except IndexError:
191 | continue
192 | if depth > ret:
193 | ret = depth
194 | return ret
195 |
196 |
197 | def collect_all_nodes(node):
198 | """Recursively collect all nodes starting from the given node."""
199 | nodes = [node]
200 | for child in node.children:
201 | nodes.extend(collect_all_nodes(child))
202 | return nodes
203 |
204 |
205 | def collect_trajectory(node):
206 | trajectory = []
207 | ques = ""
208 | while node:
209 | ques = "Question: " + node.question
210 | trajectory.append(str(node))
211 | node = node.parent
212 | if len(ques) > 0:
213 | trajectory.append(ques)
214 | return "\n".join(reversed(trajectory))
215 |
216 |
217 | def get_substrings_between_brackets(s):
218 | # 使用正则表达式找到方括号内的所有内容
219 | # 方括号在正则中是特殊字符,需要转义
220 | pattern = r"\[(.*?)\]"
221 | # re.findall 会返回所有非重叠的匹配
222 | matches = re.findall(pattern, s)
223 | return matches[0]
224 |
225 | def quote_env(idx):
226 | return str(env.reset(idx=idx))
227 |
228 |
229 | def dfs_search(args, task, idx, iterations, knnret, depth_limit=7, to_print=True):
230 | global gpt
231 | global failed_trajectories
232 | global success_trajectories
233 | gpt = partial(gpt, model_size=args.model_size, temperature=args.temperature)
234 | x = env.reset(idx=idx)
235 | if to_print:
236 | print(idx, x)
237 | root = Node(state=None, question=x)
238 | all_nodes = []
239 | failed_trajectories = []
240 | success_trajectories = []
241 | stack = [root]
242 | it = 0
243 | knn = []
244 | if knnret:
245 | for traj in knnret:
246 | format_traj = node_trajectory_to_text(traj["trajectory"])
247 | # format_traj+=f"Action {traj_depth(traj['trajectory'])}: Finish[{get_substrings_between_brackets(traj['final_answer'])}]"+"\n"
248 | knn.append(format_traj)
249 | last_node = None
250 | while stack and it < iterations:
251 | node = stack.pop()
252 | last_node = node
253 | logging.info(f"DFS at node depth {node.depth}...")
254 |
255 | if node.is_terminal and node.reward == 1:
256 | logging.info(f"Terminal node with reward 1 found at depth {node.depth}")
257 | return (
258 | node.state,
259 | node.value,
260 | all_nodes,
261 | node.reward,
262 | node.em,
263 | failed_trajectories,
264 | success_trajectories,
265 | )
266 |
267 | if node.is_terminal and node.reward == 0:
268 | logging.info(f"Terminal node with reward 0 found at depth {node.depth}")
269 | return (
270 | node.state,
271 | node.value,
272 | all_nodes,
273 | node.reward,
274 | node.em,
275 | failed_trajectories,
276 | success_trajectories,
277 | )
278 |
279 | if node.depth >= depth_limit:
280 | logging.info("Depth limit reached")
281 | it += 1
282 | continue # go to next iteration
283 |
284 | expand_node(node, args, task, knn=knn)
285 | stack.extend(reversed(node.children)) # adding all child nodes to stack for DFS
286 |
287 | all_nodes = [(node, node.value) for node in collect_all_nodes(root)]
288 | logging.info(f"State of all_nodes after iteration: {all_nodes}")
289 | it += 1
290 | # If we reach here, no solution was found
291 | logging.info("All paths explored. No solution found.")
292 | if len(failed_trajectories) == 0:
293 | trajectory = collect_trajectory(last_node)
294 | failed_trajectories.append({"trajectory": trajectory, "final_answer": ""})
295 | return root, 0, all_nodes, 0, 0, failed_trajectories, success_trajectories
296 |
297 |
298 | def select_node_dfs(stack):
299 | return stack[-1] if stack else None # return the last node in the stack
300 |
301 |
302 | # FYI: deleted mcts search here
303 |
304 |
305 | def select_node(node):
306 | while node and node.children:
307 | logging.info(
308 | f"Selecting from {len(node.children)} children at depth {node.depth}."
309 | )
310 |
311 | terminal_children = [child for child in node.children if child.is_terminal]
312 | terminal_status = [child.is_terminal for child in node.children]
313 |
314 | if len(terminal_children) == len(node.children):
315 | logging.info(
316 | f"All children are terminal at depth {node.depth}. Backtracking..."
317 | )
318 | if node.parent:
319 | node.parent.children.remove(node)
320 | node = node.parent
321 | continue
322 |
323 | node_with_reward_1 = next(
324 | (child for child in terminal_children if child.reward == 1), None
325 | )
326 | if node_with_reward_1:
327 | logging.info(f"Found terminal node with reward 1 at depth {node.depth}.")
328 | return node_with_reward_1
329 |
330 | node = max(
331 | (child for child in node.children if not child.is_terminal),
332 | key=lambda child: child.uct(),
333 | default=None,
334 | )
335 |
336 | while node.is_terminal and node.reward != 1:
337 | node = max(
338 | (child for child in node.parent.children if not child.is_terminal),
339 | key=lambda child: child.uct(),
340 | default=None,
341 | )
342 |
343 | logging.info(f"Selected node at depth {node.depth} with UCT {node.uct()}.")
344 |
345 | return node # This will return None if all paths from the root are exhausted
346 |
347 |
348 | def expand_node(node, args, task, knn=None):
349 | if node.depth >= 7:
350 | logging.info("Depth limit reached")
351 | print("Depth limit reached")
352 | node.is_terminal = True
353 | return
354 | new_nodes = generate_new_states(node, args, task, knn=knn)
355 | node.children.extend(new_nodes)
356 |
357 |
358 | def generate_new_states(node, args, task, knn=None):
359 | prompt = generate_prompt(node)
360 | sampled_actions = get_samples(
361 | task,
362 | prompt,
363 | f"Thought {node.depth + 1}: ",
364 | args.n_generate_sample,
365 | prompt_sample=args.prompt_sample,
366 | knn=knn,
367 | )
368 | logging.info(f"SAMPLED ACTION: {sampled_actions}")
369 |
370 | unique_states = {} # Store unique states here
371 | for action in sampled_actions:
372 | new_state = node.state.copy() # Make a copy of the parent node's state
373 |
374 | thought_line = next(
375 | (
376 | line.split(":")[1].strip()
377 | for line in action.split("\n")
378 | if line.startswith(f"Thought {node.depth + 1}")
379 | ),
380 | "",
381 | )
382 | action_line = next(
383 | (
384 | line.split(":")[1].strip()
385 | for line in action.split("\n")
386 | if line.startswith("Action") and ":" in line
387 | ),
388 | None,
389 | )
390 |
391 | # Use thought and action to form a unique key
392 | unique_key = f"{thought_line}::{action_line}"
393 |
394 | if unique_key in unique_states:
395 | continue # Skip if this state already exists
396 |
397 | if action_line:
398 | action_type = (
399 | action_line.split("[")[0] if "[" in action_line else action_line
400 | )
401 | action_param = (
402 | action_line.split("[")[1].split("]")[0] if "[" in action_line else ""
403 | )
404 | obs, r, done, info = step(env, f"{action_type.lower()}[{action_param}]")
405 |
406 | # Update the new state dictionary
407 | new_state["thought"] = thought_line
408 | new_state["action"] = action_line
409 | new_state["observation"] = obs
410 |
411 | new_node = Node(state=new_state, question=node.question, parent=node)
412 | new_node.is_terminal = r == 1 or done
413 | new_node.reward = r
414 | if r == 1:
415 | new_node.em = info.get("em")
416 | unique_states[unique_key] = new_node # Add this state to unique_states
417 | logging.info(f"NEW NODE: {new_node}")
418 | logging.info(f"Feedback: {info}")
419 |
420 | if new_node.is_terminal and r == 0:
421 | trajectory = collect_trajectory(new_node)
422 | failed_trajectories.append(
423 | {
424 | "trajectory": trajectory,
425 | "final_answer": f"{action_type.lower()}[{action_param}]",
426 | }
427 | )
428 | if new_node.is_terminal and r == 1:
429 | trajectory = collect_trajectory(new_node)
430 | success_trajectories.append(
431 | {
432 | "trajectory": trajectory,
433 | "final_answer": f"{action_type.lower()}[{action_param}]",
434 | }
435 | )
436 |
437 | return list(unique_states.values()) # Return unique nodes as a list
438 |
439 |
440 | def evaluate_node(node, args, task):
441 | child_prompts = [
442 | generate_prompt(child) for child in node.children if not child.is_terminal
443 | ]
444 | votes = get_values(task, node.question, child_prompts, args.n_evaluate_sample)
445 |
446 | logging.info(f"Length of votes: {len(votes)}")
447 | logging.info(f"Length of node.children: {len(node.children)}")
448 |
449 | # Pre-allocate votes list
450 | votes = votes + [0] * (len(node.children) - len(votes))
451 |
452 | max_vote = max(votes) if votes else 1
453 | if max_vote == 0:
454 | max_vote = 1 # Avoid division by zero
455 |
456 | terminal_conditions = [1 if child.is_terminal else 0 for child in node.children]
457 | for i, condition in enumerate(terminal_conditions):
458 | if condition == 1:
459 | votes[i] = max_vote + 1
460 |
461 | for i, child in enumerate(node.children):
462 | child.value = votes[i] / max_vote # Now safe from division by zero
463 |
464 | return sum(votes) / len(votes) if votes else 0
465 |
466 |
467 | def print_tree(node, level=0):
468 | indent = " " * level
469 | print(f"{indent}{node}")
470 | for child in node.children:
471 | print_tree(child, level + 1)
472 |
473 |
474 | def backpropagate(node, value):
475 | while node:
476 | node.visits += 1
477 | if node.is_terminal:
478 | if node.reward == 1:
479 | node.value = (node.value * (node.visits - 1) + value) / node.visits
480 | logging.info(
481 | f"Backpropagating with reward 1 at depth {node.depth}. New value: {node.value}."
482 | )
483 | elif node.reward == 0:
484 | node.value = (node.value * (node.visits - 1) + (-1)) / node.visits
485 | logging.info(
486 | f"Backpropagating with reward 0 at depth {node.depth}. New value: {node.value}."
487 | )
488 | else:
489 | node.value = (node.value * (node.visits - 1) + value) / node.visits
490 | logging.info(
491 | f"Backpropagating at depth {node.depth}. New value: {node.value}."
492 | )
493 |
494 | node = node.parent
495 |
496 |
497 | def generate_prompt(node):
498 | trajectory = []
499 | question = node.question
500 | while node:
501 | new_segment = []
502 | if node.state["thought"]:
503 | new_segment.append(f"Thought {node.depth}: {node.state['thought']}")
504 | if node.state["action"]:
505 | new_segment.append(f"Action {node.depth}: {node.state['action']}")
506 | if (
507 | node.state["observation"] and node.depth != 0
508 | ): # Exclude the observation from the root node
509 | new_segment.append(f"Observation {node.depth}: {node.state['observation']}")
510 | trajectory.append("\n".join(new_segment))
511 | node = node.parent
512 | return question + "\n".join(reversed(trajectory))
513 |
--------------------------------------------------------------------------------
/hotpotqa/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
4 | os.environ["PYTHONUTF8"] = "1"
5 | import time
6 | from typing import Dict, List
7 | import openai
8 | import random, json
9 | from IPython import embed
10 | import random
11 | import math
12 |
13 |
14 | class ModelServer:
15 |
16 | def get_completion_or_embedding(
17 | self,
18 | model_size: str,
19 | message,
20 | temperature: float = 0.0,
21 | max_tokens: int = 256,
22 | get_embedding: bool = False,
23 | ) -> str:
24 | assert model_size in ["70", "8", "7"]
25 |
26 | if not get_embedding:
27 | assert type(message) == list, "Message should be a list."
28 | response = client.chat.completions.create(
29 | model=model_name,
30 | messages=message,
31 | max_tokens=max_tokens,
32 | temperature=temperature,
33 | stop=["<|eot_id|>", "\nObservation", "Observation"],
34 | )
35 | else:
36 | assert type(message) == str, "Message should be a string."
37 | response = client.embeddings.create(
38 | model=model_name,
39 | input=message,
40 | )
41 | if get_embedding:
42 | return response.data[0].embedding
43 | else:
44 | return response
45 |
46 |
--------------------------------------------------------------------------------
/webshop/README:
--------------------------------------------------------------------------------
1 | # Webshop experiment code
2 |
3 | Code modified from [LATS](https://github.com/lapisrocks/LanguageAgentTreeSearch)
4 |
5 | Setup your local LLM using an OpenAI API compatible endpoint, then set it up in utils.py. We recommend [SGLang](https://github.com/sgl-project/sglang)
6 |
7 | Moreover, in lats.py, the address for Webshop server is set as http://127.0.0.1:3000
8 |
9 | Example command:
10 |
11 | python run.py --model_size 8 --temperature 0.0 --task_start_index 0 --task_end_index 50 --cot_size 5 \
12 | --prompt_sample 'cot' --n_generate_sample 1 --n_evaluate_sample 1 \
13 | --iteration 10 --algorithm 'tot' --cot_method 'knn' --run_name 'webshop_new'
--------------------------------------------------------------------------------
/webshop/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | class Task:
3 | def __init__(self):
4 | pass
5 |
6 | def __len__(self) -> int:
7 | pass
8 |
9 | def get_input(self, idx: int) -> str:
10 | pass
11 |
12 | def test_output(self, idx: int, output: str):
13 | pass
--------------------------------------------------------------------------------
/webshop/lats.py:
--------------------------------------------------------------------------------
1 |
2 | #!/usr/bin/env python
3 | # coding: utf-8
4 |
5 | # # Setup
6 |
7 | # In[1]:
8 |
9 | import os
10 | import sys
11 | import copy
12 | import itertools
13 | import numpy as np
14 | from functools import partial
15 | from models import gpt
16 | import requests
17 | import logging
18 | import random
19 |
20 | completion_tokens = prompt_tokens = 0
21 |
22 | import requests
23 | from bs4 import BeautifulSoup
24 | from bs4.element import Comment
25 |
26 | WEBSHOP_URL = "http://127.0.0.1:3000"
27 | ACTION_TO_TEMPLATE = {
28 | 'Description': 'description_page.html',
29 | 'Features': 'features_page.html',
30 | 'Reviews': 'review_page.html',
31 | 'Attributes': 'attributes_page.html',
32 | }
33 |
34 | def clean_str(p):
35 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
36 |
37 |
38 | def tag_visible(element):
39 | ignore = {'style', 'script', 'head', 'title', 'meta', '[document]'}
40 | return (
41 | element.parent.name not in ignore and not isinstance(element, Comment)
42 | )
43 |
44 |
45 | def webshop_text(session, page_type, query_string='', page_num=1, asin='', options={}, subpage='', **kwargs):
46 | if page_type == 'init':
47 | url = (
48 | f'{WEBSHOP_URL}/{session}'
49 | )
50 | if page_type == 'search':
51 | url = (
52 | f'{WEBSHOP_URL}/search_results/{session}/'
53 | f'{query_string}/{page_num}'
54 | )
55 | elif page_type == 'item':
56 | url = (
57 | f'{WEBSHOP_URL}/item_page/{session}/'
58 | f'{asin}/{query_string}/{page_num}/{options}'
59 | )
60 | elif page_type == 'item_sub':
61 | url = (
62 | f'{WEBSHOP_URL}/item_sub_page/{session}/'
63 | f'{asin}/{query_string}/{page_num}/{subpage}/{options}'
64 | )
65 | elif page_type == 'end':
66 | url = (
67 | f'{WEBSHOP_URL}/done/{session}/'
68 | f'{asin}/{options}'
69 | )
70 | # print(url)
71 | html = requests.get(url).text
72 | html_obj = BeautifulSoup(html, 'html.parser')
73 | texts = html_obj.findAll(text=True)
74 | visible_texts = list(filter(tag_visible, texts))
75 | # visible_texts = [str(text).strip().strip('\\n') for text in visible_texts]
76 | # if page_type == 'end': import pdb; pdb.set_trace()
77 | if False:
78 | # For `simple` mode, return just [SEP] separators
79 | return ' [SEP] '.join(t.strip() for t in visible_texts if t != '\n')
80 | else:
81 | # Otherwise, return an observation with tags mapped to specific, unique separators
82 | observation = ''
83 | option_type = ''
84 | options = {}
85 | asins = []
86 | cnt = 0
87 | prod_cnt = 0
88 | just_prod = 0
89 | for t in visible_texts:
90 | if t == '\n': continue
91 | if t.replace('\n', '').replace('\\n', '').replace(' ', '') == '': continue
92 | # if t.startswith('Instruction:') and page_type != 'init': continue
93 | # print(t.parent.name, t)
94 | if t.parent.name == 'button': # button
95 | processed_t = f'\n[{t}] '
96 | elif t.parent.name == 'label': # options
97 | if f"'{t}'" in url:
98 | processed_t = f'[[{t}]]'
99 | # observation = f'You have clicked {t}.\n' + observation
100 | else:
101 | processed_t = f'[{t}]'
102 | options[str(t)] = option_type
103 | # options[option_type] = options.get(option_type, []) + [str(t)]
104 | elif t.parent.get('class') == ["product-link"]: # product asins
105 | processed_t = f'\n[{t}] '
106 | if prod_cnt >= 10:
107 | processed_t = ''
108 | prod_cnt += 1
109 | asins.append(str(t))
110 | just_prod = 0
111 | else: # regular, unclickable text
112 | processed_t = '\n' + str(t) + ' '
113 | if cnt < 2 and page_type != 'init': processed_t = ''
114 | if just_prod <= 2 and prod_cnt >= 4: processed_t = ''
115 | option_type = str(t)
116 | cnt += 1
117 | just_prod += 1
118 | observation += processed_t
119 | info = {}
120 | if options:
121 | info['option_types'] = options
122 | if asins:
123 | info['asins'] = asins
124 | if 'Your score (min 0.0, max 1.0)' in visible_texts:
125 | idx = visible_texts.index('Your score (min 0.0, max 1.0)')
126 | info['reward'] = float(visible_texts[idx + 1])
127 | observation = 'Your score (min 0.0, max 1.0): ' + (visible_texts[idx + 1])
128 | return clean_str(observation), info
129 |
130 | class webshopEnv:
131 | def __init__(self):
132 | self.sessions = {}
133 |
134 | def clone_state(self):
135 | return copy.deepcopy(self.sessions)
136 |
137 | def step(self, session, action):
138 | done = False
139 | observation_ = None
140 | logging.info(self.sessions)
141 | if action == 'reset':
142 | self.sessions[session] = {'session': session, 'page_type': 'init'}
143 | elif action.startswith('think['):
144 | observation = 'OK.'
145 | elif action.startswith('search['):
146 | assert self.sessions[session]['page_type'] == 'init'
147 | query = action[7:-1]
148 | self.sessions[session] = {'session': session, 'page_type': 'search',
149 | 'query_string': query, 'page_num': 1}
150 | elif action.startswith('click['):
151 | button = action[6:-1]
152 | if button == 'Buy Now':
153 | assert self.sessions[session]['page_type'] == 'item'
154 | self.sessions[session]['page_type'] = 'end'
155 | #done = True
156 | elif button == 'Back to Search':
157 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item']
158 | self.sessions[session] = {'session': session, 'page_type': 'init'}
159 | elif button == 'Next >':
160 | #assert False # ad hoc page limitation
161 | assert self.sessions[session]['page_type'] == 'search'
162 | self.sessions[session]['page_num'] += 1
163 | elif button == '< Prev':
164 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item']
165 | if self.sessions[session]['page_type'] == 'search':
166 | #assert False
167 | self.sessions[session]['page_num'] -= 1
168 | elif self.sessions[session]['page_type'] == 'item_sub':
169 | self.sessions[session]['page_type'] = 'item'
170 | elif self.sessions[session]['page_type'] == 'item':
171 | self.sessions[session]['page_type'] = 'search'
172 | self.sessions[session]['options'] = {}
173 | elif button in ACTION_TO_TEMPLATE:
174 | assert self.sessions[session]['page_type'] == 'item'
175 | self.sessions[session]['page_type'] = 'item_sub'
176 | self.sessions[session]['subpage'] = button
177 | else:
178 | if self.sessions[session]['page_type'] == 'search':
179 | assert button in self.sessions[session].get('asins', []) # must be asins
180 | self.sessions[session]['page_type'] = 'item'
181 | self.sessions[session]['asin'] = button
182 | elif self.sessions[session]['page_type'] == 'item':
183 | assert 'option_types' in self.sessions[session]
184 | assert button in self.sessions[session]['option_types'], (button, self.sessions[session]['option_types']) # must be options
185 | option_type = self.sessions[session]['option_types'][button]
186 | if not 'options' in self.sessions[session]:
187 | self.sessions[session]['options'] = {}
188 | self.sessions[session]['options'][option_type] = button
189 | observation_ = f'You have clicked {button}.'
190 | else:
191 | assert False
192 | observation, info = webshop_text(**self.sessions[session])
193 | if observation_:
194 | observation = observation_
195 | self.sessions[session].update(info)
196 | reward = info.get('reward', 0.0)
197 | if reward != 0.0:
198 | #print(f"Current Session State: {self.sessions[session]}")
199 | #print(f"Action being processed: {action}")
200 | print(f"Resulting Observation: {observation}")
201 | observation+=" Please try again!"
202 | if reward == 1.0:
203 | done = True
204 | print("done")
205 | return observation, reward, done
206 |
207 | env = webshopEnv()
208 |
209 | logging.info("Logging has been configured.")
210 |
211 | global reflection_map
212 | reflection_map = []
213 |
214 |
215 | def step(env, action):
216 | attempts = 0
217 | while attempts < 10:
218 | try:
219 | return env.step(action)
220 | except requests.exceptions.Timeout:
221 | attempts += 1
222 |
223 |
224 |
225 | def get_samples(task, x, y, n_generate_sample, prompt_sample, knn=None):
226 | unique_trajectories = get_unique_trajectories(failed_trajectories)
227 | global reflection_map
228 | global completion_tokens
229 | global prompt_tokens
230 | reflection_map = []
231 | if prompt_sample == "standard":
232 | prompt = task.standard_prompt_wrap(x, y, [])
233 | elif prompt_sample == "cot":
234 | prompt = task.cot_prompt_wrap(x, y, [], knn)
235 | else:
236 | raise ValueError(f"prompt_sample {prompt_sample} not recognized")
237 | logging.info(f"PROMPT: {prompt}")
238 | samples,inca,incb = gpt(prompt, n=n_generate_sample)
239 | prompt_tokens+=inca
240 | completion_tokens+=incb
241 | #for trt in samples:
242 | # with open("debug_log1.log", "a", encoding="utf-8") as f:
243 | # f.write(y+"\n"+trt+"\n")
244 | return [y + _ for _ in samples]
245 |
246 |
247 | def get_unique_trajectories(failed_trajectories, num=2):
248 | unique_trajectories = []
249 | seen_final_answers = set()
250 | for traj in failed_trajectories:
251 | final_answer = traj.get("final_answer")
252 | if final_answer not in seen_final_answers:
253 | unique_trajectories.append(node_trajectory_to_text(traj["trajectory"]))
254 | seen_final_answers.add(final_answer)
255 | if len(unique_trajectories) >= num:
256 | break
257 | return unique_trajectories
258 |
259 |
260 | class Node:
261 | def __init__(self, state, question, parent=None, knn=None,env_state=None):
262 | self.state = (
263 | {"action": "", "observation": ""} if state is None else state
264 | )
265 | self.parent = parent
266 | self.question = question
267 | self.children = []
268 | self.visits = 0
269 | self.value = 0
270 | self.depth = 0 if parent is None else parent.depth + 1
271 | self.is_terminal = False
272 | self.reward = 0
273 | self.exhausted = False # If all children are terminal
274 | self.em = 0 # Exact match, evaluation metric
275 | self.knn = knn
276 | self.env_state = env_state
277 |
278 | def ques(self):
279 | return self.question
280 |
281 | def uct(self):
282 | if self.visits == 0:
283 | # return float('inf')
284 | return self.value * 2
285 | return self.value / self.visits + np.sqrt(
286 | 2 * np.log(self.parent.visits) / self.visits
287 | )
288 |
289 | def uct_with_depth(self, C1=1, C2=1):
290 | if self.visits == 0:
291 | return self.value
292 | exploitation_term = self.value / self.visits
293 | exploration_term = np.sqrt(2 * np.log(self.parent.visits) / self.visits)
294 | depth_term = self.depth
295 | return exploitation_term + C1 * exploration_term + C2 * depth_term
296 |
297 | def __str__(self):
298 | return f"Node(depth={self.depth}, value={self.value:.2f}, visits={self.visits}, action={self.state['action']}, observation={self.state['observation']})"
299 |
300 | def to_dict(self):
301 | return {
302 | "state": self.state,
303 | "question": self.question,
304 | "parent": self.parent.to_dict() if self.parent else None,
305 | "children": [child.to_dict() for child in self.children],
306 | "visits": self.visits,
307 | "value": self.value,
308 | "depth": self.depth,
309 | "is_terminal": self.is_terminal,
310 | "reward": self.reward,
311 | "em": self.em,
312 | "knn": self.knn,
313 | }
314 |
315 |
316 | def node_trajectory_to_text(node_string):
317 | lines = node_string.split("Node(")
318 | formatted_lines = []
319 | for line in lines:
320 | if line.startswith("Instruction"):
321 | formatted_lines.append(line)
322 | continue
323 | try:
324 | depth = int(line.split(",")[0].split("=")[1].strip())
325 | action = line.split(", action=")[1].split(", observation=")[0].strip()
326 | observation = line.split(", observation=")[1].split(")")[0].strip()
327 | except IndexError:
328 | continue
329 |
330 | if depth != 0:
331 | if action:
332 | formatted_lines.append(f"Action: {action}")
333 | if observation:
334 | formatted_lines.append(f"Observation: {observation}")
335 | formatted_lines.pop()
336 | return "\n".join(formatted_lines)
337 |
338 |
339 | def traj_depth(node_string):
340 | lines = node_string.split("\n")
341 | formatted_lines = []
342 | ret = 0
343 | for line in lines:
344 | try:
345 | depth = int(line.split(",")[0].split("=")[1].strip())
346 | action = line.split(", action=")[1].split(", observation=")[0].strip()
347 | observation = line.split(", observation=")[1].split(")")[0].strip()
348 | except IndexError:
349 | continue
350 | if depth > ret:
351 | ret = depth
352 | return ret
353 |
354 |
355 | def collect_all_nodes(node):
356 | """Recursively collect all nodes starting from the given node."""
357 | nodes = [node]
358 | for child in node.children:
359 | nodes.extend(collect_all_nodes(child))
360 | return nodes
361 |
362 |
363 | def collect_trajectory(node):
364 | trajectory = []
365 | ques = ""
366 | while node:
367 | ques = "Instruction: " + str(node.question).replace("WebShop","").replace("Instruction:","").strip()
368 | trajectory.append(str(node))
369 | node = node.parent
370 | if len(ques) > 0:
371 | trajectory.append(ques)
372 | return "\n".join(reversed(trajectory))
373 |
374 | import re
375 | def get_substrings_between_brackets(s):
376 | pattern = r"\[(.*?)\]"
377 | matches = re.findall(pattern, s)
378 | return matches[0]
379 |
380 | def quote_env(idx):
381 | return str(env.step(idx, "reset")[0])
382 |
383 |
384 | def dfs_search(args, task, idx, iterations, knnret, depth_limit=30, to_print=True):
385 | global gpt
386 | global failed_trajectories
387 | global success_trajectories
388 | gpt = partial(gpt, model_size=args.model_size, temperature=args.temperature)
389 | action="reset"
390 | x = env.step(idx, action)[0]
391 | if to_print:
392 | print(idx, x)
393 | root = Node(state=None, question=x)
394 | root.env_state = copy.deepcopy(env.sessions)
395 | all_nodes = []
396 | failed_trajectories = []
397 | success_trajectories = []
398 | stack = [root]
399 | it = 0
400 | knn = []
401 | if knnret:
402 | for traj in knnret:
403 | format_traj = node_trajectory_to_text(traj["trajectory"])
404 | # format_traj+=f"Action {traj_depth(traj['trajectory'])}: Finish[{get_substrings_between_brackets(traj['final_answer'])}]"+"\n"
405 | knn.append(format_traj)
406 | print("--------------knn is:")
407 | print(knn)
408 | last_node = None
409 | maxx=0.0
410 | while stack and it < iterations:
411 | node = stack.pop()
412 | last_node = node
413 | logging.info(f"DFS at node depth {node.depth}...")
414 |
415 | if node.is_terminal and node.reward == 1:
416 | logging.info(f"Terminal node with reward 1 found at depth {node.depth}")
417 | return (
418 | node.state,
419 | node.value,
420 | all_nodes,
421 | node.reward,
422 | node.em,
423 | failed_trajectories,
424 | success_trajectories,
425 | )
426 |
427 | if node.is_terminal and node.reward == 0:
428 | logging.info(f"Terminal node with reward 0 found at depth {node.depth}")
429 | return (
430 | node.state,
431 | node.value,
432 | all_nodes,
433 | maxx,
434 | node.em,
435 | failed_trajectories,
436 | success_trajectories,
437 | )
438 |
439 | maxx=max(node.reward,maxx)
440 |
441 | if node.depth >= depth_limit:
442 | logging.info("Depth limit reached")
443 | it += 1
444 | continue # go to next iteration
445 |
446 | expand_node(node, args, task,idx, knn=knn)
447 | stack.extend(reversed(node.children)) # adding all child nodes to stack for DFS
448 |
449 | all_nodes = [(node, node.value) for node in collect_all_nodes(root)]
450 | logging.info(f"State of all_nodes after iteration: {all_nodes}")
451 | it += 1
452 | # If we reach here, no solution was found
453 | logging.info("All paths explored. No solution found.")
454 | if len(failed_trajectories) == 0:
455 | trajectory = collect_trajectory(last_node)
456 | failed_trajectories.append({"trajectory": trajectory, "final_answer": ""})
457 | return root, 0, all_nodes, maxx, 0, failed_trajectories, success_trajectories
458 |
459 |
460 | def select_node_dfs(stack):
461 | return stack[-1] if stack else None # return the last node in the stack
462 |
463 |
464 |
465 |
466 | def select_node(node):
467 | while node and node.children:
468 | logging.info(
469 | f"Selecting from {len(node.children)} children at depth {node.depth}."
470 | )
471 |
472 | terminal_children = [child for child in node.children if child.is_terminal]
473 | terminal_status = [child.is_terminal for child in node.children]
474 |
475 | if len(terminal_children) == len(node.children):
476 | logging.info(
477 | f"All children are terminal at depth {node.depth}. Backtracking..."
478 | )
479 | if node.parent:
480 | node.parent.children.remove(node)
481 | node = node.parent
482 | continue
483 |
484 | node_with_reward_1 = next(
485 | (child for child in terminal_children if child.reward == 1), None
486 | )
487 | if node_with_reward_1:
488 | logging.info(f"Found terminal node with reward 1 at depth {node.depth}.")
489 | return node_with_reward_1
490 |
491 | node = max(
492 | (child for child in node.children if not child.is_terminal),
493 | key=lambda child: child.uct(),
494 | default=None,
495 | )
496 |
497 | while node.is_terminal and node.reward != 1:
498 | node = max(
499 | (child for child in node.parent.children if not child.is_terminal),
500 | key=lambda child: child.uct(),
501 | default=None,
502 | )
503 |
504 | logging.info(f"Selected node at depth {node.depth} with UCT {node.uct()}.")
505 |
506 | return node # This will return None if all paths from the root are exhausted
507 |
508 |
509 | def expand_node(node, args, task, idx,knn=None):
510 | if node.depth >= 30:
511 | logging.info("Depth limit reached")
512 | print("Depth limit reached")
513 | node.is_terminal = True
514 | return
515 | new_nodes = generate_new_states(node, args, task, idx,knn=knn)
516 | node.children.extend(new_nodes)
517 |
518 |
519 | def generate_new_states(node, args, task, idx, knn=None):
520 | prompt = generate_prompt(node)
521 | #print("the prompt is:")
522 | #print(prompt)
523 | sampled_actions = get_samples(
524 | task,
525 | prompt,
526 | f"Action: ",
527 | args.n_generate_sample,
528 | prompt_sample=args.prompt_sample,
529 | knn=knn,
530 | )
531 | logging.info(f"SAMPLED ACTION: {sampled_actions}")
532 |
533 | unique_states = {} # Store unique states here
534 | for action in sampled_actions:
535 | local_sessions = copy.deepcopy(node.env_state)
536 | env.sessions = local_sessions
537 | new_state = node.state.copy() # Make a copy of the parent node's state
538 |
539 |
540 | action_line = next(
541 | (
542 | line.split(":")[1].strip()
543 | for line in action.split("\n")
544 | if line.startswith("Action") and ":" in line
545 | ),
546 | None,
547 | )
548 |
549 | # Use thought and action to form a unique key
550 | unique_key = f"{action_line}"
551 |
552 | if unique_key in unique_states:
553 | continue # Skip if this state already exists
554 |
555 | if action_line:
556 | #print("the action line is:")
557 | #print(action_line)
558 | try:
559 | res = env.step(idx, action_line)
560 | #print("res", res)
561 | obs = res[0]
562 | r = res[1]
563 | done = res[2]
564 | except AssertionError:
565 | obs = 'Invalid action!'
566 | # print("err")
567 | r = -1
568 | done = False
569 |
570 | if action.startswith('think'):
571 | observation = 'OK.'
572 |
573 | # Update the new state dictionary
574 | new_state["action"] = action_line
575 | new_state["observation"] = obs
576 |
577 | env_state_clone = env.clone_state()
578 | new_node = Node(state=new_state, question=node.question, env_state=env_state_clone,parent=node)
579 | new_node.env_state = local_sessions
580 | new_node.is_terminal = r == 1 or done
581 | new_node.reward = r
582 | unique_states[unique_key] = new_node # Add this state to unique_states
583 | logging.info(f"NEW NODE: {new_node}")
584 |
585 | if new_node.is_terminal and r == 0:
586 | trajectory = collect_trajectory(new_node)
587 | failed_trajectories.append(
588 | {
589 | "trajectory": trajectory,
590 | "final_answer": f"{action_line}",
591 | }
592 | )
593 | if new_node.is_terminal and r == 1:
594 | trajectory = collect_trajectory(new_node)
595 | success_trajectories.append(
596 | {
597 | "trajectory": trajectory,
598 | "final_answer": f"{action_line}",
599 | }
600 | )
601 |
602 | return list(unique_states.values()) # Return unique nodes as a list
603 |
604 |
605 |
606 |
607 | def print_tree(node, level=0):
608 | indent = " " * level
609 | print(f"{indent}{node}")
610 | for child in node.children:
611 | print_tree(child, level + 1)
612 |
613 |
614 |
615 |
616 | def generate_prompt(node):
617 | trajectory = []
618 | question = node.question
619 | while node:
620 | new_segment = []
621 | if node.state["action"]:
622 | new_segment.append(f"Action: {node.state['action']}")
623 | if (
624 | node.state["observation"] and node.depth != 0
625 | ): # Exclude the observation from the root node
626 | new_segment.append(f"Observation: {node.state['observation']}")
627 | trajectory.append("\n".join(new_segment))
628 | node = node.parent
629 | return question + "\n".join(reversed(trajectory))
630 |
--------------------------------------------------------------------------------
/webshop/models.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from utils import ModelServer
3 | import time
4 | import re
5 |
6 | server = None
7 |
8 |
9 | def build_server(config_path):
10 | global server
11 | server = ModelServer(config_path=config_path)
12 |
13 |
14 | def split_and_keep_prefixes(s, delimiters):
15 | regex_pattern = f"({'|'.join(map(re.escape, delimiters))})"
16 | parts = re.split(regex_pattern, s)
17 | result = [parts[0]]
18 | for i in range(1, len(parts), 2):
19 | result.append(parts[i] + (parts[i + 1] if i + 1 < len(parts) else ""))
20 | return result
21 |
22 |
23 | def online_embed(traj):
24 | return server.get_completion_or_embedding(
25 | "7",
26 | message=traj,
27 | get_embedding=True,
28 | )
29 |
30 | ret_in=0
31 | ret_out=0
32 | def gpt(prompt, model_size="8", temperature=1.0, max_tokens=100, n=1) -> list:
33 | def call_openai_api(messages, model_size, temperature, max_tokens, n):
34 | global ret_in
35 | global ret_out
36 | outputs = []
37 | while n > 0:
38 | cnt = min(n, 20)
39 | n -= cnt
40 | res = server.get_completion_or_embedding(
41 | model_size, messages, temperature, max_tokens
42 | )
43 | ret_out+=res.usage.completion_tokens
44 | ret_in+=res.usage.prompt_tokens
45 | outputs.extend(
46 | [
47 | re.sub(r"^Action:", "", choice.message.content)
48 | for choice in res.choices
49 | ]
50 | )
51 | return outputs,res.usage.prompt_tokens,res.usage.completion_tokens
52 |
53 | messages = []
54 | parts = re.split(r"(Action:|Observation:|Instruction:)", prompt)
55 |
56 | result = [parts[0].strip()]
57 | for i in range(1, len(parts), 2):
58 | if i + 1 < len(parts):
59 | result.append(parts[i] + " " + parts[i + 1].strip())
60 | result.pop()
61 | last_obs=""
62 | for msg in result:
63 | if msg.startswith("Action"):
64 | messages.append({"role": "assistant", "content": msg})
65 | if msg.startswith("Observation"):
66 | messages.append({"role": "user", "content": msg})
67 | last_obs=msg
68 | if msg.startswith("Instruction"):
69 | messages.append({"role": "user", "content": msg})
70 | return call_openai_api(
71 | messages, model_size=model_size, temperature=temperature, max_tokens=max_tokens, n=n
72 | )
73 |
--------------------------------------------------------------------------------
/webshop/prompt.py:
--------------------------------------------------------------------------------
1 | prompt1 = """Instruction: You are also an advanced reasoning agent that can improve based on previous experience. Follow the instruct and purchase an item meeting all of the correct specifications by navigating the website.
2 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars
3 | [Search]
4 |
5 | Action: search[3 ounce bright citrus deodorant sensitive skin]
6 | Observation:
7 | [Back to Search]
8 | Page 1 (Total results: 50)
9 | [Next >]
10 | [B078GWRC1J]
11 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
12 | $10.99
13 | [B078GTKVXY]
14 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
15 | $10.99
16 | [B08KBVJ4XN]
17 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack)
18 | $15.95
19 |
20 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.]
21 | Observation: OK.
22 |
23 | Action: click[B078GWRC1J]
24 | Observation:
25 | [Back to Search]
26 | [< Prev]
27 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents]
28 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)]
29 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
30 | Price: $10.99
31 | Rating: N.A.
32 | [Description]
33 | [Features]
34 | [Reviews]
35 | [Buy Now]
36 |
37 | Action: think[For 3 ounce bottle of bright citrus deodorant for sensitive skin, the item has options 'bright citrus' and '3 ounce (pack of 1)' and seems good to buy.]
38 | Observation: OK.
39 |
40 | Action: click[bright citrus]
41 | Observation: You have clicked bright citrus.
42 |
43 | Action: click[3 ounce (pack of 1)]
44 | Observation: You have clicked 3 ounce (pack of 1).
45 |
46 | Action: click[Buy Now]
47 |
48 | """
49 |
50 |
51 | prompt1_feedback = """You are also an advanced reasoning agent that can improve based on self refection. Follow the instruction and purchase an item meeting all of the correct specifications by navigating the website.
52 | Here is an example:
53 | Webshop
54 | Instruction:
55 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars
56 | [Search]
57 |
58 | Action: search[3 ounce bright citrus deodorant sensitive skin]
59 | Observation:
60 | [Back to Search]
61 | Page 1 (Total results: 50)
62 | [Next >]
63 | [B078GWRC1J]
64 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
65 | $10.99
66 | [B078GTKVXY]
67 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
68 | $10.99
69 | [B08KBVJ4XN]
70 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack)
71 | $15.95
72 |
73 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.]
74 | Observation: OK.
75 |
76 | Action: click[B078GWRC1J]
77 | Observation:
78 | [Back to Search]
79 | [< Prev]
80 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents]
81 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)]
82 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
83 | Price: $10.99
84 | Rating: N.A.
85 | [Description]
86 | [Features]
87 | [Reviews]
88 | [Buy Now]
89 |
90 | Action: think[For 3 ounce bottle of bright citrus deodorant for sensitive skin, the item has options 'bright citrus' and '3 ounce (pack of 1)' and seems good to buy.]
91 | Observation: OK.
92 |
93 | Action: click[bright citrus]
94 | Observation: You have clicked bright citrus.
95 |
96 | Action: click[3 ounce (pack of 1)]
97 | Observation: You have clicked 3 ounce (pack of 1).
98 |
99 | Action: click[Buy Now]
100 |
101 | Here are some previous trials in which you were given access to relevant context and a item to purchase. You were unsuccessful in buying the correct item either because you did not find an item meeting all of the required specifications or because you did not select the correct item. Carefully consider the errors and make sure not to repeat the same mistakes. {trajectories}
102 |
103 | {input}
104 | """
105 |
106 | # In[4]:
107 | score_prompt = '''Given an item to purchase and a trajectory that aims to buy an item that exactly matches the specification, which corresponds to the ideal score of 1.0, analyze the following trajectory, then at the last line conclude "Thus the correctness score is {s}", where s is an integer from 1 to 10.
108 |
109 | Here are some examples
110 |
111 | Webshop
112 | Instruction:
113 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars
114 | [Search]
115 |
116 | Action: search[dairy free and apple variety pack of chips]
117 | Observation:
118 | [Back to Search]
119 | Page 1 (Total results: 50)
120 | [Next >]
121 | [B07HRFJWP8]
122 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total)
123 | $100.0
124 | [B01KMHY5PG]
125 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars)
126 | $21.49
127 | [B008D2X8C4]
128 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
129 | $100.0
130 |
131 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.]
132 | Observation: OK.
133 |
134 | Action: click[B008D2X8C4]
135 | Observation:
136 | [Back to Search]
137 | [< Prev]
138 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack]
139 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)]
140 | Price: $100.0
141 | Rating: N.A.
142 | [Description]
143 | [Features]
144 | [Reviews]
145 | [Buy Now]
146 |
147 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.]
148 | Observation: OK.
149 |
150 | Action: click[variety pack]
151 | Observation: You have clicked variety pack.
152 |
153 | Action: click[0.8 ounce (pack of 24)]
154 | Observation: You have clicked 0.8 ounce (pack of 24).
155 |
156 | Action: click[Buy Now]
157 |
158 | Reflection: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.
159 | Thus the correctness score is 5
160 |
161 | Webshop
162 | Instruction:
163 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars
164 | [Search]
165 |
166 | Action: search[3 ounce bright citrus deodorant sensitive skin]
167 | Observation:
168 | [Back to Search]
169 | Page 1 (Total results: 50)
170 | [Next >]
171 | [B078GWRC1J]
172 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
173 | $10.99
174 | [B078GTKVXY]
175 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
176 | $10.99
177 | [B08KBVJ4XN]
178 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack)
179 | $15.95
180 |
181 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.]
182 | Observation: OK.
183 | Reflection: Although the task is not yet complete, the first search is correct and returns relevant results. The thought is also logical and has potential to lead to the correct selection.
184 | Thus the correctness score is 9
185 | {input}
186 | '''
187 |
188 | score_prompt_feedback = '''Given an item to purchase and a trajectory, analyze the following trajectory, then at the last line conclude "Thus the correctness score is {s}", where s is an integer from 1 to 10.
189 | Here are some examples
190 |
191 | Webshop
192 | Instruction:
193 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars
194 | [Search]
195 |
196 | Action: search[dairy free and apple variety pack of chips]
197 | Observation:
198 | [Back to Search]
199 | Page 1 (Total results: 50)
200 | [Next >]
201 | [B07HRFJWP8]
202 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total)
203 | $100.0
204 | [B01KMHY5PG]
205 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars)
206 | $21.49
207 | [B008D2X8C4]
208 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
209 | $100.0
210 |
211 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.]
212 | Observation: OK.
213 |
214 | Action: click[B008D2X8C4]
215 | Observation:
216 | [Back to Search]
217 | [< Prev]
218 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack]
219 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)]
220 | Price: $100.0
221 | Rating: N.A.
222 | [Description]
223 | [Features]
224 | [Reviews]
225 | [Buy Now]
226 |
227 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.]
228 | Observation: OK.
229 |
230 | Action: click[variety pack]
231 | Observation: You have clicked variety pack.
232 |
233 | Action: click[0.8 ounce (pack of 24)]
234 | Observation: You have clicked 0.8 ounce (pack of 24).
235 |
236 | Action: click[Buy Now]
237 |
238 | Reflection: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.
239 | Thus the correctness score is 5
240 |
241 | Webshop
242 | Instruction:
243 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars
244 | [Search]
245 |
246 | Action: search[3 ounce bright citrus deodorant sensitive skin]
247 | Observation:
248 | [Back to Search]
249 | Page 1 (Total results: 50)
250 | [Next >]
251 | [B078GWRC1J]
252 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
253 | $10.99
254 | [B078GTKVXY]
255 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
256 | $10.99
257 | [B08KBVJ4XN]
258 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack)
259 | $15.95
260 |
261 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.]
262 | Observation: OK.
263 | Reflection: Although the task is not yet complete, the first search is correct and returns relevant results. The thought is also logical and has potential to lead to the correct selection.
264 | Thus the correctness score is 9
265 | {trajectories}
266 | {input}
267 | '''
268 |
269 | # trivial search & item, choose option
270 | prompt1_actonly = """Webshop
271 | Instruction:
272 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars
273 | [Search]
274 |
275 | Action: search[3 ounce bright citrus deodorant sensitive skin]
276 | Observation:
277 | [Back to Search]
278 | Page 1 (Total results: 50)
279 | [Next >]
280 | [B078GWRC1J]
281 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
282 | $10.99
283 | [B078GTKVXY]
284 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
285 | $10.99
286 | [B08KBVJ4XN]
287 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack)
288 | $15.95
289 |
290 | Action: click[B078GWRC1J]
291 | Observation:
292 | [Back to Search]
293 | [< Prev]
294 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents]
295 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)]
296 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce
297 | Price: $10.99
298 | Rating: N.A.
299 | [Description]
300 | [Features]
301 | [Reviews]
302 | [Buy Now]
303 |
304 | Action: click[bright citrus]
305 | Observation: You have clicked bright citrus.
306 |
307 | Action: click[3 ounce (pack of 1)]
308 | Observation: You have clicked 3 ounce (pack of 1).
309 |
310 | Action: click[Buy Now]
311 | """
312 |
313 | reflection_prompt = '''You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to an shopping website and a specific type of item to buy. You were given access to relevant context and a item to purchase. You were unsuccessful in buying the correct item either because you did not find an item meeting all of the required specifications or because you did not select the correct item. The ideal score is 1.0, and anything less is incorrect. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.
314 | Here are some examples:
315 |
316 | Previous Trial
317 | Instruction:
318 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars
319 | [Search]
320 |
321 | Action: search[dairy free and apple variety pack of chips]
322 | Observation:
323 | [Back to Search]
324 | Page 1 (Total results: 50)
325 | [Next >]
326 | [B07HRFJWP8]
327 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total)
328 | $100.0
329 | [B01KMHY5PG]
330 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars)
331 | $21.49
332 | [B008D2X8C4]
333 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz)
334 | $100.0
335 |
336 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.]
337 | Observation: OK.
338 |
339 | Action: click[B008D2X8C4]
340 | Observation:
341 | [Back to Search]
342 | [< Prev]
343 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack]
344 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)]
345 | Price: $100.0
346 | Rating: N.A.
347 | [Description]
348 | [Features]
349 | [Reviews]
350 | [Buy Now]
351 |
352 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.]
353 | Observation: OK.
354 |
355 | Action: click[variety pack]
356 | Observation: You have clicked variety pack.
357 |
358 | Action: click[0.8 ounce (pack of 24)]
359 | Observation: You have clicked 0.8 ounce (pack of 24).
360 |
361 | Action: click[Buy Now]
362 |
363 | STATUS: FAIL
364 |
365 | Reflection: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products.
366 |
367 | Previous Trial
368 | Instruction:
369 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars
370 | [Search]
371 |
372 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2]
373 | Observation:
374 | [Back to Search]
375 | Page 1 (Total results: 50)
376 | [Next >]
377 | [B06Y96MXJV]
378 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
379 | $29.99
380 | [B06Y96N1KG]
381 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
382 | $29.99
383 | [B07GJTKYJQ]
384 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
385 | $42.99
386 |
387 | Action: think[The first two items are not what I want. The third item is not what I want, but it is the closest one.]
388 | Observation: OK.
389 |
390 | Action: click[B07GJTKYJQ]
391 | Observation:
392 | [Back to Search]
393 | [< Prev]
394 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
395 | Price: $42.99
396 | Rating: N.A.
397 | [Description]
398 | [Features]
399 | [Reviews]
400 | [Buy Now]
401 |
402 | Action: think[The price is too high and it is not what I want.]
403 | Observation: OK.
404 |
405 | Action: click[< Prev]
406 | Observation:
407 | [Back to Search]
408 | Page 1 (Total results: 50)
409 | [Next >]
410 | [B06Y96MXJV]
411 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.)
412 | $29.99
413 | [B06Y96N1KG]
414 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.)
415 | $29.99
416 | [B07GJTKYJQ]
417 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces)
418 | $42.99
419 |
420 | Action: think[The first two items are not what I want.]
421 | Observation: OK.
422 |
423 | Action: click[Next >]
424 | Observation: Invalid action!
425 |
426 | Action: think[I need to refine my search.]
427 | Observation: OK.
428 |
429 | Action: click[< Back to Search]
430 | Observation: Invalid action!
431 |
432 | Action: click[Back to Search]
433 |
434 | STATUS: FAIL
435 |
436 | Reflection: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches so that I can find more products.
437 |
438 | Previous trial:
439 | {trajectory}
440 | Reflection:'''
--------------------------------------------------------------------------------
/webshop/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.4
2 | aiosignal==1.3.1
3 | async-timeout==4.0.2
4 | attrs==23.1.0
5 | backoff==2.2.1
6 | certifi==2023.5.7
7 | charset-normalizer==3.1.0
8 | frozenlist==1.3.3
9 | idna==3.4
10 | mpmath==1.3.0
11 | multidict==6.0.4
12 | numpy==1.24.3
13 | openai==0.27.7
14 | requests==2.31.0
15 | sympy==1.12
16 | tqdm==4.65.0
17 | urllib3==2.0.2
18 | yarl==1.9.2
19 | pandas==2.0.3
--------------------------------------------------------------------------------
/webshop/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | import math
6 | from models import online_embed, build_server
7 | import faiss
8 | import logging
9 | import random
10 |
11 | import lats
12 | from lats import dfs_search,quote_env
13 | from webshop import WebShopTask
14 |
15 |
16 | def random_selection(lst, n=5):
17 | if len(lst) <= n:
18 | return lst
19 | else:
20 | return random.sample(lst, n)
21 |
22 |
23 | def run(args):
24 | task = WebShopTask()
25 | print(task)
26 | logs, cnt_avg, cnt_any = [], 0, 0
27 |
28 | # create log directories if they don't exist
29 | config_path = os.path.join(args.run_name, "config.json")
30 |
31 | with open(config_path, "w", encoding="utf-8") as wf:
32 | info_dict = vars(args)
33 | info_dict["is_running"] = True
34 | json.dump(info_dict, wf, indent=4)
35 |
36 | build_server(config_path=config_path)
37 |
38 | os.makedirs(os.path.dirname(args.log_dir), exist_ok=True)
39 |
40 | logging.basicConfig(
41 | filename=args.log_dir,
42 | level=logging.INFO,
43 | format="%(asctime)s - %(levelname)s - %(message)s",
44 | filemode="a",
45 | )
46 |
47 | wins = {}
48 | lose = {}
49 | trajectories = []
50 | embedding_array = np.zeros((0, 3584))
51 | tongji=[]
52 | avg_tongji=[]
53 | cnt=0
54 | scores=[]
55 | for i in range(50):
56 | scores.append(0.0)
57 | for trial in range(10):
58 | with open("debug_log.log", "a", encoding="utf-8") as f:
59 | f.write("Trial")
60 | f.write(str(trial)+"\n")
61 | count = 0
62 | task_accs = []
63 | info = []
64 | emb_db = faiss.IndexFlatL2(3584)
65 | emb_db.add(embedding_array.astype("float32"))
66 | avg=0.0
67 | for i in range(args.task_start_index, args.task_end_index):
68 | with open("debug_log1.log", "a", encoding="utf-8") as f:
69 | f.write("------------new task---------------\n")
70 | f.write(str(i)+"\n")
71 | # solve
72 | if i in wins:
73 | continue
74 | prev = None
75 | knnret = []
76 | if trajectories and args.cot_size>0:
77 | if i in lose:
78 | prev = lose[i]
79 | else:
80 | prev=quote_env(f'fixed_{i}')
81 | if args.cot_method=="knn":
82 | fail_vec = online_embed(str(prev))
83 | dist=[]
84 | for index, row in np.ndenumerate(embedding_array):
85 | if index[1] == 0:
86 | realN = (1.0 / np.linalg.norm(np.array(fail_vec) - embedding_array[index[0]]))
87 | dist.append((realN,index[0]))
88 | dist.sort(key=lambda x: x[0], reverse=True)
89 | sz_now=min(len(trajectories),args.cot_size)
90 | dist=dist[:sz_now]
91 | dist=[(math.exp(5.0*x),y) for (x,y) in dist]
92 | original_sum=sum([x for (x,_) in dist])
93 | dist=[(x *float(sz_now) / original_sum,y) for (x,y) in dist]
94 | tot=0
95 | cntD=[]
96 | realD=[]
97 | for (x,y) in dist:
98 | cntD.append(math.floor(x))
99 | realD.append(x)
100 | tot+=math.floor(x)
101 | while totmaxx:
106 | maxx=realD[ig]-float(cntD[ig])
107 | maxi=ig
108 | cntD[maxi]+=1
109 | tot+=1
110 | weights=[]
111 | for (x,y) in dist:
112 | weights.append(x)
113 | normalized_weights = [w / sum(weights) for w in weights]
114 | for ig in range(sz_now):
115 | s_ind = np.random.choice(len(weights), p=normalized_weights)
116 | (_,y)=dist[s_ind]
117 | knnret.append(trajectories[y])
118 | else:
119 | knnret=random.sample(trajectories, min(args.cot_size,len(trajectories)))
120 | state, value, all_nodes, reward, em, failt, succt = dfs_search(
121 | args, task, f'fixed_{i}', args.iteration, knnret
122 | )
123 | if failt:
124 | with open("debug_log.log", "a", encoding="utf-8") as f:
125 | f.write("FAIL\n")
126 | print("Fail")
127 | print(i)
128 | lose[i] = failt[0]
129 | if succt:
130 | cnt=cnt+1
131 | with open("debug_log.log", "a", encoding="utf-8") as f:
132 | f.write("SUCCESS\n")
133 | print("Success")
134 | print(i)
135 | wins[i] = 1
136 | vec = online_embed(str(succt[0]))
137 | trajectories.append(succt[0])
138 | embedding_array = np.vstack((embedding_array, np.array(vec)))
139 | scores[i]=max(scores[i],reward)
140 | avg+=reward
141 | # log main metric
142 | if em is None:
143 | em = 0
144 | task_accs.append(em)
145 | cnt_avg = sum(task_accs) / len(task_accs)
146 | print(i, "len(task_accs)", len(task_accs), "cnt_avg", cnt_avg, "\n")
147 | tongji.append(cnt)
148 | avg=sum(scores)/50.0
149 | with open("debug_log.log", "a", encoding="utf-8") as f:
150 | f.write("average: ")
151 | f.write(str(avg)+"\n")
152 | avg_tongji.append(avg)
153 | print("prompt:")
154 | print(lats.prompt_tokens)
155 | print("completion:")
156 | print(lats.completion_tokens)
157 | print(tongji)
158 | print(avg_tongji)
159 | n = args.task_end_index - args.task_start_index
160 |
161 |
162 | def parse_args():
163 | args = argparse.ArgumentParser()
164 | args.add_argument("--model_size", type=str, default="8")
165 | args.add_argument("--temperature", type=float, default=1.0)
166 | args.add_argument("--task_start_index", type=int, default=900)
167 | args.add_argument("--task_end_index", type=int, default=1000)
168 | args.add_argument("--prompt_sample", type=str, choices=["standard", "cot"])
169 | args.add_argument("--n_generate_sample", type=int, default=1)
170 | args.add_argument("--n_evaluate_sample", type=int, default=1)
171 | args.add_argument("--iteration", type=int, default=50)
172 | args.add_argument("--algorithm", type=str, choices=["lats", "rap", "tot"])
173 | args.add_argument("--cot_method", type=str, choices=["knn", "random", "None"])
174 | args.add_argument("--run_name", type=str)
175 | args.add_argument("--log_file_path", type=str)
176 | args.add_argument("--log_dir", type=str)
177 | args.add_argument("--cot_size", type=int, default=0)
178 | args = args.parse_args()
179 | return args
180 |
181 |
182 | if __name__ == "__main__":
183 | args = parse_args()
184 | print(args)
185 | run(args)
186 |
--------------------------------------------------------------------------------
/webshop/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
4 | os.environ["PYTHONUTF8"] = "1"
5 | import time
6 | from typing import Dict, List
7 | import openai
8 | import random, json
9 | from IPython import embed
10 | import random
11 | import math
12 |
13 | tot_in=0
14 | tot_out=0
15 |
16 | class ModelServer:
17 |
18 | def get_completion_or_embedding(
19 | self,
20 | model_size: str,
21 | message,
22 | temperature: float = 0.0,
23 | max_tokens: int = 256,
24 | get_embedding: bool = False,
25 | ) -> str:
26 | global tot_in
27 | global tot_out
28 | assert model_size in ["70", "8", "7"]
29 |
30 | if not get_embedding:
31 | assert type(message) == list, "Message should be a list."
32 | response = client.chat.completions.create(
33 | model=model_name,
34 | messages=message,
35 | max_tokens=max_tokens,
36 | temperature=temperature,
37 | stop=["<|eot_id|>", "\nObservation", "Observation"],
38 | )
39 | tot_out+=response.usage.completion_tokens
40 | tot_in+=response.usage.prompt_tokens
41 | else:
42 | assert type(message) == str, "Message should be a string."
43 | response = client.embeddings.create(
44 | model=model_name,
45 | input=message,
46 | )
47 |
48 |
49 | if get_embedding:
50 | return response.data[0].embedding
51 | else:
52 | return response
53 |
54 |
55 |
--------------------------------------------------------------------------------
/webshop/webshop.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from base import Task
4 | from prompt import *
5 | from models import gpt
6 | import logging
7 | import random
8 | from transformers import GPT2Tokenizer
9 |
10 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
11 |
12 | def get_token_length(text):
13 | return len(tokenizer.encode(text))
14 |
15 | max_token_length = 131072
16 |
17 | class WebShopTask(Task):
18 | """
19 | Input (x) : a text instruction
20 | Output (y) : a text generation
21 | Reward (r) : # TODO
22 | Input Example:
23 | Output Example:
24 | """
25 | def __init__(self):
26 | """
27 | file: a text file, each line is some sentences
28 | """
29 | super().__init__()
30 | self.steps = 7
31 | self.stops = ['\nObservation:\n', None]
32 | self.value_cache = {}
33 | self.reflections = []
34 |
35 | def test_output(self, idx: int, output: str):
36 | output = output.split('Action:\n')[-1]
37 | prompt = score_prompt + output
38 | score_outputs = gpt(prompt, n=5, model='gpt-4')
39 | scores = []
40 | for score_output in score_outputs:
41 | # print(score_output)
42 | pattern = r".*correctness score is (\d+).*"
43 | match = re.match(pattern, score_output, re.DOTALL)
44 | if match:
45 | score = int(match.groups()[0])
46 | scores.append(score)
47 | else:
48 | print(f'------------------score no match: {[score_output]}')
49 | print(scores)
50 | # print('------------')
51 | info = {'rs': scores, 'r': sum(scores) / len(scores) if scores else 0}
52 | return info
53 |
54 | @staticmethod
55 | def standard_prompt_wrap(x: str, y:str='') -> str:
56 | return standard_prompt.format(input=x) + y
57 |
58 | @staticmethod
59 | def generate_self_reflection(traj, question):
60 |
61 | reflect_prompt = reflection_prompt.format(trajectory=traj)
62 |
63 | reflection = gpt4(reflect_prompt)
64 |
65 | traj_with_reflection = traj + "Reflection: " + reflection[0] + "\n"
66 |
67 | reflection_mapping = {
68 | 'question': question,
69 | 'reflection': reflection[0]
70 | }
71 |
72 | return traj_with_reflection, reflection_mapping
73 |
74 | @staticmethod
75 | def generate_self_reflection(z, question):
76 | reflection_mapping = []
77 | trajectories = ""
78 |
79 | sampled_items = random.sample(z, min(3, len(z)))
80 | failed_trajectories = [item['trajectory'] + f"\nReward: {item['r']}\n" for item in sampled_items if isinstance(item, dict) and 'trajectory' in item and 'r' in item]
81 |
82 | for traj in failed_trajectories:
83 | trajectories += traj
84 | reflect_prompt = reflection_prompt.format(trajectory=traj)
85 |
86 | reflection = gpt(reflect_prompt)
87 |
88 | trajectories += "Reflection: " + reflection[0] + "\n"
89 |
90 | reflection_mapping.append({
91 | 'question': question,
92 | 'trajectory': traj,
93 | 'reflection': reflection[0]
94 | })
95 |
96 | return reflection_mapping
97 |
98 | @staticmethod
99 | def cot_prompt_wrap(x: str, y: str = '', reflection_mapping_list=[],traj_list=[]):
100 | question = x
101 | input = x + y
102 | trajectories = ""
103 |
104 | ret=prompt1
105 | for tj in traj_list:
106 | ret+=tj+"\n"
107 | ret+=input+"\n"
108 | return ret
109 |
110 |
111 |
112 | @staticmethod
113 | def vote_prompt_wrap(x: str, ys: list) -> str:
114 | prompt = score_prompt + "\n" + x + "\n\n"
115 | for i, y in enumerate(ys, 1):
116 | # y = y.replace('Plan:\n', '')
117 | # TODO: truncate the plan part?
118 | prompt += f'Choice {i}:\n{y}\n'
119 | return prompt
120 |
121 | @staticmethod
122 | def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
123 | vote_results = [0] * n_candidates
124 | for vote_output in vote_outputs:
125 | pattern = r".*best trajectory is .*(\d+).*"
126 | match = re.match(pattern, vote_output, re.DOTALL)
127 | if match:
128 | vote = int(match.groups()[0]) - 1
129 | if vote in range(n_candidates):
130 | vote_results[vote] += 1
131 | else:
132 | print(f'vote no match: {[vote_output]}')
133 | return vote_results
134 |
135 | @staticmethod
136 | def compare_prompt_wrap(x: str, ys: list) -> str:
137 | assert len(ys) == 2, 'compare prompt only supports 2 candidates'
138 |
139 | # Extract the last Action for each trajectory
140 | last_actions = []
141 | for y in ys:
142 | # Split by line and reverse to start from the end
143 | lines = y.split('\n')[::-1]
144 | for line in lines:
145 | # Check for an Action line and get its content
146 | if "Action" in line:
147 | last_actions.append(line.split('Action')[-1].strip(': '))
148 | break
149 |
150 | assert len(last_actions) == 2, 'Expected to find 2 Actions'
151 |
152 | # Construct the prompt with the extracted Actions
153 | prompt = compare_prompt + f'Action 1:{last_actions[0]}\n\nAction 2:{last_actions[1]}\n'
154 | return prompt
155 |
156 |
157 | @staticmethod
158 | def compare_output_unwrap(compare_output: str):
159 | if 'more correct trajectory is 1' in compare_output:
160 | return 0
161 | elif 'more correct trajectory is 2' in compare_output:
162 | return 1
163 | elif "two trajectories are similarly correct" in compare_output:
164 | return 0.5
165 | else:
166 | print(f'-----------------compare no match: {[compare_output]}')
167 | return -1
168 |
169 | @staticmethod
170 | def value_prompt_wrap(x: str, y: str, z: list = [], reflections: list = []) -> str:
171 | question = x.split('\n')[0]
172 | if len(z) != 0:
173 | failed_trajectories = ""
174 | for traj, ref in zip(z, reflections):
175 | score = int(traj['r'] * 10) / 2
176 | trajectory = traj['trajectory']
177 | split_trajectory = trajectory.split('Action: ')
178 | first_part = split_trajectory[0] # This part will not be modified
179 |
180 | # Remove the first 'Action' and corresponding 'Observation'
181 | remaining_parts = split_trajectory[2:]
182 |
183 | # Reconstruct the trajectory string
184 | new_trajectory = 'Action: '.join([first_part] + remaining_parts)
185 | traj['trajectory'] = new_trajectory
186 | failed_trajectories += f"{y}\n{traj}\nReflection: {ref['reflection']}\nThus the correctness score is {score}\n"
187 |
188 | inp = y + "\n\nReflection: "
189 | prompt = score_prompt_feedback.format(s="", trajectories=failed_trajectories, input=inp)
190 | else:
191 | inp = y + "\n\nReflection: "
192 | prompt = score_prompt.format(s="", input=inp)
193 |
194 | return prompt
195 |
196 |
197 | @staticmethod
198 | def value_outputs_unwrap(evaluate_prompt: str):
199 | evaluate_prompt = evaluate_prompt[0]
200 | if '10' in evaluate_prompt:
201 | return 1.0
202 | elif '9' in evaluate_prompt:
203 | return 0.9
204 | elif '8' in evaluate_prompt:
205 | return 0.8
206 | elif '7' in evaluate_prompt:
207 | return 0.7
208 | elif '6' in evaluate_prompt:
209 | return 0.6
210 | elif '5' in evaluate_prompt:
211 | return 0.5
212 | elif '4' in evaluate_prompt:
213 | return 0.4
214 | elif '3' in evaluate_prompt:
215 | return 0.3
216 | elif '2' in evaluate_prompt:
217 | return 0.2
218 | elif '1' in evaluate_prompt:
219 | return 0.1
220 | else:
221 | return -1
222 |
223 |
--------------------------------------------------------------------------------