├── .gitignore ├── LICENSE ├── README.md ├── assets ├── mind2web-results.jpg ├── online-offline.jpg ├── teaser.jpg └── webarena-leaderboard.jpg ├── mind2web ├── README.md ├── data │ └── memory │ │ └── exemplars.json ├── memory.py ├── offline_induction.py ├── online_induction.py ├── pipeline.py ├── prompt │ ├── instruction_abstract.txt │ ├── instruction_action.txt │ ├── one_shot_abstract.txt │ └── one_shot_action.txt ├── requirements.txt ├── results │ └── calc_score.py ├── run_mind2web.py ├── utils │ ├── data.py │ ├── env.py │ └── llm.py └── workflow │ └── retrieve.py └── webarena ├── README.md ├── agents ├── __init__.py ├── basic │ ├── __init__.py │ └── agent.py └── legacy │ ├── __init__.py │ ├── agent.py │ ├── dynamic_prompting.py │ └── utils │ ├── __init__.py │ ├── chat_api.py │ ├── llm_utils.py │ └── prompt_templates.py ├── autoeval ├── clients.py ├── evaluate_trajectory.py ├── evaluator.py ├── prompts.py └── requirements.txt ├── config_files ├── generate_test_data.py ├── test.json └── test.raw.json ├── induce_prompt.py ├── induce_rule.py ├── pipeline.py ├── prompt ├── instruction.txt └── one_shot.txt ├── requirements.txt └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | **/results/** 3 | webarena/config_files/** 4 | webarena/autoeval/log 5 | **/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Agent Workflow Memory

3 | 4 | arXiv 5 | 6 | 7 | PRs Welcome 8 | 9 |
10 | 11 | ## Quickstart :boom: 12 | To run AWM on WebArena under `webarena/`: 13 | ```bash 14 | cd webarena 15 | python pipeline.py --website "shopping" # choose one from ['shopping', 'shopping_admin', 'reddit', 'gitlab', 'map'] 16 | ``` 17 | 18 | To run AWM on Mind2Web under `mind2web/`: 19 | ```bash 20 | cd mind2web 21 | python pipeline.py --setup "offline" # or "online" 22 | ``` 23 | Check `webarena/` and `mind2web/` folders for more detailed instructions about environment and data setups. 24 | 25 | ## What is Agent Workflow Memory? 🧠 26 | Agent Workflow Memory (AWM) proposes to induce, integrate, and utilize workflows via an agent memory. 27 | A workflow is usually a common sub-routine in solving tasks, with example-specific contexts being abstracted out. 28 | 29 |

30 | 31 | 32 | 33 |

34 | 35 | AWM can operate in both offline and online settings: 36 | - *offline* (left): when additional (e.g., training) examples are available, agents induce workflows from ground-truth annotated examples 37 | - *online* (right): without any auxiliary data, agents induce workflows from past experiences on the fly. 38 | 39 |

40 | 41 | 42 | 43 |

44 | 45 | ## How does AWM work? 📈 46 | 47 | ### On WebArena 48 | We achieve the state-of-the-art result -- 35.6% success rate. 49 | 50 |

51 | 52 | 53 | 54 |

55 | 56 | Check the code in `./webarena/` directory. 57 | 58 | ### On Mind2Web 59 | 60 | We also get the best scores among text-based agents. Particularly, AWM offline effectively generalizes across a wide range of tasks, websites, and domains. 61 | 62 |

63 | 64 | 65 | 66 |

67 | 68 | Check the code in `./mind2web/` directory. 69 | 70 | ## Citation 📜 71 | 72 | ```bibtex 73 | @inproceedings{awm2024wang, 74 | title = {Agent Workflow Memory}, 75 | author = {Wang, Zhiruo anf Mao, Jiayuan, and Fried, Daniel and Neubig, Graham}, 76 | journal={arXiv preprint arXiv:2409.07429}, 77 | year = {2024}, 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /assets/mind2web-results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/mind2web-results.jpg -------------------------------------------------------------------------------- /assets/online-offline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/online-offline.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/webarena-leaderboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/webarena-leaderboard.jpg -------------------------------------------------------------------------------- /mind2web/README.md: -------------------------------------------------------------------------------- 1 | # ATM for Mind2Web 2 | 3 | ## Install 4 | 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | Download data from the [mind2web](https://github.com/OSU-NLP-Group/Mind2Web) project, make sure you have `test_task`, `test_website`, `test_domain`, and `train` under the `data` directory; download `scores_all_data.pkl` for HTML filtering at [[link]](https://buckeyemailosu-my.sharepoint.com/personal/deng_595_buckeyemail_osu_edu/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fdeng%5F595%5Fbuckeyemail%5Fosu%5Fedu%2FDocuments%2FMind2Web%2Fscores%5Fall%5Fdata%2Epkl&parent=%2Fpersonal%2Fdeng%5F595%5Fbuckeyemail%5Fosu%5Fedu%2FDocuments%2FMind2Web&ga=1). 10 | 11 | ## Offline Workflow Induction + Test Inference 12 | 13 | To run offline workflow induction with training examples: 14 | ```bash 15 | python offline_induction.py \ 16 | --mode auto --domain Travel --subdomain Airlines --website aa \ 17 | --model "gpt-4o" --output_dir "workflow" 18 | ``` 19 | You can also switch to `--mode input` to dynamically input your desired website(s). 20 | 21 | The above command will produce a workflow file `workflow/aa.txt`, to augment this workflow in agent memory and run inference on test examples from the *aa* website: 22 | 23 | ```bash 24 | python run_mind2web.py --website "aa" --workflow_path "workflow/aa.txt" 25 | ``` 26 | 27 | ## Online Induction with Test Queries 28 | 29 | To run online workflow induction and utilization: 30 | ```bash 31 | python pipeline.py --setup online \ 32 | --benchmark "test_task" --website aa \ 33 | --results_dir results/aa/workflow \ 34 | --workflow_path workflow/aa.txt 35 | ``` 36 | 37 | Simply change to `--benchmark 'train'` if you want to run online setting on the training (or other) queries, but remember to apply to workflow and run inference on test examples afterwards. 38 | 39 | 40 | ## Overall 41 | To run the entire pipeline for both online and offline settings, you can use: 42 | ```bash 43 | python pipeline.py --setup "offline" # or "online" 44 | ``` 45 | with other arguments specified as above. 46 | -------------------------------------------------------------------------------- /mind2web/memory.py: -------------------------------------------------------------------------------- 1 | import os, json, random 2 | import numpy as np 3 | from pathlib import Path 4 | from openai import BadRequestError 5 | from utils.env import * 6 | from utils.llm import ( 7 | generate_response, num_tokens_from_messages, 8 | MAX_TOKENS, extract_from_response, 9 | ) 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_exemplars(args) -> list: 16 | """Get exemplar workflows in the prompt.""" 17 | # workflow memory 18 | memory = [] 19 | workflow_text = open(args.workflow_path, 'r').read().strip() 20 | if len(workflow_text): 21 | memory = [[{"role": "user", "content": workflow_text}]] 22 | 23 | # concrete examples 24 | with open(os.path.join(args.memory_path, "exemplars.json"), "r") as f: 25 | concrete_examples = json.load(f) 26 | if any([args.website in cex[0].get("specifier", "") for cex in concrete_examples]): 27 | concrete_examples = [ 28 | cex for cex in concrete_examples 29 | if all([tag in cex[0]["specifier"] for tag in [args.domain, args.subdomain, args.website]]) 30 | ] 31 | elif any([args.subdomain in cex[0].get("specifier", "") for cex in concrete_examples]): 32 | concrete_examples = [ 33 | cex for cex in concrete_examples 34 | if all([tag in cex[0]["specifier"] for tag in [args.domain, args.subdomain]]) 35 | ] 36 | 37 | memory += random.sample(concrete_examples, 38 | min(args.retrieve_top_k, len(concrete_examples))) 39 | memory = [[{k:v for k,v in m.items() if k!="specifier"} for m in e] for e in memory] 40 | return memory 41 | 42 | 43 | def eval_sample(task_id, args, sample): 44 | # initialize metrics 45 | element_acc, action_f1, step_success, success = [], [], [], [] 46 | token_stats = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} 47 | conversation = [] 48 | episode_length = len(sample["action_reprs"]) 49 | 50 | exemplars = get_exemplars(args) 51 | # print(exemplars) 52 | 53 | sys_message = [ 54 | { 55 | "role": "system", 56 | "content": "You are a large language model trained to navigate the web. Output the next action and wait for the next observation. Here is the action space:\n1. `CLICK [id]`: Click on an HTML element with its id.\n2. `TYPE [id] [value]`: Type a string into the element with the id.\n3. `SELECT [id] [value]`: Select a value for an HTML element by its id.", 57 | } 58 | ] 59 | 60 | prev_actions, prev_obs = [], [] 61 | previous_k = 5 62 | 63 | for s, act_repr in zip(sample["actions"], sample["action_reprs"]): 64 | _, target_act = get_target_obs_and_act(s) 65 | pos_candidates = [ 66 | c for c in s["pos_candidates"] if c["rank"] < args.top_k_elements 67 | ] 68 | 69 | # get query, obs, act 70 | target_obs, _ = get_top_k_obs(s, args.previous_top_k_elements) 71 | # Continue next loop if the ground truth element is not in the cleaned html 72 | if len(pos_candidates) == 0: 73 | element_acc.append(0) 74 | action_f1.append(0) 75 | step_success.append(0) 76 | prev_obs.append("Observation: `" + target_obs + "`") 77 | prev_actions.append("Action: `" + target_act + "` (" + act_repr + ")") 78 | conversation.append("The ground truth element is not in cleaned html") 79 | continue 80 | 81 | # construct query 82 | query = [] 83 | for o, a in zip(prev_obs, prev_actions): 84 | if len(query) == 0: 85 | query.append({ 86 | "role": "user", 87 | "content": f"Task: {sample['confirmed_task']}\nTrajectory:\n" + o, 88 | }) 89 | else: 90 | query.append({"role": "user", "content": o}) 91 | query.append({"role": "assistant", "content": a}) 92 | 93 | obs, _ = get_top_k_obs(s, args.top_k_elements, use_raw=False) 94 | if len(query) == 0: 95 | query.append({ 96 | "role": "user", 97 | "content": f"Task: {sample['confirmed_task']}\nTrajectory:\n" 98 | + "Observation: `" + obs + "`", 99 | }) 100 | else: 101 | query.append({"role": "user", "content": "Observation: `" + obs + "`"}) 102 | 103 | prev_obs.append("Observation: `" + target_obs + "`") 104 | prev_actions.append("Action: `" + target_act + "` (" + act_repr + ")") 105 | 106 | # token limit 107 | total_num_tokens = num_tokens_from_messages(sys_message + query, args.model) 108 | if total_num_tokens > MAX_TOKENS[args.model]: 109 | logger.info( 110 | f"Too many tokens in acting ({total_num_tokens} / {MAX_TOKENS[args.model]}), skipping..." 111 | ) 112 | element_acc.append(0) 113 | action_f1.append(0) 114 | step_success.append(0) 115 | conversation.append( 116 | { 117 | "input": sys_message + query, 118 | "output": f"FAILED DUE TO THE CONTEXT LIMIT: {total_num_tokens}", 119 | } 120 | ) 121 | continue 122 | 123 | # message 124 | demo_message = [] 125 | for e_id, e in enumerate(exemplars): 126 | total_num_tokens = num_tokens_from_messages( 127 | sys_message + demo_message + e + query, args.model 128 | ) 129 | if total_num_tokens > MAX_TOKENS[args.model]: 130 | logger.info( 131 | f"Using {e_id} / {len(exemplars)} exemplars due to context limit" 132 | ) 133 | break 134 | else: 135 | demo_message.extend(e) 136 | 137 | message = sys_message + demo_message + query 138 | try: 139 | response, info = generate_response( 140 | messages=message, 141 | model=args.model, 142 | temperature=args.temperature, 143 | stop_tokens=["Task:", "obs:"], 144 | ) 145 | except BadRequestError: 146 | response = "" 147 | info = { 148 | "prompt_tokens": 0, 149 | "completion_tokens": 0, 150 | "total_tokens": 0, 151 | } 152 | conversation.append({"input": message, "output": response, "token_stats": info}) 153 | for k, v in info.items(): 154 | token_stats[k] += v 155 | pred_act = extract_from_response(response, "`") 156 | pred_op, pred_id, pred_val = parse_act_str(pred_act) 157 | target_op, _, target_val = parse_act_str(target_act) 158 | 159 | # calculate metrics 160 | pos_ids = [c["backend_node_id"] for c in s["pos_candidates"]][:1] 161 | if pred_id in pos_ids: 162 | element_acc.append(1) 163 | else: 164 | element_acc.append(0) 165 | action_f1.append( 166 | calculate_f1( 167 | construct_act_str(pred_op, pred_val), 168 | construct_act_str(target_op, target_val), 169 | ) 170 | ) 171 | conversation.append({"pred_act": pred_act, "target_act": target_act}) 172 | if pred_act == target_act: 173 | step_success.append(1) 174 | else: 175 | step_success.append(0) 176 | 177 | # check the last episode_length of step_success, if all 1, then success = 1 178 | if np.sum(step_success[-episode_length:]) == episode_length: 179 | success.append(1) 180 | else: 181 | success.append(0) 182 | 183 | conversation.append( 184 | { 185 | "element_acc": element_acc, 186 | "action_f1": action_f1, 187 | "step_success": step_success, 188 | "success": success, 189 | } 190 | ) 191 | log_dir = Path(f"{args.log_dir}/{args.model}/{args.benchmark}/{args.website}/{args.suffix}") 192 | log_dir.mkdir(parents=True, exist_ok=True) 193 | with open(os.path.join(log_dir, f"{task_id}.json"), "w") as f: 194 | json.dump(conversation, f, indent=2) 195 | -------------------------------------------------------------------------------- /mind2web/offline_induction.py: -------------------------------------------------------------------------------- 1 | """Induce Website-Specific Workflows Offline from Training Examples.""" 2 | 3 | import os 4 | import json 5 | import pickle 6 | import argparse 7 | from utils.data import add_scores, format_examples, filter_workflows 8 | 9 | import openai 10 | openai.api_key = os.environ["OPENAI_API_KEY"] 11 | from openai import OpenAI 12 | client = OpenAI() 13 | 14 | # %% Data loading and processing 15 | def get_data_dict(paths: list[str]) -> dict: 16 | """Create dict for examples in domain-subdomain-website hierarchy. 17 | Args: 18 | paths: list[str], list of data path strings 19 | Rets: 20 | data_dict: dict[str, dict], (domain, subdomain, website) dict 21 | """ 22 | print("Start loading data files...") 23 | data_dict = {} 24 | for p in paths: 25 | print(p) 26 | data = json.load(open(p, 'r')) 27 | for ex in data: 28 | domain, subdomain, website = ex["domain"], ex["subdomain"], ex["website"] 29 | if domain not in data_dict: 30 | data_dict[domain] = {} 31 | if subdomain not in data_dict[domain]: 32 | data_dict[domain][subdomain] = {} 33 | if website not in data_dict[domain][subdomain]: 34 | data_dict[domain][subdomain][website] = [] 35 | data_dict[domain][subdomain][website].append(ex) 36 | print(f"Finished loading {len(paths)} files!") 37 | return data_dict 38 | 39 | def get_split(data_dict: dict) -> dict: 40 | """Return the split from the data dict from inputted option.""" 41 | options = list(data_dict.keys()) 42 | split = input(f"Select from {options} >> ") 43 | while split not in options: 44 | split = input(f"Select from {options} >> ") 45 | return split, data_dict[split] 46 | 47 | def get_examples(data_dict: dict, tags: tuple[str, str, str]) -> list[dict]: 48 | """Return the examples satisfying the tags.""" 49 | domain, subdomain, website = tags 50 | return data_dict[domain][subdomain][website] 51 | 52 | 53 | # %% Prompt and generate 54 | def llm_generate(tags: tuple[str, str, str], examples: list[dict], args, verbose: bool = False): 55 | """Call gpt model to generate workflows.""" 56 | prompt = f"Website: " + ','.join(tags) + '\n' 57 | prompt += format_examples(examples, args.prefix, args.suffix) 58 | prompt = '\n\n'.join([args.INSTRUCTION, args.ONE_SHOT, prompt]) 59 | if verbose: print("Prompt:\n", prompt, '\n\n') 60 | response = client.chat.completions.create( 61 | model=args.model_name, 62 | messages=[{"role": "user", "content": prompt}], 63 | temperature=args.temperature, 64 | max_tokens=1024, 65 | ) 66 | response = response.choices[0].message.content 67 | if verbose: print(response) 68 | return response 69 | 70 | # %% Save outputs 71 | def save_to_txt(text: str, args): 72 | """Save text to a .txt file.""" 73 | output_name = f"{args.website.lower()}_{args.output_suffix}.txt" \ 74 | if args.output_suffix is not None else f"{args.website}.txt" 75 | output_path = os.path.join(args.output_dir, output_name) 76 | with open(output_path, 'w') as fw: 77 | fw.write(text) 78 | 79 | 80 | # %% Main pipeline 81 | def main(): 82 | # load data into dict 83 | data_paths = [os.path.join(args.data_dir, f) for f in os.listdir(args.data_dir)] 84 | data_dict = get_data_dict(paths=data_paths) 85 | 86 | # load candidate scores and ranks 87 | with open(os.path.join("data", "scores_all_data.pkl"), "rb") as f: 88 | candidate_results = pickle.load(f) 89 | 90 | # load prompt contexts 91 | args.INSTRUCTION = open(args.instruction_path, 'r').read() 92 | args.ONE_SHOT = open(args.one_shot_path, 'r').read() 93 | 94 | def single_website_loop(tags: tuple[str, str, str]): 95 | """Pipeline to induce, filter, and save workflows on a single website.""" 96 | examples = get_examples(data_dict, tags=tags) 97 | print(f"Split {tags} with #{len(examples)} examples") 98 | add_scores(examples, candidate_results) 99 | response = llm_generate(tags, examples, args) 100 | workflows = filter_workflows(response, website=tags[-1]) 101 | save_to_txt(workflows, args) 102 | 103 | if args.mode == "auto": 104 | single_website_loop(args.tags) 105 | elif args.mode == "input": 106 | stop = False 107 | while not stop: 108 | # select split 109 | args.domain, domain_dict = get_split(data_dict) 110 | args.subdomain, subdomain_dict = get_split(domain_dict) 111 | args.website, examples = get_split(subdomain_dict) 112 | 113 | # generate workflows 114 | tags = [args.domain, args.subdomain, args.website] 115 | single_website_loop(tags) 116 | 117 | if input("Stop? [y/n] ").strip() == 'y': 118 | stop = True 119 | else: 120 | raise ValueError("Please enter a valid `mode` ('input' or 'auto')") 121 | 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--data_dir", type=str, default="data/train") 126 | parser.add_argument("--output_dir", type=str, default="workflow") 127 | parser.add_argument("--output_suffix", type=str, default=None) 128 | parser.add_argument("--verbose", action="store_true", 129 | help="Whether to print prompt and response.") 130 | 131 | # mode 132 | parser.add_argument("--mode", type=str, default="input", 133 | choices=["input", "auto"]) 134 | parser.add_argument("--domain", type=str, default=None, 135 | help="Specify in 'auto' mode.") 136 | parser.add_argument("--subdomain", type=str, default=None, 137 | help="Specify in 'auto' mode.") 138 | parser.add_argument("--website", type=str, default=None, 139 | help="Specify in 'auto' mode.") 140 | # model 141 | parser.add_argument("--model_name", type=str, default="gpt-4o") 142 | parser.add_argument("--temperature", type=float, default=0.0) 143 | 144 | # prompt 145 | parser.add_argument("--instruction_path", type=str, default="prompt/instruction_action.txt") 146 | parser.add_argument("--one_shot_path", type=str, default="prompt/one_shot_action.txt") 147 | parser.add_argument("--prefix", type=str, default=None) 148 | parser.add_argument("--suffix", type=str, default="# Summary Workflows") 149 | 150 | args = parser.parse_args() 151 | 152 | # sanity check 153 | if args.mode == "auto": 154 | args.tags = [args.domain, args.subdomain, args.website] 155 | assert not any([tag is None for tag in args.tags]) 156 | 157 | main() 158 | -------------------------------------------------------------------------------- /mind2web/online_induction.py: -------------------------------------------------------------------------------- 1 | """Induce Workflows from Past Agent Experiences.""" 2 | 3 | import os 4 | import json 5 | import argparse 6 | from utils.data import load_json, format_examples, filter_workflows 7 | 8 | import openai 9 | openai.api_key = os.environ["OPENAI_API_KEY"] 10 | from openai import OpenAI 11 | client = OpenAI() 12 | 13 | 14 | def is_io_dict(item: dict | str) -> bool: 15 | if isinstance(item, dict) and ("input" in item) and ("output" in item): return True 16 | return False 17 | 18 | def get_trajectory(path: str): 19 | trajectory = [] 20 | result = json.load(open(path, 'r')) 21 | for item in result: 22 | if not is_io_dict(item): continue 23 | step = { 24 | "env": "# " + item["input"][-1]["content"], 25 | "action": item["output"], 26 | } 27 | trajectory.append(step) 28 | return trajectory 29 | 30 | 31 | def main(): 32 | samples = load_json(args.data_dir, args.benchmark) 33 | print(f"Loaded #{len(samples)} test examples") 34 | samples = [s for s in samples if s["website"] == args.website] 35 | print(f"Filtering down to #{len(samples)} examples on website [{args.website}]") 36 | 37 | # load model predictions and format examples 38 | result_files = [os.path.join(args.results_dir, f) for f in os.listdir(args.results_dir)] 39 | result_list = [get_trajectory(rf) for rf in result_files] 40 | examples = [] 41 | for r, s in zip(result_list, samples): 42 | examples.append({ 43 | "confirmed_task": s["confirmed_task"], 44 | "action_reprs": [step["env"] + '\n' + step["action"] for step in r], 45 | }) 46 | prompt = format_examples(examples, args.prefix, args.suffix) 47 | 48 | # transform to workflows 49 | INSTRUCTION = open(args.instruction_path, 'r').read() 50 | ONE_SHOT = open(args.one_shot_path, 'r').read() 51 | domain, subdomain, website = samples[0]["domain"], samples[0]["subdomain"], samples[0]["website"] 52 | prompt = '\n\n'.join([INSTRUCTION, ONE_SHOT, f"Website: {domain}, {subdomain}, {website}\n{prompt}"]) 53 | response = client.chat.completions.create( 54 | model=args.model_name, 55 | messages=[{"role": "user", "content": prompt}], 56 | temperature=args.temperature, 57 | ).choices[0].message.content 58 | response = filter_workflows(response, args.website) 59 | 60 | # save to file 61 | with open(args.output_path, 'w') as fw: 62 | fw.write(response) 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument("--data_dir", type=str, default="data") 68 | parser.add_argument("--benchmark", type=str, default="test_task", 69 | choices=["test_task", "test_website", "test_domain", "train"]) 70 | parser.add_argument("--website", type=str, required=True) 71 | parser.add_argument("--results_dir", type=str, required=True) 72 | parser.add_argument("--output_path", type=str, required=True) 73 | 74 | # model 75 | parser.add_argument("--model_name", type=str, default="gpt-4o") 76 | parser.add_argument("--temperature", type=str, default=0.0) 77 | # prompt 78 | parser.add_argument("--instruction_path", type=str, default="prompt/instruction_action.txt") 79 | parser.add_argument("--one_shot_path", type=str, default="prompt/one_shot_action.txt") 80 | parser.add_argument("--prefix", type=str, default=None) 81 | parser.add_argument("--suffix", type=str, default="# Summary Workflows") 82 | 83 | args = parser.parse_args() 84 | 85 | main() 86 | -------------------------------------------------------------------------------- /mind2web/pipeline.py: -------------------------------------------------------------------------------- 1 | """Online Induction and Workflow Utilization Pipeline.""" 2 | 3 | import argparse 4 | import subprocess 5 | from utils.data import load_json 6 | 7 | def offline(): 8 | # workflow induction 9 | process = subprocess.Popen([ 10 | 'python', 'offline_induction.py', 11 | '--mode', 'auto', '--website', args.website, 12 | '--domain', args.domain, '--subdomain', args.subdomain, 13 | '--model', args.model, '--output_dir', "workflow", 14 | '--instruction_path', args.instruction_path, 15 | '--one_shot_path', args.one_shot_path, 16 | ]) 17 | process.wait() 18 | 19 | # test inference 20 | process = subprocess.Popen([ 21 | 'python', 'run_mind2web.py', 22 | '--website', args.website, 23 | '--workflow_path', f"workflow/{args.website}.txt" 24 | ]) 25 | process.wait() 26 | 27 | 28 | def online(): 29 | # load all examples for streaming 30 | samples = load_json(args.data_dir, args.benchmark) 31 | print(f"Loaded #{len(samples)} test examples") 32 | if args.website is not None: 33 | samples = [s for s in samples if s["website"] == args.website] 34 | print(f"Filtering down to #{len(samples)} examples on website [{args.website}]") 35 | n = len(samples) 36 | 37 | for i in range(0, n, args.induce_steps): 38 | j = min(n, i + args.induce_steps) 39 | print(f"Running inference on {i}-{j} th example..") 40 | 41 | process = subprocess.Popen([ 42 | 'python', 'run_mind2web.py', 43 | '--benchmark', args.benchmark, 44 | '--workflow_path', args.workflow_path, 45 | '--website', args.website, 46 | '--start_idx', f'{i}', '--end_idx', f'{j}', 47 | '--domain', args.domain, '--subdomain', args.subdomain, 48 | ]) 49 | process.wait() 50 | print(f"Finished inference on {i}-{j} th example!\n") 51 | 52 | if (j + 1) < len(samples): 53 | process = subprocess.Popen([ 54 | 'python', 'online_induction.py', 55 | '--benchmark', args.benchmark, 56 | '--website', args.website, 57 | '--results_dir', args.results_dir, 58 | '--output_path', args.workflow_path, 59 | ]) 60 | process.wait() 61 | print(f"Finished workflow induction with 0-{i} th examples!\n") 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | # examples 67 | parser.add_argument("--data_dir", type=str, default="data") 68 | parser.add_argument("--benchmark", type=str, default="test_task", 69 | choices=["test_task", "test_website", "test_domain", "train"]) 70 | parser.add_argument("--website", type=str, required=True) 71 | parser.add_argument("--domain", type=str, default=None) 72 | parser.add_argument("--subdomain", type=str, default=None) 73 | 74 | # results and workflows 75 | parser.add_argument("--results_dir", type=str, default=None) 76 | parser.add_argument("--workflow_path", type=str, default=None) 77 | 78 | # prompt 79 | parser.add_argument("--instruction_path", type=str, default="prompt/instruction_action.txt") 80 | parser.add_argument("--one_shot_path", type=str, default="prompt/one_shot_action.txt") 81 | parser.add_argument("--prefix", type=str, default=None) 82 | parser.add_argument("--suffix", type=str, default="# Summary Workflows") 83 | 84 | # gpt 85 | parser.add_argument("--model", type=str, default="gpt-4o") 86 | parser.add_argument("--temperature", type=str, default=0.0) 87 | 88 | # induction frequency 89 | parser.add_argument("--induce_steps", type=int, default=1) 90 | 91 | # setup 92 | parser.add_argument("--setup", type=str, required=True, 93 | choices=["online", "offline"]) 94 | 95 | args = parser.parse_args() 96 | 97 | if args.setup == "online": 98 | assert (args.results_dir is not None) and (args.workflow_path is not None) 99 | online() 100 | elif args.setup == "offline": 101 | assert (args.domain is not None) and (args.subdomain is not None) 102 | offline() 103 | -------------------------------------------------------------------------------- /mind2web/prompt/instruction_abstract.txt: -------------------------------------------------------------------------------- 1 | Given a list of web nagivation tasks, your task is to extract the common workflows to solve these tasks. 2 | Each given task contains a natural language instruction, and a series of actions to solve the task. You need to find the repetitive subset of actions across multiple tasks, and extract each of them out as a workflow. 3 | Each workflow should be a commonly-reused sub-routine of the tasks. Do not generate similar or overlapping workflows. Each workflow should have at least two steps. Represent the non-fixed elements (input text, button strings) with descriptive variable names as shown in the example. -------------------------------------------------------------------------------- /mind2web/prompt/instruction_action.txt: -------------------------------------------------------------------------------- 1 | Given a list of web nagivation tasks, your task is to extract the common workflows to solve these tasks. 2 | Each given task contains a natural language instruction, and a series of actions to solve the task. You need to find the repetitive subset of actions across multiple tasks, and extract each of them out as a workflow. 3 | Each workflow should be a commonly-reused sub-routine of the tasks. Do not generate similar or overlapping workflows. Each workflow should have at least two steps. -------------------------------------------------------------------------------- /mind2web/prompt/one_shot_abstract.txt: -------------------------------------------------------------------------------- 1 | Website: Travel, Airlines, delta 2 | ## Query 1: Find flights from Seattle to New York on June 5th and only show those that can be purchased with miles. 3 | Actions: 4 | [link] From Departure Airport or City Your Origin -> CLICK 5 | [textbox] Origin City or Airport -> TYPE: Seattle 6 | [link] SEA Seattle, WA -> CLICK 7 | [link] To Destination Airport or City Your Destination -> CLICK 8 | [textbox] Destination City or Airport -> TYPE: New York 9 | [link] NYC New York City Area Airports, NY -> CLICK 10 | [combobox] Trip Type:, changes will reload the page -> CLICK 11 | [option] One Way -> CLICK 12 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK 13 | [link] Next -> CLICK 14 | [link] 5 June 2023, Monday -> CLICK 15 | [button] done -> CLICK 16 | [label] Shop with Miles -> CLICK 17 | [button] SUBMIT -> CLICK 18 | 19 | ## Query 2: Find my trip with confirmation number SFTBAO including first and last name Joe Lukeman 20 | Actions: 21 | [tab] MY TRIPS -> CLICK 22 | [combobox] Find Your Trip By -> CLICK 23 | [option] Confirmation Number -> CLICK 24 | [input] -> TYPE: SFTBAO 25 | [input] -> TYPE: Joe 26 | [input] -> TYPE: Lukeman 27 | [button] SUBMIT -> CLICK 28 | 29 | ## Query 3: Find the status of March 25 flights from New York airports to Columbus in Ohio. 30 | Actions: 31 | [tab] FLIGHT STATUS -> CLICK 32 | [button] Search by date required selected as 19 March 202... -> CLICK 33 | [link] 25 March 2023, Saturday -> CLICK 34 | [button] done -> CLICK 35 | [link] Depart City required From -> CLICK 36 | [textbox] Origin City or Airport -> TYPE: New York 37 | [link] NYC New York City Area Airports, NY -> CLICK 38 | [link] Arrival City required To -> CLICK 39 | [textbox] Destination City or Airport -> TYPE: Columbus 40 | [li] CMH -> CLICK 41 | [button] SUBMIT -> CLICK 42 | 43 | ## Query 4: Check all available one way flights for a single passenger from Manhattan to Philadelphia on May 23rd in first class. 44 | Actions: 45 | [link] From Departure Airport or City Your Origin -> CLICK 46 | [textbox] Origin City or Airport -> TYPE: Manhattan 47 | [link] MHK Manhattan Regl, USA -> CLICK 48 | [link] To Destination Airport or City Your Destination -> CLICK 49 | [textbox] Destination City or Airport -> TYPE: Philadelphia 50 | [link] PHL Philadelphia, PA -> CLICK 51 | [combobox] Trip Type:, changes will reload the page -> CLICK 52 | [option] One Way -> CLICK 53 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK 54 | [link] 23 March 2023, Thursday -> CLICK 55 | [button] done -> CLICK 56 | [link] Advanced Search -> CLICK 57 | [combobox] Best Fares For -> CLICK 58 | [option] First Class -> CLICK 59 | [button] SUBMIT -> CLICK 60 | 61 | Extracted Workflows: 62 | # enter_flight_locations 63 | Given that you are on the Delta flight booking page, this workflow enters the departure and destination city/airport for your flight. 64 | [link] {link to enter departure city} -> CLICK 65 | [textbox] {textbox to input departure city} -> TYPE: {your-origin-city} 66 | [link] {best-popup-option} -> CLICK 67 | [link] {link to enter destination city} -> CLICK 68 | [textbox] {textbox to enter destination city} -> TYPE: {your-destination-city} 69 | [link] {best-popup-option} -> CLICK 70 | 71 | # select_oneway_trip 72 | Given that you are on the Delta flight booking page, this workflow changes the flight to be one-way. 73 | [combobox] {option to select trip type} -> CLICK 74 | [option] One Way -> CLICK 75 | 76 | # select_date_for_travel 77 | Given that you are on the Delta flight booking page, this workflow selects the travel date. 78 | [button] {calendar to select flight dates} -> CLICK 79 | [link] {travel-date} -> CLICK 80 | [button] done -> CLICK 81 | 82 | # find_trip 83 | Given that you are on the Delta flight searching page, this workflow finds a trip with the confirmation number and passenger name. 84 | [tab] MY TRIPS -> CLICK 85 | [combobox] {button to instantiate search} -> CLICK 86 | [option] Confirmation Number -> CLICK 87 | [input] -> TYPE: {confirmation-number} 88 | [input] -> TYPE: {passenger-name} 89 | [button] SUBMIT -> CLICK -------------------------------------------------------------------------------- /mind2web/prompt/one_shot_action.txt: -------------------------------------------------------------------------------- 1 | Website: Travel, Airlines, delta 2 | ## Query 1: Find flights from Seattle to New York on June 5th and only show those that can be purchased with miles. 3 | Actions: 4 | [link] From Departure Airport or City Your Origin -> CLICK 5 | [textbox] Origin City or Airport -> TYPE: Seattle 6 | [link] SEA Seattle, WA -> CLICK 7 | [link] To Destination Airport or City Your Destination -> CLICK 8 | [textbox] Destination City or Airport -> TYPE: New York 9 | [link] NYC New York City Area Airports, NY -> CLICK 10 | [combobox] Trip Type:, changes will reload the page -> CLICK 11 | [option] One Way -> CLICK 12 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK 13 | [link] Next -> CLICK 14 | [link] 5 June 2023, Monday -> CLICK 15 | [button] done -> CLICK 16 | [label] Shop with Miles -> CLICK 17 | [button] SUBMIT -> CLICK 18 | 19 | ## Query 2: Find my trip with confirmation number SFTBAO including first and last name Joe Lukeman 20 | Actions: 21 | [tab] MY TRIPS -> CLICK 22 | [combobox] Find Your Trip By -> CLICK 23 | [option] Confirmation Number -> CLICK 24 | [input] -> TYPE: SFTBAO 25 | [input] -> TYPE: Joe 26 | [input] -> TYPE: Lukeman 27 | [button] SUBMIT -> CLICK 28 | 29 | ## Query 3: Find the status of March 25 flights from New York airports to Columbus in Ohio. 30 | Actions: 31 | [tab] FLIGHT STATUS -> CLICK 32 | [button] Search by date required selected as 19 March 202... -> CLICK 33 | [link] 25 March 2023, Saturday -> CLICK 34 | [button] done -> CLICK 35 | [link] Depart City required From -> CLICK 36 | [textbox] Origin City or Airport -> TYPE: New York 37 | [link] NYC New York City Area Airports, NY -> CLICK 38 | [link] Arrival City required To -> CLICK 39 | [textbox] Destination City or Airport -> TYPE: Columbus 40 | [li] CMH -> CLICK 41 | [button] SUBMIT -> CLICK 42 | 43 | ## Query 4: Check all available one way flights for a single passenger from Manhattan to Philadelphia on May 23rd in first class. 44 | Actions: 45 | [link] From Departure Airport or City Your Origin -> CLICK 46 | [textbox] Origin City or Airport -> TYPE: Manhattan 47 | [link] MHK Manhattan Regl, USA -> CLICK 48 | [link] To Destination Airport or City Your Destination -> CLICK 49 | [textbox] Destination City or Airport -> TYPE: Philadelphia 50 | [link] PHL Philadelphia, PA -> CLICK 51 | [combobox] Trip Type:, changes will reload the page -> CLICK 52 | [option] One Way -> CLICK 53 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK 54 | [link] 23 March 2023, Thursday -> CLICK 55 | [button] done -> CLICK 56 | [link] Advanced Search -> CLICK 57 | [combobox] Best Fares For -> CLICK 58 | [option] First Class -> CLICK 59 | [button] SUBMIT -> CLICK 60 | 61 | 62 | Summary Workflows: 63 | ## enter_flight_locations 64 | Given that you are on the Delta flight booking page, this workflow enters the departure and destination city/airport for your flight. 65 | [link] From Departure Airport or City Your Origin -> CLICK 66 | [textbox] Origin City or Airport -> TYPE: {your-origin-city} 67 | [link] {best-popup-option} -> CLICK 68 | [link] To Destination Airport or City Your Destination -> CLICK 69 | [textbox] Destination City or Airport -> TYPE: {your-destination-city} 70 | [link] {best-popup-option} -> CLICK 71 | 72 | ## select_oneway_trip 73 | Given that you are on the Delta flight booking page, this workflow changes the flight to be one-way. 74 | [combobox] Trip Type:, changes will reload the page -> CLICK 75 | [option] One Way -> CLICK 76 | 77 | ## select_date_for_travel 78 | Given that you are on the Delta flight booking page, this workflow selects the travel date. 79 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK 80 | [link] {travel-date} -> CLICK 81 | [button] done -> CLICK 82 | 83 | ## find_trip 84 | Given that you are on the Delta flight searching page, this workflow finds a trip with the confirmation number and passenger name. 85 | [tab] MY TRIPS -> CLICK 86 | [combobox] Find Your Trip By -> CLICK 87 | [option] Confirmation Number -> CLICK 88 | [input] -> TYPE: {confirmation-number} 89 | [input] -> TYPE: {passenger-name} 90 | [button] SUBMIT -> CLICK -------------------------------------------------------------------------------- /mind2web/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | openai 4 | backoff==2.2.1 5 | tiktoken 6 | lxml==4.9.3 7 | matplotlib -------------------------------------------------------------------------------- /mind2web/results/calc_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import matplotlib.pyplot as plt 5 | 6 | def get_average(score_list: list[float], percentage: bool = False) -> float: 7 | score = sum(score_list) / len(score_list) 8 | return score * 100 if percentage else score 9 | 10 | 11 | def main(): 12 | files = os.listdir(args.results_dir) 13 | file_paths = [os.path.join(args.results_dir, f) for f in files] 14 | ele_acc, act_f1, step_sr, sr = [], [], [], [] 15 | for fp in file_paths: 16 | res = json.load(open(fp, 'r'))[-1] 17 | ele_acc.append(get_average(res["element_acc"])) 18 | act_f1.append(get_average(res["action_f1"])) 19 | step_sr.append(get_average(res["step_success"])) 20 | sr.append(get_average(res["success"])) 21 | 22 | print(f"Element Acc: {get_average(ele_acc, True):5.1f}") 23 | print(f"Action F1 : {get_average(act_f1, True):5.1f}") 24 | print(f"Step SR : {get_average(step_sr, True):5.1f}") 25 | print(f"SR : {get_average(sr, True):5.1f}") 26 | 27 | # accumulative step success rate 28 | n = len(step_sr) 29 | x = [i+1 for i in range(n)] 30 | asr = [get_average(step_sr[:i+1]) for i in range(n)] 31 | plt.plot(x, asr) 32 | 33 | # moving average 34 | # window_size = 5 35 | # x, mavg = [], [] 36 | # for i in range(n-window_size+1): 37 | # x.append(i) 38 | # mavg.append(get_average(step_sr[i:i+window_size])) 39 | # plt.plot(x, mavg) 40 | 41 | if args.viz_path is not None: 42 | plt.savefig(args.viz_path) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--results_dir", type=str, required=True) 48 | parser.add_argument("--viz_path", type=str, default=None) 49 | args = parser.parse_args() 50 | 51 | main() 52 | -------------------------------------------------------------------------------- /mind2web/run_mind2web.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from memory import eval_sample 5 | from utils.data import load_json, add_scores 6 | 7 | import logging 8 | logger = logging.getLogger("atm") 9 | logger.setLevel(logging.INFO) 10 | handler = logging.StreamHandler() 11 | handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) 12 | logger.addHandler(handler) 13 | 14 | 15 | def main(): 16 | examples = load_json(args.data_dir, args.benchmark) 17 | examples = [s for s in examples if s["website"] == args.website] 18 | print(f"Filtering down to #{len(examples)} examples on website [{args.website}]") 19 | examples = add_scores(examples) # add prediction scores and ranks to elements 20 | 21 | if args.end_idx is None: 22 | args.end_idx = len(examples) 23 | for i in tqdm(range(args.start_idx, args.end_idx)): 24 | if args.mode == "memory": 25 | eval_sample(i, args, examples[i]) 26 | elif args.mode == "action": 27 | raise NotImplementedError 28 | else: 29 | raise ValueError(f"Unsupported workflow format: {args.workflow_format}") 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--data_dir", type=str, default="data") 35 | parser.add_argument("--benchmark", type=str, default="test_task", 36 | choices=["test_task", "test_website", "test_domain", "train"]) 37 | parser.add_argument("--memory_path", type=str, default="data/memory") 38 | parser.add_argument("--log_dir", type=str, default="results") 39 | 40 | # model 41 | parser.add_argument("--model", type=str, default="gpt-4") 42 | parser.add_argument("--temperature", type=float, default=0.0) 43 | 44 | # env context 45 | parser.add_argument("--previous_top_k_elements", type=int, default=3) 46 | parser.add_argument("--top_k_elements", type=int, default=5) 47 | parser.add_argument("--retrieve_top_k", type=int, default=1) 48 | 49 | # workflow 50 | parser.add_argument("--website", type=str, required=True) 51 | parser.add_argument("--domain", type=str, default=None) 52 | parser.add_argument("--subdomain", type=str, default=None) 53 | parser.add_argument("--workflow_path", type=str, required=True) 54 | parser.add_argument("--suffix", type=str, default="workflow") 55 | 56 | # ablation 57 | parser.add_argument("--mode", type=str, default="memory", choices=["memory", "action"]) 58 | parser.add_argument("--start_idx", type=int, default=0, help="Select example index.") 59 | parser.add_argument("--end_idx", type=int, default=None, help="Select example index.") 60 | 61 | args = parser.parse_args() 62 | 63 | # sanity check 64 | if not os.path.exists(args.workflow_path): open(args.workflow_path, 'w').close() 65 | if args.retrieve_top_k != 1: print(f"Suggest set `retrieve_top_k` to 1, currently as {args.retrieve_top_k}") 66 | 67 | main() 68 | -------------------------------------------------------------------------------- /mind2web/utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | 5 | # %% load data 6 | def load_json(data_dir, folder_name): 7 | folder_path = os.path.join(data_dir, folder_name) 8 | print(f"Data path: {folder_path}") 9 | data_paths = [ 10 | os.path.join(folder_path, file) 11 | for file in os.listdir(folder_path) 12 | if file.endswith(".json") 13 | ] 14 | data_paths = sorted(data_paths, key=lambda x: int(x.split("_")[-1].split(".")[0])) 15 | 16 | # Construct trajectory dataset 17 | samples = [] 18 | for data_path in data_paths: 19 | with open(data_path, "r") as f: 20 | samples.extend(json.load(f)) 21 | print("# of samples:", len(samples)) 22 | 23 | return samples 24 | 25 | 26 | def add_scores( 27 | examples: list[dict], candidate_results: dict = None, 28 | score_path: str = "data/scores_all_data.pkl" 29 | ): 30 | """Add prediction scores and ranks to candidate elements.""" 31 | if candidate_results is None: 32 | with open(score_path, "rb") as f: 33 | candidate_results = pickle.load(f) 34 | 35 | for sample in examples: 36 | for s, act_repr in zip(sample["actions"], sample["action_reprs"]): 37 | sample_id = f"{sample['annotation_id']}_{s['action_uid']}" 38 | for candidates in [s["pos_candidates"], s["neg_candidates"]]: 39 | for candidate in candidates: 40 | candidate_id = candidate["backend_node_id"] 41 | candidate["score"] = candidate_results["scores"][sample_id][candidate_id] 42 | candidate["rank"] = candidate_results["ranks"][sample_id][candidate_id] 43 | 44 | return examples 45 | 46 | 47 | # %% workflow induction 48 | def format_examples(examples: list[dict], prefix: str = None, suffix: str = None) -> str: 49 | lines = [] 50 | for i, ex in enumerate(examples): 51 | lines.append(f"Query #{i+1}: {ex['confirmed_task']}") 52 | lines.append("Actions and Environments:") 53 | lines.extend(ex["action_reprs"]) 54 | lines.append("") 55 | prompt = '\n'.join(lines) 56 | if prefix is not None: 57 | prompt = prefix + '\n' + prompt 58 | if suffix is not None: 59 | prompt += '\n\n' + suffix 60 | return prompt 61 | 62 | 63 | 64 | # %% model generation 65 | def is_website_header(block: str, website: str) -> bool: 66 | lines = block.strip().split('\n') 67 | if len(lines) > 1: return False 68 | text = lines[0].strip() 69 | if text.startswith("#") and text.lower().endswith(website): 70 | return True 71 | return False 72 | 73 | def filter_workflows(text: str, website: str) -> str: 74 | blocks = text.split('\n\n') 75 | for i,b in enumerate(blocks): 76 | if is_website_header(b, website): 77 | blocks = blocks[i+1: ] 78 | break 79 | 80 | for i,b in enumerate(blocks): 81 | if is_website_header(b, "delta"): 82 | blocks = blocks[: i] 83 | break 84 | 85 | blocks = [b for b in blocks if "delta" not in b.lower()] 86 | return '\n\n'.join(blocks) -------------------------------------------------------------------------------- /mind2web/utils/env.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | import json 4 | import os 5 | import string 6 | import ast 7 | from lxml import etree 8 | 9 | 10 | def get_target_obs(dom_tree, target_element_ids): 11 | pruned_tree = prune_tree(dom_tree, target_element_ids) 12 | tree_repr, _ = get_tree_repr(pruned_tree, id_mapping={}, keep_html_brackets=True) 13 | 14 | return tree_repr 15 | 16 | 17 | def get_target_act(example, target_element_id): 18 | action_op = example["operation"]["op"] 19 | action_value = example["operation"]["value"] 20 | target_action = f"{action_op} [{target_element_id}]" 21 | if action_op != "CLICK": 22 | target_action += f" [{action_value}]" 23 | 24 | return target_action 25 | 26 | 27 | def parse_act_str(act_str): 28 | # Compile the regular expression pattern 29 | pattern = re.compile(r"(?:^|\s)(CLICK|SELECT|TYPE)?\s?\[(.+?)\](?:\s\[(.+?)\])?") 30 | # Search for the pattern in the string 31 | match = pattern.search(act_str) 32 | if match: 33 | # Extract the matching groups 34 | action_op = match.group(1) # This will be None if not in the list 35 | target_element_id = match.group(2) 36 | action_value = match.group(3) # This will be None if not present 37 | return action_op, target_element_id, action_value 38 | else: 39 | return None, None, None 40 | 41 | 42 | 43 | def find_node_id_by_text(example, text): 44 | dom_tree = etree.fromstring(example["cleaned_html"]) 45 | node = dom_tree.xpath(f'//*[@aria_label="{text}"]')[0] 46 | return node.attrib["backend_node_id"] 47 | 48 | 49 | def parse_act_str_workflow(act_str, example): 50 | # Compile the regular expression pattern 51 | pattern = re.compile(r"(?:^|\s)(CLICK|SELECT|TYPE)?\s?\[(.+?)\](?:\s\[(.+?)\])?") 52 | # Search for the pattern in the string 53 | match = pattern.search(act_str) 54 | if match: 55 | # Extract the matching groups 56 | action_op = match.group(1) # This will be None if not in the list 57 | target_element_id = match.group(2) 58 | action_value = match.group(3) # This will be None if not present 59 | yield action_op, target_element_id, action_value 60 | elif ('(' in act_str) and (')' in act_str): 61 | print("Action Str: ", act_str) 62 | from data.workflow.code import WORKFLOW_DICT 63 | s, e = act_str.index('('), act_str.index(')') 64 | wkfl = act_str[: s].strip() 65 | print("Workflow: ", wkfl) 66 | if wkfl not in WORKFLOW_DICT: yield None, None, None 67 | args = [arg.value for arg in ast.parse(act_str).body[0].value.args] 68 | if args == []: 69 | args = act_str[s: e].split(',') 70 | for i in range(len(args)): 71 | if '=' in args[i]: args[i] = args[i].split('=')[1] 72 | print("Args: ", args) 73 | # args = [arg.strip() for arg in act_str[s+1: e].split(',')] 74 | steps = WORKFLOW_DICT[wkfl](*args) 75 | for step in steps: 76 | print("Org Step: ", step) 77 | element_name, op_name = step.split('->') 78 | s = element_name.index(']') + 1 79 | element_name = element_name[s:].strip() 80 | if ':' in op_name: 81 | op_name, arg = [item.strip() for item in op_name.split(':')] 82 | else: 83 | arg = None 84 | try: 85 | element_id = find_node_id_by_text(example, element_name) 86 | step = f"{op_name} [{element_id}]" 87 | if arg is not None: 88 | step += f" [{arg}]" 89 | step = step.replace('"', '') 90 | print("New Step: ", step) 91 | yield parse_act_str(step) 92 | except: 93 | yield None, None, None 94 | 95 | else: 96 | yield None, None, None 97 | 98 | 99 | def construct_act_str(op, val): 100 | if op is None: 101 | if val is None: 102 | return " " 103 | return " " + val 104 | if op == "CLICK" or val is None: 105 | return op + " " 106 | return f"{op} {val}" 107 | 108 | 109 | def get_target_obs_and_act(example): 110 | if len(example["pos_candidates"]) == 0: 111 | # Simplify the raw_html if pos_candidates is empty (not in the cleaned html) 112 | dom_tree = etree.fromstring(example["raw_html"]) 113 | gt_element = dom_tree.xpath( 114 | f"//*[@data_pw_testid_buckeye='{example['action_uid']}']" 115 | ) 116 | element_id = gt_element[0].get("backend_node_id") 117 | raw_obs = get_target_obs(dom_tree, [element_id]) 118 | # Find the start index of the target element using the element ID 119 | start_idx = raw_obs.find(f"id={element_id}") 120 | # Find the start tag for the target element 121 | start_tag_idx = raw_obs.rfind("<", 0, start_idx) 122 | end_tag_idx = raw_obs.find(">", start_idx) 123 | # Extract the tag name 124 | tag_name = raw_obs[start_tag_idx + 1 : end_tag_idx].split()[0] 125 | # Initialize count for open and close tags 126 | open_count = 0 127 | close_count = 0 128 | search_idx = start_tag_idx 129 | while True: 130 | # Find the next open or close tag of the same type 131 | next_open_tag = raw_obs.find(f"<{tag_name}", search_idx) 132 | next_close_tag = raw_obs.find(f"", search_idx) 133 | # No more tags found, break 134 | if next_open_tag == -1 and next_close_tag == -1: 135 | break 136 | # Decide whether the next tag is an open or close tag 137 | if next_open_tag != -1 and ( 138 | next_open_tag < next_close_tag or next_close_tag == -1 139 | ): 140 | open_count += 1 141 | search_idx = raw_obs.find(">", next_open_tag) + 1 142 | else: 143 | close_count += 1 144 | search_idx = next_close_tag + len(f"") 145 | # If we've closed all open tags, break 146 | if open_count == close_count: 147 | break 148 | # Extract the target element 149 | o = f" {raw_obs[start_tag_idx:search_idx]} " 150 | a = get_target_act(example, element_id) 151 | else: 152 | dom_tree = etree.fromstring(example["cleaned_html"]) 153 | element_id = example["pos_candidates"][0]["backend_node_id"] 154 | o = get_target_obs(dom_tree, [element_id]) 155 | a = get_target_act(example, element_id) 156 | 157 | return o, a 158 | 159 | 160 | def get_top_k_obs(s: dict, top_k: int, use_raw: bool = True) -> tuple[str, str]: 161 | # Find one positive candidate (it can be zero) 162 | pos_candidates = s["pos_candidates"] 163 | pos_ids = [c["backend_node_id"] for c in pos_candidates][:1] 164 | # Find top_k - 1 negative candidates 165 | neg_candidates = s["neg_candidates"] 166 | neg_candidates = sorted(neg_candidates, key=lambda c: c["rank"])[: top_k - 1] 167 | neg_ids = [c["backend_node_id"] for c in neg_candidates] 168 | # Prune html with all candidates 169 | all_candidates = pos_ids + neg_ids 170 | obs = get_target_obs(etree.fromstring(s["cleaned_html"]), all_candidates) 171 | # If there is no positive candidate in cleaned_html, get it from raw_html 172 | if len(s["pos_candidates"]) == 0: 173 | assert use_raw 174 | # Simplify the raw_html if pos_candidates is empty (not in the cleaned html) 175 | dom_tree = etree.fromstring(s["raw_html"]) 176 | gt_element = dom_tree.xpath(f"//*[@data_pw_testid_buckeye='{s['action_uid']}']") 177 | element_id = gt_element[0].get("backend_node_id") 178 | raw_obs = get_target_obs(dom_tree, [element_id]) 179 | # Find the start index of the target element using the element ID 180 | start_idx = raw_obs.find(f"id={element_id}") 181 | # Find the start tag for the target element 182 | start_tag_idx = raw_obs.rfind("<", 0, start_idx) 183 | end_tag_idx = raw_obs.find(">", start_idx) 184 | # Extract the tag name 185 | tag_name = raw_obs[start_tag_idx + 1 : end_tag_idx].split()[0] 186 | # Initialize count for open and close tags 187 | open_count = 0 188 | close_count = 0 189 | search_idx = start_tag_idx 190 | while True: 191 | # Find the next open or close tag of the same type 192 | next_open_tag = raw_obs.find(f"<{tag_name}", search_idx) 193 | next_close_tag = raw_obs.find(f"", search_idx) 194 | # No more tags found, break 195 | if next_open_tag == -1 and next_close_tag == -1: 196 | break 197 | # Decide whether the next tag is an open or close tag 198 | if next_open_tag != -1 and ( 199 | next_open_tag < next_close_tag or next_close_tag == -1 200 | ): 201 | open_count += 1 202 | search_idx = raw_obs.find(">", next_open_tag) + 1 203 | else: 204 | close_count += 1 205 | search_idx = next_close_tag + len(f"") 206 | # If we've closed all open tags, break 207 | if open_count == close_count: 208 | break 209 | # Extract the target element 210 | target_element = raw_obs[start_tag_idx:search_idx] 211 | obs = obs.replace("", f"{target_element} ") 212 | 213 | return obs, all_candidates 214 | 215 | 216 | def calculate_f1(pred, label): 217 | pred = set(pred.strip().split()) 218 | label = set(label.strip().split()) 219 | # remove punctuation 220 | pred = set([x for x in pred if x not in string.punctuation]) 221 | label = set([x for x in label if x not in string.punctuation]) 222 | if len(pred) == 0 and len(label) == 0: 223 | return 1 224 | if len(pred) == 0 or len(label) == 0: 225 | return 0 226 | 227 | tp = len(pred & label) 228 | fp = len(pred - label) 229 | fn = len(label - pred) 230 | precision = tp / (tp + fp) 231 | recall = tp / (tp + fn) 232 | if precision == 0 or recall == 0: 233 | return 0 234 | f1 = 2 * precision * recall / (precision + recall) 235 | return f1 236 | 237 | 238 | def get_descendants(node, max_depth, current_depth=0): 239 | if current_depth > max_depth: 240 | return [] 241 | 242 | descendants = [] 243 | for child in node: 244 | descendants.append(child) 245 | descendants.extend(get_descendants(child, max_depth, current_depth + 1)) 246 | 247 | return descendants 248 | 249 | 250 | def get_attribute_repr(node, max_value_length=5, max_length=20): 251 | # get attribute values in order 252 | attr_values_set = set() 253 | attr_values = "" 254 | for attr in [ 255 | "role", 256 | "aria_role", 257 | "type", 258 | "alt", 259 | "aria_description", 260 | "aria_label", 261 | "label", 262 | "title", 263 | "name", 264 | "text_value", 265 | "value", 266 | "placeholder", 267 | "input_checked", 268 | "input_value", 269 | "option_selected", 270 | "class", 271 | ]: 272 | if attr in node.attrib and node.attrib[attr] is not None: 273 | value = node.attrib[attr].lower() 274 | # less menaingful values 275 | if value in [ 276 | "hidden", 277 | "none", 278 | "presentation", 279 | "null", 280 | "undefined", 281 | ] or value.startswith("http"): 282 | continue 283 | value = value.split() 284 | value = " ".join([v for v in value if len(v) < 15][:max_value_length]) 285 | if value and value not in attr_values_set: 286 | attr_values_set.add(value) 287 | attr_values += value + " " 288 | uid = node.attrib.get("backend_node_id", "") 289 | # clear all attributes 290 | node.attrib.clear() 291 | if uid: 292 | node.attrib["id"] = uid 293 | # add meta attribute 294 | if attr_values: 295 | node.attrib["meta"] = " ".join(attr_values.split()[:max_length]) 296 | 297 | 298 | 299 | def prune_tree( 300 | dom_tree, 301 | candidate_set, 302 | max_depth=5, 303 | max_children=50, 304 | max_sibling=3, 305 | ): 306 | nodes_to_keep = set() 307 | for candidate_id in candidate_set: 308 | candidate_node = dom_tree.xpath(f'//*[@backend_node_id="{candidate_id}"]')[0] 309 | nodes_to_keep.add(candidate_node.attrib["backend_node_id"]) 310 | # get all ancestors 311 | nodes_to_keep.update( 312 | [ 313 | x.attrib.get("backend_node_id", "") 314 | for x in candidate_node.xpath("ancestor::*") 315 | ] 316 | ) 317 | # get descendants with max depth 318 | nodes_to_keep.update( 319 | [ 320 | x.attrib.get("backend_node_id", "") 321 | for x in get_descendants(candidate_node, max_depth) 322 | ][:max_children] 323 | ) 324 | # get siblings within range 325 | parent = candidate_node.getparent() 326 | if parent is not None: 327 | siblings = [x for x in parent.getchildren() if x.tag != "text"] 328 | idx_in_sibling = siblings.index(candidate_node) 329 | nodes_to_keep.update( 330 | [ 331 | x.attrib.get("backend_node_id", "") 332 | for x in siblings[ 333 | max(0, idx_in_sibling - max_sibling) : idx_in_sibling 334 | + max_sibling 335 | + 1 336 | ] 337 | ] 338 | ) 339 | # clone the tree 340 | new_tree = copy.deepcopy(dom_tree) 341 | # remove nodes not in nodes_to_keep 342 | for node in new_tree.xpath("//*")[::-1]: 343 | if node.tag != "text": 344 | is_keep = node.attrib.get("backend_node_id", "") in nodes_to_keep 345 | is_candidate = node.attrib.get("backend_node_id", "") in candidate_set 346 | else: 347 | is_keep = ( 348 | node.getparent().attrib.get("backend_node_id", "") in nodes_to_keep 349 | ) 350 | is_candidate = ( 351 | node.getparent().attrib.get("backend_node_id", "") in candidate_set 352 | ) 353 | if not is_keep and node.getparent() is not None: 354 | node.getparent().remove(node) 355 | else: 356 | if not is_candidate or node.tag == "text": 357 | node.attrib.pop("backend_node_id", None) 358 | if ( 359 | len(node.attrib) == 0 360 | and not any([x.tag == "text" for x in node.getchildren()]) 361 | and node.getparent() is not None 362 | and node.tag != "text" 363 | and len(node.getchildren()) <= 1 364 | ): 365 | # insert all children into parent 366 | for child in node.getchildren(): 367 | node.addprevious(child) 368 | node.getparent().remove(node) 369 | return new_tree 370 | 371 | 372 | def get_tree_repr( 373 | tree, max_value_length=5, max_length=20, id_mapping={}, keep_html_brackets=False 374 | ): 375 | if isinstance(tree, str): 376 | tree = etree.fromstring(tree) 377 | else: 378 | tree = copy.deepcopy(tree) 379 | for node in tree.xpath("//*"): 380 | if node.tag != "text": 381 | if "backend_node_id" in node.attrib: 382 | if node.attrib["backend_node_id"] not in id_mapping: 383 | id_mapping[node.attrib["backend_node_id"]] = len(id_mapping) 384 | # node.attrib["backend_node_id"] = str( 385 | # id_mapping[node.attrib["backend_node_id"]] 386 | # ) 387 | get_attribute_repr(node, max_value_length, max_length) 388 | else: 389 | node.text = " ".join(node.text.split()[:max_length]) 390 | tree_repr = etree.tostring(tree, encoding="unicode") 391 | 392 | tree_repr = tree_repr.replace('"', " ") 393 | tree_repr = ( 394 | tree_repr.replace("meta= ", "").replace("id= ", "id=").replace(" >", ">") 395 | ) 396 | tree_repr = re.sub(r"(.*?)", r"\1", tree_repr) 397 | if not keep_html_brackets: 398 | tree_repr = tree_repr.replace("/>", "$/$>") 399 | tree_repr = re.sub(r"", r")", tree_repr) 400 | tree_repr = re.sub(r"<(.+?)>", r"(\1", tree_repr) 401 | tree_repr = tree_repr.replace("$/$", ")") 402 | 403 | html_escape_table = [ 404 | (""", '"'), 405 | ("&", "&"), 406 | ("<", "<"), 407 | (">", ">"), 408 | (" ", " "), 409 | ("–", "-"), 410 | ("’", "'"), 411 | ("‘", "'"), 412 | ("“", '"'), 413 | ("”", '"'), 414 | ("'", "'"), 415 | ("(", "("), 416 | (")", ")"), 417 | ] 418 | for k, v in html_escape_table: 419 | tree_repr = tree_repr.replace(k, v) 420 | tree_repr = re.sub(r"\s+", " ", tree_repr).strip() 421 | 422 | return tree_repr, id_mapping 423 | -------------------------------------------------------------------------------- /mind2web/utils/llm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import os 4 | import inspect 5 | import tiktoken 6 | 7 | logger = logging.getLogger("main") 8 | 9 | import openai 10 | openai.api_key = os.environ["OPENAI_API_KEY"] 11 | from openai import OpenAI 12 | client = OpenAI() 13 | 14 | 15 | def num_tokens_from_messages(messages, model): 16 | """Return the number of tokens used by a list of messages. 17 | Borrowed from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 18 | """ 19 | try: 20 | encoding = tiktoken.encoding_for_model(model) 21 | except KeyError: 22 | # print("Warning: model not found. Using cl100k_base encoding.") 23 | encoding = tiktoken.get_encoding("cl100k_base") 24 | if model in { 25 | "GPT-3-5-turbo-chat", 26 | "GPT-3-5-16k-turbo-chat", 27 | "gpt-3.5-16k-turbo-chat", 28 | "gpt-3.5-turbo-0613", 29 | "gpt-3.5-turbo-16k-0613", 30 | "gpt-3.5-turbo-1106", 31 | "gpt-4-0314", 32 | "gpt-4-32k-0314", 33 | "gpt-4-0613", 34 | "gpt-4-32k-0613", 35 | "gpt-4o", 36 | }: 37 | tokens_per_message = 3 38 | tokens_per_name = 1 39 | elif model == "gpt-3.5-turbo-0301": 40 | tokens_per_message = ( 41 | 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 42 | ) 43 | tokens_per_name = -1 # if there's a name, the role is omitted 44 | else: 45 | raise NotImplementedError( 46 | f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" 47 | ) 48 | num_tokens = 0 49 | for message in messages: 50 | num_tokens += tokens_per_message 51 | for key, value in message.items(): 52 | num_tokens += len(encoding.encode(value)) 53 | if key == "name": 54 | num_tokens += tokens_per_name 55 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 56 | return num_tokens 57 | 58 | 59 | MAX_TOKENS = { 60 | "GPT-3-5-turbo-chat": 4097, 61 | "gpt-3.5-turbo-0301": 4097, 62 | "gpt-3.5-turbo-0613": 4097, 63 | "gpt-3.5-turbo-16k-0613": 16385, 64 | "gpt-3.5-turbo-1106": 16385, 65 | "gpt-4": 8192, 66 | "gpt-4o": 16385, 67 | "GPT-3-5-16k-turbo-chat": 16385, 68 | "gpt-4-32k": 32000, 69 | } 70 | 71 | 72 | def get_mode(model: str) -> str: 73 | """Check if the model is a chat model.""" 74 | if model in [ 75 | "GPT-3-5-turbo-chat", 76 | "GPT-3-5-16k-turbo-chat", 77 | "gpt-3.5-16k-turbo-chat", 78 | "gpt-3.5-turbo-0301", 79 | "gpt-3.5-turbo-0613", 80 | "gpt-3.5-turbo-1106", 81 | "gpt-3.5-turbo-16k-0613", 82 | "gpt-4-0314", 83 | "gpt-4-32k-0314", 84 | "gpt-4-0613", 85 | "gpt-4-32k-0613", 86 | "gpt-4", 87 | "gpt-4o", 88 | "gpt-4-32k-0613", 89 | ]: 90 | return "chat" 91 | elif model in [ 92 | "davinci-002", 93 | "gpt-3.5-turbo-instruct-0914", 94 | ]: 95 | return "completion" 96 | else: 97 | raise ValueError(f"Unknown model: {model}") 98 | 99 | 100 | # @backoff.on_exception( 101 | # backoff.constant, 102 | # (APIError, RateLimitError, APIConnectionError), 103 | # interval=10, 104 | # ) 105 | def generate_response( 106 | messages: list[dict[str, str]], 107 | model: str, 108 | temperature: float, 109 | stop_tokens: list[str] | None = None, 110 | use_tools: bool = False, 111 | ) -> tuple[str, dict[str, int]]: 112 | """Send a request to the OpenAI API.""" 113 | 114 | logger.info( 115 | f"Send a request to the language model from {inspect.stack()[1].function}" 116 | ) 117 | gen_kwargs = {} 118 | 119 | if get_mode(model) == "chat": 120 | response = client.chat.completions.create( 121 | model=model, 122 | messages=messages, 123 | temperature=temperature, 124 | stop=stop_tokens if stop_tokens else None, 125 | **gen_kwargs 126 | ) 127 | message = response.choices[0].message.content 128 | else: 129 | prompt = "\n\n".join(m["content"] for m in messages) + "\n\n" 130 | response = openai.Completion.create( 131 | prompt=prompt, 132 | engine=model, 133 | temperature=temperature, 134 | stop=stop_tokens if stop_tokens else None, 135 | ) 136 | message = response["choices"][0]["text"] 137 | info = { 138 | "prompt_tokens": response.usage.prompt_tokens, 139 | "completion_tokens": response.usage.completion_tokens, 140 | "total_tokens": response.usage.total_tokens, 141 | } 142 | if message is None: message = "" 143 | 144 | return message, info 145 | 146 | 147 | def extract_from_response(response: str, backtick="```") -> str: 148 | if backtick == "```": 149 | # Matches anything between ```\n and \n``` 150 | pattern = r"```(?:[a-zA-Z]*)\n?(.*?)\n?```" 151 | elif backtick == "`": 152 | pattern = r"`(.*?)`" 153 | else: 154 | raise ValueError(f"Unknown backtick: {backtick}") 155 | match = re.search( 156 | pattern, response, re.DOTALL 157 | ) # re.DOTALL makes . match also newlines 158 | if match: 159 | extracted_string = match.group(1) 160 | else: 161 | extracted_string = "" 162 | 163 | return extracted_string 164 | -------------------------------------------------------------------------------- /mind2web/workflow/retrieve.py: -------------------------------------------------------------------------------- 1 | """Retrieve workflows given a query.""" 2 | 3 | import os 4 | import json 5 | import random 6 | import argparse 7 | 8 | from langchain.vectorstores import FAISS 9 | from langchain.embeddings.openai import OpenAIEmbeddings 10 | 11 | # %% load test examples 12 | def get_examples(data_dir: str, website: str = None) -> list[dict]: 13 | print("Start loading data files...") 14 | paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir)] 15 | examples = [] 16 | for p in paths: 17 | data = json.load(open(p, 'r')) 18 | if website is not None: 19 | data = [ex for ex in data if website==ex["website"]] 20 | examples.extend(data) 21 | print(f"Collected {len(examples)} examples about website {website}") 22 | return examples 23 | 24 | 25 | # %% load workflows 26 | def clean_workflow_name(name: str) -> str: 27 | if ':' in name: 28 | name = name[name.index(':')+1:].strip() 29 | if '`' in name: 30 | s = name.index('`') 31 | name = name[s+1: ] 32 | if '`' in name: 33 | e = name.index('`') 34 | name = name[: e] 35 | return f"## {name}" 36 | 37 | def load_workflows(path: str) -> list[dict]: 38 | """Load workflow blocks in the given file path.""" 39 | website = path.split('/')[-1].split('.')[0].split('_')[0] 40 | blocks = open(path, 'r').read().split('\n\n') 41 | 42 | def check_workflow(text: str) -> dict | None: 43 | lines = text.strip().split('\n') 44 | if len(lines) < 4: return None 45 | name = clean_workflow_name(lines[0].lstrip('#').strip()) 46 | docstr = lines[1].strip() 47 | return { 48 | "website": website, 49 | "name": name, "docstring": docstr, 50 | "content": '\n'.join([name] + lines[1: ]) 51 | } 52 | 53 | workflows = [check_workflow(b) for b in blocks] 54 | workflows = [w for w in workflows if w is not None] 55 | return workflows 56 | 57 | # %% retrieve workflows 58 | def build_memory(workflows: list[dict], memory_path: str): 59 | embedding = OpenAIEmbeddings(model="text-embedding-ada-002") 60 | metadatas = [{"name": i} for i in range(len(workflows))] 61 | texts = [f"{w['name']}\n{w['docstring']}" for w in workflows] 62 | memory = FAISS.from_texts( 63 | texts=texts, 64 | embedding=embedding, 65 | metadatas=metadatas, 66 | ) 67 | if memory_path is not None: 68 | memory.save_local(memory_path) 69 | return metadatas 70 | 71 | 72 | def get_ids_and_scores(memory, query: str, top_k: int) -> tuple[list[str], list[float]]: 73 | docs_and_similarities = memory.similarity_search_with_score(query, top_k) 74 | retrieved_ids, scores = [], [] 75 | for doc, score in docs_and_similarities: 76 | retrieved_ids.append(doc.metadata["name"]) 77 | scores.append(score) 78 | return retrieved_ids, scores 79 | 80 | 81 | # %% main pipeline 82 | def main(): 83 | # enumerate workflow files 84 | suffix = ".txt" if args.workflow_suffix is None else f"_{args.suffix}.txt" 85 | workflow_files = [] 86 | for (dirpath, dirnames, filenames) in os.walk(args.workflow_dir): 87 | workflow_files.extend([ 88 | os.path.join(dirpath, f) 89 | for f in filenames if f.endswith(suffix) 90 | ]) 91 | print(f"Collected #{len(workflow_files)} workflow files in total.") 92 | 93 | # parse workflows from each file 94 | workflows = [] 95 | for wf in workflow_files: 96 | workflows.extend(load_workflows(wf)) 97 | print(f"Collected #{len(workflows)} from files.") 98 | 99 | # select workflows 100 | if args.mode == "random": 101 | selected_workflows = random.sample(workflows, args.top_k) 102 | elif args.mode == "semantic": 103 | memory = build_memory(workflows, args.memory_path) 104 | examples = get_examples(args.data_dir, args.website) 105 | queries = [ex["confirmed_task"] for ex in examples] 106 | retrieved_ids_and_scores = [] 107 | for ex in examples: 108 | rids, rscores = get_ids_and_scores(memory, query=ex["confirmed_task"], top_k=args.top_k) 109 | retrieved_ids_and_scores.extend([(rid,rscr) for rid,rscr in zip(rids, rscores)]) 110 | retrieved_ids_and_scores = sorted(retrieved_ids_and_scores, key=lambda x:-x[1]) 111 | selected_workflows = [workflows[i] for i,s in retrieved_ids_and_scores[:args.top_k]] 112 | else: 113 | raise ValueError 114 | 115 | # write selected workflows to the output path 116 | with open(args.output_path, 'w') as fw: 117 | fw.write('\n\n'.join([w["content"] for w in selected_workflows])) 118 | 119 | 120 | def ablation(): # retrieve from training examples 121 | train_examples = json.load(open("../data/memory/exemplars.json", 'r')) 122 | embedding = OpenAIEmbeddings(model="text-embedding-ada-002") 123 | metadatas = [{"name": i} for i in range(len(train_examples))] 124 | texts = [ex[0]["specifier"] for ex in train_examples] 125 | memory = FAISS.from_texts(texts=texts, embedding=embedding, metadatas=metadatas) 126 | 127 | # get test examples 128 | examples = get_examples(args.data_dir, args.website) 129 | retrieved_ids_and_scores = [] 130 | for ex in examples: 131 | rids, rscores = get_ids_and_scores(memory, query=ex["confirmed_task"], top_k=args.top_k) 132 | retrieved_ids_and_scores.extend([(rid,rscr) for rid,rscr in zip(rids, rscores)]) 133 | print("Top Retrieved Item: ",retrieved_ids_and_scores[0]) 134 | retrieved_ids_and_scores = sorted(retrieved_ids_and_scores, key=lambda x:-x[1]) 135 | selected_examples = [train_examples[i] for i,s in retrieved_ids_and_scores[:args.top_k]] 136 | 137 | # write selected examples to the output path 138 | selected_examples = [ 139 | '\n'.join([item["content"] for item in sex]) 140 | for sex in selected_examples 141 | ] 142 | with open(args.output_path, 'w') as fw: 143 | fw.write('\n\n'.join(selected_examples)) 144 | 145 | 146 | 147 | if __name__ == "__main__": 148 | parser = argparse.ArgumentParser() 149 | # offline workflows 150 | parser.add_argument("--workflow_dir", type=str, default=None, 151 | help="Directory of workflows to retrieve from.") 152 | parser.add_argument("--workflow_suffix", type=str, default=None, 153 | help="Specified suffix of workflow files to load.") 154 | 155 | # test data 156 | parser.add_argument("--data_dir", type=str, default=None) 157 | parser.add_argument("--website", type=str, default=None) 158 | 159 | # retrieval 160 | parser.add_argument("--mode", type=str, default="random", 161 | choices=["random", "semantic"]) 162 | parser.add_argument("--memory_path", type=str, default="memory") 163 | parser.add_argument("--top_k", type=int, default=10, 164 | help="Number of top-relevant workflows to save.") 165 | 166 | parser.add_argument("--output_path", type=str, default=None, 167 | help="Path to output the collected (relevant) workflows.") 168 | # ablation 169 | parser.add_argument("--run_ablation", action="store_true", 170 | help="If run ablation study to retrieve entire examples.") 171 | 172 | args = parser.parse_args() 173 | 174 | if args.mode == "semantic": 175 | assert (args.data_dir is not None) and (args.website is not None) and (args.workflow_dir is not None) 176 | 177 | if args.run_ablation: ablation() 178 | else: main() 179 | -------------------------------------------------------------------------------- /webarena/README.md: -------------------------------------------------------------------------------- 1 | # AWM for WebArena 2 | 3 | ## Install 4 | 5 | *Install `browsergym`*: Follow the instructions in [this README](https://github.com/ServiceNow/BrowserGym) to install `browsergym`. 6 | 7 | ```bash 8 | pip install browsergym 9 | playwright install chromium 10 | ``` 11 | 12 | *Setup the `webarena` specifics*: 13 | 14 | ```bash 15 | pip install browsergym-webarena 16 | python -c "import nltk; nltk.download('punkt')" 17 | ``` 18 | 19 | Set up the web servers and environment URLs (find more details in the [webarena readme](https://github.com/web-arena-x/webarena/blob/main/environment_docker/README.md)). 20 | 21 | ```bash 22 | BASE_URL= 23 | export WA_SHOPPING="$BASE_URL:7770/" 24 | export WA_SHOPPING_ADMIN="$BASE_URL:7780/admin" 25 | export WA_REDDIT="$BASE_URL:9999" 26 | export WA_GITLAB="$BASE_URL:8023" 27 | export WA_WIKIPEDIA="$BASE_URL:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing" 28 | export WA_MAP="$BASE_URL:3000" 29 | export WA_HOMEPAGE="$BASE_URL:4399" 30 | ``` 31 | 32 | Then, generate the config files for each task, which will be used during workflow induction. 33 | 34 | ```bash 35 | cd config_files 36 | python generate_test_data.py 37 | cd ../ 38 | ``` 39 | 40 | *Install agent and evaluation requirements*: 41 | 42 | ```bash 43 | pip install -r requirements.txt # agent 44 | pip install -r autoeval/requirements.txt # model-based evaluation 45 | ``` 46 | 47 | *Setup `openai` keys* 48 | 49 | ```bash 50 | export OPENAI_API_KEY= 51 | ``` 52 | 53 | 54 | ## Run Agent 55 | 56 | ### Baseline Agent: No Memory 57 | 58 | ```bash 59 | python run.py --task_name "webarena.0" # switch task id from 0 to 811 60 | ``` 61 | 62 | You can check the results in `./results/...webarena.0.../`. 63 | 64 | ### with Agent Task Memory :atm: 65 | 66 | **Step 1**. To run inference on a task: 67 | 68 | ```bash 69 | python run.py --task "webarena.0" \ 70 | --workflow_path "workflow/shopping.txt" 71 | ``` 72 | 73 | Remember to match the workflow path with the associated tasks of the input id. By default, name the workflow file with the website name, e.g., 'shopping_admin.txt', 'reddit.txt', 'gitlab.txt', 'map.txt'. 74 | 75 | **Step 2**. To evaluate an agent-generated task trajectory, run: 76 | 77 | ```bash 78 | python -m autoeval.evaluate_trajectory --result_dir "../results/webarena.0" 79 | ``` 80 | 81 | This will produce a "{model}_autoeval.txt" in the "../results/webarena.0/" directory. 82 | 83 | Change the `model` and `prompt` format options if necessary. 84 | 85 | **Step 3**. Integrate the trajectory workflows to agent memory 86 | 87 | ```bash 88 | python induce_rule.py --results_dir results/shopping 89 | # or `induce_prompt.py` for neural-based induction 90 | ``` 91 | 92 | Switch the `criteria` to "gt" (and can skip step 2) if you want to use ground-truth reward as signals to integrate workflows. 93 | 94 | Iterate the loop of steps 1-3 for each task, or every n tasks to your demand. 95 | We provide a pipeline script to iteratively execute the above steps for all tasks: 96 | 97 | ```bash 98 | python pipeline.py --website "shopping" 99 | ``` 100 | -------------------------------------------------------------------------------- /webarena/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/webarena/agents/__init__.py -------------------------------------------------------------------------------- /webarena/agents/basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/webarena/agents/basic/__init__.py -------------------------------------------------------------------------------- /webarena/agents/basic/agent.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from browsergym.experiments import Agent, AbstractAgentArgs 4 | from browsergym.core.action.highlevel import HighLevelActionSet 5 | from browsergym.core.action.python import PythonActionSet 6 | from browsergym.utils.obs import flatten_axtree_to_str 7 | 8 | 9 | class DemoAgent(Agent): 10 | """A basic agent using OpenAI API, to demonstrate BrowserGym's functionalities.""" 11 | 12 | action_set = HighLevelActionSet( 13 | subsets=["chat", "bid"], # define a subset of the action space 14 | # subsets=["chat", "bid", "coord"] # allow the agent to also use x,y coordinates 15 | strict=False, # less strict on the parsing of the actions 16 | multiaction=True, # enable to agent to take multiple actions at once 17 | demo_mode="default", # add visual effects 18 | ) 19 | # use this instead to allow the agent to directly use Python code 20 | # action_set = PythonActionSet()) 21 | 22 | def obs_preprocessor(self, obs: dict) -> dict: 23 | return { 24 | "goal": obs["goal"], 25 | "axtree_txt": flatten_axtree_to_str(obs["axtree_object"]), 26 | } 27 | 28 | def __init__(self, model_name) -> None: 29 | super().__init__() 30 | self.model_name = model_name 31 | 32 | from openai import OpenAI 33 | 34 | self.openai_client = OpenAI() 35 | 36 | def get_action(self, obs: dict) -> tuple[str, dict]: 37 | system_msg = f"""\ 38 | # Instructions 39 | Review the current state of the page and all other information to find the best 40 | possible next action to accomplish your goal. Your answer will be interpreted 41 | and executed by a program, make sure to follow the formatting instructions. 42 | 43 | # Goal: 44 | {obs["goal"]}""" 45 | 46 | prompt = f"""\ 47 | # Current Accessibility Tree: 48 | {obs["axtree_txt"]} 49 | 50 | # Action Space 51 | {self.action_set.describe(with_long_description=False, with_examples=True)} 52 | 53 | Here is an example with chain of thought of a valid action when clicking on a button: 54 | " 55 | In order to accomplish my goal I need to click on the button with bid 12 56 | ```click("12")``` 57 | " 58 | """ 59 | 60 | # query OpenAI model 61 | response = self.openai_client.chat.completions.create( 62 | model=self.model_name, 63 | messages=[ 64 | {"role": "system", "content": system_msg}, 65 | {"role": "user", "content": prompt}, 66 | ], 67 | ) 68 | action = response.choices[0].message.content 69 | 70 | return action, {} 71 | 72 | 73 | @dataclasses.dataclass 74 | class DemoAgentArgs(AbstractAgentArgs): 75 | """ 76 | This class is meant to store the arguments that define the agent. 77 | 78 | By isolating them in a dataclass, this ensures serialization without storing 79 | internal states of the agent. 80 | """ 81 | 82 | model_name: str = "gpt-3.5-turbo" 83 | 84 | def make_agent(self): 85 | return DemoAgent(model_name=self.model_name) 86 | 87 | 88 | def main(): 89 | from browsergym.experiments import EnvArgs, ExpArgs, get_exp_result 90 | from pathlib import Path 91 | 92 | exp_root = Path().home() / "agent_experiments" 93 | exp_root.mkdir(exist_ok=True) 94 | 95 | exp_args = ExpArgs( 96 | agent_args=DemoAgentArgs(model_name="gpt-3.5-turbo"), 97 | env_args=EnvArgs( 98 | task_name="miniwob.click-test", 99 | task_seed=42, 100 | headless=False, # shows the browser 101 | ), 102 | ) 103 | 104 | exp_args.prepare(exp_root=exp_root) 105 | exp_args.run() 106 | 107 | exp_result = get_exp_result(exp_args.exp_dir) 108 | exp_record = exp_result.get_exp_record() 109 | 110 | for key, val in exp_record.items(): 111 | print(f"{key}: {val}") 112 | -------------------------------------------------------------------------------- /webarena/agents/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import GenericAgentArgs 2 | from .dynamic_prompting import Flags 3 | -------------------------------------------------------------------------------- /webarena/agents/legacy/agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | WARNING DEPRECATED WILL BE REMOVED SOON 3 | """ 4 | 5 | from dataclasses import asdict, dataclass, field 6 | import traceback 7 | from warnings import warn 8 | from langchain.schema import HumanMessage, SystemMessage 9 | 10 | from browsergym.core.action.base import AbstractActionSet 11 | from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html 12 | from browsergym.experiments import Agent, AbstractAgentArgs 13 | 14 | from ..legacy import dynamic_prompting 15 | from .utils.llm_utils import ParseError, retry 16 | from .utils.chat_api import ChatModelArgs 17 | 18 | 19 | @dataclass 20 | class GenericAgentArgs(AbstractAgentArgs): 21 | chat_model_args: ChatModelArgs = None 22 | flags: dynamic_prompting.Flags = field(default_factory=lambda: dynamic_prompting.Flags()) 23 | max_retry: int = 4 24 | 25 | def make_agent(self): 26 | return GenericAgent( 27 | chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry 28 | ) 29 | 30 | 31 | class GenericAgent(Agent): 32 | 33 | def obs_preprocessor(self, obs: dict) -> dict: 34 | """ 35 | Augment observations with text HTML and AXTree representations, which will be stored in 36 | the experiment traces. 37 | """ 38 | 39 | obs = obs.copy() 40 | obs["dom_txt"] = flatten_dom_to_str( 41 | obs["dom_object"], 42 | with_visible=self.flags.extract_visible_tag, 43 | with_center_coords=self.flags.extract_coords == "center", 44 | with_bounding_box_coords=self.flags.extract_coords == "box", 45 | filter_visible_only=self.flags.extract_visible_elements_only, 46 | ) 47 | obs["axtree_txt"] = flatten_axtree_to_str( 48 | obs["axtree_object"], 49 | with_visible=self.flags.extract_visible_tag, 50 | with_center_coords=self.flags.extract_coords == "center", 51 | with_bounding_box_coords=self.flags.extract_coords == "box", 52 | filter_visible_only=self.flags.extract_visible_elements_only, 53 | ) 54 | obs["pruned_html"] = prune_html(obs["dom_txt"]) 55 | 56 | return obs 57 | 58 | def __init__( 59 | self, 60 | chat_model_args: ChatModelArgs = None, 61 | flags: dynamic_prompting.Flags = None, 62 | max_retry: int = 4, 63 | ): 64 | self.chat_model_args = chat_model_args if chat_model_args is not None else ChatModelArgs() 65 | self.flags = flags if flags is not None else dynamic_prompting.Flags() 66 | self.max_retry = max_retry 67 | 68 | self.chat_llm = chat_model_args.make_chat_model() 69 | self.action_set = dynamic_prompting._get_action_space(self.flags) 70 | 71 | # consistency check 72 | if self.flags.use_screenshot: 73 | if not self.chat_model_args.has_vision(): 74 | warn( 75 | """\ 76 | 77 | Warning: use_screenshot is set to True, but the chat model \ 78 | does not support vision. Disabling use_screenshot.""" 79 | ) 80 | self.flags.use_screenshot = False 81 | 82 | # reset episode memory 83 | self.obs_history = [] 84 | self.actions = [] 85 | self.memories = [] 86 | self.thoughts = [] 87 | 88 | def get_action(self, obs): 89 | 90 | self.obs_history.append(obs) 91 | 92 | main_prompt = dynamic_prompting.MainPrompt( 93 | obs_history=self.obs_history, 94 | actions=self.actions, 95 | memories=self.memories, 96 | thoughts=self.thoughts, 97 | flags=self.flags, 98 | ) 99 | 100 | # Determine the minimum non-None token limit from prompt, total, and input tokens, or set to None if all are None. 101 | maxes = ( 102 | self.flags.max_prompt_tokens, 103 | self.chat_model_args.max_total_tokens, 104 | self.chat_model_args.max_input_tokens, 105 | ) 106 | maxes = [m for m in maxes if m is not None] 107 | max_prompt_tokens = min(maxes) if maxes else None 108 | 109 | prompt = dynamic_prompting.fit_tokens( 110 | main_prompt, 111 | max_prompt_tokens=max_prompt_tokens, 112 | model_name=self.chat_model_args.model_name, 113 | ) 114 | 115 | sys_msg = dynamic_prompting.SystemPrompt().prompt 116 | if self.flags.workflow_path is not None: 117 | sys_msg += '\n\n' + open(self.flags.workflow_path).read() 118 | 119 | chat_messages = [ 120 | SystemMessage(content=sys_msg), 121 | HumanMessage(content=prompt), 122 | ] 123 | 124 | def parser(text): 125 | try: 126 | ans_dict = main_prompt._parse_answer(text) 127 | except ParseError as e: 128 | # these parse errors will be caught by the retry function and 129 | # the chat_llm will have a chance to recover 130 | return None, False, str(e) 131 | 132 | return ans_dict, True, "" 133 | 134 | try: 135 | ans_dict = retry(self.chat_llm, chat_messages, n_retry=self.max_retry, parser=parser) 136 | # inferring the number of retries, TODO: make this less hacky 137 | ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 138 | except ValueError as e: 139 | # Likely due to maximum retry. We catch it here to be able to return 140 | # the list of messages for further analysis 141 | ans_dict = {"action": None} 142 | ans_dict["err_msg"] = str(e) 143 | ans_dict["stack_trace"] = traceback.format_exc() 144 | ans_dict["n_retry"] = self.max_retry 145 | 146 | self.actions.append(ans_dict["action"]) 147 | self.memories.append(ans_dict.get("memory", None)) 148 | self.thoughts.append(ans_dict.get("think", None)) 149 | 150 | ans_dict["chat_messages"] = [m.content for m in chat_messages] 151 | ans_dict["chat_model_args"] = asdict(self.chat_model_args) 152 | 153 | return ans_dict["action"], ans_dict 154 | -------------------------------------------------------------------------------- /webarena/agents/legacy/dynamic_prompting.py: -------------------------------------------------------------------------------- 1 | """ 2 | WARNING DEPRECATED WILL BE REMOVED SOON 3 | """ 4 | 5 | import abc 6 | import difflib 7 | import logging 8 | import platform 9 | 10 | from copy import deepcopy 11 | from dataclasses import asdict, dataclass 12 | from textwrap import dedent 13 | from typing import Literal 14 | from warnings import warn 15 | 16 | from browsergym.core.action.base import AbstractActionSet 17 | from browsergym.core.action.highlevel import HighLevelActionSet 18 | from browsergym.core.action.python import PythonActionSet 19 | 20 | from .utils.llm_utils import ParseError 21 | from .utils.llm_utils import ( 22 | count_tokens, 23 | image_to_jpg_base64_url, 24 | parse_html_tags_raise, 25 | ) 26 | 27 | 28 | @dataclass 29 | class Flags: 30 | use_html: bool = True 31 | use_ax_tree: bool = False 32 | drop_ax_tree_first: bool = True # This flag is no longer active TODO delete 33 | use_thinking: bool = False 34 | use_error_logs: bool = False 35 | use_past_error_logs: bool = False 36 | use_history: bool = False 37 | use_action_history: bool = False 38 | use_memory: bool = False 39 | use_diff: bool = False 40 | html_type: str = "pruned_html" 41 | use_concrete_example: bool = True 42 | use_abstract_example: bool = False 43 | multi_actions: bool = False 44 | action_space: Literal[ 45 | "python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav" 46 | ] = "bid" 47 | is_strict: bool = False 48 | # This flag will be automatically disabled `if not chat_model_args.has_vision()` 49 | use_screenshot: bool = True 50 | enable_chat: bool = False 51 | max_prompt_tokens: int = None 52 | extract_visible_tag: bool = False 53 | extract_coords: Literal["False", "center", "box"] = "False" 54 | extract_visible_elements_only: bool = False 55 | demo_mode: Literal["off", "default", "only_visible_elements"] = "off" 56 | workflow_path: str = None 57 | 58 | def copy(self): 59 | return deepcopy(self) 60 | 61 | def asdict(self): 62 | """Helper for JSON serializble requirement.""" 63 | return asdict(self) 64 | 65 | @classmethod 66 | def from_dict(self, flags_dict): 67 | """Helper for JSON serializble requirement.""" 68 | if isinstance(flags_dict, Flags): 69 | return flags_dict 70 | 71 | if not isinstance(flags_dict, dict): 72 | raise ValueError(f"Unregcognized type for flags_dict of type {type(flags_dict)}.") 73 | return Flags(**flags_dict) 74 | 75 | 76 | class PromptElement: 77 | """Base class for all prompt elements. Prompt elements can be hidden. 78 | 79 | Prompt elements are used to build the prompt. Use flags to control which 80 | prompt elements are visible. We use class attributes as a convenient way 81 | to implement static prompts, but feel free to override them with instance 82 | attributes or @property decorator.""" 83 | 84 | _prompt = "" 85 | _abstract_ex = "" 86 | _concrete_ex = "" 87 | 88 | def __init__(self, visible: bool = True) -> None: 89 | """Prompt element that can be hidden. 90 | 91 | Parameters 92 | ---------- 93 | visible : bool, optional 94 | Whether the prompt element should be visible, by default True. Can 95 | be a callable that returns a bool. This is useful when a specific 96 | flag changes during a shrink iteration. 97 | """ 98 | self._visible = visible 99 | 100 | @property 101 | def prompt(self): 102 | """Avoid overriding this method. Override _prompt instead.""" 103 | return self._hide(self._prompt) 104 | 105 | @property 106 | def abstract_ex(self): 107 | """Useful when this prompt element is requesting an answer from the llm. 108 | Provide an abstract example of the answer here. See Memory for an 109 | example. 110 | 111 | Avoid overriding this method. Override _abstract_ex instead 112 | """ 113 | return self._hide(self._abstract_ex) 114 | 115 | @property 116 | def concrete_ex(self): 117 | """Useful when this prompt element is requesting an answer from the llm. 118 | Provide a concrete example of the answer here. See Memory for an 119 | example. 120 | 121 | Avoid overriding this method. Override _concrete_ex instead 122 | """ 123 | return self._hide(self._concrete_ex) 124 | 125 | @property 126 | def is_visible(self): 127 | """Handle the case where visible is a callable.""" 128 | visible = self._visible 129 | if callable(visible): 130 | visible = visible() 131 | return visible 132 | 133 | def _hide(self, value): 134 | """Return value if visible is True, else return empty string.""" 135 | if self.is_visible: 136 | return value 137 | else: 138 | return "" 139 | 140 | def _parse_answer(self, text_answer) -> dict: 141 | if self.is_visible: 142 | return self._parse_answer(text_answer) 143 | else: 144 | return {} 145 | 146 | 147 | class Shrinkable(PromptElement, abc.ABC): 148 | @abc.abstractmethod 149 | def shrink(self) -> None: 150 | """Implement shrinking of this prompt element. 151 | 152 | You need to recursively call all shrinkable elements that are part of 153 | this prompt. You can also implement a shriking startegy for this prompt. 154 | Shrinking is can be called multiple times to progressively shrink the 155 | prompt until it fits max_tokens. Default max shrink iterations is 20. 156 | """ 157 | pass 158 | 159 | 160 | class Trunkater(Shrinkable): 161 | def __init__(self, visible, shrink_speed=0.3, start_trunkate_iteration=10): 162 | super().__init__(visible=visible) 163 | self.shrink_speed = shrink_speed 164 | self.start_trunkate_iteration = start_trunkate_iteration 165 | self.shrink_calls = 0 166 | self.deleted_lines = 0 167 | 168 | def shrink(self) -> None: 169 | if self.is_visible and self.shrink_calls >= self.start_trunkate_iteration: 170 | # remove the fraction of _prompt 171 | lines = self._prompt.splitlines() 172 | new_line_count = int(len(lines) * (1 - self.shrink_speed)) 173 | self.deleted_lines += len(lines) - new_line_count 174 | self._prompt = "\n".join(lines[:new_line_count]) 175 | self._prompt += f"\n... Deleted {self.deleted_lines} lines to reduce prompt size." 176 | 177 | self.shrink_calls += 1 178 | 179 | 180 | def fit_tokens( 181 | shrinkable: Shrinkable, max_prompt_tokens=None, max_iterations=20, model_name="openai/gpt-4" 182 | ): 183 | """Shrink a prompt element until it fits max_tokens. 184 | 185 | Parameters 186 | ---------- 187 | shrinkable : Shrinkable 188 | The prompt element to shrink. 189 | max_tokens : int 190 | The maximum number of tokens allowed. 191 | max_iterations : int, optional 192 | The maximum number of shrink iterations, by default 20. 193 | model_name : str, optional 194 | The name of the model used when tokenizing. 195 | 196 | Returns 197 | ------- 198 | str : the prompt after shrinking. 199 | """ 200 | 201 | if max_prompt_tokens is None: 202 | return shrinkable.prompt 203 | 204 | for _ in range(max_iterations): 205 | prompt = shrinkable.prompt 206 | if isinstance(prompt, str): 207 | prompt_str = prompt 208 | elif isinstance(prompt, list): 209 | prompt_str = "\n".join([p["text"] for p in prompt if p["type"] == "text"]) 210 | else: 211 | raise ValueError(f"Unrecognized type for prompt: {type(prompt)}") 212 | n_token = count_tokens(prompt_str, model=model_name) 213 | if n_token <= max_prompt_tokens: 214 | return prompt 215 | shrinkable.shrink() 216 | 217 | logging.info( 218 | dedent( 219 | f"""\ 220 | After {max_iterations} shrink iterations, the prompt is still 221 | {count_tokens(prompt_str)} tokens (greater than {max_prompt_tokens}). Returning the prompt as is.""" 222 | ) 223 | ) 224 | return prompt 225 | 226 | 227 | class HTML(Trunkater): 228 | def __init__(self, html, visible: bool = True, prefix="") -> None: 229 | super().__init__(visible=visible, start_trunkate_iteration=5) 230 | self._prompt = f"\n{prefix}HTML:\n{html}\n" 231 | 232 | 233 | class AXTree(Trunkater): 234 | def __init__(self, ax_tree, visible: bool = True, coord_type=None, prefix="") -> None: 235 | super().__init__(visible=visible, start_trunkate_iteration=10) 236 | if coord_type == "center": 237 | coord_note = """\ 238 | Note: center coordinates are provided in parenthesis and are 239 | relative to the top left corner of the page.\n\n""" 240 | elif coord_type == "box": 241 | coord_note = """\ 242 | Note: bounding box of each object are provided in parenthesis and are 243 | relative to the top left corner of the page.\n\n""" 244 | else: 245 | coord_note = "" 246 | self._prompt = f"\n{prefix}AXTree:\n{coord_note}{ax_tree}\n" 247 | 248 | 249 | class Error(PromptElement): 250 | def __init__(self, error, visible: bool = True, prefix="") -> None: 251 | super().__init__(visible=visible) 252 | self._prompt = f"\n{prefix}Error from previous action:\n{error}\n" 253 | 254 | 255 | class Observation(Shrinkable): 256 | """Observation of the current step. 257 | 258 | Contains the html, the accessibility tree and the error logs. 259 | """ 260 | 261 | def __init__(self, obs, flags: Flags) -> None: 262 | super().__init__() 263 | self.flags = flags 264 | self.obs = obs 265 | self.html = HTML(obs[flags.html_type], visible=lambda: flags.use_html, prefix="## ") 266 | self.ax_tree = AXTree( 267 | obs["axtree_txt"], 268 | visible=lambda: flags.use_ax_tree, 269 | coord_type=flags.extract_coords, 270 | prefix="## ", 271 | ) 272 | self.error = Error( 273 | obs["last_action_error"], 274 | visible=lambda: flags.use_error_logs and obs["last_action_error"], 275 | prefix="## ", 276 | ) 277 | 278 | def shrink(self): 279 | self.ax_tree.shrink() 280 | self.html.shrink() 281 | 282 | @property 283 | def _prompt(self) -> str: 284 | return f"\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n" 285 | 286 | def add_screenshot(self, prompt): 287 | if self.flags.use_screenshot: 288 | if isinstance(prompt, str): 289 | prompt = [{"type": "text", "text": prompt}] 290 | img_url = image_to_jpg_base64_url(self.obs["screenshot"]) 291 | prompt.append({"type": "image_url", "image_url": {"url": img_url}}) 292 | 293 | return prompt 294 | 295 | 296 | class MacNote(PromptElement): 297 | def __init__(self) -> None: 298 | super().__init__(visible=platform.system() == "Darwin") 299 | self._prompt = ( 300 | "\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n" 301 | ) 302 | 303 | 304 | class BeCautious(PromptElement): 305 | def __init__(self, visible: bool = True) -> None: 306 | super().__init__(visible=visible) 307 | self._prompt = f"""\ 308 | \nBe very cautious. Avoid submitting anything before verifying the effect of your 309 | actions. Take the time to explore the effect of safe actions first. For example 310 | you can fill a few elements of a form, but don't click submit before verifying 311 | that everything was filled correctly.\n""" 312 | 313 | 314 | class GoalInstructions(PromptElement): 315 | def __init__(self, goal, visible: bool = True) -> None: 316 | super().__init__(visible) 317 | self._prompt = f"""\ 318 | # Instructions 319 | Review the current state of the page and all other information to find the best 320 | possible next action to accomplish your goal. Your answer will be interpreted 321 | and executed by a program, make sure to follow the formatting instructions. 322 | 323 | ## Goal: 324 | {goal} 325 | """ 326 | 327 | 328 | class ChatInstructions(PromptElement): 329 | def __init__(self, chat_messages, visible: bool = True) -> None: 330 | super().__init__(visible) 331 | self._prompt = f"""\ 332 | # Instructions 333 | 334 | You are a UI Assistant, your goal is to help the user perform tasks using a web browser. You can 335 | communicate with the user via a chat, in which the user gives you instructions and in which you 336 | can send back messages. You have access to a web browser that both you and the user can see, 337 | and with which only you can interact via specific commands. 338 | 339 | Review the instructions from the user, the current state of the page and all other information 340 | to find the best possible next action to accomplish your goal. Your answer will be interpreted 341 | and executed by a program, make sure to follow the formatting instructions. 342 | 343 | ## Chat messages: 344 | 345 | """ 346 | self._prompt += "\n".join( 347 | [ 348 | f"""\ 349 | - [{msg['role']}] {msg['message']}""" 350 | for msg in chat_messages 351 | ] 352 | ) 353 | 354 | # workflow_path = "examples/map_autoeval.txt" 355 | # workflow = "\n\n" + open(workflow_path, 'r').read() 356 | # workflow = "" #"click('id') # when clicking, use the element id in string format." 357 | 358 | class SystemPrompt(PromptElement): 359 | _prompt = """\ 360 | You are an agent trying to solve a web task based on the content of the page and 361 | a user instructions. You can interact with the page and explore. Each time you 362 | submit an action it will be sent to the browser and you will receive a new page.""" 363 | 364 | 365 | class MainPrompt(Shrinkable): 366 | def __init__( 367 | self, 368 | obs_history, 369 | actions, 370 | memories, 371 | thoughts, 372 | flags: Flags, 373 | ) -> None: 374 | super().__init__() 375 | self.flags = flags 376 | self.history = History(obs_history, actions, memories, thoughts, flags) 377 | if self.flags.enable_chat: 378 | self.instructions = ChatInstructions(obs_history[-1]["chat_messages"]) 379 | else: 380 | if sum([msg["role"] == "user" for msg in obs_history[-1]["chat_messages"]]) > 1: 381 | logging.warning( 382 | "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." 383 | ) 384 | self.instructions = GoalInstructions(obs_history[-1]["goal"]) 385 | 386 | self.obs = Observation(obs_history[-1], self.flags) 387 | self.action_space = ActionSpace(self.flags) 388 | 389 | self.think = Think(visible=lambda: flags.use_thinking) 390 | self.memory = Memory(visible=lambda: flags.use_memory) 391 | 392 | @property 393 | def _prompt(self) -> str: 394 | prompt = f"""\ 395 | {self.instructions.prompt}\ 396 | {self.obs.prompt}\ 397 | {self.history.prompt}\ 398 | {self.action_space.prompt}\ 399 | {self.think.prompt}\ 400 | {self.memory.prompt}\ 401 | """ 402 | 403 | if self.flags.use_abstract_example: 404 | prompt += f""" 405 | # Abstract Example 406 | 407 | Here is an abstract version of the answer with description of the content of 408 | each tag. Make sure you follow this structure, but replace the content with your 409 | answer: 410 | {self.think.abstract_ex}\ 411 | {self.memory.abstract_ex}\ 412 | {self.action_space.abstract_ex}\ 413 | """ 414 | 415 | if self.flags.use_concrete_example: 416 | prompt += f""" 417 | # Concrete Example 418 | 419 | Here is a concrete example of how to format your answer. 420 | Make sure to follow the template with proper tags: 421 | {self.think.concrete_ex}\ 422 | {self.memory.concrete_ex}\ 423 | {self.action_space.concrete_ex}\ 424 | """ 425 | return self.obs.add_screenshot(prompt) 426 | 427 | def shrink(self): 428 | self.history.shrink() 429 | self.obs.shrink() 430 | 431 | def _parse_answer(self, text_answer): 432 | ans_dict = {} 433 | ans_dict.update(self.think._parse_answer(text_answer)) 434 | ans_dict.update(self.memory._parse_answer(text_answer)) 435 | ans_dict.update(self.action_space._parse_answer(text_answer)) 436 | return ans_dict 437 | 438 | 439 | class ActionSpace(PromptElement): 440 | def __init__(self, flags: Flags) -> None: 441 | super().__init__() 442 | self.flags = flags 443 | self.action_space = _get_action_space(flags) 444 | 445 | self._prompt = f"# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n" 446 | self._abstract_ex = f""" 447 | 448 | {self.action_space.example_action(abstract=True)} 449 | 450 | """ 451 | self._concrete_ex = f""" 452 | 453 | {self.action_space.example_action(abstract=False)} 454 | 455 | """ 456 | 457 | def _parse_answer(self, text_answer): 458 | ans_dict = parse_html_tags_raise(text_answer, keys=["action"], merge_multiple=True) 459 | 460 | try: 461 | # just check if action can be mapped to python code but keep action as is 462 | # the environment will be responsible for mapping it to python 463 | self.action_space.to_python_code(ans_dict["action"]) 464 | except Exception as e: 465 | raise ParseError( 466 | f"Error while parsing action\n: {e}\n" 467 | "Make sure your answer is restricted to the allowed actions." 468 | ) 469 | 470 | return ans_dict 471 | 472 | 473 | def _get_action_space(flags: Flags) -> AbstractActionSet: 474 | match flags.action_space: 475 | case "python": 476 | action_space = PythonActionSet(strict=flags.is_strict) 477 | if flags.multi_actions: 478 | warn( 479 | f"Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}." 480 | ) 481 | if flags.demo_mode != "off": 482 | warn( 483 | f"Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}." 484 | ) 485 | return action_space 486 | case "bid": 487 | action_subsets = ["chat", "bid"] 488 | case "coord": 489 | action_subsets = ["chat", "coord"] 490 | case "bid+coord": 491 | action_subsets = ["chat", "bid", "coord"] 492 | case "bid+nav": 493 | action_subsets = ["chat", "bid", "nav"] 494 | case "coord+nav": 495 | action_subsets = ["chat", "coord", "nav"] 496 | case "bid+coord+nav": 497 | action_subsets = ["chat", "bid", "coord", "nav"] 498 | case _: 499 | raise NotImplementedError(f"Unknown action_space {repr(flags.action_space)}") 500 | 501 | action_space = HighLevelActionSet( 502 | subsets=action_subsets, 503 | multiaction=flags.multi_actions, 504 | strict=flags.is_strict, 505 | demo_mode=flags.demo_mode, 506 | ) 507 | 508 | return action_space 509 | 510 | 511 | class Memory(PromptElement): 512 | _prompt = "" # provided in the abstract and concrete examples 513 | 514 | _abstract_ex = """ 515 | 516 | Write down anything you need to remember for next steps. You will be presented 517 | with the list of previous memories and past actions. 518 | 519 | """ 520 | 521 | _concrete_ex = """ 522 | 523 | I clicked on bid 32 to activate tab 2. The accessibility tree should mention 524 | focusable for elements of the form at next step. 525 | 526 | """ 527 | 528 | def _parse_answer(self, text_answer): 529 | return parse_html_tags_raise(text_answer, optional_keys=["memory"], merge_multiple=True) 530 | 531 | 532 | class Think(PromptElement): 533 | _prompt = "" 534 | 535 | _abstract_ex = """ 536 | 537 | Think step by step. If you need to make calculations such as coordinates, write them here. Describe the effect 538 | that your previous action had on the current content of the page. 539 | 540 | """ 541 | _concrete_ex = """ 542 | 543 | My memory says that I filled the first name and last name, but I can't see any 544 | content in the form. I need to explore different ways to fill the form. Perhaps 545 | the form is not visible yet or some fields are disabled. I need to replan. 546 | 547 | """ 548 | 549 | def _parse_answer(self, text_answer): 550 | return parse_html_tags_raise(text_answer, optional_keys=["think"], merge_multiple=True) 551 | 552 | 553 | def diff(previous, new): 554 | """Return a string showing the difference between original and new. 555 | 556 | If the difference is above diff_threshold, return the diff string.""" 557 | 558 | if previous == new: 559 | return "Identical", [] 560 | 561 | if len(previous) == 0 or previous is None: 562 | return "previous is empty", [] 563 | 564 | diff_gen = difflib.ndiff(previous.splitlines(), new.splitlines()) 565 | 566 | diff_lines = [] 567 | plus_count = 0 568 | minus_count = 0 569 | for line in diff_gen: 570 | if line.strip().startswith("+"): 571 | diff_lines.append(line) 572 | plus_count += 1 573 | elif line.strip().startswith("-"): 574 | diff_lines.append(line) 575 | minus_count += 1 576 | else: 577 | continue 578 | 579 | header = f"{plus_count} lines added and {minus_count} lines removed:" 580 | 581 | return header, diff_lines 582 | 583 | 584 | class Diff(Shrinkable): 585 | def __init__( 586 | self, previous, new, prefix="", max_line_diff=20, shrink_speed=2, visible=True 587 | ) -> None: 588 | super().__init__(visible=visible) 589 | self.max_line_diff = max_line_diff 590 | self.header, self.diff_lines = diff(previous, new) 591 | self.shrink_speed = shrink_speed 592 | self.prefix = prefix 593 | 594 | def shrink(self): 595 | self.max_line_diff -= self.shrink_speed 596 | self.max_line_diff = max(1, self.max_line_diff) 597 | 598 | @property 599 | def _prompt(self) -> str: 600 | diff_str = "\n".join(self.diff_lines[: self.max_line_diff]) 601 | if len(self.diff_lines) > self.max_line_diff: 602 | original_count = len(self.diff_lines) 603 | diff_str = f"{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown." 604 | return f"{self.prefix}{self.header}\n{diff_str}\n" 605 | 606 | 607 | class HistoryStep(Shrinkable): 608 | def __init__( 609 | self, previous_obs, current_obs, action, memory, flags: Flags, shrink_speed=1 610 | ) -> None: 611 | super().__init__() 612 | self.html_diff = Diff( 613 | previous_obs[flags.html_type], 614 | current_obs[flags.html_type], 615 | prefix="\n### HTML diff:\n", 616 | shrink_speed=shrink_speed, 617 | visible=lambda: flags.use_html and flags.use_diff, 618 | ) 619 | self.ax_tree_diff = Diff( 620 | previous_obs["axtree_txt"], 621 | current_obs["axtree_txt"], 622 | prefix=f"\n### Accessibility tree diff:\n", 623 | shrink_speed=shrink_speed, 624 | visible=lambda: flags.use_ax_tree and flags.use_diff, 625 | ) 626 | self.error = Error( 627 | current_obs["last_action_error"], 628 | visible=( 629 | lambda: flags.use_error_logs 630 | and current_obs["last_action_error"] 631 | and flags.use_past_error_logs 632 | ), 633 | prefix="### ", 634 | ) 635 | self.shrink_speed = shrink_speed 636 | self.action = action 637 | self.memory = memory 638 | self.flags = flags 639 | 640 | def shrink(self): 641 | super().shrink() 642 | self.html_diff.shrink() 643 | self.ax_tree_diff.shrink() 644 | 645 | @property 646 | def _prompt(self) -> str: 647 | prompt = "" 648 | 649 | if self.flags.use_action_history: 650 | prompt += f"\n### Action:\n{self.action}\n" 651 | 652 | prompt += f"{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}" 653 | 654 | if self.flags.use_memory and self.memory is not None: 655 | prompt += f"\n### Memory:\n{self.memory}\n" 656 | 657 | return prompt 658 | 659 | 660 | class History(Shrinkable): 661 | def __init__( 662 | self, history_obs, actions, memories, thoughts, flags: Flags, shrink_speed=1 663 | ) -> None: 664 | super().__init__(visible=lambda: flags.use_history) 665 | assert len(history_obs) == len(actions) + 1 666 | assert len(history_obs) == len(memories) + 1 667 | 668 | self.shrink_speed = shrink_speed 669 | self.history_steps: list[HistoryStep] = [] 670 | 671 | for i in range(1, len(history_obs)): 672 | self.history_steps.append( 673 | HistoryStep( 674 | history_obs[i - 1], 675 | history_obs[i], 676 | actions[i - 1], 677 | memories[i - 1], 678 | flags, 679 | ) 680 | ) 681 | 682 | def shrink(self): 683 | """Shrink individual steps""" 684 | # TODO set the shrink speed of older steps to be higher 685 | super().shrink() 686 | for step in self.history_steps: 687 | step.shrink() 688 | 689 | @property 690 | def _prompt(self): 691 | prompts = ["# History of interaction with the task:\n"] 692 | for i, step in enumerate(self.history_steps): 693 | prompts.append(f"## step {i}") 694 | prompts.append(step.prompt) 695 | return "\n".join(prompts) + "\n" 696 | 697 | 698 | if __name__ == "__main__": 699 | html_template = """ 700 | 701 | 702 |
703 | Hello World. 704 | Step {}. 705 |
706 | 707 | 708 | """ 709 | 710 | OBS_HISTORY = [ 711 | { 712 | "goal": "do this and that", 713 | "pruned_html": html_template.format(1), 714 | "axtree_txt": "[1] Click me", 715 | "last_action_error": "", 716 | }, 717 | { 718 | "goal": "do this and that", 719 | "pruned_html": html_template.format(2), 720 | "axtree_txt": "[1] Click me", 721 | "last_action_error": "", 722 | }, 723 | { 724 | "goal": "do this and that", 725 | "pruned_html": html_template.format(3), 726 | "axtree_txt": "[1] Click me", 727 | "last_action_error": "Hey, there is an error now", 728 | }, 729 | ] 730 | ACTIONS = ["click('41')", "click('42')"] 731 | MEMORIES = ["memory A", "memory B"] 732 | THOUGHTS = ["thought A", "thought B"] 733 | 734 | flags = Flags( 735 | use_html=True, 736 | use_ax_tree=True, 737 | use_thinking=True, 738 | use_error_logs=True, 739 | use_past_error_logs=True, 740 | use_history=True, 741 | use_action_history=True, 742 | use_memory=True, 743 | use_diff=True, 744 | html_type="pruned_html", 745 | use_concrete_example=True, 746 | use_abstract_example=True, 747 | multi_actions=True, 748 | ) 749 | 750 | print( 751 | MainPrompt( 752 | obs_history=OBS_HISTORY, 753 | actions=ACTIONS, 754 | memories=MEMORIES, 755 | thoughts=THOUGHTS, 756 | step=0, 757 | flags=flags, 758 | ).prompt 759 | ) 760 | -------------------------------------------------------------------------------- /webarena/agents/legacy/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/webarena/agents/legacy/utils/__init__.py -------------------------------------------------------------------------------- /webarena/agents/legacy/utils/chat_api.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | import io 3 | import json 4 | from .prompt_templates import PromptTemplate, get_prompt_template 5 | from langchain.schema import BaseMessage, SystemMessage, HumanMessage, AIMessage 6 | from functools import partial 7 | from typing import Optional, List, Any 8 | import logging 9 | from typing import Tuple 10 | import time 11 | 12 | from langchain_community.llms import HuggingFaceHub, HuggingFacePipeline 13 | from langchain_openai import ChatOpenAI 14 | from langchain.schema import BaseMessage 15 | from langchain.chat_models.base import SimpleChatModel 16 | from langchain.callbacks.manager import CallbackManagerForLLMRun 17 | from pydantic import Field 18 | from transformers import pipeline 19 | from dataclasses import dataclass 20 | from huggingface_hub import InferenceClient 21 | from transformers import AutoTokenizer 22 | from transformers import GPT2TokenizerFast 23 | 24 | 25 | @dataclass 26 | class ChatModelArgs: 27 | """Serializable object for instantiating a generic chat model. 28 | 29 | Attributes 30 | ---------- 31 | model_name : str 32 | The name or path of the model to use. 33 | model_url : str, optional 34 | The url of the model to use, e.g. via TGI. If None, then model_name or model_path must 35 | be specified. 36 | eai_token: str, optional 37 | The EAI token to use for authentication on Toolkit. Defaults to snow.optimass_account.cl4code's token. 38 | temperature : float 39 | The temperature to use for the model. 40 | max_new_tokens : int 41 | The maximum number of tokens to generate. 42 | hf_hosted : bool 43 | Whether the model is hosted on HuggingFace Hub. Defaults to False. 44 | info : dict, optional 45 | Any other information about how the model was finetuned. 46 | DGX related args 47 | n_gpus : int 48 | The number of GPUs to use. Defaults to 1. 49 | tgi_image : str 50 | The TGI image to use. Defaults to "e3cbr6awpnoq/research/text-generation-inference:1.1.0". 51 | ace : str 52 | The ACE to use. Defaults to "servicenow-scus-ace". 53 | workspace : str 54 | The workspace to use. Defaults to UI_COPILOT_SCUS_WORKSPACE. 55 | max_total_tokens : int 56 | The maximum number of total tokens (input + output). Defaults to 4096. 57 | """ 58 | 59 | model_name: str = "openai/gpt-3.5-turbo" 60 | model_url: str = None 61 | temperature: float = 0.1 62 | max_new_tokens: int = None 63 | max_total_tokens: int = None 64 | max_input_tokens: int = None 65 | hf_hosted: bool = False 66 | info: dict = None 67 | n_retry_server: int = 4 68 | 69 | def __post_init__(self): 70 | if self.model_url is not None and self.hf_hosted: 71 | raise ValueError("model_url cannot be specified when hf_hosted is True") 72 | 73 | def make_chat_model(self): 74 | if self.model_name.startswith("openai"): 75 | _, model_name = self.model_name.split("/") 76 | return ChatOpenAI( 77 | model_name=model_name, 78 | temperature=self.temperature, 79 | max_tokens=self.max_new_tokens, 80 | ) 81 | else: 82 | return HuggingFaceChatModel( 83 | model_name=self.model_name, 84 | hf_hosted=self.hf_hosted, 85 | temperature=self.temperature, 86 | max_new_tokens=self.max_new_tokens, 87 | max_total_tokens=self.max_total_tokens, 88 | max_input_tokens=self.max_input_tokens, 89 | model_url=self.model_url, 90 | n_retry_server=self.n_retry_server, 91 | ) 92 | 93 | @property 94 | def model_short_name(self): 95 | if "/" in self.model_name: 96 | return self.model_name.split("/")[1] 97 | else: 98 | return self.model_name 99 | 100 | def key(self): 101 | """Return a unique key for these arguments.""" 102 | return json.dumps(asdict(self), sort_keys=True) 103 | 104 | def has_vision(self): 105 | # TODO make sure to upgrade this as we add more models 106 | name_patterns_with_vision = [ 107 | "vision", 108 | "4o", 109 | ] 110 | return any(pattern in self.model_name for pattern in name_patterns_with_vision) 111 | 112 | 113 | class HuggingFaceChatModel(SimpleChatModel): 114 | """ 115 | Custom LLM Chatbot that can interface with HuggingFace models. 116 | 117 | This class allows for the creation of a custom chatbot using models hosted 118 | on HuggingFace Hub or a local checkpoint. It provides flexibility in defining 119 | the temperature for response sampling and the maximum number of new tokens 120 | in the response. 121 | 122 | Attributes: 123 | llm (Any): The HuggingFaceHub model instance. 124 | prompt_template (Any): Template for the prompt to be used for the model's input sequence. 125 | """ 126 | 127 | llm: Any = Field(description="The HuggingFaceHub model instance") 128 | tokenizer: Any = Field( 129 | default=None, 130 | description="The tokenizer to use for the model", 131 | ) 132 | prompt_template: Optional[PromptTemplate] = Field( 133 | default=None, 134 | description="Template for the prompt to be used for the model's input sequence", 135 | ) 136 | n_retry_server: int = Field( 137 | default=4, 138 | description="The number of times to retry the server if it fails to respond", 139 | ) 140 | 141 | def __init__( 142 | self, 143 | model_name: str, 144 | hf_hosted: bool, 145 | temperature: float, 146 | max_new_tokens: int, 147 | max_total_tokens: int, 148 | max_input_tokens: int, 149 | model_url: str, 150 | eai_token: str, 151 | n_retry_server: int, 152 | ): 153 | """ 154 | Initializes the CustomLLMChatbot with the specified configurations. 155 | 156 | Args: 157 | model_name (str): The path to the model checkpoint. 158 | prompt_template (PromptTemplate, optional): A string template for structuring the prompt. 159 | hf_hosted (bool, optional): Whether the model is hosted on HuggingFace Hub. Defaults to False. 160 | temperature (float, optional): Sampling temperature. Defaults to 0.1. 161 | max_new_tokens (int, optional): Maximum length for the response. Defaults to 64. 162 | model_url (str, optional): The url of the model to use. If None, then model_name or model_name will be used. Defaults to None. 163 | """ 164 | super().__init__() 165 | 166 | self.n_retry_server = n_retry_server 167 | 168 | if max_new_tokens is None: 169 | max_new_tokens = max_total_tokens - max_input_tokens 170 | logging.warning( 171 | f"max_new_tokens is not specified. Setting it to {max_new_tokens} (max_total_tokens - max_input_tokens)." 172 | ) 173 | 174 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 175 | if isinstance(self.tokenizer, GPT2TokenizerFast): 176 | # TODO: make this less hacky once tokenizer.apply_chat_template is more mature 177 | logging.warning( 178 | f"No chat template is defined for {model_name}. Resolving to the hard-coded templates." 179 | ) 180 | self.tokenizer = None 181 | self.prompt_template = get_prompt_template(model_name) 182 | 183 | if temperature < 1e-3: 184 | logging.warning( 185 | "some weird things might happen when temperature is too low for some models." 186 | ) 187 | 188 | model_kwargs = { 189 | "temperature": temperature, 190 | } 191 | 192 | if model_url is not None: 193 | logging.info("Loading the LLM from a URL") 194 | client = InferenceClient(model=model_url, token=eai_token) 195 | self.llm = partial( 196 | client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens 197 | ) 198 | elif hf_hosted: 199 | logging.info("Serving the LLM on HuggingFace Hub") 200 | model_kwargs["max_length"] = max_new_tokens 201 | self.llm = HuggingFaceHub(repo_id=model_name, model_kwargs=model_kwargs) 202 | else: 203 | logging.info("Loading the LLM locally") 204 | pipe = pipeline( 205 | task="text-generation", 206 | model=model_name, 207 | device_map="auto", 208 | max_new_tokens=max_new_tokens, 209 | model_kwargs=model_kwargs, 210 | ) 211 | self.llm = HuggingFacePipeline(pipeline=pipe) 212 | 213 | def _call( 214 | self, 215 | messages: List[BaseMessage], 216 | stop: Optional[List[str]] = None, 217 | run_manager: Optional[CallbackManagerForLLMRun] = None, 218 | **kwargs: Any, 219 | ) -> str: 220 | if stop is not None or run_manager is not None or kwargs: 221 | logging.warning( 222 | "The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation." 223 | ) 224 | 225 | if self.tokenizer: 226 | messages_formated = _convert_messages_to_dict(messages) 227 | prompt = self.tokenizer.apply_chat_template(messages_formated, tokenize=False) 228 | 229 | elif self.prompt_template: 230 | prompt = self.prompt_template.construct_prompt(messages) 231 | 232 | itr = 0 233 | while True: 234 | try: 235 | response = self.llm(prompt) 236 | return response 237 | except Exception as e: 238 | if itr == self.n_retry_server - 1: 239 | raise e 240 | logging.warning( 241 | f"Failed to get a response from the server: \n{e}\n" 242 | f"Retrying... ({itr+1}/{self.n_retry_server})" 243 | ) 244 | time.sleep(5) 245 | itr += 1 246 | 247 | def _llm_type(self): 248 | return "huggingface" 249 | 250 | 251 | def _convert_messages_to_dict(messages): 252 | """ 253 | Converts a list of message objects into a list of dictionaries, categorizing each message by its role. 254 | 255 | Each message is expected to be an instance of one of the following types: SystemMessage, HumanMessage, AIMessage. 256 | The function maps each message to its corresponding role ('system', 'user', 'assistant') and formats it into a dictionary. 257 | 258 | Args: 259 | messages (list): A list of message objects. 260 | 261 | Returns: 262 | list: A list of dictionaries where each dictionary represents a message and contains 'role' and 'content' keys. 263 | 264 | Raises: 265 | ValueError: If an unsupported message type is encountered. 266 | 267 | Example: 268 | >>> messages = [SystemMessage("System initializing..."), HumanMessage("Hello!"), AIMessage("How can I assist?")] 269 | >>> _convert_messages_to_dict(messages) 270 | [ 271 | {"role": "system", "content": "System initializing..."}, 272 | {"role": "user", "content": "Hello!"}, 273 | {"role": "assistant", "content": "How can I assist?"} 274 | ] 275 | """ 276 | 277 | # Mapping of message types to roles 278 | message_type_to_role = { 279 | SystemMessage: "system", 280 | HumanMessage: "user", 281 | AIMessage: "assistant", 282 | } 283 | 284 | chat = [] 285 | for message in messages: 286 | message_role = message_type_to_role.get(type(message)) 287 | if message_role: 288 | chat.append({"role": message_role, "content": message.content}) 289 | else: 290 | raise ValueError(f"Message type {type(message)} not supported") 291 | 292 | return chat 293 | -------------------------------------------------------------------------------- /webarena/agents/legacy/utils/llm_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | from pathlib import Path 4 | import re 5 | import time 6 | from warnings import warn 7 | import logging 8 | 9 | from functools import cache 10 | import numpy as np 11 | import tiktoken 12 | import yaml 13 | from langchain_openai import ChatOpenAI 14 | 15 | from langchain.schema import SystemMessage, HumanMessage 16 | from openai import BadRequestError 17 | from joblib import Memory 18 | from transformers import AutoModel 19 | from transformers import AutoTokenizer 20 | import io 21 | import base64 22 | from PIL import Image 23 | from openai import RateLimitError 24 | 25 | 26 | def _extract_wait_time(error_message, min_retry_wait_time=60): 27 | """Extract the wait time from an OpenAI RateLimitError message.""" 28 | match = re.search(r"try again in (\d+(\.\d+)?)s", error_message) 29 | if match: 30 | return max(min_retry_wait_time, float(match.group(1))) 31 | return min_retry_wait_time 32 | 33 | 34 | def retry( 35 | chat: ChatOpenAI, 36 | messages, 37 | n_retry, 38 | parser, 39 | log=True, 40 | min_retry_wait_time=60, 41 | rate_limit_max_wait_time=60 * 30, 42 | ): 43 | """Retry querying the chat models with the response from the parser until it 44 | returns a valid value. 45 | 46 | If the answer is not valid, it will retry and append to the chat the retry 47 | message. It will stop after `n_retry`. 48 | 49 | Note, each retry has to resend the whole prompt to the API. This can be slow 50 | and expensive. 51 | 52 | Parameters: 53 | ----------- 54 | chat (function) : a langchain ChatOpenAI taking a list of messages and 55 | returning a list of answers. 56 | messages (list) : the list of messages so far. 57 | n_retry (int) : the maximum number of sequential retries. 58 | parser (function): a function taking a message and returning a tuple 59 | with the following fields: 60 | value : the parsed value, 61 | valid : a boolean indicating if the value is valid, 62 | retry_message : a message to send to the chat if the value is not valid 63 | log (bool): whether to log the retry messages. 64 | min_retry_wait_time (float): the minimum wait time in seconds 65 | after RateLimtError. will try to parse the wait time from the error 66 | message. 67 | 68 | Returns: 69 | -------- 70 | value: the parsed value 71 | """ 72 | tries = 0 73 | rate_limit_total_delay = 0 74 | while tries < n_retry and rate_limit_total_delay < rate_limit_max_wait_time: 75 | try: 76 | answer = chat.invoke(messages) 77 | except RateLimitError as e: 78 | wait_time = _extract_wait_time(e.args[0], min_retry_wait_time) 79 | logging.warning(f"RateLimitError, waiting {wait_time}s before retrying.") 80 | time.sleep(wait_time) 81 | rate_limit_total_delay += wait_time 82 | if rate_limit_total_delay >= rate_limit_max_wait_time: 83 | logging.warning( 84 | f"Total wait time for rate limit exceeded. Waited {rate_limit_total_delay}s > {rate_limit_max_wait_time}s." 85 | ) 86 | raise 87 | continue 88 | 89 | messages.append(answer) 90 | 91 | value, valid, retry_message = parser(answer.content) 92 | if valid: 93 | return value 94 | 95 | tries += 1 96 | if log: 97 | msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}" 98 | logging.info(msg) 99 | messages.append(HumanMessage(content=retry_message)) 100 | 101 | raise ValueError(f"Could not parse a valid value after {n_retry} retries.") 102 | 103 | 104 | def retry_parallel(chat: ChatOpenAI, messages, n_retry, parser): 105 | """Retry querying the chat models with the response from the parser until it returns a valid value. 106 | 107 | It will stop after `n_retry`. It assuemes that chat will generate n_parallel answers for each message. 108 | The best answer is selected according to the score returned by the parser. If no answer is valid, the 109 | it will retry with the best answer so far and append to the chat the retry message. If there is a 110 | single parallel generation, it behaves like retry. 111 | 112 | This function is, in principle, more robust than retry. The speed and cost overhead is minimal with 113 | the prompt is large and the length of the generated message is small. 114 | 115 | Parameters: 116 | ----------- 117 | chat (function) : a langchain ChatOpenAI taking a list of messages and returning a list of answers. 118 | The number of parallel generations is specified at the creation of the chat object. 119 | messages (list) : the list of messages so far. 120 | n_retry (int) : the maximum number of sequential retries. 121 | parser (function): a function taking a message and returning a tuple with the following fields: 122 | value : the parsed value, 123 | valid : a boolean indicating if the value is valid, 124 | retry_message : a message to send to the chat if the value is not valid, 125 | score : a score to select the best answer from the parallel generations 126 | 127 | Returns: 128 | -------- 129 | value: the parsed value 130 | """ 131 | 132 | for i in range(n_retry): 133 | try: 134 | answers = chat.generate([messages]).generations[0] # chat.n parallel completions 135 | except BadRequestError as e: 136 | # most likely, the added messages triggered a message too long error 137 | # we thus retry without the last two messages 138 | if i == 0: 139 | raise e 140 | msg = f"BadRequestError, most likely the message is too long retrying with previous query." 141 | warn(msg) 142 | messages = messages[:-2] 143 | answers = chat.generate([messages]).generations[0] 144 | 145 | values, valids, retry_messages, scores = zip( 146 | *[parser(answer.message.content) for answer in answers] 147 | ) 148 | idx = np.argmax(scores) 149 | value = values[idx] 150 | valid = valids[idx] 151 | retry_message = retry_messages[idx] 152 | answer = answers[idx].message 153 | 154 | if valid: 155 | return value 156 | 157 | msg = f"Query failed. Retrying {i+1}/{n_retry}.\n[LLM]:\n{answer.content}\n[User]:\n{retry_message}" 158 | warn(msg) 159 | messages.append(answer) # already of type AIMessage 160 | messages.append(SystemMessage(content=retry_message)) 161 | 162 | raise ValueError(f"Could not parse a valid value after {n_retry} retries.") 163 | 164 | 165 | def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"): 166 | """Use tiktoken to truncate a text to a maximum number of tokens.""" 167 | enc = tiktoken.encoding_for_model(model_name) 168 | tokens = enc.encode(text) 169 | if len(tokens) - start > max_tokens: 170 | return enc.decode(tokens[start : (start + max_tokens)]) 171 | else: 172 | return text 173 | 174 | 175 | @cache 176 | def get_tokenizer(model_name="openai/gpt-4"): 177 | if model_name.startswith("openai"): 178 | return tiktoken.encoding_for_model(model_name.split("/")[-1]) 179 | else: 180 | return AutoTokenizer.from_pretrained(model_name) 181 | 182 | 183 | def count_tokens(text, model="openai/gpt-4"): 184 | enc = get_tokenizer(model) 185 | return len(enc.encode(text)) 186 | 187 | 188 | def count_messages_token(messages, model="openai/gpt-4"): 189 | """Count the number of tokens in a list of messages. 190 | 191 | Args: 192 | messages (list): a list of messages, each message can be a string or a 193 | list of dicts or an object with a content attribute. 194 | model (str): the model to use for tokenization. 195 | 196 | Returns: 197 | int: the number of tokens. 198 | """ 199 | token_count = 0 200 | for message in messages: 201 | if hasattr(message, "content"): 202 | message = message.content 203 | 204 | if isinstance(message, str): 205 | token_count += count_tokens(message, model) 206 | # handles messages with image content 207 | elif isinstance(message, (list, tuple)): 208 | for part in message: 209 | if not isinstance(part, dict): 210 | raise ValueError( 211 | f"The message is expected to be a list of dicts, but got list of {type(message)}" 212 | ) 213 | if part["type"] == "text": 214 | token_count += count_tokens(part["text"], model) 215 | else: 216 | raise ValueError( 217 | f"The message is expected to be a string or a list of dicts, but got {type(message)}" 218 | ) 219 | return token_count 220 | 221 | 222 | def json_parser(message): 223 | """Parse a json message for the retry function.""" 224 | 225 | try: 226 | value = json.loads(message) 227 | valid = True 228 | retry_message = "" 229 | except json.JSONDecodeError as e: 230 | warn(e) 231 | value = {} 232 | valid = False 233 | retry_message = "Your response is not a valid json. Please try again and be careful to the format. Don't add any apology or comment, just the answer." 234 | return value, valid, retry_message 235 | 236 | 237 | def yaml_parser(message): 238 | """Parse a yaml message for the retry function.""" 239 | 240 | # saves gpt-3.5 from some yaml parsing errors 241 | message = re.sub(r":\s*\n(?=\S|\n)", ": ", message) 242 | 243 | try: 244 | value = yaml.safe_load(message) 245 | valid = True 246 | retry_message = "" 247 | except yaml.YAMLError as e: 248 | warn(str(e)) 249 | value = {} 250 | valid = False 251 | retry_message = "Your response is not a valid yaml. Please try again and be careful to the format. Don't add any apology or comment, just the answer." 252 | return value, valid, retry_message 253 | 254 | 255 | def _compress_chunks(text, identifier, skip_list, split_regex="\n\n+"): 256 | """Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex.""" 257 | text_list = re.split(split_regex, text) 258 | text_list = [chunk.strip() for chunk in text_list] 259 | counter = collections.Counter(text_list) 260 | def_dict = {} 261 | id = 0 262 | 263 | # Store items that occur more than once in a dictionary 264 | for item, count in counter.items(): 265 | if count > 1 and item not in skip_list and len(item) > 10: 266 | def_dict[f"{identifier}-{id}"] = item 267 | id += 1 268 | 269 | # Replace redundant items with their identifiers in the text 270 | compressed_text = "\n".join(text_list) 271 | for key, value in def_dict.items(): 272 | compressed_text = compressed_text.replace(value, key) 273 | 274 | return def_dict, compressed_text 275 | 276 | 277 | def compress_string(text): 278 | """Compress a string by replacing redundant paragraphs and lines with identifiers.""" 279 | 280 | # Perform paragraph-level compression 281 | def_dict, compressed_text = _compress_chunks( 282 | text, identifier="§", skip_list=[], split_regex="\n\n+" 283 | ) 284 | 285 | # Perform line-level compression, skipping any paragraph identifiers 286 | line_dict, compressed_text = _compress_chunks( 287 | compressed_text, "¶", list(def_dict.keys()), split_regex="\n+" 288 | ) 289 | def_dict.update(line_dict) 290 | 291 | # Create a definitions section 292 | def_lines = [""] 293 | for key, value in def_dict.items(): 294 | def_lines.append(f"{key}:\n{value}") 295 | def_lines.append("") 296 | definitions = "\n".join(def_lines) 297 | 298 | return definitions + "\n" + compressed_text 299 | 300 | 301 | def extract_html_tags(text, keys): 302 | """Extract the content within HTML tags for a list of keys. 303 | 304 | Parameters 305 | ---------- 306 | text : str 307 | The input string containing the HTML tags. 308 | keys : list of str 309 | The HTML tags to extract the content from. 310 | 311 | Returns 312 | ------- 313 | dict 314 | A dictionary mapping each key to a list of subset in `text` that match the key. 315 | 316 | Notes 317 | ----- 318 | All text and keys will be converted to lowercase before matching. 319 | 320 | """ 321 | content_dict = {} 322 | # text = text.lower() 323 | # keys = set([k.lower() for k in keys]) 324 | for key in keys: 325 | pattern = f"<{key}>(.*?)" 326 | matches = re.findall(pattern, text, re.DOTALL) 327 | if matches: 328 | content_dict[key] = [match.strip() for match in matches] 329 | return content_dict 330 | 331 | 332 | class ParseError(Exception): 333 | pass 334 | 335 | 336 | def parse_html_tags_raise(text, keys=(), optional_keys=(), merge_multiple=False): 337 | """A version of parse_html_tags that raises an exception if the parsing is not successful.""" 338 | content_dict, valid, retry_message = parse_html_tags( 339 | text, keys, optional_keys, merge_multiple=merge_multiple 340 | ) 341 | if not valid: 342 | raise ParseError(retry_message) 343 | return content_dict 344 | 345 | 346 | def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False): 347 | """Satisfy the parse api, extracts 1 match per key and validates that all keys are present 348 | 349 | Parameters 350 | ---------- 351 | text : str 352 | The input string containing the HTML tags. 353 | keys : list of str 354 | The HTML tags to extract the content from. 355 | optional_keys : list of str 356 | The HTML tags to extract the content from, but are optional. 357 | 358 | Returns 359 | ------- 360 | dict 361 | A dictionary mapping each key to subset of `text` that match the key. 362 | bool 363 | Whether the parsing was successful. 364 | str 365 | A message to be displayed to the agent if the parsing was not successful. 366 | """ 367 | all_keys = tuple(keys) + tuple(optional_keys) 368 | content_dict = extract_html_tags(text, all_keys) 369 | retry_messages = [] 370 | 371 | for key in all_keys: 372 | if not key in content_dict: 373 | if not key in optional_keys: 374 | retry_messages.append(f"Missing the key <{key}> in the answer.") 375 | else: 376 | val = content_dict[key] 377 | content_dict[key] = val[0] 378 | if len(val) > 1: 379 | if not merge_multiple: 380 | retry_messages.append( 381 | f"Found multiple instances of the key {key}. You should have only one of them." 382 | ) 383 | else: 384 | # merge the multiple instances 385 | content_dict[key] = "\n".join(val) 386 | 387 | valid = len(retry_messages) == 0 388 | retry_message = "\n".join(retry_messages) 389 | return content_dict, valid, retry_message 390 | 391 | 392 | class ChatCached: 393 | # I wish I could extend ChatOpenAI, but it is somehow locked, I don't know if it's pydantic soercey. 394 | 395 | def __init__(self, chat, memory=None): 396 | self.chat = chat 397 | self.memory = memory if memory else Memory(location=Path.home() / "llm-cache", verbose=10) 398 | self._call = self.memory.cache(self.chat.__call__, ignore=["self"]) 399 | self._generate = self.memory.cache(self.chat.generate, ignore=["self"]) 400 | 401 | def __call__(self, messages): 402 | return self._call(messages) 403 | 404 | def generate(self, messages): 405 | return self._generate(messages) 406 | 407 | 408 | def download_and_save_model(model_name: str, save_dir: str = "."): 409 | model = AutoModel.from_pretrained(model_name) 410 | model.save_pretrained(save_dir) 411 | print(f"Model downloaded and saved to {save_dir}") 412 | 413 | 414 | def image_to_jpg_base64_url(image: np.ndarray | Image.Image): 415 | """Convert a numpy array to a base64 encoded image url.""" 416 | 417 | if isinstance(image, np.ndarray): 418 | image = Image.fromarray(image) 419 | if image.mode in ("RGBA", "LA"): 420 | image = image.convert("RGB") 421 | buffered = io.BytesIO() 422 | image.save(buffered, format="JPEG") 423 | 424 | image_base64 = base64.b64encode(buffered.getvalue()).decode() 425 | return f"data:image/jpeg;base64,{image_base64}" 426 | -------------------------------------------------------------------------------- /webarena/agents/legacy/utils/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain.schema import BaseMessage, SystemMessage, HumanMessage, AIMessage 4 | from dataclasses import dataclass 5 | 6 | """ 7 | To use this class, you should have the ``openai`` python package installed, and the 8 | environment variable ``OPENAI_API_KEY`` set with your API key. 9 | """ 10 | 11 | 12 | @dataclass 13 | class PromptTemplate: 14 | """ 15 | Base class for prompt templates. 16 | 17 | Defines a standard interface for prompt templates, ensuring that they contain 18 | the required fields for the CustomLLMChatbot. 19 | """ 20 | 21 | system: str 22 | human: str 23 | ai: str 24 | prompt_end: str = "" 25 | 26 | def format_message(self, message: BaseMessage) -> str: 27 | """ 28 | Formats a given message based on its type. 29 | 30 | Args: 31 | message (BaseMessage): The message to be formatted. 32 | 33 | Returns: 34 | str: The formatted message. 35 | 36 | Raises: 37 | ValueError: If the message type is not supported. 38 | """ 39 | if isinstance(message, SystemMessage): 40 | return self.system.format(input=message.content) 41 | elif isinstance(message, HumanMessage): 42 | return self.human.format(input=message.content) 43 | elif isinstance(message, AIMessage): 44 | return self.ai.format(input=message.content) 45 | else: 46 | raise ValueError(f"Message type {type(message)} not supported") 47 | 48 | def construct_prompt(self, messages: List[BaseMessage]) -> str: 49 | """ 50 | Constructs a prompt from a list of messages. 51 | 52 | Args: 53 | messages (List[BaseMessage]): The list of messages to be formatted. 54 | 55 | Returns: 56 | str: The constructed prompt. 57 | """ 58 | if not all(isinstance(m, BaseMessage) for m in messages): 59 | raise ValueError("All elements in the list must be of type BaseMessage") 60 | 61 | prompt = "".join([self.format_message(m) for m in messages]) 62 | prompt += self.prompt_end 63 | return prompt 64 | 65 | 66 | def get_prompt_template(model_name): 67 | for key, value in MODEL_PREFIX_TO_PROMPT_TEMPLATES.items(): 68 | if key in model_name: 69 | return value 70 | raise NotImplementedError(f"Model {model_name} has no supported chat template") 71 | 72 | 73 | ## Prompt templates 74 | 75 | STARCHAT_PROMPT_TEMPLATE = PromptTemplate( 76 | system="<|system|>\n{input}<|end|>\n", 77 | human="<|user|>\n{input}<|end|>\n", 78 | ai="<|assistant|>\n{input}<|end|>\n", 79 | prompt_end="<|assistant|>", 80 | ) 81 | 82 | 83 | ## Model prefix to prompt template mapping 84 | 85 | MODEL_PREFIX_TO_PROMPT_TEMPLATES = { 86 | "starcoder": STARCHAT_PROMPT_TEMPLATE, 87 | "starchat": STARCHAT_PROMPT_TEMPLATE, 88 | } 89 | -------------------------------------------------------------------------------- /webarena/autoeval/clients.py: -------------------------------------------------------------------------------- 1 | import os 2 | import base64 3 | import openai 4 | import numpy as np 5 | from PIL import Image 6 | from typing import Union, Optional 7 | from openai import OpenAI, ChatCompletion 8 | openai.api_key = os.environ["OPENAI_API_KEY"] 9 | openai.organization = os.environ.get("OPENAI_ORGANIZATION", "") 10 | client = OpenAI() 11 | 12 | 13 | class LM_Client: 14 | def __init__(self, model_name: str = "gpt-3.5-turbo") -> None: 15 | self.model_name = model_name 16 | 17 | def chat(self, messages, json_mode: bool = False) -> tuple[str, ChatCompletion]: 18 | """ 19 | messages=[ 20 | {"role": "system", "content": "You are a helpful assistant."}, 21 | {"role": "user", "content": "hi"}, 22 | ]) 23 | """ 24 | chat_completion = client.chat.completions.create( 25 | model=self.model_name, 26 | messages=messages, 27 | response_format={"type": "json_object"} if json_mode else None, 28 | temperature=0, 29 | ) 30 | response = chat_completion.choices[0].message.content 31 | return response, chat_completion 32 | 33 | def one_step_chat( 34 | self, text, system_msg: str = None, json_mode=False 35 | ) -> tuple[str, ChatCompletion]: 36 | messages = [] 37 | if system_msg is not None: 38 | messages.append({"role": "system", "content": system_msg}) 39 | messages.append({"role": "user", "content": text}) 40 | return self.chat(messages, json_mode=json_mode) 41 | 42 | 43 | class GPT4V_Client: 44 | def __init__(self, model_name: str = "gpt-4o", max_tokens: int = 512): 45 | self.model_name = model_name 46 | self.max_tokens = max_tokens 47 | 48 | def encode_image(self, path: str): 49 | with open(path, 'rb') as f: 50 | return base64.b64encode(f.read()).decode('utf-8') 51 | 52 | def one_step_chat( 53 | self, text, image: Union[Image.Image, np.ndarray], 54 | system_msg: Optional[str] = None, 55 | ) -> tuple[str, ChatCompletion]: 56 | jpg_base64_str = self.encode_image(image) 57 | messages = [] 58 | if system_msg is not None: 59 | messages.append({"role": "system", "content": system_msg}) 60 | messages += [{ 61 | "role": "user", 62 | "content": [ 63 | {"type": "text", "text": text}, 64 | {"type": "image_url", 65 | "image_url": {"url": f"data:image/jpeg;base64,{jpg_base64_str}"},}, 66 | ], 67 | }] 68 | response = client.chat.completions.create( 69 | model=self.model_name, 70 | messages=messages, 71 | max_tokens=self.max_tokens, 72 | ) 73 | return response.choices[0].message.content, response 74 | 75 | 76 | CLIENT_DICT = { 77 | "gpt-3.5-turbo": LM_Client, 78 | "gpt-4": LM_Client, 79 | "gpt-4o": GPT4V_Client, 80 | } -------------------------------------------------------------------------------- /webarena/autoeval/evaluate_trajectory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import traceback 5 | from autoeval.evaluator import Evaluator 6 | from autoeval.clients import CLIENT_DICT 7 | 8 | 9 | def load_blocks(path: str) -> list[list[str]]: 10 | """Load blank-line separated blocks from the log file.""" 11 | blocks, block = [], [] 12 | for line in open(path, 'r'): 13 | if line.strip() == "": 14 | blocks.append(block) 15 | block = [] 16 | else: 17 | if line.strip(): 18 | block.append(line.strip()) 19 | assert len(blocks) % 2 == 0 20 | return blocks 21 | 22 | def remove_invalid_steps(actions: list[str]) -> list[str]: 23 | """Remove invalid steps from the action sequence.""" 24 | valid_actions = [] 25 | for a in actions: 26 | if "click(" in a: 27 | arg = a[a.index("(")+1: a.index(")")] 28 | if type(eval(arg)) == str: 29 | valid_actions.append(a) 30 | elif "fill(" in a: 31 | arg = a[a.index("(")+1: a.index(",")].strip() 32 | if type(eval(arg)) == str: 33 | valid_actions.append(a) 34 | else: 35 | valid_actions.append(a) 36 | return valid_actions 37 | 38 | def extract_think_and_action(path: str) -> tuple[list[str], list[str]]: 39 | """Extract the task trajectory from the log file.""" 40 | blocks = load_blocks(path) 41 | think_list, action_list = [], [] 42 | for i in range(1, len(blocks), 2): 43 | # action 44 | b = blocks[i] 45 | actions = remove_invalid_steps(b[1:]) 46 | if len(actions) == 0: continue 47 | action_list.append(actions) 48 | # think 49 | b = blocks[i-1] 50 | idx = b[-1].index("browsergym.experiments.loop - INFO -") 51 | think_list.append(b[-1][idx+36: ].strip()) 52 | 53 | assert len(think_list) == len(action_list) 54 | 55 | # TODO: merge same actions 56 | return think_list, action_list 57 | 58 | def extract_response(action: str) -> str: 59 | s, e = action.index("(")+1, action.index(")") 60 | return action[s: e] 61 | 62 | 63 | def process_sample( 64 | idx: str, traj_info: dict, log_save_path, 65 | model: str, eval_version: str, 66 | ) -> list[dict]: 67 | clients = {model: CLIENT_DICT[model](model_name=model)} 68 | evaluator = Evaluator(clients, log_save_path=log_save_path + "/trajs") 69 | try: 70 | out, _ = evaluator(traj_info, model, eval_version) 71 | eval_result = None 72 | if out["status"].lower() == "success": eval_result = True 73 | else: eval_result = False 74 | return [{ 75 | "idx": idx, 76 | "gt": traj_info["eval"], 77 | "rm": eval_result, 78 | "thoughts": out["thoughts"], 79 | "uid": traj_info["traj_name"], 80 | }] 81 | except Exception as e: 82 | print(f"Error on {idx}, {e}") 83 | print(traceback.format_exc()) 84 | return { 85 | "idx": idx, 86 | "gt": traj_info["eval"], 87 | "rm": None, 88 | "thoughts": None, 89 | "uid": traj_info["traj_name"], 90 | } 91 | 92 | 93 | def main(): 94 | # load task config 95 | task_id = args.result_dir.split('/')[-1].split(".")[1] 96 | config_path = os.path.join("config_files", f"{task_id}.json") 97 | config = json.load(open(config_path)) 98 | 99 | # load trajectory log 100 | log_path = os.path.join(args.result_dir, "experiment.log") 101 | think_list, action_list = extract_think_and_action(log_path) 102 | actions = [act for acts in action_list for act in acts] 103 | if "send_msg_to_user" in action_list[-1][0]: 104 | response = extract_response(action_list[-1][0]) 105 | else: 106 | response = "" 107 | 108 | # load summary info 109 | summary_path = os.path.join(args.result_dir, "summary_info.json") 110 | summary = json.load(open(summary_path, 'r')) 111 | 112 | # collect traj info 113 | image_paths = [ 114 | os.path.join(args.result_dir, f) for f in os.listdir(args.result_dir) 115 | if f.startswith("screenshot_step_") and f.endswith(".jpg") 116 | ] 117 | image_paths = sorted(image_paths, key=lambda x: int(x.split('/')[-1].split("_")[-1].split(".")[0])) 118 | traj_info = { 119 | "intent": config["intent"], 120 | "response": response, 121 | "captions": think_list, 122 | "actions": actions, 123 | "traj_name": config["task_id"], 124 | "image_paths": image_paths, 125 | "images": image_paths, 126 | "eval": summary["cum_reward"] 127 | } 128 | 129 | # evaluate trajectory 130 | log_save_path = os.path.join("autoeval/log", args.result_dir.split('/')[-1]) 131 | print("Log Save Path:", log_save_path) 132 | if not os.path.exists(log_save_path): 133 | os.makedirs(log_save_path) 134 | os.makedirs(log_save_path + "/trajs") 135 | eval_info = process_sample( 136 | idx=config["task_id"], traj_info=traj_info, 137 | log_save_path=log_save_path, 138 | model=args.model, eval_version=args.prompt, 139 | ) 140 | output_eval_path = os.path.join(args.result_dir, f"{args.model}_autoeval.json") 141 | json.dump(eval_info, open(output_eval_path, 'w')) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--result_dir", type=str, required=True, 147 | help="Path to the result directory, e.g., 'webarena.0'.") 148 | # autoeval 149 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo", 150 | choices=["gpt-3.5", "gpt-4", "gpt-4o"]) 151 | parser.add_argument("--prompt", type=str, default="text", 152 | choices=["text", "vision"]) 153 | 154 | args = parser.parse_args() 155 | 156 | if args.model == "gpt-4o" and args.prompt != "vision": 157 | print(f"Waring: use vision prompt by default for {args.model}.") 158 | args.prompt = "vision" 159 | 160 | main() 161 | -------------------------------------------------------------------------------- /webarena/autoeval/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from autoeval.prompts import * 4 | 5 | 6 | class Evaluator: 7 | def __init__(self, lm_clients, log_save_path=None): 8 | self.lm_clients = lm_clients 9 | self.log_save_path = log_save_path 10 | 11 | def __call__(self, info, client="gpt-3.5", version="naive"): 12 | assert (client in self.lm_clients), \ 13 | f"Client {client} not found in {self.lm_clients.keys()}" 14 | if version == "text": 15 | eval_info, eval_str, prompt = self.eval_text(info, client) 16 | elif version == "vision": 17 | eval_info, eval_str, prompt = self.eval_vision(info, client) 18 | else: 19 | raise NotImplementedError(f"Version {version} not implemented") 20 | 21 | if self.log_save_path: 22 | with open(self.log_save_path + "/outputs.jsons", "a") as f: 23 | f.write( 24 | json.dumps( 25 | { 26 | "id": info["traj_name"], 27 | "eval_info": eval_info, 28 | } 29 | ) 30 | + "\n" 31 | ) 32 | with open(f"{self.log_save_path}/{info['traj_name']}.md", "w") as md_file: 33 | md_file.write(f"## Intent\n\n{info['intent']}\n\n") 34 | md_file.write(f"## RM\n\n{eval_str}\n\n") 35 | md_file.write(f"## Final Response {info['response']}\n\n") 36 | if "captions" in info and info['captions'] is not None: 37 | md_file.write("## Captions\n\n") 38 | for idx, cap in enumerate(info["captions"]): 39 | md_file.write(f"===============") 40 | md_file.write(f"{cap}\n") 41 | md_file.write("\n## Images\n\n") 42 | for idx, img in enumerate(info["image_paths"]): 43 | rel_img_path = os.path.relpath(img, self.log_save_path) 44 | md_file.write(f"![Image {idx+1}]({rel_img_path})\n") 45 | 46 | if "config" in info: 47 | md_file.write("## Config\n\n") 48 | cofig_str = json.dumps(info["config"], indent=4) 49 | md_file.write(f"```json\n{cofig_str}\n```\n") 50 | if prompt: 51 | md_file.write("## Main Prompt\n\n") 52 | md_file.write(f"```md\n{prompt}\n```\n") 53 | return eval_info, prompt 54 | 55 | def eval_text(self, info, client): 56 | response = info["response"] if info["response"] else "None" 57 | lm_client = self.lm_clients[client] 58 | action_history = "" 59 | for idx, act in enumerate(info["actions"]): 60 | action_history += f"{idx+1}: {act}\n" 61 | prompt, sys_msg = build_text_eval_prompt( 62 | info["captions"][-1], info["intent"], response, action_history 63 | ) 64 | msg_str, _ = lm_client.one_step_chat(prompt, system_msg=sys_msg) 65 | msg_dict = { 66 | "thoughts": extract_content(msg_str, "Thoughts:"), 67 | "status": extract_content(msg_str, "Status:").replace('"', ""), 68 | } 69 | return msg_dict, msg_str, prompt 70 | 71 | def eval_vision(self, info, client): 72 | assert client == "gpt-4v" or client == "gpt-4o" 73 | action_history = "" 74 | for idx, act in enumerate(info["actions"]): 75 | action_history += f"{idx+1}: {act}\n" 76 | prompt, sys_msg = build_vision_eval_prompt( 77 | info["intent"], info["response"], action_history 78 | ) 79 | img = info["images"][-1] 80 | 81 | lm_client = self.lm_clients[client] 82 | msg_str, _ = lm_client.one_step_chat( 83 | text=prompt, image=img, system_msg=sys_msg 84 | ) 85 | del info["images"] 86 | msg_dict = { 87 | "thoughts": extract_content(msg_str, "Thoughts:"), 88 | "status": extract_content(msg_str, "Status:").replace('"', ""), 89 | } 90 | return msg_dict, msg_str, prompt 91 | -------------------------------------------------------------------------------- /webarena/autoeval/prompts.py: -------------------------------------------------------------------------------- 1 | def build_obs_simplifier_prompt(cap, intent, response) -> str: 2 | prompt = f"""Given the following user question and context, extract part of the context that is unbiased, so that using that text alone would be good context for providing an unbiased answer to the user query. 3 | 4 | **User Query**: The bot responded with "{response}", does it execute this task "{intent}" successfully? 5 | 6 | **Full Context**: 7 | ```md 8 | {cap} 9 | ``` 10 | 11 | Start your answer with “Unbiased text context (includes all relevant content):" 12 | """ 13 | return prompt 14 | 15 | 16 | def build_naive_last_frame_eval_prompt(cap, intent, response) -> str: 17 | prompt = f"""**User Intent**: {intent} 18 | 19 | **Bot's Final Observation**: 20 | 21 | ```md 22 | {cap} 23 | ``` 24 | 25 | **Bot's response to the user**: {response if response else "None"}. 26 | 27 | --- 28 | 29 | Based on the provided user intent, the caption of bot's final observation and its response, did the bot successfully execute the task? Please reason step by step. 30 | 31 | Note: 32 | - The trajectory descriptions are essentially noisy captions of the screenshots captured during bot's execution. And you should infer what actions the bot took yourself. 33 | - You should categorize the execution into one of the three status: 34 | - task-possible-bot-success: The bot successfully executed the task. 35 | - task-possible-bot-fail: The bot failed to execute the task. 36 | - task-impossible: The task is impossible to execute in nature given the user intent and the environment. For example, if the user wants to buy a product that does not exist in the environment. You should carefully distinguish this from bot-fail. 37 | 38 | Format your response as a valid json: 39 | {{ 40 | "thoughts": "{{Your thoughts here, discuss if and how the trajectory progress towards the task and then reason about the final status. You should provide an explicit reason when determining the final status.}}", 41 | "status": "task-possible-bot-success" or "task-possible-bot-fail" or "task-impossible" 42 | }}""" 43 | return prompt 44 | 45 | 46 | def build_naive_multi_frame_eval_prompt(caps, intent, response) -> str: 47 | captions_str = "\n".join( 48 | [f"{idx+1}:\n```md\n{caption}\n```\n" for idx, caption in enumerate(caps[-3:])] 49 | ) 50 | prompt = f"""**User Intent**: {intent} 51 | 52 | **Bot's observation through execution**: 53 | 54 | {captions_str} 55 | 56 | **Bot's response to the user**: {response if response else "None"}. 57 | 58 | --- 59 | 60 | Based on the provided user intent, bot's observation in captions and its response, did the bot successfully execute the task? Please reason step by step. 61 | 62 | Note: 63 | - You should categorize the execution into one of the three status: 64 | - task-possible-bot-success: The bot successfully executed the task. 65 | - task-possible-bot-fail: The bot failed to execute the task. 66 | - task-impossible: The task is impossible to execute in nature given the user intent and the environment. For example, if the user wants to buy a product that does not exist in the environment. You should carefully distinguish this from bot-fail. 67 | 68 | Format your response as a valid json: 69 | {{ 70 | "thoughts": "{{Your thoughts here, discuss if and how the trajectory progress towards the task and then reason about the final status. You should provide an explicit reason when determining the final status.}}", 71 | "status": "task-possible-bot-success" or "task-possible-bot-fail" or "task-impossible" 72 | }}""" 73 | return prompt 74 | 75 | 76 | def extract_content(text, start_tag): 77 | """ 78 | Extract the content that follows 'Info:' in a given string. 79 | 80 | :param text: A string that may contain lines starting with 'Info:' 81 | :return: The content that follows 'Info:' or None if not found 82 | """ 83 | # Split the text into lines 84 | lines = text.split("\n") 85 | 86 | # Loop through each line to find a line that starts with 'Info:' 87 | for line in lines: 88 | if line.startswith(start_tag): 89 | # Extract and return the content after 'Info:' 90 | return line[len(start_tag) :].strip() 91 | 92 | # Return None if 'Info:' is not found in any line 93 | return "" 94 | 95 | 96 | def build_text_eval_prompt( 97 | cap, intent, response, last_actions 98 | ) -> tuple[str, str]: 99 | system_msg = """You are an expert in evaluating the performance of a web navigation agent. The agent is designed to help a human user navigate a website to complete a task. Given the user's intent, the agent's action history, the final state of the webpage, and the agent's response to the user, your goal is to decide whether the agent's execution is successful or not. 100 | 101 | There are three types of tasks: 102 | 1. Information seeking: The user wants to obtain certain information from the webpage, such as the information of a product, reviews, map info, comparison of map routes, etc. The bot's response must contain the information the user wants, or explicitly state that the information is not available. Otherwise, e.g. the bot encounters an exception and respond with the error content, the task is considered a failure. Besides, be careful about the sufficiency of the agent's actions. For example, when asked to list the top-searched items in a shop, the agent should order the items by the number of searches, and then return the top items. If the ordering action is missing, the task is likely to fail. 103 | 2. Site navigation: The user wants to navigate to a specific page. Carefully examine the bot's action history and the final state of the webpage to determine whether the bot successfully completes the task. No need to consider the bot's response. 104 | 3. Content modification: The user wants to modify the content of a webpage or configuration. Carefully examine the bot's action history and the final state of the webpage to determine whether the bot successfully completes the task. No need to consider the bot's response. 105 | 106 | *IMPORTANT* 107 | Format your response into two lines as shown below: 108 | 109 | Thoughts: " 110 | Status: "success" or "failure" 111 | """ 112 | prompt = f"""User Intent: {intent} 113 | 114 | Action History: 115 | {last_actions} 116 | 117 | The detailed final state of the webpage: 118 | 119 | ```md 120 | {cap} 121 | ``` 122 | 123 | Bot response to the user: {response if response else "N/A"}.""" 124 | return prompt, system_msg 125 | 126 | 127 | def build_vision_eval_prompt( 128 | intent, response, last_actions 129 | ) -> tuple[str, str]: 130 | system_msg = """You are an expert in evaluating the performance of a web navigation agent. The agent is designed to help a human user navigate a website to complete a task. Given the user's intent, the agent's action history, the final state of the webpage, and the agent's response to the user, your goal is to decide whether the agent's execution is successful or not. 131 | 132 | There are three types of tasks: 133 | 1. Information seeking: The user wants to obtain certain information from the webpage, such as the information of a product, reviews, map info, comparison of map routes, etc. The bot's response must contain the information the user wants, or explicitly state that the information is not available. Otherwise, e.g. the bot encounters an exception and respond with the error content, the task is considered a failure. Besides, be careful about the sufficiency of the agent's actions. For example, when asked to list the top-searched items in a shop, the agent should order the items by the number of searches, and then return the top items. If the ordering action is missing, the task is likely to fail. 134 | 2. Site navigation: The user wants to navigate to a specific page. Carefully examine the bot's action history and the final state of the webpage to determine whether the bot successfully completes the task. No need to consider the bot's response. 135 | 3. Content modification: The user wants to modify the content of a webpage or configuration. Carefully examine the bot's action history and the final state of the webpage to determine whether the bot successfully completes the task. No need to consider the bot's response. 136 | 137 | *IMPORTANT* 138 | Format your response into two lines as shown below: 139 | 140 | Thoughts: 141 | Status: "success" or "failure" 142 | """ 143 | prompt = f"""User Intent: {intent} 144 | 145 | Action History: 146 | {last_actions} 147 | 148 | The last snapshot of the web page is shown in the image.""" 149 | return prompt, system_msg 150 | -------------------------------------------------------------------------------- /webarena/autoeval/requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | requests 3 | pillow 4 | bs4 5 | matplotlib 6 | termcolor 7 | human_id 8 | pandas 9 | easy_ocr 10 | einops 11 | transformers_stream_generator 12 | tiktoken -------------------------------------------------------------------------------- /webarena/config_files/generate_test_data.py: -------------------------------------------------------------------------------- 1 | """Replace the website placeholders with website domains from env_config 2 | Generate the test data""" 3 | 4 | import os 5 | import json 6 | 7 | 8 | def main() -> None: 9 | with open("test.raw.json", "r") as f: 10 | raw = f.read() 11 | raw = raw.replace("__GITLAB__", os.environ.get("GITLAB")) 12 | raw = raw.replace("__REDDIT__", os.environ.get("REDDIT")) 13 | raw = raw.replace("__SHOPPING__", os.environ.get("SHOPPING")) 14 | raw = raw.replace("__SHOPPING_ADMIN__", os.environ.get("SHOPPING_ADMIN")) 15 | raw = raw.replace("__WIKIPEDIA__", os.environ.get("WIKIPEDIA")) 16 | raw = raw.replace("__MAP__", os.environ.get("MAP")) 17 | with open("test.json", "w") as f: 18 | f.write(raw) 19 | # split to multiple files 20 | data = json.loads(raw) 21 | for idx, item in enumerate(data): 22 | with open(f"{idx}.json", "w") as f: 23 | json.dump(item, f, indent=2) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() -------------------------------------------------------------------------------- /webarena/induce_prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | 6 | import openai 7 | openai.api_key = os.environ["OPENAI_API_KEY"] 8 | from openai import OpenAI 9 | client = OpenAI() 10 | 11 | # %% load examples 12 | def load_blocks(path: str) -> list[list[str]]: 13 | """Load blank-line separated blocks from the log file.""" 14 | blocks, block = [], [] 15 | for line in open(path, 'r'): 16 | if line.strip() == "": 17 | blocks.append(block) 18 | block = [] 19 | else: 20 | if line.strip(): 21 | block.append(line.strip()) 22 | assert len(blocks) % 2 == 0 23 | return blocks 24 | 25 | def remove_invalid_steps(actions: list[str]) -> list[str]: 26 | """Remove invalid steps from the action sequence.""" 27 | valid_actions = [] 28 | for a in actions: 29 | if "click(" in a: 30 | arg = a[a.index("(")+1: a.index(")")] 31 | try: 32 | if type(eval(arg)) == str and type(eval(arg[1:-1])) == int: 33 | valid_actions.append(a) 34 | except: 35 | continue 36 | elif "fill(" in a: 37 | arg = a[a.index("(")+1: a.index(",")].strip() 38 | if type(eval(arg)) == str: 39 | valid_actions.append(a) 40 | elif "scroll(" in a or "noop(" in a: 41 | continue 42 | else: 43 | valid_actions.append(a) 44 | return valid_actions 45 | 46 | def extract_think_and_action(path: str) -> tuple[list[str], list[str]]: 47 | """Extract the task trajectory from the log file.""" 48 | blocks = load_blocks(path) 49 | think_list, action_list = [], [] 50 | for i in range(1, len(blocks), 2): 51 | # action 52 | b = blocks[i] 53 | actions = remove_invalid_steps(b[1:]) 54 | if len(actions) == 0: continue 55 | action_list.append(actions) 56 | # think 57 | b = blocks[i-1] 58 | idx = b[-1].index("browsergym.experiments.loop - INFO -") 59 | think_list.append(b[-1][idx+36: ].strip()) 60 | 61 | assert len(think_list) == len(action_list) 62 | 63 | # TODO: merge same actions 64 | return think_list, action_list 65 | 66 | def format_trajectory(think_list: list[str], action_list: list[list[str]]) -> str: 67 | trajectory = [] 68 | for t, a in zip(think_list, action_list): 69 | acts = '\n'.join(a) 70 | trajectory.append(f"\n{t}\n\n\n{acts}\n") 71 | return '\n\n'.join(trajectory) 72 | 73 | def random_group_sample(d: dict, n) -> list: 74 | """Randomly sample n groups from the dictionary.""" 75 | return [ex for v in d.values() for ex in random.sample(v, min(n, len(v)))] 76 | 77 | # %% prompt model 78 | def format_examples(examples: list[dict]) -> str: 79 | """Format examples to the prompt.""" 80 | formatted_examples = [] 81 | for ex in examples: 82 | trajectory = format_trajectory(ex["think_list"], ex["action_list"]) 83 | formatted_examples.append(f"Query: {ex['query']}\nActions:\n{trajectory}") 84 | return '\n\n'.join(["## Concrete Examples"] + formatted_examples + ["## Summary Workflows"]) 85 | 86 | def llm_generate(examples: list[dict], args, verbose: bool = False): 87 | """Call gpt model to generate workflows.""" 88 | prompt = format_examples(examples) 89 | prompt = '\n\n'.join([args.INSTRUCTION, args.ONE_SHOT, prompt]) 90 | if verbose: print("Prompt:\n", prompt, '\n\n') 91 | response = client.chat.completions.create( 92 | model=args.model, 93 | messages=[{"role": "user", "content": prompt}], 94 | temperature=1.0, 95 | max_tokens=2048, 96 | ) 97 | response = response.choices[0].message.content 98 | if verbose: print(response) 99 | return response 100 | 101 | 102 | def main(): 103 | # collect result directories, e.g., ["results/webarena.0", ...] 104 | args.result_dir = args.result_dir.split() 105 | if args.criteria == "gt": 106 | file_dirs = [ 107 | os.path.join(res_dir, f) for res_dir in args.result_dir for f in os.listdir(res_dir) 108 | if json.load( 109 | open(os.path.join(res_dir, f, "summary_info.json")) 110 | )["cum_reward"] 111 | ] 112 | elif args.criteria == "autoeval": 113 | file_dirs = [] 114 | for res_dir in args.result_dir: 115 | for f in os.listdir(res_dir): 116 | record_path = os.path.join(res_dir, f, f"{args.model}_autoeval.json") 117 | if not os.path.exists(record_path): continue 118 | record = json.load(open(record_path)) 119 | if record[0]["rm"]: 120 | file_dirs.append(os.path.join(res_dir, f)) 121 | else: 122 | raise ValueError(f"Invalid criteria: {args.criteria}.") 123 | 124 | print(f"Collected {len(file_dirs)} result directories.") 125 | 126 | # template id based deduplication 127 | template_dict = {} 128 | for f in file_dirs: 129 | # get query -> task objective 130 | task_id = f.split('/')[-1].split("_")[0].split(".")[1] 131 | config_path = os.path.join("config_files", f"{task_id}.json") 132 | config = json.load(open(config_path)) 133 | query = config["intent"] 134 | 135 | template_id = config["intent_template_id"] # for deduplication 136 | 137 | # parse trajectory 138 | log_path = os.path.join(f, "experiment.log") 139 | try: 140 | think_list, action_list = extract_think_and_action(log_path) 141 | except: 142 | continue 143 | 144 | # add to template dict 145 | wdict = {"query": query, "think_list": think_list, "action_list": action_list} 146 | if template_id not in template_dict: template_dict[template_id] = [] 147 | template_dict[template_id].append(wdict) 148 | selected_examples = random_group_sample(template_dict, args.num_samples) 149 | print(f"#{len(selected_examples)} result dirs after template dedup..") 150 | 151 | # prompt model to induce workflows 152 | workflows = llm_generate(selected_examples, args, True) 153 | workflows += "\n\nclick('id') # input string id value for all actions\n\nselect_option('id', 'value') # for dropdown menu" 154 | 155 | if args.output_path is None: 156 | website = config["sites"][0] # assumes all results are about the same website 157 | args.output_path = f"workflow/{website}_neural.txt" 158 | print(f"[Warning] no output path specified, using '{args.output_path}' by default") 159 | 160 | with open(args.output_path, 'w') as fw: 161 | fw.write(workflows) 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("--result_dir", type=str, default="results", 167 | help="Path to the result directory. Support multiple directories separated by space.") 168 | parser.add_argument("--output_path", type=str, default=None, 169 | help="Path to the output file.") 170 | parser.add_argument("--criteria", type=str, default="autoeval", 171 | choices=["gt", "autoeval"], 172 | help="'gt': only use examples with gold reward, 'autoeval': use examples with autoeval reward.") 173 | parser.add_argument("--model", type=str, default="gpt-4o", 174 | choices=["gpt-3.5", "gpt-4", "gpt-4o"]) 175 | parser.add_argument("--num_samples", type=int, default=1, help="Max number of samples to input per template.") 176 | args = parser.parse_args() 177 | 178 | args.INSTRUCTION = open("prompt/instruction.txt", 'r').read() 179 | args.ONE_SHOT = open("prompt/one_shot.txt", 'r').read() 180 | 181 | main() 182 | -------------------------------------------------------------------------------- /webarena/induce_rule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | 6 | 7 | def load_blocks(path: str) -> list[list[str]]: 8 | """Load blank-line separated blocks from the log file.""" 9 | blocks, block = [], [] 10 | for line in open(path, 'r'): 11 | if line.strip() == "": 12 | blocks.append(block) 13 | block = [] 14 | else: 15 | if line.strip(): 16 | block.append(line.strip()) 17 | assert len(blocks) % 2 == 0 18 | return blocks 19 | 20 | def remove_invalid_steps(actions: list[str]) -> list[str]: 21 | """Remove invalid steps from the action sequence.""" 22 | valid_actions = [] 23 | for a in actions: 24 | if "click(" in a: 25 | arg = a[a.index("(")+1: a.index(")")] 26 | try: 27 | if type(eval(arg)) == str and type(eval(arg[1:-1])) == int: 28 | valid_actions.append(a) 29 | except: 30 | continue 31 | elif "fill(" in a: 32 | arg = a[a.index("(")+1: a.index(",")].strip() 33 | if type(eval(arg)) == str: 34 | valid_actions.append(a) 35 | else: 36 | valid_actions.append(a) 37 | return valid_actions 38 | 39 | def extract_think_and_action(path: str) -> tuple[list[str], list[str]]: 40 | """Extract the task trajectory from the log file.""" 41 | blocks = load_blocks(path) 42 | think_list, action_list = [], [] 43 | for i in range(1, len(blocks), 2): 44 | # action 45 | b = blocks[i] 46 | actions = remove_invalid_steps(b[1:]) 47 | if len(actions) == 0: continue 48 | action_list.append(actions) 49 | # think 50 | b = blocks[i-1] 51 | idx = b[-1].index("browsergym.experiments.loop - INFO -") 52 | think_list.append(b[-1][idx+36: ].strip()) 53 | 54 | assert len(think_list) == len(action_list) 55 | 56 | # TODO: merge same actions 57 | return think_list, action_list 58 | 59 | def format_trajectory(think_list: list[str], action_list: list[list[str]]) -> str: 60 | trajectory = [] 61 | for t, a in zip(think_list, action_list): 62 | acts = '\n'.join(a) 63 | trajectory.append(f"\n{t}\n\n\n{acts}\n") 64 | return '\n\n'.join(trajectory) 65 | 66 | def get_abstract_trajectory(action_list: list[list[str]]) -> str: 67 | abstract = [] 68 | for acts in action_list: 69 | for a in acts: 70 | s = a.index("(") 71 | e = a.index(',', s) if ',' in a[s:] else a.index(")", s) 72 | action = a[:s] 73 | if action != "send_msg_to_user": 74 | arg = a[s+1: e] 75 | abstract.append(f"{action}({arg})") 76 | else: 77 | abstract.append(f"{action}") 78 | return '_'.join(abstract) 79 | 80 | def random_group_sample(d: dict, n) -> list: 81 | """Randomly sample n groups from the dictionary.""" 82 | return [ex for v in d.values() for ex in random.sample(v, n)] 83 | 84 | 85 | def main(): 86 | # collect result directories, e.g., ["results/webarena.0", ...] 87 | args.result_dir = args.result_dir.split() 88 | if args.criteria == "gt": 89 | file_dirs = [ 90 | os.path.join(res_dir, f) for res_dir in args.result_dir for f in os.listdir(res_dir) 91 | if json.load( 92 | open(os.path.join(res_dir, f, "summary_info.json")) 93 | )["cum_reward"] 94 | ] 95 | elif args.criteria == "autoeval": 96 | file_dirs = [] 97 | for res_dir in args.result_dir: 98 | for f in os.listdir(res_dir): 99 | record_path = os.path.join(res_dir, f, f"{args.model}_autoeval.json") 100 | if not os.path.exists(record_path): continue 101 | record = json.load(open(record_path)) 102 | if record[0]["rm"]: 103 | file_dirs.append(os.path.join(res_dir, f)) 104 | else: 105 | raise ValueError(f"Invalid criteria: {args.criteria}.") 106 | 107 | print(f"Collected {len(file_dirs)} result directories.") 108 | 109 | # template id based deduplication 110 | template_dict = {} 111 | for f in file_dirs: 112 | # get query -> task objective 113 | task_id = f.split('/')[-1].split("_")[0].split(".")[1] 114 | config_path = os.path.join("config_files", f"{task_id}.json") 115 | config = json.load(open(config_path)) 116 | query = config["intent"] 117 | 118 | template_id = config["intent_template_id"] # for deduplication 119 | 120 | # parse trajectory 121 | log_path = os.path.join(f, "experiment.log") 122 | try: 123 | think_list, action_list = extract_think_and_action(log_path) 124 | except: 125 | continue 126 | 127 | # add to template dict 128 | wdict = {"query": query, "think_list": think_list, "action_list": action_list} 129 | if template_id not in template_dict: template_dict[template_id] = [] 130 | template_dict[template_id].append(wdict) 131 | selected_workflows = random_group_sample(template_dict, 1) 132 | print(f"#{len(selected_workflows)} result dirs after template dedup..") 133 | 134 | # deduplicate by abstract trajectory 135 | abstraj_dict = {} 136 | for w in selected_workflows: 137 | abs_traj = get_abstract_trajectory(w['action_list']) 138 | if abs_traj not in abstraj_dict: 139 | abstraj_dict[abs_traj] = [] 140 | abstraj_dict[abs_traj].append(w) 141 | selected_workflows = random_group_sample(abstraj_dict, 1) 142 | print(f"#{len(selected_workflows)} result dirs after trajectory dedup..") 143 | 144 | # manual inspection 145 | def get_workflow(d: dict) -> str: 146 | return f"Query: {d['query']}\n" + format_trajectory(d['think_list'], d['action_list']) 147 | manual_workflows = [] 148 | for w in selected_workflows: 149 | w = get_workflow(w) 150 | if args.auto: 151 | to_add = 'y' 152 | else: 153 | to_add = input("Workflow: \n" + w + "\n\nAdd? (y/n): ") 154 | if to_add == 'y': 155 | manual_workflows.append(w) 156 | print(f"#{len(manual_workflows)} result dirs after manual inspection..") 157 | 158 | 159 | 160 | if args.output_path is None: 161 | website = config["sites"][0] # assumes all results are about the same website 162 | args.output_path = f"workflow/{website}.txt" 163 | print(f"[Warning] no output path specified, using '{args.output_path}' by default") 164 | 165 | with open(args.output_path, 'w') as fw: 166 | fw.write('\n\n\n'.join(["## Concrete Examples"] + manual_workflows)) 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument("--result_dir", type=str, default="results", 172 | help="Path to the result directory. Support multiple directories separated by space.") 173 | parser.add_argument("--output_path", type=str, default=None, 174 | help="Path to the output file.") 175 | parser.add_argument("--criteria", type=str, default="autoeval", 176 | choices=["gt", "autoeval"], 177 | help="'gt': only use examples with gold reward, 'autoeval': use examples with autoeval reward.") 178 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo", 179 | choices=["gpt-3.5", "gpt-4", "gpt-4o"]) 180 | parser.add_argument("--auto", action="store_true", help="w/o manual workflow inspections.") 181 | args = parser.parse_args() 182 | 183 | main() 184 | -------------------------------------------------------------------------------- /webarena/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from subprocess import Popen 5 | 6 | def main(): 7 | # collect examples 8 | config_files = [ 9 | os.path.join("config_files", f) for f in os.listdir("config_files") 10 | if f.endswith(".json") and f.split(".")[0].isdigit() 11 | ] 12 | config_files = sorted(config_files, key=lambda x: int(x.split("/")[-1].split(".")[0])) 13 | config_list = [json.load(open(f)) for f in config_files] 14 | config_flags = [config["sites"][0] == args.website for config in config_list] 15 | task_ids = [config["task_id"] for config, flag in zip(config_list, config_flags) if flag] 16 | 17 | if args.end_index == None: args.end_index = len(task_ids) 18 | for tid in task_ids[args.start_index: args.end_index]: 19 | # step 1: run inference 20 | process = Popen([ 21 | "python", "run.py", 22 | "--task", f"webarena.{tid}", 23 | "--workflow_path", f"workflow/{args.website}.txt" 24 | ]) 25 | process.wait() 26 | 27 | # step 2: run evaluation 28 | process = Popen([ 29 | "python", "-m", "autoeval.evaluate_trajectory", 30 | "--result_dir", f"results/webarena.{tid}" 31 | ]) 32 | process.wait() 33 | 34 | # step 3: update workflow 35 | process = Popen([ 36 | "python", "induce_prompt.py", 37 | "--result_dir", "results", 38 | "--output_path", f"workflow/{args.website}.txt", 39 | ]) 40 | process.wait() 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--website", type=str, required=True, 46 | choices=["shopping", "shopping_admin", "gitlab", "reddit", "map"]) 47 | parser.add_argument("--start_index", type=int, default=0) 48 | parser.add_argument("--end_index", type=int, default=None) 49 | args = parser.parse_args() 50 | 51 | main() 52 | -------------------------------------------------------------------------------- /webarena/prompt/instruction.txt: -------------------------------------------------------------------------------- 1 | Given a list of web navigation tasks, your task is to extract the common workflows to solve these tasks. 2 | Each given task contains a natural language instruction, and a series of actions to solve the task. You need to find the repetitive subset of actions across multiple tasks, and extract each of them out as a workflow. 3 | Each workflow should be a commonly-reused sub-routine of the tasks. Do not generate similar or overlapping workflows. Each workflow should have at least two steps. Represent the non-fixed elements (input text, button strings) with descriptive variable names as shown in the example. 4 | Keep the values of invariant elements, e.g., id of "Search" or "Customers", as they will share and stay invariant across tasks. 5 | Try to generate as many workflows that can cover all the tasks in the input list. -------------------------------------------------------------------------------- /webarena/prompt/one_shot.txt: -------------------------------------------------------------------------------- 1 | ## Concrete Examples 2 | 3 | Query: What is the date when I made my first purchase on this site? 4 | Actions: 5 | 6 | To find the date of the first purchase, I need to navigate to the order history section of the user's account. I will start by clicking on the "My Account" link. 7 | 8 | 9 | click('227') 10 | 11 | 12 | 13 | To find the date of the first purchase, I need to navigate to the "My Orders" section where the order history is listed. From there, I can look for the earliest order date. I will start by clicking on the "My Orders" link in the left sidebar. 14 | 15 | 16 | click('1843') 17 | 18 | 19 | 20 | To find the date of the first purchase, I need to look at the earliest date in the "Date" column of the "Orders" table. The current page shows the first 10 orders, but there are multiple pages. I will navigate to the last page to find the earliest order date. 21 | 22 | 23 | click('1788') 24 | 25 | 26 | 27 | To find the date of the first purchase, I need to look for the earliest date in the "Date" column of the "My Orders" table. From the current view, the earliest date visible is "5/2/23". However, there might be more orders on other pages. I will navigate to the last page to ensure I find the earliest date. 28 | 29 | 30 | click('1784') 31 | 32 | 33 | 34 | To find the earliest purchase date, I need to look at the "Date" column in the "Orders" table and identify the earliest date listed. The dates currently visible are: 4/27/22, 4/5/22, 3/10/22, 3/2/22, 3/2/22. The earliest date visible is 3/2/22. I will respond with this date. 35 | 36 | 37 | send_msg_to_user('The earliest purchase date is 3/2/22.') 38 | 39 | 40 | Query: I want to browse the products in the Cabinets, Racks & Shelves category 41 | 42 | 43 | To browse the products in the Cabinets, Racks & Shelves category, I need to navigate to the "Home & Kitchen" section, which likely contains the Cabinets, Racks & Shelves category. I will start by hovering over the "Home & Kitchen" menu item to reveal the subcategories. 44 | 45 | 46 | hover('735') 47 | 48 | 49 | 50 | To browse the products in the Cabinets, Racks & Shelves category, I need to navigate to the "Storage & Organization" section under the "Home & Kitchen" menu. I will start by hovering over the "Storage & Organization" menu item to reveal the subcategories. 51 | 52 | 53 | hover('786') 54 | 55 | 56 | 57 | To browse the products in the Cabinets, Racks & Shelves category, I need to click on the "Racks, Shelves & Drawers" link under the "Storage & Organization" menu. This will take me to the desired category. 58 | 59 | 60 | click('797') 61 | 62 | 63 | 64 | The current page is displaying the "Racks, Shelves & Drawers" category under "Storage & Organization" in the "Home & Kitchen" section. The user wants to browse products in the "Cabinets, Racks & Shelves" category. Since "Racks, Shelves & Drawers" is already a subcategory of "Storage & Organization," I will confirm with the user if this is the correct category they want to browse. 65 | 66 | 67 | send_msg_to_user('You are currently viewing the "Racks, Shelves & Drawers" category under "Storage & Organization." Is this the correct category you want to browse, or are you looking for a different category?') 68 | 69 | 70 | 71 | Query: Get the order number of my most recent cancelled order 72 | 73 | To get the order number of the most recent cancelled order, I need to navigate to the "My Account" section where order history is typically managed. From the current page, I can see the "My Account" link in the top right corner. I will click on the "My Account" link to proceed. 74 | 75 | 76 | click('227') 77 | 78 | 79 | 80 | To get the order number of the most recent cancelled order, I need to navigate to the "My Orders" section. From there, I can filter or sort the orders to find the most recent cancelled order. I will start by clicking on the "My Orders" link in the left sidebar. 81 | 82 | 83 | click('1843') 84 | 85 | 86 | 87 | To get the order number of the most recent cancelled order, I need to look for the order with the status "Canceled" and note its order number. From the current page, I can see that the most recent cancelled order is order number 000000170. 88 | 89 | 90 | send_msg_to_user('The order number of your most recent cancelled order is 000000170.') 91 | 92 | 93 | 94 | ## Summary Workflows 95 | 96 | Workflow 1: Find all orders 97 | 98 | To find the date of the first purchase, I need to navigate to the order history section of the user's account. I will start by clicking on the "My Account" link. 99 | 100 | 101 | click('227') 102 | 103 | 104 | 105 | To find the date of the first purchase, I need to navigate to the "My Orders" section where the order history is listed. From there, I can look for the earliest order date. I will start by clicking on the "My Orders" link in the left sidebar. 106 | 107 | 108 | click('1843') 109 | 110 | 111 | Workflow 2: Obtain order information, such as date or order number 112 | 113 | To find the earliest purchase date, I need to look at the "Date" column in the "Orders" table and identify the earliest date listed. The dates currently visible are: 4/27/22, 4/5/22, 3/10/22, 3/2/22, 3/2/22. The earliest date visible is 3/2/22. I will respond with this date. 114 | 115 | 116 | send_msg_to_user('The earliest purchase date is 3/2/2022.') 117 | 118 | 119 | Workflow 3: Browse products in a specific (sub-)category 120 | 121 | To browse the products in the Cabinets, Racks & Shelves category, I need to navigate to the "Home & Kitchen" section, which likely contains the Cabinets, Racks & Shelves category. I will start by hovering over the "Home & Kitchen" menu item to reveal the subcategories. 122 | 123 | 124 | hover('735') 125 | 126 | 127 | 128 | To browse the products in the Cabinets, Racks & Shelves category, I need to navigate to the "Storage & Organization" section under the "Home & Kitchen" menu. I will start by hovering over the "Storage & Organization" menu item to reveal the subcategories. 129 | 130 | 131 | hover('786') 132 | 133 | 134 | 135 | To browse the products in the Cabinets, Racks & Shelves category, I need to click on the "Racks, Shelves & Drawers" link under the "Storage & Organization" menu. This will take me to the desired category. 136 | 137 | 138 | click('797') 139 | -------------------------------------------------------------------------------- /webarena/requirements.txt: -------------------------------------------------------------------------------- 1 | browsergym-core>=0.3 2 | browsergym-experiments>=0.3 3 | openai>=1.35.4,<1.36 4 | langchain>=0.2,<0.3 5 | langchain_openai>=0.1.10,<0.2 6 | tiktoken 7 | huggingface_hub 8 | contexttimer 9 | ipython 10 | pyyaml>=6 11 | pandas 12 | joblib 13 | transformers 14 | langchain_community>=0.2.6,<0.3 -------------------------------------------------------------------------------- /webarena/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | WARNING DEPRECATED WILL BE REMOVED SOON 3 | """ 4 | 5 | import os 6 | import argparse 7 | from pathlib import Path 8 | 9 | from browsergym.experiments import ExpArgs, EnvArgs 10 | 11 | from agents.legacy.agent import GenericAgentArgs 12 | from agents.legacy.dynamic_prompting import Flags 13 | from agents.legacy.utils.chat_api import ChatModelArgs 14 | 15 | 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ("yes", "true", "t", "y", "1"): 20 | return True 21 | elif v.lower() in ("no", "false", "f", "n", "0"): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError("Boolean value expected.") 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description="Run experiment with hyperparameters.") 29 | parser.add_argument( 30 | "--model_name", 31 | type=str, 32 | default="openai/gpt-4o", 33 | help="Model name for the chat model.", 34 | ) 35 | parser.add_argument( 36 | "--task_name", 37 | type=str, 38 | default="openended", 39 | help="Name of the Browsergym task to run. If 'openended', you need to specify a 'start_url'", 40 | ) 41 | parser.add_argument( 42 | "--start_url", 43 | type=str, 44 | default="https://www.google.com", 45 | help="Starting URL (only for the openended task).", 46 | ) 47 | parser.add_argument( 48 | "--slow_mo", type=int, default=30, help="Slow motion delay for the playwright actions." 49 | ) 50 | parser.add_argument( 51 | "--headless", 52 | type=str2bool, 53 | default=False, 54 | help="Run the experiment in headless mode (hides the browser windows).", 55 | ) 56 | parser.add_argument( 57 | "--demo_mode", 58 | type=str2bool, 59 | default=True, 60 | help="Add visual effects when the agents performs actions.", 61 | ) 62 | parser.add_argument( 63 | "--use_html", type=str2bool, default=False, help="Use HTML in the agent's observation space." 64 | ) 65 | parser.add_argument( 66 | "--use_ax_tree", 67 | type=str2bool, 68 | default=True, 69 | help="Use AX tree in the agent's observation space.", 70 | ) 71 | parser.add_argument( 72 | "--use_screenshot", 73 | type=str2bool, 74 | default=True, 75 | help="Use screenshot in the agent's observation space.", 76 | ) 77 | parser.add_argument( 78 | "--multi_actions", type=str2bool, default=True, help="Allow multi-actions in the agent." 79 | ) 80 | parser.add_argument( 81 | "--action_space", 82 | type=str, 83 | default="bid", 84 | choices=["python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav"], 85 | help="", 86 | ) 87 | parser.add_argument( 88 | "--use_history", 89 | type=str2bool, 90 | default=True, 91 | help="Use history in the agent's observation space.", 92 | ) 93 | parser.add_argument( 94 | "--use_thinking", 95 | type=str2bool, 96 | default=True, 97 | help="Use thinking in the agent (chain-of-thought prompting).", 98 | ) 99 | parser.add_argument( 100 | "--max_steps", 101 | type=int, 102 | default=10, 103 | help="Maximum number of steps to take for each task.", 104 | ) 105 | parser.add_argument( 106 | "--workflow_path", 107 | type=str, 108 | default=None, 109 | help="Path to the memory file to load for the agent.", 110 | ) 111 | 112 | return parser.parse_args() 113 | 114 | 115 | def main(): 116 | print( 117 | """\ 118 | WARNING this demo agent will soon be moved elsewhere. Expect it to be removed at some point.""" 119 | ) 120 | 121 | args = parse_args() 122 | if (args.workflow_path is not None) and (not os.path.exists(args.workflow_path)): 123 | open(args.workflow_path, "w").close() 124 | 125 | env_args = EnvArgs( 126 | task_name=args.task_name, 127 | task_seed=None, 128 | max_steps=args.max_steps, 129 | headless=args.headless, 130 | viewport={"width": 1500, "height": 1280}, 131 | slow_mo=args.slow_mo, 132 | ) 133 | 134 | if args.task_name == "openended": 135 | env_args.wait_for_user_message = True 136 | env_args.task_kwargs = {"start_url": args.start_url} 137 | 138 | exp_args = ExpArgs( 139 | env_args=env_args, 140 | agent_args=GenericAgentArgs( 141 | chat_model_args=ChatModelArgs( 142 | model_name=args.model_name, 143 | max_total_tokens=128_000, # "Maximum total tokens for the chat model." 144 | max_input_tokens=126_000, # "Maximum tokens for the input to the chat model." 145 | max_new_tokens=2_000, # "Maximum total tokens for the chat model." 146 | ), 147 | flags=Flags( 148 | use_html=args.use_html, 149 | use_ax_tree=args.use_ax_tree, 150 | use_thinking=args.use_thinking, # "Enable the agent with a memory (scratchpad)." 151 | use_error_logs=True, # "Prompt the agent with the error logs." 152 | use_memory=False, # "Enables the agent with a memory (scratchpad)." 153 | use_history=args.use_history, 154 | use_diff=False, # "Prompt the agent with the difference between the current and past observation." 155 | use_past_error_logs=True, # "Prompt the agent with the past error logs." 156 | use_action_history=True, # "Prompt the agent with the action history." 157 | multi_actions=args.multi_actions, 158 | use_abstract_example=True, # "Prompt the agent with an abstract example." 159 | use_concrete_example=True, # "Prompt the agent with a concrete example." 160 | use_screenshot=args.use_screenshot, 161 | enable_chat=True, 162 | demo_mode="default" if args.demo_mode else "off", 163 | workflow_path=args.workflow_path, 164 | ), 165 | ), 166 | ) 167 | 168 | exp_args.prepare(Path("./results")) 169 | exp_args.run() 170 | 171 | os.rename(exp_args.exp_dir, f"results/{args.task_name}") 172 | 173 | 174 | if __name__ == "__main__": 175 | main() 176 | --------------------------------------------------------------------------------