├── .gitignore ├── LICENSE ├── README.md ├── alfworld ├── README ├── alfworld_trial.py ├── base_config.yaml ├── generate_reflections.py ├── main.py ├── prompts │ ├── alfworld.json │ ├── alfworld_3prompts.json │ ├── fever.json │ └── prompts_naive.json ├── requirements.txt └── utils.py ├── figures └── illusion.png ├── hotpotqa ├── README ├── base.py ├── hotpot.py ├── hotpotqa.py ├── models.py ├── requirements.txt ├── run.py ├── tot.py └── utils.py └── webshop ├── README ├── base.py ├── lats.py ├── models.py ├── prompt.py ├── requirements.txt ├── run.py ├── utils.py └── webshop.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | # For a library or package, you might want to ignore these files since the code is 86 | # intended to run in multiple environments; otherwise, check them in: 87 | # .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # poetry 97 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 98 | # This is especially recommended for binary packages to ensure reproducibility, and is more 99 | # commonly ignored for libraries. 100 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 101 | #poetry.lock 102 | 103 | # pdm 104 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 105 | #pdm.lock 106 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 107 | # in version control. 108 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 109 | .pdm.toml 110 | .pdm-python 111 | .pdm-build/ 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | /alfworld_runs/70bins/ 164 | /alfworld_runs/8bins/ 165 | /alfworld_runs/7embed/ 166 | /alfworld_runs/2embed/ 167 | /alfworld_runs/reranker/ 168 | /alfworld_runs/rerankers/ 169 | /alfworld_runs/logs/ 170 | /alfworld_runs/.infinity_cache/ 171 | /alfworld_runs/e5_7b/ 172 | /8bbase/ 173 | /70bins/ 174 | /8bins/ 175 | /7embed/ 176 | /2embed/ 177 | /reranker/ 178 | /rerankers/ 179 | /logs/ 180 | /.infinity_cache/ 181 | /e5_7b/ 182 | /8bbase/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoPS: Empowering LLM Agents with Provable Cross-Task Experience Sharing 2 | 3 | This repo holds the code for [CoPS: Empowering LLM Agents with Provable Cross-Task Experience Sharing](https://arxiv.org/abs/2410.16670) by Chen Yang, Chenyang Zhao, Quanquan Gu and Dongruo Zhou. 4 | 5 | ![CoPS](./figures/illusion.png) 6 | 7 | CoPS enhances sequential reasoning in LLM-powered agents by enabling effective sharing and selection of cross-task experiences through a pessimism-based strategy. 8 | 9 | ## Installation 10 | 11 | Our repo is mainly build on SGLang version 0.2.15. 12 | 13 | ```shell 14 | git clone -b v0.2.15 https://github.com/sgl-project/sglang.git 15 | cd sglang 16 | pip install -e "python[all]" 17 | 18 | # Modify the version of flashinfer to your own device 19 | pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ 20 | ``` 21 | 22 | For the installation of Alfworld and Webshop benchmarks locally, please refer to the following links: 23 | 24 | [https://github.com/alfworld/alfworld](https://github.com/alfworld/alfworld) 25 | 26 | [https://github.com/princeton-nlp/WebShop](https://github.com/princeton-nlp/WebShop) 27 | 28 | ## Reproduce 29 | 30 | To reproduce results on each benchmarks, refer to their subdirectory. 31 | 32 | ## Model Sever 33 | 34 | To accelerate the running efficiency of COPS, we implement a data-paralysis serving router and engine with SGLang backend, i.e., the [ModelServer](https://github.com/zhaochenyang20/ModelServer). After configuration, you can get a response or embedding from these kind of commands: 35 | 36 |
37 | ModelSever Command 38 | 39 | ```python 40 | from model_sever import ModelServer 41 | message = [ 42 | { 43 | "role": "user", 44 | "content": "Interact with a household to solve a task. Here are a few examples.\nYou are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a coffeemachine 1, a countertop 1, a diningtable 3, a diningtable 2, a diningtable 1, a drawer 1, a fridge 1, a garbagecan 1, a microwave 1, a sidetable 1, a sinkbasin 1, a stoveburner 4, a stoveburner 3, a stoveburner 2, a stoveburner 1, and a toaster 1.\nYour task is to: find some apple and put it in sidetable.\n> think: To solve the task, I need to find and take an apple, then put it in sidetable.\nOK.\n> think: First I need to find an apple. An apple is more likely to appear in fridges (1), diningtables (1-3), sidetables (1), countertops (1), sinkbasins (1), garbagecan (1). I can check one by one, starting with fridge 1.\nOK.\n> go to fridge 1\nThe fridge 1 is closed.\n> open fridge 1\nYou open the fridge 1. The fridge 1 is open. In it, you see a lettuce 2, a mug 2, a potato 2, and a tomato 1.\n> go to diningtable 1\nOn the diningtable 1, you see nothing.\n> go to diningtable 2\nOn the diningtable 2, you see a pineapple 1.\n> go to diningtable 3\nOn the diningtable 3, you see a bowl 2, a bowl 1, a bread 2, and a bread 1.\n> go to sidetable 1\nOn the sidetable 1, you see a cup 1, a lettuce 1, a peppershaker 3, a potato 1, and a saltshaker 1.\n> go to countertop 1\nOn the countertop 1, you see a butterknife 1, a dishsponge 1, and a soapbottle 1.\n> go to sinkbasin 1\nOn the sinkbasin 1, you see a cup 2, a egg 1, a mug 1, and a potato 3.\n> go to garbagecan 1\nOn the garbagecan 1, you see a apple 3, and a egg 3.\n> think: Now I find an apple (3). Next, I need to take it.\n> take apple 3 from garbagecan 1\nYou pick up the apple 3 from the garbagecan 1.\n> think: Now I find an apple (3). Next, I need to put it in/on sidetable 1.\nOK.\n> go to sidetable 1\nOn the sidetable 1, you see a cup 1, a lettuce 1, a peppershaker 3, a potato 1, and a saltshaker 1.\n> put apple 3 in/on sidetable 1\nYou put the apple 3 in/on the sidetable 1.\nYou are in the middle of a room. Looking quickly around you, you see a cabinet 4, a cabinet 3, a cabinet 2, a cabinet 1, a countertop 1, a garbagecan 1, a handtowelholder 2, a handtowelholder 1, a sinkbasin 2, a sinkbasin 1, a toilet 1, a toiletpaperhanger 1, and a towelholder 1.\nYour task is to: put some spraybottle on toilet.\n> think: To solve the task, I need to find and take a sparybottle, then put it on toilet.\nOK.\n> think: First I need to find a spraybottle. A spraybottle is more likely to appear in cabinet (1-4), countertop (1), toilet (1), sinkbasin (1-2), garbagecan (1). I can check one by one, starting with cabinet 1.\nOK.\n> go to cabinet 1\nOn the cabinet 1, you see a cloth 1, a soapbar 1, a soapbottle 1.\n> go to cabinet 2\nThe cabinet 2 is closed.\n> open cabinet 2\nYou open the cabinet 2. The cabinet 2 is open. In it, you see a candle 1, and a spraybottle 2.\n> think: Now I find a spraybottle (2). Next, I need to take it.\n> take spraybottle 2 from cabinet 2\nYou pick up the spraybottle 2 from the cabinet 2.\n> think: Now I take a spraybottle (2). Next, I need to put it in/on toilet 1.\nOK.\n> go to toilet 1\nOn the toilet 1, you see a soapbottle 2.\n> put spraybottle 2 in/on toilet 1\nYou put the spraybottle 2 in/on the toilet 1.\n\n\nHere is the task:\nYou are in the middle of a room. Looking quickly around you, you see a bed 1, a desk 2, a desk 1, a drawer 6, a drawer 5, a drawer 4, a drawer 3, a drawer 2, a drawer 1, a garbagecan 1, a laundryhamper 1, a safe 1, a shelf 6, a shelf 5, a shelf 4, a shelf 3, a shelf 2, and a shelf 1.\nYour task is to: put a mug in desk.\n>", 45 | }, 46 | { 47 | "role": "assistant", 48 | "content": "think: To solve the task, I need to find and take a mug, then put it in/on desk.", 49 | }, 50 | {"role": "user", "content": "OK."}, 51 | { 52 | "role": "assistant", 53 | "content": "think: First I need to find a mug. A mug is more likely to appear in desk (1-2), drawer (1-6), shelf (1-6), garbagecan (1), laundryhamper (1), safe (1). I can check one by one, starting with desk 1.", 54 | }, 55 | {"role": "user", "content": "OK."}, 56 | {"role": "assistant", "content": "go to desk 1"}, 57 | { 58 | "role": "user", 59 | "content": "On the desk 1, you see a bowl 1, a cd 2, a creditcard 2, a desklamp 1, a keychain 1, a laptop 2, a pen 2, and a pencil 2.", 60 | }, 61 | { 62 | "role": "assistant", 63 | "content": "think: Now I don't find a mug in desk 1. Next, I need to check desk 2.", 64 | }, 65 | {"role": "user", "content": "OK."}, 66 | {"role": "assistant", "content": "go to desk 2"}, 67 | { 68 | "role": "user", 69 | "content": "On the desk 2, you see a alarmclock 1, a book 3, a bowl 2, a creditcard 3, and a keychain 3.", 70 | }, 71 | ] 72 | 73 | server = ModelServer() 74 | # `8` here is the completion model size configured in the `client_configs` 75 | complition = server.get_completion_or_embedding("8", message) 76 | # get the completion from a 8b instruct model, i.e. Llama3.1 8B 77 | print(complition) 78 | 79 | # `7` here is the embedding model size configured in the `client_configs` 80 | embedding = server.get_completion_or_embedding( 81 | "7", 82 | message="As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", 83 | get_embedding=True, 84 | ) 85 | # get the embedding of a 7b embedding model, i.e. `Alibaba-NLP/gte-Qwen1.5-7B-instruct` 86 | print(embedding[:10]) 87 | ``` 88 | 89 |
90 | 91 | ## Other Notes 92 | 93 | For all questions, contact [cya2@iu.edu](cya2@iu.edu) 94 | 95 | ## Cite 96 | 97 | ```bibtex 98 | @misc{yang2024copsempoweringllmagents, 99 | title={CoPS: Empowering LLM Agents with Provable Cross-Task Experience Sharing}, 100 | author={Chen Yang and Chenyang Zhao and Quanquan Gu and Dongruo Zhou}, 101 | year={2024}, 102 | eprint={2410.16670}, 103 | archivePrefix={arXiv}, 104 | primaryClass={cs.LG}, 105 | url={https://arxiv.org/abs/2410.16670}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /alfworld/README: -------------------------------------------------------------------------------- 1 | # Alfworld experiment code 2 | 3 | Code modified from [Reflexion](https://github.com/noahshinn/reflexion) 4 | 5 | Setup your local LLM using an OpenAI API compatible endpoint, then set it up in utils.py. We recommend [SGLang](https://github.com/sgl-project/sglang) 6 | 7 | Example command: 8 | 9 | python main.py --num_trials 10 --num_envs 134 --plan_model_size 8 \ 10 | --reflect_model_size 8 --cluster_size 134 --mem_size 0 \ 11 | --run_name 'alfworld_new' \ 12 | --mem_selection_method 'fifo' \ 13 | --in_context_trajactory_size 5 \ 14 | --online_embedding_model_size 7 \ 15 | --trajactory_search_method 'knn' \ 16 | --use_success_trajectory -------------------------------------------------------------------------------- /alfworld/alfworld_trial.py: -------------------------------------------------------------------------------- 1 | """Adapted from https://github.com/ysymyth/ReAct/blob/master/alfworld.ipynb""" 2 | 3 | import os 4 | 5 | os.environ["PYTHONUTF8"] = "1" 6 | import sys 7 | import json 8 | import yaml 9 | import importlib 10 | import alfworld 11 | import re 12 | import numpy as np 13 | import faiss 14 | import random 15 | import math 16 | import alfworld.agents.environment 17 | from utils import EnvironmentHistory, ModelServer 18 | from IPython import embed 19 | from typing import List, Dict, Any, Tuple 20 | 21 | 22 | EMBEDDING_DIM = {"7": 3584, "2": 1536} 23 | 24 | FOLDER = "./prompts" 25 | PROMPT_FILE = "alfworld_3prompts.json" 26 | 27 | with open(os.path.join(FOLDER, PROMPT_FILE), "r", encoding="utf-8") as f: 28 | d = json.load(f) 29 | 30 | 31 | PREFIXES = { 32 | "pick_and_place": "put", 33 | "pick_clean_then_place": "clean", 34 | "pick_heat_then_place": "heat", 35 | "pick_cool_then_place": "cool", 36 | "look_at_obj": "examine", 37 | "pick_two_obj": "puttwo", 38 | } 39 | 40 | 41 | def process_ob(ob): 42 | if ob.startswith("You arrive at loc "): 43 | ob = ob[ob.find(". ") + 2 :] 44 | return ob 45 | 46 | 47 | def alfworld_run( 48 | model_server: ModelServer, 49 | plan_model_size: str, 50 | env, 51 | base_prompt, 52 | memory: List[str], 53 | to_print=True, 54 | ob="", 55 | mem_size: int = 0, 56 | mem_selection_method: str = "fifo", 57 | config_path: str = None, 58 | ) -> Tuple[EnvironmentHistory, bool]: 59 | 60 | assert mem_selection_method in [ 61 | "fifo", 62 | "fix", 63 | "random", 64 | ], "Invalid memory selection method" 65 | 66 | if mem_size == 0: 67 | env_history = EnvironmentHistory(base_prompt, ob, [], []) 68 | elif mem_size != -1 and len(memory) >= mem_size: 69 | if mem_selection_method == "fifo": 70 | env_history = EnvironmentHistory(base_prompt, ob, memory[-mem_size:], []) 71 | elif mem_selection_method == "fix": 72 | env_history = EnvironmentHistory(base_prompt, ob, memory[:mem_size], []) 73 | elif mem_selection_method == "random": 74 | env_history = EnvironmentHistory( 75 | base_prompt, ob, random.sample(memory, mem_size), [] 76 | ) 77 | else: 78 | env_history = EnvironmentHistory(base_prompt, ob, memory, []) 79 | 80 | env_history.reset() 81 | if to_print: 82 | print(ob) 83 | sys.stdout.flush() 84 | cur_step = 0 85 | preva = "" 86 | message = [{"role": "user", "content": str(env_history)}] 87 | while cur_step < 49: 88 | tmpr = 0.0 89 | action = preva 90 | while action == preva: 91 | print(f"cur_step: {cur_step}") 92 | print(f"temperature: {tmpr}") 93 | completion = model_server.get_completion_or_embedding( 94 | plan_model_size, 95 | message=message, 96 | temperature=tmpr, 97 | ) 98 | lines = completion.split("\n") 99 | action = "" 100 | for line in lines: 101 | stripped_line = line.strip() 102 | if stripped_line != "": 103 | action = stripped_line 104 | break 105 | 106 | if ">" in action: 107 | action = action.replace(">", "").strip() 108 | action_words = action.split(" ") 109 | if "put" in action_words: 110 | for i in range(len(action_words)): 111 | if ( 112 | action_words[i].strip().lower() == "in" 113 | or action_words[i].strip().lower() == "on" 114 | ): 115 | action_words[i] = "in/on" 116 | action = " ".join(action_words) 117 | tmpr += 0.1 118 | env_history.add("action", action) 119 | message.append({"role": "assistant", "content": action}) 120 | preva = action 121 | observation, _, done, info = env.step([action]) 122 | observation, _, done = process_ob(observation[0]), info["won"][0], done[0] 123 | if action.startswith("think:"): 124 | observation = "OK." 125 | env_history.add("observation", observation) 126 | message.append({"role": "user", "content": observation}) 127 | if to_print: 128 | print(f"> {action}\n{observation}") 129 | sys.stdout.flush() 130 | if done: 131 | return env_history, True 132 | elif env_history.check_is_exhausted(): 133 | return env_history, False 134 | cur_step += 1 135 | return env_history, False 136 | 137 | 138 | def adjust_trial_number(trial_log_path: str, is_fail: bool = False) -> str: 139 | pattern = r"trial_(\d+)\.log" 140 | 141 | def replace(match): 142 | return "trial_fail.json" if is_fail else "trial_inf.json" 143 | 144 | result = re.sub(pattern, replace, trial_log_path) 145 | return result 146 | 147 | 148 | def adjust_trial_number2(trial_log_path: str) -> str: 149 | pattern = r"trial_(\d+)\.log" 150 | 151 | def replace(match): 152 | return "trial_cache.json" 153 | 154 | result = re.sub(pattern, replace, trial_log_path) 155 | return result 156 | 157 | 158 | def get_offline_embedding(des): 159 | with open("embedding.json", "r", encoding="utf-8") as file: 160 | embdata = json.load(file) 161 | for a, b in embdata: 162 | if des == a: 163 | return b 164 | else: 165 | return [] 166 | 167 | 168 | def enumerate_splits(string): 169 | lines = [line for line in string.split("\n") if line] 170 | result = [] 171 | 172 | for i in range(len(lines) - 1): 173 | first_part = "\n".join(lines[: i + 1]) 174 | second_part = "\n".join(lines[i + 1 :]) 175 | result.append((first_part, second_part)) 176 | 177 | return result 178 | 179 | 180 | def replace_lines(text, prefix, suffix): 181 | replaced_text = re.sub( 182 | r"^>(.*)$", lambda m: f"{prefix}{m.group(1)}{suffix}", text, flags=re.MULTILINE 183 | ) 184 | return replaced_text 185 | 186 | 187 | def print_first_n_lines(text, n=5): 188 | lines = text.split('\n') 189 | ret="" 190 | for line in lines[:n]: 191 | ret+=line+"\n" 192 | return ret 193 | 194 | def run_trial( 195 | cluster_size, 196 | cluster_counter, 197 | model_server: ModelServer, 198 | plan_model_size: str, 199 | trial_log_path: str, 200 | world_log_path: str, 201 | trial_idx: int, 202 | env_configs: List[Dict[str, Any]], 203 | mem_size: int = 0, 204 | specific_cluster_name: str = None, 205 | mem_selection_method: str = "fifo", 206 | use_success_trajectory: bool = False, 207 | trajactory_search_method: str = "knn", 208 | in_context_trajactory_size: int = 3, 209 | online_embedding_model_size: str = "2", 210 | ) -> List[Dict[str, Any]]: 211 | importlib.reload(alfworld) 212 | importlib.reload(alfworld.agents.environment) 213 | success_in_each_cluster = { 214 | "clean": 0, 215 | "put": 0, 216 | "cool": 0, 217 | "puttwo": 0, 218 | "examine": 0, 219 | "heat": 0, 220 | } 221 | with open("base_config.yaml") as reader: 222 | config = yaml.safe_load(reader) 223 | split = "eval_out_of_distribution" 224 | 225 | assert ( 226 | not use_success_trajectory 227 | and trajactory_search_method is None 228 | and in_context_trajactory_size == 0 229 | ) or ( 230 | use_success_trajectory 231 | and trajactory_search_method in ["knn", "random"] 232 | and in_context_trajactory_size > 0 233 | ), "If use_success_trajectory is False, then in_context_trajactory_size should be 0." 234 | 235 | env = getattr(alfworld.agents.environment, config["env"]["type"])( 236 | config, train_eval=split 237 | ) 238 | env = env.init_env(batch_size=1) 239 | 240 | num_successes: int = 0 241 | num_success_increase: int = 0 242 | actcnt = 0 243 | last_trial_log_path: str = adjust_trial_number(trial_log_path) 244 | fail_db_path = adjust_trial_number(trial_log_path, is_fail=True) 245 | emb_cache_path = adjust_trial_number2(trial_log_path) 246 | success_data = { 247 | "clean": [], 248 | "put": [], 249 | "cool": [], 250 | "puttwo": [], 251 | "examine": [], 252 | "heat": [], 253 | } 254 | 255 | increase_success = { 256 | "clean": [], 257 | "put": [], 258 | "cool": [], 259 | "puttwo": [], 260 | "examine": [], 261 | "heat": [], 262 | } 263 | 264 | if not os.path.exists(last_trial_log_path): 265 | with open(last_trial_log_path, "w", encoding="utf-8") as file: 266 | print("build success log") 267 | json.dump( 268 | { 269 | "clean": [], 270 | "put": [], 271 | "cool": [], 272 | "puttwo": [], 273 | "examine": [], 274 | "heat": [], 275 | }, 276 | file, 277 | indent=4, 278 | ) 279 | 280 | with open(last_trial_log_path, "r", encoding="utf-8") as file: 281 | success_data = json.load(file) 282 | 283 | fail_data = {} 284 | emb_cache = {} 285 | 286 | if not os.path.exists(emb_cache_path): 287 | with open(emb_cache_path, "w", encoding="utf-8") as file: 288 | print("build embed cache") 289 | json.dump({}, file) 290 | 291 | with open(emb_cache_path, "r", encoding="utf-8") as file: 292 | emb_cache = json.load(file) 293 | 294 | if not os.path.exists(fail_db_path): 295 | with open(fail_db_path, "w", encoding="utf-8") as file: 296 | print("build fail log") 297 | json.dump({}, file) 298 | 299 | with open(fail_db_path, "r", encoding="utf-8") as file: 300 | fail_data = json.load(file) 301 | 302 | trajectories = [] 303 | embedding_array = np.zeros((0, EMBEDDING_DIM[online_embedding_model_size])) 304 | huge_trajectories = [] 305 | huge_ary = np.zeros((0, EMBEDDING_DIM[online_embedding_model_size])) 306 | if trajactory_search_method == "knn": 307 | for key in success_data: 308 | for des, trj in success_data[key]: 309 | # vec = get_offline_embedding(des) 310 | trj_cut = (trj + "\n").split("Here is the task:")[-1].strip() 311 | if trj_cut in emb_cache: 312 | vec = emb_cache[trj_cut] 313 | else: 314 | vec = model_server.get_completion_or_embedding( 315 | online_embedding_model_size, 316 | message=trj_cut, 317 | get_embedding=True, 318 | ) 319 | emb_cache[trj_cut] = vec 320 | trajectories.append(trj) 321 | embedding_array = np.vstack((embedding_array, np.array(vec))) 322 | elif trajactory_search_method == "random": 323 | for key in success_data: 324 | for des, trj in success_data[key]: 325 | trajectories.append(trj) 326 | 327 | emb_db = faiss.IndexFlatL2(EMBEDDING_DIM[online_embedding_model_size]) 328 | emb_db.add(embedding_array.astype("float32")) 329 | huge_db = faiss.IndexFlatL2(EMBEDDING_DIM[online_embedding_model_size]) 330 | huge_db.add(huge_ary.astype("float32")) 331 | final_fail_db = {} 332 | 333 | for z, env_config in enumerate(env_configs): 334 | 335 | ob, info = env.reset() 336 | ob = "\n".join(ob[0].split("\n\n")[1:]) 337 | env_description = ob.strip() 338 | env_vec = ( 339 | np.array(get_offline_embedding(env_description)) 340 | .reshape(1, -1) 341 | .astype("float32") 342 | ) 343 | name = "/".join(info["extra.gamefile"][0].split("/")[-3:-1]) 344 | 345 | print(f"using {name}") 346 | 347 | if env_config["is_success"]: 348 | for i, (k, v) in enumerate(PREFIXES.items()): 349 | if name.startswith(k): 350 | success_in_each_cluster[v] += 1 351 | num_successes += 1 352 | actcnt += 1 353 | with open(world_log_path, "a", encoding="utf-8") as wf: 354 | wf.write( 355 | f""" 356 | 357 | ============================================================== 358 | Environment #{z} 359 | Trial #{trial_idx} 360 | Game: {name} 361 | SUCCESS 362 | ============================================================== 363 | 364 | """ 365 | ) 366 | with open(trial_log_path, "a", encoding="utf-8") as wf: 367 | wf.write( 368 | f""" 369 | 370 | ============================================================== 371 | Environment #{z} 372 | Trial #{trial_idx} 373 | Game: {name} 374 | SUCCESS 375 | ============================================================== 376 | 377 | """ 378 | ) 379 | continue 380 | 381 | for i, (k, v) in enumerate(PREFIXES.items()): 382 | if name.startswith(k): 383 | env_configs[z]["cluster"] = v 384 | if v in cluster_counter: 385 | if len(cluster_counter[v]) < cluster_size and ( 386 | specific_cluster_name is None or v == specific_cluster_name 387 | ): 388 | cluster_counter[v].append(name) 389 | else: 390 | if specific_cluster_name is None or v == specific_cluster_name: 391 | cluster_counter[v] = [name] 392 | else: 393 | cluster_counter[v] = [] 394 | base_prompt = ( 395 | "Interact with a household to solve a task. Here are a few examples.\n" 396 | + d[f"react_{v}_1"] 397 | + d[f"react_{v}_0"] 398 | + "\n" 399 | ) 400 | if ( 401 | trajectories 402 | and use_success_trajectory 403 | and in_context_trajactory_size > 0 404 | ): 405 | assert trajactory_search_method in ["knn", "random"] 406 | if trajactory_search_method == "knn": 407 | if name in fail_data: 408 | fail_key=fail_data[name].strip() 409 | else: 410 | fail_key=env_description.strip() 411 | fail_vec = model_server.get_completion_or_embedding( 412 | online_embedding_model_size, 413 | message=fail_key, 414 | get_embedding=True, 415 | ) 416 | indices = [[]] 417 | dist=[] 418 | for index, row in np.ndenumerate(embedding_array): 419 | if index[1] == 0: 420 | realN = (1.0 / np.linalg.norm(np.array(fail_vec) - embedding_array[index[0]])) 421 | dist.append((realN,index[0])) 422 | dist.sort(key=lambda x: x[0], reverse=True) 423 | sz_now=min(len(trajectories),in_context_trajactory_size) 424 | dist=dist[:sz_now] 425 | dist=[(math.exp(5.0*x),y) for (x,y) in dist] 426 | original_sum=sum([x for (x,_) in dist]) 427 | dist=[(x *float(sz_now) / original_sum,y) for (x,y) in dist] 428 | tot=0 429 | cntD=[] 430 | realD=[] 431 | for (x,y) in dist: 432 | cntD.append(math.floor(x)) 433 | realD.append(x) 434 | tot+=math.floor(x) 435 | while totmaxx: 440 | maxx=realD[ig]-float(cntD[ig]) 441 | maxi=ig 442 | cntD[maxi]+=1 443 | tot+=1 444 | weights=[] 445 | for (x,y) in dist: 446 | weights.append(x) 447 | normalized_weights = [w / sum(weights) for w in weights] 448 | for ig in range(sz_now): 449 | s_ind = np.random.choice(len(weights), p=normalized_weights) 450 | (_,y)=dist[s_ind] 451 | indices[0].append(y) 452 | elif trajactory_search_method == "random": 453 | indices[0] = random.sample( 454 | range(len(trajectories)), min(len(trajectories),in_context_trajactory_size) 455 | ) 456 | for i in indices[0]: 457 | base_prompt += ( 458 | (trajectories[i] + "\n") 459 | .split("Here is the task:")[-1] 460 | .strip() 461 | ) + "\n" 462 | base_prompt = replace_lines( 463 | base_prompt, 464 | "<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n", 465 | "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n", 466 | ) 467 | base_prompt += "\n\nExamples finished.\n\n" 468 | if env_config["principle"]: 469 | base_prompt += "Here are the core principles you should follow as much as possible in your planning:\n[Principles start]\n" 470 | for pi in env_config["principle"]: 471 | base_prompt += pi + "\n" 472 | base_prompt += "\n[Principles end]\n\n" 473 | inclus = False 474 | final_env_history = "" 475 | is_success = False 476 | if name in cluster_counter[v]: 477 | actcnt += 1 478 | inclus = True 479 | final_env_history, is_success = alfworld_run( 480 | model_server, 481 | plan_model_size, 482 | env, 483 | base_prompt, 484 | env_config["memory"], 485 | to_print=True, 486 | ob=ob, 487 | mem_size=mem_size, 488 | mem_selection_method=mem_selection_method, 489 | ) 490 | else: 491 | env_configs[z]["skip"] = True 492 | 493 | # update env config 494 | if is_success: 495 | status_str: str = f""" 496 | 497 | ============================================================== 498 | Environment #{z} 499 | Trial #{trial_idx} 500 | Game: {name} 501 | SUCCESS 502 | ============================================================== 503 | 504 | """ 505 | env_configs[z]["is_success"] = True 506 | num_successes += 1 507 | num_success_increase += 1 508 | success_in_each_cluster[v] += 1 509 | elif inclus: 510 | status_str: str = f""" 511 | 512 | ============================================================== 513 | Environment #{z} 514 | Trial #{trial_idx} 515 | Game: {name} 516 | FAIL 517 | ============================================================== 518 | 519 | """ 520 | else: 521 | status_str: str = f""" 522 | 523 | ============================================================== 524 | Environment #{z} 525 | Trial #{trial_idx} 526 | SKIP 527 | ============================================================== 528 | 529 | """ 530 | # log to world log 531 | assert world_log_path.endswith(".log") 532 | if not os.path.exists(world_log_path): 533 | os.system(f"touch {world_log_path}") 534 | with open(world_log_path, "a", encoding="utf-8") as f: 535 | f.write(status_str + "\n") 536 | 537 | if is_success: 538 | increase_success[v].append( 539 | (env_description, str(final_env_history)) 540 | ) 541 | trj_cut = (str(final_env_history) + "\n").split("Here is the task:")[-1].strip() 542 | vec = model_server.get_completion_or_embedding( 543 | online_embedding_model_size, 544 | message=trj_cut, 545 | get_embedding=True, 546 | ) 547 | emb_cache[trj_cut] = vec 548 | trajectories.append(str(final_env_history)) 549 | embedding_array = np.vstack((embedding_array, np.array(vec))) 550 | else: 551 | final_fail_db[name] = ( 552 | (str(final_env_history) + "\n") 553 | .split("Here is the task:")[-1] 554 | .strip() 555 | ) 556 | 557 | with open(trial_log_path, "a", encoding="utf-8") as wf: 558 | wf.write( 559 | f""" 560 | 561 | ============================================================== 562 | Environment #{z} 563 | {str(final_env_history)} 564 | STATUS: {"OK" if is_success else "FAIL"} 565 | ============================================================== 566 | 567 | """ 568 | ) 569 | env.close() 570 | 571 | log_str: str = f""" 572 | 573 | ============================================================== 574 | SUCCESS: {num_successes} 575 | INCREASE SUCCESS: {num_success_increase} 576 | FAIL: {actcnt - num_successes} 577 | TOTAL: {actcnt} 578 | ACCURACY: {round(num_successes / actcnt, 2)} 579 | CLUSTER SUCCEESS: {str(success_in_each_cluster)} 580 | ============================================================== 581 | 582 | """ 583 | 584 | with open(last_trial_log_path, "w", encoding="utf-8") as file: 585 | json.dump( 586 | { 587 | key: success_data.get(key, []) + increase_success.get(key, []) 588 | for key in success_data.keys() | increase_success.keys() 589 | }, 590 | file, 591 | indent=4, 592 | ) 593 | with open(fail_db_path, "w", encoding="utf-8") as file: 594 | json.dump( 595 | final_fail_db, 596 | file, 597 | indent=4, 598 | ) 599 | 600 | with open(emb_cache_path, "w", encoding="utf-8") as file: 601 | json.dump( 602 | emb_cache, 603 | file, 604 | indent=4, 605 | ) 606 | 607 | with open(trial_log_path, "a", encoding="utf-8") as wf: 608 | wf.write(log_str) 609 | with open(world_log_path, "a", encoding="utf-8") as wf: 610 | wf.write(log_str + "\n") 611 | return env_configs, cluster_counter 612 | -------------------------------------------------------------------------------- /alfworld/base_config.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | data_path: '$ALFWORLD_DATA/json_2.1.1/train' 3 | eval_id_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_seen' # null/None to disable 4 | eval_ood_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_unseen' # null/None to disable 5 | num_train_games: -1 # max training games (<=0 indicates full dataset) 6 | num_eval_games: -1 # max evaluation games (<=0 indicates full dataset) 7 | 8 | logic: 9 | domain: '$ALFWORLD_DATA/logic/alfred.pddl' # PDDL domain file that defines the world dynamics 10 | grammar: '$ALFWORLD_DATA/logic/alfred.twl2' # Grammar file that defines the text feedbacks 11 | 12 | env: 13 | type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid' 14 | regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file 15 | domain_randomization: False # shuffle Textworld print order and object id nums 16 | task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place 17 | expert_timeout_steps: 150 # max steps before timeout for expert to solve the task 18 | expert_type: "handcoded" # 'handcoded' or 'downward'. Note: the downward planner is very slow for real-time use 19 | goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED) 20 | 21 | hybrid: 22 | start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point 23 | thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training 24 | eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training 25 | 26 | thor: 27 | screen_width: 300 # width of THOR window 28 | screen_height: 300 # height of THOR window 29 | smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow) 30 | save_frames_to_disk: False # save frame PNGs to disk (useful for making videos) 31 | save_frames_path: './videos/' # path to save frame PNGs 32 | 33 | controller: 34 | type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER) 35 | debug: False 36 | load_receps: True # load receptacle locations from precomputed dict (if available) 37 | 38 | mask_rcnn: 39 | pretrained_model_path: '$ALFWORLD_DATA/detectors/mrcnn.pth' 40 | 41 | general: 42 | random_seed: 42 43 | use_cuda: True # disable this when running on machine without cuda 44 | visdom: False # plot training/eval curves, run with visdom server 45 | task: 'alfred' 46 | training_method: 'dagger' # 'dqn' or 'dagger' 47 | save_path: './training/' # path to save pytorch models 48 | observation_pool_capacity: 3 # k-size queue, 0 indicates no observation 49 | hide_init_receptacles: False # remove initial observation containing navigable receptacles 50 | 51 | training: 52 | batch_size: 10 53 | max_episode: 50000 54 | smoothing_eps: 0.1 55 | optimizer: 56 | learning_rate: 0.001 57 | clip_grad_norm: 5 58 | 59 | evaluate: 60 | run_eval: True 61 | batch_size: 10 62 | env: 63 | type: "AlfredTWEnv" 64 | 65 | checkpoint: 66 | report_frequency: 1000 # report every N episode 67 | experiment_tag: 'test' # name of experiment 68 | load_pretrained: False # during test, enable this so that the agent load your pretrained model 69 | load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path 70 | 71 | model: 72 | encoder_layers: 1 73 | decoder_layers: 1 74 | encoder_conv_num: 5 75 | block_hidden_dim: 64 76 | n_heads: 1 77 | dropout: 0.1 78 | block_dropout: 0.1 79 | recurrent: True 80 | 81 | rl: 82 | action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working) 83 | max_target_length: 20 # max token length for seq2seq generation 84 | beam_width: 10 # 1 means greedy 85 | generate_top_k: 3 86 | 87 | training: 88 | max_nb_steps_per_episode: 50 # terminate after this many steps 89 | learn_start_from_this_episode: 0 # delay updates until this epsiode 90 | target_net_update_frequency: 500 # sync target net with online net per this many epochs 91 | 92 | replay: 93 | accumulate_reward_from_final: True 94 | count_reward_lambda: 0.0 # 0 to disable 95 | novel_object_reward_lambda: 0.0 # 0 to disable 96 | discount_gamma_game_reward: 0.9 97 | discount_gamma_count_reward: 0.5 98 | discount_gamma_novel_object_reward: 0.5 99 | replay_memory_capacity: 500000 # adjust this depending on your RAM size 100 | replay_memory_priority_fraction: 0.5 101 | update_per_k_game_steps: 5 102 | replay_batch_size: 64 103 | multi_step: 3 104 | replay_sample_history_length: 4 105 | replay_sample_update_from: 2 106 | 107 | epsilon_greedy: 108 | noisy_net: False # if this is true, then epsilon greedy is disabled 109 | epsilon_anneal_episodes: 1000 # -1 if not annealing 110 | epsilon_anneal_from: 0.3 111 | epsilon_anneal_to: 0.1 112 | 113 | dagger: 114 | action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working) 115 | max_target_length: 20 # max token length for seq2seq generation 116 | beam_width: 10 # 1 means greedy 117 | generate_top_k: 5 118 | unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation 119 | 120 | training: 121 | max_nb_steps_per_episode: 50 # terminate after this many steps 122 | 123 | fraction_assist: 124 | fraction_assist_anneal_episodes: 50000 125 | fraction_assist_anneal_from: 1.0 126 | fraction_assist_anneal_to: 0.01 127 | 128 | fraction_random: 129 | fraction_random_anneal_episodes: 0 130 | fraction_random_anneal_from: 0.0 131 | fraction_random_anneal_to: 0.0 132 | 133 | replay: 134 | replay_memory_capacity: 500000 135 | update_per_k_game_steps: 5 136 | replay_batch_size: 64 137 | replay_sample_history_length: 4 138 | replay_sample_update_from: 2 139 | 140 | vision_dagger: 141 | model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input) 142 | resnet_fc_dim: 64 143 | maskrcnn_top_k_boxes: 10 # top k box features 144 | use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!) 145 | sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn' 146 | -------------------------------------------------------------------------------- /alfworld/generate_reflections.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["PYTHONUTF8"] = "1" 4 | import random 5 | from utils import ModelServer 6 | 7 | from typing import List, Dict, Any 8 | 9 | with open("./reflexion_few_shot_examples.txt", "r", encoding="utf-8") as f: 10 | FEW_SHOT_EXAMPLES = f.read() 11 | 12 | 13 | def _get_scenario(s: str) -> str: 14 | """Parses the relevant scenario from the experience log.""" 15 | return s.split("Here is the task:")[-1].strip() 16 | 17 | 18 | def _generate_reflection_query(log_str: str, memory: List[str]) -> str: 19 | """Allows the Agent to reflect upon a past experience.""" 20 | scenario: str = _get_scenario(log_str) 21 | query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. You were unsuccessful in completing the task. Do not summarize your environment, but rather think about the strategy and path you took to attempt to complete the task. Devise a concise, new plan of action that accounts for your mistake with reference to specific actions that you should have taken. For example, if you tried A and B but forgot C, then devise a plan to achieve C with environment-specific actions. You will need this later when you are solving the same task. Remember that your plan must be very concise and your entire output must only limited to a single line of words. Here are two examples: 22 | 23 | {FEW_SHOT_EXAMPLES} 24 | 25 | Here is your failed experience: 26 | {scenario}""" 27 | 28 | query += "\n\nNew plan:" 29 | return query 30 | 31 | 32 | def update_memory( 33 | model_server: ModelServer, 34 | reflection_model_size: str, 35 | trial_log_path: str, 36 | env_configs: List[Dict[str, Any]], 37 | mem_size: int = -1, 38 | ) -> List[Dict[str, Any]]: 39 | """Updates the given env_config with the appropriate reflections.""" 40 | with open(trial_log_path, "r", encoding="utf-8") as f: 41 | full_log: str = f.read() 42 | 43 | env_logs = full_log.split( 44 | "==============================================================\n\n\n\n==============================================================" 45 | ) 46 | 47 | faile = [] 48 | env_logs = env_logs[:-1] 49 | env_logs[0] = env_logs[0].strip( 50 | "\n\n==============================================================\n" 51 | ) 52 | env_logs = [each.strip() for each in env_logs] 53 | assert all( 54 | ["Environment #" in env_logs[i] for i in range(len(env_logs))] 55 | ), "Parsing error" 56 | #! 不要修改 print 的这串字符串,因为这串字符串是用来分割日志的 57 | #! env_logs[0] 需要特殊处理,而最原始的 env_logs[-1] 是胜率总结 58 | for i, env in enumerate(env_configs): 59 | 60 | # if unsolved, get reflection and update env config 61 | if not env["is_success"] and not env["skip"]: 62 | faile.append(i) 63 | assert mem_size != 0, "Memory size should not be 0." 64 | #! Reflection 的 mem selection 都是 FIFO 65 | if mem_size != -1 and len(env["memory"]) >= mem_size: 66 | memory: List[str] = env["memory"][-mem_size:] 67 | else: 68 | memory: List[str] = env["memory"] 69 | reflection_query: str = _generate_reflection_query( 70 | env_logs[i].strip(), memory 71 | ) 72 | print("start reflection env:") 73 | print(i) 74 | messages = [{"role": "user", "content": reflection_query}] 75 | reflection: str = model_server.get_completion_or_embedding(reflection_model_size, messages) # type: ignore 76 | env_configs[i]["memory"] += [reflection] 77 | 78 | return env_configs #! 删去 general reflection 的逻辑 79 | -------------------------------------------------------------------------------- /alfworld/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from IPython import embed 3 | 4 | os.environ["PYTHONUTF8"] = "1" 5 | import json 6 | import argparse 7 | from utils import ModelServer 8 | 9 | from alfworld_trial import run_trial 10 | 11 | from generate_reflections import update_memory 12 | from typing import Any, List, Dict 13 | import json 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--num_trials", type=int, help="The number of trials to run") 19 | parser.add_argument( 20 | "--num_envs", type=int, help="The number of environments per trial" 21 | ) 22 | parser.add_argument( 23 | "--mem_size", 24 | type=int, 25 | help="The size of the memory that will be used in the memory. -1 means unlimited mem, 0 means 0 mem.", 26 | default=0, 27 | ) 28 | parser.add_argument("--run_name", type=str, help="The name of the run") 29 | parser.add_argument("--is_resume", action="store_true", help="To resume run") 30 | parser.add_argument( 31 | "--resume_dir", type=str, help="If resume, the logging directory", default="" 32 | ) 33 | parser.add_argument( 34 | "--start_trial_num", 35 | type=int, 36 | help="The start trial num, if not resume, should be 0.", 37 | default=0, 38 | ) 39 | parser.add_argument( 40 | "--plan_model_size", type=str, help="only support 8, 70", default="70" 41 | ) 42 | parser.add_argument( 43 | "--reflect_model_size", type=str, help="only support 8, 70", default="70" 44 | ) 45 | parser.add_argument( 46 | "--online_embedding_model_size", type=str, help="only support 2, 7", default="2" 47 | ) 48 | parser.add_argument( 49 | "--cluster_size", 50 | type=int, 51 | help="How many instances to run in each cluster", 52 | default=5, 53 | ) 54 | parser.add_argument( 55 | "--specific_cluster_name", 56 | type=str, 57 | default=None, 58 | help="Whether to only run on a specific cluster of envs", 59 | ) 60 | parser.add_argument( 61 | "--mem_selection_method", 62 | type=str, 63 | default="fifo", 64 | help="The method to select memory for planning, only fifo/fix/random.", 65 | ) 66 | parser.add_argument( 67 | "--log_file_path", 68 | type=str, 69 | help="The path to the log file", 70 | default="", 71 | required=True, 72 | ) 73 | parser.add_argument( 74 | "--use_success_trajectory", 75 | action="store_true", 76 | help="Whether to use success trajectories in planning base prompt.", 77 | default=False, 78 | ) 79 | 80 | parser.add_argument( 81 | "--trajactory_search_method", 82 | type=str, 83 | default=None, 84 | help="The method to select memory for planning, only knn/random.", 85 | ) 86 | parser.add_argument( 87 | "--in_context_trajactory_size", 88 | type=int, 89 | default=3, 90 | help="The size of in context trajactory size.", 91 | ) 92 | args = parser.parse_args() 93 | 94 | assert args.log_file_path != "", "Log file path should be provided" 95 | assert args.num_trials > 0, "Number of trials should be positive" 96 | assert args.num_envs > 0, "Number of environments should be positive" 97 | assert ( 98 | args.resume_dir == args.run_name or args.resume_dir == "" 99 | ), "Should resume from previous directory" 100 | assert args.mem_selection_method in [ 101 | "fifo", 102 | "fix", 103 | "random", 104 | ], "Invalid memory selection method" 105 | assert args.specific_cluster_name in [ 106 | None, 107 | "clean", 108 | "put", 109 | "heat", 110 | "cool", 111 | "examine", 112 | "puttwo", 113 | ], "Invalid cluster name used" 114 | assert ( 115 | not args.use_success_trajectory 116 | and args.trajactory_search_method is None 117 | and args.in_context_trajactory_size == 0 118 | ) or ( 119 | args.use_success_trajectory 120 | and args.trajactory_search_method in ["knn", "random"] 121 | and args.in_context_trajactory_size > 0 122 | ), "If use_success_trajectory is False, then in_context_trajactory_size should be 0." 123 | assert args.trajactory_search_method in [ 124 | "knn", 125 | "random", 126 | None, 127 | ], "Invalid search method" 128 | assert args.plan_model_size in ["8", "70"], "Invalid plan model size" 129 | assert args.reflect_model_size in ["8", "70"], "Invalid reflect model size" 130 | assert args.online_embedding_model_size in [ 131 | "2", 132 | "7", 133 | ], "Invalid online embedding size" 134 | return args 135 | 136 | 137 | def main(args) -> None: 138 | if args.is_resume: 139 | if not os.path.exists(args.resume_dir): 140 | raise ValueError(f"Resume directory `{args.resume_dir}` does not exist") 141 | logging_dir = args.resume_dir 142 | 143 | # load previous environment configs 144 | env_config_path: str = os.path.join( 145 | args.resume_dir, f"env_results_trial_{args.start_trial_num - 1}.json" 146 | ) 147 | if not os.path.exists(env_config_path): 148 | raise ValueError( 149 | f"Environment config file `{env_config_path}` does not exist" 150 | ) 151 | with open(env_config_path, "r", encoding="utf-8") as rf: 152 | env_configs: List[Dict[str, Any]] = json.load(rf) 153 | else: 154 | # Create new run directory 155 | if not os.path.exists(args.run_name): 156 | os.makedirs(args.run_name) 157 | logging_dir = args.run_name 158 | 159 | # initialize environment configs 160 | env_configs: List[Dict[str, Any]] = [] 161 | 162 | env_configs = [ 163 | { 164 | "name": f"env_{i}", 165 | "memory": [], 166 | "principle": [], 167 | "is_success": False, 168 | "skip": False, 169 | "cluster": "", 170 | } 171 | for i in range(args.num_envs) 172 | ] 173 | config_path = os.path.join(logging_dir, "config.json") 174 | model_server = ModelServer(config_path=config_path) 175 | world_log_path: str = os.path.join(logging_dir, "world.log") 176 | os.system(f"touch {world_log_path}") 177 | with open(world_log_path, "a", encoding="utf-8") as wf: 178 | if args.is_resume: 179 | info_string = f""" 180 | 181 | ============================================================== 182 | RESUME 183 | Run name: {args.run_name} 184 | Number of trials: {args.num_trials} 185 | Number of max environments: {args.num_envs} 186 | Number of instances per cluster: {args.cluster_size} 187 | Resume trial number: {args.start_trial_num} 188 | Sending all logs to: {args.run_name} 189 | Setting memory size to: {args.mem_size} 190 | Plan model size: {args.plan_model_size} 191 | Reflect model size: {args.reflect_model_size} 192 | Online embedding size: {args.online_embedding_model_size} 193 | Specific cluster name: {args.specific_cluster_name} 194 | Log file path: {args.log_file_path} 195 | Use success trajectory: {args.use_success_trajectory} 196 | Trajactory search method: {args.trajactory_search_method} 197 | In context trajactory size: {args.in_context_trajactory_size} 198 | ============================================================== 199 | 200 | """ 201 | wf.write(info_string) 202 | else: 203 | info_string = f""" 204 | 205 | ============================================================== 206 | START 207 | Run name: {args.run_name} 208 | Number of trials: {args.num_trials} 209 | Number of max environments: {args.num_envs} 210 | Number of instances per cluster: {args.cluster_size} 211 | Start trial number: {args.start_trial_num} 212 | Sending all logs to: {args.run_name} 213 | Setting memory size to: {args.mem_size} 214 | Plan model size: {args.plan_model_size} 215 | Reflect model size: {args.reflect_model_size} 216 | Online embedding size: {args.online_embedding_model_size} 217 | Specific cluster name: {args.specific_cluster_name} 218 | Log file path: {args.log_file_path} 219 | Use success trajectory: {args.use_success_trajectory} 220 | Trajactory search method: {args.trajactory_search_method} 221 | In context trajactory size: {args.in_context_trajactory_size} 222 | ============================================================== 223 | 224 | """ 225 | wf.write(info_string) 226 | 227 | with open(config_path, "w", encoding="utf-8") as wf: 228 | info_dict = vars(args) 229 | info_dict["is_running"] = True 230 | json.dump(info_dict, wf, indent=4) 231 | 232 | trial_idx = args.start_trial_num 233 | cluster_counter = {} 234 | 235 | while trial_idx < args.num_trials: 236 | with open(world_log_path, "a", encoding="utf-8") as wf: 237 | wf.write( 238 | f""" 239 | 240 | ============================================================== 241 | Start Trial #{trial_idx} 242 | ============================================================== 243 | 244 | """ 245 | ) 246 | 247 | # set paths to log files 248 | trial_log_path: str = os.path.join(args.run_name, f"trial_{trial_idx}.log") 249 | os.system(f"touch {trial_log_path}") 250 | trial_env_configs_log_path: str = os.path.join( 251 | args.run_name, f"env_results_trial_{trial_idx}.json" 252 | ) 253 | if os.path.exists(trial_log_path): 254 | open(trial_log_path, "w").close() 255 | if os.path.exists(trial_env_configs_log_path): 256 | open(trial_env_configs_log_path, "w").close() 257 | 258 | env_configs, cluster_counter = run_trial( 259 | args.cluster_size, 260 | cluster_counter, 261 | model_server, 262 | args.plan_model_size, 263 | trial_log_path, 264 | world_log_path, 265 | trial_idx, 266 | env_configs, 267 | args.mem_size, 268 | specific_cluster_name=args.specific_cluster_name, 269 | mem_selection_method=args.mem_selection_method, 270 | use_success_trajectory=args.use_success_trajectory, 271 | trajactory_search_method=args.trajactory_search_method, 272 | in_context_trajactory_size=args.in_context_trajactory_size, 273 | online_embedding_model_size=args.online_embedding_model_size, 274 | ) 275 | 276 | 277 | if args.mem_size != 0: 278 | 279 | env_configs: List[Dict[str, Any]] = update_memory( 280 | model_server, 281 | args.reflect_model_size, 282 | trial_log_path, 283 | env_configs, 284 | mem_size=args.mem_size, 285 | ) 286 | 287 | # log env configs for trial 288 | with open(trial_env_configs_log_path, "w", encoding="utf-8") as wf: 289 | json.dump(env_configs, wf, indent=4) 290 | 291 | # log world for trial 292 | with open(world_log_path, "a", encoding="utf-8") as wf: 293 | wf.write( 294 | f""" 295 | 296 | ============================================================== 297 | End Trial #{trial_idx} 298 | ============================================================== 299 | 300 | """ 301 | ) 302 | trial_idx += 1 303 | 304 | 305 | if __name__ == "__main__": 306 | args = get_args() 307 | main(args) 308 | -------------------------------------------------------------------------------- /alfworld/prompts/fever.json: -------------------------------------------------------------------------------- 1 | {"webact_simple3": "\nDetermine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nAction 1: Search[Nikolaj Coster-Waldau]\nObservation 1: Nikolaj William Coster-Waldau (born 27 July 1970) is a Danish actor and producer. He graduated from the Danish National School of Performing Arts in Copenhagen in 1993,[1] and had his breakthrough role in Denmark with the film Nightwatch (1994). He played Jaime Lannister in the HBO fantasy drama series Game of Thrones, for which he received two Primetime Emmy Award nominations for Outstanding Supporting Actor in a Drama Series.. Coster-Waldau has appeared in numerous films in his native Denmark and Scandinavia, including Headhunters (2011) and A Thousand Times Good Night (2013). In the U.S, his debut film role was in the war film Black Hawk Down (2001), playing Medal of Honor recipient Gary Gordon.[2] He then played a detective in the short-lived Fox television series New Amsterdam (2008), and appeared in the 2009 Fox television film Virtuality, originally intended as a pilot.\nAction 2: Finish[SUPPORTS]\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nAction 1: Search[Stranger Things]\nObservation 1: Stranger Things is an American science fiction horror drama television series created by the Duffer Brothers. Set in the 1980s, primarily in the fictional town of Hawkins, Indiana, the series centers on a number of mysteries and supernatural events occurring around the town and their impact on an ensemble of child and adult characters. \nAction 2: Finish[REFUTES]\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nAction 1: Search[Beautiful]\nObservation 1: Could not find [Beautiful]. Similar: ['Beautiful', 'Beautiful, Beautiful', 'A Beautiful Mind (film)', 'Beautiful (Christina Aguilera song)', 'Life Is Beautiful'].\nAction 2: Search[Beautiful (Christina Aguilera song)]\nObservation 2: \"Beautiful\" is a song recorded by American singer Christina Aguilera for her fourth studio album, Stripped (2002).\nAction 3: Lookup[Billboard Hot 100]\nObservation 3: (Result 1 / 3) The song peaked at number two on the Billboard Hot 100 in the United States, where it was certified Gold for 500,000 units shipped.\nAction 4: Finish[NOT ENOUGH INFO]\n\n", "cotqa_simple3": "Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nThought: Nikolaj William Coster-Waldau appeared in the 2009 Fox television film Virtuality, so he has worked with the Fox Broadcasting Company.\nAnswer: SUPPORTS\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nThought: Stranger Things is in the fictional town of Hawkins, Indiana, not in Bloomington, Indiana.\nAnswer:REFUTES\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nThought: The song peaked at number two on the Billboard Hot 100 in the United States, but not sure if it was in 2003.\nAnswer: NOT ENOUGH INFO\n", "webqa_simple3": "Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nAnswer: SUPPORTS\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nAnswer:REFUTES\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nAnswer: NOT ENOUGH INFO\n", "webthink_simple3": "\nDetermine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nThought 1: I need to search Nikolaj Coster-Waldau and find if he has worked with the Fox Broadcasting Company.\nAction 1: Search[Nikolaj Coster-Waldau]\nObservation 1: Nikolaj William Coster-Waldau (born 27 July 1970) is a Danish actor and producer. He graduated from the Danish National School of Performing Arts in Copenhagen in 1993,[1] and had his breakthrough role in Denmark with the film Nightwatch (1994). He played Jaime Lannister in the HBO fantasy drama series Game of Thrones, for which he received two Primetime Emmy Award nominations for Outstanding Supporting Actor in a Drama Series.. Coster-Waldau has appeared in numerous films in his native Denmark and Scandinavia, including Headhunters (2011) and A Thousand Times Good Night (2013). In the U.S, his debut film role was in the war film Black Hawk Down (2001), playing Medal of Honor recipient Gary Gordon.[2] He then played a detective in the short-lived Fox television series New Amsterdam (2008), and appeared in the 2009 Fox television film Virtuality, originally intended as a pilot.\nThought 2: Because he \"appeared in the 2009 Fox television film Virtuality\", he should have worked with the Fox Broadcasting Company.\nAction 2: Finish[SUPPORTS]\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nThought 1: I should search for Stranger Things, and see if it is set in Bloomington, Indiana.\nAction 1: Search[Stranger Things]\nObservation 1: Stranger Things is an American science fiction horror drama television series created by the Duffer Brothers. Set in the 1980s, primarily in the fictional town of Hawkins, Indiana, the series centers on a number of mysteries and supernatural events occurring around the town and their impact on an ensemble of child and adult characters. \nThought 2: The observation says that it is set in a \"fictional town of Hawkins, Indiana\", so it is not set in Bloomington.\nAction 2: Finish[REFUTES]\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nThought 1: I need to search the song Beautiful and find if it reached number two on the Billboard Hot 100 in 2003.\nAction 1: Search[Beautiful]\nObservation 1: Could not find [Beautiful]. Similar: ['Beautiful', 'Beautiful, Beautiful', 'A Beautiful Mind (film)', 'Beautiful (Christina Aguilera song)', 'Life Is Beautiful'].\nThought 2: From suggestions, I should search \"Beautiful (Christina Aguilera song)\" to find the song.\nAction 2: Search[Beautiful (Christina Aguilera song)]\nObservation 2: \"Beautiful\" is a song recorded by American singer Christina Aguilera for her fourth studio album, Stripped (2002).\nThought 3: It does not mention Billboard, so I need to look up \"Billboard Hot 100\" to find if it reached number two on it in 2003.\nAction 3: Lookup[Billboard Hot 100]\nObservation 3: (Result 1 / 3) The song peaked at number two on the Billboard Hot 100 in the United States, where it was certified Gold for 500,000 units shipped.\nThought 4: It only says the song peaked at number two on the Billboard Hot 100, but not if it was in 2003. I am not sure if this claim is true or not.\nAction 4: Finish[NOT ENOUGH INFO]\n\n"} -------------------------------------------------------------------------------- /alfworld/requirements.txt: -------------------------------------------------------------------------------- 1 | faiss 2 | openai 3 | alfworld -------------------------------------------------------------------------------- /alfworld/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" 4 | os.environ["PYTHONUTF8"] = "1" 5 | import time 6 | from typing import Dict, List 7 | import openai 8 | import random, json 9 | from IPython import embed 10 | import random 11 | import math 12 | 13 | 14 | class EnvironmentHistory: 15 | def __init__( 16 | self, 17 | base_query: str, 18 | start_info, 19 | memory: List[str], 20 | history: List[Dict[str, str]] = [], 21 | ) -> None: 22 | 23 | def _get_base_query(base_query: str, start_info: str, memory: List[str]) -> str: 24 | query = base_query 25 | 26 | # add memory if it exists 27 | 28 | query += f"\nHere is the task:\n{start_info}" 29 | if len(memory) > 0: 30 | query += "\n\nBelow are your reflection memory for the task, you should apply them wisely in your planning:\n[memory start]\n" 31 | for i, m in enumerate(memory): 32 | query += f"\nReflection from Trial {i}:\n{m.strip()}" 33 | query += "\n[memory end]\n" 34 | return query 35 | 36 | self._cur_query: str = f"{_get_base_query(base_query, start_info, memory)}" 37 | self._history: List[Dict[str, str]] = history 38 | self._last_action: str = "" 39 | self._is_exhausted: bool = False 40 | 41 | def add(self, label: str, value: str) -> None: 42 | assert label in ["action", "observation", "human_edit"] 43 | self._history += [ 44 | { 45 | "label": label, 46 | "value": value, 47 | } 48 | ] 49 | if label == "action": 50 | if value == self._last_action: 51 | self._is_exhausted = True 52 | else: 53 | self._last_action = value 54 | 55 | def check_is_exhausted(self) -> bool: 56 | return self._is_exhausted 57 | 58 | def reset(self) -> None: 59 | self._history = [] 60 | 61 | def __str__(self) -> str: 62 | s: str = self._cur_query + "\n" 63 | for i, item in enumerate(self._history): 64 | if item["label"] == "action": 65 | s += f'> {item["value"]}' 66 | elif item["label"] == "observation": 67 | s += item["value"] 68 | # NOT CURRENTLY SUPPORTED 69 | elif item["label"] == "human_edit": 70 | s += f'[human edit]: {item["value"]}' 71 | if i != len(self._history) - 1: 72 | s += "\n" 73 | return s 74 | 75 | 76 | class ModelServer: 77 | 78 | def get_completion_or_embedding( 79 | self, 80 | model_size: str, 81 | message, 82 | temperature: float = 0.0, 83 | max_tokens: int = 256, 84 | get_embedding: bool = False, 85 | ) -> str: 86 | assert model_size in ["70", "8", "7"] 87 | 88 | if not get_embedding: 89 | assert type(message) == list, "Message should be a list." 90 | response = client.chat.completions.create( 91 | model=model_name, 92 | messages=message, 93 | max_tokens=max_tokens, 94 | temperature=temperature, 95 | stop=["<|eot_id|>"], 96 | ) 97 | else: 98 | assert type(message) == str, "Message should be a string." 99 | response = client.embeddings.create( 100 | model=model_name, 101 | input=message, 102 | ) 103 | return ( 104 | str(response.choices[0].message.content) 105 | if not get_embedding 106 | else response.data[0].embedding 107 | ) 108 | 109 | -------------------------------------------------------------------------------- /figures/illusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uclaml/COPS/989b29882a4ee6cdca828c402268216e4c5f5d85/figures/illusion.png -------------------------------------------------------------------------------- /hotpotqa/README: -------------------------------------------------------------------------------- 1 | # HotPotQA experiment code 2 | 3 | Code modified from [LATS](https://github.com/lapisrocks/LanguageAgentTreeSearch) 4 | 5 | Setup your local LLM using an OpenAI API compatible endpoint, then set it up in utils.py. We recommend [SGLang](https://github.com/sgl-project/sglang) 6 | 7 | Example command: 8 | 9 | python run.py --model_size 8 --temperature 0.0 --task_start_index 0 --task_end_index 100 --cot_size 5 \ 10 | --prompt_sample 'cot' --n_generate_sample 1 --n_evaluate_sample 1 \ 11 | --iteration 10 --algorithm 'tot' --cot_method 'knn' --run_name 'hotpotqa_new' -------------------------------------------------------------------------------- /hotpotqa/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Task: 5 | def __init__(self): 6 | pass 7 | 8 | def __len__(self) -> int: 9 | pass 10 | 11 | def get_input(self, idx: int) -> str: 12 | pass 13 | 14 | def test_output(self, idx: int, output: str): 15 | pass 16 | -------------------------------------------------------------------------------- /hotpotqa/hotpotqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from base import Task 4 | from hotpot import * 5 | from models import gpt 6 | import logging 7 | from transformers import GPT2Tokenizer 8 | import random 9 | 10 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 11 | 12 | 13 | def get_token_length(text): 14 | return len(tokenizer.encode(text)) 15 | 16 | 17 | max_token_length = 4000 18 | 19 | # Configuring the logging 20 | 21 | 22 | class HotPotQATask(Task): 23 | """ 24 | Input (x) : a text instruction 25 | Output (y) : a text generation 26 | Reward (r) : # TODO 27 | Input Example: 28 | Output Example: 29 | """ 30 | 31 | def __init__(self): 32 | """ 33 | file: a text file, each line is some sentences 34 | """ 35 | super().__init__() 36 | self.steps = 7 37 | self.stops = ["\nObservation:\n", None] 38 | self.value_cache = {} 39 | 40 | def __len__(self) -> int: 41 | return len(self.data) 42 | 43 | def get_input(self, idx: int) -> str: 44 | return self.data[idx] 45 | 46 | def test_output(self, idx: int, output: str): 47 | output = output.split("Action:\n")[-1] 48 | prompt = score_prompt + output 49 | score_outputs = gpt(prompt, n=5, model="gpt-4") 50 | scores = [] 51 | for score_output in score_outputs: 52 | # print(score_output) 53 | pattern = r".*correctness score is (\d+).*" 54 | match = re.match(pattern, score_output, re.DOTALL) 55 | if match: 56 | score = int(match.groups()[0]) 57 | scores.append(score) 58 | else: 59 | print(f"------------------score no match: {[score_output]}") 60 | print(scores) 61 | # print('------------') 62 | info = {"rs": scores, "r": sum(scores) / len(scores) if scores else 0} 63 | return info 64 | 65 | @staticmethod 66 | def standard_prompt_wrap(x: str, y: str = "") -> str: 67 | return standard_prompt.format(input=x) + y 68 | 69 | @staticmethod 70 | def generate_self_reflection(z, question): 71 | reflection_mapping = [] 72 | trajectories = "" 73 | 74 | sampled_items = random.sample(z, min(3, len(z))) 75 | failed_trajectories = "\n".join([f"{question}\n{traj}\n" for traj in z]) 76 | failed_trajectories = [ 77 | f"Question: {traj}" for traj in failed_trajectories.split("Question: ")[1:] 78 | ] 79 | 80 | for traj in failed_trajectories: 81 | trajectories += traj 82 | 83 | reflect_prompt = reflection_prompt.format(trajectory=traj) 84 | 85 | reflection = gpt(reflect_prompt) 86 | 87 | trajectories += "Reflection: " + reflection[0] + "\n" 88 | 89 | reflection_mapping.append( 90 | {"question": question, "trajectory": traj, "reflection": reflection[0]} 91 | ) 92 | 93 | return reflection_mapping 94 | 95 | @staticmethod 96 | def cot_prompt_wrap(x: str, y: str = "", reflection_mapping_list=[], knn=None): 97 | question = x 98 | input = x + y 99 | trajectories = "" 100 | # z = [] 101 | if reflection_mapping_list: 102 | for reflection_mapping in reflection_mapping_list: 103 | traj_with_reflection = ( 104 | reflection_mapping["trajectory"] 105 | + "FAILED TRAJECTORY\nReflection: " 106 | + reflection_mapping["reflection"] 107 | + "\n\n" 108 | ) 109 | trajectories += traj_with_reflection 110 | 111 | prompt = cot_prompt_feedback.format(trajectories=trajectories, input=input) 112 | if get_token_length(prompt) > max_token_length: 113 | print("Too long") 114 | trajectories = "" 115 | for reflection_mapping in reflection_mapping_list[:3]: 116 | traj_with_reflection = ( 117 | reflection_mapping["trajectory"] 118 | + "FAILED TRAJECTORY\nReflection: " 119 | + reflection_mapping["reflection"] 120 | + "\n\n" 121 | ) 122 | trajectories += traj_with_reflection 123 | prompt = cot_prompt_feedback_short.format( 124 | trajectories=trajectories, input=input 125 | ) 126 | 127 | return prompt 128 | else: 129 | prompt = cot_prompt 130 | if knn: 131 | for traj in knn: 132 | prompt += traj + "\n" 133 | prompt += "\n" + input + "\n" 134 | # if get_token_length(prompt) > max_token_length: 135 | # prompt = cot_prompt_short.format(input=input) 136 | return prompt 137 | 138 | @staticmethod 139 | def vote_prompt_wrap(x: str, ys: list) -> str: 140 | prompt = vote_prompt + "\n" + x + "\n\n" 141 | for i, y in enumerate(ys, 1): 142 | # y = y.replace('Plan:\n', '') 143 | # TODO: truncate the plan part? 144 | prompt += f"Choice {i}:\n{y}\n" 145 | return prompt 146 | 147 | @staticmethod 148 | def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list: 149 | vote_results = [0] * n_candidates 150 | for vote_output in vote_outputs: 151 | pattern = r".*best trajectory is .*(\d+).*" 152 | match = re.match(pattern, vote_output, re.DOTALL) 153 | if match: 154 | vote = int(match.groups()[0]) - 1 155 | if vote in range(n_candidates): 156 | vote_results[vote] += 1 157 | else: 158 | print(f"vote no match: {[vote_output]}") 159 | return vote_results 160 | 161 | @staticmethod 162 | def compare_prompt_wrap(x: str, ys: list) -> str: 163 | assert len(ys) == 2, "compare prompt only supports 2 candidates" 164 | 165 | # Extract the last Action for each trajectory 166 | last_actions = [] 167 | for y in ys: 168 | # Split by line and reverse to start from the end 169 | lines = y.split("\n")[::-1] 170 | for line in lines: 171 | # Check for an Action line and get its content 172 | if "Action" in line: 173 | last_actions.append(line.split("Action")[-1].strip(": ")) 174 | break 175 | 176 | assert len(last_actions) == 2, "Expected to find 2 Actions" 177 | 178 | # Construct the prompt with the extracted Actions 179 | prompt = ( 180 | compare_prompt 181 | + f"Action 1:{last_actions[0]}\n\nAction 2:{last_actions[1]}\n" 182 | ) 183 | return prompt 184 | 185 | @staticmethod 186 | def compare_output_unwrap(compare_output: str): 187 | if "more correct trajectory is 1" in compare_output: 188 | return 0 189 | elif "more correct trajectory is 2" in compare_output: 190 | return 1 191 | elif "two trajectories are similarly correct" in compare_output: 192 | return 0.5 193 | else: 194 | print(f"-----------------compare no match: {[compare_output]}") 195 | return -1 196 | 197 | @staticmethod 198 | def value_prompt_wrap(x: str, y: str, z: list = [], reflections: list = []) -> str: 199 | question = x.split("\n")[0] 200 | # z = [] 201 | if len(z) != 0: 202 | failed_trajectories = "" 203 | 204 | # Combine the trajectories with their corresponding reflections 205 | for traj, ref in zip(z, reflections): 206 | failed_trajectories += f"{question}\n{traj}\nThis trajectory is incorrect as {ref['reflection']}\nThus the correctness score is 1\n" 207 | 208 | inp = x + y + "\nThis trajectory is " 209 | 210 | prompt = value_prompt_reasoning_feedback.format( 211 | s="", trajectories=failed_trajectories, input=inp 212 | ) 213 | 214 | if get_token_length(prompt) > max_token_length: 215 | prompt = value_prompt_reasoning_feedback_short.format( 216 | s="", trajectories=failed_trajectories, input=inp 217 | ) 218 | if len(z) != 0 and False: 219 | failed_trajectories = "\n".join( 220 | [f"{question}\n{traj}\nThus the correctness score is 1\n" for traj in z] 221 | ) 222 | inp = x + y + "\nThus the correctness score is " 223 | prompt = value_prompt_feedback.format( 224 | s="", trajectories=failed_trajectories, input=inp 225 | ) 226 | if get_token_length(prompt) > max_token_length: 227 | print("Token length exceeded", get_token_length(prompt)) 228 | failed_trajectories = "\n".join( 229 | [ 230 | f"{question}\n{traj}\nThus the correctness score is 1\n" 231 | for traj in z[:2] 232 | ] 233 | ) 234 | inp = y + "\nThus the correctness score is " 235 | prompt = value_prompt_feedback.format( 236 | s="", trajectories=failed_trajectories, input=inp 237 | ) 238 | print("New length", get_token_length(prompt)) 239 | else: 240 | inp = y + "\nThis trajectory is " 241 | # inp = y + "\nThus the correctess score is " 242 | # prompt = value_prompt.format(s="", input=inp) 243 | prompt = value_prompt_reasoning.format(s="", input=inp) 244 | 245 | return prompt 246 | 247 | @staticmethod 248 | def value_outputs_unwrap(evaluate_prompt: str): 249 | evaluate_prompt = evaluate_prompt[0] 250 | if "10" in evaluate_prompt: 251 | return 1.0 252 | elif "9" in evaluate_prompt: 253 | return 0.9 254 | elif "8" in evaluate_prompt: 255 | return 0.8 256 | elif "7" in evaluate_prompt: 257 | return 0.7 258 | elif "6" in evaluate_prompt: 259 | return 0.6 260 | elif "5" in evaluate_prompt: 261 | return 0.5 262 | elif "4" in evaluate_prompt: 263 | return 0.4 264 | elif "3" in evaluate_prompt: 265 | return 0.3 266 | elif "2" in evaluate_prompt: 267 | return 0.2 268 | elif "1" in evaluate_prompt: 269 | return 0.1 270 | else: 271 | return -1 272 | -------------------------------------------------------------------------------- /hotpotqa/models.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from utils import ModelServer 3 | import time 4 | import re 5 | 6 | server = None 7 | 8 | 9 | def build_server(config_path): 10 | global server 11 | server = ModelServer(config_path=config_path) 12 | 13 | 14 | def split_and_keep_prefixes(s, delimiters): 15 | regex_pattern = f"({'|'.join(map(re.escape, delimiters))})" 16 | parts = re.split(regex_pattern, s) 17 | result = [parts[0]] 18 | for i in range(1, len(parts), 2): 19 | result.append(parts[i] + (parts[i + 1] if i + 1 < len(parts) else "")) 20 | return result 21 | 22 | 23 | def online_embed(traj): 24 | return server.get_completion_or_embedding( 25 | "7", 26 | message=traj, 27 | get_embedding=True, 28 | ) 29 | 30 | 31 | def gpt(prompt, model_size="8", temperature=1.0, max_tokens=100, n=1) -> list: 32 | def call_openai_api(messages, model_size, temperature, max_tokens, n): 33 | outputs = [] 34 | while n > 0: 35 | cnt = min(n, 20) 36 | n -= cnt 37 | res = server.get_completion_or_embedding( 38 | model_size, messages, temperature, max_tokens 39 | ) 40 | outputs.extend( 41 | [ 42 | re.sub(r"^Thought \d+: ", "", choice.message.content) 43 | for choice in res.choices 44 | ] 45 | ) 46 | return outputs 47 | 48 | messages = [] 49 | parts = re.split(r"(Thought \d+:|Observation \d+:|Question:)", prompt) 50 | 51 | result = [parts[0].strip()] 52 | for i in range(1, len(parts), 2): 53 | if i + 1 < len(parts): 54 | result.append(parts[i] + " " + parts[i + 1].strip()) 55 | result.pop() 56 | 57 | for msg in result: 58 | if msg.startswith("Solve"): 59 | messages.append({"role": "user", "content": msg}) 60 | if msg.startswith("Thought"): 61 | messages.append({"role": "assistant", "content": msg}) 62 | if msg.startswith("Observation"): 63 | messages.append({"role": "user", "content": msg}) 64 | if msg.startswith("Question"): 65 | messages.append({"role": "user", "content": msg}) 66 | 67 | return call_openai_api( 68 | messages, model_size=model_size, temperature=temperature, max_tokens=max_tokens, n=n 69 | ) 70 | 71 | 72 | reflection_prompt = """You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to an Docstore API environment and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[], or you used up your set number of reasoning steps. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences. 73 | Here are some examples: 74 | Previous Trial: 75 | Question: The Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of what? 76 | Thought 1: I need to search Rome Protocols, find the three Prime Ministers, then find what they were assassinated as part of. 77 | Action 1: Search[Rome Protocols] 78 | Observation 1: The Rome Protocols were a series of three international agreements concluded in Rome on 17 March 1934 between the governments of Austria, Hungary and Italy. They were signed by Italian Prime Minister Benito Mussolini, Austrian Chancellor Engelbert Dollfuss and Hungarian Prime Minister Gyula Gömbös. All of the protocols went into effect on 12 July 1934 and were registered in League of Nations Treaty Series on 12 December 1934. 79 | Thought 2: The three Prime Ministers were Benito Mussolini, Engelbert Dollfuss, and Gyula Gömbös. I need to search Benito Mussolini and find what he was assassinated as part of. 80 | Action 2: Search[Benito Mussolini] 81 | Observation 2: Benito Amilcare Andrea Mussolini (UK: MU(U)SS-ə-LEE-nee, US: MOOSS-, Italian: [beˈniːto aˈmilkare anˈdrɛːa mussoˈliːni]; 29 July 1883 – 28 April 1945) was an Italian politician and journalist who founded and led the National Fascist Party (PNF). He was Prime Minister of Italy from the March on Rome in 1922 until his deposition in 1943, as well as "Duce" of Italian fascism from the establishment of the Italian Fasces of Combat in 1919 until his summary execution in 1945 by Italian partisans. As dictator of Italy and principal founder of fascism, Mussolini inspired and supported the international spread of fascist movements during the inter-war period.Mussolini was originally a socialist politician and a journalist at the Avanti! newspaper. In 1912, he became a member of the National Directorate of the Italian Socialist Party (PSI), but he was expelled from the PSI for advocating military intervention in World War I, in opposition to the party's stance on neutrality. In 1914, Mussolini founded a new journal, Il Popolo d'Italia, and served in the Royal Italian Army during the war until he was wounded and discharged in 1917. Mussolini denounced the PSI, his views now centering on Italian nationalism instead of socialism, and later founded the fascist movement which came to oppose egalitarianism and class conflict, instead advocating "revolutionary nationalism" transcending class lines. On 31 October 1922, following the March on Rome (28–30 October), Mussolini was appointed prime minister by King Victor Emmanuel III, becoming the youngest individual to hold the office up to that time. After removing all political opposition through his secret police and outlawing labor strikes, Mussolini and his followers consolidated power through a series of laws that transformed the nation into a one-party dictatorship. Within five years, Mussolini had established dictatorial authority by both legal and illegal means and aspired to create a totalitarian state. In 1929, Mussolini signed the Lateran Treaty with the Holy See to establish Vatican City. 82 | Mussolini's foreign policy aimed to restore the ancient grandeur of the Roman Empire by expanding Italian colonial possessions and the fascist sphere of influence. In the 1920s, he ordered the Pacification of Libya, instructed the bombing of Corfu over an incident with Greece, established a protectorate over Albania, and incorporated the city of Fiume into the Italian state via agreements with Yugoslavia. In 1936, Ethiopia was conquered following the Second Italo-Ethiopian War and merged into Italian East Africa (AOI) with Eritrea and Somalia. In 1939, Italian forces annexed Albania. Between 1936 and 1939, Mussolini ordered the successful Italian military intervention in Spain in favor of Francisco Franco during the Spanish Civil War. Mussolini's Italy initially tried to avoid the outbreak of a second global war, sending troops at the Brenner Pass to delay Anschluss and taking part in the Stresa Front, the Lytton Report, the Treaty of Lausanne, the Four-Power Pact and the Munich Agreement. However, Italy then alienated itself from Britain and France by aligning with Germany and Japan. Germany invaded Poland on 1 September 1939, resulting in declarations of war by France and the UK and the start of World War II. 83 | On 10 June 1940, Mussolini decided to enter the war on the Axis side. Despite initial success, the subsequent Axis collapse on multiple fronts and eventual Allied invasion of Sicily made Mussolini lose the support of the population and members of the Fascist Party. As a consequence, early on 25 July 1943, the Grand Council of Fascism passed a motion of no confidence in Mussolini; later that day King Victor Emmanuel III dismissed him as head of government and had him placed in custody, appointing Pietro Badoglio to succeed him as Prime Minister. After the king agreed to an armistice with the Allies, on 12 September 1943 Mussolini was rescued from captivity in the Gran Sasso raid by German paratroopers and Waffen-SS commandos led by Major Otto-Harald Mors. Adolf Hitler, after meeting with the rescued former dictator, then put Mussolini in charge of a puppet regime in northern Italy, the Italian Social Republic (Italian: Repubblica Sociale Italiana, RSI), informally known as the Salò Republic, causing a civil war. In late April 1945, in the wake of near total defeat, Mussolini and his mistress Clara Petacci attempted to flee to Switzerland, but both were captured by Italian communist partisans and summarily executed by firing squad on 28 April 1945 near Lake Como. The bodies of Mussolini and his mistress were then taken to Milan, where they were hung upside down at a service station to publicly confirm their demise. 84 | Thought 3: Benito Mussolini was assassinated as part of the collapse of the Axis powers during World War II. So the Rome Protocols were signed by three Prime Ministers one of which was assassinated as part of World War II. 85 | Action 3: Finish[World War II] 86 | 87 | Reflection: I searched one of the prime ministers involved in the signing, then attemted to answer right away. I should have searched each of the prime ministers, then looked up 'death' on each of their pages in order to get more information before answering. 88 | 89 | Previous Trial: 90 | Question: Kam Heskin plays Paige Morgan in a 2004 film directed by who? 91 | Thought 1: I need to search Kam Heskin and Paige Morgan, find the 2004 film they are in, then find the director of the film. 92 | Action 1: Search[Kam Heskin] 93 | Observation 1: Kam Erika Heskin (born May 8, 1973) is an American actress. She began her career playing Caitlin Deschanel on the NBC daytime soap opera Sunset Beach (1998–1999), before appearing in films Planet of the Apes (2001) and Catch Me If You Can (2002). Heskin went on in 2003 to play Elizabeth Bennet in an independent film Pride & Prejudice: A Latter-Day Comedy, Robin in Charmed (2004), and Paige Morgan in The Prince and Me film franchise (2006–2010). 94 | Thought 2: Kam Heskin plays Paige Morgan in The Prince and Me film franchise. I need to search The Prince and Me film franchise and find the director of the 2004 film. 95 | Action 2: Search[The Prince and Me film franchise] 96 | Observation 2: Could not find [The Prince and Me film franchise]. Similar: ['Kam Heskin', 'Prince of Persia', 'Harry Potter and the Half-Blood Prince (film)', 'Prince of Persia: The Sands of Time (film)', 'The Chronicles of Narnia (film series)', 'Majestic Prince (franchise)', 'Kingsman (franchise)', 'The Little Prince (2015 film)', 'Shrek (franchise)', 'Harry Potter (film series)'] 97 | Thought 3: I need to search The Prince and Me (2004 film) instead. 98 | Action 3: Search[The Prince and Me (2004 film)] 99 | Observation 3: Could not find [The Prince and Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Kam Heskin', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince of Persia', 'Brooklynn Prince', 'Prince (musician)'] 100 | Thought 4: I need to search The Prince & Me (2004 film). 101 | Action 4: Search[The Prince & Me (2004 film)] 102 | Observation 4: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince'] 103 | Thought 5: I need to search The Prince & Me (2004 film) instead. 104 | Action 5: Search[The Prince & Me (2004 film)] 105 | Observation 5: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince'] 106 | Thought 6: I need to search The Prince & Me (2004 film). 107 | Action 6: Search[The Prince & Me (2004 film)] 108 | Observation 6: Could not find [The Prince & Me (2004 film)]. Similar: ['The Prince & Me', 'The Prince & Me 2: The Royal Wedding', 'Prince of Darkness (film)', 'Prince of Persia: The Sands of Time (film)', 'Rob Knox', 'Alexander (2004 film)', 'Prince (musician)', 'Prince of Persia', 'Kam Heskin', 'Brooklynn Prince'] 109 | 110 | Reflection: I got stuck in a loop where I kept trying to search 'The Prince & Me (2004 film)' but the page could not be found. Instead I should have tried to search the similar results that had a similar name to see and they were made in 2004. 111 | 112 | Previous Trial: 113 | Question: The Oberoi family is part of a hotel company that has a head office in what city? 114 | Thought 1: I need to search the Oberoi family and find the city where the hotel company's head office is located. 115 | Action 1: Search[Oberoi family] 116 | Observation 1: The Oberoi Group is a luxury hotel group with its head office in New Delhi, India.[2] Founded in 1934, the company owns and operates 32 luxury hotels and two river cruise ships in 7 countries, primarily under its Oberoi Hotels & Resorts and Trident brands.[3] The group also operates The Oberoi Centre for Learning and Development, which is regarded as one of Asia's top institutions for hospitality education.[4]. The foundations of the Oberoi Group date back to 1934 when The Rai Bahadur Mohan Singh Oberoi, the founder of the group, bought two properties: the Maidens in Delhi and the Clarke's in Shimla.[5] In the following years Oberoi, assisted by his two sons, Tilak Raj Singh Oberoi and Prithvi Raj Singh Oberoi (P.R.S. Oberoi 117 | Thought 2: The Oberoi Group has its head office in New Delhi, India. I just need to finish with New Delhi. 118 | Action 2: Finish[New Delhi] 119 | Observation 2: Episode finished, reward = 0 120 | 121 | Reflection: I appeared to have retrieved the correct information about The Oberoi Family and the location of it's head office, and provided a corresponding answer. However this answer does not exactly match the ground truth answer so I should try a different wording, such as Delhi. 122 | 123 | Previous trial:""" 124 | 125 | def reflgpt(prompt, model_size="8", temperature=1.0, max_tokens=100, n=1) -> list: 126 | def call_openai_api(messages, model_size, temperature, max_tokens, n): 127 | outputs = [] 128 | while n > 0: 129 | cnt = min(n, 20) 130 | n -= cnt 131 | res = server.get_completion_or_embedding( 132 | model_size, messages, temperature, max_tokens 133 | ) 134 | outputs.extend( 135 | [ 136 | re.sub(r"^Thought \d+: ", "", choice.message.content) 137 | for choice in res.choices 138 | ] 139 | ) 140 | return outputs 141 | 142 | # messages = [{"role": "user", "content": prompt}] 143 | messages = [] 144 | messages.append({"role": "user", "content": reflection_prompt+"\n"+prompt+"\nReflection:"}) 145 | # print(messages) 146 | # time.sleep(30) 147 | return call_openai_api( 148 | messages, model_size=model_size, temperature=temperature, max_tokens=max_tokens, n=n 149 | ) -------------------------------------------------------------------------------- /hotpotqa/requirements.txt: -------------------------------------------------------------------------------- 1 | faiss 2 | openai 3 | transformers 4 | logging 5 | re 6 | typing 7 | itertools 8 | functools 9 | wikienv 10 | wrappers 11 | requests 12 | 13 | -------------------------------------------------------------------------------- /hotpotqa/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import math 6 | from models import online_embed, build_server 7 | import faiss 8 | from hotpotqa import HotPotQATask 9 | from tot import dfs_search, quote_env 10 | import logging 11 | import random 12 | 13 | 14 | def random_selection(lst, n=5): 15 | if len(lst) <= n: 16 | return lst 17 | else: 18 | return random.sample(lst, n) 19 | 20 | 21 | def run(args): 22 | task = HotPotQATask() 23 | print(task) 24 | logs, cnt_avg, cnt_any = [], 0, 0 25 | 26 | config_path = os.path.join(args.run_name, "config.json") 27 | 28 | with open(config_path, "w", encoding="utf-8") as wf: 29 | info_dict = vars(args) 30 | info_dict["is_running"] = True 31 | json.dump(info_dict, wf, indent=4) 32 | 33 | build_server(config_path=config_path) 34 | 35 | os.makedirs(os.path.dirname(args.log_dir), exist_ok=True) 36 | 37 | logging.basicConfig( 38 | filename=args.log_dir, 39 | level=logging.INFO, 40 | format="%(asctime)s - %(levelname)s - %(message)s", 41 | filemode="a", 42 | ) 43 | 44 | wins = {} 45 | lose = {} 46 | trajectories = [] 47 | embedding_array = np.zeros((0, 3584)) 48 | tongji=[] 49 | cnt=0 50 | for trial in range(10): 51 | print("Trial") 52 | print(trial) 53 | count = 0 54 | task_accs = [] 55 | info = [] 56 | emb_db = faiss.IndexFlatL2(3584) 57 | emb_db.add(embedding_array.astype("float32")) 58 | 59 | for i in range(args.task_start_index, args.task_end_index): 60 | # solve 61 | if i in wins: 62 | continue 63 | prev = None 64 | knnret = [] 65 | if trajectories and args.cot_size>0: 66 | if i in lose: 67 | prev = lose[i] 68 | else: 69 | prev=quote_env(i) 70 | if args.cot_method=="knn": 71 | fail_vec = online_embed(str(prev)) 72 | dist=[] 73 | for index, row in np.ndenumerate(embedding_array): 74 | if index[1] == 0: 75 | realN = (1.0 / np.linalg.norm(np.array(fail_vec) - embedding_array[index[0]])) 76 | dist.append((realN,index[0])) 77 | dist.sort(key=lambda x: x[0], reverse=True) 78 | sz_now=min(len(trajectories),args.cot_size) 79 | dist=dist[:sz_now] 80 | dist=[(math.exp(0.0*x),y) for (x,y) in dist] 81 | original_sum=sum([x for (x,_) in dist]) 82 | dist=[(x *float(sz_now) / original_sum,y) for (x,y) in dist] 83 | tot=0 84 | cntD=[] 85 | realD=[] 86 | for (x,y) in dist: 87 | cntD.append(math.floor(x)) 88 | realD.append(x) 89 | tot+=math.floor(x) 90 | while totmaxx: 95 | maxx=realD[ig]-float(cntD[ig]) 96 | maxi=ig 97 | cntD[maxi]+=1 98 | tot+=1 99 | weights=[] 100 | for (x,y) in dist: 101 | weights.append(x) 102 | normalized_weights = [w / sum(weights) for w in weights] 103 | for ig in range(sz_now): 104 | s_ind = np.random.choice(len(weights), p=normalized_weights) 105 | (_,y)=dist[s_ind] 106 | knnret.append(trajectories[y]) 107 | else: 108 | knnret=random.sample(trajectories, min(args.cot_size,len(trajectories))) 109 | state, value, all_nodes, reward, em, failt, succt = dfs_search( 110 | args, task, i, args.iteration, knnret 111 | ) 112 | if failt: 113 | print("Fail") 114 | print(i) 115 | lose[i] = failt[0] 116 | if succt: 117 | cnt=cnt+1 118 | print("Success") 119 | print(i) 120 | wins[i] = 1 121 | vec = online_embed(str(succt[0])) 122 | trajectories.append(succt[0]) 123 | embedding_array = np.vstack((embedding_array, np.array(vec))) 124 | # log main metric 125 | if em is None: 126 | em = 0 127 | task_accs.append(em) 128 | cnt_avg = sum(task_accs) / len(task_accs) 129 | print(i, "len(task_accs)", len(task_accs), "cnt_avg", cnt_avg, "\n") 130 | tongji.append(cnt) 131 | print(tongji) 132 | n = args.task_end_index - args.task_start_index 133 | 134 | 135 | def parse_args(): 136 | args = argparse.ArgumentParser() 137 | args.add_argument("--model_size", type=str, default="8") 138 | args.add_argument("--temperature", type=float, default=1.0) 139 | args.add_argument("--task_start_index", type=int, default=900) 140 | args.add_argument("--task_end_index", type=int, default=1000) 141 | args.add_argument("--prompt_sample", type=str, choices=["standard", "cot"]) 142 | args.add_argument("--n_generate_sample", type=int, default=1) 143 | args.add_argument("--n_evaluate_sample", type=int, default=1) 144 | args.add_argument("--iteration", type=int, default=50) 145 | args.add_argument("--algorithm", type=str, choices=["lats", "rap", "tot"]) 146 | args.add_argument("--cot_method", type=str, choices=["knn", "random", "None"]) 147 | args.add_argument("--run_name", type=str) 148 | args.add_argument("--log_file_path", type=str) 149 | args.add_argument("--log_dir", type=str) 150 | args.add_argument("--cot_size", type=int, default=0) 151 | args = args.parse_args() 152 | return args 153 | 154 | 155 | if __name__ == "__main__": 156 | args = parse_args() 157 | print(args) 158 | run(args) 159 | -------------------------------------------------------------------------------- /hotpotqa/tot.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from functools import partial 4 | from models import gpt 5 | import wikienv, wrappers 6 | import requests 7 | import logging 8 | import re 9 | 10 | # Configuring the logging 11 | logging.basicConfig( 12 | filename="tot_150it.log", 13 | level=logging.INFO, 14 | format="%(asctime)s - %(levelname)s - %(message)s", 15 | filemode="w", 16 | ) 17 | 18 | # Test logging 19 | logging.info("This is a test log entryyyyyy.") 20 | 21 | env = wikienv.WikiEnv() 22 | env = wrappers.HotPotQAWrapper(env, split="train") 23 | env = wrappers.LoggingWrapper(env) 24 | 25 | logging.info("Logging has been configured.") 26 | 27 | global reflection_map 28 | reflection_map = [] 29 | 30 | 31 | def step(env, action): 32 | attempts = 0 33 | while attempts < 10: 34 | try: 35 | return env.step(action) 36 | except requests.exceptions.Timeout: 37 | attempts += 1 38 | 39 | 40 | def get_value(task, x, y, n_evaluate_sample, cache_value=True): 41 | global reflection_map 42 | unique_trajectories = get_unique_trajectories(failed_trajectories) 43 | value_prompt = task.value_prompt_wrap(x, y, unique_trajectories, reflection_map) 44 | logging.info(f"Current: {x}") 45 | logging.info(f"Current: {y}") 46 | if cache_value and value_prompt in task.value_cache: 47 | return task.value_cache[value_prompt] 48 | logging.info(f"VALUE PROMPT: {value_prompt}") 49 | value_outputs = gpt(value_prompt, n=n_evaluate_sample) 50 | logging.info(f"VALUE OUTPUTS: {value_outputs}") 51 | value = task.value_outputs_unwrap(value_outputs) 52 | logging.info(f"VALUES: {value}") 53 | if cache_value: 54 | task.value_cache[value_prompt] = value 55 | return value 56 | 57 | 58 | def get_values(task, x, ys, n_evaluate_sample, cache_value=True): 59 | values = [] 60 | local_value_cache = {} 61 | for y in ys: # each partial output 62 | if y in local_value_cache: # avoid duplicate candidates 63 | value = 0 64 | else: 65 | value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value) 66 | local_value_cache[y] = value 67 | values.append(value) 68 | return values 69 | 70 | 71 | def get_samples(task, x, y, n_generate_sample, prompt_sample, knn=None): 72 | unique_trajectories = get_unique_trajectories(failed_trajectories) 73 | global reflection_map 74 | reflection_map = [] 75 | if prompt_sample == "standard": 76 | prompt = task.standard_prompt_wrap(x, y, []) 77 | elif prompt_sample == "cot": 78 | prompt = task.cot_prompt_wrap(x, y, [], knn) 79 | else: 80 | raise ValueError(f"prompt_sample {prompt_sample} not recognized") 81 | logging.info(f"PROMPT: {prompt}") 82 | samples = gpt(prompt, n=n_generate_sample) 83 | return [y + _ for _ in samples] 84 | 85 | 86 | def get_unique_trajectories(failed_trajectories, num=2): 87 | unique_trajectories = [] 88 | seen_final_answers = set() 89 | for traj in failed_trajectories: 90 | final_answer = traj.get("final_answer") 91 | if final_answer not in seen_final_answers: 92 | unique_trajectories.append(node_trajectory_to_text(traj["trajectory"])) 93 | seen_final_answers.add(final_answer) 94 | if len(unique_trajectories) >= num: 95 | break 96 | return unique_trajectories 97 | 98 | 99 | class Node: 100 | def __init__(self, state, question, parent=None, knn=None): 101 | self.state = ( 102 | {"thought": "", "action": "", "observation": ""} if state is None else state 103 | ) 104 | self.parent = parent 105 | self.question = question 106 | self.children = [] 107 | self.visits = 0 108 | self.value = 0 109 | self.depth = 0 if parent is None else parent.depth + 1 110 | self.is_terminal = False 111 | self.reward = 0 112 | self.exhausted = False # If all children are terminal 113 | self.em = 0 # Exact match, evaluation metric 114 | self.knn = knn 115 | 116 | def ques(self): 117 | return self.question 118 | 119 | def uct(self): 120 | if self.visits == 0: 121 | # return float('inf') 122 | return self.value * 2 123 | return self.value / self.visits + np.sqrt( 124 | 2 * np.log(self.parent.visits) / self.visits 125 | ) 126 | 127 | def uct_with_depth(self, C1=1, C2=1): 128 | if self.visits == 0: 129 | return self.value 130 | exploitation_term = self.value / self.visits 131 | exploration_term = np.sqrt(2 * np.log(self.parent.visits) / self.visits) 132 | depth_term = self.depth 133 | return exploitation_term + C1 * exploration_term + C2 * depth_term 134 | 135 | def __str__(self): 136 | return f"Node(depth={self.depth}, value={self.value:.2f}, visits={self.visits}, thought={self.state['thought']}, action={self.state['action']}, observation={self.state['observation']})" 137 | 138 | def to_dict(self): 139 | return { 140 | "state": self.state, 141 | "question": self.question, 142 | "parent": self.parent.to_dict() if self.parent else None, 143 | "children": [child.to_dict() for child in self.children], 144 | "visits": self.visits, 145 | "value": self.value, 146 | "depth": self.depth, 147 | "is_terminal": self.is_terminal, 148 | "reward": self.reward, 149 | "em": self.em, 150 | "knn": self.knn, 151 | } 152 | 153 | 154 | def node_trajectory_to_text(node_string): 155 | lines = node_string.split("\n") 156 | formatted_lines = [] 157 | for line in lines: 158 | if line.startswith("Question"): 159 | formatted_lines.append(line) 160 | continue 161 | try: 162 | depth = int(line.split(",")[0].split("=")[1].strip()) 163 | thought = line.split(", thought=")[1].split(", action=")[0].strip() 164 | action = line.split(", action=")[1].split(", observation=")[0].strip() 165 | observation = line.split(", observation=")[1].split(")")[0].strip() 166 | except IndexError: 167 | continue 168 | 169 | if depth != 0: 170 | if thought: 171 | formatted_lines.append(f"Thought {depth}: {thought}") 172 | if action: 173 | formatted_lines.append(f"Action {depth}: {action}") 174 | if observation: 175 | formatted_lines.append(f"Observation {depth}: {observation}") 176 | formatted_lines.pop() 177 | return "\n".join(formatted_lines) 178 | 179 | 180 | def traj_depth(node_string): 181 | lines = node_string.split("\n") 182 | formatted_lines = [] 183 | ret = 0 184 | for line in lines: 185 | try: 186 | depth = int(line.split(",")[0].split("=")[1].strip()) 187 | thought = line.split(", thought=")[1].split(", action=")[0].strip() 188 | action = line.split(", action=")[1].split(", observation=")[0].strip() 189 | observation = line.split(", observation=")[1].split(")")[0].strip() 190 | except IndexError: 191 | continue 192 | if depth > ret: 193 | ret = depth 194 | return ret 195 | 196 | 197 | def collect_all_nodes(node): 198 | """Recursively collect all nodes starting from the given node.""" 199 | nodes = [node] 200 | for child in node.children: 201 | nodes.extend(collect_all_nodes(child)) 202 | return nodes 203 | 204 | 205 | def collect_trajectory(node): 206 | trajectory = [] 207 | ques = "" 208 | while node: 209 | ques = "Question: " + node.question 210 | trajectory.append(str(node)) 211 | node = node.parent 212 | if len(ques) > 0: 213 | trajectory.append(ques) 214 | return "\n".join(reversed(trajectory)) 215 | 216 | 217 | def get_substrings_between_brackets(s): 218 | # 使用正则表达式找到方括号内的所有内容 219 | # 方括号在正则中是特殊字符,需要转义 220 | pattern = r"\[(.*?)\]" 221 | # re.findall 会返回所有非重叠的匹配 222 | matches = re.findall(pattern, s) 223 | return matches[0] 224 | 225 | def quote_env(idx): 226 | return str(env.reset(idx=idx)) 227 | 228 | 229 | def dfs_search(args, task, idx, iterations, knnret, depth_limit=7, to_print=True): 230 | global gpt 231 | global failed_trajectories 232 | global success_trajectories 233 | gpt = partial(gpt, model_size=args.model_size, temperature=args.temperature) 234 | x = env.reset(idx=idx) 235 | if to_print: 236 | print(idx, x) 237 | root = Node(state=None, question=x) 238 | all_nodes = [] 239 | failed_trajectories = [] 240 | success_trajectories = [] 241 | stack = [root] 242 | it = 0 243 | knn = [] 244 | if knnret: 245 | for traj in knnret: 246 | format_traj = node_trajectory_to_text(traj["trajectory"]) 247 | # format_traj+=f"Action {traj_depth(traj['trajectory'])}: Finish[{get_substrings_between_brackets(traj['final_answer'])}]"+"\n" 248 | knn.append(format_traj) 249 | last_node = None 250 | while stack and it < iterations: 251 | node = stack.pop() 252 | last_node = node 253 | logging.info(f"DFS at node depth {node.depth}...") 254 | 255 | if node.is_terminal and node.reward == 1: 256 | logging.info(f"Terminal node with reward 1 found at depth {node.depth}") 257 | return ( 258 | node.state, 259 | node.value, 260 | all_nodes, 261 | node.reward, 262 | node.em, 263 | failed_trajectories, 264 | success_trajectories, 265 | ) 266 | 267 | if node.is_terminal and node.reward == 0: 268 | logging.info(f"Terminal node with reward 0 found at depth {node.depth}") 269 | return ( 270 | node.state, 271 | node.value, 272 | all_nodes, 273 | node.reward, 274 | node.em, 275 | failed_trajectories, 276 | success_trajectories, 277 | ) 278 | 279 | if node.depth >= depth_limit: 280 | logging.info("Depth limit reached") 281 | it += 1 282 | continue # go to next iteration 283 | 284 | expand_node(node, args, task, knn=knn) 285 | stack.extend(reversed(node.children)) # adding all child nodes to stack for DFS 286 | 287 | all_nodes = [(node, node.value) for node in collect_all_nodes(root)] 288 | logging.info(f"State of all_nodes after iteration: {all_nodes}") 289 | it += 1 290 | # If we reach here, no solution was found 291 | logging.info("All paths explored. No solution found.") 292 | if len(failed_trajectories) == 0: 293 | trajectory = collect_trajectory(last_node) 294 | failed_trajectories.append({"trajectory": trajectory, "final_answer": ""}) 295 | return root, 0, all_nodes, 0, 0, failed_trajectories, success_trajectories 296 | 297 | 298 | def select_node_dfs(stack): 299 | return stack[-1] if stack else None # return the last node in the stack 300 | 301 | 302 | # FYI: deleted mcts search here 303 | 304 | 305 | def select_node(node): 306 | while node and node.children: 307 | logging.info( 308 | f"Selecting from {len(node.children)} children at depth {node.depth}." 309 | ) 310 | 311 | terminal_children = [child for child in node.children if child.is_terminal] 312 | terminal_status = [child.is_terminal for child in node.children] 313 | 314 | if len(terminal_children) == len(node.children): 315 | logging.info( 316 | f"All children are terminal at depth {node.depth}. Backtracking..." 317 | ) 318 | if node.parent: 319 | node.parent.children.remove(node) 320 | node = node.parent 321 | continue 322 | 323 | node_with_reward_1 = next( 324 | (child for child in terminal_children if child.reward == 1), None 325 | ) 326 | if node_with_reward_1: 327 | logging.info(f"Found terminal node with reward 1 at depth {node.depth}.") 328 | return node_with_reward_1 329 | 330 | node = max( 331 | (child for child in node.children if not child.is_terminal), 332 | key=lambda child: child.uct(), 333 | default=None, 334 | ) 335 | 336 | while node.is_terminal and node.reward != 1: 337 | node = max( 338 | (child for child in node.parent.children if not child.is_terminal), 339 | key=lambda child: child.uct(), 340 | default=None, 341 | ) 342 | 343 | logging.info(f"Selected node at depth {node.depth} with UCT {node.uct()}.") 344 | 345 | return node # This will return None if all paths from the root are exhausted 346 | 347 | 348 | def expand_node(node, args, task, knn=None): 349 | if node.depth >= 7: 350 | logging.info("Depth limit reached") 351 | print("Depth limit reached") 352 | node.is_terminal = True 353 | return 354 | new_nodes = generate_new_states(node, args, task, knn=knn) 355 | node.children.extend(new_nodes) 356 | 357 | 358 | def generate_new_states(node, args, task, knn=None): 359 | prompt = generate_prompt(node) 360 | sampled_actions = get_samples( 361 | task, 362 | prompt, 363 | f"Thought {node.depth + 1}: ", 364 | args.n_generate_sample, 365 | prompt_sample=args.prompt_sample, 366 | knn=knn, 367 | ) 368 | logging.info(f"SAMPLED ACTION: {sampled_actions}") 369 | 370 | unique_states = {} # Store unique states here 371 | for action in sampled_actions: 372 | new_state = node.state.copy() # Make a copy of the parent node's state 373 | 374 | thought_line = next( 375 | ( 376 | line.split(":")[1].strip() 377 | for line in action.split("\n") 378 | if line.startswith(f"Thought {node.depth + 1}") 379 | ), 380 | "", 381 | ) 382 | action_line = next( 383 | ( 384 | line.split(":")[1].strip() 385 | for line in action.split("\n") 386 | if line.startswith("Action") and ":" in line 387 | ), 388 | None, 389 | ) 390 | 391 | # Use thought and action to form a unique key 392 | unique_key = f"{thought_line}::{action_line}" 393 | 394 | if unique_key in unique_states: 395 | continue # Skip if this state already exists 396 | 397 | if action_line: 398 | action_type = ( 399 | action_line.split("[")[0] if "[" in action_line else action_line 400 | ) 401 | action_param = ( 402 | action_line.split("[")[1].split("]")[0] if "[" in action_line else "" 403 | ) 404 | obs, r, done, info = step(env, f"{action_type.lower()}[{action_param}]") 405 | 406 | # Update the new state dictionary 407 | new_state["thought"] = thought_line 408 | new_state["action"] = action_line 409 | new_state["observation"] = obs 410 | 411 | new_node = Node(state=new_state, question=node.question, parent=node) 412 | new_node.is_terminal = r == 1 or done 413 | new_node.reward = r 414 | if r == 1: 415 | new_node.em = info.get("em") 416 | unique_states[unique_key] = new_node # Add this state to unique_states 417 | logging.info(f"NEW NODE: {new_node}") 418 | logging.info(f"Feedback: {info}") 419 | 420 | if new_node.is_terminal and r == 0: 421 | trajectory = collect_trajectory(new_node) 422 | failed_trajectories.append( 423 | { 424 | "trajectory": trajectory, 425 | "final_answer": f"{action_type.lower()}[{action_param}]", 426 | } 427 | ) 428 | if new_node.is_terminal and r == 1: 429 | trajectory = collect_trajectory(new_node) 430 | success_trajectories.append( 431 | { 432 | "trajectory": trajectory, 433 | "final_answer": f"{action_type.lower()}[{action_param}]", 434 | } 435 | ) 436 | 437 | return list(unique_states.values()) # Return unique nodes as a list 438 | 439 | 440 | def evaluate_node(node, args, task): 441 | child_prompts = [ 442 | generate_prompt(child) for child in node.children if not child.is_terminal 443 | ] 444 | votes = get_values(task, node.question, child_prompts, args.n_evaluate_sample) 445 | 446 | logging.info(f"Length of votes: {len(votes)}") 447 | logging.info(f"Length of node.children: {len(node.children)}") 448 | 449 | # Pre-allocate votes list 450 | votes = votes + [0] * (len(node.children) - len(votes)) 451 | 452 | max_vote = max(votes) if votes else 1 453 | if max_vote == 0: 454 | max_vote = 1 # Avoid division by zero 455 | 456 | terminal_conditions = [1 if child.is_terminal else 0 for child in node.children] 457 | for i, condition in enumerate(terminal_conditions): 458 | if condition == 1: 459 | votes[i] = max_vote + 1 460 | 461 | for i, child in enumerate(node.children): 462 | child.value = votes[i] / max_vote # Now safe from division by zero 463 | 464 | return sum(votes) / len(votes) if votes else 0 465 | 466 | 467 | def print_tree(node, level=0): 468 | indent = " " * level 469 | print(f"{indent}{node}") 470 | for child in node.children: 471 | print_tree(child, level + 1) 472 | 473 | 474 | def backpropagate(node, value): 475 | while node: 476 | node.visits += 1 477 | if node.is_terminal: 478 | if node.reward == 1: 479 | node.value = (node.value * (node.visits - 1) + value) / node.visits 480 | logging.info( 481 | f"Backpropagating with reward 1 at depth {node.depth}. New value: {node.value}." 482 | ) 483 | elif node.reward == 0: 484 | node.value = (node.value * (node.visits - 1) + (-1)) / node.visits 485 | logging.info( 486 | f"Backpropagating with reward 0 at depth {node.depth}. New value: {node.value}." 487 | ) 488 | else: 489 | node.value = (node.value * (node.visits - 1) + value) / node.visits 490 | logging.info( 491 | f"Backpropagating at depth {node.depth}. New value: {node.value}." 492 | ) 493 | 494 | node = node.parent 495 | 496 | 497 | def generate_prompt(node): 498 | trajectory = [] 499 | question = node.question 500 | while node: 501 | new_segment = [] 502 | if node.state["thought"]: 503 | new_segment.append(f"Thought {node.depth}: {node.state['thought']}") 504 | if node.state["action"]: 505 | new_segment.append(f"Action {node.depth}: {node.state['action']}") 506 | if ( 507 | node.state["observation"] and node.depth != 0 508 | ): # Exclude the observation from the root node 509 | new_segment.append(f"Observation {node.depth}: {node.state['observation']}") 510 | trajectory.append("\n".join(new_segment)) 511 | node = node.parent 512 | return question + "\n".join(reversed(trajectory)) 513 | -------------------------------------------------------------------------------- /hotpotqa/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" 4 | os.environ["PYTHONUTF8"] = "1" 5 | import time 6 | from typing import Dict, List 7 | import openai 8 | import random, json 9 | from IPython import embed 10 | import random 11 | import math 12 | 13 | 14 | class ModelServer: 15 | 16 | def get_completion_or_embedding( 17 | self, 18 | model_size: str, 19 | message, 20 | temperature: float = 0.0, 21 | max_tokens: int = 256, 22 | get_embedding: bool = False, 23 | ) -> str: 24 | assert model_size in ["70", "8", "7"] 25 | 26 | if not get_embedding: 27 | assert type(message) == list, "Message should be a list." 28 | response = client.chat.completions.create( 29 | model=model_name, 30 | messages=message, 31 | max_tokens=max_tokens, 32 | temperature=temperature, 33 | stop=["<|eot_id|>", "\nObservation", "Observation"], 34 | ) 35 | else: 36 | assert type(message) == str, "Message should be a string." 37 | response = client.embeddings.create( 38 | model=model_name, 39 | input=message, 40 | ) 41 | if get_embedding: 42 | return response.data[0].embedding 43 | else: 44 | return response 45 | 46 | -------------------------------------------------------------------------------- /webshop/README: -------------------------------------------------------------------------------- 1 | # Webshop experiment code 2 | 3 | Code modified from [LATS](https://github.com/lapisrocks/LanguageAgentTreeSearch) 4 | 5 | Setup your local LLM using an OpenAI API compatible endpoint, then set it up in utils.py. We recommend [SGLang](https://github.com/sgl-project/sglang) 6 | 7 | Moreover, in lats.py, the address for Webshop server is set as http://127.0.0.1:3000 8 | 9 | Example command: 10 | 11 | python run.py --model_size 8 --temperature 0.0 --task_start_index 0 --task_end_index 50 --cot_size 5 \ 12 | --prompt_sample 'cot' --n_generate_sample 1 --n_evaluate_sample 1 \ 13 | --iteration 10 --algorithm 'tot' --cot_method 'knn' --run_name 'webshop_new' -------------------------------------------------------------------------------- /webshop/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | class Task: 3 | def __init__(self): 4 | pass 5 | 6 | def __len__(self) -> int: 7 | pass 8 | 9 | def get_input(self, idx: int) -> str: 10 | pass 11 | 12 | def test_output(self, idx: int, output: str): 13 | pass -------------------------------------------------------------------------------- /webshop/lats.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # coding: utf-8 4 | 5 | # # Setup 6 | 7 | # In[1]: 8 | 9 | import os 10 | import sys 11 | import copy 12 | import itertools 13 | import numpy as np 14 | from functools import partial 15 | from models import gpt 16 | import requests 17 | import logging 18 | import random 19 | 20 | completion_tokens = prompt_tokens = 0 21 | 22 | import requests 23 | from bs4 import BeautifulSoup 24 | from bs4.element import Comment 25 | 26 | WEBSHOP_URL = "http://127.0.0.1:3000" 27 | ACTION_TO_TEMPLATE = { 28 | 'Description': 'description_page.html', 29 | 'Features': 'features_page.html', 30 | 'Reviews': 'review_page.html', 31 | 'Attributes': 'attributes_page.html', 32 | } 33 | 34 | def clean_str(p): 35 | return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") 36 | 37 | 38 | def tag_visible(element): 39 | ignore = {'style', 'script', 'head', 'title', 'meta', '[document]'} 40 | return ( 41 | element.parent.name not in ignore and not isinstance(element, Comment) 42 | ) 43 | 44 | 45 | def webshop_text(session, page_type, query_string='', page_num=1, asin='', options={}, subpage='', **kwargs): 46 | if page_type == 'init': 47 | url = ( 48 | f'{WEBSHOP_URL}/{session}' 49 | ) 50 | if page_type == 'search': 51 | url = ( 52 | f'{WEBSHOP_URL}/search_results/{session}/' 53 | f'{query_string}/{page_num}' 54 | ) 55 | elif page_type == 'item': 56 | url = ( 57 | f'{WEBSHOP_URL}/item_page/{session}/' 58 | f'{asin}/{query_string}/{page_num}/{options}' 59 | ) 60 | elif page_type == 'item_sub': 61 | url = ( 62 | f'{WEBSHOP_URL}/item_sub_page/{session}/' 63 | f'{asin}/{query_string}/{page_num}/{subpage}/{options}' 64 | ) 65 | elif page_type == 'end': 66 | url = ( 67 | f'{WEBSHOP_URL}/done/{session}/' 68 | f'{asin}/{options}' 69 | ) 70 | # print(url) 71 | html = requests.get(url).text 72 | html_obj = BeautifulSoup(html, 'html.parser') 73 | texts = html_obj.findAll(text=True) 74 | visible_texts = list(filter(tag_visible, texts)) 75 | # visible_texts = [str(text).strip().strip('\\n') for text in visible_texts] 76 | # if page_type == 'end': import pdb; pdb.set_trace() 77 | if False: 78 | # For `simple` mode, return just [SEP] separators 79 | return ' [SEP] '.join(t.strip() for t in visible_texts if t != '\n') 80 | else: 81 | # Otherwise, return an observation with tags mapped to specific, unique separators 82 | observation = '' 83 | option_type = '' 84 | options = {} 85 | asins = [] 86 | cnt = 0 87 | prod_cnt = 0 88 | just_prod = 0 89 | for t in visible_texts: 90 | if t == '\n': continue 91 | if t.replace('\n', '').replace('\\n', '').replace(' ', '') == '': continue 92 | # if t.startswith('Instruction:') and page_type != 'init': continue 93 | # print(t.parent.name, t) 94 | if t.parent.name == 'button': # button 95 | processed_t = f'\n[{t}] ' 96 | elif t.parent.name == 'label': # options 97 | if f"'{t}'" in url: 98 | processed_t = f'[[{t}]]' 99 | # observation = f'You have clicked {t}.\n' + observation 100 | else: 101 | processed_t = f'[{t}]' 102 | options[str(t)] = option_type 103 | # options[option_type] = options.get(option_type, []) + [str(t)] 104 | elif t.parent.get('class') == ["product-link"]: # product asins 105 | processed_t = f'\n[{t}] ' 106 | if prod_cnt >= 10: 107 | processed_t = '' 108 | prod_cnt += 1 109 | asins.append(str(t)) 110 | just_prod = 0 111 | else: # regular, unclickable text 112 | processed_t = '\n' + str(t) + ' ' 113 | if cnt < 2 and page_type != 'init': processed_t = '' 114 | if just_prod <= 2 and prod_cnt >= 4: processed_t = '' 115 | option_type = str(t) 116 | cnt += 1 117 | just_prod += 1 118 | observation += processed_t 119 | info = {} 120 | if options: 121 | info['option_types'] = options 122 | if asins: 123 | info['asins'] = asins 124 | if 'Your score (min 0.0, max 1.0)' in visible_texts: 125 | idx = visible_texts.index('Your score (min 0.0, max 1.0)') 126 | info['reward'] = float(visible_texts[idx + 1]) 127 | observation = 'Your score (min 0.0, max 1.0): ' + (visible_texts[idx + 1]) 128 | return clean_str(observation), info 129 | 130 | class webshopEnv: 131 | def __init__(self): 132 | self.sessions = {} 133 | 134 | def clone_state(self): 135 | return copy.deepcopy(self.sessions) 136 | 137 | def step(self, session, action): 138 | done = False 139 | observation_ = None 140 | logging.info(self.sessions) 141 | if action == 'reset': 142 | self.sessions[session] = {'session': session, 'page_type': 'init'} 143 | elif action.startswith('think['): 144 | observation = 'OK.' 145 | elif action.startswith('search['): 146 | assert self.sessions[session]['page_type'] == 'init' 147 | query = action[7:-1] 148 | self.sessions[session] = {'session': session, 'page_type': 'search', 149 | 'query_string': query, 'page_num': 1} 150 | elif action.startswith('click['): 151 | button = action[6:-1] 152 | if button == 'Buy Now': 153 | assert self.sessions[session]['page_type'] == 'item' 154 | self.sessions[session]['page_type'] = 'end' 155 | #done = True 156 | elif button == 'Back to Search': 157 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item'] 158 | self.sessions[session] = {'session': session, 'page_type': 'init'} 159 | elif button == 'Next >': 160 | #assert False # ad hoc page limitation 161 | assert self.sessions[session]['page_type'] == 'search' 162 | self.sessions[session]['page_num'] += 1 163 | elif button == '< Prev': 164 | assert self.sessions[session]['page_type'] in ['search', 'item_sub', 'item'] 165 | if self.sessions[session]['page_type'] == 'search': 166 | #assert False 167 | self.sessions[session]['page_num'] -= 1 168 | elif self.sessions[session]['page_type'] == 'item_sub': 169 | self.sessions[session]['page_type'] = 'item' 170 | elif self.sessions[session]['page_type'] == 'item': 171 | self.sessions[session]['page_type'] = 'search' 172 | self.sessions[session]['options'] = {} 173 | elif button in ACTION_TO_TEMPLATE: 174 | assert self.sessions[session]['page_type'] == 'item' 175 | self.sessions[session]['page_type'] = 'item_sub' 176 | self.sessions[session]['subpage'] = button 177 | else: 178 | if self.sessions[session]['page_type'] == 'search': 179 | assert button in self.sessions[session].get('asins', []) # must be asins 180 | self.sessions[session]['page_type'] = 'item' 181 | self.sessions[session]['asin'] = button 182 | elif self.sessions[session]['page_type'] == 'item': 183 | assert 'option_types' in self.sessions[session] 184 | assert button in self.sessions[session]['option_types'], (button, self.sessions[session]['option_types']) # must be options 185 | option_type = self.sessions[session]['option_types'][button] 186 | if not 'options' in self.sessions[session]: 187 | self.sessions[session]['options'] = {} 188 | self.sessions[session]['options'][option_type] = button 189 | observation_ = f'You have clicked {button}.' 190 | else: 191 | assert False 192 | observation, info = webshop_text(**self.sessions[session]) 193 | if observation_: 194 | observation = observation_ 195 | self.sessions[session].update(info) 196 | reward = info.get('reward', 0.0) 197 | if reward != 0.0: 198 | #print(f"Current Session State: {self.sessions[session]}") 199 | #print(f"Action being processed: {action}") 200 | print(f"Resulting Observation: {observation}") 201 | observation+=" Please try again!" 202 | if reward == 1.0: 203 | done = True 204 | print("done") 205 | return observation, reward, done 206 | 207 | env = webshopEnv() 208 | 209 | logging.info("Logging has been configured.") 210 | 211 | global reflection_map 212 | reflection_map = [] 213 | 214 | 215 | def step(env, action): 216 | attempts = 0 217 | while attempts < 10: 218 | try: 219 | return env.step(action) 220 | except requests.exceptions.Timeout: 221 | attempts += 1 222 | 223 | 224 | 225 | def get_samples(task, x, y, n_generate_sample, prompt_sample, knn=None): 226 | unique_trajectories = get_unique_trajectories(failed_trajectories) 227 | global reflection_map 228 | global completion_tokens 229 | global prompt_tokens 230 | reflection_map = [] 231 | if prompt_sample == "standard": 232 | prompt = task.standard_prompt_wrap(x, y, []) 233 | elif prompt_sample == "cot": 234 | prompt = task.cot_prompt_wrap(x, y, [], knn) 235 | else: 236 | raise ValueError(f"prompt_sample {prompt_sample} not recognized") 237 | logging.info(f"PROMPT: {prompt}") 238 | samples,inca,incb = gpt(prompt, n=n_generate_sample) 239 | prompt_tokens+=inca 240 | completion_tokens+=incb 241 | #for trt in samples: 242 | # with open("debug_log1.log", "a", encoding="utf-8") as f: 243 | # f.write(y+"\n"+trt+"\n") 244 | return [y + _ for _ in samples] 245 | 246 | 247 | def get_unique_trajectories(failed_trajectories, num=2): 248 | unique_trajectories = [] 249 | seen_final_answers = set() 250 | for traj in failed_trajectories: 251 | final_answer = traj.get("final_answer") 252 | if final_answer not in seen_final_answers: 253 | unique_trajectories.append(node_trajectory_to_text(traj["trajectory"])) 254 | seen_final_answers.add(final_answer) 255 | if len(unique_trajectories) >= num: 256 | break 257 | return unique_trajectories 258 | 259 | 260 | class Node: 261 | def __init__(self, state, question, parent=None, knn=None,env_state=None): 262 | self.state = ( 263 | {"action": "", "observation": ""} if state is None else state 264 | ) 265 | self.parent = parent 266 | self.question = question 267 | self.children = [] 268 | self.visits = 0 269 | self.value = 0 270 | self.depth = 0 if parent is None else parent.depth + 1 271 | self.is_terminal = False 272 | self.reward = 0 273 | self.exhausted = False # If all children are terminal 274 | self.em = 0 # Exact match, evaluation metric 275 | self.knn = knn 276 | self.env_state = env_state 277 | 278 | def ques(self): 279 | return self.question 280 | 281 | def uct(self): 282 | if self.visits == 0: 283 | # return float('inf') 284 | return self.value * 2 285 | return self.value / self.visits + np.sqrt( 286 | 2 * np.log(self.parent.visits) / self.visits 287 | ) 288 | 289 | def uct_with_depth(self, C1=1, C2=1): 290 | if self.visits == 0: 291 | return self.value 292 | exploitation_term = self.value / self.visits 293 | exploration_term = np.sqrt(2 * np.log(self.parent.visits) / self.visits) 294 | depth_term = self.depth 295 | return exploitation_term + C1 * exploration_term + C2 * depth_term 296 | 297 | def __str__(self): 298 | return f"Node(depth={self.depth}, value={self.value:.2f}, visits={self.visits}, action={self.state['action']}, observation={self.state['observation']})" 299 | 300 | def to_dict(self): 301 | return { 302 | "state": self.state, 303 | "question": self.question, 304 | "parent": self.parent.to_dict() if self.parent else None, 305 | "children": [child.to_dict() for child in self.children], 306 | "visits": self.visits, 307 | "value": self.value, 308 | "depth": self.depth, 309 | "is_terminal": self.is_terminal, 310 | "reward": self.reward, 311 | "em": self.em, 312 | "knn": self.knn, 313 | } 314 | 315 | 316 | def node_trajectory_to_text(node_string): 317 | lines = node_string.split("Node(") 318 | formatted_lines = [] 319 | for line in lines: 320 | if line.startswith("Instruction"): 321 | formatted_lines.append(line) 322 | continue 323 | try: 324 | depth = int(line.split(",")[0].split("=")[1].strip()) 325 | action = line.split(", action=")[1].split(", observation=")[0].strip() 326 | observation = line.split(", observation=")[1].split(")")[0].strip() 327 | except IndexError: 328 | continue 329 | 330 | if depth != 0: 331 | if action: 332 | formatted_lines.append(f"Action: {action}") 333 | if observation: 334 | formatted_lines.append(f"Observation: {observation}") 335 | formatted_lines.pop() 336 | return "\n".join(formatted_lines) 337 | 338 | 339 | def traj_depth(node_string): 340 | lines = node_string.split("\n") 341 | formatted_lines = [] 342 | ret = 0 343 | for line in lines: 344 | try: 345 | depth = int(line.split(",")[0].split("=")[1].strip()) 346 | action = line.split(", action=")[1].split(", observation=")[0].strip() 347 | observation = line.split(", observation=")[1].split(")")[0].strip() 348 | except IndexError: 349 | continue 350 | if depth > ret: 351 | ret = depth 352 | return ret 353 | 354 | 355 | def collect_all_nodes(node): 356 | """Recursively collect all nodes starting from the given node.""" 357 | nodes = [node] 358 | for child in node.children: 359 | nodes.extend(collect_all_nodes(child)) 360 | return nodes 361 | 362 | 363 | def collect_trajectory(node): 364 | trajectory = [] 365 | ques = "" 366 | while node: 367 | ques = "Instruction: " + str(node.question).replace("WebShop","").replace("Instruction:","").strip() 368 | trajectory.append(str(node)) 369 | node = node.parent 370 | if len(ques) > 0: 371 | trajectory.append(ques) 372 | return "\n".join(reversed(trajectory)) 373 | 374 | import re 375 | def get_substrings_between_brackets(s): 376 | pattern = r"\[(.*?)\]" 377 | matches = re.findall(pattern, s) 378 | return matches[0] 379 | 380 | def quote_env(idx): 381 | return str(env.step(idx, "reset")[0]) 382 | 383 | 384 | def dfs_search(args, task, idx, iterations, knnret, depth_limit=30, to_print=True): 385 | global gpt 386 | global failed_trajectories 387 | global success_trajectories 388 | gpt = partial(gpt, model_size=args.model_size, temperature=args.temperature) 389 | action="reset" 390 | x = env.step(idx, action)[0] 391 | if to_print: 392 | print(idx, x) 393 | root = Node(state=None, question=x) 394 | root.env_state = copy.deepcopy(env.sessions) 395 | all_nodes = [] 396 | failed_trajectories = [] 397 | success_trajectories = [] 398 | stack = [root] 399 | it = 0 400 | knn = [] 401 | if knnret: 402 | for traj in knnret: 403 | format_traj = node_trajectory_to_text(traj["trajectory"]) 404 | # format_traj+=f"Action {traj_depth(traj['trajectory'])}: Finish[{get_substrings_between_brackets(traj['final_answer'])}]"+"\n" 405 | knn.append(format_traj) 406 | print("--------------knn is:") 407 | print(knn) 408 | last_node = None 409 | maxx=0.0 410 | while stack and it < iterations: 411 | node = stack.pop() 412 | last_node = node 413 | logging.info(f"DFS at node depth {node.depth}...") 414 | 415 | if node.is_terminal and node.reward == 1: 416 | logging.info(f"Terminal node with reward 1 found at depth {node.depth}") 417 | return ( 418 | node.state, 419 | node.value, 420 | all_nodes, 421 | node.reward, 422 | node.em, 423 | failed_trajectories, 424 | success_trajectories, 425 | ) 426 | 427 | if node.is_terminal and node.reward == 0: 428 | logging.info(f"Terminal node with reward 0 found at depth {node.depth}") 429 | return ( 430 | node.state, 431 | node.value, 432 | all_nodes, 433 | maxx, 434 | node.em, 435 | failed_trajectories, 436 | success_trajectories, 437 | ) 438 | 439 | maxx=max(node.reward,maxx) 440 | 441 | if node.depth >= depth_limit: 442 | logging.info("Depth limit reached") 443 | it += 1 444 | continue # go to next iteration 445 | 446 | expand_node(node, args, task,idx, knn=knn) 447 | stack.extend(reversed(node.children)) # adding all child nodes to stack for DFS 448 | 449 | all_nodes = [(node, node.value) for node in collect_all_nodes(root)] 450 | logging.info(f"State of all_nodes after iteration: {all_nodes}") 451 | it += 1 452 | # If we reach here, no solution was found 453 | logging.info("All paths explored. No solution found.") 454 | if len(failed_trajectories) == 0: 455 | trajectory = collect_trajectory(last_node) 456 | failed_trajectories.append({"trajectory": trajectory, "final_answer": ""}) 457 | return root, 0, all_nodes, maxx, 0, failed_trajectories, success_trajectories 458 | 459 | 460 | def select_node_dfs(stack): 461 | return stack[-1] if stack else None # return the last node in the stack 462 | 463 | 464 | 465 | 466 | def select_node(node): 467 | while node and node.children: 468 | logging.info( 469 | f"Selecting from {len(node.children)} children at depth {node.depth}." 470 | ) 471 | 472 | terminal_children = [child for child in node.children if child.is_terminal] 473 | terminal_status = [child.is_terminal for child in node.children] 474 | 475 | if len(terminal_children) == len(node.children): 476 | logging.info( 477 | f"All children are terminal at depth {node.depth}. Backtracking..." 478 | ) 479 | if node.parent: 480 | node.parent.children.remove(node) 481 | node = node.parent 482 | continue 483 | 484 | node_with_reward_1 = next( 485 | (child for child in terminal_children if child.reward == 1), None 486 | ) 487 | if node_with_reward_1: 488 | logging.info(f"Found terminal node with reward 1 at depth {node.depth}.") 489 | return node_with_reward_1 490 | 491 | node = max( 492 | (child for child in node.children if not child.is_terminal), 493 | key=lambda child: child.uct(), 494 | default=None, 495 | ) 496 | 497 | while node.is_terminal and node.reward != 1: 498 | node = max( 499 | (child for child in node.parent.children if not child.is_terminal), 500 | key=lambda child: child.uct(), 501 | default=None, 502 | ) 503 | 504 | logging.info(f"Selected node at depth {node.depth} with UCT {node.uct()}.") 505 | 506 | return node # This will return None if all paths from the root are exhausted 507 | 508 | 509 | def expand_node(node, args, task, idx,knn=None): 510 | if node.depth >= 30: 511 | logging.info("Depth limit reached") 512 | print("Depth limit reached") 513 | node.is_terminal = True 514 | return 515 | new_nodes = generate_new_states(node, args, task, idx,knn=knn) 516 | node.children.extend(new_nodes) 517 | 518 | 519 | def generate_new_states(node, args, task, idx, knn=None): 520 | prompt = generate_prompt(node) 521 | #print("the prompt is:") 522 | #print(prompt) 523 | sampled_actions = get_samples( 524 | task, 525 | prompt, 526 | f"Action: ", 527 | args.n_generate_sample, 528 | prompt_sample=args.prompt_sample, 529 | knn=knn, 530 | ) 531 | logging.info(f"SAMPLED ACTION: {sampled_actions}") 532 | 533 | unique_states = {} # Store unique states here 534 | for action in sampled_actions: 535 | local_sessions = copy.deepcopy(node.env_state) 536 | env.sessions = local_sessions 537 | new_state = node.state.copy() # Make a copy of the parent node's state 538 | 539 | 540 | action_line = next( 541 | ( 542 | line.split(":")[1].strip() 543 | for line in action.split("\n") 544 | if line.startswith("Action") and ":" in line 545 | ), 546 | None, 547 | ) 548 | 549 | # Use thought and action to form a unique key 550 | unique_key = f"{action_line}" 551 | 552 | if unique_key in unique_states: 553 | continue # Skip if this state already exists 554 | 555 | if action_line: 556 | #print("the action line is:") 557 | #print(action_line) 558 | try: 559 | res = env.step(idx, action_line) 560 | #print("res", res) 561 | obs = res[0] 562 | r = res[1] 563 | done = res[2] 564 | except AssertionError: 565 | obs = 'Invalid action!' 566 | # print("err") 567 | r = -1 568 | done = False 569 | 570 | if action.startswith('think'): 571 | observation = 'OK.' 572 | 573 | # Update the new state dictionary 574 | new_state["action"] = action_line 575 | new_state["observation"] = obs 576 | 577 | env_state_clone = env.clone_state() 578 | new_node = Node(state=new_state, question=node.question, env_state=env_state_clone,parent=node) 579 | new_node.env_state = local_sessions 580 | new_node.is_terminal = r == 1 or done 581 | new_node.reward = r 582 | unique_states[unique_key] = new_node # Add this state to unique_states 583 | logging.info(f"NEW NODE: {new_node}") 584 | 585 | if new_node.is_terminal and r == 0: 586 | trajectory = collect_trajectory(new_node) 587 | failed_trajectories.append( 588 | { 589 | "trajectory": trajectory, 590 | "final_answer": f"{action_line}", 591 | } 592 | ) 593 | if new_node.is_terminal and r == 1: 594 | trajectory = collect_trajectory(new_node) 595 | success_trajectories.append( 596 | { 597 | "trajectory": trajectory, 598 | "final_answer": f"{action_line}", 599 | } 600 | ) 601 | 602 | return list(unique_states.values()) # Return unique nodes as a list 603 | 604 | 605 | 606 | 607 | def print_tree(node, level=0): 608 | indent = " " * level 609 | print(f"{indent}{node}") 610 | for child in node.children: 611 | print_tree(child, level + 1) 612 | 613 | 614 | 615 | 616 | def generate_prompt(node): 617 | trajectory = [] 618 | question = node.question 619 | while node: 620 | new_segment = [] 621 | if node.state["action"]: 622 | new_segment.append(f"Action: {node.state['action']}") 623 | if ( 624 | node.state["observation"] and node.depth != 0 625 | ): # Exclude the observation from the root node 626 | new_segment.append(f"Observation: {node.state['observation']}") 627 | trajectory.append("\n".join(new_segment)) 628 | node = node.parent 629 | return question + "\n".join(reversed(trajectory)) 630 | -------------------------------------------------------------------------------- /webshop/models.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from utils import ModelServer 3 | import time 4 | import re 5 | 6 | server = None 7 | 8 | 9 | def build_server(config_path): 10 | global server 11 | server = ModelServer(config_path=config_path) 12 | 13 | 14 | def split_and_keep_prefixes(s, delimiters): 15 | regex_pattern = f"({'|'.join(map(re.escape, delimiters))})" 16 | parts = re.split(regex_pattern, s) 17 | result = [parts[0]] 18 | for i in range(1, len(parts), 2): 19 | result.append(parts[i] + (parts[i + 1] if i + 1 < len(parts) else "")) 20 | return result 21 | 22 | 23 | def online_embed(traj): 24 | return server.get_completion_or_embedding( 25 | "7", 26 | message=traj, 27 | get_embedding=True, 28 | ) 29 | 30 | ret_in=0 31 | ret_out=0 32 | def gpt(prompt, model_size="8", temperature=1.0, max_tokens=100, n=1) -> list: 33 | def call_openai_api(messages, model_size, temperature, max_tokens, n): 34 | global ret_in 35 | global ret_out 36 | outputs = [] 37 | while n > 0: 38 | cnt = min(n, 20) 39 | n -= cnt 40 | res = server.get_completion_or_embedding( 41 | model_size, messages, temperature, max_tokens 42 | ) 43 | ret_out+=res.usage.completion_tokens 44 | ret_in+=res.usage.prompt_tokens 45 | outputs.extend( 46 | [ 47 | re.sub(r"^Action:", "", choice.message.content) 48 | for choice in res.choices 49 | ] 50 | ) 51 | return outputs,res.usage.prompt_tokens,res.usage.completion_tokens 52 | 53 | messages = [] 54 | parts = re.split(r"(Action:|Observation:|Instruction:)", prompt) 55 | 56 | result = [parts[0].strip()] 57 | for i in range(1, len(parts), 2): 58 | if i + 1 < len(parts): 59 | result.append(parts[i] + " " + parts[i + 1].strip()) 60 | result.pop() 61 | last_obs="" 62 | for msg in result: 63 | if msg.startswith("Action"): 64 | messages.append({"role": "assistant", "content": msg}) 65 | if msg.startswith("Observation"): 66 | messages.append({"role": "user", "content": msg}) 67 | last_obs=msg 68 | if msg.startswith("Instruction"): 69 | messages.append({"role": "user", "content": msg}) 70 | return call_openai_api( 71 | messages, model_size=model_size, temperature=temperature, max_tokens=max_tokens, n=n 72 | ) 73 | -------------------------------------------------------------------------------- /webshop/prompt.py: -------------------------------------------------------------------------------- 1 | prompt1 = """Instruction: You are also an advanced reasoning agent that can improve based on previous experience. Follow the instruct and purchase an item meeting all of the correct specifications by navigating the website. 2 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars 3 | [Search] 4 | 5 | Action: search[3 ounce bright citrus deodorant sensitive skin] 6 | Observation: 7 | [Back to Search] 8 | Page 1 (Total results: 50) 9 | [Next >] 10 | [B078GWRC1J] 11 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 12 | $10.99 13 | [B078GTKVXY] 14 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 15 | $10.99 16 | [B08KBVJ4XN] 17 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack) 18 | $15.95 19 | 20 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.] 21 | Observation: OK. 22 | 23 | Action: click[B078GWRC1J] 24 | Observation: 25 | [Back to Search] 26 | [< Prev] 27 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents] 28 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)] 29 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 30 | Price: $10.99 31 | Rating: N.A. 32 | [Description] 33 | [Features] 34 | [Reviews] 35 | [Buy Now] 36 | 37 | Action: think[For 3 ounce bottle of bright citrus deodorant for sensitive skin, the item has options 'bright citrus' and '3 ounce (pack of 1)' and seems good to buy.] 38 | Observation: OK. 39 | 40 | Action: click[bright citrus] 41 | Observation: You have clicked bright citrus. 42 | 43 | Action: click[3 ounce (pack of 1)] 44 | Observation: You have clicked 3 ounce (pack of 1). 45 | 46 | Action: click[Buy Now] 47 | 48 | """ 49 | 50 | 51 | prompt1_feedback = """You are also an advanced reasoning agent that can improve based on self refection. Follow the instruction and purchase an item meeting all of the correct specifications by navigating the website. 52 | Here is an example: 53 | Webshop 54 | Instruction: 55 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars 56 | [Search] 57 | 58 | Action: search[3 ounce bright citrus deodorant sensitive skin] 59 | Observation: 60 | [Back to Search] 61 | Page 1 (Total results: 50) 62 | [Next >] 63 | [B078GWRC1J] 64 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 65 | $10.99 66 | [B078GTKVXY] 67 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 68 | $10.99 69 | [B08KBVJ4XN] 70 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack) 71 | $15.95 72 | 73 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.] 74 | Observation: OK. 75 | 76 | Action: click[B078GWRC1J] 77 | Observation: 78 | [Back to Search] 79 | [< Prev] 80 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents] 81 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)] 82 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 83 | Price: $10.99 84 | Rating: N.A. 85 | [Description] 86 | [Features] 87 | [Reviews] 88 | [Buy Now] 89 | 90 | Action: think[For 3 ounce bottle of bright citrus deodorant for sensitive skin, the item has options 'bright citrus' and '3 ounce (pack of 1)' and seems good to buy.] 91 | Observation: OK. 92 | 93 | Action: click[bright citrus] 94 | Observation: You have clicked bright citrus. 95 | 96 | Action: click[3 ounce (pack of 1)] 97 | Observation: You have clicked 3 ounce (pack of 1). 98 | 99 | Action: click[Buy Now] 100 | 101 | Here are some previous trials in which you were given access to relevant context and a item to purchase. You were unsuccessful in buying the correct item either because you did not find an item meeting all of the required specifications or because you did not select the correct item. Carefully consider the errors and make sure not to repeat the same mistakes. {trajectories} 102 | 103 | {input} 104 | """ 105 | 106 | # In[4]: 107 | score_prompt = '''Given an item to purchase and a trajectory that aims to buy an item that exactly matches the specification, which corresponds to the ideal score of 1.0, analyze the following trajectory, then at the last line conclude "Thus the correctness score is {s}", where s is an integer from 1 to 10. 108 | 109 | Here are some examples 110 | 111 | Webshop 112 | Instruction: 113 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars 114 | [Search] 115 | 116 | Action: search[dairy free and apple variety pack of chips] 117 | Observation: 118 | [Back to Search] 119 | Page 1 (Total results: 50) 120 | [Next >] 121 | [B07HRFJWP8] 122 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total) 123 | $100.0 124 | [B01KMHY5PG] 125 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars) 126 | $21.49 127 | [B008D2X8C4] 128 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 129 | $100.0 130 | 131 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.] 132 | Observation: OK. 133 | 134 | Action: click[B008D2X8C4] 135 | Observation: 136 | [Back to Search] 137 | [< Prev] 138 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack] 139 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)] 140 | Price: $100.0 141 | Rating: N.A. 142 | [Description] 143 | [Features] 144 | [Reviews] 145 | [Buy Now] 146 | 147 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.] 148 | Observation: OK. 149 | 150 | Action: click[variety pack] 151 | Observation: You have clicked variety pack. 152 | 153 | Action: click[0.8 ounce (pack of 24)] 154 | Observation: You have clicked 0.8 ounce (pack of 24). 155 | 156 | Action: click[Buy Now] 157 | 158 | Reflection: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products. 159 | Thus the correctness score is 5 160 | 161 | Webshop 162 | Instruction: 163 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars 164 | [Search] 165 | 166 | Action: search[3 ounce bright citrus deodorant sensitive skin] 167 | Observation: 168 | [Back to Search] 169 | Page 1 (Total results: 50) 170 | [Next >] 171 | [B078GWRC1J] 172 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 173 | $10.99 174 | [B078GTKVXY] 175 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 176 | $10.99 177 | [B08KBVJ4XN] 178 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack) 179 | $15.95 180 | 181 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.] 182 | Observation: OK. 183 | Reflection: Although the task is not yet complete, the first search is correct and returns relevant results. The thought is also logical and has potential to lead to the correct selection. 184 | Thus the correctness score is 9 185 | {input} 186 | ''' 187 | 188 | score_prompt_feedback = '''Given an item to purchase and a trajectory, analyze the following trajectory, then at the last line conclude "Thus the correctness score is {s}", where s is an integer from 1 to 10. 189 | Here are some examples 190 | 191 | Webshop 192 | Instruction: 193 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars 194 | [Search] 195 | 196 | Action: search[dairy free and apple variety pack of chips] 197 | Observation: 198 | [Back to Search] 199 | Page 1 (Total results: 50) 200 | [Next >] 201 | [B07HRFJWP8] 202 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total) 203 | $100.0 204 | [B01KMHY5PG] 205 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars) 206 | $21.49 207 | [B008D2X8C4] 208 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 209 | $100.0 210 | 211 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.] 212 | Observation: OK. 213 | 214 | Action: click[B008D2X8C4] 215 | Observation: 216 | [Back to Search] 217 | [< Prev] 218 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack] 219 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)] 220 | Price: $100.0 221 | Rating: N.A. 222 | [Description] 223 | [Features] 224 | [Reviews] 225 | [Buy Now] 226 | 227 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.] 228 | Observation: OK. 229 | 230 | Action: click[variety pack] 231 | Observation: You have clicked variety pack. 232 | 233 | Action: click[0.8 ounce (pack of 24)] 234 | Observation: You have clicked 0.8 ounce (pack of 24). 235 | 236 | Action: click[Buy Now] 237 | 238 | Reflection: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products. 239 | Thus the correctness score is 5 240 | 241 | Webshop 242 | Instruction: 243 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars 244 | [Search] 245 | 246 | Action: search[3 ounce bright citrus deodorant sensitive skin] 247 | Observation: 248 | [Back to Search] 249 | Page 1 (Total results: 50) 250 | [Next >] 251 | [B078GWRC1J] 252 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 253 | $10.99 254 | [B078GTKVXY] 255 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 256 | $10.99 257 | [B08KBVJ4XN] 258 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack) 259 | $15.95 260 | 261 | Action: think[B078GWRC1J and B078GTKVXY are bright citrus deodorant less then 50 dollars. I can check B078GWRC1J first.] 262 | Observation: OK. 263 | Reflection: Although the task is not yet complete, the first search is correct and returns relevant results. The thought is also logical and has potential to lead to the correct selection. 264 | Thus the correctness score is 9 265 | {trajectories} 266 | {input} 267 | ''' 268 | 269 | # trivial search & item, choose option 270 | prompt1_actonly = """Webshop 271 | Instruction: 272 | i would like a 3 ounce bottle of bright citrus deodorant for sensitive skin, and price lower than 50.00 dollars 273 | [Search] 274 | 275 | Action: search[3 ounce bright citrus deodorant sensitive skin] 276 | Observation: 277 | [Back to Search] 278 | Page 1 (Total results: 50) 279 | [Next >] 280 | [B078GWRC1J] 281 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 282 | $10.99 283 | [B078GTKVXY] 284 | Ginger Fresh Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 285 | $10.99 286 | [B08KBVJ4XN] 287 | Barrel and Oak - Aluminum-Free Deodorant, Deodorant for Men, Essential Oil-Based Scent, 24-Hour Odor Protection, Cedar & Patchouli Blend, Gentle on Sensitive Skin (Mountain Sage, 2.7 oz, 2-Pack) 288 | $15.95 289 | 290 | Action: click[B078GWRC1J] 291 | Observation: 292 | [Back to Search] 293 | [< Prev] 294 | scent [assorted scents][bright citrus][calming lavender][ginger fresh][simply non-scents] 295 | size [travel set (4-pack)][3 ounce (pack of 1)][3-ounce (2-pack)] 296 | Bright Citrus Deodorant by Earth Mama | Natural and Safe for Sensitive Skin, Pregnancy and Breastfeeding, Contains Organic Calendula 3-Ounce 297 | Price: $10.99 298 | Rating: N.A. 299 | [Description] 300 | [Features] 301 | [Reviews] 302 | [Buy Now] 303 | 304 | Action: click[bright citrus] 305 | Observation: You have clicked bright citrus. 306 | 307 | Action: click[3 ounce (pack of 1)] 308 | Observation: You have clicked 3 ounce (pack of 1). 309 | 310 | Action: click[Buy Now] 311 | """ 312 | 313 | reflection_prompt = '''You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to an shopping website and a specific type of item to buy. You were given access to relevant context and a item to purchase. You were unsuccessful in buying the correct item either because you did not find an item meeting all of the required specifications or because you did not select the correct item. The ideal score is 1.0, and anything less is incorrect. In a few sentences, Diagnose a possible reason for failure and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences. 314 | Here are some examples: 315 | 316 | Previous Trial 317 | Instruction: 318 | i am looking for dairy free and apple variety pack of chips, and price lower than 30.00 dollars 319 | [Search] 320 | 321 | Action: search[dairy free and apple variety pack of chips] 322 | Observation: 323 | [Back to Search] 324 | Page 1 (Total results: 50) 325 | [Next >] 326 | [B07HRFJWP8] 327 | Enjoy Life Foods Soft Baked Ovals, Breakfast Bars, Nut Free Bars, Soy Free, Dairy Free, Non GMO, Gluten Free, Vegan, Variety Pack, 4 Boxes (20 Bars Total) 328 | $100.0 329 | [B01KMHY5PG] 330 | Enjoy Life Soft Baked Chewy Bars, Variety Pack, Nut Free Bars, Soy Free, Dairy Free, Gluten Free, 6 Boxes (30 Total Bars) 331 | $21.49 332 | [B008D2X8C4] 333 | Enjoy Life Lentil Chips Variety Pack, Dairy Free Chips, Soy Free, Nut Free, Non GMO, Vegan, Gluten Free, 24 Bags (0.8 oz) 334 | $100.0 335 | 336 | Action: think[The first two items are not what I want. The third item is the dairy free and apple variety pack of chips I want.] 337 | Observation: OK. 338 | 339 | Action: click[B008D2X8C4] 340 | Observation: 341 | [Back to Search] 342 | [< Prev] 343 | flavor name [dill and sour cream][garlic & parmesan][light sea salt][margherita pizza][thai chili lime][variety pack] 344 | size [0.8 ounce (pack of 24)][4 ounce (pack of 12)] 345 | Price: $100.0 346 | Rating: N.A. 347 | [Description] 348 | [Features] 349 | [Reviews] 350 | [Buy Now] 351 | 352 | Action: think[The item has options 'variety pack' and '0.8 ounce (pack of 24)' and seems good to buy.] 353 | Observation: OK. 354 | 355 | Action: click[variety pack] 356 | Observation: You have clicked variety pack. 357 | 358 | Action: click[0.8 ounce (pack of 24)] 359 | Observation: You have clicked 0.8 ounce (pack of 24). 360 | 361 | Action: click[Buy Now] 362 | 363 | STATUS: FAIL 364 | 365 | Reflection: In this attempt, I was unsuccessful. I accidentally bought a product that was $100, which is more than my budget of $30. Either way, the initial search results were not good. Next time, I will do search["variety pack of chips"] and then check if the results meet the dairy free and the $30 budget constraints. I will continue to refine my searches so that I can find more products. 366 | 367 | Previous Trial 368 | Instruction: 369 | i need gluten free vegetarian smoked peppered bacon - 4 ounce (pack of 2), and price lower than 40.00 dollars 370 | [Search] 371 | 372 | Action: search[gluten free vegetarian smoked peppered bacon 4 ounce pack of 2] 373 | Observation: 374 | [Back to Search] 375 | Page 1 (Total results: 50) 376 | [Next >] 377 | [B06Y96MXJV] 378 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 379 | $29.99 380 | [B06Y96N1KG] 381 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 382 | $29.99 383 | [B07GJTKYJQ] 384 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 385 | $42.99 386 | 387 | Action: think[The first two items are not what I want. The third item is not what I want, but it is the closest one.] 388 | Observation: OK. 389 | 390 | Action: click[B07GJTKYJQ] 391 | Observation: 392 | [Back to Search] 393 | [< Prev] 394 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 395 | Price: $42.99 396 | Rating: N.A. 397 | [Description] 398 | [Features] 399 | [Reviews] 400 | [Buy Now] 401 | 402 | Action: think[The price is too high and it is not what I want.] 403 | Observation: OK. 404 | 405 | Action: click[< Prev] 406 | Observation: 407 | [Back to Search] 408 | Page 1 (Total results: 50) 409 | [Next >] 410 | [B06Y96MXJV] 411 | Smoked Bacon Sea Salt 3-Pack: Smoked Bacon Chipotle, Smoked Bacon and Onion, Smoked Peppered Bacon - All-Natural Sea Salts Slowly Smoked Over Alderwood - No Gluten, No MSG, Non-GMO (12 total oz.) 412 | $29.99 413 | [B06Y96N1KG] 414 | Spicy Hot Pepper Sea Salt 3-Pack: Ghost Pepper, Jalapeno, Habanero - All Natural, Delicious Blends of Peppers & Sea Salt - Gluten-Free, Kosher, No MSG, Non-GMO (12 total oz.) 415 | $29.99 416 | [B07GJTKYJQ] 417 | Louisville Vegan Jerky - 5 Flavor Variety Pack, Non-GMO Soy Protein, Gluten-Free (Black Pepper, Buffalo Dill, Pepperoni, Maple Bacon, and Carolina BBQ, 3 Ounces) 418 | $42.99 419 | 420 | Action: think[The first two items are not what I want.] 421 | Observation: OK. 422 | 423 | Action: click[Next >] 424 | Observation: Invalid action! 425 | 426 | Action: think[I need to refine my search.] 427 | Observation: OK. 428 | 429 | Action: click[< Back to Search] 430 | Observation: Invalid action! 431 | 432 | Action: click[Back to Search] 433 | 434 | STATUS: FAIL 435 | 436 | Reflection: In this attempt, I was unsuccessful. The initial search results were not the products that I was looking for. Next time, I will do search["vegetarian bacon"] for a broader search. I will check that the new results will fulfill the gluten free and 4 ounce pack of 2 constraints. I will continue to refine my searches so that I can find more products. 437 | 438 | Previous trial: 439 | {trajectory} 440 | Reflection:''' -------------------------------------------------------------------------------- /webshop/requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.4 2 | aiosignal==1.3.1 3 | async-timeout==4.0.2 4 | attrs==23.1.0 5 | backoff==2.2.1 6 | certifi==2023.5.7 7 | charset-normalizer==3.1.0 8 | frozenlist==1.3.3 9 | idna==3.4 10 | mpmath==1.3.0 11 | multidict==6.0.4 12 | numpy==1.24.3 13 | openai==0.27.7 14 | requests==2.31.0 15 | sympy==1.12 16 | tqdm==4.65.0 17 | urllib3==2.0.2 18 | yarl==1.9.2 19 | pandas==2.0.3 -------------------------------------------------------------------------------- /webshop/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import math 6 | from models import online_embed, build_server 7 | import faiss 8 | import logging 9 | import random 10 | 11 | import lats 12 | from lats import dfs_search,quote_env 13 | from webshop import WebShopTask 14 | 15 | 16 | def random_selection(lst, n=5): 17 | if len(lst) <= n: 18 | return lst 19 | else: 20 | return random.sample(lst, n) 21 | 22 | 23 | def run(args): 24 | task = WebShopTask() 25 | print(task) 26 | logs, cnt_avg, cnt_any = [], 0, 0 27 | 28 | # create log directories if they don't exist 29 | config_path = os.path.join(args.run_name, "config.json") 30 | 31 | with open(config_path, "w", encoding="utf-8") as wf: 32 | info_dict = vars(args) 33 | info_dict["is_running"] = True 34 | json.dump(info_dict, wf, indent=4) 35 | 36 | build_server(config_path=config_path) 37 | 38 | os.makedirs(os.path.dirname(args.log_dir), exist_ok=True) 39 | 40 | logging.basicConfig( 41 | filename=args.log_dir, 42 | level=logging.INFO, 43 | format="%(asctime)s - %(levelname)s - %(message)s", 44 | filemode="a", 45 | ) 46 | 47 | wins = {} 48 | lose = {} 49 | trajectories = [] 50 | embedding_array = np.zeros((0, 3584)) 51 | tongji=[] 52 | avg_tongji=[] 53 | cnt=0 54 | scores=[] 55 | for i in range(50): 56 | scores.append(0.0) 57 | for trial in range(10): 58 | with open("debug_log.log", "a", encoding="utf-8") as f: 59 | f.write("Trial") 60 | f.write(str(trial)+"\n") 61 | count = 0 62 | task_accs = [] 63 | info = [] 64 | emb_db = faiss.IndexFlatL2(3584) 65 | emb_db.add(embedding_array.astype("float32")) 66 | avg=0.0 67 | for i in range(args.task_start_index, args.task_end_index): 68 | with open("debug_log1.log", "a", encoding="utf-8") as f: 69 | f.write("------------new task---------------\n") 70 | f.write(str(i)+"\n") 71 | # solve 72 | if i in wins: 73 | continue 74 | prev = None 75 | knnret = [] 76 | if trajectories and args.cot_size>0: 77 | if i in lose: 78 | prev = lose[i] 79 | else: 80 | prev=quote_env(f'fixed_{i}') 81 | if args.cot_method=="knn": 82 | fail_vec = online_embed(str(prev)) 83 | dist=[] 84 | for index, row in np.ndenumerate(embedding_array): 85 | if index[1] == 0: 86 | realN = (1.0 / np.linalg.norm(np.array(fail_vec) - embedding_array[index[0]])) 87 | dist.append((realN,index[0])) 88 | dist.sort(key=lambda x: x[0], reverse=True) 89 | sz_now=min(len(trajectories),args.cot_size) 90 | dist=dist[:sz_now] 91 | dist=[(math.exp(5.0*x),y) for (x,y) in dist] 92 | original_sum=sum([x for (x,_) in dist]) 93 | dist=[(x *float(sz_now) / original_sum,y) for (x,y) in dist] 94 | tot=0 95 | cntD=[] 96 | realD=[] 97 | for (x,y) in dist: 98 | cntD.append(math.floor(x)) 99 | realD.append(x) 100 | tot+=math.floor(x) 101 | while totmaxx: 106 | maxx=realD[ig]-float(cntD[ig]) 107 | maxi=ig 108 | cntD[maxi]+=1 109 | tot+=1 110 | weights=[] 111 | for (x,y) in dist: 112 | weights.append(x) 113 | normalized_weights = [w / sum(weights) for w in weights] 114 | for ig in range(sz_now): 115 | s_ind = np.random.choice(len(weights), p=normalized_weights) 116 | (_,y)=dist[s_ind] 117 | knnret.append(trajectories[y]) 118 | else: 119 | knnret=random.sample(trajectories, min(args.cot_size,len(trajectories))) 120 | state, value, all_nodes, reward, em, failt, succt = dfs_search( 121 | args, task, f'fixed_{i}', args.iteration, knnret 122 | ) 123 | if failt: 124 | with open("debug_log.log", "a", encoding="utf-8") as f: 125 | f.write("FAIL\n") 126 | print("Fail") 127 | print(i) 128 | lose[i] = failt[0] 129 | if succt: 130 | cnt=cnt+1 131 | with open("debug_log.log", "a", encoding="utf-8") as f: 132 | f.write("SUCCESS\n") 133 | print("Success") 134 | print(i) 135 | wins[i] = 1 136 | vec = online_embed(str(succt[0])) 137 | trajectories.append(succt[0]) 138 | embedding_array = np.vstack((embedding_array, np.array(vec))) 139 | scores[i]=max(scores[i],reward) 140 | avg+=reward 141 | # log main metric 142 | if em is None: 143 | em = 0 144 | task_accs.append(em) 145 | cnt_avg = sum(task_accs) / len(task_accs) 146 | print(i, "len(task_accs)", len(task_accs), "cnt_avg", cnt_avg, "\n") 147 | tongji.append(cnt) 148 | avg=sum(scores)/50.0 149 | with open("debug_log.log", "a", encoding="utf-8") as f: 150 | f.write("average: ") 151 | f.write(str(avg)+"\n") 152 | avg_tongji.append(avg) 153 | print("prompt:") 154 | print(lats.prompt_tokens) 155 | print("completion:") 156 | print(lats.completion_tokens) 157 | print(tongji) 158 | print(avg_tongji) 159 | n = args.task_end_index - args.task_start_index 160 | 161 | 162 | def parse_args(): 163 | args = argparse.ArgumentParser() 164 | args.add_argument("--model_size", type=str, default="8") 165 | args.add_argument("--temperature", type=float, default=1.0) 166 | args.add_argument("--task_start_index", type=int, default=900) 167 | args.add_argument("--task_end_index", type=int, default=1000) 168 | args.add_argument("--prompt_sample", type=str, choices=["standard", "cot"]) 169 | args.add_argument("--n_generate_sample", type=int, default=1) 170 | args.add_argument("--n_evaluate_sample", type=int, default=1) 171 | args.add_argument("--iteration", type=int, default=50) 172 | args.add_argument("--algorithm", type=str, choices=["lats", "rap", "tot"]) 173 | args.add_argument("--cot_method", type=str, choices=["knn", "random", "None"]) 174 | args.add_argument("--run_name", type=str) 175 | args.add_argument("--log_file_path", type=str) 176 | args.add_argument("--log_dir", type=str) 177 | args.add_argument("--cot_size", type=int, default=0) 178 | args = args.parse_args() 179 | return args 180 | 181 | 182 | if __name__ == "__main__": 183 | args = parse_args() 184 | print(args) 185 | run(args) 186 | -------------------------------------------------------------------------------- /webshop/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" 4 | os.environ["PYTHONUTF8"] = "1" 5 | import time 6 | from typing import Dict, List 7 | import openai 8 | import random, json 9 | from IPython import embed 10 | import random 11 | import math 12 | 13 | tot_in=0 14 | tot_out=0 15 | 16 | class ModelServer: 17 | 18 | def get_completion_or_embedding( 19 | self, 20 | model_size: str, 21 | message, 22 | temperature: float = 0.0, 23 | max_tokens: int = 256, 24 | get_embedding: bool = False, 25 | ) -> str: 26 | global tot_in 27 | global tot_out 28 | assert model_size in ["70", "8", "7"] 29 | 30 | if not get_embedding: 31 | assert type(message) == list, "Message should be a list." 32 | response = client.chat.completions.create( 33 | model=model_name, 34 | messages=message, 35 | max_tokens=max_tokens, 36 | temperature=temperature, 37 | stop=["<|eot_id|>", "\nObservation", "Observation"], 38 | ) 39 | tot_out+=response.usage.completion_tokens 40 | tot_in+=response.usage.prompt_tokens 41 | else: 42 | assert type(message) == str, "Message should be a string." 43 | response = client.embeddings.create( 44 | model=model_name, 45 | input=message, 46 | ) 47 | 48 | 49 | if get_embedding: 50 | return response.data[0].embedding 51 | else: 52 | return response 53 | 54 | 55 | -------------------------------------------------------------------------------- /webshop/webshop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from base import Task 4 | from prompt import * 5 | from models import gpt 6 | import logging 7 | import random 8 | from transformers import GPT2Tokenizer 9 | 10 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 11 | 12 | def get_token_length(text): 13 | return len(tokenizer.encode(text)) 14 | 15 | max_token_length = 131072 16 | 17 | class WebShopTask(Task): 18 | """ 19 | Input (x) : a text instruction 20 | Output (y) : a text generation 21 | Reward (r) : # TODO 22 | Input Example: 23 | Output Example: 24 | """ 25 | def __init__(self): 26 | """ 27 | file: a text file, each line is some sentences 28 | """ 29 | super().__init__() 30 | self.steps = 7 31 | self.stops = ['\nObservation:\n', None] 32 | self.value_cache = {} 33 | self.reflections = [] 34 | 35 | def test_output(self, idx: int, output: str): 36 | output = output.split('Action:\n')[-1] 37 | prompt = score_prompt + output 38 | score_outputs = gpt(prompt, n=5, model='gpt-4') 39 | scores = [] 40 | for score_output in score_outputs: 41 | # print(score_output) 42 | pattern = r".*correctness score is (\d+).*" 43 | match = re.match(pattern, score_output, re.DOTALL) 44 | if match: 45 | score = int(match.groups()[0]) 46 | scores.append(score) 47 | else: 48 | print(f'------------------score no match: {[score_output]}') 49 | print(scores) 50 | # print('------------') 51 | info = {'rs': scores, 'r': sum(scores) / len(scores) if scores else 0} 52 | return info 53 | 54 | @staticmethod 55 | def standard_prompt_wrap(x: str, y:str='') -> str: 56 | return standard_prompt.format(input=x) + y 57 | 58 | @staticmethod 59 | def generate_self_reflection(traj, question): 60 | 61 | reflect_prompt = reflection_prompt.format(trajectory=traj) 62 | 63 | reflection = gpt4(reflect_prompt) 64 | 65 | traj_with_reflection = traj + "Reflection: " + reflection[0] + "\n" 66 | 67 | reflection_mapping = { 68 | 'question': question, 69 | 'reflection': reflection[0] 70 | } 71 | 72 | return traj_with_reflection, reflection_mapping 73 | 74 | @staticmethod 75 | def generate_self_reflection(z, question): 76 | reflection_mapping = [] 77 | trajectories = "" 78 | 79 | sampled_items = random.sample(z, min(3, len(z))) 80 | failed_trajectories = [item['trajectory'] + f"\nReward: {item['r']}\n" for item in sampled_items if isinstance(item, dict) and 'trajectory' in item and 'r' in item] 81 | 82 | for traj in failed_trajectories: 83 | trajectories += traj 84 | reflect_prompt = reflection_prompt.format(trajectory=traj) 85 | 86 | reflection = gpt(reflect_prompt) 87 | 88 | trajectories += "Reflection: " + reflection[0] + "\n" 89 | 90 | reflection_mapping.append({ 91 | 'question': question, 92 | 'trajectory': traj, 93 | 'reflection': reflection[0] 94 | }) 95 | 96 | return reflection_mapping 97 | 98 | @staticmethod 99 | def cot_prompt_wrap(x: str, y: str = '', reflection_mapping_list=[],traj_list=[]): 100 | question = x 101 | input = x + y 102 | trajectories = "" 103 | 104 | ret=prompt1 105 | for tj in traj_list: 106 | ret+=tj+"\n" 107 | ret+=input+"\n" 108 | return ret 109 | 110 | 111 | 112 | @staticmethod 113 | def vote_prompt_wrap(x: str, ys: list) -> str: 114 | prompt = score_prompt + "\n" + x + "\n\n" 115 | for i, y in enumerate(ys, 1): 116 | # y = y.replace('Plan:\n', '') 117 | # TODO: truncate the plan part? 118 | prompt += f'Choice {i}:\n{y}\n' 119 | return prompt 120 | 121 | @staticmethod 122 | def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list: 123 | vote_results = [0] * n_candidates 124 | for vote_output in vote_outputs: 125 | pattern = r".*best trajectory is .*(\d+).*" 126 | match = re.match(pattern, vote_output, re.DOTALL) 127 | if match: 128 | vote = int(match.groups()[0]) - 1 129 | if vote in range(n_candidates): 130 | vote_results[vote] += 1 131 | else: 132 | print(f'vote no match: {[vote_output]}') 133 | return vote_results 134 | 135 | @staticmethod 136 | def compare_prompt_wrap(x: str, ys: list) -> str: 137 | assert len(ys) == 2, 'compare prompt only supports 2 candidates' 138 | 139 | # Extract the last Action for each trajectory 140 | last_actions = [] 141 | for y in ys: 142 | # Split by line and reverse to start from the end 143 | lines = y.split('\n')[::-1] 144 | for line in lines: 145 | # Check for an Action line and get its content 146 | if "Action" in line: 147 | last_actions.append(line.split('Action')[-1].strip(': ')) 148 | break 149 | 150 | assert len(last_actions) == 2, 'Expected to find 2 Actions' 151 | 152 | # Construct the prompt with the extracted Actions 153 | prompt = compare_prompt + f'Action 1:{last_actions[0]}\n\nAction 2:{last_actions[1]}\n' 154 | return prompt 155 | 156 | 157 | @staticmethod 158 | def compare_output_unwrap(compare_output: str): 159 | if 'more correct trajectory is 1' in compare_output: 160 | return 0 161 | elif 'more correct trajectory is 2' in compare_output: 162 | return 1 163 | elif "two trajectories are similarly correct" in compare_output: 164 | return 0.5 165 | else: 166 | print(f'-----------------compare no match: {[compare_output]}') 167 | return -1 168 | 169 | @staticmethod 170 | def value_prompt_wrap(x: str, y: str, z: list = [], reflections: list = []) -> str: 171 | question = x.split('\n')[0] 172 | if len(z) != 0: 173 | failed_trajectories = "" 174 | for traj, ref in zip(z, reflections): 175 | score = int(traj['r'] * 10) / 2 176 | trajectory = traj['trajectory'] 177 | split_trajectory = trajectory.split('Action: ') 178 | first_part = split_trajectory[0] # This part will not be modified 179 | 180 | # Remove the first 'Action' and corresponding 'Observation' 181 | remaining_parts = split_trajectory[2:] 182 | 183 | # Reconstruct the trajectory string 184 | new_trajectory = 'Action: '.join([first_part] + remaining_parts) 185 | traj['trajectory'] = new_trajectory 186 | failed_trajectories += f"{y}\n{traj}\nReflection: {ref['reflection']}\nThus the correctness score is {score}\n" 187 | 188 | inp = y + "\n\nReflection: " 189 | prompt = score_prompt_feedback.format(s="", trajectories=failed_trajectories, input=inp) 190 | else: 191 | inp = y + "\n\nReflection: " 192 | prompt = score_prompt.format(s="", input=inp) 193 | 194 | return prompt 195 | 196 | 197 | @staticmethod 198 | def value_outputs_unwrap(evaluate_prompt: str): 199 | evaluate_prompt = evaluate_prompt[0] 200 | if '10' in evaluate_prompt: 201 | return 1.0 202 | elif '9' in evaluate_prompt: 203 | return 0.9 204 | elif '8' in evaluate_prompt: 205 | return 0.8 206 | elif '7' in evaluate_prompt: 207 | return 0.7 208 | elif '6' in evaluate_prompt: 209 | return 0.6 210 | elif '5' in evaluate_prompt: 211 | return 0.5 212 | elif '4' in evaluate_prompt: 213 | return 0.4 214 | elif '3' in evaluate_prompt: 215 | return 0.3 216 | elif '2' in evaluate_prompt: 217 | return 0.2 218 | elif '1' in evaluate_prompt: 219 | return 0.1 220 | else: 221 | return -1 222 | 223 | --------------------------------------------------------------------------------