├── .python-version ├── src ├── vectorstore │ ├── __init__.py │ ├── views.py │ ├── base.py │ └── chromadb.py ├── agent │ ├── web │ │ ├── history │ │ │ ├── utils.py │ │ │ ├── views.py │ │ │ └── __init__.py │ │ ├── prompt │ │ │ ├── answer.md │ │ │ ├── action.md │ │ │ ├── observation.md │ │ │ └── system.md │ │ ├── state.py │ │ ├── context │ │ │ ├── script.js │ │ │ ├── views.py │ │ │ ├── config.py │ │ │ └── __init__.py │ │ ├── browser │ │ │ ├── config.py │ │ │ └── __init__.py │ │ ├── utils.py │ │ ├── dom │ │ │ ├── views.py │ │ │ └── __init__.py │ │ ├── tools │ │ │ ├── views.py │ │ │ └── __init__.py │ │ └── __init__.py │ └── __init__.py ├── router │ ├── utils.py │ ├── __init__.py │ └── prompt.md ├── memory │ ├── episodic │ │ ├── utils.py │ │ ├── prompt │ │ │ ├── memory.md │ │ │ ├── add.md │ │ │ ├── replace.md │ │ │ ├── retrieve.md │ │ │ └── update.md │ │ ├── routes.json │ │ ├── views.py │ │ └── __init__.py │ ├── semantic │ │ └── __init__.py │ └── __init__.py ├── tool │ ├── thinking.py │ ├── registry │ │ ├── views.py │ │ └── __init__.py │ └── __init__.py ├── embedding │ ├── __init__.py │ ├── ollama.py │ ├── mistral.py │ └── gemini.py ├── inference │ ├── __init__.py │ ├── anthropic.py │ ├── nvidia.py │ ├── open_router.py │ ├── mistral.py │ ├── ollama.py │ ├── openai.py │ ├── groq.py │ └── gemini.py ├── speech │ └── __init__.py └── message │ └── __init__.py ├── assets └── demo1.mov ├── pyproject.toml ├── main.py ├── setup.py ├── LICENSE ├── README.md ├── .gitignore └── CONTRIBUTING.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | -------------------------------------------------------------------------------- /src/vectorstore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/agent/web/history/utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/demo1.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CursorTouch/Web-Navigator/HEAD/assets/demo1.mov -------------------------------------------------------------------------------- /src/router/utils.py: -------------------------------------------------------------------------------- 1 | def read_markdown_file(file_path: str) -> str: 2 | with open(file_path, 'r',encoding='utf-8') as f: 3 | markdown_content = f.read() 4 | return markdown_content -------------------------------------------------------------------------------- /src/memory/episodic/utils.py: -------------------------------------------------------------------------------- 1 | def read_markdown_file(file_path: str) -> str: 2 | with open(file_path, 'r',encoding='utf-8') as f: 3 | markdown_content = f.read() 4 | return markdown_content -------------------------------------------------------------------------------- /src/agent/web/prompt/answer.md: -------------------------------------------------------------------------------- 1 | ```xml 2 | 3 | {evaluate} 4 | {memory} 5 | {thought} 6 | {final_answer} 7 | 8 | ``` -------------------------------------------------------------------------------- /src/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC,abstractmethod 2 | 3 | class BaseAgent(ABC): 4 | @abstractmethod 5 | def invoke(self,input:str): 6 | pass 7 | @abstractmethod 8 | def stream(self,input:str): 9 | pass -------------------------------------------------------------------------------- /src/agent/web/prompt/action.md: -------------------------------------------------------------------------------- 1 | ```xml 2 | 3 | {evaluate} 4 | {memory} 5 | {thought} 6 | {action_name} 7 | {action_input} 8 | 9 | ``` -------------------------------------------------------------------------------- /src/vectorstore/views.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass,field 2 | from uuid import uuid4 3 | 4 | @dataclass 5 | class Document: 6 | id:str=field(default_factory=lambda: str(uuid4())) 7 | content:str=field(default_factory=str) 8 | metadata:dict=field(default_factory=dict) -------------------------------------------------------------------------------- /src/memory/semantic/__init__.py: -------------------------------------------------------------------------------- 1 | from src.memory import BaseMemory 2 | 3 | class SemanticMemory(BaseMemory): 4 | def store(self, input: str): 5 | pass 6 | 7 | def retrieve(self, input: str): 8 | pass 9 | 10 | def attach_memory(self)->str: 11 | pass -------------------------------------------------------------------------------- /src/memory/episodic/prompt/memory.md: -------------------------------------------------------------------------------- 1 | ### EPISODIC MEMORIES 2 | Following are the revelant memories found based on the past interactions. Use these memories as context to enhance your thought process. 3 | 4 | {memories} 5 | 6 | NOTE: 7 | - Incorporate past experiences to align with the user's current needs and objectives. 8 | - Use memory as a guide to improve reasoning, maintain consistency, and refine responses. 9 | -------------------------------------------------------------------------------- /src/agent/web/state.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.context import BrowserState 2 | from src.agent.web.dom.views import DOMState 3 | from typing import TypedDict,Annotated 4 | from src.message import BaseMessage 5 | from operator import add 6 | 7 | class AgentState(TypedDict): 8 | input:str 9 | output:str 10 | agent_data:dict 11 | prev_observation:str 12 | browser_state:BrowserState|None 13 | dom_state:DOMState|None 14 | messages: Annotated[list[BaseMessage],add] -------------------------------------------------------------------------------- /src/tool/thinking.py: -------------------------------------------------------------------------------- 1 | from src.tool import Tool 2 | from pydantic import BaseModel,Field 3 | 4 | class Thinking(BaseModel): 5 | thought: str=Field(...,description="Your extended thinking goes here") 6 | 7 | @Tool('Thinking Tool',params=Thinking) 8 | async def thinking_tool(thought:str,context=None): 9 | ''' 10 | To think about something. It will not obtain new information or make any changes, but just log the thought. 11 | Use it when complex reasoning or brainstorming is needed. 12 | ''' 13 | return thought -------------------------------------------------------------------------------- /src/tool/registry/views.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel,Field,ConfigDict 2 | from typing import Callable,Type 3 | 4 | class Function(BaseModel): 5 | name:str=Field(...,description="the name of the action") 6 | description:str=Field(...,description="the description of the action") 7 | params:Type[BaseModel]|None 8 | function:Callable|None 9 | model_config=ConfigDict(arbitrary_types_allowed=True) 10 | 11 | class ToolResult(BaseModel): 12 | name: str = Field(...,description="the action taken") 13 | content: str = Field(...,description="the output of the action") -------------------------------------------------------------------------------- /src/memory/episodic/prompt/add.md: -------------------------------------------------------------------------------- 1 | You are asked to generate episodic memory by analyzing conversations to extract key elements for guiding future interactions. Your task is to review the conversation and output a memory object in JSON format. Follow these guidelines: 2 | 3 | 1. Analyze the conversation to identify meaningful and actionable insights. 4 | 2. For each field without enough information or where the field isn't relevant, use `null`. 5 | 3. Be concise and ensure each string is clear and actionable. 6 | 4. Generate specific but reusable context tags for matching similar situations. 7 | 5. Your response should only have one json object. -------------------------------------------------------------------------------- /src/agent/web/context/script.js: -------------------------------------------------------------------------------- 1 | Object.defineProperty(navigator, 'webdriver', { 2 | get: () => undefined 3 | }); 4 | 5 | Object.defineProperty(navigator, 'languages', { 6 | get: () => ['en-US', 'en'] 7 | }); 8 | 9 | Object.defineProperty(navigator, 'plugins', { 10 | get: () => [1, 2, 3, 4, 5], 11 | }); 12 | 13 | window.chrome = { runtime: {} }; 14 | 15 | const originalQuery = window.navigator.permissions.query; 16 | window.navigator.permissions.query = (parameters) => ( 17 | parameters.name === 'notifications' ? 18 | Promise.resolve({ state: Notification.permission }) : 19 | originalQuery(parameters) 20 | ); -------------------------------------------------------------------------------- /src/memory/episodic/prompt/replace.md: -------------------------------------------------------------------------------- 1 | You are asked to generate episodic memory by analyzing conversations to extract key elements for guiding future interactions. Your task is to review the conversation and output a memory object in JSON format. Follow these guidelines: 2 | 3 | 1. Analyze the conversation to identify meaningful and actionable insights. 4 | 2. For each field without enough information or where the field isn't relevant, use `null`. 5 | 3. Be concise and ensure each string is clear and actionable. 6 | 4. Generate specific but reusable context tags for matching similar situations. 7 | 5. Your response should only have one json object. -------------------------------------------------------------------------------- /src/memory/episodic/prompt/retrieve.md: -------------------------------------------------------------------------------- 1 | You are asked to retrieve revelant episodic memories to assist in the current user query. 2 | ### Follow these instructions: 3 | 1. You will be provided with a set of past memories in JSON format. 4 | 2. The user will provide a query, and your goal is to identify and retrieve the most relevant memories that could assist the user in achieving their goal. 5 | 3. Include memories that align similar with the context or goal of the user's query (The underlying methodology is similar). 6 | 4. Output the selected memories as a JSON array. Each memory object should retain its original structure and data. 7 | 8 | ### Memories 9 | {memories} 10 | -------------------------------------------------------------------------------- /src/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from chromadb import Documents,EmbeddingFunction,Embeddings 2 | from abc import ABC,abstractmethod 3 | 4 | class BaseEmbedding(ABC,EmbeddingFunction): 5 | def __init__(self,model:str='',api_key:str='',base_url:str=''): 6 | self.name=self.__class__.__name__.replace('Embedding','') 7 | self.api_key=api_key 8 | self.model=model 9 | self.base_url=base_url 10 | self.headers={'Content-Type': 'application/json'} 11 | 12 | def __call__(self, input:Documents)->Embeddings: 13 | return self.embed(input) 14 | 15 | @abstractmethod 16 | def embed(self,text:list[str]|str)->list: 17 | pass -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "Web-Navigator" 3 | version = "0.1.0" 4 | description = "Web-Navigator is an ai agent for web browsing and scraping websites." 5 | readme = "README.md" 6 | license = { file = "LICENSE" } 7 | requires-python = ">=3.13" 8 | dependencies = [ 9 | "httpx>=0.28.1", 10 | "keyboard>=0.13.5", 11 | "langgraph>=0.6.7", 12 | "markdownify>=1.2.0", 13 | "nest-asyncio>=1.6.0", 14 | "playwright>=1.55.0", 15 | "pyaudio>=0.2.14", 16 | "pydantic>=2.11.9", 17 | "pyperclip>=1.10.0", 18 | "python-dotenv>=1.1.1", 19 | "ratelimit>=2.2.1", 20 | "requests>=2.32.5", 21 | "rich>=14.1.0", 22 | "tenacity>=9.1.2", 23 | "termcolor>=3.1.0", 24 | ] 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.browser.config import BrowserConfig 2 | from src.inference.gemini import ChatGemini 3 | from src.agent.web import Agent 4 | from dotenv import load_dotenv 5 | import os 6 | 7 | load_dotenv() 8 | 9 | api_key = os.getenv('GOOGLE_API_KEY') 10 | browser_instance_dir = os.getenv('BROWSER_INSTANCE_DIR') 11 | user_data_dir = os.getenv('USER_DATA_DIR') 12 | 13 | llm=ChatGemini(model='gemini-2.5-flash-lite',api_key=api_key,temperature=0) 14 | config=BrowserConfig(browser='edge',browser_instance_dir=browser_instance_dir,user_data_dir=user_data_dir,headless=False) 15 | 16 | agent=Agent(config=config,llm=llm,verbose=True,use_vision=True,max_iteration=100) 17 | user_query = input('Enter your query: ') 18 | agent.print_response(user_query) -------------------------------------------------------------------------------- /src/agent/web/history/views.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.dom.views import BoundingBox,CenterCord 2 | from pydantic import BaseModel,Field 3 | from dataclasses import dataclass 4 | 5 | class DOMHistoryElementNode(BaseModel): 6 | tag:str 7 | role:str 8 | name:str 9 | center:CenterCord 10 | bounding_box:BoundingBox 11 | xpath:dict[str,str]=Field(default_factory=dict) 12 | attributes:dict[str,str]=Field(default_factory=dict) 13 | viewport:tuple[int,int]=Field(default_factory=tuple) 14 | 15 | def to_dict(self)->dict[str,str]: 16 | return {'tag':self.tag,'role':self.role,'xpath':self.xpath,'attributes':self.attributes,'bounding_box':self.bounding_box.to_dict()} 17 | 18 | @dataclass 19 | class HashElement: 20 | attributes:str 21 | xpath:str -------------------------------------------------------------------------------- /src/memory/episodic/prompt/update.md: -------------------------------------------------------------------------------- 1 | You are a memory updater tasked with refining and enhancing episodic memories based on new insights from the current conversation. Your role is to analyze the provided relevant memories and the current conversation, updating them to incorporate the new information while preserving their original purpose and clarity. Follow these rules: 2 | 3 | 1. Only update the provided relevant memories with information from the current conversation that adds value or clarity. 4 | 2. Ensure updates are concise, actionable, and maintain the structure of the original memory. 5 | 3. If a field becomes irrelevant or lacks enough information after the update, set it to `null`. 6 | 4. Output all updated memories as a valid JSON array, preserving the format of the input memories. -------------------------------------------------------------------------------- /src/memory/episodic/routes.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "ADD", 4 | "description": "Add the converstion to the memory because relevant memories is empty." 5 | }, 6 | { 7 | "name": "UPDATE", 8 | "description": "Update the relevant memories with information from the conversation if there is new information or approach that can be added to the existing relevant memories." 9 | }, 10 | { 11 | "name": "IDLE", 12 | "description": "Do nothing as the conversation is redundant when compared to the relevant memories in the knowledge base." 13 | }, 14 | { 15 | "name": "REPLACE", 16 | "description": "Replace the relevant memories because it is less valuable, and the conversation provides more significant insights or conversation is a combine of multiple relevant memories." 17 | } 18 | ] -------------------------------------------------------------------------------- /src/agent/web/context/views.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass,field 2 | from playwright.async_api import Page,BrowserContext as PlaywrightBrowserContext 3 | from src.agent.web.dom.views import DOMState 4 | from typing import Optional 5 | 6 | @dataclass 7 | class Tab: 8 | id:int 9 | url:str 10 | title:str 11 | page:Page 12 | 13 | def to_string(self)->str: 14 | return f'{self.id} - Title: {self.title} - URL: {self.url}' 15 | 16 | @dataclass 17 | class BrowserState: 18 | current_tab:Optional[Tab]=None 19 | tabs:list[Tab]=field(default_factory=list) 20 | screenshot:Optional[str]=None 21 | dom_state:DOMState=field(default_factory=DOMState([])) 22 | 23 | def tabs_to_string(self)->str: 24 | return '\n'.join([tab.to_string() for tab in self.tabs]) 25 | 26 | @dataclass 27 | class BrowserSession: 28 | context: PlaywrightBrowserContext 29 | current_page: Page 30 | state: BrowserState -------------------------------------------------------------------------------- /src/embedding/ollama.py: -------------------------------------------------------------------------------- 1 | from requests import RequestException,HTTPError,ConnectionError 2 | from src.embedding import BaseEmbedding 3 | from httpx import Client 4 | from typing import Literal 5 | 6 | class OllamaEmbedding(BaseEmbedding): 7 | def embed(self, text): 8 | url=self.base_url or f'http://localhost:11434/api/embed' 9 | headers=self.headers 10 | payload={ 11 | 'model':self.model, 12 | 'input':text 13 | } 14 | try: 15 | with Client() as client: 16 | response=client.post(url=url,json=payload,headers=headers) 17 | response.raise_for_status() 18 | return response.json()['embeddings'][0] 19 | except HTTPError as err: 20 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 21 | except ConnectionError as err: 22 | print(err) -------------------------------------------------------------------------------- /src/agent/web/prompt/observation.md: -------------------------------------------------------------------------------- 1 | ```xml 2 | 3 | 4 | Current Step: {iteration} 5 | 6 | Max. Steps: {max_iteration} 7 | 8 | Action Response: {observation} 9 | 10 | 11 | [Begin of Tab Info] 12 | Current Tab: {current_tab} 13 | 14 | Open Tabs: 15 | {tabs} 16 | [End of Tab Info] 17 | 18 | [Begin of Viewport] 19 | List of Interactive Elements: 20 | {interactive_elements} 21 | 22 | List of Scrollable Elements: 23 | {scrollable_elements} 24 | 25 | List of Informative Elements: 26 | {informative_elements} 27 | [End of Viewport] 28 | 29 | {query} 30 | 31 | 32 | Note: Use the `Done Tool` if the task is completely over else continue solving. 33 | 34 | 35 | ``` -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('./docs/README.md') as f: 4 | long_description = f.read() 5 | setup( 6 | name='Web-Agent', 7 | version='0.1', 8 | description='The web agent is designed to automate the process of gathering information from the internet, such as to navigate websites, perform searches, and retrieve data.', 9 | author='jeomon', 10 | author_email='jeogeoalukka@gmail', 11 | url='https://github.com/Jeomon/Web-Agent', 12 | packages=find_packages(), 13 | install_requires=[ 14 | 'langgraph', 15 | 'tenacity' 16 | 'requests' 17 | 'playwright' 18 | 'termcolor' 19 | 'python-dotenv' 20 | 'httpx' 21 | 'nest_asyncio' 22 | 'MainContentExtractor' 23 | ], 24 | entry_points={ 25 | 'console_scripts': [ 26 | 'web-agent=main:main' 27 | ] 28 | }, 29 | long_description=long_description, 30 | long_description_content_type='text/markdown', 31 | license='MIT' 32 | ) -------------------------------------------------------------------------------- /src/agent/web/history/__init__.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.history.views import DOMHistoryElementNode, HashElement 2 | from src.agent.web.dom.views import DOMElementNode 3 | from hashlib import sha256 4 | 5 | class History: 6 | 7 | def convert_dom_element_to_history_element(self,element:DOMElementNode)->DOMHistoryElementNode: 8 | return DOMHistoryElementNode(**element.to_dict()) 9 | 10 | def compare_dom_element_with_history_element(self,element:DOMElementNode,history_element:DOMHistoryElementNode)->bool: 11 | hash_dom_element=self.hash_element(element) 12 | hash_history_element=self.hash_element(history_element) 13 | return hash_dom_element==hash_history_element 14 | 15 | def hash_element(self,element:DOMElementNode|DOMHistoryElementNode): 16 | element:dict=element.to_dict() 17 | attributes=sha256(str(element.get('attributes')).encode()).hexdigest() 18 | xpath=sha256(str(element.get('xpath')).encode()).hexdigest() 19 | return HashElement(attributes=attributes,xpath=xpath) 20 | -------------------------------------------------------------------------------- /src/embedding/mistral.py: -------------------------------------------------------------------------------- 1 | from src.embedding import BaseEmbedding 2 | from httpx import Client 3 | from typing import Literal 4 | from requests import RequestException,HTTPError,ConnectionError 5 | import json 6 | 7 | class MistralEmbedding(BaseEmbedding): 8 | def embed(self, text): 9 | url=self.base_url or 'https://api.mistral.ai/v1/embeddings' 10 | self.headers['Authorization'] = f'Bearer {self.api_key}' 11 | headers=self.headers 12 | payload={ 13 | 'model':self.model, 14 | 'input':text, 15 | 'encoding_format':'float' 16 | } 17 | try: 18 | with Client() as client: 19 | response=client.post(url=url,json=payload,headers=headers) 20 | response.raise_for_status() 21 | return response.json()['data']['embedding'] 22 | except HTTPError as err: 23 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 24 | except ConnectionError as err: 25 | print(err) 26 | 27 | -------------------------------------------------------------------------------- /src/vectorstore/base.py: -------------------------------------------------------------------------------- 1 | from src.vectorstore.views import Document 2 | from abc import ABC,abstractmethod 3 | from typing import Sequence 4 | 5 | class BaseVectorStore(ABC): 6 | @abstractmethod 7 | def create_collection(self,collection_name:str)->None: 8 | pass 9 | 10 | @abstractmethod 11 | def insert(self,documents:list[Document])->None: 12 | pass 13 | 14 | @abstractmethod 15 | def search(self,query:str,k:int)->list[Document]: 16 | pass 17 | 18 | @abstractmethod 19 | def delete(self,collection_name:str)->None: 20 | pass 21 | 22 | @abstractmethod 23 | def update(self,id:str,content:str,metadata:dict)->None: 24 | pass 25 | 26 | @abstractmethod 27 | def delete_collection(self,collection_name:str)->None: 28 | pass 29 | 30 | @abstractmethod 31 | def get(self,id:str)->Document: 32 | pass 33 | 34 | @abstractmethod 35 | def all(self)->list[Document]: 36 | pass 37 | 38 | @abstractmethod 39 | def all_collections(self)->Sequence: 40 | pass 41 | 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 CursorTouch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/router/__init__.py: -------------------------------------------------------------------------------- 1 | from src.message import HumanMessage,SystemMessage 2 | from src.router.utils import read_markdown_file 3 | from src.inference import BaseInference 4 | from json import dumps 5 | 6 | class LLMRouter: 7 | def __init__(self,instructions:list[str]=[],routes:list[dict]=[],llm:BaseInference=None,verbose=False): 8 | self.system_prompt=read_markdown_file('./src/router/prompt.md') 9 | self.instructions=self.__get_instructions(instructions) 10 | self.routes=dumps(routes,indent=2) 11 | self.llm=llm 12 | self.verbose=verbose 13 | 14 | def __get_instructions(self,instructions): 15 | return '\n'.join([f'{i+1}. {instruction}' for i,instruction in enumerate(instructions)]) 16 | 17 | def invoke(self,query:str)->str: 18 | parameters={'instructions':self.instructions,'routes':self.routes} 19 | messages=[SystemMessage(self.system_prompt.format(**parameters)),HumanMessage(query)] 20 | response=self.llm.invoke(messages,json=True) 21 | route=response.content.get('route') 22 | if self.verbose: 23 | print(f"Going to {route.upper()} route") 24 | return route 25 | -------------------------------------------------------------------------------- /src/agent/web/browser/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal 3 | from pathlib import Path 4 | 5 | @dataclass 6 | class BrowserConfig: 7 | headless:bool=False 8 | wss_url:str=None 9 | device:str=None 10 | browser_instance_dir:str=None 11 | downloads_dir:str=(Path.home()/'Downloads').as_posix() 12 | browser:Literal['chrome','firefox','edge']='edge' 13 | user_data_dir:str=None 14 | timeout:int=60*1000 15 | slow_mo:int=300 16 | 17 | SECURITY_ARGS = [ 18 | '--disable-web-security', 19 | '--disable-site-isolation-trials', 20 | '--disable-features=IsolateOrigins,site-per-process', 21 | ] 22 | 23 | BROWSER_ARGS=[ 24 | '--disable-sandbox', 25 | '--enable-blink-features=IdleDetection', 26 | '--disable-blink-features=AutomationControlled', 27 | '--disable-infobars', 28 | '--disable-background-timer-throttling', 29 | '--disable-popup-blocking', 30 | '--disable-backgrounding-occluded-windows', 31 | '--disable-renderer-backgrounding', 32 | '--disable-window-activation', 33 | '--disable-focus-on-load', 34 | '--no-first-run', 35 | '--no-default-browser-check', 36 | '--no-startup-window', 37 | '--window-position=0,0', 38 | '--remote-debugging-port=9222' 39 | ] 40 | 41 | IGNORE_DEFAULT_ARGS=['--enable-automation'] -------------------------------------------------------------------------------- /src/memory/episodic/views.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel,Field 2 | from uuid import uuid4 3 | 4 | class Memory(BaseModel): 5 | id: str = Field(description='The id of the memory to be filled by the user',examples=['32cd40b6-5db1-48b3-9434-ebd9502f06f0'],default_factory=lambda: str(uuid4())) 6 | tags: list[str] = Field(...,description='Tags to help identify similar future conversations.',examples=[['google','weather']]) 7 | summary: str = Field(...,description='Describes what the conversation accomplished.') 8 | what_worked: str = Field(...,description='Highlights the most effective strategy used.') 9 | what_to_avoid: str = Field(...,description='Describes the important pitfalls to avoid.') 10 | 11 | def to_dict(self): 12 | return self.model_dump() 13 | 14 | class Config: 15 | extra = 'allow' 16 | 17 | class Memories(BaseModel): 18 | memories: list[Memory]=Field(description='The list of memories',default_factory=list) 19 | 20 | def model_dump(self, *args, **kwargs): 21 | return super().model_dump(*args, **kwargs)["memories"] 22 | 23 | def all(self): 24 | return [memory.to_dict() for memory in self.memories] 25 | 26 | def to_string(self): 27 | return '\n\n'.join([f'**Tags:** {memory.tags}\n***Summary:** {memory.summary}\n**What Worked:** {memory.what_worked}\n**What to Avoid:** {memory.what_to_avoid}' for memory in self.memories]) 28 | 29 | -------------------------------------------------------------------------------- /src/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC,abstractmethod 2 | from src.inference import BaseInference 3 | from src.message import BaseMessage,SystemMessage 4 | import json 5 | import os 6 | 7 | class BaseMemory(ABC): 8 | def __init__(self,knowledge_base:str='knowledge_base.json',llm:BaseInference=None,verbose=False): 9 | self.llm=llm 10 | self.verbose=verbose 11 | self.knowledge_base=knowledge_base 12 | self.__initialize_memory() 13 | 14 | @abstractmethod 15 | def store(self,conversation:list[BaseMessage])->None: 16 | pass 17 | 18 | @abstractmethod 19 | def retrieve(self,query:str)->list[dict]: 20 | pass 21 | 22 | @abstractmethod 23 | def attach_memory(self)->str: 24 | pass 25 | 26 | def __initialize_memory(self): 27 | if not os.path.exists(f'./memory_data/{self.knowledge_base}'): 28 | os.makedirs('./memory_data',exist_ok=True) 29 | with open(f'./memory_data/{self.knowledge_base}','w') as f: 30 | f.write(json.dumps(self.memories,indent=2)) 31 | else: 32 | with open(f'./memory_data/{self.knowledge_base}','r') as f: 33 | self.memories=json.loads(f.read()) 34 | def conversation_to_text(self,conversation:list[BaseMessage]): 35 | conversation=list(self.__filter_conversation(conversation)) 36 | return '\n'.join([f'{message.role}: {message.content}' for message in conversation]) 37 | def __filter_conversation(self,conversation:list[BaseMessage]): 38 | return filter(lambda message: not isinstance(message,SystemMessage),conversation) -------------------------------------------------------------------------------- /src/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,SystemMessage 2 | from abc import ABC,abstractmethod 3 | from pydantic import BaseModel 4 | from typing import Optional 5 | from src.tool import Tool 6 | 7 | class Token(BaseModel): 8 | input: Optional[int]=None 9 | output: Optional[int]=None 10 | cache: Optional[int]=None 11 | total: Optional[int]=None 12 | 13 | structured_output_prompt=''' 14 | Integrate the JSON output as part of the structured response, ensuring it strictly follows the provided schema. 15 | ```json 16 | {json_schema} 17 | ``` 18 | Validate all fields, use `null` or empty values for missing data, and format the JSON in a clear, indented code block. 19 | ''' 20 | 21 | class BaseInference(ABC): 22 | def __init__(self,model:str,api_key:str='',base_url:str='',tools:list[Tool]=[],temperature:float=0.5): 23 | self.model=model 24 | self.api_key=api_key 25 | self.base_url=base_url 26 | self.tools=tools 27 | self.temperature=temperature 28 | self.headers={'Content-Type': 'application/json'} 29 | self.structured_output_prompt=structured_output_prompt 30 | self.tokens:Token=Token(input=0,output=0,total=0) 31 | 32 | @abstractmethod 33 | def invoke(self,messages:list[dict],json:bool=False,model:BaseModel=None)->AIMessage|BaseModel: 34 | pass 35 | 36 | @abstractmethod 37 | async def async_invoke(self,messages:list[dict],json:bool=False,model:BaseModel=None)->AIMessage|BaseModel: 38 | pass 39 | 40 | @abstractmethod 41 | def stream(self,messages:list[dict],json:bool=False)->AIMessage: 42 | pass 43 | 44 | def structured(self,message:SystemMessage,model:BaseModel): 45 | return f'{message.content}\n{structured_output_prompt.format(json_schema=model.model_json_schema())}' 46 | 47 | -------------------------------------------------------------------------------- /src/agent/web/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ast 3 | 4 | def read_markdown_file(file_path: str) -> str: 5 | with open(file_path, 'r',encoding='utf-8') as f: 6 | markdown_content = f.read() 7 | return markdown_content 8 | 9 | def extract_agent_data(text): 10 | # Dictionary to store extracted values 11 | result = {} 12 | # Extract Memory 13 | evaluate_memory = re.search(r"(.*?)<\/Memory>", text, re.DOTALL) 14 | if evaluate_memory: 15 | result['Memory'] = evaluate_memory.group(1).strip() 16 | # Extract Evaluate 17 | evaluate_match = re.search(r"(.*?)<\/Evaluate>", text, re.DOTALL) 18 | if evaluate_match: 19 | result['Evaluate'] = evaluate_match.group(1).strip() 20 | # Extract Thought 21 | thought_match = re.search(r"(.*?)<\/Thought>", text, re.DOTALL) 22 | if thought_match: 23 | result['Thought'] = thought_match.group(1).strip() 24 | # Extract Action-Name 25 | action_name_match = re.search(r"(.*?)<\/Action-Name>", text, re.DOTALL) 26 | if action_name_match: 27 | result['Action Name'] = action_name_match.group(1).strip() 28 | # Extract and convert Action-Input to a dictionary 29 | action_input_match = re.search(r"(.*?)<\/Action-Input>", text, re.DOTALL) 30 | if action_input_match: 31 | action_input_str = action_input_match.group(1).strip() 32 | try: 33 | # Convert string to dictionary safely using ast.literal_eval 34 | result['Action Input'] = ast.literal_eval(action_input_str.replace('null', 'None').replace('true', 'True').replace('false', 'False')) 35 | except (ValueError, SyntaxError): 36 | # If there's an issue with conversion, store it as raw string 37 | result['Action Input'] = action_input_str 38 | return result 39 | -------------------------------------------------------------------------------- /src/router/prompt.md: -------------------------------------------------------------------------------- 1 | ### **LLM Router** 2 | You are an advanced intelligent LLM Router responsible for determining the most accurate route for a given user query. Your primary task is to analyze the query, reason about its complexity, and map it to the most appropriate route from the available routes. 3 | 4 | **Instructions (optional):** 5 | {instructions} 6 | 7 | **Available Routes:** 8 | {routes} 9 | 10 | --- 11 | 12 | ### **Enhanced Reasoning and Decision-Making Process**: 13 | 0. **Instructions Priority**: If instructions are provided, they must be given top priority. Always refer to the instructions before making any decisions. 14 | 15 | 1. **Thorough Query Understanding**: Analyze the query to capture nuances, objectives, and any hidden complexities. 16 | 17 | 2. **Route Comparison**: Use detailed reasoning to compare the query against the available route descriptions. Ensure you consider both simple and advanced requirements within the query. 18 | 19 | 3. **Contextual Mapping**: Factor in the user’s intention and potential requirements (e.g., tool access, complex reasoning, or multiple steps) before choosing a route. 20 | 21 | 4. **Complex Scenario Handling**: In cases of ambiguity or complex queries, apply advanced reasoning by weighing potential routes based on their descriptions and the needs of the query. 22 | 23 | 5. **Judgment Enhancement**: Use a higher level of reasoning to ensure that no errors in routing occur, especially in situations where the query is multifaceted or when multiple routes seem plausible. 24 | 25 | 6. **Avoid Redundancy**: Avoid mapping the query to multiple routes unless explicitly stated, ensuring that tasks are distinct and non-overlapping. 26 | 27 | 7. **Confidence and Precision**: Always make confident and precise routing decisions, ensuring that the final output matches the query's requirements perfectly. 28 | 29 | --- 30 | 31 | ### **Response Format**: 32 | Your task is to return the correct route based on the query in the following JSON format: 33 | 34 | ```json 35 | {{ 36 | "route": "the route name goes over here" 37 | }} 38 | ``` -------------------------------------------------------------------------------- /src/vectorstore/chromadb.py: -------------------------------------------------------------------------------- 1 | from src.vectorstore.base import BaseVectorStore,Document 2 | from src.embedding import BaseEmbedding 3 | from chromadb.config import Settings 4 | from chromadb import Client 5 | from pathlib import Path 6 | 7 | class ChromaDBVectorStore(BaseVectorStore): 8 | def __init__(self,collection_name:str,embedding:BaseEmbedding,path:Path=Path.cwd()/'.chroma'): 9 | self.settings=Settings(anonymized_telemetry=False,is_persistent=True,persist_directory=path.as_posix()) 10 | self.client=Client(settings=self.settings) 11 | self.db=self.create_collection(collection_name=collection_name,embedding=embedding) 12 | 13 | def create_collection(self,collection_name:str,embedding:BaseEmbedding=None): 14 | return self.client.get_or_create_collection(name=collection_name,embedding_function=embedding) 15 | 16 | def insert(self,documents:list[Document]): 17 | ids=[doc.id for doc in documents] 18 | contents=[doc.content for doc in documents] 19 | metadatas=[doc.metadata for doc in documents] 20 | self.db.add(ids=ids,documents=contents,metadatas=metadatas) 21 | 22 | def search(self, query:str, k=5): 23 | return self.db.query(query_texts=[query],n_results=k) 24 | 25 | def update(self, id:str, content:str, metadata:dict): 26 | self.db.update(ids=[id],documents=[content],metadatas=[metadata]) 27 | 28 | def delete(self, id): 29 | self.db.delete(ids=[id]) 30 | 31 | def get(self, id): 32 | response = self.db.get(ids=[id]) 33 | return self.parse_db_response(response) 34 | 35 | def delete_collection(self, collection_name): 36 | self.client.delete_collection(name=collection_name) 37 | 38 | def all_collections(self): 39 | return self.client.list_collections() 40 | 41 | def all(self): 42 | response=self.db.get() 43 | return self.parse_db_response(response) 44 | 45 | def parse_db_response(self, response: dict) -> list[Document]: 46 | ids = response.get('ids', []) 47 | documents = response.get('documents', []) 48 | metadatas = response.get('metadatas', []) 49 | result = [] 50 | for i in range(len(ids)): 51 | doc = Document( 52 | id=ids[i], 53 | content=documents[i] if i < len(documents) else '', 54 | metadata=metadatas[i] if i < len(metadatas) else {} 55 | ) 56 | result.append(doc) 57 | return result 58 | -------------------------------------------------------------------------------- /src/agent/web/context/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass,field 2 | from typing import Optional,Any 3 | 4 | @dataclass 5 | class ContextConfig: 6 | credentials:dict[str,Any]=field(default_factory=dict) 7 | minimum_wait_page_load_time:float=0.5 8 | wait_for_network_idle_page_load_time:float=1 9 | maximum_wait_page_load_time:float=5 10 | disable_security:bool=True 11 | user_agent:str="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36" 12 | 13 | 14 | RELEVANT_FILE_EXTENSIONS = set([ 15 | '.pdf','.doc','.docx','.xls', 16 | '.xlsx','.ppt','.pptx','.txt', 17 | '.csv','.json','.png','.jpg', 18 | '.jpeg','.gif','.svg','.zip', 19 | '.rar','.7z','.tar','.gz', 20 | '.bz2','.mp3','.mp4','.wav', 21 | '.ogg','.flac','.webm','.mp4', 22 | '.avi','.mkv','.mov','.wmv', 23 | '.mpg','.mpeg','.m4v','.3gp', 24 | ]) 25 | 26 | RELEVANT_CONTEXT_TYPES =set([ 27 | #Document Files 28 | 'application/x-7z-compressed', 29 | 'application/zip', 30 | 'application/x-rar-compressed', 31 | 'application/x-iso9660-image', 32 | 'application/x-tar', 33 | 'application/x-gzip', 34 | 'application/x-bzip2', 35 | 'application/vnd.ms-excel', 36 | 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 37 | 'application/vnd.ms-powerpoint', 38 | 'application/vnd.openxmlformats-officedocument.presentationml.presentation', 39 | 'application/pdf', 40 | 'application/msword', 41 | 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', 42 | 'application/vnd.oasis.opendocument.text', 43 | #Audio Files 44 | 'audio/mpeg','audio/wav','audio/mp3','audio/ogg','audio/flac','audio/webm', 45 | #Video Files 46 | 'video/mp4','video/ogg','video/webm','video/quicktime', 47 | #Image Files 48 | 'image/jpeg','image/png','image/gif','image/bmp','image/svg+xml' 49 | ]) 50 | 51 | RELEVANT_RESOURCE_TYPES = [ 52 | 'document', 53 | 'stylesheet', 54 | 'image', 55 | 'font', 56 | 'script', 57 | 'iframe', 58 | ] 59 | 60 | IGNORED_URL_PATTERNS = set([ 61 | 'analytics', 62 | 'tracking', 63 | 'telemetry', 64 | 'googletagmanager', 65 | 'beacon', 66 | 'metrics', 67 | 'doubleclick', 68 | 'adsystem', 69 | 'adserver', 70 | 'advertising', 71 | 'cdn.optimizely', 72 | 'facebook.com/plugins', 73 | 'platform.twitter', 74 | 'linkedin.com/embed', 75 | 'livechat', 76 | 'zendesk', 77 | 'intercom', 78 | 'crisp.chat', 79 | 'hotjar', 80 | 'push-notifications', 81 | 'onesignal', 82 | 'pushwoosh', 83 | 'heartbeat', 84 | 'ping', 85 | 'alive', 86 | 'webrtc', 87 | 'rtmp://', 88 | 'wss://', 89 | 'cloudfront.net', 90 | 'fastly.net' 91 | ]) -------------------------------------------------------------------------------- /src/speech/__init__.py: -------------------------------------------------------------------------------- 1 | from src.inference import BaseInference 2 | from pyaudio import PyAudio,paInt16,Stream 3 | from tempfile import NamedTemporaryFile 4 | from src.message import AIMessage 5 | import keyboard 6 | import wave 7 | import os 8 | 9 | class Speech: 10 | def __init__(self,llm:BaseInference=None): 11 | self.chunk_size=1024 12 | self.frame_rate=44100 13 | self.channels=1 14 | self.audio=PyAudio() 15 | self.stream=None 16 | self.llm=llm 17 | self.tempfile_path='' 18 | 19 | def setup_stream(self): 20 | audio=self.audio 21 | self.stream=audio.open(**{ 22 | 'format':paInt16, 23 | 'channels':self.channels, 24 | 'rate':self.frame_rate, 25 | 'input':True, 26 | 'frames_per_buffer':self.chunk_size 27 | }) 28 | 29 | def get_stream(self)->Stream: 30 | if self.stream is None: 31 | self.setup_stream() 32 | return self.stream 33 | 34 | def record_audio(self)->bytes: 35 | stream=self.get_stream() 36 | frames=[] 37 | is_recording=True 38 | print('Recording audio...') 39 | print('Press enter to stop recording...') 40 | while is_recording: 41 | data=stream.read(self.chunk_size) 42 | frames.append(data) 43 | if keyboard.is_pressed('enter'): 44 | is_recording=False 45 | stream.stop_stream() 46 | print('Recording stopped...') 47 | return b''.join(frames) 48 | 49 | def bytes_to_tempfile(self, bytes: bytes): 50 | temp_file = NamedTemporaryFile(delete=False, suffix='.wav') 51 | self.tempfile_path = temp_file.name 52 | temp_file.close() 53 | try: 54 | with wave.open(self.tempfile_path, 'wb') as wf: 55 | wf.setnchannels(self.channels) 56 | wf.setsampwidth(self.audio.get_sample_size(paInt16)) 57 | wf.setframerate(self.frame_rate) 58 | wf.writeframes(bytes) 59 | except Exception as e: 60 | raise Exception(f"Export failed. {e}") 61 | 62 | def close(self): 63 | if self.stream is not None: 64 | self.stream.close() 65 | if self.audio is not None: 66 | self.audio.terminate() 67 | self.stream=None 68 | self.audio=None 69 | os.remove(self.tempfile_path) 70 | 71 | def invoke(self)->AIMessage: 72 | audio_bytes=self.record_audio() 73 | self.bytes_to_tempfile(audio_bytes) 74 | print(f'Using {self.llm.model} audio to text...') 75 | response=self.llm.invoke(file_path=self.tempfile_path) 76 | self.close() 77 | return response 78 | 79 | -------------------------------------------------------------------------------- /src/message/__init__.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from abc import ABC 3 | import requests 4 | import base64 5 | import re 6 | 7 | class BaseMessage(ABC): 8 | def to_dict(self)->dict[str,str]: 9 | return { 10 | 'role': self.role, 11 | 'content': f'''{self.content}''' 12 | } 13 | def __repr__(self): 14 | class_name = self.__class__.__name__ 15 | attributes = ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) 16 | return f"{class_name}({attributes})" 17 | 18 | class HumanMessage(BaseMessage): 19 | def __init__(self,content): 20 | self.role='user' 21 | self.content=content 22 | 23 | class AIMessage(BaseMessage): 24 | def __init__(self,content): 25 | self.role='assistant' 26 | self.content=content 27 | 28 | class SystemMessage(BaseMessage): 29 | def __init__(self,content): 30 | self.role='system' 31 | self.content=content 32 | 33 | class ImageMessage(BaseMessage): 34 | def __init__(self,text:str=None,image_path:str=None,image_obj:str=None): 35 | self.role='user' 36 | if image_obj is not None or image_path is None: 37 | self.content=(text,self.__encoder(image_obj)) 38 | elif image_path is not None or image_obj is None: 39 | self.content=(text,self.__image_to_base64(image_path)) 40 | else: 41 | raise Exception('image_path and image_base_64 cannot be both None or both not None') 42 | 43 | def __is_url(self,image_path:str)->bool: 44 | url_pattern = re.compile(r'^https?://') 45 | return url_pattern.match(image_path) is not None 46 | 47 | def __is_file_path(self,image_path:str)->bool: 48 | file_path_pattern = re.compile(r'^([./~]|([a-zA-Z]:)|\\|//)?\.?\/?[a-zA-Z0-9._-]+(\.[a-zA-Z0-9]+)?$') 49 | return file_path_pattern.match(image_path) is not None 50 | 51 | def __image_to_base64(self,image_source: str) -> str: 52 | if self.__is_url(image_source): 53 | response = requests.get(image_source) 54 | bytes = BytesIO(response.content) 55 | image_bytes = bytes.read() 56 | elif self.__is_file_path(image_source): 57 | with open(image_source, 'rb') as image: 58 | image_bytes = image.read() 59 | else: 60 | raise ValueError("Invalid image source. Must be a URL or file path.") 61 | return base64.b64encode(image_bytes).decode('utf-8') 62 | 63 | def __encoder(self,b:bytes): 64 | return base64.b64encode(b).decode('utf-8') 65 | 66 | class ToolMessage(BaseMessage): 67 | def __init__(self,id:str,name:str,args:dict): 68 | self.id=id 69 | self.role='tool' 70 | self.name=name 71 | self.args=args -------------------------------------------------------------------------------- /src/tool/__init__.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Optional,Callable 3 | from inspect import getdoc 4 | from json import dumps 5 | 6 | class Tool: 7 | def __init__(self, name: str='',description: Optional[str]=None, params: Optional[BaseModel]=None,schema:Optional[dict]=None,func:Optional[Callable]=None): 8 | self.name = name 9 | self.params = params 10 | self.func = func 11 | self.description = description 12 | self.schema = schema 13 | 14 | def __call__(self, func): 15 | if self.params: 16 | # Store the decorated function and its metadata 17 | self.description = self.description or getdoc(func) 18 | skip_keys=['title'] 19 | if self.params is not None: 20 | self.schema = {k:{term:content for term,content in v.items() if term not in skip_keys} for k,v in self.params.model_json_schema().get('properties').items() if k not in skip_keys} 21 | elif self.schema is not None: 22 | self.schema = {k:{term:content for term,content in v.items() if term not in skip_keys} for k,v in self.schema.get('properties').items() if k not in skip_keys} 23 | self.func = func 24 | return self # Return the Tool Instance 25 | 26 | def invoke(self, **kwargs): 27 | # Validate inputs using the schema and invoke the wrapped function 28 | try: 29 | if self.params: 30 | args = self.params(**kwargs) # Validate arguments 31 | return self.func(**args.dict()) # Call the function with validated arg 32 | else: 33 | return self.func(**kwargs) 34 | except Exception as e: 35 | return f"Error: {str(e)}" 36 | 37 | async def async_invoke(self, **kwargs): 38 | # Validate inputs using the schema and invoke the wrapped function 39 | try: 40 | if self.params: 41 | args = self.params(**kwargs) # Validate arguments 42 | return await self.func(**args.dict()) # Call the function with validated arg 43 | else: 44 | return await self.func(**kwargs) 45 | except Exception as e: 46 | return f"Error: {str(e)}" 47 | 48 | def __repr__(self): 49 | if self.params is not None: 50 | params=list(self.params.model_json_schema().get('properties').keys()) 51 | elif self.schema is not None: 52 | params=list(self.schema.get('properties').keys()) 53 | return f"Tool(name={self.name}, description={self.description}, params={params})" 54 | 55 | def get_prompt(self): 56 | return f'''Tool Name: {self.name}\nTool Description: {self.description}\nTool Input: {dumps(self.schema,indent=2)}''' -------------------------------------------------------------------------------- /src/embedding/gemini.py: -------------------------------------------------------------------------------- 1 | from requests import RequestException,HTTPError,ConnectionError 2 | from src.embedding import BaseEmbedding 3 | from httpx import Client 4 | from typing import Literal 5 | 6 | class GeminiEmbedding(BaseEmbedding): 7 | def __init__(self,model:str='',output_dimensionality:int=None,task_type:Literal['TASK_TYPE_UNSPECIFIED','RETRIEVAL_QUERY','RETRIEVAL_DOCUMENT','SEMANTIC_SIMILARITY','CLASSIFICATION','CLUSTERING']='',api_key:str='',base_url:str=''): 8 | self.api_key=api_key 9 | self.model=model 10 | self.base_url=base_url 11 | self.output_dimensionality=output_dimensionality 12 | self.task_type=task_type 13 | self.headers={'Content-Type': 'application/json'} 14 | def embed(self,text:list[str]|str='',title:str=''): 15 | headers=self.headers 16 | if isinstance(text,list): 17 | mode='batchEmbedContents' 18 | else: 19 | mode='embedContent' 20 | url=self.base_url or f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:{mode}" 21 | params={'key':self.api_key} 22 | if isinstance(text,list): 23 | payload={ 24 | 'requests':[ 25 | { 26 | 'model':f'models/{self.model}', 27 | 'content':{ 28 | 'parts':[ 29 | { 30 | 'text':_text 31 | } 32 | ] 33 | } 34 | } 35 | for _text in text] 36 | } 37 | else: 38 | payload={ 39 | 'model':f'models/{self.model}', 40 | 'content':{ 41 | 'parts':[ 42 | { 43 | 'text':text 44 | } 45 | ] 46 | } 47 | } 48 | if self.task_type: 49 | payload['task_type']=self.task_type 50 | if self.output_dimensionality: 51 | payload['output_dimensionality']=self.output_dimensionality 52 | if title: 53 | payload['title']=title 54 | try: 55 | with Client() as client: 56 | response=client.post(url=url,json=payload,headers=headers,params=params) 57 | response.raise_for_status() 58 | if isinstance(text,list): 59 | data=response.json() 60 | return [e['values'] for e in data['embeddings']] 61 | else: 62 | data=response.json() 63 | return data['embedding']['values'] 64 | except HTTPError as err: 65 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 66 | except ConnectionError as err: 67 | print(err) -------------------------------------------------------------------------------- /src/agent/web/browser/__init__.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.browser.config import BrowserConfig,BROWSER_ARGS,SECURITY_ARGS,IGNORE_DEFAULT_ARGS 2 | from playwright.async_api import async_playwright,Browser as PlaywrightBrowser,Playwright 3 | 4 | class Browser: 5 | def __init__(self,config:BrowserConfig=None): 6 | self.playwright:Playwright = None 7 | self.config = config if config else BrowserConfig() 8 | self.playwright_browser:PlaywrightBrowser = None 9 | 10 | async def __aenter__(self): 11 | await self.init_browser() 12 | return self 13 | 14 | async def __aexit__(self, exc_type, exc_val, exc_tb): 15 | await self.close_browser() 16 | 17 | async def init_browser(self): 18 | self.playwright =await async_playwright().start() 19 | self.playwright_browser = await self.setup_browser(self.config.browser) 20 | 21 | async def get_playwright_browser(self)->PlaywrightBrowser: 22 | if self.playwright_browser is None: 23 | await self.init_browser() 24 | return self.playwright_browser 25 | 26 | async def setup_browser(self,browser:str)->PlaywrightBrowser: 27 | parameters={ 28 | 'headless':self.config.headless, 29 | 'downloads_path':self.config.downloads_dir, 30 | 'timeout':self.config.timeout, 31 | 'slow_mo':self.config.slow_mo, 32 | 'args':BROWSER_ARGS + SECURITY_ARGS, 33 | 'ignore_default_args': IGNORE_DEFAULT_ARGS 34 | } 35 | if self.config.wss_url is not None: 36 | if browser=='chrome': 37 | browser_instance=await self.playwright.chromium.connect(self.config.wss_url) 38 | elif browser=='firefox': 39 | browser_instance=await self.playwright.firefox.connect(self.config.wss_url) 40 | elif browser=='edge': 41 | browser_instance=await self.playwright.chromium.connect(self.config.wss_url) 42 | else: 43 | raise Exception('Invalid Browser Type') 44 | elif self.config.browser_instance_dir is not None: 45 | browser_instance=None 46 | else: 47 | if self.config.device is not None: 48 | parameters={**self.playwright.devices.get(self.config.device)} 49 | parameters.pop('default_browser_type',None) 50 | if browser=='chrome': 51 | browser_instance=await self.playwright.chromium.launch(channel='chrome',**parameters) 52 | elif browser=='firefox': 53 | browser_instance=await self.playwright.firefox.launch(**parameters) 54 | elif browser=='edge': 55 | browser_instance=await self.playwright.chromium.launch(channel='msedge',**parameters) 56 | else: 57 | raise Exception('Invalid Browser Type') 58 | return browser_instance 59 | 60 | async def close_browser(self): 61 | try: 62 | if self.playwright_browser: 63 | await self.playwright_browser.close() 64 | if self.playwright: 65 | await self.playwright.stop() 66 | except Exception as e: 67 | print('Browser failed to close') 68 | finally: 69 | self.playwright=None 70 | self.playwright_browser=None -------------------------------------------------------------------------------- /src/agent/web/dom/views.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass,field 2 | from textwrap import shorten 3 | 4 | @dataclass 5 | class BoundingBox: 6 | left:int 7 | top:int 8 | width:int 9 | height:int 10 | 11 | def to_string(self): 12 | return f'({self.left},{self.top},{self.width},{self.height})' 13 | 14 | def to_dict(self): 15 | return {'left':self.left,'top':self.top,'width':self.width,'height':self.height} 16 | 17 | @dataclass 18 | class CenterCord: 19 | x:int 20 | y:int 21 | 22 | def to_string(self)->str: 23 | return f'({self.x},{self.y})' 24 | 25 | def to_dict(self): 26 | return {'x':self.x,'y':self.y} 27 | 28 | @dataclass 29 | class DOMElementNode: 30 | tag: str 31 | role: str 32 | name: str 33 | bounding_box: BoundingBox 34 | center: CenterCord 35 | attributes: dict[str,str] = field(default_factory=dict) 36 | xpath: dict[str,str]=field(default_factory=dict) 37 | viewport: tuple[int,int]=field(default_factory=tuple) 38 | 39 | def __repr__(self): 40 | return f"DOMElementNode(tag='{self.tag}', role='{self.role}', name='{self.name}', attributes={self.attributes}, cordinates={self.center}, bounding_box={self.bounding_box}, xpath='{self.xpath}')" 41 | 42 | def to_dict(self)->dict[str,str]: 43 | return {'tag':self.tag,'role':self.role,'name':self.name,'bounding_box':self.bounding_box.to_dict(),'attributes':self.attributes, 'cordinates':self.center.to_dict()} 44 | 45 | @dataclass 46 | class ScrollElementNode: 47 | tag: str 48 | role: str 49 | name: str 50 | attributes: dict[str,str] = field(default_factory=dict) 51 | xpath: dict[str,str]=field(default_factory=dict) 52 | viewport: tuple[int,int]=field(default_factory=tuple) 53 | 54 | def __repr__(self): 55 | return f"ScrollableElementNode(tag='{self.tag}', role='{self.role}', name='{shorten(self.name,width=500)}', attributes={self.attributes}, xpath='{self.xpath}')" 56 | 57 | def to_dict(self)->dict[str,str]: 58 | return {'tag':self.tag,'role':self.role,'name':self.name,'attributes':self.attributes} 59 | 60 | @dataclass 61 | class DOMTextualNode: 62 | tag:str 63 | role:str 64 | content:str 65 | center: CenterCord 66 | xpath: dict[str,str]=field(default_factory=dict) 67 | viewport: tuple[int,int]=field(default_factory=tuple) 68 | 69 | def __repr__(self): 70 | return f'DOMTextualNode(tag={self.tag}, role={self.role}, content={self.content}, cordinates={self.center}, xpath={self.xpath})' 71 | 72 | def to_dict(self)->dict[str,str]: 73 | return {'tag':self.tag,'role':self.role,'content':self.content, 'center':self.center.to_dict()} 74 | 75 | @dataclass 76 | class DOMState: 77 | interactive_nodes: list[DOMElementNode]=field(default_factory=list) 78 | informative_nodes:list[DOMTextualNode]=field(default_factory=list) 79 | scrollable_nodes:list[ScrollElementNode]=field(default_factory=list) 80 | selector_map: dict[str,DOMElementNode|ScrollElementNode]=field(default_factory=dict) 81 | 82 | def interactive_elements_to_string(self)->str: 83 | return '\n'.join([f'{index} - Tag: {node.tag} Role: {node.role} Name: {node.name} Attributes: {node.attributes} Cordinates: {node.center.to_string()}' for index,(node) in enumerate(self.interactive_nodes)]) 84 | 85 | def informative_elements_to_string(self)->str: 86 | return '\n'.join([f'Tag: {node.tag} Role: {node.role} Content: {node.content}' for node in self.informative_nodes]) 87 | 88 | def scrollable_elements_to_string(self)->str: 89 | n=len(self.interactive_nodes) 90 | return '\n'.join([f'{n+index} - Tag: {node.tag} Role: {node.role} Name: {shorten(node.name,width=500)} Attributes: {node.attributes}' for index,node in enumerate(self.scrollable_nodes)]) 91 | 92 | -------------------------------------------------------------------------------- /src/tool/registry/__init__.py: -------------------------------------------------------------------------------- 1 | # src/tool/registry/__init__.py 2 | from src.tool.registry.views import Function,ToolResult 3 | from src.tool import Tool 4 | 5 | class Registry: 6 | def __init__(self,tools:list[Tool]): 7 | self.tools=tools 8 | self.tools_registry=self.registry() 9 | 10 | def tools_prompt(self,excluded_tools:list[str]=[])->str: 11 | prompts=[] 12 | for tool in self.tools: 13 | if tool.name in excluded_tools: 14 | continue 15 | prompts.append(tool.get_prompt()) 16 | return '\n\n'.join(prompts) 17 | 18 | def registry(self)->dict[str,Function]: 19 | tools_registry={} 20 | for tool in self.tools: 21 | tools_registry.update({tool.name : Function(name=tool.name,description=tool.description,params=tool.params,function=tool.func)}) 22 | return tools_registry 23 | 24 | async def async_execute(self,name:str,input:dict,**kwargs)->ToolResult: 25 | tool=self.tools_registry.get(name) 26 | try: 27 | # Check if name is None or empty, which indicates an LLM failure 28 | if not name: 29 | raise ValueError('Action Name was None or empty. The LLM failed to choose a valid action.') 30 | 31 | if tool is None: 32 | raise ValueError(f'Tool "{name}" not found. Please choose from the available tools.') 33 | 34 | if tool.params: 35 | # Ensure input is a dictionary before validation 36 | if not isinstance(input, dict): 37 | raise TypeError(f"Action Input for tool '{name}' must be a dictionary, but got {type(input)}: {input}") 38 | tool_params=tool.params.model_validate(input) 39 | params=tool_params.model_dump()|kwargs 40 | else: 41 | params=input|kwargs 42 | 43 | content=await tool.function(**params) 44 | return ToolResult(name=name,content=content) 45 | except Exception as e: 46 | # If 'name' was the issue, use a placeholder 'Invalid Action' 47 | # Otherwise, use the provided 'name'. 48 | error_name = name if name else "Invalid Action" 49 | error_content = f"Error executing tool '{error_name}': {str(e)}" 50 | print(f"DEBUG: Tool execution failed. Name: {name}, Input: {input}, Error: {e}") 51 | return ToolResult(name=error_name, content=error_content) 52 | 53 | def execute(self,name:str,input:dict,**kwargs)->ToolResult: 54 | tool=self.tools_registry.get(name) 55 | try: 56 | # Check if name is None or empty 57 | if not name: 58 | raise ValueError('Action Name was None or empty. The LLM failed to choose a valid action.') 59 | 60 | if tool is None: 61 | raise ValueError(f'Tool "{name}" not found. Please choose from the available tools.') 62 | 63 | if tool.params: 64 | # Ensure input is a dictionary before validation 65 | if not isinstance(input, dict): 66 | raise TypeError(f"Action Input for tool '{name}' must be a dictionary, but got {type(input)}: {input}") 67 | tool_params=tool.params.model_validate(input) 68 | params=tool_params.model_dump()|kwargs 69 | else: 70 | params=input|kwargs 71 | 72 | content=tool.function(**params) 73 | return ToolResult(name=name,content=content) 74 | except Exception as e: 75 | error_name = name if name else "Invalid Action" 76 | error_content = f"Error executing tool '{error_name}': {str(e)}" 77 | print(f"DEBUG: Tool execution failed. Name: {name}, Input: {input}, Error: {e}") 78 | return ToolResult(name=error_name,content=error_content) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

🌐 Web-Navigator

4 | 5 | 6 | License 7 | 8 | Python 9 | Powered by Playwright 10 |
11 | 12 | 13 | Follow on Twitter 14 | 15 | 16 | Join us on Discord 17 | 18 | 19 |
20 | 21 |
22 | 23 | **Web Navigator** is your intelligent browsing companion, built to seamlessly navigate websites, interact with dynamic content, perform smart searches, download files, and adapt to ever-changing pages — all with minimal effort from you. Powered by advanced LLMs and the robust Playwright framework, it transforms complex web tasks into streamlined, automated workflows that boost productivity and save time. 24 | 25 | ## 🛠️Installation Guide 26 | 27 | ### **Prerequisites** 28 | 29 | - Python 3.11 or higher 30 | - UV 31 | 32 | ### **Installation Steps** 33 | 34 | **Clone the repository:** 35 | 36 | ```bash 37 | git clone https://github.com/CursorTouch/Web-Navigator.git 38 | cd Web-Navigator 39 | ``` 40 | 41 | **Install dependencies:** 42 | 43 | ```bash 44 | uv sync 45 | ``` 46 | 47 | **Setup Playwright:** 48 | 49 | ```bash 50 | playwright install 51 | ``` 52 | 53 | --- 54 | 55 | **Setting up the `.env` file:** 56 | 57 | ```bash 58 | GOOGLE_API_KEY="" 59 | ``` 60 | 61 | Basic setup of the agent. 62 | 63 | ```python 64 | from src.inference.gemini import ChatGemini 65 | from src.agent.web import WebAgent 66 | from dotenv import load_dotenv 67 | import os 68 | 69 | load_dotenv() 70 | google_api_key=os.getenv('GOOGLE_API_KEY') 71 | 72 | llm=ChatGemini(model='gemini-2.0-flash',api_key=google_api_key,temperature=0) 73 | agent=Agent(llm=llm,verbose=True,use_vision=False) 74 | 75 | user_query=input('Enter your query: ') 76 | agent_response=agent.invoke(user_query) 77 | print(agent_response.get('output')) 78 | 79 | ``` 80 | 81 | Execute the following command to start the agent: 82 | 83 | ```bash 84 | python app.py 85 | ``` 86 | 87 | ## 🎥Demos 88 | 89 | **Prompt:** I want to know the price details of the RTX 4060 laptop gpu from varrious sellers from amazon.in 90 | 91 | https://github.com/user-attachments/assets/c729dda9-0ecc-4b07-9113-62fddccca52f 92 | 93 | **Prompt:** Make a twitter post about AI on X 94 | 95 | https://github.com/user-attachments/assets/126ef697-f506-4630-9a0a-1dbbfead9f7e 96 | 97 | **Prompt:** Can you play the trailer of GTA 6 on youtube 98 | 99 | https://github.com/user-attachments/assets/7abde708-7fe0-46f8-96ac-16124aaf2ef4 100 | 101 | **Prompt:** Can you go to my github account and visit the Windows MCP 102 | 103 | https://github.com/user-attachments/assets/cb8ad60c-0609-42e3-9fb9-584ad77c4e3a 104 | 105 | --- 106 | 107 | ## 🪪License 108 | 109 | This project is licensed under MIT License - see the [LICENSE](LICENSE) file for details. 110 | 111 | ## 🤝Contributing 112 | 113 | Contributions are welcome! Please see [CONTRIBUTING](CONTRIBUTING.md) for setup instructions and development guidelines. 114 | 115 | Made with ❤️ by [Jeomon George](https://github.com/Jeomon), [Muhammad Yaseen](https://github.com/mhmdyaseen) 116 | 117 | --- 118 | 119 | ## 📒References 120 | 121 | - **[Playwright Documentation](https://playwright.dev/docs/intro)** 122 | - **[LangGraph Examples](https://github.com/langchain-ai/langgraph/blob/main/examples/web-navigation/web_voyager.ipynb)** 123 | - **[vimGPT](https://github.com/ishan0102/vimGPT)** 124 | - **[WebVoyager](https://github.com/MinorJerry/WebVoyager)** 125 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | uploads/ 16 | screenshots/ 17 | user_data/ 18 | memory_data/ 19 | db/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | /Web312venv/ 95 | 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 117 | .pdm.toml 118 | .pdm-python 119 | .pdm-build/ 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | test.py 172 | notebook.ipynb -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Web-Navigator 2 | 3 | Thank you for your interest in contributing to Web Agent! This document provides guidelines and instructions for contributing to this project. 4 | 5 | ## Table of Contents 6 | 7 | - [Getting Started](#getting-started) 8 | - [Development Environment](#development-environment) 9 | - [Installation](#installation-from-source) 10 | 11 | - [Development Workflow](#development-workflow) 12 | - [Branching Strategy](#branching-strategy) 13 | - [Commit Messages](#commit-messages) 14 | - [Code Style](#code-style) 15 | - [Pre-commit Hooks](#pre-commit-hooks) 16 | 17 | - [Testing](#testing) 18 | - [Running Tests](#running-tests) 19 | - [Adding Tests](#adding-tests) 20 | 21 | - [Pull Requests](#pull-requests) 22 | - [Creating a Pull Request](#creating-a-pull-request) 23 | - [Pull Request Template](#pull-request-template) 24 | 25 | - [Documentation](#documentation) 26 | - [Getting Help](#getting-help) 27 | 28 | ## Getting Started 29 | 30 | ### Development Environment 31 | 32 | WebAgent requires: 33 | - Python 3.11 or later 34 | 35 | ### Installation 36 | 37 | 1. Fork the repository on GitHub. 38 | 2. Clone your fork locally: 39 | 40 | ```bash 41 | git clone https://github.com/CursorTouch/Web-Agent.git 42 | cd web-agent 43 | ``` 44 | 45 | 3. Install the package in development mode: 46 | 47 | ```bash 48 | pip install -e ".[dev,search]" 49 | ``` 50 | 51 | 4. Set up pre-commit hooks: 52 | 53 | ```bash 54 | pip install pre-commit 55 | pre-commit install 56 | ``` 57 | 58 | ## Development Workflow 59 | 60 | ### Branching Strategy 61 | 62 | - `main` branch contains the latest stable code 63 | - Create feature branches from `main` named according to the feature you're implementing: `feature/your-feature-name` 64 | - For bug fixes, use: `fix/bug-description` 65 | 66 | ### Commit Messages 67 | 68 | For now no commit style is enforced, try to keep your commit messages informational. 69 | 70 | ### Code Style 71 | 72 | Key style guidelines: 73 | 74 | - Line length: 100 characters 75 | - Use double quotes for strings 76 | - Follow PEP 8 naming conventions 77 | - Add type hints to function signatures 78 | 79 | ### Pre-commit Hooks 80 | 81 | We use pre-commit hooks to ensure code quality before committing. The configuration is in `.pre-commit-config.yaml`. 82 | 83 | The hooks will: 84 | 85 | - Run linting checks 86 | - Check for trailing whitespace and fix it 87 | - Ensure files end with a newline 88 | - Validate YAML files 89 | - Check for large files 90 | - Remove debug statements 91 | 92 | ## Testing 93 | 94 | ### Running Tests 95 | 96 | Run the test suite with pytest: 97 | 98 | ```bash 99 | pytest 100 | ``` 101 | 102 | To run specific test categories: 103 | 104 | ```bash 105 | pytest tests/ 106 | ``` 107 | 108 | ### Adding Tests 109 | 110 | - Add unit tests for new functionality in `tests/unit/` 111 | - For slow or network-dependent tests, mark them with `@pytest.mark.slow` or `@pytest.mark.integration` 112 | - Aim for high test coverage of new code 113 | 114 | ## Pull Requests 115 | 116 | ### Creating a Pull Request 117 | 118 | 1. Ensure your code passes all tests and pre-commit hooks 119 | 2. Push your changes to your fork 120 | 3. Submit a pull request to the main repository 121 | 4. Follow the pull request template 122 | 123 | ## Documentation 124 | 125 | - Update docstrings for new or modified functions, classes, and methods 126 | - Use Google-style docstrings: 127 | 128 | ```python 129 | def function_name(param1: type, param2: type) -> return_type: 130 | """Short description. 131 | Longer description if needed. 132 | 133 | Args: 134 | param1: Description of param1 135 | param2: Description of param2 136 | 137 | Returns: 138 | Description of return value 139 | 140 | Raises: 141 | ExceptionType: When and why this exception is raised 142 | """ 143 | ``` 144 | 145 | 146 | ## Getting Help 147 | 148 | If you need help with your contribution: 149 | 150 | - Open an issue for discussion 151 | - Reach out to any of the our maintainers 152 | 153 | We look forward to your contributions! -------------------------------------------------------------------------------- /src/agent/web/tools/views.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel,Field 2 | from typing import Literal 3 | 4 | class SharedBaseModel(BaseModel): 5 | class Config: 6 | extra="allow" 7 | 8 | class Done(SharedBaseModel): 9 | content:str = Field(...,description="Summary of the completed task in proper markdown format explaining what was accomplished",examples=["The task is completed successfully. User profile updated with new email address."]) 10 | 11 | class Click(SharedBaseModel): 12 | index:int = Field(...,description="The index/label of the interactive element to click (buttons, links, checkboxes, tabs, etc.)",examples=[0]) 13 | 14 | class Type(SharedBaseModel): 15 | index:int = Field(...,description="The index/label of the input element to type text into (text fields, search boxes, text areas)",examples=[0]) 16 | text:str = Field(...,description="The text content to type into the input field",examples=["hello world","user@example.com","My search query"]) 17 | clear:Literal['True','False']=Field(description="Whether to clear existing text before typing new content",default="False",examples=['True']) 18 | press_enter:Literal['True','False']=Field(description="Whether to press Enter after typing",default="False",examples=['True']) 19 | 20 | class Wait(SharedBaseModel): 21 | time:int = Field(...,description="Number of seconds to wait for page loading, animations, or content to appear",examples=[1,3,5]) 22 | 23 | class Scroll(SharedBaseModel): 24 | direction: Literal['up','down'] = Field(description="The direction to scroll content", examples=['up','down'], default='up') 25 | index: int = Field(description="Index of specific scrollable element, if None then scrolls the entire page", examples=[0, 5, 12,None],default=None) 26 | amount: int = Field(description="Number of pixels to scroll, if None then scrolls by page/container height. Must required for scrollable container elements and the amount should be small", examples=[100, 25, 50],default=500) 27 | 28 | class GoTo(SharedBaseModel): 29 | url:str = Field(...,description="The complete URL to navigate to including protocol (http/https)",examples=["https://www.example.com","https://google.com/search?q=test"]) 30 | 31 | class Back(SharedBaseModel): 32 | pass 33 | 34 | class Forward(SharedBaseModel): 35 | pass 36 | 37 | class Key(SharedBaseModel): 38 | keys:str = Field(...,description="Keyboard key or key combination to press (supports modifiers like Control, Alt, Shift)",examples=["Enter","Control+A","Escape","Tab","Control+C"]) 39 | times:int = Field(description="Number of times to repeat the key press sequence",examples=[1,2,3],default=1) 40 | 41 | class Download(SharedBaseModel): 42 | url:str = Field(...,description="Direct URL of the file to download (supports various file types: PDF, images, videos, documents)",examples=["https://www.example.com/document.pdf","https://site.com/image.jpg"]) 43 | filename:str=Field(...,description="Local filename to save the downloaded file as (include file extension)",examples=["document.pdf","image.jpg","data.xlsx"]) 44 | 45 | class Scrape(SharedBaseModel): 46 | pass 47 | 48 | class Tab(SharedBaseModel): 49 | mode:Literal['open','close','switch'] = Field(...,description="Tab operation: 'open' creates new tab, 'close' closes current tab, 'switch' changes to existing tab",examples=['open','close','switch']) 50 | tab_index:int = Field(description="Zero-based index of the tab to switch to (only required for 'switch' mode)",examples=[0,1,2],default=None) 51 | 52 | class Upload(SharedBaseModel): 53 | index:int = Field(...,description="Index of the file input element to upload files to",examples=[0]) 54 | filenames:list[str] = Field(...,description="List of filenames to upload from the ./uploads directory (supports single or multiple files)",examples=[["document.pdf"],["image1.jpg","image2.png"]]) 55 | 56 | class Menu(SharedBaseModel): 57 | index:int = Field(...,description="Index of the dropdown/select element to interact with",examples=[0]) 58 | labels:list[str] = Field(...,description="List of visible option labels to select from the dropdown menu (supports single or multiple selection)",examples=[["BMW"],["Option 1","Option 2"]]) 59 | 60 | class Script(SharedBaseModel): 61 | script:str = Field(...,description="The JavaScript code to execute in the current webpage to scrape data. Make sure the script is well-formatted",examples=["console.log('Hello, world!')"]) 62 | 63 | class HumanInput(SharedBaseModel): 64 | prompt: str = Field(..., description="Clear question or instruction to ask the human user when assistance is needed", examples=["Please enter the OTP code sent to your phone", "What is your preferred payment method?", "Please solve this CAPTCHA"]) -------------------------------------------------------------------------------- /src/agent/web/dom/__init__.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.dom.views import DOMElementNode, DOMTextualNode, ScrollElementNode, DOMState, CenterCord, BoundingBox 2 | from playwright.async_api import Page, Frame 3 | from typing import TYPE_CHECKING 4 | from asyncio import sleep 5 | 6 | if TYPE_CHECKING: 7 | from src.agent.web.context import Context 8 | 9 | class DOM: 10 | def __init__(self, context:'Context'): 11 | self.context=context 12 | 13 | async def get_state(self,use_vision:bool=False,freeze:bool=False)->tuple[str|None,DOMState]: 14 | '''Get the state of the webpage.''' 15 | try: 16 | if freeze: 17 | await sleep(5) 18 | with open('./src/agent/web/dom/script.js') as f: 19 | script=f.read() 20 | page=await self.context.get_current_page() 21 | await page.wait_for_load_state('domcontentloaded',timeout=10*1000) 22 | await self.context.execute_script(page,script) 23 | #Access from frames 24 | frames=page.frames 25 | interactive_nodes,informative_nodes,scrollable_nodes=await self.get_elements(frames=frames) 26 | if use_vision: 27 | # Add bounding boxes to the interactive elements 28 | boxes=map(lambda node:node.bounding_box.to_dict(),interactive_nodes) 29 | await self.context.execute_script(page,'boxes=>{mark_page(boxes)}',list(boxes)) 30 | screenshot=await self.context.get_screenshot(save_screenshot=False) 31 | # Remove bounding boxes from the interactive elements 32 | if freeze: 33 | await sleep(10) 34 | await sleep(0.1) 35 | await self.context.execute_script(page,'unmark_page()') 36 | else: 37 | screenshot=None 38 | except Exception as e: 39 | print(f"Failed to get elements from page: {page.url}\nError: {e}") 40 | interactive_nodes,informative_nodes,scrollable_nodes=[],[],[] 41 | screenshot=None 42 | selector_map=dict(enumerate(interactive_nodes+scrollable_nodes)) 43 | return (screenshot,DOMState(interactive_nodes=interactive_nodes,informative_nodes=informative_nodes,scrollable_nodes=scrollable_nodes,selector_map=selector_map)) 44 | 45 | async def get_elements(self,frames:list[Frame|Page])->tuple[list[DOMElementNode],list[DOMTextualNode],list[ScrollElementNode]]: 46 | '''Get the interactive elements of the webpage.''' 47 | interactive_elements,informative_elements,scrollable_elements=[],[],[] 48 | with open('./src/agent/web/dom/script.js') as f: 49 | script=f.read() 50 | try: 51 | for index,frame in enumerate(frames): 52 | if frame.is_detached() or frame.url=='about:blank': 53 | continue 54 | # print(f"Getting elements from frame: {frame.url}") 55 | await self.context.execute_script(frame,script) # Inject JS 56 | #index=0 means Main Frame 57 | if index>0 and not await self.context.is_frame_visible(frame=frame): 58 | continue 59 | # print(f"Getting elements from frame: {frame.url}") 60 | await self.context.execute_script(frame,script) 61 | nodes:dict=await self.context.execute_script(frame,'getElements()') 62 | element_nodes,textual_nodes,scrollable_nodes=nodes.values() 63 | if index>0: 64 | frame_element =await frame.frame_element() 65 | frame_xpath=await self.context.execute_script(frame,'(frame_element)=>getXPath(frame_element)',frame_element) 66 | else: 67 | frame_xpath='' 68 | viewport =await self.context.get_viewport() 69 | for element in element_nodes: 70 | element_xpath=element.get('xpath') 71 | node=DOMElementNode(**{ 72 | 'tag':element.get('tag'), 73 | 'role':element.get('role'), 74 | 'name':element.get('name'), 75 | 'attributes':element.get('attributes'), 76 | 'center':CenterCord(**element.get('center')), 77 | 'bounding_box':BoundingBox(**element.get('box')), 78 | 'xpath':{'frame':frame_xpath,'element':element_xpath}, 79 | 'viewport':viewport 80 | }) 81 | interactive_elements.append(node) 82 | 83 | for element in scrollable_nodes: 84 | element_xpath=element.get('xpath') 85 | node=ScrollElementNode(**{ 86 | 'tag':element.get('tag'), 87 | 'role':element.get('role'), 88 | 'name':element.get('name'), 89 | 'attributes':element.get('attributes'), 90 | 'xpath':{'frame':frame_xpath,'element':element_xpath}, 91 | 'viewport':viewport 92 | }) 93 | scrollable_elements.append(node) 94 | 95 | for element in textual_nodes: 96 | element_xpath=element.get('xpath') 97 | node=DOMTextualNode(**{ 98 | 'tag':element.get('tag'), 99 | 'role':element.get('role'), 100 | 'content':element.get('content'), 101 | 'center':CenterCord(**element.get('center')), 102 | 'xpath':{'frame':frame_xpath,'element':element_xpath}, 103 | 'viewport':viewport 104 | }) 105 | informative_elements.append(node) 106 | 107 | except Exception as e: 108 | print(f"Failed to get elements from frame: {frame.url}\nError: {e}") 109 | return interactive_elements,informative_elements,scrollable_elements 110 | -------------------------------------------------------------------------------- /src/memory/episodic/__init__.py: -------------------------------------------------------------------------------- 1 | from src.memory.episodic.utils import read_markdown_file 2 | from src.memory.episodic.views import Memory,Memories 3 | from src.message import SystemMessage,HumanMessage 4 | from src.inference import BaseInference 5 | from src.message import BaseMessage 6 | from src.memory import BaseMemory 7 | from src.router import LLMRouter 8 | from termcolor import colored 9 | from uuid import uuid4 10 | import json 11 | 12 | with open('./src/memory/episodic/routes.json','r') as f: 13 | routes=json.load(f) 14 | 15 | class EpisodicMemory(BaseMemory): 16 | def __init__(self,knowledge_base:str='knowledge_base.json',llm:BaseInference=None,verbose=False): 17 | self.memories:Memories=[] 18 | super().__init__(knowledge_base=knowledge_base,llm=llm,verbose=verbose) 19 | 20 | def router(self,conversation:list[BaseMessage]): 21 | instructions=[ 22 | 'Go through the conversation and revelant memories to determine which route to take.', 23 | 'Be proactive in your choice and perform this task carefully and with utmost accuracy.' 24 | ] 25 | router=LLMRouter(instructions=instructions,routes=routes,llm=self.llm,verbose=False) 26 | route=router.invoke(f'### Revelant Memories from Knowledge Base:\n{self.memories.to_string()}\n### Conversation:\n{self.conversation_to_text(conversation)}') 27 | return route 28 | 29 | def store(self, conversation: list[BaseMessage]): 30 | route=self.router(conversation) 31 | if route=='ADD': 32 | self.add_memory(conversation) 33 | elif route=='UPDATE': 34 | self.update_memory(conversation) 35 | elif route=='REPLACE': 36 | self.replace_memory(conversation) 37 | else: 38 | self.idle_memory(conversation) 39 | 40 | def idle_memory(self,conversation:list[BaseMessage]): 41 | if self.verbose: 42 | print(f'{colored(f'Idle memory:',color='yellow',attrs=['bold'])}\n{json.dumps(self.memories.all(),indent=2)}') 43 | return None 44 | 45 | def add_memory(self,conversation:list[BaseMessage]): 46 | system_prompt=read_markdown_file('src/memory/episodic/prompt/add.md') 47 | text_conversation=self.conversation_to_text(conversation) 48 | user_prompt=f'### Conversation:\n{text_conversation}' 49 | messages=[SystemMessage(system_prompt),HumanMessage(user_prompt)] 50 | memory=self.llm.invoke(messages,model=Memory) 51 | if self.verbose: 52 | print(f'{colored(f'Adding memory to Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(memory.to_dict(),indent=2)}') 53 | with open(f'./memory_data/{self.knowledge_base}','r+') as f: 54 | knowledge_base:list[dict] = json.load(f) 55 | knowledge_base.append(memory.model_dump()) 56 | f.seek(0) 57 | json.dump(knowledge_base, f, indent=2) 58 | 59 | def update_memory(self,conversation:list[BaseMessage]): 60 | system_prompt=read_markdown_file('src/memory/episodic/prompt/update.md') 61 | text_conversation=self.conversation_to_text(conversation) 62 | user_prompt=f'### Revelant memories from Knowledge Base:\n{self.memories.to_string()}\n### Conversation:\n{text_conversation}' 63 | messages=[SystemMessage(system_prompt),HumanMessage(user_prompt)] 64 | memory:Memory=self.llm.invoke(messages,model=Memory) 65 | if self.verbose: 66 | print(f'{colored(f'Updated memory from Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(memory.to_dict(),indent=2)}') 67 | with open(f'./memory_data/{self.knowledge_base}','r+') as f: 68 | knowledge_base = [Memory.model_validate(memory) for memory in json.loads(f)] 69 | memory_ids=[memory.get('id') for memory in self.memories.all()] 70 | updated_knowledge_base=list(filter(lambda memory:memory.id not in memory_ids,knowledge_base)) 71 | updated_knowledge_base.extend(memory) 72 | f.seek(0) 73 | json.dump(updated_knowledge_base, f, indent=2) 74 | f.truncate() 75 | 76 | def replace_memory(self,conversation:list[BaseMessage]): 77 | system_prompt=read_markdown_file('src/memory/episodic/prompt/replace.md') 78 | text_conversation=self.conversation_to_text(conversation) 79 | user_prompt=f'### Conversation:\n{text_conversation}' 80 | messages=[SystemMessage(system_prompt),HumanMessage(user_prompt)] 81 | memory:Memory=self.llm.invoke(messages,model=Memory) 82 | if self.verbose: 83 | print(f'{colored(f'Replacing memory from Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(memory.to_dict(),indent=2)}') 84 | with open(f'./memory_data/{self.knowledge_base}','r+') as f: 85 | knowledge_base = [Memory.model_validate(memory) for memory in json.loads(f)] 86 | memory_ids=[memory.id for memory in self.memories.all()] 87 | updated_knowledge_base=list(filter(lambda memory:memory.id not in memory_ids,knowledge_base)) 88 | updated_knowledge_base.append(memory) 89 | f.seek(0) 90 | json.dump([memory.model_dump() for memory in updated_knowledge_base], f, indent=2) 91 | f.truncate() 92 | 93 | def retrieve(self, query: str)->list[dict]: 94 | memories=[memory for memory in self.memories] 95 | system_prompt=read_markdown_file('src/memory/episodic/prompt/retrieve.md') 96 | user_prompt=f'### Query: {query}\n Now, select the memories those are relevant to solve the query.' 97 | messages=[SystemMessage(system_prompt.format(memories=memories)),HumanMessage(user_prompt)] 98 | memories=self.llm.invoke(messages,model=Memories) 99 | self.memories=memories 100 | if self.verbose: 101 | print(f'{colored(f'Retrieved memories from Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(self.memories.all(),indent=2)}') 102 | return self.memories 103 | 104 | def attach_memory(self,system_prompt:str)->str: 105 | episodic_prompt=read_markdown_file('src/memory/episodic/prompt/memory.md') 106 | memory_prompt=episodic_prompt.format(memories=self.memories.to_string()) 107 | return f'{system_prompt}\n\n{memory_prompt}' 108 | 109 | -------------------------------------------------------------------------------- /src/agent/web/prompt/system.md: -------------------------------------------------------------------------------- 1 | # 🕸️Web Navigator 2 | 3 | You are Web Navigator designed by CursorTouch to solve the web related queries given by the USER in the . 4 | 5 | The current date is {current_datetime} 6 | 7 | Web Navigator can perform deep research. For the tasks that requires more contextual information perform research on that area of the topic. This can be performed in the intermediate stages or in the beginning itself and continue solving the task. 8 | 9 | Web Navigator can go both in-depth and breath on any given topic by looking through different sources, articles, blogs, ...etc. and this is an inheritant feature in deep research. 10 | 11 | Web Navigator enjoys helping the user to achieve the . 12 | 13 | Additional Instructions: 14 | 15 | {instructions} 16 | 17 | Available Tools: 18 | 19 | {tools_prompt} 20 | 21 | IMPORTANT: Only use tools that is available. Never hallucinate using tools. 22 | 23 | ## System Information: 24 | 25 | - **Operating System:** {os} 26 | - **Browser:** {browser} 27 | - **Home Directory:** {home_dir} 28 | - **Downloads Folder:** {downloads_dir} 29 | 30 | At every step, Web Agent will be given the state: 31 | 32 | ```xml 33 | 34 | 35 | Current Step: How many steps over 36 | Max. Steps: Max. steps allowed with in which, solve the task 37 | Action Reponse : Result of executing the previous action 38 | 39 | 40 | [Begin of Tab Info] 41 | Current Tab: The info related to current tab agent is working on. 42 | Open Tabs: The info related to other tabs those are open in the browser. 43 | [End of Tab Info] 44 | 45 | [Begin of Viewport] 46 | List of Interactive Elements: the interactable elements on the current tab like buttons,links and more. 47 | List of Scrollable Elements: these elements enable the agent to scroll on specific sections of the webpage. 48 | List of Informative Elements: these elements provide the text in the webpage. 49 | [End of Viewport] 50 | 51 | 52 | The ultimate goal for Web Navigator given by the user, use it to track progress. 53 | 54 | 55 | ``` 56 | 57 | Web Navigator must follow the following rules while browsing the web: 58 | 59 | 1. ALWAYS start solving the given query using the appropirate search domains like google, youtube, wikipaedia, twitter ...etc. 60 | 2. When performing deep research make sure conduct it in a seperate tab using `Tab Tool` and not on the current working tab. 61 | 3. If any banners or ads those are obstructing the way close it and accept cookies if you see in the page. 62 | 4. If a captcha appears, attempt solving it if possible or else use fallback strategies (ex: go back, alternative site). 63 | 5. You can scroll through specific sections of the webpage if there are Scrollable Elements to get relevant content from those sections. 64 | 6. Develop search queries that are clear and optimistic to the . 65 | 7. To scrape the entire webpage use the `Scrape Tool`. It would include all the text and links present in the page. 66 | 67 | Web Navigator must follow the following rules for better reasoning and planning in : 68 | 69 | 1. Use the recent steps to track the progress and context towards . 70 | 2. Incorporate , , , screenshot (if available) in your reasoning process and explain what you want to achieve next from based on the current state. 71 | 3. You can create plan in this stage to clearly define your objectives to achieve and even self-reflect to correct yourself from mistakes. 72 | 4. Analysis whether are you stuck at same goal for few steps. If so, try alternative methods. 73 | 5. When you are ready to finish, state you are preparing answer the user by gathering the findings you got and then use the `Done Tool`. 74 | 6. Explicitly judge the effectiveness of the previous action and keep it in . 75 | 7. Valuable information gained so far will be present in use it as needed. Use this information to connect the dots to gain new insights. 76 | 77 | Web Navigator must follow the following rules during the agentic loop: 78 | 79 | 1. Start by `GoTo Tool` going to the current search domain. 80 | 2. Use `Done Tool` when you have performed/completed the ultimate task, this include sufficient knowledge gained from browsing the internet. This tool provides you an opportunity to terminate and share your findings with the user. 81 | 3. The contains elements within the viewport only are listed. Use `Scroll Tool` if you suspect relevant content is offscreen which you want to interact with. Scroll ONLY if there is more content above or below the webpage. 82 | 4. When browsing especially in search engines keep an eye on the auto suggestions that pops up under the input field. 83 | 5. If the page isn't fully loaded, use `Wait Tool` to wait and if any changes are not seen in the webpage after performing an action then wait. 84 | 6. For clicking only use `Click Tool` and for clicking and typing use `Type Tool`. 85 | 7. When you respond provide thorough, well-detailed explanations of all findings and also mention the sources you referred based on the . 86 | 8. When clicking on certain links using `Click Tool` then sometimes the site opens in a new tab then the shall be w.r.t to this tab. 87 | 9. Don't caught stuck in loops while solving the given the task. Each step is an attempt reach the goal. 88 | 10. If the query includes specific information like location, size, price then efficiently apply those filter while in the webpage. 89 | 11. NEVER close the last tab on the browser (the browser will close automatically). 90 | 12. You can ask the user for clarification or more data to continue using `Human Tool`. 91 | 13. The contains the information gained from the internet and essential context this included the data from such as credentials. 92 | 14. Remember to complete the task within `{max_iteration} steps` and ALWAYS output 1 reasonable action per step. 93 | 94 | Web Navigator must follow the following rules for : 95 | 96 | 1. ALWAYS remember solving the is the ultimate agenda. 97 | 2. Analysis the query, understand its complexity and break it into atomic subtasks. 98 | 3. If the task contains explict steps or instructions to follow that with high priority. 99 | 4. Always look for the latest information for the unless explicity specified. 100 | 5. You can do deep research to understand more on the topic to gain more insight for . 101 | 6. If additional instructions are given pay a good attention to that and act accordingly. 102 | 7. Give atmost importance to the user preference. 103 | 104 | Web Navigator must follow the following communication guidelines: 105 | 106 | 1. Maintain professional yet conversational tone. 107 | 2. The response highlight indepth findings and explained in detail. 108 | 3. Highlight key insights for the . 109 | 4. Format the responses in clean markdown format. 110 | 5. Only give verified information to the USER. 111 | 112 | ALWAYS respond exclusively in the following XML format: 113 | 114 | ```xml 115 | 116 | Success|Neutral|Failure - [Brief analysis of current state and progress] 117 | [Key information gathered from progress and current step also critical context for the problem statement from web] 118 | [Strategic planning and reasoning for next action based on analysis of the current state and what has been done so far] 119 | [Selected tool name (example: ABC Tool)] 120 | {{'param1':'value1','param2':'value2'}} 121 | 122 | ``` 123 | 124 | Begin!!! 125 | -------------------------------------------------------------------------------- /src/inference/anthropic.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage 2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 3 | from requests import RequestException,HTTPError,ConnectionError 4 | from ratelimit import limits,sleep_and_retry 5 | from httpx import Client,AsyncClient 6 | from src.inference import BaseInference,Token 7 | from pydantic import BaseModel 8 | from typing import Generator 9 | from typing import Literal 10 | from pathlib import Path 11 | from json import loads 12 | from uuid import uuid4 13 | import mimetypes 14 | import requests 15 | 16 | class ChatAnthropic(BaseInference): 17 | @sleep_and_retry 18 | @limits(calls=15,period=60) 19 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 20 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel: 21 | self.headers.update({ 22 | 'x-api-key': self.api_key, 23 | "anthropic-version": "2023-06-01", 24 | }) 25 | headers=self.headers 26 | temperature=self.temperature 27 | url=self.base_url or "https://api.anthropic.com/v1/messages" 28 | contents=[] 29 | system_instruct=None 30 | for message in messages: 31 | if isinstance(message,(HumanMessage,AIMessage)): 32 | contents.append(message.to_dict()) 33 | elif isinstance(message,ImageMessage): 34 | text,image=message.content 35 | contents.append([ 36 | { 37 | 'role':'user', 38 | 'content':[ 39 | { 40 | 'type':'text', 41 | 'text':text 42 | }, 43 | { 44 | 'type':'image', 45 | 'source':{ 46 | 'type':'base64', 47 | 'media_type':'image/png', 48 | 'data':image 49 | } 50 | } 51 | ] 52 | } 53 | ]) 54 | elif isinstance(message,SystemMessage): 55 | system_instruct=self.structured(message,model) if model else message.content 56 | else: 57 | raise Exception("Invalid Message") 58 | 59 | payload={ 60 | "model": self.model, 61 | "messages": contents, 62 | "temperature": temperature, 63 | "response_format": { 64 | "type": "json_object" if json or model else "text" 65 | }, 66 | "stream":False, 67 | } 68 | if self.tools: 69 | payload["tools"]=[{ 70 | 'type':'function', 71 | 'function':{ 72 | 'name':tool.name, 73 | 'description':tool.description, 74 | 'input_schema':tool.schema 75 | } 76 | } for tool in self.tools] 77 | if system_instruct: 78 | payload['system']=system_instruct 79 | try: 80 | with Client() as client: 81 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 82 | json_object=response.json() 83 | # print(json_object) 84 | if json_object.get('error'): 85 | raise HTTPError(json_object['error']['message']) 86 | message = json_object['content'][0] 87 | usage_metadata=json_object['usage'] 88 | input,output,total=usage_metadata['input_tokens'],usage_metadata['output_tokens'] 89 | total=input+output 90 | self.tokens=Token(input=input,output=output,total=total) 91 | if model: 92 | return model.model_validate_json(message.get('text')) 93 | if json: 94 | return AIMessage(loads(message.get('text'))) 95 | if message.get('content'): 96 | return AIMessage(message.get('text')) 97 | else: 98 | tool_call=message 99 | return ToolMessage(id= tool_call['id'] or str(uuid4()),name=tool_call['name'],args=tool_call['input']) 100 | except HTTPError as err: 101 | err_object=loads(err.response.text) 102 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 103 | except ConnectionError as err: 104 | print(err) 105 | exit() 106 | 107 | @sleep_and_retry 108 | @limits(calls=15,period=60) 109 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 110 | async def async_invoke(self, messages: list[BaseMessage], json: bool = False, model: BaseModel = None) -> AIMessage | ToolMessage | BaseModel: 111 | self.headers.update({ 112 | 'x-api-key': self.api_key, 113 | "anthropic-version": "2023-06-01", 114 | }) 115 | headers = self.headers 116 | temperature = self.temperature 117 | url = self.base_url or "https://api.anthropic.com/v1/messages" 118 | contents = [] 119 | system_instruct = None 120 | 121 | for message in messages: 122 | if isinstance(message, (HumanMessage, AIMessage)): 123 | contents.append(message.to_dict()) 124 | elif isinstance(message, ImageMessage): 125 | text, image = message.content 126 | contents.append([ 127 | { 128 | 'role': 'user', 129 | 'content': [ 130 | { 131 | 'type': 'text', 132 | 'text': text 133 | }, 134 | { 135 | 'type': 'image', 136 | 'source': { 137 | 'type': 'base64', 138 | 'media_type': 'image/png', 139 | 'data': image 140 | } 141 | } 142 | ] 143 | } 144 | ]) 145 | elif isinstance(message, SystemMessage): 146 | system_instruct = self.structured(message, model) if model else message.content 147 | else: 148 | raise Exception("Invalid Message") 149 | 150 | payload = { 151 | "model": self.model, 152 | "messages": contents, 153 | "temperature": temperature, 154 | "response_format": { 155 | "type": "json_object" if json or model else "text" 156 | }, 157 | "stream": False, 158 | } 159 | if self.tools: 160 | payload["tools"] = [{ 161 | 'type': 'function', 162 | 'function': { 163 | 'name': tool.name, 164 | 'description': tool.description, 165 | 'input_schema': tool.schema 166 | } 167 | } for tool in self.tools] 168 | if system_instruct: 169 | payload['system'] = system_instruct 170 | 171 | try: 172 | async with AsyncClient() as client: 173 | response = await client.post(url, json=payload, headers=headers) 174 | response.raise_for_status() 175 | json_object = response.json() 176 | if json_object.get('error'): 177 | raise HTTPError(json_object['error']['message']) 178 | message = json_object['content'][0] 179 | usage_metadata = json_object['usage'] 180 | input, output= usage_metadata['input_tokens'], usage_metadata['output_tokens'] 181 | total=input+output 182 | self.tokens = Token(input=input, output=output, total=total) 183 | if model: 184 | return model.model_validate_json(message.get('text')) 185 | if json: 186 | return AIMessage(loads(message.get('text'))) 187 | if message.get('content'): 188 | return AIMessage(message.get('text')) 189 | else: 190 | tool_call = message 191 | return ToolMessage(id=tool_call['id'] or str(uuid4()), name=tool_call['name'], args=tool_call['input']) 192 | except HTTPError as err: 193 | err_object = loads(err.response.text) 194 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 195 | except Exception as err: 196 | print(err) 197 | 198 | @sleep_and_retry 199 | @limits(calls=15,period=60) 200 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 201 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]: 202 | pass 203 | 204 | def available_models(self): 205 | url='https://api.groq.com/openai/v1/models' 206 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 207 | headers=self.headers 208 | response=requests.get(url=url,headers=headers) 209 | response.raise_for_status() 210 | models=response.json() 211 | return [model['id'] for model in models['data'] if model['active']] -------------------------------------------------------------------------------- /src/inference/nvidia.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage 2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 3 | from requests import RequestException,HTTPError,ConnectionError 4 | from ratelimit import limits,sleep_and_retry 5 | from httpx import Client,AsyncClient 6 | from src.inference import BaseInference,Token 7 | from pydantic import BaseModel 8 | from typing import Generator 9 | from json import loads 10 | from uuid import uuid4 11 | import requests 12 | 13 | class ChatNvidia(BaseInference): 14 | @sleep_and_retry 15 | @limits(calls=15,period=60) 16 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 17 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel: 18 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 19 | headers=self.headers 20 | temperature=self.temperature 21 | url=self.base_url or "https://integrate.api.nvidia.com/v1/chat/completions" 22 | contents=[] 23 | for message in messages: 24 | if isinstance(message,SystemMessage): 25 | if model: 26 | message.content=self.structured(message,model) 27 | contents.append(message.to_dict()) 28 | if isinstance(message,(HumanMessage,AIMessage)): 29 | contents.append(message.to_dict()) 30 | if isinstance(message,ImageMessage): 31 | text,image=message.content 32 | contents.append([ 33 | { 34 | 'role':'user', 35 | 'content':[ 36 | { 37 | 'type':'text', 38 | 'text':text 39 | }, 40 | { 41 | 'type':'image_url', 42 | 'image_url':{ 43 | 'url':image 44 | } 45 | } 46 | ] 47 | } 48 | ]) 49 | 50 | payload={ 51 | "model": self.model, 52 | "messages": contents, 53 | "temperature": temperature, 54 | "response_format": { 55 | "type": "json_object" if json or model else "text" 56 | }, 57 | "stream":False, 58 | } 59 | if self.tools: 60 | payload["tools"]=[{ 61 | 'type':'function', 62 | 'function':{ 63 | 'name':tool.name, 64 | 'description':tool.description, 65 | 'parameters':tool.schema 66 | } 67 | } for tool in self.tools] 68 | try: 69 | with Client() as client: 70 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 71 | json_object=response.json() 72 | # print(json_object) 73 | if json_object.get('error'): 74 | raise HTTPError(json_object['error']['message']) 75 | message=json_object['choices'][0]['message'] 76 | usage_metadata=json_object['usage'] 77 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 78 | self.tokens=Token(input=input,output=output,total=total) 79 | if model: 80 | return model.model_validate_json(message.get('content')) 81 | if json: 82 | return AIMessage(loads(message.get('content'))) 83 | if message.get('content'): 84 | return AIMessage(message.get('content')) 85 | else: 86 | tool_call=message.get('tool_calls')[0]['function'] 87 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 88 | except HTTPError as err: 89 | err_object=loads(err.response.text) 90 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 91 | except ConnectionError as err: 92 | print(err) 93 | exit() 94 | 95 | @sleep_and_retry 96 | @limits(calls=15,period=60) 97 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 98 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel: 99 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 100 | headers=self.headers 101 | temperature=self.temperature 102 | url=self.base_url or "https://integrate.api.nvidia.com/v1/chat/completions" 103 | contents=[] 104 | for message in messages: 105 | if isinstance(message,SystemMessage): 106 | if model: 107 | message.content=self.structured(message,model) 108 | contents.append(message.to_dict()) 109 | if isinstance(message,(HumanMessage,AIMessage)): 110 | contents.append(message.to_dict()) 111 | if isinstance(message,ImageMessage): 112 | text,image=message.content 113 | contents.append([ 114 | { 115 | 'role':'user', 116 | 'content':[ 117 | { 118 | 'type':'text', 119 | 'text':text 120 | }, 121 | { 122 | 'type':'image_url', 123 | 'image_url':{ 124 | 'url':image 125 | } 126 | } 127 | ] 128 | } 129 | ]) 130 | 131 | payload={ 132 | "model": self.model, 133 | "messages": contents, 134 | "temperature": temperature, 135 | "response_format": { 136 | "type": "json_object" if json or model else "text" 137 | }, 138 | "stream":False, 139 | } 140 | if self.tools: 141 | payload["tools"]=[{ 142 | 'type':'function', 143 | 'function':{ 144 | 'name':tool.name, 145 | 'description':tool.description, 146 | 'parameters':tool.schema 147 | } 148 | } for tool in self.tools] 149 | try: 150 | async with AsyncClient() as client: 151 | response=await client.post(url=url,json=payload,headers=headers,timeout=None) 152 | json_object=response.json() 153 | # print(json_object) 154 | if json_object.get('error'): 155 | raise HTTPError(json_object['error']['message']) 156 | message=json_object['choices'][0]['message'] 157 | usage_metadata=json_object['usage'] 158 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 159 | self.tokens=Token(input=input,output=output,total=total) 160 | if model: 161 | return model.model_validate_json(message.get('content')) 162 | if json: 163 | return AIMessage(loads(message.get('content'))) 164 | if message.get('content'): 165 | return AIMessage(message.get('content')) 166 | else: 167 | tool_call=message.get('tool_calls')[0]['function'] 168 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 169 | except HTTPError as err: 170 | err_object=loads(err.response.text) 171 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 172 | except ConnectionError as err: 173 | print(err) 174 | exit() 175 | 176 | @sleep_and_retry 177 | @limits(calls=15,period=60) 178 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 179 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]: 180 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 181 | headers=self.headers 182 | temperature=self.temperature 183 | url=self.base_url or "https://integrate.api.nvidia.com/v1/chat/completions" 184 | messages=[message.to_dict() for message in messages] 185 | payload={ 186 | "model": self.model, 187 | "messages": messages, 188 | "temperature": temperature, 189 | "response_format": { 190 | "type": "json_object" if json else "text" 191 | }, 192 | "stream":True, 193 | } 194 | try: 195 | response=requests.post(url=url,json=payload,headers=headers) 196 | response.raise_for_status() 197 | chunks=response.iter_lines(decode_unicode=True) 198 | for chunk in chunks: 199 | chunk=chunk.replace('data: ','') 200 | if chunk and chunk!='[DONE]': 201 | delta=loads(chunk)['choices'][0]['delta'] 202 | yield delta.get('content','') 203 | except HTTPError as err: 204 | err_object=loads(err.response.text) 205 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 206 | except ConnectionError as err: 207 | print(err) 208 | exit() -------------------------------------------------------------------------------- /src/inference/open_router.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,HumanMessage,ImageMessage,SystemMessage,ToolMessage 2 | from requests import get,RequestException,HTTPError,ConnectionError 3 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 4 | from ratelimit import limits,sleep_and_retry 5 | from src.inference import BaseInference,Token 6 | from httpx import Client,AsyncClient 7 | from pydantic import BaseModel 8 | from typing import Literal 9 | from json import loads 10 | from uuid import uuid4 11 | 12 | class ChatOpenRouter(BaseInference): 13 | @sleep_and_retry 14 | @limits(calls=15,period=60) 15 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 16 | def invoke(self, messages: list[BaseMessage],json=False,model:BaseModel|None=None) -> AIMessage|ToolMessage|BaseModel: 17 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 18 | headers=self.headers 19 | temperature=self.temperature 20 | url=self.base_url or "https://openrouter.ai/api/v1/chat/completions" 21 | contents=[] 22 | for message in messages: 23 | if isinstance(message,SystemMessage): 24 | if model: 25 | message.content=self.structured(message,model) 26 | contents.append(message.to_dict()) 27 | elif isinstance(message,(HumanMessage,AIMessage)): 28 | contents.append(message.to_dict()) 29 | elif isinstance(message,ImageMessage): 30 | text,image=message.content 31 | # Fix: Don't wrap in an extra list 32 | contents.append({ 33 | 'role':'user', 34 | 'content':[ 35 | { 36 | 'type':'text', 37 | 'text':text 38 | }, 39 | { 40 | 'type':'image_url', 41 | 'image_url':{ 42 | 'url':image 43 | } 44 | } 45 | ] 46 | }) 47 | 48 | payload={ 49 | "model": self.model, 50 | "messages": contents, 51 | "temperature": temperature, 52 | "response_format": { 53 | "type": "json_object" if json or model else "text" 54 | }, 55 | "stream":False, 56 | } 57 | if self.tools: 58 | payload["tools"]=[{ 59 | 'type':'function', 60 | 'function':{ 61 | 'name':tool.name, 62 | 'description':tool.description, 63 | 'parameters':tool.schema 64 | } 65 | } for tool in self.tools] 66 | try: 67 | with Client() as client: 68 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 69 | 70 | # Check HTTP status first 71 | response.raise_for_status() 72 | 73 | json_object=response.json() 74 | # print(json_object) 75 | if json_object.get('error'): 76 | raise HTTPError(json_object['error']['message']) 77 | message=json_object['choices'][0]['message'] 78 | usage_metadata=json_object['usage'] 79 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 80 | self.tokens=Token(input=input,output=output,total=total) 81 | if model: 82 | return model.model_validate_json(message.get('content')) 83 | if json: 84 | return AIMessage(loads(message.get('content'))) 85 | if message.get('content'): 86 | return AIMessage(message.get('content')) 87 | else: 88 | tool_call=message.get('tool_calls')[0]['function'] 89 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 90 | except HTTPError as err: 91 | # Fix: Proper error handling for HTTPError 92 | try: 93 | if hasattr(err, 'response') and err.response is not None: 94 | err_object = err.response.json() 95 | error_msg = err_object.get("error", {}).get("message", "Unknown API error") 96 | status_code = err.response.status_code 97 | print(f'\nError: {error_msg}\nStatus Code: {status_code}') 98 | else: 99 | print(f'\nHTTP Error: {str(err)}') 100 | except Exception as parse_err: 101 | print(f'\nHTTP Error: {str(err)} (Could not parse error response: {parse_err})') 102 | raise err # Re-raise instead of exit() 103 | except ConnectionError as err: 104 | print(f'\nConnection Error: {err}') 105 | raise err # Re-raise instead of exit() 106 | except Exception as err: 107 | print(f'\nUnexpected Error: {err}') 108 | raise err # Re-raise instead of exit() 109 | 110 | @sleep_and_retry 111 | @limits(calls=15,period=60) 112 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 113 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel: 114 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 115 | headers=self.headers 116 | temperature=self.temperature 117 | # Fix: Use OpenRouter URL instead of Groq URL 118 | url=self.base_url or "https://openrouter.ai/api/v1/chat/completions" 119 | contents=[] 120 | for message in messages: 121 | if isinstance(message,SystemMessage): 122 | if model: 123 | message.content=self.structured(message,model) 124 | contents.append(message.to_dict()) 125 | elif isinstance(message,(HumanMessage,AIMessage)): 126 | contents.append(message.to_dict()) 127 | elif isinstance(message,ImageMessage): 128 | text,image=message.content 129 | # Fix: Don't wrap in an extra list 130 | contents.append({ 131 | 'role':'user', 132 | 'content':[ 133 | { 134 | 'type':'text', 135 | 'text':text 136 | }, 137 | { 138 | 'type':'image_url', 139 | 'image_url':{ 140 | 'url':image 141 | } 142 | } 143 | ] 144 | }) 145 | 146 | payload={ 147 | "model": self.model, 148 | "messages": contents, 149 | "temperature": temperature, 150 | "response_format": { 151 | "type": "json_object" if json or model else "text" 152 | }, 153 | "stream":False, 154 | } 155 | if self.tools: 156 | payload["tools"]=[{ 157 | 'type':'function', 158 | 'function':{ 159 | 'name':tool.name, 160 | 'description':tool.description, 161 | 'parameters':tool.schema 162 | } 163 | } for tool in self.tools] 164 | try: 165 | async with AsyncClient() as client: 166 | response=await client.post(url=url,json=payload,headers=headers,timeout=None) 167 | 168 | # Check HTTP status first 169 | response.raise_for_status() 170 | 171 | json_object=response.json() 172 | # print(json_object) 173 | if json_object.get('error'): 174 | raise HTTPError(json_object['error']['message']) 175 | message=json_object['choices'][0]['message'] 176 | usage_metadata=json_object['usage'] 177 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 178 | self.tokens=Token(input=input,output=output,total=total) 179 | if model: 180 | return model.model_validate_json(message.get('content')) 181 | if json: 182 | return AIMessage(loads(message.get('content'))) 183 | if message.get('content'): 184 | return AIMessage(message.get('content')) 185 | else: 186 | tool_call=message.get('tool_calls')[0]['function'] 187 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 188 | except HTTPError as err: 189 | # Fix: Proper error handling for HTTPError 190 | try: 191 | if hasattr(err, 'response') and err.response is not None: 192 | err_object = err.response.json() 193 | error_msg = err_object.get("error", {}).get("message", "Unknown API error") 194 | status_code = err.response.status_code 195 | print(f'\nError: {error_msg}\nStatus Code: {status_code}') 196 | else: 197 | print(f'\nHTTP Error: {str(err)}') 198 | except Exception as parse_err: 199 | print(f'\nHTTP Error: {str(err)} (Could not parse error response: {parse_err})') 200 | raise err # Re-raise instead of exit() 201 | except ConnectionError as err: 202 | print(f'\nConnection Error: {err}') 203 | raise err # Re-raise instead of exit() 204 | except Exception as err: 205 | print(f'\nUnexpected Error: {err}') 206 | raise err # Re-raise instead of exit() 207 | 208 | def stream(self, messages, json = False): 209 | pass -------------------------------------------------------------------------------- /src/inference/mistral.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage 2 | from requests import RequestException,HTTPError,ConnectionError 3 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 4 | from ratelimit import limits,sleep_and_retry 5 | from src.inference import BaseInference,Token 6 | from httpx import Client,AsyncClient 7 | from pydantic import BaseModel 8 | from typing import Generator 9 | from json import loads 10 | from uuid import uuid4 11 | import requests 12 | 13 | class ChatMistral(BaseInference): 14 | @sleep_and_retry 15 | @limits(calls=15,period=60) 16 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 17 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel: 18 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 19 | headers=self.headers 20 | temperature=self.temperature 21 | url=self.base_url or "https://api.mistral.ai/v1/chat/completions" 22 | contents=[] 23 | for message in messages: 24 | if isinstance(message,SystemMessage): 25 | if model: 26 | message.content=self.structured(message,model) 27 | contents.append(message.to_dict()) 28 | if isinstance(message,(HumanMessage,AIMessage)): 29 | contents.append(message.to_dict()) 30 | if isinstance(message,ImageMessage): 31 | text,image_data=message.content 32 | contents.append([ 33 | { 34 | 'role':'user', 35 | 'content':{ 36 | { 37 | 'type':'text', 38 | 'text':text 39 | }, 40 | { 41 | 'type':'image_url', 42 | 'image_url':image_data 43 | } 44 | } 45 | } 46 | ]) 47 | 48 | payload={ 49 | "model": self.model, 50 | "messages": contents, 51 | "temperature": temperature, 52 | "response_format": { 53 | "type": "json_object" if json or model else "text" 54 | }, 55 | "stream":False, 56 | } 57 | if self.tools: 58 | payload["tools"]=[{ 59 | 'type':'function', 60 | 'function':{ 61 | 'name':tool.name, 62 | 'description':tool.description, 63 | 'parameters':tool.schema 64 | } 65 | } for tool in self.tools] 66 | try: 67 | with Client() as client: 68 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 69 | json_object=response.json() 70 | # print(json_object) 71 | if json_object.get('error'): 72 | raise Exception(json_object['error']['message']) 73 | message=json_object['choices'][0]['message'] 74 | usage_metadata=json_object['usage'] 75 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 76 | self.tokens=Token(input=input,output=output,total=total) 77 | if model: 78 | return model.model_validate_json(message.get('content')) 79 | if json: 80 | return AIMessage(loads(message.get('content'))) 81 | if message.get('content'): 82 | return AIMessage(message.get('content')) 83 | else: 84 | tool_call=message.get('tool_calls')[0]['function'] 85 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 86 | except HTTPError as err: 87 | err_object=loads(err.response.text) 88 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 89 | except ConnectionError as err: 90 | print(err) 91 | exit() 92 | 93 | @sleep_and_retry 94 | @limits(calls=15,period=60) 95 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 96 | async def async_invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel: 97 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 98 | headers=self.headers 99 | temperature=self.temperature 100 | url=self.base_url or "https://api.mistral.ai/v1/chat/completions" 101 | contents=[] 102 | for message in messages: 103 | if isinstance(message,SystemMessage): 104 | if model: 105 | message.content=self.structured(message,model) 106 | contents.append(message.to_dict()) 107 | if isinstance(message,(HumanMessage,AIMessage)): 108 | contents.append(message.to_dict()) 109 | if isinstance(message,ImageMessage): 110 | text,image_data=message.content 111 | contents.append([ 112 | { 113 | 'role':'user', 114 | 'content':{ 115 | { 116 | 'type':'text', 117 | 'text':text 118 | }, 119 | { 120 | 'type':'image_url', 121 | 'image_url':image_data 122 | } 123 | } 124 | } 125 | ]) 126 | 127 | payload={ 128 | "model": self.model, 129 | "messages": contents, 130 | "temperature": temperature, 131 | "response_format": { 132 | "type": "json_object" if json or model else "text" 133 | }, 134 | "stream":False, 135 | } 136 | if self.tools: 137 | payload["tools"]=[{ 138 | 'type':'function', 139 | 'function':{ 140 | 'name':tool.name, 141 | 'description':tool.description, 142 | 'parameters':tool.schema 143 | } 144 | } for tool in self.tools] 145 | try: 146 | async with AsyncClient() as client: 147 | response=await client.post(url=url,json=payload,headers=headers,timeout=None) 148 | json_object=response.json() 149 | # print(json_object) 150 | if json_object.get('error'): 151 | raise Exception(json_object['error']['message']) 152 | message=json_object['choices'][0]['message'] 153 | usage_metadata=json_object['usage'] 154 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 155 | self.tokens=Token(input=input,output=output,total=total) 156 | if model: 157 | return model.model_validate_json(message.get('content')) 158 | if json: 159 | return AIMessage(loads(message.get('content'))) 160 | if message.get('content'): 161 | return AIMessage(message.get('content')) 162 | else: 163 | tool_call=message.get('tool_calls')[0]['function'] 164 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 165 | except HTTPError as err: 166 | err_object=loads(err.response.text) 167 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 168 | except ConnectionError as err: 169 | print(err) 170 | exit() 171 | 172 | @sleep_and_retry 173 | @limits(calls=15,period=60) 174 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 175 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]: 176 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 177 | headers=self.headers 178 | temperature=self.temperature 179 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions" 180 | messages=[message.to_dict() for message in messages] 181 | payload={ 182 | "model": self.model, 183 | "messages": messages, 184 | "temperature": temperature, 185 | "stream":True, 186 | } 187 | if json: 188 | payload["response_format"]={ 189 | "type": "json_object" 190 | } 191 | try: 192 | response=requests.post(url=url,json=payload,headers=headers,timeout=None) 193 | response.raise_for_status() 194 | chunks=response.iter_lines(decode_unicode=True) 195 | for chunk in chunks: 196 | chunk=chunk.replace('data: ','') 197 | if chunk and chunk!='[DONE]': 198 | delta=loads(chunk)['choices'][0]['delta'] 199 | yield delta.get('content','') 200 | except HTTPError as err: 201 | err_object=loads(err.response.text) 202 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 203 | except ConnectionError as err: 204 | print(err) 205 | exit() 206 | 207 | def available_models(self): 208 | url="https://api.mistral.ai/v1/models" 209 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 210 | headers=self.headers 211 | response=requests.get(url=url,headers=headers) 212 | response.raise_for_status() 213 | models=response.json() 214 | return [model['id'] for model in models['data']] -------------------------------------------------------------------------------- /src/agent/web/context/__init__.py: -------------------------------------------------------------------------------- 1 | from playwright.async_api import Page,Browser as PlaywrightBrowser,Frame,ElementHandle,BrowserContext as PlaywrightContext 2 | from src.agent.web.context.config import IGNORED_URL_PATTERNS,RELEVANT_FILE_EXTENSIONS,RELEVANT_CONTEXT_TYPES 3 | from src.agent.web.browser.config import BROWSER_ARGS,SECURITY_ARGS,IGNORE_DEFAULT_ARGS 4 | from src.agent.web.context.views import BrowserSession,BrowserState,Tab 5 | from src.agent.web.context.config import ContextConfig 6 | from src.agent.web.dom.views import DOMElementNode 7 | from src.agent.web.browser import Browser 8 | from src.agent.web.dom import DOM 9 | from urllib.parse import urlparse 10 | from datetime import datetime 11 | from pathlib import Path 12 | from uuid import uuid4 13 | from os import getcwd 14 | 15 | class Context: 16 | def __init__(self,browser:Browser,config:ContextConfig=ContextConfig()): 17 | self.browser=browser 18 | self.config=config 19 | self.context_id=str(uuid4()) 20 | self.session:BrowserSession=None 21 | 22 | async def __aenter__(self): 23 | await self.init_session() 24 | return self 25 | 26 | async def __aexit__(self, exc_type, exc_val, exc_tb): 27 | await self.close_session() 28 | 29 | async def close_session(self): 30 | if self.session is None: 31 | return None 32 | try: 33 | await self.session.context.close() 34 | except Exception as e: 35 | print('Context failed to close',e) 36 | finally: 37 | self.browser_context=None 38 | 39 | async def init_session(self): 40 | browser=await self.browser.get_playwright_browser() 41 | context=await self.setup_context(browser) 42 | if browser is not None: # The case whether is no user_data provided 43 | page=await context.new_page() 44 | else: # The case where the user_data is provided 45 | pages=context.pages 46 | if len(pages): 47 | page=pages[0] 48 | else: 49 | page=await context.new_page() 50 | state=await self.initial_state(page) 51 | self.session=BrowserSession(context,page,state) 52 | 53 | async def initial_state(self,page:Page): 54 | screenshot,dom_state=None,[] 55 | current_tab=Tab(0,page.url,await page.title(),page) 56 | tabs=[] 57 | state=BrowserState(current_tab=current_tab,tabs=tabs,screenshot=screenshot,dom_state=dom_state) 58 | return state 59 | 60 | async def update_state(self,use_vision:bool=False): 61 | dom=DOM(self) 62 | screenshot,dom_state=await dom.get_state(use_vision=use_vision) 63 | tabs=await self.get_all_tabs() 64 | current_tab=await self.get_current_tab() 65 | state=BrowserState(current_tab=current_tab,tabs=tabs,screenshot=screenshot,dom_state=dom_state) 66 | return state 67 | 68 | async def get_state(self,use_vision=False)->BrowserState: 69 | session=await self.get_session() 70 | state=await self.update_state(use_vision=use_vision) 71 | session.state=state 72 | return session.state 73 | 74 | async def get_session(self)->BrowserSession: 75 | if self.session is None: 76 | await self.init_session() 77 | return self.session 78 | 79 | async def get_current_page(self)->Page: 80 | session=await self.get_session() 81 | if session.current_page is None: 82 | raise ValueError("No current page found") 83 | return session.current_page 84 | 85 | async def setup_context(self,browser:PlaywrightBrowser|None=None)->PlaywrightContext: 86 | if self.browser.config.device is not None: 87 | parameters={**self.browser.playwright.devices.get(self.browser.config.device)} 88 | parameters.pop('default_browser_type',None) 89 | else: 90 | parameters={ 91 | 'ignore_https_errors':self.config.disable_security, 92 | 'user_agent':self.config.user_agent, 93 | 'bypass_csp':self.config.disable_security, 94 | 'java_script_enabled':True, 95 | 'accept_downloads':True, 96 | 'no_viewport':True 97 | } 98 | if browser is not None: 99 | context=await browser.new_context(**parameters) 100 | with open('./src/agent/web/context/script.js') as f: 101 | script=f.read() 102 | await context.add_init_script(script) 103 | else: 104 | args=['--no-sandbox','--disable-blink-features=AutomationControlled','--disable-blink-features=IdleDetection','--no-infobars'] 105 | parameters=parameters|{ 106 | 'headless':self.browser.config.headless, 107 | 'slow_mo':self.browser.config.slow_mo, 108 | 'ignore_default_args': IGNORE_DEFAULT_ARGS, 109 | 'args': args+SECURITY_ARGS, 110 | 'user_data_dir': self.browser.config.user_data_dir, 111 | 'downloads_path': self.browser.config.downloads_dir, 112 | 'executable_path': self.browser.config.browser_instance_dir, 113 | } 114 | # browser is None if the user_data_dir is not None in the Browser class 115 | browser=self.browser.config.browser 116 | if browser=='chrome': 117 | context=await self.browser.playwright.chromium.launch_persistent_context(channel='chrome',**parameters) 118 | elif browser=='firefox': 119 | context=await self.browser.playwright.firefox.launch_persistent_context(**parameters) 120 | elif browser=='edge': 121 | context=await self.browser.playwright.chromium.launch_persistent_context(channel='msedge',**parameters) 122 | else: 123 | raise Exception('Invalid Browser Type') 124 | return context 125 | 126 | async def get_all_tabs(self)->list[Tab]: 127 | session=await self.get_session() 128 | pages=session.context.pages 129 | tabs:list[Tab]=[] 130 | for id,page in enumerate(pages): 131 | await page.wait_for_load_state('domcontentloaded') 132 | try: 133 | url=page.url 134 | title=await page.title() 135 | except Exception as e: 136 | print(f'Tab failed to load: {e}') 137 | continue 138 | tabs.append(Tab(id=id,url=url,title=title,page=page)) 139 | return tabs 140 | 141 | async def get_current_tab(self)->Tab: 142 | tabs=await self.get_all_tabs() 143 | current_page=await self.get_current_page() 144 | return next((tab for tab in tabs if tab.page==current_page),None) 145 | 146 | async def get_selector_map(self)->dict[int,DOMElementNode]: 147 | session=await self.get_session() 148 | return session.state.dom_state.selector_map 149 | 150 | async def get_element_by_index(self,index:int)->DOMElementNode: 151 | selector_map=await self.get_selector_map() 152 | if index not in selector_map.keys(): 153 | raise Exception(f'Element under index {index} not found') 154 | element=selector_map.get(index) 155 | return element 156 | 157 | async def get_handle_by_xpath(self,xpath:dict[str,str])->ElementHandle: 158 | frame=await self.get_frame_by_xpath(xpath) 159 | _,element_xpath=xpath.values() 160 | element=await frame.locator(f'xpath={element_xpath}').element_handle() 161 | return element 162 | 163 | async def get_frame_by_xpath(self,xpath:dict[str,str])->Frame: 164 | page=await self.get_current_page() 165 | frame_xpath,_=xpath.values() 166 | if frame_xpath: # handle elements from iframe 167 | frame=page.frame_locator(f'xpath={frame_xpath}') 168 | else: # handle elements from main frame 169 | frame=page.main_frame 170 | return frame 171 | 172 | async def execute_script(self,obj:Frame|Page,script:str,args:list=None,enable_handle:bool=False): 173 | if enable_handle: 174 | handle=await obj.evaluate_handle(script,args) 175 | return handle.as_element() 176 | return await obj.evaluate(script,args) 177 | 178 | async def get_viewport(self)->tuple[int,int]: 179 | page=await self.get_current_page() 180 | viewport:dict=await self.execute_script(page,'({width: window.innerWidth, height: window.innerHeight})') 181 | return(viewport.get('width'),viewport.get('height')) 182 | 183 | def is_ad_url(self,url:str)->bool: 184 | url_pattern=urlparse(url).netloc 185 | if not url_pattern: 186 | return True 187 | return any(pattern in url_pattern for pattern in IGNORED_URL_PATTERNS) 188 | 189 | async def is_frame_visible(self,frame:Frame)->bool: 190 | if frame.is_detached() or self.is_ad_url(frame.url): 191 | return False 192 | frame_element=await frame.frame_element() 193 | if frame_element is None: 194 | return False 195 | style=await frame_element.get_attribute('style') 196 | if style is not None: 197 | css:dict=self.inline_style_parser(style) 198 | if any([css.get('display')=='none',css.get('visibility')=='hidden']): 199 | return False 200 | bbox=await frame_element.bounding_box() 201 | if bbox is None: 202 | return False 203 | area=bbox.get('width')*bbox.get('height') 204 | if any([bbox.get('x')<0,bbox.get('y')<0,area<10]): 205 | return False 206 | return True 207 | 208 | async def get_screenshot(self,save_screenshot:bool=False,full_page:bool=False): 209 | page=await self.get_current_page() 210 | if save_screenshot: 211 | date_time=datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 212 | folder_path=Path(getcwd()).joinpath('./screenshots') 213 | folder_path.mkdir(parents=True,exist_ok=True) 214 | path=folder_path.joinpath(f'screenshot_{date_time}.jpeg') 215 | else: 216 | path=None 217 | await page.wait_for_timeout(2*1000) 218 | screenshot=await page.screenshot(path=path,full_page=full_page,animations='disabled',type='jpeg') 219 | return screenshot 220 | 221 | def inline_style_parser(self,style:str)->dict[str,str]: 222 | styles = {} 223 | if not style: 224 | return styles 225 | for rule in style.split(";"): 226 | if ":" in rule: 227 | prop, val = rule.split(":", 1) 228 | styles[prop.strip()] = val.strip() 229 | return styles -------------------------------------------------------------------------------- /src/inference/ollama.py: -------------------------------------------------------------------------------- 1 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 2 | from src.message import AIMessage,BaseMessage,ToolMessage 3 | from requests import get,RequestException,ConnectionError 4 | from httpx import Client,AsyncClient,HTTPError 5 | from ratelimit import limits,sleep_and_retry 6 | from src.inference import BaseInference,Token 7 | from pydantic import BaseModel 8 | from typing import Generator 9 | from json import loads 10 | from uuid import uuid4 11 | 12 | class ChatOllama(BaseInference): 13 | @sleep_and_retry 14 | @limits(calls=15,period=60) 15 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 16 | def invoke(self,messages: list[BaseMessage],json=False,model:BaseModel=None)->AIMessage: 17 | headers=self.headers 18 | temperature=self.temperature 19 | url=self.base_url or "http://localhost:11434/api/chat" 20 | payload={ 21 | "model": self.model, 22 | "messages": [message.to_dict() for message in messages], 23 | "options":{ 24 | "temperature": temperature, 25 | }, 26 | "stream":False 27 | } 28 | if json: 29 | payload['format']='json' 30 | if model: 31 | payload['format']=model.model_json_schema() 32 | if self.tools: 33 | payload["tools"]=[{ 34 | 'type':'function', 35 | 'function':{ 36 | 'name':tool.name, 37 | 'description':tool.description, 38 | 'parameters':tool.schema 39 | } 40 | } for tool in self.tools] 41 | try: 42 | with Client() as client: 43 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 44 | response.raise_for_status() 45 | json_object=response.json() 46 | message=json_object['message'] 47 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count'] 48 | self.tokens=Token(input=input,output=output,total=total) 49 | if model: 50 | return model.model_validate_json(message.get('content')) 51 | if json: 52 | return AIMessage(loads(message.get('content'))) 53 | if message.get('content'): 54 | return AIMessage(message.get('content')) 55 | else: 56 | tool_call=message.get('tool_calls')[0]['function'] 57 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 58 | except HTTPError as err: 59 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 60 | 61 | @sleep_and_retry 62 | @limits(calls=15,period=60) 63 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 64 | async def async_invoke(self,messages: list[BaseMessage],json=False,model:BaseModel=None)->AIMessage: 65 | headers=self.headers 66 | temperature=self.temperature 67 | url=self.base_url or "http://localhost:11434/api/chat" 68 | payload={ 69 | "model": self.model, 70 | "messages": [message.to_dict() for message in messages], 71 | "options":{ 72 | "temperature": temperature, 73 | }, 74 | "stream":False 75 | } 76 | if json: 77 | payload['format']='json' 78 | if model: 79 | payload['format']=model.model_json_schema() 80 | if self.tools: 81 | payload["tools"]=[{ 82 | 'type':'function', 83 | 'function':{ 84 | 'name':tool.name, 85 | 'description':tool.description, 86 | 'parameters':tool.schema 87 | } 88 | } for tool in self.tools] 89 | try: 90 | async with AsyncClient() as client: 91 | response=await client.post(url=url,json=payload,headers=headers,timeout=None) 92 | response.raise_for_status() 93 | json_object=response.json() 94 | message=json_object['message'] 95 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count'] 96 | self.tokens=Token(input=input,output=output,total=total) 97 | if model: 98 | return model.model_validate_json(message.get('content')) 99 | if json: 100 | return AIMessage(loads(message.get('content'))) 101 | if message.get('content'): 102 | return AIMessage(message.get('content')) 103 | else: 104 | tool_call=message.get('tool_calls')[0]['function'] 105 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 106 | except HTTPError as err: 107 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 108 | 109 | @sleep_and_retry 110 | @limits(calls=15,period=60) 111 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 112 | def stream(self,messages: list[BaseMessage],json=False)->Generator[str,None,None]: 113 | headers=self.headers 114 | temperature=self.temperature 115 | url=self.base_url or "http://localhost:11434/api/chat" 116 | payload={ 117 | "model": self.model, 118 | "messages": [message.to_dict() for message in messages], 119 | "options":{ 120 | "temperature": temperature, 121 | }, 122 | "stream":True 123 | } 124 | if json: 125 | payload['format']='json' 126 | try: 127 | with Client() as client: 128 | response=client.post(url=url,json=payload,headers=headers,stream=True,timeout=None) 129 | response.raise_for_status() 130 | chunks=response.iter_lines(decode_unicode=True) 131 | return (loads(chunk)['message']['content'] for chunk in chunks) 132 | except HTTPError as err: 133 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 134 | except ConnectionError as err: 135 | print(err) 136 | exit() 137 | 138 | def available_models(self): 139 | url='http://localhost:11434/api/tags' 140 | headers=self.headers 141 | response=get(url=url,headers=headers) 142 | response.raise_for_status() 143 | models=response.json() 144 | return [model['name'] for model in models['models']] 145 | 146 | class Ollama(BaseInference): 147 | @sleep_and_retry 148 | @limits(calls=15,period=60) 149 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 150 | def invoke(self, query:str,json=False,model:BaseModel=None)->AIMessage: 151 | headers=self.headers 152 | temperature=self.temperature 153 | url=self.base_url or "http://localhost:11434/api/generate" 154 | payload={ 155 | "model": self.model, 156 | "prompt": query, 157 | "options":{ 158 | "temperature": temperature, 159 | }, 160 | "format":'json' if json else '', 161 | "stream":False 162 | } 163 | if json: 164 | payload['format']='json' 165 | if model: 166 | payload['format']=model.model_json_schema() 167 | try: 168 | with Client() as client: 169 | response=client.post(url=url,json=payload,headers=headers) 170 | response.raise_for_status() 171 | json_object=response.json() 172 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count'] 173 | self.tokens=Token(input=input,output=output,total=total) 174 | if model: 175 | return model.model_validate_json(json_object.get('response')) 176 | if json: 177 | return AIMessage(loads(json_object.get('response'))) 178 | return AIMessage(json_object.get('response')) 179 | except HTTPError as err: 180 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 181 | 182 | @sleep_and_retry 183 | @limits(calls=15,period=60) 184 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 185 | async def async_invoke(self, query:str,json=False,model:BaseModel=None)->AIMessage: 186 | headers=self.headers 187 | temperature=self.temperature 188 | url=self.base_url or "http://localhost:11434/api/generate" 189 | payload={ 190 | "model": self.model, 191 | "prompt": query, 192 | "options":{ 193 | "temperature": temperature, 194 | }, 195 | "format":'json' if json else '', 196 | "stream":False 197 | } 198 | if json: 199 | payload['format']='json' 200 | if model: 201 | payload['format']=model.model_json_schema() 202 | try: 203 | async with AsyncClient() as client: 204 | response=await client.post(url=url,json=payload,headers=headers) 205 | response.raise_for_status() 206 | json_object=response.json() 207 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count'] 208 | self.tokens=Token(input=input,output=output,total=total) 209 | if model: 210 | return model.model_validate_json(json_object.get('response')) 211 | if json: 212 | return AIMessage(loads(json_object.get('response'))) 213 | return AIMessage(json_object.get('response')) 214 | except HTTPError as err: 215 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 216 | 217 | @sleep_and_retry 218 | @limits(calls=15,period=60) 219 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 220 | def stream(self,query:str,json=False)->Generator[str,None,None]: 221 | headers=self.headers 222 | temperature=self.temperature 223 | url=self.base_url or "http://localhost:11434/api/generate" 224 | payload={ 225 | "model": self.model, 226 | "prompt": query, 227 | "options":{ 228 | "temperature": temperature, 229 | }, 230 | "format":'json' if json else '', 231 | "stream":True 232 | } 233 | try: 234 | with Client() as client: 235 | response=client.post(url=url,json=payload,headers=headers,stream=True) 236 | response.raise_for_status() 237 | chunks=response.iter_lines(decode_unicode=True) 238 | return (loads(chunk)['response'] for chunk in chunks) 239 | except HTTPError as err: 240 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 241 | except ConnectionError as err: 242 | print(err) 243 | exit() 244 | 245 | def available_models(self): 246 | url='http://localhost:11434/api/tags' 247 | headers=self.headers 248 | response=get(url=url,headers=headers) 249 | response.raise_for_status() 250 | models=response.json() 251 | return [model['name'] for model in models['models']] 252 | -------------------------------------------------------------------------------- /src/inference/openai.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage 2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 3 | from requests import RequestException,HTTPError,ConnectionError 4 | from ratelimit import limits,sleep_and_retry 5 | from httpx import Client,AsyncClient 6 | from src.inference import BaseInference,Token 7 | from pydantic import BaseModel 8 | from typing import Generator 9 | from typing import Literal 10 | from pathlib import Path 11 | from json import loads 12 | from uuid import uuid4 13 | import mimetypes 14 | import requests 15 | 16 | class ChatOpenAI(BaseInference): 17 | @sleep_and_retry 18 | @limits(calls=15,period=60) 19 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 20 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel: 21 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 22 | headers=self.headers 23 | temperature=self.temperature 24 | url=self.base_url or "https://api.openai.com/v1/chat/completions" 25 | contents=[] 26 | for message in messages: 27 | if isinstance(message,SystemMessage): 28 | if model: 29 | message.content=self.structured(message,model) 30 | contents.append(message.to_dict()) 31 | if isinstance(message,(HumanMessage,AIMessage)): 32 | contents.append(message.to_dict()) 33 | if isinstance(message,ImageMessage): 34 | text,image=message.content 35 | contents.append([ 36 | { 37 | 'role':'user', 38 | 'content':[ 39 | { 40 | 'type':'text', 41 | 'text':text 42 | }, 43 | { 44 | 'type':'image_url', 45 | 'image_url':{ 46 | 'url':image 47 | } 48 | } 49 | ] 50 | } 51 | ]) 52 | 53 | payload={ 54 | "model": self.model, 55 | "messages": contents, 56 | "temperature": temperature, 57 | "response_format": { 58 | "type": "json_object" if json or model else "text" 59 | }, 60 | "stream":False, 61 | } 62 | if self.tools: 63 | payload["tools"]=[{ 64 | 'type':'function', 65 | 'function':{ 66 | 'name':tool.name, 67 | 'description':tool.description, 68 | 'parameters':tool.schema 69 | } 70 | } for tool in self.tools] 71 | try: 72 | with Client() as client: 73 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 74 | json_object=response.json() 75 | # print(json_object) 76 | if json_object.get('error'): 77 | raise HTTPError(json_object['error']['message']) 78 | message=json_object['choices'][0]['message'] 79 | usage_metadata=json_object['usage'] 80 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 81 | self.tokens=Token(input=input,output=output,total=total) 82 | if model: 83 | return model.model_validate_json(message.get('content')) 84 | if json: 85 | return AIMessage(loads(message.get('content'))) 86 | if message.get('content'): 87 | return AIMessage(message.get('content')) 88 | else: 89 | tool_call=message.get('tool_calls')[0]['function'] 90 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 91 | except HTTPError as err: 92 | err_object=loads(err.response.text) 93 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 94 | except ConnectionError as err: 95 | print(err) 96 | exit() 97 | 98 | @sleep_and_retry 99 | @limits(calls=15,period=60) 100 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 101 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel: 102 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 103 | headers=self.headers 104 | temperature=self.temperature 105 | url=self.base_url or "https://api.openai.com/v1/chat/completions" 106 | contents=[] 107 | for message in messages: 108 | if isinstance(message,SystemMessage): 109 | if model: 110 | message.content=self.structured(message,model) 111 | contents.append(message.to_dict()) 112 | if isinstance(message,(HumanMessage,AIMessage)): 113 | contents.append(message.to_dict()) 114 | if isinstance(message,ImageMessage): 115 | text,image=message.content 116 | contents.append([ 117 | { 118 | 'role':'user', 119 | 'content':[ 120 | { 121 | 'type':'text', 122 | 'text':text 123 | }, 124 | { 125 | 'type':'image_url', 126 | 'image_url':{ 127 | 'url':image 128 | } 129 | } 130 | ] 131 | } 132 | ]) 133 | 134 | payload={ 135 | "model": self.model, 136 | "messages": contents, 137 | "temperature": temperature, 138 | "response_format": { 139 | "type": "json_object" if json or model else "text" 140 | }, 141 | "stream":False, 142 | } 143 | if self.tools: 144 | payload["tools"]=[{ 145 | 'type':'function', 146 | 'function':{ 147 | 'name':tool.name, 148 | 'description':tool.description, 149 | 'parameters':tool.schema 150 | } 151 | } for tool in self.tools] 152 | try: 153 | async with AsyncClient() as client: 154 | response=await client.post(url=url,json=payload,headers=headers,timeout=None) 155 | json_object=response.json() 156 | # print(json_object) 157 | if json_object.get('error'): 158 | raise HTTPError(json_object['error']['message']) 159 | message=json_object['choices'][0]['message'] 160 | usage_metadata=json_object['usage'] 161 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 162 | self.tokens=Token(input=input,output=output,total=total) 163 | if model: 164 | return model.model_validate_json(message.get('content')) 165 | if json: 166 | return AIMessage(loads(message.get('content'))) 167 | if message.get('content'): 168 | return AIMessage(message.get('content')) 169 | else: 170 | tool_call=message.get('tool_calls')[0]['function'] 171 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 172 | except HTTPError as err: 173 | err_object=loads(err.response.text) 174 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 175 | except ConnectionError as err: 176 | print(err) 177 | exit() 178 | 179 | @sleep_and_retry 180 | @limits(calls=15,period=60) 181 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 182 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]: 183 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 184 | headers=self.headers 185 | temperature=self.temperature 186 | url=self.base_url or "https://api.openai.com/v1/chat/completions" 187 | messages=[message.to_dict() for message in messages] 188 | payload={ 189 | "model": self.model, 190 | "messages": messages, 191 | "temperature": temperature, 192 | "response_format": { 193 | "type": "json_object" if json else "text" 194 | }, 195 | "stream":True, 196 | } 197 | try: 198 | response=requests.post(url=url,json=payload,headers=headers) 199 | response.raise_for_status() 200 | chunks=response.iter_lines(decode_unicode=True) 201 | for chunk in chunks: 202 | chunk=chunk.replace('data: ','') 203 | if chunk and chunk!='[DONE]': 204 | delta=loads(chunk)['choices'][0]['delta'] 205 | yield delta.get('content','') 206 | except HTTPError as err: 207 | err_object=loads(err.response.text) 208 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 209 | except ConnectionError as err: 210 | print(err) 211 | exit() 212 | 213 | def available_models(self): 214 | url='https://api.openai.com/v1/models' 215 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 216 | headers=self.headers 217 | response=requests.get(url=url,headers=headers) 218 | response.raise_for_status() 219 | models=response.json() 220 | return [model['id'] for model in models['data'] if model['active']] 221 | 222 | class AudioOpenAI(BaseInference): 223 | def __init__(self,mode:Literal['transcriptions','translations']='transcriptions', model: str = '', api_key: str = '', base_url: str = '', temperature: float = 0.5): 224 | self.mode=mode 225 | super().__init__(model, api_key, base_url, temperature) 226 | 227 | def invoke(self,file_path:str='', language:str='en', json:bool=False)->AIMessage: 228 | path=Path(file_path) 229 | headers={'Authorization': f'Bearer {self.api_key}'} 230 | url=self.base_url or f"https://api.openai.com/v1/audio/{self.mode}" 231 | data={ 232 | "model": self.model, 233 | "temperature": self.temperature, 234 | "response_format": "json_object" if json else "text" 235 | } 236 | if self.mode=='transcriptions': 237 | data['language']=language 238 | # Get the MIME type for the file 239 | mime_type, _ = mimetypes.guess_type(path.name) 240 | files={ 241 | 'file': (path.name,self.__read_audio(path),mime_type) 242 | } 243 | try: 244 | with Client() as client: 245 | response=client.post(url=url,data=data,files=files,headers=headers,timeout=None) 246 | response.raise_for_status() 247 | if json: 248 | content=loads(response.text)['text'] 249 | else: 250 | content=response.text 251 | return AIMessage(content) 252 | except HTTPError as err: 253 | err_object=loads(err.response.text) 254 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 255 | except ConnectionError as err: 256 | print(err) 257 | exit() 258 | 259 | def __read_audio(self,file_path:str): 260 | with open(file_path,'rb') as f: 261 | audio_data=f.read() 262 | return audio_data 263 | 264 | def async_invoke(self, messages:BaseMessage=[]): 265 | pass 266 | 267 | def stream(self, messages:BaseMessage=[]): 268 | pass 269 | 270 | def available_models(self): 271 | url='https://api.groq.com/openai/v1/models' 272 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 273 | headers=self.headers 274 | response=requests.get(url=url,headers=headers) 275 | response.raise_for_status() 276 | models=response.json() 277 | return [model['id'] for model in models['data'] if model['active']] 278 | -------------------------------------------------------------------------------- /src/inference/groq.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage 2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 3 | from requests import RequestException,HTTPError,ConnectionError 4 | from ratelimit import limits,sleep_and_retry 5 | from httpx import Client,AsyncClient 6 | from src.inference import BaseInference,Token 7 | from pydantic import BaseModel 8 | from typing import Generator 9 | from typing import Literal 10 | from pathlib import Path 11 | from json import loads 12 | from uuid import uuid4 13 | import mimetypes 14 | import requests 15 | 16 | class ChatGroq(BaseInference): 17 | @sleep_and_retry 18 | @limits(calls=15,period=60) 19 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 20 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel: 21 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 22 | headers=self.headers 23 | temperature=self.temperature 24 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions" 25 | contents=[] 26 | for message in messages: 27 | if isinstance(message,SystemMessage): 28 | if model: 29 | message.content=self.structured(message,model) 30 | contents.append(message.to_dict()) 31 | if isinstance(message,(HumanMessage,AIMessage)): 32 | contents.append(message.to_dict()) 33 | if isinstance(message,ImageMessage): 34 | text,image=message.content 35 | contents.append([ 36 | { 37 | 'role':'user', 38 | 'content':[ 39 | { 40 | 'type':'text', 41 | 'text':text 42 | }, 43 | { 44 | 'type':'image_url', 45 | 'image_url':{ 46 | 'url':image 47 | } 48 | } 49 | ] 50 | } 51 | ]) 52 | 53 | payload={ 54 | "model": self.model, 55 | "messages": contents, 56 | "temperature": temperature, 57 | "response_format": { 58 | "type": "json_object" if json or model else "text" 59 | }, 60 | "stream":False, 61 | } 62 | if self.tools: 63 | payload["tools"]=[{ 64 | 'type':'function', 65 | 'function':{ 66 | 'name':tool.name, 67 | 'description':tool.description, 68 | 'parameters':tool.schema 69 | } 70 | } for tool in self.tools] 71 | try: 72 | with Client() as client: 73 | response=client.post(url=url,json=payload,headers=headers,timeout=None) 74 | json_object=response.json() 75 | # print(json_object) 76 | if json_object.get('error'): 77 | raise HTTPError(json_object['error']['message']) 78 | message=json_object['choices'][0]['message'] 79 | usage_metadata=json_object['usage'] 80 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 81 | self.tokens=Token(input=input,output=output,total=total) 82 | if model: 83 | return model.model_validate_json(message.get('content')) 84 | if json: 85 | return AIMessage(loads(message.get('content'))) 86 | if message.get('content'): 87 | return AIMessage(message.get('content')) 88 | else: 89 | tool_call=message.get('tool_calls')[0]['function'] 90 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 91 | except HTTPError as err: 92 | err_object=loads(err.response.text) 93 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 94 | except ConnectionError as err: 95 | print(err) 96 | exit() 97 | 98 | @sleep_and_retry 99 | @limits(calls=15,period=60) 100 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 101 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel: 102 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 103 | headers=self.headers 104 | temperature=self.temperature 105 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions" 106 | contents=[] 107 | for message in messages: 108 | if isinstance(message,SystemMessage): 109 | if model: 110 | message.content=self.structured(message,model) 111 | contents.append(message.to_dict()) 112 | if isinstance(message,(HumanMessage,AIMessage)): 113 | contents.append(message.to_dict()) 114 | if isinstance(message,ImageMessage): 115 | text,image=message.content 116 | contents.append([ 117 | { 118 | 'role':'user', 119 | 'content':[ 120 | { 121 | 'type':'text', 122 | 'text':text 123 | }, 124 | { 125 | 'type':'image_url', 126 | 'image_url':{ 127 | 'url':image 128 | } 129 | } 130 | ] 131 | } 132 | ]) 133 | 134 | payload={ 135 | "model": self.model, 136 | "messages": contents, 137 | "temperature": temperature, 138 | "response_format": { 139 | "type": "json_object" if json or model else "text" 140 | }, 141 | "stream":False, 142 | } 143 | if self.tools: 144 | payload["tools"]=[{ 145 | 'type':'function', 146 | 'function':{ 147 | 'name':tool.name, 148 | 'description':tool.description, 149 | 'parameters':tool.schema 150 | } 151 | } for tool in self.tools] 152 | try: 153 | async with AsyncClient() as client: 154 | response=await client.post(url=url,json=payload,headers=headers,timeout=None) 155 | json_object=response.json() 156 | # print(json_object) 157 | if json_object.get('error'): 158 | raise HTTPError(json_object['error']['message']) 159 | message=json_object['choices'][0]['message'] 160 | usage_metadata=json_object['usage'] 161 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens'] 162 | self.tokens=Token(input=input,output=output,total=total) 163 | if model: 164 | return model.model_validate_json(message.get('content')) 165 | if json: 166 | return AIMessage(loads(message.get('content'))) 167 | if message.get('content'): 168 | return AIMessage(message.get('content')) 169 | else: 170 | tool_call=message.get('tool_calls')[0]['function'] 171 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments']) 172 | except HTTPError as err: 173 | err_object=loads(err.response.text) 174 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 175 | except ConnectionError as err: 176 | print(err) 177 | exit() 178 | 179 | @sleep_and_retry 180 | @limits(calls=15,period=60) 181 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 182 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]: 183 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 184 | headers=self.headers 185 | temperature=self.temperature 186 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions" 187 | messages=[message.to_dict() for message in messages] 188 | payload={ 189 | "model": self.model, 190 | "messages": messages, 191 | "temperature": temperature, 192 | "response_format": { 193 | "type": "json_object" if json else "text" 194 | }, 195 | "stream":True, 196 | } 197 | try: 198 | response=requests.post(url=url,json=payload,headers=headers) 199 | response.raise_for_status() 200 | chunks=response.iter_lines(decode_unicode=True) 201 | for chunk in chunks: 202 | chunk=chunk.replace('data: ','') 203 | if chunk and chunk!='[DONE]': 204 | delta=loads(chunk)['choices'][0]['delta'] 205 | yield delta.get('content','') 206 | except HTTPError as err: 207 | err_object=loads(err.response.text) 208 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 209 | except ConnectionError as err: 210 | print(err) 211 | exit() 212 | 213 | def available_models(self): 214 | url='https://api.groq.com/openai/v1/models' 215 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 216 | headers=self.headers 217 | response=requests.get(url=url,headers=headers) 218 | response.raise_for_status() 219 | models=response.json() 220 | return [model['id'] for model in models['data'] if model['active']] 221 | 222 | class AudioGroq(BaseInference): 223 | def __init__(self,mode:Literal['transcriptions','translations']='transcriptions', model: str = '', api_key: str = '', base_url: str = '', temperature: float = 0.5): 224 | self.mode=mode 225 | super().__init__(model, api_key, base_url, temperature) 226 | 227 | def invoke(self,file_path:str='', language:str='en', json:bool=False)->AIMessage: 228 | path=Path(file_path) 229 | headers={'Authorization': f'Bearer {self.api_key}'} 230 | url=self.base_url or f"https://api.groq.com/openai/v1/audio/{self.mode}" 231 | data={ 232 | "model": self.model, 233 | "temperature": self.temperature, 234 | "response_format": "json_object" if json else "text" 235 | } 236 | if self.mode=='transcriptions': 237 | data['language']=language 238 | # Get the MIME type for the file 239 | mime_type, _ = mimetypes.guess_type(path.name) 240 | files={ 241 | 'file': (path.name,self.__read_audio(path),mime_type) 242 | } 243 | try: 244 | with Client() as client: 245 | response=client.post(url=url,data=data,files=files,headers=headers,timeout=None) 246 | response.raise_for_status() 247 | if json: 248 | content=loads(response.text)['text'] 249 | else: 250 | content=response.text 251 | return AIMessage(content) 252 | except HTTPError as err: 253 | err_object=loads(err.response.text) 254 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}') 255 | except ConnectionError as err: 256 | print(err) 257 | exit() 258 | 259 | def __read_audio(self,file_path:str): 260 | with open(file_path,'rb') as f: 261 | audio_data=f.read() 262 | return audio_data 263 | 264 | def async_invoke(self, messages:BaseMessage=[]): 265 | pass 266 | 267 | def stream(self, messages:BaseMessage=[]): 268 | pass 269 | 270 | def available_models(self): 271 | url='https://api.groq.com/openai/v1/models' 272 | self.headers.update({'Authorization': f'Bearer {self.api_key}'}) 273 | headers=self.headers 274 | response=requests.get(url=url,headers=headers) 275 | response.raise_for_status() 276 | models=response.json() 277 | return [model['id'] for model in models['data'] if model['active']] 278 | -------------------------------------------------------------------------------- /src/agent/web/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.tools.views import Click,Type,Wait,Scroll,GoTo,Back,Key,Download,Scrape,Tab,Upload,Menu,Done,Forward,HumanInput,Script 2 | from src.agent.web.context import Context 3 | from markdownify import markdownify 4 | from typing import Literal,Optional 5 | from termcolor import colored 6 | from src.tool import Tool 7 | from asyncio import sleep 8 | from pathlib import Path 9 | from os import getcwd 10 | import httpx 11 | 12 | @Tool('Done Tool',params=Done) 13 | async def done_tool(content:str,context:Context=None): 14 | '''Indicates that the current task has been completed successfully. Use this to signal completion and provide a summary of what was accomplished.''' 15 | return content 16 | 17 | @Tool('Click Tool',params=Click) 18 | async def click_tool(index:int,context:Context=None): 19 | '''Clicks on interactive elements like buttons, links, checkboxes, radio buttons, tabs, or any clickable UI component. Automatically scrolls the element into view if needed and handles hidden elements.''' 20 | page=await context.get_current_page() 21 | await page.wait_for_load_state('load') 22 | element=await context.get_element_by_index(index=index) 23 | handle=await context.get_handle_by_xpath(element.xpath) 24 | is_hidden=await handle.is_hidden() 25 | if not is_hidden: 26 | await handle.scroll_into_view_if_needed() 27 | await handle.click(force=True) 28 | return f'Clicked on the element at label {index}' 29 | 30 | @Tool('Type Tool',params=Type) 31 | async def type_tool(index:int,text:str,clear:Literal['True','False']='False',press_enter:Literal['True','False']='False',context:Context=None): 32 | '''Types text into input fields, text areas, search boxes, or any editable element. Can optionally clear existing content before typing. Includes natural typing delay for better compatibility.''' 33 | page=await context.get_current_page() 34 | element=await context.get_element_by_index(index=index) 35 | handle=await context.get_handle_by_xpath(element.xpath) 36 | await page.wait_for_load_state('load') 37 | is_hidden=await handle.is_hidden() 38 | if not is_hidden: 39 | await handle.scroll_into_view_if_needed() 40 | await handle.click(force=True) 41 | if clear=='True': 42 | await page.keyboard.press('Control+A') 43 | await page.keyboard.press('Backspace') 44 | await page.keyboard.type(text,delay=80) 45 | if press_enter=='True': 46 | await page.keyboard.press('Enter') 47 | return f'Typed {text} in element at label {index}' 48 | 49 | @Tool('Wait Tool',params=Wait) 50 | async def wait_tool(time:int,context:Context=None): 51 | '''Pauses execution for a specified number of seconds. Use this to wait for page loading, animations to complete, or content to appear after an action.''' 52 | await sleep(time) 53 | return f'Waited for {time}s' 54 | 55 | @Tool('Scroll Tool',params=Scroll) 56 | async def scroll_tool(direction:Literal['up','down']='up',index:int=None,amount:int=500,context:Context=None): 57 | '''Scrolls either the webpage or a specific scrollable container. Can scroll by page increments or by specific pixel amounts. If index is provided, scrolls the specific element container; otherwise scrolls the page. Automatically detects scrollable containers and prevents unnecessary scroll attempts.''' 58 | page=await context.get_current_page() 59 | if index is not None: 60 | element=await context.get_element_by_index(index=index) 61 | handle=await context.get_handle_by_xpath(xpath=element.xpath) 62 | if direction=='up': 63 | await page.evaluate(f'(element)=> element.scrollBy(0,{-amount})', handle) 64 | elif direction=='down': 65 | await page.evaluate(f'(element)=> element.scrollBy(0,{amount})', handle) 66 | else: 67 | raise ValueError('Invalid direction') 68 | return f'Scrolled {direction} inside the element at label {index} by {amount}' 69 | else: 70 | scroll_y_before = await context.execute_script(page,"() => window.scrollY") 71 | max_scroll_y = await context.execute_script(page,"() => document.documentElement.scrollHeight - window.innerHeight") 72 | min_scroll_y = await context.execute_script(page,"() => document.documentElement.scrollHeight") 73 | # Check if scrolling is possible 74 | if scroll_y_before >= max_scroll_y and direction == 'down': 75 | return "Already at the bottom, cannot scroll further." 76 | elif scroll_y_before == min_scroll_y and direction == 'up': 77 | return "Already at the top, cannot scroll further." 78 | if direction=='up': 79 | if amount is None: 80 | await page.keyboard.press('PageUp') 81 | else: 82 | await page.mouse.wheel(0,-amount) 83 | elif direction=='down': 84 | if amount is None: 85 | await page.keyboard.press('PageDown') 86 | else: 87 | await page.mouse.wheel(0,amount) 88 | else: 89 | raise ValueError('Invalid direction') 90 | # Get scroll position after scrolling 91 | scroll_y_after = await context.execute_script(page,"() => window.scrollY") 92 | # Verify if scrolling was successful 93 | if scroll_y_before == scroll_y_after: 94 | return "Scrolling has no effect, the entire content fits within the viewport." 95 | amount=amount if amount else 'one page' 96 | return f'Scrolled {direction} by {amount}' 97 | 98 | @Tool('GoTo Tool',params=GoTo) 99 | async def goto_tool(url:str,context:Context=None): 100 | '''Navigates directly to a specified URL in the current tab. Supports HTTP/HTTPS URLs and waits for the DOM content to load before proceeding.''' 101 | page=await context.get_current_page() 102 | await page.goto(url=url,wait_until='domcontentloaded') 103 | await page.wait_for_timeout(2.5*1000) 104 | return f'Navigated to {url}' 105 | 106 | @Tool('Back Tool',params=Back) 107 | async def back_tool(context:Context=None): 108 | '''Navigates to the previous page in the browser history, equivalent to clicking the browser's back button. Waits for the page to fully load.''' 109 | page=await context.get_current_page() 110 | await page.go_back() 111 | await page.wait_for_load_state('load') 112 | return 'Navigated to previous page' 113 | 114 | @Tool('Forward Tool',params=Forward) 115 | async def forward_tool(context:Context=None): 116 | '''Navigates to the next page in the browser history, equivalent to clicking the browser's forward button. Waits for the page to fully load.''' 117 | page=await context.get_current_page() 118 | await page.go_forward() 119 | await page.wait_for_load_state('load') 120 | return 'Navigated to next page' 121 | 122 | @Tool('Key Tool',params=Key) 123 | async def key_tool(keys:str,times:int=1,context:Context=None): 124 | '''Performs keyboard shortcuts and key combinations (e.g., "Control+C", "Enter", "Escape", "Tab"). Can repeat the key press multiple times. Supports all standard keyboard keys and modifiers.''' 125 | page=await context.get_current_page() 126 | await page.wait_for_load_state('domcontentloaded') 127 | for _ in range(times): 128 | await page.keyboard.press(keys) 129 | return f'Pressed {keys}' 130 | 131 | @Tool('Download Tool',params=Download) 132 | async def download_tool(url:str=None,filename:str=None,context:Context=None): 133 | '''Downloads files from the internet (PDFs, images, videos, audio, documents) and saves them to the system's downloads directory. Handles various file types and formats.''' 134 | folder_path=Path(context.browser.config.downloads_dir) 135 | async with httpx.AsyncClient() as client: 136 | response=await client.get(url) 137 | path=folder_path.joinpath(filename) 138 | with open(path,'wb') as f: 139 | async for chunk in response.aiter_bytes(): 140 | f.write(chunk) 141 | return f'Downloaded {filename} from {url} and saved it to {path}' 142 | 143 | @Tool('Scrape Tool',params=Scrape) 144 | async def scrape_tool(context:Context=None): 145 | '''Extracts and returns the main content from the current webpage. Can output in markdown format (preserving links and structure). Filters out navigation, ads, and other non-essential content.''' 146 | page=await context.get_current_page() 147 | await page.wait_for_load_state('domcontentloaded') 148 | html=await page.content() 149 | content= markdownify(html) 150 | return f'Scraped the contents of the entire webpage:\n{content}' 151 | 152 | @Tool('Tab Tool', params=Tab) 153 | async def tab_tool(mode: Literal['open', 'close', 'switch'], tab_index: Optional[int] = None, context: Context = None): 154 | '''Manages browser tabs: opens new blank tabs, closes the current tab (if not the last one), or switches between existing tabs by index. Automatically handles focus and loading states.''' 155 | session = await context.get_session() 156 | pages = session.context.pages # Get all open tabs 157 | if mode == 'open': 158 | page = await session.context.new_page() 159 | session.current_page = page 160 | await page.wait_for_load_state('load') 161 | return 'Opened a new blank tab and switched to it.' 162 | elif mode == 'close': 163 | if len(pages) == 1: 164 | return 'Cannot close the last remaining tab.' 165 | page = session.current_page 166 | await page.close() 167 | # Get remaining pages after closing 168 | pages = session.context.pages 169 | session.current_page = pages[-1] # Switch to last remaining tab 170 | await session.current_page.bring_to_front() 171 | await session.current_page.wait_for_load_state('load') 172 | return 'Closed current tab and switched to the next last tab.' 173 | elif mode == 'switch': 174 | if tab_index is None or tab_index < 0 or tab_index >= len(pages): 175 | raise IndexError(f'Tab index {tab_index} is out of range. Available tabs: {len(pages)}') 176 | session.current_page = pages[tab_index] 177 | await session.current_page.bring_to_front() 178 | await session.current_page.wait_for_load_state('load') 179 | return f'Switched to tab {tab_index} (Total tabs: {len(pages)}).' 180 | else: 181 | raise ValueError("Invalid mode. Use 'open', 'close', or 'switch'.") 182 | 183 | @Tool('Upload Tool',params=Upload) 184 | async def upload_tool(index:int,filenames:list[str],context:Context=None): 185 | '''Uploads one or more files to file input elements on webpages. Handles both single and multiple file uploads. Files should be placed in the ./uploads directory before using this tool.''' 186 | element=await context.get_element_by_index(index=index) 187 | handle=await context.get_handle_by_xpath(element.xpath) 188 | files=[Path(getcwd()).joinpath('./uploads',filename) for filename in filenames] 189 | page=await context.get_current_page() 190 | async with page.expect_file_chooser() as file_chooser_info: 191 | await handle.click() 192 | file_chooser=await file_chooser_info.value 193 | handle=file_chooser.element 194 | if file_chooser.is_multiple(): 195 | await handle.set_input_files(files=files) 196 | else: 197 | await handle.set_input_files(files=files[0]) 198 | await page.wait_for_load_state('load') 199 | return f'Uploaded {filenames} to element at label {index}' 200 | 201 | @Tool('Menu Tool',params=Menu) 202 | async def menu_tool(index:int,labels:list[str],context:Context=None): 203 | '''Interacts with dropdown menus, select elements, and multi-select lists. Can select single or multiple options by their visible labels. Handles both simple dropdowns and complex multi-selection interfaces.''' 204 | element=await context.get_element_by_index(index=index) 205 | handle=await context.get_handle_by_xpath(element.xpath) 206 | labels=labels if len(labels)>1 else labels[0] 207 | await handle.select_option(label=labels) 208 | return f'Opened context menu of element at label {index} and selected {", ".join(labels)}' 209 | 210 | @Tool("Script Tool",params=Script) 211 | async def script_tool(script:str,context:Context=None): 212 | '''Executes arbitrary JavaScript code on the page. Can be used to manipulate the DOM or trigger events or scrape data. Returns the result of the executed script.''' 213 | page=await context.get_current_page() 214 | result=await context.execute_script(page,script) 215 | return f"Result of the executed script: {result}" 216 | 217 | @Tool('Human Tool',params=HumanInput) 218 | async def human_tool(prompt:str,context:Context=None): 219 | '''Requests human assistance when encountering challenges that require human intervention such as CAPTCHAs, OTP codes, complex decisions, or when explicitly asked to involve a human user.''' 220 | print(colored(f"Agent: {prompt}", color='cyan', attrs=['bold'])) 221 | human_response = input("User: ") 222 | return f"User provided the following input: '{human_response}'" -------------------------------------------------------------------------------- /src/inference/gemini.py: -------------------------------------------------------------------------------- 1 | from src.message import AIMessage,BaseMessage,HumanMessage,ImageMessage,SystemMessage,ToolMessage 2 | from requests import get,RequestException,HTTPError,ConnectionError 3 | from tenacity import retry,stop_after_attempt,retry_if_exception_type 4 | from ratelimit import limits,sleep_and_retry 5 | from src.inference import BaseInference,Token 6 | from httpx import Client,AsyncClient 7 | from pydantic import BaseModel 8 | from typing import Optional 9 | from typing import Literal 10 | from src.tool import Tool 11 | from json import loads 12 | from uuid import uuid4 13 | 14 | class ChatGemini(BaseInference): 15 | def __init__(self,model:str,api_version:Literal['v1','v1beta','v1alpha']='v1beta',modality:Literal['text','audio']='text',api_key:str='',base_url:str='',tools:list=[],temperature:float=0.5): 16 | super().__init__(model,api_key=api_key,base_url=base_url,tools=tools,temperature=temperature) 17 | self.api_version=api_version 18 | self.modality=modality 19 | 20 | 21 | def cache_content(self,system_message:Optional[SystemMessage]=None,tools:Optional[list[Tool]]=None,messages:Optional[list[BaseMessage]]=None,display_name:Optional[str]=None,ttl:int=60): 22 | url = f"https://generativelanguage.googleapis.com/{self.api_version}/cachedContents?key={self.api_key}" 23 | payload = { 24 | "ttl": f"{ttl}s", 25 | "model": f"models/{self.model}", 26 | } 27 | # Add display name if provided 28 | if display_name: 29 | payload["display_name"] = display_name 30 | # Add system instruction if provided 31 | if system_message: 32 | payload["systemInstruction"] = { 33 | "parts": [ 34 | { 35 | "text": system_message.content 36 | } 37 | ] 38 | } 39 | # Add tools if provided 40 | if tools: 41 | payload["tools"] = [ 42 | { 43 | "function_declarations": [ 44 | { 45 | "name": tool.name, 46 | "description": tool.description, 47 | "parameters": tool.schema 48 | } 49 | for tool in tools 50 | ] 51 | } 52 | ] 53 | try: 54 | with Client() as client: 55 | response = client.post(url, json=payload, headers=self.headers, timeout=None) 56 | json_obj=response.json() 57 | if json_obj.get('error'): 58 | raise Exception(json_obj['error']['message']) 59 | usage_metadata=json_obj['usageMetadata'] 60 | self.tokens=Token(cache=usage_metadata["totalTokenCount"]) 61 | return json_obj['name'] 62 | except HTTPError as err: 63 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 64 | return None 65 | except ConnectionError as err: 66 | print(f'Connection Error: {err}') 67 | return None 68 | 69 | @sleep_and_retry 70 | @limits(calls=15,period=60) 71 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 72 | def invoke(self, messages: list[BaseMessage],json=False,model:BaseModel|None=None,cache_name:Optional[str]=None) -> AIMessage|ToolMessage|BaseModel: 73 | self.headers.update({'x-goog-api-key':self.api_key}) 74 | temperature=self.temperature 75 | url=self.base_url or f"https://generativelanguage.googleapis.com/{self.api_version}/models/{self.model}:generateContent" 76 | contents=[] 77 | system_instruct=None 78 | for message in messages: 79 | if isinstance(message,HumanMessage): 80 | contents.append({ 81 | 'role':'user', 82 | 'parts':[{ 83 | 'text':message.content 84 | }] 85 | }) 86 | elif isinstance(message,AIMessage): 87 | contents.append({ 88 | 'role':'model', 89 | 'parts':[{ 90 | 'text':message.content 91 | }] 92 | }) 93 | elif isinstance(message,ImageMessage): 94 | text,image=message.content 95 | contents.append({ 96 | 'role':'user', 97 | 'parts':[{ 98 | 'text':text 99 | }, 100 | { 101 | 'inline_data':{ 102 | 'mime_type':'image/jpeg', 103 | 'data': image 104 | } 105 | }] 106 | }) 107 | elif isinstance(message,SystemMessage): 108 | system_instruct={ 109 | 'parts':{ 110 | 'text': self.structured(message,model) if model else message.content 111 | } 112 | } 113 | else: 114 | raise Exception("Invalid Message") 115 | 116 | payload={ 117 | 'contents': contents, 118 | 'generationConfig':{ 119 | 'temperature': temperature, 120 | 'responseMimeType':'application/json' if json or model else 'text/plain', 121 | 'responseModalities': [self.modality] 122 | } 123 | } 124 | if self.tools: 125 | payload['tools']=[ 126 | { 127 | 'function_declarations':[ 128 | { 129 | 'name': tool.name, 130 | 'description': tool.description, 131 | 'parameters': tool.schema 132 | } 133 | for tool in self.tools] 134 | } 135 | ] 136 | if system_instruct: 137 | payload['system_instruction']=system_instruct 138 | 139 | if cache_name: 140 | payload['cachedContent']=f"cachedContents/{cache_name}" 141 | 142 | try: 143 | with Client() as client: 144 | response=client.post(url=url,headers=self.headers,json=payload,timeout=None) 145 | json_obj=response.json() 146 | # print(json_obj) 147 | if json_obj.get('error'): 148 | raise Exception(json_obj['error']['message']) 149 | message=json_obj['candidates'][0]['content']['parts'][0] 150 | usage_metadata=json_obj['usageMetadata'] 151 | input,output,total=usage_metadata['promptTokenCount'],usage_metadata['candidatesTokenCount'],usage_metadata['totalTokenCount'] 152 | self.tokens=Token(input=input,output=output,total=total) 153 | # print(message) 154 | if model: 155 | return model.model_validate_json(message['text']) 156 | if json: 157 | content=loads(message['text']) 158 | return AIMessage(content) 159 | if message['text']: 160 | content=message['text'] 161 | return AIMessage(content) 162 | else: 163 | tool_call=message['functionCall'] 164 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['args']) 165 | 166 | except HTTPError as err: 167 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 168 | except ConnectionError as err: 169 | print(err) 170 | exit() 171 | 172 | @sleep_and_retry 173 | @limits(calls=15,period=60) 174 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 175 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel: 176 | temperature=self.temperature 177 | url=self.base_url or f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent" 178 | self.headers.update({'x-goog-api-key':self.api_key}) 179 | contents=[] 180 | system_instruction=None 181 | for message in messages: 182 | if isinstance(message,HumanMessage): 183 | contents.append({ 184 | 'role':'user', 185 | 'parts':[{ 186 | 'text':message.content 187 | }] 188 | }) 189 | elif isinstance(message,AIMessage): 190 | contents.append({ 191 | 'role':'model', 192 | 'parts':[{ 193 | 'text':message.content 194 | }] 195 | }) 196 | elif isinstance(message,ImageMessage): 197 | text,image=message.content 198 | contents.append({ 199 | 'role':'user', 200 | 'parts':[{ 201 | 'text':text 202 | }, 203 | { 204 | 'inline_data':{ 205 | 'mime_type':'image/jpeg', 206 | 'data': image 207 | } 208 | }] 209 | }) 210 | elif isinstance(message,SystemMessage): 211 | system_instruction={ 212 | 'parts':{ 213 | 'text': self.structured(message,model) if model else message.content 214 | } 215 | } 216 | else: 217 | raise Exception("Invalid Message") 218 | 219 | payload={ 220 | 'contents': contents, 221 | 'generationConfig':{ 222 | 'temperature': temperature, 223 | 'responseMimeType':'application/json' if json or model else 'text/plain', 224 | 'responseModalities': [self.modality] 225 | } 226 | } 227 | if self.tools: 228 | payload['tools']=[ 229 | { 230 | 'function_declarations':[ 231 | { 232 | 'name': tool.name, 233 | 'description': tool.description, 234 | 'parameters': tool.schema 235 | } 236 | for tool in self.tools] 237 | } 238 | ] 239 | if system_instruction: 240 | payload['system_instruction']=system_instruction 241 | try: 242 | async with AsyncClient() as client: 243 | response=await client.post(url=url,headers=self.headers,json=payload,timeout=None) 244 | json_obj=response.json() 245 | # print(json_obj) 246 | if json_obj.get('error'): 247 | raise Exception(json_obj['error']['message']) 248 | message=json_obj['candidates'][0]['content']['parts'][0] 249 | usage_metadata=json_obj['usageMetadata'] 250 | input,output,total=usage_metadata['promptTokenCount'],usage_metadata['candidatesTokenCount'],usage_metadata['totalTokenCount'] 251 | self.tokens=Token(input=input,output=output,total=total) 252 | if model: 253 | return model.model_validate_json(message['text']) 254 | if json: 255 | content=loads(message['text']) 256 | return AIMessage(content) 257 | if message['text']: 258 | content=message['text'] 259 | return AIMessage(content) 260 | else: 261 | tool_call=message['functionCall'] 262 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['args']) 263 | 264 | except HTTPError as err: 265 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 266 | except ConnectionError as err: 267 | print(err) 268 | exit() 269 | 270 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException)) 271 | def stream(self, query:str): 272 | headers=self.headers 273 | temperature=self.temperature 274 | url=self.base_url or f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent" 275 | 276 | def available_models(self): 277 | url='https://generativelanguage.googleapis.com/v1beta/models' 278 | headers=self.headers 279 | params={'key':self.api_key} 280 | try: 281 | response=get(url=url,headers=headers,params=params) 282 | response.raise_for_status() 283 | json_obj=response.json() 284 | models=json_obj['models'] 285 | except HTTPError as err: 286 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}') 287 | exit() 288 | except ConnectionError as err: 289 | print(err) 290 | exit() 291 | return [model['displayName'] for model in models] -------------------------------------------------------------------------------- /src/agent/web/__init__.py: -------------------------------------------------------------------------------- 1 | from src.agent.web.tools import click_tool,goto_tool,type_tool,scroll_tool,wait_tool,back_tool,key_tool,scrape_tool,tab_tool,forward_tool,done_tool,download_tool,human_tool,script_tool 2 | from src.message import SystemMessage,HumanMessage,ImageMessage,AIMessage 3 | from src.agent.web.utils import read_markdown_file,extract_agent_data 4 | from src.agent.web.browser import Browser,BrowserConfig 5 | from langgraph.graph import StateGraph,END,START 6 | from src.agent.web.state import AgentState 7 | from src.agent.web.context import Context 8 | from src.inference import BaseInference 9 | from src.tool.registry import Registry 10 | from rich.markdown import Markdown 11 | from src.memory import BaseMemory 12 | from rich.console import Console 13 | from src.agent import BaseAgent 14 | from pydantic import BaseModel 15 | from datetime import datetime 16 | from termcolor import colored 17 | from textwrap import dedent 18 | from src.tool import Tool 19 | from pathlib import Path 20 | import textwrap 21 | import platform 22 | import asyncio 23 | import json 24 | 25 | main_tools=[ 26 | click_tool,goto_tool,key_tool,scrape_tool, 27 | type_tool,scroll_tool,wait_tool,back_tool, 28 | tab_tool,done_tool,forward_tool,download_tool, 29 | script_tool 30 | ] 31 | 32 | class Agent(BaseAgent): 33 | def __init__(self,config:BrowserConfig=None,additional_tools:list[Tool]=[], 34 | instructions:list=[],memory:BaseMemory=None,llm:BaseInference=None,max_iteration:int=10, 35 | use_vision:bool=False,include_human_in_loop:bool=False,verbose:bool=False,token_usage:bool=False) -> None: 36 | """ 37 | Initializes the WebAgent object. 38 | 39 | Args: 40 | config (BrowserConfig, optional): Browser configuration. Defaults to None. 41 | additional_tools (list[Tool], optional): Additional tools to be used. Defaults to []. 42 | instructions (list, optional): Instructions for the agent. Defaults to []. 43 | memory (BaseMemory, optional): Memory object for the agent. Defaults to None. 44 | llm (BaseInference, optional): Large Language Model object. Defaults to None. 45 | max_iteration (int, optional): Maximum number of iterations. Defaults to 10. 46 | use_vision (bool, optional): Whether to use vision or not. Defaults to False. 47 | include_human_in_loop (bool, optional): Whether to include human in the loop or not. Defaults to False. 48 | verbose (bool, optional): Whether to print verbose output or not. Defaults to False. 49 | token_usage (bool, optional): Whether to track token usage or not. Defaults to False. 50 | 51 | Returns: 52 | None 53 | """ 54 | self.name='Web Agent' 55 | self.description='The Web Agent is designed to automate the process of gathering information from the internet, such as to navigate websites, perform searches, and retrieve data.' 56 | self.observation_prompt=read_markdown_file('./src/agent/web/prompt/observation.md') 57 | self.system_prompt=read_markdown_file('./src/agent/web/prompt/system.md') 58 | self.action_prompt=read_markdown_file('./src/agent/web/prompt/action.md') 59 | self.answer_prompt=read_markdown_file('./src/agent/web/prompt/answer.md') 60 | self.instructions=self.format_instructions(instructions) 61 | self.registry=Registry(main_tools+additional_tools+([human_tool] if include_human_in_loop else [])) 62 | self.include_human_in_loop=include_human_in_loop 63 | self.browser=Browser(config=config) 64 | self.context=Context(browser=self.browser) 65 | self.max_iteration=max_iteration 66 | self.token_usage=token_usage 67 | self.structured_output=None 68 | self.use_vision=use_vision 69 | self.verbose=verbose 70 | self.start_time=None 71 | self.memory=memory 72 | self.end_time=None 73 | self.iteration=0 74 | self.llm=llm 75 | self.graph=self.create_graph() 76 | 77 | def format_instructions(self,instructions): 78 | return '\n'.join([f'{i+1}. {instruction}' for (i,instruction) in enumerate(instructions)]) 79 | 80 | async def reason(self,state:AgentState): 81 | "Call LLM to make decision based on the current state of the browser" 82 | system_prompt=self.system_prompt.format(**{ 83 | 'os':platform.system(), 84 | 'instructions':self.instructions, 85 | 'home_dir':Path.home().as_posix(), 86 | 'max_iteration':self.max_iteration, 87 | 'human_in_loop':self.include_human_in_loop, 88 | 'tools_prompt':self.registry.tools_prompt(), 89 | 'browser':self.browser.config.browser.capitalize(), 90 | 'downloads_dir':self.browser.config.downloads_dir, 91 | 'current_datetime':datetime.now().strftime('%A, %B %d, %Y') 92 | }) 93 | messages=[SystemMessage(system_prompt)]+state.get('messages') 94 | ai_message=await self.llm.async_invoke(messages=messages) 95 | agent_data=extract_agent_data(ai_message.content) 96 | memory=agent_data.get('Memory') 97 | evaluate=agent_data.get("Evaluate") 98 | thought=agent_data.get('Thought') 99 | if self.verbose: 100 | print(colored(f'Evaluate: {evaluate}',color='light_yellow',attrs=['bold'])) 101 | print(colored(f'Memory: {memory}',color='light_green',attrs=['bold'])) 102 | print(colored(f'Thought: {thought}',color='light_magenta',attrs=['bold'])) 103 | last_message=state.get('messages').pop() # ImageMessage/HumanMessage. To remove the past browser state 104 | if isinstance(last_message,(ImageMessage,HumanMessage)): 105 | message=HumanMessage(dedent(f''' 106 | 107 | 108 | Current Step: {self.iteration} 109 | 110 | Max. Steps: {self.max_iteration} 111 | 112 | Action Response: {state.get('prev_observation')} 113 | 114 | 115 | ''')) 116 | return {**state,'agent_data': agent_data,'messages':[message]} 117 | 118 | async def action(self,state:AgentState): 119 | "Execute the provided action" 120 | agent_data=state.get('agent_data') 121 | memory=agent_data.get('Memory') 122 | evaluate=agent_data.get("Evaluate") 123 | thought=agent_data.get('Thought') 124 | action_name=agent_data.get('Action Name') 125 | action_input:dict=agent_data.get('Action Input') 126 | if self.verbose: 127 | print(colored(f'Action: {action_name}({','.join([f'{k}={v}' for k,v in action_input.items()])})',color='blue',attrs=['bold'])) 128 | action_result=await self.registry.async_execute(action_name,action_input,context=self.context) 129 | observation=action_result.content 130 | if self.verbose: 131 | print(colored(f'Observation: {textwrap.shorten(observation,width=1000,placeholder='...')}',color='green',attrs=['bold'])) 132 | if self.verbose and self.token_usage: 133 | print(f'Input Tokens: {self.llm.tokens.input} Output Tokens: {self.llm.tokens.output} Total Tokens: {self.llm.tokens.total}') 134 | # Get the current screenshot,browser state and dom state 135 | browser_state=await self.context.get_state(use_vision=self.use_vision) 136 | current_tab=browser_state.current_tab 137 | dom_state=browser_state.dom_state 138 | image_obj=browser_state.screenshot 139 | # Redefining the AIMessage and adding the new observation 140 | action_prompt=self.action_prompt.format(**{ 141 | 'memory':memory, 142 | 'evaluate':evaluate, 143 | 'thought':thought, 144 | 'action_name':action_name, 145 | 'action_input':json.dumps(action_input,indent=2) 146 | }) 147 | observation_prompt=self.observation_prompt.format(**{ 148 | 'iteration':self.iteration, 149 | 'max_iteration':self.max_iteration, 150 | 'observation':observation, 151 | 'current_tab':current_tab.to_string(), 152 | 'tabs':browser_state.tabs_to_string(), 153 | 'interactive_elements':dom_state.interactive_elements_to_string(), 154 | 'informative_elements':dom_state.informative_elements_to_string(), 155 | 'scrollable_elements':dom_state.scrollable_elements_to_string(), 156 | 'query':state.get('input') 157 | }) 158 | messages=[AIMessage(action_prompt),ImageMessage(text=observation_prompt,image_obj=image_obj) if self.use_vision and image_obj is not None else HumanMessage(observation_prompt)] 159 | return {**state,'messages':messages,'browser_state':browser_state,'dom_state':dom_state,'prev_observation':observation} 160 | 161 | async def answer(self,state:AgentState): 162 | "Give the final answer" 163 | if self.iterationdict|BaseModel: 216 | self.iteration=0 217 | observation_prompt=self.observation_prompt.format(**{ 218 | 'iteration':self.iteration, 219 | 'max_iteration':self.max_iteration, 220 | 'memory':'Nothing to remember', 221 | 'evaluate':'Nothing to evaluate', 222 | 'thought':'Nothing to think', 223 | 'action':'No Action', 224 | 'observation':'No Observation', 225 | 'current_tab':'No tabs open', 226 | 'tabs':'No tabs open', 227 | 'interactive_elements':'No interactive elements found', 228 | 'informative_elements':'No informative elements found', 229 | 'scrollable_elements':'No scrollable elements found', 230 | 'query':input 231 | }) 232 | state={ 233 | 'input':input, 234 | 'agent_data':{}, 235 | 'prev_observation':'No Observation', 236 | 'browser_state':None, 237 | 'dom_state':None, 238 | 'output':'', 239 | 'messages':[HumanMessage(observation_prompt)] 240 | } 241 | self.start_time=datetime.now() 242 | response=await self.graph.ainvoke(state,config={'recursion_limit':self.max_iteration}) 243 | self.end_time=datetime.now() 244 | total_seconds=(self.end_time-self.start_time).total_seconds() 245 | if self.verbose and self.token_usage: 246 | print(f'Input Tokens: {self.llm.tokens.input} Output Tokens: {self.llm.tokens.output} Total Tokens: {self.llm.tokens.total}') 247 | print(f'Total Time Taken: {total_seconds} seconds Number of Steps: {self.iteration}') 248 | # Extract and store the key takeaways of the task performed by the agent 249 | if self.memory: 250 | self.memory.store(response.get('messages')) 251 | return response 252 | 253 | def invoke(self, input: str)->dict|BaseModel: 254 | if self.verbose: 255 | print('Entering '+colored(self.name,'black','on_white')) 256 | try: 257 | loop = asyncio.get_running_loop() 258 | except RuntimeError: 259 | loop = asyncio.new_event_loop() 260 | asyncio.set_event_loop(loop) 261 | response = loop.run_until_complete(self.async_invoke(input=input)) 262 | return response 263 | 264 | async def invoke_history(self, input: str)->dict|BaseModel: 265 | pass 266 | 267 | def print_response(self,input: str): 268 | console=Console() 269 | response=self.invoke(input) 270 | console.print(Markdown(response.get('output'))) 271 | 272 | async def close(self): 273 | '''Close the browser and context followed by clean up''' 274 | try: 275 | await self.context.close_session() 276 | await self.browser.close_browser() 277 | except Exception: 278 | print('Failed to finish clean up') 279 | finally: 280 | self.context=None 281 | self.browser=None 282 | 283 | def stream(self, input:str): 284 | pass 285 | --------------------------------------------------------------------------------