├── .python-version
├── src
├── vectorstore
│ ├── __init__.py
│ ├── views.py
│ ├── base.py
│ └── chromadb.py
├── agent
│ ├── web
│ │ ├── history
│ │ │ ├── utils.py
│ │ │ ├── views.py
│ │ │ └── __init__.py
│ │ ├── prompt
│ │ │ ├── answer.md
│ │ │ ├── action.md
│ │ │ ├── observation.md
│ │ │ └── system.md
│ │ ├── state.py
│ │ ├── context
│ │ │ ├── script.js
│ │ │ ├── views.py
│ │ │ ├── config.py
│ │ │ └── __init__.py
│ │ ├── browser
│ │ │ ├── config.py
│ │ │ └── __init__.py
│ │ ├── utils.py
│ │ ├── dom
│ │ │ ├── views.py
│ │ │ └── __init__.py
│ │ ├── tools
│ │ │ ├── views.py
│ │ │ └── __init__.py
│ │ └── __init__.py
│ └── __init__.py
├── router
│ ├── utils.py
│ ├── __init__.py
│ └── prompt.md
├── memory
│ ├── episodic
│ │ ├── utils.py
│ │ ├── prompt
│ │ │ ├── memory.md
│ │ │ ├── add.md
│ │ │ ├── replace.md
│ │ │ ├── retrieve.md
│ │ │ └── update.md
│ │ ├── routes.json
│ │ ├── views.py
│ │ └── __init__.py
│ ├── semantic
│ │ └── __init__.py
│ └── __init__.py
├── tool
│ ├── thinking.py
│ ├── registry
│ │ ├── views.py
│ │ └── __init__.py
│ └── __init__.py
├── embedding
│ ├── __init__.py
│ ├── ollama.py
│ ├── mistral.py
│ └── gemini.py
├── inference
│ ├── __init__.py
│ ├── anthropic.py
│ ├── nvidia.py
│ ├── open_router.py
│ ├── mistral.py
│ ├── ollama.py
│ ├── openai.py
│ ├── groq.py
│ └── gemini.py
├── speech
│ └── __init__.py
└── message
│ └── __init__.py
├── assets
└── demo1.mov
├── pyproject.toml
├── main.py
├── setup.py
├── LICENSE
├── README.md
├── .gitignore
└── CONTRIBUTING.md
/.python-version:
--------------------------------------------------------------------------------
1 | 3.13
2 |
--------------------------------------------------------------------------------
/src/vectorstore/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/agent/web/history/utils.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/demo1.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CursorTouch/Web-Navigator/HEAD/assets/demo1.mov
--------------------------------------------------------------------------------
/src/router/utils.py:
--------------------------------------------------------------------------------
1 | def read_markdown_file(file_path: str) -> str:
2 | with open(file_path, 'r',encoding='utf-8') as f:
3 | markdown_content = f.read()
4 | return markdown_content
--------------------------------------------------------------------------------
/src/memory/episodic/utils.py:
--------------------------------------------------------------------------------
1 | def read_markdown_file(file_path: str) -> str:
2 | with open(file_path, 'r',encoding='utf-8') as f:
3 | markdown_content = f.read()
4 | return markdown_content
--------------------------------------------------------------------------------
/src/agent/web/prompt/answer.md:
--------------------------------------------------------------------------------
1 | ```xml
2 |
8 | ```
--------------------------------------------------------------------------------
/src/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC,abstractmethod
2 |
3 | class BaseAgent(ABC):
4 | @abstractmethod
5 | def invoke(self,input:str):
6 | pass
7 | @abstractmethod
8 | def stream(self,input:str):
9 | pass
--------------------------------------------------------------------------------
/src/agent/web/prompt/action.md:
--------------------------------------------------------------------------------
1 | ```xml
2 |
9 | ```
--------------------------------------------------------------------------------
/src/vectorstore/views.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass,field
2 | from uuid import uuid4
3 |
4 | @dataclass
5 | class Document:
6 | id:str=field(default_factory=lambda: str(uuid4()))
7 | content:str=field(default_factory=str)
8 | metadata:dict=field(default_factory=dict)
--------------------------------------------------------------------------------
/src/memory/semantic/__init__.py:
--------------------------------------------------------------------------------
1 | from src.memory import BaseMemory
2 |
3 | class SemanticMemory(BaseMemory):
4 | def store(self, input: str):
5 | pass
6 |
7 | def retrieve(self, input: str):
8 | pass
9 |
10 | def attach_memory(self)->str:
11 | pass
--------------------------------------------------------------------------------
/src/memory/episodic/prompt/memory.md:
--------------------------------------------------------------------------------
1 | ### EPISODIC MEMORIES
2 | Following are the revelant memories found based on the past interactions. Use these memories as context to enhance your thought process.
3 |
4 | {memories}
5 |
6 | NOTE:
7 | - Incorporate past experiences to align with the user's current needs and objectives.
8 | - Use memory as a guide to improve reasoning, maintain consistency, and refine responses.
9 |
--------------------------------------------------------------------------------
/src/agent/web/state.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.context import BrowserState
2 | from src.agent.web.dom.views import DOMState
3 | from typing import TypedDict,Annotated
4 | from src.message import BaseMessage
5 | from operator import add
6 |
7 | class AgentState(TypedDict):
8 | input:str
9 | output:str
10 | agent_data:dict
11 | prev_observation:str
12 | browser_state:BrowserState|None
13 | dom_state:DOMState|None
14 | messages: Annotated[list[BaseMessage],add]
--------------------------------------------------------------------------------
/src/tool/thinking.py:
--------------------------------------------------------------------------------
1 | from src.tool import Tool
2 | from pydantic import BaseModel,Field
3 |
4 | class Thinking(BaseModel):
5 | thought: str=Field(...,description="Your extended thinking goes here")
6 |
7 | @Tool('Thinking Tool',params=Thinking)
8 | async def thinking_tool(thought:str,context=None):
9 | '''
10 | To think about something. It will not obtain new information or make any changes, but just log the thought.
11 | Use it when complex reasoning or brainstorming is needed.
12 | '''
13 | return thought
--------------------------------------------------------------------------------
/src/tool/registry/views.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel,Field,ConfigDict
2 | from typing import Callable,Type
3 |
4 | class Function(BaseModel):
5 | name:str=Field(...,description="the name of the action")
6 | description:str=Field(...,description="the description of the action")
7 | params:Type[BaseModel]|None
8 | function:Callable|None
9 | model_config=ConfigDict(arbitrary_types_allowed=True)
10 |
11 | class ToolResult(BaseModel):
12 | name: str = Field(...,description="the action taken")
13 | content: str = Field(...,description="the output of the action")
--------------------------------------------------------------------------------
/src/memory/episodic/prompt/add.md:
--------------------------------------------------------------------------------
1 | You are asked to generate episodic memory by analyzing conversations to extract key elements for guiding future interactions. Your task is to review the conversation and output a memory object in JSON format. Follow these guidelines:
2 |
3 | 1. Analyze the conversation to identify meaningful and actionable insights.
4 | 2. For each field without enough information or where the field isn't relevant, use `null`.
5 | 3. Be concise and ensure each string is clear and actionable.
6 | 4. Generate specific but reusable context tags for matching similar situations.
7 | 5. Your response should only have one json object.
--------------------------------------------------------------------------------
/src/agent/web/context/script.js:
--------------------------------------------------------------------------------
1 | Object.defineProperty(navigator, 'webdriver', {
2 | get: () => undefined
3 | });
4 |
5 | Object.defineProperty(navigator, 'languages', {
6 | get: () => ['en-US', 'en']
7 | });
8 |
9 | Object.defineProperty(navigator, 'plugins', {
10 | get: () => [1, 2, 3, 4, 5],
11 | });
12 |
13 | window.chrome = { runtime: {} };
14 |
15 | const originalQuery = window.navigator.permissions.query;
16 | window.navigator.permissions.query = (parameters) => (
17 | parameters.name === 'notifications' ?
18 | Promise.resolve({ state: Notification.permission }) :
19 | originalQuery(parameters)
20 | );
--------------------------------------------------------------------------------
/src/memory/episodic/prompt/replace.md:
--------------------------------------------------------------------------------
1 | You are asked to generate episodic memory by analyzing conversations to extract key elements for guiding future interactions. Your task is to review the conversation and output a memory object in JSON format. Follow these guidelines:
2 |
3 | 1. Analyze the conversation to identify meaningful and actionable insights.
4 | 2. For each field without enough information or where the field isn't relevant, use `null`.
5 | 3. Be concise and ensure each string is clear and actionable.
6 | 4. Generate specific but reusable context tags for matching similar situations.
7 | 5. Your response should only have one json object.
--------------------------------------------------------------------------------
/src/memory/episodic/prompt/retrieve.md:
--------------------------------------------------------------------------------
1 | You are asked to retrieve revelant episodic memories to assist in the current user query.
2 | ### Follow these instructions:
3 | 1. You will be provided with a set of past memories in JSON format.
4 | 2. The user will provide a query, and your goal is to identify and retrieve the most relevant memories that could assist the user in achieving their goal.
5 | 3. Include memories that align similar with the context or goal of the user's query (The underlying methodology is similar).
6 | 4. Output the selected memories as a JSON array. Each memory object should retain its original structure and data.
7 |
8 | ### Memories
9 | {memories}
10 |
--------------------------------------------------------------------------------
/src/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | from chromadb import Documents,EmbeddingFunction,Embeddings
2 | from abc import ABC,abstractmethod
3 |
4 | class BaseEmbedding(ABC,EmbeddingFunction):
5 | def __init__(self,model:str='',api_key:str='',base_url:str=''):
6 | self.name=self.__class__.__name__.replace('Embedding','')
7 | self.api_key=api_key
8 | self.model=model
9 | self.base_url=base_url
10 | self.headers={'Content-Type': 'application/json'}
11 |
12 | def __call__(self, input:Documents)->Embeddings:
13 | return self.embed(input)
14 |
15 | @abstractmethod
16 | def embed(self,text:list[str]|str)->list:
17 | pass
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "Web-Navigator"
3 | version = "0.1.0"
4 | description = "Web-Navigator is an ai agent for web browsing and scraping websites."
5 | readme = "README.md"
6 | license = { file = "LICENSE" }
7 | requires-python = ">=3.13"
8 | dependencies = [
9 | "httpx>=0.28.1",
10 | "keyboard>=0.13.5",
11 | "langgraph>=0.6.7",
12 | "markdownify>=1.2.0",
13 | "nest-asyncio>=1.6.0",
14 | "playwright>=1.55.0",
15 | "pyaudio>=0.2.14",
16 | "pydantic>=2.11.9",
17 | "pyperclip>=1.10.0",
18 | "python-dotenv>=1.1.1",
19 | "ratelimit>=2.2.1",
20 | "requests>=2.32.5",
21 | "rich>=14.1.0",
22 | "tenacity>=9.1.2",
23 | "termcolor>=3.1.0",
24 | ]
25 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.browser.config import BrowserConfig
2 | from src.inference.gemini import ChatGemini
3 | from src.agent.web import Agent
4 | from dotenv import load_dotenv
5 | import os
6 |
7 | load_dotenv()
8 |
9 | api_key = os.getenv('GOOGLE_API_KEY')
10 | browser_instance_dir = os.getenv('BROWSER_INSTANCE_DIR')
11 | user_data_dir = os.getenv('USER_DATA_DIR')
12 |
13 | llm=ChatGemini(model='gemini-2.5-flash-lite',api_key=api_key,temperature=0)
14 | config=BrowserConfig(browser='edge',browser_instance_dir=browser_instance_dir,user_data_dir=user_data_dir,headless=False)
15 |
16 | agent=Agent(config=config,llm=llm,verbose=True,use_vision=True,max_iteration=100)
17 | user_query = input('Enter your query: ')
18 | agent.print_response(user_query)
--------------------------------------------------------------------------------
/src/agent/web/history/views.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.dom.views import BoundingBox,CenterCord
2 | from pydantic import BaseModel,Field
3 | from dataclasses import dataclass
4 |
5 | class DOMHistoryElementNode(BaseModel):
6 | tag:str
7 | role:str
8 | name:str
9 | center:CenterCord
10 | bounding_box:BoundingBox
11 | xpath:dict[str,str]=Field(default_factory=dict)
12 | attributes:dict[str,str]=Field(default_factory=dict)
13 | viewport:tuple[int,int]=Field(default_factory=tuple)
14 |
15 | def to_dict(self)->dict[str,str]:
16 | return {'tag':self.tag,'role':self.role,'xpath':self.xpath,'attributes':self.attributes,'bounding_box':self.bounding_box.to_dict()}
17 |
18 | @dataclass
19 | class HashElement:
20 | attributes:str
21 | xpath:str
--------------------------------------------------------------------------------
/src/memory/episodic/prompt/update.md:
--------------------------------------------------------------------------------
1 | You are a memory updater tasked with refining and enhancing episodic memories based on new insights from the current conversation. Your role is to analyze the provided relevant memories and the current conversation, updating them to incorporate the new information while preserving their original purpose and clarity. Follow these rules:
2 |
3 | 1. Only update the provided relevant memories with information from the current conversation that adds value or clarity.
4 | 2. Ensure updates are concise, actionable, and maintain the structure of the original memory.
5 | 3. If a field becomes irrelevant or lacks enough information after the update, set it to `null`.
6 | 4. Output all updated memories as a valid JSON array, preserving the format of the input memories.
--------------------------------------------------------------------------------
/src/memory/episodic/routes.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "name": "ADD",
4 | "description": "Add the converstion to the memory because relevant memories is empty."
5 | },
6 | {
7 | "name": "UPDATE",
8 | "description": "Update the relevant memories with information from the conversation if there is new information or approach that can be added to the existing relevant memories."
9 | },
10 | {
11 | "name": "IDLE",
12 | "description": "Do nothing as the conversation is redundant when compared to the relevant memories in the knowledge base."
13 | },
14 | {
15 | "name": "REPLACE",
16 | "description": "Replace the relevant memories because it is less valuable, and the conversation provides more significant insights or conversation is a combine of multiple relevant memories."
17 | }
18 | ]
--------------------------------------------------------------------------------
/src/agent/web/context/views.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass,field
2 | from playwright.async_api import Page,BrowserContext as PlaywrightBrowserContext
3 | from src.agent.web.dom.views import DOMState
4 | from typing import Optional
5 |
6 | @dataclass
7 | class Tab:
8 | id:int
9 | url:str
10 | title:str
11 | page:Page
12 |
13 | def to_string(self)->str:
14 | return f'{self.id} - Title: {self.title} - URL: {self.url}'
15 |
16 | @dataclass
17 | class BrowserState:
18 | current_tab:Optional[Tab]=None
19 | tabs:list[Tab]=field(default_factory=list)
20 | screenshot:Optional[str]=None
21 | dom_state:DOMState=field(default_factory=DOMState([]))
22 |
23 | def tabs_to_string(self)->str:
24 | return '\n'.join([tab.to_string() for tab in self.tabs])
25 |
26 | @dataclass
27 | class BrowserSession:
28 | context: PlaywrightBrowserContext
29 | current_page: Page
30 | state: BrowserState
--------------------------------------------------------------------------------
/src/embedding/ollama.py:
--------------------------------------------------------------------------------
1 | from requests import RequestException,HTTPError,ConnectionError
2 | from src.embedding import BaseEmbedding
3 | from httpx import Client
4 | from typing import Literal
5 |
6 | class OllamaEmbedding(BaseEmbedding):
7 | def embed(self, text):
8 | url=self.base_url or f'http://localhost:11434/api/embed'
9 | headers=self.headers
10 | payload={
11 | 'model':self.model,
12 | 'input':text
13 | }
14 | try:
15 | with Client() as client:
16 | response=client.post(url=url,json=payload,headers=headers)
17 | response.raise_for_status()
18 | return response.json()['embeddings'][0]
19 | except HTTPError as err:
20 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
21 | except ConnectionError as err:
22 | print(err)
--------------------------------------------------------------------------------
/src/agent/web/prompt/observation.md:
--------------------------------------------------------------------------------
1 | ```xml
2 |
3 |
4 | Current Step: {iteration}
5 |
6 | Max. Steps: {max_iteration}
7 |
8 | Action Response: {observation}
9 |
10 |
11 | [Begin of Tab Info]
12 | Current Tab: {current_tab}
13 |
14 | Open Tabs:
15 | {tabs}
16 | [End of Tab Info]
17 |
18 | [Begin of Viewport]
19 | List of Interactive Elements:
20 | {interactive_elements}
21 |
22 | List of Scrollable Elements:
23 | {scrollable_elements}
24 |
25 | List of Informative Elements:
26 | {informative_elements}
27 | [End of Viewport]
28 |
29 | {query}
30 |
31 |
32 | Note: Use the `Done Tool` if the task is completely over else continue solving.
33 |
34 |
35 | ```
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open('./docs/README.md') as f:
4 | long_description = f.read()
5 | setup(
6 | name='Web-Agent',
7 | version='0.1',
8 | description='The web agent is designed to automate the process of gathering information from the internet, such as to navigate websites, perform searches, and retrieve data.',
9 | author='jeomon',
10 | author_email='jeogeoalukka@gmail',
11 | url='https://github.com/Jeomon/Web-Agent',
12 | packages=find_packages(),
13 | install_requires=[
14 | 'langgraph',
15 | 'tenacity'
16 | 'requests'
17 | 'playwright'
18 | 'termcolor'
19 | 'python-dotenv'
20 | 'httpx'
21 | 'nest_asyncio'
22 | 'MainContentExtractor'
23 | ],
24 | entry_points={
25 | 'console_scripts': [
26 | 'web-agent=main:main'
27 | ]
28 | },
29 | long_description=long_description,
30 | long_description_content_type='text/markdown',
31 | license='MIT'
32 | )
--------------------------------------------------------------------------------
/src/agent/web/history/__init__.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.history.views import DOMHistoryElementNode, HashElement
2 | from src.agent.web.dom.views import DOMElementNode
3 | from hashlib import sha256
4 |
5 | class History:
6 |
7 | def convert_dom_element_to_history_element(self,element:DOMElementNode)->DOMHistoryElementNode:
8 | return DOMHistoryElementNode(**element.to_dict())
9 |
10 | def compare_dom_element_with_history_element(self,element:DOMElementNode,history_element:DOMHistoryElementNode)->bool:
11 | hash_dom_element=self.hash_element(element)
12 | hash_history_element=self.hash_element(history_element)
13 | return hash_dom_element==hash_history_element
14 |
15 | def hash_element(self,element:DOMElementNode|DOMHistoryElementNode):
16 | element:dict=element.to_dict()
17 | attributes=sha256(str(element.get('attributes')).encode()).hexdigest()
18 | xpath=sha256(str(element.get('xpath')).encode()).hexdigest()
19 | return HashElement(attributes=attributes,xpath=xpath)
20 |
--------------------------------------------------------------------------------
/src/embedding/mistral.py:
--------------------------------------------------------------------------------
1 | from src.embedding import BaseEmbedding
2 | from httpx import Client
3 | from typing import Literal
4 | from requests import RequestException,HTTPError,ConnectionError
5 | import json
6 |
7 | class MistralEmbedding(BaseEmbedding):
8 | def embed(self, text):
9 | url=self.base_url or 'https://api.mistral.ai/v1/embeddings'
10 | self.headers['Authorization'] = f'Bearer {self.api_key}'
11 | headers=self.headers
12 | payload={
13 | 'model':self.model,
14 | 'input':text,
15 | 'encoding_format':'float'
16 | }
17 | try:
18 | with Client() as client:
19 | response=client.post(url=url,json=payload,headers=headers)
20 | response.raise_for_status()
21 | return response.json()['data']['embedding']
22 | except HTTPError as err:
23 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
24 | except ConnectionError as err:
25 | print(err)
26 |
27 |
--------------------------------------------------------------------------------
/src/vectorstore/base.py:
--------------------------------------------------------------------------------
1 | from src.vectorstore.views import Document
2 | from abc import ABC,abstractmethod
3 | from typing import Sequence
4 |
5 | class BaseVectorStore(ABC):
6 | @abstractmethod
7 | def create_collection(self,collection_name:str)->None:
8 | pass
9 |
10 | @abstractmethod
11 | def insert(self,documents:list[Document])->None:
12 | pass
13 |
14 | @abstractmethod
15 | def search(self,query:str,k:int)->list[Document]:
16 | pass
17 |
18 | @abstractmethod
19 | def delete(self,collection_name:str)->None:
20 | pass
21 |
22 | @abstractmethod
23 | def update(self,id:str,content:str,metadata:dict)->None:
24 | pass
25 |
26 | @abstractmethod
27 | def delete_collection(self,collection_name:str)->None:
28 | pass
29 |
30 | @abstractmethod
31 | def get(self,id:str)->Document:
32 | pass
33 |
34 | @abstractmethod
35 | def all(self)->list[Document]:
36 | pass
37 |
38 | @abstractmethod
39 | def all_collections(self)->Sequence:
40 | pass
41 |
42 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 CursorTouch
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/src/router/__init__.py:
--------------------------------------------------------------------------------
1 | from src.message import HumanMessage,SystemMessage
2 | from src.router.utils import read_markdown_file
3 | from src.inference import BaseInference
4 | from json import dumps
5 |
6 | class LLMRouter:
7 | def __init__(self,instructions:list[str]=[],routes:list[dict]=[],llm:BaseInference=None,verbose=False):
8 | self.system_prompt=read_markdown_file('./src/router/prompt.md')
9 | self.instructions=self.__get_instructions(instructions)
10 | self.routes=dumps(routes,indent=2)
11 | self.llm=llm
12 | self.verbose=verbose
13 |
14 | def __get_instructions(self,instructions):
15 | return '\n'.join([f'{i+1}. {instruction}' for i,instruction in enumerate(instructions)])
16 |
17 | def invoke(self,query:str)->str:
18 | parameters={'instructions':self.instructions,'routes':self.routes}
19 | messages=[SystemMessage(self.system_prompt.format(**parameters)),HumanMessage(query)]
20 | response=self.llm.invoke(messages,json=True)
21 | route=response.content.get('route')
22 | if self.verbose:
23 | print(f"Going to {route.upper()} route")
24 | return route
25 |
--------------------------------------------------------------------------------
/src/agent/web/browser/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal
3 | from pathlib import Path
4 |
5 | @dataclass
6 | class BrowserConfig:
7 | headless:bool=False
8 | wss_url:str=None
9 | device:str=None
10 | browser_instance_dir:str=None
11 | downloads_dir:str=(Path.home()/'Downloads').as_posix()
12 | browser:Literal['chrome','firefox','edge']='edge'
13 | user_data_dir:str=None
14 | timeout:int=60*1000
15 | slow_mo:int=300
16 |
17 | SECURITY_ARGS = [
18 | '--disable-web-security',
19 | '--disable-site-isolation-trials',
20 | '--disable-features=IsolateOrigins,site-per-process',
21 | ]
22 |
23 | BROWSER_ARGS=[
24 | '--disable-sandbox',
25 | '--enable-blink-features=IdleDetection',
26 | '--disable-blink-features=AutomationControlled',
27 | '--disable-infobars',
28 | '--disable-background-timer-throttling',
29 | '--disable-popup-blocking',
30 | '--disable-backgrounding-occluded-windows',
31 | '--disable-renderer-backgrounding',
32 | '--disable-window-activation',
33 | '--disable-focus-on-load',
34 | '--no-first-run',
35 | '--no-default-browser-check',
36 | '--no-startup-window',
37 | '--window-position=0,0',
38 | '--remote-debugging-port=9222'
39 | ]
40 |
41 | IGNORE_DEFAULT_ARGS=['--enable-automation']
--------------------------------------------------------------------------------
/src/memory/episodic/views.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel,Field
2 | from uuid import uuid4
3 |
4 | class Memory(BaseModel):
5 | id: str = Field(description='The id of the memory to be filled by the user',examples=['32cd40b6-5db1-48b3-9434-ebd9502f06f0'],default_factory=lambda: str(uuid4()))
6 | tags: list[str] = Field(...,description='Tags to help identify similar future conversations.',examples=[['google','weather']])
7 | summary: str = Field(...,description='Describes what the conversation accomplished.')
8 | what_worked: str = Field(...,description='Highlights the most effective strategy used.')
9 | what_to_avoid: str = Field(...,description='Describes the important pitfalls to avoid.')
10 |
11 | def to_dict(self):
12 | return self.model_dump()
13 |
14 | class Config:
15 | extra = 'allow'
16 |
17 | class Memories(BaseModel):
18 | memories: list[Memory]=Field(description='The list of memories',default_factory=list)
19 |
20 | def model_dump(self, *args, **kwargs):
21 | return super().model_dump(*args, **kwargs)["memories"]
22 |
23 | def all(self):
24 | return [memory.to_dict() for memory in self.memories]
25 |
26 | def to_string(self):
27 | return '\n\n'.join([f'**Tags:** {memory.tags}\n***Summary:** {memory.summary}\n**What Worked:** {memory.what_worked}\n**What to Avoid:** {memory.what_to_avoid}' for memory in self.memories])
28 |
29 |
--------------------------------------------------------------------------------
/src/memory/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC,abstractmethod
2 | from src.inference import BaseInference
3 | from src.message import BaseMessage,SystemMessage
4 | import json
5 | import os
6 |
7 | class BaseMemory(ABC):
8 | def __init__(self,knowledge_base:str='knowledge_base.json',llm:BaseInference=None,verbose=False):
9 | self.llm=llm
10 | self.verbose=verbose
11 | self.knowledge_base=knowledge_base
12 | self.__initialize_memory()
13 |
14 | @abstractmethod
15 | def store(self,conversation:list[BaseMessage])->None:
16 | pass
17 |
18 | @abstractmethod
19 | def retrieve(self,query:str)->list[dict]:
20 | pass
21 |
22 | @abstractmethod
23 | def attach_memory(self)->str:
24 | pass
25 |
26 | def __initialize_memory(self):
27 | if not os.path.exists(f'./memory_data/{self.knowledge_base}'):
28 | os.makedirs('./memory_data',exist_ok=True)
29 | with open(f'./memory_data/{self.knowledge_base}','w') as f:
30 | f.write(json.dumps(self.memories,indent=2))
31 | else:
32 | with open(f'./memory_data/{self.knowledge_base}','r') as f:
33 | self.memories=json.loads(f.read())
34 | def conversation_to_text(self,conversation:list[BaseMessage]):
35 | conversation=list(self.__filter_conversation(conversation))
36 | return '\n'.join([f'{message.role}: {message.content}' for message in conversation])
37 | def __filter_conversation(self,conversation:list[BaseMessage]):
38 | return filter(lambda message: not isinstance(message,SystemMessage),conversation)
--------------------------------------------------------------------------------
/src/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,SystemMessage
2 | from abc import ABC,abstractmethod
3 | from pydantic import BaseModel
4 | from typing import Optional
5 | from src.tool import Tool
6 |
7 | class Token(BaseModel):
8 | input: Optional[int]=None
9 | output: Optional[int]=None
10 | cache: Optional[int]=None
11 | total: Optional[int]=None
12 |
13 | structured_output_prompt='''
14 | Integrate the JSON output as part of the structured response, ensuring it strictly follows the provided schema.
15 | ```json
16 | {json_schema}
17 | ```
18 | Validate all fields, use `null` or empty values for missing data, and format the JSON in a clear, indented code block.
19 | '''
20 |
21 | class BaseInference(ABC):
22 | def __init__(self,model:str,api_key:str='',base_url:str='',tools:list[Tool]=[],temperature:float=0.5):
23 | self.model=model
24 | self.api_key=api_key
25 | self.base_url=base_url
26 | self.tools=tools
27 | self.temperature=temperature
28 | self.headers={'Content-Type': 'application/json'}
29 | self.structured_output_prompt=structured_output_prompt
30 | self.tokens:Token=Token(input=0,output=0,total=0)
31 |
32 | @abstractmethod
33 | def invoke(self,messages:list[dict],json:bool=False,model:BaseModel=None)->AIMessage|BaseModel:
34 | pass
35 |
36 | @abstractmethod
37 | async def async_invoke(self,messages:list[dict],json:bool=False,model:BaseModel=None)->AIMessage|BaseModel:
38 | pass
39 |
40 | @abstractmethod
41 | def stream(self,messages:list[dict],json:bool=False)->AIMessage:
42 | pass
43 |
44 | def structured(self,message:SystemMessage,model:BaseModel):
45 | return f'{message.content}\n{structured_output_prompt.format(json_schema=model.model_json_schema())}'
46 |
47 |
--------------------------------------------------------------------------------
/src/agent/web/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import ast
3 |
4 | def read_markdown_file(file_path: str) -> str:
5 | with open(file_path, 'r',encoding='utf-8') as f:
6 | markdown_content = f.read()
7 | return markdown_content
8 |
9 | def extract_agent_data(text):
10 | # Dictionary to store extracted values
11 | result = {}
12 | # Extract Memory
13 | evaluate_memory = re.search(r"(.*?)<\/Memory>", text, re.DOTALL)
14 | if evaluate_memory:
15 | result['Memory'] = evaluate_memory.group(1).strip()
16 | # Extract Evaluate
17 | evaluate_match = re.search(r"(.*?)<\/Evaluate>", text, re.DOTALL)
18 | if evaluate_match:
19 | result['Evaluate'] = evaluate_match.group(1).strip()
20 | # Extract Thought
21 | thought_match = re.search(r"(.*?)<\/Thought>", text, re.DOTALL)
22 | if thought_match:
23 | result['Thought'] = thought_match.group(1).strip()
24 | # Extract Action-Name
25 | action_name_match = re.search(r"(.*?)<\/Action-Name>", text, re.DOTALL)
26 | if action_name_match:
27 | result['Action Name'] = action_name_match.group(1).strip()
28 | # Extract and convert Action-Input to a dictionary
29 | action_input_match = re.search(r"(.*?)<\/Action-Input>", text, re.DOTALL)
30 | if action_input_match:
31 | action_input_str = action_input_match.group(1).strip()
32 | try:
33 | # Convert string to dictionary safely using ast.literal_eval
34 | result['Action Input'] = ast.literal_eval(action_input_str.replace('null', 'None').replace('true', 'True').replace('false', 'False'))
35 | except (ValueError, SyntaxError):
36 | # If there's an issue with conversion, store it as raw string
37 | result['Action Input'] = action_input_str
38 | return result
39 |
--------------------------------------------------------------------------------
/src/router/prompt.md:
--------------------------------------------------------------------------------
1 | ### **LLM Router**
2 | You are an advanced intelligent LLM Router responsible for determining the most accurate route for a given user query. Your primary task is to analyze the query, reason about its complexity, and map it to the most appropriate route from the available routes.
3 |
4 | **Instructions (optional):**
5 | {instructions}
6 |
7 | **Available Routes:**
8 | {routes}
9 |
10 | ---
11 |
12 | ### **Enhanced Reasoning and Decision-Making Process**:
13 | 0. **Instructions Priority**: If instructions are provided, they must be given top priority. Always refer to the instructions before making any decisions.
14 |
15 | 1. **Thorough Query Understanding**: Analyze the query to capture nuances, objectives, and any hidden complexities.
16 |
17 | 2. **Route Comparison**: Use detailed reasoning to compare the query against the available route descriptions. Ensure you consider both simple and advanced requirements within the query.
18 |
19 | 3. **Contextual Mapping**: Factor in the user’s intention and potential requirements (e.g., tool access, complex reasoning, or multiple steps) before choosing a route.
20 |
21 | 4. **Complex Scenario Handling**: In cases of ambiguity or complex queries, apply advanced reasoning by weighing potential routes based on their descriptions and the needs of the query.
22 |
23 | 5. **Judgment Enhancement**: Use a higher level of reasoning to ensure that no errors in routing occur, especially in situations where the query is multifaceted or when multiple routes seem plausible.
24 |
25 | 6. **Avoid Redundancy**: Avoid mapping the query to multiple routes unless explicitly stated, ensuring that tasks are distinct and non-overlapping.
26 |
27 | 7. **Confidence and Precision**: Always make confident and precise routing decisions, ensuring that the final output matches the query's requirements perfectly.
28 |
29 | ---
30 |
31 | ### **Response Format**:
32 | Your task is to return the correct route based on the query in the following JSON format:
33 |
34 | ```json
35 | {{
36 | "route": "the route name goes over here"
37 | }}
38 | ```
--------------------------------------------------------------------------------
/src/vectorstore/chromadb.py:
--------------------------------------------------------------------------------
1 | from src.vectorstore.base import BaseVectorStore,Document
2 | from src.embedding import BaseEmbedding
3 | from chromadb.config import Settings
4 | from chromadb import Client
5 | from pathlib import Path
6 |
7 | class ChromaDBVectorStore(BaseVectorStore):
8 | def __init__(self,collection_name:str,embedding:BaseEmbedding,path:Path=Path.cwd()/'.chroma'):
9 | self.settings=Settings(anonymized_telemetry=False,is_persistent=True,persist_directory=path.as_posix())
10 | self.client=Client(settings=self.settings)
11 | self.db=self.create_collection(collection_name=collection_name,embedding=embedding)
12 |
13 | def create_collection(self,collection_name:str,embedding:BaseEmbedding=None):
14 | return self.client.get_or_create_collection(name=collection_name,embedding_function=embedding)
15 |
16 | def insert(self,documents:list[Document]):
17 | ids=[doc.id for doc in documents]
18 | contents=[doc.content for doc in documents]
19 | metadatas=[doc.metadata for doc in documents]
20 | self.db.add(ids=ids,documents=contents,metadatas=metadatas)
21 |
22 | def search(self, query:str, k=5):
23 | return self.db.query(query_texts=[query],n_results=k)
24 |
25 | def update(self, id:str, content:str, metadata:dict):
26 | self.db.update(ids=[id],documents=[content],metadatas=[metadata])
27 |
28 | def delete(self, id):
29 | self.db.delete(ids=[id])
30 |
31 | def get(self, id):
32 | response = self.db.get(ids=[id])
33 | return self.parse_db_response(response)
34 |
35 | def delete_collection(self, collection_name):
36 | self.client.delete_collection(name=collection_name)
37 |
38 | def all_collections(self):
39 | return self.client.list_collections()
40 |
41 | def all(self):
42 | response=self.db.get()
43 | return self.parse_db_response(response)
44 |
45 | def parse_db_response(self, response: dict) -> list[Document]:
46 | ids = response.get('ids', [])
47 | documents = response.get('documents', [])
48 | metadatas = response.get('metadatas', [])
49 | result = []
50 | for i in range(len(ids)):
51 | doc = Document(
52 | id=ids[i],
53 | content=documents[i] if i < len(documents) else '',
54 | metadata=metadatas[i] if i < len(metadatas) else {}
55 | )
56 | result.append(doc)
57 | return result
58 |
--------------------------------------------------------------------------------
/src/agent/web/context/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass,field
2 | from typing import Optional,Any
3 |
4 | @dataclass
5 | class ContextConfig:
6 | credentials:dict[str,Any]=field(default_factory=dict)
7 | minimum_wait_page_load_time:float=0.5
8 | wait_for_network_idle_page_load_time:float=1
9 | maximum_wait_page_load_time:float=5
10 | disable_security:bool=True
11 | user_agent:str="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36"
12 |
13 |
14 | RELEVANT_FILE_EXTENSIONS = set([
15 | '.pdf','.doc','.docx','.xls',
16 | '.xlsx','.ppt','.pptx','.txt',
17 | '.csv','.json','.png','.jpg',
18 | '.jpeg','.gif','.svg','.zip',
19 | '.rar','.7z','.tar','.gz',
20 | '.bz2','.mp3','.mp4','.wav',
21 | '.ogg','.flac','.webm','.mp4',
22 | '.avi','.mkv','.mov','.wmv',
23 | '.mpg','.mpeg','.m4v','.3gp',
24 | ])
25 |
26 | RELEVANT_CONTEXT_TYPES =set([
27 | #Document Files
28 | 'application/x-7z-compressed',
29 | 'application/zip',
30 | 'application/x-rar-compressed',
31 | 'application/x-iso9660-image',
32 | 'application/x-tar',
33 | 'application/x-gzip',
34 | 'application/x-bzip2',
35 | 'application/vnd.ms-excel',
36 | 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
37 | 'application/vnd.ms-powerpoint',
38 | 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
39 | 'application/pdf',
40 | 'application/msword',
41 | 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
42 | 'application/vnd.oasis.opendocument.text',
43 | #Audio Files
44 | 'audio/mpeg','audio/wav','audio/mp3','audio/ogg','audio/flac','audio/webm',
45 | #Video Files
46 | 'video/mp4','video/ogg','video/webm','video/quicktime',
47 | #Image Files
48 | 'image/jpeg','image/png','image/gif','image/bmp','image/svg+xml'
49 | ])
50 |
51 | RELEVANT_RESOURCE_TYPES = [
52 | 'document',
53 | 'stylesheet',
54 | 'image',
55 | 'font',
56 | 'script',
57 | 'iframe',
58 | ]
59 |
60 | IGNORED_URL_PATTERNS = set([
61 | 'analytics',
62 | 'tracking',
63 | 'telemetry',
64 | 'googletagmanager',
65 | 'beacon',
66 | 'metrics',
67 | 'doubleclick',
68 | 'adsystem',
69 | 'adserver',
70 | 'advertising',
71 | 'cdn.optimizely',
72 | 'facebook.com/plugins',
73 | 'platform.twitter',
74 | 'linkedin.com/embed',
75 | 'livechat',
76 | 'zendesk',
77 | 'intercom',
78 | 'crisp.chat',
79 | 'hotjar',
80 | 'push-notifications',
81 | 'onesignal',
82 | 'pushwoosh',
83 | 'heartbeat',
84 | 'ping',
85 | 'alive',
86 | 'webrtc',
87 | 'rtmp://',
88 | 'wss://',
89 | 'cloudfront.net',
90 | 'fastly.net'
91 | ])
--------------------------------------------------------------------------------
/src/speech/__init__.py:
--------------------------------------------------------------------------------
1 | from src.inference import BaseInference
2 | from pyaudio import PyAudio,paInt16,Stream
3 | from tempfile import NamedTemporaryFile
4 | from src.message import AIMessage
5 | import keyboard
6 | import wave
7 | import os
8 |
9 | class Speech:
10 | def __init__(self,llm:BaseInference=None):
11 | self.chunk_size=1024
12 | self.frame_rate=44100
13 | self.channels=1
14 | self.audio=PyAudio()
15 | self.stream=None
16 | self.llm=llm
17 | self.tempfile_path=''
18 |
19 | def setup_stream(self):
20 | audio=self.audio
21 | self.stream=audio.open(**{
22 | 'format':paInt16,
23 | 'channels':self.channels,
24 | 'rate':self.frame_rate,
25 | 'input':True,
26 | 'frames_per_buffer':self.chunk_size
27 | })
28 |
29 | def get_stream(self)->Stream:
30 | if self.stream is None:
31 | self.setup_stream()
32 | return self.stream
33 |
34 | def record_audio(self)->bytes:
35 | stream=self.get_stream()
36 | frames=[]
37 | is_recording=True
38 | print('Recording audio...')
39 | print('Press enter to stop recording...')
40 | while is_recording:
41 | data=stream.read(self.chunk_size)
42 | frames.append(data)
43 | if keyboard.is_pressed('enter'):
44 | is_recording=False
45 | stream.stop_stream()
46 | print('Recording stopped...')
47 | return b''.join(frames)
48 |
49 | def bytes_to_tempfile(self, bytes: bytes):
50 | temp_file = NamedTemporaryFile(delete=False, suffix='.wav')
51 | self.tempfile_path = temp_file.name
52 | temp_file.close()
53 | try:
54 | with wave.open(self.tempfile_path, 'wb') as wf:
55 | wf.setnchannels(self.channels)
56 | wf.setsampwidth(self.audio.get_sample_size(paInt16))
57 | wf.setframerate(self.frame_rate)
58 | wf.writeframes(bytes)
59 | except Exception as e:
60 | raise Exception(f"Export failed. {e}")
61 |
62 | def close(self):
63 | if self.stream is not None:
64 | self.stream.close()
65 | if self.audio is not None:
66 | self.audio.terminate()
67 | self.stream=None
68 | self.audio=None
69 | os.remove(self.tempfile_path)
70 |
71 | def invoke(self)->AIMessage:
72 | audio_bytes=self.record_audio()
73 | self.bytes_to_tempfile(audio_bytes)
74 | print(f'Using {self.llm.model} audio to text...')
75 | response=self.llm.invoke(file_path=self.tempfile_path)
76 | self.close()
77 | return response
78 |
79 |
--------------------------------------------------------------------------------
/src/message/__init__.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from abc import ABC
3 | import requests
4 | import base64
5 | import re
6 |
7 | class BaseMessage(ABC):
8 | def to_dict(self)->dict[str,str]:
9 | return {
10 | 'role': self.role,
11 | 'content': f'''{self.content}'''
12 | }
13 | def __repr__(self):
14 | class_name = self.__class__.__name__
15 | attributes = ", ".join(f"{key}={value}" for key, value in self.__dict__.items())
16 | return f"{class_name}({attributes})"
17 |
18 | class HumanMessage(BaseMessage):
19 | def __init__(self,content):
20 | self.role='user'
21 | self.content=content
22 |
23 | class AIMessage(BaseMessage):
24 | def __init__(self,content):
25 | self.role='assistant'
26 | self.content=content
27 |
28 | class SystemMessage(BaseMessage):
29 | def __init__(self,content):
30 | self.role='system'
31 | self.content=content
32 |
33 | class ImageMessage(BaseMessage):
34 | def __init__(self,text:str=None,image_path:str=None,image_obj:str=None):
35 | self.role='user'
36 | if image_obj is not None or image_path is None:
37 | self.content=(text,self.__encoder(image_obj))
38 | elif image_path is not None or image_obj is None:
39 | self.content=(text,self.__image_to_base64(image_path))
40 | else:
41 | raise Exception('image_path and image_base_64 cannot be both None or both not None')
42 |
43 | def __is_url(self,image_path:str)->bool:
44 | url_pattern = re.compile(r'^https?://')
45 | return url_pattern.match(image_path) is not None
46 |
47 | def __is_file_path(self,image_path:str)->bool:
48 | file_path_pattern = re.compile(r'^([./~]|([a-zA-Z]:)|\\|//)?\.?\/?[a-zA-Z0-9._-]+(\.[a-zA-Z0-9]+)?$')
49 | return file_path_pattern.match(image_path) is not None
50 |
51 | def __image_to_base64(self,image_source: str) -> str:
52 | if self.__is_url(image_source):
53 | response = requests.get(image_source)
54 | bytes = BytesIO(response.content)
55 | image_bytes = bytes.read()
56 | elif self.__is_file_path(image_source):
57 | with open(image_source, 'rb') as image:
58 | image_bytes = image.read()
59 | else:
60 | raise ValueError("Invalid image source. Must be a URL or file path.")
61 | return base64.b64encode(image_bytes).decode('utf-8')
62 |
63 | def __encoder(self,b:bytes):
64 | return base64.b64encode(b).decode('utf-8')
65 |
66 | class ToolMessage(BaseMessage):
67 | def __init__(self,id:str,name:str,args:dict):
68 | self.id=id
69 | self.role='tool'
70 | self.name=name
71 | self.args=args
--------------------------------------------------------------------------------
/src/tool/__init__.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Optional,Callable
3 | from inspect import getdoc
4 | from json import dumps
5 |
6 | class Tool:
7 | def __init__(self, name: str='',description: Optional[str]=None, params: Optional[BaseModel]=None,schema:Optional[dict]=None,func:Optional[Callable]=None):
8 | self.name = name
9 | self.params = params
10 | self.func = func
11 | self.description = description
12 | self.schema = schema
13 |
14 | def __call__(self, func):
15 | if self.params:
16 | # Store the decorated function and its metadata
17 | self.description = self.description or getdoc(func)
18 | skip_keys=['title']
19 | if self.params is not None:
20 | self.schema = {k:{term:content for term,content in v.items() if term not in skip_keys} for k,v in self.params.model_json_schema().get('properties').items() if k not in skip_keys}
21 | elif self.schema is not None:
22 | self.schema = {k:{term:content for term,content in v.items() if term not in skip_keys} for k,v in self.schema.get('properties').items() if k not in skip_keys}
23 | self.func = func
24 | return self # Return the Tool Instance
25 |
26 | def invoke(self, **kwargs):
27 | # Validate inputs using the schema and invoke the wrapped function
28 | try:
29 | if self.params:
30 | args = self.params(**kwargs) # Validate arguments
31 | return self.func(**args.dict()) # Call the function with validated arg
32 | else:
33 | return self.func(**kwargs)
34 | except Exception as e:
35 | return f"Error: {str(e)}"
36 |
37 | async def async_invoke(self, **kwargs):
38 | # Validate inputs using the schema and invoke the wrapped function
39 | try:
40 | if self.params:
41 | args = self.params(**kwargs) # Validate arguments
42 | return await self.func(**args.dict()) # Call the function with validated arg
43 | else:
44 | return await self.func(**kwargs)
45 | except Exception as e:
46 | return f"Error: {str(e)}"
47 |
48 | def __repr__(self):
49 | if self.params is not None:
50 | params=list(self.params.model_json_schema().get('properties').keys())
51 | elif self.schema is not None:
52 | params=list(self.schema.get('properties').keys())
53 | return f"Tool(name={self.name}, description={self.description}, params={params})"
54 |
55 | def get_prompt(self):
56 | return f'''Tool Name: {self.name}\nTool Description: {self.description}\nTool Input: {dumps(self.schema,indent=2)}'''
--------------------------------------------------------------------------------
/src/embedding/gemini.py:
--------------------------------------------------------------------------------
1 | from requests import RequestException,HTTPError,ConnectionError
2 | from src.embedding import BaseEmbedding
3 | from httpx import Client
4 | from typing import Literal
5 |
6 | class GeminiEmbedding(BaseEmbedding):
7 | def __init__(self,model:str='',output_dimensionality:int=None,task_type:Literal['TASK_TYPE_UNSPECIFIED','RETRIEVAL_QUERY','RETRIEVAL_DOCUMENT','SEMANTIC_SIMILARITY','CLASSIFICATION','CLUSTERING']='',api_key:str='',base_url:str=''):
8 | self.api_key=api_key
9 | self.model=model
10 | self.base_url=base_url
11 | self.output_dimensionality=output_dimensionality
12 | self.task_type=task_type
13 | self.headers={'Content-Type': 'application/json'}
14 | def embed(self,text:list[str]|str='',title:str=''):
15 | headers=self.headers
16 | if isinstance(text,list):
17 | mode='batchEmbedContents'
18 | else:
19 | mode='embedContent'
20 | url=self.base_url or f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:{mode}"
21 | params={'key':self.api_key}
22 | if isinstance(text,list):
23 | payload={
24 | 'requests':[
25 | {
26 | 'model':f'models/{self.model}',
27 | 'content':{
28 | 'parts':[
29 | {
30 | 'text':_text
31 | }
32 | ]
33 | }
34 | }
35 | for _text in text]
36 | }
37 | else:
38 | payload={
39 | 'model':f'models/{self.model}',
40 | 'content':{
41 | 'parts':[
42 | {
43 | 'text':text
44 | }
45 | ]
46 | }
47 | }
48 | if self.task_type:
49 | payload['task_type']=self.task_type
50 | if self.output_dimensionality:
51 | payload['output_dimensionality']=self.output_dimensionality
52 | if title:
53 | payload['title']=title
54 | try:
55 | with Client() as client:
56 | response=client.post(url=url,json=payload,headers=headers,params=params)
57 | response.raise_for_status()
58 | if isinstance(text,list):
59 | data=response.json()
60 | return [e['values'] for e in data['embeddings']]
61 | else:
62 | data=response.json()
63 | return data['embedding']['values']
64 | except HTTPError as err:
65 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
66 | except ConnectionError as err:
67 | print(err)
--------------------------------------------------------------------------------
/src/agent/web/browser/__init__.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.browser.config import BrowserConfig,BROWSER_ARGS,SECURITY_ARGS,IGNORE_DEFAULT_ARGS
2 | from playwright.async_api import async_playwright,Browser as PlaywrightBrowser,Playwright
3 |
4 | class Browser:
5 | def __init__(self,config:BrowserConfig=None):
6 | self.playwright:Playwright = None
7 | self.config = config if config else BrowserConfig()
8 | self.playwright_browser:PlaywrightBrowser = None
9 |
10 | async def __aenter__(self):
11 | await self.init_browser()
12 | return self
13 |
14 | async def __aexit__(self, exc_type, exc_val, exc_tb):
15 | await self.close_browser()
16 |
17 | async def init_browser(self):
18 | self.playwright =await async_playwright().start()
19 | self.playwright_browser = await self.setup_browser(self.config.browser)
20 |
21 | async def get_playwright_browser(self)->PlaywrightBrowser:
22 | if self.playwright_browser is None:
23 | await self.init_browser()
24 | return self.playwright_browser
25 |
26 | async def setup_browser(self,browser:str)->PlaywrightBrowser:
27 | parameters={
28 | 'headless':self.config.headless,
29 | 'downloads_path':self.config.downloads_dir,
30 | 'timeout':self.config.timeout,
31 | 'slow_mo':self.config.slow_mo,
32 | 'args':BROWSER_ARGS + SECURITY_ARGS,
33 | 'ignore_default_args': IGNORE_DEFAULT_ARGS
34 | }
35 | if self.config.wss_url is not None:
36 | if browser=='chrome':
37 | browser_instance=await self.playwright.chromium.connect(self.config.wss_url)
38 | elif browser=='firefox':
39 | browser_instance=await self.playwright.firefox.connect(self.config.wss_url)
40 | elif browser=='edge':
41 | browser_instance=await self.playwright.chromium.connect(self.config.wss_url)
42 | else:
43 | raise Exception('Invalid Browser Type')
44 | elif self.config.browser_instance_dir is not None:
45 | browser_instance=None
46 | else:
47 | if self.config.device is not None:
48 | parameters={**self.playwright.devices.get(self.config.device)}
49 | parameters.pop('default_browser_type',None)
50 | if browser=='chrome':
51 | browser_instance=await self.playwright.chromium.launch(channel='chrome',**parameters)
52 | elif browser=='firefox':
53 | browser_instance=await self.playwright.firefox.launch(**parameters)
54 | elif browser=='edge':
55 | browser_instance=await self.playwright.chromium.launch(channel='msedge',**parameters)
56 | else:
57 | raise Exception('Invalid Browser Type')
58 | return browser_instance
59 |
60 | async def close_browser(self):
61 | try:
62 | if self.playwright_browser:
63 | await self.playwright_browser.close()
64 | if self.playwright:
65 | await self.playwright.stop()
66 | except Exception as e:
67 | print('Browser failed to close')
68 | finally:
69 | self.playwright=None
70 | self.playwright_browser=None
--------------------------------------------------------------------------------
/src/agent/web/dom/views.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass,field
2 | from textwrap import shorten
3 |
4 | @dataclass
5 | class BoundingBox:
6 | left:int
7 | top:int
8 | width:int
9 | height:int
10 |
11 | def to_string(self):
12 | return f'({self.left},{self.top},{self.width},{self.height})'
13 |
14 | def to_dict(self):
15 | return {'left':self.left,'top':self.top,'width':self.width,'height':self.height}
16 |
17 | @dataclass
18 | class CenterCord:
19 | x:int
20 | y:int
21 |
22 | def to_string(self)->str:
23 | return f'({self.x},{self.y})'
24 |
25 | def to_dict(self):
26 | return {'x':self.x,'y':self.y}
27 |
28 | @dataclass
29 | class DOMElementNode:
30 | tag: str
31 | role: str
32 | name: str
33 | bounding_box: BoundingBox
34 | center: CenterCord
35 | attributes: dict[str,str] = field(default_factory=dict)
36 | xpath: dict[str,str]=field(default_factory=dict)
37 | viewport: tuple[int,int]=field(default_factory=tuple)
38 |
39 | def __repr__(self):
40 | return f"DOMElementNode(tag='{self.tag}', role='{self.role}', name='{self.name}', attributes={self.attributes}, cordinates={self.center}, bounding_box={self.bounding_box}, xpath='{self.xpath}')"
41 |
42 | def to_dict(self)->dict[str,str]:
43 | return {'tag':self.tag,'role':self.role,'name':self.name,'bounding_box':self.bounding_box.to_dict(),'attributes':self.attributes, 'cordinates':self.center.to_dict()}
44 |
45 | @dataclass
46 | class ScrollElementNode:
47 | tag: str
48 | role: str
49 | name: str
50 | attributes: dict[str,str] = field(default_factory=dict)
51 | xpath: dict[str,str]=field(default_factory=dict)
52 | viewport: tuple[int,int]=field(default_factory=tuple)
53 |
54 | def __repr__(self):
55 | return f"ScrollableElementNode(tag='{self.tag}', role='{self.role}', name='{shorten(self.name,width=500)}', attributes={self.attributes}, xpath='{self.xpath}')"
56 |
57 | def to_dict(self)->dict[str,str]:
58 | return {'tag':self.tag,'role':self.role,'name':self.name,'attributes':self.attributes}
59 |
60 | @dataclass
61 | class DOMTextualNode:
62 | tag:str
63 | role:str
64 | content:str
65 | center: CenterCord
66 | xpath: dict[str,str]=field(default_factory=dict)
67 | viewport: tuple[int,int]=field(default_factory=tuple)
68 |
69 | def __repr__(self):
70 | return f'DOMTextualNode(tag={self.tag}, role={self.role}, content={self.content}, cordinates={self.center}, xpath={self.xpath})'
71 |
72 | def to_dict(self)->dict[str,str]:
73 | return {'tag':self.tag,'role':self.role,'content':self.content, 'center':self.center.to_dict()}
74 |
75 | @dataclass
76 | class DOMState:
77 | interactive_nodes: list[DOMElementNode]=field(default_factory=list)
78 | informative_nodes:list[DOMTextualNode]=field(default_factory=list)
79 | scrollable_nodes:list[ScrollElementNode]=field(default_factory=list)
80 | selector_map: dict[str,DOMElementNode|ScrollElementNode]=field(default_factory=dict)
81 |
82 | def interactive_elements_to_string(self)->str:
83 | return '\n'.join([f'{index} - Tag: {node.tag} Role: {node.role} Name: {node.name} Attributes: {node.attributes} Cordinates: {node.center.to_string()}' for index,(node) in enumerate(self.interactive_nodes)])
84 |
85 | def informative_elements_to_string(self)->str:
86 | return '\n'.join([f'Tag: {node.tag} Role: {node.role} Content: {node.content}' for node in self.informative_nodes])
87 |
88 | def scrollable_elements_to_string(self)->str:
89 | n=len(self.interactive_nodes)
90 | return '\n'.join([f'{n+index} - Tag: {node.tag} Role: {node.role} Name: {shorten(node.name,width=500)} Attributes: {node.attributes}' for index,node in enumerate(self.scrollable_nodes)])
91 |
92 |
--------------------------------------------------------------------------------
/src/tool/registry/__init__.py:
--------------------------------------------------------------------------------
1 | # src/tool/registry/__init__.py
2 | from src.tool.registry.views import Function,ToolResult
3 | from src.tool import Tool
4 |
5 | class Registry:
6 | def __init__(self,tools:list[Tool]):
7 | self.tools=tools
8 | self.tools_registry=self.registry()
9 |
10 | def tools_prompt(self,excluded_tools:list[str]=[])->str:
11 | prompts=[]
12 | for tool in self.tools:
13 | if tool.name in excluded_tools:
14 | continue
15 | prompts.append(tool.get_prompt())
16 | return '\n\n'.join(prompts)
17 |
18 | def registry(self)->dict[str,Function]:
19 | tools_registry={}
20 | for tool in self.tools:
21 | tools_registry.update({tool.name : Function(name=tool.name,description=tool.description,params=tool.params,function=tool.func)})
22 | return tools_registry
23 |
24 | async def async_execute(self,name:str,input:dict,**kwargs)->ToolResult:
25 | tool=self.tools_registry.get(name)
26 | try:
27 | # Check if name is None or empty, which indicates an LLM failure
28 | if not name:
29 | raise ValueError('Action Name was None or empty. The LLM failed to choose a valid action.')
30 |
31 | if tool is None:
32 | raise ValueError(f'Tool "{name}" not found. Please choose from the available tools.')
33 |
34 | if tool.params:
35 | # Ensure input is a dictionary before validation
36 | if not isinstance(input, dict):
37 | raise TypeError(f"Action Input for tool '{name}' must be a dictionary, but got {type(input)}: {input}")
38 | tool_params=tool.params.model_validate(input)
39 | params=tool_params.model_dump()|kwargs
40 | else:
41 | params=input|kwargs
42 |
43 | content=await tool.function(**params)
44 | return ToolResult(name=name,content=content)
45 | except Exception as e:
46 | # If 'name' was the issue, use a placeholder 'Invalid Action'
47 | # Otherwise, use the provided 'name'.
48 | error_name = name if name else "Invalid Action"
49 | error_content = f"Error executing tool '{error_name}': {str(e)}"
50 | print(f"DEBUG: Tool execution failed. Name: {name}, Input: {input}, Error: {e}")
51 | return ToolResult(name=error_name, content=error_content)
52 |
53 | def execute(self,name:str,input:dict,**kwargs)->ToolResult:
54 | tool=self.tools_registry.get(name)
55 | try:
56 | # Check if name is None or empty
57 | if not name:
58 | raise ValueError('Action Name was None or empty. The LLM failed to choose a valid action.')
59 |
60 | if tool is None:
61 | raise ValueError(f'Tool "{name}" not found. Please choose from the available tools.')
62 |
63 | if tool.params:
64 | # Ensure input is a dictionary before validation
65 | if not isinstance(input, dict):
66 | raise TypeError(f"Action Input for tool '{name}' must be a dictionary, but got {type(input)}: {input}")
67 | tool_params=tool.params.model_validate(input)
68 | params=tool_params.model_dump()|kwargs
69 | else:
70 | params=input|kwargs
71 |
72 | content=tool.function(**params)
73 | return ToolResult(name=name,content=content)
74 | except Exception as e:
75 | error_name = name if name else "Invalid Action"
76 | error_content = f"Error executing tool '{error_name}': {str(e)}"
77 | print(f"DEBUG: Tool execution failed. Name: {name}, Input: {input}, Error: {e}")
78 | return ToolResult(name=error_name,content=error_content)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
20 |
21 |
22 |
23 | **Web Navigator** is your intelligent browsing companion, built to seamlessly navigate websites, interact with dynamic content, perform smart searches, download files, and adapt to ever-changing pages — all with minimal effort from you. Powered by advanced LLMs and the robust Playwright framework, it transforms complex web tasks into streamlined, automated workflows that boost productivity and save time.
24 |
25 | ## 🛠️Installation Guide
26 |
27 | ### **Prerequisites**
28 |
29 | - Python 3.11 or higher
30 | - UV
31 |
32 | ### **Installation Steps**
33 |
34 | **Clone the repository:**
35 |
36 | ```bash
37 | git clone https://github.com/CursorTouch/Web-Navigator.git
38 | cd Web-Navigator
39 | ```
40 |
41 | **Install dependencies:**
42 |
43 | ```bash
44 | uv sync
45 | ```
46 |
47 | **Setup Playwright:**
48 |
49 | ```bash
50 | playwright install
51 | ```
52 |
53 | ---
54 |
55 | **Setting up the `.env` file:**
56 |
57 | ```bash
58 | GOOGLE_API_KEY=""
59 | ```
60 |
61 | Basic setup of the agent.
62 |
63 | ```python
64 | from src.inference.gemini import ChatGemini
65 | from src.agent.web import WebAgent
66 | from dotenv import load_dotenv
67 | import os
68 |
69 | load_dotenv()
70 | google_api_key=os.getenv('GOOGLE_API_KEY')
71 |
72 | llm=ChatGemini(model='gemini-2.0-flash',api_key=google_api_key,temperature=0)
73 | agent=Agent(llm=llm,verbose=True,use_vision=False)
74 |
75 | user_query=input('Enter your query: ')
76 | agent_response=agent.invoke(user_query)
77 | print(agent_response.get('output'))
78 |
79 | ```
80 |
81 | Execute the following command to start the agent:
82 |
83 | ```bash
84 | python app.py
85 | ```
86 |
87 | ## 🎥Demos
88 |
89 | **Prompt:** I want to know the price details of the RTX 4060 laptop gpu from varrious sellers from amazon.in
90 |
91 | https://github.com/user-attachments/assets/c729dda9-0ecc-4b07-9113-62fddccca52f
92 |
93 | **Prompt:** Make a twitter post about AI on X
94 |
95 | https://github.com/user-attachments/assets/126ef697-f506-4630-9a0a-1dbbfead9f7e
96 |
97 | **Prompt:** Can you play the trailer of GTA 6 on youtube
98 |
99 | https://github.com/user-attachments/assets/7abde708-7fe0-46f8-96ac-16124aaf2ef4
100 |
101 | **Prompt:** Can you go to my github account and visit the Windows MCP
102 |
103 | https://github.com/user-attachments/assets/cb8ad60c-0609-42e3-9fb9-584ad77c4e3a
104 |
105 | ---
106 |
107 | ## 🪪License
108 |
109 | This project is licensed under MIT License - see the [LICENSE](LICENSE) file for details.
110 |
111 | ## 🤝Contributing
112 |
113 | Contributions are welcome! Please see [CONTRIBUTING](CONTRIBUTING.md) for setup instructions and development guidelines.
114 |
115 | Made with ❤️ by [Jeomon George](https://github.com/Jeomon), [Muhammad Yaseen](https://github.com/mhmdyaseen)
116 |
117 | ---
118 |
119 | ## 📒References
120 |
121 | - **[Playwright Documentation](https://playwright.dev/docs/intro)**
122 | - **[LangGraph Examples](https://github.com/langchain-ai/langgraph/blob/main/examples/web-navigation/web_voyager.ipynb)**
123 | - **[vimGPT](https://github.com/ishan0102/vimGPT)**
124 | - **[WebVoyager](https://github.com/MinorJerry/WebVoyager)**
125 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | uploads/
16 | screenshots/
17 | user_data/
18 | memory_data/
19 | db/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 | /Web312venv/
95 |
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117 | .pdm.toml
118 | .pdm-python
119 | .pdm-build/
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .env
133 | .venv
134 | env/
135 | venv/
136 | ENV/
137 | env.bak/
138 | venv.bak/
139 |
140 | # Spyder project settings
141 | .spyderproject
142 | .spyproject
143 |
144 | # Rope project settings
145 | .ropeproject
146 |
147 | # mkdocs documentation
148 | /site
149 |
150 | # mypy
151 | .mypy_cache/
152 | .dmypy.json
153 | dmypy.json
154 |
155 | # Pyre type checker
156 | .pyre/
157 |
158 | # pytype static type analyzer
159 | .pytype/
160 |
161 | # Cython debug symbols
162 | cython_debug/
163 |
164 | # PyCharm
165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167 | # and can be added to the global gitignore or merged into this file. For a more nuclear
168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169 | #.idea/
170 |
171 | test.py
172 | notebook.ipynb
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Web-Navigator
2 |
3 | Thank you for your interest in contributing to Web Agent! This document provides guidelines and instructions for contributing to this project.
4 |
5 | ## Table of Contents
6 |
7 | - [Getting Started](#getting-started)
8 | - [Development Environment](#development-environment)
9 | - [Installation](#installation-from-source)
10 |
11 | - [Development Workflow](#development-workflow)
12 | - [Branching Strategy](#branching-strategy)
13 | - [Commit Messages](#commit-messages)
14 | - [Code Style](#code-style)
15 | - [Pre-commit Hooks](#pre-commit-hooks)
16 |
17 | - [Testing](#testing)
18 | - [Running Tests](#running-tests)
19 | - [Adding Tests](#adding-tests)
20 |
21 | - [Pull Requests](#pull-requests)
22 | - [Creating a Pull Request](#creating-a-pull-request)
23 | - [Pull Request Template](#pull-request-template)
24 |
25 | - [Documentation](#documentation)
26 | - [Getting Help](#getting-help)
27 |
28 | ## Getting Started
29 |
30 | ### Development Environment
31 |
32 | WebAgent requires:
33 | - Python 3.11 or later
34 |
35 | ### Installation
36 |
37 | 1. Fork the repository on GitHub.
38 | 2. Clone your fork locally:
39 |
40 | ```bash
41 | git clone https://github.com/CursorTouch/Web-Agent.git
42 | cd web-agent
43 | ```
44 |
45 | 3. Install the package in development mode:
46 |
47 | ```bash
48 | pip install -e ".[dev,search]"
49 | ```
50 |
51 | 4. Set up pre-commit hooks:
52 |
53 | ```bash
54 | pip install pre-commit
55 | pre-commit install
56 | ```
57 |
58 | ## Development Workflow
59 |
60 | ### Branching Strategy
61 |
62 | - `main` branch contains the latest stable code
63 | - Create feature branches from `main` named according to the feature you're implementing: `feature/your-feature-name`
64 | - For bug fixes, use: `fix/bug-description`
65 |
66 | ### Commit Messages
67 |
68 | For now no commit style is enforced, try to keep your commit messages informational.
69 |
70 | ### Code Style
71 |
72 | Key style guidelines:
73 |
74 | - Line length: 100 characters
75 | - Use double quotes for strings
76 | - Follow PEP 8 naming conventions
77 | - Add type hints to function signatures
78 |
79 | ### Pre-commit Hooks
80 |
81 | We use pre-commit hooks to ensure code quality before committing. The configuration is in `.pre-commit-config.yaml`.
82 |
83 | The hooks will:
84 |
85 | - Run linting checks
86 | - Check for trailing whitespace and fix it
87 | - Ensure files end with a newline
88 | - Validate YAML files
89 | - Check for large files
90 | - Remove debug statements
91 |
92 | ## Testing
93 |
94 | ### Running Tests
95 |
96 | Run the test suite with pytest:
97 |
98 | ```bash
99 | pytest
100 | ```
101 |
102 | To run specific test categories:
103 |
104 | ```bash
105 | pytest tests/
106 | ```
107 |
108 | ### Adding Tests
109 |
110 | - Add unit tests for new functionality in `tests/unit/`
111 | - For slow or network-dependent tests, mark them with `@pytest.mark.slow` or `@pytest.mark.integration`
112 | - Aim for high test coverage of new code
113 |
114 | ## Pull Requests
115 |
116 | ### Creating a Pull Request
117 |
118 | 1. Ensure your code passes all tests and pre-commit hooks
119 | 2. Push your changes to your fork
120 | 3. Submit a pull request to the main repository
121 | 4. Follow the pull request template
122 |
123 | ## Documentation
124 |
125 | - Update docstrings for new or modified functions, classes, and methods
126 | - Use Google-style docstrings:
127 |
128 | ```python
129 | def function_name(param1: type, param2: type) -> return_type:
130 | """Short description.
131 | Longer description if needed.
132 |
133 | Args:
134 | param1: Description of param1
135 | param2: Description of param2
136 |
137 | Returns:
138 | Description of return value
139 |
140 | Raises:
141 | ExceptionType: When and why this exception is raised
142 | """
143 | ```
144 |
145 |
146 | ## Getting Help
147 |
148 | If you need help with your contribution:
149 |
150 | - Open an issue for discussion
151 | - Reach out to any of the our maintainers
152 |
153 | We look forward to your contributions!
--------------------------------------------------------------------------------
/src/agent/web/tools/views.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel,Field
2 | from typing import Literal
3 |
4 | class SharedBaseModel(BaseModel):
5 | class Config:
6 | extra="allow"
7 |
8 | class Done(SharedBaseModel):
9 | content:str = Field(...,description="Summary of the completed task in proper markdown format explaining what was accomplished",examples=["The task is completed successfully. User profile updated with new email address."])
10 |
11 | class Click(SharedBaseModel):
12 | index:int = Field(...,description="The index/label of the interactive element to click (buttons, links, checkboxes, tabs, etc.)",examples=[0])
13 |
14 | class Type(SharedBaseModel):
15 | index:int = Field(...,description="The index/label of the input element to type text into (text fields, search boxes, text areas)",examples=[0])
16 | text:str = Field(...,description="The text content to type into the input field",examples=["hello world","user@example.com","My search query"])
17 | clear:Literal['True','False']=Field(description="Whether to clear existing text before typing new content",default="False",examples=['True'])
18 | press_enter:Literal['True','False']=Field(description="Whether to press Enter after typing",default="False",examples=['True'])
19 |
20 | class Wait(SharedBaseModel):
21 | time:int = Field(...,description="Number of seconds to wait for page loading, animations, or content to appear",examples=[1,3,5])
22 |
23 | class Scroll(SharedBaseModel):
24 | direction: Literal['up','down'] = Field(description="The direction to scroll content", examples=['up','down'], default='up')
25 | index: int = Field(description="Index of specific scrollable element, if None then scrolls the entire page", examples=[0, 5, 12,None],default=None)
26 | amount: int = Field(description="Number of pixels to scroll, if None then scrolls by page/container height. Must required for scrollable container elements and the amount should be small", examples=[100, 25, 50],default=500)
27 |
28 | class GoTo(SharedBaseModel):
29 | url:str = Field(...,description="The complete URL to navigate to including protocol (http/https)",examples=["https://www.example.com","https://google.com/search?q=test"])
30 |
31 | class Back(SharedBaseModel):
32 | pass
33 |
34 | class Forward(SharedBaseModel):
35 | pass
36 |
37 | class Key(SharedBaseModel):
38 | keys:str = Field(...,description="Keyboard key or key combination to press (supports modifiers like Control, Alt, Shift)",examples=["Enter","Control+A","Escape","Tab","Control+C"])
39 | times:int = Field(description="Number of times to repeat the key press sequence",examples=[1,2,3],default=1)
40 |
41 | class Download(SharedBaseModel):
42 | url:str = Field(...,description="Direct URL of the file to download (supports various file types: PDF, images, videos, documents)",examples=["https://www.example.com/document.pdf","https://site.com/image.jpg"])
43 | filename:str=Field(...,description="Local filename to save the downloaded file as (include file extension)",examples=["document.pdf","image.jpg","data.xlsx"])
44 |
45 | class Scrape(SharedBaseModel):
46 | pass
47 |
48 | class Tab(SharedBaseModel):
49 | mode:Literal['open','close','switch'] = Field(...,description="Tab operation: 'open' creates new tab, 'close' closes current tab, 'switch' changes to existing tab",examples=['open','close','switch'])
50 | tab_index:int = Field(description="Zero-based index of the tab to switch to (only required for 'switch' mode)",examples=[0,1,2],default=None)
51 |
52 | class Upload(SharedBaseModel):
53 | index:int = Field(...,description="Index of the file input element to upload files to",examples=[0])
54 | filenames:list[str] = Field(...,description="List of filenames to upload from the ./uploads directory (supports single or multiple files)",examples=[["document.pdf"],["image1.jpg","image2.png"]])
55 |
56 | class Menu(SharedBaseModel):
57 | index:int = Field(...,description="Index of the dropdown/select element to interact with",examples=[0])
58 | labels:list[str] = Field(...,description="List of visible option labels to select from the dropdown menu (supports single or multiple selection)",examples=[["BMW"],["Option 1","Option 2"]])
59 |
60 | class Script(SharedBaseModel):
61 | script:str = Field(...,description="The JavaScript code to execute in the current webpage to scrape data. Make sure the script is well-formatted",examples=["console.log('Hello, world!')"])
62 |
63 | class HumanInput(SharedBaseModel):
64 | prompt: str = Field(..., description="Clear question or instruction to ask the human user when assistance is needed", examples=["Please enter the OTP code sent to your phone", "What is your preferred payment method?", "Please solve this CAPTCHA"])
--------------------------------------------------------------------------------
/src/agent/web/dom/__init__.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.dom.views import DOMElementNode, DOMTextualNode, ScrollElementNode, DOMState, CenterCord, BoundingBox
2 | from playwright.async_api import Page, Frame
3 | from typing import TYPE_CHECKING
4 | from asyncio import sleep
5 |
6 | if TYPE_CHECKING:
7 | from src.agent.web.context import Context
8 |
9 | class DOM:
10 | def __init__(self, context:'Context'):
11 | self.context=context
12 |
13 | async def get_state(self,use_vision:bool=False,freeze:bool=False)->tuple[str|None,DOMState]:
14 | '''Get the state of the webpage.'''
15 | try:
16 | if freeze:
17 | await sleep(5)
18 | with open('./src/agent/web/dom/script.js') as f:
19 | script=f.read()
20 | page=await self.context.get_current_page()
21 | await page.wait_for_load_state('domcontentloaded',timeout=10*1000)
22 | await self.context.execute_script(page,script)
23 | #Access from frames
24 | frames=page.frames
25 | interactive_nodes,informative_nodes,scrollable_nodes=await self.get_elements(frames=frames)
26 | if use_vision:
27 | # Add bounding boxes to the interactive elements
28 | boxes=map(lambda node:node.bounding_box.to_dict(),interactive_nodes)
29 | await self.context.execute_script(page,'boxes=>{mark_page(boxes)}',list(boxes))
30 | screenshot=await self.context.get_screenshot(save_screenshot=False)
31 | # Remove bounding boxes from the interactive elements
32 | if freeze:
33 | await sleep(10)
34 | await sleep(0.1)
35 | await self.context.execute_script(page,'unmark_page()')
36 | else:
37 | screenshot=None
38 | except Exception as e:
39 | print(f"Failed to get elements from page: {page.url}\nError: {e}")
40 | interactive_nodes,informative_nodes,scrollable_nodes=[],[],[]
41 | screenshot=None
42 | selector_map=dict(enumerate(interactive_nodes+scrollable_nodes))
43 | return (screenshot,DOMState(interactive_nodes=interactive_nodes,informative_nodes=informative_nodes,scrollable_nodes=scrollable_nodes,selector_map=selector_map))
44 |
45 | async def get_elements(self,frames:list[Frame|Page])->tuple[list[DOMElementNode],list[DOMTextualNode],list[ScrollElementNode]]:
46 | '''Get the interactive elements of the webpage.'''
47 | interactive_elements,informative_elements,scrollable_elements=[],[],[]
48 | with open('./src/agent/web/dom/script.js') as f:
49 | script=f.read()
50 | try:
51 | for index,frame in enumerate(frames):
52 | if frame.is_detached() or frame.url=='about:blank':
53 | continue
54 | # print(f"Getting elements from frame: {frame.url}")
55 | await self.context.execute_script(frame,script) # Inject JS
56 | #index=0 means Main Frame
57 | if index>0 and not await self.context.is_frame_visible(frame=frame):
58 | continue
59 | # print(f"Getting elements from frame: {frame.url}")
60 | await self.context.execute_script(frame,script)
61 | nodes:dict=await self.context.execute_script(frame,'getElements()')
62 | element_nodes,textual_nodes,scrollable_nodes=nodes.values()
63 | if index>0:
64 | frame_element =await frame.frame_element()
65 | frame_xpath=await self.context.execute_script(frame,'(frame_element)=>getXPath(frame_element)',frame_element)
66 | else:
67 | frame_xpath=''
68 | viewport =await self.context.get_viewport()
69 | for element in element_nodes:
70 | element_xpath=element.get('xpath')
71 | node=DOMElementNode(**{
72 | 'tag':element.get('tag'),
73 | 'role':element.get('role'),
74 | 'name':element.get('name'),
75 | 'attributes':element.get('attributes'),
76 | 'center':CenterCord(**element.get('center')),
77 | 'bounding_box':BoundingBox(**element.get('box')),
78 | 'xpath':{'frame':frame_xpath,'element':element_xpath},
79 | 'viewport':viewport
80 | })
81 | interactive_elements.append(node)
82 |
83 | for element in scrollable_nodes:
84 | element_xpath=element.get('xpath')
85 | node=ScrollElementNode(**{
86 | 'tag':element.get('tag'),
87 | 'role':element.get('role'),
88 | 'name':element.get('name'),
89 | 'attributes':element.get('attributes'),
90 | 'xpath':{'frame':frame_xpath,'element':element_xpath},
91 | 'viewport':viewport
92 | })
93 | scrollable_elements.append(node)
94 |
95 | for element in textual_nodes:
96 | element_xpath=element.get('xpath')
97 | node=DOMTextualNode(**{
98 | 'tag':element.get('tag'),
99 | 'role':element.get('role'),
100 | 'content':element.get('content'),
101 | 'center':CenterCord(**element.get('center')),
102 | 'xpath':{'frame':frame_xpath,'element':element_xpath},
103 | 'viewport':viewport
104 | })
105 | informative_elements.append(node)
106 |
107 | except Exception as e:
108 | print(f"Failed to get elements from frame: {frame.url}\nError: {e}")
109 | return interactive_elements,informative_elements,scrollable_elements
110 |
--------------------------------------------------------------------------------
/src/memory/episodic/__init__.py:
--------------------------------------------------------------------------------
1 | from src.memory.episodic.utils import read_markdown_file
2 | from src.memory.episodic.views import Memory,Memories
3 | from src.message import SystemMessage,HumanMessage
4 | from src.inference import BaseInference
5 | from src.message import BaseMessage
6 | from src.memory import BaseMemory
7 | from src.router import LLMRouter
8 | from termcolor import colored
9 | from uuid import uuid4
10 | import json
11 |
12 | with open('./src/memory/episodic/routes.json','r') as f:
13 | routes=json.load(f)
14 |
15 | class EpisodicMemory(BaseMemory):
16 | def __init__(self,knowledge_base:str='knowledge_base.json',llm:BaseInference=None,verbose=False):
17 | self.memories:Memories=[]
18 | super().__init__(knowledge_base=knowledge_base,llm=llm,verbose=verbose)
19 |
20 | def router(self,conversation:list[BaseMessage]):
21 | instructions=[
22 | 'Go through the conversation and revelant memories to determine which route to take.',
23 | 'Be proactive in your choice and perform this task carefully and with utmost accuracy.'
24 | ]
25 | router=LLMRouter(instructions=instructions,routes=routes,llm=self.llm,verbose=False)
26 | route=router.invoke(f'### Revelant Memories from Knowledge Base:\n{self.memories.to_string()}\n### Conversation:\n{self.conversation_to_text(conversation)}')
27 | return route
28 |
29 | def store(self, conversation: list[BaseMessage]):
30 | route=self.router(conversation)
31 | if route=='ADD':
32 | self.add_memory(conversation)
33 | elif route=='UPDATE':
34 | self.update_memory(conversation)
35 | elif route=='REPLACE':
36 | self.replace_memory(conversation)
37 | else:
38 | self.idle_memory(conversation)
39 |
40 | def idle_memory(self,conversation:list[BaseMessage]):
41 | if self.verbose:
42 | print(f'{colored(f'Idle memory:',color='yellow',attrs=['bold'])}\n{json.dumps(self.memories.all(),indent=2)}')
43 | return None
44 |
45 | def add_memory(self,conversation:list[BaseMessage]):
46 | system_prompt=read_markdown_file('src/memory/episodic/prompt/add.md')
47 | text_conversation=self.conversation_to_text(conversation)
48 | user_prompt=f'### Conversation:\n{text_conversation}'
49 | messages=[SystemMessage(system_prompt),HumanMessage(user_prompt)]
50 | memory=self.llm.invoke(messages,model=Memory)
51 | if self.verbose:
52 | print(f'{colored(f'Adding memory to Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(memory.to_dict(),indent=2)}')
53 | with open(f'./memory_data/{self.knowledge_base}','r+') as f:
54 | knowledge_base:list[dict] = json.load(f)
55 | knowledge_base.append(memory.model_dump())
56 | f.seek(0)
57 | json.dump(knowledge_base, f, indent=2)
58 |
59 | def update_memory(self,conversation:list[BaseMessage]):
60 | system_prompt=read_markdown_file('src/memory/episodic/prompt/update.md')
61 | text_conversation=self.conversation_to_text(conversation)
62 | user_prompt=f'### Revelant memories from Knowledge Base:\n{self.memories.to_string()}\n### Conversation:\n{text_conversation}'
63 | messages=[SystemMessage(system_prompt),HumanMessage(user_prompt)]
64 | memory:Memory=self.llm.invoke(messages,model=Memory)
65 | if self.verbose:
66 | print(f'{colored(f'Updated memory from Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(memory.to_dict(),indent=2)}')
67 | with open(f'./memory_data/{self.knowledge_base}','r+') as f:
68 | knowledge_base = [Memory.model_validate(memory) for memory in json.loads(f)]
69 | memory_ids=[memory.get('id') for memory in self.memories.all()]
70 | updated_knowledge_base=list(filter(lambda memory:memory.id not in memory_ids,knowledge_base))
71 | updated_knowledge_base.extend(memory)
72 | f.seek(0)
73 | json.dump(updated_knowledge_base, f, indent=2)
74 | f.truncate()
75 |
76 | def replace_memory(self,conversation:list[BaseMessage]):
77 | system_prompt=read_markdown_file('src/memory/episodic/prompt/replace.md')
78 | text_conversation=self.conversation_to_text(conversation)
79 | user_prompt=f'### Conversation:\n{text_conversation}'
80 | messages=[SystemMessage(system_prompt),HumanMessage(user_prompt)]
81 | memory:Memory=self.llm.invoke(messages,model=Memory)
82 | if self.verbose:
83 | print(f'{colored(f'Replacing memory from Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(memory.to_dict(),indent=2)}')
84 | with open(f'./memory_data/{self.knowledge_base}','r+') as f:
85 | knowledge_base = [Memory.model_validate(memory) for memory in json.loads(f)]
86 | memory_ids=[memory.id for memory in self.memories.all()]
87 | updated_knowledge_base=list(filter(lambda memory:memory.id not in memory_ids,knowledge_base))
88 | updated_knowledge_base.append(memory)
89 | f.seek(0)
90 | json.dump([memory.model_dump() for memory in updated_knowledge_base], f, indent=2)
91 | f.truncate()
92 |
93 | def retrieve(self, query: str)->list[dict]:
94 | memories=[memory for memory in self.memories]
95 | system_prompt=read_markdown_file('src/memory/episodic/prompt/retrieve.md')
96 | user_prompt=f'### Query: {query}\n Now, select the memories those are relevant to solve the query.'
97 | messages=[SystemMessage(system_prompt.format(memories=memories)),HumanMessage(user_prompt)]
98 | memories=self.llm.invoke(messages,model=Memories)
99 | self.memories=memories
100 | if self.verbose:
101 | print(f'{colored(f'Retrieved memories from Knowledge Base:',color='yellow',attrs=['bold'])}\n{json.dumps(self.memories.all(),indent=2)}')
102 | return self.memories
103 |
104 | def attach_memory(self,system_prompt:str)->str:
105 | episodic_prompt=read_markdown_file('src/memory/episodic/prompt/memory.md')
106 | memory_prompt=episodic_prompt.format(memories=self.memories.to_string())
107 | return f'{system_prompt}\n\n{memory_prompt}'
108 |
109 |
--------------------------------------------------------------------------------
/src/agent/web/prompt/system.md:
--------------------------------------------------------------------------------
1 | # 🕸️Web Navigator
2 |
3 | You are Web Navigator designed by CursorTouch to solve the web related queries given by the USER in the .
4 |
5 | The current date is {current_datetime}
6 |
7 | Web Navigator can perform deep research. For the tasks that requires more contextual information perform research on that area of the topic. This can be performed in the intermediate stages or in the beginning itself and continue solving the task.
8 |
9 | Web Navigator can go both in-depth and breath on any given topic by looking through different sources, articles, blogs, ...etc. and this is an inheritant feature in deep research.
10 |
11 | Web Navigator enjoys helping the user to achieve the .
12 |
13 | Additional Instructions:
14 |
15 | {instructions}
16 |
17 | Available Tools:
18 |
19 | {tools_prompt}
20 |
21 | IMPORTANT: Only use tools that is available. Never hallucinate using tools.
22 |
23 | ## System Information:
24 |
25 | - **Operating System:** {os}
26 | - **Browser:** {browser}
27 | - **Home Directory:** {home_dir}
28 | - **Downloads Folder:** {downloads_dir}
29 |
30 | At every step, Web Agent will be given the state:
31 |
32 | ```xml
33 |
34 |
35 | Current Step: How many steps over
36 | Max. Steps: Max. steps allowed with in which, solve the task
37 | Action Reponse : Result of executing the previous action
38 |
39 |
40 | [Begin of Tab Info]
41 | Current Tab: The info related to current tab agent is working on.
42 | Open Tabs: The info related to other tabs those are open in the browser.
43 | [End of Tab Info]
44 |
45 | [Begin of Viewport]
46 | List of Interactive Elements: the interactable elements on the current tab like buttons,links and more.
47 | List of Scrollable Elements: these elements enable the agent to scroll on specific sections of the webpage.
48 | List of Informative Elements: these elements provide the text in the webpage.
49 | [End of Viewport]
50 |
51 |
52 | The ultimate goal for Web Navigator given by the user, use it to track progress.
53 |
54 |
55 | ```
56 |
57 | Web Navigator must follow the following rules while browsing the web:
58 |
59 | 1. ALWAYS start solving the given query using the appropirate search domains like google, youtube, wikipaedia, twitter ...etc.
60 | 2. When performing deep research make sure conduct it in a seperate tab using `Tab Tool` and not on the current working tab.
61 | 3. If any banners or ads those are obstructing the way close it and accept cookies if you see in the page.
62 | 4. If a captcha appears, attempt solving it if possible or else use fallback strategies (ex: go back, alternative site).
63 | 5. You can scroll through specific sections of the webpage if there are Scrollable Elements to get relevant content from those sections.
64 | 6. Develop search queries that are clear and optimistic to the .
65 | 7. To scrape the entire webpage use the `Scrape Tool`. It would include all the text and links present in the page.
66 |
67 | Web Navigator must follow the following rules for better reasoning and planning in :
68 |
69 | 1. Use the recent steps to track the progress and context towards .
70 | 2. Incorporate , , , screenshot (if available) in your reasoning process and explain what you want to achieve next from based on the current state.
71 | 3. You can create plan in this stage to clearly define your objectives to achieve and even self-reflect to correct yourself from mistakes.
72 | 4. Analysis whether are you stuck at same goal for few steps. If so, try alternative methods.
73 | 5. When you are ready to finish, state you are preparing answer the user by gathering the findings you got and then use the `Done Tool`.
74 | 6. Explicitly judge the effectiveness of the previous action and keep it in .
75 | 7. Valuable information gained so far will be present in use it as needed. Use this information to connect the dots to gain new insights.
76 |
77 | Web Navigator must follow the following rules during the agentic loop:
78 |
79 | 1. Start by `GoTo Tool` going to the current search domain.
80 | 2. Use `Done Tool` when you have performed/completed the ultimate task, this include sufficient knowledge gained from browsing the internet. This tool provides you an opportunity to terminate and share your findings with the user.
81 | 3. The contains elements within the viewport only are listed. Use `Scroll Tool` if you suspect relevant content is offscreen which you want to interact with. Scroll ONLY if there is more content above or below the webpage.
82 | 4. When browsing especially in search engines keep an eye on the auto suggestions that pops up under the input field.
83 | 5. If the page isn't fully loaded, use `Wait Tool` to wait and if any changes are not seen in the webpage after performing an action then wait.
84 | 6. For clicking only use `Click Tool` and for clicking and typing use `Type Tool`.
85 | 7. When you respond provide thorough, well-detailed explanations of all findings and also mention the sources you referred based on the .
86 | 8. When clicking on certain links using `Click Tool` then sometimes the site opens in a new tab then the shall be w.r.t to this tab.
87 | 9. Don't caught stuck in loops while solving the given the task. Each step is an attempt reach the goal.
88 | 10. If the query includes specific information like location, size, price then efficiently apply those filter while in the webpage.
89 | 11. NEVER close the last tab on the browser (the browser will close automatically).
90 | 12. You can ask the user for clarification or more data to continue using `Human Tool`.
91 | 13. The contains the information gained from the internet and essential context this included the data from such as credentials.
92 | 14. Remember to complete the task within `{max_iteration} steps` and ALWAYS output 1 reasonable action per step.
93 |
94 | Web Navigator must follow the following rules for :
95 |
96 | 1. ALWAYS remember solving the is the ultimate agenda.
97 | 2. Analysis the query, understand its complexity and break it into atomic subtasks.
98 | 3. If the task contains explict steps or instructions to follow that with high priority.
99 | 4. Always look for the latest information for the unless explicity specified.
100 | 5. You can do deep research to understand more on the topic to gain more insight for .
101 | 6. If additional instructions are given pay a good attention to that and act accordingly.
102 | 7. Give atmost importance to the user preference.
103 |
104 | Web Navigator must follow the following communication guidelines:
105 |
106 | 1. Maintain professional yet conversational tone.
107 | 2. The response highlight indepth findings and explained in detail.
108 | 3. Highlight key insights for the .
109 | 4. Format the responses in clean markdown format.
110 | 5. Only give verified information to the USER.
111 |
112 | ALWAYS respond exclusively in the following XML format:
113 |
114 | ```xml
115 |
122 | ```
123 |
124 | Begin!!!
125 |
--------------------------------------------------------------------------------
/src/inference/anthropic.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage
2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
3 | from requests import RequestException,HTTPError,ConnectionError
4 | from ratelimit import limits,sleep_and_retry
5 | from httpx import Client,AsyncClient
6 | from src.inference import BaseInference,Token
7 | from pydantic import BaseModel
8 | from typing import Generator
9 | from typing import Literal
10 | from pathlib import Path
11 | from json import loads
12 | from uuid import uuid4
13 | import mimetypes
14 | import requests
15 |
16 | class ChatAnthropic(BaseInference):
17 | @sleep_and_retry
18 | @limits(calls=15,period=60)
19 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
20 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel:
21 | self.headers.update({
22 | 'x-api-key': self.api_key,
23 | "anthropic-version": "2023-06-01",
24 | })
25 | headers=self.headers
26 | temperature=self.temperature
27 | url=self.base_url or "https://api.anthropic.com/v1/messages"
28 | contents=[]
29 | system_instruct=None
30 | for message in messages:
31 | if isinstance(message,(HumanMessage,AIMessage)):
32 | contents.append(message.to_dict())
33 | elif isinstance(message,ImageMessage):
34 | text,image=message.content
35 | contents.append([
36 | {
37 | 'role':'user',
38 | 'content':[
39 | {
40 | 'type':'text',
41 | 'text':text
42 | },
43 | {
44 | 'type':'image',
45 | 'source':{
46 | 'type':'base64',
47 | 'media_type':'image/png',
48 | 'data':image
49 | }
50 | }
51 | ]
52 | }
53 | ])
54 | elif isinstance(message,SystemMessage):
55 | system_instruct=self.structured(message,model) if model else message.content
56 | else:
57 | raise Exception("Invalid Message")
58 |
59 | payload={
60 | "model": self.model,
61 | "messages": contents,
62 | "temperature": temperature,
63 | "response_format": {
64 | "type": "json_object" if json or model else "text"
65 | },
66 | "stream":False,
67 | }
68 | if self.tools:
69 | payload["tools"]=[{
70 | 'type':'function',
71 | 'function':{
72 | 'name':tool.name,
73 | 'description':tool.description,
74 | 'input_schema':tool.schema
75 | }
76 | } for tool in self.tools]
77 | if system_instruct:
78 | payload['system']=system_instruct
79 | try:
80 | with Client() as client:
81 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
82 | json_object=response.json()
83 | # print(json_object)
84 | if json_object.get('error'):
85 | raise HTTPError(json_object['error']['message'])
86 | message = json_object['content'][0]
87 | usage_metadata=json_object['usage']
88 | input,output,total=usage_metadata['input_tokens'],usage_metadata['output_tokens']
89 | total=input+output
90 | self.tokens=Token(input=input,output=output,total=total)
91 | if model:
92 | return model.model_validate_json(message.get('text'))
93 | if json:
94 | return AIMessage(loads(message.get('text')))
95 | if message.get('content'):
96 | return AIMessage(message.get('text'))
97 | else:
98 | tool_call=message
99 | return ToolMessage(id= tool_call['id'] or str(uuid4()),name=tool_call['name'],args=tool_call['input'])
100 | except HTTPError as err:
101 | err_object=loads(err.response.text)
102 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
103 | except ConnectionError as err:
104 | print(err)
105 | exit()
106 |
107 | @sleep_and_retry
108 | @limits(calls=15,period=60)
109 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
110 | async def async_invoke(self, messages: list[BaseMessage], json: bool = False, model: BaseModel = None) -> AIMessage | ToolMessage | BaseModel:
111 | self.headers.update({
112 | 'x-api-key': self.api_key,
113 | "anthropic-version": "2023-06-01",
114 | })
115 | headers = self.headers
116 | temperature = self.temperature
117 | url = self.base_url or "https://api.anthropic.com/v1/messages"
118 | contents = []
119 | system_instruct = None
120 |
121 | for message in messages:
122 | if isinstance(message, (HumanMessage, AIMessage)):
123 | contents.append(message.to_dict())
124 | elif isinstance(message, ImageMessage):
125 | text, image = message.content
126 | contents.append([
127 | {
128 | 'role': 'user',
129 | 'content': [
130 | {
131 | 'type': 'text',
132 | 'text': text
133 | },
134 | {
135 | 'type': 'image',
136 | 'source': {
137 | 'type': 'base64',
138 | 'media_type': 'image/png',
139 | 'data': image
140 | }
141 | }
142 | ]
143 | }
144 | ])
145 | elif isinstance(message, SystemMessage):
146 | system_instruct = self.structured(message, model) if model else message.content
147 | else:
148 | raise Exception("Invalid Message")
149 |
150 | payload = {
151 | "model": self.model,
152 | "messages": contents,
153 | "temperature": temperature,
154 | "response_format": {
155 | "type": "json_object" if json or model else "text"
156 | },
157 | "stream": False,
158 | }
159 | if self.tools:
160 | payload["tools"] = [{
161 | 'type': 'function',
162 | 'function': {
163 | 'name': tool.name,
164 | 'description': tool.description,
165 | 'input_schema': tool.schema
166 | }
167 | } for tool in self.tools]
168 | if system_instruct:
169 | payload['system'] = system_instruct
170 |
171 | try:
172 | async with AsyncClient() as client:
173 | response = await client.post(url, json=payload, headers=headers)
174 | response.raise_for_status()
175 | json_object = response.json()
176 | if json_object.get('error'):
177 | raise HTTPError(json_object['error']['message'])
178 | message = json_object['content'][0]
179 | usage_metadata = json_object['usage']
180 | input, output= usage_metadata['input_tokens'], usage_metadata['output_tokens']
181 | total=input+output
182 | self.tokens = Token(input=input, output=output, total=total)
183 | if model:
184 | return model.model_validate_json(message.get('text'))
185 | if json:
186 | return AIMessage(loads(message.get('text')))
187 | if message.get('content'):
188 | return AIMessage(message.get('text'))
189 | else:
190 | tool_call = message
191 | return ToolMessage(id=tool_call['id'] or str(uuid4()), name=tool_call['name'], args=tool_call['input'])
192 | except HTTPError as err:
193 | err_object = loads(err.response.text)
194 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
195 | except Exception as err:
196 | print(err)
197 |
198 | @sleep_and_retry
199 | @limits(calls=15,period=60)
200 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
201 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]:
202 | pass
203 |
204 | def available_models(self):
205 | url='https://api.groq.com/openai/v1/models'
206 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
207 | headers=self.headers
208 | response=requests.get(url=url,headers=headers)
209 | response.raise_for_status()
210 | models=response.json()
211 | return [model['id'] for model in models['data'] if model['active']]
--------------------------------------------------------------------------------
/src/inference/nvidia.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage
2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
3 | from requests import RequestException,HTTPError,ConnectionError
4 | from ratelimit import limits,sleep_and_retry
5 | from httpx import Client,AsyncClient
6 | from src.inference import BaseInference,Token
7 | from pydantic import BaseModel
8 | from typing import Generator
9 | from json import loads
10 | from uuid import uuid4
11 | import requests
12 |
13 | class ChatNvidia(BaseInference):
14 | @sleep_and_retry
15 | @limits(calls=15,period=60)
16 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
17 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel:
18 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
19 | headers=self.headers
20 | temperature=self.temperature
21 | url=self.base_url or "https://integrate.api.nvidia.com/v1/chat/completions"
22 | contents=[]
23 | for message in messages:
24 | if isinstance(message,SystemMessage):
25 | if model:
26 | message.content=self.structured(message,model)
27 | contents.append(message.to_dict())
28 | if isinstance(message,(HumanMessage,AIMessage)):
29 | contents.append(message.to_dict())
30 | if isinstance(message,ImageMessage):
31 | text,image=message.content
32 | contents.append([
33 | {
34 | 'role':'user',
35 | 'content':[
36 | {
37 | 'type':'text',
38 | 'text':text
39 | },
40 | {
41 | 'type':'image_url',
42 | 'image_url':{
43 | 'url':image
44 | }
45 | }
46 | ]
47 | }
48 | ])
49 |
50 | payload={
51 | "model": self.model,
52 | "messages": contents,
53 | "temperature": temperature,
54 | "response_format": {
55 | "type": "json_object" if json or model else "text"
56 | },
57 | "stream":False,
58 | }
59 | if self.tools:
60 | payload["tools"]=[{
61 | 'type':'function',
62 | 'function':{
63 | 'name':tool.name,
64 | 'description':tool.description,
65 | 'parameters':tool.schema
66 | }
67 | } for tool in self.tools]
68 | try:
69 | with Client() as client:
70 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
71 | json_object=response.json()
72 | # print(json_object)
73 | if json_object.get('error'):
74 | raise HTTPError(json_object['error']['message'])
75 | message=json_object['choices'][0]['message']
76 | usage_metadata=json_object['usage']
77 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
78 | self.tokens=Token(input=input,output=output,total=total)
79 | if model:
80 | return model.model_validate_json(message.get('content'))
81 | if json:
82 | return AIMessage(loads(message.get('content')))
83 | if message.get('content'):
84 | return AIMessage(message.get('content'))
85 | else:
86 | tool_call=message.get('tool_calls')[0]['function']
87 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
88 | except HTTPError as err:
89 | err_object=loads(err.response.text)
90 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
91 | except ConnectionError as err:
92 | print(err)
93 | exit()
94 |
95 | @sleep_and_retry
96 | @limits(calls=15,period=60)
97 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
98 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel:
99 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
100 | headers=self.headers
101 | temperature=self.temperature
102 | url=self.base_url or "https://integrate.api.nvidia.com/v1/chat/completions"
103 | contents=[]
104 | for message in messages:
105 | if isinstance(message,SystemMessage):
106 | if model:
107 | message.content=self.structured(message,model)
108 | contents.append(message.to_dict())
109 | if isinstance(message,(HumanMessage,AIMessage)):
110 | contents.append(message.to_dict())
111 | if isinstance(message,ImageMessage):
112 | text,image=message.content
113 | contents.append([
114 | {
115 | 'role':'user',
116 | 'content':[
117 | {
118 | 'type':'text',
119 | 'text':text
120 | },
121 | {
122 | 'type':'image_url',
123 | 'image_url':{
124 | 'url':image
125 | }
126 | }
127 | ]
128 | }
129 | ])
130 |
131 | payload={
132 | "model": self.model,
133 | "messages": contents,
134 | "temperature": temperature,
135 | "response_format": {
136 | "type": "json_object" if json or model else "text"
137 | },
138 | "stream":False,
139 | }
140 | if self.tools:
141 | payload["tools"]=[{
142 | 'type':'function',
143 | 'function':{
144 | 'name':tool.name,
145 | 'description':tool.description,
146 | 'parameters':tool.schema
147 | }
148 | } for tool in self.tools]
149 | try:
150 | async with AsyncClient() as client:
151 | response=await client.post(url=url,json=payload,headers=headers,timeout=None)
152 | json_object=response.json()
153 | # print(json_object)
154 | if json_object.get('error'):
155 | raise HTTPError(json_object['error']['message'])
156 | message=json_object['choices'][0]['message']
157 | usage_metadata=json_object['usage']
158 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
159 | self.tokens=Token(input=input,output=output,total=total)
160 | if model:
161 | return model.model_validate_json(message.get('content'))
162 | if json:
163 | return AIMessage(loads(message.get('content')))
164 | if message.get('content'):
165 | return AIMessage(message.get('content'))
166 | else:
167 | tool_call=message.get('tool_calls')[0]['function']
168 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
169 | except HTTPError as err:
170 | err_object=loads(err.response.text)
171 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
172 | except ConnectionError as err:
173 | print(err)
174 | exit()
175 |
176 | @sleep_and_retry
177 | @limits(calls=15,period=60)
178 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
179 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]:
180 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
181 | headers=self.headers
182 | temperature=self.temperature
183 | url=self.base_url or "https://integrate.api.nvidia.com/v1/chat/completions"
184 | messages=[message.to_dict() for message in messages]
185 | payload={
186 | "model": self.model,
187 | "messages": messages,
188 | "temperature": temperature,
189 | "response_format": {
190 | "type": "json_object" if json else "text"
191 | },
192 | "stream":True,
193 | }
194 | try:
195 | response=requests.post(url=url,json=payload,headers=headers)
196 | response.raise_for_status()
197 | chunks=response.iter_lines(decode_unicode=True)
198 | for chunk in chunks:
199 | chunk=chunk.replace('data: ','')
200 | if chunk and chunk!='[DONE]':
201 | delta=loads(chunk)['choices'][0]['delta']
202 | yield delta.get('content','')
203 | except HTTPError as err:
204 | err_object=loads(err.response.text)
205 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
206 | except ConnectionError as err:
207 | print(err)
208 | exit()
--------------------------------------------------------------------------------
/src/inference/open_router.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,HumanMessage,ImageMessage,SystemMessage,ToolMessage
2 | from requests import get,RequestException,HTTPError,ConnectionError
3 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
4 | from ratelimit import limits,sleep_and_retry
5 | from src.inference import BaseInference,Token
6 | from httpx import Client,AsyncClient
7 | from pydantic import BaseModel
8 | from typing import Literal
9 | from json import loads
10 | from uuid import uuid4
11 |
12 | class ChatOpenRouter(BaseInference):
13 | @sleep_and_retry
14 | @limits(calls=15,period=60)
15 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
16 | def invoke(self, messages: list[BaseMessage],json=False,model:BaseModel|None=None) -> AIMessage|ToolMessage|BaseModel:
17 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
18 | headers=self.headers
19 | temperature=self.temperature
20 | url=self.base_url or "https://openrouter.ai/api/v1/chat/completions"
21 | contents=[]
22 | for message in messages:
23 | if isinstance(message,SystemMessage):
24 | if model:
25 | message.content=self.structured(message,model)
26 | contents.append(message.to_dict())
27 | elif isinstance(message,(HumanMessage,AIMessage)):
28 | contents.append(message.to_dict())
29 | elif isinstance(message,ImageMessage):
30 | text,image=message.content
31 | # Fix: Don't wrap in an extra list
32 | contents.append({
33 | 'role':'user',
34 | 'content':[
35 | {
36 | 'type':'text',
37 | 'text':text
38 | },
39 | {
40 | 'type':'image_url',
41 | 'image_url':{
42 | 'url':image
43 | }
44 | }
45 | ]
46 | })
47 |
48 | payload={
49 | "model": self.model,
50 | "messages": contents,
51 | "temperature": temperature,
52 | "response_format": {
53 | "type": "json_object" if json or model else "text"
54 | },
55 | "stream":False,
56 | }
57 | if self.tools:
58 | payload["tools"]=[{
59 | 'type':'function',
60 | 'function':{
61 | 'name':tool.name,
62 | 'description':tool.description,
63 | 'parameters':tool.schema
64 | }
65 | } for tool in self.tools]
66 | try:
67 | with Client() as client:
68 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
69 |
70 | # Check HTTP status first
71 | response.raise_for_status()
72 |
73 | json_object=response.json()
74 | # print(json_object)
75 | if json_object.get('error'):
76 | raise HTTPError(json_object['error']['message'])
77 | message=json_object['choices'][0]['message']
78 | usage_metadata=json_object['usage']
79 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
80 | self.tokens=Token(input=input,output=output,total=total)
81 | if model:
82 | return model.model_validate_json(message.get('content'))
83 | if json:
84 | return AIMessage(loads(message.get('content')))
85 | if message.get('content'):
86 | return AIMessage(message.get('content'))
87 | else:
88 | tool_call=message.get('tool_calls')[0]['function']
89 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
90 | except HTTPError as err:
91 | # Fix: Proper error handling for HTTPError
92 | try:
93 | if hasattr(err, 'response') and err.response is not None:
94 | err_object = err.response.json()
95 | error_msg = err_object.get("error", {}).get("message", "Unknown API error")
96 | status_code = err.response.status_code
97 | print(f'\nError: {error_msg}\nStatus Code: {status_code}')
98 | else:
99 | print(f'\nHTTP Error: {str(err)}')
100 | except Exception as parse_err:
101 | print(f'\nHTTP Error: {str(err)} (Could not parse error response: {parse_err})')
102 | raise err # Re-raise instead of exit()
103 | except ConnectionError as err:
104 | print(f'\nConnection Error: {err}')
105 | raise err # Re-raise instead of exit()
106 | except Exception as err:
107 | print(f'\nUnexpected Error: {err}')
108 | raise err # Re-raise instead of exit()
109 |
110 | @sleep_and_retry
111 | @limits(calls=15,period=60)
112 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
113 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel:
114 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
115 | headers=self.headers
116 | temperature=self.temperature
117 | # Fix: Use OpenRouter URL instead of Groq URL
118 | url=self.base_url or "https://openrouter.ai/api/v1/chat/completions"
119 | contents=[]
120 | for message in messages:
121 | if isinstance(message,SystemMessage):
122 | if model:
123 | message.content=self.structured(message,model)
124 | contents.append(message.to_dict())
125 | elif isinstance(message,(HumanMessage,AIMessage)):
126 | contents.append(message.to_dict())
127 | elif isinstance(message,ImageMessage):
128 | text,image=message.content
129 | # Fix: Don't wrap in an extra list
130 | contents.append({
131 | 'role':'user',
132 | 'content':[
133 | {
134 | 'type':'text',
135 | 'text':text
136 | },
137 | {
138 | 'type':'image_url',
139 | 'image_url':{
140 | 'url':image
141 | }
142 | }
143 | ]
144 | })
145 |
146 | payload={
147 | "model": self.model,
148 | "messages": contents,
149 | "temperature": temperature,
150 | "response_format": {
151 | "type": "json_object" if json or model else "text"
152 | },
153 | "stream":False,
154 | }
155 | if self.tools:
156 | payload["tools"]=[{
157 | 'type':'function',
158 | 'function':{
159 | 'name':tool.name,
160 | 'description':tool.description,
161 | 'parameters':tool.schema
162 | }
163 | } for tool in self.tools]
164 | try:
165 | async with AsyncClient() as client:
166 | response=await client.post(url=url,json=payload,headers=headers,timeout=None)
167 |
168 | # Check HTTP status first
169 | response.raise_for_status()
170 |
171 | json_object=response.json()
172 | # print(json_object)
173 | if json_object.get('error'):
174 | raise HTTPError(json_object['error']['message'])
175 | message=json_object['choices'][0]['message']
176 | usage_metadata=json_object['usage']
177 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
178 | self.tokens=Token(input=input,output=output,total=total)
179 | if model:
180 | return model.model_validate_json(message.get('content'))
181 | if json:
182 | return AIMessage(loads(message.get('content')))
183 | if message.get('content'):
184 | return AIMessage(message.get('content'))
185 | else:
186 | tool_call=message.get('tool_calls')[0]['function']
187 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
188 | except HTTPError as err:
189 | # Fix: Proper error handling for HTTPError
190 | try:
191 | if hasattr(err, 'response') and err.response is not None:
192 | err_object = err.response.json()
193 | error_msg = err_object.get("error", {}).get("message", "Unknown API error")
194 | status_code = err.response.status_code
195 | print(f'\nError: {error_msg}\nStatus Code: {status_code}')
196 | else:
197 | print(f'\nHTTP Error: {str(err)}')
198 | except Exception as parse_err:
199 | print(f'\nHTTP Error: {str(err)} (Could not parse error response: {parse_err})')
200 | raise err # Re-raise instead of exit()
201 | except ConnectionError as err:
202 | print(f'\nConnection Error: {err}')
203 | raise err # Re-raise instead of exit()
204 | except Exception as err:
205 | print(f'\nUnexpected Error: {err}')
206 | raise err # Re-raise instead of exit()
207 |
208 | def stream(self, messages, json = False):
209 | pass
--------------------------------------------------------------------------------
/src/inference/mistral.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage
2 | from requests import RequestException,HTTPError,ConnectionError
3 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
4 | from ratelimit import limits,sleep_and_retry
5 | from src.inference import BaseInference,Token
6 | from httpx import Client,AsyncClient
7 | from pydantic import BaseModel
8 | from typing import Generator
9 | from json import loads
10 | from uuid import uuid4
11 | import requests
12 |
13 | class ChatMistral(BaseInference):
14 | @sleep_and_retry
15 | @limits(calls=15,period=60)
16 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
17 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel:
18 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
19 | headers=self.headers
20 | temperature=self.temperature
21 | url=self.base_url or "https://api.mistral.ai/v1/chat/completions"
22 | contents=[]
23 | for message in messages:
24 | if isinstance(message,SystemMessage):
25 | if model:
26 | message.content=self.structured(message,model)
27 | contents.append(message.to_dict())
28 | if isinstance(message,(HumanMessage,AIMessage)):
29 | contents.append(message.to_dict())
30 | if isinstance(message,ImageMessage):
31 | text,image_data=message.content
32 | contents.append([
33 | {
34 | 'role':'user',
35 | 'content':{
36 | {
37 | 'type':'text',
38 | 'text':text
39 | },
40 | {
41 | 'type':'image_url',
42 | 'image_url':image_data
43 | }
44 | }
45 | }
46 | ])
47 |
48 | payload={
49 | "model": self.model,
50 | "messages": contents,
51 | "temperature": temperature,
52 | "response_format": {
53 | "type": "json_object" if json or model else "text"
54 | },
55 | "stream":False,
56 | }
57 | if self.tools:
58 | payload["tools"]=[{
59 | 'type':'function',
60 | 'function':{
61 | 'name':tool.name,
62 | 'description':tool.description,
63 | 'parameters':tool.schema
64 | }
65 | } for tool in self.tools]
66 | try:
67 | with Client() as client:
68 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
69 | json_object=response.json()
70 | # print(json_object)
71 | if json_object.get('error'):
72 | raise Exception(json_object['error']['message'])
73 | message=json_object['choices'][0]['message']
74 | usage_metadata=json_object['usage']
75 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
76 | self.tokens=Token(input=input,output=output,total=total)
77 | if model:
78 | return model.model_validate_json(message.get('content'))
79 | if json:
80 | return AIMessage(loads(message.get('content')))
81 | if message.get('content'):
82 | return AIMessage(message.get('content'))
83 | else:
84 | tool_call=message.get('tool_calls')[0]['function']
85 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
86 | except HTTPError as err:
87 | err_object=loads(err.response.text)
88 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
89 | except ConnectionError as err:
90 | print(err)
91 | exit()
92 |
93 | @sleep_and_retry
94 | @limits(calls=15,period=60)
95 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
96 | async def async_invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel:
97 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
98 | headers=self.headers
99 | temperature=self.temperature
100 | url=self.base_url or "https://api.mistral.ai/v1/chat/completions"
101 | contents=[]
102 | for message in messages:
103 | if isinstance(message,SystemMessage):
104 | if model:
105 | message.content=self.structured(message,model)
106 | contents.append(message.to_dict())
107 | if isinstance(message,(HumanMessage,AIMessage)):
108 | contents.append(message.to_dict())
109 | if isinstance(message,ImageMessage):
110 | text,image_data=message.content
111 | contents.append([
112 | {
113 | 'role':'user',
114 | 'content':{
115 | {
116 | 'type':'text',
117 | 'text':text
118 | },
119 | {
120 | 'type':'image_url',
121 | 'image_url':image_data
122 | }
123 | }
124 | }
125 | ])
126 |
127 | payload={
128 | "model": self.model,
129 | "messages": contents,
130 | "temperature": temperature,
131 | "response_format": {
132 | "type": "json_object" if json or model else "text"
133 | },
134 | "stream":False,
135 | }
136 | if self.tools:
137 | payload["tools"]=[{
138 | 'type':'function',
139 | 'function':{
140 | 'name':tool.name,
141 | 'description':tool.description,
142 | 'parameters':tool.schema
143 | }
144 | } for tool in self.tools]
145 | try:
146 | async with AsyncClient() as client:
147 | response=await client.post(url=url,json=payload,headers=headers,timeout=None)
148 | json_object=response.json()
149 | # print(json_object)
150 | if json_object.get('error'):
151 | raise Exception(json_object['error']['message'])
152 | message=json_object['choices'][0]['message']
153 | usage_metadata=json_object['usage']
154 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
155 | self.tokens=Token(input=input,output=output,total=total)
156 | if model:
157 | return model.model_validate_json(message.get('content'))
158 | if json:
159 | return AIMessage(loads(message.get('content')))
160 | if message.get('content'):
161 | return AIMessage(message.get('content'))
162 | else:
163 | tool_call=message.get('tool_calls')[0]['function']
164 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
165 | except HTTPError as err:
166 | err_object=loads(err.response.text)
167 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
168 | except ConnectionError as err:
169 | print(err)
170 | exit()
171 |
172 | @sleep_and_retry
173 | @limits(calls=15,period=60)
174 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
175 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]:
176 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
177 | headers=self.headers
178 | temperature=self.temperature
179 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions"
180 | messages=[message.to_dict() for message in messages]
181 | payload={
182 | "model": self.model,
183 | "messages": messages,
184 | "temperature": temperature,
185 | "stream":True,
186 | }
187 | if json:
188 | payload["response_format"]={
189 | "type": "json_object"
190 | }
191 | try:
192 | response=requests.post(url=url,json=payload,headers=headers,timeout=None)
193 | response.raise_for_status()
194 | chunks=response.iter_lines(decode_unicode=True)
195 | for chunk in chunks:
196 | chunk=chunk.replace('data: ','')
197 | if chunk and chunk!='[DONE]':
198 | delta=loads(chunk)['choices'][0]['delta']
199 | yield delta.get('content','')
200 | except HTTPError as err:
201 | err_object=loads(err.response.text)
202 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
203 | except ConnectionError as err:
204 | print(err)
205 | exit()
206 |
207 | def available_models(self):
208 | url="https://api.mistral.ai/v1/models"
209 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
210 | headers=self.headers
211 | response=requests.get(url=url,headers=headers)
212 | response.raise_for_status()
213 | models=response.json()
214 | return [model['id'] for model in models['data']]
--------------------------------------------------------------------------------
/src/agent/web/context/__init__.py:
--------------------------------------------------------------------------------
1 | from playwright.async_api import Page,Browser as PlaywrightBrowser,Frame,ElementHandle,BrowserContext as PlaywrightContext
2 | from src.agent.web.context.config import IGNORED_URL_PATTERNS,RELEVANT_FILE_EXTENSIONS,RELEVANT_CONTEXT_TYPES
3 | from src.agent.web.browser.config import BROWSER_ARGS,SECURITY_ARGS,IGNORE_DEFAULT_ARGS
4 | from src.agent.web.context.views import BrowserSession,BrowserState,Tab
5 | from src.agent.web.context.config import ContextConfig
6 | from src.agent.web.dom.views import DOMElementNode
7 | from src.agent.web.browser import Browser
8 | from src.agent.web.dom import DOM
9 | from urllib.parse import urlparse
10 | from datetime import datetime
11 | from pathlib import Path
12 | from uuid import uuid4
13 | from os import getcwd
14 |
15 | class Context:
16 | def __init__(self,browser:Browser,config:ContextConfig=ContextConfig()):
17 | self.browser=browser
18 | self.config=config
19 | self.context_id=str(uuid4())
20 | self.session:BrowserSession=None
21 |
22 | async def __aenter__(self):
23 | await self.init_session()
24 | return self
25 |
26 | async def __aexit__(self, exc_type, exc_val, exc_tb):
27 | await self.close_session()
28 |
29 | async def close_session(self):
30 | if self.session is None:
31 | return None
32 | try:
33 | await self.session.context.close()
34 | except Exception as e:
35 | print('Context failed to close',e)
36 | finally:
37 | self.browser_context=None
38 |
39 | async def init_session(self):
40 | browser=await self.browser.get_playwright_browser()
41 | context=await self.setup_context(browser)
42 | if browser is not None: # The case whether is no user_data provided
43 | page=await context.new_page()
44 | else: # The case where the user_data is provided
45 | pages=context.pages
46 | if len(pages):
47 | page=pages[0]
48 | else:
49 | page=await context.new_page()
50 | state=await self.initial_state(page)
51 | self.session=BrowserSession(context,page,state)
52 |
53 | async def initial_state(self,page:Page):
54 | screenshot,dom_state=None,[]
55 | current_tab=Tab(0,page.url,await page.title(),page)
56 | tabs=[]
57 | state=BrowserState(current_tab=current_tab,tabs=tabs,screenshot=screenshot,dom_state=dom_state)
58 | return state
59 |
60 | async def update_state(self,use_vision:bool=False):
61 | dom=DOM(self)
62 | screenshot,dom_state=await dom.get_state(use_vision=use_vision)
63 | tabs=await self.get_all_tabs()
64 | current_tab=await self.get_current_tab()
65 | state=BrowserState(current_tab=current_tab,tabs=tabs,screenshot=screenshot,dom_state=dom_state)
66 | return state
67 |
68 | async def get_state(self,use_vision=False)->BrowserState:
69 | session=await self.get_session()
70 | state=await self.update_state(use_vision=use_vision)
71 | session.state=state
72 | return session.state
73 |
74 | async def get_session(self)->BrowserSession:
75 | if self.session is None:
76 | await self.init_session()
77 | return self.session
78 |
79 | async def get_current_page(self)->Page:
80 | session=await self.get_session()
81 | if session.current_page is None:
82 | raise ValueError("No current page found")
83 | return session.current_page
84 |
85 | async def setup_context(self,browser:PlaywrightBrowser|None=None)->PlaywrightContext:
86 | if self.browser.config.device is not None:
87 | parameters={**self.browser.playwright.devices.get(self.browser.config.device)}
88 | parameters.pop('default_browser_type',None)
89 | else:
90 | parameters={
91 | 'ignore_https_errors':self.config.disable_security,
92 | 'user_agent':self.config.user_agent,
93 | 'bypass_csp':self.config.disable_security,
94 | 'java_script_enabled':True,
95 | 'accept_downloads':True,
96 | 'no_viewport':True
97 | }
98 | if browser is not None:
99 | context=await browser.new_context(**parameters)
100 | with open('./src/agent/web/context/script.js') as f:
101 | script=f.read()
102 | await context.add_init_script(script)
103 | else:
104 | args=['--no-sandbox','--disable-blink-features=AutomationControlled','--disable-blink-features=IdleDetection','--no-infobars']
105 | parameters=parameters|{
106 | 'headless':self.browser.config.headless,
107 | 'slow_mo':self.browser.config.slow_mo,
108 | 'ignore_default_args': IGNORE_DEFAULT_ARGS,
109 | 'args': args+SECURITY_ARGS,
110 | 'user_data_dir': self.browser.config.user_data_dir,
111 | 'downloads_path': self.browser.config.downloads_dir,
112 | 'executable_path': self.browser.config.browser_instance_dir,
113 | }
114 | # browser is None if the user_data_dir is not None in the Browser class
115 | browser=self.browser.config.browser
116 | if browser=='chrome':
117 | context=await self.browser.playwright.chromium.launch_persistent_context(channel='chrome',**parameters)
118 | elif browser=='firefox':
119 | context=await self.browser.playwright.firefox.launch_persistent_context(**parameters)
120 | elif browser=='edge':
121 | context=await self.browser.playwright.chromium.launch_persistent_context(channel='msedge',**parameters)
122 | else:
123 | raise Exception('Invalid Browser Type')
124 | return context
125 |
126 | async def get_all_tabs(self)->list[Tab]:
127 | session=await self.get_session()
128 | pages=session.context.pages
129 | tabs:list[Tab]=[]
130 | for id,page in enumerate(pages):
131 | await page.wait_for_load_state('domcontentloaded')
132 | try:
133 | url=page.url
134 | title=await page.title()
135 | except Exception as e:
136 | print(f'Tab failed to load: {e}')
137 | continue
138 | tabs.append(Tab(id=id,url=url,title=title,page=page))
139 | return tabs
140 |
141 | async def get_current_tab(self)->Tab:
142 | tabs=await self.get_all_tabs()
143 | current_page=await self.get_current_page()
144 | return next((tab for tab in tabs if tab.page==current_page),None)
145 |
146 | async def get_selector_map(self)->dict[int,DOMElementNode]:
147 | session=await self.get_session()
148 | return session.state.dom_state.selector_map
149 |
150 | async def get_element_by_index(self,index:int)->DOMElementNode:
151 | selector_map=await self.get_selector_map()
152 | if index not in selector_map.keys():
153 | raise Exception(f'Element under index {index} not found')
154 | element=selector_map.get(index)
155 | return element
156 |
157 | async def get_handle_by_xpath(self,xpath:dict[str,str])->ElementHandle:
158 | frame=await self.get_frame_by_xpath(xpath)
159 | _,element_xpath=xpath.values()
160 | element=await frame.locator(f'xpath={element_xpath}').element_handle()
161 | return element
162 |
163 | async def get_frame_by_xpath(self,xpath:dict[str,str])->Frame:
164 | page=await self.get_current_page()
165 | frame_xpath,_=xpath.values()
166 | if frame_xpath: # handle elements from iframe
167 | frame=page.frame_locator(f'xpath={frame_xpath}')
168 | else: # handle elements from main frame
169 | frame=page.main_frame
170 | return frame
171 |
172 | async def execute_script(self,obj:Frame|Page,script:str,args:list=None,enable_handle:bool=False):
173 | if enable_handle:
174 | handle=await obj.evaluate_handle(script,args)
175 | return handle.as_element()
176 | return await obj.evaluate(script,args)
177 |
178 | async def get_viewport(self)->tuple[int,int]:
179 | page=await self.get_current_page()
180 | viewport:dict=await self.execute_script(page,'({width: window.innerWidth, height: window.innerHeight})')
181 | return(viewport.get('width'),viewport.get('height'))
182 |
183 | def is_ad_url(self,url:str)->bool:
184 | url_pattern=urlparse(url).netloc
185 | if not url_pattern:
186 | return True
187 | return any(pattern in url_pattern for pattern in IGNORED_URL_PATTERNS)
188 |
189 | async def is_frame_visible(self,frame:Frame)->bool:
190 | if frame.is_detached() or self.is_ad_url(frame.url):
191 | return False
192 | frame_element=await frame.frame_element()
193 | if frame_element is None:
194 | return False
195 | style=await frame_element.get_attribute('style')
196 | if style is not None:
197 | css:dict=self.inline_style_parser(style)
198 | if any([css.get('display')=='none',css.get('visibility')=='hidden']):
199 | return False
200 | bbox=await frame_element.bounding_box()
201 | if bbox is None:
202 | return False
203 | area=bbox.get('width')*bbox.get('height')
204 | if any([bbox.get('x')<0,bbox.get('y')<0,area<10]):
205 | return False
206 | return True
207 |
208 | async def get_screenshot(self,save_screenshot:bool=False,full_page:bool=False):
209 | page=await self.get_current_page()
210 | if save_screenshot:
211 | date_time=datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
212 | folder_path=Path(getcwd()).joinpath('./screenshots')
213 | folder_path.mkdir(parents=True,exist_ok=True)
214 | path=folder_path.joinpath(f'screenshot_{date_time}.jpeg')
215 | else:
216 | path=None
217 | await page.wait_for_timeout(2*1000)
218 | screenshot=await page.screenshot(path=path,full_page=full_page,animations='disabled',type='jpeg')
219 | return screenshot
220 |
221 | def inline_style_parser(self,style:str)->dict[str,str]:
222 | styles = {}
223 | if not style:
224 | return styles
225 | for rule in style.split(";"):
226 | if ":" in rule:
227 | prop, val = rule.split(":", 1)
228 | styles[prop.strip()] = val.strip()
229 | return styles
--------------------------------------------------------------------------------
/src/inference/ollama.py:
--------------------------------------------------------------------------------
1 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
2 | from src.message import AIMessage,BaseMessage,ToolMessage
3 | from requests import get,RequestException,ConnectionError
4 | from httpx import Client,AsyncClient,HTTPError
5 | from ratelimit import limits,sleep_and_retry
6 | from src.inference import BaseInference,Token
7 | from pydantic import BaseModel
8 | from typing import Generator
9 | from json import loads
10 | from uuid import uuid4
11 |
12 | class ChatOllama(BaseInference):
13 | @sleep_and_retry
14 | @limits(calls=15,period=60)
15 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
16 | def invoke(self,messages: list[BaseMessage],json=False,model:BaseModel=None)->AIMessage:
17 | headers=self.headers
18 | temperature=self.temperature
19 | url=self.base_url or "http://localhost:11434/api/chat"
20 | payload={
21 | "model": self.model,
22 | "messages": [message.to_dict() for message in messages],
23 | "options":{
24 | "temperature": temperature,
25 | },
26 | "stream":False
27 | }
28 | if json:
29 | payload['format']='json'
30 | if model:
31 | payload['format']=model.model_json_schema()
32 | if self.tools:
33 | payload["tools"]=[{
34 | 'type':'function',
35 | 'function':{
36 | 'name':tool.name,
37 | 'description':tool.description,
38 | 'parameters':tool.schema
39 | }
40 | } for tool in self.tools]
41 | try:
42 | with Client() as client:
43 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
44 | response.raise_for_status()
45 | json_object=response.json()
46 | message=json_object['message']
47 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count']
48 | self.tokens=Token(input=input,output=output,total=total)
49 | if model:
50 | return model.model_validate_json(message.get('content'))
51 | if json:
52 | return AIMessage(loads(message.get('content')))
53 | if message.get('content'):
54 | return AIMessage(message.get('content'))
55 | else:
56 | tool_call=message.get('tool_calls')[0]['function']
57 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
58 | except HTTPError as err:
59 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
60 |
61 | @sleep_and_retry
62 | @limits(calls=15,period=60)
63 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
64 | async def async_invoke(self,messages: list[BaseMessage],json=False,model:BaseModel=None)->AIMessage:
65 | headers=self.headers
66 | temperature=self.temperature
67 | url=self.base_url or "http://localhost:11434/api/chat"
68 | payload={
69 | "model": self.model,
70 | "messages": [message.to_dict() for message in messages],
71 | "options":{
72 | "temperature": temperature,
73 | },
74 | "stream":False
75 | }
76 | if json:
77 | payload['format']='json'
78 | if model:
79 | payload['format']=model.model_json_schema()
80 | if self.tools:
81 | payload["tools"]=[{
82 | 'type':'function',
83 | 'function':{
84 | 'name':tool.name,
85 | 'description':tool.description,
86 | 'parameters':tool.schema
87 | }
88 | } for tool in self.tools]
89 | try:
90 | async with AsyncClient() as client:
91 | response=await client.post(url=url,json=payload,headers=headers,timeout=None)
92 | response.raise_for_status()
93 | json_object=response.json()
94 | message=json_object['message']
95 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count']
96 | self.tokens=Token(input=input,output=output,total=total)
97 | if model:
98 | return model.model_validate_json(message.get('content'))
99 | if json:
100 | return AIMessage(loads(message.get('content')))
101 | if message.get('content'):
102 | return AIMessage(message.get('content'))
103 | else:
104 | tool_call=message.get('tool_calls')[0]['function']
105 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
106 | except HTTPError as err:
107 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
108 |
109 | @sleep_and_retry
110 | @limits(calls=15,period=60)
111 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
112 | def stream(self,messages: list[BaseMessage],json=False)->Generator[str,None,None]:
113 | headers=self.headers
114 | temperature=self.temperature
115 | url=self.base_url or "http://localhost:11434/api/chat"
116 | payload={
117 | "model": self.model,
118 | "messages": [message.to_dict() for message in messages],
119 | "options":{
120 | "temperature": temperature,
121 | },
122 | "stream":True
123 | }
124 | if json:
125 | payload['format']='json'
126 | try:
127 | with Client() as client:
128 | response=client.post(url=url,json=payload,headers=headers,stream=True,timeout=None)
129 | response.raise_for_status()
130 | chunks=response.iter_lines(decode_unicode=True)
131 | return (loads(chunk)['message']['content'] for chunk in chunks)
132 | except HTTPError as err:
133 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
134 | except ConnectionError as err:
135 | print(err)
136 | exit()
137 |
138 | def available_models(self):
139 | url='http://localhost:11434/api/tags'
140 | headers=self.headers
141 | response=get(url=url,headers=headers)
142 | response.raise_for_status()
143 | models=response.json()
144 | return [model['name'] for model in models['models']]
145 |
146 | class Ollama(BaseInference):
147 | @sleep_and_retry
148 | @limits(calls=15,period=60)
149 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
150 | def invoke(self, query:str,json=False,model:BaseModel=None)->AIMessage:
151 | headers=self.headers
152 | temperature=self.temperature
153 | url=self.base_url or "http://localhost:11434/api/generate"
154 | payload={
155 | "model": self.model,
156 | "prompt": query,
157 | "options":{
158 | "temperature": temperature,
159 | },
160 | "format":'json' if json else '',
161 | "stream":False
162 | }
163 | if json:
164 | payload['format']='json'
165 | if model:
166 | payload['format']=model.model_json_schema()
167 | try:
168 | with Client() as client:
169 | response=client.post(url=url,json=payload,headers=headers)
170 | response.raise_for_status()
171 | json_object=response.json()
172 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count']
173 | self.tokens=Token(input=input,output=output,total=total)
174 | if model:
175 | return model.model_validate_json(json_object.get('response'))
176 | if json:
177 | return AIMessage(loads(json_object.get('response')))
178 | return AIMessage(json_object.get('response'))
179 | except HTTPError as err:
180 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
181 |
182 | @sleep_and_retry
183 | @limits(calls=15,period=60)
184 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
185 | async def async_invoke(self, query:str,json=False,model:BaseModel=None)->AIMessage:
186 | headers=self.headers
187 | temperature=self.temperature
188 | url=self.base_url or "http://localhost:11434/api/generate"
189 | payload={
190 | "model": self.model,
191 | "prompt": query,
192 | "options":{
193 | "temperature": temperature,
194 | },
195 | "format":'json' if json else '',
196 | "stream":False
197 | }
198 | if json:
199 | payload['format']='json'
200 | if model:
201 | payload['format']=model.model_json_schema()
202 | try:
203 | async with AsyncClient() as client:
204 | response=await client.post(url=url,json=payload,headers=headers)
205 | response.raise_for_status()
206 | json_object=response.json()
207 | input,output,total=json_object['prompt_eval_count'],json_object['eval_count'],json_object['prompt_eval_count']+json_object['eval_count']
208 | self.tokens=Token(input=input,output=output,total=total)
209 | if model:
210 | return model.model_validate_json(json_object.get('response'))
211 | if json:
212 | return AIMessage(loads(json_object.get('response')))
213 | return AIMessage(json_object.get('response'))
214 | except HTTPError as err:
215 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
216 |
217 | @sleep_and_retry
218 | @limits(calls=15,period=60)
219 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
220 | def stream(self,query:str,json=False)->Generator[str,None,None]:
221 | headers=self.headers
222 | temperature=self.temperature
223 | url=self.base_url or "http://localhost:11434/api/generate"
224 | payload={
225 | "model": self.model,
226 | "prompt": query,
227 | "options":{
228 | "temperature": temperature,
229 | },
230 | "format":'json' if json else '',
231 | "stream":True
232 | }
233 | try:
234 | with Client() as client:
235 | response=client.post(url=url,json=payload,headers=headers,stream=True)
236 | response.raise_for_status()
237 | chunks=response.iter_lines(decode_unicode=True)
238 | return (loads(chunk)['response'] for chunk in chunks)
239 | except HTTPError as err:
240 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
241 | except ConnectionError as err:
242 | print(err)
243 | exit()
244 |
245 | def available_models(self):
246 | url='http://localhost:11434/api/tags'
247 | headers=self.headers
248 | response=get(url=url,headers=headers)
249 | response.raise_for_status()
250 | models=response.json()
251 | return [model['name'] for model in models['models']]
252 |
--------------------------------------------------------------------------------
/src/inference/openai.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage
2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
3 | from requests import RequestException,HTTPError,ConnectionError
4 | from ratelimit import limits,sleep_and_retry
5 | from httpx import Client,AsyncClient
6 | from src.inference import BaseInference,Token
7 | from pydantic import BaseModel
8 | from typing import Generator
9 | from typing import Literal
10 | from pathlib import Path
11 | from json import loads
12 | from uuid import uuid4
13 | import mimetypes
14 | import requests
15 |
16 | class ChatOpenAI(BaseInference):
17 | @sleep_and_retry
18 | @limits(calls=15,period=60)
19 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
20 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel:
21 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
22 | headers=self.headers
23 | temperature=self.temperature
24 | url=self.base_url or "https://api.openai.com/v1/chat/completions"
25 | contents=[]
26 | for message in messages:
27 | if isinstance(message,SystemMessage):
28 | if model:
29 | message.content=self.structured(message,model)
30 | contents.append(message.to_dict())
31 | if isinstance(message,(HumanMessage,AIMessage)):
32 | contents.append(message.to_dict())
33 | if isinstance(message,ImageMessage):
34 | text,image=message.content
35 | contents.append([
36 | {
37 | 'role':'user',
38 | 'content':[
39 | {
40 | 'type':'text',
41 | 'text':text
42 | },
43 | {
44 | 'type':'image_url',
45 | 'image_url':{
46 | 'url':image
47 | }
48 | }
49 | ]
50 | }
51 | ])
52 |
53 | payload={
54 | "model": self.model,
55 | "messages": contents,
56 | "temperature": temperature,
57 | "response_format": {
58 | "type": "json_object" if json or model else "text"
59 | },
60 | "stream":False,
61 | }
62 | if self.tools:
63 | payload["tools"]=[{
64 | 'type':'function',
65 | 'function':{
66 | 'name':tool.name,
67 | 'description':tool.description,
68 | 'parameters':tool.schema
69 | }
70 | } for tool in self.tools]
71 | try:
72 | with Client() as client:
73 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
74 | json_object=response.json()
75 | # print(json_object)
76 | if json_object.get('error'):
77 | raise HTTPError(json_object['error']['message'])
78 | message=json_object['choices'][0]['message']
79 | usage_metadata=json_object['usage']
80 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
81 | self.tokens=Token(input=input,output=output,total=total)
82 | if model:
83 | return model.model_validate_json(message.get('content'))
84 | if json:
85 | return AIMessage(loads(message.get('content')))
86 | if message.get('content'):
87 | return AIMessage(message.get('content'))
88 | else:
89 | tool_call=message.get('tool_calls')[0]['function']
90 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
91 | except HTTPError as err:
92 | err_object=loads(err.response.text)
93 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
94 | except ConnectionError as err:
95 | print(err)
96 | exit()
97 |
98 | @sleep_and_retry
99 | @limits(calls=15,period=60)
100 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
101 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel:
102 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
103 | headers=self.headers
104 | temperature=self.temperature
105 | url=self.base_url or "https://api.openai.com/v1/chat/completions"
106 | contents=[]
107 | for message in messages:
108 | if isinstance(message,SystemMessage):
109 | if model:
110 | message.content=self.structured(message,model)
111 | contents.append(message.to_dict())
112 | if isinstance(message,(HumanMessage,AIMessage)):
113 | contents.append(message.to_dict())
114 | if isinstance(message,ImageMessage):
115 | text,image=message.content
116 | contents.append([
117 | {
118 | 'role':'user',
119 | 'content':[
120 | {
121 | 'type':'text',
122 | 'text':text
123 | },
124 | {
125 | 'type':'image_url',
126 | 'image_url':{
127 | 'url':image
128 | }
129 | }
130 | ]
131 | }
132 | ])
133 |
134 | payload={
135 | "model": self.model,
136 | "messages": contents,
137 | "temperature": temperature,
138 | "response_format": {
139 | "type": "json_object" if json or model else "text"
140 | },
141 | "stream":False,
142 | }
143 | if self.tools:
144 | payload["tools"]=[{
145 | 'type':'function',
146 | 'function':{
147 | 'name':tool.name,
148 | 'description':tool.description,
149 | 'parameters':tool.schema
150 | }
151 | } for tool in self.tools]
152 | try:
153 | async with AsyncClient() as client:
154 | response=await client.post(url=url,json=payload,headers=headers,timeout=None)
155 | json_object=response.json()
156 | # print(json_object)
157 | if json_object.get('error'):
158 | raise HTTPError(json_object['error']['message'])
159 | message=json_object['choices'][0]['message']
160 | usage_metadata=json_object['usage']
161 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
162 | self.tokens=Token(input=input,output=output,total=total)
163 | if model:
164 | return model.model_validate_json(message.get('content'))
165 | if json:
166 | return AIMessage(loads(message.get('content')))
167 | if message.get('content'):
168 | return AIMessage(message.get('content'))
169 | else:
170 | tool_call=message.get('tool_calls')[0]['function']
171 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
172 | except HTTPError as err:
173 | err_object=loads(err.response.text)
174 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
175 | except ConnectionError as err:
176 | print(err)
177 | exit()
178 |
179 | @sleep_and_retry
180 | @limits(calls=15,period=60)
181 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
182 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]:
183 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
184 | headers=self.headers
185 | temperature=self.temperature
186 | url=self.base_url or "https://api.openai.com/v1/chat/completions"
187 | messages=[message.to_dict() for message in messages]
188 | payload={
189 | "model": self.model,
190 | "messages": messages,
191 | "temperature": temperature,
192 | "response_format": {
193 | "type": "json_object" if json else "text"
194 | },
195 | "stream":True,
196 | }
197 | try:
198 | response=requests.post(url=url,json=payload,headers=headers)
199 | response.raise_for_status()
200 | chunks=response.iter_lines(decode_unicode=True)
201 | for chunk in chunks:
202 | chunk=chunk.replace('data: ','')
203 | if chunk and chunk!='[DONE]':
204 | delta=loads(chunk)['choices'][0]['delta']
205 | yield delta.get('content','')
206 | except HTTPError as err:
207 | err_object=loads(err.response.text)
208 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
209 | except ConnectionError as err:
210 | print(err)
211 | exit()
212 |
213 | def available_models(self):
214 | url='https://api.openai.com/v1/models'
215 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
216 | headers=self.headers
217 | response=requests.get(url=url,headers=headers)
218 | response.raise_for_status()
219 | models=response.json()
220 | return [model['id'] for model in models['data'] if model['active']]
221 |
222 | class AudioOpenAI(BaseInference):
223 | def __init__(self,mode:Literal['transcriptions','translations']='transcriptions', model: str = '', api_key: str = '', base_url: str = '', temperature: float = 0.5):
224 | self.mode=mode
225 | super().__init__(model, api_key, base_url, temperature)
226 |
227 | def invoke(self,file_path:str='', language:str='en', json:bool=False)->AIMessage:
228 | path=Path(file_path)
229 | headers={'Authorization': f'Bearer {self.api_key}'}
230 | url=self.base_url or f"https://api.openai.com/v1/audio/{self.mode}"
231 | data={
232 | "model": self.model,
233 | "temperature": self.temperature,
234 | "response_format": "json_object" if json else "text"
235 | }
236 | if self.mode=='transcriptions':
237 | data['language']=language
238 | # Get the MIME type for the file
239 | mime_type, _ = mimetypes.guess_type(path.name)
240 | files={
241 | 'file': (path.name,self.__read_audio(path),mime_type)
242 | }
243 | try:
244 | with Client() as client:
245 | response=client.post(url=url,data=data,files=files,headers=headers,timeout=None)
246 | response.raise_for_status()
247 | if json:
248 | content=loads(response.text)['text']
249 | else:
250 | content=response.text
251 | return AIMessage(content)
252 | except HTTPError as err:
253 | err_object=loads(err.response.text)
254 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
255 | except ConnectionError as err:
256 | print(err)
257 | exit()
258 |
259 | def __read_audio(self,file_path:str):
260 | with open(file_path,'rb') as f:
261 | audio_data=f.read()
262 | return audio_data
263 |
264 | def async_invoke(self, messages:BaseMessage=[]):
265 | pass
266 |
267 | def stream(self, messages:BaseMessage=[]):
268 | pass
269 |
270 | def available_models(self):
271 | url='https://api.groq.com/openai/v1/models'
272 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
273 | headers=self.headers
274 | response=requests.get(url=url,headers=headers)
275 | response.raise_for_status()
276 | models=response.json()
277 | return [model['id'] for model in models['data'] if model['active']]
278 |
--------------------------------------------------------------------------------
/src/inference/groq.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,SystemMessage,ImageMessage,HumanMessage,ToolMessage
2 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
3 | from requests import RequestException,HTTPError,ConnectionError
4 | from ratelimit import limits,sleep_and_retry
5 | from httpx import Client,AsyncClient
6 | from src.inference import BaseInference,Token
7 | from pydantic import BaseModel
8 | from typing import Generator
9 | from typing import Literal
10 | from pathlib import Path
11 | from json import loads
12 | from uuid import uuid4
13 | import mimetypes
14 | import requests
15 |
16 | class ChatGroq(BaseInference):
17 | @sleep_and_retry
18 | @limits(calls=15,period=60)
19 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
20 | def invoke(self, messages: list[BaseMessage],json:bool=False,model:BaseModel=None)->AIMessage|ToolMessage|BaseModel:
21 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
22 | headers=self.headers
23 | temperature=self.temperature
24 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions"
25 | contents=[]
26 | for message in messages:
27 | if isinstance(message,SystemMessage):
28 | if model:
29 | message.content=self.structured(message,model)
30 | contents.append(message.to_dict())
31 | if isinstance(message,(HumanMessage,AIMessage)):
32 | contents.append(message.to_dict())
33 | if isinstance(message,ImageMessage):
34 | text,image=message.content
35 | contents.append([
36 | {
37 | 'role':'user',
38 | 'content':[
39 | {
40 | 'type':'text',
41 | 'text':text
42 | },
43 | {
44 | 'type':'image_url',
45 | 'image_url':{
46 | 'url':image
47 | }
48 | }
49 | ]
50 | }
51 | ])
52 |
53 | payload={
54 | "model": self.model,
55 | "messages": contents,
56 | "temperature": temperature,
57 | "response_format": {
58 | "type": "json_object" if json or model else "text"
59 | },
60 | "stream":False,
61 | }
62 | if self.tools:
63 | payload["tools"]=[{
64 | 'type':'function',
65 | 'function':{
66 | 'name':tool.name,
67 | 'description':tool.description,
68 | 'parameters':tool.schema
69 | }
70 | } for tool in self.tools]
71 | try:
72 | with Client() as client:
73 | response=client.post(url=url,json=payload,headers=headers,timeout=None)
74 | json_object=response.json()
75 | # print(json_object)
76 | if json_object.get('error'):
77 | raise HTTPError(json_object['error']['message'])
78 | message=json_object['choices'][0]['message']
79 | usage_metadata=json_object['usage']
80 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
81 | self.tokens=Token(input=input,output=output,total=total)
82 | if model:
83 | return model.model_validate_json(message.get('content'))
84 | if json:
85 | return AIMessage(loads(message.get('content')))
86 | if message.get('content'):
87 | return AIMessage(message.get('content'))
88 | else:
89 | tool_call=message.get('tool_calls')[0]['function']
90 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
91 | except HTTPError as err:
92 | err_object=loads(err.response.text)
93 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
94 | except ConnectionError as err:
95 | print(err)
96 | exit()
97 |
98 | @sleep_and_retry
99 | @limits(calls=15,period=60)
100 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
101 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel:
102 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
103 | headers=self.headers
104 | temperature=self.temperature
105 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions"
106 | contents=[]
107 | for message in messages:
108 | if isinstance(message,SystemMessage):
109 | if model:
110 | message.content=self.structured(message,model)
111 | contents.append(message.to_dict())
112 | if isinstance(message,(HumanMessage,AIMessage)):
113 | contents.append(message.to_dict())
114 | if isinstance(message,ImageMessage):
115 | text,image=message.content
116 | contents.append([
117 | {
118 | 'role':'user',
119 | 'content':[
120 | {
121 | 'type':'text',
122 | 'text':text
123 | },
124 | {
125 | 'type':'image_url',
126 | 'image_url':{
127 | 'url':image
128 | }
129 | }
130 | ]
131 | }
132 | ])
133 |
134 | payload={
135 | "model": self.model,
136 | "messages": contents,
137 | "temperature": temperature,
138 | "response_format": {
139 | "type": "json_object" if json or model else "text"
140 | },
141 | "stream":False,
142 | }
143 | if self.tools:
144 | payload["tools"]=[{
145 | 'type':'function',
146 | 'function':{
147 | 'name':tool.name,
148 | 'description':tool.description,
149 | 'parameters':tool.schema
150 | }
151 | } for tool in self.tools]
152 | try:
153 | async with AsyncClient() as client:
154 | response=await client.post(url=url,json=payload,headers=headers,timeout=None)
155 | json_object=response.json()
156 | # print(json_object)
157 | if json_object.get('error'):
158 | raise HTTPError(json_object['error']['message'])
159 | message=json_object['choices'][0]['message']
160 | usage_metadata=json_object['usage']
161 | input,output,total=usage_metadata['prompt_tokens'],usage_metadata['completion_tokens'],usage_metadata['total_tokens']
162 | self.tokens=Token(input=input,output=output,total=total)
163 | if model:
164 | return model.model_validate_json(message.get('content'))
165 | if json:
166 | return AIMessage(loads(message.get('content')))
167 | if message.get('content'):
168 | return AIMessage(message.get('content'))
169 | else:
170 | tool_call=message.get('tool_calls')[0]['function']
171 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['arguments'])
172 | except HTTPError as err:
173 | err_object=loads(err.response.text)
174 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
175 | except ConnectionError as err:
176 | print(err)
177 | exit()
178 |
179 | @sleep_and_retry
180 | @limits(calls=15,period=60)
181 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
182 | def stream(self, messages: list[BaseMessage],json=False)->Generator[str,None,None]:
183 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
184 | headers=self.headers
185 | temperature=self.temperature
186 | url=self.base_url or "https://api.groq.com/openai/v1/chat/completions"
187 | messages=[message.to_dict() for message in messages]
188 | payload={
189 | "model": self.model,
190 | "messages": messages,
191 | "temperature": temperature,
192 | "response_format": {
193 | "type": "json_object" if json else "text"
194 | },
195 | "stream":True,
196 | }
197 | try:
198 | response=requests.post(url=url,json=payload,headers=headers)
199 | response.raise_for_status()
200 | chunks=response.iter_lines(decode_unicode=True)
201 | for chunk in chunks:
202 | chunk=chunk.replace('data: ','')
203 | if chunk and chunk!='[DONE]':
204 | delta=loads(chunk)['choices'][0]['delta']
205 | yield delta.get('content','')
206 | except HTTPError as err:
207 | err_object=loads(err.response.text)
208 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
209 | except ConnectionError as err:
210 | print(err)
211 | exit()
212 |
213 | def available_models(self):
214 | url='https://api.groq.com/openai/v1/models'
215 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
216 | headers=self.headers
217 | response=requests.get(url=url,headers=headers)
218 | response.raise_for_status()
219 | models=response.json()
220 | return [model['id'] for model in models['data'] if model['active']]
221 |
222 | class AudioGroq(BaseInference):
223 | def __init__(self,mode:Literal['transcriptions','translations']='transcriptions', model: str = '', api_key: str = '', base_url: str = '', temperature: float = 0.5):
224 | self.mode=mode
225 | super().__init__(model, api_key, base_url, temperature)
226 |
227 | def invoke(self,file_path:str='', language:str='en', json:bool=False)->AIMessage:
228 | path=Path(file_path)
229 | headers={'Authorization': f'Bearer {self.api_key}'}
230 | url=self.base_url or f"https://api.groq.com/openai/v1/audio/{self.mode}"
231 | data={
232 | "model": self.model,
233 | "temperature": self.temperature,
234 | "response_format": "json_object" if json else "text"
235 | }
236 | if self.mode=='transcriptions':
237 | data['language']=language
238 | # Get the MIME type for the file
239 | mime_type, _ = mimetypes.guess_type(path.name)
240 | files={
241 | 'file': (path.name,self.__read_audio(path),mime_type)
242 | }
243 | try:
244 | with Client() as client:
245 | response=client.post(url=url,data=data,files=files,headers=headers,timeout=None)
246 | response.raise_for_status()
247 | if json:
248 | content=loads(response.text)['text']
249 | else:
250 | content=response.text
251 | return AIMessage(content)
252 | except HTTPError as err:
253 | err_object=loads(err.response.text)
254 | print(f'\nError: {err_object["error"]["message"]}\nStatus Code: {err.response.status_code}')
255 | except ConnectionError as err:
256 | print(err)
257 | exit()
258 |
259 | def __read_audio(self,file_path:str):
260 | with open(file_path,'rb') as f:
261 | audio_data=f.read()
262 | return audio_data
263 |
264 | def async_invoke(self, messages:BaseMessage=[]):
265 | pass
266 |
267 | def stream(self, messages:BaseMessage=[]):
268 | pass
269 |
270 | def available_models(self):
271 | url='https://api.groq.com/openai/v1/models'
272 | self.headers.update({'Authorization': f'Bearer {self.api_key}'})
273 | headers=self.headers
274 | response=requests.get(url=url,headers=headers)
275 | response.raise_for_status()
276 | models=response.json()
277 | return [model['id'] for model in models['data'] if model['active']]
278 |
--------------------------------------------------------------------------------
/src/agent/web/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.tools.views import Click,Type,Wait,Scroll,GoTo,Back,Key,Download,Scrape,Tab,Upload,Menu,Done,Forward,HumanInput,Script
2 | from src.agent.web.context import Context
3 | from markdownify import markdownify
4 | from typing import Literal,Optional
5 | from termcolor import colored
6 | from src.tool import Tool
7 | from asyncio import sleep
8 | from pathlib import Path
9 | from os import getcwd
10 | import httpx
11 |
12 | @Tool('Done Tool',params=Done)
13 | async def done_tool(content:str,context:Context=None):
14 | '''Indicates that the current task has been completed successfully. Use this to signal completion and provide a summary of what was accomplished.'''
15 | return content
16 |
17 | @Tool('Click Tool',params=Click)
18 | async def click_tool(index:int,context:Context=None):
19 | '''Clicks on interactive elements like buttons, links, checkboxes, radio buttons, tabs, or any clickable UI component. Automatically scrolls the element into view if needed and handles hidden elements.'''
20 | page=await context.get_current_page()
21 | await page.wait_for_load_state('load')
22 | element=await context.get_element_by_index(index=index)
23 | handle=await context.get_handle_by_xpath(element.xpath)
24 | is_hidden=await handle.is_hidden()
25 | if not is_hidden:
26 | await handle.scroll_into_view_if_needed()
27 | await handle.click(force=True)
28 | return f'Clicked on the element at label {index}'
29 |
30 | @Tool('Type Tool',params=Type)
31 | async def type_tool(index:int,text:str,clear:Literal['True','False']='False',press_enter:Literal['True','False']='False',context:Context=None):
32 | '''Types text into input fields, text areas, search boxes, or any editable element. Can optionally clear existing content before typing. Includes natural typing delay for better compatibility.'''
33 | page=await context.get_current_page()
34 | element=await context.get_element_by_index(index=index)
35 | handle=await context.get_handle_by_xpath(element.xpath)
36 | await page.wait_for_load_state('load')
37 | is_hidden=await handle.is_hidden()
38 | if not is_hidden:
39 | await handle.scroll_into_view_if_needed()
40 | await handle.click(force=True)
41 | if clear=='True':
42 | await page.keyboard.press('Control+A')
43 | await page.keyboard.press('Backspace')
44 | await page.keyboard.type(text,delay=80)
45 | if press_enter=='True':
46 | await page.keyboard.press('Enter')
47 | return f'Typed {text} in element at label {index}'
48 |
49 | @Tool('Wait Tool',params=Wait)
50 | async def wait_tool(time:int,context:Context=None):
51 | '''Pauses execution for a specified number of seconds. Use this to wait for page loading, animations to complete, or content to appear after an action.'''
52 | await sleep(time)
53 | return f'Waited for {time}s'
54 |
55 | @Tool('Scroll Tool',params=Scroll)
56 | async def scroll_tool(direction:Literal['up','down']='up',index:int=None,amount:int=500,context:Context=None):
57 | '''Scrolls either the webpage or a specific scrollable container. Can scroll by page increments or by specific pixel amounts. If index is provided, scrolls the specific element container; otherwise scrolls the page. Automatically detects scrollable containers and prevents unnecessary scroll attempts.'''
58 | page=await context.get_current_page()
59 | if index is not None:
60 | element=await context.get_element_by_index(index=index)
61 | handle=await context.get_handle_by_xpath(xpath=element.xpath)
62 | if direction=='up':
63 | await page.evaluate(f'(element)=> element.scrollBy(0,{-amount})', handle)
64 | elif direction=='down':
65 | await page.evaluate(f'(element)=> element.scrollBy(0,{amount})', handle)
66 | else:
67 | raise ValueError('Invalid direction')
68 | return f'Scrolled {direction} inside the element at label {index} by {amount}'
69 | else:
70 | scroll_y_before = await context.execute_script(page,"() => window.scrollY")
71 | max_scroll_y = await context.execute_script(page,"() => document.documentElement.scrollHeight - window.innerHeight")
72 | min_scroll_y = await context.execute_script(page,"() => document.documentElement.scrollHeight")
73 | # Check if scrolling is possible
74 | if scroll_y_before >= max_scroll_y and direction == 'down':
75 | return "Already at the bottom, cannot scroll further."
76 | elif scroll_y_before == min_scroll_y and direction == 'up':
77 | return "Already at the top, cannot scroll further."
78 | if direction=='up':
79 | if amount is None:
80 | await page.keyboard.press('PageUp')
81 | else:
82 | await page.mouse.wheel(0,-amount)
83 | elif direction=='down':
84 | if amount is None:
85 | await page.keyboard.press('PageDown')
86 | else:
87 | await page.mouse.wheel(0,amount)
88 | else:
89 | raise ValueError('Invalid direction')
90 | # Get scroll position after scrolling
91 | scroll_y_after = await context.execute_script(page,"() => window.scrollY")
92 | # Verify if scrolling was successful
93 | if scroll_y_before == scroll_y_after:
94 | return "Scrolling has no effect, the entire content fits within the viewport."
95 | amount=amount if amount else 'one page'
96 | return f'Scrolled {direction} by {amount}'
97 |
98 | @Tool('GoTo Tool',params=GoTo)
99 | async def goto_tool(url:str,context:Context=None):
100 | '''Navigates directly to a specified URL in the current tab. Supports HTTP/HTTPS URLs and waits for the DOM content to load before proceeding.'''
101 | page=await context.get_current_page()
102 | await page.goto(url=url,wait_until='domcontentloaded')
103 | await page.wait_for_timeout(2.5*1000)
104 | return f'Navigated to {url}'
105 |
106 | @Tool('Back Tool',params=Back)
107 | async def back_tool(context:Context=None):
108 | '''Navigates to the previous page in the browser history, equivalent to clicking the browser's back button. Waits for the page to fully load.'''
109 | page=await context.get_current_page()
110 | await page.go_back()
111 | await page.wait_for_load_state('load')
112 | return 'Navigated to previous page'
113 |
114 | @Tool('Forward Tool',params=Forward)
115 | async def forward_tool(context:Context=None):
116 | '''Navigates to the next page in the browser history, equivalent to clicking the browser's forward button. Waits for the page to fully load.'''
117 | page=await context.get_current_page()
118 | await page.go_forward()
119 | await page.wait_for_load_state('load')
120 | return 'Navigated to next page'
121 |
122 | @Tool('Key Tool',params=Key)
123 | async def key_tool(keys:str,times:int=1,context:Context=None):
124 | '''Performs keyboard shortcuts and key combinations (e.g., "Control+C", "Enter", "Escape", "Tab"). Can repeat the key press multiple times. Supports all standard keyboard keys and modifiers.'''
125 | page=await context.get_current_page()
126 | await page.wait_for_load_state('domcontentloaded')
127 | for _ in range(times):
128 | await page.keyboard.press(keys)
129 | return f'Pressed {keys}'
130 |
131 | @Tool('Download Tool',params=Download)
132 | async def download_tool(url:str=None,filename:str=None,context:Context=None):
133 | '''Downloads files from the internet (PDFs, images, videos, audio, documents) and saves them to the system's downloads directory. Handles various file types and formats.'''
134 | folder_path=Path(context.browser.config.downloads_dir)
135 | async with httpx.AsyncClient() as client:
136 | response=await client.get(url)
137 | path=folder_path.joinpath(filename)
138 | with open(path,'wb') as f:
139 | async for chunk in response.aiter_bytes():
140 | f.write(chunk)
141 | return f'Downloaded {filename} from {url} and saved it to {path}'
142 |
143 | @Tool('Scrape Tool',params=Scrape)
144 | async def scrape_tool(context:Context=None):
145 | '''Extracts and returns the main content from the current webpage. Can output in markdown format (preserving links and structure). Filters out navigation, ads, and other non-essential content.'''
146 | page=await context.get_current_page()
147 | await page.wait_for_load_state('domcontentloaded')
148 | html=await page.content()
149 | content= markdownify(html)
150 | return f'Scraped the contents of the entire webpage:\n{content}'
151 |
152 | @Tool('Tab Tool', params=Tab)
153 | async def tab_tool(mode: Literal['open', 'close', 'switch'], tab_index: Optional[int] = None, context: Context = None):
154 | '''Manages browser tabs: opens new blank tabs, closes the current tab (if not the last one), or switches between existing tabs by index. Automatically handles focus and loading states.'''
155 | session = await context.get_session()
156 | pages = session.context.pages # Get all open tabs
157 | if mode == 'open':
158 | page = await session.context.new_page()
159 | session.current_page = page
160 | await page.wait_for_load_state('load')
161 | return 'Opened a new blank tab and switched to it.'
162 | elif mode == 'close':
163 | if len(pages) == 1:
164 | return 'Cannot close the last remaining tab.'
165 | page = session.current_page
166 | await page.close()
167 | # Get remaining pages after closing
168 | pages = session.context.pages
169 | session.current_page = pages[-1] # Switch to last remaining tab
170 | await session.current_page.bring_to_front()
171 | await session.current_page.wait_for_load_state('load')
172 | return 'Closed current tab and switched to the next last tab.'
173 | elif mode == 'switch':
174 | if tab_index is None or tab_index < 0 or tab_index >= len(pages):
175 | raise IndexError(f'Tab index {tab_index} is out of range. Available tabs: {len(pages)}')
176 | session.current_page = pages[tab_index]
177 | await session.current_page.bring_to_front()
178 | await session.current_page.wait_for_load_state('load')
179 | return f'Switched to tab {tab_index} (Total tabs: {len(pages)}).'
180 | else:
181 | raise ValueError("Invalid mode. Use 'open', 'close', or 'switch'.")
182 |
183 | @Tool('Upload Tool',params=Upload)
184 | async def upload_tool(index:int,filenames:list[str],context:Context=None):
185 | '''Uploads one or more files to file input elements on webpages. Handles both single and multiple file uploads. Files should be placed in the ./uploads directory before using this tool.'''
186 | element=await context.get_element_by_index(index=index)
187 | handle=await context.get_handle_by_xpath(element.xpath)
188 | files=[Path(getcwd()).joinpath('./uploads',filename) for filename in filenames]
189 | page=await context.get_current_page()
190 | async with page.expect_file_chooser() as file_chooser_info:
191 | await handle.click()
192 | file_chooser=await file_chooser_info.value
193 | handle=file_chooser.element
194 | if file_chooser.is_multiple():
195 | await handle.set_input_files(files=files)
196 | else:
197 | await handle.set_input_files(files=files[0])
198 | await page.wait_for_load_state('load')
199 | return f'Uploaded {filenames} to element at label {index}'
200 |
201 | @Tool('Menu Tool',params=Menu)
202 | async def menu_tool(index:int,labels:list[str],context:Context=None):
203 | '''Interacts with dropdown menus, select elements, and multi-select lists. Can select single or multiple options by their visible labels. Handles both simple dropdowns and complex multi-selection interfaces.'''
204 | element=await context.get_element_by_index(index=index)
205 | handle=await context.get_handle_by_xpath(element.xpath)
206 | labels=labels if len(labels)>1 else labels[0]
207 | await handle.select_option(label=labels)
208 | return f'Opened context menu of element at label {index} and selected {", ".join(labels)}'
209 |
210 | @Tool("Script Tool",params=Script)
211 | async def script_tool(script:str,context:Context=None):
212 | '''Executes arbitrary JavaScript code on the page. Can be used to manipulate the DOM or trigger events or scrape data. Returns the result of the executed script.'''
213 | page=await context.get_current_page()
214 | result=await context.execute_script(page,script)
215 | return f"Result of the executed script: {result}"
216 |
217 | @Tool('Human Tool',params=HumanInput)
218 | async def human_tool(prompt:str,context:Context=None):
219 | '''Requests human assistance when encountering challenges that require human intervention such as CAPTCHAs, OTP codes, complex decisions, or when explicitly asked to involve a human user.'''
220 | print(colored(f"Agent: {prompt}", color='cyan', attrs=['bold']))
221 | human_response = input("User: ")
222 | return f"User provided the following input: '{human_response}'"
--------------------------------------------------------------------------------
/src/inference/gemini.py:
--------------------------------------------------------------------------------
1 | from src.message import AIMessage,BaseMessage,HumanMessage,ImageMessage,SystemMessage,ToolMessage
2 | from requests import get,RequestException,HTTPError,ConnectionError
3 | from tenacity import retry,stop_after_attempt,retry_if_exception_type
4 | from ratelimit import limits,sleep_and_retry
5 | from src.inference import BaseInference,Token
6 | from httpx import Client,AsyncClient
7 | from pydantic import BaseModel
8 | from typing import Optional
9 | from typing import Literal
10 | from src.tool import Tool
11 | from json import loads
12 | from uuid import uuid4
13 |
14 | class ChatGemini(BaseInference):
15 | def __init__(self,model:str,api_version:Literal['v1','v1beta','v1alpha']='v1beta',modality:Literal['text','audio']='text',api_key:str='',base_url:str='',tools:list=[],temperature:float=0.5):
16 | super().__init__(model,api_key=api_key,base_url=base_url,tools=tools,temperature=temperature)
17 | self.api_version=api_version
18 | self.modality=modality
19 |
20 |
21 | def cache_content(self,system_message:Optional[SystemMessage]=None,tools:Optional[list[Tool]]=None,messages:Optional[list[BaseMessage]]=None,display_name:Optional[str]=None,ttl:int=60):
22 | url = f"https://generativelanguage.googleapis.com/{self.api_version}/cachedContents?key={self.api_key}"
23 | payload = {
24 | "ttl": f"{ttl}s",
25 | "model": f"models/{self.model}",
26 | }
27 | # Add display name if provided
28 | if display_name:
29 | payload["display_name"] = display_name
30 | # Add system instruction if provided
31 | if system_message:
32 | payload["systemInstruction"] = {
33 | "parts": [
34 | {
35 | "text": system_message.content
36 | }
37 | ]
38 | }
39 | # Add tools if provided
40 | if tools:
41 | payload["tools"] = [
42 | {
43 | "function_declarations": [
44 | {
45 | "name": tool.name,
46 | "description": tool.description,
47 | "parameters": tool.schema
48 | }
49 | for tool in tools
50 | ]
51 | }
52 | ]
53 | try:
54 | with Client() as client:
55 | response = client.post(url, json=payload, headers=self.headers, timeout=None)
56 | json_obj=response.json()
57 | if json_obj.get('error'):
58 | raise Exception(json_obj['error']['message'])
59 | usage_metadata=json_obj['usageMetadata']
60 | self.tokens=Token(cache=usage_metadata["totalTokenCount"])
61 | return json_obj['name']
62 | except HTTPError as err:
63 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
64 | return None
65 | except ConnectionError as err:
66 | print(f'Connection Error: {err}')
67 | return None
68 |
69 | @sleep_and_retry
70 | @limits(calls=15,period=60)
71 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
72 | def invoke(self, messages: list[BaseMessage],json=False,model:BaseModel|None=None,cache_name:Optional[str]=None) -> AIMessage|ToolMessage|BaseModel:
73 | self.headers.update({'x-goog-api-key':self.api_key})
74 | temperature=self.temperature
75 | url=self.base_url or f"https://generativelanguage.googleapis.com/{self.api_version}/models/{self.model}:generateContent"
76 | contents=[]
77 | system_instruct=None
78 | for message in messages:
79 | if isinstance(message,HumanMessage):
80 | contents.append({
81 | 'role':'user',
82 | 'parts':[{
83 | 'text':message.content
84 | }]
85 | })
86 | elif isinstance(message,AIMessage):
87 | contents.append({
88 | 'role':'model',
89 | 'parts':[{
90 | 'text':message.content
91 | }]
92 | })
93 | elif isinstance(message,ImageMessage):
94 | text,image=message.content
95 | contents.append({
96 | 'role':'user',
97 | 'parts':[{
98 | 'text':text
99 | },
100 | {
101 | 'inline_data':{
102 | 'mime_type':'image/jpeg',
103 | 'data': image
104 | }
105 | }]
106 | })
107 | elif isinstance(message,SystemMessage):
108 | system_instruct={
109 | 'parts':{
110 | 'text': self.structured(message,model) if model else message.content
111 | }
112 | }
113 | else:
114 | raise Exception("Invalid Message")
115 |
116 | payload={
117 | 'contents': contents,
118 | 'generationConfig':{
119 | 'temperature': temperature,
120 | 'responseMimeType':'application/json' if json or model else 'text/plain',
121 | 'responseModalities': [self.modality]
122 | }
123 | }
124 | if self.tools:
125 | payload['tools']=[
126 | {
127 | 'function_declarations':[
128 | {
129 | 'name': tool.name,
130 | 'description': tool.description,
131 | 'parameters': tool.schema
132 | }
133 | for tool in self.tools]
134 | }
135 | ]
136 | if system_instruct:
137 | payload['system_instruction']=system_instruct
138 |
139 | if cache_name:
140 | payload['cachedContent']=f"cachedContents/{cache_name}"
141 |
142 | try:
143 | with Client() as client:
144 | response=client.post(url=url,headers=self.headers,json=payload,timeout=None)
145 | json_obj=response.json()
146 | # print(json_obj)
147 | if json_obj.get('error'):
148 | raise Exception(json_obj['error']['message'])
149 | message=json_obj['candidates'][0]['content']['parts'][0]
150 | usage_metadata=json_obj['usageMetadata']
151 | input,output,total=usage_metadata['promptTokenCount'],usage_metadata['candidatesTokenCount'],usage_metadata['totalTokenCount']
152 | self.tokens=Token(input=input,output=output,total=total)
153 | # print(message)
154 | if model:
155 | return model.model_validate_json(message['text'])
156 | if json:
157 | content=loads(message['text'])
158 | return AIMessage(content)
159 | if message['text']:
160 | content=message['text']
161 | return AIMessage(content)
162 | else:
163 | tool_call=message['functionCall']
164 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['args'])
165 |
166 | except HTTPError as err:
167 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
168 | except ConnectionError as err:
169 | print(err)
170 | exit()
171 |
172 | @sleep_and_retry
173 | @limits(calls=15,period=60)
174 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
175 | async def async_invoke(self, messages: list[BaseMessage],json=False,model:BaseModel=None) -> AIMessage|ToolMessage|BaseModel:
176 | temperature=self.temperature
177 | url=self.base_url or f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent"
178 | self.headers.update({'x-goog-api-key':self.api_key})
179 | contents=[]
180 | system_instruction=None
181 | for message in messages:
182 | if isinstance(message,HumanMessage):
183 | contents.append({
184 | 'role':'user',
185 | 'parts':[{
186 | 'text':message.content
187 | }]
188 | })
189 | elif isinstance(message,AIMessage):
190 | contents.append({
191 | 'role':'model',
192 | 'parts':[{
193 | 'text':message.content
194 | }]
195 | })
196 | elif isinstance(message,ImageMessage):
197 | text,image=message.content
198 | contents.append({
199 | 'role':'user',
200 | 'parts':[{
201 | 'text':text
202 | },
203 | {
204 | 'inline_data':{
205 | 'mime_type':'image/jpeg',
206 | 'data': image
207 | }
208 | }]
209 | })
210 | elif isinstance(message,SystemMessage):
211 | system_instruction={
212 | 'parts':{
213 | 'text': self.structured(message,model) if model else message.content
214 | }
215 | }
216 | else:
217 | raise Exception("Invalid Message")
218 |
219 | payload={
220 | 'contents': contents,
221 | 'generationConfig':{
222 | 'temperature': temperature,
223 | 'responseMimeType':'application/json' if json or model else 'text/plain',
224 | 'responseModalities': [self.modality]
225 | }
226 | }
227 | if self.tools:
228 | payload['tools']=[
229 | {
230 | 'function_declarations':[
231 | {
232 | 'name': tool.name,
233 | 'description': tool.description,
234 | 'parameters': tool.schema
235 | }
236 | for tool in self.tools]
237 | }
238 | ]
239 | if system_instruction:
240 | payload['system_instruction']=system_instruction
241 | try:
242 | async with AsyncClient() as client:
243 | response=await client.post(url=url,headers=self.headers,json=payload,timeout=None)
244 | json_obj=response.json()
245 | # print(json_obj)
246 | if json_obj.get('error'):
247 | raise Exception(json_obj['error']['message'])
248 | message=json_obj['candidates'][0]['content']['parts'][0]
249 | usage_metadata=json_obj['usageMetadata']
250 | input,output,total=usage_metadata['promptTokenCount'],usage_metadata['candidatesTokenCount'],usage_metadata['totalTokenCount']
251 | self.tokens=Token(input=input,output=output,total=total)
252 | if model:
253 | return model.model_validate_json(message['text'])
254 | if json:
255 | content=loads(message['text'])
256 | return AIMessage(content)
257 | if message['text']:
258 | content=message['text']
259 | return AIMessage(content)
260 | else:
261 | tool_call=message['functionCall']
262 | return ToolMessage(id=str(uuid4()),name=tool_call['name'],args=tool_call['args'])
263 |
264 | except HTTPError as err:
265 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
266 | except ConnectionError as err:
267 | print(err)
268 | exit()
269 |
270 | @retry(stop=stop_after_attempt(3),retry=retry_if_exception_type(RequestException))
271 | def stream(self, query:str):
272 | headers=self.headers
273 | temperature=self.temperature
274 | url=self.base_url or f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent"
275 |
276 | def available_models(self):
277 | url='https://generativelanguage.googleapis.com/v1beta/models'
278 | headers=self.headers
279 | params={'key':self.api_key}
280 | try:
281 | response=get(url=url,headers=headers,params=params)
282 | response.raise_for_status()
283 | json_obj=response.json()
284 | models=json_obj['models']
285 | except HTTPError as err:
286 | print(f'Error: {err.response.text}, Status Code: {err.response.status_code}')
287 | exit()
288 | except ConnectionError as err:
289 | print(err)
290 | exit()
291 | return [model['displayName'] for model in models]
--------------------------------------------------------------------------------
/src/agent/web/__init__.py:
--------------------------------------------------------------------------------
1 | from src.agent.web.tools import click_tool,goto_tool,type_tool,scroll_tool,wait_tool,back_tool,key_tool,scrape_tool,tab_tool,forward_tool,done_tool,download_tool,human_tool,script_tool
2 | from src.message import SystemMessage,HumanMessage,ImageMessage,AIMessage
3 | from src.agent.web.utils import read_markdown_file,extract_agent_data
4 | from src.agent.web.browser import Browser,BrowserConfig
5 | from langgraph.graph import StateGraph,END,START
6 | from src.agent.web.state import AgentState
7 | from src.agent.web.context import Context
8 | from src.inference import BaseInference
9 | from src.tool.registry import Registry
10 | from rich.markdown import Markdown
11 | from src.memory import BaseMemory
12 | from rich.console import Console
13 | from src.agent import BaseAgent
14 | from pydantic import BaseModel
15 | from datetime import datetime
16 | from termcolor import colored
17 | from textwrap import dedent
18 | from src.tool import Tool
19 | from pathlib import Path
20 | import textwrap
21 | import platform
22 | import asyncio
23 | import json
24 |
25 | main_tools=[
26 | click_tool,goto_tool,key_tool,scrape_tool,
27 | type_tool,scroll_tool,wait_tool,back_tool,
28 | tab_tool,done_tool,forward_tool,download_tool,
29 | script_tool
30 | ]
31 |
32 | class Agent(BaseAgent):
33 | def __init__(self,config:BrowserConfig=None,additional_tools:list[Tool]=[],
34 | instructions:list=[],memory:BaseMemory=None,llm:BaseInference=None,max_iteration:int=10,
35 | use_vision:bool=False,include_human_in_loop:bool=False,verbose:bool=False,token_usage:bool=False) -> None:
36 | """
37 | Initializes the WebAgent object.
38 |
39 | Args:
40 | config (BrowserConfig, optional): Browser configuration. Defaults to None.
41 | additional_tools (list[Tool], optional): Additional tools to be used. Defaults to [].
42 | instructions (list, optional): Instructions for the agent. Defaults to [].
43 | memory (BaseMemory, optional): Memory object for the agent. Defaults to None.
44 | llm (BaseInference, optional): Large Language Model object. Defaults to None.
45 | max_iteration (int, optional): Maximum number of iterations. Defaults to 10.
46 | use_vision (bool, optional): Whether to use vision or not. Defaults to False.
47 | include_human_in_loop (bool, optional): Whether to include human in the loop or not. Defaults to False.
48 | verbose (bool, optional): Whether to print verbose output or not. Defaults to False.
49 | token_usage (bool, optional): Whether to track token usage or not. Defaults to False.
50 |
51 | Returns:
52 | None
53 | """
54 | self.name='Web Agent'
55 | self.description='The Web Agent is designed to automate the process of gathering information from the internet, such as to navigate websites, perform searches, and retrieve data.'
56 | self.observation_prompt=read_markdown_file('./src/agent/web/prompt/observation.md')
57 | self.system_prompt=read_markdown_file('./src/agent/web/prompt/system.md')
58 | self.action_prompt=read_markdown_file('./src/agent/web/prompt/action.md')
59 | self.answer_prompt=read_markdown_file('./src/agent/web/prompt/answer.md')
60 | self.instructions=self.format_instructions(instructions)
61 | self.registry=Registry(main_tools+additional_tools+([human_tool] if include_human_in_loop else []))
62 | self.include_human_in_loop=include_human_in_loop
63 | self.browser=Browser(config=config)
64 | self.context=Context(browser=self.browser)
65 | self.max_iteration=max_iteration
66 | self.token_usage=token_usage
67 | self.structured_output=None
68 | self.use_vision=use_vision
69 | self.verbose=verbose
70 | self.start_time=None
71 | self.memory=memory
72 | self.end_time=None
73 | self.iteration=0
74 | self.llm=llm
75 | self.graph=self.create_graph()
76 |
77 | def format_instructions(self,instructions):
78 | return '\n'.join([f'{i+1}. {instruction}' for (i,instruction) in enumerate(instructions)])
79 |
80 | async def reason(self,state:AgentState):
81 | "Call LLM to make decision based on the current state of the browser"
82 | system_prompt=self.system_prompt.format(**{
83 | 'os':platform.system(),
84 | 'instructions':self.instructions,
85 | 'home_dir':Path.home().as_posix(),
86 | 'max_iteration':self.max_iteration,
87 | 'human_in_loop':self.include_human_in_loop,
88 | 'tools_prompt':self.registry.tools_prompt(),
89 | 'browser':self.browser.config.browser.capitalize(),
90 | 'downloads_dir':self.browser.config.downloads_dir,
91 | 'current_datetime':datetime.now().strftime('%A, %B %d, %Y')
92 | })
93 | messages=[SystemMessage(system_prompt)]+state.get('messages')
94 | ai_message=await self.llm.async_invoke(messages=messages)
95 | agent_data=extract_agent_data(ai_message.content)
96 | memory=agent_data.get('Memory')
97 | evaluate=agent_data.get("Evaluate")
98 | thought=agent_data.get('Thought')
99 | if self.verbose:
100 | print(colored(f'Evaluate: {evaluate}',color='light_yellow',attrs=['bold']))
101 | print(colored(f'Memory: {memory}',color='light_green',attrs=['bold']))
102 | print(colored(f'Thought: {thought}',color='light_magenta',attrs=['bold']))
103 | last_message=state.get('messages').pop() # ImageMessage/HumanMessage. To remove the past browser state
104 | if isinstance(last_message,(ImageMessage,HumanMessage)):
105 | message=HumanMessage(dedent(f'''
106 |
107 |
108 | Current Step: {self.iteration}
109 |
110 | Max. Steps: {self.max_iteration}
111 |
112 | Action Response: {state.get('prev_observation')}
113 |
114 |
115 | '''))
116 | return {**state,'agent_data': agent_data,'messages':[message]}
117 |
118 | async def action(self,state:AgentState):
119 | "Execute the provided action"
120 | agent_data=state.get('agent_data')
121 | memory=agent_data.get('Memory')
122 | evaluate=agent_data.get("Evaluate")
123 | thought=agent_data.get('Thought')
124 | action_name=agent_data.get('Action Name')
125 | action_input:dict=agent_data.get('Action Input')
126 | if self.verbose:
127 | print(colored(f'Action: {action_name}({','.join([f'{k}={v}' for k,v in action_input.items()])})',color='blue',attrs=['bold']))
128 | action_result=await self.registry.async_execute(action_name,action_input,context=self.context)
129 | observation=action_result.content
130 | if self.verbose:
131 | print(colored(f'Observation: {textwrap.shorten(observation,width=1000,placeholder='...')}',color='green',attrs=['bold']))
132 | if self.verbose and self.token_usage:
133 | print(f'Input Tokens: {self.llm.tokens.input} Output Tokens: {self.llm.tokens.output} Total Tokens: {self.llm.tokens.total}')
134 | # Get the current screenshot,browser state and dom state
135 | browser_state=await self.context.get_state(use_vision=self.use_vision)
136 | current_tab=browser_state.current_tab
137 | dom_state=browser_state.dom_state
138 | image_obj=browser_state.screenshot
139 | # Redefining the AIMessage and adding the new observation
140 | action_prompt=self.action_prompt.format(**{
141 | 'memory':memory,
142 | 'evaluate':evaluate,
143 | 'thought':thought,
144 | 'action_name':action_name,
145 | 'action_input':json.dumps(action_input,indent=2)
146 | })
147 | observation_prompt=self.observation_prompt.format(**{
148 | 'iteration':self.iteration,
149 | 'max_iteration':self.max_iteration,
150 | 'observation':observation,
151 | 'current_tab':current_tab.to_string(),
152 | 'tabs':browser_state.tabs_to_string(),
153 | 'interactive_elements':dom_state.interactive_elements_to_string(),
154 | 'informative_elements':dom_state.informative_elements_to_string(),
155 | 'scrollable_elements':dom_state.scrollable_elements_to_string(),
156 | 'query':state.get('input')
157 | })
158 | messages=[AIMessage(action_prompt),ImageMessage(text=observation_prompt,image_obj=image_obj) if self.use_vision and image_obj is not None else HumanMessage(observation_prompt)]
159 | return {**state,'messages':messages,'browser_state':browser_state,'dom_state':dom_state,'prev_observation':observation}
160 |
161 | async def answer(self,state:AgentState):
162 | "Give the final answer"
163 | if self.iterationdict|BaseModel:
216 | self.iteration=0
217 | observation_prompt=self.observation_prompt.format(**{
218 | 'iteration':self.iteration,
219 | 'max_iteration':self.max_iteration,
220 | 'memory':'Nothing to remember',
221 | 'evaluate':'Nothing to evaluate',
222 | 'thought':'Nothing to think',
223 | 'action':'No Action',
224 | 'observation':'No Observation',
225 | 'current_tab':'No tabs open',
226 | 'tabs':'No tabs open',
227 | 'interactive_elements':'No interactive elements found',
228 | 'informative_elements':'No informative elements found',
229 | 'scrollable_elements':'No scrollable elements found',
230 | 'query':input
231 | })
232 | state={
233 | 'input':input,
234 | 'agent_data':{},
235 | 'prev_observation':'No Observation',
236 | 'browser_state':None,
237 | 'dom_state':None,
238 | 'output':'',
239 | 'messages':[HumanMessage(observation_prompt)]
240 | }
241 | self.start_time=datetime.now()
242 | response=await self.graph.ainvoke(state,config={'recursion_limit':self.max_iteration})
243 | self.end_time=datetime.now()
244 | total_seconds=(self.end_time-self.start_time).total_seconds()
245 | if self.verbose and self.token_usage:
246 | print(f'Input Tokens: {self.llm.tokens.input} Output Tokens: {self.llm.tokens.output} Total Tokens: {self.llm.tokens.total}')
247 | print(f'Total Time Taken: {total_seconds} seconds Number of Steps: {self.iteration}')
248 | # Extract and store the key takeaways of the task performed by the agent
249 | if self.memory:
250 | self.memory.store(response.get('messages'))
251 | return response
252 |
253 | def invoke(self, input: str)->dict|BaseModel:
254 | if self.verbose:
255 | print('Entering '+colored(self.name,'black','on_white'))
256 | try:
257 | loop = asyncio.get_running_loop()
258 | except RuntimeError:
259 | loop = asyncio.new_event_loop()
260 | asyncio.set_event_loop(loop)
261 | response = loop.run_until_complete(self.async_invoke(input=input))
262 | return response
263 |
264 | async def invoke_history(self, input: str)->dict|BaseModel:
265 | pass
266 |
267 | def print_response(self,input: str):
268 | console=Console()
269 | response=self.invoke(input)
270 | console.print(Markdown(response.get('output')))
271 |
272 | async def close(self):
273 | '''Close the browser and context followed by clean up'''
274 | try:
275 | await self.context.close_session()
276 | await self.browser.close_browser()
277 | except Exception:
278 | print('Failed to finish clean up')
279 | finally:
280 | self.context=None
281 | self.browser=None
282 |
283 | def stream(self, input:str):
284 | pass
285 |
--------------------------------------------------------------------------------