├── visual-tree-search-backend
├── app
│ ├── __init__.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── lwats
│ │ │ ├── __init__.py
│ │ │ ├── core_async
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── agent_factory.py
│ │ │ ├── agents_async
│ │ │ │ ├── __init__.py
│ │ │ │ └── SearchAgents
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── tree_vis.py
│ │ │ │ │ ├── lats_agent.py
│ │ │ │ │ ├── trajectory_score.py
│ │ │ │ │ └── lats_node.py
│ │ │ ├── webagent_utils_async
│ │ │ │ ├── __init__.py
│ │ │ │ ├── action
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── action_parser.py
│ │ │ │ │ ├── parsers.py
│ │ │ │ │ └── base.py
│ │ │ │ ├── tools
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── upload_file.py
│ │ │ │ │ ├── select_option.py
│ │ │ │ │ ├── navigation.py
│ │ │ │ │ ├── registry.py
│ │ │ │ │ ├── shared_utils.py
│ │ │ │ │ └── webscraping.py
│ │ │ │ ├── utils
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── browser_env
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── constants.py
│ │ │ │ │ └── javascript
│ │ │ │ │ │ └── frame_unmark_elements.js
│ │ │ │ └── evaluation
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── evaluators.py
│ │ │ │ │ └── feedback.py
│ │ │ └── evaluation_async
│ │ │ │ └── evaluators.py
│ │ ├── services
│ │ │ └── hello_service.py
│ │ ├── controllers
│ │ │ └── hello_controller.py
│ │ ├── routes
│ │ │ ├── hello.py
│ │ │ ├── terminate_session.py
│ │ │ ├── sse.py
│ │ │ ├── websocket.py
│ │ │ ├── tree_traversal.py
│ │ │ └── tree_websocket.py
│ │ └── run_demo_treesearch_async.py
│ ├── lib
│ │ └── supabase.py
│ └── main.py
├── test
│ ├── __init__.py
│ ├── test-ws.py
│ ├── test-tree-ws.py
│ ├── test_tree_search_depth.py
│ ├── websocket-client.html
│ ├── test-tree-search-ws-simple.py
│ └── test-tree-search-ws-lats.py
├── .env.example
├── Dockerfile
├── requirements.txt
└── README.md
├── visual-tree-search-browser-service
├── README.md
├── app
│ └── __init__.py
├── requirements.txt
└── Dockerfile
├── visual-tree-search-state-reset
├── app
│ ├── __init__.py
│ ├── api
│ │ ├── services
│ │ │ └── hello_service.py
│ │ ├── controllers
│ │ │ └── hello_controller.py
│ │ └── routes
│ │ │ ├── hello.py
│ │ │ ├── test_sql.py
│ │ │ ├── test_container.py
│ │ │ └── test_db.py
│ └── main.py
├── .gitignore
├── requirements.txt
├── README.md
└── Server_Setup.md
├── visual-tree-search-app
├── .env.example
├── postcss.config.mjs
├── public
│ ├── favicon.ico
│ ├── vercel.svg
│ ├── window.svg
│ ├── file.svg
│ ├── globe.svg
│ └── next.svg
├── lib
│ └── utils.ts
├── next.config.ts
├── pages
│ ├── _document.tsx
│ ├── api
│ │ └── hello.ts
│ └── _app.tsx
├── eslint.config.mjs
├── components.json
├── components
│ ├── Footer.tsx
│ ├── ui
│ │ ├── label.tsx
│ │ ├── input.tsx
│ │ ├── checkbox.tsx
│ │ ├── scroll-area.tsx
│ │ ├── button.tsx
│ │ ├── card.tsx
│ │ └── dialog.tsx
│ ├── Layout.tsx
│ ├── LiveBrowserView.tsx
│ ├── Header.tsx
│ └── Sidebar.tsx
├── tsconfig.json
├── .gitignore
├── package.json
├── README.md
└── styles
│ └── globals.css
├── .env.example
├── .gitignore
├── .github
└── workflows
│ ├── backend-ecs.yml
│ └── browser-service-ecs.yml
├── LICENSE
└── README.md
/visual-tree-search-backend/app/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-browser-service/README.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-browser-service/app/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-app/.env.example:
--------------------------------------------------------------------------------
1 | NEXT_PUBLIC_BACKEND_URL=""
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/core_async/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/agents_async/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/action/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Test package for Visual Tree Search
3 | """
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .DS_Store
3 | venv
4 | .env
5 |
6 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=
2 | BROWSERBASE_PROJECT_ID=
3 | BROWSERBASE_API_KEY=
4 | ACCOUNT_RESET_URL=
--------------------------------------------------------------------------------
/visual-tree-search-backend/.env.example:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=""
2 | BROWSERBASE_PROJECT_ID=""
3 | BROWSERBASE_API_KEY=""
4 | ACCOUNT_RESET_URL=""
--------------------------------------------------------------------------------
/visual-tree-search-app/postcss.config.mjs:
--------------------------------------------------------------------------------
1 | const config = {
2 | plugins: ["@tailwindcss/postcss"],
3 | };
4 |
5 | export default config;
6 |
--------------------------------------------------------------------------------
/visual-tree-search-app/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PathOnAIOrg/VisualTreeSearch-Demo/HEAD/visual-tree-search-app/public/favicon.ico
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/services/hello_service.py:
--------------------------------------------------------------------------------
1 | class HelloService:
2 | def get_hello_data(self):
3 | return {"name": "John Doe"}
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/api/services/hello_service.py:
--------------------------------------------------------------------------------
1 | class HelloService:
2 | def get_hello_data(self):
3 | return {"name": "John Doe"}
--------------------------------------------------------------------------------
/visual-tree-search-app/public/vercel.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-browser-service/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.109.2
2 | uvicorn==0.27.1
3 | python-multipart==0.0.9
4 | pydantic==2.6.1
5 | playwright==1.41.2
6 | python-dotenv==1.0.1
7 | boto3==1.34.34
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi>=0.103.1
2 | uvicorn>=0.23.2
3 | websockets>=11.0.3
4 | python-dotenv>=1.0.0
5 | supabase>=1.0.3
6 | httpx>=0.24.1
7 | sqlalchemy>=2.0.27
8 | pymysql>=1.1.0
9 | docker
--------------------------------------------------------------------------------
/visual-tree-search-app/lib/utils.ts:
--------------------------------------------------------------------------------
1 | import { clsx, type ClassValue } from "clsx"
2 | import { twMerge } from "tailwind-merge"
3 |
4 | export function cn(...inputs: ClassValue[]) {
5 | return twMerge(clsx(inputs))
6 | }
7 |
--------------------------------------------------------------------------------
/visual-tree-search-app/next.config.ts:
--------------------------------------------------------------------------------
1 | import type { NextConfig } from "next";
2 |
3 | const nextConfig: NextConfig = {
4 | /* config options here */
5 | reactStrictMode: true,
6 | };
7 |
8 | export default nextConfig;
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | visual-tree-search-backend/app/api/test_logs/*
2 | visual-tree-search-backend/log/*
3 | log/*
4 | shopping.json
5 | Log.md
6 | visual-tree-search-backend/log/
7 | visual-tree-search-backend/app/api/log/
8 | __pycache__/
9 | .DS_Store
10 | venv
11 | .env
12 |
13 |
--------------------------------------------------------------------------------
/visual-tree-search-app/pages/_document.tsx:
--------------------------------------------------------------------------------
1 | import { Html, Head, Main, NextScript } from "next/document";
2 |
3 | export default function Document() {
4 | return (
5 |
6 |
8 |
9 |
10 |
11 |
12 | );
13 | }
14 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/constants.py:
--------------------------------------------------------------------------------
1 | TEXT_MAX_LENGTH = 2**32 - 1
2 |
3 | BROWSERGYM_ID_ATTRIBUTE = "data-unique-test-id" # Playwright's default is "data-testid"
4 | BROWSERGYM_VISIBILITY_ATTRIBUTE = "browsergym_visibility_ratio"
5 | BROWSERGYM_SETOFMARKS_ATTRIBUTE = "browsergym_set_of_marks"
6 |
7 | EXTRACT_OBS_MAX_TRIES = 5
8 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/controllers/hello_controller.py:
--------------------------------------------------------------------------------
1 | from fastapi import Response
2 | from app.api.services.hello_service import HelloService
3 |
4 | class HelloController:
5 | def __init__(self):
6 | self.hello_service = HelloService()
7 |
8 | def get_hello(self):
9 | data = self.hello_service.get_hello_data()
10 | return data
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/routes/hello.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter
2 | from app.api.controllers.hello_controller import HelloController
3 |
4 | router = APIRouter(
5 | prefix="",
6 | tags=["hello"]
7 | )
8 | hello_controller = HelloController()
9 |
10 | @router.get("", include_in_schema=True)
11 | async def get_hello():
12 | return hello_controller.get_hello()
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/api/controllers/hello_controller.py:
--------------------------------------------------------------------------------
1 | from fastapi import Response
2 | from app.api.services.hello_service import HelloService
3 |
4 | class HelloController:
5 | def __init__(self):
6 | self.hello_service = HelloService()
7 |
8 | def get_hello(self):
9 | data = self.hello_service.get_hello_data()
10 | return data
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/api/routes/hello.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter
2 | from app.api.controllers.hello_controller import HelloController
3 |
4 | router = APIRouter(
5 | prefix="",
6 | tags=["hello"]
7 | )
8 | hello_controller = HelloController()
9 |
10 | @router.get("", include_in_schema=True)
11 | async def get_hello():
12 | return hello_controller.get_hello()
--------------------------------------------------------------------------------
/visual-tree-search-app/pages/api/hello.ts:
--------------------------------------------------------------------------------
1 | // Next.js API route support: https://nextjs.org/docs/api-routes/introduction
2 | import type { NextApiRequest, NextApiResponse } from "next";
3 |
4 | type Data = {
5 | name: string;
6 | };
7 |
8 | export default function handler(
9 | req: NextApiRequest,
10 | res: NextApiResponse,
11 | ) {
12 | res.status(200).json({ name: "John Doe" });
13 | }
14 |
--------------------------------------------------------------------------------
/visual-tree-search-app/public/window.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-app/public/file.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-app/pages/_app.tsx:
--------------------------------------------------------------------------------
1 | import "../styles/globals.css";
2 | import type { AppProps } from "next/app";
3 | import Layout from "../components/Layout";
4 | import { ThemeProvider } from "next-themes";
5 |
6 |
7 | export default function App({ Component, pageProps }: AppProps) {
8 | return (
9 |
10 |
11 |
12 |
13 |
14 | );
15 | }
16 |
--------------------------------------------------------------------------------
/visual-tree-search-app/eslint.config.mjs:
--------------------------------------------------------------------------------
1 | import { dirname } from "path";
2 | import { fileURLToPath } from "url";
3 | import { FlatCompat } from "@eslint/eslintrc";
4 |
5 | const __filename = fileURLToPath(import.meta.url);
6 | const __dirname = dirname(__filename);
7 |
8 | const compat = new FlatCompat({
9 | baseDirectory: __dirname,
10 | });
11 |
12 | const eslintConfig = [
13 | ...compat.extends("next/core-web-vitals", "next/typescript"),
14 | ];
15 |
16 | export default eslintConfig;
17 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://ui.shadcn.com/schema.json",
3 | "style": "new-york",
4 | "rsc": false,
5 | "tsx": true,
6 | "tailwind": {
7 | "config": "",
8 | "css": "styles/globals.css",
9 | "baseColor": "neutral",
10 | "cssVariables": true,
11 | "prefix": ""
12 | },
13 | "aliases": {
14 | "components": "@/components",
15 | "utils": "@/lib/utils",
16 | "ui": "@/components/ui",
17 | "lib": "@/lib",
18 | "hooks": "@/hooks"
19 | },
20 | "iconLibrary": "lucide"
21 | }
--------------------------------------------------------------------------------
/visual-tree-search-app/components/Footer.tsx:
--------------------------------------------------------------------------------
1 | const Footer = () => {
2 | return (
3 |
12 | );
13 | };
14 |
15 | export default Footer;
16 |
17 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11-slim
2 |
3 | WORKDIR /app
4 |
5 | # Set environment variables
6 | ENV PYTHONDONTWRITEBYTECODE=1
7 | ENV PYTHONUNBUFFERED=1
8 |
9 | # Install dependencies
10 | COPY requirements.txt .
11 | RUN pip install --no-cache-dir -r requirements.txt
12 |
13 | # Copy all files
14 | COPY . .
15 |
16 | # Expose the port your app runs on
17 | EXPOSE 3000
18 |
19 | # Set production environment
20 | ENV APP_ENV=production
21 |
22 | # Start the application
23 | CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "3000"]
--------------------------------------------------------------------------------
/visual-tree-search-backend/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi>=0.103.1
2 | uvicorn>=0.23.2
3 | websockets>=11.0.3
4 | python-dotenv>=1.0.0
5 | supabase>=1.0.3
6 | httpx>=0.24.1
7 | openai==1.42.0
8 | playwright==1.39.0
9 | watchdog==4.0.2
10 | pyparsing==3.1.2
11 | litellm==1.42.1
12 | Pillow>=10.1
13 | beautifulsoup4>=4.12
14 | elevenlabs==1.6.1
15 | numpy>= 1.14
16 | lxml>= 4.9
17 | httpx==0.27.2
18 | evaluate==0.4.0
19 | beartype==0.12.0
20 | scikit-image==0.22.0
21 | numpy==1.25.2
22 | aiolimiter==1.1.0
23 | transformers==4.34.0
24 | nltk==3.8.1
25 | browserbase
26 | aiohttp
27 | boto3==1.34.34
--------------------------------------------------------------------------------
/visual-tree-search-browser-service/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM mcr.microsoft.com/playwright/python:v1.41.2-jammy
2 |
3 | WORKDIR /app
4 |
5 | # Copy requirements first to leverage Docker cache
6 | COPY requirements.txt .
7 | RUN pip install --no-cache-dir -r requirements.txt
8 |
9 | # Copy all files
10 | COPY . .
11 |
12 | # Expose the port your app runs on
13 | EXPOSE 3000
14 |
15 | # Set production environment
16 | ENV APP_ENV=production
17 |
18 | # Start the application
19 | # Start the application
20 | CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "3000"]
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/lib/supabase.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 | from supabase import create_client, Client
4 |
5 | # Load environment variables
6 | load_dotenv()
7 |
8 | # Get Supabase configuration
9 | SUPABASE_URL = os.getenv("SUPABASE_URL")
10 | SUPABASE_KEY = os.getenv("SUPABASE_KEY")
11 |
12 | if not SUPABASE_URL:
13 | raise ValueError("Missing SUPABASE_URL environment variable")
14 |
15 | if not SUPABASE_KEY:
16 | raise ValueError("Missing SUPABASE_KEY environment variable")
17 |
18 | # Create Supabase client
19 | supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
--------------------------------------------------------------------------------
/visual-tree-search-app/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2017",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "skipLibCheck": true,
7 | "strict": true,
8 | "noEmit": true,
9 | "esModuleInterop": true,
10 | "module": "esnext",
11 | "moduleResolution": "bundler",
12 | "resolveJsonModule": true,
13 | "isolatedModules": true,
14 | "jsx": "preserve",
15 | "incremental": true,
16 | "paths": {
17 | "@/*": ["./*"]
18 | }
19 | },
20 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"],
21 | "exclude": ["node_modules"]
22 | }
23 |
--------------------------------------------------------------------------------
/visual-tree-search-app/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.*
7 | .yarn/*
8 | !.yarn/patches
9 | !.yarn/plugins
10 | !.yarn/releases
11 | !.yarn/versions
12 |
13 | # testing
14 | /coverage
15 |
16 | # next.js
17 | /.next/
18 | /out/
19 |
20 | # production
21 | /build
22 |
23 | # misc
24 | .DS_Store
25 | *.pem
26 |
27 | # debug
28 | npm-debug.log*
29 | yarn-debug.log*
30 | yarn-error.log*
31 | .pnpm-debug.log*
32 |
33 | # env files (can opt-in for committing if needed)
34 | .env*
35 |
36 | # vercel
37 | .vercel
38 |
39 | # typescript
40 | *.tsbuildinfo
41 | next-env.d.ts
42 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/label.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 | import * as LabelPrimitive from "@radix-ui/react-label"
3 |
4 | import { cn } from "@/lib/utils"
5 |
6 | function Label({
7 | className,
8 | ...props
9 | }: React.ComponentProps) {
10 | return (
11 |
19 | )
20 | }
21 |
22 | export { Label }
23 |
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/README.md:
--------------------------------------------------------------------------------
1 | # VisualTreeSearch Website State Reset
2 |
3 | This repo contains the FastAPI server for state reset on the backend database of WebArena shopping website.
4 |
5 | ## usage
6 |
7 |
8 | ### Setup
9 |
10 |
11 | [Setup WebArena shopping website on AWS](./Server_Setup.md)
12 |
13 |
14 |
15 | ### start FastAPI server
16 |
17 | uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
18 |
19 | ### send http requests
20 |
21 | - Test
22 |
23 | `curl -N http://localhost:8000/api/hello`
24 |
25 | - Test DB status
26 |
27 | `curl -N http://localhost:8000/api/db/status`
28 |
29 | - Reset account information with DB SQL operations
30 |
31 | `curl -N http://localhost:8000/api/sql/restore `
32 |
33 | - Reset the whole container
34 |
35 | `curl -N http://localhost:8000/api/container/reset/****PUBLICIP****`
36 |
37 |
38 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/upload_file.py:
--------------------------------------------------------------------------------
1 | from .shared_utils import take_action
2 | from .registry import ToolRegistry, Tool
3 |
4 | def upload_file(task_description, features=None, branching_factor=None, playwright_manager=None, log_folder='log', elements_filter=None):
5 | response = take_action(task_description, ["file"], features, branching_factor, playwright_manager, log_folder, elements_filter)
6 | return response
7 |
8 |
9 | def register_upload_file_tool():
10 | ToolRegistry.register(Tool(
11 | name="upload_file",
12 | func=upload_file,
13 | description="Upload a file.",
14 | parameters={
15 | "task_description": {
16 | "type": "string",
17 | "description": "The description of the file upload task"
18 | }
19 | }
20 | ))
21 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/select_option.py:
--------------------------------------------------------------------------------
1 | from .shared_utils import take_action
2 | from .registry import ToolRegistry, Tool
3 |
4 |
5 | def select_option(task_description, features=None, branching_factor=None, playwright_manager=None, log_folder='log', elements_filter=None):
6 | response = take_action(task_description, ["select_option"], features, branching_factor, playwright_manager,
7 | log_folder, elements_filter)
8 | return response
9 |
10 |
11 | def register_select_option_tool():
12 | ToolRegistry.register(Tool(
13 | name="select_option",
14 | func=select_option,
15 | description="Select an option from a dropdown or list.",
16 | parameters={
17 | "task_description": {
18 | "type": "string",
19 | "description": "The description of the option selection task"
20 | }
21 | }
22 | ))
23 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/navigation.py:
--------------------------------------------------------------------------------
1 | from .shared_utils import take_action
2 | from .registry import ToolRegistry, Tool
3 |
4 |
5 | def navigation(task_description, features=None, branching_factor=None, playwright_manager=None, log_folder='log', elements_filter=None):
6 | response = take_action(task_description, ["bid", "nav"], features, branching_factor, playwright_manager, log_folder, elements_filter)
7 | return response
8 |
9 |
10 | def register_navigation_tool():
11 | ToolRegistry.register(Tool(
12 | name="navigation",
13 | func=navigation,
14 | description="Perform a web navigation task, including fill text, click, search, go to new page",
15 | parameters={
16 | "task_description": {
17 | "type": "string",
18 | "description": "The description of the web navigation task, including fill text, click, search, go to new page"
19 | }
20 | }
21 | ))
--------------------------------------------------------------------------------
/visual-tree-search-app/public/globe.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/input.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 |
3 | import { cn } from "@/lib/utils"
4 |
5 | function Input({ className, type, ...props }: React.ComponentProps<"input">) {
6 | return (
7 |
18 | )
19 | }
20 |
21 | export { Input }
22 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/routes/terminate_session.py:
--------------------------------------------------------------------------------
1 | # terminate browser session session
2 | from fastapi import APIRouter, HTTPException
3 | from dotenv import load_dotenv
4 | from browserbase import Browserbase
5 | from playwright.async_api import async_playwright
6 | import os
7 |
8 | # Load environment variables from .env file
9 | load_dotenv()
10 |
11 | API_KEY = os.environ["BROWSERBASE_API_KEY"]
12 | PROJECT_ID = os.environ["BROWSERBASE_PROJECT_ID"]
13 |
14 | router = APIRouter()
15 |
16 | @router.post("/{session_id}")
17 | async def terminate_session(session_id: str):
18 | try:
19 | # Initialize the Browserbase client
20 | bb = Browserbase(api_key=API_KEY)
21 |
22 | # Update the session's status to request release
23 | bb.sessions.update(
24 | session_id,
25 | project_id=PROJECT_ID,
26 | status="REQUEST_RELEASE"
27 | )
28 |
29 | return {"status": "success", "message": f"Session {session_id} termination requested"}
30 | except Exception as e:
31 | raise HTTPException(status_code=500, detail=str(e))
--------------------------------------------------------------------------------
/visual-tree-search-app/components/Layout.tsx:
--------------------------------------------------------------------------------
1 | import Header from "./Header";
2 | import Sidebar from "./Sidebar";
3 | import Head from "next/head";
4 |
5 | interface LayoutProps {
6 | children: React.ReactNode;
7 | }
8 |
9 | const Layout: React.FC = ({ children }) => {
10 | return (
11 | <>
12 |
13 | Visual Tree Search
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | {children}
25 |
26 |
27 |
28 |
29 | >
30 | );
31 | };
32 |
33 | export default Layout;
--------------------------------------------------------------------------------
/visual-tree-search-app/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "visual-tree-search-app",
3 | "version": "0.1.0",
4 | "private": true,
5 | "scripts": {
6 | "dev": "next dev",
7 | "build": "next build",
8 | "start": "next start",
9 | "lint": "next lint"
10 | },
11 | "dependencies": {
12 | "@radix-ui/react-checkbox": "^1.1.4",
13 | "@radix-ui/react-dialog": "^1.1.6",
14 | "@radix-ui/react-label": "^2.1.2",
15 | "@radix-ui/react-scroll-area": "^1.2.3",
16 | "@radix-ui/react-slot": "^1.1.2",
17 | "@types/d3": "^7.4.3",
18 | "class-variance-authority": "^0.7.1",
19 | "clsx": "^2.1.1",
20 | "d3": "^7.9.0",
21 | "framer-motion": "^12.9.2",
22 | "lucide-react": "^0.479.0",
23 | "next": "15.2.2",
24 | "next-themes": "^0.4.6",
25 | "react": "^19.0.0",
26 | "react-dom": "^19.0.0",
27 | "react-icons": "^5.5.0",
28 | "tailwind-merge": "^3.0.2",
29 | "tailwindcss-animate": "^1.0.7"
30 | },
31 | "devDependencies": {
32 | "@eslint/eslintrc": "^3",
33 | "@tailwindcss/postcss": "^4",
34 | "@types/node": "^20",
35 | "@types/react": "^19",
36 | "@types/react-dom": "^19",
37 | "eslint": "^9",
38 | "eslint-config-next": "15.2.2",
39 | "tailwindcss": "^4",
40 | "typescript": "^5"
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/checkbox.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 | import * as CheckboxPrimitive from "@radix-ui/react-checkbox"
3 | import { CheckIcon } from "lucide-react"
4 |
5 | import { cn } from "@/lib/utils"
6 |
7 | function Checkbox({
8 | className,
9 | ...props
10 | }: React.ComponentProps) {
11 | return (
12 |
20 |
24 |
25 |
26 |
27 | )
28 | }
29 |
30 | export { Checkbox }
31 |
--------------------------------------------------------------------------------
/visual-tree-search-app/public/next.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/api/routes/test_sql.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | from ...sql.ops import delete_data, restore_data, fetch_data
5 | from ...sql.verify import verifyAccountStatus
6 |
7 |
8 | import logging
9 | from fastapi import APIRouter, HTTPException
10 |
11 | # Configure basic logging
12 | logging.basicConfig(
13 | level=logging.INFO,
14 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
15 | )
16 |
17 | router = APIRouter()
18 |
19 |
20 | def runcmd(cmd):
21 | print('------------------------------------------------------------')
22 | print(cmd)
23 | res=os.popen(cmd).read()
24 | print(res)
25 | print('------------------------------------------------------------')
26 | return(res)
27 |
28 |
29 |
30 | # curl -N http://localhost:8000/api/sql/restore
31 | @router.get("/restore")
32 | async def sql_restore():
33 | delres=delete_data()
34 | resres=restore_data()
35 | return {
36 | "delres":delres,
37 | "resres":resres
38 | }
39 |
40 |
41 | # curl -N http://localhost:8000/api/sql/extract
42 | @router.get("/extract")
43 | async def sql_extract():
44 | res=fetch_data()
45 | return {
46 | "status":res
47 | }
48 |
49 |
50 |
51 |
52 | # curl -N http://localhost:8000/api/sql/verify
53 | @router.get("/verify")
54 | async def sql_verify():
55 | res=verifyAccountStatus()
56 | return {
57 | "status":res
58 | }
59 |
60 |
61 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/LiveBrowserView.tsx:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import { Globe } from 'lucide-react';
3 |
4 | interface LiveBrowserViewProps {
5 | liveBrowserUrl: string | null;
6 | width: string;
7 | }
8 |
9 | const LiveBrowserView: React.FC = ({ liveBrowserUrl, width }) => {
10 | return (
11 |
15 |
16 |
17 |
18 | Live Browser View
19 |
20 |
21 |
22 | {liveBrowserUrl ? (
23 |
28 | ) : (
29 |
30 |
Browser view will appear here when search starts
31 |
32 | )}
33 |
34 |
35 | );
36 | };
37 |
38 | export default LiveBrowserView;
39 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/Header.tsx:
--------------------------------------------------------------------------------
1 | import { Moon, Sun } from "lucide-react";
2 | import { Button } from "@/components/ui/button";
3 | import { useTheme } from "next-themes";
4 | import Link from "next/link";
5 |
6 | const Header = () => {
7 | const { theme, setTheme } = useTheme();
8 |
9 | return (
10 |
33 | );
34 | };
35 |
36 | export default Header;
--------------------------------------------------------------------------------
/.github/workflows/backend-ecs.yml:
--------------------------------------------------------------------------------
1 | name: Backend Service Deploy to Amazon ECS
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | paths:
7 | - 'visual-tree-search-backend/**'
8 | workflow_dispatch:
9 |
10 | env:
11 | AWS_REGION: us-east-1
12 | ECR_REPOSITORY: visual-tree-search-backend
13 | ECS_CLUSTER: fargate-cluster
14 | ECS_SERVICE: visual-tree-search-backend-service
15 |
16 | permissions:
17 | contents: read
18 |
19 | jobs:
20 | deploy:
21 | name: Deploy to ECS
22 | runs-on: ubuntu-latest # Changed from arm64-latest to ubuntu-latest
23 | environment: production
24 |
25 | steps:
26 | - name: Checkout
27 | uses: actions/checkout@v3
28 |
29 | - name: Configure AWS credentials
30 | uses: aws-actions/configure-aws-credentials@v1
31 | with:
32 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
33 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
34 | aws-region: ${{ env.AWS_REGION }}
35 |
36 | - name: Login to Amazon ECR
37 | id: login-ecr
38 | uses: aws-actions/amazon-ecr-login@v1
39 |
40 | - name: Set up Docker Buildx
41 | uses: docker/setup-buildx-action@v2
42 |
43 | - name: Build, tag, and push image to Amazon ECR
44 | env:
45 | ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
46 | run: |
47 | # Build a docker container for ARM64 and push it to ECR
48 | docker buildx build --platform linux/arm64 \
49 | --tag $ECR_REGISTRY/$ECR_REPOSITORY:latest \
50 | --push ./visual-tree-search-backend
51 |
52 | - name: Force new deployment of ECS service
53 | run: |
54 | aws ecs update-service --cluster ${{ env.ECS_CLUSTER }} --service ${{ env.ECS_SERVICE }} --force-new-deployment
--------------------------------------------------------------------------------
/.github/workflows/browser-service-ecs.yml:
--------------------------------------------------------------------------------
1 | name: Browser Service Deploy to Amazon ECS
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | paths:
7 | - 'visual-tree-search-browser-service/**'
8 | workflow_dispatch:
9 |
10 | env:
11 | AWS_REGION: us-east-1
12 | ECR_REPOSITORY: visual-tree-search-browser-service
13 | ECS_CLUSTER: fargate-cluster
14 | ECS_SERVICE: visual-tree-search-browser-service
15 |
16 | permissions:
17 | contents: read
18 |
19 | jobs:
20 | deploy:
21 | name: Deploy to ECS
22 | runs-on: ubuntu-latest # Changed from arm64-latest to ubuntu-latest
23 | environment: production
24 |
25 | steps:
26 | - name: Checkout
27 | uses: actions/checkout@v3
28 |
29 | - name: Configure AWS credentials
30 | uses: aws-actions/configure-aws-credentials@v1
31 | with:
32 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
33 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
34 | aws-region: ${{ env.AWS_REGION }}
35 |
36 | - name: Login to Amazon ECR
37 | id: login-ecr
38 | uses: aws-actions/amazon-ecr-login@v1
39 |
40 | - name: Set up Docker Buildx
41 | uses: docker/setup-buildx-action@v2
42 |
43 | - name: Build, tag, and push image to Amazon ECR
44 | env:
45 | ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
46 | run: |
47 | # Build a docker container for ARM64 and push it to ECR
48 | docker buildx build --platform linux/arm64 \
49 | --tag $ECR_REGISTRY/$ECR_REPOSITORY:latest \
50 | --push ./visual-tree-search-browser-service
51 |
52 | - name: Force new deployment of ECS service
53 | run: |
54 | aws ecs update-service --cluster ${{ env.ECS_CLUSTER }} --service ${{ env.ECS_SERVICE }} --force-new-deployment
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from fastapi import FastAPI, WebSocket
4 | from fastapi.middleware.cors import CORSMiddleware
5 | import uvicorn
6 | from dotenv import load_dotenv
7 |
8 | # Configure logging
9 | logging.basicConfig(
10 | level=logging.INFO,
11 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12 | )
13 |
14 | # Load environment variables
15 | load_dotenv()
16 |
17 | # Create FastAPI app
18 | app = FastAPI(
19 | title="Core API Backend",
20 | # Disable automatic redirects for trailing slashes
21 | redirect_slashes=False
22 | )
23 |
24 | # Configure CORS
25 | app.add_middleware(
26 | CORSMiddleware,
27 | allow_origins=["*"], # Allow all origins
28 | allow_credentials=True,
29 | allow_methods=["*"],
30 | allow_headers=["*"],
31 | )
32 |
33 | # Root route
34 | @app.get("/")
35 | async def root():
36 | return {"message": "Welcome to FastAPI backend"}
37 |
38 | # Import routers - do this after creating the app to avoid circular imports
39 | from app.api.routes.hello import router as hello_router
40 | from app.api.routes.test_db import router as test_db_router
41 | from app.api.routes.test_container import router as test_container_router
42 | from app.api.routes.test_sql import router as test_sql_router
43 |
44 | # Include routers from different modules
45 | app.include_router(hello_router, prefix="/api/hello", tags=["hello"])
46 | app.include_router(test_db_router, prefix="/api/db", tags=["database"])
47 | app.include_router(test_container_router, prefix="/api/container", tags=["container"])
48 | app.include_router(test_sql_router, prefix="/api/sql", tags=["sql"])
49 |
50 |
51 | if __name__ == "__main__":
52 | port = int(os.getenv("PORT", 3000))
53 | uvicorn.run("app.main:app", host="0.0.0.0", port=port, reload=True)
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/scroll-area.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 | import * as ScrollAreaPrimitive from "@radix-ui/react-scroll-area"
3 |
4 | import { cn } from "@/lib/utils"
5 |
6 | function ScrollArea({
7 | className,
8 | children,
9 | ...props
10 | }: React.ComponentProps) {
11 | return (
12 |
17 |
21 | {children}
22 |
23 |
24 |
25 |
26 | )
27 | }
28 |
29 | function ScrollBar({
30 | className,
31 | orientation = "vertical",
32 | ...props
33 | }: React.ComponentProps) {
34 | return (
35 |
48 |
52 |
53 | )
54 | }
55 |
56 | export { ScrollArea, ScrollBar }
57 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/browser_env/javascript/frame_unmark_elements.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Go through all DOM elements in the frame (including shadowDOMs),
3 | * and cleanup previously stored data in the aria-roledescription attribute.
4 | */
5 | () => {
6 | // get all DOM elements in the current frame (does not include elements in shadowDOMs)
7 | let elements = Array.from(document.querySelectorAll('*'));
8 | let i = 0;
9 | while (i < elements.length) {
10 | const elem = elements[i];
11 | // add shadowDOM elements to the elements array, in such a way that order is preserved
12 | // TODO: do we really need the order preserved?
13 | if (elem.shadowRoot !== null) {
14 | elements = new Array(
15 | ...Array.prototype.slice.call(elements, 0, i + 1),
16 | ...Array.from(elem.shadowRoot.querySelectorAll("*")),
17 | ...Array.prototype.slice.call(elements, i + 1)
18 | );
19 | }
20 | i++;
21 | // Hack: remove custom data stored inside the aria-roledescription tag
22 | // - elem_global_id: global browsergym identifier
23 | if (elem.hasAttribute("aria-roledescription")) {
24 | let content = elem.getAttribute("aria-roledescription");
25 | // TODO: handle more data if needed
26 | let n_data_items = 1; // bid
27 | let post_data_index = 0;
28 | for (let j = 0 ; j < n_data_items ; j++) {
29 | post_data_index = content.indexOf("_", post_data_index) + 1;
30 | }
31 | original_content = content.substring(post_data_index);
32 | if (original_content) {
33 | elem.setAttribute("aria-roledescription", original_content);
34 | }
35 | else {
36 | elem.removeAttribute("aria-roledescription");
37 | }
38 | }
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/visual-tree-search-app/README.md:
--------------------------------------------------------------------------------
1 | This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/pages/api-reference/create-next-app).
2 |
3 | ## Getting Started
4 |
5 | First, run the development server:
6 |
7 | ```bash
8 | npm run dev
9 | # or
10 | yarn dev
11 | # or
12 | pnpm dev
13 | # or
14 | bun dev
15 | ```
16 |
17 | Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
18 |
19 | You can start editing the page by modifying `pages/index.tsx`. The page auto-updates as you edit the file.
20 |
21 | [API routes](https://nextjs.org/docs/pages/building-your-application/routing/api-routes) can be accessed on [http://localhost:3000/api/hello](http://localhost:3000/api/hello). This endpoint can be edited in `pages/api/hello.ts`.
22 |
23 | The `pages/api` directory is mapped to `/api/*`. Files in this directory are treated as [API routes](https://nextjs.org/docs/pages/building-your-application/routing/api-routes) instead of React pages.
24 |
25 | This project uses [`next/font`](https://nextjs.org/docs/pages/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
26 |
27 | ## Learn More
28 |
29 | To learn more about Next.js, take a look at the following resources:
30 |
31 | - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
32 | - [Learn Next.js](https://nextjs.org/learn-pages-router) - an interactive Next.js tutorial.
33 |
34 | You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
35 |
36 | ## Deploy on Vercel
37 |
38 | The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
39 |
40 | Check out our [Next.js deployment documentation](https://nextjs.org/docs/pages/building-your-application/deploying) for more details.
41 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/evaluation/evaluators.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | import math
3 | import re
4 |
5 |
6 | class Plan(BaseModel):
7 | goal_finished: bool
8 |
9 |
10 | def parse_oai_logprob(response):
11 | response_logprob = 0
12 | try:
13 | for content in response.choices[0].logprobs.content:
14 | response_logprob += content.logprob
15 | return round(math.exp(response_logprob), 5)
16 | except Exception as e:
17 | print(f"An error occurred when checking oai logprob: {e}")
18 |
19 |
20 | def extract_action(text):
21 | match = re.search(r'```(.*?)```', text, re.DOTALL)
22 | if match:
23 | extracted_text = match.group(1)
24 | return extracted_text # Output: click('249')
25 | else:
26 | raise Exception("No exact action found.")
27 |
28 |
29 | def goal_finished_evaluator(messages, openai_client):
30 | new_response = openai_client.beta.chat.completions.parse(
31 | model='gpt-4o-mini',
32 | messages=messages,
33 | response_format=Plan,
34 | logprobs=True,
35 | )
36 | message = new_response.choices[0].message.parsed
37 | confidence_score = parse_oai_logprob(new_response)
38 |
39 | goal_finished = message.goal_finished
40 | return goal_finished, confidence_score
41 |
42 |
43 | def goal_finished_value_function():
44 | pass
45 |
46 |
47 | def early_stop(
48 | trajectory: list, action_set: dict, max_steps: int, thresholds: dict[str, int]
49 | ) -> tuple[bool, str]:
50 | """Check whether need to stop early"""
51 |
52 | # reach the max step
53 | num_steps = (len(trajectory) - 1) / 2
54 | if num_steps >= max_steps:
55 | return True, f"Reach max steps {max_steps}"
56 |
57 | last_k_actions = []
58 | action_seq = []
59 |
60 | # Case: same action for k times
61 | k = thresholds["repeating_action"]
62 | if len(trajectory) >= k:
63 | last_k_actions = [extract_action(tra['action']) for tra in trajectory[-k:]]
64 | last_action = last_k_actions[-1]
65 | if (sum([action == last_action for action in
66 | last_k_actions]) > k): # not >=k as last_action in last_k_actions as well
67 | return True, f"Same typing action for {k} times"
68 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/button.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 | import { Slot } from "@radix-ui/react-slot"
3 | import { cva, type VariantProps } from "class-variance-authority"
4 |
5 | import { cn } from "@/lib/utils"
6 |
7 | const buttonVariants = cva(
8 | "inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
9 | {
10 | variants: {
11 | variant: {
12 | default:
13 | "bg-primary text-primary-foreground shadow-xs hover:bg-primary/90",
14 | destructive:
15 | "bg-destructive text-white shadow-xs hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60",
16 | outline:
17 | "border bg-background shadow-xs hover:bg-accent hover:text-accent-foreground dark:bg-input/30 dark:border-input dark:hover:bg-input/50",
18 | secondary:
19 | "bg-secondary text-secondary-foreground shadow-xs hover:bg-secondary/80",
20 | ghost:
21 | "hover:bg-accent hover:text-accent-foreground dark:hover:bg-accent/50",
22 | link: "text-primary underline-offset-4 hover:underline",
23 | },
24 | size: {
25 | default: "h-9 px-4 py-2 has-[>svg]:px-3",
26 | sm: "h-8 rounded-md gap-1.5 px-3 has-[>svg]:px-2.5",
27 | lg: "h-10 rounded-md px-6 has-[>svg]:px-4",
28 | icon: "size-9",
29 | },
30 | },
31 | defaultVariants: {
32 | variant: "default",
33 | size: "default",
34 | },
35 | }
36 | )
37 |
38 | function Button({
39 | className,
40 | variant,
41 | size,
42 | asChild = false,
43 | ...props
44 | }: React.ComponentProps<"button"> &
45 | VariantProps & {
46 | asChild?: boolean
47 | }) {
48 | const Comp = asChild ? Slot : "button"
49 |
50 | return (
51 |
56 | )
57 | }
58 |
59 | export { Button, buttonVariants }
60 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/card.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 |
3 | import { cn } from "@/lib/utils"
4 |
5 | function Card({ className, ...props }: React.ComponentProps<"div">) {
6 | return (
7 |
15 | )
16 | }
17 |
18 | function CardHeader({ className, ...props }: React.ComponentProps<"div">) {
19 | return (
20 |
28 | )
29 | }
30 |
31 | function CardTitle({ className, ...props }: React.ComponentProps<"div">) {
32 | return (
33 |
38 | )
39 | }
40 |
41 | function CardDescription({ className, ...props }: React.ComponentProps<"div">) {
42 | return (
43 |
48 | )
49 | }
50 |
51 | function CardAction({ className, ...props }: React.ComponentProps<"div">) {
52 | return (
53 |
61 | )
62 | }
63 |
64 | function CardContent({ className, ...props }: React.ComponentProps<"div">) {
65 | return (
66 |
71 | )
72 | }
73 |
74 | function CardFooter({ className, ...props }: React.ComponentProps<"div">) {
75 | return (
76 |
81 | )
82 | }
83 |
84 | export {
85 | Card,
86 | CardHeader,
87 | CardFooter,
88 | CardTitle,
89 | CardAction,
90 | CardDescription,
91 | CardContent,
92 | }
93 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/test-ws.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import websockets
4 | import logging
5 | import uuid
6 |
7 | # Set up logging to see more details
8 | logging.basicConfig(level=logging.INFO)
9 |
10 | async def test_websocket():
11 | uri = "ws://localhost:3000/ws"
12 |
13 | print(f"Connecting to {uri}")
14 |
15 | async with websockets.connect(uri) as websocket:
16 | print("Connected to WebSocket")
17 |
18 | # Send a ping message
19 | await websocket.send(json.dumps({
20 | "type": "ping"
21 | }))
22 |
23 | # Wait for response
24 | response = await websocket.recv()
25 | print(f"Received: {response}")
26 |
27 | # You can send other test messages here
28 | # For example, if the websocket supports it:
29 | await websocket.send(json.dumps({
30 | "type": "hello",
31 | "message": "Testing the websocket connection"
32 | }))
33 |
34 | # Set a timeout to prevent hanging indefinitely
35 | timeout = 60 # 1 minute
36 | start_time = asyncio.get_event_loop().time()
37 |
38 | # Continuously receive messages with timeout
39 | while True:
40 | try:
41 | # Use wait_for to add timeout to recv
42 | elapsed = asyncio.get_event_loop().time() - start_time
43 | remaining = max(0, timeout - elapsed)
44 |
45 | if remaining <= 0:
46 | print("Timeout reached. No more messages received.")
47 | break
48 |
49 | response = await asyncio.wait_for(websocket.recv(), timeout=remaining)
50 | print(f"Received: {response}")
51 |
52 | try:
53 | data = json.loads(response)
54 | print(f"Message type: {data.get('type', 'unknown')}")
55 | print(json.dumps(data, indent=2))
56 | except json.JSONDecodeError:
57 | print("Received non-JSON message")
58 |
59 | except asyncio.TimeoutError:
60 | print("Timeout waiting for response from WebSocket")
61 | break
62 | except websockets.exceptions.ConnectionClosed:
63 | print("Connection closed")
64 | break
65 | except Exception as e:
66 | print(f"Error: {e}")
67 | break
68 |
69 | asyncio.run(test_websocket())
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # VisualTreeSearch-Demo License
2 |
3 | VisualTreeSearch-Demo is a project that builds upon the LiteWebAgent and BrowserGym projects, incorporating modifications and new contributions to enhance web automation and interaction.
4 |
5 | ## License
6 |
7 | Copyright 2024 PathOnAI.org
8 |
9 | Licensed under the Apache License, Version 2.0 (the "License");
10 | you may not use this file except in compliance with the License.
11 | You may obtain a copy of the License at:
12 |
13 | http://www.apache.org/licenses/LICENSE-2.0
14 |
15 | Unless required by applicable law or agreed to in writing, software
16 | distributed under the License is distributed on an "AS IS" BASIS,
17 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | See the License for the specific language governing permissions and
19 | limitations under the License.
20 |
21 | ## Third-Party Licenses
22 |
23 | This project includes code inherited from the LiteWebAgent project, which also incorporates code from BrowserGym. Both projects are licensed under the Apache License, Version 2.0.
24 |
25 | ### BrowserGym
26 |
27 | Copyright 2024 ServiceNow
28 |
29 | You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0
30 |
31 | The original BrowserGym code is available at: https://github.com/ServiceNow/BrowserGym
32 |
33 | ### LiteWebAgent
34 |
35 | Modifications to the original BrowserGym code within LiteWebAgent include:
36 | * Reuse of action and observation modules
37 | * Extensive modifications to `action/highlevel.py` for improved frontend visualization
38 | * Introduction of `observation/extract_elements.py` to enable extraction of interactive elements
39 |
40 | ## Additional Modifications in VisualTreeSearch-Demo
41 |
42 | Further modifications and new contributions following the LiteWebAgent project are attributed to **Danqing Zhang** (danqing.zhang.personal@gmail.com). Permission must be obtained from Danqing Zhang for use of these specific contributions.
43 |
44 | For inquiries regarding the use of Danqing Zhang's contributions, please contact:
45 | **Danqing Zhang**
46 | Email: danqing.zhang.personal@gmail.com
47 |
48 | ## Contributing
49 |
50 | Please read the license information carefully before contributing to or using this project. Ensure you have the necessary permissions, especially when working with components attributed to specific individuals or organizations.
51 |
52 | ## Contact
53 |
54 | For general inquiries about the project, please refer to the PathOnAI.org copyright notice.
55 | For specific questions about additional modifications, contact Danqing Zhang at the email provided above.
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/run_demo_treesearch_async.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dotenv import load_dotenv
3 | import json
4 | import logging
5 | import asyncio
6 |
7 | from lwats.core_async.config import AgentConfig, add_agent_config_arguments, filter_valid_config_args
8 | load_dotenv()
9 | from lwats.core_async.agent_factory import setup_search_agent
10 |
11 |
12 | async def main(args):
13 | # Log the arguments to help debug
14 | logging.info(f"Running tree search with args: {args.__dict__}")
15 |
16 | # Ensure starting_url is set correctly
17 | if not hasattr(args, 'starting_url') or not args.starting_url:
18 | logging.error("starting_url is not set or is empty")
19 | return {"error": "starting_url is required"}
20 |
21 | logging.info(f"Using starting URL: {args.starting_url}")
22 |
23 | agent_config = AgentConfig(**filter_valid_config_args(args.__dict__))
24 | print(agent_config)
25 |
26 | agent, playwright_manager = await setup_search_agent(
27 | agent_type=args.agent_type,
28 | starting_url=args.starting_url,
29 | goal=args.goal,
30 | images=args.images,
31 | agent_config=agent_config
32 | )
33 |
34 | # Run the search
35 | results = await agent.run()
36 |
37 | # Close the playwright_manager when done
38 | await playwright_manager.close()
39 |
40 | return results
41 |
42 | if __name__ == "__main__":
43 | # When running as a script, we need to use absolute imports
44 | import sys
45 | import os
46 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
47 |
48 | parser = argparse.ArgumentParser(description="Run web agent with specified configuration")
49 | parser.add_argument("--agent-type", type=str, default="LATSAgent",
50 | help="Type of agent to use (default: LATSAgent)")
51 | # Task
52 | parser.add_argument("--starting-url", type=str, default="http://xwebarena.pathonai.org:7770/",
53 | help="Starting URL for the web agent")
54 | parser.add_argument("--goal", type=str, default="search running shoes, click on the first result",
55 | help="Goal for the web agent to accomplish")
56 | parser.add_argument("--images", type=str, default="",
57 | help="Comma-separated paths to image files (e.g., 'path1.jpg,path2.jpg')")
58 |
59 | add_agent_config_arguments(parser)
60 | args = parser.parse_args()
61 | args.images = [img.strip() for img in args.images.split(',')] if args.images else []
62 |
63 | # Run the async main function with asyncio
64 | results = asyncio.run(main(args))
65 | print(results)
66 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/evaluation_async/evaluators.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | import math
3 | import re
4 | from ..webagent_utils_async.evaluation.evaluators import parse_oai_logprob
5 |
6 | import base64
7 |
8 | class IsGoalFinished(BaseModel):
9 | goal_finished: bool
10 |
11 | def goal_finished_evaluator(messages, openai_client, goal, screenshot, model='gpt-4o-mini'):
12 | system_message = """You are an AI assistant evaluating the progress of a web browsing task. Your role is to determine if the overall goal of the task has been accomplished based on the actions taken and the conversation history.
13 |
14 | Guidelines for determining if the goal is finished:
15 | 1. Review the list of actions taken during the web browsing session.
16 | 2. Compare these actions to the stated goal of the task.
17 | 3. Consider if the necessary information has been found or if the required interactions have been completed.
18 | 4. Look for indicators of task completion, such as finding specific information, completing a purchase, or submitting a form.
19 | 5. If the last actions suggest that the main objective has been achieved, consider the goal finished.
20 | 6. If there are clear next steps that haven't been taken, the goal may not be finished.
21 |
22 | Remember:
23 | - Web browsing tasks often involve multiple steps like searching, navigating pages, filling forms, or extracting information.
24 | - The goal is finished when all necessary actions to complete the task have been taken.
25 | - If the actions deviate significantly from the goal without resolution, the goal may not be finished.
26 |
27 | Respond with 'goal_finished: true' if you determine the goal has been accomplished, or 'goal_finished: false' if it's still in progress or incomplete."""
28 |
29 | base64_image = base64.b64encode(screenshot).decode('utf-8')
30 | # screenshot bytes
31 | new_response = openai_client.beta.chat.completions.parse(
32 | model=model,
33 | messages= [
34 | {"role": "system", "content": system_message},
35 | *messages,
36 | {"role": "user",
37 | "content": [
38 | {"type": "text", "text": f"The final screenshot is as in the image, and the goal is {goal}, Is the overall goal finished?"},
39 | {"type": "image_url",
40 | "image_url": {
41 | "url": f"data:image/jpeg;base64,{base64_image}",
42 | "detail": "high"
43 | }
44 | }
45 | ]
46 | },
47 | ],
48 | response_format=IsGoalFinished,
49 | logprobs=True,
50 | )
51 | message = new_response.choices[0].message.parsed
52 | confidence_score = parse_oai_logprob(new_response)
53 |
54 | goal_finished = message.goal_finished
55 | return goal_finished, confidence_score
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from fastapi import FastAPI, WebSocket
4 | from fastapi.middleware.cors import CORSMiddleware
5 | import uvicorn
6 | from dotenv import load_dotenv
7 |
8 | # Configure logging
9 | logging.basicConfig(
10 | level=logging.INFO,
11 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12 | )
13 |
14 | # Load environment variables
15 | load_dotenv()
16 |
17 | # Create FastAPI app
18 | app = FastAPI(
19 | title="Core API Backend",
20 | redirect_slashes=False
21 | )
22 |
23 | # Configure CORS
24 | app.add_middleware(
25 | CORSMiddleware,
26 | allow_origins=["*"], # Allow all origins
27 | allow_credentials=True,
28 | allow_methods=["*"],
29 | allow_headers=["*"],
30 | )
31 |
32 | # Root route
33 | @app.get("/")
34 | async def root():
35 | return {"message": "Welcome to FastAPI backend"}
36 |
37 | # Import routers - do this after creating the app to avoid circular imports
38 | from app.api.routes.hello import router as hello_router
39 | from app.api.routes.sse import router as sse_router
40 | from app.api.routes.websocket import router as ws_router
41 | from app.api.routes.tree_websocket import router as tree_ws_router
42 | from app.api.routes.tree_search import router as tree_search_router
43 | from app.api.routes.tree_search_websocket import router as tree_search_ws_router
44 | from app.api.routes.terminate_session import router as terminate_session_router
45 | # Include routers from different modules
46 | app.include_router(hello_router, prefix="/api/hello", tags=["hello"])
47 | app.include_router(sse_router, prefix="/api/sse", tags=["sse"])
48 | app.include_router(ws_router, prefix="/api/ws", tags=["websocket"])
49 | app.include_router(tree_ws_router, prefix="/api/tree", tags=["tree"])
50 | app.include_router(tree_search_router, prefix="/api/tree-search", tags=["tree-search"])
51 | app.include_router(tree_search_ws_router, prefix="/api/tree-search-ws", tags=["tree-search-ws"])
52 | app.include_router(terminate_session_router, prefix="/api/terminate-session", tags=["terminate-session"])
53 | # Import the WebSocket endpoint handlers
54 | from app.api.routes.websocket import websocket_endpoint
55 | from app.api.routes.tree_websocket import tree_websocket_endpoint
56 | from app.api.routes.tree_search_websocket import tree_search_websocket_endpoint
57 | # Register the WebSocket endpoints
58 | @app.websocket("/ws")
59 | async def websocket_route(websocket: WebSocket):
60 | await websocket_endpoint(websocket)
61 |
62 | @app.websocket("/tree-ws")
63 | async def tree_websocket_route(websocket: WebSocket):
64 | await tree_websocket_endpoint(websocket)
65 |
66 | @app.websocket("/tree-search-ws")
67 | async def tree_search_websocket_route(websocket: WebSocket):
68 | await tree_search_websocket_endpoint(websocket)
69 |
70 | if __name__ == "__main__":
71 | port = int(os.getenv("PORT", 3000))
72 | uvicorn.run("app.main:app", host="0.0.0.0", port=port, reload=True)
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/registry.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, Callable, List
2 | from typing import Optional
3 |
4 |
5 | class Tool:
6 | def __init__(self, name: str, func: Callable, description: str, parameters: Dict[str, Any]):
7 | self.name = name
8 | self.func = func
9 | self.description = description
10 | self.parameters = parameters
11 |
12 |
13 | class ToolRegistry:
14 | _instance = None
15 | _tools: Dict[str, Tool] = {}
16 |
17 | def __new__(cls):
18 | if cls._instance is None:
19 | cls._instance = super(ToolRegistry, cls).__new__(cls)
20 | cls._register_all_tools() # Register all tools when the singleton is first created
21 | return cls._instance
22 |
23 | @classmethod
24 | def register(cls, tool: Tool):
25 | print(f"Registering tool: {tool.name}") # Debug statement
26 | cls._tools[tool.name] = tool
27 |
28 | @classmethod
29 | def get_tool(cls, name: str) -> Tool:
30 | return cls._tools.get(name)
31 |
32 | @classmethod
33 | def get_all_tools(cls) -> Dict[str, Tool]:
34 | return cls._tools
35 |
36 | @classmethod
37 | def get_tool_description(cls, name: str) -> Optional[Dict[str, Any]]:
38 | tool = cls.get_tool(name)
39 | if tool is None:
40 | return None # or return {} if you prefer an empty dictionary
41 |
42 | # Create a copy of the tool parameters and remove 'required' from properties
43 | properties = {
44 | param: {k: v for k, v in details.items() if k != "required"}
45 | for param, details in tool.parameters.items()
46 | }
47 |
48 | # Ensure the 'required' field is always a list
49 | required_params = [param for param, details in tool.parameters.items() if details.get("required", False)]
50 |
51 | return {
52 | "type": "function",
53 | "function": {
54 | "name": tool.name,
55 | "description": tool.description,
56 | "parameters": {
57 | "type": "object",
58 | "properties": properties,
59 | "required": required_params if required_params else [] # Ensure it's always an array
60 | }
61 | }
62 | }
63 |
64 | @classmethod
65 | def _register_all_tools(cls):
66 | try:
67 | from .navigation import register_navigation_tool
68 | register_navigation_tool()
69 | from .select_option import register_select_option_tool
70 | register_select_option_tool()
71 | from .upload_file import register_upload_file_tool
72 | register_upload_file_tool()
73 | from .webscraping import register_webscraping_tool
74 | register_webscraping_tool()
75 | except Exception as e:
76 | print(f"Error while registering tools: {e}") # Debug statement to catch any import or registration issues
77 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/README.md:
--------------------------------------------------------------------------------
1 | # Backend
2 | ## 1. test playwright_manager
3 | ```
4 | cd visual-tree-search-backend/app/api/lwats/webagent_utils_async/utils
5 | python playwright_manager.py
6 | ```
7 | also we use this script to renew the cookies store in file `app/api/shopping.json`
8 |
9 | ## 2. Run demo tree search
10 | chromium mode
11 | ```
12 | python run_demo_treesearch_async.py \
13 | --browser-mode chromium \
14 | --storage-state shopping.json \
15 | --starting-url "http://xwebarena.pathonai.org:7770/" \
16 | --agent-type "SimpleSearchAgent" \
17 | --action_generation_model "gpt-4o-mini" \
18 | --goal "search running shoes, click on the first result" \
19 | --iterations 3 \
20 | --max_depth 3 \
21 | --search_algorithm bfs
22 | ```
23 |
24 | browserbase mode
25 | ```
26 | python run_demo_treesearch_async.py \
27 | --browser-mode browserbase \
28 | --storage-state shopping.json \
29 | --starting-url "http://xwebarena.pathonai.org:7770/" \
30 | --agent-type "SimpleSearchAgent" \
31 | --action_generation_model "gpt-4o-mini" \
32 | --goal "search running shoes, click on the first result" \
33 | --iterations 3 \
34 | --max_depth 3 \
35 | --search_algorithm bfs
36 | ```
37 |
38 |
39 | ## 3. test websocket
40 | ```
41 | uvicorn app.main:app --host 0.0.0.0 --port 3000
42 |
43 | python test/test-tree-search-ws-simple.py --algorithm dfs
44 | python test/test-tree-search-ws-simple.py --algorithm bfs
45 | ```
46 |
47 | ## 4. end-to-end test with frontend
48 | ```
49 | backend: uvicorn app.main:app --host 0.0.0.0 --port 3000
50 | frontend: npm run dev -- -p 3001
51 | then go to http://localhost:3001/tree-search-playground
52 | to test the message passing from the backend to the frontend
53 | ```
54 |
55 |
56 |
57 | ## 5. terminate session from backend
58 | ```
59 | curl -X POST http://localhost:3000/api/terminate-session/647f4021-2402-4733-84a3-255f0d20c151
60 | {"status":"success","message":"Session 647f4021-2402-4733-84a3-255f0d20c151 termination requested"}
61 | ```
62 |
63 | ## 6. Add more search agent
64 | ```
65 | python run_demo_treesearch_async.py \
66 | --browser-mode chromium \
67 | --storage-state shopping.json \
68 | --starting-url "http://xwebarena.pathonai.org:7770/" \
69 | --agent-type "LATSAgent" \
70 | --action_generation_model "gpt-4o-mini" \
71 | --goal "search running shoes, click on the first result" \
72 | --iterations 3 \
73 | --max_depth 3
74 | ```
75 |
76 | ```
77 | python run_demo_treesearch_async.py \
78 | --browser-mode chromium \
79 | --storage-state shopping.json \
80 | --starting-url "http://xwebarena.pathonai.org:7770/" \
81 | --agent-type "MCTSAgent" \
82 | --action_generation_model "gpt-4o-mini" \
83 | --goal "search running shoes, click on the first result" \
84 | --iterations 3 \
85 | --max_depth 3
86 | ```
87 |
88 | ## 7. Add LATS agent
89 | * test run_demo_treesearch_async.py
90 | * test web socket
91 | ```
92 | uvicorn app.main:app --host 0.0.0.0 --port 3000
93 | python test/test-tree-search-ws-lats.py
94 | ```
95 |
96 | ## 7. Add MCTS agent
97 | * test run_demo_treesearch_async.py
98 | * test web socket
99 | ```
100 | uvicorn app.main:app --host 0.0.0.0 --port 3000
101 | python test/test-tree-search-ws-mcts.py
102 | ```
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/routes/sse.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from datetime import datetime
4 | from typing import Dict, List, Any
5 | import logging
6 |
7 | # Configure basic logging
8 | logging.basicConfig(
9 | level=logging.INFO,
10 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11 | )
12 |
13 | from fastapi import APIRouter, Request
14 | from fastapi.responses import StreamingResponse
15 |
16 | router = APIRouter()
17 |
18 | # Store active SSE connections (using a simple list for the example)
19 | clients: List[asyncio.Queue] = []
20 |
21 | # Track the ping task
22 | ping_task = None
23 |
24 | async def send_event_to_all(event_data: Dict[str, Any]):
25 | """Send an event to all connected clients"""
26 | for queue in clients:
27 | await queue.put(event_data)
28 |
29 | async def ping_clients():
30 | """Send a ping to all clients every second"""
31 | while True:
32 | if clients:
33 | event_data = {
34 | "type": "ping",
35 | "message": "Server heartbeat",
36 | "timestamp": datetime.utcnow().isoformat()
37 | }
38 | await send_event_to_all(event_data)
39 | logging.info(f"Sent ping to {len(clients)} clients")
40 | await asyncio.sleep(1)
41 |
42 | async def sse_generator(request: Request, client_queue: asyncio.Queue):
43 | """Generate SSE events"""
44 | try:
45 | while True:
46 | if await request.is_disconnected():
47 | break
48 |
49 | # Get message from the queue
50 | message = await client_queue.get()
51 |
52 | # Convert message to string if it's a dict
53 | if isinstance(message, dict):
54 | message = json.dumps(message)
55 |
56 | # Send the event
57 | yield f"data: {message}\n\n"
58 | finally:
59 | # Remove client when disconnected
60 | if client_queue in clients:
61 | clients.remove(client_queue)
62 | logging.info(f"Client disconnected - remaining clients: {len(clients)}")
63 |
64 | @router.get("/events")
65 | async def sse_endpoint(request: Request):
66 | """SSE endpoint that sends events to clients"""
67 | # Client connection queue
68 | client_queue = asyncio.Queue()
69 | clients.append(client_queue)
70 |
71 | # Start ping task if not already running
72 | global ping_task
73 | if ping_task is None or ping_task.done():
74 | ping_task = asyncio.create_task(ping_clients())
75 |
76 | logging.info(f"Client connected - total clients: {len(clients)}")
77 |
78 | # Initial connection message
79 | await client_queue.put({
80 | "type": "connection",
81 | "message": "Connected to SSE"
82 | })
83 |
84 | # Use StreamingResponse for SSE
85 | return StreamingResponse(
86 | sse_generator(request, client_queue),
87 | media_type="text/event-stream",
88 | headers={
89 | "Cache-Control": "no-cache",
90 | "Connection": "keep-alive",
91 | }
92 | )
93 |
94 | @router.get("/status")
95 | async def sse_status():
96 | """Get SSE connection status"""
97 | return {
98 | "active": len(clients),
99 | "status": "running"
100 | }
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/api/routes/test_container.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | import logging
3 | from typing import Dict, List, Any
4 | import json
5 | import docker,time,os,sys
6 |
7 | from fastapi import APIRouter, HTTPException
8 |
9 | # Configure basic logging
10 | logging.basicConfig(
11 | level=logging.INFO,
12 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13 | )
14 |
15 | router = APIRouter()
16 |
17 |
18 | def runsleep(dur, tik=5):
19 | cnt=int(dur)
20 | while(cnt>0):
21 | print('sleeping... ',cnt)
22 | time.sleep(int(tik))
23 | cnt-=int(tik)
24 |
25 |
26 | # curl -N http://localhost:8000/api/container/reset/SERVERIP
27 | @router.get("/reset/{myip}")
28 | async def container_reset(myip: str):
29 | # Step 1: Connect to Docker
30 | client = docker.from_env()
31 | res=""
32 |
33 | # Step 2: Find the container ID using the image shopping_final_0712
34 | containers = client.containers.list(all=True)
35 | target_container_id = None
36 | for container in containers:
37 | if container.attrs['Config']['Image'] == 'shopping_final_0712':
38 | target_container_id = container.short_id
39 | break
40 |
41 | # Step 3: Stop and remove the container if found
42 | if target_container_id:
43 | container = client.containers.get(target_container_id)
44 | container.stop()
45 | container.remove()
46 | res=res+f"Container {target_container_id} stopped and removed."+"\n"
47 | else:
48 | return {
49 | "status": "No container found using the image shopping_final_0712."
50 | }
51 |
52 | # Step 4: Create a new container with specified parameters
53 | try:
54 | new_container = client.containers.run(
55 | 'shopping_final_0712',
56 | detach=True,
57 | name='shopping',
58 | ports={'80/tcp': 7770, '3306/tcp': 33061}
59 | )
60 | res=res+f"New container created with ID: {new_container.short_id}"+"\n"
61 |
62 | runsleep(20)
63 | # Step 5: Execute commands inside the new container
64 | '''
65 | $ docker exec shopping /var/www/magento2/bin/magento setup:store-config:set --base-url="http://SERVERIP:7770" # no trailing slash
66 | $ docker exec shopping mysql -u magentouser -pMyPassword magentodb -e 'UPDATE core_config_data SET value="http://SERVERIP:7770/" WHERE path = "web/secure/base_url";'
67 | $ docker exec shopping /var/www/magento2/bin/magento cache:flush
68 | '''
69 | new_container = client.containers.get('shopping')
70 |
71 | cmdlist=['/var/www/magento2/bin/magento setup:store-config:set --base-url="http://'+myip+':7770"' ,
72 | f'mysql -u magentouser -pMyPassword magentodb -e \'UPDATE core_config_data SET value="http://'+myip+':7770/" WHERE path = "web/secure/base_url";\'' ,
73 | '/var/www/magento2/bin/magento cache:flush']
74 |
75 | for cmd in cmdlist:
76 | res=res+" cmd: "+cmd+"\n"
77 | exec_result = new_container.exec_run(cmd, stdout=True, stderr=True)
78 | res=res+" result: "+exec_result.output.decode('utf-8')+"\n"
79 | res=res+" --------------------- \n"
80 |
81 | return {
82 | "status": res
83 | }
84 |
85 | except docker.errors.APIError as e:
86 | return {
87 | "status": f"Failed to execute commands: {e}"
88 | }
89 |
90 |
91 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/action/action_parser.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Dict, List, Tuple, Any, Union, Literal
3 | import ast
4 |
5 | def parse_action(action_str: str) -> Tuple[str, List[Any], Dict[str, Any]]:
6 | """
7 | Executes an action string and returns a list of executed functions with their arguments.
8 |
9 | Args:
10 | action_str: String containing one or more function calls
11 |
12 | Returns:
13 | List of tuples containing (function_name, args_list)
14 |
15 | Examples:
16 | parse_action('click("123")')
17 | parse_action('fill("237", "example value")')
18 | """
19 | # Strip whitespace and handle empty strings
20 | action_str = action_str.strip()
21 | if not action_str:
22 | return []
23 |
24 | # Extract function name and arguments using regex
25 | match = re.match(r'(\w+)\((.*)\)$', action_str)
26 | if not match:
27 | raise ValueError(f"Invalid action format: {action_str}")
28 |
29 | func_name, args_str = match.groups()
30 |
31 | # Parse arguments handling both positional and keyword args
32 | args = []
33 | kwargs = {}
34 |
35 | if args_str:
36 | # Use a more sophisticated regex that preserves list structures
37 | pattern = r',\s*(?=[^[\]]*(?:\[|$))'
38 | parts = re.split(pattern, args_str)
39 |
40 | for part in parts:
41 | part = part.strip()
42 | if '=' in part: # Keyword argument
43 | key, value = part.split('=', 1)
44 | key = key.strip()
45 | try:
46 | kwargs[key] = ast.literal_eval(value)
47 | except (SyntaxError, ValueError):
48 | raise ValueError(f"Invalid keyword argument format in: {part}")
49 | else: # Positional argument
50 | try:
51 | args.append(ast.literal_eval(part))
52 | except (SyntaxError, ValueError):
53 | raise ValueError(f"Invalid argument format in: {part}")
54 |
55 | return (func_name, args, kwargs)
56 |
57 |
58 | if __name__ == "__main__":
59 | print(parse_action('noop(1000)'))
60 | """
61 | Examples:
62 | select_option('a48', "blue")
63 | select_option('c48', ["red", "green", "blue"])
64 | """
65 | print(parse_action('select_option("a48", "blue")'))
66 | print(parse_action('select_option("c48", ["red", "green", "blue"])'))
67 |
68 | """
69 | click('a51')
70 | click('b22', button="right")
71 | click('48', button="middle", modifiers=["Shift"])
72 | """
73 | print(parse_action('click("a51")'))
74 | print(parse_action('click("b22", button="right")'))
75 | print(parse_action('click("48", button="middle", modifiers=["Shift"])'))
76 |
77 | """
78 | upload_file("572", "my_receipt.pdf")
79 | upload_file("63", ["/home/bob/Documents/image.jpg", "/home/bob/Documents/file.zip"])
80 | """
81 | print(parse_action('upload_file("572", "my_receipt.pdf")'))
82 | print(parse_action('upload_file("63", ["/home/bob/Documents/image.jpg", "/home/bob/Documents/)file.zip"])'))
83 |
84 | """
85 | fill('237', 'example value')
86 | fill('45', "multi-line\\nexample")
87 | fill('a12', "example with \\"quotes\\"")
88 | """
89 | print(parse_action('fill("237", "example value")'))
90 | print(parse_action('fill("45", "multi-line\\nexample")'))
91 | print(parse_action('fill("a12", "example with \\"quotes\\"")'))
92 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/evaluation/feedback.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import logging
4 | from openai import OpenAI
5 | import base64
6 | from ..utils.utils import encode_image
7 | from pydantic import BaseModel
8 |
9 | logger = logging.getLogger(__name__)
10 | openai_client = OpenAI()
11 |
12 |
13 | async def capture_post_action_feedback(page, action, goal, log_folder):
14 | # screenshot_path_post = os.path.join(log_folder, 'screenshots', 'screenshot_post.png')
15 | time.sleep(3)
16 | # page.screenshot(path=screenshot_path_post)
17 | # base64_image = encode_image(screenshot_path_post)
18 | screenshot_bytes = await page.screenshot()
19 |
20 | # Encode the bytes to base64
21 | base64_image = base64.b64encode(screenshot_bytes).decode('utf-8')
22 | prompt = f"""
23 | After we take action {action}, a screenshot was captured.
24 |
25 | # Screenshot description:
26 | The image provided is a screenshot of the application state after the action was performed.
27 |
28 | # The original goal:
29 | {goal}
30 |
31 | Based on the screenshot and the updated Accessibility Tree, is the goal finished now? Provide an answer and explanation, referring to visual elements from the screenshot if relevant.
32 | """
33 |
34 | response = openai_client.chat.completions.create(
35 | model="gpt-4o",
36 | messages=[
37 | {"role": "user",
38 | "content": [
39 | {"type": "text", "text": prompt},
40 | {"type": "image_url",
41 | "image_url": {
42 | "url": f"data:image/jpeg;base64,{base64_image}",
43 | "detail": "high"
44 | }
45 | }
46 | ]
47 | },
48 | ],
49 | )
50 |
51 | return response.choices[0].message.content
52 |
53 |
54 | class Feedback(BaseModel):
55 | is_done: bool
56 | explanation: str
57 |
58 | FEEDBACK_SYSTEM_PROMPT_TEMPLATE = \
59 | """You are a helpful assitant. Given a goal, a description of the action just taken, a screenshot of the web page after the action was taken, your task is to provides feedback on web task completion by evaluating the current state against the desired goal."""
60 |
61 | FEEDBACK_USER_PROMPT_TEMPLATE = \
62 | """
63 | # The goal:
64 | {goal}
65 |
66 | # The action description:
67 | {action_description}
68 |
69 | Based on the screenshot of the web page, is the goal finished now? Provide an answer and explanation, referring to visual elements from the screenshot if relevant."""
70 |
71 | async def generate_feedback_with_screenshot(goal, action_description, screenshot, model):
72 | system_prompt = FEEDBACK_SYSTEM_PROMPT_TEMPLATE
73 | user_prompt = FEEDBACK_USER_PROMPT_TEMPLATE.format(goal=goal, action_description=action_description)
74 | base64_image = base64.b64encode(screenshot).decode('utf-8')
75 | response = openai_client.beta.chat.completions.parse(
76 | model=model,
77 | response_format=Feedback,
78 | messages=[
79 | {"role": "system", "content": system_prompt},
80 | {"role": "user",
81 | "content": [
82 | {"type": "text", "text": user_prompt},
83 | {
84 | "type": "image_url",
85 | "image_url": {
86 | "url": f"data:image/jpeg;base64,{base64_image}",
87 | "detail": "low"
88 | }
89 | }
90 | ]
91 | },
92 | ],
93 | )
94 | feedback = response.choices[0].message.parsed
95 | return feedback
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/routes/websocket.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from datetime import datetime
4 | from typing import List, Dict, Any
5 | import logging
6 |
7 | # Configure basic logging
8 | logging.basicConfig(
9 | level=logging.INFO,
10 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11 | )
12 |
13 | from fastapi import APIRouter, WebSocket, WebSocketDisconnect
14 |
15 | router = APIRouter()
16 |
17 | # Track active WebSocket connections
18 | active_connections: List[WebSocket] = []
19 |
20 | # Track the ping task
21 | ping_task = None
22 |
23 | async def send_to_all(data: Dict[str, Any]):
24 | """Send a message to all connected clients"""
25 | message = json.dumps(data)
26 | for connection in active_connections:
27 | try:
28 | await connection.send_text(message)
29 | except Exception as e:
30 | logging.error(f"Error sending message: {e}")
31 |
32 | async def ping_clients():
33 | """Send a ping to all clients every second"""
34 | while True:
35 | if active_connections:
36 | data = {
37 | "type": "ping",
38 | "message": "Server heartbeat",
39 | "timestamp": datetime.utcnow().isoformat()
40 | }
41 | await send_to_all(data)
42 | logging.info(f"Sent ping to {len(active_connections)} websocket clients")
43 | await asyncio.sleep(1)
44 |
45 | # Define the WebSocket endpoint that will be used in main.py
46 | async def websocket_endpoint(websocket: WebSocket):
47 | """Handle WebSocket connections"""
48 | await websocket.accept()
49 | active_connections.append(websocket)
50 |
51 | # Start ping task if not already running
52 | global ping_task
53 | if ping_task is None or ping_task.done():
54 | ping_task = asyncio.create_task(ping_clients())
55 |
56 | logging.info(f"WebSocket client connected - total clients: {len(active_connections)}")
57 |
58 | # Send initial connection message
59 | await websocket.send_json({
60 | "type": "connection",
61 | "message": "Connected to WebSocket server"
62 | })
63 |
64 | try:
65 | while True:
66 | # Receive message from the client
67 | data = await websocket.receive_text()
68 |
69 | try:
70 | # Parse the message
71 | parsed_data = json.loads(data)
72 | logging.info(f"Received message: {parsed_data}")
73 |
74 | # Echo back the message
75 | await websocket.send_json({
76 | "type": "echo",
77 | "message": parsed_data,
78 | "timestamp": datetime.utcnow().isoformat()
79 | })
80 | except json.JSONDecodeError as e:
81 | logging.error(f"Error parsing message: {e}")
82 | await websocket.send_json({
83 | "type": "error",
84 | "message": f"Invalid JSON: {str(e)}"
85 | })
86 | except WebSocketDisconnect:
87 | # Remove connection when client disconnects
88 | if websocket in active_connections:
89 | active_connections.remove(websocket)
90 | logging.info(f"WebSocket client disconnected - remaining clients: {len(active_connections)}")
91 | except Exception as e:
92 | logging.error(f"WebSocket error: {e}")
93 | if websocket in active_connections:
94 | active_connections.remove(websocket)
95 |
96 | # Add a route for testing WebSocket functionality via HTTP
97 | @router.get("/status")
98 | async def websocket_status():
99 | """Get WebSocket connection status"""
100 | return {
101 | "active": len(active_connections),
102 | "status": "running"
103 | }
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/action/parsers.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import pyparsing as pp
3 |
4 | from dataclasses import dataclass
5 | from typing import Any
6 |
7 |
8 | @dataclass
9 | class NamedArgument:
10 | name: str
11 | value: Any
12 |
13 | def __repr__(self):
14 | return f"{self.name}={repr(self.value)}"
15 |
16 |
17 | def _build_highlevel_action_parser() -> pp.ParserElement:
18 | """
19 | Returns:
20 | An action parser that accepts Python-like function calls with string, number, list or dict literals as arguments.
21 | Example:
22 | func("a", 42, None, True, [2, 4, "s"], {"a_key": "a_value"}, )
23 | The parser is loose and accepts multi-line or single-line combinations af calls.
24 | Example:
25 | func() func()
26 | \tfunc()
27 | Python comments are ignored.
28 | Example:
29 | # this is a comment
30 | func() # this function call will be parsed
31 | # func() # this one will not
32 | The parser will return a list of (function_name, function_args) tuples, one for each function call in the input.
33 | The parser will raise exceptions
34 |
35 | """
36 |
37 | def make_keyword(kwd_str, kwd_value):
38 | return pp.Keyword(kwd_str).set_parse_action(pp.replace_with(kwd_value))
39 |
40 | TRUE = make_keyword("True", True)
41 | FALSE = make_keyword("False", False)
42 | NONE = make_keyword("None", None)
43 |
44 | LBRACK, RBRACK, LBRACE, RBRACE, LPAREN, RPAREN, COLON = map(pp.Suppress, "[]{}():")
45 |
46 | def literal_eval(toks):
47 | return ast.literal_eval(toks[0])
48 |
49 | string = pp.python_quoted_string().set_parse_action(literal_eval)
50 | number = pp.pyparsing_common.number()
51 | dict = pp.Forward().set_name("dict") # will be defined later
52 | list = pp.Forward().set_name("list") # will be defined later
53 | _tuple = pp.Forward().set_name("tuple") # will be defined later
54 | element = (string | number | dict | list | _tuple | TRUE | FALSE | NONE).set_name("element")
55 |
56 | list_items = pp.DelimitedList(element, allow_trailing_delim=True).set_name(None)
57 | list << pp.Group(LBRACK + pp.Optional(list_items) + RBRACK, aslist=True)
58 | _tuple << pp.Group(LPAREN + pp.Optional(list_items) + RPAREN, aslist=True).set_parse_action(
59 | lambda tokens: tuple(tokens[0])
60 | )
61 |
62 | dict_item = pp.Group(string + COLON + element, aslist=True).set_name("dict item")
63 | dict_items = pp.DelimitedList(dict_item, allow_trailing_delim=True).set_name(None)
64 | dict << pp.Dict(LBRACE + pp.Optional(dict_items) + RBRACE, asdict=True)
65 |
66 | arg = element
67 | list_args = pp.DelimitedList(arg, allow_trailing_delim=True).set_name(None)
68 | named_arg = (pp.pyparsing_common.identifier() + pp.Literal("=") + element).set_parse_action(
69 | lambda tokens: NamedArgument(name=tokens[0], value=tokens[2])
70 | )
71 | list_named_args = pp.DelimitedList(named_arg, allow_trailing_delim=True).set_name(None)
72 | function_call = pp.pyparsing_common.identifier() + pp.Group(
73 | LPAREN + pp.Optional(list_args) + pp.Optional(list_named_args) + RPAREN, aslist=True
74 | )
75 |
76 | multiple_function_calls = pp.DelimitedList(pp.Group(function_call), delim="")
77 | multiple_function_calls.ignore(pp.python_style_comment())
78 |
79 | parser = multiple_function_calls
80 |
81 | return parser
82 |
83 |
84 | # this one will be used to extract python-like function calls
85 | highlevel_action_parser: pp.ParserElement = _build_highlevel_action_parser()
86 |
87 | # this one will be used to process the docstring in high-level actions, in order to describe the action space
88 | action_docstring_parser: pp.ParserElement = (
89 | pp.Group(pp.OneOrMore(pp.Word(pp.printables), stop_on=pp.Literal("Examples:")))
90 | + pp.Literal("Examples:").suppress()
91 | + pp.Group(highlevel_action_parser)
92 | )
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/core_async/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field, fields
2 | from typing import List, Optional
3 |
4 | @dataclass
5 | class AgentConfig:
6 | # Browser settings
7 | headless: bool = False
8 | browser_mode: str = "browserbase"
9 | storage_state: str = 'state.json'
10 |
11 | # Model settings
12 | default_model: str = "gpt-4o-mini"
13 | action_generation_model: str = "gpt-4o-mini"
14 | feedback_model: str = "gpt-4o-mini"
15 | planning_model: str = "gpt-4o"
16 | action_grounding_model: str = "gpt-4o"
17 | evaluation_model: str = "gpt-4o"
18 |
19 | # Search settings
20 | search_algorithm: str = "bfs"
21 | exploration_weight: float = 1.41
22 | branching_factor: int = 5
23 | iterations: int = 1
24 | max_depth: int = 3
25 | num_simulations: int = 1
26 | account_reset: bool = True
27 |
28 | # for LATS
29 | simulation_score: float = 0.75
30 |
31 | # for MCTS
32 | reflection_score: float = 0.75
33 | set_prior_value: bool = False
34 |
35 | # Features
36 | features: List[str] = field(default_factory=lambda: ['axtree'])
37 | fullpage: bool = False
38 | elements_filter: str = "som"
39 |
40 | # Logging
41 | log_folder: str = "log"
42 |
43 | def add_agent_config_arguments(parser):
44 | # Environment
45 | parser.add_argument("--browser-mode", type=str, required=False,
46 | help="Specify the browser mode")
47 | parser.add_argument("--storage-state", type=str, required=False,
48 | help="Storage state json file")
49 | # Model
50 | parser.add_argument("--action_generation_model", type=str, required=False,
51 | help="action generation model, right now only supports openai models")
52 | parser.add_argument("--feedback_model", type=str, required=False,
53 | help="feedback model, right now only supports openai models")
54 | parser.add_argument("--planning_model", type=str, required=False,
55 | help="planning model, right now only supports openai models")
56 | parser.add_argument("--action_grounding_model", type=str, required=False,
57 | help="action grounding model, right now only supports openai models")
58 | parser.add_argument("--evaluation_model", type=str, required=False,
59 | help="evaluation model, right now only supports openai models")
60 | # Search
61 | parser.add_argument("--search_algorithm", type=str, required=False,
62 | help="bfs or dfs")
63 | parser.add_argument("--exploration_weight", type=float, required=False,
64 | help="exploration weight")
65 | parser.add_argument("--branching_factor", type=int, required=False,
66 | help="branching factor")
67 | parser.add_argument("--iterations", type=int, required=False,
68 | help="Number of iterations to run")
69 | parser.add_argument("--max_depth", type=int, required=False,
70 | help="max depth of rollout")
71 | parser.add_argument("--num_simulations", type=int, required=False,
72 | help="Number of simulations to run")
73 |
74 | # Features
75 | parser.add_argument("--features", type=str, required=False,
76 | help="features to use")
77 | parser.add_argument("--fullpage", type=bool, required=False,
78 | help="fullpage")
79 | parser.add_argument("--elements_filter", type=str, required=False,
80 | help="elements filter")
81 |
82 | # Logging
83 | parser.add_argument("--log_folder", type=str, required=False,
84 | help="log folder")
85 |
86 | def filter_valid_config_args(args_dict):
87 | valid_fields = {field.name for field in fields(AgentConfig)}
88 | return {k: v for k, v in args_dict.items() if k in valid_fields and v is not None}
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/shared_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | from openai import OpenAI
4 | from ..action.highlevel import HighLevelActionSet
5 | from collections import defaultdict
6 | from ..utils.utils import query_openai_model
7 | from ..action.utils import prepare_prompt, execute_action
8 | from ..action.utils import build_highlevel_action_parser
9 | from ..browser_env.observation import extract_page_info
10 | from ..evaluation.feedback import capture_post_action_feedback
11 |
12 | logger = logging.getLogger(__name__)
13 | openai_client = OpenAI()
14 |
15 |
16 | def get_action_probability(responses, branching_factor):
17 | highlevel_action_parser = build_highlevel_action_parser()
18 | print(responses)
19 | parsed_actions_count = defaultdict(int)
20 | all_actions = {}
21 | for response in responses:
22 | result = highlevel_action_parser.parse_string(response)
23 | result = result[0] if result else "" # Convert to string
24 | if result not in all_actions:
25 | all_actions[result] = {'action': response}
26 | parsed_actions_count[result] += 1
27 | print(parsed_actions_count)
28 | top_actions = sorted(parsed_actions_count, key=parsed_actions_count.get, reverse=True)[:branching_factor]
29 | top_action_count = sum([parsed_actions_count[action] for action in top_actions])
30 | updated_actions = []
31 | for action in top_actions:
32 | a = all_actions[action]
33 | a['prob'] = parsed_actions_count[action] / top_action_count
34 | updated_actions.append(a)
35 |
36 | print(updated_actions)
37 | return updated_actions
38 |
39 |
40 | async def take_action(task_description, agent_type, features=None, branching_factor=None, playwright_manager=None,
41 | log_folder='log', elements_filter=None):
42 | try:
43 | context = await playwright_manager.get_context()
44 | page = await playwright_manager.get_page()
45 | action_set = HighLevelActionSet(
46 | subsets=agent_type,
47 | strict=False,
48 | multiaction=True,
49 | demo_mode="default"
50 | )
51 | # Extract page information
52 | time.sleep(3)
53 | page_info = await extract_page_info(page, log_folder)
54 |
55 | # Prepare messages for AI model
56 | system_msg = f"""
57 | # Instructions
58 | Review the current state of the page and all other information to find the best
59 | possible next action to accomplish your goal. Your answer will be interpreted
60 | and executed by a program, make sure to follow the formatting instructions.
61 | Provide ONLY ONE action. Do not suggest multiple actions or a sequence of actions.
62 | # Goal:
63 | {task_description}"""
64 |
65 | prompt = prepare_prompt(page_info, action_set, features, log_folder, elements_filter)
66 |
67 | # Query OpenAI model
68 | if branching_factor == None:
69 | responses = query_openai_model(system_msg, prompt, page_info['screenshot_som'], num_outputs=20)
70 | else:
71 | responses = query_openai_model(system_msg, prompt, page_info['screenshot_som'],
72 | num_outputs=max(branching_factor * 2, 20))
73 |
74 | updated_actions = get_action_probability(responses, branching_factor)
75 | action = updated_actions[0]['action']
76 | print("action is:")
77 | print(action)
78 |
79 | # Execute the action
80 | try:
81 | await execute_action(action, action_set, page, context, task_description,
82 | page_info['interactive_elements'], log_folder)
83 | except Exception as e:
84 | last_action_error = f"{type(e).__name__}: {str(e)}"
85 | logger.error(f"Action execution failed: {last_action_error}")
86 | return f"Task failed with action error: {last_action_error}"
87 |
88 | feedback = await capture_post_action_feedback(page, action, task_description, log_folder)
89 | return f"The action is: {action} - the result is: {feedback}"
90 |
91 | except Exception as e:
92 | error_msg = f"{type(e).__name__}: {str(e)}"
93 | logger.error(f"Task failed: {error_msg}")
94 | return f"Task failed: {error_msg}"
95 |
--------------------------------------------------------------------------------
/visual-tree-search-app/components/ui/dialog.tsx:
--------------------------------------------------------------------------------
1 | import * as React from "react"
2 | import * as DialogPrimitive from "@radix-ui/react-dialog"
3 | import { XIcon } from "lucide-react"
4 |
5 | import { cn } from "@/lib/utils"
6 |
7 | function Dialog({
8 | ...props
9 | }: React.ComponentProps) {
10 | return
11 | }
12 |
13 | function DialogTrigger({
14 | ...props
15 | }: React.ComponentProps) {
16 | return
17 | }
18 |
19 | function DialogPortal({
20 | ...props
21 | }: React.ComponentProps) {
22 | return
23 | }
24 |
25 | function DialogClose({
26 | ...props
27 | }: React.ComponentProps) {
28 | return
29 | }
30 |
31 | function DialogOverlay({
32 | className,
33 | ...props
34 | }: React.ComponentProps) {
35 | return (
36 |
44 | )
45 | }
46 |
47 | function DialogContent({
48 | className,
49 | children,
50 | ...props
51 | }: React.ComponentProps) {
52 | return (
53 |
54 |
55 |
63 | {children}
64 |
65 |
66 | Close
67 |
68 |
69 |
70 | )
71 | }
72 |
73 | function DialogHeader({ className, ...props }: React.ComponentProps<"div">) {
74 | return (
75 |
80 | )
81 | }
82 |
83 | function DialogFooter({ className, ...props }: React.ComponentProps<"div">) {
84 | return (
85 |
93 | )
94 | }
95 |
96 | function DialogTitle({
97 | className,
98 | ...props
99 | }: React.ComponentProps) {
100 | return (
101 |
106 | )
107 | }
108 |
109 | function DialogDescription({
110 | className,
111 | ...props
112 | }: React.ComponentProps) {
113 | return (
114 |
119 | )
120 | }
121 |
122 | export {
123 | Dialog,
124 | DialogClose,
125 | DialogContent,
126 | DialogDescription,
127 | DialogFooter,
128 | DialogHeader,
129 | DialogOverlay,
130 | DialogPortal,
131 | DialogTitle,
132 | DialogTrigger,
133 | }
134 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/test-tree-ws.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import websockets
4 | import logging
5 | import uuid
6 |
7 | # Set up logging to see more details
8 | logging.basicConfig(level=logging.INFO)
9 |
10 | async def test_websocket():
11 | uri = "ws://localhost:3000/tree-ws"
12 |
13 | print(f"Connecting to {uri}")
14 |
15 | async with websockets.connect(uri) as websocket:
16 | print("Connected to WebSocket")
17 |
18 | # Send a ping message
19 | await websocket.send(json.dumps({
20 | "type": "ping"
21 | }))
22 |
23 | # Wait for response
24 | response = await websocket.recv()
25 | print(f"Received: {response}")
26 |
27 | # Create a sample tree structure for testing
28 | sample_tree = {
29 | "id": "root-1",
30 | "name": "Root Node",
31 | "children": [
32 | {
33 | "id": "child-1",
34 | "name": "Child 1",
35 | "children": [
36 | {
37 | "id": "grandchild-1",
38 | "name": "Grandchild 1",
39 | "children": []
40 | },
41 | {
42 | "id": "grandchild-2",
43 | "name": "Grandchild 2",
44 | "children": []
45 | }
46 | ]
47 | },
48 | {
49 | "id": "child-2",
50 | "name": "Child 2",
51 | "children": [
52 | {
53 | "id": "grandchild-3",
54 | "name": "Grandchild 3",
55 | "children": []
56 | }
57 | ]
58 | }
59 | ]
60 | }
61 |
62 | # Send tree data for traversal
63 | await websocket.send(json.dumps({
64 | "type": "tree",
65 | "content": sample_tree
66 | }))
67 |
68 | print("Sent tree data for traversal")
69 |
70 | # Alternatively, send a traversal request
71 | # await websocket.send(json.dumps({
72 | # "type": "traversal_request",
73 | # "algorithm": "bfs",
74 | # "tree": sample_tree
75 | # }))
76 |
77 | # Set a timeout to prevent hanging indefinitely
78 | timeout = 60 # 1 minute
79 | start_time = asyncio.get_event_loop().time()
80 |
81 | # Continuously receive messages with timeout
82 | while True:
83 | try:
84 | # Use wait_for to add timeout to recv
85 | elapsed = asyncio.get_event_loop().time() - start_time
86 | remaining = max(0, timeout - elapsed)
87 |
88 | if remaining <= 0:
89 | print("Timeout reached. No more messages received.")
90 | break
91 |
92 | response = await asyncio.wait_for(websocket.recv(), timeout=remaining)
93 | print(f"Received: {response}")
94 |
95 | try:
96 | data = json.loads(response)
97 | print(f"Message type: {data.get('type', 'unknown')}")
98 | print(json.dumps(data, indent=2))
99 |
100 | # If we receive a completion message, we can exit
101 | if data.get("type") == "info" and data.get("message") == "BFS traversal completed":
102 | print("Traversal completed successfully!")
103 | break
104 |
105 | except json.JSONDecodeError:
106 | print("Received non-JSON message")
107 |
108 | except asyncio.TimeoutError:
109 | print("Timeout waiting for response from WebSocket")
110 | break
111 | except websockets.exceptions.ConnectionClosed:
112 | print("Connection closed")
113 | break
114 | except Exception as e:
115 | print(f"Error: {e}")
116 | break
117 |
118 | asyncio.run(test_websocket())
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/routes/tree_traversal.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from datetime import datetime
4 | import logging
5 | from typing import Dict, Any, List, Set
6 |
7 | # Configure basic logging
8 | logging.basicConfig(
9 | level=logging.INFO,
10 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11 | )
12 |
13 | from fastapi import WebSocket, WebSocketDisconnect
14 |
15 | # Define a function to perform BFS traversal of the tree
16 | async def bfs_traversal(websocket: WebSocket, tree_data: Dict[str, Any]):
17 | """Perform BFS traversal of the tree and send updates to the client"""
18 | try:
19 | # Send acknowledgment
20 | info_message = {
21 | "type": "info",
22 | "message": "Starting BFS traversal",
23 | "timestamp": datetime.utcnow().isoformat()
24 | }
25 | logging.info(f"Sending to websocket: {json.dumps(info_message)}")
26 | await websocket.send_json(info_message)
27 |
28 | # Send the root node first
29 | root_message = {
30 | "type": "traversal", # Changed to 'traversal' for consistency with frontend
31 | "algorithm": "bfs",
32 | "nodeId": tree_data["id"],
33 | "nodeName": tree_data["name"],
34 | "parentId": None, # Root has no parent
35 | "isRoot": True,
36 | "message": f"Root node: {tree_data['name']} (ID: {tree_data['id']})",
37 | "timestamp": datetime.utcnow().isoformat()
38 | }
39 | logging.info(f"Sending to websocket: {json.dumps(root_message)}")
40 | await websocket.send_json(root_message)
41 |
42 | # Wait to make the traversal visible
43 | await asyncio.sleep(1)
44 |
45 | # Initialize queue with root node's children
46 | queue = []
47 | visited = set([tree_data["id"]])
48 |
49 | # Add root's children to queue
50 | if "children" in tree_data and tree_data["children"]:
51 | for child in tree_data["children"]:
52 | queue.append((child, tree_data)) # Store child with its parent
53 |
54 | # BFS traversal
55 | while queue:
56 | # Get the next node and its parent
57 | child, parent = queue.pop(0)
58 |
59 | if child["id"] not in visited:
60 | visited.add(child["id"])
61 |
62 | # Send child node with parent reference
63 | node_message = {
64 | "type": "traversal", # Changed to 'traversal' for consistency with frontend
65 | "algorithm": "bfs",
66 | "nodeId": child["id"],
67 | "nodeName": child["name"],
68 | "parentId": parent["id"],
69 | "isRoot": False,
70 | "message": f"Visiting node: {child['name']} (ID: {child['id']}, Parent: {parent['name']})",
71 | "timestamp": datetime.utcnow().isoformat()
72 | }
73 | logging.info(f"Sending to websocket: {json.dumps(node_message)}")
74 | await websocket.send_json(node_message)
75 |
76 | # Wait to make the traversal visible
77 | await asyncio.sleep(1)
78 |
79 | # Add child's children to queue
80 | if "children" in child and child["children"]:
81 | for grandchild in child["children"]:
82 | queue.append((grandchild, child)) # Store child with its parent
83 |
84 | # Send completion message
85 | completion_message = {
86 | "type": "info",
87 | "message": "BFS traversal completed",
88 | "timestamp": datetime.utcnow().isoformat()
89 | }
90 | logging.info(f"Sending to websocket: {json.dumps(completion_message)}")
91 | await websocket.send_json(completion_message)
92 |
93 | except WebSocketDisconnect:
94 | logging.info("Client disconnected during tree traversal")
95 | except Exception as e:
96 | logging.error(f"Error during BFS tree traversal: {e}")
97 | await websocket.send_json({
98 | "type": "error",
99 | "message": f"Error during BFS traversal: {str(e)}",
100 | "timestamp": datetime.utcnow().isoformat()
101 | })
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/test_tree_search_depth.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import asyncio
3 | from unittest.mock import AsyncMock, MagicMock, patch
4 | from datetime import datetime
5 | import sys
6 | import os
7 |
8 | # Add the parent directory to the Python path
9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10 |
11 | # Configure pytest-asyncio
12 | pytestmark = pytest.mark.asyncio
13 |
14 | from app.api.lwats.agents_async.SimpleSearchAgents.simple_search_agent import SimpleSearchAgent
15 | from app.api.lwats.core_async.config import AgentConfig
16 | from app.api.lwats.agents_async.SimpleSearchAgents.lats_node import LATSNode
17 |
18 | @pytest.fixture
19 | def mock_websocket():
20 | websocket = AsyncMock()
21 | websocket.send_json = AsyncMock()
22 | return websocket
23 |
24 | @pytest.fixture
25 | def mock_config():
26 | config = MagicMock(spec=AgentConfig)
27 | config.max_depth = 2 # Set a small max depth for testing
28 | config.search_algorithm = "bfs"
29 | config.evaluation_model = "gpt-4"
30 | config.browser_mode = "browserbase"
31 | config.headless = True
32 | config.storage_state = None
33 | config.log_folder = "test_logs"
34 | config.fullpage = False
35 | config.features = []
36 | config.elements_filter = []
37 | config.branching_factor = 2
38 | config.action_generation_model = "gpt-4"
39 | config.action_grounding_model = "gpt-4"
40 | return config
41 |
42 | @pytest.fixture
43 | def mock_playwright_manager():
44 | manager = AsyncMock()
45 | manager.get_page = AsyncMock()
46 | manager.close = AsyncMock()
47 | manager.get_live_browser_url = AsyncMock(return_value="http://test-url")
48 | return manager
49 |
50 | @pytest.mark.asyncio
51 | async def test_bfs_with_websocket_depth_limit(mock_websocket, mock_config, mock_playwright_manager):
52 | # Create a SimpleSearchAgent instance
53 | agent = SimpleSearchAgent(
54 | starting_url="http://test.com",
55 | messages=[],
56 | goal="test goal",
57 | images=[],
58 | playwright_manager=mock_playwright_manager,
59 | config=mock_config
60 | )
61 |
62 | # Create a simple tree structure for testing
63 | root = agent.root_node
64 | child1 = LATSNode(
65 | natural_language_description="child1",
66 | action="action1",
67 | prob=1.0,
68 | element=None,
69 | goal="test goal",
70 | parent=root
71 | )
72 | child2 = LATSNode(
73 | natural_language_description="child2",
74 | action="action2",
75 | prob=1.0,
76 | element=None,
77 | goal="test goal",
78 | parent=root
79 | )
80 | root.children = [child1, child2]
81 |
82 | # Mock the expand method to simulate node expansion
83 | async def mock_expand(node, websocket=None):
84 | if node.depth < mock_config.max_depth:
85 | new_child = LATSNode(
86 | natural_language_description=f"child{node.depth + 1}",
87 | action=f"action{node.depth + 1}",
88 | prob=1.0,
89 | element=None,
90 | goal="test goal",
91 | parent=node
92 | )
93 | node.children = [new_child]
94 |
95 | # Mock both expand and _reset_browser methods
96 | with patch.object(agent, 'expand', side_effect=mock_expand), \
97 | patch.object(agent, '_reset_browser', return_value="http://test-url"):
98 | # Run the bfs_with_websocket method
99 | result = await agent.bfs_with_websocket(mock_websocket)
100 |
101 | # Verify that the depth limit was respected
102 | # Check if any node beyond max_depth was processed
103 | depth_limit_messages = [
104 | call for call in mock_websocket.send_json.call_args_list
105 | if call[0][0].get("type") == "node_terminal" and
106 | call[0][0].get("reason") == "depth_limit"
107 | ]
108 |
109 | assert len(depth_limit_messages) > 0, "No depth limit messages were sent"
110 |
111 | # Verify the structure of the tree
112 | def check_node_depth(node, expected_depth):
113 | assert node.depth <= mock_config.max_depth, f"Node at depth {node.depth} exceeds max_depth"
114 | for child in node.children:
115 | check_node_depth(child, expected_depth + 1)
116 |
117 | check_node_depth(root, 0)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VisualTreeSearch
2 |
3 | A powerful web agent visualization tool that helps you understand and analyze web automation processes through visual tree search.
4 |
5 | **Disclaimer: Please note that VisualTreeSearch is not affiliated with any for-profit company. This is a collaborative project from PathOnAI.org, an open-source AI research community, where Danqing Zhang (danqing.zhang.personal@gmail.com) is the main contributor and lead author of the ECML-PKDD paper. If anyone claims VisualTreeSearch is affiliated with any for-profit company, please contact Danqing Zhang (danqing.zhang.personal@gmail.com) for verification.**
6 |
7 |
8 | ## News
9 | * 06/16/2025 - "VisualTreeSearch: Understanding Web Agent Test-time Scaling" was accepted by ECML-PKDD 2025.
10 | * 04/28/2025 - Released this Open Source Repo: [visual tree search](https://github.com/PathOnAIOrg/VisualTreeSearch-Demo).
11 |
12 |
13 | ## 🌐 Live Demo
14 |
15 | [](https://www.youtube.com/embed/stRNDePQGV0)
16 |
17 | Visit our live demo at: [visual-tree-search.pathonai.org](https://visual-tree-search.pathonai.org)
18 |
19 | ## 🌟 Features
20 |
21 | - Interactive visualization of web agent actions
22 | - Real-time tree search visualization
23 | - Modern and responsive UI
24 | - Comprehensive web automation analysis
25 |
26 | ## 🛠️ Tech Stack
27 |
28 | ### Frontend
29 | - **Framework**: NextJS 14
30 | - **Styling**: TailwindCSS
31 | - **UI Components**: Shadcn UI
32 | - **Deployment**: Vercel
33 |
34 | ### Backend
35 | - **Framework**: FastAPI
36 | - **Deployment**: AWS ECS
37 |
38 | ### Browser Service
39 | - **Framework**: FastAPI
40 | - **Deployment**: AWS ECS
41 | - **Browser Engine**: Chromium (via Playwright)
42 |
43 | ### State Reset
44 | - **Framework**: FastAPI
45 | - **Deployment**: AWS EC2
46 | - **Database Access**: SQLAlchemy ORM connecting to MariaDB (using MySQL-compatible interface)
47 |
48 | ## 🚀 Getting Started
49 |
50 | ### Prerequisites
51 | - Node.js (Latest LTS version)
52 | - Python 3.8+
53 | - npm or yarn
54 | - Git
55 |
56 | ### Installation
57 |
58 | 1. Clone the repository
59 | ```bash
60 | git clone https://github.com/PathOnAI/VisualTreeSearch-Demo.git
61 | cd VisualTreeSearch-Demo
62 | ```
63 |
64 | 2. Backend Setup
65 | ```bash
66 | # Navigate to backend directory
67 | cd visual-tree-search-backend
68 |
69 | # Create and activate virtual environment (recommended)
70 | python -m venv venv
71 | source venv/bin/activate # On Windows use: venv\Scripts\activate
72 |
73 | # Install dependencies
74 | pip install -r requirements.txt
75 | pip install uvicorn[standard] # Install uvicorn with standard extras
76 | ```
77 |
78 | 3. Frontend Setup
79 | ```bash
80 | # Navigate to frontend directory
81 | cd ../visual-tree-search-app
82 |
83 | # Install dependencies
84 | npm install
85 |
86 | # Create .env file
87 | echo "NEXT_PUBLIC_BACKEND_URL=http://127.0.0.1:3000" > .env
88 | ```
89 |
90 | ### Local Development
91 |
92 | #### Backend
93 | 1. Navigate to backend directory:
94 | ```bash
95 | cd visual-tree-search-backend
96 | ```
97 |
98 | 2. Activate virtual environment (if not already activated):
99 | ```bash
100 | source venv/bin/activate # On Windows use: venv\Scripts\activate
101 | ```
102 |
103 | 3. Run the FastAPI server:
104 | ```bash
105 | uvicorn app.main:app --host 0.0.0.0 --port 3000 --reload
106 | ```
107 |
108 | Note: The `--reload` flag enables auto-reload when code changes are detected. Remove it in production.
109 |
110 | #### Frontend
111 | 1. Open a new terminal and navigate to frontend directory:
112 | ```bash
113 | cd visual-tree-search-app
114 | ```
115 |
116 | 2. Start the development server:
117 | ```bash
118 | npm run dev -- -p 3001
119 | ```
120 |
121 | The application should now be running at:
122 | - Frontend: http://localhost:3001
123 | - Backend: http://localhost:3000
124 |
125 | ## 📝 Project Structure
126 |
127 | ```
128 | VisualTreeSearch-Demo/
129 | ├── visual-tree-search-app/ # Frontend application
130 | │ ├── src/ # Source code
131 | │ ├── public/ # Static files
132 | │ └── package.json # Frontend dependencies
133 | ├── visual-tree-search-backend/ # Backend API service
134 | │ ├── app/ # Backend source code
135 | │ ├── requirements.txt # Backend dependencies
136 | │ └── test/ # Test files
137 | ├── visual-tree-search-browser-service/ # Browser automation service
138 | └── visual-tree-search-state-reset/ # State management service
139 | ```
140 |
141 |
142 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/core_async/agent_factory.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import logging
4 | from dotenv import load_dotenv
5 | from openai import OpenAI
6 |
7 | from .config import AgentConfig
8 | from ..agents_async.SearchAgents.simple_search_agent import SimpleSearchAgent
9 | from ..agents_async.SearchAgents.lats_agent import LATSAgent
10 | from ..agents_async.SearchAgents.mcts_agent import MCTSAgent
11 | from ..webagent_utils_async.utils.utils import setup_logger
12 | from ..webagent_utils_async.utils.playwright_manager import setup_playwright
13 |
14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15 |
16 | _ = load_dotenv()
17 |
18 | logger = logging.getLogger(__name__)
19 | openai_client = OpenAI()
20 |
21 | # Define the default features
22 | DEFAULT_FEATURES = ['screenshot', 'dom', 'axtree', 'focused_element', 'extra_properties', 'interactive_elements']
23 |
24 |
25 | SEARCH_AGENT_SYSTEM_PROMPT = \
26 | """You are a web search agent designed to perform specific tasks on web pages as instructed by the user. Your primary objectives are:
27 |
28 | 1. Execute ONLY the task explicitly provided by the user.
29 | 2. Perform the task efficiently and accurately using the available functions.
30 | 3. If there are errors, retry using a different approach within the scope of the given task.
31 | 4. Once the current task is completed, stop and wait for further instructions.
32 |
33 | Critical guidelines:
34 | - Strictly limit your actions to the current task. Do not attempt additional tasks or next steps.
35 | - Use only the functions provided to you. Do not attempt to use functions or methods that are not explicitly available.
36 | - For navigation or interaction with page elements, always use the appropriate bid (browser element ID) when required by a function.
37 | - Do not try to navigate to external websites or use URLs directly.
38 | - If a task cannot be completed with the available functions, report the limitation rather than attempting unsupported actions.
39 | - After completing a task, report its completion and await new instructions. Do not suggest or initiate further actions.
40 |
41 | Remember: Your role is to execute the given task precisely as instructed, using only the provided functions and within the confines of the current web page. Do not exceed these boundaries under any circumstances."""
42 |
43 | async def setup_search_agent(
44 | agent_type,
45 | starting_url,
46 | goal,
47 | images,
48 | agent_config: AgentConfig
49 | ):
50 | logger = setup_logger()
51 |
52 | file_path = os.path.join(agent_config.log_folder, 'flow', 'steps.json')
53 | os.makedirs(os.path.dirname(file_path), exist_ok=True)
54 | with open(file_path, 'w') as file:
55 | file.write(goal + '\n')
56 | file.write(starting_url + '\n')
57 |
58 | playwright_manager = await setup_playwright(
59 | headless=agent_config.headless,
60 | mode=agent_config.browser_mode,
61 | storage_state=agent_config.storage_state
62 | )
63 | # storage_state='state.json', headless=False, mode="chromium"
64 |
65 | page = await playwright_manager.get_page()
66 | await page.goto(starting_url)
67 | # Maximize the window on macOS
68 | # await page.set_viewport_size({"width": 1440, "height": 900})
69 |
70 | messages = [{
71 | "role": "system",
72 | "content": SEARCH_AGENT_SYSTEM_PROMPT,
73 | }]
74 |
75 | if agent_type == "SimpleSearchAgent":
76 | print("SimpleSearchAgent")
77 | agent = SimpleSearchAgent(
78 | starting_url=starting_url,
79 | messages=messages,
80 | goal=goal,
81 | images = images,
82 | playwright_manager=playwright_manager,
83 | config=agent_config,
84 | )
85 | elif agent_type == "LATSAgent":
86 | print("LATSAgent")
87 | agent = LATSAgent(
88 | starting_url=starting_url,
89 | messages=messages,
90 | goal=goal,
91 | images = images,
92 | playwright_manager=playwright_manager,
93 | config=agent_config,
94 | )
95 | elif agent_type == "MCTSAgent":
96 | print("MCTSAgent")
97 | agent = MCTSAgent(
98 | starting_url=starting_url,
99 | messages=messages,
100 | goal=goal,
101 | images = images,
102 | playwright_manager=playwright_manager,
103 | config=agent_config,
104 | )
105 | else:
106 | error_message = f"Unsupported agent type: {agent_type}. Please use 'FunctionCallingAgent', 'HighLevelPlanningAgent', 'ContextAwarePlanningAgent', 'PromptAgent' or 'PromptSearchAgent' ."
107 | logger.error(error_message)
108 | return {"error": error_message}
109 | return agent, playwright_manager
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/Server_Setup.md:
--------------------------------------------------------------------------------
1 | # Setup
2 |
3 | ### Format disk
4 |
5 |
6 | ```
7 | sudo fdisk -l
8 | sudo mkfs.ext4 /dev/xvdk -F
9 | mkdir ~/data0
10 | sudo mount -t ext4 /dev/xvdk /home/ubuntu/data0 -o defaults,nodelalloc,noatime
11 | sudo chmod o+w /home/ubuntu/data0
12 | sudo chown ubuntu /home/ubuntu/data0
13 | cd /home/ubuntu/data0
14 | ```
15 |
16 |
17 |
18 | ### firewall (Optional)
19 |
20 | ```
21 | echo y | sudo ufw enable
22 | sudo systemctl start ufw
23 | sudo systemctl enable ufw
24 | sudo ufw allow from ****YOURIP****
25 | sudo ufw status
26 |
27 |
28 | sudo ufw allow 7770
29 | sudo ufw allow 7780
30 | sudo ufw allow 9999
31 | sudo ufw allow 8023
32 | sudo ufw allow 3000
33 | sudo ufw allow 8888
34 | sudo ufw allow 4399
35 | sudo ufw allow 9980
36 | sudo ufw allow 4040
37 | sudo ufw status
38 |
39 | ```
40 |
41 |
42 | ### docker
43 |
44 | ```
45 | # Add Docker's official GPG key:
46 | sudo apt-get update
47 | sudo apt-get install ca-certificates curl -y
48 | sudo install -m 0755 -d /etc/apt/keyrings
49 | sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
50 | sudo chmod a+r /etc/apt/keyrings/docker.asc
51 | # Add the repository to Apt sources:
52 | echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
53 | $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
54 | sudo apt-get update
55 | sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin -y
56 | sudo docker run hello-world
57 |
58 |
59 | # docker without sudo
60 | getent group docker # Check if the docker group exists. If it doesn't exist, create it: sudo groupadd docker
61 | sudo usermod -aG docker ubuntu
62 | sudo systemctl restart docker
63 | newgrp docker
64 | docker run hello-world
65 |
66 | sudo apt install htop iotop -y
67 | ```
68 |
69 | ### move Docker Data Directory
70 |
71 | Move Docker Data Directory to another filesystem to save space on system disk
72 |
73 | (eg, `/home/ubuntu/data0/docker`)
74 |
75 |
76 | ```
77 | sudo systemctl stop docker
78 | sudo systemctl status docker # make sure docker service is stopped
79 |
80 | mkdir /home/ubuntu/data0/docker
81 |
82 | # Copy the Data to the New Location: Use rsync or cp to preserve file permissions and symbolic links
83 | sudo rsync -aHAX /var/lib/docker/ /home/ubuntu/data0/docker/
84 | # Rename the Old Directory (as a backup)
85 | sudo mv /var/lib/docker /var/lib/docker.bak
86 |
87 | ```
88 |
89 | Then Edit the Docker configuration file (e.g., /etc/docker/daemon.json)
90 | `sudo vi /etc/docker/daemon.json`
91 |
92 | ```
93 | {
94 | "data-root": "/home/ubuntu/data0/docker/"
95 | }
96 | ```
97 |
98 | ```
99 | sudo systemctl restart docker
100 |
101 | # Verify Docker is running:
102 | sudo systemctl status docker
103 |
104 | # Check if containers and images are accessible:
105 | docker ps -a
106 | docker images
107 |
108 | # Start any necessary containers:
109 | docker start
110 |
111 | # Once everything works as expected, you can delete the backup:
112 | sudo rm -rf /var/lib/docker.bak
113 | ```
114 |
115 | ## Set up WebArena website server
116 |
117 | replace `****PUBLICIP****` with the public ip or URL of the server
118 |
119 | ```
120 | wget http://metis.lti.cs.cmu.edu/webarena-images/shopping_final_0712.tar
121 | ```
122 |
123 | ### Shopping Website (OneStopShop)
124 |
125 | ```
126 | docker load --input shopping_final_0712.tar
127 | docker run --name shopping -p 7770:80 -d shopping_final_0712
128 | # wait ~1 min to wait all services to start
129 |
130 | docker exec shopping /var/www/magento2/bin/magento setup:store-config:set --base-url="http://****PUBLICIP****:7770" # no trailing slash
131 | docker exec shopping mysql -u magentouser -pMyPassword magentodb -e 'UPDATE core_config_data SET value="http://****PUBLICIP****:7770/" WHERE path = "web/secure/base_url";'
132 | docker exec shopping /var/www/magento2/bin/magento cache:flush
133 | ```
134 |
135 | ### Allow remote connection to the DB inside docker container
136 |
137 |
138 | ```
139 | docker exec -it shopping mysql -u root -p
140 | # Then enter the password (1234567890) when prompted
141 |
142 | # Then run:
143 | CREATE USER 'root'@'172.17.0.1' IDENTIFIED BY '1234567890';
144 | GRANT ALL PRIVILEGES ON *.* TO 'root'@'172.17.0.1' WITH GRANT OPTION;
145 | FLUSH PRIVILEGES;
146 | ```
147 |
148 |
149 | ### Default DB information
150 |
151 | ```
152 | # Connection parameters
153 | host = '172.17.0.2'
154 | port = 3306
155 | user = 'root'
156 | password = '1234567890'
157 | database = 'magentodb'
158 | ```
159 |
160 |
161 | ### Default user information
162 |
163 |
164 | ```
165 | "username": "emma.lopez@gmail.com",
166 | "password": "Password.123",
167 | ```
168 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/tools/webscraping.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | from urllib.parse import urljoin, urlparse
3 | import time
4 | import random
5 | from .registry import ToolRegistry, Tool
6 |
7 | def webscraping(task_description, features=None, branching_factor=None, playwright_manager=None, log_folder='log', elements_filter=None):
8 | max_retries = 3
9 | page = playwright_manager.get_page()
10 |
11 | for attempt in range(max_retries):
12 | try:
13 | # Wait for potential Cloudflare challenge to be solved
14 | page.wait_for_load_state('networkidle')
15 |
16 | # Check if we're still on a Cloudflare page
17 | title = page.title()
18 | if 'cloudflare' in title.lower():
19 | print(f"Cloudflare detected, waiting... (Attempt {attempt + 1})")
20 | time.sleep(5) # Wait for 5 seconds before retrying
21 | continue
22 |
23 | # Get the current URL
24 | current_url = page.url
25 |
26 | page_content = page.content()
27 | soup = BeautifulSoup(page_content, 'html.parser')
28 |
29 | # Remove unwanted elements
30 | for element in soup.select('aside, header, nav, footer'):
31 | element.decompose()
32 |
33 | # Extract content
34 | content = {
35 | 'url': current_url,
36 | 'title': page.title(),
37 | 'main_content': get_main_content(soup),
38 | 'paragraphs': get_paragraphs(soup),
39 | 'headings': get_headings(soup),
40 | 'meta_data': get_meta_data(soup),
41 | 'internal_links': get_internal_links(soup, current_url),
42 | 'formatted_content': get_formatted_content(soup)
43 | }
44 |
45 | return content
46 |
47 | except Exception as e:
48 | print(f"Error occurred: {e}")
49 | if attempt == max_retries - 1:
50 | raise
51 | time.sleep(random.uniform(1, 3))
52 |
53 |
54 | def get_main_content(soup):
55 | main_content = soup.find(id='main')
56 | return main_content.text if main_content else "Main content not found"
57 |
58 |
59 | def get_paragraphs(soup):
60 | return [p.text for p in soup.find_all('p')]
61 |
62 |
63 | def get_headings(soup):
64 | return [h.text for h in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6'])]
65 |
66 |
67 | def get_meta_data(soup):
68 | meta_data = {}
69 | for meta_tag in soup.find_all('meta'):
70 | name = meta_tag.get('name') or meta_tag.get('property')
71 | content = meta_tag.get('content')
72 | if name and content:
73 | meta_data[name] = content
74 | return meta_data
75 |
76 |
77 | def get_internal_links(soup, url):
78 | internal_links = []
79 | base_url = urlparse(url).scheme + "://" + urlparse(url).hostname
80 |
81 | for a_tag in soup.find_all('a', href=True):
82 | href = a_tag['href']
83 | if href.startswith('/'):
84 | full_url = urljoin(base_url, href)
85 | internal_links.append(full_url)
86 | elif href.startswith(base_url):
87 | internal_links.append(href)
88 |
89 | return internal_links
90 |
91 |
92 | def get_formatted_content(soup):
93 | formatted_article = []
94 | unique_text = set()
95 | tag_names = set(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p', 'ul', 'ol'])
96 | ignore_classes = {'ch2-container', 'ch2-theme-bar', 'ch2-style-light',
97 | 'ch2-dialog', 'ch2-dialog-bottom', 'ch2-visible',
98 | 'ch2-settings', 'ch2-settings-scan'}
99 |
100 | for tag in soup.find_all(True):
101 | if tag.name in tag_names:
102 | if not any(cls in ignore_classes for cls in tag.get('class', [])) and \
103 | not any(cls in ignore_classes for parent in tag.parents for cls in parent.get('class', [])):
104 | if not any(child.name in tag_names for child in tag.children):
105 | text = tag.get_text().strip()
106 | lower_text = text.lower()
107 | if text and lower_text not in unique_text:
108 | unique_text.add(lower_text)
109 | formatted_article.append(f"{text} \n")
110 |
111 | return ''.join(formatted_article)
112 |
113 |
114 | def register_webscraping_tool():
115 | ToolRegistry.register(Tool(
116 | name="webscraping",
117 | func=webscraping,
118 | description="Scrape content from the current web page",
119 | parameters={
120 | "task_description": {
121 | "type": "string",
122 | "description": "The description of the webscraping task"
123 | }
124 | }
125 | ))
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/tree_vis.py:
--------------------------------------------------------------------------------
1 | """Utilities for visualizing LATS tree structures."""
2 |
3 | from typing import Optional
4 | from .lats_node import LATSNode
5 |
6 | # ANSI color codes
7 | GREEN = '\033[92m'
8 | RED = '\033[91m'
9 | RESET = '\033[0m'
10 |
11 | def collect_all_nodes(node: LATSNode) -> list[LATSNode]:
12 | """
13 | Recursively collect all nodes starting from the given node.
14 |
15 | Args:
16 | node: The root node to start collection from
17 |
18 | Returns:
19 | list[LATSNode]: List of all nodes in the tree
20 | """
21 | nodes = [node]
22 | for child in node.children:
23 | nodes.extend(collect_all_nodes(child))
24 | return nodes
25 |
26 | def better_print(node: LATSNode, level: int = 0, selected_node: Optional[LATSNode] = None) -> None:
27 | """
28 | Print tree structure recursively with indentation, showing node statistics.
29 |
30 | Args:
31 | node: The node to print
32 | level: Current indentation level (default=0)
33 | selected_node: The currently selected node to highlight
34 | """
35 | indent = " " * level
36 |
37 | action = node.action if node.action is not None else 'None'
38 | if isinstance(action, str):
39 | action = action.replace('\n', '')
40 |
41 | visits = f"visits: {node.visits}"
42 | value = f"value: {node.value:.3f}" if hasattr(node, 'value') else "value: N/A"
43 | reward = f"reward: {node.reward:.3f}" if hasattr(node, 'reward') else "reward: N/A"
44 | stats = f"[{visits}, {value}, {reward}]"
45 |
46 | if node == selected_node:
47 | print(f"{indent}├── Level {level}: {GREEN}{action}{RESET} {stats} ← Selected")
48 | else:
49 | print(f"{indent}├── Level {level}: {action} {stats}")
50 |
51 | for child in node.children:
52 | better_print(child, level + 1, selected_node)
53 |
54 | def print_trajectory(terminal_node: LATSNode) -> None:
55 | """
56 | Print the single path from a terminal node to the root.
57 |
58 | Args:
59 | terminal_node: The leaf node to start the trajectory from
60 | """
61 | path = []
62 | current = terminal_node
63 | while current is not None:
64 | path.append(current)
65 | current = current.parent
66 |
67 | for level, node in enumerate(reversed(path)):
68 | indent = " " * level
69 | action = node.action
70 |
71 | visits = f"visits: {node.visits}"
72 | value = f"value: {node.value:.3f}" if hasattr(node, 'value') else "value: N/A"
73 | reward = f"reward: {node.reward:.3f}" if hasattr(node, 'reward') else "reward: N/A"
74 | is_terminal = f"terminal: {node.is_terminal}"
75 | feedback = f"feedback: {node.feedback if node.feedback else 'N/A'}"
76 | stats = f"[{visits}, {value}, {reward}, {is_terminal}, {feedback}]"
77 |
78 | indicator = ""
79 | if node == terminal_node:
80 | indicator = "← Terminal"
81 | elif not hasattr(node, 'parent') or node.parent is None:
82 | indicator = "(Root)"
83 |
84 | print(f"{indent}├── Level {level}: {GREEN}{action}{RESET} {stats} {indicator}")
85 |
86 | def print_entire_tree(root: LATSNode) -> None:
87 | """
88 | Print the entire tree structure starting from the root node.
89 |
90 | Args:
91 | root: The root node of the tree to print
92 | """
93 | def _print_subtree(node: LATSNode, level: int, prefix: str, is_last: bool) -> None:
94 | # Prepare the current line's prefix
95 | current_prefix = prefix + ("└── " if is_last else "├── ")
96 |
97 | # Prepare node statistics
98 | action = node.action
99 | node_id = f"id: {id(node)}"
100 | visits = f"visits: {node.visits}"
101 | value = f"value: {node.value:.3f}" if hasattr(node, 'value') else "value: N/A"
102 | reward = f"reward: {node.reward:.3f}" if hasattr(node, 'reward') else "reward: N/A"
103 | is_terminal = f"terminal: {node.is_terminal}"
104 | feedback = f"feedback: {node.feedback if node.feedback else 'N/A'}"
105 | stats = f"[{visits}, {value}, {reward}, {is_terminal}, {feedback}]"
106 |
107 | # Add indicator for root or terminal nodes
108 | indicator = ""
109 | if not node.children:
110 | indicator = "← Terminal"
111 | elif level == 0:
112 | indicator = "(Root)"
113 |
114 | # Print the current node
115 | print(f"{current_prefix}{node_id} Level {level}: {GREEN}{action}{RESET} {stats} {indicator}")
116 |
117 | # Prepare the prefix for children
118 | child_prefix = prefix + (" " if is_last else "│ ")
119 |
120 | # Sort children by some criteria (e.g., visits) if desired
121 | children = sorted(node.children, key=lambda x: x.visits, reverse=True) if node.children else []
122 |
123 | # Recursively print all children
124 | for i, child in enumerate(children):
125 | is_last_child = (i == len(children) - 1)
126 | _print_subtree(child, level + 1, child_prefix, is_last_child)
127 |
128 | # Start the recursive printing from the root
129 | _print_subtree(root, 0, "", True)
--------------------------------------------------------------------------------
/visual-tree-search-app/components/Sidebar.tsx:
--------------------------------------------------------------------------------
1 | import Link from 'next/link';
2 | import { useRouter } from 'next/router';
3 | import { Home, LayoutDashboard, Network, Search, ChevronLeft, ChevronRight } from 'lucide-react';
4 | import { Button } from "@/components/ui/button";
5 | import { ScrollArea } from "@/components/ui/scroll-area";
6 | import { useState, useEffect } from 'react';
7 |
8 | const Sidebar = () => {
9 | const router = useRouter();
10 | const [collapsed, setCollapsed] = useState(false);
11 | const [isMobile, setIsMobile] = useState(false);
12 |
13 | useEffect(() => {
14 | const checkScreenSize = () => {
15 | const mobile = window.innerWidth < 1024;
16 | setIsMobile(mobile);
17 | // Only collapse when switching from large to small screen
18 | if (mobile && !isMobile) {
19 | setCollapsed(true);
20 | document.documentElement.style.setProperty('--sidebar-width', '3.5rem');
21 | }
22 | // Auto expand when switching from small to large screen
23 | if (!mobile && isMobile) {
24 | setCollapsed(false);
25 | document.documentElement.style.setProperty('--sidebar-width', '14rem');
26 | }
27 | };
28 |
29 | // Set initial sidebar width
30 | document.documentElement.style.setProperty('--sidebar-width', collapsed ? '3.5rem' : '14rem');
31 |
32 | checkScreenSize();
33 | window.addEventListener('resize', checkScreenSize);
34 | return () => window.removeEventListener('resize', checkScreenSize);
35 | }, [isMobile, collapsed]);
36 |
37 | const menuItems = [
38 | {
39 | name: 'Home',
40 | href: '/',
41 | icon: Home
42 | },
43 | {
44 | name: 'MCTS',
45 | href: '/MCTSAgent',
46 | icon: LayoutDashboard
47 | },
48 | {
49 | name: 'LATS',
50 | href: '/LATSAgent',
51 | icon: Network
52 | },
53 | {
54 | name: 'BFS/DFS',
55 | href: '/SimpleSearchAgent',
56 | icon: Search
57 | },
58 | ];
59 |
60 | const isActive = (path: string) => {
61 | if (path === '/') {
62 | return router.pathname === '/';
63 | }
64 | return router.pathname.startsWith(path);
65 | };
66 |
67 | const toggleSidebar = () => {
68 | setCollapsed(!collapsed);
69 | document.documentElement.style.setProperty('--sidebar-width', collapsed ? '3.5rem' : '14rem');
70 | };
71 |
72 | return (
73 |
76 |
77 |
78 | {!collapsed &&
VisualTreeSearch
}
79 |
87 |
88 |
89 |
92 |
93 |
94 |
120 |
121 |
122 |
123 |
124 | {!collapsed && (
125 |
126 | VisualTreeSearch Demo
127 |
128 | )}
129 |
130 |
131 |
132 | );
133 | };
134 |
135 | export default Sidebar;
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/webagent_utils_async/action/base.py:
--------------------------------------------------------------------------------
1 | # copied and modified from https://github.com/ServiceNow/BrowserGym
2 | import playwright.async_api
3 | from abc import ABC, abstractmethod
4 | import ast
5 | import sys
6 | import os
7 | import importlib.util
8 | import logging
9 | from typing import Any, Callable, Optional, Tuple
10 | from pathlib import Path
11 | from datetime import datetime
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class AbstractActionSet(ABC):
17 | def __init__(self, strict: bool = False):
18 | self.strict = strict
19 |
20 | @abstractmethod
21 | def describe(self, with_long_description: bool = True, with_examples: bool = True) -> str:
22 | """
23 | Returns a textual description of this action space.
24 | """
25 |
26 | @abstractmethod
27 | def example_action(self, abstract: bool) -> str:
28 | """
29 | Returns an example action as a string.
30 | """
31 |
32 | @abstractmethod
33 | def to_python_code(self, action) -> str:
34 | """
35 | Converts the given action to browsergym-compatible python code.
36 |
37 | Args:
38 | action: the action to convert.
39 |
40 | Returns:
41 | Executable python code that performs the action in a browsergym environment.
42 | """
43 |
44 |
45 | def validate_python_syntax(code: str) -> Tuple[bool, str]:
46 | """
47 | Validate Python code syntax using AST parser.
48 |
49 | Args:
50 | code: String containing Python code
51 |
52 | Returns:
53 | Tuple of (is_valid, error_message)
54 | """
55 | try:
56 | ast.parse(code)
57 | return True, ""
58 | except SyntaxError as e:
59 | error_msg = f"Syntax error at line {e.lineno}, column {e.offset}: {e.msg}"
60 | return False, error_msg
61 | except Exception as e:
62 | return False, f"Parsing error: {str(e)}"
63 |
64 |
65 | def save_code_to_file(code: str, log_folder: str) -> str:
66 | """Save code to a file and return the file path."""
67 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
68 | code_logs_dir = os.path.join(log_folder, "code")
69 | os.makedirs(code_logs_dir, exist_ok=True)
70 | filename = f"code_{timestamp}.py"
71 | file_path = os.path.join(code_logs_dir, filename)
72 |
73 | header = f"""# Generated Code
74 | # Timestamp: {datetime.now().isoformat()}
75 | # File: {filename}
76 | """
77 |
78 | with open(file_path, 'w', encoding='utf-8') as f:
79 | f.write(header + '\n' + code)
80 |
81 | logger.info(f"Saved code to: {file_path}")
82 | return file_path
83 |
84 |
85 | async def execute_python_code(
86 | code: str,
87 | page: playwright.async_api.Page,
88 | context,
89 | send_message_to_user: callable,
90 | report_infeasible_instructions: callable,
91 | ):
92 | """
93 | Executes Python code in a new context, including asynchronous code using `await`.
94 |
95 | Args:
96 | code: the Python code to execute, as a string.
97 | page: the playwright page that will be made accessible to the code.
98 | send_message_to_user: utility function that will be made accessible to the code.
99 | report_infeasible_instructions: utility function that will be made accessible to the code.
100 | """
101 | globals = {
102 | "page": page,
103 | "context": context,
104 | "send_message_to_user": send_message_to_user,
105 | "report_infeasible_instructions": report_infeasible_instructions,
106 | }
107 |
108 | # Format the code with proper indentation
109 | formatted_code = "\n".join(" " + line for line in code.splitlines())
110 |
111 | # Create the async function wrapper with the properly indented code
112 | wrapper = f"""async def __ex():
113 | {formatted_code}"""
114 |
115 | # Execute the wrapped code
116 | exec_globals = {}
117 | exec(wrapper, globals, exec_globals)
118 | await exec_globals['__ex']()
119 |
120 |
121 | async def execute_python_code_safely(
122 | code: str,
123 | page: 'playwright.async_api.Page',
124 | context: Any,
125 | log_folder: str,
126 | send_message_to_user: Optional[Callable[[str], None]] = None,
127 | report_infeasible_instructions: Optional[Callable[[str], None]] = None
128 | ) -> str:
129 | """Execute Python code from file with provided context."""
130 |
131 | # Save the code to a file
132 | file_path = save_code_to_file(code, log_folder)
133 |
134 | try:
135 | # Add the code directory to Python path
136 | sys.path.insert(0, os.path.dirname(file_path))
137 |
138 | # Import the module using importlib
139 | spec = importlib.util.spec_from_file_location("generated_code", file_path)
140 | if spec is None or spec.loader is None:
141 | raise ImportError(f"Could not load spec for {file_path}")
142 |
143 | module = importlib.util.module_from_spec(spec)
144 |
145 | # Set the global variables in the module
146 | module.page = page
147 | module.context = context
148 | module.send_message_to_user = send_message_to_user
149 | module.report_infeasible_instructions = report_infeasible_instructions
150 |
151 | # Execute the module
152 | await spec.loader.exec_module(module)
153 |
154 | except Exception as e:
155 | logger.error(f"Error executing code: {e}")
156 | raise
157 |
158 | finally:
159 | # Remove the directory from sys.path
160 | if os.path.dirname(file_path) in sys.path:
161 | sys.path.remove(os.path.dirname(file_path))
162 |
163 | return file_path
164 |
165 |
--------------------------------------------------------------------------------
/visual-tree-search-state-reset/app/api/routes/test_db.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import create_engine, text
2 | import traceback
3 | import logging
4 | from typing import Dict, List, Any
5 | import json
6 |
7 | from fastapi import APIRouter, HTTPException
8 |
9 | # Configure basic logging
10 | logging.basicConfig(
11 | level=logging.INFO,
12 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
13 | )
14 |
15 | router = APIRouter()
16 |
17 | # Database connection parameters of webarena container
18 | DB_CONFIG = {
19 | "username": "root",
20 | "password": "1234567890",
21 | "database": "magentodb",
22 | "host": "172.17.0.2",
23 | "port": 3306
24 | }
25 |
26 | def get_db_connection():
27 | """Create and return a SQLAlchemy database engine"""
28 | try:
29 | connection_string = f"mysql+pymysql://{DB_CONFIG['username']}:{DB_CONFIG['password']}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}"
30 | logging.info(f"Creating connection with: {connection_string}")
31 |
32 | engine = create_engine(
33 | connection_string,
34 | connect_args={'connect_timeout': 30},
35 | echo=False # Set to True for SQL debugging
36 | )
37 | return engine
38 | except Exception as e:
39 | logging.error(f"Error creating database engine: {str(e)}")
40 | traceback.print_exc()
41 | raise HTTPException(status_code=500, detail=f"Database connection error: {str(e)}")
42 |
43 | @router.get("/status")
44 | async def db_status():
45 | """Check database connection status"""
46 | try:
47 | engine = get_db_connection()
48 | with engine.connect() as connection:
49 | version_query = connection.execute(text("SELECT VERSION()"))
50 | version = version_query.scalar()
51 |
52 | return {
53 | "status": "connected",
54 | "version": version,
55 | "config": {
56 | "host": DB_CONFIG["host"],
57 | "port": DB_CONFIG["port"],
58 | "database": DB_CONFIG["database"],
59 | "username": DB_CONFIG["username"]
60 | }
61 | }
62 | except Exception as e:
63 | logging.error(f"Database status check failed: {str(e)}")
64 | return {
65 | "status": "error",
66 | "message": str(e),
67 | "config": {
68 | "host": DB_CONFIG["host"],
69 | "port": DB_CONFIG["port"],
70 | "database": DB_CONFIG["database"],
71 | "username": DB_CONFIG["username"]
72 | }
73 | }
74 |
75 | @router.get("/tables")
76 | async def list_tables():
77 | """List all tables in the database"""
78 | try:
79 | engine = get_db_connection()
80 | tables = []
81 |
82 | with engine.connect() as connection:
83 | result = connection.execute(text("SHOW TABLES"))
84 | for row in result:
85 | tables.append(row[0])
86 |
87 | return {
88 | "status": "success",
89 | "table_count": len(tables),
90 | "tables": tables
91 | }
92 | except Exception as e:
93 | logging.error(f"Error listing tables: {str(e)}")
94 | traceback.print_exc()
95 | raise HTTPException(status_code=500, detail=f"Error listing tables: {str(e)}")
96 |
97 | @router.get("/tables/{table_name}/structure")
98 | async def get_table_structure(table_name: str):
99 | """Get the structure of a specific table"""
100 | try:
101 | engine = get_db_connection()
102 |
103 | with engine.connect() as connection:
104 | # Check if table exists
105 | tables_result = connection.execute(text("SHOW TABLES"))
106 | tables = [row[0] for row in tables_result]
107 |
108 | if table_name not in tables:
109 | raise HTTPException(status_code=404, detail=f"Table '{table_name}' not found")
110 |
111 | # Get table structure
112 | columns_result = connection.execute(text(f"DESCRIBE {table_name}"))
113 | columns = []
114 | for row in columns_result:
115 | columns.append({
116 | "field": row[0],
117 | "type": row[1],
118 | "null": row[2],
119 | "key": row[3],
120 | "default": row[4],
121 | "extra": row[5]
122 | })
123 |
124 | return {
125 | "status": "success",
126 | "table": table_name,
127 | "columns": columns
128 | }
129 | except HTTPException:
130 | raise
131 | except Exception as e:
132 | logging.error(f"Error getting table structure: {str(e)}")
133 | traceback.print_exc()
134 | raise HTTPException(status_code=500, detail=f"Error getting table structure: {str(e)}")
135 |
136 | @router.get("/query")
137 | async def run_test_query(query: str = "SELECT VERSION()"):
138 | """Run a test query (read-only for safety)"""
139 | # Only allow SELECT queries for safety
140 | if not query.strip().upper().startswith("SELECT"):
141 | raise HTTPException(status_code=400, detail="Only SELECT queries are allowed")
142 |
143 | try:
144 | engine = get_db_connection()
145 |
146 | with engine.connect() as connection:
147 | result = connection.execute(text(query))
148 | rows = [dict(row._mapping) for row in result]
149 |
150 | return {
151 | "status": "success",
152 | "query": query,
153 | "row_count": len(rows),
154 | "results": rows
155 | }
156 | except Exception as e:
157 | logging.error(f"Error executing query: {str(e)}")
158 | traceback.print_exc()
159 | raise HTTPException(status_code=500, detail=f"Error executing query: {str(e)}")
160 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/routes/tree_websocket.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from datetime import datetime
4 | from typing import List, Dict, Any
5 | import logging
6 |
7 | # Configure basic logging
8 | logging.basicConfig(
9 | level=logging.INFO,
10 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11 | )
12 |
13 | from fastapi import APIRouter, WebSocket, WebSocketDisconnect
14 |
15 | # Import the tree traversal handler
16 | from app.api.routes.tree_traversal import bfs_traversal
17 |
18 | router = APIRouter()
19 |
20 | # Track active WebSocket connections for tree visualization
21 | active_tree_connections: List[WebSocket] = []
22 |
23 | # Track the ping task
24 | tree_ping_task = None
25 |
26 | async def send_to_all_tree_clients(data: Dict[str, Any]):
27 | """Send a message to all connected tree visualization clients"""
28 | message = json.dumps(data)
29 | for connection in active_tree_connections:
30 | try:
31 | await connection.send_text(message)
32 | except Exception as e:
33 | logging.error(f"Error sending message to tree client: {e}")
34 |
35 | async def ping_tree_clients():
36 | """Send a ping to all tree visualization clients every second"""
37 | while True:
38 | if active_tree_connections:
39 | data = {
40 | "type": "ping",
41 | "message": "Tree server heartbeat",
42 | "timestamp": datetime.utcnow().isoformat()
43 | }
44 | await send_to_all_tree_clients(data)
45 | logging.info(f"Sent ping to {len(active_tree_connections)} tree websocket clients")
46 | await asyncio.sleep(1)
47 |
48 | # Define the WebSocket endpoint for tree visualization
49 | async def tree_websocket_endpoint(websocket: WebSocket):
50 | """Handle WebSocket connections for tree visualization"""
51 | await websocket.accept()
52 | active_tree_connections.append(websocket)
53 |
54 | # Start ping task if not already running
55 | global tree_ping_task
56 | if tree_ping_task is None or tree_ping_task.done():
57 | tree_ping_task = asyncio.create_task(ping_tree_clients())
58 |
59 | logging.info(f"Tree WebSocket client connected - total clients: {len(active_tree_connections)}")
60 |
61 | # Send initial connection message
62 | await websocket.send_json({
63 | "type": "connection",
64 | "message": "Connected to Tree WebSocket server"
65 | })
66 |
67 | try:
68 | while True:
69 | # Receive message from the client
70 | data = await websocket.receive_text()
71 |
72 | try:
73 | # Parse the message
74 | parsed_data = json.loads(data)
75 | logging.info(f"Received tree message: {parsed_data}")
76 |
77 | # Handle tree data
78 | if parsed_data.get("type") == "tree" and "content" in parsed_data:
79 | tree_data = parsed_data["content"]
80 |
81 | # Log the received tree
82 | logging.info(f"Received tree data: {json.dumps(tree_data)[:100]}...")
83 |
84 | # Send acknowledgment
85 | await websocket.send_json({
86 | "type": "info",
87 | "message": "Tree data received",
88 | "timestamp": datetime.utcnow().isoformat()
89 | })
90 |
91 | # Start BFS traversal
92 | await bfs_traversal(websocket, tree_data)
93 |
94 | # Handle traversal request
95 | elif parsed_data.get("type") == "traversal_request":
96 | algorithm = parsed_data.get("algorithm", "bfs")
97 | tree_data = parsed_data.get("tree")
98 |
99 | if tree_data:
100 | if algorithm == "bfs":
101 | await bfs_traversal(websocket, tree_data)
102 | else:
103 | await websocket.send_json({
104 | "type": "error",
105 | "message": f"Unsupported algorithm: {algorithm}"
106 | })
107 | else:
108 | await websocket.send_json({
109 | "type": "error",
110 | "message": "No tree data provided for traversal"
111 | })
112 | else:
113 | # Echo back other message types
114 | await websocket.send_json({
115 | "type": "echo",
116 | "message": parsed_data,
117 | "timestamp": datetime.utcnow().isoformat()
118 | })
119 | except json.JSONDecodeError as e:
120 | logging.error(f"Error parsing tree message: {e}")
121 | await websocket.send_json({
122 | "type": "error",
123 | "message": f"Invalid JSON: {str(e)}"
124 | })
125 | except WebSocketDisconnect:
126 | # Remove connection when client disconnects
127 | if websocket in active_tree_connections:
128 | active_tree_connections.remove(websocket)
129 | logging.info(f"Tree WebSocket client disconnected - remaining clients: {len(active_tree_connections)}")
130 | except Exception as e:
131 | logging.error(f"Tree WebSocket error: {e}")
132 | if websocket in active_tree_connections:
133 | active_tree_connections.remove(websocket)
134 |
135 | # Add a route for testing Tree WebSocket functionality via HTTP
136 | @router.get("/status")
137 | async def tree_websocket_status():
138 | """Get Tree WebSocket connection status"""
139 | return {
140 | "active": len(active_tree_connections),
141 | "status": "running"
142 | }
--------------------------------------------------------------------------------
/visual-tree-search-app/styles/globals.css:
--------------------------------------------------------------------------------
1 | @import "tailwindcss";
2 |
3 | @plugin "tailwindcss-animate";
4 |
5 | @custom-variant dark (&:is(.dark *));
6 |
7 | :root {
8 | --background: #f8fafc;
9 | --foreground: #334155;
10 | --radius: 0.6rem;
11 |
12 | /* Primary colors */
13 | --primary: #3b82f6;
14 | --primary-foreground: #ffffff;
15 | --primary-hover: #2563eb;
16 |
17 | /* Secondary colors */
18 | --secondary: #f1f5f9;
19 | --secondary-foreground: #475569;
20 | --secondary-hover: #e2e8f0;
21 |
22 | /* Accent colors */
23 | --accent: #6366f1;
24 | --accent-foreground: #ffffff;
25 | --accent-hover: #4f46e5;
26 |
27 | /* Muted colors */
28 | --muted: #f1f5f9;
29 | --muted-foreground: #64748b;
30 |
31 | /* Card colors */
32 | --card: #ffffff;
33 | --card-foreground: #334155;
34 | --card-border: #e2e8f0;
35 |
36 | /* Border colors */
37 | --border: #e2e8f0;
38 | --input: #e2e8f0;
39 | --ring: #3b82f6;
40 |
41 | /* Status colors */
42 | --success: #16a34a;
43 | --warning: #d97706;
44 | --error: #dc2626;
45 | --info: #3b82f6;
46 |
47 | /* Chart colors */
48 | --chart-1: #3b82f6;
49 | --chart-2: #6366f1;
50 | --chart-3: #db2777;
51 | --chart-4: #16a34a;
52 | --chart-5: #d97706;
53 |
54 | /* Sidebar colors */
55 | --sidebar: #f1f5f9;
56 | --sidebar-foreground: #475569;
57 | --sidebar-border: #e2e8f0;
58 | --sidebar-hover: #e2e8f0;
59 | }
60 |
61 | .dark {
62 | --background: #0f172a;
63 | --foreground: #e2e8f0;
64 |
65 | /* Primary colors */
66 | --primary: #60a5fa;
67 | --primary-foreground: #0f172a;
68 | --primary-hover: #3b82f6;
69 |
70 | /* Secondary colors */
71 | --secondary: #1e293b;
72 | --secondary-foreground: #e2e8f0;
73 | --secondary-hover: #334155;
74 |
75 | /* Accent colors */
76 | --accent: #818cf8;
77 | --accent-foreground: #0f172a;
78 | --accent-hover: #6366f1;
79 |
80 | /* Muted colors */
81 | --muted: #1e293b;
82 | --muted-foreground: #94a3b8;
83 |
84 | /* Card colors */
85 | --card: #1e293b;
86 | --card-foreground: #e2e8f0;
87 | --card-border: #334155;
88 |
89 | /* Border colors */
90 | --border: #334155;
91 | --input: #334155;
92 | --ring: #60a5fa;
93 |
94 | /* Status colors */
95 | --success: #34d399;
96 | --warning: #fbbf24;
97 | --error: #f87171;
98 | --info: #60a5fa;
99 |
100 | /* Chart colors */
101 | --chart-1: #60a5fa;
102 | --chart-2: #818cf8;
103 | --chart-3: #f472b6;
104 | --chart-4: #34d399;
105 | --chart-5: #fbbf24;
106 |
107 | /* Sidebar colors */
108 | --sidebar: #1e293b;
109 | --sidebar-foreground: #e2e8f0;
110 | --sidebar-border: #334155;
111 | --sidebar-hover: #334155;
112 | }
113 |
114 | @theme inline {
115 | --color-background: var(--background);
116 | --color-foreground: var(--foreground);
117 | --font-sans: var(--font-geist-sans);
118 | --font-mono: var(--font-geist-mono);
119 |
120 | /* Primary */
121 | --color-primary: var(--primary);
122 | --color-primary-foreground: var(--primary-foreground);
123 | --color-primary-hover: var(--primary-hover);
124 |
125 | /* Secondary */
126 | --color-secondary: var(--secondary);
127 | --color-secondary-foreground: var(--secondary-foreground);
128 | --color-secondary-hover: var(--secondary-hover);
129 |
130 | /* Accent */
131 | --color-accent: var(--accent);
132 | --color-accent-foreground: var(--accent-foreground);
133 | --color-accent-hover: var(--accent-hover);
134 |
135 | /* Muted */
136 | --color-muted: var(--muted);
137 | --color-muted-foreground: var(--muted-foreground);
138 |
139 | /* Card */
140 | --color-card: var(--card);
141 | --color-card-foreground: var(--card-foreground);
142 | --color-card-border: var(--card-border);
143 |
144 | /* Border */
145 | --color-border: var(--border);
146 | --color-input: var(--input);
147 | --color-ring: var(--ring);
148 |
149 | /* Status */
150 | --color-success: var(--success);
151 | --color-warning: var(--warning);
152 | --color-error: var(--error);
153 | --color-info: var(--info);
154 |
155 | /* Chart */
156 | --color-chart-1: var(--chart-1);
157 | --color-chart-2: var(--chart-2);
158 | --color-chart-3: var(--chart-3);
159 | --color-chart-4: var(--chart-4);
160 | --color-chart-5: var(--chart-5);
161 |
162 | /* Sidebar */
163 | --color-sidebar: var(--sidebar);
164 | --color-sidebar-foreground: var(--sidebar-foreground);
165 | --color-sidebar-border: var(--sidebar-border);
166 | --color-sidebar-hover: var(--sidebar-hover);
167 |
168 | /* Border radius */
169 | --radius-sm: calc(var(--radius) - 2px);
170 | --radius-md: var(--radius);
171 | --radius-lg: calc(var(--radius) + 2px);
172 | --radius-xl: calc(var(--radius) + 4px);
173 | }
174 |
175 | body {
176 | background: var(--background);
177 | color: var(--foreground);
178 | font-family: var(--font-sans);
179 | }
180 |
181 | @layer base {
182 | * {
183 | @apply border-border;
184 | }
185 | body {
186 | @apply bg-background text-foreground antialiased;
187 | }
188 | h1, h2, h3, h4, h5, h6 {
189 | @apply font-semibold tracking-tight;
190 | }
191 | h1 {
192 | @apply text-4xl md:text-5xl font-bold;
193 | }
194 | h2 {
195 | @apply text-3xl md:text-4xl font-bold;
196 | }
197 | h3 {
198 | @apply text-2xl md:text-3xl;
199 | }
200 | h4 {
201 | @apply text-xl md:text-2xl;
202 | }
203 | h5 {
204 | @apply text-lg md:text-xl;
205 | }
206 | h6 {
207 | @apply text-base md:text-lg;
208 | }
209 |
210 | /* Modern scrollbar */
211 | ::-webkit-scrollbar {
212 | width: 8px;
213 | height: 8px;
214 | }
215 |
216 | ::-webkit-scrollbar-track {
217 | background: transparent;
218 | }
219 |
220 | ::-webkit-scrollbar-thumb {
221 | background: var(--muted-foreground);
222 | opacity: 0.5;
223 | border-radius: 999px;
224 | }
225 |
226 | ::-webkit-scrollbar-thumb:hover {
227 | background: var(--primary);
228 | }
229 | }
230 |
231 | /* Card hover effects */
232 | .card {
233 | transition: all 0.2s ease;
234 | }
235 |
236 | .card:hover {
237 | transform: translateY(-2px);
238 | box-shadow: 0 10px 25px -5px rgba(var(--primary), 0.1),
239 | 0 8px 10px -6px rgba(var(--primary), 0.05);
240 | }
241 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional, Tuple, List
2 | from datetime import datetime
3 | from dotenv import load_dotenv
4 | load_dotenv()
5 |
6 | from .tree_vis import RED, better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
7 | from .lats_node import LATSNode
8 | from .base_agent import BaseAgent
9 |
10 | class LATSAgent(BaseAgent):
11 | async def run(self, websocket=None) -> list[LATSNode]:
12 | # if websocket:
13 | # await websocket.send_json({
14 | # "type": "search_status",
15 | # "status": "started",
16 | # "message": "Starting LATS search",
17 | # "timestamp": datetime.utcnow().isoformat()
18 | # })
19 |
20 | best_node = await self.lats_search(websocket)
21 | print_trajectory(best_node)
22 | return best_node
23 |
24 | async def lats_search(self, websocket=None):
25 | terminal_nodes = []
26 |
27 | for i in range(self.config.iterations):
28 | await self.websocket_iteration_start(i, websocket=websocket)
29 |
30 | print(f"Iteration {i}/{self.config.iterations} ...")
31 |
32 | # Step 1: Node Selection
33 | ## TODO: move websocket node selection into node_selection method
34 | print(f"{GREEN}Step 1: node selection{RESET}")
35 | await self.websocket_step_start(step=1, step_name="node_selection", websocket=websocket)
36 | node = await self.node_selection(self.root_node)
37 | await self.websocket_node_selection(node, websocket=websocket)
38 |
39 | if node is None:
40 | print("All paths lead to terminal nodes with value 0. Ending search.")
41 | break
42 |
43 | # Step 2: Node Expansion
44 | print(f"{GREEN}Step 2: node expansion{RESET}")
45 | await self.websocket_step_start(step=2, step_name="node_expansion", websocket=websocket)
46 | if node.depth < self.config.max_depth :
47 | await self.node_expansion(node, websocket)
48 | if node is None:
49 | # all the nodes are terminal, stop the search
50 | print(f"{RED}All nodes are terminal, stopping search{RESET}")
51 | break
52 | tree_data = self._get_tree_data()
53 | if websocket:
54 | await self.websocket_tree_update(type="tree_update_node_expansion", websocket=websocket, tree_data=tree_data)
55 | else:
56 | print_entire_tree(self.root_node)
57 |
58 |
59 | # Step 3: Evaluation
60 | print(f"{GREEN}Step 3: node chilren evaluation{RESET}")
61 | await self.websocket_step_start(step=3, step_name="node_children_evaluation", websocket=websocket)
62 | await self.node_children_evaluation(node)
63 | tree_data = self._get_tree_data()
64 | if websocket:
65 | await self.websocket_tree_update(type="tree_update_node_children_evaluation", websocket=websocket, tree_data=tree_data)
66 | else:
67 | print("after evaluation")
68 | print_entire_tree(self.root_node)
69 |
70 |
71 | # Step 4: Simulation
72 | print(f"{GREEN}Step 4: simulation{RESET}")
73 | await self.websocket_step_start(step=4, step_name="simulation", websocket=websocket)
74 | selected_node = max(node.children, key=lambda child: child.value)
75 | await self.websocket_node_selection(selected_node, websocket=websocket, type="node_selected_for_simulation")
76 | reward, terminal_node = await self.simulation(selected_node, websocket=websocket)
77 | terminal_nodes.append(terminal_node)
78 | await self.websocket_simulation_result(reward, terminal_node, websocket=websocket)
79 |
80 | # simulation score threshold
81 | if reward >= self.config.simulation_score:
82 | await self.websocket_search_complete("success", reward, terminal_node.get_trajectory(), websocket=websocket)
83 | await self.playwright_manager.close()
84 | return terminal_node
85 |
86 | # Step 5: Backpropagation
87 | print(f"{GREEN}Step 5: backpropagation{RESET}")
88 | await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
89 | self.backpropagate(selected_node, reward)
90 | tree_data = self._get_tree_data()
91 | print_entire_tree(self.root_node)
92 | print(tree_data)
93 | if websocket:
94 | await self.websocket_tree_update(type="tree_update_node_backpropagation", websocket=websocket, tree_data=tree_data)
95 | else:
96 | print("after backpropagation")
97 | print_entire_tree(self.root_node)
98 |
99 | # Find best node
100 | all_nodes_list = collect_all_nodes(self.root_node)
101 | all_nodes_list.extend(terminal_nodes)
102 |
103 | ## temp change: if value is the same, choose the deeper node
104 | best_child = max(all_nodes_list, key=lambda x: (x.value, x.depth))
105 |
106 | if best_child.value >= 0.75:
107 | print("Successful trajectory found")
108 | await self.websocket_search_complete("success", best_child.value, best_child.get_trajectory(), websocket=websocket)
109 | else:
110 | print("Unsuccessful trajectory found")
111 | await self.websocket_search_complete("partial_success", best_child.value, best_child.get_trajectory(), websocket=websocket)
112 | await self.playwright_manager.close()
113 |
114 | return best_child if best_child is not None else self.root_node
115 |
116 | async def node_selection(self, node: LATSNode, websocket=None) -> Optional[LATSNode]:
117 | if node.is_terminal:
118 | return None
119 | ## TODO; move this node selection logic from LATSNode to LATSAgent
120 | selected_node = node.get_best_leaf()
121 | await self.websocket_node_selection(selected_node, websocket=websocket)
122 | return selected_node
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/websocket-client.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | WebSocket Test Client
7 |
76 |
77 |
78 | WebSocket Test Client
79 |
80 | Disconnected
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
220 |
221 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/test-tree-search-ws-simple.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import websockets
4 | import argparse
5 | import logging
6 | import sys
7 | import os
8 | from datetime import datetime
9 |
10 | # Configure logging
11 | logging.basicConfig(
12 | level=logging.INFO,
13 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
14 | )
15 | logger = logging.getLogger(__name__)
16 |
17 | ## for BFS and DFS
18 | # account_reset
19 | # browser_setup
20 | # node_created
21 | # node_selected
22 | # tree_update_node_expansion
23 | # tree_update_node_evaluation
24 |
25 |
26 | # ANSI color codes for different message types
27 | COLORS = {
28 | # Core updates
29 | 'iteration_start': '\033[94m', # Blue
30 | 'step_start': '\033[94m', # Blue
31 |
32 | # Node operations
33 | 'node_selected': '\033[92m', # Green
34 | 'node_created': '\033[92m', # Green
35 | 'node_simulated': '\033[92m', # Green
36 | 'node_terminal': '\033[92m', # Green
37 |
38 | # Tree/Path updates
39 | 'tree_update': '\033[96m', # Cyan
40 | 'tree_update_node_expansion': '\033[96m', # Cyan
41 | 'tree_update_node_evaluation': '\033[96m', # Cyan
42 | 'trajectory_update': '\033[96m', # Cyan
43 | 'removed_simulation': '\033[96m', # Cyan
44 |
45 | # Results/Completion
46 | 'simulation_result': '\033[93m', # Yellow
47 | 'search_complete': '\033[95m', # Magenta
48 | 'success': '\033[95m', # Magenta
49 | 'partial_success': '\033[93m', # Yellow
50 | 'failure': '\033[91m', # Red
51 |
52 | # System messages
53 | 'account_reset': '\033[91m', # Red
54 | 'browser_setup': '\033[91m', # Red
55 | 'error': '\033[91m', # Red
56 |
57 | # Status updates
58 | 'status_update': '\033[94m', # Blue
59 | 'reset': '\033[0m' # Reset
60 | }
61 |
62 | # Default values
63 | DEFAULT_WS_URL = "ws://localhost:3000/tree-search-ws"
64 | DEFAULT_STARTING_URL = "http://xwebarena.pathonai.org:7770/"
65 | DEFAULT_GOAL = "search running shoes, click on the first result"
66 |
67 | async def connect_and_test_search(
68 | ws_url: str,
69 | starting_url: str,
70 | goal: str,
71 | search_algorithm: str = "bfs",
72 | max_depth: int = 3
73 | ):
74 | """
75 | Connect to the WebSocket endpoint and test the tree search functionality.
76 |
77 | Args:
78 | ws_url: WebSocket URL to connect to
79 | starting_url: URL to start the search from
80 | goal: Goal to achieve
81 | search_algorithm: Search algorithm to use (bfs or dfs)
82 | max_depth: Maximum depth for the search tree
83 | """
84 | logger.info(f"Connecting to WebSocket at {ws_url}")
85 |
86 | async with websockets.connect(ws_url) as websocket:
87 | logger.info("Connected to WebSocket")
88 |
89 | # Wait for connection established message
90 | response = await websocket.recv()
91 | data = json.loads(response)
92 | if data.get("type") == "connection_established":
93 | logger.info(f"Connection established with ID: {data.get('connection_id')}")
94 |
95 | # Send search request
96 | request = {
97 | "type": "start_search",
98 | "agent_type": "SimpleSearchAgent",
99 | "starting_url": starting_url,
100 | "goal": goal,
101 | "search_algorithm": search_algorithm,
102 | "max_depth": max_depth
103 | }
104 |
105 | logger.info(f"Sending search request: {request}")
106 | await websocket.send(json.dumps(request))
107 |
108 | # Process responses
109 | while True:
110 | try:
111 | response = await websocket.recv()
112 | data = json.loads(response)
113 |
114 | # Print the raw websocket message with colored type
115 | msg_type = data.get("type", "unknown")
116 | color = COLORS.get(msg_type, COLORS['reset'])
117 | print(f"\nWebSocket message - Type: {color}{msg_type}{COLORS['reset']}")
118 | print(f"Raw message: {json.dumps(data, indent=2)}")
119 |
120 | if msg_type == "search_complete":
121 | break
122 |
123 | except websockets.exceptions.ConnectionClosed:
124 | logger.warning("WebSocket connection closed")
125 | break
126 | except Exception as e:
127 | logger.error(f"Error processing message: {e}")
128 | break
129 |
130 | logger.info("Test completed")
131 |
132 | def parse_arguments():
133 | """Parse command line arguments"""
134 | parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality")
135 |
136 | parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL,
137 | help=f"WebSocket URL (default: {DEFAULT_WS_URL})")
138 |
139 | parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL,
140 | help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})")
141 |
142 | parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
143 | help=f"Goal to achieve (default: {DEFAULT_GOAL})")
144 |
145 | parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs"], default="bfs",
146 | help="Search algorithm to use (default: bfs)")
147 |
148 | parser.add_argument("--max-depth", type=int, default=3,
149 | help="Maximum depth for the search tree (default: 3)")
150 |
151 | # Add the new argument for log file
152 | parser.add_argument("--log-file", type=str,
153 | help="File to save the colored output to")
154 |
155 | return parser.parse_args()
156 |
157 | async def main():
158 | """Main entry point"""
159 | args = parse_arguments()
160 |
161 | # Setup logging to file if requested
162 | original_stdout = sys.stdout
163 | original_stderr = sys.stderr
164 | log_file = None
165 |
166 | if args.log_file:
167 | class TeeOutput:
168 | def __init__(self, terminal, log_file):
169 | self.terminal = terminal
170 | self.log_file = log_file
171 |
172 | def write(self, message):
173 | self.terminal.write(message)
174 | self.log_file.write(message)
175 |
176 | def flush(self):
177 | self.terminal.flush()
178 | self.log_file.flush()
179 |
180 | log_file = open(args.log_file, 'w', encoding='utf-8')
181 | sys.stdout = TeeOutput(sys.stdout, log_file)
182 | sys.stderr = TeeOutput(sys.stderr, log_file)
183 | logger.info(f"Logging colored output to {args.log_file}")
184 |
185 | logger.info("Starting tree search WebSocket test")
186 | logger.info(f"WebSocket URL: {args.ws_url}")
187 | logger.info(f"Starting URL: {args.starting_url}")
188 | logger.info(f"Goal: {args.goal}")
189 | logger.info(f"Algorithm: {args.algorithm}")
190 | logger.info(f"Max depth: {args.max_depth}")
191 |
192 | try:
193 | await connect_and_test_search(
194 | ws_url=args.ws_url,
195 | starting_url=args.starting_url,
196 | goal=args.goal,
197 | search_algorithm=args.algorithm,
198 | max_depth=args.max_depth
199 | )
200 | finally:
201 | # Clean up if logging to file
202 | if log_file:
203 | sys.stdout = original_stdout
204 | sys.stderr = original_stderr
205 | log_file.close()
206 | logger.info(f"Closed log file: {args.log_file}")
207 |
208 | if __name__ == "__main__":
209 | asyncio.run(main())
210 |
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/trajectory_score.py:
--------------------------------------------------------------------------------
1 | """Module for scoring and evaluating action trajectories using LLMs."""
2 |
3 | import base64
4 | import json
5 | import datetime
6 | from typing import Any, Optional, List, Dict, TypedDict
7 | from openai import OpenAI
8 |
9 | class TrajectoryMetrics(TypedDict):
10 | """Structured metrics for trajectory evaluation."""
11 | overall_score: float
12 | efficiency_score: float
13 | accuracy_score: float
14 | robustness_score: float
15 | detailed_explanation: str
16 | improvement_suggestions: List[str]
17 | key_achievements: List[str]
18 | potential_issues: List[str]
19 | metadata: Dict[str, Any]
20 |
21 | SYSTEM_PROMPT = \
22 | """You are an expert web task completion evaluator. Your task is to provide a comprehensive evaluation of web task completion
23 | by analyzing the trajectory against the desired goal. Consider multiple aspects of the task execution and provide detailed feedback.
24 |
25 | Analyze the provided trajectory and screenshot of the web page, return a JSON response with:
26 | 1. overall_score (float 0-10): Overall task completion score
27 | 2. efficiency_score (float 0-10): How well the task was completed (minimal steps, optimal path)
28 | 3. accuracy_score (float 0-10): How precisely the actions were executed
29 | 4. robustness_score (float 0-10): How well the solution handles edge cases
30 | 5. detailed_explanation (string): Comprehensive analysis of the execution
31 | 6. improvement_suggestions (list of strings): Specific ways to improve the solution
32 | 7. key_achievements (list of strings): Important milestones reached
33 | 8. potential_issues (list of strings): Areas that could be problematic
34 |
35 | Example format:
36 | {
37 | "overall_score": 8.5,
38 | "efficiency_score": 9.0,
39 | "accuracy_score": 8.0,
40 | "robustness_score": 7.5,
41 | "detailed_explanation": "The trajectory effectively achieves the goal with minimal steps...",
42 | "improvement_suggestions": ["Could have used more efficient selectors", "Consider adding error handling"],
43 | "key_achievements": ["Successfully logged in", "Found target element"],
44 | "potential_issues": ["No timeout handling", "Assumes specific page layout"]
45 | }
46 | """
47 |
48 | USER_PROMPT_TEMPLATE = \
49 | """Goal: {goal}
50 |
51 | Trajectory:
52 | {trajectory_str}
53 |
54 | Current Page State:
55 | {page_state}
56 |
57 | Please provide a comprehensive evaluation of the task completion."""
58 |
59 | def format_trajectory_step(step: Dict[str, Any], index: int) -> str:
60 | """Format a single trajectory step with detailed information."""
61 | return f"""Step {index}:
62 | Action: {step['action']}
63 | Description: {step['natural_language_description']}
64 | Target: {step.get('target', 'N/A')}
65 | Status: {step.get('status', 'completed')}
66 | Output: {step.get('output', 'N/A')}"""
67 |
68 | def create_llm_prompt(
69 | trajectory: List[Dict[str, Any]],
70 | goal: str,
71 | page_state: Optional[Dict[str, Any]] = None
72 | ) -> str:
73 | """
74 | Creates a prompt for LLM scoring and processes trajectory information.
75 |
76 | Args:
77 | trajectory: List of dictionaries containing action and description
78 | goal: The goal of the trajectory
79 | page_state: Optional dictionary containing current page state information
80 |
81 | Returns:
82 | str: Formatted prompt string
83 | """
84 | # Format trajectory steps with more detail
85 | trajectory_str = "\n\n".join(
86 | format_trajectory_step(step, i+1)
87 | for i, step in enumerate(trajectory)
88 | )
89 |
90 | # Format page state if available
91 | page_state_str = "No page state information available"
92 | if page_state:
93 | page_state_str = json.dumps(page_state, indent=2)
94 |
95 | prompt = USER_PROMPT_TEMPLATE.format(
96 | goal=goal,
97 | trajectory_str=trajectory_str,
98 | page_state=page_state_str
99 | )
100 | return prompt
101 |
102 | def validate_evaluation(evaluation: Dict[str, Any]) -> bool:
103 | """Validate the evaluation output has all required fields and correct types."""
104 | required_fields = {
105 | 'overall_score': (int, float),
106 | 'efficiency_score': (int, float),
107 | 'accuracy_score': (int, float),
108 | 'robustness_score': (int, float),
109 | 'detailed_explanation': str,
110 | 'improvement_suggestions': list,
111 | 'key_achievements': list,
112 | 'potential_issues': list
113 | }
114 |
115 | for field, expected_type in required_fields.items():
116 | if field not in evaluation:
117 | return False
118 | if not isinstance(evaluation[field], expected_type):
119 | return False
120 | if isinstance(evaluation[field], (int, float)):
121 | if not 0 <= evaluation[field] <= 10:
122 | return False
123 |
124 | return True
125 |
126 | def normalize_scores(evaluation: Dict[str, Any]) -> Dict[str, Any]:
127 | """Normalize all scores to be between 0 and 1."""
128 | score_fields = ['overall_score', 'efficiency_score', 'accuracy_score', 'robustness_score']
129 | for field in score_fields:
130 | if field in evaluation:
131 | evaluation[field] = evaluation[field] / 10.0
132 | return evaluation
133 |
134 | def score_trajectory_with_openai(
135 | prompt: str,
136 | openai_client: OpenAI,
137 | model: str = "gpt-4o",
138 | screenshot: Optional[bytes] = None
139 | ) -> Dict[str, Any]:
140 | """
141 | Uses OpenAI to score the trajectory based on the provided prompt.
142 |
143 | Args:
144 | prompt: The prompt to send to OpenAI
145 | openai_client: OpenAI client instance
146 | model: OpenAI model to use
147 | screenshot: Screenshot of the current page
148 |
149 | Returns:
150 | dict: Parsed response containing comprehensive evaluation
151 | """
152 | system_message = SYSTEM_PROMPT
153 |
154 | try:
155 | content = [
156 | {"type": "text", "text": prompt},
157 | ]
158 | if screenshot is not None:
159 | base64_image = base64.b64encode(screenshot).decode('utf-8')
160 | content.append({
161 | "type": "image_url",
162 | "image_url": {
163 | "url": f"data:image/jpeg;base64,{base64_image}",
164 | "detail": "high"
165 | }
166 | })
167 |
168 | response = openai_client.chat.completions.create(
169 | model=model,
170 | messages=[
171 | {"role": "system", "content": system_message},
172 | {"role": "user", "content": content}
173 | ],
174 | response_format={"type": "json_object"}
175 | )
176 |
177 | evaluation = json.loads(response.choices[0].message.content)
178 |
179 | # Validate evaluation
180 | if not validate_evaluation(evaluation):
181 | raise ValueError("Invalid evaluation format")
182 |
183 | # Normalize scores
184 | evaluation = normalize_scores(evaluation)
185 |
186 | # Add metadata
187 | evaluation["metadata"] = {
188 | "model_used": model,
189 | "timestamp": datetime.datetime.now().isoformat(),
190 | "has_screenshot": screenshot is not None
191 | }
192 |
193 | return evaluation
194 |
195 | except Exception as e:
196 | return {
197 | "overall_score": 0.0,
198 | "efficiency_score": 0.0,
199 | "accuracy_score": 0.0,
200 | "robustness_score": 0.0,
201 | "detailed_explanation": f"Error occurred during evaluation: {str(e)}",
202 | "improvement_suggestions": ["Check API connection and try again"],
203 | "key_achievements": [],
204 | "potential_issues": ["Evaluation failed"],
205 | "metadata": {
206 | "error": str(e),
207 | "timestamp": datetime.datetime.now().isoformat()
208 | }
209 | }
--------------------------------------------------------------------------------
/visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dataclasses import dataclass
3 | from typing import Optional
4 | from pydantic import BaseModel
5 | import base64
6 | from ...webagent_utils_async.evaluation.feedback import Feedback
7 |
8 | @dataclass
9 | class Element:
10 | """Represents a DOM element with its properties."""
11 | text: str
12 | tag: str
13 | id: str
14 | title: str
15 | ariaLabel: str
16 | name: str
17 | value: str
18 | placeholder: str
19 | class_name: str # Changed from 'class' as it's a reserved keyword
20 | role: str
21 | unique_selector: str
22 | selector_uniqueness_validated: bool
23 |
24 | class Observation(BaseModel):
25 | text: str
26 | image: Optional[bytes] = None
27 | image_base64: Optional[str] = None
28 |
29 | def get_base64_image(self):
30 | if self.image_base64 is None:
31 | self.image_base64 = base64.b64encode(self.image).decode('utf-8')
32 | return self.image_base64
33 |
34 | class LATSNode:
35 | """
36 | A node class for Language-based Action Tree Search (LATS).
37 |
38 | This class implements a tree structure for MCTS-like search algorithms,
39 | specifically designed for language-based action planning in UI interactions.
40 |
41 | Attributes:
42 | natural_language_description (str): Human-readable description of the action
43 | action (str): The actual action to be executed
44 | prob (float): Probability or confidence score for this action
45 | element (Element): DOM element associated with this action
46 | goal (str): The target goal state
47 | parent (Optional[LATSNode]): Parent node in the tree
48 | children (list[LATSNode]): Child nodes in the tree
49 | visits (int): Number of times this node has been visited
50 | value (float): Accumulated value/score of this node
51 | depth (int): Depth of this node in the tree
52 | is_terminal (bool): Whether this node is a terminal state
53 | reward (float): Reward received at this node
54 | exhausted (bool): Whether all children have been explored
55 | em (float): Exact match score for evaluation
56 | """
57 |
58 | def __init__(
59 | self,
60 | natural_language_description: str,
61 | action: str,
62 | prob: float,
63 | element: dict, # Using dict instead of Element for backward compatibility
64 | goal: str,
65 | parent: Optional['LATSNode'] = None
66 | ) -> None:
67 | """
68 | Initialize a new LATSNode.
69 |
70 | Args:
71 | natural_language_description: Human-readable description of the action
72 | action: The actual action to be executed
73 | prob: Probability or confidence score for this action
74 | element: DOM element associated with this action
75 | goal: The target goal state
76 | parent: Parent node in the tree, if any
77 | """
78 | self.natural_language_description = natural_language_description
79 | self.action = action
80 | self.prob = prob
81 | self.element = element
82 | self.feedback = ''
83 | self.goal_finish_feedback: Optional[Feedback] = None
84 | self.parent = parent
85 | self.goal = goal
86 | self.children: list[LATSNode] = []
87 | self.visits = 0
88 | self.value = 0.0
89 | self.depth = 0 if parent is None else parent.depth + 1
90 | self.is_terminal = False
91 | # The goal has been achieved;
92 | # The maximum depth has been reached;
93 | # A failure condition has been triggered.
94 | self.exhausted = False # If all children are terminal
95 | self.em = 0.0 # Exact match, evaluation metric
96 | self.observation: Optional[Observation] = None
97 |
98 | def uct(self) -> float:
99 | """
100 | Calculate the UCT (Upper Confidence Bound for Trees) value for this node.
101 |
102 | Returns:
103 | float: The UCT value for this node. If the node has never been visited,
104 | returns the node's current value.
105 | """
106 | if self.visits == 0:
107 | return self.value
108 | return self.value / self.visits + np.sqrt(2 * np.log(self.parent.visits) / self.visits)
109 |
110 | def get_best_leaf(self) -> 'LATSNode':
111 | """
112 | Recursively get the best leaf node from the current node.
113 |
114 | The method searches through unfinished (non-terminal) children,
115 | selects the one with the highest UCT score, and continues the
116 | search recursively until a leaf node (with no unfinished children) is reached.
117 |
118 | Returns:
119 | LATSNode: The best leaf node for expansion based on UCT.
120 | """
121 | unfinished_children = [c for c in self.children if not c.is_terminal]
122 | if not unfinished_children:
123 | return self
124 |
125 | best_child = max(unfinished_children, key=lambda x: x.uct())
126 | return best_child.get_best_leaf()
127 |
128 | def get_action_trajectory(self) -> list[dict]:
129 | trajectory = []
130 | node = self
131 | # exclude the root node
132 | while node.parent is not None:
133 | trajectory.append({
134 | "action": node.action,
135 | "natural_language_description": node.natural_language_description,
136 | "element": node.element
137 | })
138 | node = node.parent
139 | return trajectory[::-1]
140 |
141 | def get_trajectory(self) -> list[dict]:
142 | trajectory = []
143 | node = self
144 | # exclude the root node
145 | while node.parent is not None:
146 | trajectory.append({
147 | "natural_language_description": node.natural_language_description,
148 | "action": node.action
149 | })
150 | node = node.parent
151 | return trajectory[::-1]
152 |
153 | def add_child(self, child: 'LATSNode') -> None:
154 | self.children.append(child)
155 | child.parent = self
156 | child.depth = self.depth + 1
157 |
158 | def check_terminal(self) -> bool:
159 | if not self.children or all(child.is_terminal for child in self.children):
160 | self.is_terminal = True
161 | if self.parent:
162 | self.parent.check_terminal()
163 |
164 | def __str__(self) -> str:
165 | """
166 | Get a string representation of the node.
167 |
168 | Returns:
169 | str: A string describing the node's key attributes
170 | """
171 | return (f"Node(depth={self.depth}, value={self.value:.2f}, "
172 | f"visits={self.visits}, action={self.action}, "
173 | f"feedback={self.feedback})")
174 |
175 | def to_dict(self) -> dict:
176 | """
177 | Convert the node and its subtree to a dictionary representation.
178 |
179 | Returns:
180 | dict: A dictionary containing all node attributes and recursive
181 | representations of parent and children nodes
182 | """
183 | return {
184 | 'state': self.state,
185 | 'question': self.question,
186 | 'parent': self.parent.to_dict() if self.parent else None,
187 | 'children': [child.to_dict() for child in self.children],
188 | 'visits': self.visits,
189 | 'value': self.value,
190 | 'depth': self.depth,
191 | 'is_terminal': self.is_terminal,
192 | # 'reward': self.reward,
193 | 'em': self.em,
194 | }
195 |
196 | @property
197 | def state(self) -> dict:
198 | """
199 | Get the current state representation of the node.
200 |
201 | Returns:
202 | dict: A dictionary containing the node's state information
203 | """
204 | return {
205 | 'natural_language_description': self.natural_language_description,
206 | 'action': self.action,
207 | 'prob': self.prob,
208 | 'element': self.element
209 | }
210 |
211 | @property
212 | def question(self) -> str:
213 | """
214 | Get the goal/question associated with this node.
215 |
216 | Returns:
217 | str: The goal or question string
218 | """
219 | return self.goal
--------------------------------------------------------------------------------
/visual-tree-search-backend/test/test-tree-search-ws-lats.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import websockets
4 | import argparse
5 | import logging
6 | import sys
7 | import os
8 | from datetime import datetime
9 |
10 | # Configure logging
11 | logging.basicConfig(
12 | level=logging.INFO,
13 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
14 | )
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | # account_reset
19 | # browser_setup
20 |
21 | ## for LATS
22 | # step_start
23 | # node_created
24 | # node_selected
25 | # node_selected_for_simulation
26 | # tree_update_node_expansion
27 | # tree_update_node_children_evaluation
28 | # tree_update_node_backpropagation
29 | # removed_simulation
30 |
31 | COLORS = {
32 | # Core updates
33 | 'iteration_start': '\033[94m', # Blue
34 | 'step_start': '\033[94m', # Blue
35 |
36 | # Node operations
37 | 'node_selected': '\033[92m', # Green
38 | 'node_selected_for_simulation': '\033[92m', # Green
39 | 'node_created': '\033[92m', # Green
40 | 'node_simulated': '\033[92m', # Green
41 | 'node_terminal': '\033[92m', # Green
42 |
43 | # Tree/Path updates
44 | 'tree_update': '\033[96m', # Cyan
45 | 'tree_update_node_expansion': '\033[96m', # Cyan
46 | 'tree_update_node_evaluation': '\033[96m', # Cyan
47 | 'tree_update_node_children_evaluation': '\033[96m', # Cyan
48 | 'tree_update_node_backpropagation': '\033[96m', # Cyan
49 | 'tree_update_simulation': '\033[96m', # Cyan
50 | 'trajectory_update': '\033[96m', # Cyan
51 | 'removed_simulation': '\033[96m', # Cyan
52 |
53 | # Results/Completion
54 | 'simulation_result': '\033[93m', # Yellow
55 | 'search_complete': '\033[95m', # Magenta
56 | 'success': '\033[95m', # Magenta
57 | 'partial_success': '\033[93m', # Yellow
58 | 'failure': '\033[91m', # Red
59 |
60 | # System messages
61 | 'account_reset': '\033[91m', # Red
62 | 'browser_setup': '\033[91m', # Red
63 | 'error': '\033[91m', # Red
64 |
65 | # Status updates
66 | 'status_update': '\033[94m', # Blue
67 | 'reset': '\033[0m' # Reset
68 | }
69 |
70 | # Default values
71 | DEFAULT_WS_URL = "ws://localhost:3000/tree-search-ws"
72 | DEFAULT_STARTING_URL = "http://xwebarena.pathonai.org:7770/"
73 | DEFAULT_GOAL = "search running shoes, click on the first result"
74 |
75 | async def connect_and_test_search(
76 | ws_url: str,
77 | starting_url: str,
78 | goal: str,
79 | search_algorithm: str = "bfs",
80 | max_depth: int = 3,
81 | iterations: int = 5
82 | ):
83 | """
84 | Connect to the WebSocket endpoint and test the tree search functionality.
85 |
86 | Args:
87 | ws_url: WebSocket URL to connect to
88 | starting_url: URL to start the search from
89 | goal: Goal to achieve
90 | search_algorithm: Search algorithm to use (bfs or dfs)
91 | max_depth: Maximum depth for the search tree
92 | iterations: Number of iterations for LATS algorithm
93 | """
94 | logger.info(f"Connecting to WebSocket at {ws_url}")
95 |
96 | async with websockets.connect(ws_url) as websocket:
97 | logger.info("Connected to WebSocket")
98 |
99 | # Wait for connection established message
100 | response = await websocket.recv()
101 | data = json.loads(response)
102 | if data.get("type") == "connection_established":
103 | logger.info(f"Connection established with ID: {data.get('connection_id')}")
104 |
105 | # Send search request
106 | request = {
107 | "type": "start_search",
108 | "agent_type": "LATSAgent",
109 | "starting_url": starting_url,
110 | "goal": goal,
111 | "search_algorithm": search_algorithm,
112 | "max_depth": max_depth,
113 | "iterations": iterations
114 | }
115 |
116 | logger.info(f"Sending search request: {request}")
117 | await websocket.send(json.dumps(request))
118 |
119 | # Process responses
120 | while True:
121 | try:
122 | response = await websocket.recv()
123 | data = json.loads(response)
124 |
125 | # Print the raw websocket message with colored type
126 | msg_type = data.get("type", "unknown")
127 | color = COLORS.get(msg_type, COLORS['reset'])
128 | print(f"\nWebSocket message - Type: {color}{msg_type}{COLORS['reset']}")
129 | print(f"Raw message: {json.dumps(data, indent=2)}")
130 |
131 | if msg_type == "search_complete":
132 | break
133 |
134 | except websockets.exceptions.ConnectionClosed:
135 | logger.warning("WebSocket connection closed")
136 | break
137 | except Exception as e:
138 | logger.error(f"Error processing message: {e}")
139 | break
140 |
141 | logger.info("Test completed")
142 |
143 | def parse_arguments():
144 | """Parse command line arguments"""
145 | parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality")
146 |
147 | parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL,
148 | help=f"WebSocket URL (default: {DEFAULT_WS_URL})")
149 |
150 | parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL,
151 | help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})")
152 |
153 | parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
154 | help=f"Goal to achieve (default: {DEFAULT_GOAL})")
155 |
156 | parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="lats",
157 | help="Search algorithm to use (default: lats)")
158 |
159 | parser.add_argument("--max-depth", type=int, default=3,
160 | help="Maximum depth for the search tree (default: 3)")
161 |
162 | parser.add_argument("--iterations", type=int, default=5,
163 | help="Number of iterations for LATS algorithm (default: 5)")
164 |
165 | # Add the new argument for log file
166 | parser.add_argument("--log-file", type=str,
167 | help="File to save the colored output to")
168 |
169 | return parser.parse_args()
170 |
171 | async def main():
172 | """Main entry point"""
173 | args = parse_arguments()
174 |
175 | # Setup logging to file if requested
176 | original_stdout = sys.stdout
177 | original_stderr = sys.stderr
178 | log_file = None
179 |
180 | if args.log_file:
181 | class TeeOutput:
182 | def __init__(self, terminal, log_file):
183 | self.terminal = terminal
184 | self.log_file = log_file
185 |
186 | def write(self, message):
187 | self.terminal.write(message)
188 | self.log_file.write(message)
189 |
190 | def flush(self):
191 | self.terminal.flush()
192 | self.log_file.flush()
193 |
194 | log_file = open(args.log_file, 'w', encoding='utf-8')
195 | sys.stdout = TeeOutput(sys.stdout, log_file)
196 | sys.stderr = TeeOutput(sys.stderr, log_file)
197 | logger.info(f"Logging colored output to {args.log_file}")
198 |
199 | logger.info("Starting tree search WebSocket test")
200 | logger.info(f"WebSocket URL: {args.ws_url}")
201 | logger.info(f"Starting URL: {args.starting_url}")
202 | logger.info(f"Goal: {args.goal}")
203 | logger.info(f"Algorithm: {args.algorithm}")
204 | logger.info(f"Max depth: {args.max_depth}")
205 | logger.info(f"Iterations: {args.iterations}")
206 |
207 | try:
208 | await connect_and_test_search(
209 | ws_url=args.ws_url,
210 | starting_url=args.starting_url,
211 | goal=args.goal,
212 | search_algorithm=args.algorithm,
213 | max_depth=args.max_depth,
214 | iterations=args.iterations
215 | )
216 | finally:
217 | # Clean up if logging to file
218 | if log_file:
219 | sys.stdout = original_stdout
220 | sys.stderr = original_stderr
221 | log_file.close()
222 | logger.info(f"Closed log file: {args.log_file}")
223 |
224 | if __name__ == "__main__":
225 | asyncio.run(main())
--------------------------------------------------------------------------------