├── .gitignore
├── LICENSE
├── README.md
├── assets
├── mind2web-results.jpg
├── online-offline.jpg
├── teaser.jpg
└── webarena-leaderboard.jpg
├── mind2web
├── README.md
├── data
│ └── memory
│ │ └── exemplars.json
├── memory.py
├── offline_induction.py
├── online_induction.py
├── pipeline.py
├── prompt
│ ├── instruction_abstract.txt
│ ├── instruction_action.txt
│ ├── one_shot_abstract.txt
│ └── one_shot_action.txt
├── requirements.txt
├── results
│ └── calc_score.py
├── run_mind2web.py
├── utils
│ ├── data.py
│ ├── env.py
│ └── llm.py
└── workflow
│ └── retrieve.py
└── webarena
├── README.md
├── agents
├── __init__.py
├── basic
│ ├── __init__.py
│ └── agent.py
└── legacy
│ ├── __init__.py
│ ├── agent.py
│ ├── dynamic_prompting.py
│ └── utils
│ ├── __init__.py
│ ├── chat_api.py
│ ├── llm_utils.py
│ └── prompt_templates.py
├── autoeval
├── clients.py
├── evaluate_trajectory.py
├── evaluator.py
├── prompts.py
└── requirements.txt
├── config_files
├── generate_test_data.py
├── test.json
└── test.raw.json
├── induce_prompt.py
├── induce_rule.py
├── pipeline.py
├── prompt
├── instruction.txt
└── one_shot.txt
├── requirements.txt
└── run.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/**
2 | **/results/**
3 | webarena/config_files/**
4 | webarena/autoeval/log
5 | **/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
10 |
11 | ## Quickstart :boom:
12 | To run AWM on WebArena under `webarena/`:
13 | ```bash
14 | cd webarena
15 | python pipeline.py --website "shopping" # choose one from ['shopping', 'shopping_admin', 'reddit', 'gitlab', 'map']
16 | ```
17 |
18 | To run AWM on Mind2Web under `mind2web/`:
19 | ```bash
20 | cd mind2web
21 | python pipeline.py --setup "offline" # or "online"
22 | ```
23 | Check `webarena/` and `mind2web/` folders for more detailed instructions about environment and data setups.
24 |
25 | ## What is Agent Workflow Memory? 🧠
26 | Agent Workflow Memory (AWM) proposes to induce, integrate, and utilize workflows via an agent memory.
27 | A workflow is usually a common sub-routine in solving tasks, with example-specific contexts being abstracted out.
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | AWM can operate in both offline and online settings:
36 | - *offline* (left): when additional (e.g., training) examples are available, agents induce workflows from ground-truth annotated examples
37 | - *online* (right): without any auxiliary data, agents induce workflows from past experiences on the fly.
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | ## How does AWM work? 📈
46 |
47 | ### On WebArena
48 | We achieve the state-of-the-art result -- 35.6% success rate.
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 | Check the code in `./webarena/` directory.
57 |
58 | ### On Mind2Web
59 |
60 | We also get the best scores among text-based agents. Particularly, AWM offline effectively generalizes across a wide range of tasks, websites, and domains.
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | Check the code in `./mind2web/` directory.
69 |
70 | ## Citation 📜
71 |
72 | ```bibtex
73 | @inproceedings{awm2024wang,
74 | title = {Agent Workflow Memory},
75 | author = {Wang, Zhiruo anf Mao, Jiayuan, and Fried, Daniel and Neubig, Graham},
76 | journal={arXiv preprint arXiv:2409.07429},
77 | year = {2024},
78 | }
79 | ```
80 |
--------------------------------------------------------------------------------
/assets/mind2web-results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/mind2web-results.jpg
--------------------------------------------------------------------------------
/assets/online-offline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/online-offline.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/teaser.jpg
--------------------------------------------------------------------------------
/assets/webarena-leaderboard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zorazrw/agent-workflow-memory/907d3cbafcae021fe3a4577c5a10539752e63596/assets/webarena-leaderboard.jpg
--------------------------------------------------------------------------------
/mind2web/README.md:
--------------------------------------------------------------------------------
1 | # ATM for Mind2Web
2 |
3 | ## Install
4 |
5 | ```bash
6 | pip install -r requirements.txt
7 | ```
8 |
9 | Download data from the [mind2web](https://github.com/OSU-NLP-Group/Mind2Web) project, make sure you have `test_task`, `test_website`, `test_domain`, and `train` under the `data` directory; download `scores_all_data.pkl` for HTML filtering at [[link]](https://buckeyemailosu-my.sharepoint.com/personal/deng_595_buckeyemail_osu_edu/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fdeng%5F595%5Fbuckeyemail%5Fosu%5Fedu%2FDocuments%2FMind2Web%2Fscores%5Fall%5Fdata%2Epkl&parent=%2Fpersonal%2Fdeng%5F595%5Fbuckeyemail%5Fosu%5Fedu%2FDocuments%2FMind2Web&ga=1).
10 |
11 | ## Offline Workflow Induction + Test Inference
12 |
13 | To run offline workflow induction with training examples:
14 | ```bash
15 | python offline_induction.py \
16 | --mode auto --domain Travel --subdomain Airlines --website aa \
17 | --model "gpt-4o" --output_dir "workflow"
18 | ```
19 | You can also switch to `--mode input` to dynamically input your desired website(s).
20 |
21 | The above command will produce a workflow file `workflow/aa.txt`, to augment this workflow in agent memory and run inference on test examples from the *aa* website:
22 |
23 | ```bash
24 | python run_mind2web.py --website "aa" --workflow_path "workflow/aa.txt"
25 | ```
26 |
27 | ## Online Induction with Test Queries
28 |
29 | To run online workflow induction and utilization:
30 | ```bash
31 | python pipeline.py --setup online \
32 | --benchmark "test_task" --website aa \
33 | --results_dir results/aa/workflow \
34 | --workflow_path workflow/aa.txt
35 | ```
36 |
37 | Simply change to `--benchmark 'train'` if you want to run online setting on the training (or other) queries, but remember to apply to workflow and run inference on test examples afterwards.
38 |
39 |
40 | ## Overall
41 | To run the entire pipeline for both online and offline settings, you can use:
42 | ```bash
43 | python pipeline.py --setup "offline" # or "online"
44 | ```
45 | with other arguments specified as above.
46 |
--------------------------------------------------------------------------------
/mind2web/memory.py:
--------------------------------------------------------------------------------
1 | import os, json, random
2 | import numpy as np
3 | from pathlib import Path
4 | from openai import BadRequestError
5 | from utils.env import *
6 | from utils.llm import (
7 | generate_response, num_tokens_from_messages,
8 | MAX_TOKENS, extract_from_response,
9 | )
10 |
11 | import logging
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def get_exemplars(args) -> list:
16 | """Get exemplar workflows in the prompt."""
17 | # workflow memory
18 | memory = []
19 | workflow_text = open(args.workflow_path, 'r').read().strip()
20 | if len(workflow_text):
21 | memory = [[{"role": "user", "content": workflow_text}]]
22 |
23 | # concrete examples
24 | with open(os.path.join(args.memory_path, "exemplars.json"), "r") as f:
25 | concrete_examples = json.load(f)
26 | if any([args.website in cex[0].get("specifier", "") for cex in concrete_examples]):
27 | concrete_examples = [
28 | cex for cex in concrete_examples
29 | if all([tag in cex[0]["specifier"] for tag in [args.domain, args.subdomain, args.website]])
30 | ]
31 | elif any([args.subdomain in cex[0].get("specifier", "") for cex in concrete_examples]):
32 | concrete_examples = [
33 | cex for cex in concrete_examples
34 | if all([tag in cex[0]["specifier"] for tag in [args.domain, args.subdomain]])
35 | ]
36 |
37 | memory += random.sample(concrete_examples,
38 | min(args.retrieve_top_k, len(concrete_examples)))
39 | memory = [[{k:v for k,v in m.items() if k!="specifier"} for m in e] for e in memory]
40 | return memory
41 |
42 |
43 | def eval_sample(task_id, args, sample):
44 | # initialize metrics
45 | element_acc, action_f1, step_success, success = [], [], [], []
46 | token_stats = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
47 | conversation = []
48 | episode_length = len(sample["action_reprs"])
49 |
50 | exemplars = get_exemplars(args)
51 | # print(exemplars)
52 |
53 | sys_message = [
54 | {
55 | "role": "system",
56 | "content": "You are a large language model trained to navigate the web. Output the next action and wait for the next observation. Here is the action space:\n1. `CLICK [id]`: Click on an HTML element with its id.\n2. `TYPE [id] [value]`: Type a string into the element with the id.\n3. `SELECT [id] [value]`: Select a value for an HTML element by its id.",
57 | }
58 | ]
59 |
60 | prev_actions, prev_obs = [], []
61 | previous_k = 5
62 |
63 | for s, act_repr in zip(sample["actions"], sample["action_reprs"]):
64 | _, target_act = get_target_obs_and_act(s)
65 | pos_candidates = [
66 | c for c in s["pos_candidates"] if c["rank"] < args.top_k_elements
67 | ]
68 |
69 | # get query, obs, act
70 | target_obs, _ = get_top_k_obs(s, args.previous_top_k_elements)
71 | # Continue next loop if the ground truth element is not in the cleaned html
72 | if len(pos_candidates) == 0:
73 | element_acc.append(0)
74 | action_f1.append(0)
75 | step_success.append(0)
76 | prev_obs.append("Observation: `" + target_obs + "`")
77 | prev_actions.append("Action: `" + target_act + "` (" + act_repr + ")")
78 | conversation.append("The ground truth element is not in cleaned html")
79 | continue
80 |
81 | # construct query
82 | query = []
83 | for o, a in zip(prev_obs, prev_actions):
84 | if len(query) == 0:
85 | query.append({
86 | "role": "user",
87 | "content": f"Task: {sample['confirmed_task']}\nTrajectory:\n" + o,
88 | })
89 | else:
90 | query.append({"role": "user", "content": o})
91 | query.append({"role": "assistant", "content": a})
92 |
93 | obs, _ = get_top_k_obs(s, args.top_k_elements, use_raw=False)
94 | if len(query) == 0:
95 | query.append({
96 | "role": "user",
97 | "content": f"Task: {sample['confirmed_task']}\nTrajectory:\n"
98 | + "Observation: `" + obs + "`",
99 | })
100 | else:
101 | query.append({"role": "user", "content": "Observation: `" + obs + "`"})
102 |
103 | prev_obs.append("Observation: `" + target_obs + "`")
104 | prev_actions.append("Action: `" + target_act + "` (" + act_repr + ")")
105 |
106 | # token limit
107 | total_num_tokens = num_tokens_from_messages(sys_message + query, args.model)
108 | if total_num_tokens > MAX_TOKENS[args.model]:
109 | logger.info(
110 | f"Too many tokens in acting ({total_num_tokens} / {MAX_TOKENS[args.model]}), skipping..."
111 | )
112 | element_acc.append(0)
113 | action_f1.append(0)
114 | step_success.append(0)
115 | conversation.append(
116 | {
117 | "input": sys_message + query,
118 | "output": f"FAILED DUE TO THE CONTEXT LIMIT: {total_num_tokens}",
119 | }
120 | )
121 | continue
122 |
123 | # message
124 | demo_message = []
125 | for e_id, e in enumerate(exemplars):
126 | total_num_tokens = num_tokens_from_messages(
127 | sys_message + demo_message + e + query, args.model
128 | )
129 | if total_num_tokens > MAX_TOKENS[args.model]:
130 | logger.info(
131 | f"Using {e_id} / {len(exemplars)} exemplars due to context limit"
132 | )
133 | break
134 | else:
135 | demo_message.extend(e)
136 |
137 | message = sys_message + demo_message + query
138 | try:
139 | response, info = generate_response(
140 | messages=message,
141 | model=args.model,
142 | temperature=args.temperature,
143 | stop_tokens=["Task:", "obs:"],
144 | )
145 | except BadRequestError:
146 | response = ""
147 | info = {
148 | "prompt_tokens": 0,
149 | "completion_tokens": 0,
150 | "total_tokens": 0,
151 | }
152 | conversation.append({"input": message, "output": response, "token_stats": info})
153 | for k, v in info.items():
154 | token_stats[k] += v
155 | pred_act = extract_from_response(response, "`")
156 | pred_op, pred_id, pred_val = parse_act_str(pred_act)
157 | target_op, _, target_val = parse_act_str(target_act)
158 |
159 | # calculate metrics
160 | pos_ids = [c["backend_node_id"] for c in s["pos_candidates"]][:1]
161 | if pred_id in pos_ids:
162 | element_acc.append(1)
163 | else:
164 | element_acc.append(0)
165 | action_f1.append(
166 | calculate_f1(
167 | construct_act_str(pred_op, pred_val),
168 | construct_act_str(target_op, target_val),
169 | )
170 | )
171 | conversation.append({"pred_act": pred_act, "target_act": target_act})
172 | if pred_act == target_act:
173 | step_success.append(1)
174 | else:
175 | step_success.append(0)
176 |
177 | # check the last episode_length of step_success, if all 1, then success = 1
178 | if np.sum(step_success[-episode_length:]) == episode_length:
179 | success.append(1)
180 | else:
181 | success.append(0)
182 |
183 | conversation.append(
184 | {
185 | "element_acc": element_acc,
186 | "action_f1": action_f1,
187 | "step_success": step_success,
188 | "success": success,
189 | }
190 | )
191 | log_dir = Path(f"{args.log_dir}/{args.model}/{args.benchmark}/{args.website}/{args.suffix}")
192 | log_dir.mkdir(parents=True, exist_ok=True)
193 | with open(os.path.join(log_dir, f"{task_id}.json"), "w") as f:
194 | json.dump(conversation, f, indent=2)
195 |
--------------------------------------------------------------------------------
/mind2web/offline_induction.py:
--------------------------------------------------------------------------------
1 | """Induce Website-Specific Workflows Offline from Training Examples."""
2 |
3 | import os
4 | import json
5 | import pickle
6 | import argparse
7 | from utils.data import add_scores, format_examples, filter_workflows
8 |
9 | import openai
10 | openai.api_key = os.environ["OPENAI_API_KEY"]
11 | from openai import OpenAI
12 | client = OpenAI()
13 |
14 | # %% Data loading and processing
15 | def get_data_dict(paths: list[str]) -> dict:
16 | """Create dict for examples in domain-subdomain-website hierarchy.
17 | Args:
18 | paths: list[str], list of data path strings
19 | Rets:
20 | data_dict: dict[str, dict], (domain, subdomain, website) dict
21 | """
22 | print("Start loading data files...")
23 | data_dict = {}
24 | for p in paths:
25 | print(p)
26 | data = json.load(open(p, 'r'))
27 | for ex in data:
28 | domain, subdomain, website = ex["domain"], ex["subdomain"], ex["website"]
29 | if domain not in data_dict:
30 | data_dict[domain] = {}
31 | if subdomain not in data_dict[domain]:
32 | data_dict[domain][subdomain] = {}
33 | if website not in data_dict[domain][subdomain]:
34 | data_dict[domain][subdomain][website] = []
35 | data_dict[domain][subdomain][website].append(ex)
36 | print(f"Finished loading {len(paths)} files!")
37 | return data_dict
38 |
39 | def get_split(data_dict: dict) -> dict:
40 | """Return the split from the data dict from inputted option."""
41 | options = list(data_dict.keys())
42 | split = input(f"Select from {options} >> ")
43 | while split not in options:
44 | split = input(f"Select from {options} >> ")
45 | return split, data_dict[split]
46 |
47 | def get_examples(data_dict: dict, tags: tuple[str, str, str]) -> list[dict]:
48 | """Return the examples satisfying the tags."""
49 | domain, subdomain, website = tags
50 | return data_dict[domain][subdomain][website]
51 |
52 |
53 | # %% Prompt and generate
54 | def llm_generate(tags: tuple[str, str, str], examples: list[dict], args, verbose: bool = False):
55 | """Call gpt model to generate workflows."""
56 | prompt = f"Website: " + ','.join(tags) + '\n'
57 | prompt += format_examples(examples, args.prefix, args.suffix)
58 | prompt = '\n\n'.join([args.INSTRUCTION, args.ONE_SHOT, prompt])
59 | if verbose: print("Prompt:\n", prompt, '\n\n')
60 | response = client.chat.completions.create(
61 | model=args.model_name,
62 | messages=[{"role": "user", "content": prompt}],
63 | temperature=args.temperature,
64 | max_tokens=1024,
65 | )
66 | response = response.choices[0].message.content
67 | if verbose: print(response)
68 | return response
69 |
70 | # %% Save outputs
71 | def save_to_txt(text: str, args):
72 | """Save text to a .txt file."""
73 | output_name = f"{args.website.lower()}_{args.output_suffix}.txt" \
74 | if args.output_suffix is not None else f"{args.website}.txt"
75 | output_path = os.path.join(args.output_dir, output_name)
76 | with open(output_path, 'w') as fw:
77 | fw.write(text)
78 |
79 |
80 | # %% Main pipeline
81 | def main():
82 | # load data into dict
83 | data_paths = [os.path.join(args.data_dir, f) for f in os.listdir(args.data_dir)]
84 | data_dict = get_data_dict(paths=data_paths)
85 |
86 | # load candidate scores and ranks
87 | with open(os.path.join("data", "scores_all_data.pkl"), "rb") as f:
88 | candidate_results = pickle.load(f)
89 |
90 | # load prompt contexts
91 | args.INSTRUCTION = open(args.instruction_path, 'r').read()
92 | args.ONE_SHOT = open(args.one_shot_path, 'r').read()
93 |
94 | def single_website_loop(tags: tuple[str, str, str]):
95 | """Pipeline to induce, filter, and save workflows on a single website."""
96 | examples = get_examples(data_dict, tags=tags)
97 | print(f"Split {tags} with #{len(examples)} examples")
98 | add_scores(examples, candidate_results)
99 | response = llm_generate(tags, examples, args)
100 | workflows = filter_workflows(response, website=tags[-1])
101 | save_to_txt(workflows, args)
102 |
103 | if args.mode == "auto":
104 | single_website_loop(args.tags)
105 | elif args.mode == "input":
106 | stop = False
107 | while not stop:
108 | # select split
109 | args.domain, domain_dict = get_split(data_dict)
110 | args.subdomain, subdomain_dict = get_split(domain_dict)
111 | args.website, examples = get_split(subdomain_dict)
112 |
113 | # generate workflows
114 | tags = [args.domain, args.subdomain, args.website]
115 | single_website_loop(tags)
116 |
117 | if input("Stop? [y/n] ").strip() == 'y':
118 | stop = True
119 | else:
120 | raise ValueError("Please enter a valid `mode` ('input' or 'auto')")
121 |
122 |
123 | if __name__ == "__main__":
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument("--data_dir", type=str, default="data/train")
126 | parser.add_argument("--output_dir", type=str, default="workflow")
127 | parser.add_argument("--output_suffix", type=str, default=None)
128 | parser.add_argument("--verbose", action="store_true",
129 | help="Whether to print prompt and response.")
130 |
131 | # mode
132 | parser.add_argument("--mode", type=str, default="input",
133 | choices=["input", "auto"])
134 | parser.add_argument("--domain", type=str, default=None,
135 | help="Specify in 'auto' mode.")
136 | parser.add_argument("--subdomain", type=str, default=None,
137 | help="Specify in 'auto' mode.")
138 | parser.add_argument("--website", type=str, default=None,
139 | help="Specify in 'auto' mode.")
140 | # model
141 | parser.add_argument("--model_name", type=str, default="gpt-4o")
142 | parser.add_argument("--temperature", type=float, default=0.0)
143 |
144 | # prompt
145 | parser.add_argument("--instruction_path", type=str, default="prompt/instruction_action.txt")
146 | parser.add_argument("--one_shot_path", type=str, default="prompt/one_shot_action.txt")
147 | parser.add_argument("--prefix", type=str, default=None)
148 | parser.add_argument("--suffix", type=str, default="# Summary Workflows")
149 |
150 | args = parser.parse_args()
151 |
152 | # sanity check
153 | if args.mode == "auto":
154 | args.tags = [args.domain, args.subdomain, args.website]
155 | assert not any([tag is None for tag in args.tags])
156 |
157 | main()
158 |
--------------------------------------------------------------------------------
/mind2web/online_induction.py:
--------------------------------------------------------------------------------
1 | """Induce Workflows from Past Agent Experiences."""
2 |
3 | import os
4 | import json
5 | import argparse
6 | from utils.data import load_json, format_examples, filter_workflows
7 |
8 | import openai
9 | openai.api_key = os.environ["OPENAI_API_KEY"]
10 | from openai import OpenAI
11 | client = OpenAI()
12 |
13 |
14 | def is_io_dict(item: dict | str) -> bool:
15 | if isinstance(item, dict) and ("input" in item) and ("output" in item): return True
16 | return False
17 |
18 | def get_trajectory(path: str):
19 | trajectory = []
20 | result = json.load(open(path, 'r'))
21 | for item in result:
22 | if not is_io_dict(item): continue
23 | step = {
24 | "env": "# " + item["input"][-1]["content"],
25 | "action": item["output"],
26 | }
27 | trajectory.append(step)
28 | return trajectory
29 |
30 |
31 | def main():
32 | samples = load_json(args.data_dir, args.benchmark)
33 | print(f"Loaded #{len(samples)} test examples")
34 | samples = [s for s in samples if s["website"] == args.website]
35 | print(f"Filtering down to #{len(samples)} examples on website [{args.website}]")
36 |
37 | # load model predictions and format examples
38 | result_files = [os.path.join(args.results_dir, f) for f in os.listdir(args.results_dir)]
39 | result_list = [get_trajectory(rf) for rf in result_files]
40 | examples = []
41 | for r, s in zip(result_list, samples):
42 | examples.append({
43 | "confirmed_task": s["confirmed_task"],
44 | "action_reprs": [step["env"] + '\n' + step["action"] for step in r],
45 | })
46 | prompt = format_examples(examples, args.prefix, args.suffix)
47 |
48 | # transform to workflows
49 | INSTRUCTION = open(args.instruction_path, 'r').read()
50 | ONE_SHOT = open(args.one_shot_path, 'r').read()
51 | domain, subdomain, website = samples[0]["domain"], samples[0]["subdomain"], samples[0]["website"]
52 | prompt = '\n\n'.join([INSTRUCTION, ONE_SHOT, f"Website: {domain}, {subdomain}, {website}\n{prompt}"])
53 | response = client.chat.completions.create(
54 | model=args.model_name,
55 | messages=[{"role": "user", "content": prompt}],
56 | temperature=args.temperature,
57 | ).choices[0].message.content
58 | response = filter_workflows(response, args.website)
59 |
60 | # save to file
61 | with open(args.output_path, 'w') as fw:
62 | fw.write(response)
63 |
64 |
65 | if __name__ == "__main__":
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument("--data_dir", type=str, default="data")
68 | parser.add_argument("--benchmark", type=str, default="test_task",
69 | choices=["test_task", "test_website", "test_domain", "train"])
70 | parser.add_argument("--website", type=str, required=True)
71 | parser.add_argument("--results_dir", type=str, required=True)
72 | parser.add_argument("--output_path", type=str, required=True)
73 |
74 | # model
75 | parser.add_argument("--model_name", type=str, default="gpt-4o")
76 | parser.add_argument("--temperature", type=str, default=0.0)
77 | # prompt
78 | parser.add_argument("--instruction_path", type=str, default="prompt/instruction_action.txt")
79 | parser.add_argument("--one_shot_path", type=str, default="prompt/one_shot_action.txt")
80 | parser.add_argument("--prefix", type=str, default=None)
81 | parser.add_argument("--suffix", type=str, default="# Summary Workflows")
82 |
83 | args = parser.parse_args()
84 |
85 | main()
86 |
--------------------------------------------------------------------------------
/mind2web/pipeline.py:
--------------------------------------------------------------------------------
1 | """Online Induction and Workflow Utilization Pipeline."""
2 |
3 | import argparse
4 | import subprocess
5 | from utils.data import load_json
6 |
7 | def offline():
8 | # workflow induction
9 | process = subprocess.Popen([
10 | 'python', 'offline_induction.py',
11 | '--mode', 'auto', '--website', args.website,
12 | '--domain', args.domain, '--subdomain', args.subdomain,
13 | '--model', args.model, '--output_dir', "workflow",
14 | '--instruction_path', args.instruction_path,
15 | '--one_shot_path', args.one_shot_path,
16 | ])
17 | process.wait()
18 |
19 | # test inference
20 | process = subprocess.Popen([
21 | 'python', 'run_mind2web.py',
22 | '--website', args.website,
23 | '--workflow_path', f"workflow/{args.website}.txt"
24 | ])
25 | process.wait()
26 |
27 |
28 | def online():
29 | # load all examples for streaming
30 | samples = load_json(args.data_dir, args.benchmark)
31 | print(f"Loaded #{len(samples)} test examples")
32 | if args.website is not None:
33 | samples = [s for s in samples if s["website"] == args.website]
34 | print(f"Filtering down to #{len(samples)} examples on website [{args.website}]")
35 | n = len(samples)
36 |
37 | for i in range(0, n, args.induce_steps):
38 | j = min(n, i + args.induce_steps)
39 | print(f"Running inference on {i}-{j} th example..")
40 |
41 | process = subprocess.Popen([
42 | 'python', 'run_mind2web.py',
43 | '--benchmark', args.benchmark,
44 | '--workflow_path', args.workflow_path,
45 | '--website', args.website,
46 | '--start_idx', f'{i}', '--end_idx', f'{j}',
47 | '--domain', args.domain, '--subdomain', args.subdomain,
48 | ])
49 | process.wait()
50 | print(f"Finished inference on {i}-{j} th example!\n")
51 |
52 | if (j + 1) < len(samples):
53 | process = subprocess.Popen([
54 | 'python', 'online_induction.py',
55 | '--benchmark', args.benchmark,
56 | '--website', args.website,
57 | '--results_dir', args.results_dir,
58 | '--output_path', args.workflow_path,
59 | ])
60 | process.wait()
61 | print(f"Finished workflow induction with 0-{i} th examples!\n")
62 |
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | # examples
67 | parser.add_argument("--data_dir", type=str, default="data")
68 | parser.add_argument("--benchmark", type=str, default="test_task",
69 | choices=["test_task", "test_website", "test_domain", "train"])
70 | parser.add_argument("--website", type=str, required=True)
71 | parser.add_argument("--domain", type=str, default=None)
72 | parser.add_argument("--subdomain", type=str, default=None)
73 |
74 | # results and workflows
75 | parser.add_argument("--results_dir", type=str, default=None)
76 | parser.add_argument("--workflow_path", type=str, default=None)
77 |
78 | # prompt
79 | parser.add_argument("--instruction_path", type=str, default="prompt/instruction_action.txt")
80 | parser.add_argument("--one_shot_path", type=str, default="prompt/one_shot_action.txt")
81 | parser.add_argument("--prefix", type=str, default=None)
82 | parser.add_argument("--suffix", type=str, default="# Summary Workflows")
83 |
84 | # gpt
85 | parser.add_argument("--model", type=str, default="gpt-4o")
86 | parser.add_argument("--temperature", type=str, default=0.0)
87 |
88 | # induction frequency
89 | parser.add_argument("--induce_steps", type=int, default=1)
90 |
91 | # setup
92 | parser.add_argument("--setup", type=str, required=True,
93 | choices=["online", "offline"])
94 |
95 | args = parser.parse_args()
96 |
97 | if args.setup == "online":
98 | assert (args.results_dir is not None) and (args.workflow_path is not None)
99 | online()
100 | elif args.setup == "offline":
101 | assert (args.domain is not None) and (args.subdomain is not None)
102 | offline()
103 |
--------------------------------------------------------------------------------
/mind2web/prompt/instruction_abstract.txt:
--------------------------------------------------------------------------------
1 | Given a list of web nagivation tasks, your task is to extract the common workflows to solve these tasks.
2 | Each given task contains a natural language instruction, and a series of actions to solve the task. You need to find the repetitive subset of actions across multiple tasks, and extract each of them out as a workflow.
3 | Each workflow should be a commonly-reused sub-routine of the tasks. Do not generate similar or overlapping workflows. Each workflow should have at least two steps. Represent the non-fixed elements (input text, button strings) with descriptive variable names as shown in the example.
--------------------------------------------------------------------------------
/mind2web/prompt/instruction_action.txt:
--------------------------------------------------------------------------------
1 | Given a list of web nagivation tasks, your task is to extract the common workflows to solve these tasks.
2 | Each given task contains a natural language instruction, and a series of actions to solve the task. You need to find the repetitive subset of actions across multiple tasks, and extract each of them out as a workflow.
3 | Each workflow should be a commonly-reused sub-routine of the tasks. Do not generate similar or overlapping workflows. Each workflow should have at least two steps.
--------------------------------------------------------------------------------
/mind2web/prompt/one_shot_abstract.txt:
--------------------------------------------------------------------------------
1 | Website: Travel, Airlines, delta
2 | ## Query 1: Find flights from Seattle to New York on June 5th and only show those that can be purchased with miles.
3 | Actions:
4 | [link] From Departure Airport or City Your Origin -> CLICK
5 | [textbox] Origin City or Airport -> TYPE: Seattle
6 | [link] SEA Seattle, WA -> CLICK
7 | [link] To Destination Airport or City Your Destination -> CLICK
8 | [textbox] Destination City or Airport -> TYPE: New York
9 | [link] NYC New York City Area Airports, NY -> CLICK
10 | [combobox] Trip Type:, changes will reload the page -> CLICK
11 | [option] One Way -> CLICK
12 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK
13 | [link] Next -> CLICK
14 | [link] 5 June 2023, Monday -> CLICK
15 | [button] done -> CLICK
16 | [label] Shop with Miles -> CLICK
17 | [button] SUBMIT -> CLICK
18 |
19 | ## Query 2: Find my trip with confirmation number SFTBAO including first and last name Joe Lukeman
20 | Actions:
21 | [tab] MY TRIPS -> CLICK
22 | [combobox] Find Your Trip By -> CLICK
23 | [option] Confirmation Number -> CLICK
24 | [input] -> TYPE: SFTBAO
25 | [input] -> TYPE: Joe
26 | [input] -> TYPE: Lukeman
27 | [button] SUBMIT -> CLICK
28 |
29 | ## Query 3: Find the status of March 25 flights from New York airports to Columbus in Ohio.
30 | Actions:
31 | [tab] FLIGHT STATUS -> CLICK
32 | [button] Search by date required selected as 19 March 202... -> CLICK
33 | [link] 25 March 2023, Saturday -> CLICK
34 | [button] done -> CLICK
35 | [link] Depart City required From -> CLICK
36 | [textbox] Origin City or Airport -> TYPE: New York
37 | [link] NYC New York City Area Airports, NY -> CLICK
38 | [link] Arrival City required To -> CLICK
39 | [textbox] Destination City or Airport -> TYPE: Columbus
40 | [li] CMH -> CLICK
41 | [button] SUBMIT -> CLICK
42 |
43 | ## Query 4: Check all available one way flights for a single passenger from Manhattan to Philadelphia on May 23rd in first class.
44 | Actions:
45 | [link] From Departure Airport or City Your Origin -> CLICK
46 | [textbox] Origin City or Airport -> TYPE: Manhattan
47 | [link] MHK Manhattan Regl, USA -> CLICK
48 | [link] To Destination Airport or City Your Destination -> CLICK
49 | [textbox] Destination City or Airport -> TYPE: Philadelphia
50 | [link] PHL Philadelphia, PA -> CLICK
51 | [combobox] Trip Type:, changes will reload the page -> CLICK
52 | [option] One Way -> CLICK
53 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK
54 | [link] 23 March 2023, Thursday -> CLICK
55 | [button] done -> CLICK
56 | [link] Advanced Search -> CLICK
57 | [combobox] Best Fares For -> CLICK
58 | [option] First Class -> CLICK
59 | [button] SUBMIT -> CLICK
60 |
61 | Extracted Workflows:
62 | # enter_flight_locations
63 | Given that you are on the Delta flight booking page, this workflow enters the departure and destination city/airport for your flight.
64 | [link] {link to enter departure city} -> CLICK
65 | [textbox] {textbox to input departure city} -> TYPE: {your-origin-city}
66 | [link] {best-popup-option} -> CLICK
67 | [link] {link to enter destination city} -> CLICK
68 | [textbox] {textbox to enter destination city} -> TYPE: {your-destination-city}
69 | [link] {best-popup-option} -> CLICK
70 |
71 | # select_oneway_trip
72 | Given that you are on the Delta flight booking page, this workflow changes the flight to be one-way.
73 | [combobox] {option to select trip type} -> CLICK
74 | [option] One Way -> CLICK
75 |
76 | # select_date_for_travel
77 | Given that you are on the Delta flight booking page, this workflow selects the travel date.
78 | [button] {calendar to select flight dates} -> CLICK
79 | [link] {travel-date} -> CLICK
80 | [button] done -> CLICK
81 |
82 | # find_trip
83 | Given that you are on the Delta flight searching page, this workflow finds a trip with the confirmation number and passenger name.
84 | [tab] MY TRIPS -> CLICK
85 | [combobox] {button to instantiate search} -> CLICK
86 | [option] Confirmation Number -> CLICK
87 | [input] -> TYPE: {confirmation-number}
88 | [input] -> TYPE: {passenger-name}
89 | [button] SUBMIT -> CLICK
--------------------------------------------------------------------------------
/mind2web/prompt/one_shot_action.txt:
--------------------------------------------------------------------------------
1 | Website: Travel, Airlines, delta
2 | ## Query 1: Find flights from Seattle to New York on June 5th and only show those that can be purchased with miles.
3 | Actions:
4 | [link] From Departure Airport or City Your Origin -> CLICK
5 | [textbox] Origin City or Airport -> TYPE: Seattle
6 | [link] SEA Seattle, WA -> CLICK
7 | [link] To Destination Airport or City Your Destination -> CLICK
8 | [textbox] Destination City or Airport -> TYPE: New York
9 | [link] NYC New York City Area Airports, NY -> CLICK
10 | [combobox] Trip Type:, changes will reload the page -> CLICK
11 | [option] One Way -> CLICK
12 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK
13 | [link] Next -> CLICK
14 | [link] 5 June 2023, Monday -> CLICK
15 | [button] done -> CLICK
16 | [label] Shop with Miles -> CLICK
17 | [button] SUBMIT -> CLICK
18 |
19 | ## Query 2: Find my trip with confirmation number SFTBAO including first and last name Joe Lukeman
20 | Actions:
21 | [tab] MY TRIPS -> CLICK
22 | [combobox] Find Your Trip By -> CLICK
23 | [option] Confirmation Number -> CLICK
24 | [input] -> TYPE: SFTBAO
25 | [input] -> TYPE: Joe
26 | [input] -> TYPE: Lukeman
27 | [button] SUBMIT -> CLICK
28 |
29 | ## Query 3: Find the status of March 25 flights from New York airports to Columbus in Ohio.
30 | Actions:
31 | [tab] FLIGHT STATUS -> CLICK
32 | [button] Search by date required selected as 19 March 202... -> CLICK
33 | [link] 25 March 2023, Saturday -> CLICK
34 | [button] done -> CLICK
35 | [link] Depart City required From -> CLICK
36 | [textbox] Origin City or Airport -> TYPE: New York
37 | [link] NYC New York City Area Airports, NY -> CLICK
38 | [link] Arrival City required To -> CLICK
39 | [textbox] Destination City or Airport -> TYPE: Columbus
40 | [li] CMH -> CLICK
41 | [button] SUBMIT -> CLICK
42 |
43 | ## Query 4: Check all available one way flights for a single passenger from Manhattan to Philadelphia on May 23rd in first class.
44 | Actions:
45 | [link] From Departure Airport or City Your Origin -> CLICK
46 | [textbox] Origin City or Airport -> TYPE: Manhattan
47 | [link] MHK Manhattan Regl, USA -> CLICK
48 | [link] To Destination Airport or City Your Destination -> CLICK
49 | [textbox] Destination City or Airport -> TYPE: Philadelphia
50 | [link] PHL Philadelphia, PA -> CLICK
51 | [combobox] Trip Type:, changes will reload the page -> CLICK
52 | [option] One Way -> CLICK
53 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK
54 | [link] 23 March 2023, Thursday -> CLICK
55 | [button] done -> CLICK
56 | [link] Advanced Search -> CLICK
57 | [combobox] Best Fares For -> CLICK
58 | [option] First Class -> CLICK
59 | [button] SUBMIT -> CLICK
60 |
61 |
62 | Summary Workflows:
63 | ## enter_flight_locations
64 | Given that you are on the Delta flight booking page, this workflow enters the departure and destination city/airport for your flight.
65 | [link] From Departure Airport or City Your Origin -> CLICK
66 | [textbox] Origin City or Airport -> TYPE: {your-origin-city}
67 | [link] {best-popup-option} -> CLICK
68 | [link] To Destination Airport or City Your Destination -> CLICK
69 | [textbox] Destination City or Airport -> TYPE: {your-destination-city}
70 | [link] {best-popup-option} -> CLICK
71 |
72 | ## select_oneway_trip
73 | Given that you are on the Delta flight booking page, this workflow changes the flight to be one-way.
74 | [combobox] Trip Type:, changes will reload the page -> CLICK
75 | [option] One Way -> CLICK
76 |
77 | ## select_date_for_travel
78 | Given that you are on the Delta flight booking page, this workflow selects the travel date.
79 | [button] Depart and Return Calendar Use enter to open, es... -> CLICK
80 | [link] {travel-date} -> CLICK
81 | [button] done -> CLICK
82 |
83 | ## find_trip
84 | Given that you are on the Delta flight searching page, this workflow finds a trip with the confirmation number and passenger name.
85 | [tab] MY TRIPS -> CLICK
86 | [combobox] Find Your Trip By -> CLICK
87 | [option] Confirmation Number -> CLICK
88 | [input] -> TYPE: {confirmation-number}
89 | [input] -> TYPE: {passenger-name}
90 | [button] SUBMIT -> CLICK
--------------------------------------------------------------------------------
/mind2web/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | numpy
3 | openai
4 | backoff==2.2.1
5 | tiktoken
6 | lxml==4.9.3
7 | matplotlib
--------------------------------------------------------------------------------
/mind2web/results/calc_score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import matplotlib.pyplot as plt
5 |
6 | def get_average(score_list: list[float], percentage: bool = False) -> float:
7 | score = sum(score_list) / len(score_list)
8 | return score * 100 if percentage else score
9 |
10 |
11 | def main():
12 | files = os.listdir(args.results_dir)
13 | file_paths = [os.path.join(args.results_dir, f) for f in files]
14 | ele_acc, act_f1, step_sr, sr = [], [], [], []
15 | for fp in file_paths:
16 | res = json.load(open(fp, 'r'))[-1]
17 | ele_acc.append(get_average(res["element_acc"]))
18 | act_f1.append(get_average(res["action_f1"]))
19 | step_sr.append(get_average(res["step_success"]))
20 | sr.append(get_average(res["success"]))
21 |
22 | print(f"Element Acc: {get_average(ele_acc, True):5.1f}")
23 | print(f"Action F1 : {get_average(act_f1, True):5.1f}")
24 | print(f"Step SR : {get_average(step_sr, True):5.1f}")
25 | print(f"SR : {get_average(sr, True):5.1f}")
26 |
27 | # accumulative step success rate
28 | n = len(step_sr)
29 | x = [i+1 for i in range(n)]
30 | asr = [get_average(step_sr[:i+1]) for i in range(n)]
31 | plt.plot(x, asr)
32 |
33 | # moving average
34 | # window_size = 5
35 | # x, mavg = [], []
36 | # for i in range(n-window_size+1):
37 | # x.append(i)
38 | # mavg.append(get_average(step_sr[i:i+window_size]))
39 | # plt.plot(x, mavg)
40 |
41 | if args.viz_path is not None:
42 | plt.savefig(args.viz_path)
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("--results_dir", type=str, required=True)
48 | parser.add_argument("--viz_path", type=str, default=None)
49 | args = parser.parse_args()
50 |
51 | main()
52 |
--------------------------------------------------------------------------------
/mind2web/run_mind2web.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from memory import eval_sample
5 | from utils.data import load_json, add_scores
6 |
7 | import logging
8 | logger = logging.getLogger("atm")
9 | logger.setLevel(logging.INFO)
10 | handler = logging.StreamHandler()
11 | handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
12 | logger.addHandler(handler)
13 |
14 |
15 | def main():
16 | examples = load_json(args.data_dir, args.benchmark)
17 | examples = [s for s in examples if s["website"] == args.website]
18 | print(f"Filtering down to #{len(examples)} examples on website [{args.website}]")
19 | examples = add_scores(examples) # add prediction scores and ranks to elements
20 |
21 | if args.end_idx is None:
22 | args.end_idx = len(examples)
23 | for i in tqdm(range(args.start_idx, args.end_idx)):
24 | if args.mode == "memory":
25 | eval_sample(i, args, examples[i])
26 | elif args.mode == "action":
27 | raise NotImplementedError
28 | else:
29 | raise ValueError(f"Unsupported workflow format: {args.workflow_format}")
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument("--data_dir", type=str, default="data")
35 | parser.add_argument("--benchmark", type=str, default="test_task",
36 | choices=["test_task", "test_website", "test_domain", "train"])
37 | parser.add_argument("--memory_path", type=str, default="data/memory")
38 | parser.add_argument("--log_dir", type=str, default="results")
39 |
40 | # model
41 | parser.add_argument("--model", type=str, default="gpt-4")
42 | parser.add_argument("--temperature", type=float, default=0.0)
43 |
44 | # env context
45 | parser.add_argument("--previous_top_k_elements", type=int, default=3)
46 | parser.add_argument("--top_k_elements", type=int, default=5)
47 | parser.add_argument("--retrieve_top_k", type=int, default=1)
48 |
49 | # workflow
50 | parser.add_argument("--website", type=str, required=True)
51 | parser.add_argument("--domain", type=str, default=None)
52 | parser.add_argument("--subdomain", type=str, default=None)
53 | parser.add_argument("--workflow_path", type=str, required=True)
54 | parser.add_argument("--suffix", type=str, default="workflow")
55 |
56 | # ablation
57 | parser.add_argument("--mode", type=str, default="memory", choices=["memory", "action"])
58 | parser.add_argument("--start_idx", type=int, default=0, help="Select example index.")
59 | parser.add_argument("--end_idx", type=int, default=None, help="Select example index.")
60 |
61 | args = parser.parse_args()
62 |
63 | # sanity check
64 | if not os.path.exists(args.workflow_path): open(args.workflow_path, 'w').close()
65 | if args.retrieve_top_k != 1: print(f"Suggest set `retrieve_top_k` to 1, currently as {args.retrieve_top_k}")
66 |
67 | main()
68 |
--------------------------------------------------------------------------------
/mind2web/utils/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 |
5 | # %% load data
6 | def load_json(data_dir, folder_name):
7 | folder_path = os.path.join(data_dir, folder_name)
8 | print(f"Data path: {folder_path}")
9 | data_paths = [
10 | os.path.join(folder_path, file)
11 | for file in os.listdir(folder_path)
12 | if file.endswith(".json")
13 | ]
14 | data_paths = sorted(data_paths, key=lambda x: int(x.split("_")[-1].split(".")[0]))
15 |
16 | # Construct trajectory dataset
17 | samples = []
18 | for data_path in data_paths:
19 | with open(data_path, "r") as f:
20 | samples.extend(json.load(f))
21 | print("# of samples:", len(samples))
22 |
23 | return samples
24 |
25 |
26 | def add_scores(
27 | examples: list[dict], candidate_results: dict = None,
28 | score_path: str = "data/scores_all_data.pkl"
29 | ):
30 | """Add prediction scores and ranks to candidate elements."""
31 | if candidate_results is None:
32 | with open(score_path, "rb") as f:
33 | candidate_results = pickle.load(f)
34 |
35 | for sample in examples:
36 | for s, act_repr in zip(sample["actions"], sample["action_reprs"]):
37 | sample_id = f"{sample['annotation_id']}_{s['action_uid']}"
38 | for candidates in [s["pos_candidates"], s["neg_candidates"]]:
39 | for candidate in candidates:
40 | candidate_id = candidate["backend_node_id"]
41 | candidate["score"] = candidate_results["scores"][sample_id][candidate_id]
42 | candidate["rank"] = candidate_results["ranks"][sample_id][candidate_id]
43 |
44 | return examples
45 |
46 |
47 | # %% workflow induction
48 | def format_examples(examples: list[dict], prefix: str = None, suffix: str = None) -> str:
49 | lines = []
50 | for i, ex in enumerate(examples):
51 | lines.append(f"Query #{i+1}: {ex['confirmed_task']}")
52 | lines.append("Actions and Environments:")
53 | lines.extend(ex["action_reprs"])
54 | lines.append("")
55 | prompt = '\n'.join(lines)
56 | if prefix is not None:
57 | prompt = prefix + '\n' + prompt
58 | if suffix is not None:
59 | prompt += '\n\n' + suffix
60 | return prompt
61 |
62 |
63 |
64 | # %% model generation
65 | def is_website_header(block: str, website: str) -> bool:
66 | lines = block.strip().split('\n')
67 | if len(lines) > 1: return False
68 | text = lines[0].strip()
69 | if text.startswith("#") and text.lower().endswith(website):
70 | return True
71 | return False
72 |
73 | def filter_workflows(text: str, website: str) -> str:
74 | blocks = text.split('\n\n')
75 | for i,b in enumerate(blocks):
76 | if is_website_header(b, website):
77 | blocks = blocks[i+1: ]
78 | break
79 |
80 | for i,b in enumerate(blocks):
81 | if is_website_header(b, "delta"):
82 | blocks = blocks[: i]
83 | break
84 |
85 | blocks = [b for b in blocks if "delta" not in b.lower()]
86 | return '\n\n'.join(blocks)
--------------------------------------------------------------------------------
/mind2web/utils/env.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import re
3 | import json
4 | import os
5 | import string
6 | import ast
7 | from lxml import etree
8 |
9 |
10 | def get_target_obs(dom_tree, target_element_ids):
11 | pruned_tree = prune_tree(dom_tree, target_element_ids)
12 | tree_repr, _ = get_tree_repr(pruned_tree, id_mapping={}, keep_html_brackets=True)
13 |
14 | return tree_repr
15 |
16 |
17 | def get_target_act(example, target_element_id):
18 | action_op = example["operation"]["op"]
19 | action_value = example["operation"]["value"]
20 | target_action = f"{action_op} [{target_element_id}]"
21 | if action_op != "CLICK":
22 | target_action += f" [{action_value}]"
23 |
24 | return target_action
25 |
26 |
27 | def parse_act_str(act_str):
28 | # Compile the regular expression pattern
29 | pattern = re.compile(r"(?:^|\s)(CLICK|SELECT|TYPE)?\s?\[(.+?)\](?:\s\[(.+?)\])?")
30 | # Search for the pattern in the string
31 | match = pattern.search(act_str)
32 | if match:
33 | # Extract the matching groups
34 | action_op = match.group(1) # This will be None if not in the list
35 | target_element_id = match.group(2)
36 | action_value = match.group(3) # This will be None if not present
37 | return action_op, target_element_id, action_value
38 | else:
39 | return None, None, None
40 |
41 |
42 |
43 | def find_node_id_by_text(example, text):
44 | dom_tree = etree.fromstring(example["cleaned_html"])
45 | node = dom_tree.xpath(f'//*[@aria_label="{text}"]')[0]
46 | return node.attrib["backend_node_id"]
47 |
48 |
49 | def parse_act_str_workflow(act_str, example):
50 | # Compile the regular expression pattern
51 | pattern = re.compile(r"(?:^|\s)(CLICK|SELECT|TYPE)?\s?\[(.+?)\](?:\s\[(.+?)\])?")
52 | # Search for the pattern in the string
53 | match = pattern.search(act_str)
54 | if match:
55 | # Extract the matching groups
56 | action_op = match.group(1) # This will be None if not in the list
57 | target_element_id = match.group(2)
58 | action_value = match.group(3) # This will be None if not present
59 | yield action_op, target_element_id, action_value
60 | elif ('(' in act_str) and (')' in act_str):
61 | print("Action Str: ", act_str)
62 | from data.workflow.code import WORKFLOW_DICT
63 | s, e = act_str.index('('), act_str.index(')')
64 | wkfl = act_str[: s].strip()
65 | print("Workflow: ", wkfl)
66 | if wkfl not in WORKFLOW_DICT: yield None, None, None
67 | args = [arg.value for arg in ast.parse(act_str).body[0].value.args]
68 | if args == []:
69 | args = act_str[s: e].split(',')
70 | for i in range(len(args)):
71 | if '=' in args[i]: args[i] = args[i].split('=')[1]
72 | print("Args: ", args)
73 | # args = [arg.strip() for arg in act_str[s+1: e].split(',')]
74 | steps = WORKFLOW_DICT[wkfl](*args)
75 | for step in steps:
76 | print("Org Step: ", step)
77 | element_name, op_name = step.split('->')
78 | s = element_name.index(']') + 1
79 | element_name = element_name[s:].strip()
80 | if ':' in op_name:
81 | op_name, arg = [item.strip() for item in op_name.split(':')]
82 | else:
83 | arg = None
84 | try:
85 | element_id = find_node_id_by_text(example, element_name)
86 | step = f"{op_name} [{element_id}]"
87 | if arg is not None:
88 | step += f" [{arg}]"
89 | step = step.replace('"', '')
90 | print("New Step: ", step)
91 | yield parse_act_str(step)
92 | except:
93 | yield None, None, None
94 |
95 | else:
96 | yield None, None, None
97 |
98 |
99 | def construct_act_str(op, val):
100 | if op is None:
101 | if val is None:
102 | return " "
103 | return " " + val
104 | if op == "CLICK" or val is None:
105 | return op + " "
106 | return f"{op} {val}"
107 |
108 |
109 | def get_target_obs_and_act(example):
110 | if len(example["pos_candidates"]) == 0:
111 | # Simplify the raw_html if pos_candidates is empty (not in the cleaned html)
112 | dom_tree = etree.fromstring(example["raw_html"])
113 | gt_element = dom_tree.xpath(
114 | f"//*[@data_pw_testid_buckeye='{example['action_uid']}']"
115 | )
116 | element_id = gt_element[0].get("backend_node_id")
117 | raw_obs = get_target_obs(dom_tree, [element_id])
118 | # Find the start index of the target element using the element ID
119 | start_idx = raw_obs.find(f"id={element_id}")
120 | # Find the start tag for the target element
121 | start_tag_idx = raw_obs.rfind("<", 0, start_idx)
122 | end_tag_idx = raw_obs.find(">", start_idx)
123 | # Extract the tag name
124 | tag_name = raw_obs[start_tag_idx + 1 : end_tag_idx].split()[0]
125 | # Initialize count for open and close tags
126 | open_count = 0
127 | close_count = 0
128 | search_idx = start_tag_idx
129 | while True:
130 | # Find the next open or close tag of the same type
131 | next_open_tag = raw_obs.find(f"<{tag_name}", search_idx)
132 | next_close_tag = raw_obs.find(f"{tag_name}>", search_idx)
133 | # No more tags found, break
134 | if next_open_tag == -1 and next_close_tag == -1:
135 | break
136 | # Decide whether the next tag is an open or close tag
137 | if next_open_tag != -1 and (
138 | next_open_tag < next_close_tag or next_close_tag == -1
139 | ):
140 | open_count += 1
141 | search_idx = raw_obs.find(">", next_open_tag) + 1
142 | else:
143 | close_count += 1
144 | search_idx = next_close_tag + len(f"{tag_name}>")
145 | # If we've closed all open tags, break
146 | if open_count == close_count:
147 | break
148 | # Extract the target element
149 | o = f" {raw_obs[start_tag_idx:search_idx]} "
150 | a = get_target_act(example, element_id)
151 | else:
152 | dom_tree = etree.fromstring(example["cleaned_html"])
153 | element_id = example["pos_candidates"][0]["backend_node_id"]
154 | o = get_target_obs(dom_tree, [element_id])
155 | a = get_target_act(example, element_id)
156 |
157 | return o, a
158 |
159 |
160 | def get_top_k_obs(s: dict, top_k: int, use_raw: bool = True) -> tuple[str, str]:
161 | # Find one positive candidate (it can be zero)
162 | pos_candidates = s["pos_candidates"]
163 | pos_ids = [c["backend_node_id"] for c in pos_candidates][:1]
164 | # Find top_k - 1 negative candidates
165 | neg_candidates = s["neg_candidates"]
166 | neg_candidates = sorted(neg_candidates, key=lambda c: c["rank"])[: top_k - 1]
167 | neg_ids = [c["backend_node_id"] for c in neg_candidates]
168 | # Prune html with all candidates
169 | all_candidates = pos_ids + neg_ids
170 | obs = get_target_obs(etree.fromstring(s["cleaned_html"]), all_candidates)
171 | # If there is no positive candidate in cleaned_html, get it from raw_html
172 | if len(s["pos_candidates"]) == 0:
173 | assert use_raw
174 | # Simplify the raw_html if pos_candidates is empty (not in the cleaned html)
175 | dom_tree = etree.fromstring(s["raw_html"])
176 | gt_element = dom_tree.xpath(f"//*[@data_pw_testid_buckeye='{s['action_uid']}']")
177 | element_id = gt_element[0].get("backend_node_id")
178 | raw_obs = get_target_obs(dom_tree, [element_id])
179 | # Find the start index of the target element using the element ID
180 | start_idx = raw_obs.find(f"id={element_id}")
181 | # Find the start tag for the target element
182 | start_tag_idx = raw_obs.rfind("<", 0, start_idx)
183 | end_tag_idx = raw_obs.find(">", start_idx)
184 | # Extract the tag name
185 | tag_name = raw_obs[start_tag_idx + 1 : end_tag_idx].split()[0]
186 | # Initialize count for open and close tags
187 | open_count = 0
188 | close_count = 0
189 | search_idx = start_tag_idx
190 | while True:
191 | # Find the next open or close tag of the same type
192 | next_open_tag = raw_obs.find(f"<{tag_name}", search_idx)
193 | next_close_tag = raw_obs.find(f"{tag_name}>", search_idx)
194 | # No more tags found, break
195 | if next_open_tag == -1 and next_close_tag == -1:
196 | break
197 | # Decide whether the next tag is an open or close tag
198 | if next_open_tag != -1 and (
199 | next_open_tag < next_close_tag or next_close_tag == -1
200 | ):
201 | open_count += 1
202 | search_idx = raw_obs.find(">", next_open_tag) + 1
203 | else:
204 | close_count += 1
205 | search_idx = next_close_tag + len(f"{tag_name}>")
206 | # If we've closed all open tags, break
207 | if open_count == close_count:
208 | break
209 | # Extract the target element
210 | target_element = raw_obs[start_tag_idx:search_idx]
211 | obs = obs.replace("