├── webexp ├── explore │ ├── __init__.py │ ├── core │ │ ├── trace.py │ │ ├── task.py │ │ ├── agent.py │ │ ├── graph.py │ │ ├── trajectory.py │ │ ├── evaluator.py │ │ ├── episode.py │ │ └── node.py │ └── algorithms │ │ └── web_explore.py ├── agents │ ├── __init__.py │ ├── prompt_builders │ │ ├── nav_explorer_prompt_builder.py │ │ ├── __init__.py │ │ ├── page_explorer_prompt_builder.py │ │ └── solver_prompt_builder.py │ ├── run_episode.py │ ├── trajectory_data.py │ ├── nav_explorer_agent.py │ ├── base_agent.py │ ├── page_explorer_agent.py │ └── solver_agent.py ├── __init__.py ├── benchmark │ └── run_webarena.py └── train │ └── sft_policy.py ├── figures ├── go-browse-main-figure-colored.pdf └── go-browse-main-figure-colored.png ├── setup.py ├── requirements.txt ├── configs ├── benchmark_webarena.yaml ├── agent_run_episode.yaml └── go_browse_config.yaml ├── webarena-reset ├── reddit-reset.sh ├── shopping-reset.sh ├── shopping-admin-reset.sh ├── gitlab-reset.sh └── reset_server.py ├── LICENSE ├── projects └── go-browse │ └── data │ ├── process_dataset.py │ ├── process_nnetnav_data.py │ └── generate_dataset.py └── README.md /webexp/explore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/go-browse-main-figure-colored.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApGa/Go-Browse/HEAD/figures/go-browse-main-figure-colored.pdf -------------------------------------------------------------------------------- /figures/go-browse-main-figure-colored.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApGa/Go-Browse/HEAD/figures/go-browse-main-figure-colored.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = "webexp", 5 | version = "0.0.1", 6 | packages=find_packages(), 7 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | agentlab 3 | browsergym 4 | evaluate 5 | kaleido 6 | scikit-learn 7 | bitsandbytes 8 | deepspeed 9 | flash-attn 10 | liger-kernel 11 | omegaconf 12 | pillow 13 | plotly 14 | torch 15 | transformers 16 | trl 17 | wandb -------------------------------------------------------------------------------- /webexp/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | for module in glob(os.path.join(os.path.dirname(__file__), "*.py")): 5 | if os.path.basename(module) == "__init__.py": 6 | continue 7 | __import__(f"webexp.agents.{os.path.basename(module)[:-3]}") -------------------------------------------------------------------------------- /webexp/__init__.py: -------------------------------------------------------------------------------- 1 | from browsergym.core.registration import register_task 2 | from browsergym.webarena import config, task 3 | 4 | import browsergym.webarena 5 | 6 | class ExplorationTaskWrapper(task.GenericWebArenaTask): 7 | def setup(self, page): 8 | super().setup(page) 9 | self.evaluator = lambda *args, **kwargs: 0.0 10 | return None, {} 11 | 12 | # TODO: We probably only need to register one task per webarena domain and give each a human-readable name 13 | 14 | for task_id in config.TASK_IDS: 15 | gym_id = f"webarena.exploration.{task_id}" 16 | register_task( 17 | gym_id, 18 | ExplorationTaskWrapper, 19 | task_kwargs={"task_id": task_id}, 20 | ) 21 | -------------------------------------------------------------------------------- /configs/benchmark_webarena.yaml: -------------------------------------------------------------------------------- 1 | agent_factory_args: 2 | name: SolverAgent 3 | model_id: 4 | base_url: 5 | base_url_2: 6 | api_key: 7 | temperature: 0. 8 | char_limit: 80000 # Character limit for truncating prompt. 9 | 10 | exp_dir: 11 | n_jobs: 12 12 | resume_dir: null # If resuming from a previous, incomplete run, you can specify the specific run directory path here (exp_dir/run_dir). -------------------------------------------------------------------------------- /webarena-reset/reddit-reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOST=$BASE_URL 4 | PORT="9999" 5 | 6 | # Construct the full URL 7 | FULL_URL="${HOST}:${PORT}" 8 | 9 | # Stop and remove the reddit container if it exists 10 | if [ "$(docker ps -q -f name=forum)" ]; then 11 | echo "Stopping and removing existing forum container..." 12 | docker stop forum 13 | docker rm forum 14 | fi 15 | 16 | echo "Building the forum image..." 17 | docker run --name forum -p ${PORT}:80 -d postmill-populated-exposed-withimg 18 | 19 | # Wait for the forum container to start 20 | echo "Waiting for the forum container to start..." 21 | sleep 60 22 | 23 | echo "Configuring the forum container..." 24 | docker exec forum sed -i "s/^ENABLE_EXPERIMENTAL_REST_API.*/ENABLE_EXPERIMENTAL_REST_API=1/" .env 25 | docker exec -it forum psql -U postmill -d postmill -c "UPDATE users SET trusted = true WHERE username = 'MarvelsGrantMan136';" 26 | 27 | curl "$FULL_URL" 28 | 29 | # Additional wait time to be safe 30 | sleep 30 -------------------------------------------------------------------------------- /webarena-reset/shopping-reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOST=$BASE_URL 4 | PORT="7770" 5 | 6 | # Construct the full URL 7 | FULL_URL="${HOST}:${PORT}" 8 | 9 | # Stop and remove the shopping container if it exists 10 | if [ "$(docker ps -q -f name=shopping)" ]; then 11 | echo "Stopping and removing existing shopping container..." 12 | docker stop shopping 13 | docker rm shopping 14 | fi 15 | 16 | 17 | echo "Building the shopping image..." 18 | docker run --name shopping -p ${PORT}:80 -d shopping_final_0712 19 | 20 | # Wait for the shopping container to start 21 | echo "Waiting for the shopping container to start..." 22 | sleep 30 23 | 24 | 25 | echo "Configuring the shopping container..." 26 | docker exec shopping /var/www/magento2/bin/magento setup:store-config:set --base-url="${FULL_URL}" # no trailing / 27 | docker exec shopping mysql -u magentouser -pMyPassword magentodb -e "UPDATE core_config_data SET value=\"${FULL_URL}/\" WHERE path = \"web/secure/base_url\";" 28 | docker exec shopping /var/www/magento2/bin/magento cache:flush 29 | 30 | curl "$FULL_URL" 31 | 32 | # Additional wait time to be safe 33 | sleep 30 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Apurva Gandhi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/agent_run_episode.yaml: -------------------------------------------------------------------------------- 1 | agent_factory_args: 2 | name: SolverAgent 3 | model_id: 4 | base_url: 5 | base_url_2: 6 | api_key: 7 | temperature: 0. 8 | char_limit: 80000 # Character limit for truncating prompt. 9 | 10 | 11 | env_args: 12 | # task name can be any task registered with browsergym. 13 | # you can also run on any live website by setting task_name to "openended" 14 | # and providing another "start_url" in env_args.task_kwargs with the URL you want to run the agent on. 15 | task_name: webarena.608 16 | #task_name: openended 17 | #task_kwargs: 18 | # start_url: https://www.google.com/maps 19 | max_steps: 30 20 | # If headless is set to true, browser GUI will not be shown. Useful for running on a remote machine. 21 | headless: false 22 | viewport: 23 | width: 1280 24 | height: 1440 25 | 26 | exp_dir: "./results" -------------------------------------------------------------------------------- /webexp/agents/prompt_builders/nav_explorer_prompt_builder.py: -------------------------------------------------------------------------------- 1 | from .solver_prompt_builder import SolverPromptBuilder 2 | 3 | class NavExplorerPromptBuilder(SolverPromptBuilder): 4 | 5 | def cot_examples(self) -> list[dict]: 6 | return [ 7 | {"thought": "It seems that we can navigate to different pages including, Reviews, Home, and Recommendations from this page. Before adding these as navigation tasks, I will first try navigating to them to make sure these indeed take me to new webpages. I will start with the Reviews page.", "action": "click('42')"}, 8 | {"thought": "It seems I was successfully able to navigate to the Reviews page and have now returned. I will add this to the list of navigation tasks and then try other places to navigate to.", "action": "add_tasks_to_dataset('[NAV] Navigate to the Reviews page.')"}, 9 | {"thought": "I see a menu item for categories. Perhaps expanding this menu by clicking on it will show us additional places to navigate to.", "action": "click('5')"}, 10 | {"thought": "I have thoroughly explored this web page and found a good variety of tasks for the user to perform on this page. I will now respond to the user confirming that I have finished collecting data on this web page and now am ready to explore a new web page.", "action": "send_msg_to_user('I have finished exploring and collecting a variety of tasks on this web page. We can move on to the next page.')"}, 11 | ] 12 | -------------------------------------------------------------------------------- /webarena-reset/shopping-admin-reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOST=$BASE_URL 4 | PORT="7780" 5 | 6 | # Construct the full URL 7 | FULL_URL="${HOST}:${PORT}" 8 | 9 | # Stop and remove the shopping admin container if it exists 10 | if [ "$(docker ps -q -f name=shopping-admin)" ]; then 11 | echo "Stopping and removing existing shopping admin container..." 12 | docker stop shopping-admin 13 | docker rm shopping-admin 14 | fi 15 | 16 | echo "Building the shopping admin image..." 17 | docker run --name shopping_admin -p ${PORT}:80 -d shopping_admin_final_0719 18 | 19 | # Wait for the shopping admin container to start 20 | echo "Waiting for the shopping admin container to start..." 21 | sleep 30 22 | 23 | echo "Configuring the shopping admin container..." 24 | # remove the requirement to reset password 25 | docker exec shopping_admin php /var/www/magento2/bin/magento config:set admin/security/password_is_forced 0 26 | docker exec shopping_admin php /var/www/magento2/bin/magento config:set admin/security/password_lifetime 0 27 | docker exec shopping_admin /var/www/magento2/bin/magento setup:store-config:set --base-url="${FULL_URL}" # no trailing / 28 | docker exec shopping_admin mysql -u magentouser -pMyPassword magentodb -e "UPDATE core_config_data SET value=\"${FULL_URL}/\" WHERE path = \"web/secure/base_url\";" 29 | docker exec shopping_admin /var/www/magento2/bin/magento cache:flush 30 | 31 | curl "$FULL_URL/admin" 32 | 33 | # Additional wait time to be safe 34 | sleep 30 -------------------------------------------------------------------------------- /webarena-reset/gitlab-reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOST=$BASE_URL 4 | PORT="8023" 5 | 6 | # Construct the full URL 7 | FULL_URL="${HOST}:${PORT}" 8 | 9 | # Stop and remove the gitlab container if it exists 10 | if [ "$(docker ps -q -f name=gitlab)" ]; then 11 | echo "Stopping and removing existing gitlab container..." 12 | docker stop gitlab 13 | docker rm gitlab 14 | fi 15 | 16 | echo "Building the gitlab image..." 17 | docker run --name gitlab -d -p ${PORT}:8023 gitlab-populated-final-port8023 /opt/gitlab/embedded/bin/runsvdir-start 18 | 19 | echo "Waiting for the gitlab container to start..." 20 | # Wait for the gitlab container to start 21 | sleep 30 22 | 23 | echo "Configuring the gitlab container..." 24 | docker exec gitlab sed -i "s|^external_url.*|external_url '${FULL_URL}'|" /etc/gitlab/gitlab.rb 25 | docker exec gitlab gitlab-ctl reconfigure 26 | 27 | curl "$FULL_URL" 28 | 29 | # Wait for GitLab to be fully initialized with a healthcheck 30 | echo "Waiting for GitLab to be fully ready..." 31 | MAX_RETRIES=30 32 | COUNT=0 33 | DELAY=10 34 | 35 | while [ $COUNT -lt $MAX_RETRIES ]; do 36 | HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" "$FULL_URL/users/sign_in") 37 | 38 | if [ "$HTTP_CODE" = "200" ]; then 39 | echo "GitLab is ready! (Status code: $HTTP_CODE)" 40 | break 41 | else 42 | echo "GitLab not ready yet. Status code: $HTTP_CODE. Retrying in ${DELAY}s..." 43 | sleep $DELAY 44 | COUNT=$((COUNT+1)) 45 | fi 46 | done 47 | 48 | if [ $COUNT -eq $MAX_RETRIES ]; then 49 | echo "Warning: Reached maximum retries. GitLab may not be fully initialized." 50 | fi 51 | 52 | # Additional wait time to be safe 53 | sleep 60 -------------------------------------------------------------------------------- /webexp/explore/core/trace.py: -------------------------------------------------------------------------------- 1 | from .trajectory import TrajectoryStep 2 | from dataclasses import dataclass 3 | import json 4 | import os 5 | 6 | @dataclass 7 | class Trace: 8 | """A trace of a sequence of steps""" 9 | 10 | steps: list[TrajectoryStep] 11 | start_url: str 12 | end_url: str 13 | misc: dict | None = None 14 | 15 | @classmethod 16 | def from_trajectory_steps( 17 | cls, 18 | steps: list[TrajectoryStep| str], 19 | start_url: str, 20 | end_url: str, 21 | misc: dict | None = None, 22 | ): 23 | return cls(steps, start_url, end_url, misc) 24 | 25 | def save(self, save_dir: str): 26 | if not os.path.exists(save_dir): 27 | os.makedirs(save_dir) 28 | 29 | trace_info = { 30 | "start_url": self.start_url, 31 | "end_url": self.end_url, 32 | "misc": self.misc 33 | } 34 | 35 | with open(os.path.join(save_dir, "trace_info.json"), "w") as f: 36 | json.dump(trace_info, f, indent=4) 37 | 38 | for i, step in enumerate(self.steps): 39 | step_save_dir = os.path.join(save_dir, f"step_{i}") 40 | os.makedirs(step_save_dir, exist_ok=True) 41 | step.save(step_save_dir, keep_image_in_memory=True, save_image=False) 42 | 43 | def load(load_dir: str, load_steps: bool=True, load_images: bool=False): 44 | with open(os.path.join(load_dir, "trace_info.json"), "r") as f: 45 | trace_info = json.load(f) 46 | 47 | steps = [] 48 | if load_steps: 49 | for i in range(len(os.listdir(load_dir)) - 1): 50 | step_load_dir = os.path.join(load_dir, f"step_{i}") 51 | steps.append(TrajectoryStep.load(step_load_dir, load_image=load_images)) 52 | else: 53 | steps = os.listdir(load_dir) 54 | 55 | return Trace(steps, trace_info["start_url"], trace_info["end_url"], trace_info["misc"]) 56 | 57 | def __len__(self): 58 | return len(self.steps) -------------------------------------------------------------------------------- /webexp/agents/prompt_builders/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from ..trajectory_data import StepData, TrajectoryData 3 | 4 | def flatten_messages(messages: list[dict]): 5 | flattened_messages = [] 6 | for message in messages: 7 | role = message['role'] 8 | text = '\n\n'.join(c['text'] for c in message['content']) 9 | flattened_messages.append({'role': role, 'content': text}) 10 | return flattened_messages 11 | 12 | 13 | class BasePromptBuilder: 14 | 15 | def build_trajectory_messages(self, trajectory_data: TrajectoryData, char_limit: int = -1) -> list[dict]: 16 | raise NotImplementedError 17 | 18 | def build_messages(self, goal: str, current_step: StepData, history: list[StepData], char_limit: int = -1) -> dict: 19 | raise NotImplementedError 20 | 21 | def pretty_prompt_string(self, messages: list[dict]) -> str: 22 | """ 23 | Convert a list of prompt messages into a single string suitable for printing. 24 | 25 | Args: 26 | messages (list[dict]): The list of messages to convert. 27 | 28 | Returns: 29 | str: The processed stringified prompt suitable for printing. 30 | """ 31 | prompt_text_strings = [] 32 | for message in messages: 33 | prompt_text_strings.append(message["content"]) 34 | full_prompt_txt = "\n".join(prompt_text_strings) 35 | return full_prompt_txt 36 | 37 | PROMPT_BUILDER_REGISTRY: dict[str, BasePromptBuilder] = {} 38 | 39 | class PromptBuilderFactory: 40 | 41 | def create_prompt_builder(self, name: str): 42 | if name not in PROMPT_BUILDER_REGISTRY: 43 | raise ValueError(f"Unknown prompt builder: {name}") 44 | return PROMPT_BUILDER_REGISTRY[name]() 45 | 46 | @staticmethod 47 | def register(cls, aliases: str | tuple[str] = tuple()): 48 | PROMPT_BUILDER_REGISTRY[cls.__name__] = cls 49 | 50 | if isinstance(aliases, str): 51 | aliases = (aliases,) 52 | 53 | for name in aliases: 54 | PROMPT_BUILDER_REGISTRY[name] = cls 55 | -------------------------------------------------------------------------------- /webexp/agents/run_episode.py: -------------------------------------------------------------------------------- 1 | from .base_agent import AgentFactory 2 | from browsergym.experiments import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result 3 | from dataclasses import dataclass 4 | from omegaconf import OmegaConf as oc 5 | import argparse 6 | 7 | @dataclass 8 | class RunEpisodeConfig: 9 | """ 10 | Configuration for running an agent for an episode. 11 | 12 | Attributes: 13 | agent_factory_args (dict): Arguments for the agent factory. 14 | env_args (dict): Arguments for the environment. 15 | exp_dir (str): Directory for storing experiment results. Default is "./results". 16 | """ 17 | agent_factory_args: dict 18 | env_args: dict 19 | exp_dir: str 20 | 21 | 22 | class BrowserGymAgentArgsWrapper(AbstractAgentArgs): 23 | def __init__(self, agent_factory_args: dict): 24 | super().__init__() 25 | self.agent_factory_args = agent_factory_args 26 | 27 | def make_agent(self): 28 | return AgentFactory.create_agent(**self.agent_factory_args) 29 | 30 | def main(): 31 | 32 | # Need to get config file from command line 33 | parser = argparse.ArgumentParser(description="Run an episode with a browser gym agent.") 34 | parser.add_argument( 35 | "--config", 36 | "-c", 37 | type=str, 38 | required=True, 39 | help="Path to the configuration file.", 40 | ) 41 | args = parser.parse_args() 42 | 43 | config: RunEpisodeConfig = oc.load(args.config) 44 | oc.resolve(config) 45 | config_dict = oc.to_container(config) 46 | 47 | agent_args = BrowserGymAgentArgsWrapper(config.agent_factory_args) 48 | env_args = EnvArgs(**config_dict['env_args']) 49 | 50 | exp_args = ExpArgs( 51 | env_args=env_args, 52 | agent_args=agent_args, 53 | ) 54 | 55 | exp_args.prepare(config.exp_dir) 56 | exp_args.run() 57 | 58 | exp_result = get_exp_result(exp_args.exp_dir) 59 | exp_record = exp_result.get_exp_record() 60 | 61 | for key, val in exp_record.items(): 62 | print(f"{key}: {val}") 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /webexp/benchmark/run_webarena.py: -------------------------------------------------------------------------------- 1 | from agentlab.experiments.study import make_study, AgentArgs, Study 2 | from ..agents.base_agent import AgentFactory 3 | from dataclasses import dataclass 4 | from omegaconf import OmegaConf as oc 5 | import argparse 6 | import os 7 | 8 | @dataclass 9 | class RunBenchmarkConfig: 10 | """ 11 | Configuration for running an agent for an episode. 12 | 13 | Attributes: 14 | agent_factory_args (dict): Arguments for the agent factory. 15 | exp_dir (str): Directory to save the experiment results. 16 | n_jobs (int): Number of workers to use. 17 | """ 18 | agent_factory_args: dict 19 | exp_dir: str 20 | n_jobs: int 21 | resume_dir: str | None = None 22 | 23 | 24 | class AgentLabAgentArgsWrapper(AgentArgs): 25 | def __init__(self, agent_factory_args: dict): 26 | super().__init__() 27 | self.agent_factory_args = agent_factory_args 28 | 29 | def make_agent(self): 30 | return AgentFactory.create_agent(**self.agent_factory_args) 31 | 32 | 33 | def run(): 34 | 35 | parser = argparse.ArgumentParser(description="Run webarena benchmark.") 36 | parser.add_argument( 37 | "--config", 38 | "-c", 39 | type=str, 40 | required=True, 41 | help="Path to the configuration file.", 42 | ) 43 | args = parser.parse_args() 44 | 45 | config: RunBenchmarkConfig = oc.load(args.config) 46 | oc.resolve(config) 47 | config_dict = oc.to_container(config) 48 | 49 | agent_args = AgentLabAgentArgsWrapper(config_dict['agent_factory_args']) 50 | 51 | if config.resume_dir is None: 52 | 53 | study = make_study( 54 | benchmark="webarena", 55 | agent_args=[agent_args], 56 | comment="WebArena benchmark run", 57 | 58 | ) 59 | else: 60 | study = Study.load(config.resume_dir) 61 | study.find_incomplete(include_errors=True) 62 | 63 | study.run( 64 | n_jobs=config.n_jobs, 65 | exp_root=config.exp_dir, 66 | n_relaunch=8 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | run() 72 | -------------------------------------------------------------------------------- /webexp/agents/prompt_builders/page_explorer_prompt_builder.py: -------------------------------------------------------------------------------- 1 | from .solver_prompt_builder import SolverPromptBuilder 2 | 3 | class PageExplorerPromptBuilder(SolverPromptBuilder): 4 | 5 | def cot_examples(self) -> list[dict]: 6 | return [ 7 | {"thought": "I see selectors for choosing the date on this orders page. An example task could be to find the orders in a particular time period (e.g., in Jan 2022). Before we add this as a task to our dataset, we should first try to see if this is possible by trying out the task ourselves. I will click on the date selector.", "action": "click('12')"}, 8 | {"thought": "This page lists information about customers that we can create information extraction tasks about. I will add such tasks to the dataset.", "action": "add_tasks_to_dataset('[INFO] What is the email address of the customer Joe Bloggs?', '[INFO] List the names of all customers from Texas')"}, 9 | {"thought": "It seems that we can navigate to different pages including, Reviews, Home, and Recommendations from this page. Let me add these as navigation tasks to the dataset.", "action": "add_tasks_to_dataset('[NAV] Navigate to the Reviews page.', '[NAV] Go to the Home page.', '[NAV] Visit the Recommendations page.')"}, 10 | {"thought": "I see a menu item for 'Product Information' on this page. Perhaps there are some interesting information extraction tasks based on this. Let me click on the 'Product Information' menu item to explore further to see if I can find some concrete tasks to add to the dataset.", "action": "click('5')"}, 11 | {"thought": "My last action has taken me to a new URL/page. Since my goal is to find tasks the original page, I will first go back to the previous page to continue finding tasks on that page.", "action": "go_back()"}, 12 | {"thought": "This is the product page for the Nintendo Switch. I see that we can perform content modification tasks such as adding the product to the cart or adding a review. I will add these tasks to the dataset.", "action": "add_tasks_to_dataset('[MOD] Add a Nintendo Switch with 256 GB storage to the cart.', '[MOD] Leave a negative review for the Nintendo Switch saying that the joycon controllers started drifting after a month of usage.')"}, 13 | {"thought": "I have thoroughly explored this web page and found a good variety of tasks for the user to perform on this page. I will now respond to the user confirming that I have finished collecting data on this web page and now am ready to explore a new web page.", "action": "send_msg_to_user('I have finished exploring and collecting a variety of tasks on this web page. We can move on to the next page.')"} 14 | ] 15 | -------------------------------------------------------------------------------- /webarena-reset/reset_server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, HTTPException 2 | import asyncio 3 | import os 4 | 5 | app = FastAPI() 6 | 7 | # Define the mapping of domain names to their reset scripts 8 | DOMAIN_SCRIPTS = { 9 | "shopping": "shopping-reset.sh", 10 | "shopping_admin": "shopping-admin-reset.sh", 11 | "reddit": "reddit-reset.sh", 12 | "gitlab": "gitlab-reset.sh", 13 | } 14 | 15 | @app.post("/reset/{domain}") 16 | async def reset_domain(domain: str): 17 | # Map does not require reset as it is stateless 18 | if domain == "map": 19 | return {"message": "Map domain does not require reset."} 20 | 21 | if domain == "all": 22 | # Reset all domains 23 | results = {} 24 | for domain_name, script_path in DOMAIN_SCRIPTS.items(): 25 | if os.path.exists(script_path): 26 | try: 27 | proc = await asyncio.create_subprocess_exec( 28 | "/bin/bash", script_path, 29 | stdout=asyncio.subprocess.PIPE, 30 | stderr=asyncio.subprocess.PIPE 31 | ) 32 | stdout, stderr = await proc.communicate() 33 | 34 | if proc.returncode == 0: 35 | results[domain_name] = { 36 | "message": f"{domain_name} reset successfully", 37 | "output": stdout.decode() 38 | } 39 | else: 40 | results[domain_name] = { 41 | "error": f"Error resetting {domain_name}: {stderr.decode()}" 42 | } 43 | except Exception as e: 44 | results[domain_name] = {"error": f"Exception while resetting {domain_name}: {str(e)}"} 45 | else: 46 | results[domain_name] = {"error": f"Reset script not found for {domain_name}"} 47 | return results 48 | 49 | if domain not in DOMAIN_SCRIPTS: 50 | raise HTTPException(status_code=404, detail="Domain not found") 51 | 52 | script_path = DOMAIN_SCRIPTS[domain] 53 | 54 | if not os.path.exists(script_path): 55 | raise HTTPException(status_code=500, detail=f"Reset script not found for {domain}") 56 | 57 | try: 58 | proc = await asyncio.create_subprocess_exec( 59 | "/bin/bash", script_path, 60 | stdout=asyncio.subprocess.PIPE, 61 | stderr=asyncio.subprocess.PIPE 62 | ) 63 | stdout, stderr = await proc.communicate() 64 | 65 | if proc.returncode == 0: 66 | return {"message": f"{domain} reset successfully", "output": stdout.decode()} 67 | else: 68 | raise HTTPException(status_code=500, detail=f"Error resetting {domain}: {stderr.decode()}") 69 | except Exception as e: 70 | raise HTTPException(status_code=500, detail=f"Exception while resetting {domain}: {str(e)}") -------------------------------------------------------------------------------- /configs/go_browse_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | env_args: 3 | # For webarena domains, replace the task name with the following per domain: 4 | # webarena.exploration.7 (map) 5 | # webarena.exploration.44 (gitlab) 6 | # webarena.exploration.0 (shopping_admin) 7 | # webarena.exploration.66 (reddit) 8 | # webarena.exploration.51 (shopping) 9 | # To run on any live website, set task_name to "openended" and provide a start_url in task_kwargs. 10 | task_name: 11 | #task_name: openended 12 | #task_kwargs: 13 | # start_url: https://www.google.com/maps 14 | max_steps: 10000000 15 | headless: true 16 | viewport: 17 | width: 1280 18 | height: 1440 19 | 20 | evaluator: 21 | model_name: gpt-4o-2024-05-13 22 | 23 | max_nodes: 20 24 | resume_from: null # If resuming from a previous, incomplete run, you can specify the specific run directory path here (exp_dir). 25 | max_feasible_page_explorer_tasks_per_node: 20 26 | max_feasible_nav_explorer_tasks_per_node: 10 27 | 28 | page_explorers: 29 | - agent_factory_args: 30 | name: "PageExplorerAgent" 31 | model_id: "gpt-4o-2024-08-06" 32 | base_url: "$" 33 | api_key: "$" 34 | temperature: 0.7 35 | max_steps: 20 36 | retries: 3 37 | 38 | - agent_factory_args: 39 | name: "PageExplorerAgent" 40 | model_id: "claude-3-7-sonnet-20250219" 41 | base_url: "$" 42 | api_key: "$" 43 | temperature: 0.7 44 | max_steps: 10 45 | retries: 3 46 | 47 | nav_explorers: 48 | - agent_factory_args: 49 | name: "NavExplorerAgent" 50 | model_id: "claude-3-7-sonnet-20250219" 51 | base_url: "$" 52 | api_key: "$" 53 | temperature: 0.7 54 | max_steps: 15 55 | retries: 3 56 | 57 | feasibility_checkers: 58 | - agent_factory_args: 59 | name: "SolverAgent" 60 | model_id: "claude-3-7-sonnet-20250219" 61 | base_url: "$" 62 | api_key: "$" 63 | temperature: 0.7 64 | char_limit: 80000 65 | max_steps: 10 66 | retries: 3 67 | 68 | solvers: 69 | - agent_factory_args: 70 | name: "SolverAgent" 71 | model_id: "gpt-4o-mini-2024-07-18" 72 | base_url: "$" 73 | api_key: "$" 74 | temperature: 0.7 75 | char_limit: 80000 76 | max_steps: 10 77 | retries: 2 78 | 79 | - agent_factory_args: 80 | name: "SolverAgent" 81 | model_id: "Qwen/Qwen2.5-7B-Instruct" 82 | base_url: "$" 83 | temperature: 0.7 84 | char_limit: 80000 85 | max_steps: 10 86 | retries: 2 87 | 88 | 89 | allowlist_patterns: 90 | - /.* 91 | 92 | 93 | denylist_patterns: 94 | - ".*/admin/admin/system_config" 95 | - ".*/admin/admin/system_config/.*" 96 | 97 | exp_dir: 98 | 99 | # You can optionally reset the environment after exploration of each node, by setting the full_reset_url here. 100 | # To setup the reset url, copy/clone over the webarena-reset folder to your webarena hosting server and run reset_server.py with fastapi. 101 | # Note the reset scripts that the reset_server.py uses expects that you set the BASE_URL env var on the server to the public base url of the server. 102 | full_reset_url: null 103 | -------------------------------------------------------------------------------- /webexp/agents/trajectory_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from dataclasses import dataclass 3 | from copy import deepcopy 4 | 5 | @dataclass 6 | class StepData: 7 | misc : dict | None = None 8 | 9 | def process_for_dataset(self) -> StepData: 10 | return deepcopy(self) 11 | 12 | def process_for_prompt(self) -> StepData: 13 | return deepcopy(self) 14 | 15 | @dataclass 16 | class TrajectoryData: 17 | steps: list[StepData] 18 | goal: str 19 | reward: float 20 | misc: dict | None = None 21 | 22 | def process_for_dataset(self) -> TrajectoryData: 23 | return deepcopy(self) 24 | 25 | def process_for_prompt(self) -> TrajectoryData: 26 | return deepcopy(self) 27 | 28 | 29 | @dataclass 30 | class BrowserGymAgentStepData(StepData): 31 | action: str | None = None 32 | thought: str | None = None 33 | axtree: str | None = None 34 | last_action_error: str | None = None 35 | 36 | def process_for_dataset(self) -> BrowserGymAgentStepData: 37 | if self.misc is None: 38 | self.misc = {} 39 | 40 | if self.action is None or self.thought is None or self.axtree is None: 41 | self.misc["skip"] = True 42 | 43 | 44 | return BrowserGymAgentStepData( 45 | action=self.action, 46 | thought=self.thought, 47 | axtree=self.axtree, 48 | last_action_error=self.last_action_error, 49 | misc=self.misc 50 | ) 51 | 52 | 53 | @dataclass 54 | class BrowserGymAgentTrajectoryData(TrajectoryData): 55 | steps: list[BrowserGymAgentStepData] 56 | goal: str 57 | reward: float 58 | misc: dict | None = None 59 | 60 | def process_for_dataset(self) -> BrowserGymAgentTrajectoryData: 61 | processed_steps = [step.process_for_dataset() for step in self.steps] 62 | processed_steps = [step for step in processed_steps if not step.misc.get("skip", False)] 63 | 64 | if not processed_steps: 65 | self.misc["skip"] = True 66 | return BrowserGymAgentTrajectoryData( 67 | steps=processed_steps, 68 | goal=self.goal, 69 | reward=self.reward, 70 | misc=self.misc 71 | ) 72 | 73 | reward = self.reward 74 | 75 | if processed_steps[-1].action and "report_infeasible" in processed_steps[-1].action and reward > 0: 76 | self.misc["skip"] = True 77 | return BrowserGymAgentTrajectoryData( 78 | steps=processed_steps, 79 | goal=self.goal, 80 | reward=0.0, 81 | misc=self.misc 82 | ) 83 | 84 | # If there are a trailing set of "noop" actions, we should remove them from the end of the trajectory 85 | while len(processed_steps) > 1 and processed_steps[-1].action and "noop" in processed_steps[-1].action: 86 | processed_steps.pop() 87 | 88 | # TODO: Consider removing repeated action. Though, this might not always be accurate. 89 | 90 | return BrowserGymAgentTrajectoryData( 91 | steps=processed_steps, 92 | goal=self.goal, 93 | reward=self.reward, 94 | misc=self.misc 95 | ) 96 | 97 | def process_for_prompt(self) -> BrowserGymAgentTrajectoryData: 98 | processed_steps = [step.process_for_prompt() for step in self.steps] 99 | return BrowserGymAgentTrajectoryData( 100 | steps=processed_steps, 101 | goal=self.goal, 102 | reward=self.reward, 103 | misc=self.misc 104 | ) 105 | -------------------------------------------------------------------------------- /webexp/train/sft_policy.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import torch 4 | from datasets import load_dataset 5 | from datetime import timedelta 6 | from transformers import AutoTokenizer 7 | from transformers.utils.import_utils import is_torch_bf16_gpu_available 8 | from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM 9 | 10 | #%% Set paths 11 | web_explore_data_dir = "" # Replace with your web explore data directory 12 | dataset_path = os.path.join(web_explore_data_dir, "") # Replace with your dataset path 13 | 14 | model_id = "Qwen/Qwen2.5-7B-Instruct" 15 | output_dir = os.path.join(web_explore_data_dir, "go-browse/outputs_qwen_7B_bc/") 16 | 17 | os.environ["WANDB_API_KEY"] = " # Replace with your Weights & Biases API key" 18 | os.environ["WANDB_PROJECT"] = "Go-Browse" 19 | 20 | #%% 21 | max_seq_length = 24000 22 | 23 | trainer_config = SFTConfig( 24 | per_device_train_batch_size = 1, 25 | gradient_accumulation_steps = 4, 26 | warmup_steps = 100, 27 | num_train_epochs = 2, 28 | learning_rate = 2e-5, 29 | fp16 = not is_torch_bf16_gpu_available(), 30 | bf16 = is_torch_bf16_gpu_available(), 31 | logging_steps = 1, 32 | optim = "paged_adamw_8bit", 33 | weight_decay = 0.01, 34 | lr_scheduler_type = "linear", 35 | seed = 3407, 36 | output_dir = output_dir, 37 | save_strategy = "steps", 38 | save_steps=500, 39 | dataset_text_field = "text", 40 | max_seq_length = max_seq_length, 41 | packing = False, 42 | gradient_checkpointing=True, 43 | gradient_checkpointing_kwargs={'use_reentrant': False}, 44 | use_liger=False, 45 | save_total_limit=1, 46 | dataset_num_proc=8, 47 | report_to=["wandb"], 48 | model_init_kwargs = { 49 | "attn_implementation": "flash_attention_2", 50 | "trust_remote_code": True, 51 | "torch_dtype": "bfloat16" if is_torch_bf16_gpu_available() else "float16", 52 | "device_map": "auto" 53 | }, 54 | ) 55 | 56 | tokenizer = AutoTokenizer.from_pretrained(model_id) 57 | data_collator = DataCollatorForCompletionOnlyLM( 58 | response_template = "<|im_start|>assistant\n", 59 | tokenizer = tokenizer 60 | ) 61 | 62 | 63 | #%% 64 | 65 | def formatting_prompts_func(sample): 66 | convos = sample['flattened'] 67 | texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False, enable_thinking=False) for convo in convos] 68 | return { "text" : texts, } 69 | 70 | def flatten_messages(sample): 71 | prompt = sample['prompt'] 72 | completion = sample['completion'] 73 | messages = prompt + completion 74 | return {'flattened': messages} 75 | 76 | #%% 77 | dataset = load_dataset('json', data_files=os.path.join(web_explore_data_dir, "datasets/go-browse-data-processed-no-prefix.jsonl")) 78 | dataset = dataset['train'] 79 | dataset = dataset.filter(lambda x: x['traj_reward'] > 0).map(lambda x: x['step_data']) 80 | 81 | #%% 82 | dataset = dataset.filter(lambda x: 'prompt' in x and 'completion' in x and x['prompt'] and x['completion']) 83 | dataset = dataset.map(flatten_messages, remove_columns=['prompt', 'completion']) 84 | dataset = dataset.map(formatting_prompts_func, batched=True) 85 | 86 | #%% 87 | print(dataset[0]['text']) 88 | #%% 89 | trainer = SFTTrainer( 90 | model = model_id, 91 | tokenizer = tokenizer, 92 | train_dataset = dataset, 93 | data_collator = data_collator, 94 | args = trainer_config 95 | ) 96 | 97 | #%% 98 | trainer_stats = trainer.train() 99 | 100 | #%% 101 | print(trainer_stats) 102 | #%% 103 | trainer.save_model(os.path.join(output_dir, "final_checkpoint")) 104 | -------------------------------------------------------------------------------- /webexp/agents/nav_explorer_agent.py: -------------------------------------------------------------------------------- 1 | from .base_agent import AgentFactory 2 | from .solver_agent import SolverAgent 3 | from .prompt_builders.nav_explorer_prompt_builder import NavExplorerPromptBuilder 4 | from browsergym.core.action.highlevel import HighLevelActionSet 5 | from textwrap import dedent 6 | 7 | TASK_COLLECTOR = [] 8 | 9 | def add_tasks_to_dataset(*tasks: str): 10 | """Given one or more navigation task strings, add them to the dataset we are collecting. 11 | You should add tags to the start of the task string to indicate the type of task it is. 12 | Since your job is to find navigation tasks, you should use the [NAV] tag to indicate that the task is a navigation task. 13 | 14 | Examples: 15 | add_tasks_to_dataset('[NAV] Navigate to the Recommendations page.', '[NAV] Visit the Home page.') 16 | add_tasks_to_dataset('[NAV] Take me to my cart.') 17 | """ 18 | TASK_COLLECTOR.extend(tasks) 19 | 20 | 21 | @AgentFactory.register 22 | class NavExplorerAgent(SolverAgent): 23 | """ 24 | Agent used to propose exploration tasks for a web page. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_id: str, 30 | base_url: str | None = None, 31 | api_key: str | None = None, 32 | temperature: float = 1.0, 33 | char_limit: int = -1, 34 | demo_mode: str = 'off', 35 | ): 36 | """ 37 | Initialize the agent. 38 | """ 39 | super().__init__(model_id=model_id, base_url=base_url, api_key=api_key, temperature=temperature, char_limit=char_limit, demo_mode=demo_mode) 40 | 41 | self.action_set = HighLevelActionSet( 42 | subsets=["chat", "bid", "infeas", "nav", "tab", "custom"], 43 | custom_actions=[add_tasks_to_dataset], 44 | strict=False, 45 | multiaction=False, 46 | demo_mode=demo_mode, 47 | ) 48 | 49 | self.action_set.python_includes = "from webexp.agents.nav_explorer_agent import TASK_COLLECTOR, add_tasks_to_dataset\n" + self.action_set.python_includes 50 | 51 | self.prompt_builder = NavExplorerPromptBuilder(action_set=self.action_set) 52 | 53 | self.config['goal_str'] = self.goal_str 54 | 55 | def reset(self): 56 | super().reset() 57 | TASK_COLLECTOR.clear() 58 | 59 | def get_proposed_tasks(self) -> list[str]: 60 | return TASK_COLLECTOR.copy() 61 | 62 | @property 63 | def goal_str(self) -> str: 64 | return dedent("""\ 65 | I am trying to collect a dataset to train a better web browser agent that can perform actions for users in a web browser. For this, we are particularly interested to collect **navigation tasks** that are feasible to perform from the current web page. 66 | Navigation tasks are tasks requiring navigating to a specific page. 67 | 68 | Collect navigation tasks that require navigating to another webpage from this current page. You may click on links to try finding other interesting pages to collect tasks from. But if you do navigate to another page, instead of collecting tasks on that page, make sure to navigate back to the previous page using `go_back` or `goto`. We will collect tasks from these new pages later. When collecting navigation tasks, prioritize those that would likely have interesting/useful tasks on them over ones that likely won't give many useful tasks to collect. 69 | 70 | As you are exploring, you can add navigation tasks to the dataset using the `add_tasks_to_dataset` function. 71 | 72 | When you are done exploring the current page, send a message to the user using `send_msg_to_user` confirming this. 73 | 74 | Be sure to prioritize adding navigation tasks to pages that a typical user of this web page would most often want to navigate to, over niche pages that the typical user would rarely frequent. 75 | 76 | **Important** 77 | Remember that if you are successful at navigating to a new page, you should add a corresponding task to the dataset as your next action before finding new pages.""" 78 | ) 79 | -------------------------------------------------------------------------------- /webexp/explore/core/task.py: -------------------------------------------------------------------------------- 1 | from .trajectory import Trajectory 2 | from dataclasses import dataclass 3 | import json 4 | import logging 5 | import os 6 | import re 7 | 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.INFO) 10 | 11 | @dataclass 12 | class Task: 13 | goal: str 14 | positive_trajs: list[Trajectory] 15 | negative_trajs: list[Trajectory] 16 | exp_dir: str 17 | misc: dict = None 18 | 19 | def __post_init__(self): 20 | if not os.path.exists(self.exp_dir): 21 | os.makedirs(self.exp_dir) 22 | os.makedirs(os.path.join(self.exp_dir, "positive_trajs")) 23 | os.makedirs(os.path.join(self.exp_dir, "negative_trajs")) 24 | task_info = { 25 | "goal": self.goal, 26 | "misc": self.misc, 27 | } 28 | with open(os.path.join(self.exp_dir, "task_info.json"), "w") as f: 29 | json.dump(task_info, f, indent=4) 30 | 31 | def is_feasible(self) -> bool: 32 | return len(self.positive_trajs) > 0 33 | 34 | def add_trajectory(self, traj: Trajectory, subdirectory=None): 35 | 36 | exp_dir = self.exp_dir 37 | 38 | if subdirectory is not None: 39 | exp_dir = os.path.join(exp_dir, subdirectory) 40 | 41 | if traj.success: 42 | traj_save_dir = os.path.join(exp_dir, "positive_trajs", f"{len(self.positive_trajs)}") 43 | self.positive_trajs.append(traj) 44 | logger.info(f"Saving positive trajectory to {traj_save_dir}") 45 | os.makedirs(traj_save_dir, exist_ok=True) 46 | traj.save(traj_save_dir) 47 | else: 48 | traj_save_dir = os.path.join(exp_dir, "negative_trajs", f"{len(self.negative_trajs)}") 49 | self.negative_trajs.append(traj) 50 | logger.info(f"Saving negative trajectory to {traj_save_dir}") 51 | os.makedirs(traj_save_dir, exist_ok=True) 52 | traj.save(traj_save_dir) 53 | 54 | @staticmethod 55 | def load(load_dir: str, load_steps: bool=True, load_images: bool=True): 56 | with open(os.path.join(load_dir, "task_info.json"), "r") as f: 57 | task_info = json.load(f) 58 | 59 | positive_trajs = [] 60 | for i in range(len(os.listdir(os.path.join(load_dir, "positive_trajs")))): 61 | traj_load_dir = os.path.join(load_dir, "positive_trajs", f"{i}") 62 | positive_trajs.append(Trajectory.load(traj_load_dir, load_steps=load_steps, load_images=load_images)) 63 | 64 | negative_trajs = [] 65 | for i in range(len(os.listdir(os.path.join(load_dir, "negative_trajs")))): 66 | traj_load_dir = os.path.join(load_dir, "negative_trajs", f"{i}") 67 | negative_trajs.append(Trajectory.load(traj_load_dir, load_steps=load_steps, load_images=load_images)) 68 | 69 | return Task(task_info["goal"], positive_trajs, negative_trajs, load_dir, task_info["misc"]) 70 | 71 | @staticmethod 72 | def process_raw_goal(goal: str) -> tuple[str, list[str]]: 73 | """ 74 | Process the raw goal string to remove any tags and return the cleaned goal and any tags. 75 | """ 76 | # Parse tags from the goal string 77 | tags = [] 78 | # Updated pattern to match any tag inside square brackets 79 | tag_pattern = r'\[([A-Za-z0-9_-]+)\]' 80 | 81 | # Find all tags in the goal string 82 | found_tags = re.findall(tag_pattern, goal) 83 | if found_tags: 84 | tags = found_tags 85 | # Remove the tags from the goal string 86 | goal = re.sub(r'\[[A-Za-z0-9_-]+\]', '', goal).strip() 87 | return goal, tags 88 | 89 | @staticmethod 90 | def from_goal(goal: str, exp_dir: str, misc: dict = None): 91 | if misc is None: 92 | misc = {} 93 | 94 | goal, tags = Task.process_raw_goal(goal) 95 | 96 | # Add the tags to the misc dictionary 97 | if tags: 98 | misc["tags"] = tags 99 | 100 | return Task(goal, [], [], exp_dir, misc) 101 | -------------------------------------------------------------------------------- /projects/go-browse/data/process_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from glob import glob 5 | from tqdm import tqdm 6 | from browsergym.core.action.highlevel import HighLevelActionSet 7 | 8 | 9 | from webexp.agents.solver_agent import SolverAgent 10 | from webexp.agents.prompt_builders.solver_prompt_builder import SolverPromptBuilder 11 | from webexp.agents.trajectory_data import BrowserGymAgentStepData, BrowserGymAgentTrajectoryData 12 | 13 | # Note: This script assumes input data is ordered by trajectory and starts with traj_num 0 14 | 15 | INCLUDE_PREFIX = False 16 | 17 | INPUT_DATA_FILE = "" # Replace with your desired input file path 18 | OUTPUT_DATA_FILE = "" # Replace with your desired output file path 19 | 20 | PREFIX_PROB = 0 21 | 22 | if __name__ == "__main__": 23 | 24 | action_set = HighLevelActionSet( 25 | subsets=["chat", "bid", "infeas", "nav"], 26 | strict=False, 27 | multiaction=False, 28 | demo_mode=False 29 | ) 30 | 31 | prompt_builder = SolverPromptBuilder(action_set) 32 | 33 | total_trajs = 0 34 | total_steps = 0 35 | prev_traj_data = None 36 | 37 | 38 | with open(INPUT_DATA_FILE, "r") as f: 39 | with open(OUTPUT_DATA_FILE, "w") as out_f: 40 | 41 | steps = [] 42 | curr_traj = 0 43 | traj_goal = "" 44 | for line in tqdm(f, desc="Processing lines", position=0, leave=False): 45 | 46 | line_data = json.loads(line) 47 | traj_data = line_data['traj_data'] 48 | 49 | if curr_traj == traj_data['traj_num'] - 1: 50 | 51 | curr_traj += 1 52 | 53 | traj = BrowserGymAgentTrajectoryData(steps, prev_traj_data['goal'], prev_traj_data['reward'], prev_traj_data['misc']).process_for_dataset() 54 | 55 | steps = [] 56 | 57 | if 'skip' in traj.misc and traj.misc['skip']: 58 | continue 59 | 60 | total_trajs += 1 61 | 62 | step_data_points = prompt_builder.build_trajectory_messages(traj, char_limit=80000) 63 | for i, d in enumerate(step_data_points): 64 | d_wrap = {'step_idx': total_steps, 'step_data': d, 'traj_reward': traj.reward, 'next_step_idx': (total_steps + 1) if i != (len(step_data_points) - 1) else -1, 'traj_length': len(step_data_points), 'step_number': i} 65 | out_f.write(json.dumps(d_wrap) + "\n") 66 | total_steps += 1 67 | 68 | prev_traj_data = traj_data 69 | 70 | step_data = line_data['step_data'] 71 | if INCLUDE_PREFIX or not ('is_prefix_step' in step_data['misc'] and step_data['misc']['is_prefix_step']): 72 | steps.append(BrowserGymAgentStepData( 73 | action=step_data['parsed_action'], 74 | thought=step_data['thought'], 75 | axtree=step_data['obs']['axtree_txt'], 76 | last_action_error=step_data['obs']['last_action_error'], 77 | misc=step_data['misc'], 78 | )) 79 | 80 | if steps: # If there are steps collected for the last trajectory 81 | traj = BrowserGymAgentTrajectoryData(steps, prev_traj_data['goal'], prev_traj_data['reward'], prev_traj_data['misc']).process_for_dataset() 82 | 83 | step_data_points = prompt_builder.build_trajectory_messages(traj, char_limit=80000) 84 | for i, d in enumerate(step_data_points): 85 | d_wrap = {'step_idx': total_steps, 'step_data': d, 'traj_reward': traj.reward, 'next_step_idx': (total_steps + 1) if i != (len(step_data_points) - 1) else -1, 'traj_length': len(step_data_points), 'step_number': i} 86 | out_f.write(json.dumps(d_wrap) + "\n") 87 | total_steps += 1 88 | total_trajs += 1 89 | 90 | 91 | print("Total Trajs: ", total_trajs) 92 | print("Total Steps: ", total_steps) 93 | -------------------------------------------------------------------------------- /webexp/explore/core/agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from .graph import Graph 3 | from .node import Node 4 | from .trajectory import Trajectory 5 | from ...agents.base_agent import Agent, ExplorerAgent 6 | from browsergym.core.env import BrowserEnv 7 | from typing import Protocol, runtime_checkable 8 | 9 | 10 | @runtime_checkable 11 | class AgentWithExplorationCallbacks(Agent, Protocol): 12 | """ 13 | Protocol for an agent that supports exploration callbacks. 14 | """ 15 | 16 | def register_pre_step_callbacks(self, callbacks: list[callable]) -> AgentWithExplorationCallbacks: 17 | """Register a callback to be called before each step.""" 18 | ... 19 | 20 | def register_post_step_callbacks(self, callbacks: list[callable]) -> AgentWithExplorationCallbacks: 21 | """Register a callback to be called after each step.""" 22 | ... 23 | 24 | def run_pre_step_callbacks(self, step_num: int, goal: str, env: BrowserEnv, graph: Graph, traj: Trajectory, obs: dict, reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict) -> tuple: 25 | """Run all registered pre-step callbacks and return potentially modified versions of the inputs.""" 26 | ... 27 | 28 | def run_post_step_callbacks(self, step_num: int, goal: str, env: BrowserEnv, graph: Graph, traj: Trajectory, obs: dict, reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict) -> tuple: 29 | """Run all registered post-step callbacks and return potentially modified versions of the inputs.""" 30 | ... 31 | 32 | 33 | def wrap_agent_for_callback_protocol(agent: Agent, pre_step_callbacks: list[callable]= None, post_step_callbacks: list[callable]=None) -> AgentWithExplorationCallbacks: 34 | """ 35 | Wrap an agent to implement the AgentWithExplorationCallbacks protocol. 36 | """ 37 | if isinstance(agent, AgentWithExplorationCallbacks): 38 | return agent 39 | 40 | class CallbackProtocolAgentWrapper(type(agent)): 41 | def __init__(self, agent): 42 | self._agent = agent 43 | self._pre_step_callbacks = [] 44 | self._post_step_callbacks = [] 45 | 46 | def register_pre_step_callbacks(self, callbacks): 47 | self._pre_step_callbacks.extend(callbacks) 48 | return self 49 | 50 | def register_post_step_callbacks(self, callbacks): 51 | self._post_step_callbacks.extend(callbacks) 52 | return self 53 | 54 | def run_pre_step_callbacks(self, step_num: int, goal: str, env: BrowserEnv, graph: Graph, node: Node, traj: Trajectory, obs: dict, reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict) -> tuple: 55 | for callback in self._pre_step_callbacks: 56 | step_num, obs, reward, terminated, truncated, env_info, goal, callback_context = callback(self, step_num, goal, env, graph, node, traj, obs, reward, terminated, truncated, env_info, callback_context) 57 | return step_num, obs, reward, terminated, truncated, env_info, goal, callback_context 58 | 59 | def run_post_step_callbacks(self, step_num: int, goal: str, env: BrowserEnv, graph: Graph, node: Node, traj: Trajectory, obs: dict, reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict) -> tuple: 60 | for callback in self._post_step_callbacks: 61 | step_num, obs, reward, terminated, truncated, env_info, goal, callback_context = callback(self, step_num, goal, env, graph, node, traj, obs, reward, terminated, truncated, env_info, callback_context) 62 | return step_num, obs, reward, terminated, truncated, env_info, goal, callback_context 63 | 64 | def __getattr__(self, name): 65 | return getattr(self._agent, name) 66 | 67 | wrapped_agent = CallbackProtocolAgentWrapper(agent) 68 | 69 | if pre_step_callbacks: 70 | wrapped_agent.register_pre_step_callbacks(pre_step_callbacks) 71 | 72 | if post_step_callbacks: 73 | wrapped_agent.register_post_step_callbacks(post_step_callbacks) 74 | 75 | return wrapped_agent 76 | 77 | class ExplorerAgentWithExplorationCallbacks(ExplorerAgent, AgentWithExplorationCallbacks, Protocol): 78 | """ 79 | Intersection type for an ExplorerAgent that supports exploration callbacks. 80 | """ 81 | ... -------------------------------------------------------------------------------- /webexp/explore/core/graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from .node import Node 3 | from .trace import Trace 4 | from .trajectory import TrajectoryStep 5 | from typing import Sequence 6 | import json 7 | import logging 8 | import os 9 | import re 10 | 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.INFO) 13 | 14 | if not logger.handlers: 15 | handler = logging.StreamHandler() 16 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | handler.setFormatter(formatter) 18 | logger.addHandler(handler) 19 | 20 | class Graph: 21 | def __init__( 22 | self, 23 | root_url: str, 24 | exp_dir: str, 25 | allowlist_patterns: Sequence[str] = tuple(), 26 | denylist_patterns: Sequence[str] = tuple(), 27 | resume: bool=False 28 | ): 29 | 30 | self.nodes = {} 31 | self.explored_nodes = [] 32 | self.unexplored_nodes = [] 33 | self.exp_dir = os.path.join(exp_dir, "graph") 34 | self.allowlist_patterns = allowlist_patterns 35 | self.denylist_patterns = denylist_patterns 36 | 37 | if not resume: 38 | self.root = self.add_url(root_url, None, []) 39 | 40 | # Save graph info 41 | graph_info = { 42 | "root_url": self.root.url, 43 | "allowlist_patterns": self.allowlist_patterns, 44 | "denylist_patterns": self.denylist_patterns, 45 | } 46 | with open(os.path.join(self.exp_dir, "graph_info.json"), "w") as f: 47 | json.dump(graph_info, f, indent=4) 48 | 49 | def get_node(self, url: str) -> Node | None: 50 | return self.nodes.get(url, None) 51 | 52 | def add_url(self, url: str, parent: Node, prefixes: list[Trace], node_misc: dict = None) -> Node: 53 | 54 | if url in self.nodes: 55 | logger.warning(f"In Graph.add_url: Node {url} already exists in the graph.") 56 | return self.nodes[url] 57 | 58 | node_exp_dir = os.path.join(self.exp_dir, f"node_{len(self.nodes)}") 59 | node = Node(url, {}, {}, [], "", prefixes, False, node_exp_dir, misc=node_misc) 60 | if parent: 61 | parent.children.append(node.url) 62 | parent.update_save(save_prefix=False) 63 | self.nodes[url] = node 64 | self.unexplored_nodes.append(node) 65 | return node 66 | 67 | def add_to_explored(self, node: Node): 68 | self.explored_nodes.append(node) 69 | self.unexplored_nodes.remove(node) 70 | node.visited = True 71 | node.update_save(save_prefix=False) 72 | logger.info(f"Node {node.url} has been explored.") 73 | 74 | def get_next_node(self) -> Node | None: 75 | if len(self.unexplored_nodes) == 0: 76 | logger.info("No nodes left to explore.") 77 | return None 78 | return self.unexplored_nodes[0] #TODO: Can add user-defined priortization here. 79 | 80 | 81 | def check_if_url_allowed(self, url: str) -> bool: 82 | for pattern in self.allowlist_patterns: 83 | if re.match(pattern, url): 84 | return True 85 | for pattern in self.denylist_patterns: 86 | if re.match(pattern, url): 87 | return False 88 | return True 89 | 90 | 91 | @staticmethod 92 | def load(path: str, load_steps: bool=True, load_prefixes: bool=True, load_images: bool=True, max_nodes=-1) -> Graph: 93 | nodes = {} 94 | explored_nodes = [] 95 | unexplored_nodes = [] 96 | 97 | logger.info(f"Loading graph from {path}") 98 | 99 | if max_nodes == -1: 100 | max_nodes = len(os.listdir(path)) - 1 101 | else: 102 | max_nodes = min(max_nodes, len(os.listdir(path)) - 1) 103 | 104 | for i in range(max_nodes): 105 | logger.info(f"Loading node {i} from {path}") 106 | node_load_dir = os.path.join(path, f"node_{i}") 107 | node = Node.load(node_load_dir, load_steps=load_steps, load_prefix=load_prefixes, load_images=load_images) 108 | nodes[node.url] = node 109 | if node.visited: 110 | explored_nodes.append(node) 111 | else: 112 | unexplored_nodes.append(node) 113 | 114 | graph_info = {} 115 | with open(os.path.join(path, "graph_info.json"), "r") as f: 116 | graph_info = json.load(f) 117 | 118 | graph = Graph(graph_info["root_url"], path, graph_info["allowlist_patterns"], graph_info["denylist_patterns"], resume=True) 119 | graph.root = nodes[graph_info["root_url"]] 120 | graph.nodes = nodes 121 | graph.explored_nodes = explored_nodes 122 | graph.unexplored_nodes = unexplored_nodes 123 | graph.exp_dir = path 124 | 125 | logger.info(f"Loaded graph with {len(nodes)} nodes, {len(explored_nodes)} explored nodes, and {len(unexplored_nodes)} unexplored nodes.") 126 | 127 | return graph 128 | -------------------------------------------------------------------------------- /webexp/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | @runtime_checkable 4 | class Agent(Protocol): 5 | """ 6 | Protocol for an agent. 7 | """ 8 | 9 | def get_action(self, obs: dict, oracle_action: tuple[str, str]=None, **kwargs) -> tuple[str, dict]: 10 | """ 11 | Get the action for the given observation. 12 | 13 | Args: 14 | obs (dict): The observation from the environment. 15 | oracle_action (tuple[str, str], optional): Tuple of (action, reason) to use if available instead of generating a new one. 16 | 17 | Returns: 18 | str: The action to take. 19 | """ 20 | ... 21 | 22 | def reset(self): 23 | """ 24 | Reset the agent's state. 25 | """ 26 | ... 27 | 28 | def get_config(self) -> dict: 29 | """ 30 | Get the agent's configuration. 31 | 32 | Returns: 33 | dict: The agent's configuration. 34 | """ 35 | ... 36 | 37 | def obs_preprocessor(self, obs: dict) -> dict: 38 | """ 39 | Preprocess the observation before it is passed to get_action. 40 | 41 | Args: 42 | obs (dict): The observation from the environment. 43 | 44 | Returns: 45 | dict: The preprocessed observation. 46 | """ 47 | ... 48 | 49 | def action_processor(self, action: str) -> str: 50 | """ 51 | Process the action before it is passed to the environment. 52 | 53 | Args: 54 | action (str): The action to process. 55 | 56 | Returns: 57 | str: The processed action. 58 | """ 59 | ... 60 | 61 | @runtime_checkable 62 | class ExplorerAgent(Agent, Protocol): 63 | """ 64 | Agent used to propose and collect exploration tasks. 65 | """ 66 | 67 | def get_proposed_tasks(self) -> list[str]: 68 | ... 69 | 70 | @property 71 | def goal_str(self) -> str: 72 | """ 73 | Get the exploration goal/task string for the agent. 74 | 75 | Returns: 76 | str: The goal string. 77 | """ 78 | ... 79 | 80 | 81 | class BaseAgent: 82 | def __init__(self, **kwargs): 83 | """ 84 | Initialize the agent. 85 | 86 | Args: 87 | *args: Positional arguments for the agent. 88 | **kwargs: Keyword arguments for the agent. 89 | """ 90 | self.config = { 91 | "name": self.__class__.__name__, 92 | **kwargs 93 | } 94 | 95 | def reset(self): 96 | """ 97 | Reset the agent's state. 98 | """ 99 | raise NotImplementedError 100 | 101 | def get_action(self, obs: dict, oracle_action: tuple[str, str]=None, **kwargs) -> tuple[str, dict]: 102 | """ 103 | Get the action for the given observation. 104 | 105 | Args: 106 | obs (dict): The observation from the environment. 107 | oracle_action (str, optional): Tuple of (action, thought) to use if available instead of generating a new one. 108 | 109 | Returns: 110 | str: The action to take. 111 | """ 112 | raise NotImplementedError 113 | 114 | def get_config(self) -> dict: 115 | """ 116 | Get the agent's configuration. 117 | 118 | Returns: 119 | dict: The agent's configuration. 120 | """ 121 | return self.config 122 | 123 | def obs_preprocessor(self, obs: dict) -> dict: 124 | """ 125 | Preprocess the observation before it is passed to get_action. 126 | 127 | Args: 128 | obs (dict): The observation from the environment. 129 | 130 | Returns: 131 | dict: The preprocessed observation. 132 | """ 133 | return obs 134 | 135 | def action_processor(self, action: str) -> str: 136 | """ 137 | Process the action before it is passed to the environment. 138 | 139 | Args: 140 | action (str): The action to process. 141 | 142 | Returns: 143 | str: The processed action. 144 | """ 145 | return action 146 | 147 | @classmethod 148 | def create_agent(cls, *args, **kwargs): 149 | """ 150 | Create an agent instance. 151 | 152 | Args: 153 | *args: Positional arguments for the agent. 154 | **kwargs: Keyword arguments for the agent. 155 | 156 | Returns: 157 | BaseAgent: An instance of a class derived from BaseAgent. 158 | """ 159 | return cls(*args, **kwargs) 160 | 161 | 162 | AGENT_FACTORY_REGISTRY: dict[str, BaseAgent] = {} 163 | 164 | class AgentFactory: 165 | 166 | @staticmethod 167 | def create_agent(name: str, *args, **kwargs): 168 | """ 169 | Create an agent instance. 170 | 171 | Args: 172 | *args: Positional arguments for the agent. 173 | **kwargs: Keyword arguments for the agent. 174 | 175 | Returns: 176 | BaseAgent: An instance of a class derived from BaseAgent. 177 | """ 178 | if name not in AGENT_FACTORY_REGISTRY: 179 | raise ValueError(f"Unknown agent: {name}") 180 | return AGENT_FACTORY_REGISTRY[name].create_agent(*args, **kwargs) 181 | 182 | 183 | @staticmethod 184 | def register(cls, aliases: str | tuple[str] = tuple()): 185 | """ 186 | Register an agent class. 187 | 188 | Args: 189 | cls (type): The agent class to register. 190 | 191 | Returns: 192 | type: The agent class that was registered. 193 | """ 194 | AGENT_FACTORY_REGISTRY[cls.__name__] = cls 195 | 196 | if isinstance(aliases, str): 197 | aliases = (aliases,) 198 | 199 | for name in aliases: 200 | AGENT_FACTORY_REGISTRY[name] = cls 201 | 202 | return cls 203 | -------------------------------------------------------------------------------- /webexp/explore/core/trajectory.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | from PIL import Image 4 | import json 5 | import os 6 | 7 | def _extract_text_obs(observation): 8 | """ 9 | Recursively extract all textual observations from a nested structure. 10 | """ 11 | if isinstance(observation, str): 12 | return observation 13 | elif isinstance(observation, dict): 14 | return {k: _extract_text_obs(v) for k, v in observation.items() if isinstance(v, (str, dict, list))} 15 | elif isinstance(observation, list): 16 | return [_extract_text_obs(item) for item in observation if isinstance(item, (str, dict, list))] 17 | return None 18 | 19 | @dataclass 20 | class TrajectoryStep: 21 | action: str | None 22 | parsed_action: str | None 23 | thought: str | None 24 | observation: dict 25 | misc: dict | None = None 26 | 27 | def __post_init__(self): 28 | self._last_saved_dir = None 29 | 30 | def save(self, save_dir: str, keep_image_in_memory: bool=False, save_image: bool=True): 31 | # Extract all textual observations, including nested collections 32 | text_obs = _extract_text_obs(self.observation) 33 | step_info = { 34 | "action": self.action, 35 | "parsed_action": self.parsed_action, 36 | "thought": self.thought, 37 | "observation": text_obs, 38 | "misc": self.misc 39 | } 40 | 41 | with open(os.path.join(save_dir, "step_info.json"), "w") as f: 42 | json.dump(step_info, f, indent=4) 43 | 44 | if 'screenshot' in self.observation: 45 | if save_image: 46 | # Save screenshot 47 | screenshot = self.observation["screenshot"] 48 | img = Image.fromarray(screenshot) 49 | img.save(os.path.join(save_dir, "screenshot.png")) 50 | 51 | if not keep_image_in_memory: 52 | # Remove the screenshot from memory to save space 53 | del self.observation["screenshot"] 54 | self.observation = {k: v for k, v in self.observation.items() if k != "screenshot"} 55 | 56 | self._last_saved_dir = save_dir 57 | 58 | @staticmethod 59 | def load(load_dir: str, load_image: bool=True): 60 | with open(os.path.join(load_dir, "step_info.json"), "r") as f: 61 | step_info = json.load(f) 62 | 63 | if load_image: 64 | screenshot = np.asarray(Image.open(os.path.join(load_dir, "screenshot.png"))) 65 | step_info["observation"]["screenshot"] = screenshot 66 | 67 | return TrajectoryStep(step_info["action"], step_info["parsed_action"], step_info["thought"], step_info["observation"], step_info["misc"]) 68 | 69 | @property 70 | def last_saved_dir(self) -> str | None: 71 | return self._last_saved_dir 72 | 73 | 74 | @dataclass 75 | class Trajectory: 76 | steps: list[TrajectoryStep] 77 | final_state: TrajectoryStep | None 78 | goal: str 79 | reward: float 80 | success: bool 81 | response: str 82 | agent_info: dict 83 | misc: dict 84 | 85 | def add_step(self, action: str, parsed_action: str | None, thought: str | None, observation: dict, misc: dict = None): 86 | self.steps.append(TrajectoryStep(action, parsed_action, thought, observation, misc)) 87 | 88 | def extract_response(self, env): 89 | chat_messages = env.chat.messages 90 | if chat_messages and chat_messages[-1]["role"] == "assistant": 91 | self.response = chat_messages[-1]["message"] 92 | elif chat_messages and chat_messages[-1]["role"] == "infeasible": 93 | self.response = "User goal/request is infeasible." 94 | 95 | return self.response 96 | 97 | def save(self, save_dir: str): 98 | traj_info = { 99 | "goal": self.goal, 100 | "reward": self.reward, 101 | "success": self.success, 102 | "response": self.response, 103 | "agent_info": self.agent_info, 104 | "misc": self.misc 105 | } 106 | 107 | with open(os.path.join(save_dir, "traj_info.json"), "w") as f: 108 | json.dump(traj_info, f, indent=4) 109 | 110 | for i, step in enumerate(self.steps): 111 | 112 | step_save_dir = os.path.join(save_dir, f"step_{i}") 113 | os.makedirs(step_save_dir, exist_ok=True) 114 | 115 | step.save(step_save_dir) 116 | 117 | final_state_save_dir = os.path.join(save_dir, "final_state") 118 | os.makedirs(final_state_save_dir, exist_ok=True) 119 | if self.final_state is not None: 120 | self.final_state.save(final_state_save_dir) 121 | 122 | @staticmethod 123 | def load(load_dir: str, load_steps: bool=True, load_images: bool=True): 124 | with open(os.path.join(load_dir, "traj_info.json"), "r") as f: 125 | traj_info = json.load(f) 126 | 127 | steps = [] 128 | if load_steps: 129 | i = 0 130 | while os.path.exists(os.path.join(load_dir, f"step_{i}")): 131 | step_load_dir = os.path.join(load_dir, f"step_{i}") 132 | steps.append(TrajectoryStep.load(step_load_dir, load_image=load_images)) 133 | i += 1 134 | 135 | final_state_load_dir = os.path.join(load_dir, "final_state") 136 | final_state = TrajectoryStep.load(final_state_load_dir, load_image=load_images) 137 | 138 | return Trajectory(steps, final_state, traj_info["goal"], traj_info["reward"], traj_info["success"], traj_info["response"], traj_info["agent_info"], traj_info["misc"]) 139 | 140 | def __len__(self): 141 | return len(self.steps) 142 | 143 | @staticmethod 144 | def from_goal(goal: str, agent_info: dict = None, misc: dict = None): 145 | if agent_info is None: 146 | agent_info = {} 147 | if misc is None: 148 | misc = {} 149 | return Trajectory([], None, goal, 0.0, False, "N/A", agent_info=agent_info, misc=misc) 150 | -------------------------------------------------------------------------------- /webexp/agents/page_explorer_agent.py: -------------------------------------------------------------------------------- 1 | from .base_agent import AgentFactory 2 | from .solver_agent import SolverAgent 3 | from .prompt_builders.page_explorer_prompt_builder import PageExplorerPromptBuilder 4 | from browsergym.core.action.highlevel import HighLevelActionSet 5 | from textwrap import dedent 6 | 7 | TASK_COLLECTOR = [] 8 | 9 | def add_tasks_to_dataset(*tasks: str): 10 | """Given one or more task strings, add them to the dataset we are collecting. 11 | You should add tags to the start of the task string to indicate the type of task it is. 12 | For example, you can use the following tags: 13 | [INFO] for information seeking tasks, 14 | [NAV] for navigation tasks, and 15 | [MOD] for content modification tasks, configuration changes, or anything that modifies the state of the webpage. 16 | You can add multiple tags to a single task string if it is a combination of different types of tasks. 17 | For example, you can use [INFO][NAV] for a task that requires both information seeking and navigation. 18 | 19 | Examples: 20 | add_tasks_to_dataset('[MOD] Add the Apple iPhone 13 to the cart.', '[MOD] Leave a review for iPhone 13 saying that I loved it.') 21 | add_tasks_to_dataset('[INFO] List the best-selling product for first quarter of 2023.') 22 | add_tasks_to_dataset('[INFO] Compare the driving and walking times from University of Washington to Amazon's headquarters in Seattle.') 23 | add_tasks_to_dataset('[NAV] Navigate to the product page for the Apple iPhone 13.') 24 | """ 25 | TASK_COLLECTOR.extend(tasks) 26 | 27 | 28 | @AgentFactory.register 29 | class PageExplorerAgent(SolverAgent): 30 | """ 31 | Agent used to propose exploration tasks for a web page. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | model_id: str, 37 | base_url: str | None = None, 38 | api_key: str | None = None, 39 | temperature: float = 1.0, 40 | char_limit: int = -1, 41 | demo_mode: str = 'off', 42 | ): 43 | """ 44 | Initialize the agent. 45 | """ 46 | super().__init__(model_id=model_id, base_url=base_url, api_key=api_key, temperature=temperature, char_limit=char_limit, demo_mode=demo_mode) 47 | 48 | self.action_set = HighLevelActionSet( 49 | subsets=["chat", "bid", "infeas", "nav", "tab", "custom"], 50 | custom_actions=[add_tasks_to_dataset], 51 | strict=False, 52 | multiaction=False, 53 | demo_mode=demo_mode, 54 | ) 55 | 56 | self.action_set.python_includes = "from webexp.agents.page_explorer_agent import TASK_COLLECTOR, add_tasks_to_dataset\n" + self.action_set.python_includes 57 | 58 | self.prompt_builder = PageExplorerPromptBuilder(action_set=self.action_set) 59 | 60 | def reset(self): 61 | super().reset() 62 | TASK_COLLECTOR.clear() 63 | 64 | def get_proposed_tasks(self) -> list[str]: 65 | return TASK_COLLECTOR.copy() 66 | 67 | @property 68 | def goal_str(self) -> str: 69 | return dedent("""\ 70 | I am trying to collect a dataset to train a better web browser agent that can perform actions for users in a web browser. For this, I need to first collect tasks that are feasible to perform on the current web page. 71 | The tasks should be concrete (e.g., on an amazon product page for product X, an appropriate task could be "Leave a positive review for X" or on a maps website a task could be "Show me driving directions from X to Y." where X and Y are specific locations). 72 | You may explore by performing actions on this web page if that helps to determine concrete tasks that are feasible. 73 | 74 | Find the tasks that are possible to perform on the current web page itself, without have to navigate to other links/urls. Though, you may find it helpful to navigate through menus on this page to get a better idea of what types of tasks are feasible. If you accidentally go to a new url while trying to navigate items on the page, you can go back to the previous page using the `go_back` function. 75 | 76 | Tasks are usually of three types: 77 | 1. Information seeking: The user wants to obtain certain information from the webpage, such as the information of a product, reviews, map info, comparison of map routes, etc. 78 | 2. Site navigation: The user wants to navigate to a specific page. 79 | 3. Content modification: The user wants to modify the content of a webpage or configuration. 80 | 81 | Be as specific as you can while creating tasks. The web agent may start from a different web page when asked to complete the task and so may not have the current page context to understand the task. So, for example, avoid creating generic tasks like "Add item to cart" or "Print receipt for this order." Instead you want to create specific tasks like "Add a Sony PS5 to cart" or "Print a receipt for Martha Jone's order of the Nike Velocity Sweatpants from May 21, 2021" 82 | 83 | I recommend the following order to collecting tasks: 84 | 1. First look for information seeking/extraction tasks that can be answered simply using information on the current page, requiring no additional actions. 85 | 2. Collect navigation tasks that require navigating to another webpage from this current page. You may click to links to try finding other interesting pages to collect tasks from. But if you do navigate to another page, instead of collecting tasks on that page, make sure to navigate back to the previous page using `go_back`. We will collect tasks from these new pages later. When collecting navigation tasks, prioritize those that would likely have interesting/useful tasks on them over ones that likely won't give many useful tasks to collect. 86 | 3. Finally, you can try to find content modification tasks on the current page that require performing actions on the current page itself. 87 | 88 | As you are exploring the page, you may find it helpful to click on buttons, links, and other elements on the page to see if they reveal any additional information or options that could lead to new tasks. You can also hover over elements to see if they provide any tooltips or additional context. 89 | 90 | **Important**: 91 | When collecting tasks, focus more on the common tasks that a typical user of this webpage would want to perform. Avoid niche tasks that are unlikely to be relevant to the typical user of this website. 92 | For most common styles of tasks, it may be useful to include a few variants or related tasks to help the web agent learn frequently used skills. 93 | 94 | As you are exploring, you can add tasks to the dataset using the `add_tasks_to_dataset` function. 95 | 96 | When you are done exploring, send a message to the user using `send_msg_to_user` confirming this.""" 97 | ) 98 | -------------------------------------------------------------------------------- /projects/go-browse/data/process_nnetnav_data.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import json 4 | from datasets import load_dataset 5 | from tqdm import tqdm 6 | from browsergym.core.action.highlevel import HighLevelActionSet 7 | from webexp.agents.solver_agent import SolverAgent 8 | from webexp.agents.prompt_builders.solver_prompt_builder import SolverPromptBuilder 9 | from webexp.agents.trajectory_data import BrowserGymAgentStepData, BrowserGymAgentTrajectoryData 10 | 11 | OUTPUT_DATA_FILE = "" # Replace with your desired output file path 12 | 13 | #%% 14 | def action_convert(action): 15 | if 'type' in action: 16 | action = action.replace("type", "fill") 17 | id = action.split("[")[1].split("]")[0] 18 | content = action.split("[")[2].split("]")[0] 19 | if action.count("[") > 3: 20 | press_enter_after = action.split("[")[3].split("]")[0] 21 | else: 22 | press_enter_after = '0' 23 | action = f"fill(\'{id}\', \'{content}\')" 24 | if press_enter_after == "1": 25 | action += f"\npress(\'{id}\', 'Enter')" 26 | elif 'press' in action: 27 | key_comb = action.split("[")[1].split("]")[0] 28 | action = f"press(None, \'{key_comb}\')" 29 | elif 'scroll' in action: 30 | direction = action.split("[")[1].split("]")[0] 31 | if direction == 'down': 32 | action = f"scroll(0, 500)" 33 | elif direction == 'up': 34 | action = f"scroll(0, -500)" 35 | elif 'close_tab' in action: 36 | action = action.replace("close_tab", "tab_close") 37 | action = f"tab_close()" 38 | elif 'stop' in action: 39 | action = action.replace("stop", "send_msg_to_user") 40 | text = action.split("[")[1].split("]")[0] 41 | action = f"send_msg_to_user(\'{text}\')" 42 | elif 'click' in action: 43 | id = action.split("[")[1].split("]")[0] 44 | action = f"click(\'{id}\')" 45 | elif 'hover' in action: 46 | id = action.split("[")[1].split("]")[0] 47 | action = f"hover(\'{id}\')" 48 | elif 'new_tab' in action: 49 | action = f"new_tab()" 50 | elif 'tab_focus' in action: 51 | id = action.split("[")[1].split("]")[0] 52 | action = f"tab_focus(\'{id}\')" 53 | elif 'goto' in action: 54 | url = action.split("[")[1].split("]")[0] 55 | action = f"goto(\'{url}\')" 56 | elif 'go_back' in action: 57 | action = f"go_back()" 58 | elif 'go_forward' in action: 59 | action = f"go_forward()" 60 | return action 61 | 62 | #%% 63 | def parse_example(example): 64 | parts = example["output"].split("In summary, ", 1) 65 | think = parts[0].strip() 66 | original_output = example["output"] 67 | 68 | # Remove assistant header from think section 69 | think = think.replace("<|start_header_id|>assistant<|end_header_id|>", "")\ 70 | .replace("Let's think step-by-step.", "")\ 71 | .replace("Let's think step by step.", "")\ 72 | .strip() 73 | 74 | if len(parts) > 1: 75 | action = parts[1].split("<|eot_id|>")[0].strip() 76 | if "``````" in action: 77 | action_parts = action.split("```") 78 | if len(action_parts) > 2: 79 | action = action_parts[2].split("``````")[0] 80 | action = action_convert(action) 81 | else: 82 | action = "" 83 | elif "```" in action: 84 | # Split by triple backticks and process each action 85 | action_parts = action.split("```") 86 | # Extract only the action part after "In summary, my next action should be" 87 | if len(action_parts) > 1: 88 | action = action_parts[1].strip() 89 | action = action_convert(action) 90 | else: 91 | action = "" 92 | else: 93 | action = "" 94 | 95 | prompt = example['prompt'] 96 | 97 | # Extract content between OBJECTIVE and PREVIOUS ACTIONS 98 | objective_start = prompt.find("OBJECTIVE:") 99 | previous_actions_start = prompt.find("PREVIOUS ACTIONS:") 100 | 101 | if objective_start != -1 and previous_actions_start != -1: 102 | # Start after "OBJECTIVE:" and remove any leading/trailing whitespace 103 | objective_content = prompt[objective_start + len("OBJECTIVE:"):previous_actions_start].strip() 104 | else: 105 | objective_content = "" 106 | 107 | axtree = prompt.split("OBSERVATION:")[1].split("\nURL: ")[0].strip() 108 | 109 | step_data = BrowserGymAgentStepData( 110 | action=action, 111 | thought=think, 112 | axtree=axtree, 113 | last_action_error=None, 114 | misc={'goal': objective_content}, 115 | ) 116 | 117 | return step_data 118 | 119 | 120 | #%% 121 | dataset = load_dataset("stanfordnlp/nnetnav-wa", split='train') 122 | 123 | # %% 124 | unique_task_names = sorted(set(dataset['task_name'])) 125 | 126 | 127 | #%% 128 | action_set = HighLevelActionSet( 129 | subsets=["chat", "bid", "infeas", "nav"], 130 | strict=False, 131 | multiaction=False, 132 | demo_mode=False 133 | ) 134 | 135 | prompt_builder = SolverPromptBuilder(action_set) 136 | 137 | # %% Process each task group separately 138 | 139 | total_trajs = len(unique_task_names) 140 | total_steps = 0 141 | skipped_steps = 0 142 | 143 | 144 | processed_data = [] 145 | 146 | with open(OUTPUT_DATA_FILE, "w") as out_f: 147 | 148 | curr_task_name = dataset[0]['task_name'] 149 | steps = [] 150 | 151 | for t, example in tqdm(enumerate(dataset), desc="Processing examples", position=0, leave=False): 152 | if example['task_name'] != curr_task_name: 153 | curr_task_name = example['task_name'] 154 | 155 | traj = BrowserGymAgentTrajectoryData(steps, steps[0].misc['goal'], 1, {}).process_for_dataset() 156 | 157 | step_data_points = prompt_builder.build_trajectory_messages(traj, char_limit=80000) 158 | for i, d in enumerate(step_data_points): 159 | d_wrap = {'step_idx': total_steps, 'step_data': d, 'traj_reward': traj.reward, 'next_step_idx': (total_steps + 1) if i != (len(step_data_points) - 1) else -1, 'traj_length': len(step_data_points), 'step_number': i} 160 | out_f.write(json.dumps(d_wrap) + "\n") 161 | total_steps += 1 162 | 163 | steps = [] 164 | 165 | try: 166 | steps.append(parse_example(example)) 167 | except Exception as e: 168 | print(f"Error processing example {t}: {e}") 169 | skipped_steps += 1 170 | continue 171 | 172 | print(f"Total steps: {total_steps}, Skipped steps: {skipped_steps}") 173 | print(f"Total trajectories: {total_trajs}") 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Go-Browse: Training Web Agents with Structured Exploration

3 | 4 | arXiv 5 | 6 | 7 | PRs Welcome 8 | 9 | 10 |
11 | 12 | ## Table of Contents 13 | - [Overview](#overview) 14 | - [Setup](#setup) 15 | - [Collect Dataset](#collect-dataset) 16 | - [Process Collected Go-Browse Dataset for Training](#process-collected-go-browse-dataset-for-training) 17 | - [Process NNetNav Dataset for Training](#process-nnetnav-dataset-for-training) 18 | - [Finetune a Model](#finetune-a-model) 19 | - [Benchmark a Model on WebArena](#benchmark-a-model-on-webarena) 20 | - [Run an Episode on a Website](#run-an-episode-on-a-website) 21 | - [Go-Browse-WA Dataset and Trained Models Release](#go-browse-wa-dataset-and-trained-models-release) 22 | - [Citation](#citation) 23 | 24 | 25 | ## Overview 26 | 27 | Go-Browse is a method for automatic, unsupervised collection of high-quality and diverse web agent training data via structured exploration of websites. 28 | 29 | Go-Browse has an outer loop that iteratively builds up a graph of previously visited webpages on a website (incentivizing global website coverage) and an inner loop that thoroughly explores each discovered webpage by: (1) Proposing tasks to solve on that page and tasks to discover neighboring pages; (2) Filtering these tasks to feasible ones by trying to solve them and judging successes with a strong computer-use LM + a VLM-as-a-judge and (3) Sampling additional task-solving trajectories with various other pretrained LMs. 30 | 31 | ![image](figures/go-browse-main-figure-colored.png) 32 | 33 | By resetting the inner loop to previously discovered webpages, the outer loop helps Go-Browse reuse information across the multiple inner loop invocations, enabling more efficient and deeper exploration of websites. 34 | 35 | We release [Go-Browse-WA](#go-browse-wa-dataset-and-trained-models-release), a dataset collected by running Go-Browse on 100 webpages from WebArena websites, collecting ~10K successful task-solving trajectories and ~17K unsuccessful ones. 36 | 37 | Finetuning Qwen-2.5-7B-Instruct on Go-Browse-WA achieves state-of-the-art performance for sub-10B parameter models on the WebArena benchmark with a overall success rate of 21.7%, beating the previous best finetuned sub-10B model by 2.9 percentage points and beating GPT-4o-mini by 2.4 percentage points. 38 | 39 | ## Setup 40 | 41 | Note, we ran our experiments with Python 3.12, though earlier python versions may also work. 42 | 43 | 1. Follow the instructions here to install browsergym with webarena and playwright with chromium: https://github.com/ServiceNow/BrowserGym 44 | 2. Install `webexp` and dependencies: 45 | ```sh 46 | pip install -r requirements.txt 47 | pip install -e . 48 | ``` 49 | 3. Setup a WebArena Server using the instructions here: [webarena readme](https://github.com/web-arena-x/webarena/blob/main/environment_docker/README.md). You can also optionally setup the a reset server to remotely reset the webarena environments by: 50 | - Copy/clone over the webarena-reset folder to your webarena hosting instance 51 | - `pip install fastapi[standard]` on this instance. 52 | - `cd webarena-reset` 53 | - `export BASE_URL=` 54 | - `fastapi run reset_server.py` 55 | - You can now reset a specific domain at once (e.g. map with `/reset/map`) or all domains at once with (e.g., `/reset/all`). 56 | 57 | ## Collect Dataset 58 | Example config file used for Go-Browse-WA data generation is: `configs/go_browse_config.yaml` 59 | 60 | For each domain (website) that you want to run data generation for, duplicate/modify the config file by filling in placeholders and then run: 61 | ```sh 62 | python -m webexp.explore.algorithms.web_explore -c configs/web_explore_config.yaml 63 | ``` 64 | 65 | ### Process Collected Go-Browse Dataset for Training 66 | First, set the input and output paths as appropriate in `projects/go-browse/data/generate_dataset.py` and `projects/go-browse/data/process_dataset.py` 67 | 68 | Then: 69 | ```sh 70 | python projects/go-browse/data/generate_dataset.py 71 | python projects/go-browse/data/process_dataset.py 72 | ``` 73 | 74 | ### Process NNetNav Dataset for Training 75 | First, set the output path as appropriate in `projects/go-browse/data/process_nnetnav_data.py` 76 | 77 | Then: 78 | ``` 79 | python projects/go-browse/data/process_nnetnav_data.py 80 | ``` 81 | 82 | ## Finetune a Model 83 | First, replace the placeholder paths/env vars as appropriate in `webexp/train/sft_policy.py` 84 | 85 | Then: 86 | ``` 87 | python webexp/train/sft_policy.py 88 | ``` 89 | 90 | ## Benchmark a Model on WebArena 91 | If benchmarking a finetuned model, first serve the model using an inference server like [vllm](https://docs.vllm.ai/en/latest/) or [sglang](https://docs.sglang.ai/). We used `vllm` in our experiments. 92 | 93 | Duplicate/edit the following config file by filling in the placeholders: `configs/benchmark_webarena.yaml`. 94 | 95 | Then: 96 | ``` 97 | python -m webexp.benchmark.run_webarena -c configs/benchmark_webarena.yaml 98 | ``` 99 | 100 | ## Run an Episode on a Website 101 | If performing inference with a finetuned model, first serve the model using an inference server like [vllm](https://docs.vllm.ai/en/latest/) or [sglang](https://docs.sglang.ai/). 102 | 103 | Duplicate/edit the following config file by filling in the placeholders: `configs/benchmark_webarena.yaml`. 104 | 105 | Then: 106 | ``` 107 | python -m webexp.agents.run_episode -c configs/benchmark_webarena.yaml 108 | ``` 109 | 110 | ## Go-Browse-WA Dataset and Trained Models Release 111 | 112 | Datasets (on HF Hub): 113 | - Processed dataset (output of `projects/go-browse/data/process_dataset.py`): [apurvaga/go-browse-wa](https://huggingface.co/datasets/apurvaga/go-browse-wa). 114 | 115 | This includes both successful and unsuccessful trajectories processed for finetuning. Page observations are represented as accessibility trees (potentially truncated for context length limits while training). 116 | 117 | - Raw dataset: [apurvaga/go-browse-wa-raw](https://huggingface.co/datasets/apurvaga/go-browse-wa-raw) 118 | 119 | Raw version includes screenshots, pruned_html, full accessibility tree text and additional metadata. 120 | 121 | Finetuned models (on HF Hub): 122 | - [apurvaga/go-browse-wa-qwen-7B](https://huggingface.co/apurvaga/go-browse-wa-qwen-7B) 123 | - [apurvaga/nnetnav-wa-qwen-7B](https://huggingface.co/apurvaga/nnetnav-wa-qwen-7B) 124 | 125 | ## Citation 126 | ```bibtex 127 | @misc{gandhi2025gobrowse, 128 | title={Go-Browse: Training Web Agents with Structured Exploration}, 129 | author={Apurva Gandhi and Graham Neubig}, 130 | year={2025}, 131 | eprint={2506.03533}, 132 | archivePrefix={arXiv}, 133 | primaryClass={cs.CL}, 134 | url={https://arxiv.org/abs/2506.03533}, 135 | } 136 | ``` -------------------------------------------------------------------------------- /webexp/explore/core/evaluator.py: -------------------------------------------------------------------------------- 1 | from .trajectory import Trajectory 2 | from PIL import Image 3 | from typing import Union, Optional 4 | from openai import OpenAI 5 | from openai.types.chat import ChatCompletion 6 | from textwrap import dedent 7 | import base64 8 | import io 9 | import logging 10 | import numpy as np 11 | import os 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | 16 | if not logger.handlers: 17 | handler = logging.StreamHandler() 18 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | 22 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", None)) 23 | 24 | def extract_content(text, start_tag): 25 | """ 26 | Extract the content that follows 'Info:' in a given string. 27 | 28 | :param text: A string that may contain lines starting with 'Info:' 29 | :return: The content that follows 'Info:' or None if not found 30 | """ 31 | # Split the text into lines 32 | lines = text.split("\n") 33 | 34 | # Loop through each line to find a line that starts with 'Info:' 35 | for line in lines: 36 | if line.startswith(start_tag): 37 | # Extract and return the content after 'Info:' 38 | return line[len(start_tag) :].strip() 39 | 40 | # Return None if 'Info:' is not found in any line 41 | return "" 42 | 43 | def image_to_jpg_base64_url(image: np.ndarray | Image.Image): 44 | """Convert a numpy array to a base64 encoded image url.""" 45 | 46 | if isinstance(image, np.ndarray): 47 | image = Image.fromarray(image) 48 | if image.mode in ("RGBA", "LA"): 49 | image = image.convert("RGB") 50 | 51 | with io.BytesIO() as buffer: 52 | image.save(buffer, format="JPEG") 53 | image_base64 = base64.b64encode(buffer.getvalue()).decode() 54 | 55 | return image_base64 56 | 57 | 58 | def build_vision_eval_prompt( 59 | intent, response, last_actions, axtree_txt 60 | ) -> tuple[str, str]: 61 | system_msg = dedent("""\ 62 | You are an expert in evaluating the performance of a web navigation agent. The agent is designed to help a human user navigate a website to complete a task. Given the user's intent, the agent's action history, the final state of the webpage, and the agent's response to the user, your goal is to decide whether the agent's execution is successful or not. 63 | 64 | There are three types of tasks: 65 | 1. Information seeking: The user wants to obtain certain information from the webpage, such as the information of a product, reviews, map info, comparison of map routes, etc. The bot's response must contain the information the user wants, or explicitly state that the information is not available. Otherwise, e.g. the bot encounters an exception and respond with the error content, the task is considered a failure. Besides, be careful about the sufficiency of the agent's actions. For example, when asked to list the top-searched items in a shop, the agent should order the items by the number of searches, and then return the top items. If the ordering action is missing, the task is likely to fail. 66 | 2. Site navigation: The user wants to navigate to a specific page. Carefully examine the bot's action history and the final state of the webpage to determine whether the bot successfully completes the task. No need to consider the bot's response. 67 | 3. Content modification: The user wants to modify the content of a webpage or configuration. Carefully examine the bot's action history and the final state of the webpage to determine whether the bot successfully completes the task. No need to consider the bot's response. 68 | 69 | *IMPORTANT* 70 | Format your response into two lines as shown below: 71 | 72 | Thoughts: 73 | Status: "success" or "failure" 74 | """ 75 | ) 76 | prompt = ( 77 | f"User Intent: {intent}\n\n" 78 | "Action History:\n" 79 | f"{last_actions}\n\n" 80 | "The final state of the webpage provided as an accessibility tree:\n" 81 | f"{axtree_txt}\n\n" 82 | "The last snapshot of the web page is shown in the image." 83 | ) 84 | 85 | return prompt, system_msg 86 | 87 | class GPT4V_Client: 88 | def __init__(self, model_name: str = "gpt-4o", max_tokens: int = 512): 89 | self.model_name = model_name 90 | self.max_tokens = max_tokens 91 | 92 | def encode_image(self, path: str): 93 | if isinstance(path, np.ndarray): 94 | return image_to_jpg_base64_url(path) 95 | 96 | with open(path, 'rb') as f: 97 | return base64.b64encode(f.read()).decode('utf-8') 98 | 99 | def one_step_chat( 100 | self, text, image: Union[Image.Image, np.ndarray], 101 | system_msg: Optional[str] = None, 102 | ) -> tuple[str, ChatCompletion]: 103 | jpg_base64_str = self.encode_image(image) 104 | messages = [] 105 | if system_msg is not None: 106 | messages.append({"role": "system", "content": system_msg}) 107 | messages += [{ 108 | "role": "user", 109 | "content": [ 110 | {"type": "text", "text": text}, 111 | {"type": "image_url", 112 | "image_url": {"url": f"data:image/jpeg;base64,{jpg_base64_str}"},}, 113 | ], 114 | }] 115 | response = client.chat.completions.create( 116 | model=self.model_name, 117 | messages=messages, 118 | max_tokens=self.max_tokens, 119 | ) 120 | return response.choices[0].message.content, response 121 | 122 | 123 | class Evaluator: 124 | def __init__(self, model_name: str): 125 | self.model_name = model_name 126 | self.client = GPT4V_Client(model_name) #LM_Client(model_name) 127 | 128 | def evaluate(self, trajectory: Trajectory): 129 | action_history = "" 130 | for idx, step in enumerate(trajectory.steps): 131 | action_history += f"{idx+1}: {step.action}\n" 132 | 133 | response = trajectory.response if trajectory.response else "None" 134 | 135 | prompt, sys_msg = build_vision_eval_prompt( 136 | trajectory.goal, response, action_history, trajectory.steps[-1].observation["axtree_txt"] 137 | ) 138 | img = trajectory.steps[-1].observation["screenshot"] 139 | msg_str, llm_response_obj = self.client.one_step_chat(text=prompt, image=img, system_msg=sys_msg) 140 | 141 | msg_dict = { 142 | "thoughts": extract_content(msg_str, "Thoughts:"), 143 | "status": extract_content(msg_str, "Status:").replace('"', ""), 144 | } 145 | 146 | logger.info(f"Evaluating trajectory with goal: {trajectory.goal}") 147 | logger.info(f"Model Response: {msg_str}") 148 | 149 | trajectory.success = msg_dict["status"].lower() == "success" 150 | trajectory.reward = 1.0 if trajectory.success else 0.0 151 | 152 | evaluation_info = { 153 | "output": msg_dict, 154 | "reward": trajectory.reward, 155 | "model_usage": llm_response_obj.usage.to_dict() 156 | } 157 | 158 | if trajectory.misc is None: 159 | trajectory.misc = {} 160 | 161 | trajectory.misc["evaluation_info"] = evaluation_info 162 | -------------------------------------------------------------------------------- /projects/go-browse/data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | from glob import glob 5 | from tqdm import tqdm 6 | from webexp.explore.core.trajectory import Trajectory 7 | from webexp.explore.core.node import Node 8 | from webexp.explore.core.graph import Graph 9 | import networkx as nx 10 | import random 11 | 12 | OUTPUT_FILE_NAME = "" # Replace with your desired output file path 13 | 14 | 15 | DIRS_TO_LOAD = [ 16 | "", # Replace with your desired graph directory path 17 | "", # Replace with your desired graph directory path 18 | ... 19 | ] 20 | 21 | MAX_NODES_PER_DIR = 20 22 | 23 | TRAJ_PATTERNS = [ 24 | 'tasks/*/*_trajs/*/', 25 | ] 26 | 27 | def extract_domain(path): 28 | """ 29 | Extract domain name from a run directory path. 30 | 31 | Example path: "data/runs/4432878-map/graph" 32 | Returns: "map" 33 | """ 34 | # Use regex to extract the domain name between the last hyphen and /graph 35 | match = re.search(r'/runs/\d+-([^/]+)/graph', path) 36 | if match: 37 | return match.group(1) 38 | 39 | # Alternative approach using string manipulation if regex fails 40 | try: 41 | # Split the path and locate the directory containing 'graph' 42 | parts = path.split('/') 43 | for i, part in enumerate(parts): 44 | if part == 'graph' and i > 0: 45 | # Get the parent directory name and extract domain after hyphen 46 | parent_dir = parts[i-1] 47 | if '-' in parent_dir: 48 | return parent_dir.split('-', 1)[1] 49 | except: 50 | pass 51 | 52 | # Return None or the original path if parsing fails 53 | return None 54 | 55 | 56 | def build_nx_graph(g: Graph): 57 | 58 | netG = nx.DiGraph() 59 | 60 | node_urls = list(g.nodes.keys()) 61 | for i, node in enumerate(g.nodes.values()): 62 | for prefix in node.prefixes: 63 | prefix_misc = prefix.misc 64 | if prefix_misc and prefix.start_url in node_urls: 65 | task_misc = prefix_misc.get("task_misc", None) 66 | sourced_from_nav_explore_task = task_misc is not None and 'agent_info' in task_misc and task_misc['agent_info']['name'] == 'NavExplorerAgent' 67 | 68 | tags = prefix_misc.get("tags", None) 69 | task_tagged_as_nav = "NAV" in tags if tags else False 70 | 71 | if task_tagged_as_nav or sourced_from_nav_explore_task: 72 | if netG.has_edge(prefix.start_url, prefix.end_url): 73 | old_weight = netG[prefix.start_url][prefix.end_url]['weight'] 74 | if old_weight > len(prefix): 75 | netG[prefix.start_url][prefix.end_url]['weight'] = len(prefix) 76 | netG[prefix.start_url][prefix.end_url]['trace'] = prefix 77 | else: 78 | netG.add_edge(prefix.start_url, prefix.end_url, weight=len(prefix), trace=prefix) 79 | 80 | return netG 81 | 82 | 83 | def get_prefixes(g: Graph, netG: nx.DiGraph): 84 | shortest_paths = dict(nx.single_source_all_shortest_paths(netG, source=g.root.url, weight='weight')) 85 | 86 | prefixes = {} 87 | 88 | for k in shortest_paths.keys(): 89 | paths = shortest_paths[k] 90 | 91 | prefixes[k] = [] 92 | 93 | for path in paths: 94 | 95 | curr_prefix = [] 96 | 97 | for i in range(len(path) - 1): 98 | start_url = path[i] 99 | end_url = path[i + 1] 100 | 101 | curr_prefix.extend(netG[start_url][end_url]['trace'].steps) 102 | 103 | for step in curr_prefix: 104 | step.misc['is_prefix_step'] = True 105 | 106 | prefixes[k].append(curr_prefix) 107 | 108 | return prefixes 109 | 110 | 111 | if __name__ == "__main__": 112 | 113 | total_trajs = 0 114 | total_steps = 0 115 | 116 | 117 | with open(OUTPUT_FILE_NAME, "w") as f: 118 | for directory in tqdm(DIRS_TO_LOAD, desc="Domains", position=0, leave=False): 119 | 120 | g = Graph.load(directory, load_steps=True, load_prefixes=True, load_images=False, max_nodes=MAX_NODES_PER_DIR) 121 | 122 | netG = build_nx_graph(g) 123 | 124 | prefixes = get_prefixes(g, netG) 125 | 126 | for _, node in tqdm(enumerate(g.nodes.values()), desc="Nodes", position=1, leave=False): 127 | for task in tqdm(node.tasks.values(), desc="Tasks", position=2, leave=False): 128 | for traj in task.positive_trajs + task.negative_trajs: 129 | sampled_prefix = random.choice(prefixes[node.url]) if node.url in prefixes else [] 130 | traj_steps = traj.steps 131 | 132 | is_traj_prefixed = False 133 | 134 | if (not ("needs_prefix" in traj.misc and not traj.misc["needs_prefix"])) and traj.reward > 0 and sampled_prefix: 135 | traj_steps = sampled_prefix + traj_steps 136 | is_traj_prefixed = True 137 | 138 | for i, step in enumerate(traj_steps): 139 | 140 | processed_obs = {k: v for k, v in step.observation.items() if k != "extra_element_properties"} 141 | 142 | step_data = { 143 | "action": step.action, 144 | "parsed_action": step.parsed_action, 145 | "thought": step.thought, 146 | "obs": processed_obs, 147 | "misc": step.misc, 148 | "step_number": i, 149 | } 150 | 151 | traj_data = { 152 | "goal": traj.goal, 153 | "reward": traj.reward, 154 | "success": traj.success, 155 | "response": traj.response, 156 | "misc": traj.misc, 157 | "traj_num": total_trajs, 158 | "traj_length": len(traj_steps), 159 | "is_traj_prefixed": is_traj_prefixed, 160 | } 161 | 162 | node_data = { 163 | "node_url": node.url 164 | } 165 | 166 | graph_data = { 167 | "domain": extract_domain(directory), 168 | "root_url": g.root.url, 169 | } 170 | 171 | row_data = { 172 | "step_data": step_data, 173 | "traj_data": traj_data, 174 | "node_data": node_data, 175 | "graph_data": graph_data, 176 | } 177 | 178 | f.write(json.dumps(row_data) + "\n") 179 | 180 | total_steps += 1 181 | total_trajs += 1 182 | 183 | print("Total Trajs: ", total_trajs) 184 | print("Total Steps: ", total_steps) -------------------------------------------------------------------------------- /webexp/explore/core/episode.py: -------------------------------------------------------------------------------- 1 | from .agent import AgentWithExplorationCallbacks, ExplorerAgentWithExplorationCallbacks, wrap_agent_for_callback_protocol 2 | from .evaluator import Evaluator 3 | from .graph import Graph 4 | from .node import Node 5 | from .task import Task 6 | from .trajectory import Trajectory, TrajectoryStep 7 | from ...agents.base_agent import Agent, ExplorerAgent 8 | from browsergym.core.env import BrowserEnv 9 | from browsergym.experiments.loop import _send_chat_info 10 | from tenacity import retry, stop_after_attempt, wait_exponential 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | 16 | if not logger.handlers: 17 | handler = logging.StreamHandler() 18 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | 22 | @retry( 23 | stop=stop_after_attempt(3), 24 | wait=wait_exponential(multiplier=1, min=4, max=20), 25 | ) 26 | def reset_env_to_node( 27 | env: BrowserEnv, 28 | node: Node, 29 | agent: Agent, 30 | goal: str, 31 | ): 32 | """ 33 | Reset the environment to a given node. 34 | 35 | Args: 36 | env: The environment to reset. 37 | node (Node): The node to reset to. 38 | agent (BaseAgent): The agent to reset. 39 | goal (str): The goal for the episode. 40 | """ 41 | 42 | logger.info(f"Resetting environment to node: {node.url}...") 43 | 44 | env.reset() 45 | nav_action = f"goto('{node.url}')" 46 | env.action_mapping = agent.action_processor 47 | env.goal_object = [{"type": "text", "text": goal}] 48 | env.chat.add_message(role="user", msg=goal) 49 | obs, _, _, _, _ = env.step(nav_action) 50 | return agent.obs_preprocessor(obs) 51 | 52 | 53 | def get_fresh_obs(env: BrowserEnv): 54 | """ 55 | Get a fresh observation from the environment. 56 | 57 | Args: 58 | env: The environment to get the observation from. 59 | 60 | Returns: 61 | dict: The observation from the environment. 62 | """ 63 | # TODO: We can make an ExplorationBrowserEnv that has a more reliable api. 64 | obs = env._get_obs() 65 | return obs 66 | 67 | # TODO: Can move this into a ExplorationBrowserEnv class as part of the is_done logic. 68 | def has_new_assistant_message(env: BrowserEnv): 69 | """ 70 | Check if there is a new assistant message in the environment. 71 | 72 | Args: 73 | env: The environment to check. 74 | 75 | Returns: 76 | bool: True if there is a new assistant message, False otherwise. 77 | """ 78 | chat_messages = env.chat.messages 79 | if chat_messages[-1]["role"] == "assistant" or chat_messages[-1]["role"] == "infeasible": 80 | return True 81 | return False 82 | 83 | def get_action( 84 | env: BrowserEnv, 85 | agent: Agent, 86 | obs: dict, 87 | traj: Trajectory, 88 | oracle_action: str = None 89 | ) -> tuple[str, dict]: 90 | """ 91 | Get the action from the agent. 92 | 93 | Args: 94 | env: The environment to get the action from. 95 | agent (BaseAgent): The agent to get the action from. 96 | obs (dict): The observation from the environment. 97 | traj (Trajectory): The trajectory of the episode. 98 | oracle_action (str, optional): The oracle action to use if available. 99 | 100 | Returns: 101 | tuple: The action and action extras dict from the agent. 102 | """ 103 | action, action_extras = agent.get_action(obs, oracle_action=oracle_action) 104 | thought = action_extras.get("thought", None) 105 | parsed_action = action_extras.get("parsed_action", None) 106 | 107 | if thought and "think" not in action_extras: 108 | action_extras["think"] = thought 109 | 110 | logger.info(f"Agent chose action: \n{action}") 111 | 112 | traj.add_step(action, parsed_action, thought, obs, {'model_usage': action_extras.get("model_usage", None), 'agent_config': agent.get_config()}) 113 | 114 | # TODO: Need a more stable api for modifying the chat pane. Perhaps we can create an env wrapper that exposes such as an api. 115 | _send_chat_info(env.chat, action, action_extras) 116 | 117 | return action, action_extras 118 | 119 | 120 | def perform_env_step( 121 | env: BrowserEnv, 122 | agent: Agent, 123 | action: str, 124 | ) -> tuple: 125 | """ 126 | Perform a step in the environment. 127 | 128 | Args: 129 | env: The environment to perform the step in. 130 | agent (BaseAgent): The agent to perform the step with. 131 | action (str): The action to perform. 132 | traj (Trajectory): The trajectory of the episode. 133 | oracle_action (str, optional): The oracle action to use if available. 134 | 135 | Returns: 136 | tuple: The observation, reward, terminated, truncated, and env_info from the environment. 137 | """ 138 | obs, reward, terminated, truncated, env_info = env.step(action) 139 | obs = agent.obs_preprocessor(obs) 140 | return obs, reward, terminated, truncated, env_info 141 | 142 | 143 | 144 | def run_episode( 145 | goal: str, 146 | node: Node, 147 | env: BrowserEnv, 148 | agent: AgentWithExplorationCallbacks, 149 | evaluator: Evaluator, 150 | graph: Graph, 151 | max_steps: int, 152 | callback_context: dict = None, 153 | ) -> Trajectory: 154 | """ 155 | Run an episode with an agent in the environment. 156 | 157 | Args: 158 | goal (str): The goal for the episode. 159 | node (Node): The current node in the graph. 160 | env: The environment. 161 | agent (BaseAgent): The agent to run the episode. 162 | evaluator (Evaluator): The evaluator to evaluate the episode. 163 | graph (Graph): The graph of nodes. 164 | max_steps (int): The maximum number of steps in the episode. 165 | 166 | Returns: 167 | 168 | Trajectory: The trajectory of the episode. 169 | """ 170 | 171 | logger.info(f"Running episode for goal: {goal}, for node {node.url}...") 172 | 173 | obs = reset_env_to_node( 174 | env=env, 175 | node=node, 176 | agent=agent, 177 | goal=goal, 178 | ) 179 | 180 | agent.reset() 181 | 182 | traj = Trajectory.from_goal(goal, agent.get_config()) 183 | 184 | num_steps = 0 185 | done = False 186 | 187 | callback_context_seed = callback_context if callback_context else {} 188 | 189 | while not done and num_steps < max_steps: 190 | 191 | logger.info(f"Step {num_steps} for goal {goal}.") 192 | num_steps += 1 193 | 194 | callback_context = {**callback_context_seed} 195 | 196 | num_steps, obs, reward, terminated, truncated, env_info, goal, callback_context = agent.run_pre_step_callbacks( 197 | num_steps, goal, env, graph, node, traj, obs, 0.0, False, False, {}, callback_context 198 | ) 199 | 200 | action, action_extras = get_action( 201 | env=env, 202 | agent=agent, 203 | obs=obs, 204 | traj=traj 205 | ) 206 | 207 | if action is None: 208 | logger.info("Agent returned None action. Ending episode.") 209 | break 210 | 211 | obs, reward, terminated, truncated, env_info = perform_env_step( 212 | env=env, 213 | agent=agent, 214 | action=action, 215 | ) 216 | 217 | if has_new_assistant_message(env): 218 | logger.info("New assistant message received.") 219 | terminated = True 220 | 221 | # We only need to evaluate if the when we are not exploring. 222 | if not isinstance(agent, ExplorerAgent): 223 | logger.info("Evaluating episode...") 224 | evaluator.evaluate(traj) 225 | logger.info(f"Episode evaluated and received reward {traj.reward}.") 226 | 227 | num_steps, obs, reward, terminated, truncated, env_info, goal, callback_context = agent.run_post_step_callbacks( 228 | num_steps, goal, env, graph, node, traj, obs, reward, terminated, truncated, env_info, callback_context 229 | ) 230 | 231 | done = terminated or truncated 232 | 233 | traj.extract_response(env) 234 | traj.final_state = TrajectoryStep( 235 | action=None, 236 | parsed_action=None, 237 | thought=None, 238 | observation=obs, 239 | misc=None, 240 | ) 241 | 242 | return traj 243 | -------------------------------------------------------------------------------- /webexp/explore/core/node.py: -------------------------------------------------------------------------------- 1 | from .task import Task 2 | from .trace import Trace 3 | from .trajectory import TrajectoryStep, Trajectory 4 | from dataclasses import dataclass 5 | import json 6 | import logging 7 | import os 8 | import shutil 9 | import glob 10 | 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.INFO) 13 | 14 | if not logger.handlers: 15 | handler = logging.StreamHandler() 16 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | handler.setFormatter(formatter) 18 | logger.addHandler(handler) 19 | 20 | def delete_folder_contents(folder: str): 21 | for filename in os.listdir(folder): 22 | file_path = os.path.join(folder, filename) 23 | try: 24 | if os.path.isfile(file_path) or os.path.islink(file_path): 25 | os.unlink(file_path) 26 | elif os.path.isdir(file_path): 27 | shutil.rmtree(file_path) 28 | except Exception as e: 29 | print(f'Failed to delete {file_path}. Reason: {e}') 30 | 31 | 32 | @dataclass 33 | class Node: 34 | url: str 35 | tasks: dict[str, Task] 36 | exploration_tasks: dict[str, Task] 37 | children: list[str] 38 | description: str 39 | prefixes: list[Trace] 40 | visited: bool 41 | exp_dir: str 42 | misc: dict = None 43 | 44 | 45 | def __post_init__(self): 46 | if not os.path.exists(self.exp_dir): 47 | os.makedirs(self.exp_dir) 48 | node_info = { 49 | "url": self.url, 50 | "description": self.description, 51 | "children": self.children, 52 | #"prefix_source": self.prefix_source, 53 | "visited": self.visited, 54 | "misc": self.misc 55 | } 56 | with open(os.path.join(self.exp_dir, "node_info.json"), "w") as f: 57 | json.dump(node_info, f, indent=4) 58 | 59 | 60 | # Save prefixes 61 | prefix_save_dir = os.path.join(self.exp_dir, "prefixes") 62 | os.makedirs(prefix_save_dir) 63 | for i, trace in enumerate(self.prefixes): 64 | trace_save_dir = os.path.join(prefix_save_dir, f"prefix_{i}") 65 | os.makedirs(trace_save_dir) 66 | trace.save(trace_save_dir) 67 | 68 | 69 | @staticmethod 70 | def load(load_dir: str, load_steps: bool=True, load_prefix: bool=True, load_images: bool=True): 71 | 72 | logger.info(f"Loading node from {load_dir}") 73 | 74 | with open(os.path.join(load_dir, "node_info.json"), "r") as f: 75 | node_info = json.load(f) 76 | 77 | visited = node_info["visited"] 78 | 79 | tasks = {} 80 | exploration_tasks = {} 81 | 82 | if visited: 83 | if os.path.exists(os.path.join(load_dir, "tasks")): 84 | for i in range(len(os.listdir(os.path.join(load_dir, "tasks")))): 85 | task_load_dir = os.path.join(load_dir, "tasks", f"task_{i}") 86 | task = Task.load(task_load_dir, load_steps=load_steps, load_images=load_images) 87 | tasks[task.goal] = task 88 | 89 | if os.path.exists(os.path.join(load_dir, "exploration_tasks")): 90 | for i in range(len(os.listdir(os.path.join(load_dir, "exploration_tasks")))): 91 | task_load_dir = os.path.join(load_dir, "exploration_tasks", f"task_{i}") 92 | task = Task.load(task_load_dir, load_steps=load_steps, load_images=load_images) 93 | exploration_tasks[task.goal] = task 94 | 95 | prefixes = [] 96 | if load_prefix: 97 | 98 | if os.path.exists(os.path.join(load_dir, "prefixes")): 99 | # Use glob to find all prefix directories 100 | prefix_dirs = glob.glob(os.path.join(load_dir, "prefixes", "prefix_*")) 101 | 102 | # Sort numerically by extracting the number from each path 103 | prefix_dirs.sort(key=lambda x: int(x.split("_")[-1])) 104 | 105 | for prefix_dir in prefix_dirs: 106 | try: 107 | prefixes.append(Trace.load(prefix_dir, load_steps=load_steps, load_images=False)) 108 | except Exception as e: 109 | logger.warning(f"Failed to load prefix from {prefix_dir}: {e}") 110 | 111 | 112 | return Node( 113 | node_info["url"], 114 | tasks, 115 | exploration_tasks, 116 | node_info["children"], 117 | node_info["description"], 118 | prefixes, 119 | node_info["visited"], 120 | load_dir, 121 | misc=node_info.get("misc", None) 122 | ) 123 | 124 | def update_save(self, save_prefix=False, save_info=True): 125 | if save_info: 126 | node_info = { 127 | "url": self.url, 128 | "description": self.description, 129 | "children": self.children, 130 | "visited": self.visited, 131 | "misc": self.misc 132 | } 133 | with open(os.path.join(self.exp_dir, "node_info.json"), "w") as f: 134 | json.dump(node_info, f, indent=4) 135 | 136 | if save_prefix: 137 | 138 | prefix_save_dir = os.path.join(self.exp_dir, "prefixes") 139 | #delete existing prefixes 140 | delete_folder_contents(prefix_save_dir) 141 | for i, trace in enumerate(self.prefixes): 142 | trace_save_dir = os.path.join(prefix_save_dir, f"prefix_{i}") 143 | os.makedirs(trace_save_dir, exist_ok=True) 144 | trace.save(trace_save_dir) 145 | 146 | 147 | def add_task(self, goal: str, task_misc: dict = None) -> Task: 148 | task_dir = os.path.join(self.exp_dir, "tasks", f"task_{len(self.tasks)}") 149 | processed_goal, _ = Task.process_raw_goal(goal) 150 | if processed_goal not in self.tasks: 151 | self.tasks[processed_goal] = Task.from_goal(goal, task_dir, misc=task_misc) 152 | logger.info(f"Added task for goal: {processed_goal}") 153 | else: 154 | logger.warning(f"Task for goal: {processed_goal} already exists. Not adding again.") 155 | 156 | return self.tasks[processed_goal] 157 | 158 | 159 | def add_tasks(self, goals: list[str], task_misc: dict = None) -> list[Task]: 160 | return [self.add_task(goal, task_misc) for goal in goals] 161 | 162 | 163 | def add_exploration_task(self, goal: str, task_misc: dict = None) -> Task: 164 | task_exp_dir = os.path.join(self.exp_dir, "exploration_tasks", f"task_{len(self.exploration_tasks)}") 165 | processed_goal, _ = Task.process_raw_goal(goal) 166 | if processed_goal not in self.exploration_tasks: 167 | self.exploration_tasks[processed_goal] = Task.from_goal(goal, task_exp_dir, misc=task_misc) 168 | logger.info(f"Added exploration task for goal: {processed_goal}") 169 | else: 170 | logger.warning(f"Exploration task for goal: {processed_goal} already exists. Not adding again.") 171 | 172 | return self.exploration_tasks[processed_goal] 173 | 174 | def add_exploration_tasks(self, goals: list[str], task_misc: dict = None) -> list[Task]: 175 | return [self.add_exploration_task(goal, task_misc) for goal in goals] 176 | 177 | def add_trajectory(self, traj: Trajectory, task_misc: dict = None): 178 | task = self.add_task(traj.goal, task_misc) 179 | task.add_trajectory(traj) 180 | 181 | 182 | def add_trajectories(self, trajs: list[Trajectory], task_misc: dict = None): 183 | for traj in trajs: 184 | self.add_trajectory(traj, task_misc) 185 | 186 | 187 | def add_exploration_traj(self, traj: Trajectory): 188 | task = self.add_exploration_task(traj.goal) 189 | task.add_trajectory(traj) 190 | 191 | 192 | def get_feasible_tasks(self) -> list[Task]: 193 | 194 | feasible_tasks = [] 195 | 196 | for task in self.tasks.values(): 197 | if task.is_feasible(): 198 | feasible_tasks.append(task) 199 | 200 | return feasible_tasks 201 | 202 | 203 | def add_prefix(self, prefix: Trace): 204 | prefix_save_dir = os.path.join(self.exp_dir, "prefixes") 205 | os.makedirs(prefix_save_dir, exist_ok=True) 206 | 207 | # Find the highest existing prefix number 208 | existing_prefixes = glob.glob(os.path.join(prefix_save_dir, "prefix_*")) 209 | if existing_prefixes: 210 | # Extract numbers and find the highest one 211 | highest_num = max([int(p.split("_")[-1]) for p in existing_prefixes]) 212 | next_num = highest_num + 1 213 | else: 214 | next_num = 0 215 | 216 | # Use the next available number 217 | trace_save_dir = os.path.join(prefix_save_dir, f"prefix_{next_num}") 218 | os.makedirs(trace_save_dir) 219 | prefix.save(trace_save_dir) 220 | self.prefixes.append(prefix) 221 | -------------------------------------------------------------------------------- /webexp/agents/solver_agent.py: -------------------------------------------------------------------------------- 1 | from .base_agent import AgentFactory, BaseAgent 2 | from .prompt_builders.solver_prompt_builder import SolverPromptBuilder 3 | from .trajectory_data import BrowserGymAgentStepData 4 | from browsergym.core.action.highlevel import HighLevelActionSet 5 | from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html 6 | from openai import OpenAI 7 | from tenacity import retry, before_sleep_log, stop_after_attempt, wait_exponential, wait_random 8 | import ast 9 | import logging 10 | import os 11 | import re 12 | import time 13 | 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | def messages_to_string(messages: list[dict]) -> str: 18 | prompt_text_strings = [] 19 | for message in messages: 20 | prompt_text_strings.append(message["content"]) 21 | full_prompt_txt = "\n".join(prompt_text_strings) 22 | return full_prompt_txt 23 | 24 | 25 | def extract_action_and_thought(raw_string): 26 | """Extract thought and action from potentially malformed JSON string. 27 | 28 | Args: 29 | raw_string (str): Raw string containing thought and action 30 | 31 | Returns: 32 | tuple: (action, thought) or (None, None) if extraction fails 33 | """ 34 | # Initialize defaults 35 | thought = None 36 | action = None 37 | 38 | try: 39 | # Look for thought pattern using non-greedy match 40 | thought_match = re.search(r'"thought"\s*:\s*"(.*?)"(?=\s*[,}])', raw_string, re.DOTALL) 41 | if thought_match: 42 | thought = thought_match.group(1) 43 | # Clean up escaped quotes 44 | thought = thought.replace('\\"', '"') 45 | 46 | # Look for action pattern using non-greedy match 47 | action_match = re.search(r'"action"\s*:\s*"(.*?)"(?=\s*[,}])', raw_string, re.DOTALL) 48 | if action_match: 49 | action = action_match.group(1) 50 | # Clean up escaped quotes 51 | action = action.replace('\\"', '"') 52 | 53 | except Exception as e: 54 | print(f"Error parsing string: {e}") 55 | return None, None 56 | 57 | return action, thought 58 | 59 | 60 | @AgentFactory.register 61 | class SolverAgent(BaseAgent): 62 | """ 63 | Agent used to fulfill/solve user requests. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | model_id: str, 69 | model_id_2: str | None = None, 70 | base_url: str | None = None, 71 | base_url_2: str | None = None, 72 | api_key: str | None = None, 73 | temperature: float = 1.0, 74 | char_limit: int = -1, 75 | demo_mode: str = 'off', 76 | ): 77 | """ 78 | Initialize the agent. 79 | 80 | Args: 81 | model_name (str): The name of the model to use. 82 | temperature (float): The temperature to use for sampling. 83 | demo_mode (bool): Whether to run in demo mode. 84 | """ 85 | 86 | # These are args that will be specified in the config. 87 | super().__init__(model_id=model_id, temperature=temperature, char_limit=char_limit, demo_mode=demo_mode) 88 | 89 | self.model_id = model_id 90 | self.model_id_2 = model_id_2 or model_id 91 | self.temperature = temperature 92 | self.char_limit = char_limit 93 | self.demo_mode = demo_mode 94 | 95 | 96 | base_url = base_url or os.getenv("OPENAI_BASE_URL") 97 | base_url_2 = base_url_2 or os.getenv("OPENAI_BASE_URL") 98 | api_key = api_key or os.getenv("OPENAI_API_KEY", "Unspecified!") 99 | self.client = OpenAI(base_url=base_url, api_key=api_key) 100 | self.client_long = OpenAI(base_url=base_url_2, api_key=api_key) 101 | 102 | self.action_set = HighLevelActionSet( 103 | subsets=["chat", "bid", "infeas", "nav"], 104 | strict=False, 105 | multiaction=False, 106 | demo_mode=demo_mode 107 | ) 108 | 109 | self.prompt_builder = SolverPromptBuilder(self.action_set) 110 | 111 | self.history: list[BrowserGymAgentStepData] = [] 112 | 113 | def reset(self): 114 | self.history.clear() 115 | 116 | def obs_preprocessor(self, obs: dict) -> dict: 117 | 118 | return { 119 | "chat_messages": obs["chat_messages"], 120 | "screenshot": obs["screenshot"], 121 | "goal_object": obs["goal_object"], 122 | "last_action": obs["last_action"], 123 | "last_action_error": obs["last_action_error"], 124 | "open_pages_urls": obs["open_pages_urls"], 125 | "open_pages_titles": obs["open_pages_titles"], 126 | "active_page_index": obs["active_page_index"], 127 | "axtree_txt": flatten_axtree_to_str(obs["axtree_object"], filter_visible_only=False, extra_properties=obs["extra_element_properties"]), 128 | "axtree_visible_only_txt": flatten_axtree_to_str(obs["axtree_object"], filter_visible_only=True, extra_properties=obs["extra_element_properties"]), 129 | "pruned_html": prune_html(flatten_dom_to_str(obs["dom_object"])), 130 | "extra_element_properties": obs["extra_element_properties"], 131 | } 132 | 133 | def action_processor(self, action: str) -> str: 134 | """ 135 | Process the action before it is passed to the environment. 136 | 137 | Args: 138 | action (str): The action to process. 139 | 140 | Returns: 141 | str: The processed action. 142 | """ 143 | parsed_action, thought = extract_action_and_thought(action) 144 | return self.action_set.to_python_code(parsed_action if parsed_action else action) 145 | 146 | 147 | def get_action(self, obs: dict, oracle_action:tuple[str, str] = None, **kwargs) -> tuple[str, dict]: 148 | """ 149 | Get the action for the given observation. 150 | 151 | Args: 152 | obs (dict): The observation from the environment. 153 | oracle_action tuple[str, str]: Tuple of (action, thought) to use if available instead of generating a new one. 154 | 155 | Returns: 156 | str: The action to take. 157 | """ 158 | 159 | current_step = BrowserGymAgentStepData( 160 | action=None, 161 | thought=None, 162 | axtree=obs["axtree_txt"], 163 | last_action_error=obs.get("last_action_error"), 164 | misc={} 165 | ) 166 | 167 | if oracle_action is None: 168 | # Use adaptive retry mechanism with character limit reduction 169 | response = self.make_llm_call_with_adaptive_retry(obs, current_step) 170 | 171 | raw_action = response.choices[0].message.content 172 | action, thought = extract_action_and_thought(raw_action) 173 | current_step.misc["model_usage"] = response.usage.to_dict() 174 | 175 | else: 176 | action, thought = oracle_action 177 | raw_action = f'{{"thought": "{thought}", "action": "{action}"}}' 178 | 179 | print(f"Raw Action:\n {raw_action}") 180 | 181 | current_step.action = action 182 | current_step.thought = thought 183 | current_step.misc.update({"thought": thought, "parsed_action": action}) 184 | 185 | self.history.append(current_step) 186 | 187 | return raw_action, current_step.misc 188 | 189 | def make_llm_call_with_adaptive_retry(self, obs: dict, current_step: BrowserGymAgentStepData) -> dict: 190 | """ 191 | Make a call to the LLM with adaptive retry that reduces character limit on failures. 192 | 193 | Args: 194 | obs (dict): The observation from the environment. 195 | current_step (BrowserGymAgentStepData): The current step data. 196 | 197 | Returns: 198 | dict: The response from the LLM. 199 | """ 200 | max_attempts = 5 201 | attempt = 0 202 | current_char_limit = self.char_limit 203 | 204 | while attempt < max_attempts: 205 | try: 206 | # Build messages with current character limit 207 | messages = self.prompt_builder.build_messages( 208 | goal=obs["goal_object"][0]["text"], 209 | current_step=current_step, 210 | history=self.history, 211 | char_limit=current_char_limit if (attempt == 0) or (current_char_limit < 0) else current_char_limit * 2 # TODO: Ad-hoc! 212 | )['prompt'] 213 | 214 | print(f"Attempt {attempt+1}: Using char_limit={current_char_limit}") 215 | 216 | if attempt == 0: 217 | # Make the actual API call 218 | return self.client.chat.completions.create( 219 | model=self.model_id, 220 | messages=messages, 221 | temperature=self.temperature 222 | ) 223 | else: 224 | return self.client_long.chat.completions.create( 225 | model=self.model_id_2, 226 | messages=messages, 227 | temperature=self.temperature 228 | ) 229 | 230 | except Exception as e: 231 | attempt += 1 232 | if attempt >= max_attempts: 233 | logger.error(f"Failed after {max_attempts} attempts: {str(e)}") 234 | raise 235 | 236 | if attempt > 1: 237 | current_char_limit = int(current_char_limit * 0.95) 238 | logger.warning(f"Retrying with {current_char_limit} character limit after error: {str(e)}") 239 | 240 | if attempt > 1: # Skip delay for first retry 241 | wait_time = 1.5 * (2 ** (attempt-1)) + (0.1 * attempt) 242 | logger.info(f"Waiting {wait_time:.2f} seconds before retry") 243 | time.sleep(wait_time) 244 | else: 245 | logger.info("Retrying immediately") 246 | -------------------------------------------------------------------------------- /webexp/agents/prompt_builders/solver_prompt_builder.py: -------------------------------------------------------------------------------- 1 | from . import BasePromptBuilder, flatten_messages 2 | from ..trajectory_data import BrowserGymAgentStepData, BrowserGymAgentTrajectoryData 3 | from browsergym.core.action.base import AbstractActionSet 4 | from dataclasses import dataclass 5 | from textwrap import dedent 6 | import json 7 | 8 | class SolverPromptBuilder(BasePromptBuilder): 9 | 10 | def __init__(self, action_set: AbstractActionSet): 11 | self.action_set = action_set 12 | self.action_set_description = action_set.describe(with_long_description=True, with_examples=True) 13 | 14 | 15 | def build_messages(self, obs: dict): 16 | messages = [] 17 | if "message" in obs: 18 | messages.append({"text": obs["message"]}) 19 | return messages 20 | 21 | def format_thought_and_action(self, thought: str, action: str) -> str: 22 | d = {} 23 | if thought: 24 | d['thought'] = thought 25 | if action: 26 | d['action'] = action 27 | return json.dumps(d) 28 | 29 | def trim_axtree(self, axtree: str, num_chars_overflow: int) -> str: 30 | trim_str = "...trimmed due to context size limit" 31 | return axtree[:-num_chars_overflow - len(trim_str)] + trim_str 32 | 33 | 34 | def build_trajectory_messages(self, trajectory_data: BrowserGymAgentTrajectoryData, char_limit: int=-1) -> list[dict]: 35 | messages = [] 36 | for i, step in enumerate(trajectory_data.steps): 37 | messages.append(self.build_messages(trajectory_data.goal, step, trajectory_data.steps[:i], char_limit)) 38 | return messages 39 | 40 | 41 | def build_messages(self, goal: str, current_step: BrowserGymAgentStepData, history: list[BrowserGymAgentStepData], char_limit: int=-1) -> dict: 42 | past_thoughts = [step.thought for step in history] 43 | past_actions = [step.misc['parsed_action'] if 'parsed_action' in step.misc else step.action for step in history] 44 | 45 | axtree = current_step.axtree 46 | last_action_error = current_step.last_action_error 47 | completion_thought = current_step.thought 48 | completion_action = current_step.misc['parsed_action'] if current_step.misc and 'parsed_action' in current_step.misc else current_step.action 49 | 50 | add_completion = completion_thought or completion_action 51 | 52 | messages = self._build_messages( 53 | goal, 54 | past_thoughts, 55 | past_actions, 56 | axtree, 57 | last_action_error, 58 | completion_thought, 59 | completion_action 60 | ) 61 | curr_char_count = self.count_message_chars(messages['prompt'] + (messages['completion'] if add_completion else [])) 62 | if char_limit > 0 and curr_char_count > char_limit: 63 | past_thoughts, past_actions = self.trim_past_thoughts_and_actions(past_thoughts, past_actions, max_allowed=8) 64 | messages = self._build_messages( 65 | goal, 66 | past_thoughts, 67 | past_actions, 68 | axtree, 69 | last_action_error, 70 | completion_thought, 71 | completion_action 72 | ) 73 | 74 | curr_char_count = self.count_message_chars(messages['prompt'] + (messages['completion'] if add_completion else [])) 75 | remaining_overflow = curr_char_count - char_limit 76 | if remaining_overflow > 0: 77 | axtree = self.trim_axtree(axtree, remaining_overflow) 78 | messages = self._build_messages( 79 | goal, 80 | past_thoughts, 81 | past_actions, 82 | axtree, 83 | last_action_error, 84 | completion_thought, 85 | completion_action 86 | ) 87 | 88 | return {k : flatten_messages(v) for k, v in messages.items() if v} 89 | 90 | 91 | def count_message_chars(self, messages: list[dict]) -> int: 92 | return sum([len(m['text']) for message in messages for m in message['content']]) 93 | 94 | def trim_past_thoughts_and_actions(self, past_thoughts: list[str | None], past_actions: list[str], max_allowed: int=3) -> tuple[list[str | None], list[str]]: 95 | if len(past_thoughts) > max_allowed: 96 | past_thoughts = past_thoughts[-max_allowed:] 97 | past_actions = past_actions[-max_allowed:] 98 | return past_thoughts, past_actions 99 | 100 | 101 | def _build_messages( 102 | self, 103 | goal: str, 104 | thoughts: list[str | None], 105 | actions: list[str | None], 106 | axtree: str, 107 | last_action_error: str | None = None, 108 | completion_thought: str | None = None, 109 | completion_action: str | None = None 110 | ): 111 | system_messages = {"role": "system", "content": [self.system_message()]} 112 | user_messages = { 113 | "role": "user", 114 | "content": [ 115 | self.goal_message(goal), 116 | self.axtree_message(axtree), 117 | self.action_space_message(self.action_set), 118 | self.action_history_messages(thoughts, actions), 119 | ] 120 | } 121 | if last_action_error: 122 | user_messages["content"].append(self.last_action_error_message(last_action_error)) 123 | 124 | user_messages["content"].append(self.next_action_request_message()) 125 | 126 | output = { "prompt": [system_messages, user_messages] } 127 | 128 | if completion_thought or completion_action: 129 | assistant_messages = { 130 | "role": "assistant", 131 | "content": [self.completion_message(completion_thought, completion_action)] 132 | } 133 | output["completion"] = [assistant_messages] 134 | 135 | return output 136 | 137 | 138 | #TODO: Make sure this updated system prompt works just as well. 139 | def system_message(self): 140 | return { 141 | "type": "text", 142 | "text": dedent("""\ 143 | # Instructions 144 | You are a UI Assistant, your goal is to help the user perform tasks using a web browser. 145 | Review the instructions from the user, the current state of the page and all other information to find the best possible next action to accomplish your goal. Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. 146 | """ 147 | ) 148 | } 149 | 150 | def goal_message(self, goal: str): 151 | return { 152 | "type": "text", 153 | "text": ( 154 | "# Goal\n" 155 | f"{goal}" 156 | ) 157 | } 158 | 159 | 160 | def action_space_message(self, action_set: AbstractActionSet): 161 | newline = "\n" 162 | return { 163 | "type": "text", 164 | "text": ("# Action Space" 165 | f"{self.action_set_description}\n\n" 166 | "Here are examples of actions with chain-of-thought reasoning:\n\n" 167 | f"{newline.join(newline + json.dumps(cot_example) for cot_example in self.cot_examples())}\n\n\n" 168 | ) 169 | } 170 | 171 | def cot_examples(self) -> list[dict]: 172 | return [ 173 | {"thought": "I now need to click on the Submit button to send the form. I will use the click action on the button, which has bid 12.", "action": "click('12')"}, 174 | {"thought": "I found the information requested by the user, I will send it to the chat.", "action": "send_msg_to_user('The price for a 15 inch laptop is 1499 USD.')"}, 175 | {"thought": "I have finished navigating to the Products page. I will inform the user that I have completed the task.", "action": "send_msg_to_user('I have finished navigating to the Products page.')"}, 176 | ] 177 | 178 | 179 | def axtree_message(self, axtree: str): 180 | return { 181 | "type": "text", 182 | "text": ( 183 | "# Current page Accessibility Tree\n" 184 | f"{axtree}" 185 | ) 186 | } 187 | 188 | def last_action_error_message(self, last_action_error: str): 189 | return { 190 | "type": "text", 191 | "text": ( 192 | "# Error message from last action\n" 193 | f"{last_action_error}" 194 | ) 195 | } 196 | 197 | def action_history_messages(self, thoughts: list[str | None], actions: list[str]): 198 | newline = "\n" 199 | return { 200 | "type": "text", 201 | "text": ( 202 | "# History of past actions\n" 203 | f"{newline.join(self.format_thought_and_action(thought, action) for thought, action in zip(thoughts, actions))}" 204 | ) 205 | } 206 | 207 | 208 | def next_action_request_message(self): 209 | return { 210 | "type": "text", 211 | "text": ( 212 | "# Next action\n\n" 213 | "You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, the current state of the page before deciding on your next action. Provide your output as a single json with a thought and an action. All reasoning must be contained within the thought key of the json output, and only a single action must be provided for the action key. Future actions will be taken subsequently. If you have finished performing the request, send a message to the user in a concise and to the point manner." 214 | ) 215 | } 216 | 217 | 218 | def completion_message(self, completion_thought: str, completion_action: str): 219 | return { 220 | "type": "text", 221 | "text": f"{self.format_thought_and_action(completion_thought, completion_action)}" 222 | } 223 | 224 | 225 | -------------------------------------------------------------------------------- /webexp/explore/algorithms/web_explore.py: -------------------------------------------------------------------------------- 1 | from ..core.agent import AgentWithExplorationCallbacks, ExplorerAgentWithExplorationCallbacks, wrap_agent_for_callback_protocol 2 | from ..core.evaluator import Evaluator 3 | from ..core.episode import run_episode, get_action, perform_env_step 4 | from ..core.graph import Graph 5 | from ..core.node import Node 6 | from ..core.task import Task 7 | from ..core.trace import Trace 8 | from ..core.trajectory import Trajectory 9 | from ...agents.base_agent import AgentFactory 10 | from browsergym.core.env import BrowserEnv 11 | from browsergym.experiments.loop import EnvArgs 12 | from dataclasses import dataclass 13 | from omegaconf import OmegaConf as oc 14 | from pathlib import Path 15 | from typing import Sequence, List, Dict, Optional 16 | import argparse 17 | import logging 18 | import os 19 | import random 20 | import requests 21 | import traceback 22 | 23 | logger = logging.getLogger(__name__) 24 | logger.setLevel(logging.INFO) 25 | 26 | if not logger.handlers: 27 | handler = logging.StreamHandler() 28 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 29 | handler.setFormatter(formatter) 30 | logger.addHandler(handler) 31 | 32 | @dataclass 33 | class WebExploreAgentConfig: 34 | """ 35 | Configuration for the Explorer agents. 36 | 37 | Attributes: 38 | agent_name (str): Name of the agent. 39 | agent_factory_args (Dict): Arguments for the agent factory. 40 | max_steps (int): Maximum steps for the agent. 41 | retries (int): Number of retries for the agent. 42 | """ 43 | agent_factory_args: Dict 44 | max_steps: int 45 | retries: int 46 | 47 | @dataclass 48 | class WebExploreConfig: 49 | """ 50 | Configuration for the WebExplore algorithm. 51 | 52 | Attributes: 53 | env (Dict): Environment configuration. 54 | evaluator (Dict): Evaluator configuration. 55 | max_nodes (int): Maximum number of nodes to explore. 56 | resume_from (Optional[str]): Path to resume from. 57 | page_explorers (List[WebExploreAgentConfig]): List of page explorer agent configurations. 58 | nav_explorers (List[WebExploreAgentConfig]): List of navigation explorer agent configurations. 59 | feasibility_checkers (List[WebExploreAgentConfig]): List of feasibility checker agent configurations. 60 | solvers (List[WebExploreAgentConfig]): List of solver agent configurations. 61 | allowlist_patterns (List[str]): List of URL patterns to allow. 62 | denylist_patterns (List[str]): List of URL patterns to block/deny. 63 | max_feasible_page_explorer_tasks_per_node (int): Maximum feasible tasks per node for page explorers. 64 | max_feasible_nav_explorer_tasks_per_node (int): Maximum feasible tasks per node for navigation explorers. 65 | exp_dir (str): Directory for saving exploration data. 66 | full_reset_url (Optional[str]): URL for full reset. 67 | """ 68 | env: Dict 69 | evaluator: Dict 70 | max_nodes: int 71 | resume_from: Optional[str] 72 | page_explorers: List[WebExploreAgentConfig] 73 | nav_explorers: List[WebExploreAgentConfig] 74 | feasibility_checkers: List[WebExploreAgentConfig] 75 | solvers: List[WebExploreAgentConfig] 76 | allowlist_patterns: List[str] 77 | denylist_patterns: List[str] 78 | exp_dir: str 79 | max_feasible_page_explorer_tasks_per_node: int 80 | max_feasible_nav_explorer_tasks_per_node: int 81 | full_reset_url: Optional[str] 82 | 83 | 84 | def perform_full_reset(full_reset_url: str, num_retries: int = 3): 85 | """ 86 | Perform a full reset of the environment by sending a POST request to the specified URL. 87 | """ 88 | for _ in range(num_retries): 89 | try: 90 | response = requests.post(full_reset_url) 91 | if response.status_code == 200: 92 | logger.info(f"Full reset successful: {response.text}") 93 | return 94 | else: 95 | logger.error(f"Full reset failed: {response.status_code} - {response.text}") 96 | except requests.exceptions.RequestException as e: 97 | logger.error(f"Error during full reset: {e}") 98 | 99 | logger.error("Failed to perform full reset after multiple attempts.") 100 | 101 | def backtrack_if_needed( 102 | agent, step_num: int, goal: str, env: BrowserEnv, graph: Graph, node: Node, traj: Trajectory, obs: dict, 103 | reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict 104 | ): 105 | """ 106 | Callback to check if we are on a blocked page and backtrack if needed. 107 | """ 108 | open_urls = obs['open_pages_urls'] 109 | for url in open_urls: 110 | if not graph.check_if_url_allowed(url): 111 | 112 | logger.info(f"Blocked page detected: {url}") 113 | 114 | oracle_action = ( 115 | "go_back()", 116 | "I am not permitted to view this page as it is on a blocklist,\ 117 | I will return back to the previous page and try something else." 118 | ) 119 | 120 | action = get_action( 121 | env=env, 122 | agent=agent, 123 | obs=obs, 124 | traj=traj, 125 | oracle_action=oracle_action 126 | ) 127 | 128 | obs, reward, terminated, truncated, env_info = perform_env_step( 129 | env=env, 130 | agent=agent, 131 | action=action, 132 | ) 133 | 134 | logger.info(f"Backtracked to {obs['open_pages_urls'][-1]}") 135 | 136 | return step_num, obs, reward, terminated, truncated, env_info, goal, callback_context 137 | 138 | def prestep_store_url( 139 | agent, step_num: int, goal: str, env: BrowserEnv, graph: Graph, node: Node, traj: Trajectory, obs: dict, 140 | reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict 141 | ): 142 | """ 143 | Callback to log the active url before the step. 144 | """ 145 | callback_context['pre_step_url'] = env.page.url 146 | return step_num, obs, reward, terminated, truncated, env_info, goal, callback_context 147 | 148 | def backtrack_when_new_page_found( 149 | agent, step_num: int, goal: str, env: BrowserEnv, graph: Graph, node: Node, traj: Trajectory, obs: dict, 150 | reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict 151 | ): 152 | """ 153 | Callback to check if we are on a new page and backtrack if needed. 154 | """ 155 | 156 | open_urls = obs['open_pages_urls'] 157 | if len(open_urls) > 1: 158 | for i in range(len(open_urls) - 1): 159 | oracle_action = ( 160 | "close_tab()", 161 | "I have opened a new tab. It is better to just use a single tab when exploring. \ 162 | I will close tab and return to the original tab to resume exploring." 163 | ) 164 | 165 | action = get_action( 166 | env=env, 167 | agent=agent, 168 | obs=obs, 169 | traj=traj, 170 | oracle_action=oracle_action 171 | ) 172 | 173 | obs, reward, terminated, truncated, env_info = perform_env_step( 174 | env=env, 175 | agent=agent, 176 | action=action, 177 | ) 178 | 179 | logger.info(f"Closed tab {open_urls[i]}") 180 | 181 | open_urls = obs['open_pages_urls'] 182 | 183 | if open_urls[0] != callback_context['pre_step_url']: 184 | oracle_action = ( 185 | "go_back()", 186 | "I was successfully able to navigate to the new page. Since I was able to successfully navigate to a new page, \ 187 | I should add a corresponding navigation task to the dataset next. But first, I will navigate back to the previous page." 188 | ) 189 | 190 | action = get_action( 191 | env=env, 192 | agent=agent, 193 | obs=obs, 194 | traj=traj, 195 | oracle_action=oracle_action 196 | ) 197 | obs, reward, terminated, truncated, env_info = perform_env_step( 198 | env=env, 199 | agent=agent, 200 | action=action, 201 | ) 202 | logger.info(f"Backtracked to {obs['open_pages_urls'][-1]}") 203 | 204 | return step_num, obs, reward, terminated, truncated, env_info, goal, callback_context 205 | 206 | 207 | def sample_task_candidates_for_node( 208 | env: BrowserEnv, 209 | explorer: ExplorerAgentWithExplorationCallbacks, 210 | evaluator: Evaluator, 211 | graph: Graph, 212 | node: Node, 213 | max_steps: int, 214 | max_retries: int = 3 215 | ) -> tuple[Task]: 216 | 217 | goal = explorer.goal_str 218 | tasks = [] 219 | 220 | logger.info(f"Sampling tasks for node {node.url} with agent config:\n{explorer.get_config()}") 221 | 222 | retry = 0 223 | while not tasks and retry < max_retries: 224 | logger.info(f"Sampling tasks for node {node.url}. On Retry {retry}/{max_retries}.") 225 | 226 | traj = run_episode( 227 | goal=goal, 228 | node=node, 229 | env=env, 230 | agent=explorer, 231 | evaluator=evaluator, 232 | graph=graph, 233 | max_steps=max_steps 234 | ) 235 | 236 | node.add_exploration_traj(traj) 237 | 238 | tasks.extend(explorer.get_proposed_tasks()) 239 | 240 | logger.info(f"On Retry {retry}. Sampled tasks for node {node.url}:\n{tasks}.") 241 | 242 | retry += 1 243 | 244 | 245 | return node.add_tasks(tasks, task_misc={'agent_info': explorer.get_config()}) 246 | 247 | 248 | def filter_to_feasible_tasks_for_node( 249 | tasks: List[Task], 250 | env: BrowserEnv, 251 | feasibility_checker: AgentWithExplorationCallbacks, 252 | evaluator: Evaluator, 253 | graph: Graph, 254 | node: Node, 255 | max_steps: int | Sequence[int] = 10, 256 | max_retries: int = 3, 257 | max_feasible_tasks: Optional[int] = None, 258 | ): 259 | 260 | # Shuffle tasks if max_feasible_tasks is provided to ensure diversity 261 | if max_feasible_tasks is not None: 262 | random.shuffle(tasks) 263 | 264 | # TODO: We may want to account for the more general case where we can have multiple feasibility checkers. 265 | # In this case, we would need to initialize this count to the number of feasible tasks found so far for filtered to tasks with similar agent_configs to input tasks. 266 | feasible_count = 0 267 | 268 | for i, task in enumerate(tasks): 269 | trajs = [] 270 | for r in range(max_retries): 271 | try: 272 | traj = run_episode( 273 | goal=task.goal, 274 | node=node, 275 | env=env, 276 | agent=feasibility_checker, 277 | evaluator=evaluator, 278 | graph=graph, 279 | max_steps=max_steps, 280 | callback_context={"task_misc": task.misc} # Pass task misc to the callback context 281 | ) 282 | 283 | trajs.append(traj) 284 | 285 | if traj.success: 286 | feasible_count += 1 287 | break 288 | 289 | except Exception as e: 290 | logger.error(f"Error checking feasibility for node {node.url} and task {task} on retry {r}: {e}") 291 | logger.error(traceback.format_exc()) 292 | 293 | node.add_trajectories(trajs) 294 | 295 | # Early termination if we've found enough feasible tasks 296 | if max_feasible_tasks is not None and feasible_count >= max_feasible_tasks: 297 | logger.info(f"Found {feasible_count} feasible tasks (max: {max_feasible_tasks}). Stopping feasibility checking early.") 298 | break 299 | 300 | 301 | def sample_task_solving_trajectories_for_node( 302 | node: Node, 303 | env: BrowserEnv, 304 | agent: AgentWithExplorationCallbacks, 305 | evaluator: Evaluator, 306 | graph: Graph, 307 | max_steps: int, 308 | num_trajs_per_task: int, 309 | ): 310 | tasks = node.get_feasible_tasks() 311 | 312 | logger.info(f"Sampling trajectories for node {node.url} with agent config:\n{agent.get_config()}") 313 | logger.info(f"Node has {len(tasks)} feasible tasks.") 314 | 315 | for task in tasks: 316 | 317 | logger.info(f"Sampling prefixed trajectories for node {node.url} and task {task.goal}.") 318 | 319 | for _ in range(num_trajs_per_task): 320 | 321 | try: 322 | traj = run_episode( 323 | goal=task.goal, 324 | node=node, 325 | env=env, 326 | agent=agent, 327 | evaluator=evaluator, 328 | graph=graph, 329 | max_steps=max_steps, 330 | callback_context={"task_misc": task.misc} 331 | ) 332 | 333 | traj.misc["needs_prefix"] = True 334 | 335 | node.add_trajectory(traj) 336 | 337 | except Exception as e: 338 | logger.error(f"Error sampling trajectories for node {node.url} and task {task.goal}: {e}") 339 | logger.error(traceback.format_exc()) 340 | 341 | 342 | for _ in range(num_trajs_per_task): 343 | 344 | try: 345 | traj = run_episode( 346 | goal=task.goal, 347 | node=graph.root, 348 | env=env, 349 | agent=agent, 350 | evaluator=evaluator, 351 | graph=graph, 352 | max_steps=max_steps, 353 | callback_context={**task.misc} 354 | ) 355 | 356 | traj.misc["needs_prefix"] = False 357 | 358 | node.add_trajectory(traj) 359 | 360 | except Exception as e: 361 | logger.error(f"Error sampling trajectories for node {node.url} and task {task.goal}: {e}") 362 | logger.error(traceback.format_exc()) 363 | 364 | def process_open_urls_callback( 365 | agent: AgentWithExplorationCallbacks, step_num: int, goal: str, env: BrowserEnv, graph: Graph, node: Node, traj: Trajectory, obs: dict, 366 | reward: float, terminated: bool, truncated: bool, env_info: dict, callback_context: dict 367 | ): 368 | """ 369 | Callback to process the open urls after each step. 370 | """ 371 | open_urls = obs['open_pages_urls'] 372 | 373 | for url in open_urls: 374 | curr_prefix = Trace.from_trajectory_steps( 375 | steps=traj.steps, 376 | start_url=node.url, 377 | end_url=url, 378 | misc={'agent_info': agent.get_config(), 'goal': goal, 'task_misc': callback_context.get('task_misc', {})} 379 | ) 380 | 381 | if graph.check_if_url_allowed(url): 382 | 383 | update_prefix = url != node.url # No self-edges 384 | 385 | url_node = graph.get_node(url) 386 | if url_node: 387 | if update_prefix: 388 | url_node.add_prefix(curr_prefix) 389 | else: 390 | graph.add_url( 391 | url=url, 392 | parent=node, 393 | prefixes=[curr_prefix] if update_prefix else [], 394 | node_misc={'discovered_by': agent.get_config(), 'goal': goal, 'task_misc': callback_context.get('task_misc', {})} 395 | ) 396 | 397 | if url not in node.children: 398 | node.children.append(url) 399 | node.update_save(save_prefix=False, save_info=True) 400 | 401 | return step_num, obs, reward, terminated, truncated, env_info, goal, callback_context 402 | 403 | 404 | def web_explore_loop(): 405 | 406 | parser = argparse.ArgumentParser(description="Run an episode with a browser gym agent.") 407 | parser.add_argument( 408 | "--config", 409 | "-c", 410 | type=str, 411 | required=True, 412 | help="Path to the configuration file.", 413 | ) 414 | args = parser.parse_args() 415 | 416 | config: WebExploreConfig = oc.load(args.config) 417 | oc.resolve(config) 418 | config_dict = oc.to_container(config) 419 | 420 | logger.info(f"WebExploreConfig:\n{config}") 421 | 422 | os.makedirs(config.exp_dir, exist_ok=True) 423 | 424 | page_explorers = [ 425 | wrap_agent_for_callback_protocol( 426 | AgentFactory.create_agent(**explorer['agent_factory_args']), 427 | pre_step_callbacks=[prestep_store_url, ], 428 | post_step_callbacks=[backtrack_if_needed, process_open_urls_callback], 429 | ) 430 | for explorer in config_dict['page_explorers'] 431 | ] 432 | 433 | nav_explorers = [ 434 | wrap_agent_for_callback_protocol( 435 | AgentFactory.create_agent(**explorer['agent_factory_args']), 436 | pre_step_callbacks=[prestep_store_url,], 437 | post_step_callbacks=[backtrack_if_needed, process_open_urls_callback, backtrack_when_new_page_found], 438 | ) 439 | for explorer in config_dict['nav_explorers'] 440 | ] 441 | 442 | feasibility_checkers = [ 443 | wrap_agent_for_callback_protocol( 444 | AgentFactory.create_agent(**feasibility_checker['agent_factory_args']), 445 | pre_step_callbacks=[prestep_store_url], 446 | post_step_callbacks=[backtrack_if_needed, process_open_urls_callback], 447 | ) 448 | for feasibility_checker in config_dict['feasibility_checkers'] 449 | ] 450 | 451 | solvers = [ 452 | wrap_agent_for_callback_protocol( 453 | AgentFactory.create_agent(**solver['agent_factory_args']), 454 | pre_step_callbacks=[prestep_store_url,], 455 | post_step_callbacks=[backtrack_if_needed, process_open_urls_callback], 456 | ) 457 | for solver in config_dict['solvers'] 458 | ] 459 | 460 | env: BrowserEnv = EnvArgs(**config_dict['env_args']).make_env( 461 | action_mapping=lambda x: x, 462 | exp_dir=config.exp_dir 463 | ) 464 | env = env.unwrapped 465 | env.reset() 466 | root_url = env.page.url 467 | 468 | evaluator = Evaluator(**config.evaluator) 469 | 470 | if config.resume_from: 471 | graph = Graph.load(os.path.join(config.resume_from, "graph"), load_images=False) 472 | else: 473 | graph = Graph( 474 | root_url=root_url, 475 | exp_dir=config.exp_dir, 476 | denylist_patterns=config_dict['denylist_patterns'], 477 | allowlist_patterns=config_dict['allowlist_patterns'] 478 | ) 479 | 480 | try: 481 | curr_node = graph.get_next_node() 482 | exploration_count = len(graph.explored_nodes) 483 | 484 | while curr_node and exploration_count < config.max_nodes: 485 | 486 | logger.info(f"Exploring node {curr_node.url} ...") 487 | 488 | if hasattr(config, 'full_reset_url') and config.full_reset_url: 489 | logger.info(f"Performing full env reset with url: {config.full_reset_url}") 490 | perform_full_reset(config.full_reset_url) 491 | 492 | if not len(curr_node.tasks): 493 | 494 | page_explorer_tasks = [] 495 | for i, page_explorer in enumerate(page_explorers): 496 | page_explorer_tasks.extend(sample_task_candidates_for_node( 497 | env=env, 498 | explorer=page_explorer, 499 | evaluator=evaluator, 500 | graph=graph, 501 | node=curr_node, 502 | max_steps=config.page_explorers[i].max_steps, 503 | max_retries=config.page_explorers[i].retries, 504 | )) 505 | 506 | nav_explorer_tasks = [] 507 | for i, nav_explorer in enumerate(nav_explorers): 508 | nav_explorer_tasks.extend(sample_task_candidates_for_node( 509 | env=env, 510 | explorer=nav_explorer, 511 | evaluator=evaluator, 512 | graph=graph, 513 | node=curr_node, 514 | max_steps=config.nav_explorers[i].max_steps, 515 | max_retries=config.nav_explorers[i].retries, 516 | )) 517 | 518 | 519 | for i, feasibility_checker in enumerate(feasibility_checkers): 520 | filter_to_feasible_tasks_for_node( 521 | tasks=page_explorer_tasks, 522 | env=env, 523 | feasibility_checker=feasibility_checker, 524 | evaluator=evaluator, 525 | graph=graph, 526 | node=curr_node, 527 | max_steps=config.feasibility_checkers[i].max_steps, 528 | max_retries=config.feasibility_checkers[i].retries, 529 | max_feasible_tasks=config.max_feasible_page_explorer_tasks_per_node 530 | ) 531 | 532 | filter_to_feasible_tasks_for_node( 533 | tasks=nav_explorer_tasks, 534 | env=env, 535 | feasibility_checker=feasibility_checker, 536 | evaluator=evaluator, 537 | graph=graph, 538 | node=curr_node, 539 | max_steps=config.feasibility_checkers[i].max_steps, 540 | max_retries=config.feasibility_checkers[i].retries, 541 | max_feasible_tasks=config.max_feasible_nav_explorer_tasks_per_node 542 | ) 543 | 544 | for i, solver in enumerate(solvers): 545 | sample_task_solving_trajectories_for_node( 546 | node=curr_node, 547 | env=env, 548 | agent=solver, 549 | evaluator=evaluator, 550 | graph=graph, 551 | max_steps=config.solvers[i].max_steps, 552 | num_trajs_per_task=config.solvers[i].retries 553 | ) 554 | 555 | graph.add_to_explored(curr_node) 556 | exploration_count += 1 557 | curr_node = graph.get_next_node() 558 | 559 | if exploration_count == config.max_nodes: 560 | logger.info(f"Max nodes to explore reached: {config.max_nodes}") 561 | else: 562 | logger.info(f"We will now explore the next node: {curr_node.url if curr_node else 'No nodes left to explore!'}") 563 | 564 | except Exception as e: 565 | logger.error(f"Error during exploration: {e}") 566 | logger.error(traceback.format_exc()) 567 | raise e 568 | 569 | finally: 570 | env.close() 571 | 572 | if __name__ == "__main__": 573 | web_explore_loop() 574 | --------------------------------------------------------------------------------