├── tests ├── regression │ ├── __init__.py │ ├── test_ionwave.py │ ├── test_wichita.py │ └── test_extract_aigrant_companies.py ├── __init__.py ├── unit │ ├── __init__.py │ ├── llm │ │ └── test_llm_integration.py │ ├── handlers │ │ ├── test_act_handler.py │ │ ├── test_observe_handler.py │ │ └── test_extract_handler.py │ ├── core │ │ ├── test_live_page_proxy.py │ │ ├── test_page.py │ │ ├── test_wait_for_settled_dom.py │ │ └── test_frame_id_tracking.py │ ├── test_client_initialization.py │ └── test_client_api.py ├── mocks │ └── __init__.py └── integration │ ├── local │ └── test_core_local.py │ └── api │ ├── test_core_api.py │ └── test_frame_id_integration.py ├── stagehand ├── handlers │ ├── __init__.py │ ├── extract_handler.py │ └── observe_handler.py ├── agent │ ├── __init__.py │ ├── utils.py │ ├── client.py │ └── image_compression_utils.py ├── llm │ ├── __init__.py │ └── client.py ├── a11y │ └── __init__.py ├── types │ ├── llm.py │ ├── __init__.py │ ├── a11y.py │ ├── agent.py │ └── page.py ├── __init__.py ├── metrics.py ├── config.py ├── api.py └── context.py ├── media ├── logo.png ├── qrcode.jpg ├── director_icon.svg ├── dark_license.svg └── light_license.svg ├── requirements.txt ├── MANIFEST.in ├── format ├── format.sh ├── pytest.ini ├── run_tests.sh ├── .gitignore ├── pyproject.toml ├── CACHE_GUIDE.md └── examples └── cache_manager_tool.py /tests/regression/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Tests package for stagehand-python 2 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Unit tests for stagehand-python 2 | -------------------------------------------------------------------------------- /stagehand/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # Initialize handlers package 2 | -------------------------------------------------------------------------------- /stagehand/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent 2 | 3 | __all__ = [] 4 | -------------------------------------------------------------------------------- /media/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srszzw/stagehand-glm/HEAD/media/logo.png -------------------------------------------------------------------------------- /media/qrcode.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srszzw/stagehand-glm/HEAD/media/qrcode.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Installs the package in editable mode with development dependencies 2 | -e .[dev] 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt 3 | global-exclude *.pyc 4 | global-exclude __pycache__ 5 | global-exclude .DS_Store 6 | global-exclude */node_modules/* -------------------------------------------------------------------------------- /stagehand/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import LLMClient 2 | from .inference import extract, observe 3 | from .prompts import ( 4 | build_extract_system_prompt, 5 | build_extract_user_prompt, 6 | build_metadata_prompt, 7 | build_metadata_system_prompt, 8 | build_observe_system_prompt, 9 | build_observe_user_message, 10 | ) 11 | -------------------------------------------------------------------------------- /stagehand/a11y/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | build_hierarchical_tree, 3 | find_scrollable_element_ids, 4 | format_simplified_tree, 5 | get_accessibility_tree, 6 | get_xpath_by_resolved_object_id, 7 | ) 8 | 9 | __all__ = [ 10 | "format_simplified_tree", 11 | "build_hierarchical_tree", 12 | "get_accessibility_tree", 13 | "get_xpath_by_resolved_object_id", 14 | "find_scrollable_element_ids", 15 | ] 16 | -------------------------------------------------------------------------------- /tests/mocks/__init__.py: -------------------------------------------------------------------------------- 1 | """Mock implementations for Stagehand testing""" 2 | 3 | from .mock_llm import MockLLMClient, MockLLMResponse 4 | from .mock_browser import MockBrowser, MockBrowserContext, MockPlaywrightPage 5 | from .mock_server import MockStagehandServer 6 | 7 | __all__ = [ 8 | "MockLLMClient", 9 | "MockLLMResponse", 10 | "MockBrowser", 11 | "MockBrowserContext", 12 | "MockPlaywrightPage", 13 | "MockStagehandServer" 14 | ] -------------------------------------------------------------------------------- /stagehand/agent/utils.py: -------------------------------------------------------------------------------- 1 | def sanitize_message(msg: dict) -> dict: 2 | """Return a copy of the message with image_url omitted for computer_call_output messages.""" 3 | if msg.get("type") == "computer_call_output": 4 | output = msg.get("output", {}) 5 | if isinstance(output, dict): 6 | sanitized = msg.copy() 7 | sanitized["output"] = {**output, "image_url": "[omitted]"} 8 | return sanitized 9 | return msg 10 | -------------------------------------------------------------------------------- /format: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define source directories (adjust as needed) 4 | SOURCE_DIRS="stagehand" 5 | 6 | # Apply Black formatting first 7 | echo "Applying Black formatting..." 8 | black $SOURCE_DIRS 9 | 10 | # Apply Ruff with autofix for all issues (including import sorting) 11 | echo "Applying Ruff autofixes (including import sorting)..." 12 | ruff check --fix $SOURCE_DIRS 13 | 14 | echo "Checking for remaining issues..." 15 | ruff check $SOURCE_DIRS 16 | 17 | echo "Done! Code has been formatted and linted." -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define source directories (adjust as needed) 4 | SOURCE_DIRS="stagehand" 5 | 6 | # Apply Black formatting first 7 | echo "Applying Black formatting..." 8 | black $SOURCE_DIRS 9 | 10 | # Apply Ruff with autofix for all issues (including import sorting) 11 | echo "Applying Ruff autofixes (including import sorting)..." 12 | ruff check --fix $SOURCE_DIRS 13 | 14 | echo "Checking for remaining issues..." 15 | ruff check $SOURCE_DIRS 16 | 17 | echo "Done! Code has been formatted and linted." -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | python_files = test_*.py 4 | python_classes = Test* 5 | python_functions = test_* 6 | asyncio_mode = auto 7 | 8 | markers = 9 | unit: marks tests as unit tests 10 | integration: marks tests as integration tests 11 | smoke: marks tests as smoke tests 12 | local: marks tests as local integration tests 13 | api: marks tests as API integration tests 14 | e2e: marks tests as end-to-end tests 15 | regression: marks tests as regression tests 16 | 17 | log_cli = true 18 | log_cli_level = INFO -------------------------------------------------------------------------------- /stagehand/types/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional, TypedDict, Union 2 | 3 | 4 | class ChatMessageImageUrl(TypedDict): 5 | url: str 6 | 7 | 8 | class ChatMessageSource(TypedDict): 9 | type: str 10 | media_type: str 11 | data: str 12 | 13 | 14 | class ChatMessageImageContent(TypedDict): 15 | type: Literal["image_url"] 16 | image_url: Optional[ChatMessageImageUrl] # Make optional based on TS def 17 | text: Optional[str] # Added based on TS def 18 | source: Optional[ChatMessageSource] # Added based on TS def 19 | 20 | 21 | class ChatMessageTextContent(TypedDict): 22 | type: Literal["text"] 23 | text: str 24 | 25 | 26 | # ChatMessageContent can be a string or a list of text/image content parts 27 | ChatMessageContent = Union[ 28 | str, list[Union[ChatMessageImageContent, ChatMessageTextContent]] 29 | ] 30 | 31 | 32 | # Updated ChatMessage type definition 33 | class ChatMessage(TypedDict): 34 | role: Literal["system", "user", "assistant"] 35 | content: ChatMessageContent 36 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Run tests with coverage reporting 3 | 4 | # Make sure we're in the right directory 5 | cd "$(dirname "$0")" 6 | 7 | # Install dev requirements if needed 8 | if [[ -z $(pip3 list | grep pytest) ]]; then 9 | echo "Installing development requirements..." 10 | pip3 install -r requirements-dev.txt 11 | fi 12 | 13 | # Install package in development mode if needed 14 | if [[ -z $(pip3 list | grep stagehand) ]]; then 15 | echo "Installing stagehand package in development mode..." 16 | pip3 install -e . 17 | fi 18 | 19 | # Run the tests 20 | echo "Running tests with coverage..." 21 | python3 -m pytest tests/ -v --cov=stagehand --cov-report=term --cov-report=html 22 | 23 | echo "Tests complete. HTML coverage report is in htmlcov/ directory." 24 | 25 | # Check if we should open the report 26 | if [[ "$1" == "--open" || "$1" == "-o" ]]; then 27 | echo "Opening HTML coverage report..." 28 | if [[ "$OSTYPE" == "darwin"* ]]; then 29 | # macOS 30 | open htmlcov/index.html 31 | elif [[ "$OSTYPE" == "linux-gnu"* ]]; then 32 | # Linux with xdg-open 33 | xdg-open htmlcov/index.html 34 | elif [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" ]]; then 35 | # Windows 36 | start htmlcov/index.html 37 | else 38 | echo "Couldn't automatically open the report. Please open htmlcov/index.html manually." 39 | fi 40 | fi -------------------------------------------------------------------------------- /stagehand/__init__.py: -------------------------------------------------------------------------------- 1 | """Stagehand - The AI Browser Automation Framework""" 2 | 3 | from importlib.metadata import version as get_version 4 | 5 | from .agent import Agent 6 | from .config import StagehandConfig, default_config 7 | from .handlers.observe_handler import ObserveHandler 8 | from .llm import LLMClient 9 | from .logging import LogConfig, configure_logging 10 | from .main import Stagehand 11 | from .metrics import StagehandFunctionName, StagehandMetrics 12 | from .page import StagehandPage 13 | from .schemas import ( 14 | ActOptions, 15 | ActResult, 16 | AgentConfig, 17 | AgentExecuteOptions, 18 | AgentExecuteResult, 19 | AgentProvider, 20 | ExtractOptions, 21 | ExtractResult, 22 | ObserveOptions, 23 | ObserveResult, 24 | ) 25 | 26 | __version__ = get_version("stagehand") 27 | 28 | __all__ = [ 29 | "Stagehand", 30 | "StagehandConfig", 31 | "StagehandPage", 32 | "Agent", 33 | "AgentConfig", 34 | "AgentExecuteOptions", 35 | "AgentExecuteResult", 36 | "AgentProvider", 37 | "ActOptions", 38 | "ActResult", 39 | "ExtractOptions", 40 | "ExtractResult", 41 | "ObserveOptions", 42 | "ObserveResult", 43 | "ObserveHandler", 44 | "LLMClient", 45 | "configure_logging", 46 | "StagehandFunctionName", 47 | "StagehandMetrics", 48 | "LogConfig", 49 | ] 50 | -------------------------------------------------------------------------------- /stagehand/types/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Exports for accessibility types. 3 | """ 4 | 5 | from .a11y import ( 6 | AccessibilityNode, 7 | AXNode, 8 | AXProperty, 9 | AXValue, 10 | CDPSession, 11 | Locator, 12 | PlaywrightCommandError, 13 | PlaywrightMethodNotSupportedError, 14 | TreeResult, 15 | ) 16 | from .agent import ( 17 | AgentConfig, 18 | ) 19 | from .llm import ( 20 | ChatMessage, 21 | ) 22 | from .page import ( 23 | ActOptions, 24 | ActResult, 25 | DefaultExtractSchema, 26 | EmptyExtractSchema, 27 | ExtractOptions, 28 | ExtractResult, 29 | MetadataSchema, 30 | ObserveElementSchema, 31 | ObserveInferenceSchema, 32 | ObserveOptions, 33 | ObserveResult, 34 | ) 35 | 36 | __all__ = [ 37 | "AXProperty", 38 | "AXValue", 39 | "AXNode", 40 | "AccessibilityNode", 41 | "TreeResult", 42 | "CDPSession", 43 | "Locator", 44 | "PlaywrightCommandError", 45 | "PlaywrightMethodNotSupportedError", 46 | "ChatMessage", 47 | "ObserveElementSchema", 48 | "ObserveInferenceSchema", 49 | "ActOptions", 50 | "ActResult", 51 | "ObserveOptions", 52 | "ObserveResult", 53 | "MetadataSchema", 54 | "DefaultExtractSchema", 55 | "ExtractOptions", 56 | "ExtractResult", 57 | "AgentConfig", 58 | "AgentExecuteOptions", 59 | "AgentResult", 60 | "EmptyExtractSchema", 61 | ] 62 | -------------------------------------------------------------------------------- /media/director_icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /stagehand/metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | 5 | 6 | class StagehandFunctionName(str, Enum): 7 | """Function names for tracking metrics.""" 8 | 9 | ACT = "act" 10 | EXTRACT = "extract" 11 | OBSERVE = "observe" 12 | AGENT = "agent" 13 | 14 | 15 | @dataclass 16 | class StagehandMetrics: 17 | """Metrics for token usage and inference time across different functions.""" 18 | 19 | act_prompt_tokens: int = 0 20 | act_completion_tokens: int = 0 21 | act_inference_time_ms: int = 0 22 | 23 | extract_prompt_tokens: int = 0 24 | extract_completion_tokens: int = 0 25 | extract_inference_time_ms: int = 0 26 | 27 | observe_prompt_tokens: int = 0 28 | observe_completion_tokens: int = 0 29 | observe_inference_time_ms: int = 0 30 | 31 | agent_prompt_tokens: int = 0 32 | agent_completion_tokens: int = 0 33 | agent_inference_time_ms: int = 0 34 | 35 | total_prompt_tokens: int = 0 36 | total_completion_tokens: int = 0 37 | total_inference_time_ms: int = 0 38 | 39 | 40 | def start_inference_timer() -> float: 41 | """Start timing inference latency. 42 | 43 | Returns: 44 | The start time as a float timestamp. 45 | """ 46 | return time.time() 47 | 48 | 49 | def get_inference_time_ms(start_time: float) -> int: 50 | """Get elapsed inference time in milliseconds. 51 | 52 | Args: 53 | start_time: The timestamp when timing started. 54 | 55 | Returns: 56 | The elapsed time in milliseconds. 57 | """ 58 | if start_time == 0: 59 | return 0 60 | return int((time.time() - start_time) * 1000) 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | **/node_modules 5 | /.pnp 6 | .pnp.* 7 | .yarn/* 8 | !.yarn/patches 9 | !.yarn/plugins 10 | !.yarn/releases 11 | !.yarn/versions 12 | 13 | # testing 14 | /coverage 15 | 16 | # next.js 17 | /.next/ 18 | /out/ 19 | 20 | # production 21 | /build 22 | 23 | # misc 24 | .DS_Store 25 | *.pem 26 | 27 | # debug 28 | npm-debug.log* 29 | yarn-debug.log* 30 | yarn-error.log* 31 | 32 | # env files (can opt-in for committing if needed) 33 | .env* 34 | 35 | # vercel 36 | .vercel 37 | 38 | # typescript 39 | *.tsbuildinfo 40 | next-env.d.ts 41 | 42 | # Python bytecode 43 | __pycache__/ 44 | *.py[cod] 45 | *$py.class 46 | 47 | # Python virtual environments 48 | env/ 49 | venv/ 50 | .venv/ 51 | pip-wheel-metadata/ 52 | *.egg-info/ 53 | *.egg 54 | 55 | # Python build directories 56 | build/ 57 | dist/ 58 | develop-eggs/ 59 | .eggs/ 60 | 61 | # Python packaging 62 | MANIFEST 63 | *.manifest 64 | *.spec 65 | 66 | # Installer logs 67 | pip-log.txt 68 | pip-delete-this-directory.txt 69 | 70 | # Unit test / coverage reports 71 | htmlcov/ 72 | .tox/ 73 | .nox/ 74 | .coverage 75 | .coverage.* 76 | .cache 77 | .pytest_cache/ 78 | 79 | # MyPy type checker 80 | .mypy_cache/ 81 | .dmypy.json 82 | dmypy.json 83 | 84 | # Pyre type checker 85 | .pyre/ 86 | 87 | # Jupyter Notebook checkpoints 88 | .ipynb_checkpoints 89 | 90 | # PyCharm project files 91 | .idea/ 92 | 93 | # VSCode settings 94 | .vscode/ 95 | 96 | # Local scripts 97 | scripts/ 98 | 99 | # Logs 100 | *.log 101 | /uv.lock 102 | -------------------------------------------------------------------------------- /stagehand/types/a11y.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, TypedDict, Union 2 | 3 | 4 | class AXProperty(TypedDict): 5 | name: str 6 | value: Any # Can be more specific if needed 7 | 8 | 9 | class AXValue(TypedDict): 10 | type: str 11 | value: Optional[Union[str, int, float, bool]] 12 | 13 | 14 | class AXNode(TypedDict): 15 | nodeId: str 16 | role: Optional[AXValue] 17 | name: Optional[AXValue] 18 | description: Optional[AXValue] 19 | value: Optional[AXValue] 20 | backendDOMNodeId: Optional[int] 21 | parentId: Optional[str] 22 | childIds: Optional[list[str]] 23 | properties: Optional[list[AXProperty]] 24 | 25 | 26 | class AccessibilityNode(TypedDict, total=False): 27 | nodeId: str 28 | role: str 29 | name: Optional[str] 30 | description: Optional[str] 31 | value: Optional[str] 32 | backendDOMNodeId: Optional[int] 33 | parentId: Optional[str] 34 | childIds: Optional[list[str]] 35 | children: Optional[list["AccessibilityNode"]] 36 | properties: Optional[list[AXProperty]] # Assuming structure from AXNode 37 | 38 | 39 | class TreeResult(TypedDict): 40 | tree: list[AccessibilityNode] 41 | simplified: str 42 | iframes: list[AccessibilityNode] # Simplified iframe info 43 | idToUrl: dict[str, str] 44 | 45 | 46 | # Placeholder for Playwright Page/CDPSession/Locator if not using StagehandPage directly 47 | # from playwright.async_api import Page, CDPSession, Locator 48 | # Assuming types are imported if StagehandPage is not used directly 49 | CDPSession = Any # Replace with actual Playwright CDPSession type if needed 50 | Locator = Any # Replace with actual Playwright Locator type if needed 51 | 52 | 53 | # --- Placeholder Exceptions --- 54 | class PlaywrightCommandError(Exception): 55 | pass 56 | 57 | 58 | class PlaywrightMethodNotSupportedError(Exception): 59 | pass 60 | -------------------------------------------------------------------------------- /tests/unit/llm/test_llm_integration.py: -------------------------------------------------------------------------------- 1 | """Test LLM integration functionality including different providers and response handling""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | import json 6 | 7 | from stagehand.llm.client import LLMClient 8 | from tests.mocks.mock_llm import MockLLMClient, MockLLMResponse 9 | from stagehand.logging import StagehandLogger 10 | 11 | 12 | class TestLLMClientInitialization: 13 | """Test LLM client initialization and setup""" 14 | 15 | def test_llm_client_creation_with_openai(self): 16 | """Test LLM client creation with OpenAI provider""" 17 | client = LLMClient( 18 | api_key="test-openai-key", 19 | default_model="gpt-4o", 20 | stagehand_logger=StagehandLogger(), 21 | ) 22 | 23 | assert client.default_model == "gpt-4o" 24 | # Note: api_key is set globally on litellm, not stored on client 25 | 26 | def test_llm_client_creation_with_anthropic(self): 27 | """Test LLM client creation with Anthropic provider""" 28 | client = LLMClient( 29 | api_key="test-anthropic-key", 30 | default_model="claude-3-sonnet", 31 | stagehand_logger=StagehandLogger(), 32 | ) 33 | 34 | assert client.default_model == "claude-3-sonnet" 35 | # Note: api_key is set globally on litellm, not stored on client 36 | 37 | def test_llm_client_with_custom_options(self): 38 | """Test LLM client with custom configuration options""" 39 | client = LLMClient( 40 | api_key="test-key", 41 | default_model="gpt-4o-mini", 42 | stagehand_logger=StagehandLogger(), 43 | ) 44 | 45 | assert client.default_model == "gpt-4o-mini" 46 | # Note: LLMClient doesn't store temperature, max_tokens, timeout as instance attributes 47 | # These are passed as kwargs to the completion method 48 | 49 | 50 | # TODO: let's do these in integration rather than simulation 51 | class TestLLMErrorHandling: 52 | """Test LLM error handling and recovery""" 53 | 54 | @pytest.mark.asyncio 55 | async def test_api_rate_limit_error(self): 56 | """Test handling of API rate limit errors""" 57 | mock_llm = MockLLMClient() 58 | mock_llm.simulate_failure(True, "Rate limit exceeded") 59 | 60 | messages = [{"role": "user", "content": "Test rate limit"}] 61 | 62 | with pytest.raises(Exception) as exc_info: 63 | await mock_llm.completion(messages) 64 | 65 | assert "Rate limit exceeded" in str(exc_info.value) -------------------------------------------------------------------------------- /tests/unit/handlers/test_act_handler.py: -------------------------------------------------------------------------------- 1 | """Test ActHandler functionality for AI-powered action execution""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | 6 | from stagehand.handlers.act_handler import ActHandler 7 | from stagehand.types import ActOptions, ActResult, ObserveResult 8 | from tests.mocks.mock_llm import MockLLMClient, MockLLMResponse 9 | 10 | 11 | class TestActHandlerInitialization: 12 | """Test ActHandler initialization and setup""" 13 | 14 | def test_act_handler_creation(self, mock_stagehand_page): 15 | """Test basic ActHandler creation""" 16 | mock_client = MagicMock() 17 | mock_client.llm = MockLLMClient() 18 | mock_client.logger = MagicMock() 19 | 20 | handler = ActHandler( 21 | mock_stagehand_page, 22 | mock_client, 23 | user_provided_instructions="Test instructions", 24 | self_heal=True 25 | ) 26 | 27 | assert handler.stagehand_page == mock_stagehand_page 28 | assert handler.stagehand == mock_client 29 | assert handler.user_provided_instructions == "Test instructions" 30 | assert handler.self_heal is True 31 | 32 | 33 | class TestActExecution: 34 | """Test action execution functionality""" 35 | 36 | @pytest.mark.smoke 37 | @pytest.mark.asyncio 38 | async def test_act_with_string_action(self, mock_stagehand_page): 39 | """Test executing action with string instruction""" 40 | mock_client = MagicMock() 41 | mock_llm = MockLLMClient() 42 | mock_client.llm = mock_llm 43 | mock_client.start_inference_timer = MagicMock() 44 | mock_client.update_metrics = MagicMock() 45 | mock_client.logger = MagicMock() 46 | 47 | handler = ActHandler(mock_stagehand_page, mock_client, "", True) 48 | 49 | # Mock the observe handler to return a successful result 50 | mock_observe_result = ObserveResult( 51 | selector="xpath=//button[@id='submit-btn']", 52 | description="Submit button", 53 | method="click", 54 | arguments=[] 55 | ) 56 | mock_stagehand_page._observe_handler = MagicMock() 57 | mock_stagehand_page._observe_handler.observe = AsyncMock(return_value=[mock_observe_result]) 58 | 59 | # Mock the playwright method execution 60 | handler._perform_playwright_method = AsyncMock() 61 | 62 | result = await handler.act({"action": "click on the submit button"}) 63 | 64 | assert isinstance(result, ActResult) 65 | assert result.success is True 66 | assert "performed successfully" in result.message 67 | assert result.action == "Submit button" 68 | 69 | 70 | -------------------------------------------------------------------------------- /stagehand/agent/client.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Optional 3 | 4 | # Forward declaration or direct import. Assuming direct import is fine. 5 | # If circular dependency issues arise, a forward declaration string might be needed for CUAHandler type hint. 6 | from ..handlers.cua_handler import CUAHandler 7 | from ..types.agent import AgentAction, AgentConfig, AgentExecuteOptions, AgentResult 8 | 9 | 10 | class AgentClient(ABC): 11 | def __init__( 12 | self, 13 | model: str, 14 | instructions: Optional[str], 15 | config: Optional[AgentConfig], 16 | logger: Any, 17 | handler: CUAHandler, 18 | ): 19 | self.model = model 20 | self.instructions = instructions # System prompt/base instructions 21 | self.config = config if config else AgentConfig() # Ensure config is never None 22 | self.logger = logger 23 | self.handler: CUAHandler = handler # Client holds a reference to the handler 24 | 25 | @abstractmethod 26 | async def run_task( 27 | self, instruction: str, options: Optional[AgentExecuteOptions] 28 | ) -> AgentResult: 29 | """ 30 | Manages the entire multi-step interaction with the CUA provider. 31 | This includes: 32 | - Getting initial page state (screenshot). 33 | - Sending initial messages to the provider. 34 | - Looping through provider responses and actions. 35 | - Calling CUAHandler to perform actions on the page. 36 | - Getting updated page state after actions. 37 | - Formatting and sending results/state back to the provider. 38 | - Returning the final AgentResult. 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | def _format_initial_messages( 44 | self, instruction: str, screenshot_base64: Optional[str] 45 | ) -> list[Any]: 46 | """ 47 | Prepares the initial list of messages to send to the CUA provider. 48 | Specific to each provider's API format. 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def _process_provider_response( 54 | self, response: Any 55 | ) -> tuple[Optional[AgentAction], Optional[str], bool, Optional[str]]: 56 | """ 57 | Parses the raw response from the CUA provider. 58 | Returns: 59 | - AgentAction (if an action is to be performed) 60 | - Reasoning text (if provided by the model) 61 | - Boolean indicating if the task is complete 62 | - Message from the model (if any, e.g., final summary) 63 | """ 64 | pass 65 | 66 | @abstractmethod 67 | def _format_action_feedback( 68 | self, action: AgentAction, action_result: dict, new_screenshot_base64: str 69 | ) -> list[Any]: 70 | """ 71 | Formats the feedback to the provider after an action is performed. 72 | This typically includes the result of the action and the new page state (screenshot). 73 | """ 74 | pass 75 | 76 | @abstractmethod 77 | def format_screenshot(self, screenshot_base64: str) -> Any: 78 | """Format a screenshot for the agent. Takes the base64 encoded screenshot and returns a client-specific message part.""" 79 | pass 80 | 81 | @abstractmethod 82 | def key_to_playwright(self, key: str) -> str: 83 | """Convert a key to a playwright key if needed by the client before creating an AgentAction.""" 84 | pass 85 | -------------------------------------------------------------------------------- /stagehand/agent/image_compression_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | def find_items_with_images(items: list[dict[str, Any]]) -> list[int]: 5 | """ 6 | Finds all items in the conversation history that contain images 7 | 8 | Args: 9 | items: Array of conversation items to check 10 | 11 | Returns: 12 | Array of indices where images were found 13 | """ 14 | items_with_images = [] 15 | 16 | for index, item in enumerate(items): 17 | has_image = False 18 | 19 | if isinstance(item.get("content"), list): 20 | has_image = any( 21 | content_item.get("type") == "tool_result" 22 | and "content" in content_item 23 | and isinstance(content_item["content"], list) 24 | and any( 25 | nested_item.get("type") == "image" 26 | for nested_item in content_item["content"] 27 | if isinstance(nested_item, dict) 28 | ) 29 | for content_item in item["content"] 30 | if isinstance(content_item, dict) 31 | ) 32 | 33 | if has_image: 34 | items_with_images.append(index) 35 | 36 | return items_with_images 37 | 38 | 39 | def compress_conversation_images( 40 | items: list[dict[str, Any]], keep_most_recent_count: int = 2 41 | ) -> dict[str, list[dict[str, Any]]]: 42 | """ 43 | Compresses conversation history by removing images from older items 44 | while keeping the most recent images intact 45 | 46 | Args: 47 | items: Array of conversation items to process 48 | keep_most_recent_count: Number of most recent image-containing items to preserve (default: 2) 49 | 50 | Returns: 51 | Dictionary with processed items 52 | """ 53 | items_with_images = find_items_with_images(items) 54 | 55 | for index, item in enumerate(items): 56 | image_index = -1 57 | if index in items_with_images: 58 | image_index = items_with_images.index(index) 59 | 60 | should_compress = ( 61 | image_index >= 0 62 | and image_index < len(items_with_images) - keep_most_recent_count 63 | ) 64 | 65 | if should_compress: 66 | if isinstance(item.get("content"), list): 67 | new_content = [] 68 | for content_item in item["content"]: 69 | if isinstance(content_item, dict): 70 | if ( 71 | content_item.get("type") == "tool_result" 72 | and "content" in content_item 73 | and isinstance(content_item["content"], list) 74 | and any( 75 | nested_item.get("type") == "image" 76 | for nested_item in content_item["content"] 77 | if isinstance(nested_item, dict) 78 | ) 79 | ): 80 | # Replace the content with a text placeholder 81 | new_content.append( 82 | {**content_item, "content": "screenshot taken"} 83 | ) 84 | else: 85 | new_content.append(content_item) 86 | else: 87 | new_content.append(content_item) 88 | 89 | item["content"] = new_content 90 | 91 | return {"items": items} 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ "setuptools>=61.0", "wheel",] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "stagehand" 7 | version = "0.5.0" 8 | description = "Python SDK for Stagehand" 9 | readme = "README.md" 10 | classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent",] 11 | requires-python = ">=3.9" 12 | dependencies = [ "httpx>=0.24.0", "python-dotenv>=1.0.0", "pydantic>=1.10.0", "playwright>=1.42.1", "requests>=2.31.0", "browserbase>=1.4.0", "rich>=13.7.0", "openai>=1.83.0", "anthropic>=0.51.0", "litellm>=1.72.0",] 13 | [[project.authors]] 14 | name = "Browserbase, Inc." 15 | email = "support@browserbase.com" 16 | 17 | [project.license] 18 | text = "MIT" 19 | 20 | [project.optional-dependencies] 21 | dev = [ "pytest>=7.3.1", "pytest-asyncio>=0.21.0", "pytest-mock>=3.10.0", "pytest-cov>=4.1.0", "black>=23.3.0", "isort>=5.12.0", "mypy>=1.3.0", "ruff", "psutil>=5.9.0",] 22 | agent-cache = [ "opencv-python>=4.8.0", "scikit-image>=0.21.0",] 23 | 24 | [project.urls] 25 | Homepage = "https://github.com/browserbase/stagehand-python" 26 | Repository = "https://github.com/browserbase/stagehand-python" 27 | 28 | [tool.ruff] 29 | line-length = 88 30 | target-version = "py39" 31 | exclude = [ ".git", ".ruff_cache", "__pycache__", "venv", ".venv", "dist", "tests",] 32 | 33 | [tool.black] 34 | line-length = 88 35 | target-version = [ "py39",] 36 | include = "\\.pyi?$" 37 | exclude = "/(\n \\.git\n | \\.hg\n | \\.mypy_cache\n | \\.tox\n | \\.venv\n | _build\n | buck-out\n | build\n | dist\n | __pycache__\n | python-sdk\n)/\n" 38 | skip-string-normalization = false 39 | preview = true 40 | 41 | [tool.isort] 42 | profile = "black" 43 | line_length = 88 44 | multi_line_output = 3 45 | include_trailing_comma = true 46 | force_grid_wrap = 0 47 | use_parentheses = true 48 | ensure_newline_before_comments = true 49 | skip_gitignore = true 50 | skip_glob = [ "**/venv/**", "**/.venv/**", "**/__pycache__/**",] 51 | 52 | [tool.ruff.lint] 53 | select = [ "E", "F", "B", "C4", "UP", "N", "I", "C",] 54 | ignore = [ "E501", "C901",] 55 | fixable = [ "ALL",] 56 | unfixable = [] 57 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 58 | 59 | [tool.ruff.format] 60 | quote-style = "double" 61 | indent-style = "space" 62 | line-ending = "auto" 63 | 64 | [tool.setuptools.package-data] 65 | stagehand = [ "domScripts.js",] 66 | 67 | [tool.pytest.ini_options] 68 | testpaths = [ "tests",] 69 | python_files = [ "test_*.py",] 70 | python_classes = [ "Test*",] 71 | python_functions = [ "test_*",] 72 | asyncio_mode = "auto" 73 | addopts = [ "--cov=stagehand", "--cov-report=html:htmlcov", "--cov-report=term-missing", "--cov-report=xml", "--strict-markers", "--strict-config", "-ra", "--tb=short",] 74 | markers = [ "unit: Unit tests for individual components", "integration: Integration tests requiring multiple components", "e2e: End-to-end tests with full workflows", "slow: Tests that take longer to run", "browserbase: Tests requiring Browserbase connection", "local: Tests for local browser functionality", "llm: Tests involving LLM interactions", "mock: Tests using mock objects only", "performance: Performance and load tests", "smoke: Quick smoke tests for basic functionality",] 75 | filterwarnings = [ "ignore::DeprecationWarning", "ignore::PendingDeprecationWarning", "ignore::UserWarning:pytest_asyncio", "ignore::RuntimeWarning",] 76 | minversion = "7.0" 77 | 78 | [tool.ruff.lint.pep8-naming] 79 | classmethod-decorators = [ "classmethod", "validator",] 80 | 81 | [tool.ruff.lint.per-file-ignores] 82 | "__init__.py" = [ "F401",] 83 | "tests/*" = [ "F401", "F811",] 84 | 85 | [tool.ruff.lint.pydocstyle] 86 | convention = "google" 87 | 88 | [tool.setuptools.packages.find] 89 | where = [ ".",] 90 | include = [ "stagehand*",] 91 | -------------------------------------------------------------------------------- /tests/regression/test_ionwave.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regression test for ionwave functionality. 3 | 4 | This test verifies that navigation actions work correctly by clicking on links, 5 | based on the TypeScript ionwave evaluation. 6 | """ 7 | 8 | import os 9 | import pytest 10 | import pytest_asyncio 11 | 12 | from stagehand import Stagehand, StagehandConfig 13 | 14 | 15 | class TestIonwave: 16 | """Regression test for ionwave functionality""" 17 | 18 | @pytest.fixture(scope="class") 19 | def local_config(self): 20 | """Configuration for LOCAL mode testing""" 21 | return StagehandConfig( 22 | env="LOCAL", 23 | model_name="gpt-4o-mini", 24 | headless=True, 25 | verbose=1, 26 | dom_settle_timeout_ms=2000, 27 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 28 | ) 29 | 30 | @pytest.fixture(scope="class") 31 | def browserbase_config(self): 32 | """Configuration for BROWSERBASE mode testing""" 33 | return StagehandConfig( 34 | env="BROWSERBASE", 35 | api_key=os.getenv("BROWSERBASE_API_KEY"), 36 | project_id=os.getenv("BROWSERBASE_PROJECT_ID"), 37 | model_name="gpt-4o", 38 | headless=False, 39 | verbose=2, 40 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 41 | ) 42 | 43 | @pytest_asyncio.fixture 44 | async def local_stagehand(self, local_config): 45 | """Create a Stagehand instance for LOCAL testing""" 46 | stagehand = Stagehand(config=local_config) 47 | await stagehand.init() 48 | yield stagehand 49 | await stagehand.close() 50 | 51 | @pytest_asyncio.fixture 52 | async def browserbase_stagehand(self, browserbase_config): 53 | """Create a Stagehand instance for BROWSERBASE testing""" 54 | if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")): 55 | pytest.skip("Browserbase credentials not available") 56 | 57 | stagehand = Stagehand(config=browserbase_config) 58 | await stagehand.init() 59 | yield stagehand 60 | await stagehand.close() 61 | 62 | @pytest.mark.asyncio 63 | @pytest.mark.regression 64 | @pytest.mark.local 65 | async def test_ionwave_local(self, local_stagehand): 66 | """ 67 | Regression test: ionwave 68 | 69 | Mirrors the TypeScript ionwave evaluation: 70 | - Navigate to ionwave test site 71 | - Click on "Closed Bids" link 72 | - Verify navigation to closed-bids.html page 73 | """ 74 | stagehand = local_stagehand 75 | 76 | await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/ionwave/") 77 | 78 | result = await stagehand.page.act('Click on "Closed Bids"') 79 | 80 | current_url = stagehand.page.url 81 | expected_url = "https://browserbase.github.io/stagehand-eval-sites/sites/ionwave/closed-bids.html" 82 | 83 | # Test passes if we successfully navigated to the expected URL 84 | assert current_url.startswith(expected_url), f"Expected URL to start with {expected_url}, but got {current_url}" 85 | 86 | @pytest.mark.asyncio 87 | @pytest.mark.regression 88 | @pytest.mark.api 89 | @pytest.mark.skipif( 90 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 91 | reason="Browserbase credentials not available" 92 | ) 93 | async def test_ionwave_browserbase(self, browserbase_stagehand): 94 | """ 95 | Regression test: ionwave (Browserbase) 96 | 97 | Same test as local but running in Browserbase environment. 98 | """ 99 | stagehand = browserbase_stagehand 100 | 101 | await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/ionwave/") 102 | 103 | result = await stagehand.page.act('Click on "Closed Bids"') 104 | 105 | current_url = stagehand.page.url 106 | expected_url = "https://browserbase.github.io/stagehand-eval-sites/sites/ionwave/closed-bids.html" 107 | 108 | # Test passes if we successfully navigated to the expected URL 109 | assert current_url.startswith(expected_url), f"Expected URL to start with {expected_url}, but got {current_url}" -------------------------------------------------------------------------------- /tests/unit/handlers/test_observe_handler.py: -------------------------------------------------------------------------------- 1 | """Test ObserveHandler functionality for AI-powered element observation""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | 6 | from stagehand.handlers.observe_handler import ObserveHandler 7 | from stagehand.schemas import ObserveOptions, ObserveResult 8 | from tests.mocks.mock_llm import MockLLMClient 9 | 10 | 11 | def setup_observe_mocks(mock_stagehand_page): 12 | """Set up common mocks for observe handler tests""" 13 | mock_stagehand_page._wait_for_settled_dom = AsyncMock() 14 | mock_stagehand_page.send_cdp = AsyncMock() 15 | mock_stagehand_page.get_cdp_client = AsyncMock() 16 | 17 | # Mock the accessibility tree and xpath utilities 18 | with patch('stagehand.handlers.observe_handler.get_accessibility_tree') as mock_tree, \ 19 | patch('stagehand.handlers.observe_handler.get_xpath_by_resolved_object_id') as mock_xpath: 20 | 21 | mock_tree.return_value = {"simplified": "mocked tree", "iframes": []} 22 | mock_xpath.return_value = "//button[@id='test']" 23 | 24 | return mock_tree, mock_xpath 25 | 26 | 27 | class TestObserveHandlerInitialization: 28 | """Test ObserveHandler initialization""" 29 | 30 | def test_observe_handler_creation(self, mock_stagehand_page): 31 | """Test basic handler creation""" 32 | mock_client = MagicMock() 33 | mock_client.logger = MagicMock() 34 | 35 | handler = ObserveHandler(mock_stagehand_page, mock_client, "") 36 | 37 | assert handler.stagehand_page == mock_stagehand_page 38 | assert handler.stagehand == mock_client 39 | assert handler.user_provided_instructions == "" 40 | 41 | 42 | class TestObserveExecution: 43 | """Test observe execution and response processing""" 44 | 45 | @pytest.mark.smoke 46 | @pytest.mark.asyncio 47 | async def test_observe_single_element(self, mock_stagehand_page): 48 | """Test observing a single element""" 49 | # Set up mock client with proper LLM response 50 | mock_client = MagicMock() 51 | mock_client.logger = MagicMock() 52 | mock_client.logger.info = MagicMock() 53 | mock_client.logger.debug = MagicMock() 54 | mock_client.start_inference_timer = MagicMock() 55 | mock_client.update_metrics = MagicMock() 56 | 57 | # Create a MockLLMClient instance 58 | mock_llm = MockLLMClient() 59 | mock_client.llm = mock_llm 60 | 61 | # Set up the LLM to return the observe response in the format expected by observe_inference 62 | # The MockLLMClient should return this when the response_type is "observe" 63 | mock_llm.set_custom_response("observe", [ 64 | { 65 | "element_id": 12345, 66 | "description": "Submit button in the form", 67 | "method": "click", 68 | "arguments": [] 69 | } 70 | ]) 71 | 72 | # Mock the CDP and accessibility tree functions 73 | with patch('stagehand.handlers.observe_handler.get_accessibility_tree') as mock_get_tree, \ 74 | patch('stagehand.handlers.observe_handler.get_xpath_by_resolved_object_id') as mock_get_xpath: 75 | 76 | mock_get_tree.return_value = { 77 | "simplified": "[1] button: Submit button", 78 | "iframes": [] 79 | } 80 | mock_get_xpath.return_value = "//button[@id='submit-button']" 81 | 82 | # Mock CDP responses 83 | mock_stagehand_page.send_cdp = AsyncMock(return_value={ 84 | "object": {"objectId": "mock-object-id"} 85 | }) 86 | mock_cdp_client = AsyncMock() 87 | mock_stagehand_page.get_cdp_client = AsyncMock(return_value=mock_cdp_client) 88 | 89 | # Create handler and run observe 90 | handler = ObserveHandler(mock_stagehand_page, mock_client, "") 91 | options = ObserveOptions(instruction="find the submit button") 92 | result = await handler.observe(options) 93 | 94 | # Verify results 95 | assert isinstance(result, list) 96 | assert len(result) == 1 97 | assert isinstance(result[0], ObserveResult) 98 | assert result[0].selector == "xpath=//button[@id='submit-button']" 99 | assert result[0].description == "Submit button in the form" 100 | assert result[0].method == "click" 101 | 102 | # Verify that LLM was called 103 | assert mock_llm.call_count == 1 104 | -------------------------------------------------------------------------------- /tests/integration/local/test_core_local.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_asyncio 3 | import os 4 | 5 | from stagehand import Stagehand, StagehandConfig 6 | 7 | 8 | class TestStagehandLocalIntegration: 9 | """Integration tests for Stagehand Python SDK in LOCAL mode.""" 10 | 11 | @pytest.fixture(scope="class") 12 | def local_config(self): 13 | """Configuration for LOCAL mode testing""" 14 | return StagehandConfig( 15 | env="LOCAL", 16 | model_name="gpt-4o-mini", 17 | headless=True, # Use headless mode for CI 18 | verbose=1, 19 | dom_settle_timeout_ms=2000, 20 | self_heal=True, 21 | wait_for_captcha_solves=False, 22 | system_prompt="You are a browser automation assistant for testing purposes.", 23 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY")}, 24 | use_api=False, 25 | ) 26 | 27 | @pytest_asyncio.fixture 28 | async def stagehand_local(self, local_config): 29 | """Create a Stagehand instance for LOCAL testing""" 30 | stagehand = Stagehand(config=local_config) 31 | await stagehand.init() 32 | yield stagehand 33 | await stagehand.close() 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.integration 37 | @pytest.mark.local 38 | async def test_stagehand_local_initialization(self, stagehand_local): 39 | """Ensure that Stagehand initializes correctly in LOCAL mode.""" 40 | assert stagehand_local._initialized is True 41 | 42 | @pytest.mark.asyncio 43 | @pytest.mark.integration 44 | @pytest.mark.local 45 | async def test_local_observe_and_act_workflow(self, stagehand_local): 46 | """Test core observe and act workflow in LOCAL mode - extracted from e2e tests.""" 47 | stagehand = stagehand_local 48 | 49 | # Navigate to a form page for testing 50 | await stagehand.page.goto("https://httpbin.org/forms/post") 51 | 52 | # Test OBSERVE primitive: Find form elements 53 | form_elements = await stagehand.page.observe("Find all form input elements") 54 | 55 | # Verify observations 56 | assert form_elements is not None 57 | assert len(form_elements) > 0 58 | 59 | # Verify observation structure 60 | for obs in form_elements: 61 | assert hasattr(obs, "selector") 62 | assert obs.selector # Not empty 63 | 64 | # Test ACT primitive: Fill form fields 65 | await stagehand.page.act("Fill the customer name field with 'Local Integration Test'") 66 | await stagehand.page.act("Fill the telephone field with '555-LOCAL'") 67 | await stagehand.page.act("Fill the email field with 'local@integration.test'") 68 | 69 | # Verify actions worked by observing filled fields 70 | filled_fields = await stagehand.page.observe("Find all filled form input fields") 71 | assert filled_fields is not None 72 | assert len(filled_fields) > 0 73 | 74 | # Test interaction with specific elements 75 | customer_field = await stagehand.page.observe("Find the customer name input field") 76 | assert customer_field is not None 77 | assert len(customer_field) > 0 78 | 79 | @pytest.mark.asyncio 80 | @pytest.mark.integration 81 | @pytest.mark.local 82 | async def test_local_basic_navigation_and_observe(self, stagehand_local): 83 | """Test basic navigation and observe functionality in LOCAL mode""" 84 | stagehand = stagehand_local 85 | 86 | # Navigate to a simple page 87 | await stagehand.page.goto("https://example.com") 88 | 89 | # Observe elements on the page 90 | observations = await stagehand.page.observe("Find all the links on the page") 91 | 92 | # Verify we got some observations 93 | assert observations is not None 94 | assert len(observations) > 0 95 | 96 | # Verify observation structure 97 | for obs in observations: 98 | assert hasattr(obs, "selector") 99 | assert obs.selector # Not empty 100 | 101 | @pytest.mark.asyncio 102 | @pytest.mark.integration 103 | @pytest.mark.local 104 | async def test_local_extraction_functionality(self, stagehand_local): 105 | """Test extraction functionality in LOCAL mode""" 106 | stagehand = stagehand_local 107 | 108 | # Navigate to a content-rich page 109 | await stagehand.page.goto("https://news.ycombinator.com") 110 | 111 | # Extract article titles using simple string instruction 112 | articles_text = await stagehand.page.extract( 113 | "Extract the titles of the first 3 articles on the page as a JSON list" 114 | ) 115 | 116 | # Verify extraction worked 117 | assert articles_text is not None -------------------------------------------------------------------------------- /CACHE_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Stagehand 缓存机制使用指南 2 | 3 | ## 🚀 概述 4 | 5 | Stagehand 缓存机制通过智能缓存 LLM 分析结果,显著减少 AI 调用次数,提升自动化性能和降低成本。 6 | 7 | ## ✨ 主要特性 8 | 9 | - 🧠 **智能缓存**: 自动缓存 LLM 分析的元素定位结果 10 | - ⚡ **性能提升**: 减少 70% 的 LLM 调用,提升 3-5 倍执行速度 11 | - 🛡️ **自动验证**: 使用前验证缓存的 XPath 是否仍然有效 12 | - 🔄 **智能降级**: 缓存失效时自动回退到 LLM 分析 13 | - 💾 **持久存储**: 支持文件和内存双层缓存 14 | - 🎯 **精确匹配**: 基于指令、页面URL和标题的复合键匹配 15 | 16 | ## 📖 使用方法 17 | 18 | ### 基础用法 19 | 20 | ```python 21 | # 启用缓存 (默认) 22 | await page.act("点击登录按钮", use_cache=True) 23 | await page.observe("找到用户名输入框", use_cache=True) 24 | 25 | # 禁用缓存 26 | await page.act("点击登录按钮", use_cache=False) 27 | await page.observe("找到用户名输入框", use_cache=False) 28 | ``` 29 | 30 | ### 自定义缓存时间 31 | 32 | ```python 33 | # 设置缓存过期时间 (秒) 34 | await page.act("点击登录按钮", use_cache=True, cache_ttl=3600) # 1小时 35 | await page.observe("找到搜索框", use_cache=True, cache_ttl=1800) # 30分钟 36 | ``` 37 | 38 | ### 不同场景的推荐设置 39 | 40 | ```python 41 | # 登录表单元素 - 长缓存 (2小时) 42 | await page.act("输入用户名", use_cache=True, cache_ttl=7200) 43 | 44 | # 动态内容 - 短缓存 (5分钟) 45 | await page.observe("查找错误信息", use_cache=True, cache_ttl=300) 46 | 47 | # 一次性操作 - 禁用缓存 48 | await page.act("点击验证码刷新", use_cache=False) 49 | ``` 50 | 51 | ## 🔧 缓存管理 52 | 53 | ### 命令行工具 54 | 55 | ```bash 56 | # 查看缓存统计 57 | python cache_manager_tool.py stats 58 | 59 | # 列出所有缓存 60 | python cache_manager_tool.py list 61 | 62 | # 清理过期缓存 63 | python cache_manager_tool.py clear --expired-only 64 | 65 | # 清理所有缓存 66 | python cache_manager_tool.py clear 67 | 68 | # 搜索缓存 69 | python cache_manager_tool.py search "用户名" 70 | 71 | # 导出缓存 72 | python cache_manager_tool.py export backup.json 73 | 74 | # 导入缓存 75 | python cache_manager_tool.py import backup.json 76 | ``` 77 | 78 | ### 程序化管理 79 | 80 | ```python 81 | from stagehand.cache import StagehandCache 82 | 83 | # 创建缓存管理器 84 | cache = StagehandCache() 85 | 86 | # 获取统计信息 87 | stats = cache.get_cache_stats() 88 | print(f"缓存数量: {stats['total_caches']}") 89 | print(f"命中次数: {stats['total_hits']}") 90 | 91 | # 清理过期缓存 92 | cleared = cache.clear_cache(expired_only=True, ttl=3600) 93 | print(f"清理了 {cleared} 条过期缓存") 94 | ``` 95 | 96 | ## 📊 性能对比 97 | 98 | ### 首次运行 vs 缓存命中 99 | 100 | | 操作类型 | 首次运行 | 缓存命中 | 性能提升 | 101 | |---------|---------|---------|---------| 102 | | 简单定位 | 2-3秒 | 0.1-0.2秒 | **15x** | 103 | | 复杂查找 | 5-8秒 | 0.2-0.3秒 | **25x** | 104 | | 表单填充 | 3-4秒 | 0.1秒 | **30x** | 105 | 106 | ### 成本节省 107 | 108 | - 🎯 **LLM调用减少**: 70-90% 109 | - 💰 **API成本降低**: 显著节省token消费 110 | - ⚡ **响应时间**: 提升3-5倍 111 | 112 | ## 🎯 最佳实践 113 | 114 | ### 1. 缓存策略 115 | 116 | ```python 117 | # 稳定元素 - 长缓存 118 | await page.act("点击导航菜单", cache_ttl=86400) # 24小时 119 | 120 | # 动态元素 - 中等缓存 121 | await page.observe("找到商品列表", cache_ttl=3600) # 1小时 122 | 123 | # 临时元素 - 短缓存 124 | await page.observe("查找弹窗", cache_ttl=300) # 5分钟 125 | 126 | # 验证码等 - 禁用缓存 127 | await page.act("输入验证码", use_cache=False) 128 | ``` 129 | 130 | ### 2. 缓存键优化 131 | 132 | 缓存键基于以下信息生成: 133 | - 📝 **指令内容**: "找到用户名输入框" 134 | - 🌐 **页面URL**: "https://example.com/login" 135 | - 📄 **页面标题**: "登录页面" 136 | 137 | 确保指令描述准确和一致: 138 | 139 | ```python 140 | # 推荐 - 具体明确 141 | await page.observe("找到用户名输入框") 142 | await page.observe("找到密码输入框") 143 | await page.observe("找到登录按钮") 144 | 145 | # 不推荐 - 模糊不清 146 | await page.observe("找到输入框") 147 | await page.observe("找到按钮") 148 | ``` 149 | 150 | ### 3. 错误处理 151 | 152 | ```python 153 | try: 154 | # 使用缓存 155 | result = await page.act("点击按钮", use_cache=True) 156 | if not result.success: 157 | # 缓存可能失效,重试不使用缓存 158 | result = await page.act("点击按钮", use_cache=False) 159 | except Exception as e: 160 | print(f"操作失败: {e}") 161 | ``` 162 | 163 | ## 🔍 故障排除 164 | 165 | ### 常见问题 166 | 167 | **Q: 缓存没有命中?** 168 | A: 检查指令是否一致、页面URL是否相同、缓存是否过期 169 | 170 | **Q: 缓存的XPath失效?** 171 | A: 系统会自动验证并重新分析,无需手动处理 172 | 173 | **Q: 如何强制刷新缓存?** 174 | A: 设置 `use_cache=False` 或清理相关缓存 175 | 176 | **Q: 缓存文件太大?** 177 | A: 定期清理过期缓存,或设置较短的TTL 178 | 179 | ### 调试技巧 180 | 181 | ```python 182 | # 启用详细日志 183 | import logging 184 | logging.basicConfig(level=logging.DEBUG) 185 | 186 | # 检查缓存状态 187 | cache_stats = page._observe_handler.cache_manager.get_cache_stats() 188 | print(f"缓存统计: {cache_stats}") 189 | 190 | # 手动验证XPath 191 | xpath_valid = await page._observe_handler.cache_manager.validate_cached_xpath( 192 | page, "xpath=//button[@id='login']" 193 | ) 194 | print(f"XPath有效性: {xpath_valid}") 195 | ``` 196 | 197 | ## 📈 监控和优化 198 | 199 | ### 性能监控 200 | 201 | ```python 202 | import time 203 | 204 | start_time = time.time() 205 | await page.act("执行操作", use_cache=True) 206 | execution_time = time.time() - start_time 207 | 208 | print(f"执行时间: {execution_time:.2f}秒") 209 | ``` 210 | 211 | ### 缓存命中率 212 | 213 | ```python 214 | stats = cache.get_cache_stats() 215 | hit_rate = stats['total_hits'] / max(stats['total_caches'], 1) * 100 216 | print(f"缓存命中率: {hit_rate:.1f}%") 217 | ``` 218 | 219 | ## 🎉 示例项目 220 | 221 | 查看以下示例了解完整用法: 222 | 223 | - `examples/test_cache_functionality.py` - 缓存功能测试 224 | - `examples/admin_login_cached.py` - 带缓存的登录自动化 225 | - `examples/cache_manager_tool.py` - 缓存管理工具 226 | 227 | ## 🤝 贡献 228 | 229 | 欢迎提交改进建议和bug报告!缓存机制是一个持续优化的功能,您的反馈非常宝贵。 230 | -------------------------------------------------------------------------------- /stagehand/types/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal, Optional, Union 2 | 3 | from pydantic import BaseModel, RootModel 4 | 5 | 6 | class AgentConfig(BaseModel): 7 | """ 8 | Configuration for agent execution. 9 | 10 | Attributes: 11 | model (Optional[str]): The model name to use. 12 | instructions (Optional[str]): Custom instructions for the agent (system prompt). 13 | options (Optional[dict[str, Any]]): Additional provider-specific options. 14 | """ 15 | 16 | model: Optional[str] = None 17 | instructions: Optional[str] = None 18 | options: Optional[dict[str, Any]] = None 19 | max_steps: Optional[int] = 20 20 | 21 | 22 | class ClickAction(BaseModel): 23 | type: Literal["click"] 24 | x: int 25 | y: int 26 | button: Optional[Literal["left", "right", "middle", "back", "forward"]] 27 | 28 | 29 | class DoubleClickAction(BaseModel): 30 | type: Literal["double_click", "doubleClick"] 31 | x: int 32 | y: int 33 | 34 | 35 | class TypeAction(BaseModel): 36 | type: Literal["type"] 37 | text: str 38 | x: Optional[int] = None 39 | y: Optional[int] = None 40 | press_enter_after: Optional[bool] = False 41 | clear_before: Optional[bool] = True # 默认清空输入框 42 | 43 | 44 | class KeyPressAction(BaseModel): 45 | type: Literal["keypress"] 46 | keys: list[str] # e.g., ["CONTROL", "A"] 47 | 48 | 49 | class ScrollAction(BaseModel): 50 | type: Literal["scroll"] 51 | x: int 52 | y: int 53 | scroll_x: Optional[int] 54 | scroll_y: Optional[int] 55 | 56 | 57 | class Point(BaseModel): 58 | x: int 59 | y: int 60 | 61 | 62 | class DragAction(BaseModel): 63 | type: Literal["drag"] 64 | path: list[Point] 65 | 66 | 67 | class MoveAction(BaseModel): 68 | type: Literal["move"] 69 | x: int 70 | y: int 71 | 72 | 73 | class WaitAction(BaseModel): 74 | type: Literal["wait"] 75 | miliseconds: Optional[int] = 0 76 | # No specific args, implies a default wait time 77 | 78 | 79 | class ScreenshotAction(BaseModel): 80 | type: Literal["screenshot"] 81 | # No specific args, screenshot is handled by client 82 | 83 | 84 | class FunctionArguments(BaseModel): 85 | url: str 86 | # Add other function arguments as needed 87 | 88 | 89 | class FunctionAction(BaseModel): 90 | type: Literal["function"] 91 | name: str 92 | arguments: Optional[FunctionArguments] 93 | 94 | 95 | class KeyAction(BaseModel): # From Anthropic 96 | type: Literal["key"] 97 | text: str 98 | 99 | 100 | AgentActionType = RootModel[ 101 | Union[ 102 | ClickAction, 103 | DoubleClickAction, 104 | TypeAction, 105 | KeyPressAction, 106 | ScrollAction, 107 | DragAction, 108 | MoveAction, 109 | WaitAction, 110 | ScreenshotAction, 111 | FunctionAction, 112 | KeyAction, 113 | ] 114 | ] 115 | 116 | 117 | class AgentAction(BaseModel): 118 | action_type: str 119 | reasoning: Optional[str] = None 120 | action: AgentActionType 121 | status: Optional[str] = None 122 | step: Optional[list[dict[str, Any]]] = None 123 | 124 | 125 | class AgentUsage(BaseModel): 126 | input_tokens: int 127 | output_tokens: int 128 | inference_time_ms: int 129 | 130 | 131 | class AgentResult(BaseModel): 132 | actions: list[AgentActionType] 133 | message: Optional[str] 134 | usage: Optional[AgentUsage] 135 | completed: bool 136 | 137 | 138 | class ActionExecutionResult(BaseModel): 139 | success: bool 140 | error: Optional[str] 141 | 142 | 143 | class AgentClientOptions(BaseModel): 144 | api_key: Optional[str] 145 | base_url: Optional[str] 146 | model_name: Optional[str] # For specific models like gpt-4, claude-2 147 | max_tokens: Optional[int] 148 | temperature: Optional[float] 149 | wait_between_actions: Optional[int] # in milliseconds 150 | # other client-specific options 151 | 152 | 153 | class AgentHandlerOptions(BaseModel): 154 | model_name: str # e.g., "openai", "anthropic" 155 | client_options: Optional[AgentClientOptions] 156 | user_provided_instructions: Optional[str] 157 | # Add other handler options as needed 158 | 159 | 160 | class AgentExecuteOptions(BaseModel): 161 | """ 162 | Agent execution parameters. 163 | 164 | Attributes: 165 | instruction (str): The instruction to execute. 166 | max_steps (Optional[int]): Maximum number of steps the agent can take. Defaults to 15. 167 | auto_screenshot (Optional[bool]): Whether to automatically capture screenshots after each action. False will let the agent choose when to capture screenshots. Defaults to True. 168 | wait_between_actions (Optional[int]): Milliseconds to wait between actions. 169 | context (Optional[str]): Additional context for the agent. 170 | """ 171 | 172 | instruction: str 173 | max_steps: Optional[int] = 15 174 | auto_screenshot: Optional[bool] = True 175 | wait_between_actions: Optional[int] = 1000 176 | context: Optional[str] = None 177 | 178 | 179 | class EnvState(BaseModel): 180 | # The screenshot in PNG format. 181 | screenshot: bytes 182 | url: str 183 | -------------------------------------------------------------------------------- /media/dark_license.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /media/light_license.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /stagehand/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Literal, Optional 3 | 4 | from browserbase.types import SessionCreateParams as BrowserbaseSessionCreateParams 5 | from pydantic import BaseModel, ConfigDict, Field 6 | 7 | from stagehand.schemas import AvailableModel 8 | 9 | 10 | class StagehandConfig(BaseModel): 11 | """ 12 | Configuration for the Stagehand client. 13 | 14 | Attributes: 15 | env (str): Environment type. 'BROWSERBASE' for remote usage 16 | api_key (Optional[str]): BrowserbaseAPI key for authentication. 17 | project_id (Optional[str]): Browserbase Project identifier. 18 | api_url (Optional[str]): Stagehand API URL. 19 | browserbase_session_create_params (Optional[BrowserbaseSessionCreateParams]): Browserbase session create params. 20 | browserbase_session_id (Optional[str]): Session ID for resuming Browserbase sessions. 21 | model_name (Optional[str]): Name of the model to use. 22 | model_api_key (Optional[str]): Model API key. 23 | logger (Optional[Callable[[Any], None]]): Custom logging function. 24 | verbose (Optional[int]): Verbosity level for logs (1=minimal, 2=medium, 3=detailed). 25 | use_rich_logging (bool): Whether to use Rich for colorized logging. 26 | dom_settle_timeout_ms (Optional[int]): Timeout for DOM to settle (in milliseconds). 27 | enable_caching (Optional[bool]): Enable caching functionality. 28 | self_heal (Optional[bool]): Enable self-healing functionality. 29 | wait_for_captcha_solves (Optional[bool]): Whether to wait for CAPTCHA to be solved. 30 | act_timeout_ms (Optional[int]): Timeout for act commands (in milliseconds). 31 | headless (bool): Run browser in headless mode 32 | system_prompt (Optional[str]): System prompt to use for LLM interactions. 33 | local_browser_launch_options (Optional[dict[str, Any]]): Local browser launch options. 34 | use_api (bool): Whether to use API mode. 35 | experimental (bool): Enable experimental features. 36 | """ 37 | 38 | env: Literal["BROWSERBASE", "LOCAL"] = "BROWSERBASE" 39 | api_key: Optional[str] = Field( 40 | None, alias="apiKey", description="Browserbase API key for authentication" 41 | ) 42 | project_id: Optional[str] = Field( 43 | None, alias="projectId", description="Browserbase project ID" 44 | ) 45 | api_url: Optional[str] = Field( 46 | os.environ.get("STAGEHAND_API_URL", "https://api.stagehand.browserbase.com/v1"), 47 | alias="apiUrl", 48 | description="Stagehand API URL", 49 | ) 50 | model_api_key: Optional[str] = Field( 51 | None, alias="modelApiKey", description="Model API key" 52 | ) 53 | model_api_base: Optional[str] = Field( 54 | None, 55 | alias="modelApiBase", 56 | description="Model API base URL (for OpenAI-compatible endpoints)", 57 | ) 58 | verbose: Optional[int] = Field( 59 | 1, 60 | description="Verbosity level for logs: 0=minimal (ERROR), 1=medium (INFO), 2=detailed (DEBUG)", 61 | ) 62 | logger: Optional[Callable[[Any], None]] = Field( 63 | None, description="Custom logging function" 64 | ) 65 | use_rich_logging: Optional[bool] = Field( 66 | True, description="Whether to use Rich for colorized logging" 67 | ) 68 | dom_settle_timeout_ms: Optional[int] = Field( 69 | 3000, 70 | alias="domSettleTimeoutMs", 71 | description="Timeout for DOM to settle (in ms)", 72 | ) 73 | browserbase_session_create_params: Optional[BrowserbaseSessionCreateParams] = Field( 74 | None, 75 | alias="browserbaseSessionCreateParams", 76 | description="Browserbase session create params", 77 | ) 78 | enable_caching: Optional[bool] = Field( 79 | False, alias="enableCaching", description="Enable caching functionality" 80 | ) 81 | browserbase_session_id: Optional[str] = Field( 82 | None, 83 | alias="browserbaseSessionID", 84 | description="Session ID for resuming Browserbase sessions", 85 | ) 86 | model_name: Optional[str] = Field( 87 | AvailableModel.GPT_4O, alias="modelName", description="Name of the model to use" 88 | ) 89 | self_heal: Optional[bool] = Field( 90 | True, alias="selfHeal", description="Enable self-healing functionality" 91 | ) 92 | wait_for_captcha_solves: Optional[bool] = Field( 93 | False, 94 | alias="waitForCaptchaSolves", 95 | description="Whether to wait for CAPTCHA to be solved", 96 | ) 97 | system_prompt: Optional[str] = Field( 98 | None, 99 | alias="systemPrompt", 100 | description="System prompt to use for LLM interactions", 101 | ) 102 | local_browser_launch_options: Optional[dict[str, Any]] = Field( 103 | {}, 104 | alias="localBrowserLaunchOptions", 105 | description="Local browser launch options", 106 | ) 107 | use_api: Optional[bool] = Field( 108 | True, 109 | alias=None, 110 | description="Whether to use the Stagehand API", 111 | ) 112 | experimental: Optional[bool] = Field( 113 | False, 114 | alias=None, 115 | description="Whether to use experimental features", 116 | ) 117 | 118 | model_config = ConfigDict(populate_by_name=True) 119 | 120 | def with_overrides(self, **overrides) -> "StagehandConfig": 121 | """ 122 | Create a new config instance with the specified overrides. 123 | 124 | Args: 125 | **overrides: Key-value pairs to override in the config 126 | 127 | Returns: 128 | StagehandConfig: New config instance with overrides applied 129 | """ 130 | config_dict = self.model_dump() 131 | config_dict.update(overrides) 132 | return StagehandConfig(**config_dict) 133 | 134 | 135 | # Default configuration instance 136 | default_config = StagehandConfig() 137 | -------------------------------------------------------------------------------- /stagehand/llm/client.py: -------------------------------------------------------------------------------- 1 | """LLM client for model interactions.""" 2 | 3 | from typing import TYPE_CHECKING, Any, Callable, Optional 4 | 5 | import litellm 6 | 7 | from stagehand.metrics import get_inference_time_ms, start_inference_timer 8 | 9 | if TYPE_CHECKING: 10 | from ..logging import StagehandLogger 11 | 12 | 13 | class LLMClient: 14 | """ 15 | Client for making LLM calls using the litellm library. 16 | Provides a simplified interface for chat completions. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | stagehand_logger: "StagehandLogger", 22 | api_key: Optional[str] = None, 23 | default_model: Optional[str] = None, 24 | metrics_callback: Optional[Callable[[Any, int, Optional[str]], None]] = None, 25 | **kwargs: Any, # To catch other potential litellm global settings 26 | ): 27 | """ 28 | Initializes the LiteLLMClient. 29 | 30 | Args: 31 | stagehand_logger: StagehandLogger instance for centralized logging 32 | api_key: An API key for the default provider, if required. 33 | It's often better to set provider-specific environment variables 34 | (e.g., OPENAI_API_KEY, ANTHROPIC_API_KEY) which litellm reads automatically. 35 | Passing api_key here might set litellm.api_key globally, which may 36 | not be desired if using multiple providers. 37 | default_model: The default model to use if none is specified in chat_completion 38 | (e.g., "gpt-4o", "claude-3-opus-20240229"). 39 | metrics_callback: Optional callback to track metrics from responses 40 | **kwargs: Additional global settings for litellm (e.g., api_base). 41 | See litellm documentation for available settings. 42 | """ 43 | self.logger = stagehand_logger 44 | self.default_model = default_model 45 | self.metrics_callback = metrics_callback 46 | 47 | # Warning:Prefer environment variables for specific providers. 48 | if api_key: 49 | litellm.api_key = api_key 50 | 51 | # Apply other global settings if provided 52 | for key, value in kwargs.items(): 53 | if hasattr(litellm, key): 54 | setattr(litellm, key, value) 55 | self.logger.debug(f"Set global litellm.{key}", category="llm") 56 | # Handle common aliases or expected config names if necessary 57 | elif key == "api_base": # Example: map api_base if needed 58 | litellm.api_base = value 59 | self.logger.debug( 60 | f"Set global litellm.api_base to {value}", category="llm" 61 | ) 62 | 63 | def create_response( 64 | self, 65 | *, 66 | messages: list[dict[str, str]], 67 | model: Optional[str] = None, 68 | function_name: Optional[str] = None, 69 | **kwargs: Any, 70 | ) -> dict[str, Any]: 71 | """ 72 | Generate a chat completion response using litellm. 73 | 74 | Args: 75 | messages: A list of message dictionaries, e.g., [{"role": "user", "content": "Hello"}]. 76 | model: The specific model to use (e.g., "gpt-4o", "claude-3-opus-20240229"). 77 | Overrides the default_model if provided. 78 | function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.) 79 | Used for metrics tracking. 80 | **kwargs: Additional parameters to pass directly to litellm.completion 81 | (e.g., temperature, max_tokens, stream=True, specific provider arguments). 82 | 83 | Returns: 84 | A dictionary containing the completion response from litellm, typically 85 | including choices, usage statistics, etc. Structure depends on the model 86 | provider and whether streaming is used. 87 | 88 | Raises: 89 | ValueError: If no model is specified (neither default nor in the call). 90 | Exception: Propagates exceptions from litellm.completion. 91 | """ 92 | completion_model = model or self.default_model 93 | if not completion_model: 94 | raise ValueError( 95 | "No model specified for chat completion (neither default_model nor model argument)." 96 | ) 97 | 98 | # Standardize gemini provider to google 99 | if completion_model.startswith("google/"): 100 | completion_model = completion_model.replace("google/", "gemini/") 101 | 102 | # Prepare arguments directly from kwargs 103 | params = { 104 | "model": completion_model, 105 | "messages": messages, 106 | **kwargs, # Pass through any extra arguments 107 | } 108 | # Filter out None values only for keys explicitly present in kwargs to avoid sending nulls 109 | # unless they were intentionally provided as None. 110 | filtered_params = { 111 | k: v for k, v in params.items() if v is not None or k in kwargs 112 | } 113 | 114 | self.logger.debug( 115 | f"Calling litellm.completion with model={completion_model} and params: {filtered_params}", 116 | category="llm", 117 | ) 118 | 119 | try: 120 | # Start tracking inference time 121 | start_time = start_inference_timer() 122 | 123 | # Use litellm's completion function 124 | response = litellm.completion(**filtered_params) 125 | 126 | # Calculate inference time 127 | inference_time_ms = get_inference_time_ms(start_time) 128 | 129 | # Update metrics if callback is provided 130 | if self.metrics_callback: 131 | self.metrics_callback(response, inference_time_ms, function_name) 132 | 133 | return response 134 | 135 | except Exception as e: 136 | self.logger.error(f"Error calling litellm.completion: {e}", category="llm") 137 | # Consider more specific exception handling based on litellm errors 138 | raise 139 | -------------------------------------------------------------------------------- /examples/cache_manager_tool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stagehand 缓存管理工具 3 | 提供缓存查看、清理、导入导出等功能 4 | """ 5 | 6 | import argparse 7 | import json 8 | import os 9 | from datetime import datetime 10 | 11 | from stagehand.cache import StagehandCache 12 | 13 | 14 | def display_cache_stats(cache_manager: StagehandCache): 15 | """显示缓存统计信息""" 16 | stats = cache_manager.get_cache_stats() 17 | 18 | print("📊 缓存统计信息:") 19 | print("=" * 50) 20 | print(f"总缓存数量: {stats['total_caches']}") 21 | print(f"总命中次数: {stats['total_hits']}") 22 | print(f"内存缓存大小: {stats['memory_cache_size']}") 23 | print(f"缓存文件: {stats['cache_file']}") 24 | print(f"缓存版本: {stats['version']}") 25 | print("=" * 50) 26 | 27 | 28 | def display_cache_details(cache_manager: StagehandCache): 29 | """显示缓存详细信息""" 30 | caches = cache_manager.cache_data.get("caches", {}) 31 | 32 | if not caches: 33 | print("📭 暂无缓存记录") 34 | return 35 | 36 | print(f"📋 缓存详细信息 (共 {len(caches)} 条):") 37 | print("=" * 80) 38 | 39 | for i, (cache_key, cache_item) in enumerate(caches.items(), 1): 40 | print(f"\n[{i}] 缓存记录:") 41 | print(f" 🔑 Key: {cache_key[:16]}...") 42 | print(f" 📝 指令: {cache_item.get('instruction', 'N/A')[:60]}...") 43 | print(f" 🌐 页面: {cache_item.get('page_url', 'N/A')}") 44 | print(f" 🎯 XPath: {cache_item.get('result', {}).get('selector', 'N/A')}") 45 | print(f" 📅 创建时间: {cache_item.get('created_at', 'N/A')}") 46 | print(f" 🔥 命中次数: {cache_item.get('hit_count', 0)}") 47 | print(f" ⏰ 最后使用: {cache_item.get('last_used', 'N/A')}") 48 | 49 | 50 | def clear_cache(cache_manager: StagehandCache, expired_only: bool = False): 51 | """清理缓存""" 52 | if expired_only: 53 | cleared = cache_manager.clear_cache(expired_only=True, ttl=3600) 54 | print(f"🧹 已清理 {cleared} 条过期缓存") 55 | else: 56 | cleared = cache_manager.clear_cache(expired_only=False) 57 | print(f"🧹 已清理所有缓存 ({cleared} 条)") 58 | 59 | 60 | def export_cache(cache_manager: StagehandCache, export_file: str): 61 | """导出缓存到文件""" 62 | try: 63 | with open(export_file, "w", encoding="utf-8") as f: 64 | json.dump(cache_manager.cache_data, f, ensure_ascii=False, indent=2) 65 | print(f"📤 缓存已导出到: {export_file}") 66 | except Exception as e: 67 | print(f"❌ 导出失败: {e}") 68 | 69 | 70 | def import_cache(cache_manager: StagehandCache, import_file: str): 71 | """从文件导入缓存""" 72 | try: 73 | if not os.path.exists(import_file): 74 | print(f"❌ 文件不存在: {import_file}") 75 | return 76 | 77 | with open(import_file, "r", encoding="utf-8") as f: 78 | imported_data = json.load(f) 79 | 80 | # 合并缓存数据 81 | current_caches = cache_manager.cache_data.get("caches", {}) 82 | imported_caches = imported_data.get("caches", {}) 83 | 84 | merged_count = 0 85 | for key, value in imported_caches.items(): 86 | if key not in current_caches: 87 | current_caches[key] = value 88 | merged_count += 1 89 | 90 | cache_manager._save_cache() 91 | print(f"📥 已导入 {merged_count} 条新缓存记录") 92 | 93 | except Exception as e: 94 | print(f"❌ 导入失败: {e}") 95 | 96 | 97 | def search_cache(cache_manager: StagehandCache, keyword: str): 98 | """搜索缓存""" 99 | caches = cache_manager.cache_data.get("caches", {}) 100 | matched_caches = [] 101 | 102 | for cache_key, cache_item in caches.items(): 103 | instruction = cache_item.get("instruction", "").lower() 104 | page_url = cache_item.get("page_url", "").lower() 105 | description = cache_item.get("result", {}).get("description", "").lower() 106 | 107 | if ( 108 | keyword.lower() in instruction 109 | or keyword.lower() in page_url 110 | or keyword.lower() in description 111 | ): 112 | matched_caches.append((cache_key, cache_item)) 113 | 114 | if not matched_caches: 115 | print(f"🔍 未找到包含 '{keyword}' 的缓存记录") 116 | return 117 | 118 | print(f"🔍 找到 {len(matched_caches)} 条匹配记录:") 119 | print("=" * 60) 120 | 121 | for i, (cache_key, cache_item) in enumerate(matched_caches, 1): 122 | print(f"\n[{i}] 匹配记录:") 123 | print(f" 📝 指令: {cache_item.get('instruction', 'N/A')}") 124 | print(f" 🌐 页面: {cache_item.get('page_url', 'N/A')}") 125 | print(f" 🎯 描述: {cache_item.get('result', {}).get('description', 'N/A')}") 126 | print(f" 🔥 命中次数: {cache_item.get('hit_count', 0)}") 127 | 128 | 129 | def main(): 130 | parser = argparse.ArgumentParser(description="Stagehand 缓存管理工具") 131 | parser.add_argument( 132 | "--cache-file", default="stagehand_cache.json", help="缓存文件路径" 133 | ) 134 | 135 | subparsers = parser.add_subparsers(dest="command", help="可用命令") 136 | 137 | # 统计信息 138 | subparsers.add_parser("stats", help="显示缓存统计信息") 139 | 140 | # 详细信息 141 | subparsers.add_parser("list", help="显示缓存详细信息") 142 | 143 | # 清理缓存 144 | clear_parser = subparsers.add_parser("clear", help="清理缓存") 145 | clear_parser.add_argument( 146 | "--expired-only", action="store_true", help="只清理过期缓存" 147 | ) 148 | 149 | # 导出缓存 150 | export_parser = subparsers.add_parser("export", help="导出缓存") 151 | export_parser.add_argument("file", help="导出文件路径") 152 | 153 | # 导入缓存 154 | import_parser = subparsers.add_parser("import", help="导入缓存") 155 | import_parser.add_argument("file", help="导入文件路径") 156 | 157 | # 搜索缓存 158 | search_parser = subparsers.add_parser("search", help="搜索缓存") 159 | search_parser.add_argument("keyword", help="搜索关键词") 160 | 161 | args = parser.parse_args() 162 | 163 | # 创建缓存管理器 164 | cache_manager = StagehandCache(cache_file=args.cache_file) 165 | 166 | if args.command == "stats": 167 | display_cache_stats(cache_manager) 168 | elif args.command == "list": 169 | display_cache_details(cache_manager) 170 | elif args.command == "clear": 171 | clear_cache(cache_manager, args.expired_only) 172 | elif args.command == "export": 173 | export_cache(cache_manager, args.file) 174 | elif args.command == "import": 175 | import_cache(cache_manager, args.file) 176 | elif args.command == "search": 177 | search_cache(cache_manager, args.keyword) 178 | else: 179 | parser.print_help() 180 | 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /stagehand/types/page.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | # Ignore linting error for this class name since it's used as a constant 7 | # ruff: noqa: N801 8 | class DefaultExtractSchema(BaseModel): 9 | extraction: str 10 | 11 | 12 | class EmptyExtractSchema(BaseModel): 13 | page_text: str 14 | 15 | 16 | class ObserveElementSchema(BaseModel): 17 | element_id: int 18 | description: str = Field( 19 | ..., description="A description of the observed element and its purpose." 20 | ) 21 | method: str 22 | arguments: list[str] 23 | 24 | 25 | class ObserveInferenceSchema(BaseModel): 26 | elements: list[ObserveElementSchema] 27 | 28 | 29 | class MetadataSchema(BaseModel): 30 | completed: bool 31 | progress: str 32 | 33 | 34 | class ActOptions(BaseModel): 35 | """ 36 | Options for the 'act' command. 37 | 38 | Attributes: 39 | action (str): The action command to be executed by the AI. 40 | variables (Optional[dict[str, str]]): Key-value pairs for variable substitution. 41 | model_name (Optional[str]): The model to use for processing. 42 | dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle after an action. 43 | timeout_ms (Optional[int]): Timeout for the action in milliseconds. 44 | """ 45 | 46 | action: str = Field(..., description="The action command to be executed by the AI.") 47 | variables: Optional[dict[str, str]] = None 48 | model_name: Optional[str] = None 49 | dom_settle_timeout_ms: Optional[int] = None 50 | timeout_ms: Optional[int] = None 51 | model_client_options: Optional[dict[str, Any]] = None 52 | 53 | 54 | class ActResult(BaseModel): 55 | """ 56 | Result of the 'act' command. 57 | 58 | Attributes: 59 | success (bool): Whether the action was successful. 60 | message (str): Message from the AI about the action. 61 | action (str): The action command that was executed. 62 | """ 63 | 64 | success: bool = Field(..., description="Whether the action was successful.") 65 | message: str = Field(..., description="Message from the AI about the action.") 66 | action: str = Field(description="The action command that was executed.") 67 | 68 | 69 | class ObserveOptions(BaseModel): 70 | """ 71 | Options for the 'observe' command. 72 | 73 | Attributes: 74 | instruction (str): Instruction detailing what the AI should observe. 75 | model_name (Optional[AvailableModel]): The model to use for processing. 76 | draw_overlay (Optional[bool]): Whether to draw an overlay on observed elements. 77 | dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before observation. 78 | """ 79 | 80 | instruction: str = Field( 81 | ..., description="Instruction detailing what the AI should observe." 82 | ) 83 | model_name: Optional[str] = None 84 | draw_overlay: Optional[bool] = None 85 | dom_settle_timeout_ms: Optional[int] = None 86 | model_client_options: Optional[dict[str, Any]] = None 87 | 88 | 89 | class ObserveResult(BaseModel): 90 | """ 91 | Result of the 'observe' command. 92 | 93 | Attributes: 94 | selector (str): The selector of the observed element. 95 | description (str): The description of the observed element. 96 | backend_node_id (Optional[int]): The backend node ID. 97 | method (Optional[str]): The method to execute. 98 | arguments (Optional[list[str]]): The arguments for the method. 99 | """ 100 | 101 | selector: str = Field(..., description="The selector of the observed element.") 102 | description: str = Field( 103 | ..., description="The description of the observed element." 104 | ) 105 | backend_node_id: Optional[int] = None 106 | method: Optional[str] = None 107 | arguments: Optional[list[str]] = None 108 | 109 | def __getitem__(self, key): 110 | """ 111 | Enable dictionary-style access to attributes. 112 | This allows usage like result["selector"] in addition to result.selector 113 | """ 114 | return getattr(self, key) 115 | 116 | 117 | class ExtractOptions(BaseModel): 118 | """ 119 | Options for the 'extract' command. 120 | 121 | Attributes: 122 | instruction (str): Instruction specifying what data to extract using AI. 123 | model_name (Optional[AvailableModel]): The model to use for processing. 124 | selector (Optional[str]): CSS selector to limit extraction to. 125 | schema_definition (Union[dict[str, Any], type[BaseModel]]): A JSON schema or Pydantic model that defines the structure of the expected data. 126 | Note: If passing a Pydantic model, invoke its .model_json_schema() method to ensure the schema is JSON serializable. 127 | use_text_extract (Optional[bool]): Whether to use text-based extraction. 128 | dom_settle_timeout_ms (Optional[int]): Additional time for DOM to settle before extraction. 129 | """ 130 | 131 | instruction: str = Field( 132 | ..., description="Instruction specifying what data to extract using AI." 133 | ) 134 | model_name: Optional[str] = None 135 | selector: Optional[str] = None 136 | # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method 137 | # to convert it to a JSON serializable dictionary before sending it with the extract command. 138 | schema_definition: Union[dict[str, Any], type[BaseModel]] = Field( 139 | default=DefaultExtractSchema, 140 | description="A JSON schema or Pydantic model that defines the structure of the expected data.", 141 | ) 142 | use_text_extract: Optional[bool] = None 143 | dom_settle_timeout_ms: Optional[int] = None 144 | model_client_options: Optional[dict[Any, Any]] = None 145 | 146 | 147 | class ExtractResult(BaseModel): 148 | """ 149 | Result of the 'extract' command. 150 | 151 | The 'data' field will contain the Pydantic model instance if a schema was provided 152 | and validation was successful, otherwise it may contain the raw extracted dictionary. 153 | """ 154 | 155 | data: Optional[Any] = None 156 | 157 | def __getitem__(self, key): 158 | """ 159 | Enable dictionary-style access to attributes. 160 | This allows usage like result["selector"] in addition to result.selector 161 | """ 162 | return getattr(self, key) 163 | -------------------------------------------------------------------------------- /tests/unit/core/test_live_page_proxy.py: -------------------------------------------------------------------------------- 1 | """Test the LivePageProxy functionality""" 2 | 3 | import asyncio 4 | import pytest 5 | from unittest.mock import AsyncMock, MagicMock, patch 6 | from stagehand.main import LivePageProxy, Stagehand 7 | from stagehand.page import StagehandPage 8 | 9 | 10 | @pytest.mark.asyncio 11 | async def test_live_page_proxy_basic_delegation(mock_stagehand_config): 12 | """Test that LivePageProxy delegates to the active page""" 13 | # Create a Stagehand instance 14 | stagehand = Stagehand(config=mock_stagehand_config) 15 | 16 | # Mock page 17 | mock_page = MagicMock(spec=StagehandPage) 18 | mock_page.url = "https://active.com" 19 | mock_page.title = AsyncMock(return_value="Active Page") 20 | 21 | # Set up the page 22 | stagehand._page = mock_page 23 | stagehand._initialized = True 24 | 25 | # Get the proxy 26 | proxy = stagehand.page 27 | 28 | # Test that it delegates to the page 29 | assert proxy.url == "https://active.com" 30 | title = await proxy.title() 31 | assert title == "Active Page" 32 | 33 | 34 | @pytest.mark.asyncio 35 | async def test_live_page_proxy_no_page_fallback(mock_stagehand_config): 36 | """Test that LivePageProxy raises error when no page is set""" 37 | # Create a Stagehand instance 38 | stagehand = Stagehand(config=mock_stagehand_config) 39 | 40 | # No page set 41 | stagehand._page = None 42 | stagehand._initialized = True 43 | 44 | # Get the proxy 45 | proxy = stagehand.page 46 | 47 | # Accessing attributes should raise RuntimeError 48 | with pytest.raises(RuntimeError, match="No active page available"): 49 | _ = proxy.url 50 | 51 | 52 | @pytest.mark.asyncio 53 | async def test_live_page_proxy_page_stability(mock_stagehand_config): 54 | """Test that LivePageProxy waits for page stability on async operations""" 55 | # Create a Stagehand instance 56 | stagehand = Stagehand(config=mock_stagehand_config) 57 | 58 | # Track lock acquisition 59 | lock_acquired = False 60 | lock_released = False 61 | 62 | class TestLock: 63 | async def __aenter__(self): 64 | nonlocal lock_acquired 65 | lock_acquired = True 66 | await asyncio.sleep(0.1) # Simulate some work 67 | return self 68 | 69 | async def __aexit__(self, *args): 70 | nonlocal lock_released 71 | lock_released = True 72 | 73 | stagehand._page_switch_lock = TestLock() 74 | 75 | # Mock page with async method 76 | mock_page = MagicMock(spec=StagehandPage) 77 | mock_page.click = AsyncMock(return_value=None) 78 | 79 | # Set up the page 80 | stagehand._page = mock_page 81 | stagehand._initialized = True 82 | 83 | # Get the proxy 84 | proxy = stagehand.page 85 | 86 | # Call an async method (should wait for stability) 87 | await proxy.click("button") 88 | 89 | # Verify lock was acquired and released 90 | assert lock_acquired 91 | assert lock_released 92 | mock_page.click.assert_called_once_with("button") 93 | 94 | 95 | @pytest.mark.asyncio 96 | async def test_live_page_proxy_navigation_no_stability_check(mock_stagehand_config): 97 | """Test that navigation methods don't wait for page stability""" 98 | # Create a Stagehand instance 99 | stagehand = Stagehand(config=mock_stagehand_config) 100 | 101 | # Track lock acquisition (should not happen) 102 | lock_acquired = False 103 | 104 | class TestLock: 105 | async def __aenter__(self): 106 | nonlocal lock_acquired 107 | lock_acquired = True 108 | return self 109 | 110 | async def __aexit__(self, *args): 111 | pass 112 | 113 | stagehand._page_switch_lock = TestLock() 114 | 115 | # Mock page with navigation methods 116 | mock_page = MagicMock(spec=StagehandPage) 117 | mock_page.goto = AsyncMock(return_value=None) 118 | mock_page.reload = AsyncMock(return_value=None) 119 | mock_page.go_back = AsyncMock(return_value=None) 120 | mock_page.go_forward = AsyncMock(return_value=None) 121 | 122 | # Set up the page 123 | stagehand._page = mock_page 124 | stagehand._initialized = True 125 | 126 | # Get the proxy 127 | proxy = stagehand.page 128 | 129 | # Call navigation methods (should NOT wait for stability) 130 | await proxy.goto("https://example.com") 131 | await proxy.reload() 132 | await proxy.go_back() 133 | await proxy.go_forward() 134 | 135 | # Verify lock was NOT acquired 136 | assert not lock_acquired 137 | 138 | # Verify methods were called 139 | mock_page.goto.assert_called_once_with("https://example.com") 140 | mock_page.reload.assert_called_once() 141 | mock_page.go_back.assert_called_once() 142 | mock_page.go_forward.assert_called_once() 143 | 144 | 145 | @pytest.mark.asyncio 146 | async def test_live_page_proxy_dynamic_page_switching(mock_stagehand_config): 147 | """Test that LivePageProxy dynamically switches between pages""" 148 | # Create a Stagehand instance 149 | stagehand = Stagehand(config=mock_stagehand_config) 150 | 151 | # Mock pages 152 | page1 = MagicMock(spec=StagehandPage) 153 | page1.url = "https://page1.com" 154 | 155 | page2 = MagicMock(spec=StagehandPage) 156 | page2.url = "https://page2.com" 157 | 158 | # Set up initial state 159 | stagehand._page = page1 160 | stagehand._initialized = True 161 | 162 | # Get the proxy 163 | proxy = stagehand.page 164 | 165 | # Initially points to page1 166 | assert proxy.url == "https://page1.com" 167 | 168 | # Switch page 169 | stagehand._page = page2 170 | 171 | # Now points to page2 without creating a new proxy 172 | assert proxy.url == "https://page2.com" 173 | 174 | 175 | def test_live_page_proxy_no_page_error(mock_stagehand_config): 176 | """Test that LivePageProxy raises error when no page is available""" 177 | # Create a Stagehand instance 178 | stagehand = Stagehand(config=mock_stagehand_config) 179 | 180 | # No page set 181 | stagehand._page = None 182 | stagehand._initialized = True 183 | 184 | # Get the proxy 185 | proxy = stagehand.page 186 | 187 | # Accessing attributes should raise RuntimeError 188 | with pytest.raises(RuntimeError, match="No active page available"): 189 | _ = proxy.url 190 | 191 | 192 | def test_live_page_proxy_not_initialized(mock_stagehand_config): 193 | """Test that page property returns None when not initialized""" 194 | # Create a Stagehand instance 195 | stagehand = Stagehand(config=mock_stagehand_config) 196 | stagehand._initialized = False 197 | 198 | # Should return None 199 | assert stagehand.page is None -------------------------------------------------------------------------------- /tests/unit/core/test_page.py: -------------------------------------------------------------------------------- 1 | """Test StagehandPage wrapper functionality and AI primitives""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | from pydantic import BaseModel 6 | 7 | from stagehand.page import StagehandPage 8 | from stagehand.schemas import ( 9 | ActOptions, 10 | ActResult, 11 | ExtractOptions, 12 | ExtractResult, 13 | ObserveOptions, 14 | ObserveResult, 15 | DEFAULT_EXTRACT_SCHEMA 16 | ) 17 | from tests.mocks.mock_browser import MockPlaywrightPage, setup_page_with_content 18 | from tests.mocks.mock_llm import MockLLMClient 19 | 20 | 21 | class TestStagehandPageInitialization: 22 | """Test StagehandPage initialization and setup""" 23 | 24 | def test_page_initialization(self, mock_playwright_page): 25 | """Test basic page initialization""" 26 | mock_client = MagicMock() 27 | mock_client.env = "LOCAL" 28 | mock_client.logger = MagicMock() 29 | 30 | page = StagehandPage(mock_playwright_page, mock_client) 31 | 32 | assert page._page == mock_playwright_page 33 | assert page._stagehand == mock_client 34 | # The fixture creates a MagicMock, not a MockPlaywrightPage 35 | assert hasattr(page._page, 'evaluate') # Check for expected method instead 36 | 37 | def test_page_attribute_forwarding(self, mock_playwright_page): 38 | """Test that page attributes are forwarded to underlying Playwright page""" 39 | mock_client = MagicMock() 40 | mock_client.env = "LOCAL" 41 | mock_client.logger = MagicMock() 42 | 43 | # Ensure keyboard is a regular MagicMock, not AsyncMock 44 | mock_playwright_page.keyboard = MagicMock() 45 | mock_playwright_page.keyboard.press = MagicMock(return_value=None) 46 | 47 | page = StagehandPage(mock_playwright_page, mock_client) 48 | 49 | # Should forward attribute access to underlying page 50 | assert page.url == mock_playwright_page.url 51 | 52 | # Should forward method calls 53 | page.keyboard.press("Enter") 54 | mock_playwright_page.keyboard.press.assert_called_with("Enter") 55 | 56 | 57 | class TestPageNavigation: 58 | """Test page navigation functionality""" 59 | 60 | @pytest.mark.asyncio 61 | async def test_goto_local_mode(self, mock_stagehand_page): 62 | """Test navigation in LOCAL mode""" 63 | mock_stagehand_page._stagehand.env = "LOCAL" 64 | 65 | await mock_stagehand_page.goto("https://example.com") 66 | 67 | # Should call Playwright's goto directly 68 | mock_stagehand_page._page.goto.assert_called_with( 69 | "https://example.com", 70 | referer=None, 71 | timeout=None, 72 | wait_until=None 73 | ) 74 | 75 | @pytest.mark.asyncio 76 | async def test_goto_browserbase_mode(self, mock_stagehand_page): 77 | """Test navigation in BROWSERBASE mode""" 78 | mock_stagehand_page._stagehand.env = "BROWSERBASE" 79 | mock_stagehand_page._stagehand.use_api = True 80 | mock_stagehand_page._stagehand._execute = AsyncMock(return_value={"success": True}) 81 | 82 | lock = AsyncMock() 83 | mock_stagehand_page._stagehand._get_lock_for_session.return_value = lock 84 | 85 | await mock_stagehand_page.goto("https://example.com") 86 | 87 | # Should call server execute method 88 | mock_stagehand_page._stagehand._execute.assert_called_with( 89 | "navigate", 90 | {"url": "https://example.com"} 91 | ) 92 | 93 | 94 | class TestActFunctionality: 95 | """Test the act() method for AI-powered actions""" 96 | 97 | @pytest.mark.asyncio 98 | async def test_act_with_string_instruction_local(self, mock_stagehand_page): 99 | """Test act() with string instruction in LOCAL mode""" 100 | mock_stagehand_page._stagehand.env = "LOCAL" 101 | 102 | # Mock the act handler 103 | mock_act_handler = MagicMock() 104 | mock_act_handler.act = AsyncMock(return_value=ActResult( 105 | success=True, 106 | message="Button clicked successfully", 107 | action="click on submit button" 108 | )) 109 | mock_stagehand_page._act_handler = mock_act_handler 110 | 111 | result = await mock_stagehand_page.act("click on the submit button") 112 | 113 | assert isinstance(result, ActResult) 114 | assert result.success is True 115 | assert "clicked" in result.message 116 | mock_act_handler.act.assert_called_once() 117 | 118 | 119 | class TestObserveFunctionality: 120 | """Test the observe() method for AI-powered element observation""" 121 | 122 | @pytest.mark.asyncio 123 | async def test_observe_with_string_instruction_local(self, mock_stagehand_page): 124 | """Test observe() with string instruction in LOCAL mode""" 125 | mock_stagehand_page._stagehand.env = "LOCAL" 126 | 127 | # Mock the observe handler 128 | mock_observe_handler = MagicMock() 129 | mock_observe_handler.observe = AsyncMock(return_value=[ 130 | ObserveResult( 131 | selector="#submit-btn", 132 | description="Submit button", 133 | backend_node_id=123, 134 | method="click", 135 | arguments=[] 136 | ) 137 | ]) 138 | mock_stagehand_page._observe_handler = mock_observe_handler 139 | 140 | result = await mock_stagehand_page.observe("find the submit button") 141 | 142 | assert isinstance(result, list) 143 | assert len(result) == 1 144 | assert isinstance(result[0], ObserveResult) 145 | assert result[0].selector == "#submit-btn" 146 | mock_observe_handler.observe.assert_called_once() 147 | 148 | 149 | class TestExtractFunctionality: 150 | """Test the extract() method for AI-powered data extraction""" 151 | 152 | @pytest.mark.asyncio 153 | async def test_extract_with_string_instruction_local(self, mock_stagehand_page): 154 | """Test extract() with string instruction in LOCAL mode""" 155 | mock_stagehand_page._stagehand.env = "LOCAL" 156 | 157 | # Mock the extract handler 158 | mock_extract_handler = MagicMock() 159 | mock_extract_result = MagicMock() 160 | mock_extract_result.data = {"title": "Sample Title", "description": "Sample description"} 161 | mock_extract_handler.extract = AsyncMock(return_value=mock_extract_result) 162 | mock_stagehand_page._extract_handler = mock_extract_handler 163 | 164 | result = await mock_stagehand_page.extract("extract the page title") 165 | 166 | assert result == {"title": "Sample Title", "description": "Sample description"} 167 | mock_extract_handler.extract.assert_called_once() 168 | -------------------------------------------------------------------------------- /tests/integration/api/test_core_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import pytest_asyncio 5 | from pydantic import BaseModel, Field 6 | 7 | from stagehand import Stagehand, StagehandConfig 8 | from stagehand.schemas import ExtractOptions 9 | 10 | 11 | class Article(BaseModel): 12 | """Schema for article extraction tests""" 13 | title: str = Field(..., description="The title of the article") 14 | summary: str = Field(None, description="A brief summary or description of the article") 15 | 16 | 17 | class TestStagehandAPIIntegration: 18 | """Integration tests for Stagehand Python SDK in BROWSERBASE API mode""" 19 | 20 | @pytest.fixture(scope="class") 21 | def browserbase_config(self): 22 | """Configuration for BROWSERBASE mode testing""" 23 | return StagehandConfig( 24 | env="BROWSERBASE", 25 | api_key=os.getenv("BROWSERBASE_API_KEY"), 26 | project_id=os.getenv("BROWSERBASE_PROJECT_ID"), 27 | model_name="gpt-4o", 28 | headless=False, 29 | verbose=2, 30 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 31 | ) 32 | 33 | @pytest_asyncio.fixture 34 | async def stagehand_api(self, browserbase_config): 35 | """Create a Stagehand instance for BROWSERBASE API testing""" 36 | if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")): 37 | pytest.skip("Browserbase credentials not available") 38 | 39 | stagehand = Stagehand(config=browserbase_config) 40 | await stagehand.init() 41 | yield stagehand 42 | await stagehand.close() 43 | 44 | @pytest.mark.asyncio 45 | @pytest.mark.integration 46 | @pytest.mark.api 47 | @pytest.mark.skipif( 48 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 49 | reason="Browserbase credentials are not available for API integration tests", 50 | ) 51 | async def test_stagehand_api_initialization(self, stagehand_api): 52 | """Ensure that Stagehand initializes correctly against the Browserbase API.""" 53 | assert stagehand_api.session_id is not None 54 | 55 | @pytest.mark.asyncio 56 | @pytest.mark.integration 57 | @pytest.mark.api 58 | @pytest.mark.skipif( 59 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 60 | reason="Browserbase credentials are not available for API integration tests", 61 | ) 62 | async def test_api_observe_and_act_workflow(self, stagehand_api): 63 | """Test core observe and act workflow in API mode - replicated from local tests.""" 64 | stagehand = stagehand_api 65 | 66 | # Navigate to a form page for testing 67 | await stagehand.page.goto("https://httpbin.org/forms/post") 68 | 69 | # Test OBSERVE primitive: Find form elements 70 | form_elements = await stagehand.page.observe("Find all form input elements") 71 | 72 | # Verify observations 73 | assert form_elements is not None 74 | assert len(form_elements) > 0 75 | 76 | # Verify observation structure 77 | for obs in form_elements: 78 | assert hasattr(obs, "selector") 79 | assert obs.selector # Not empty 80 | 81 | # Test ACT primitive: Fill form fields 82 | await stagehand.page.act("Fill the customer name field with 'API Integration Test'") 83 | await stagehand.page.act("Fill the telephone field with '555-API'") 84 | await stagehand.page.act("Fill the email field with 'api@integration.test'") 85 | 86 | # Verify actions worked by observing filled fields 87 | filled_fields = await stagehand.page.observe("Find all filled form input fields") 88 | assert filled_fields is not None 89 | assert len(filled_fields) > 0 90 | 91 | # Test interaction with specific elements 92 | customer_field = await stagehand.page.observe("Find the customer name input field") 93 | assert customer_field is not None 94 | assert len(customer_field) > 0 95 | 96 | @pytest.mark.asyncio 97 | @pytest.mark.integration 98 | @pytest.mark.api 99 | @pytest.mark.skipif( 100 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 101 | reason="Browserbase credentials are not available for API integration tests", 102 | ) 103 | async def test_api_basic_navigation_and_observe(self, stagehand_api): 104 | """Test basic navigation and observe functionality in API mode - replicated from local tests.""" 105 | stagehand = stagehand_api 106 | 107 | # Navigate to a simple page 108 | await stagehand.page.goto("https://example.com") 109 | 110 | # Observe elements on the page 111 | observations = await stagehand.page.observe("Find all the links on the page") 112 | 113 | # Verify we got some observations 114 | assert observations is not None 115 | assert len(observations) > 0 116 | 117 | # Verify observation structure 118 | for obs in observations: 119 | assert hasattr(obs, "selector") 120 | assert obs.selector # Not empty 121 | 122 | @pytest.mark.asyncio 123 | @pytest.mark.integration 124 | @pytest.mark.api 125 | @pytest.mark.skipif( 126 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 127 | reason="Browserbase credentials are not available for API integration tests", 128 | ) 129 | async def test_api_extraction_functionality(self, stagehand_api): 130 | """Test extraction functionality in API mode - replicated from local tests.""" 131 | stagehand = stagehand_api 132 | 133 | # Navigate to a content-rich page 134 | await stagehand.page.goto("https://news.ycombinator.com") 135 | 136 | # Test simple text-based extraction 137 | titles_text = await stagehand.page.extract( 138 | "Extract the titles of the first 3 articles on the page as a JSON array" 139 | ) 140 | 141 | # Verify extraction worked 142 | assert titles_text is not None 143 | 144 | # Test schema-based extraction 145 | extract_options = ExtractOptions( 146 | instruction="Extract the first article's title and any available summary", 147 | schema_definition=Article 148 | ) 149 | 150 | article_data = await stagehand.page.extract(extract_options) 151 | assert article_data is not None 152 | 153 | # Validate the extracted data structure (Browserbase format) 154 | if hasattr(article_data, 'data') and article_data.data: 155 | # BROWSERBASE mode format 156 | article = Article.model_validate(article_data.data) 157 | assert article.title 158 | assert len(article.title) > 0 159 | elif hasattr(article_data, 'title'): 160 | # Fallback format 161 | article = Article.model_validate(article_data.model_dump()) 162 | assert article.title 163 | assert len(article.title) > 0 164 | 165 | # Verify API session is active 166 | assert stagehand.session_id is not None -------------------------------------------------------------------------------- /stagehand/api.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | from importlib.metadata import PackageNotFoundError, version 4 | from typing import Any 5 | 6 | from .utils import convert_dict_keys_to_camel_case 7 | 8 | __all__ = ["_create_session", "_execute"] 9 | 10 | 11 | async def _create_session(self): 12 | """ 13 | Create a new session by calling /sessions/start on the server. 14 | Depends on browserbase_api_key, browserbase_project_id, and model_api_key. 15 | """ 16 | if not self.browserbase_api_key: 17 | raise ValueError("browserbase_api_key is required to create a session.") 18 | if not self.browserbase_project_id: 19 | raise ValueError("browserbase_project_id is required to create a session.") 20 | if not self.model_api_key: 21 | raise ValueError("model_api_key is required to create a session.") 22 | 23 | browserbase_session_create_params = ( 24 | convert_dict_keys_to_camel_case(self.browserbase_session_create_params) 25 | if self.browserbase_session_create_params 26 | else None 27 | ) 28 | 29 | payload = { 30 | "modelName": self.model_name, 31 | "verbose": 2 if self.verbose == 3 else self.verbose, 32 | "domSettleTimeoutMs": self.dom_settle_timeout_ms, 33 | "browserbaseSessionID": self.session_id, 34 | "browserbaseSessionCreateParams": ( 35 | browserbase_session_create_params 36 | if browserbase_session_create_params 37 | else { 38 | "browserSettings": { 39 | "blockAds": True, 40 | "viewport": { 41 | "width": 1024, 42 | "height": 768, 43 | }, 44 | }, 45 | } 46 | ), 47 | } 48 | 49 | # Add the new parameters if they have values 50 | if hasattr(self, "self_heal") and self.self_heal is not None: 51 | payload["selfHeal"] = self.self_heal 52 | 53 | if ( 54 | hasattr(self, "wait_for_captcha_solves") 55 | and self.wait_for_captcha_solves is not None 56 | ): 57 | payload["waitForCaptchaSolves"] = self.wait_for_captcha_solves 58 | 59 | if hasattr(self, "act_timeout_ms") and self.act_timeout_ms is not None: 60 | payload["actTimeoutMs"] = self.act_timeout_ms 61 | 62 | if hasattr(self, "system_prompt") and self.system_prompt: 63 | payload["systemPrompt"] = self.system_prompt 64 | 65 | if hasattr(self, "model_client_options") and self.model_client_options: 66 | payload["modelClientOptions"] = self.model_client_options 67 | 68 | def get_version(package_str): 69 | try: 70 | result = version(package_str) 71 | except PackageNotFoundError: 72 | self.logger.error(package_str + " not installed") 73 | result = None 74 | return result 75 | 76 | headers = { 77 | "x-bb-api-key": self.browserbase_api_key, 78 | "x-bb-project-id": self.browserbase_project_id, 79 | "x-model-api-key": self.model_api_key, 80 | "Content-Type": "application/json", 81 | "x-sent-at": datetime.now().isoformat(), 82 | "x-language": "python", 83 | "x-sdk-version": get_version("stagehand"), 84 | } 85 | 86 | # async with self._client: 87 | resp = await self._client.post( 88 | f"{self.api_url}/sessions/start", 89 | json=payload, 90 | headers=headers, 91 | ) 92 | if resp.status_code != 200: 93 | raise RuntimeError(f"Failed to create session: {resp.text}") 94 | data = resp.json() 95 | self.logger.debug(f"Session created: {data}") 96 | if not data.get("success") or "sessionId" not in data.get("data", {}): 97 | raise RuntimeError(f"Invalid response format: {resp.text}") 98 | 99 | self.session_id = data["data"]["sessionId"] 100 | 101 | 102 | async def _execute(self, method: str, payload: dict[str, Any]) -> Any: 103 | """ 104 | Internal helper to call /sessions/{session_id}/{method} with the given method and payload. 105 | Streams line-by-line, returning the 'result' from the final message (if any). 106 | """ 107 | headers = { 108 | "x-bb-api-key": self.browserbase_api_key, 109 | "x-bb-project-id": self.browserbase_project_id, 110 | "Content-Type": "application/json", 111 | "Connection": "keep-alive", 112 | "x-sent-at": datetime.now().isoformat(), 113 | # Always enable streaming for better log handling 114 | "x-stream-response": "true", 115 | } 116 | if self.model_api_key: 117 | headers["x-model-api-key"] = self.model_api_key 118 | 119 | # Convert snake_case keys to camelCase for the API 120 | modified_payload = convert_dict_keys_to_camel_case(payload) 121 | 122 | # async with self._client: 123 | try: 124 | # Always use streaming for consistent log handling 125 | async with self._client.stream( 126 | "POST", 127 | f"{self.api_url}/sessions/{self.session_id}/{method}", 128 | json=modified_payload, 129 | headers=headers, 130 | ) as response: 131 | if response.status_code != 200: 132 | error_text = await response.aread() 133 | error_message = error_text.decode("utf-8") 134 | self.logger.error( 135 | f"[HTTP ERROR] Status {response.status_code}: {error_message}" 136 | ) 137 | raise RuntimeError( 138 | f"Request failed with status {response.status_code}: {error_message}" 139 | ) 140 | result = None 141 | 142 | async for line in response.aiter_lines(): 143 | # Skip empty lines 144 | if not line.strip(): 145 | continue 146 | 147 | try: 148 | # Handle SSE-style messages that start with "data: " 149 | if line.startswith("data: "): 150 | line = line[len("data: ") :] 151 | 152 | message = json.loads(line) 153 | # Handle different message types 154 | msg_type = message.get("type") 155 | 156 | if msg_type == "system": 157 | status = message.get("data", {}).get("status") 158 | if status == "error": 159 | error_msg = message.get("data", {}).get( 160 | "error", "Unknown error" 161 | ) 162 | self.logger.error(f"[ERROR] {error_msg}") 163 | raise RuntimeError(f"Server returned error: {error_msg}") 164 | elif status == "finished": 165 | result = message.get("data", {}).get("result") 166 | elif msg_type == "log": 167 | # Process log message using _handle_log 168 | await self._handle_log(message) 169 | else: 170 | # Log any other message types 171 | self.logger.debug(f"[UNKNOWN] Message type: {msg_type}") 172 | except json.JSONDecodeError: 173 | self.logger.error(f"Could not parse line as JSON: {line}") 174 | 175 | # Return the final result 176 | return result 177 | except Exception as e: 178 | self.logger.error(f"[EXCEPTION] {str(e)}") 179 | raise 180 | -------------------------------------------------------------------------------- /tests/unit/core/test_wait_for_settled_dom.py: -------------------------------------------------------------------------------- 1 | """Test the CDP-based _wait_for_settled_dom implementation""" 2 | 3 | import asyncio 4 | import pytest 5 | from unittest.mock import AsyncMock, MagicMock, call 6 | from stagehand.page import StagehandPage 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_wait_for_settled_dom_basic(mock_stagehand_client, mock_playwright_page): 11 | """Test basic functionality of _wait_for_settled_dom""" 12 | # Create a StagehandPage instance 13 | page = StagehandPage(mock_playwright_page, mock_stagehand_client) 14 | 15 | # Mock CDP client 16 | mock_cdp_client = MagicMock() 17 | mock_cdp_client.send = AsyncMock() 18 | mock_cdp_client.on = MagicMock() 19 | mock_cdp_client.remove_listener = MagicMock() 20 | 21 | # Mock get_cdp_client to return our mock 22 | page.get_cdp_client = AsyncMock(return_value=mock_cdp_client) 23 | 24 | # Mock page title to simulate document exists 25 | mock_playwright_page.title = AsyncMock(return_value="Test Page") 26 | 27 | # Create a task that will call _wait_for_settled_dom 28 | async def run_wait(): 29 | await page._wait_for_settled_dom(timeout_ms=1000) 30 | 31 | # Start the wait task 32 | wait_task = asyncio.create_task(run_wait()) 33 | 34 | # Give it a moment to set up event handlers 35 | await asyncio.sleep(0.1) 36 | 37 | # Verify CDP domains were enabled 38 | assert mock_cdp_client.send.call_count >= 3 39 | mock_cdp_client.send.assert_any_call("Network.enable") 40 | mock_cdp_client.send.assert_any_call("Page.enable") 41 | mock_cdp_client.send.assert_any_call("Target.setAutoAttach", { 42 | "autoAttach": True, 43 | "waitForDebuggerOnStart": False, 44 | "flatten": True, 45 | "filter": [ 46 | {"type": "worker", "exclude": True}, 47 | {"type": "shared_worker", "exclude": True}, 48 | ], 49 | }) 50 | 51 | # Verify event handlers were registered 52 | assert mock_cdp_client.on.call_count >= 6 53 | event_names = [call[0][0] for call in mock_cdp_client.on.call_args_list] 54 | assert "Network.requestWillBeSent" in event_names 55 | assert "Network.loadingFinished" in event_names 56 | assert "Network.loadingFailed" in event_names 57 | assert "Network.requestServedFromCache" in event_names 58 | assert "Network.responseReceived" in event_names 59 | assert "Page.frameStoppedLoading" in event_names 60 | 61 | # Cancel the task (it would timeout otherwise) 62 | wait_task.cancel() 63 | try: 64 | await wait_task 65 | except asyncio.CancelledError: 66 | pass 67 | 68 | # Verify event handlers were unregistered 69 | assert mock_cdp_client.remove_listener.call_count >= 6 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_wait_for_settled_dom_with_requests(mock_stagehand_client, mock_playwright_page): 74 | """Test _wait_for_settled_dom with network requests""" 75 | # Create a StagehandPage instance 76 | page = StagehandPage(mock_playwright_page, mock_stagehand_client) 77 | 78 | # Mock CDP client 79 | mock_cdp_client = MagicMock() 80 | mock_cdp_client.send = AsyncMock() 81 | 82 | # Store event handlers 83 | event_handlers = {} 84 | 85 | def mock_on(event_name, handler): 86 | event_handlers[event_name] = handler 87 | 88 | def mock_remove_listener(event_name, handler): 89 | if event_name in event_handlers: 90 | del event_handlers[event_name] 91 | 92 | mock_cdp_client.on = mock_on 93 | mock_cdp_client.remove_listener = mock_remove_listener 94 | 95 | # Mock get_cdp_client to return our mock 96 | page.get_cdp_client = AsyncMock(return_value=mock_cdp_client) 97 | 98 | # Mock page title to simulate document exists 99 | mock_playwright_page.title = AsyncMock(return_value="Test Page") 100 | 101 | # Create a task that will call _wait_for_settled_dom 102 | async def run_wait(): 103 | await page._wait_for_settled_dom(timeout_ms=5000) 104 | 105 | # Start the wait task 106 | wait_task = asyncio.create_task(run_wait()) 107 | 108 | # Give it a moment to set up event handlers 109 | await asyncio.sleep(0.1) 110 | 111 | # Simulate a network request 112 | if "Network.requestWillBeSent" in event_handlers: 113 | event_handlers["Network.requestWillBeSent"]({ 114 | "requestId": "req1", 115 | "type": "Document", 116 | "frameId": "frame1", 117 | "request": {"url": "https://example.com"} 118 | }) 119 | 120 | # Give it a moment 121 | await asyncio.sleep(0.1) 122 | 123 | # The task should still be running (request in flight) 124 | assert not wait_task.done() 125 | 126 | # Finish the request 127 | if "Network.loadingFinished" in event_handlers: 128 | event_handlers["Network.loadingFinished"]({"requestId": "req1"}) 129 | 130 | # Wait for the quiet period (0.5s) plus a bit 131 | await asyncio.sleep(0.6) 132 | 133 | # The task should now be complete 134 | assert wait_task.done() 135 | await wait_task # Should complete without error 136 | 137 | 138 | @pytest.mark.asyncio 139 | async def test_wait_for_settled_dom_timeout(mock_stagehand_client, mock_playwright_page): 140 | """Test _wait_for_settled_dom timeout behavior""" 141 | # Create a StagehandPage instance 142 | page = StagehandPage(mock_playwright_page, mock_stagehand_client) 143 | 144 | # Mock CDP client 145 | mock_cdp_client = MagicMock() 146 | mock_cdp_client.send = AsyncMock() 147 | mock_cdp_client.on = MagicMock() 148 | mock_cdp_client.remove_listener = MagicMock() 149 | 150 | # Mock get_cdp_client to return our mock 151 | page.get_cdp_client = AsyncMock(return_value=mock_cdp_client) 152 | 153 | # Mock page title to simulate document exists 154 | mock_playwright_page.title = AsyncMock(return_value="Test Page") 155 | 156 | # Set a very short timeout 157 | mock_stagehand_client.dom_settle_timeout_ms = 100 158 | 159 | # Run wait with timeout 160 | await page._wait_for_settled_dom() 161 | 162 | # Should complete without error due to timeout 163 | assert True # If we get here, the timeout worked 164 | 165 | 166 | @pytest.mark.asyncio 167 | async def test_wait_for_settled_dom_no_document(mock_stagehand_client, mock_playwright_page): 168 | """Test _wait_for_settled_dom when document doesn't exist initially""" 169 | # Create a StagehandPage instance 170 | page = StagehandPage(mock_playwright_page, mock_stagehand_client) 171 | 172 | # Mock CDP client 173 | mock_cdp_client = MagicMock() 174 | mock_cdp_client.send = AsyncMock() 175 | mock_cdp_client.on = MagicMock() 176 | mock_cdp_client.remove_listener = MagicMock() 177 | 178 | # Mock get_cdp_client to return our mock 179 | page.get_cdp_client = AsyncMock(return_value=mock_cdp_client) 180 | 181 | # Mock page title to throw exception (no document) 182 | mock_playwright_page.title = AsyncMock(side_effect=Exception("No document")) 183 | mock_playwright_page.wait_for_load_state = AsyncMock() 184 | 185 | # Set a short timeout 186 | mock_stagehand_client.dom_settle_timeout_ms = 500 187 | 188 | # Run wait 189 | await page._wait_for_settled_dom() 190 | 191 | # Should have waited for domcontentloaded 192 | mock_playwright_page.wait_for_load_state.assert_called_once_with("domcontentloaded") -------------------------------------------------------------------------------- /tests/regression/test_wichita.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regression test for wichita functionality. 3 | 4 | This test verifies that combination actions (act + extract) work correctly, 5 | based on the TypeScript wichita evaluation. 6 | """ 7 | 8 | import os 9 | import pytest 10 | import pytest_asyncio 11 | from pydantic import BaseModel, Field, ConfigDict 12 | 13 | from stagehand import Stagehand, StagehandConfig 14 | from stagehand.schemas import ExtractOptions, StagehandBaseModel 15 | 16 | 17 | class BidResults(StagehandBaseModel): 18 | """Schema for bid results extraction""" 19 | total_results: str = Field(..., description="The total number of bids that the search produced", alias="totalResults") 20 | 21 | model_config = ConfigDict(populate_by_name=True) # Allow both total_results and totalResults 22 | 23 | 24 | class TestWichita: 25 | """Regression test for wichita functionality""" 26 | 27 | @pytest.fixture(scope="class") 28 | def local_config(self): 29 | """Configuration for LOCAL mode testing""" 30 | return StagehandConfig( 31 | env="LOCAL", 32 | model_name="gpt-4o-mini", 33 | headless=True, 34 | verbose=1, 35 | dom_settle_timeout_ms=2000, 36 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 37 | ) 38 | 39 | @pytest.fixture(scope="class") 40 | def browserbase_config(self): 41 | """Configuration for BROWSERBASE mode testing""" 42 | return StagehandConfig( 43 | env="BROWSERBASE", 44 | api_key=os.getenv("BROWSERBASE_API_KEY"), 45 | project_id=os.getenv("BROWSERBASE_PROJECT_ID"), 46 | model_name="gpt-4o", 47 | headless=False, 48 | verbose=2, 49 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 50 | ) 51 | 52 | @pytest_asyncio.fixture 53 | async def local_stagehand(self, local_config): 54 | """Create a Stagehand instance for LOCAL testing""" 55 | stagehand = Stagehand(config=local_config) 56 | await stagehand.init() 57 | yield stagehand 58 | await stagehand.close() 59 | 60 | @pytest_asyncio.fixture 61 | async def browserbase_stagehand(self, browserbase_config): 62 | """Create a Stagehand instance for BROWSERBASE testing""" 63 | if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")): 64 | pytest.skip("Browserbase credentials not available") 65 | 66 | stagehand = Stagehand(config=browserbase_config) 67 | await stagehand.init() 68 | yield stagehand 69 | await stagehand.close() 70 | 71 | @pytest.mark.asyncio 72 | @pytest.mark.regression 73 | @pytest.mark.local 74 | async def test_wichita_local(self, local_stagehand): 75 | """ 76 | Regression test: wichita 77 | 78 | Mirrors the TypeScript wichita evaluation: 79 | - Navigate to Wichita Falls TX government bids page 80 | - Click on "Show Closed/Awarded/Cancelled bids" 81 | - Extract the total number of bids 82 | - Verify the count is within expected range (updated range: 400-430 to accommodate recent values) 83 | """ 84 | stagehand = local_stagehand 85 | 86 | await stagehand.page.goto("https://www.wichitafallstx.gov/Bids.aspx") 87 | 88 | # Click to show closed/awarded/cancelled bids 89 | await stagehand.page.act('Click on "Show Closed/Awarded/Cancelled bids"') 90 | 91 | # Extract the total number of results using proper Python schema-based extraction 92 | extract_options = ExtractOptions( 93 | instruction="Extract the total number of bids that the search produced.", 94 | schema_definition=BidResults 95 | ) 96 | 97 | result = await stagehand.page.extract(extract_options) 98 | 99 | # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly 100 | total_results = result.total_results 101 | 102 | # Ensure we got some result 103 | assert total_results is not None, f"Failed to extract total_results from the page. Result: {result}" 104 | 105 | # Parse the number from the result with better extraction 106 | import re 107 | numbers = re.findall(r'\d+', str(total_results)) 108 | assert numbers, f"No numbers found in extracted result: {total_results}" 109 | 110 | # Get the largest number (assuming it's the total count) 111 | extracted_number = max(int(num) for num in numbers) 112 | 113 | # Updated range to accommodate recent results (417 observed consistently) 114 | # Expanding from 405 ± 10 to 400-430 to be more realistic 115 | min_expected = 400 116 | max_expected = 430 117 | 118 | # Check if the number is within the updated range 119 | is_within_range = min_expected <= extracted_number <= max_expected 120 | 121 | assert is_within_range, ( 122 | f"Total number of results {extracted_number} is not within the expected range " 123 | f"{min_expected}-{max_expected}. Raw extraction result: {total_results}" 124 | ) 125 | 126 | @pytest.mark.asyncio 127 | @pytest.mark.regression 128 | @pytest.mark.api 129 | @pytest.mark.skipif( 130 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 131 | reason="Browserbase credentials not available" 132 | ) 133 | async def test_wichita_browserbase(self, browserbase_stagehand): 134 | """ 135 | Regression test: wichita (Browserbase) 136 | 137 | Same test as local but running in Browserbase environment. 138 | """ 139 | stagehand = browserbase_stagehand 140 | 141 | await stagehand.page.goto("https://www.wichitafallstx.gov/Bids.aspx") 142 | 143 | # Click to show closed/awarded/cancelled bids 144 | await stagehand.page.act('Click on "Show Closed/Awarded/Cancelled bids"') 145 | 146 | # Extract the total number of results using proper Python schema-based extraction 147 | extract_options = ExtractOptions( 148 | instruction="Extract the total number of bids that the search produced.", 149 | schema_definition=BidResults 150 | ) 151 | 152 | result = await stagehand.page.extract(extract_options) 153 | 154 | # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly 155 | total_results = result.total_results 156 | 157 | # Ensure we got some result 158 | assert total_results is not None, f"Failed to extract total_results from the page. Result: {result}" 159 | 160 | # Parse the number from the result with better extraction 161 | import re 162 | numbers = re.findall(r'\d+', str(total_results)) 163 | assert numbers, f"No numbers found in extracted result: {total_results}" 164 | 165 | # Get the largest number (assuming it's the total count) 166 | extracted_number = max(int(num) for num in numbers) 167 | 168 | # Updated range to accommodate recent results (417 observed consistently) 169 | # Expanding from 405 ± 10 to 400-430 to be more realistic 170 | min_expected = 400 171 | max_expected = 430 172 | 173 | # Check if the number is within the updated range 174 | is_within_range = min_expected <= extracted_number <= max_expected 175 | 176 | assert is_within_range, ( 177 | f"Total number of results {extracted_number} is not within the expected range " 178 | f"{min_expected}-{max_expected}. Raw extraction result: {total_results}" 179 | ) -------------------------------------------------------------------------------- /stagehand/handlers/extract_handler.py: -------------------------------------------------------------------------------- 1 | """Extract handler for performing data extraction from page elements using LLMs.""" 2 | 3 | from typing import Optional, TypeVar 4 | 5 | from pydantic import BaseModel 6 | 7 | from stagehand.a11y.utils import get_accessibility_tree 8 | from stagehand.llm.inference import extract as extract_inference 9 | from stagehand.metrics import StagehandFunctionName # Changed import location 10 | from stagehand.types import ( 11 | DefaultExtractSchema, 12 | EmptyExtractSchema, 13 | ExtractOptions, 14 | ExtractResult, 15 | ) 16 | from stagehand.utils import inject_urls, transform_url_strings_to_ids 17 | 18 | T = TypeVar("T", bound=BaseModel) 19 | 20 | 21 | class ExtractHandler: 22 | """Handler for processing extract operations locally.""" 23 | 24 | def __init__( 25 | self, stagehand_page, stagehand_client, user_provided_instructions=None 26 | ): 27 | """ 28 | Initialize the ExtractHandler. 29 | 30 | Args: 31 | stagehand_page: StagehandPage instance 32 | stagehand_client: Stagehand client instance 33 | user_provided_instructions: Optional custom system instructions 34 | """ 35 | self.stagehand_page = stagehand_page 36 | self.stagehand = stagehand_client 37 | self.logger = stagehand_client.logger 38 | self.user_provided_instructions = user_provided_instructions 39 | 40 | async def extract( 41 | self, 42 | options: Optional[ExtractOptions] = None, 43 | schema: Optional[type[BaseModel]] = None, 44 | ) -> ExtractResult: 45 | """ 46 | Execute an extraction operation locally. 47 | 48 | Args: 49 | options: ExtractOptions containing the instruction and other parameters 50 | schema: Optional Pydantic model for structured output 51 | 52 | Returns: 53 | ExtractResult instance 54 | """ 55 | if not options: 56 | # If no options provided, extract the entire page text 57 | self.logger.info("Extracting entire page text") 58 | return await self._extract_page_text() 59 | 60 | instruction = options.instruction 61 | # TODO add targeted extract 62 | # selector = options.selector 63 | 64 | # TODO: add schema to log 65 | self.logger.debug( 66 | "extract", 67 | category="extract", 68 | auxiliary={"instruction": instruction}, 69 | ) 70 | self.logger.info( 71 | f"Starting extraction with instruction: '{instruction}'", category="extract" 72 | ) 73 | 74 | # Start inference timer if available 75 | if hasattr(self.stagehand, "start_inference_timer"): 76 | self.stagehand.start_inference_timer() 77 | 78 | # Wait for DOM to settle 79 | await self.stagehand_page._wait_for_settled_dom() 80 | 81 | # TODO add targeted extract 82 | # target_xpath = ( 83 | # selector.replace("xpath=", "") 84 | # if selector and selector.startswith("xpath=") 85 | # else "" 86 | # ) 87 | 88 | # Get accessibility tree data 89 | tree = await get_accessibility_tree(self.stagehand_page, self.logger) 90 | self.logger.info("Getting accessibility tree data") 91 | output_string = tree["simplified"] 92 | id_to_url_mapping = tree.get("idToUrl", {}) 93 | 94 | # Transform schema URL fields to numeric IDs if necessary 95 | transformed_schema = schema 96 | url_paths = [] 97 | if schema: 98 | # TODO: Remove this once we have a better way to handle URLs 99 | transformed_schema, url_paths = transform_url_strings_to_ids(schema) 100 | else: 101 | schema = transformed_schema = DefaultExtractSchema 102 | 103 | # Use inference to call the LLM 104 | extraction_response = extract_inference( 105 | instruction=instruction, 106 | tree_elements=output_string, 107 | schema=transformed_schema, 108 | llm_client=self.stagehand.llm, 109 | user_provided_instructions=self.user_provided_instructions, 110 | logger=self.logger, 111 | log_inference_to_file=False, # TODO: Implement logging to file if needed 112 | ) 113 | 114 | # Extract metrics from response and update them directly 115 | prompt_tokens = extraction_response.get("prompt_tokens", 0) 116 | completion_tokens = extraction_response.get("completion_tokens", 0) 117 | inference_time_ms = extraction_response.get("inference_time_ms", 0) 118 | 119 | # Update metrics directly using the Stagehand client 120 | self.stagehand.update_metrics( 121 | StagehandFunctionName.EXTRACT, 122 | prompt_tokens, 123 | completion_tokens, 124 | inference_time_ms, 125 | ) 126 | 127 | # Process extraction response 128 | raw_data_dict = extraction_response.get("data", {}) 129 | metadata = extraction_response.get("metadata", {}) 130 | 131 | self.logger.info( 132 | f"DEBUG: extraction_response keys: {list(extraction_response.keys())}" 133 | ) 134 | self.logger.info(f"DEBUG: raw_data_dict: {raw_data_dict}") 135 | self.logger.info(f"DEBUG: metadata: {metadata}") 136 | self.logger.info(f"DEBUG: schema: {schema}") 137 | 138 | # Inject URLs back into result if necessary 139 | if url_paths: 140 | inject_urls( 141 | raw_data_dict, url_paths, id_to_url_mapping 142 | ) # Modifies raw_data_dict in place 143 | 144 | if metadata.get("completed"): 145 | self.logger.debug( 146 | "Extraction completed successfully", 147 | auxiliary={"result": raw_data_dict}, 148 | ) 149 | else: 150 | self.logger.debug( 151 | "Extraction incomplete after processing all data", 152 | auxiliary={"result": raw_data_dict}, 153 | ) 154 | 155 | processed_data_payload = raw_data_dict # Default to the raw dictionary 156 | 157 | if schema and isinstance( 158 | raw_data_dict, dict 159 | ): # schema is the Pydantic model type 160 | try: 161 | self.logger.info( 162 | f"DEBUG: Attempting to validate raw_data_dict: {raw_data_dict}" 163 | ) 164 | self.logger.info(f"DEBUG: Against schema: {schema}") 165 | validated_model_instance = schema.model_validate(raw_data_dict) 166 | self.logger.info( 167 | f"DEBUG: Validation successful: {validated_model_instance}" 168 | ) 169 | processed_data_payload = validated_model_instance # Payload is now the Pydantic model instance 170 | except Exception as e: 171 | self.logger.error( 172 | f"Failed to validate extracted data against schema {schema.__name__}: {e}. Keeping raw data dict in .data field." 173 | ) 174 | self.logger.info( 175 | f"DEBUG: Validation failed, raw_data_dict: {raw_data_dict}" 176 | ) 177 | 178 | # Create ExtractResult object 179 | result = ExtractResult( 180 | data=processed_data_payload, 181 | ) 182 | 183 | return result 184 | 185 | async def _extract_page_text(self) -> ExtractResult: 186 | """Extract just the text content from the page.""" 187 | await self.stagehand_page._wait_for_settled_dom() 188 | 189 | tree = await get_accessibility_tree(self.stagehand_page, self.logger) 190 | output_string = tree["simplified"] 191 | output_dict = {"page_text": output_string} 192 | validated_model = EmptyExtractSchema.model_validate(output_dict) 193 | return ExtractResult(data=validated_model).data 194 | -------------------------------------------------------------------------------- /tests/unit/handlers/test_extract_handler.py: -------------------------------------------------------------------------------- 1 | """Test ExtractHandler functionality for AI-powered data extraction""" 2 | 3 | import pytest 4 | from unittest.mock import AsyncMock, MagicMock, patch 5 | from pydantic import BaseModel 6 | 7 | from stagehand.handlers.extract_handler import ExtractHandler 8 | from stagehand.types import ExtractOptions, ExtractResult, DefaultExtractSchema 9 | from tests.mocks.mock_llm import MockLLMClient, MockLLMResponse 10 | 11 | 12 | class TestExtractHandlerInitialization: 13 | """Test ExtractHandler initialization and setup""" 14 | 15 | def test_extract_handler_creation(self, mock_stagehand_page): 16 | """Test basic ExtractHandler creation""" 17 | mock_client = MagicMock() 18 | mock_client.llm = MockLLMClient() 19 | 20 | handler = ExtractHandler( 21 | mock_stagehand_page, 22 | mock_client, 23 | user_provided_instructions="Test extraction instructions" 24 | ) 25 | 26 | assert handler.stagehand_page == mock_stagehand_page 27 | assert handler.stagehand == mock_client 28 | assert handler.user_provided_instructions == "Test extraction instructions" 29 | 30 | 31 | class TestExtractExecution: 32 | """Test data extraction functionality""" 33 | 34 | @pytest.mark.asyncio 35 | async def test_extract_with_default_schema(self, mock_stagehand_page): 36 | """Test extracting data with default schema""" 37 | mock_client = MagicMock() 38 | mock_llm = MockLLMClient() 39 | mock_client.llm = mock_llm 40 | mock_client.start_inference_timer = MagicMock() 41 | mock_client.update_metrics = MagicMock() 42 | 43 | handler = ExtractHandler(mock_stagehand_page, mock_client, "") 44 | 45 | # Mock page content 46 | mock_stagehand_page._page.content = AsyncMock(return_value="Sample content") 47 | 48 | # Mock extract_inference 49 | with patch('stagehand.handlers.extract_handler.extract_inference') as mock_extract_inference: 50 | mock_extract_inference.return_value = { 51 | "data": {"extraction": "Sample extracted text from the page"}, 52 | "metadata": {"completed": True}, 53 | "prompt_tokens": 100, 54 | "completion_tokens": 50, 55 | "inference_time_ms": 1000 56 | } 57 | 58 | # Also need to mock _wait_for_settled_dom 59 | mock_stagehand_page._wait_for_settled_dom = AsyncMock() 60 | 61 | options = ExtractOptions(instruction="extract the main content") 62 | result = await handler.extract(options) 63 | 64 | assert isinstance(result, ExtractResult) 65 | # The handler should now properly populate the result with extracted data 66 | assert result.data is not None 67 | # The handler returns a validated Pydantic model instance, not a raw dict 68 | assert isinstance(result.data, DefaultExtractSchema) 69 | assert result.data.extraction == "Sample extracted text from the page" 70 | 71 | # Verify the mocks were called 72 | mock_extract_inference.assert_called_once() 73 | 74 | @pytest.mark.asyncio 75 | async def test_extract_with_no_schema_returns_default_schema(self, mock_stagehand_page): 76 | """Test extracting data with no schema returns DefaultExtractSchema instance""" 77 | mock_client = MagicMock() 78 | mock_llm = MockLLMClient() 79 | mock_client.llm = mock_llm 80 | mock_client.start_inference_timer = MagicMock() 81 | mock_client.update_metrics = MagicMock() 82 | 83 | handler = ExtractHandler(mock_stagehand_page, mock_client, "") 84 | mock_stagehand_page._page.content = AsyncMock(return_value="Sample content") 85 | 86 | # Mock extract_inference - return data compatible with DefaultExtractSchema 87 | with patch('stagehand.handlers.extract_handler.extract_inference') as mock_extract_inference: 88 | mock_extract_inference.return_value = { 89 | "data": {"extraction": "Sample extracted text from the page"}, 90 | "metadata": {"completed": True}, 91 | "prompt_tokens": 100, 92 | "completion_tokens": 50, 93 | "inference_time_ms": 1000 94 | } 95 | 96 | mock_stagehand_page._wait_for_settled_dom = AsyncMock() 97 | 98 | options = ExtractOptions(instruction="extract the main content") 99 | # No schema parameter passed - should use DefaultExtractSchema 100 | result = await handler.extract(options) 101 | 102 | assert isinstance(result, ExtractResult) 103 | assert result.data is not None 104 | # Should return DefaultExtractSchema instance 105 | assert isinstance(result.data, DefaultExtractSchema) 106 | assert result.data.extraction == "Sample extracted text from the page" 107 | 108 | # Verify the mocks were called 109 | mock_extract_inference.assert_called_once() 110 | 111 | @pytest.mark.asyncio 112 | async def test_extract_with_pydantic_model_returns_validated_model(self, mock_stagehand_page): 113 | """Test extracting data with custom Pydantic model returns validated model instance""" 114 | mock_client = MagicMock() 115 | mock_llm = MockLLMClient() 116 | mock_client.llm = mock_llm 117 | mock_client.start_inference_timer = MagicMock() 118 | mock_client.update_metrics = MagicMock() 119 | 120 | class ProductModel(BaseModel): 121 | name: str 122 | price: float 123 | in_stock: bool = True 124 | 125 | handler = ExtractHandler(mock_stagehand_page, mock_client, "") 126 | mock_stagehand_page._page.content = AsyncMock(return_value="Product page") 127 | 128 | # Mock transform_url_strings_to_ids to avoid the subscripted generics bug 129 | with patch('stagehand.handlers.extract_handler.transform_url_strings_to_ids') as mock_transform: 130 | mock_transform.return_value = (ProductModel, []) 131 | 132 | # Mock extract_inference - return data compatible with ProductModel 133 | with patch('stagehand.handlers.extract_handler.extract_inference') as mock_extract_inference: 134 | mock_extract_inference.return_value = { 135 | "data": { 136 | "name": "Wireless Mouse", 137 | "price": 29.99, 138 | "in_stock": True 139 | }, 140 | "metadata": {"completed": True}, 141 | "prompt_tokens": 150, 142 | "completion_tokens": 80, 143 | "inference_time_ms": 1200 144 | } 145 | 146 | mock_stagehand_page._wait_for_settled_dom = AsyncMock() 147 | 148 | options = ExtractOptions(instruction="extract product details") 149 | # Pass ProductModel as schema parameter - should return ProductModel instance 150 | result = await handler.extract(options, ProductModel) 151 | 152 | assert isinstance(result, ExtractResult) 153 | assert result.data is not None 154 | # Should return ProductModel instance due to validation 155 | assert isinstance(result.data, ProductModel) 156 | assert result.data.name == "Wireless Mouse" 157 | assert result.data.price == 29.99 158 | assert result.data.in_stock is True 159 | 160 | # Verify the mocks were called 161 | mock_extract_inference.assert_called_once() 162 | -------------------------------------------------------------------------------- /stagehand/handlers/observe_handler.py: -------------------------------------------------------------------------------- 1 | """Observe handler for performing observations of page elements using LLMs.""" 2 | 3 | from typing import Any, Optional 4 | 5 | from stagehand.a11y.utils import get_accessibility_tree, get_xpath_by_resolved_object_id 6 | from stagehand.llm.inference import observe as observe_inference 7 | from stagehand.metrics import StagehandFunctionName # Changed import location 8 | from stagehand.schemas import ObserveOptions, ObserveResult 9 | from stagehand.utils import draw_observe_overlay 10 | from stagehand.cache import StagehandCache 11 | 12 | 13 | class ObserveHandler: 14 | """Handler for processing observe operations locally.""" 15 | 16 | def __init__( 17 | self, stagehand_page, stagehand_client, user_provided_instructions=None 18 | ): 19 | """ 20 | Initialize the ObserveHandler. 21 | 22 | Args: 23 | stagehand_page: StagehandPage instance 24 | stagehand_client: Stagehand client instance 25 | user_provided_instructions: Optional custom system instructions 26 | """ 27 | self.stagehand_page = stagehand_page 28 | self.stagehand = stagehand_client 29 | self.logger = stagehand_client.logger 30 | self.user_provided_instructions = user_provided_instructions 31 | # 初始化缓存管理器 32 | self.cache_manager = StagehandCache(logger=self.logger) 33 | 34 | # TODO: better kwargs 35 | async def observe( 36 | self, 37 | options: ObserveOptions, 38 | from_act: bool = False, 39 | use_cache: bool = True, 40 | cache_ttl: int = 3600, 41 | ) -> list[ObserveResult]: 42 | """ 43 | Execute an observation operation locally. 44 | 45 | Args: 46 | options: ObserveOptions containing the instruction and other parameters 47 | from_act: Whether this observe call is from an act operation 48 | use_cache: Whether to use caching mechanism 49 | cache_ttl: Cache time-to-live in seconds 50 | 51 | Returns: 52 | list of ObserveResult instances 53 | """ 54 | instruction = options.instruction 55 | if not instruction: 56 | instruction = ( 57 | "Find elements that can be used for any future actions in the page. " 58 | "These may be navigation links, related pages, section/subsection links, " 59 | "buttons, or other interactive elements. Be comprehensive: if there are " 60 | "multiple elements that may be relevant for future actions, return all of them." 61 | ) 62 | 63 | if not from_act: 64 | self.logger.info( 65 | f"Starting observation for task: '{instruction}'", 66 | category="observe", 67 | ) 68 | 69 | # Start inference timer if available 70 | if hasattr(self.stagehand, "start_inference_timer"): 71 | self.stagehand.start_inference_timer() 72 | 73 | # Get DOM representation 74 | output_string = "" 75 | iframes = [] 76 | 77 | await self.stagehand_page._wait_for_settled_dom() 78 | 79 | # 获取页面信息用于缓存 80 | page_url = self.stagehand_page._page.url 81 | page_title = None 82 | try: 83 | page_title = await self.stagehand_page._page.title() 84 | except: 85 | pass 86 | 87 | # 检查缓存(如果启用) 88 | if use_cache: 89 | cached_result = self.cache_manager.get_cached_result( 90 | instruction, page_url, page_title, cache_ttl 91 | ) 92 | if cached_result: 93 | # 验证缓存的xpath是否仍然有效 94 | is_valid = await self.cache_manager.validate_cached_xpath( 95 | self.stagehand_page, cached_result.selector 96 | ) 97 | if is_valid: 98 | self.logger.info("🚀 使用缓存结果,跳过LLM调用") 99 | return [cached_result] 100 | else: 101 | self.logger.info("⚠️ 缓存的xpath已失效,将重新分析") 102 | 103 | # Get accessibility tree data using our utility function 104 | self.logger.info("Getting accessibility tree data") 105 | tree = await get_accessibility_tree(self.stagehand_page, self.logger) 106 | output_string = tree["simplified"] 107 | iframes = tree.get("iframes", []) 108 | 109 | # use inference to call the llm 110 | observation_response = observe_inference( 111 | instruction=instruction, 112 | tree_elements=output_string, 113 | llm_client=self.stagehand.llm, 114 | user_provided_instructions=self.user_provided_instructions, 115 | logger=self.logger, 116 | log_inference_to_file=False, # TODO: Implement logging to file if needed 117 | from_act=from_act, 118 | ) 119 | 120 | # Extract metrics from response 121 | prompt_tokens = observation_response.get("prompt_tokens", 0) 122 | completion_tokens = observation_response.get("completion_tokens", 0) 123 | inference_time_ms = observation_response.get("inference_time_ms", 0) 124 | 125 | # Update metrics directly using the Stagehand client 126 | function_name = ( 127 | StagehandFunctionName.ACT if from_act else StagehandFunctionName.OBSERVE 128 | ) 129 | self.stagehand.update_metrics( 130 | function_name, prompt_tokens, completion_tokens, inference_time_ms 131 | ) 132 | 133 | # Add iframes to the response if any 134 | elements = observation_response.get("elements", []) 135 | for iframe in iframes: 136 | elements.append( 137 | { 138 | "element_id": int(iframe.get("nodeId", 0)), 139 | "description": "an iframe", 140 | "method": "not-supported", 141 | "arguments": [], 142 | } 143 | ) 144 | 145 | # Generate selectors for all elements 146 | elements_with_selectors = await self._add_selectors_to_elements(elements) 147 | 148 | self.logger.debug( 149 | "Found elements", auxiliary={"elements": elements_with_selectors} 150 | ) 151 | 152 | # 保存到缓存(如果启用且有结果) 153 | if use_cache and elements_with_selectors: 154 | try: 155 | # 保存第一个最相关的结果到缓存 156 | first_result = elements_with_selectors[0] 157 | self.cache_manager.set_cache( 158 | instruction, page_url, first_result, page_title 159 | ) 160 | except Exception as e: 161 | self.logger.warning(f"保存缓存失败: {e}") 162 | 163 | # Draw overlay if requested 164 | if options.draw_overlay: 165 | await draw_observe_overlay(self.stagehand_page, elements_with_selectors) 166 | 167 | # Return the list of results without trying to attach _llm_response 168 | return elements_with_selectors 169 | 170 | async def _add_selectors_to_elements( 171 | self, 172 | elements: list[dict[str, Any]], 173 | ) -> list[ObserveResult]: 174 | """ 175 | Add selectors to elements based on their element IDs. 176 | 177 | Args: 178 | elements: list of elements from LLM response 179 | 180 | Returns: 181 | list of elements with selectors added (xpaths) 182 | """ 183 | result = [] 184 | 185 | for element in elements: 186 | element_id = element.get("element_id") 187 | rest = {k: v for k, v in element.items() if k != "element_id"} 188 | 189 | # Generate xpath for element using CDP 190 | self.logger.info( 191 | "Getting xpath for element", 192 | auxiliary={"elementId": str(element_id)}, 193 | ) 194 | 195 | args = {"backendNodeId": element_id} 196 | response = await self.stagehand_page.send_cdp("DOM.resolveNode", args) 197 | object_id = response.get("object", {}).get("objectId") 198 | 199 | if not object_id: 200 | self.logger.info( 201 | f"Invalid object ID returned for element: {element_id}" 202 | ) 203 | continue 204 | 205 | # Use our utility function to get the XPath 206 | cdp_client = await self.stagehand_page.get_cdp_client() 207 | xpath = await get_xpath_by_resolved_object_id(cdp_client, object_id) 208 | 209 | if not xpath: 210 | self.logger.info(f"Empty xpath returned for element: {element_id}") 211 | continue 212 | 213 | result.append(ObserveResult(**{**rest, "selector": f"xpath={xpath}"})) 214 | 215 | return result 216 | -------------------------------------------------------------------------------- /tests/regression/test_extract_aigrant_companies.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regression test for extract_aigrant_companies functionality. 3 | 4 | This test verifies that data extraction works correctly by extracting 5 | companies that received AI grants along with their batch numbers, 6 | based on the TypeScript extract_aigrant_companies evaluation. 7 | """ 8 | 9 | import os 10 | import pytest 11 | import pytest_asyncio 12 | from pydantic import BaseModel, Field 13 | from typing import List 14 | 15 | from stagehand import Stagehand, StagehandConfig 16 | from stagehand.schemas import ExtractOptions 17 | 18 | 19 | class Company(BaseModel): 20 | company: str = Field(..., description="The name of the company") 21 | batch: str = Field(..., description="The batch number of the grant") 22 | 23 | 24 | class Companies(BaseModel): 25 | companies: List[Company] = Field(..., description="List of companies that received AI grants") 26 | 27 | 28 | class TestExtractAigrantCompanies: 29 | """Regression test for extract_aigrant_companies functionality""" 30 | 31 | @pytest.fixture(scope="class") 32 | def local_config(self): 33 | """Configuration for LOCAL mode testing""" 34 | return StagehandConfig( 35 | env="LOCAL", 36 | model_name="gpt-4o-mini", 37 | headless=True, 38 | verbose=1, 39 | dom_settle_timeout_ms=2000, 40 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 41 | ) 42 | 43 | @pytest.fixture(scope="class") 44 | def browserbase_config(self): 45 | """Configuration for BROWSERBASE mode testing""" 46 | return StagehandConfig( 47 | env="BROWSERBASE", 48 | api_key=os.getenv("BROWSERBASE_API_KEY"), 49 | project_id=os.getenv("BROWSERBASE_PROJECT_ID"), 50 | model_name="gpt-4o", 51 | headless=False, 52 | verbose=2, 53 | model_client_options={"apiKey": os.getenv("MODEL_API_KEY") or os.getenv("OPENAI_API_KEY")}, 54 | ) 55 | 56 | @pytest_asyncio.fixture 57 | async def local_stagehand(self, local_config): 58 | """Create a Stagehand instance for LOCAL testing""" 59 | stagehand = Stagehand(config=local_config) 60 | await stagehand.init() 61 | yield stagehand 62 | await stagehand.close() 63 | 64 | @pytest_asyncio.fixture 65 | async def browserbase_stagehand(self, browserbase_config): 66 | """Create a Stagehand instance for BROWSERBASE testing""" 67 | if not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")): 68 | pytest.skip("Browserbase credentials not available") 69 | 70 | stagehand = Stagehand(config=browserbase_config) 71 | await stagehand.init() 72 | yield stagehand 73 | await stagehand.close() 74 | 75 | @pytest.mark.asyncio 76 | @pytest.mark.regression 77 | @pytest.mark.local 78 | async def test_extract_aigrant_companies_local(self, local_stagehand): 79 | """ 80 | Regression test: extract_aigrant_companies 81 | 82 | Mirrors the TypeScript extract_aigrant_companies evaluation: 83 | - Navigate to AI grant companies test site 84 | - Extract all companies that received AI grants with their batch numbers 85 | - Verify total count is 91 86 | - Verify first company is "Goodfire" in batch "4" 87 | - Verify last company is "Forefront" in batch "1" 88 | """ 89 | stagehand = local_stagehand 90 | 91 | await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/") 92 | 93 | # Extract all companies with their batch numbers 94 | extract_options = ExtractOptions( 95 | instruction=( 96 | "Extract all companies that received the AI grant and group them with their " 97 | "batch numbers as an array of objects. Each object should contain the company " 98 | "name and its corresponding batch number." 99 | ), 100 | schema_definition=Companies 101 | ) 102 | 103 | result = await stagehand.page.extract(extract_options) 104 | 105 | # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly 106 | companies = result.companies 107 | 108 | # Verify total count 109 | expected_length = 91 110 | assert len(companies) == expected_length, ( 111 | f"Expected {expected_length} companies, but got {len(companies)}" 112 | ) 113 | 114 | # Verify first company 115 | expected_first_item = { 116 | "company": "Goodfire", 117 | "batch": "4" 118 | } 119 | assert len(companies) > 0, "No companies were extracted" 120 | first_company = companies[0] 121 | assert first_company.company == expected_first_item["company"], ( 122 | f"Expected first company to be '{expected_first_item['company']}', " 123 | f"but got '{first_company.company}'" 124 | ) 125 | assert first_company.batch == expected_first_item["batch"], ( 126 | f"Expected first company batch to be '{expected_first_item['batch']}', " 127 | f"but got '{first_company.batch}'" 128 | ) 129 | 130 | # Verify last company 131 | expected_last_item = { 132 | "company": "Forefront", 133 | "batch": "1" 134 | } 135 | last_company = companies[-1] 136 | assert last_company.company == expected_last_item["company"], ( 137 | f"Expected last company to be '{expected_last_item['company']}', " 138 | f"but got '{last_company.company}'" 139 | ) 140 | assert last_company.batch == expected_last_item["batch"], ( 141 | f"Expected last company batch to be '{expected_last_item['batch']}', " 142 | f"but got '{last_company.batch}'" 143 | ) 144 | 145 | @pytest.mark.asyncio 146 | @pytest.mark.regression 147 | @pytest.mark.api 148 | @pytest.mark.skipif( 149 | not (os.getenv("BROWSERBASE_API_KEY") and os.getenv("BROWSERBASE_PROJECT_ID")), 150 | reason="Browserbase credentials not available" 151 | ) 152 | async def test_extract_aigrant_companies_browserbase(self, browserbase_stagehand): 153 | """ 154 | Regression test: extract_aigrant_companies (Browserbase) 155 | 156 | Same test as local but running in Browserbase environment. 157 | """ 158 | stagehand = browserbase_stagehand 159 | 160 | await stagehand.page.goto("https://browserbase.github.io/stagehand-eval-sites/sites/aigrant/") 161 | 162 | # Extract all companies with their batch numbers 163 | extract_options = ExtractOptions( 164 | instruction=( 165 | "Extract all companies that received the AI grant and group them with their " 166 | "batch numbers as an array of objects. Each object should contain the company " 167 | "name and its corresponding batch number." 168 | ), 169 | schema_definition=Companies 170 | ) 171 | 172 | result = await stagehand.page.extract(extract_options) 173 | 174 | # Both LOCAL and BROWSERBASE modes return the Pydantic model instance directly 175 | companies = result.companies 176 | 177 | # Verify total count 178 | expected_length = 91 179 | assert len(companies) == expected_length, ( 180 | f"Expected {expected_length} companies, but got {len(companies)}" 181 | ) 182 | 183 | # Verify first company 184 | expected_first_item = { 185 | "company": "Goodfire", 186 | "batch": "4" 187 | } 188 | assert len(companies) > 0, "No companies were extracted" 189 | first_company = companies[0] 190 | assert first_company.company == expected_first_item["company"], ( 191 | f"Expected first company to be '{expected_first_item['company']}', " 192 | f"but got '{first_company.company}'" 193 | ) 194 | assert first_company.batch == expected_first_item["batch"], ( 195 | f"Expected first company batch to be '{expected_first_item['batch']}', " 196 | f"but got '{first_company.batch}'" 197 | ) 198 | 199 | # Verify last company 200 | expected_last_item = { 201 | "company": "Forefront", 202 | "batch": "1" 203 | } 204 | last_company = companies[-1] 205 | assert last_company.company == expected_last_item["company"], ( 206 | f"Expected last company to be '{expected_last_item['company']}', " 207 | f"but got '{last_company.company}'" 208 | ) 209 | assert last_company.batch == expected_last_item["batch"], ( 210 | f"Expected last company batch to be '{expected_last_item['batch']}', " 211 | f"but got '{last_company.batch}'" 212 | ) -------------------------------------------------------------------------------- /tests/integration/api/test_frame_id_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integration tests for frame ID functionality with the API. 3 | Tests that frame IDs are properly tracked and sent to the server. 4 | """ 5 | 6 | import pytest 7 | import os 8 | from unittest.mock import patch, AsyncMock, MagicMock 9 | from stagehand import Stagehand 10 | 11 | 12 | @pytest.mark.skipif( 13 | not os.getenv("BROWSERBASE_API_KEY") or not os.getenv("BROWSERBASE_PROJECT_ID"), 14 | reason="Browserbase credentials not configured" 15 | ) 16 | @pytest.mark.asyncio 17 | class TestFrameIdIntegration: 18 | """Integration tests for frame ID tracking with the API.""" 19 | 20 | async def test_frame_id_initialization_and_api_calls(self): 21 | """Test that frame IDs are initialized and included in API calls.""" 22 | # Mock the HTTP client to capture API calls 23 | with patch('stagehand.main.httpx.AsyncClient') as MockClient: 24 | mock_client = AsyncMock() 25 | MockClient.return_value = mock_client 26 | 27 | # Mock session creation response 28 | mock_response = MagicMock() 29 | mock_response.status_code = 200 30 | mock_response.json.return_value = { 31 | "success": True, 32 | "data": { 33 | "sessionId": "test-session-123", 34 | "available": True 35 | } 36 | } 37 | mock_client.post = AsyncMock(return_value=mock_response) 38 | 39 | # Mock streaming response for execute calls 40 | mock_stream_response = AsyncMock() 41 | mock_stream_response.status_code = 200 42 | mock_stream_response.__aenter__ = AsyncMock(return_value=mock_stream_response) 43 | mock_stream_response.__aexit__ = AsyncMock() 44 | 45 | # Mock the async iterator for streaming lines 46 | async def mock_aiter_lines(): 47 | yield 'data: {"type": "system", "data": {"status": "finished", "result": {"success": true}}}' 48 | 49 | mock_stream_response.aiter_lines = mock_aiter_lines 50 | mock_client.stream = MagicMock(return_value=mock_stream_response) 51 | 52 | # Initialize Stagehand 53 | stagehand = Stagehand( 54 | env="BROWSERBASE", 55 | use_api=True, 56 | browserbase_api_key="test-api-key", 57 | browserbase_project_id="test-project", 58 | model_api_key="test-model-key" 59 | ) 60 | 61 | try: 62 | # Initialize browser (this will create session via API) 63 | await stagehand.init() 64 | 65 | # Verify session was created 66 | assert mock_client.post.called 67 | 68 | # Get the page and context 69 | page = stagehand.page 70 | context = stagehand.context 71 | 72 | # Verify frame tracking attributes exist 73 | assert hasattr(page, 'frame_id') 74 | assert hasattr(context, 'frame_id_map') 75 | 76 | # Simulate setting a frame ID (normally done by CDP listener) 77 | test_frame_id = "test-frame-456" 78 | page.update_root_frame_id(test_frame_id) 79 | context.register_frame_id(test_frame_id, page) 80 | 81 | # Test that frame ID is included in navigate call 82 | await page.goto("https://example.com") 83 | 84 | # Check the stream call was made with frameId 85 | stream_call_args = mock_client.stream.call_args 86 | if stream_call_args: 87 | payload = stream_call_args[1].get('json', {}) 88 | assert 'frameId' in payload 89 | assert payload['frameId'] == test_frame_id 90 | 91 | finally: 92 | await stagehand.close() 93 | 94 | async def test_multiple_pages_frame_id_tracking(self): 95 | """Test frame ID tracking with multiple pages.""" 96 | with patch('stagehand.main.httpx.AsyncClient') as MockClient: 97 | mock_client = AsyncMock() 98 | MockClient.return_value = mock_client 99 | 100 | # Setup mocks as in previous test 101 | mock_response = MagicMock() 102 | mock_response.status_code = 200 103 | mock_response.json.return_value = { 104 | "success": True, 105 | "data": { 106 | "sessionId": "test-session-789", 107 | "available": True 108 | } 109 | } 110 | mock_client.post = AsyncMock(return_value=mock_response) 111 | 112 | stagehand = Stagehand( 113 | env="BROWSERBASE", 114 | use_api=True, 115 | browserbase_api_key="test-api-key", 116 | browserbase_project_id="test-project", 117 | model_api_key="test-model-key" 118 | ) 119 | 120 | try: 121 | await stagehand.init() 122 | 123 | # Get first page 124 | page1 = stagehand.page 125 | context = stagehand.context 126 | 127 | # Simulate frame IDs for testing 128 | frame_id_1 = "frame-page1" 129 | page1.update_root_frame_id(frame_id_1) 130 | context.register_frame_id(frame_id_1, page1) 131 | 132 | # Create second page 133 | page2 = await context.new_page() 134 | frame_id_2 = "frame-page2" 135 | page2.update_root_frame_id(frame_id_2) 136 | context.register_frame_id(frame_id_2, page2) 137 | 138 | # Verify both pages are tracked 139 | assert len(context.frame_id_map) == 2 140 | assert context.get_stagehand_page_by_frame_id(frame_id_1) == page1 141 | assert context.get_stagehand_page_by_frame_id(frame_id_2) == page2 142 | 143 | # Verify each page has its own frame ID 144 | assert page1.frame_id == frame_id_1 145 | assert page2.frame_id == frame_id_2 146 | 147 | finally: 148 | await stagehand.close() 149 | 150 | async def test_frame_id_persistence_across_navigation(self): 151 | """Test that frame IDs are updated when navigating to new pages.""" 152 | with patch('stagehand.main.httpx.AsyncClient') as MockClient: 153 | mock_client = AsyncMock() 154 | MockClient.return_value = mock_client 155 | 156 | # Setup basic mocks 157 | mock_response = MagicMock() 158 | mock_response.status_code = 200 159 | mock_response.json.return_value = { 160 | "success": True, 161 | "data": { 162 | "sessionId": "test-session-nav", 163 | "available": True 164 | } 165 | } 166 | mock_client.post = AsyncMock(return_value=mock_response) 167 | 168 | stagehand = Stagehand( 169 | env="BROWSERBASE", 170 | use_api=True, 171 | browserbase_api_key="test-api-key", 172 | browserbase_project_id="test-project", 173 | model_api_key="test-model-key" 174 | ) 175 | 176 | try: 177 | await stagehand.init() 178 | 179 | page = stagehand.page 180 | context = stagehand.context 181 | 182 | # Initial frame ID 183 | initial_frame_id = "frame-initial" 184 | page.update_root_frame_id(initial_frame_id) 185 | context.register_frame_id(initial_frame_id, page) 186 | 187 | assert page.frame_id == initial_frame_id 188 | assert initial_frame_id in context.frame_id_map 189 | 190 | # Simulate navigation causing frame ID change 191 | # (In real scenario, CDP listener would handle this) 192 | new_frame_id = "frame-after-nav" 193 | context.unregister_frame_id(initial_frame_id) 194 | page.update_root_frame_id(new_frame_id) 195 | context.register_frame_id(new_frame_id, page) 196 | 197 | # Verify frame ID was updated 198 | assert page.frame_id == new_frame_id 199 | assert initial_frame_id not in context.frame_id_map 200 | assert new_frame_id in context.frame_id_map 201 | assert context.get_stagehand_page_by_frame_id(new_frame_id) == page 202 | 203 | finally: 204 | await stagehand.close() -------------------------------------------------------------------------------- /tests/unit/core/test_frame_id_tracking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for frame ID tracking functionality. 3 | Tests the implementation of frame ID map in StagehandContext and StagehandPage. 4 | """ 5 | 6 | import pytest 7 | from unittest.mock import AsyncMock, MagicMock, patch 8 | from stagehand.context import StagehandContext 9 | from stagehand.page import StagehandPage 10 | 11 | 12 | @pytest.fixture 13 | def mock_stagehand(): 14 | """Create a mock Stagehand client.""" 15 | mock = MagicMock() 16 | mock.logger = MagicMock() 17 | mock.logger.debug = MagicMock() 18 | mock.logger.error = MagicMock() 19 | mock._page_switch_lock = AsyncMock() 20 | return mock 21 | 22 | 23 | @pytest.fixture 24 | def mock_browser_context(): 25 | """Create a mock Playwright BrowserContext.""" 26 | mock_context = MagicMock() 27 | mock_context.pages = [] 28 | mock_context.new_page = AsyncMock() 29 | mock_context.new_cdp_session = AsyncMock() 30 | return mock_context 31 | 32 | 33 | @pytest.fixture 34 | def mock_page(): 35 | """Create a mock Playwright Page.""" 36 | page = MagicMock() 37 | page.url = "https://example.com" 38 | page.evaluate = AsyncMock(return_value=False) 39 | page.add_init_script = AsyncMock() 40 | page.context = MagicMock() 41 | page.once = MagicMock() 42 | return page 43 | 44 | 45 | class TestFrameIdTracking: 46 | """Test suite for frame ID tracking functionality.""" 47 | 48 | def test_stagehand_context_initialization(self, mock_browser_context, mock_stagehand): 49 | """Test that StagehandContext initializes with frame_id_map.""" 50 | context = StagehandContext(mock_browser_context, mock_stagehand) 51 | 52 | assert hasattr(context, 'frame_id_map') 53 | assert isinstance(context.frame_id_map, dict) 54 | assert len(context.frame_id_map) == 0 55 | 56 | def test_register_frame_id(self, mock_browser_context, mock_stagehand, mock_page): 57 | """Test registering a frame ID.""" 58 | context = StagehandContext(mock_browser_context, mock_stagehand) 59 | stagehand_page = StagehandPage(mock_page, mock_stagehand, context) 60 | 61 | # Register frame ID 62 | frame_id = "frame-123" 63 | context.register_frame_id(frame_id, stagehand_page) 64 | 65 | assert frame_id in context.frame_id_map 66 | assert context.frame_id_map[frame_id] == stagehand_page 67 | 68 | def test_unregister_frame_id(self, mock_browser_context, mock_stagehand, mock_page): 69 | """Test unregistering a frame ID.""" 70 | context = StagehandContext(mock_browser_context, mock_stagehand) 71 | stagehand_page = StagehandPage(mock_page, mock_stagehand, context) 72 | 73 | # Register and then unregister 74 | frame_id = "frame-456" 75 | context.register_frame_id(frame_id, stagehand_page) 76 | context.unregister_frame_id(frame_id) 77 | 78 | assert frame_id not in context.frame_id_map 79 | 80 | def test_get_stagehand_page_by_frame_id(self, mock_browser_context, mock_stagehand, mock_page): 81 | """Test retrieving a StagehandPage by frame ID.""" 82 | context = StagehandContext(mock_browser_context, mock_stagehand) 83 | stagehand_page = StagehandPage(mock_page, mock_stagehand, context) 84 | 85 | frame_id = "frame-789" 86 | context.register_frame_id(frame_id, stagehand_page) 87 | 88 | retrieved_page = context.get_stagehand_page_by_frame_id(frame_id) 89 | assert retrieved_page == stagehand_page 90 | 91 | # Test non-existent frame ID 92 | non_existent = context.get_stagehand_page_by_frame_id("non-existent") 93 | assert non_existent is None 94 | 95 | def test_stagehand_page_frame_id_property(self, mock_page, mock_stagehand): 96 | """Test StagehandPage frame_id property and update method.""" 97 | stagehand_page = StagehandPage(mock_page, mock_stagehand) 98 | 99 | # Initially None 100 | assert stagehand_page.frame_id is None 101 | 102 | # Update frame ID 103 | new_frame_id = "frame-abc" 104 | stagehand_page.update_root_frame_id(new_frame_id) 105 | 106 | assert stagehand_page.frame_id == new_frame_id 107 | mock_stagehand.logger.debug.assert_called_with( 108 | f"Updated frame ID to {new_frame_id}", category="page" 109 | ) 110 | 111 | @pytest.mark.asyncio 112 | async def test_attach_frame_navigated_listener(self, mock_browser_context, mock_stagehand, mock_page): 113 | """Test attaching CDP frame navigation listener.""" 114 | context = StagehandContext(mock_browser_context, mock_stagehand) 115 | stagehand_page = StagehandPage(mock_page, mock_stagehand, context) 116 | 117 | # Mock CDP session 118 | mock_cdp_session = MagicMock() 119 | mock_cdp_session.send = AsyncMock() 120 | mock_cdp_session.on = MagicMock() 121 | mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session) 122 | 123 | # Mock frame tree response 124 | mock_cdp_session.send.return_value = { 125 | "frameTree": { 126 | "frame": { 127 | "id": "initial-frame-id" 128 | } 129 | } 130 | } 131 | 132 | # Attach listener 133 | await context._attach_frame_navigated_listener(mock_page, stagehand_page) 134 | 135 | # Verify CDP session was created and Page domain was enabled 136 | mock_browser_context.new_cdp_session.assert_called_once_with(mock_page) 137 | mock_cdp_session.send.assert_any_call("Page.enable") 138 | mock_cdp_session.send.assert_any_call("Page.getFrameTree") 139 | 140 | # Verify frame ID was set 141 | assert stagehand_page.frame_id == "initial-frame-id" 142 | assert "initial-frame-id" in context.frame_id_map 143 | 144 | # Verify event listener was registered 145 | mock_cdp_session.on.assert_called_once() 146 | assert mock_cdp_session.on.call_args[0][0] == "Page.frameNavigated" 147 | 148 | @pytest.mark.asyncio 149 | async def test_frame_id_in_api_calls(self, mock_page, mock_stagehand): 150 | """Test that frame ID is included in API payloads.""" 151 | stagehand_page = StagehandPage(mock_page, mock_stagehand) 152 | stagehand_page.update_root_frame_id("test-frame-123") 153 | 154 | # Mock the stagehand client for API mode 155 | mock_stagehand.use_api = True 156 | mock_stagehand._get_lock_for_session = MagicMock() 157 | mock_stagehand._get_lock_for_session.return_value = AsyncMock() 158 | mock_stagehand._execute = AsyncMock(return_value={"success": True}) 159 | 160 | # Test goto with frame ID 161 | await stagehand_page.goto("https://example.com") 162 | 163 | # Verify frame ID was included in the payload 164 | call_args = mock_stagehand._execute.call_args 165 | assert call_args[0][0] == "navigate" 166 | assert "frameId" in call_args[0][1] 167 | assert call_args[0][1]["frameId"] == "test-frame-123" 168 | 169 | @pytest.mark.asyncio 170 | async def test_frame_navigation_event_handling(self, mock_browser_context, mock_stagehand, mock_page): 171 | """Test handling of frame navigation events.""" 172 | context = StagehandContext(mock_browser_context, mock_stagehand) 173 | stagehand_page = StagehandPage(mock_page, mock_stagehand, context) 174 | 175 | # Set initial frame ID 176 | initial_frame_id = "frame-initial" 177 | stagehand_page.update_root_frame_id(initial_frame_id) 178 | context.register_frame_id(initial_frame_id, stagehand_page) 179 | 180 | # Mock CDP session 181 | mock_cdp_session = MagicMock() 182 | mock_cdp_session.send = AsyncMock() 183 | mock_cdp_session.on = MagicMock() 184 | mock_browser_context.new_cdp_session = AsyncMock(return_value=mock_cdp_session) 185 | 186 | # Mock initial frame tree 187 | mock_cdp_session.send.return_value = { 188 | "frameTree": { 189 | "frame": { 190 | "id": initial_frame_id 191 | } 192 | } 193 | } 194 | 195 | await context._attach_frame_navigated_listener(mock_page, stagehand_page) 196 | 197 | # Get the registered event handler 198 | event_handler = mock_cdp_session.on.call_args[0][1] 199 | 200 | # Simulate frame navigation event 201 | new_frame_id = "frame-new" 202 | event_handler({ 203 | "frame": { 204 | "id": new_frame_id, 205 | "parentId": None # Root frame has no parent 206 | } 207 | }) 208 | 209 | # Verify old frame ID was unregistered and new one registered 210 | assert initial_frame_id not in context.frame_id_map 211 | assert new_frame_id in context.frame_id_map 212 | assert stagehand_page.frame_id == new_frame_id 213 | assert context.frame_id_map[new_frame_id] == stagehand_page -------------------------------------------------------------------------------- /tests/unit/test_client_initialization.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import unittest.mock as mock 3 | import os 4 | 5 | import pytest 6 | 7 | from stagehand import Stagehand 8 | from stagehand.config import StagehandConfig 9 | 10 | 11 | class TestClientInitialization: 12 | """Tests for the Stagehand client initialization and configuration.""" 13 | 14 | @pytest.mark.smoke 15 | @mock.patch.dict(os.environ, {}, clear=True) 16 | def test_init_with_direct_params(self): 17 | """Test initialization with direct parameters.""" 18 | # Create a config with LOCAL env to avoid BROWSERBASE validation issues 19 | config = StagehandConfig(env="LOCAL") 20 | client = Stagehand( 21 | config=config, 22 | api_url="http://test-server.com", 23 | browserbase_session_id="test-session", 24 | api_key="test-api-key", 25 | project_id="test-project-id", 26 | model_api_key="test-model-api-key", 27 | verbose=2, 28 | ) 29 | 30 | assert client.api_url == "http://test-server.com" 31 | assert client.session_id == "test-session" 32 | # In LOCAL mode, browserbase keys are not used 33 | assert client.model_api_key == "test-model-api-key" 34 | assert client.verbose == 2 35 | assert client._initialized is False 36 | assert client._closed is False 37 | 38 | @pytest.mark.smoke 39 | @mock.patch.dict(os.environ, {}, clear=True) 40 | def test_init_with_config(self): 41 | """Test initialization with a configuration object.""" 42 | config = StagehandConfig( 43 | env="LOCAL", # Use LOCAL to avoid BROWSERBASE validation 44 | api_key="config-api-key", 45 | project_id="config-project-id", 46 | browserbase_session_id="config-session-id", 47 | model_name="gpt-4", 48 | dom_settle_timeout_ms=500, 49 | self_heal=True, 50 | wait_for_captcha_solves=True, 51 | system_prompt="Custom system prompt for testing", 52 | ) 53 | 54 | client = Stagehand(config=config, api_url="http://test-server.com") 55 | 56 | assert client.api_url == "http://test-server.com" 57 | assert client.session_id == "config-session-id" 58 | assert client.browserbase_api_key == "config-api-key" 59 | assert client.browserbase_project_id == "config-project-id" 60 | assert client.model_name == "gpt-4" 61 | assert client.dom_settle_timeout_ms == 500 62 | assert hasattr(client, "self_heal") 63 | assert client.self_heal is True 64 | assert hasattr(client, "wait_for_captcha_solves") 65 | assert client.wait_for_captcha_solves is True 66 | assert hasattr(client, "config") 67 | assert hasattr(client, "system_prompt") 68 | assert client.system_prompt == "Custom system prompt for testing" 69 | 70 | @mock.patch.dict(os.environ, {}, clear=True) 71 | def test_config_priority_over_direct_params(self): 72 | """Test that config parameters take precedence over direct parameters (except session_id).""" 73 | config = StagehandConfig( 74 | env="LOCAL", # Use LOCAL to avoid BROWSERBASE validation 75 | api_key="config-api-key", 76 | project_id="config-project-id", 77 | browserbase_session_id="config-session-id", 78 | ) 79 | 80 | client = Stagehand( 81 | config=config, 82 | api_key="direct-api-key", 83 | project_id="direct-project-id", 84 | browserbase_session_id="direct-session-id", 85 | ) 86 | 87 | # Override parameters take precedence over config parameters 88 | assert client.browserbase_api_key == "direct-api-key" 89 | assert client.browserbase_project_id == "direct-project-id" 90 | # session_id parameter overrides config since it's passed as browserbase_session_id override 91 | assert client.session_id == "direct-session-id" 92 | 93 | def test_init_with_missing_required_fields(self): 94 | """Test initialization with missing required fields.""" 95 | # No error when initialized without session_id 96 | client = Stagehand( 97 | api_key="test-api-key", project_id="test-project-id" 98 | ) 99 | assert client.session_id is None 100 | 101 | # Test that error handling for missing API key is functioning 102 | # by patching the ValueError that should be raised 103 | with mock.patch.object( 104 | Stagehand, 105 | "__init__", 106 | side_effect=ValueError("browserbase_api_key is required"), 107 | ): 108 | with pytest.raises(ValueError, match="browserbase_api_key is required"): 109 | Stagehand( 110 | browserbase_session_id="test-session", project_id="test-project-id" 111 | ) 112 | 113 | def test_init_as_context_manager(self): 114 | """Test the client as a context manager.""" 115 | client = Stagehand( 116 | api_url="http://test-server.com", 117 | browserbase_session_id="test-session", 118 | api_key="test-api-key", 119 | project_id="test-project-id", 120 | ) 121 | 122 | # Mock the async context manager methods 123 | client.__aenter__ = mock.AsyncMock(return_value=client) 124 | client.__aexit__ = mock.AsyncMock() 125 | client.init = mock.AsyncMock() 126 | client.close = mock.AsyncMock() 127 | 128 | # We can't easily test an async context manager in a non-async test, 129 | # so we just verify the methods exist and are async 130 | assert hasattr(client, "__aenter__") 131 | assert hasattr(client, "__aexit__") 132 | 133 | # Verify init is called in __aenter__ 134 | assert client.init is not None 135 | 136 | # Verify close is called in __aexit__ 137 | assert client.close is not None 138 | 139 | @pytest.mark.asyncio 140 | async def test_init_playwright_timeout(self): 141 | """Test that init() raises TimeoutError when playwright takes too long to start.""" 142 | config = StagehandConfig(env="LOCAL") 143 | client = Stagehand(config=config) 144 | 145 | # Mock async_playwright to simulate a hanging start() method 146 | mock_playwright_instance = mock.AsyncMock() 147 | mock_start = mock.AsyncMock() 148 | 149 | # Make start() hang indefinitely 150 | async def hanging_start(): 151 | await asyncio.sleep(100) # Sleep longer than the 30s timeout 152 | 153 | mock_start.side_effect = hanging_start 154 | mock_playwright_instance.start = mock_start 155 | 156 | with mock.patch("stagehand.main.async_playwright", return_value=mock_playwright_instance): 157 | # The init() method should raise TimeoutError due to the 30-second timeout 158 | with pytest.raises(asyncio.TimeoutError): 159 | await client.init() 160 | 161 | # Ensure the client is not marked as initialized 162 | assert client._initialized is False 163 | 164 | @pytest.mark.asyncio 165 | async def test_create_session(self): 166 | """Test session creation.""" 167 | client = Stagehand( 168 | api_url="http://test-server.com", 169 | api_key="test-api-key", 170 | project_id="test-project-id", 171 | model_api_key="test-model-api-key", 172 | ) 173 | 174 | # Override the _create_session method for easier testing 175 | original_create_session = client._create_session 176 | 177 | async def mock_create_session(): 178 | client.session_id = "new-test-session-id" 179 | 180 | client._create_session = mock_create_session 181 | 182 | # Call _create_session 183 | await client._create_session() 184 | 185 | # Verify session ID was set 186 | assert client.session_id == "new-test-session-id" 187 | 188 | @pytest.mark.asyncio 189 | async def test_create_session_failure(self): 190 | """Test session creation failure.""" 191 | client = Stagehand( 192 | api_url="http://test-server.com", 193 | api_key="test-api-key", 194 | project_id="test-project-id", 195 | model_api_key="test-model-api-key", 196 | ) 197 | 198 | # Override the _create_session method to raise an error 199 | original_create_session = client._create_session 200 | 201 | async def mock_create_session(): 202 | raise RuntimeError("Failed to create session: Invalid request") 203 | 204 | client._create_session = mock_create_session 205 | 206 | # Call _create_session and expect error 207 | with pytest.raises(RuntimeError, match="Failed to create session"): 208 | await client._create_session() 209 | 210 | @pytest.mark.asyncio 211 | async def test_create_session_invalid_response(self): 212 | """Test session creation with invalid response format.""" 213 | client = Stagehand( 214 | api_url="http://test-server.com", 215 | api_key="test-api-key", 216 | project_id="test-project-id", 217 | model_api_key="test-model-api-key", 218 | ) 219 | 220 | # Override the _create_session method to raise a specific error 221 | original_create_session = client._create_session 222 | 223 | async def mock_create_session(): 224 | raise RuntimeError("Invalid response format: {'success': true, 'data': {}}") 225 | 226 | client._create_session = mock_create_session 227 | 228 | # Call _create_session and expect error 229 | with pytest.raises(RuntimeError, match="Invalid response format"): 230 | await client._create_session() 231 | -------------------------------------------------------------------------------- /tests/unit/test_client_api.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import unittest.mock as mock 4 | 5 | import pytest 6 | from httpx import AsyncClient, Response 7 | 8 | from stagehand import Stagehand 9 | 10 | 11 | class TestClientAPI: 12 | """Tests for the Stagehand client API interactions.""" 13 | 14 | @pytest.fixture 15 | async def mock_client(self): 16 | """Create a mock Stagehand client for testing.""" 17 | client = Stagehand( 18 | api_url="http://test-server.com", 19 | browserbase_session_id="test-session-123", 20 | api_key="test-api-key", 21 | project_id="test-project-id", 22 | model_api_key="test-model-api-key", 23 | ) 24 | return client 25 | 26 | @pytest.mark.asyncio 27 | async def test_execute_success(self, mock_client): 28 | """Test successful execution of a streaming API request.""" 29 | 30 | # Create a custom implementation of _execute for testing 31 | async def mock_execute(method, payload): 32 | # Print debug info 33 | print("\n==== EXECUTING TEST_METHOD ====") 34 | print( 35 | f"URL: {mock_client.api_url}/sessions/{mock_client.session_id}/{method}" 36 | ) 37 | print(f"Payload: {payload}") 38 | print( 39 | f"Headers: {{'x-bb-api-key': '{mock_client.browserbase_api_key}', 'x-bb-project-id': '{mock_client.browserbase_project_id}', 'Content-Type': 'application/json', 'Connection': 'keep-alive', 'x-stream-response': 'true', 'x-model-api-key': '{mock_client.model_api_key}'}}" 40 | ) 41 | 42 | # Return the expected result directly 43 | return {"key": "value"} 44 | 45 | # Replace the method with our mock 46 | mock_client._execute = mock_execute 47 | 48 | # Call _execute and check results 49 | result = await mock_client._execute("test_method", {"param": "value"}) 50 | 51 | # Verify result matches the expected value 52 | assert result == {"key": "value"} 53 | 54 | @pytest.mark.asyncio 55 | async def test_execute_error_response(self, mock_client): 56 | """Test handling of error responses.""" 57 | # Create a mock implementation that simulates an error response 58 | async def mock_execute(method, payload): 59 | # Simulate the error handling that would happen in the real _execute method 60 | raise RuntimeError("Request failed with status 400: Bad request") 61 | 62 | # Replace the method with our mock 63 | mock_client._execute = mock_execute 64 | 65 | # Call _execute and expect it to raise the error 66 | with pytest.raises(RuntimeError, match="Request failed with status 400"): 67 | await mock_client._execute("test_method", {"param": "value"}) 68 | 69 | @pytest.mark.asyncio 70 | async def test_execute_connection_error(self, mock_client): 71 | """Test handling of connection errors.""" 72 | 73 | # Create a custom implementation of _execute that raises an exception 74 | async def mock_execute(method, payload): 75 | # Print debug info 76 | print("\n==== EXECUTING TEST_METHOD ====") 77 | print( 78 | f"URL: {mock_client.api_url}/sessions/{mock_client.session_id}/{method}" 79 | ) 80 | print(f"Payload: {payload}") 81 | print( 82 | f"Headers: {{'x-bb-api-key': '{mock_client.browserbase_api_key}', 'x-bb-project-id': '{mock_client.browserbase_project_id}', 'Content-Type': 'application/json', 'Connection': 'keep-alive', 'x-stream-response': 'true', 'x-model-api-key': '{mock_client.model_api_key}'}}" 83 | ) 84 | 85 | # Raise the expected exception 86 | raise Exception("Connection failed") 87 | 88 | # Replace the method with our mock 89 | mock_client._execute = mock_execute 90 | 91 | # Call _execute and check it raises the exception 92 | with pytest.raises(Exception, match="Connection failed"): 93 | await mock_client._execute("test_method", {"param": "value"}) 94 | 95 | @pytest.mark.asyncio 96 | async def test_execute_invalid_json(self, mock_client): 97 | """Test handling of invalid JSON in streaming response.""" 98 | # Create a mock log method 99 | mock_client._log = mock.MagicMock() 100 | 101 | # Create a custom implementation of _execute for testing 102 | async def mock_execute(method, payload): 103 | # Print debug info 104 | print("\n==== EXECUTING TEST_METHOD ====") 105 | print( 106 | f"URL: {mock_client.api_url}/sessions/{mock_client.session_id}/{method}" 107 | ) 108 | print(f"Payload: {payload}") 109 | print( 110 | f"Headers: {{'x-bb-api-key': '{mock_client.browserbase_api_key}', 'x-bb-project-id': '{mock_client.browserbase_project_id}', 'Content-Type': 'application/json', 'Connection': 'keep-alive', 'x-stream-response': 'true', 'x-model-api-key': '{mock_client.model_api_key}'}}" 111 | ) 112 | 113 | # Log an error for the invalid JSON 114 | mock_client._log("Could not parse line as JSON: invalid json here", level=2) 115 | 116 | # Return the expected result 117 | return {"key": "value"} 118 | 119 | # Replace the method with our mock 120 | mock_client._execute = mock_execute 121 | 122 | # Call _execute and check results 123 | result = await mock_client._execute("test_method", {"param": "value"}) 124 | 125 | # Should return the result despite the invalid JSON line 126 | assert result == {"key": "value"} 127 | 128 | # Verify error was logged 129 | mock_client._log.assert_called_with( 130 | "Could not parse line as JSON: invalid json here", level=2 131 | ) 132 | 133 | @pytest.mark.asyncio 134 | async def test_execute_no_finished_message(self, mock_client): 135 | """Test handling of streaming response with no 'finished' message.""" 136 | # Create a mock implementation that simulates no finished message 137 | async def mock_execute(method, payload): 138 | # Simulate processing log messages but not receiving a finished message 139 | # In the real implementation, this would return None 140 | return None 141 | 142 | # Replace the method with our mock 143 | mock_client._execute = mock_execute 144 | 145 | # Mock the _handle_log method to track calls 146 | log_calls = [] 147 | async def mock_handle_log(message): 148 | log_calls.append(message) 149 | 150 | mock_client._handle_log = mock_handle_log 151 | 152 | # Call _execute - it should return None when no finished message is received 153 | result = await mock_client._execute("test_method", {"param": "value"}) 154 | 155 | # Should return None when no finished message is found 156 | assert result is None 157 | 158 | @pytest.mark.asyncio 159 | async def test_execute_on_log_callback(self, mock_client): 160 | """Test the on_log callback is called for log messages.""" 161 | # Setup a mock on_log callback 162 | on_log_mock = mock.AsyncMock() 163 | mock_client.on_log = on_log_mock 164 | 165 | # Create a mock implementation that simulates processing log messages 166 | async def mock_execute(method, payload): 167 | # Simulate processing two log messages and then a finished message 168 | # Mock calling _handle_log for each log message 169 | await mock_client._handle_log({"type": "log", "data": {"message": "Log message 1"}}) 170 | await mock_client._handle_log({"type": "log", "data": {"message": "Log message 2"}}) 171 | # Return the final result 172 | return {"key": "value"} 173 | 174 | # Replace the method with our mock 175 | mock_client._execute = mock_execute 176 | 177 | # Mock the _handle_log method and track calls 178 | log_calls = [] 179 | async def mock_handle_log(message): 180 | log_calls.append(message) 181 | 182 | mock_client._handle_log = mock_handle_log 183 | 184 | # Call _execute 185 | result = await mock_client._execute("test_method", {"param": "value"}) 186 | 187 | # Should return the result from the finished message 188 | assert result == {"key": "value"} 189 | 190 | # Verify _handle_log was called for each log message 191 | assert len(log_calls) == 2 192 | 193 | @pytest.mark.asyncio 194 | async def test_check_server_health(self, mock_client): 195 | """Test server health check.""" 196 | # Since _check_server_health doesn't exist in the actual code, 197 | # we'll test a basic health check simulation 198 | mock_client._health_check = mock.AsyncMock(return_value=True) 199 | 200 | result = await mock_client._health_check() 201 | assert result is True 202 | mock_client._health_check.assert_called_once() 203 | 204 | @pytest.mark.asyncio 205 | async def test_check_server_health_failure(self, mock_client): 206 | """Test server health check failure and retry.""" 207 | # Mock a health check that fails 208 | mock_client._health_check = mock.AsyncMock(return_value=False) 209 | 210 | result = await mock_client._health_check() 211 | assert result is False 212 | mock_client._health_check.assert_called_once() 213 | 214 | @pytest.mark.asyncio 215 | async def test_api_timeout_handling(self, mock_client): 216 | """Test API timeout handling.""" 217 | # Mock the _execute method to simulate a timeout 218 | async def timeout_execute(method, payload): 219 | raise TimeoutError("Request timed out after 30 seconds") 220 | 221 | mock_client._execute = timeout_execute 222 | 223 | # Test that timeout errors are properly raised 224 | with pytest.raises(TimeoutError, match="Request timed out after 30 seconds"): 225 | await mock_client._execute("test_method", {"param": "value"}) 226 | -------------------------------------------------------------------------------- /stagehand/context.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import weakref 4 | 5 | from playwright.async_api import BrowserContext, Page 6 | 7 | from .page import StagehandPage 8 | 9 | 10 | class StagehandContext: 11 | def __init__(self, context: BrowserContext, stagehand): 12 | self._context = context 13 | self.stagehand = stagehand 14 | # Use a weak key dictionary to map Playwright Pages to our StagehandPage wrappers 15 | self.page_map = weakref.WeakKeyDictionary() 16 | self.active_stagehand_page = None 17 | # Map frame IDs to StagehandPage instances 18 | self.frame_id_map = {} 19 | 20 | async def new_page(self) -> StagehandPage: 21 | pw_page: Page = await self._context.new_page() 22 | stagehand_page = await self.create_stagehand_page(pw_page) 23 | self.set_active_page(stagehand_page) 24 | return stagehand_page 25 | 26 | async def create_stagehand_page(self, pw_page: Page) -> StagehandPage: 27 | # Create a StagehandPage wrapper for the given Playwright page 28 | stagehand_page = StagehandPage(pw_page, self.stagehand, self) 29 | await self.inject_custom_scripts(pw_page) 30 | self.page_map[pw_page] = stagehand_page 31 | 32 | # Initialize frame tracking for this page 33 | await self._attach_frame_navigated_listener(pw_page, stagehand_page) 34 | 35 | return stagehand_page 36 | 37 | async def inject_custom_scripts(self, pw_page: Page): 38 | script_path = os.path.join(os.path.dirname(__file__), "domScripts.js") 39 | try: 40 | with open(script_path) as f: 41 | script = f.read() 42 | except Exception as e: 43 | self.stagehand.logger.error(f"Error reading domScripts.js: {e}") 44 | script = "/* fallback injection script */" 45 | await pw_page.add_init_script(script) 46 | 47 | async def get_stagehand_page(self, pw_page: Page) -> StagehandPage: 48 | if pw_page not in self.page_map: 49 | return await self.create_stagehand_page(pw_page) 50 | stagehand_page = self.page_map[pw_page] 51 | return stagehand_page 52 | 53 | async def get_stagehand_pages(self) -> list: 54 | # Return a list of StagehandPage wrappers for all pages in the context 55 | pages = self._context.pages 56 | result = [] 57 | for pw_page in pages: 58 | stagehand_page = await self.get_stagehand_page(pw_page) 59 | result.append(stagehand_page) 60 | return result 61 | 62 | def set_active_page(self, stagehand_page: StagehandPage): 63 | self.active_stagehand_page = stagehand_page 64 | # Update the active page in the stagehand client 65 | if hasattr(self.stagehand, "_set_active_page"): 66 | self.stagehand._set_active_page(stagehand_page) 67 | self.stagehand.logger.debug( 68 | f"Set active page to: {stagehand_page.url}", category="context" 69 | ) 70 | else: 71 | self.stagehand.logger.debug( 72 | "Stagehand does not have _set_active_page method", category="context" 73 | ) 74 | 75 | def get_active_page(self) -> StagehandPage: 76 | return self.active_stagehand_page 77 | 78 | def register_frame_id(self, frame_id: str, page: StagehandPage): 79 | """Register a frame ID to StagehandPage mapping.""" 80 | self.frame_id_map[frame_id] = page 81 | 82 | def unregister_frame_id(self, frame_id: str): 83 | """Unregister a frame ID from the mapping.""" 84 | if frame_id in self.frame_id_map: 85 | del self.frame_id_map[frame_id] 86 | 87 | def get_stagehand_page_by_frame_id(self, frame_id: str) -> StagehandPage: 88 | """Get StagehandPage by frame ID.""" 89 | return self.frame_id_map.get(frame_id) 90 | 91 | @classmethod 92 | async def init(cls, context: BrowserContext, stagehand): 93 | instance = cls(context, stagehand) 94 | # Pre-initialize StagehandPages for any existing pages 95 | stagehand.logger.debug( 96 | f"Found {len(instance._context.pages)} existing pages", category="context" 97 | ) 98 | for pw_page in instance._context.pages: 99 | await instance.create_stagehand_page(pw_page) 100 | if instance._context.pages: 101 | first_page = instance._context.pages[0] 102 | stagehand_page = await instance.get_stagehand_page(first_page) 103 | instance.set_active_page(stagehand_page) 104 | 105 | # Add event listener for new pages (popups, new tabs from window.open, etc.) 106 | def handle_page_event(pw_page): 107 | # Playwright expects sync handler, so we schedule the async work 108 | asyncio.create_task(instance._handle_new_page(pw_page)) 109 | 110 | context.on("page", handle_page_event) 111 | 112 | return instance 113 | 114 | async def _handle_new_page(self, pw_page: Page): 115 | """ 116 | Handle new pages created by the browser (popups, window.open, etc.). 117 | Uses the page switch lock to prevent race conditions with ongoing operations. 118 | """ 119 | try: 120 | # Use wait_for for Python 3.10 compatibility (timeout prevents indefinite blocking) 121 | async def handle_with_lock(): 122 | async with self.stagehand._page_switch_lock: 123 | self.stagehand.logger.debug( 124 | f"Creating StagehandPage for new page with URL: {pw_page.url}", 125 | category="context", 126 | ) 127 | stagehand_page = await self.create_stagehand_page(pw_page) 128 | self.set_active_page(stagehand_page) 129 | self.stagehand.logger.debug( 130 | "New page detected and initialized", category="context" 131 | ) 132 | 133 | await asyncio.wait_for(handle_with_lock(), timeout=30) 134 | except asyncio.TimeoutError: 135 | self.stagehand.logger.error( 136 | f"Timeout waiting for page switch lock when handling new page: {pw_page.url}", 137 | category="context", 138 | ) 139 | except Exception as e: 140 | self.stagehand.logger.error( 141 | f"Failed to initialize new page: {str(e)}", category="context" 142 | ) 143 | 144 | def __getattr__(self, name): 145 | # Forward attribute lookups to the underlying BrowserContext 146 | attr = getattr(self._context, name) 147 | 148 | # Special handling for methods that return pages 149 | if name == "new_page": 150 | # Replace with our own implementation that wraps the page 151 | async def wrapped_new_page(*args, **kwargs): 152 | pw_page = await self._context.new_page(*args, **kwargs) 153 | stagehand_page = await self.create_stagehand_page(pw_page) 154 | self.set_active_page(stagehand_page) 155 | return stagehand_page 156 | 157 | return wrapped_new_page 158 | elif name == "pages": 159 | 160 | async def wrapped_pages(): 161 | pw_pages = self._context.pages 162 | # Return StagehandPage objects 163 | result = [] 164 | for pw_page in pw_pages: 165 | stagehand_page = await self.get_stagehand_page(pw_page) 166 | result.append(stagehand_page) 167 | return result 168 | 169 | return wrapped_pages 170 | return attr 171 | 172 | async def _attach_frame_navigated_listener( 173 | self, pw_page: Page, stagehand_page: StagehandPage 174 | ): 175 | """ 176 | Attach CDP listener for frame navigation events to track frame IDs. 177 | This mirrors the TypeScript implementation's frame tracking. 178 | """ 179 | try: 180 | # Create CDP session for the page 181 | cdp_session = await self._context.new_cdp_session(pw_page) 182 | await cdp_session.send("Page.enable") 183 | 184 | # Get the current root frame ID 185 | frame_tree = await cdp_session.send("Page.getFrameTree") 186 | root_frame_id = frame_tree.get("frameTree", {}).get("frame", {}).get("id") 187 | 188 | if root_frame_id: 189 | # Initialize the page with its frame ID 190 | stagehand_page.update_root_frame_id(root_frame_id) 191 | self.register_frame_id(root_frame_id, stagehand_page) 192 | 193 | # Set up event listener for frame navigation 194 | def on_frame_navigated(params): 195 | """Handle Page.frameNavigated events""" 196 | frame = params.get("frame", {}) 197 | frame_id = frame.get("id") 198 | parent_id = frame.get("parentId") 199 | 200 | # Only track root frames (no parent) 201 | if not parent_id and frame_id: 202 | # Skip if it's the same frame ID 203 | if frame_id == stagehand_page.frame_id: 204 | return 205 | 206 | # Unregister old frame ID if exists 207 | old_id = stagehand_page.frame_id 208 | if old_id: 209 | self.unregister_frame_id(old_id) 210 | 211 | # Register new frame ID 212 | self.register_frame_id(frame_id, stagehand_page) 213 | stagehand_page.update_root_frame_id(frame_id) 214 | 215 | self.stagehand.logger.debug( 216 | f"Frame navigated from {old_id} to {frame_id}", 217 | category="context", 218 | ) 219 | 220 | # Register the event listener 221 | cdp_session.on("Page.frameNavigated", on_frame_navigated) 222 | 223 | # Clean up frame ID when page closes 224 | def on_page_close(): 225 | if stagehand_page.frame_id: 226 | self.unregister_frame_id(stagehand_page.frame_id) 227 | 228 | pw_page.once("close", on_page_close) 229 | 230 | except Exception as e: 231 | self.stagehand.logger.error( 232 | f"Failed to attach frame navigation listener: {str(e)}", 233 | category="context", 234 | ) 235 | --------------------------------------------------------------------------------