├── __init__.py ├── utils ├── __init__.py ├── eval_prompt.py ├── collator.py ├── action_space_converter.py └── function_parser.py ├── assets ├── google.png └── screenspot-v2.png ├── requirements.txt ├── .gitignore ├── preprocessing ├── prompts.py ├── aguvis_models.py ├── action_conversion.py └── aguvis_processor.py ├── recipe.ipynb └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/google.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/smol2operator/HEAD/assets/google.png -------------------------------------------------------------------------------- /assets/screenspot-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/smol2operator/HEAD/assets/screenspot-v2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | trl==0.21.0 2 | transformers==4.56.1 3 | Pillow==11.3.0 4 | torchvision==0.23.0 5 | num2words==0.5.14 6 | wandb==0.21.1 7 | python-dotenv==1.1.1 -------------------------------------------------------------------------------- /utils/eval_prompt.py: -------------------------------------------------------------------------------- 1 | # This file compiles the evaluation prompt for the model 2 | 3 | # Screenspot-v2 Dataset: https://huggingface.co/datasets/HongxinLi/ScreenSpot_v2 4 | 5 | # Screenspot-v2 Evaluation Prompt (Phase 1) 6 | 7 | SCREENSPOT_V2_USER_PROMPT_PHASE_1 = """Using the screenshot, you will get an instruction and will need to output a click that completes the instruction or targets the given element. 8 | 9 | Just write your action as follows: 10 | 11 | Action: click(0.XXXX, 0.YYYY) 12 | With 0.XXXX and 0.YYYY the normalized coordinates of the click position on the screenshot, representing relative horizontal (X-axis) and vertical (Y-axis) positions on the screen respectively. 13 | 14 | Now write the click needed to complete the instruction: 15 | Instruction: {instruction} 16 | """ 17 | 18 | # Screenspot-v2 Evaluation Prompt (Phase 2) 19 | 20 | SCREENSPOT_V2_SYSTEM_PROMPT_PHASE_1 = '''You are a helpful GUI agent. You'll be given a task and a screenshot of the screen. Complete the task using Python function calls. 21 | 22 | For each step: 23 | • First, to express the thought process guiding your next action and the reasoning behind it. 24 | • Then, use to perform the action. it will be executed in a stateful environment. 25 | 26 | The following functions are exposed to the Python interpreter: 27 | 28 | 29 | # OS ACTIONS 30 | 31 | 32 | def click(x: Optional[float] = None, y: Optional[float] = None) -> str: 33 | """ 34 | Performs a left-click at the specified normalized coordinates 35 | Args: 36 | x: The x coordinate (horizontal position) 37 | y: The y coordinate (vertical position) 38 | """ 39 | 40 | 41 | 42 | The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.''' 43 | 44 | SCREENSPOT_V2_USER_PROMPT_PHASE_2 = """Please generate the next move according to the UI screenshot, instruction and previous actions. 45 | 46 | Instruction: {instruction} 47 | 48 | Previous actions: 49 | None""" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | 173 | # Temp folders 174 | data/ 175 | wandb/ 176 | logs/ 177 | eval_results/ 178 | results/ 179 | 180 | .vscode/ 181 | .python-version -------------------------------------------------------------------------------- /utils/collator.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | def create_collate_fn(processor, max_length: int): 5 | """Optimized collate function for VLM training that masks system prompt tokens.""" 6 | 7 | def collate_fn(examples: list[dict[str, list | str | Image.Image]]): 8 | batch_messages: list[list[dict[str, list | str | Image.Image]]] = [] 9 | assistant_messages: list[list[str]] = [] 10 | all_image_inputs: list[list[Image.Image]] = [] 11 | for example in examples: 12 | images: list[Image.Image] = example["images"] 13 | is_first_user = True 14 | sample: list[dict[str, list | str | Image.Image]] = [] 15 | assistant: list[str] = [] 16 | for text in example["texts"]: 17 | if "system" in text.keys(): 18 | sample.append( 19 | { 20 | "role": "system", 21 | "content": [{"type": "text", "text": text["system"]}], 22 | } 23 | ) 24 | 25 | if is_first_user: 26 | sample.append( 27 | { 28 | "role": "user", 29 | "content": [ 30 | {"type": "image", "image": images[0]}, 31 | {"type": "text", "text": text["user"]}, 32 | ], 33 | } 34 | ) 35 | is_first_user = False 36 | else: 37 | sample.append( 38 | { 39 | "role": "user", 40 | "content": [ 41 | {"type": "text", "text": text["user"]}, 42 | ], 43 | } 44 | ) 45 | 46 | sample.append( 47 | { 48 | "role": "assistant", 49 | "content": [{"type": "text", "text": "\n" + text["assistant"]}], 50 | } 51 | ) 52 | assistant.append(text["assistant"]) 53 | 54 | batch_messages.append(sample) 55 | assistant_messages.append(assistant) 56 | all_image_inputs.append(images) 57 | 58 | texts = [ 59 | processor.apply_chat_template( 60 | messages, tokenize=False, add_generation_prompt=False 61 | ) 62 | for messages in batch_messages 63 | ] 64 | 65 | batch = processor( 66 | text=texts, 67 | images=all_image_inputs if all_image_inputs else None, 68 | max_length=max_length, 69 | truncation=True, 70 | padding=True, 71 | return_tensors="pt", 72 | ) 73 | 74 | input_ids = batch["input_ids"] 75 | labels = input_ids.clone() 76 | 77 | assistant_encodings = [ 78 | processor.tokenizer( 79 | [msg + "" for msg in assistant_message], 80 | add_special_tokens=False, 81 | padding=False, 82 | )["input_ids"] 83 | for assistant_message in assistant_messages 84 | ] 85 | 86 | # Mask out all except the assistant messages 87 | for i, assistant_ids_list in enumerate(assistant_encodings): 88 | seq = input_ids[i].tolist() 89 | assistant_positions: list[int] = [] 90 | for ids in assistant_ids_list: 91 | start_pos = 0 92 | while start_pos < len(seq) - len(ids) + 1: 93 | found = False 94 | for j in range(start_pos, len(seq) - len(ids) + 1): 95 | if seq[j : j + len(ids)] == ids: 96 | assistant_positions.extend(range(j, j + len(ids))) 97 | start_pos = j + len(ids) 98 | found = True 99 | break 100 | if not found: 101 | break 102 | 103 | for pos in range(len(seq)): 104 | if pos not in assistant_positions: 105 | labels[i, pos] = -100 106 | 107 | batch["labels"] = labels 108 | return batch 109 | 110 | return collate_fn 111 | -------------------------------------------------------------------------------- /preprocessing/prompts.py: -------------------------------------------------------------------------------- 1 | OS_ACTIONS = """ 2 | def final_answer(answer: any) -> any: 3 | \"\"\" 4 | Provides a final answer to the given problem. 5 | Args: 6 | answer: The final answer to the problem 7 | \"\"\" 8 | 9 | def move_mouse(self, x: float, y: float) -> str: 10 | \"\"\" 11 | Moves the mouse cursor to the specified coordinates 12 | Args: 13 | x: The x coordinate (horizontal position) 14 | y: The y coordinate (vertical position) 15 | \"\"\" 16 | 17 | def click(x: Optional[float] = None, y: Optional[float] = None) -> str: 18 | \"\"\" 19 | Performs a left-click at the specified normalized coordinates 20 | Args: 21 | x: The x coordinate (horizontal position) 22 | y: The y coordinate (vertical position) 23 | \"\"\" 24 | 25 | def double_click(x: Optional[float] = None, y: Optional[float] = None) -> str: 26 | \"\"\" 27 | Performs a double-click at the specified normalized coordinates 28 | Args: 29 | x: The x coordinate (horizontal position) 30 | y: The y coordinate (vertical position) 31 | \"\"\" 32 | 33 | def type(text: str) -> str: 34 | \"\"\" 35 | Types the specified text at the current cursor position. 36 | Args: 37 | text: The text to type 38 | \"\"\" 39 | 40 | def press(keys: str | list[str]) -> str: 41 | \"\"\" 42 | Presses a keyboard key 43 | Args: 44 | keys: The key or list of keys to press (e.g. "enter", "space", "backspace", "ctrl", etc.). 45 | \"\"\" 46 | 47 | def navigate_back() -> str: 48 | \"\"\" 49 | Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly. 50 | \"\"\" 51 | 52 | def drag(from_coord: list[float], to_coord: list[float]) -> str: 53 | \"\"\" 54 | Clicks [x1, y1], drags mouse to [x2, y2], then release click. 55 | Args: 56 | x1: origin x coordinate 57 | y1: origin y coordinate 58 | x2: end x coordinate 59 | y2: end y coordinate 60 | \"\"\" 61 | 62 | def scroll(direction: Literal["up", "down"] = "down", amount: int = 1) -> str: 63 | \"\"\" 64 | Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus. 65 | Args: 66 | x: The x coordinate (horizontal position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates 67 | y: The y coordinate (vertical position) of the element to scroll/zoom, defaults to None to not focus on specific coordinates 68 | direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out. 69 | amount: The amount to scroll. A good amount is 1 or 2. 70 | \"\"\" 71 | 72 | def wait(seconds: float) -> str: 73 | \"\"\" 74 | Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps) 75 | Args: 76 | seconds: Number of seconds to wait, generally 2 is enough. 77 | \"\"\" 78 | """ 79 | 80 | MOBILE_ACTIONS = """ 81 | def navigate_back() -> str: 82 | \"\"\" 83 | Return to home page 84 | \"\"\" 85 | 86 | def open_app(app_name: str) -> str: 87 | \"\"\" 88 | Launches the specified application. 89 | Args: 90 | app_name: the name of the application to launch 91 | \"\"\" 92 | 93 | def swipe(from_coord: list[str], to_coord: list[str]) -> str: 94 | \"\"\" 95 | swipe from 'from_coord' to 'to_coord' 96 | Args: 97 | from_coord: origin coordinates 98 | to_coord: end coordinates 99 | \"\"\" 100 | 101 | def long_press(x: int, y: int) -> str: 102 | \"\"\" 103 | Performs a long-press at the specified coordinates 104 | Args: 105 | x: The x coordinate (horizontal position) 106 | y: The y coordinate (vertical position) 107 | \"\"\" 108 | """ 109 | 110 | OS_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls. 111 | 112 | For each step: 113 | • First, to express the thought process guiding your next action and the reasoning behind it. 114 | • Then, use to perform the action. it will be executed in a stateful environment. 115 | 116 | The following functions are exposed to the Python interpreter: 117 | 118 | {OS_ACTIONS} 119 | 120 | 121 | The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. 122 | """ 123 | 124 | MOBILE_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls. 125 | 126 | For each step: 127 | • First, to express the thought process guiding your next action and the reasoning behind it. 128 | • Then, use to perform the action. it will be executed in a stateful environment. 129 | 130 | The following functions are exposed to the Python interpreter: 131 | 132 | 133 | # OS ACTIONS 134 | 135 | {OS_ACTIONS} 136 | 137 | # MOBILE ACTIONS 138 | 139 | {MOBILE_ACTIONS} 140 | 141 | 142 | The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. 143 | """ -------------------------------------------------------------------------------- /preprocessing/aguvis_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models and data structures for the AGUVIS dataset processing. 3 | Contains all Pydantic models, configuration classes, and data validation logic. 4 | """ 5 | 6 | from typing import List, Optional, Literal, Any, Dict 7 | from pydantic import BaseModel, Field, RootModel, field_validator, model_validator 8 | from PIL import Image 9 | import json 10 | from collections import OrderedDict 11 | 12 | 13 | class ChatMessage(BaseModel): 14 | """Represents a single chat message with role and content.""" 15 | role: Literal["user", "assistant", "system"] 16 | content: str 17 | 18 | @staticmethod 19 | def from_conversation_list(data: list[dict[str, str]]) -> list["ChatMessage"]: 20 | """Convert conversation list to ChatMessage objects.""" 21 | messages = [] 22 | system_added = False 23 | for item in data: 24 | if item["from"] == "system": 25 | if not system_added: 26 | role: Literal["user", "assistant", "system"] = "system" 27 | messages.append(ChatMessage(role=role, content=item["value"])) 28 | system_added = True 29 | elif item["from"] == "human": 30 | role = "user" 31 | messages.append(ChatMessage(role=role, content=item["value"])) 32 | else: 33 | role = "assistant" 34 | messages.append(ChatMessage(role=role, content=item["value"])) 35 | 36 | return messages 37 | 38 | 39 | class ConversationEntry(BaseModel): 40 | """Represents a single entry in a conversation.""" 41 | from_: Literal["system", "human", "gpt"] = Field(alias="from") 42 | value: str 43 | recipient: Optional[str] = None 44 | end_turn: Optional[bool] = None 45 | 46 | def to_chat_message(self) -> ChatMessage: 47 | """Convert conversation entry to ChatMessage.""" 48 | if self.from_ == "system": 49 | role: Literal["user", "assistant", "system"] = "system" 50 | elif self.from_ == "human": 51 | role = "user" 52 | else: 53 | role = "assistant" 54 | return ChatMessage(role=role, content=self.value) 55 | 56 | 57 | class ConversationData(BaseModel): 58 | """Represents conversation data with associated image and conversation entries.""" 59 | image: str 60 | conversations: List[ConversationEntry] 61 | recipient: Optional[str] = None 62 | end_turn: Optional[bool] = None 63 | 64 | @field_validator("image", mode="before") 65 | def validate_image(cls, v): 66 | """Validate and normalize image field.""" 67 | if isinstance(v, list): 68 | if len(v) == 1: 69 | return v[0] 70 | elif len(v) == 2: 71 | return v[1] 72 | else: 73 | raise ValueError("Expected 1 or 2 images, got multiple") 74 | return v 75 | 76 | def to_chat_messages(self) -> list[ChatMessage]: 77 | """Convert all conversation entries to ChatMessage objects.""" 78 | return [conversation.to_chat_message() for conversation in self.conversations] 79 | 80 | 81 | class ConversationDataList(RootModel[List[ConversationData]]): 82 | """Root model for a list of conversation data with validation and optional deduplication.""" 83 | 84 | @classmethod 85 | def from_json_with_deduplication(cls, json_str: str, deduplicate: bool = True) -> "ConversationDataList": 86 | """Create instance from JSON with deduplication control.""" 87 | if deduplicate: 88 | # Use normal validation with deduplication 89 | return cls.model_validate_json(json_str) 90 | else: 91 | data = json.loads(json_str) 92 | conversation_data_list = [ConversationData(**item) for item in data] 93 | 94 | # Create instance directly without triggering model validators 95 | instance = cls.__new__(cls) 96 | instance.__dict__.update({'root': conversation_data_list}) 97 | instance.__pydantic_fields_set__ = {'root'} 98 | instance.__pydantic_extra__ = {} 99 | 100 | return instance 101 | 102 | @model_validator(mode="after") 103 | def validate_conversation(self): 104 | """Validate and deduplicate conversations.""" 105 | new_conversations: dict[str, List[ConversationData]] = {} 106 | 107 | # merge image duplicates 108 | for conversation in self.root: 109 | if conversation.image not in new_conversations: 110 | new_conversations[conversation.image] = [conversation] 111 | else: 112 | new_conversations[conversation.image].append(conversation) 113 | 114 | # delete text duplicates 115 | duplicates = 0 116 | for data in new_conversations.values(): 117 | if isinstance(data, list): 118 | index_to_pop = set() 119 | for i in range(len(data) - 1): 120 | for j in range(i + 1, len(data)): 121 | if [c1.model_dump() for c1 in data[i].conversations] == [c2.model_dump() for c2 in data[j].conversations]: 122 | if j not in index_to_pop: 123 | duplicates += 1 124 | index_to_pop.add(j) 125 | for index in sorted(index_to_pop, reverse=True): 126 | data.pop(index) 127 | 128 | # merge conversations for same images 129 | new_data = [] 130 | for data in new_conversations.values(): 131 | for i in range(len(data)): 132 | if i == 0: 133 | new_data.append(data[i]) 134 | else: 135 | new_data[-1].conversations.extend(data[i].conversations) 136 | 137 | self.root = new_data 138 | return self 139 | 140 | 141 | class DataRow(BaseModel): 142 | """Represents a processed data row with images, texts, and source information.""" 143 | images: list[Image.Image] 144 | texts: list[OrderedDict[str, str]] 145 | source: str 146 | 147 | class Config: 148 | arbitrary_types_allowed = True 149 | 150 | @classmethod 151 | def from_chat_messages(cls, messages: list[ChatMessage], image: Image.Image, source: str) -> "DataRow": 152 | """Create DataRow from chat messages and associated image.""" 153 | system, user, assistant = None, None, None 154 | have_system = any(message.role == "system" for message in messages) 155 | texts: list[OrderedDict[str, str]] = [] 156 | images = [image] 157 | chat_messages: OrderedDict[str, str] = OrderedDict() 158 | 159 | for message in messages: 160 | if message.role == "system": 161 | system = message.content 162 | elif message.role == "user": 163 | user = message.content 164 | elif message.role == "assistant": 165 | assistant = message.content 166 | 167 | if have_system and user is not None and assistant is not None and system is not None: 168 | chat_messages["system"] = system 169 | chat_messages["user"] = user 170 | chat_messages["assistant"] = assistant 171 | texts.append(chat_messages) 172 | chat_messages = OrderedDict() 173 | user, assistant = None, None 174 | 175 | elif not have_system and user is not None and assistant is not None: 176 | chat_messages["user"] = user 177 | chat_messages["assistant"] = assistant 178 | texts.append(chat_messages) 179 | chat_messages = OrderedDict() 180 | user, assistant = None, None 181 | 182 | return cls(images=images, texts=texts, source=source) 183 | 184 | def to_model_dump(self) -> dict: 185 | """Convert to dictionary representation.""" 186 | return { 187 | "images": self.images, 188 | "texts": self.texts, 189 | "source": self.source, 190 | } 191 | 192 | 193 | class DatasetConfig(BaseModel): 194 | """Configuration for dataset processing.""" 195 | huggingface_repo_id: str 196 | local_path: str 197 | config_dict: List[Dict[str, Any]] 198 | smolagents_repo_id: str 199 | reasoning: bool 200 | deduplicate: bool = False 201 | 202 | 203 | class ProcessingConfig(BaseModel): 204 | """Configuration for processing parameters.""" 205 | subset_name: str 206 | json_path: str 207 | images_folder: str 208 | 209 | @classmethod 210 | def from_config_dict(cls, config: Dict[str, Any]) -> "ProcessingConfig": 211 | """Create ProcessingConfig from configuration dictionary.""" 212 | subset_name = ( 213 | config["json_path"] 214 | .replace(".json", "") 215 | ) 216 | 217 | return cls( 218 | subset_name=subset_name, 219 | json_path=config["json_path"], 220 | images_folder=config["images_folder"] 221 | ) 222 | 223 | 224 | # Mobile action space files configuration 225 | MOBILE_FILES = [ 226 | "android_control.json", 227 | "aitw-l1.json", 228 | "aitw-l2.json", 229 | "aitw-l3.json", 230 | "coat.json", 231 | "amex-l1.json", 232 | "amex-l2.json", 233 | "amex-l3.json", 234 | "gui-odyssey-l1.json", 235 | "gui-odyssey-l2.json", 236 | "gui-odyssey-l3.json", 237 | ] 238 | 239 | # Stage 1 dataset configuration 240 | CONFIG_DICT_STAGE_1 = [ 241 | { 242 | "json_path": "guienv.json", 243 | "images_folder": "guienvs/images/", 244 | }, 245 | { 246 | "json_path": "omniact.json", 247 | "images_folder": "omniact/images/", 248 | }, 249 | { 250 | "json_path": "ricoig16k.json", 251 | "images_folder": "ricoig16k/images/", 252 | }, 253 | { 254 | "json_path": "ricosca.json", 255 | "images_folder": "ricosca/images/", 256 | }, 257 | { 258 | "json_path": "seeclick.json", 259 | "images_folder": "seeclick/seeclick_web_imgs/", 260 | }, 261 | { 262 | "json_path": "webui350k.json", 263 | "images_folder": "webui350k/images/", 264 | }, 265 | { 266 | "json_path": "ui_refexp.json", 267 | "images_folder": "ui_refexp/images/", 268 | }, 269 | { 270 | "json_path": "widget_captioning.json", 271 | "images_folder": "widget_captioning/images/", 272 | }, 273 | ] 274 | 275 | # Stage 2 dataset configuration 276 | CONFIG_DICT_STAGE_2 = [ 277 | { 278 | "json_path": "mind2web-l1.json", 279 | "images_folder": "mind2web/", 280 | }, 281 | { 282 | "json_path": "mind2web-l2.json", 283 | "images_folder": "mind2web/", 284 | }, 285 | { 286 | "json_path": "mind2web-l3.json", 287 | "images_folder": "mind2web/", 288 | }, 289 | { 290 | "json_path": "guiact-web-single.json", 291 | "images_folder": "guiact-web-single/images/", 292 | }, 293 | { 294 | "json_path": "guiact-web-multi-l1.json", 295 | "images_folder": "guiact-web-multi-v2/images", 296 | }, 297 | { 298 | "json_path": "guiact-web-multi-l2.json", 299 | "images_folder": "guiact-web-multi-v2/images", 300 | }, 301 | { 302 | "json_path": "guiact-web-multi-l3.json", 303 | "images_folder": "guiact-web-multi-v2/images", 304 | }, 305 | { 306 | "json_path": "miniwob-l1.json", 307 | "images_folder": "images", 308 | }, 309 | { 310 | "json_path": "miniwob-l2.json", 311 | "images_folder": "images", 312 | }, 313 | { 314 | "json_path": "miniwob-l3.json", 315 | "images_folder": "images", 316 | }, 317 | { 318 | "json_path": "coat.json", 319 | "images_folder": "coat/images/", 320 | }, 321 | { 322 | "json_path": "android_control.json", 323 | "images_folder": "android_control/images/", 324 | }, 325 | { 326 | "json_path": "gui-odyssey-l1.json", 327 | "images_folder": "gui-odyssey/images/", 328 | }, 329 | { 330 | "json_path": "gui-odyssey-l2.json", 331 | "images_folder": "gui-odyssey/images/", 332 | }, 333 | { 334 | "json_path": "gui-odyssey-l3.json", 335 | "images_folder": "gui-odyssey/images/", 336 | }, 337 | { 338 | "json_path": "amex-l1.json", 339 | "images_folder": "amex/images/", 340 | }, 341 | { 342 | "json_path": "amex-l2.json", 343 | "images_folder": "amex/images/", 344 | }, 345 | { 346 | "json_path": "amex-l3.json", 347 | "images_folder": "amex/images/", 348 | }, 349 | { 350 | "json_path": "aitw-l1.json", 351 | "images_folder": "aitw-v1/images/", 352 | }, 353 | { 354 | "json_path": "aitw-l2.json", 355 | "images_folder": "aitw-v1/images/", 356 | }, 357 | { 358 | "json_path": "aitw-l3.json", 359 | "images_folder": "aitw-v1/images/", 360 | }, 361 | ] 362 | -------------------------------------------------------------------------------- /preprocessing/action_conversion.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from utils.function_parser import FunctionCall 3 | from copy import deepcopy 4 | 5 | # Configure logger for this module 6 | logger = logging.getLogger(__name__) 7 | 8 | """ 9 | ╔══════════════════════════════════════════════════════════════════════════════════════╗ 10 | ║ 🔄 ACTION CONVERSION MAPPINGS 🔄 ║ 11 | ║ Transform Aguvis & PyAutoGUI actions to unified API ║ 12 | ╚══════════════════════════════════════════════════════════════════════════════════════╝ 13 | 14 | 📱 MOBILE ACTIONS (Aguvis → Custom API): 15 | • mobile.home() → navigate_home() 16 | • mobile.open_app(app_name='drupe') → open_app(app_name: str) 17 | • mobile.swipe(from_coord=[x,y], to_coord=[x,y]) → swipe(from_coord: tuple, to_coord: tuple) 18 | • mobile.back() → navigate_back() 19 | • mobile.long_press(x=0.8, y=0.9) → long_press(x: float, y: float) 20 | • mobile.terminate(status='success') → final_answer(answer: str) 21 | • mobile.wait(seconds=3) → wait(seconds: int) 22 | 23 | 💻 DESKTOP ACTIONS (PyAutoGUI → Custom API): 24 | • pyautogui.click(x=0.8, y=0.9) → click(x: float, y: float) 25 | • pyautogui.doubleClick() → double_click() 26 | • pyautogui.rightClick() → right_click() 27 | • pyautogui.hotkey(keys=['ctrl', 'c']) → press(keys: str | list) 28 | • pyautogui.press(keys='enter') → press(keys: str | list) 29 | • pyautogui.moveTo(x=0.04, y=0.4) → move_mouse(x: float, y: float) 30 | • pyautogui.write(message='text') → type(text: str) 31 | • pyautogui.dragTo(from_coord=[x,y], to_coord=[x,y]) → drag(from_coord: tuple, to_coord: tuple) 32 | 33 | 🖱️ SCROLL ACTIONS (Smart Direction Detection): 34 | • pyautogui.scroll(page=-0.1) [negative] → scroll(direction="up", amount=10) 35 | • pyautogui.scroll(page=0.1) [positive] → scroll(direction="down", amount=10) 36 | • pyautogui.hscroll(page=-0.1) [negative] → scroll(direction="left", amount=10) 37 | • pyautogui.hscroll(page=0.1) [positive] → scroll(direction="right", amount=10) 38 | 39 | ✅ COMPLETION ACTIONS: 40 | • answer('text') → final_answer('text') 41 | """ 42 | 43 | 44 | def convert_to_pixel_coordinates(action: FunctionCall, resolution: tuple[int, int]) -> None: 45 | """ 46 | 🎯 Convert normalized coordinates (0.0-1.0) to absolute pixel coordinates. 47 | 48 | Transforms relative coordinates to screen pixels based on the given resolution. 49 | Handles both single coordinates (x, y) and coordinate pairs (from_coord, to_coord). 50 | 51 | Args: 52 | action: FunctionCall object containing coordinate parameters 53 | resolution: Screen resolution as (width, height) in pixels 54 | 55 | Note: Modifies the action object in-place by updating parameter names and values. 56 | """ 57 | if "arg_0" in action.parameters: 58 | if isinstance(action.parameters["arg_0"], (list, tuple)): 59 | action.parameters["from_coord"] = (int(action.parameters["arg_0"][0] * resolution[0]), int(action.parameters["arg_0"][1] * resolution[1])) 60 | else: 61 | action.parameters["x"] = int(action.parameters["arg_0"] * resolution[0]) 62 | del action.parameters["arg_0"] 63 | if "arg_1" in action.parameters: 64 | if isinstance(action.parameters["arg_1"], (list, tuple)): 65 | action.parameters["to_coord"] = (int(action.parameters["arg_1"][0] * resolution[0]), int(action.parameters["arg_1"][1] * resolution[1])) 66 | else: 67 | action.parameters["y"] = int(action.parameters["arg_1"] * resolution[1]) 68 | del action.parameters["arg_1"] 69 | 70 | def change_argument_name(action: FunctionCall) -> None: 71 | """ 72 | 🔄 Transform generic argument names to semantic parameter names. 73 | 74 | Converts arg_0, arg_1 to meaningful names like 'x', 'y', 'from_coord', 'to_coord'. 75 | Maintains coordinate values as floats for normalized coordinate system. 76 | 77 | Args: 78 | action: FunctionCall object with generic argument names 79 | 80 | Note: Modifies the action object in-place, preserving original coordinate values. 81 | """ 82 | if "arg_0" in action.parameters: 83 | if isinstance(action.parameters["arg_0"], (list, tuple)): 84 | action.parameters["from_coord"] = (float(action.parameters["arg_0"][0]), float(action.parameters["arg_0"][1])) 85 | else: 86 | action.parameters["x"] = float(action.parameters["arg_0"]) 87 | del action.parameters["arg_0"] 88 | if "arg_1" in action.parameters: 89 | if isinstance(action.parameters["arg_1"], (list, tuple)): 90 | action.parameters["to_coord"] = (float(action.parameters["arg_1"][0]), float(action.parameters["arg_1"][1])) 91 | else: 92 | action.parameters["y"] = float(action.parameters["arg_1"]) 93 | del action.parameters["arg_1"] 94 | 95 | 96 | def rename_parameters(action: FunctionCall) -> None: 97 | """ 98 | 🏷️ Standardize parameter names to arg_0, arg_1, arg_2 format. 99 | 100 | Converts named parameters to a generic indexed format while preserving 101 | the original parameter order. This creates a uniform interface for 102 | subsequent processing steps. 103 | 104 | Args: 105 | action: FunctionCall object to standardize parameter names for 106 | 107 | Example: 108 | Before: {"x": 0.5, "y": 0.8} → After: {"arg_0": 0.5, "arg_1": 0.8} 109 | """ 110 | if not action.parameters: 111 | return 112 | 113 | for i, (key, value) in enumerate(deepcopy(action.parameters).items()): 114 | tmp = value 115 | del action.parameters[key] 116 | action.parameters[f"arg_{i}"] = tmp 117 | 118 | 119 | 120 | def action_conversion( 121 | actions: list[FunctionCall], resolution: tuple[int, int] 122 | ) -> list[FunctionCall]: 123 | """ 124 | 🚀 Master conversion function: Transform diverse action formats into unified API. 125 | 126 | This is the main orchestrator that converts actions from different sources 127 | (Aguvis mobile actions, PyAutoGUI desktop actions) into a standardized 128 | action format for consistent processing. 129 | 130 | Args: 131 | actions: List of FunctionCall objects to convert 132 | resolution: Screen resolution (width, height) for coordinate conversion 133 | 134 | Returns: 135 | List of converted FunctionCall objects with unified naming and structure 136 | 137 | Features: 138 | • 📱 Mobile action normalization (Aguvis → Custom API) 139 | • 💻 Desktop action standardization (PyAutoGUI → Custom API) 140 | • 🎯 Smart coordinate handling (relative ↔ absolute) 141 | • 🖱️ Intelligent scroll direction detection 142 | • ✅ Consistent error handling and validation 143 | """ 144 | for i, action in enumerate(actions): 145 | rename_parameters(action) 146 | 147 | # ═══════════════════════════════════════════════════════════════ 148 | # 📱 MOBILE ACTIONS (Aguvis Framework) 149 | # ═══════════════════════════════════════════════════════════════ 150 | if action.function_name == "mobile.home": 151 | actions[i].function_name = "navigate_home" 152 | 153 | elif action.function_name == "mobile.open_app": 154 | actions[i].function_name = "open_app" 155 | 156 | elif action.function_name == "mobile.swipe": 157 | actions[i].function_name = "swipe" 158 | change_argument_name(actions[i]) 159 | 160 | elif action.function_name == "mobile.back": 161 | actions[i].function_name = "navigate_back" 162 | 163 | elif action.function_name == "mobile.long_press": 164 | actions[i].function_name = "long_press" 165 | change_argument_name(actions[i]) 166 | 167 | elif action.function_name in ["mobile.terminate", "answer"]: 168 | actions[i].function_name = "final_answer" 169 | 170 | elif action.function_name == "mobile.wait": 171 | actions[i].function_name = "wait" 172 | if "arg_0" in actions[i].parameters: 173 | actions[i].parameters["seconds"] = int(actions[i].parameters["arg_0"]) 174 | del actions[i].parameters["arg_0"] 175 | 176 | # ═══════════════════════════════════════════════════════════════ 177 | # 💻 DESKTOP ACTIONS (PyAutoGUI Framework) 178 | # ═══════════════════════════════════════════════════════════════ 179 | elif action.function_name == "pyautogui.click": 180 | actions[i].function_name = "click" 181 | change_argument_name(actions[i]) 182 | 183 | elif action.function_name == "pyautogui.doubleClick": 184 | actions[i].function_name = "double_click" 185 | change_argument_name(actions[i]) 186 | 187 | elif action.function_name == "pyautogui.rightClick": 188 | actions[i].function_name = "right_click" 189 | change_argument_name(actions[i]) 190 | 191 | elif action.function_name in ["pyautogui.hotkey", "pyautogui.press"]: 192 | actions[i].function_name = "press" 193 | if "arg_0" in actions[i].parameters: 194 | actions[i].parameters["keys"] = actions[i].parameters["arg_0"] 195 | del actions[i].parameters["arg_0"] 196 | 197 | elif action.function_name == "pyautogui.moveTo": 198 | actions[i].function_name = "move_mouse" 199 | change_argument_name(actions[i]) 200 | 201 | elif action.function_name == "pyautogui.write": 202 | actions[i].function_name = "type" 203 | 204 | # ────────────────────────────────────────────────────────────── 205 | # 🖱️ SCROLL ACTIONS (Direction Detection) 206 | # ────────────────────────────────────────────────────────────── 207 | elif action.function_name in ["pyautogui.scroll", "pyautogui.hscroll"]: 208 | arg_value = actions[i].parameters["arg_0"] 209 | if arg_value < 0: 210 | if action.function_name == "pyautogui.hscroll": 211 | actions[i].parameters["direction"] = "left" 212 | else: 213 | actions[i].parameters["direction"] = "up" 214 | else: 215 | if action.function_name == "pyautogui.hscroll": 216 | actions[i].parameters["direction"] = "right" 217 | else: 218 | actions[i].parameters["direction"] = "down" 219 | del actions[i].parameters["arg_0"] 220 | actions[i].function_name = "scroll" 221 | actions[i].parameters["amount"] = int(abs(arg_value * 100)) 222 | 223 | elif action.function_name == "pyautogui.dragTo": 224 | actions[i].function_name = "drag" 225 | change_argument_name(actions[i]) 226 | 227 | else: 228 | raise ValueError(f"🚫 Unsupported action: {action.function_name}") 229 | 230 | # 💾 Preserve original string representation for debugging 231 | actions[i].original_string = actions[i].to_string() 232 | 233 | return actions 234 | 235 | if __name__ == "__main__": 236 | from utils.function_parser import FunctionCall 237 | 238 | """ 239 | ╔════════════════════════════════════════════════════════════════════════════════╗ 240 | ║ 🧪 TESTING & DEMONSTRATION 🧪 ║ 241 | ║ Comprehensive test suite for action conversion ║ 242 | ╚════════════════════════════════════════════════════════════════════════════════╝ 243 | """ 244 | 245 | # 📋 Complete test dataset covering all supported action types 246 | actions = [ 247 | # ═══════════════════════════════════════════════════════════════ 248 | # 📱 MOBILE ACTIONS (Aguvis Framework) 249 | # ═══════════════════════════════════════════════════════════════ 250 | FunctionCall("mobile.home", {}, "mobile.home()"), 251 | FunctionCall("mobile.open_app", {"app_name": "drupe"}, "mobile.open_app(app_name='drupe')"), 252 | FunctionCall("mobile.swipe", {"from_coord": [0.581, 0.898], "to_coord": [0.601, 0.518]}, "mobile.swipe(from_coord=[0.581,0.898],to_coord=[0.601,0.518])"), 253 | FunctionCall("mobile.back", {}, "mobile.back()"), 254 | FunctionCall("mobile.long_press", {"x": 0.799, "y": 0.911}, "mobile.long_press(x=0.799, y=0.911)"), 255 | FunctionCall("mobile.terminate", {"status": "success"}, "mobile.terminate(status='success')"), 256 | FunctionCall("answer", {"arg_0": "text"}, "answer('text')"), 257 | FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)"), 258 | 259 | # ═══════════════════════════════════════════════════════════════ 260 | # 💻 DESKTOP ACTIONS (PyAutoGUI Framework) 261 | # ═══════════════════════════════════════════════════════════════ 262 | FunctionCall("pyautogui.hscroll", {"page": -0.1}, "pyautogui.hscroll(page=-0.1)"), 263 | FunctionCall("pyautogui.scroll", {"page": 0.13}, "pyautogui.scroll(page=0.13)"), 264 | FunctionCall("pyautogui.click", {"x": 0.8102, "y": 0.9463}, "pyautogui.click(x=0.8102, y=0.9463)"), 265 | FunctionCall("pyautogui.doubleClick", {}, "pyautogui.doubleClick()"), 266 | FunctionCall("pyautogui.hotkey", {"keys": ["ctrl", "c"]}, "pyautogui.hotkey(keys=['ctrl','c'])"), 267 | FunctionCall("pyautogui.press", {"keys": "enter"}, "pyautogui.press(keys='enter')"), 268 | FunctionCall("pyautogui.moveTo", {"x": 0.04, "y": 0.405}, "pyautogui.moveTo(x=0.04, y=0.405)"), 269 | FunctionCall("pyautogui.write", {"message": "bread buns"}, "pyautogui.write(message='bread buns')"), 270 | FunctionCall("pyautogui.dragTo", {"from_coord": [0.87, 0.423], "to_coord": [0.8102, 0.9463]}, "pyautogui.dragTo(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463])"), 271 | ] 272 | 273 | # 🖥️ Test resolution (Full HD Portrait - typical mobile orientation) 274 | resolution = (1080, 1920) 275 | 276 | logger.info("🔄 BEFORE CONVERSION:") 277 | logger.info("═" * 50) 278 | for action in actions: 279 | logger.info(f" 📋 {action}") 280 | 281 | logger.info(f"\n🚀 AFTER CONVERSION (Resolution: {resolution}):") 282 | logger.info("═" * 50) 283 | converted = action_conversion(actions, resolution) 284 | for action in converted: 285 | logger.info(f" ✅ {action}") 286 | -------------------------------------------------------------------------------- /utils/action_space_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Flexible Action Space Converter 4 | 5 | A configurable system that allows users to define custom action space mappings 6 | for transforming unified API actions to their own custom action formats. 7 | This enables users to create domain-specific action spaces for training 8 | assistants with different action vocabularies. 9 | """ 10 | 11 | from __future__ import annotations 12 | import logging 13 | from typing import List, Callable, Any, Optional 14 | from copy import deepcopy 15 | from utils.function_parser import FunctionCall 16 | from pydantic import BaseModel 17 | 18 | logging.basicConfig(level=logging.INFO) 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class ParameterMapping(BaseModel): 25 | """Defines how to map parameters from source to target function.""" 26 | source_name: str 27 | target_name: str 28 | transform: Optional[Callable[[Any], Any]] = None 29 | default_value: Optional[Any] = None 30 | 31 | 32 | class ActionMapping(BaseModel): 33 | """Defines how to convert one action to another.""" 34 | source_function: str 35 | target_function: str 36 | parameter_mappings: List[ParameterMapping] 37 | custom_transform: Optional[Callable[[FunctionCall], FunctionCall]] = None 38 | description: str = "" 39 | 40 | 41 | class ActionSpaceConverter: 42 | """ 43 | Flexible Action Space Converter 44 | 45 | Allows users to define custom action space mappings to transform 46 | the unified API actions into their own custom action formats. 47 | """ 48 | 49 | def __init__(self, mappings: List[ActionMapping]): 50 | """ 51 | Initialize the converter with action mappings. 52 | 53 | Args: 54 | mappings: List of ActionMapping objects defining the conversions 55 | """ 56 | self.mappings = {mapping.source_function: mapping for mapping in mappings} 57 | self._validate_mappings() 58 | 59 | def _validate_mappings(self) -> None: 60 | """Validate that all mappings are properly configured.""" 61 | for source_func, mapping in self.mappings.items(): 62 | if not mapping.target_function: 63 | raise ValueError(f"Target function not specified for '{source_func}'") 64 | 65 | # Check for duplicate parameter targets 66 | target_params = [pm.target_name for pm in mapping.parameter_mappings] 67 | if len(target_params) != len(set(target_params)): 68 | raise ValueError(f"Duplicate target parameter names in mapping for '{source_func}'") 69 | 70 | def convert_actions(self, actions: List[FunctionCall]) -> List[FunctionCall]: 71 | """ 72 | Convert a list of actions using the defined mappings. 73 | 74 | Args: 75 | actions: List of FunctionCall objects to convert 76 | 77 | Returns: 78 | List of converted FunctionCall objects 79 | 80 | Raises: 81 | ValueError: If an unsupported action is encountered 82 | """ 83 | converted_actions = [] 84 | 85 | for action in actions: 86 | try: 87 | converted_action = self._convert_single_action(action) 88 | converted_actions.append(converted_action) 89 | except Exception as e: 90 | logger.error(f"Failed to convert action '{action.function_name}': {e}") 91 | raise 92 | 93 | return converted_actions 94 | 95 | def _convert_single_action(self, action: FunctionCall) -> FunctionCall: 96 | """ 97 | Convert a single action using the mapping. 98 | 99 | Args: 100 | action: FunctionCall object to convert 101 | 102 | Returns: 103 | Converted FunctionCall object 104 | """ 105 | if action.function_name not in self.mappings: 106 | raise ValueError(f"Unsupported action: {action.function_name}") 107 | 108 | mapping = self.mappings[action.function_name] 109 | 110 | if mapping.custom_transform: 111 | return mapping.custom_transform(deepcopy(action)) 112 | 113 | new_parameters = {} 114 | 115 | for param_mapping in mapping.parameter_mappings: 116 | source_value = action.parameters.get(param_mapping.source_name, param_mapping.default_value) 117 | 118 | if source_value is None and param_mapping.default_value is None: 119 | # Skip missing optional parameters 120 | continue 121 | 122 | if param_mapping.transform: 123 | source_value = param_mapping.transform(source_value) 124 | 125 | new_parameters[param_mapping.target_name] = source_value 126 | 127 | # Create new action 128 | new_action = FunctionCall( 129 | function_name=mapping.target_function, 130 | parameters=new_parameters, 131 | original_string="", 132 | description=mapping.description 133 | ) 134 | 135 | # Update the original string representation 136 | new_action.original_string = new_action.to_string() 137 | 138 | return new_action 139 | 140 | def add_mapping(self, mapping: ActionMapping) -> None: 141 | """Add a new action mapping to the converter.""" 142 | self.mappings[mapping.source_function] = mapping 143 | self._validate_mappings() 144 | 145 | def remove_mapping(self, source_function: str) -> None: 146 | """Remove an action mapping from the converter.""" 147 | if source_function in self.mappings: 148 | del self.mappings[source_function] 149 | 150 | def get_supported_actions(self) -> List[str]: 151 | """Get list of supported source actions.""" 152 | return list(self.mappings.keys()) 153 | 154 | def get_mapping_info(self, source_function: str) -> Optional[ActionMapping]: 155 | """Get mapping information for a specific source function.""" 156 | return self.mappings.get(source_function) 157 | 158 | 159 | def create_default_unified_to_custom_converter() -> ActionSpaceConverter: 160 | """ 161 | 🏭 Factory function to create a converter from unified API to a custom action space. 162 | 163 | This demonstrates how to create custom action space mappings. 164 | Users can modify this or create their own mapping configurations. 165 | 166 | Returns: 167 | ActionSpaceConverter configured with example custom mappings 168 | """ 169 | 170 | # Example custom action space mappings 171 | mappings = [ 172 | # Navigation actions 173 | ActionMapping( 174 | source_function="navigate_home", 175 | target_function="go_home", 176 | parameter_mappings=[], 177 | description="Navigate to home screen" 178 | ), 179 | 180 | ActionMapping( 181 | source_function="navigate_back", 182 | target_function="go_back", 183 | parameter_mappings=[], 184 | description="Navigate back" 185 | ), 186 | 187 | # App interaction 188 | ActionMapping( 189 | source_function="open_app", 190 | target_function="launch_application", 191 | parameter_mappings=[ 192 | ParameterMapping(source_name="arg_0", target_name="application_name") 193 | ], 194 | description="Launch an application" 195 | ), 196 | 197 | # Touch interactions 198 | ActionMapping( 199 | source_function="click", 200 | target_function="touch", 201 | parameter_mappings=[ 202 | ParameterMapping(source_name="x", target_name="x_coord"), 203 | ParameterMapping(source_name="y", target_name="y_coord") 204 | ], 205 | description="Touch screen at coordinates" 206 | ), 207 | 208 | ActionMapping( 209 | source_function="long_press", 210 | target_function="long_touch", 211 | parameter_mappings=[ 212 | ParameterMapping(source_name="x", target_name="x_coord"), 213 | ParameterMapping(source_name="y", target_name="y_coord"), 214 | ParameterMapping(source_name="duration", target_name="hold_time", default_value=1.0) 215 | ], 216 | description="Long touch screen at coordinates" 217 | ), 218 | 219 | # Gesture actions 220 | ActionMapping( 221 | source_function="swipe", 222 | target_function="gesture_swipe", 223 | parameter_mappings=[ 224 | ParameterMapping(source_name="from_coord", target_name="start_point"), 225 | ParameterMapping(source_name="to_coord", target_name="end_point"), 226 | ], 227 | description="Swipe gesture between two points" 228 | ), 229 | 230 | # Scroll actions with custom direction mapping 231 | ActionMapping( 232 | source_function="scroll", 233 | target_function="scroll_view", 234 | parameter_mappings=[ 235 | ParameterMapping( 236 | source_name="direction", 237 | target_name="scroll_direction", 238 | transform=lambda x: {"up": "north", "down": "south", "left": "west", "right": "east"}.get(x, x) 239 | ), 240 | ParameterMapping(source_name="amount", target_name="scroll_distance") 241 | ], 242 | description="Scroll view in specified direction" 243 | ), 244 | 245 | # Input actions 246 | ActionMapping( 247 | source_function="type", 248 | target_function="input_text", 249 | parameter_mappings=[ 250 | ParameterMapping(source_name="text", target_name="content"), 251 | ], 252 | description="Input text" 253 | ), 254 | 255 | ActionMapping( 256 | source_function="press", 257 | target_function="key_press", 258 | parameter_mappings=[ 259 | ParameterMapping(source_name="keys", target_name="key_combination") 260 | ], 261 | description="Press key combination" 262 | ), 263 | 264 | # Mouse actions 265 | ActionMapping( 266 | source_function="move_mouse", 267 | target_function="cursor_move", 268 | parameter_mappings=[ 269 | ParameterMapping(source_name="x", target_name="x_position"), 270 | ParameterMapping(source_name="y", target_name="y_position") 271 | ], 272 | description="Move cursor to position" 273 | ), 274 | 275 | ActionMapping( 276 | source_function="double_click", 277 | target_function="double_touch", 278 | parameter_mappings=[ 279 | ParameterMapping(source_name="x", target_name="x_coord", default_value=0.5), 280 | ParameterMapping(source_name="y", target_name="y_coord", default_value=0.5) 281 | ], 282 | description="Double touch at coordinates" 283 | ), 284 | 285 | ActionMapping( 286 | source_function="right_click", 287 | target_function="context_menu", 288 | parameter_mappings=[ 289 | ParameterMapping(source_name="x", target_name="x_coord", default_value=0.5), 290 | ParameterMapping(source_name="y", target_name="y_coord", default_value=0.5) 291 | ], 292 | description="Open context menu" 293 | ), 294 | 295 | ActionMapping( 296 | source_function="drag", 297 | target_function="drag_and_drop", 298 | parameter_mappings=[ 299 | ParameterMapping(source_name="from_coord", target_name="start_position"), 300 | ParameterMapping(source_name="to_coord", target_name="end_position") 301 | ], 302 | description="Drag and drop operation" 303 | ), 304 | 305 | # Timing and completion 306 | ActionMapping( 307 | source_function="wait", 308 | target_function="pause", 309 | parameter_mappings=[ 310 | ParameterMapping(source_name="seconds", target_name="duration") 311 | ], 312 | description="Pause execution for specified duration" 313 | ), 314 | 315 | ActionMapping( 316 | source_function="final_answer", 317 | target_function="complete_task", 318 | parameter_mappings=[ 319 | ParameterMapping(source_name="arg_0", target_name="answer") 320 | ], 321 | description="Complete task with result" 322 | ), 323 | ] 324 | 325 | return ActionSpaceConverter(mappings) 326 | 327 | def convert_assistant(chat_message: dict, converter: ActionSpaceConverter) -> dict: 328 | """ 329 | Convert function calls in assistant messages to sentence format. 330 | 331 | Args: 332 | chat_message: Dictionary with format {"user": "...", "assistant": "..."} 333 | 334 | Returns: 335 | Updated chat message with function calls converted to sentences 336 | """ 337 | from utils.function_parser import parse_function_call 338 | 339 | if "assistant" not in chat_message: 340 | return chat_message 341 | 342 | assistant_message = chat_message["assistant"] 343 | 344 | # Parse function calls from the assistant message 345 | old_function_calls = parse_function_call(assistant_message) 346 | new_function_calls = converter.convert_actions(old_function_calls) 347 | 348 | # Replace each function call with its sentence format 349 | updated_message = assistant_message 350 | for new_function_call, old_function_call in zip(new_function_calls, old_function_calls): 351 | updated_message = updated_message.replace(old_function_call.to_string(), new_function_call.to_string()) 352 | 353 | # Return updated chat message 354 | chat_message["assistant"] = updated_message 355 | return chat_message 356 | 357 | 358 | # Testing and demonstration 359 | if __name__ == "__main__": 360 | 361 | logger.info("🧪 Testing Action Space Converter") 362 | logger.info("=" * 50) 363 | 364 | # Test with default custom converter 365 | converter = create_default_unified_to_custom_converter() 366 | 367 | chat_history = [ 368 | {"user": "Click on the home button", "assistant": "I'll click on the home button for you. navigate_home()"}, 369 | {"user": "Type hello world", "assistant": "I'll type that text for you. type(text='hello world')"}, 370 | {"user": "Click at coordinates 0.5, 0.8", "assistant": "I'll click at those coordinates. click(x=0.5, y=0.8)"}, 371 | {"user": "Scroll up by 10 units", "assistant": "I'll scroll up for you. scroll(direction='up', amount=10)"} 372 | ] 373 | 374 | # Test the function with chat history 375 | logger.info("🧪 Testing Chat Message Function Call Conversion") 376 | logger.info("=" * 60) 377 | 378 | for i, chat_msg in enumerate(chat_history): 379 | logger.info(f"\n📩 Chat Message {i+1}:") 380 | logger.info(f" User: {chat_msg['user']}") 381 | logger.info(f" Original Assistant: {chat_msg['assistant']}") 382 | 383 | converted_msg = convert_assistant(chat_msg, converter) 384 | logger.info(f" Converted Assistant: {converted_msg['assistant']}") -------------------------------------------------------------------------------- /utils/function_parser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Function parser for extracting function names, parameter names, and values from string function calls. 4 | Supports both mobile and pyautogui function patterns. 5 | """ 6 | 7 | import re 8 | from typing import Dict, List, Tuple, Any 9 | from collections import OrderedDict 10 | from pydantic import BaseModel 11 | 12 | class FunctionCall(BaseModel): 13 | """Represents a parsed function call with its parameters.""" 14 | function_name: str 15 | parameters: Dict[str, Any] 16 | original_string: str 17 | description: str = "" 18 | 19 | def to_string(self) -> str: 20 | """ 21 | Reconstruct the function call string from the parsed data. 22 | 23 | Returns: 24 | String representation of the function call 25 | 26 | Examples: 27 | >>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)") 28 | >>> call.to_string() 29 | "mobile.wait(seconds=3)" 30 | 31 | >>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)") 32 | >>> call.to_string() 33 | "function(1, 2, x=0.5)" 34 | """ 35 | if not self.parameters: 36 | return f"{self.function_name}()" 37 | 38 | # Separate positional and named arguments 39 | positional_args = [] 40 | named_args = [] 41 | 42 | for name, value in self.parameters.items(): 43 | if name.startswith("arg_"): 44 | # Positional argument 45 | positional_args.append((int(name.split("_")[1]), value)) 46 | else: 47 | # kwargs 48 | named_args.append((name, value)) 49 | 50 | # Sort positional arguments by index 51 | positional_args.sort(key=lambda x: x[0]) 52 | 53 | # Build parameter string 54 | param_parts = [] 55 | 56 | # Add positional arguments 57 | for _, value in positional_args: 58 | param_parts.append(self._value_to_string(value)) 59 | 60 | # Add named arguments 61 | for name, value in named_args: 62 | param_parts.append(f"{name}={self._value_to_string(value)}") 63 | 64 | return f"{self.function_name}({', '.join(param_parts)})" 65 | 66 | def _value_to_string(self, value: Any) -> str: 67 | """ 68 | Convert a value to its string representation for function calls. 69 | 70 | Args: 71 | value: The value to convert 72 | 73 | Returns: 74 | String representation of the value 75 | """ 76 | if isinstance(value, str): 77 | # Quote strings 78 | return f"'{value}'" 79 | elif isinstance(value, (list, tuple)): 80 | # Convert lists/tuples to string representation 81 | items = [self._value_to_string(item) for item in value] 82 | return f"[{', '.join(items)}]" 83 | elif isinstance(value, dict): 84 | # Convert dictionaries to string representation 85 | items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()] 86 | return f"{{{', '.join(items)}}}" 87 | elif isinstance(value, bool): 88 | # Convert booleans to lowercase 89 | return str(value).lower() 90 | elif value is None: 91 | return "None" 92 | else: 93 | # Numbers and other types 94 | return str(value) 95 | 96 | 97 | def parse_function_call(function_string: str, pattern_to_match: list[str] = []) -> List[FunctionCall]: 98 | """ 99 | Parse a function call string and extract all function calls found. 100 | 101 | Args: 102 | function_string: String representation of function calls 103 | 104 | Returns: 105 | List of FunctionCall objects with parsed information 106 | 107 | Examples: 108 | >>> parse_function_call("mobile.wait(seconds=3)") 109 | [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)] 110 | 111 | >>> parse_function_call("mobile. wait(seconds=3)") 112 | [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)] 113 | 114 | >>> parse_function_call("mobile.wait(seconds=3) mobile.home()") 115 | [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)] 116 | """ 117 | # Remove any leading/trailing whitespace 118 | function_string = function_string.strip() 119 | 120 | # Pattern to match function calls with parameters 121 | # Matches: function_name(param1=value1, param2=value2, ...) 122 | # Can have any characters before the function call, extracts just the function name 123 | pattern = r'.*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)' 124 | 125 | matches = re.findall(pattern, function_string) 126 | if not matches: 127 | # No valid function calls found in: {function_string} 128 | return [] 129 | 130 | results = [] 131 | for match in matches: 132 | function_name = match[0] 133 | params_string = match[1] 134 | 135 | if pattern_to_match and all(pattern not in function_name for pattern in pattern_to_match): 136 | continue 137 | 138 | # Parse parameters 139 | parameters = parse_parameters(params_string) 140 | 141 | # Create the original string for this specific function call 142 | original_string = f"{function_name}({params_string})" 143 | 144 | results.append(FunctionCall( 145 | function_name=function_name, 146 | parameters=parameters, 147 | original_string=original_string 148 | )) 149 | 150 | return results 151 | 152 | 153 | def parse_parameters(params_string: str) -> Dict[str, Any]: 154 | """ 155 | Parse parameter string and extract parameter names and values. 156 | 157 | Args: 158 | params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'") 159 | 160 | Returns: 161 | Dictionary mapping parameter names to their values 162 | 163 | Examples: 164 | >>> parse_parameters("x=0.5, y=0.6") 165 | {'x': 0.5, 'y': 0.6} 166 | 167 | >>> parse_parameters("app_name='drupe'") 168 | {'app_name': 'drupe'} 169 | 170 | >>> parse_parameters("'text'") 171 | {'arg_0': 'text'} 172 | 173 | >>> parse_parameters("1, 3, 4") 174 | {'arg_0': 1, 'arg_1': 3, 'arg_2': 4} 175 | 176 | >>> parse_parameters("arg1, arg2, x=0.5") 177 | {'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5} 178 | """ 179 | if not params_string.strip(): 180 | return {} 181 | 182 | parameters = OrderedDict() 183 | 184 | # Split by commas, but be careful with commas inside quotes or brackets 185 | param_parts = split_parameters(params_string) 186 | 187 | positional_index = 0 188 | 189 | for part in param_parts: 190 | part = part.strip() 191 | if not part: 192 | continue 193 | 194 | # Parse individual parameter 195 | name, value = parse_single_parameter(part) 196 | 197 | # For positional arguments, use index-based naming 198 | if name.startswith("arg_"): 199 | name = f"arg_{positional_index}" 200 | positional_index += 1 201 | 202 | parameters[name] = value 203 | 204 | return parameters 205 | 206 | 207 | def split_parameters(params_string: str) -> List[str]: 208 | """ 209 | Split parameter string by commas, respecting quotes and brackets. 210 | 211 | Args: 212 | params_string: String containing parameters 213 | 214 | Returns: 215 | List of individual parameter strings 216 | """ 217 | parts = [] 218 | current_part = "" 219 | paren_count = 0 220 | bracket_count = 0 221 | brace_count = 0 222 | in_quotes = False 223 | quote_char = None 224 | 225 | for char in params_string: 226 | if char in ['"', "'"] and (not in_quotes or char == quote_char): 227 | if not in_quotes: 228 | in_quotes = True 229 | quote_char = char 230 | else: 231 | in_quotes = False 232 | quote_char = None 233 | elif not in_quotes: 234 | if char == '(': 235 | paren_count += 1 236 | elif char == ')': 237 | paren_count -= 1 238 | elif char == '[': 239 | bracket_count += 1 240 | elif char == ']': 241 | bracket_count -= 1 242 | elif char == '{': 243 | brace_count += 1 244 | elif char == '}': 245 | brace_count -= 1 246 | elif char == ',' and paren_count == 0 and bracket_count == 0 and brace_count == 0: 247 | parts.append(current_part.strip()) 248 | current_part = "" 249 | continue 250 | 251 | current_part += char 252 | 253 | if current_part.strip(): 254 | parts.append(current_part.strip()) 255 | 256 | return parts 257 | 258 | 259 | def parse_single_parameter(param_string: str) -> Tuple[str, Any]: 260 | """ 261 | Parse a single parameter string into name and value. 262 | 263 | Args: 264 | param_string: String like "x=0.5" or "app_name='drupe'" or just "value" 265 | 266 | Returns: 267 | Tuple of (parameter_name, parameter_value) 268 | 269 | Examples: 270 | >>> parse_single_parameter("x=0.5") 271 | ('x', 0.5) 272 | 273 | >>> parse_single_parameter("app_name='drupe'") 274 | ('app_name', 'drupe') 275 | 276 | >>> parse_single_parameter("'text'") 277 | ('arg_0', 'text') 278 | 279 | >>> parse_single_parameter("123") 280 | ('arg_0', 123) 281 | 282 | >>> parse_single_parameter("3") 283 | ('arg_0', 3) 284 | """ 285 | # Pattern to match parameter name and value 286 | pattern = r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$' 287 | 288 | match = re.match(pattern, param_string) 289 | if match: 290 | # Named parameter 291 | param_name = match.group(1) 292 | param_value_str = match.group(2).strip() 293 | param_value = parse_value(param_value_str) 294 | return param_name, param_value 295 | else: 296 | # Positional parameter - treat as unnamed argument 297 | param_value = parse_value(param_string) 298 | return "arg_0", param_value 299 | 300 | 301 | def parse_value(value_string: str) -> Any: 302 | """ 303 | Parse a value string into appropriate Python type. 304 | 305 | Args: 306 | value_string: String representation of a value 307 | 308 | Returns: 309 | Parsed value (int, float, str, list, etc.) 310 | 311 | Examples: 312 | >>> parse_value("3") 313 | 3 314 | 315 | >>> parse_value("3.14") 316 | 3.14 317 | 318 | >>> parse_value("'hello'") 319 | 'hello' 320 | 321 | >>> parse_value("[0.581, 0.898]") 322 | [0.581, 0.898] 323 | """ 324 | value_string = value_string.strip() 325 | 326 | # String values (quoted) 327 | if (value_string.startswith("'") and value_string.endswith("'")) or \ 328 | (value_string.startswith('"') and value_string.endswith('"')): 329 | return value_string[1:-1] 330 | 331 | # List values 332 | if value_string.startswith('[') and value_string.endswith(']'): 333 | return parse_list(value_string) 334 | 335 | # Dictionary values 336 | if value_string.startswith('{') and value_string.endswith('}'): 337 | return parse_dict(value_string) 338 | 339 | # Boolean values 340 | if value_string.lower() in ['true', 'false']: 341 | return value_string.lower() == 'true' 342 | 343 | # None value 344 | if value_string.lower() == 'none': 345 | return None 346 | 347 | # Numeric values 348 | try: 349 | # Try integer first 350 | if '.' not in value_string: 351 | return int(value_string) 352 | else: 353 | return float(value_string) 354 | except ValueError: 355 | # If it's not a number, return as string (remove quotes if present) 356 | if value_string.startswith("'") and value_string.endswith("'"): 357 | return value_string[1:-1] 358 | elif value_string.startswith('"') and value_string.endswith('"'): 359 | return value_string[1:-1] 360 | else: 361 | return value_string 362 | 363 | 364 | def parse_list(list_string: str) -> List[Any]: 365 | """ 366 | Parse a list string into a Python list. 367 | 368 | Args: 369 | list_string: String like "[0.581, 0.898]" 370 | 371 | Returns: 372 | List of parsed values 373 | 374 | Examples: 375 | >>> parse_list("[0.581, 0.898]") 376 | [0.581, 0.898] 377 | """ 378 | # Remove outer brackets 379 | content = list_string[1:-1].strip() 380 | if not content: 381 | return [] 382 | 383 | # Split by commas, respecting nested structures 384 | parts = split_parameters(content) 385 | 386 | return [parse_value(part.strip()) for part in parts] 387 | 388 | 389 | def parse_dict(dict_string: str) -> Dict[str, Any]: 390 | """ 391 | Parse a dictionary string into a Python dict. 392 | 393 | Args: 394 | dict_string: String like "{'key': 'value'}" 395 | 396 | Returns: 397 | Dictionary of parsed key-value pairs 398 | """ 399 | # Remove outer braces 400 | content = dict_string[1:-1].strip() 401 | if not content: 402 | return {} 403 | 404 | # Split by commas, respecting nested structures 405 | parts = split_parameters(content) 406 | 407 | result = {} 408 | for part in parts: 409 | part = part.strip() 410 | if ':' in part: 411 | key_str, value_str = part.split(':', 1) 412 | key = parse_value(key_str.strip()) 413 | value = parse_value(value_str.strip()) 414 | result[key] = value 415 | 416 | return result 417 | 418 | 419 | def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]: 420 | """ 421 | Parse multiple function call strings. 422 | 423 | Args: 424 | function_strings: List of function call strings 425 | 426 | Returns: 427 | List of FunctionCall objects 428 | """ 429 | results = [] 430 | for func_str in function_strings: 431 | try: 432 | result_list = parse_function_call(func_str) 433 | results.extend(result_list) 434 | except Exception as e: 435 | print(f"Warning: Could not parse function call '{func_str}': {e}") 436 | continue 437 | 438 | return results 439 | 440 | 441 | def extract_function_calls_from_text(text: str) -> List[FunctionCall]: 442 | """ 443 | Extract and parse function calls from a text block. 444 | 445 | Args: 446 | text: Text containing function calls 447 | 448 | Returns: 449 | List of FunctionCall objects 450 | """ 451 | # Pattern to find function calls in text 452 | # Matches: function_name(param1=value1, param2=value2) 453 | pattern = r'[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)' 454 | 455 | matches = re.findall(pattern, text) 456 | return parse_multiple_functions(matches) 457 | 458 | 459 | # Example usage and testing 460 | if __name__ == "__main__": 461 | test_cases = [ 462 | "mobile.home()", 463 | "mobile.open_app(app_name='drupe')", 464 | "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])", 465 | "mobile.back()", 466 | "mobile.long_press(x=0.799, y=0.911)", 467 | "mobile.terminate(status='success')", 468 | "answer('text')", 469 | "pyautogui.hscroll(page=-0.1)", 470 | "pyautogui.scroll(page=-0.1)", 471 | "pyautogui.scroll(0.13)", 472 | "pyautogui.click(x=0.8102, y=0.9463)", 473 | "pyautogui.hotkey(keys=['ctrl', 'c'])", 474 | "pyautogui.doubleClick()", 475 | "pyautogui.press(keys='enter')", 476 | "pyautogui.press(keys=['enter'])", 477 | "pyautogui.moveTo(x=0.04, y=0.405)", 478 | "pyautogui.write(message='bread buns')", 479 | "pyautogui.dragTo(x=0.8102, y=0.9463)", 480 | "mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])", 481 | # Additional test cases for multiple positional arguments 482 | "function(arg1, arg2, arg3)", 483 | "function('hello', 123, x=0.5)", 484 | "function(arg1, arg2, named_param='value')", 485 | "function(1, 2, 3, 4, 5)", 486 | "function('a', 'b', 'c', x=1, y=2)", 487 | ] 488 | 489 | print("Testing function parser:") 490 | print("=" * 50) 491 | 492 | for test_case in test_cases: 493 | try: 494 | results = parse_function_call(test_case) 495 | print(f"✓ {test_case}") 496 | for result in results: 497 | print(f" Function: {result.function_name}") 498 | print(f" Parameters: {result.parameters}") 499 | print() 500 | except Exception as e: 501 | print(f"✗ {test_case}") 502 | print(f" Error: {e}") 503 | print() 504 | 505 | # Test extracting from text 506 | print("Testing text extraction:") 507 | print("=" * 50) 508 | 509 | sample_text = """ 510 | mobile.wait(seconds=3) 511 | mobile.open_app(app_name='drupe') 512 | pyautogui.click(x=0.8102, y=0.9463) 513 | pyautogui.write(message='bread buns') 514 | """ 515 | 516 | extracted = extract_function_calls_from_text(sample_text) 517 | for func_call in extracted: 518 | print(f"Found: {func_call.function_name} with params: {func_call.parameters}") 519 | 520 | # Test reconstruction 521 | print("\nTesting function call reconstruction:") 522 | print("=" * 50) 523 | 524 | reconstruction_tests = [ 525 | "mobile.wait(seconds=3)", 526 | "mobile.home()", 527 | "mobile.open_app(app_name='drupe')", 528 | "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])", 529 | "answer('text')", 530 | "pyautogui.scroll(0.13)", 531 | "pyautogui.click(x=0.8102, y=0.9463)", 532 | "pyautogui.hotkey(keys=['ctrl', 'c'])", 533 | "function(1, 2, 3)", 534 | "function('hello', 123, x=0.5, y=0.8)", 535 | "function([1, 3], 'arg2', named_param='value')", 536 | ] 537 | 538 | for test_case in reconstruction_tests: 539 | parsed_list = parse_function_call(test_case) 540 | for parsed in parsed_list: 541 | reconstructed = parsed.to_string() 542 | print(f"Original: {test_case}") 543 | print(f"Reconstructed: {reconstructed}") 544 | print(f"Match: {test_case == reconstructed}") 545 | assert test_case == reconstructed 546 | print() -------------------------------------------------------------------------------- /preprocessing/aguvis_processor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | AGUVIS Dataset Processor Module 4 | 5 | Downloading, processing, and uploading the aguvis-stage1/2 datasets. 6 | Downloads from huggingface.co/datasets/xlangai/aguvis-stage1/2 and uploads to smolagents/aguvis-stage-1/2 7 | """ 8 | 9 | import logging 10 | import re 11 | import gc 12 | import os 13 | import zipfile 14 | import tarfile 15 | from copy import deepcopy 16 | from pathlib import Path 17 | from typing import Any, Dict, List, Generator, Optional 18 | from itertools import islice 19 | from multiprocessing import Pool 20 | from collections import defaultdict 21 | 22 | from tqdm import tqdm 23 | from datasets import Dataset, concatenate_datasets, get_dataset_config_names 24 | from dotenv import load_dotenv 25 | from huggingface_hub import login, snapshot_download 26 | from PIL import Image 27 | 28 | from utils.function_parser import parse_function_call 29 | from preprocessing.prompts import OS_SYSTEM_PROMPT, MOBILE_SYSTEM_PROMPT 30 | from preprocessing.action_conversion import action_conversion 31 | from preprocessing.aguvis_models import ( 32 | ConversationDataList, 33 | ConversationData, 34 | ChatMessage, 35 | DataRow, 36 | DatasetConfig, 37 | ProcessingConfig, 38 | MOBILE_FILES, 39 | CONFIG_DICT_STAGE_1, 40 | CONFIG_DICT_STAGE_2 41 | ) 42 | 43 | # Configure logger for this module 44 | logger = logging.getLogger(__name__) 45 | 46 | def huggingface_authenticate(): 47 | """Authenticate with HuggingFace Hub using token.""" 48 | hf_token = os.getenv("HF_TOKEN") 49 | if hf_token: 50 | logger.info("Authenticating with HuggingFace Hub using token...") 51 | login(token=hf_token) 52 | else: 53 | raise ValueError("HF_TOKEN environment variable not set.") 54 | 55 | class DatasetDownloader: 56 | """Handles dataset downloading and extraction operations.""" 57 | 58 | def download_dataset(self, repo_id: str, local_dir: str) -> str: 59 | """Download the dataset using snapshot_download.""" 60 | logger.info(f"Downloading dataset from {repo_id}...") 61 | local_path = snapshot_download( 62 | repo_id=repo_id, local_dir=local_dir, repo_type="dataset" 63 | ) 64 | logger.info(f"Dataset downloaded to: {local_path}") 65 | return local_path 66 | 67 | def extract_zip_files(self, dataset_path: str): 68 | """Extract all zip files found in the dataset directory, but only if not already extracted.""" 69 | logger.info("Extracting zip files...") 70 | dataset_dir = Path(dataset_path) 71 | 72 | for zip_file in dataset_dir.rglob("*.zip"): 73 | extract_dir = zip_file.parent / zip_file.stem 74 | if extract_dir.exists() and any(extract_dir.iterdir()): 75 | logger.info( 76 | f"Skipping extraction for {zip_file} (already extracted at {extract_dir})" 77 | ) 78 | continue 79 | 80 | logger.info(f"Extracting: {zip_file}") 81 | with zipfile.ZipFile(zip_file, "r") as zip_ref: 82 | zip_ref.extractall(extract_dir) 83 | logger.info(f"Extracted to: {extract_dir}") 84 | 85 | def extract_tar_parts_grouped(self, dataset_path: str): 86 | """ 87 | Finds all .tar.gz.part_* groups, merges them, and extracts them into directories 88 | named after their common prefix. 89 | """ 90 | dataset_dir = Path(dataset_path) 91 | part_files = list(dataset_dir.glob("*.tar.gz.part_*")) 92 | 93 | if not part_files: 94 | logger.info("No split .tar.gz.part_* files found.") 95 | return 96 | 97 | # Group part files by prefix 98 | groups = defaultdict(list) 99 | for part in part_files: 100 | prefix = part.name.split(".tar.gz.part_")[0] 101 | groups[prefix].append(part) 102 | 103 | for prefix, parts in groups.items(): 104 | parts = sorted(parts) # Ensure correct order 105 | merged_tar_path = dataset_dir / f"{prefix}.tar.gz" 106 | extract_dir = dataset_dir / prefix 107 | 108 | if extract_dir.exists() and any(extract_dir.iterdir()): 109 | logger.info( 110 | f"Skipping extraction for '{prefix}' (already extracted at {extract_dir})" 111 | ) 112 | continue 113 | 114 | # Merge parts 115 | CHUNK_SIZE = 1024 * 1024 116 | logger.info(f"Merging parts for '{prefix}'...") 117 | with open(merged_tar_path, "wb") as outfile: 118 | for part in parts: 119 | logger.info(f" Adding: {part.name}") 120 | with open(part, "rb") as infile: 121 | while chunk := infile.read(CHUNK_SIZE): 122 | outfile.write(chunk) 123 | 124 | logger.info(f"Merged to: {merged_tar_path}") 125 | 126 | # Extract 127 | logger.info(f"Extracting to: {extract_dir}") 128 | with tarfile.open(merged_tar_path, "r:gz") as tar: 129 | tar.extractall(path=extract_dir) 130 | logger.info(f"Done extracting '{prefix}'\n") 131 | 132 | @staticmethod 133 | def discover_dataset_config(dataset_path: str, config_dict: List[Dict[str, Any]]) -> List[ProcessingConfig]: 134 | """Discover dataset configuration by scanning the data directory.""" 135 | dataset_dir = Path(dataset_path) 136 | train_dir = dataset_dir 137 | 138 | if not train_dir.exists(): 139 | raise FileNotFoundError(f"Train directory not found: {train_dir}") 140 | 141 | configs = [] 142 | processed_splits = set() 143 | 144 | # Find all JSON files in the train directory 145 | for config in config_dict: 146 | processing_config = ProcessingConfig.from_config_dict(config) 147 | 148 | # Skip if we already processed this split 149 | if processing_config.subset_name in processed_splits: 150 | continue 151 | 152 | configs.append(processing_config) 153 | processed_splits.add(processing_config.subset_name) 154 | logger.info( 155 | f"Discovered config: {processing_config.subset_name} -> {processing_config.images_folder}" 156 | ) 157 | 158 | return configs 159 | 160 | 161 | 162 | 163 | class SampleProcessor: 164 | """Processes and converts messages to different formats.""" 165 | 166 | @staticmethod 167 | def load_image_from_folder(images_folder: Path, img_path: str) -> Image.Image: 168 | """Load images from the specified folder.""" 169 | full_path = images_folder / img_path 170 | img = Image.open(full_path) 171 | new_img = img.copy() 172 | img.close() 173 | return new_img 174 | 175 | @staticmethod 176 | def convert_to_code_agent_format(messages: list[ChatMessage], json_path: str, reasoning: bool): 177 | """Convert messages to code agent format.""" 178 | for i, message in enumerate(messages): 179 | content = message.content 180 | 181 | if message.role == "system": 182 | if json_path in MOBILE_FILES: 183 | content = MOBILE_SYSTEM_PROMPT 184 | else: 185 | content = OS_SYSTEM_PROMPT 186 | 187 | if message.role == "user": 188 | content = content.replace("\n", "").replace("", "") 189 | 190 | elif message.role == "assistant": 191 | content = ( 192 | content.replace("Action: ", "") 193 | .replace("Observation: ", "") 194 | .replace("Thought: ", "") 195 | ) 196 | if reasoning and i == len(messages) - 1: 197 | content = ( 198 | "\n" + content.strip() + "\n" 199 | ) 200 | elif reasoning: 201 | # TODO: Check if there is always only 2 assistants 202 | content = ( 203 | "\n" 204 | + content.strip() 205 | + "\n\n" 206 | ) 207 | else: 208 | content = content.strip() 209 | 210 | messages[i].content = content 211 | 212 | # Fuse subsequent messages have the same role, merge it 213 | if i > 0 and messages[i].role == messages[i - 1].role: 214 | # Need to fuse both messages 215 | if reasoning: 216 | messages[i - 1].content += messages[i].content 217 | else: 218 | messages[i - 1].content += "\n" + messages[i].content 219 | messages.pop(i) 220 | 221 | return messages 222 | 223 | @staticmethod 224 | def convert_to_chat_format( 225 | data: ConversationData, json_path: str, reasoning: bool 226 | ) -> list[ChatMessage]: 227 | """Convert data item to chat template format.""" 228 | chat_messages = data.to_chat_messages() 229 | chat_messages = SampleProcessor.convert_to_code_agent_format(chat_messages, json_path, reasoning) 230 | return chat_messages 231 | 232 | @staticmethod 233 | def convert_to_new_action_space( 234 | messages: list[ChatMessage], resolution: tuple[int, int], code_format: bool = True 235 | ) -> list[ChatMessage]: 236 | """Convert messages to new action space format.""" 237 | regex_match: re.Match | str | None = None 238 | index = -1 239 | regex = r"\n(.*?)\n" 240 | assistant_msg = [(i, message) for i, message in enumerate(messages) if message.role == "assistant"] 241 | 242 | if assistant_msg: 243 | for index, msg in assistant_msg: 244 | 245 | if code_format: 246 | regex_match = re.search(regex, msg.content, re.DOTALL) 247 | else: 248 | regex_match = msg.content 249 | 250 | if regex_match is not None: 251 | function_calls = parse_function_call( 252 | regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, 253 | pattern_to_match=["pyautogui", "mobile", "terminate", "answer"], 254 | ) 255 | 256 | if len(function_calls) > 0: 257 | 258 | for i, function_call in enumerate(deepcopy(function_calls)): 259 | 260 | # pyautogui.dragTo have multiple signatures, we need to unify them before converting to new action space 261 | if function_call.function_name == "pyautogui.dragTo" and not isinstance(list(function_calls[i].parameters.values())[0], (list, tuple)): 262 | x1, y1 = islice(function_calls[i-1].parameters.values(), 2) 263 | x2, y2 = islice(function_calls[i].parameters.values(), 2) 264 | function_calls[i].parameters = {"from_coord": (x1, y1), "to_coord": (x2, y2)} 265 | function_calls[i].original_string = function_calls[i].to_string() 266 | function_calls.pop(i-1) 267 | 268 | function_calls = action_conversion(function_calls, resolution=resolution) 269 | 270 | new_action_string = "\n".join( 271 | [function_call.to_string() for function_call in function_calls] 272 | ) 273 | messages[index].content = messages[index].content.replace( 274 | regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, new_action_string 275 | ) 276 | 277 | return messages 278 | 279 | 280 | class DataProcessor: 281 | """Handles data processing and generation.""" 282 | 283 | def __init__(self): 284 | self.sample_processor = SampleProcessor() 285 | 286 | def process_subset( 287 | self, config: ProcessingConfig, dataset_path: str, deduplicate: bool = True 288 | ) -> tuple[ConversationDataList, Path]: 289 | """Process a single dataset subset.""" 290 | subset_name = config.subset_name 291 | 292 | logger.info(f"Processing split: {subset_name}") 293 | 294 | dataset_dir = Path(dataset_path) 295 | images_folder = dataset_dir / config.subset_name.replace("-l1", "").replace("-l2", "").replace("-l3", "") / config.images_folder 296 | 297 | if not images_folder.exists(): 298 | logger.warning(f"Images folder not found: {images_folder}") 299 | else: 300 | logger.info(f"Images folder: {images_folder}") 301 | 302 | json_config_path = dataset_dir / config.json_path 303 | with open(json_config_path, "r") as f: 304 | # Create ConversationDataList with deduplication control 305 | data = ConversationDataList.from_json_with_deduplication(f.read(), deduplicate) 306 | logger.info(f"Added '{json_config_path}' (deduplication: {deduplicate})") 307 | 308 | return data, images_folder 309 | 310 | def row_generator( 311 | self, data: ConversationDataList, images_folder: Path, json_path: str, reasoning: bool 312 | ) -> Generator[Dict[str, Any], None, None]: 313 | """Generate processed data rows.""" 314 | conversations: list[ConversationData] = data.root 315 | for item in tqdm(conversations): 316 | try: 317 | # Load images 318 | image = self.sample_processor.load_image_from_folder(images_folder, item.image) 319 | chat_message = self.sample_processor.convert_to_chat_format(item, json_path, reasoning) 320 | chat_message = self.sample_processor.convert_to_new_action_space(chat_message, image.size, code_format=reasoning) 321 | if len(chat_message) == 0: 322 | continue 323 | 324 | row = DataRow.from_chat_messages(chat_message, image, source=json_path.split("/")[-1].split(".")[0]) 325 | yield row.model_dump(exclude_none=True) 326 | del image 327 | except Exception as e: 328 | import traceback 329 | traceback.print_exc() 330 | logger.error(f"Error processing item: {e}", item) 331 | continue 332 | 333 | 334 | class SingleConfigProcessor: 335 | """Processes a single configuration in isolation.""" 336 | 337 | def __init__(self): 338 | self.data_processor = DataProcessor() 339 | 340 | @staticmethod 341 | def check_subset_exists(repo_id: str, subset_name: str) -> bool: 342 | """Check if a subset already exists in the remote dataset.""" 343 | try: 344 | config_names = get_dataset_config_names(repo_id) 345 | return subset_name in config_names 346 | except Exception as e: 347 | logger.warning(f"Could not check if subset exists: {e}") 348 | return False 349 | 350 | def process_single_config( 351 | self, config: ProcessingConfig, dataset_path: str, smolagents_repo_id: str, reasoning: bool, deduplicate: bool = True 352 | ) -> bool: 353 | """Process a single config in a separate process.""" 354 | try: 355 | # Authenticate in this process 356 | huggingface_authenticate() 357 | 358 | logger.info(f"\n{'=' * 50}") 359 | logger.info(f"Processing config: {config.subset_name}") 360 | 361 | # Check if the subset already exists in the remote dataset 362 | subset_name = config.subset_name 363 | if SingleConfigProcessor.check_subset_exists(smolagents_repo_id, subset_name): 364 | logger.info( 365 | f"Subset '{subset_name}' already exists in {smolagents_repo_id}, skipping processing." 366 | ) 367 | return True 368 | 369 | json_path = config.json_path 370 | data, image_folder = self.data_processor.process_subset(config, dataset_path, deduplicate) 371 | 372 | # Collect all rows first 373 | rows = [] 374 | datasets = [] 375 | for row in self.data_processor.row_generator(data, image_folder, json_path, reasoning): 376 | rows.append(row) 377 | if len(rows) > 20000: 378 | logger.info("Creating batch dataset") 379 | dataset = Dataset.from_list(rows) 380 | datasets.append(dataset) 381 | rows = [] 382 | gc.collect() 383 | 384 | if len(rows) > 0: 385 | # Create dataset from collected data 386 | dataset = Dataset.from_list(rows) 387 | datasets.append(dataset) 388 | rows = [] 389 | 390 | dataset_to_push = concatenate_datasets(datasets) 391 | 392 | # Push to hub 393 | dataset_to_push.push_to_hub( 394 | smolagents_repo_id, 395 | config_name=subset_name, 396 | split="train", 397 | ) 398 | 399 | logger.info(f"Processed and uploaded subset: {config.subset_name}") 400 | 401 | # Force garbage collection to manage memory 402 | gc.collect() 403 | 404 | return True 405 | 406 | except Exception as e: 407 | logger.error(f"Error processing config {config.subset_name}: {e}") 408 | import traceback 409 | traceback.print_exc() 410 | return False 411 | 412 | 413 | class AguvisDatasetProcessor: 414 | """Main class for orchestrating the entire AGUVIS dataset processing pipeline.""" 415 | 416 | def __init__(self): 417 | self.downloader = DatasetDownloader() 418 | self.config_processor = SingleConfigProcessor() 419 | 420 | @staticmethod 421 | def authenticate(): 422 | """Authenticate with HuggingFace Hub using token.""" 423 | hf_token = os.getenv("HF_TOKEN") 424 | if hf_token: 425 | logger.info("Authenticating with HuggingFace Hub using token...") 426 | login(token=hf_token) 427 | else: 428 | raise ValueError("HF_TOKEN environment variable not set.") 429 | 430 | def make_dataset_from_original_data( 431 | self, dataset_config: DatasetConfig, max_processes: Optional[int] = None 432 | ): 433 | """Main function to orchestrate the entire process.""" 434 | load_dotenv(override=True) 435 | 436 | logger.info(f"Starting {dataset_config.smolagents_repo_id} dataset processing...") 437 | 438 | self.authenticate() 439 | 440 | dataset_path = self.downloader.download_dataset( 441 | dataset_config.huggingface_repo_id, dataset_config.local_path 442 | ) 443 | 444 | self.downloader.extract_zip_files(dataset_path) 445 | self.downloader.extract_tar_parts_grouped(dataset_path) 446 | 447 | dataset_configs = self.downloader.discover_dataset_config( 448 | dataset_path, dataset_config.config_dict 449 | ) 450 | converted_repo_id = dataset_config.smolagents_repo_id 451 | reasoning = dataset_config.reasoning 452 | deduplicate = dataset_config.deduplicate 453 | 454 | if max_processes is None: 455 | max_processes = 1 456 | num_processes = min(max_processes, len(dataset_configs)) 457 | logger.info(f"Using {num_processes} processes to process {len(dataset_configs)} configs") 458 | 459 | # Prepare arguments for multiprocessing 460 | process_args = [ 461 | (config, dataset_path, converted_repo_id, reasoning, deduplicate) 462 | for config in dataset_configs 463 | ] 464 | 465 | # Process configs in parallel with progress tracking 466 | logger.info(f"Starting parallel processing of {len(process_args)} configs...") 467 | try: 468 | with Pool(processes=num_processes) as pool: 469 | results = [] 470 | for i, result in enumerate(pool.starmap(self.config_processor.process_single_config, process_args)): 471 | results.append(result) 472 | logger.info(f"Completed {i+1}/{len(process_args)} configs") 473 | except Exception as e: 474 | logger.error(f"Multiprocessing failed: {e}") 475 | logger.info("Falling back to sequential processing...") 476 | results = [] 477 | for i, args in enumerate(process_args): 478 | result = self.config_processor.process_single_config(*args) 479 | results.append(result) 480 | logger.info(f"Completed {i+1}/{len(process_args)} configs (sequential)") 481 | 482 | # Check results 483 | successful = sum(results) 484 | total = len(process_args) 485 | logger.info(f"\nProcessing complete: {successful}/{total} configs processed successfully") 486 | 487 | if successful < total: 488 | failed_count = total - successful 489 | logger.warning(f"Warning: {failed_count} configs failed to process. Check the logs above for details.") 490 | else: 491 | logger.info("All configs processed successfully!") 492 | 493 | 494 | def main(): 495 | """Main entry point for the script.""" 496 | # Create dataset configurations 497 | dataset_config_1 = DatasetConfig( 498 | huggingface_repo_id="xlangai/aguvis-stage1", 499 | local_path="./aguvis_raw_stage_1", 500 | config_dict=CONFIG_DICT_STAGE_1, 501 | smolagents_repo_id="smolagents/aguvis-stage-test", 502 | reasoning=False, 503 | deduplicate=True, 504 | ) 505 | 506 | dataset_config_2 = DatasetConfig( 507 | huggingface_repo_id="xlangai/aguvis-stage2", 508 | local_path="./aguvis_raw_stage_2", 509 | config_dict=CONFIG_DICT_STAGE_2, 510 | smolagents_repo_id="smolagents/aguvis-stage-test", 511 | reasoning=True, 512 | deduplicate=False, 513 | ) 514 | 515 | # Create processor and run 516 | processor = AguvisDatasetProcessor() 517 | 518 | # You can specify max_processes to limit the number of parallel processes 519 | processor.make_dataset_from_original_data(dataset_config_1, max_processes=4) 520 | processor.make_dataset_from_original_data(dataset_config_2, max_processes=4) 521 | 522 | 523 | if __name__ == "__main__": 524 | # python -m preprocessing.aguvis_processor 525 | main() 526 | -------------------------------------------------------------------------------- /recipe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Vision-Language Model Training Recipe\n", 8 | "\n", 9 | "Two-phase training pipeline for SmolVLM2 on GUI datasets.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Setup & Imports\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from transformers import AutoModelForImageTextToText, AutoProcessor, PreTrainedModel\n", 26 | "from trl import SFTTrainer, SFTConfig\n", 27 | "from datasets import load_dataset, concatenate_datasets, DatasetDict\n", 28 | "from PIL import Image\n", 29 | "from typing import Any, Callable\n", 30 | "import torch\n", 31 | "import wandb\n", 32 | "import logging\n", 33 | "import os\n", 34 | "\n", 35 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", 36 | "\n", 37 | "logger = logging.getLogger(__name__)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "## Dataset Configuration\n", 45 | "\n", 46 | "Phase 1: Basic GUI understanding datasets\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "dataset_mixture_phase_1 = [\n", 56 | " {\n", 57 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 58 | " \"config\": \"guienv\",\n", 59 | " \"split\": \"train\",\n", 60 | " \"columns\": [\"images\", \"texts\"],\n", 61 | " \"weight\": 1.0,\n", 62 | " },\n", 63 | " {\n", 64 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 65 | " \"config\": \"omniact\",\n", 66 | " \"split\": \"train\",\n", 67 | " \"columns\": [\"images\", \"texts\"],\n", 68 | " \"weight\": 1.0,\n", 69 | " },\n", 70 | " {\n", 71 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 72 | " \"config\": \"ricoig16k\",\n", 73 | " \"split\": \"train\",\n", 74 | " \"columns\": [\"images\", \"texts\"],\n", 75 | " \"weight\": 1.0,\n", 76 | " },\n", 77 | " {\n", 78 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 79 | " \"config\": \"ricosca\",\n", 80 | " \"split\": \"train\",\n", 81 | " \"columns\": [\"images\", \"texts\"],\n", 82 | " \"weight\": 1.0,\n", 83 | " },\n", 84 | " {\n", 85 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 86 | " \"config\": \"seeclick\",\n", 87 | " \"split\": \"train\",\n", 88 | " \"columns\": [\"images\", \"texts\"],\n", 89 | " \"weight\": 1.0,\n", 90 | " },\n", 91 | " {\n", 92 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 93 | " \"config\": \"ui_refexp\",\n", 94 | " \"split\": \"train\",\n", 95 | " \"columns\": [\"images\", \"texts\"],\n", 96 | " \"weight\": 1.0,\n", 97 | " },\n", 98 | " {\n", 99 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 100 | " \"config\": \"webui350k\",\n", 101 | " \"split\": \"train\",\n", 102 | " \"columns\": [\"images\", \"texts\"],\n", 103 | " \"weight\": 1.0,\n", 104 | " },\n", 105 | " {\n", 106 | " \"id\": \"smolagents/aguvis-stage-1\",\n", 107 | " \"config\": \"widget_captioning\",\n", 108 | " \"split\": \"train\",\n", 109 | " \"columns\": [\"images\", \"texts\"],\n", 110 | " \"weight\": 1.0,\n", 111 | " },\n", 112 | "]\n" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "Phase 2: Advanced agentic behavior datasets\n" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "dataset_mixture_phase_2 = [\n", 129 | " {\n", 130 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 131 | " \"config\": \"mind2web-l1\",\n", 132 | " \"split\": \"train\",\n", 133 | " \"columns\": [\"images\", \"texts\"],\n", 134 | " \"weight\": 1.0,\n", 135 | " },\n", 136 | " {\n", 137 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 138 | " \"config\": \"mind2web-l2\",\n", 139 | " \"split\": \"train\",\n", 140 | " \"columns\": [\"images\", \"texts\"],\n", 141 | " \"weight\": 1.0,\n", 142 | " },\n", 143 | " {\n", 144 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 145 | " \"config\": \"guiact-web-single\",\n", 146 | " \"split\": \"train\",\n", 147 | " \"columns\": [\"images\", \"texts\"],\n", 148 | " \"weight\": 1.0,\n", 149 | " },\n", 150 | " {\n", 151 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 152 | " \"config\": \"guiact-web-multi-l1\",\n", 153 | " \"split\": \"train\",\n", 154 | " \"columns\": [\"images\", \"texts\"],\n", 155 | " \"weight\": 1.0,\n", 156 | " },\n", 157 | " {\n", 158 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 159 | " \"config\": \"guiact-web-multi-l2\",\n", 160 | " \"split\": \"train\",\n", 161 | " \"columns\": [\"images\", \"texts\"],\n", 162 | " \"weight\": 1.0,\n", 163 | " },\n", 164 | " {\n", 165 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 166 | " \"config\": \"miniwob-l1\",\n", 167 | " \"split\": \"train\",\n", 168 | " \"columns\": [\"images\", \"texts\"],\n", 169 | " \"weight\": 1.0,\n", 170 | " },\n", 171 | " {\n", 172 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 173 | " \"config\": \"miniwob-l2\",\n", 174 | " \"split\": \"train\",\n", 175 | " \"columns\": [\"images\", \"texts\"],\n", 176 | " \"weight\": 1.0,\n", 177 | " },\n", 178 | " {\n", 179 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 180 | " \"config\": \"coat\",\n", 181 | " \"split\": \"train\",\n", 182 | " \"columns\": [\"images\", \"texts\"],\n", 183 | " \"weight\": 1.0,\n", 184 | " },\n", 185 | " {\n", 186 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 187 | " \"config\": \"android_control\",\n", 188 | " \"split\": \"train\",\n", 189 | " \"columns\": [\"images\", \"texts\"],\n", 190 | " \"weight\": 1.0,\n", 191 | " },\n", 192 | " {\n", 193 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 194 | " \"config\": \"gui-odyssey-l1\",\n", 195 | " \"split\": \"train\",\n", 196 | " \"columns\": [\"images\", \"texts\"],\n", 197 | " \"weight\": 0.33,\n", 198 | " },\n", 199 | " {\n", 200 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 201 | " \"config\": \"gui-odyssey-l2\",\n", 202 | " \"split\": \"train\",\n", 203 | " \"columns\": [\"images\", \"texts\"],\n", 204 | " \"weight\": 0.33,\n", 205 | " },\n", 206 | " {\n", 207 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 208 | " \"config\": \"amex-l1\",\n", 209 | " \"split\": \"train\",\n", 210 | " \"columns\": [\"images\", \"texts\"],\n", 211 | " \"weight\": 0.33,\n", 212 | " },\n", 213 | " {\n", 214 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 215 | " \"config\": \"amex-l2\",\n", 216 | " \"split\": \"train\",\n", 217 | " \"columns\": [\"images\", \"texts\"],\n", 218 | " \"weight\": 0.33,\n", 219 | " },\n", 220 | " {\n", 221 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 222 | " \"config\": \"aitw-l1\",\n", 223 | " \"split\": \"train\",\n", 224 | " \"columns\": [\"images\", \"texts\"],\n", 225 | " \"weight\": 1.0,\n", 226 | " },\n", 227 | " {\n", 228 | " \"id\": \"smolagents/aguvis-stage-2\",\n", 229 | " \"config\": \"aitw-l2\",\n", 230 | " \"split\": \"train\",\n", 231 | " \"columns\": [\"images\", \"texts\"],\n", 232 | " \"weight\": 1.0,\n", 233 | " },\n", 234 | "]\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "## Dataset Loading Utility\n", 242 | "\n", 243 | "Loads and combines multiple datasets with weights and splits into train/test.\n" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "def get_dataset(dataset_mixture: list[dict[str, Any]], test_split_size: float = 0.01) -> DatasetDict:\n", 253 | " \"\"\"Load a dataset or a mixture of datasets based on the configuration.\n", 254 | "\n", 255 | " Args:\n", 256 | " dataset_mixture (list[dict[str, Any]]): Dataset configuration.\n", 257 | "\n", 258 | " Returns:\n", 259 | " DatasetDict: The loaded datasets.\n", 260 | " \"\"\"\n", 261 | " logger.info(f\"Creating dataset mixture with {len(dataset_mixture)} datasets\")\n", 262 | " seed = 42\n", 263 | " datasets_list = []\n", 264 | "\n", 265 | " for dataset_config in dataset_mixture:\n", 266 | " logger.info(\n", 267 | " f\"Loading dataset for mixture: {dataset_config['id']} (config: {dataset_config['config']})\"\n", 268 | " )\n", 269 | " ds = load_dataset(\n", 270 | " dataset_config[\"id\"],\n", 271 | " dataset_config[\"config\"],\n", 272 | " split=dataset_config[\"split\"],\n", 273 | " )\n", 274 | " ds = ds.select_columns(dataset_config[\"columns\"])\n", 275 | " ds = ds.shuffle(seed=seed).select(\n", 276 | " range(int(len(ds) * dataset_config[\"weight\"]))\n", 277 | " )\n", 278 | " logger.info(\n", 279 | " f\"Subsampled dataset '{dataset_config['id']}' (config: {dataset_config['config']}) with weight={dataset_config['weight']} to {len(ds)} examples\"\n", 280 | " )\n", 281 | "\n", 282 | " datasets_list.append(ds)\n", 283 | "\n", 284 | " if datasets_list:\n", 285 | " combined_dataset = concatenate_datasets(datasets_list)\n", 286 | " combined_dataset = combined_dataset.shuffle(seed=seed)\n", 287 | " logger.info(f\"Created dataset mixture with {len(combined_dataset)} examples\")\n", 288 | "\n", 289 | " combined_dataset = combined_dataset.train_test_split(test_size=test_split_size, seed=seed)\n", 290 | " logger.info(\n", 291 | " f\"Split dataset into train and test sets with test size: {test_split_size}\"\n", 292 | " )\n", 293 | " return combined_dataset\n", 294 | " else:\n", 295 | " raise ValueError(\"No datasets were loaded from the mixture configuration\")\n" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "## Model & Training Parameters\n" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "base_model_name = \"HuggingFaceTB/SmolVLM2-2.2B-Instruct\"\n", 312 | "phase_1_model_name = \"SmolVLM2-2.2B-Instruct-GUI\"\n", 313 | "phase_2_model_name = \"SmolVLM2-2.2B-Instruct-Agentic-GUI\"\n", 314 | "image_size = 1152\n", 315 | "max_length = 16384\n" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "## Processor & Data Collator Setup\n", 323 | "\n", 324 | "Configure processor for image/text handling and create custom collator.\n" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "processor = AutoProcessor.from_pretrained(\n", 334 | " base_model_name,\n", 335 | " revision=\"main\",\n", 336 | " trust_remote_code=True,\n", 337 | ")\n", 338 | "processor.image_processor.size = {\"longest_edge\": image_size}\n", 339 | "processor.tokenizer.truncation_side = \"right\"\n", 340 | "processor.tokenizer.padding_side = \"right\"\n" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "def create_collate_fn(processor, max_length: int):\n", 350 | " \"\"\"Optimized collate function for VLM training that masks system prompt tokens.\"\"\"\n", 351 | "\n", 352 | " def collate_fn(examples: list[dict[str, list | str | Image.Image]]):\n", 353 | " batch_messages: list[list[dict[str, list | str | Image.Image]]] = []\n", 354 | " assistant_messages: list[list[str]] = []\n", 355 | " all_image_inputs: list[list[Image.Image]] = []\n", 356 | " for example in examples:\n", 357 | " images: list[Image.Image] = example[\"images\"]\n", 358 | " is_first_user = True\n", 359 | " sample: list[dict[str, list | str | Image.Image]] = []\n", 360 | " assistant: list[str] = []\n", 361 | " for text in example[\"texts\"]:\n", 362 | " if \"system\" in text.keys():\n", 363 | " sample.append(\n", 364 | " {\n", 365 | " \"role\": \"system\",\n", 366 | " \"content\": [{\"type\": \"text\", \"text\": text[\"system\"]}],\n", 367 | " }\n", 368 | " )\n", 369 | "\n", 370 | " if is_first_user:\n", 371 | " sample.append(\n", 372 | " {\n", 373 | " \"role\": \"user\",\n", 374 | " \"content\": [\n", 375 | " {\"type\": \"image\", \"image\": images[0]},\n", 376 | " {\"type\": \"text\", \"text\": text[\"user\"]},\n", 377 | " ],\n", 378 | " }\n", 379 | " )\n", 380 | " is_first_user = False\n", 381 | " else:\n", 382 | " sample.append(\n", 383 | " {\n", 384 | " \"role\": \"user\",\n", 385 | " \"content\": [\n", 386 | " {\"type\": \"text\", \"text\": text[\"user\"]},\n", 387 | " ],\n", 388 | " }\n", 389 | " )\n", 390 | "\n", 391 | " sample.append(\n", 392 | " {\n", 393 | " \"role\": \"assistant\",\n", 394 | " \"content\": [{\"type\": \"text\", \"text\": \"\\n\" + text[\"assistant\"]}],\n", 395 | " }\n", 396 | " )\n", 397 | " assistant.append(text[\"assistant\"])\n", 398 | "\n", 399 | " batch_messages.append(sample)\n", 400 | " assistant_messages.append(assistant)\n", 401 | " all_image_inputs.append(images)\n", 402 | "\n", 403 | " texts = [\n", 404 | " processor.apply_chat_template(\n", 405 | " messages, tokenize=False, add_generation_prompt=False\n", 406 | " )\n", 407 | " for messages in batch_messages\n", 408 | " ]\n", 409 | "\n", 410 | " batch = processor(\n", 411 | " text=texts,\n", 412 | " images=all_image_inputs if all_image_inputs else None,\n", 413 | " max_length=max_length,\n", 414 | " truncation=True,\n", 415 | " padding=True,\n", 416 | " return_tensors=\"pt\",\n", 417 | " )\n", 418 | "\n", 419 | " input_ids = batch[\"input_ids\"]\n", 420 | " labels = input_ids.clone()\n", 421 | "\n", 422 | " assistant_encodings = [\n", 423 | " processor.tokenizer(\n", 424 | " [msg + \"\" for msg in assistant_message],\n", 425 | " add_special_tokens=False,\n", 426 | " padding=False,\n", 427 | " )[\"input_ids\"]\n", 428 | " for assistant_message in assistant_messages\n", 429 | " ]\n", 430 | "\n", 431 | " # Mask out all except the assistant messages\n", 432 | " for i, assistant_ids_list in enumerate(assistant_encodings):\n", 433 | " seq = input_ids[i].tolist()\n", 434 | " assistant_positions: list[int] = []\n", 435 | " for ids in assistant_ids_list:\n", 436 | " start_pos = 0\n", 437 | " while start_pos < len(seq) - len(ids) + 1:\n", 438 | " found = False\n", 439 | " for j in range(start_pos, len(seq) - len(ids) + 1):\n", 440 | " if seq[j : j + len(ids)] == ids:\n", 441 | " assistant_positions.extend(range(j, j + len(ids)))\n", 442 | " start_pos = j + len(ids)\n", 443 | " found = True\n", 444 | " break\n", 445 | " if not found:\n", 446 | " break\n", 447 | "\n", 448 | " for pos in range(len(seq)):\n", 449 | " if pos not in assistant_positions:\n", 450 | " labels[i, pos] = -100\n", 451 | "\n", 452 | " batch[\"labels\"] = labels\n", 453 | " return batch\n", 454 | "\n", 455 | " return collate_fn\n", 456 | "\n", 457 | "data_collator = create_collate_fn(processor, max_length)\n" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "metadata": {}, 463 | "source": [ 464 | "## Training Utility Function\n", 465 | "\n", 466 | "Handles training loop, metrics logging, and model saving.\n" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "def training_and_save(model: PreTrainedModel, dataset: DatasetDict, training_args: SFTConfig, data_collator: Callable, processor: AutoProcessor):\n", 476 | " trainer = SFTTrainer(\n", 477 | " model=model,\n", 478 | " args=training_args,\n", 479 | " data_collator=data_collator,\n", 480 | " train_dataset=dataset[\"train\"],\n", 481 | " eval_dataset=dataset[\"test\"],\n", 482 | " processing_class=processor,\n", 483 | " )\n", 484 | " logger.info(\"*** Training ***\")\n", 485 | " train_result = trainer.train()\n", 486 | " metrics = train_result.metrics\n", 487 | " metrics[\"train_samples\"] = len(dataset[\"train\"])\n", 488 | " trainer.log_metrics(\"train\", metrics)\n", 489 | " trainer.save_metrics(\"train\", metrics)\n", 490 | " trainer.save_state()\n", 491 | " logger.info(\"*** Save model ***\")\n", 492 | " trainer.save_model(training_args.output_dir)\n", 493 | " logger.info(f\"Model saved to {training_args.output_dir}\")\n", 494 | " \n", 495 | " if hasattr(trainer, 'state') and trainer.state.is_world_process_zero:\n", 496 | " wandb.finish()\n" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "metadata": {}, 502 | "source": [ 503 | "## Phase 1 Training\n", 504 | "\n", 505 | "Load base model and train on GUI understanding datasets.\n" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [ 514 | "model = AutoModelForImageTextToText.from_pretrained(\n", 515 | " base_model_name,\n", 516 | " revision=\"main\",\n", 517 | " torch_dtype=torch.bfloat16,\n", 518 | " attn_implementation=\"sdpa\",\n", 519 | " trust_remote_code=True,\n", 520 | ")\n" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "training_args = SFTConfig(\n", 530 | " max_length=max_length,\n", 531 | " output_dir=f\"./{phase_1_model_name}\",\n", 532 | " optim=\"adamw_torch\",\n", 533 | " lr_scheduler_type=\"cosine_with_min_lr\",\n", 534 | " lr_scheduler_kwargs={\"min_lr_rate\": 0.1},\n", 535 | " max_grad_norm=0.2,\n", 536 | " warmup_ratio=0.03,\n", 537 | " learning_rate=2.0e-05,\n", 538 | " gradient_accumulation_steps=32,\n", 539 | " per_device_eval_batch_size=2,\n", 540 | " per_device_train_batch_size=2,\n", 541 | " max_steps=-1,\n", 542 | " num_train_epochs=2.0,\n", 543 | " bf16=True,\n", 544 | " do_eval=True,\n", 545 | " eval_strategy=\"steps\",\n", 546 | " eval_steps=100,\n", 547 | " gradient_checkpointing=True,\n", 548 | " gradient_checkpointing_kwargs={\"use_reentrant\": False},\n", 549 | " log_level=\"info\",\n", 550 | " logging_steps=5,\n", 551 | " logging_strategy=\"steps\",\n", 552 | " overwrite_output_dir=False,\n", 553 | " report_to=[\"wandb\"],\n", 554 | " run_name=f\"{base_model_name}-phase-1\",\n", 555 | " save_strategy=\"epoch\",\n", 556 | " save_steps=1,\n", 557 | " save_total_limit=1,\n", 558 | " ddp_find_unused_parameters=False,\n", 559 | " dataset_kwargs={\"skip_prepare_dataset\": True},\n", 560 | " remove_unused_columns=False,\n", 561 | " seed=42,\n", 562 | ")\n" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "dataset = get_dataset(dataset_mixture_phase_1)\n", 572 | "training_and_save(model, dataset, training_args, data_collator, processor)\n" 573 | ] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "metadata": {}, 578 | "source": [ 579 | "## Phase 2 Training\n", 580 | "\n", 581 | "Load Phase 1 model and continue training on agentic datasets.\n" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": null, 587 | "metadata": {}, 588 | "outputs": [], 589 | "source": [ 590 | "model = AutoModelForImageTextToText.from_pretrained(\n", 591 | " f\"./{phase_1_model_name}\",\n", 592 | " revision=\"main\",\n", 593 | " torch_dtype=torch.bfloat16,\n", 594 | " attn_implementation=\"sdpa\",\n", 595 | " trust_remote_code=True,\n", 596 | ")\n", 597 | "\n", 598 | "dataset_phase_2 = get_dataset(dataset_mixture_phase_2)\n" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "Adjust training arguments for Phase 2\n" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": null, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "training_args.gradient_accumulation_steps = 16\n", 615 | "training_args.per_device_train_batch_size = 4\n", 616 | "training_args.per_device_eval_batch_size = 4\n", 617 | "training_args.run_name = f\"{base_model_name}-phase-2\"\n", 618 | "training_args.output_dir = f\"./{phase_2_model_name}\"\n", 619 | "\n", 620 | "training_and_save(model, dataset_phase_2, training_args, data_collator, processor)\n" 621 | ] 622 | } 623 | ], 624 | "metadata": { 625 | "language_info": { 626 | "name": "python" 627 | } 628 | }, 629 | "nbformat": 4, 630 | "nbformat_minor": 2 631 | } 632 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Smol2Operator: Post-Training GUI Agents for Computer Use 2 | 3 | **TL;DR:** This work shows how a lightweight vision–language model can acquire GUI-grounded skills and evolve into an agentic GUI coder. We release all training recipes, data-processing tools, and datasets to enable full reproducibility and foster further research 🫡. 4 | 5 | --- 6 | 7 | https://github.com/user-attachments/assets/af51ba24-bc9d-431c-b42d-589ff7d134e7 8 | 9 | --- 10 | 11 | ## Table of Contents 12 | 13 | - [Introduction](#introduction) 14 | - [1. Data Transformation and Unified Action Space](#1-data-transformation-and-unified-action-space) 15 | - [The Challenge of Inconsistent Action Spaces](#the-challenge-of-inconsistent-action-spaces) 16 | - [Our Unified Approach](#our-unified-approach) 17 | - [Example Data Transformation](#example-data-transformation) 18 | - [Custom Action Space Adaptation with Action Space Converter](#bonus-custom-action-space-adaptation-with-action-space-converter) 19 | - [Key Features](#key-features) 20 | - [Usage Example](#usage-example) 21 | - [Transformed and Released Datasets](#transformed-and-released-datasets) 22 | - [2. Phase 1: From Zero to Perception](#2-phase-1-from-zero-to-perception) 23 | - [Training Data](#training-data) 24 | - [Optimization Experiments](#optimization-experiments) 25 | - [Image Resolution and Coordinate System Analysis](#image-resolution-and-coordinate-system-analysis) 26 | - [Key Findings](#key-findings) 27 | - [Phase 1 Results](#phase-1-results) 28 | - [3. Phase 2: From Perception to Cognition](#3-phase-2-from-perception-to-cognition) 29 | - [Training Data](#training-data-1) 30 | - [Phase 2 Results](#phase-2-results) 31 | - [4. All you need is Open Source](#4-all-you-need-is-open-source) 32 | - [5. Conclusion](#5-conclusion) 33 | - [What Next?](#what-next) 34 | 35 |
36 | 37 |
44 | 💡 Additional Resources: 45 |

46 | • Datasets: smolagents/aguvis-stage-1, smolagents/aguvis-stage-2 47 |
48 | 49 | ## Introduction 50 | 51 | Graphical User Interface (GUI) automation is one of the most challenging frontiers in computer vision. Developing models that see and interact with user interfaces enables AI agents to navigate mobile, desktop, and web platforms. This will reshape the future of digital interaction. 52 | 53 | In this blog post, we present a comprehensive approach to training vision-language models for GUI automation through a multi-phase training strategy. We demonstrate how to transform a model with zero grounding capabilities into an agentic coder capable of understanding and interacting with graphical interfaces. 54 | 55 | Rather than aiming for a SOTA model, our goal is to demonstrate the entire process, from data processing to model training, and, in doing so, show how to unlock GUI-grounding capabilities in VLMs. 56 | 57 |
58 | 59 | ![GUI capabilities combine understanding of the interface and precise element localization. These abilities enable the model to translate high-level tasks into low-level GUI actions such as clicking, typing, …](assets/google.png) 60 |

GUI capabilities combine understanding of the interface and precise element localization. These abilities enable the model to translate high-level tasks into low-level GUI actions such as clicking, typing, …

68 | 69 |
70 | 71 |
72 | 73 | 74 | Our approach leverages [**SmolVLM2-2.2B-Instruct**](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) as the baseline model, a small powerful vision-language model that initially has no grounding capabilities for GUI tasks. This makes it an ideal candidate to demonstrate the effectiveness of our training methodology. Through our two-phase training process, we first instill grounding capabilities in the model, then enhance it with agentic reasoning abilities using Supervised Fine-Tuning (SFT). 75 | 76 | We evaluate our approach on an established perception benchmark: [**ScreenSpot-v2**](https://huggingface.co/datasets/HongxinLi/ScreenSpot_v2), which tests the model’s ability to understand and locate elements within screenshots. Our process is inspired by the [AGUVIS](https://huggingface.co/papers/2412.04454) paper, and we leverage their carefully curated datasets to build upon their foundational work. 77 | 78 |
79 | 80 | ![Evolution of ScreenSpot-v2 performance during the training phase of the base model SmolVLM2-2.2B-Instruct.](assets/screenspot-v2.png) 81 |

Evolution of ScreenSpot-v2 performance during the training phase of the base model **SmolVLM2-2.2B-Instruct**.

89 | 90 |
91 | 92 | 93 | ## 1. Data Transformation and Unified Action Space 94 | 95 | *This section explains how we **convert heterogeneous GUI actions format from multiple datasets into a single unified format**. By standardizing function names, signatures, and parameters, we create consistent, high-quality data that forms the foundation for effective model training.* 96 | 97 | ### The Challenge of Inconsistent Action Spaces 98 | 99 | One of the primary challenges when working with multiple GUI automation datasets is the lack of standardization in action representations. Different datasets use varying function signatures, parameter naming conventions, and action taxonomies, making it difficult to train a unified model across diverse data sources. 100 | 101 | ### Our Unified Approach 102 | 103 | We took the open-source datasets ([xlangai/aguvis-stage1](https://huggingface.co/datasets/xlangai/aguvis-stage1), [xlangai/aguvis-stage2](https://huggingface.co/datasets/xlangai/aguvis-stage2)), originally used by [AGUVIS](https://huggingface.co/papers/2412.04454), and implemented a comprehensive data transformation pipeline to create a unified action space. Our approach involved: 104 | 105 | 1. **Function Parsing and Normalization**: We developed a function parser (see `utils/function_parser.py`) that can extract and parse function calls from various formats across all datasets. This parser supports any function signature format, handles complex parameter structures, and can reconstruct function calls with proper parameter ordering. 106 | 2. **Action Space Unification**: We implemented a comprehensive action conversion system (see `preprocessing/action_conversion.py`) that transforms all original action representations into a standardized function naming and argument structure. This process highlighted the significant inconsistencies in function signatures across different datasets and allowed us to: 107 | - Remove undesired or redundant actions 108 | - Standardize parameter naming conventions 109 | - Create a cohesive action vocabulary 110 | 3. **(Bonus) Flexible Adaptation Framework**: Our transformation pipeline includes utilities that allow users to: 111 | - Adapt the entire dataset to their own action space naming conventions using the `utils/action_space_converter.py` tool 112 | - Extract and analyze the current action space structure 113 | 114 | ### Example Data Transformation 115 | 116 | Here are real examples from our action conversion system (`preprocessing/action_conversion.py`) showing how we transform heterogeneous action representations into our unified format (grounding coordinates normalized to [0,1]): 117 | 118 | **Before (Original Action Dataset Formats):** 119 | 120 | ```python 121 | # Mobile Actions 122 | mobile.home() 123 | mobile.open_app(app_name='drupe') 124 | mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518]) 125 | mobile.long_press(x=0.799, y=0.911) 126 | mobile.terminate(status='success') 127 | # Desktop Actions 128 | pyautogui.click(x=0.8102, y=0.9463) 129 | pyautogui.doubleClick(x=0.8102, y=0.9463) 130 | pyautogui.hotkey(keys=['ctrl', 'c']) 131 | pyautogui.scroll(page=-0.1) 132 | pyautogui.write(message='bread buns') 133 | pyautogui.dragTo(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463]) 134 | ``` 135 | 136 | **After (Unified Action Dataset Formats):** 137 | 138 | ```python 139 | # Unified Mobile Actions 140 | navigate_home() 141 | open_app(app_name='drupe') 142 | swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518]) 143 | long_press(x=0.799, y=0.911) 144 | final_answer('success') 145 | # Unified Desktop Actions 146 | click(x=0.8102, y=0.9463) 147 | double_click(x=0.8102, y=0.9463) 148 | press(keys=['ctrl', 'c']) 149 | scroll(direction='up', amount=10) # Smart direction detection 150 | type(text='bread buns') 151 | drag(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463]) 152 | ``` 153 | 154 | This unification process was essential for creating coherent training data that allows the model to learn consistent action patterns across diverse GUI environments. 155 | 156 |
157 | 158 |
165 | 💡 Why Normalized Coordinates? 166 |
167 | Using raw pixel coordinates in text-action datapoint (e.g. click(x=302, y=63)) ties them to a single image size. Vision Language Models (VLMs) often resize images, causing pixel coordinates to break and require adjustment. Normalized coordinates (relative to image size) remain valid at any resolution and keep the dataset consistent. 168 |
169 | 170 | 171 | ### (Bonus) Custom Action Space Adaptation with Action Space Converter 172 | 173 | To maximize flexibility for different use cases, we developed the **Action Space Converter** (`utils/action_space_converter.py`), a tool that allows users to easily adapt from an action space to their own custom action vocabularies and naming conventions. 174 | 175 | You can use this tool to transform one action signature (function names, parameter names, and parameter value changes, ...) into another: 176 | 177 | **Before** 178 | 179 | ```python 180 | assistant_message: "Action: click(x=0.5, y=0.3)" 181 | ``` 182 | 183 | **After** 184 | 185 | ```python 186 | assistant_message: "Action: touch(x_coord=200, y_coord=300)" 187 | ``` 188 | 189 | ### Key Features 190 | 191 | The Action Space Converter provides: 192 | 193 | 1. **Configurable Mappings**: Define custom mappings between unified actions and your preferred action names 194 | 2. **Parameter Transformation**: Rename parameters, apply value transformations, and set default values 195 | 3. **Flexible Architecture**: Support for both simple parameter mappings and complex custom transformation functions 196 | 4. **Validation**: Built-in validation to ensure mapping configurations are valid 197 | 198 | ### Usage Example 199 | 200 | ```python 201 | from utils.action_space_converter import ActionSpaceConverter, ActionMapping, ParameterMapping 202 | from utils.function_parser import parse_function_call 203 | 204 | # Create custom mappings 205 | mappings = [ 206 | ActionMapping( 207 | source_function="click", 208 | target_function="touch", 209 | parameter_mappings=[ 210 | ParameterMapping(source_name="x", target_name="x_coord"), 211 | ParameterMapping(source_name="y", target_name="y_coord") 212 | ], 213 | description="Touch screen at coordinates" ), 214 | ActionMapping( 215 | source_function="type", # source_function is the name of the function in the original function call 216 | target_function="write", # target_function is the name of the function in the target function call 217 | parameter_mappings=[ 218 | ParameterMapping(source_name="text", target_name="content") 219 | # source_name is the name of the parameter in the original function call 220 | # target_name is the name of the parameter in the target function call 221 | ], 222 | description="Input text" 223 | ) 224 | ] 225 | 226 | assistant_message = "I'll interact at those coordinates for you. click(x=0.5, y=0.3) Now I'll input the text. type(text='hello world')" 227 | 228 | # Parse function calls 229 | parsed_function_calls = parse_function_call(text) 230 | 231 | # Initialize converter 232 | converter = ActionSpaceConverter(mappings) 233 | 234 | # Convert actions 235 | converted_actions = converter.convert_actions(parsed_function_calls) 236 | for new_function_call, old_function_call in zip(converted_actions, parsed_function_calls): 237 | text = text.replace(old_function_call.to_string(), new_function_call.to_string()) 238 | 239 | print(text) 240 | # Output: I'll interact at those coordinates for you. touch(x_coord=0.5, y_coord=0.3) Now I'll input the text. write(content='hello world') 241 | ``` 242 | 243 | This tool enables researchers and practitioners to: 244 | 245 | - **Customize Training Data**: Adapt the dataset to match their specific action vocabulary requirements 246 | - **Domain Adaptation**: Transform actions for different platforms (mobile vs. desktop vs. web) 247 | - **Framework Integration**: Easily align training data with existing automation frameworks 248 | - **Rapid Experimentation**: Quickly test different action space configurations 249 | - **Release Preparation**: Standardize action spaces for production deployment with consistent naming conventions 250 | 251 | The Action Space Converter is particularly valuable for preparing datasets for training, as it ensures consistent action vocabularies across different deployment environments while maintaining compatibility with existing automation frameworks. 252 | 253 | ### Transformed and Released Datasets 254 | 255 | Through this pipeline, we transform the open-source datasets [xlangai/aguvis-stage1](https://huggingface.co/datasets/xlangai/aguvis-stage1), [xlangai/aguvis-stage2](https://huggingface.co/datasets/xlangai/aguvis-stage2) into our unified action space (see [here](https://www.notion.so/Smol2Operator-Post-Training-GUI-Agents-for-Computer-Use-Draft-Blog-Post-2701384ebcac8035bbaad69b5b32ed99?pvs=21)). The output of this process is released as two new fully formatted datasets: [smolagents/aguvis-stage-1](https://huggingface.co/datasets/smolagents/aguvis-stage-1) and [smolagents/aguvis-stage-2](https://huggingface.co/datasets/smolagents/aguvis-stage-2). 256 | 257 | ## 2. Phase 1: From Zero to Perception 258 | 259 | ### Training Data 260 | 261 | Phase 1 leverages the [smolagents/aguvis-stage-1](https://huggingface.co/datasets/smolagents/aguvis-stage-1) dataset, which introduces **GUI grounding** by pairing low-level instructions with diverse executable actions (expressed in code form). For example, a user/assistant turn in [smolagents/aguvis-stage-1](https://huggingface.co/datasets/smolagents/aguvis-stage-1) follows the structure: 262 | 263 | ```json 264 | { 265 | "user": "click on more button", 266 | "assistant": "click(x=0.8875, y=0.2281)", 267 | } 268 | ``` 269 | 270 | Each sample links a screenshot with multi-turn user/assistant interactions, enabling the model to learn fine-grained action grounding across dialogue turns. During fine-tuning, the data collator masks everything except the assistant’s answers when computing the loss. 271 | 272 | 278 | 279 | ### Optimization Experiments 280 | 281 | Before proceeding with full-scale Phase 1 training, we conducted comprehensive ablation studies to determine optimal training configurations 282 | 283 | ### Image Resolution and Coordinate System Analysis 284 | 285 | We experimented with different image sizes and coordinate representation systems to identify the optimal configuration for SmolVLM2: 286 | 287 | - **Image Sizes Tested**: 384px, 768px, 1152px 288 | - **Coordinate Systems**: Pixel coordinates vs. normalized coordinates (0-1 range) 289 | - **Training Data**: 400K samples from Aguvis datasets 290 | 291 | > Some SOTA GUI VLMs (e.g., Qwen-VL) appear also to use a different normalized range (0–1000), which was not tested in this experiment. 292 | 293 | 294 |
295 | 300 | 301 | 305 | 313 | 321 | 322 | 323 | 324 | 325 | 334 | 340 | 341 | 342 | 349 | 355 | 356 | 357 | 364 | 370 | 371 | 372 | 379 | 385 | 386 | 387 | 394 | 400 | 401 | 402 | 411 | 417 | 418 | 419 | 426 | 432 | 433 | 434 | 441 | 447 | 448 | 449 | 456 | 462 | 463 | 464 | 471 | 477 | 478 | 479 |
Configuration (coords / image size)Screenspot-v2 (%)
Normalized coordinates
Base / –0.47
38431.28
76432.32
115233.72
Pixel coordinates
Base / –0.55
3841.17
7642.67
11524.32
480 |

Table 1: Baseline on HuggingFaceTB/SmolVLM2-2.2B-Instruct (400k samples, aguvis-stage-1). Higher is better.

488 |
489 | 490 | 491 | 492 | *As demonstrated in our benchmark results, SmolVLM2-2.2B-Instruct base initially achieved 0% performance on perception benchmarks like ScreenSpot-v2. This complete lack of grounding capability provided us with a clean slate to evaluate the effectiveness of our training methodology.* 493 | 494 | ### Key Findings 495 | 496 | From our experiments, we determined that: 497 | - **Image Size**: 1152px 498 | - **Coordinate System**: Normalized coordinates (0-1 range) proved most effective for SmolVLM2 499 | - Note: The optimal choice between pixel and normalized coordinates may vary depending on the base model’s pre-training approach 500 | 501 | ### Phase 1 Results 502 | 503 | Using the optimal configuration (1152px resolution with normalized coordinates), we trained for 2 epochs on the smolagents/aguvis-stage-1 dataset. The results were remarkable, **+41% improvement over baseline on ScreenSpot-v2** 504 | 505 | This dramatic improvement demonstrates that our Phase 1 training successfully instilled fundamental grounding capabilities in the model, enabling it to understand and locate visual elements within screenshots. 506 | 507 | 508 |
509 | 514 | 515 | 519 | 527 | 535 | 536 | 537 | 538 | 539 | 545 | 551 | 552 | 553 |
Configuration (coords / image size)Screenspot-v2 (%)
Normalized coordinates / 115241.27
554 |

Table 2: Baseline on HuggingFaceTB/SmolVLM2-2.2B-Instruct (2 epochs, aguvis-stage-1).

562 |
563 | 564 | 565 | 566 | ## 3. Phase 2: From Perception to Cognition 567 | 568 | Whereas Phase 1 provided grounding capabilities, Phase 2 targets **agentic reasoning,** the ability to deliberate and plan before acting. This stage transforms the model from a reactive system identifying GUI elements into a proactive agent capable of executing complex, multi-step interactions. 569 | 570 | ### Training Data 571 | 572 | Phase 2 uses the [smolagents/aguvis-stage-2](https://huggingface.co/datasets/smolagents/aguvis-stage-2) dataset, which introduces agentic scenarios: 573 | 574 | - **Explicit reasoning** about upcoming actions 575 | 576 | - **Context consistence** across multiple interaction steps 577 | 578 | - **High-level instructions** require multi-step, low-level actions. 579 | 580 | For example, the [smolagents/aguvis-stage-2](https://huggingface.co/datasets/smolagents/aguvis-stage-2) chat message is like this: 581 | 582 | ```json 583 | { 584 | "system": "You are a helpful GUI agent. ...", 585 | "user": "Please generate the next move according to the UI screenshot, instruction and previous actions.\n\nInstruction: What information does the site provide about Judith Lauand's career, works and exhibitions?\n\nPrevious actions:\nNone", 586 | "assistant": "\nClick on the link labeled 'Judith Lauand: Brazilian 1922-2022' to explore more about her career and exhibitions.\n\n\nclick(x=0.41, y=0.178)\n", 587 | } 588 | ``` 589 | 590 | Each sample links a screenshot with a system/user/assistant turn. During fine-tuning, the data collator masks everything except the assistant’s answers when computing the loss. 591 | 592 | 598 | 599 | ### Phase 2 Results 600 | 601 | Starting from the Phase 1 checkpoint (1152 px resolution, normalized coordinates), we fine-tuned the model for two epochs on [smolagents/aguvis-stage-2](https://huggingface.co/datasets/smolagents/aguvis-stage-2). The accuracy on **ScreenSpot-v2 increased from 41% to 61%**, indicating that explicit reasoning improves GUI grounding performance. 602 | 603 | 604 | 605 |
606 | 611 | 612 | 616 | 624 | 632 | 633 | 634 | 635 | 636 | 642 | 648 | 649 | 650 |
Configuration (coords / image size)Screenspot-v2 (%)
Normalized coordinates / 115261.71
651 |

Table 2: Baseline on HuggingFaceTB/SmolVLM2-2.2B-Instruct after Phase 1 finetuning (2 epochs, aguvis-stage-1).

659 |
660 | 661 |
662 | 663 |
670 | 💡 We also reproduced the two-phase training on a much smaller VLM (nanoVLM-460M). Despite its reduced capacity, the model achieved ~58% on ScreenSpot-v2, demonstrating that the training strategy scales down effectively, making it SOTA on ScreenSpot-v2 for this model size (460M parameters). In addition, aguvis-stage-1 is already included in FineVision Dataset! 671 |
672 | 673 | 674 | ## 4. All you need is Open Source 675 | 676 | All training code, data processing pipelines, datasets and model are open-source! 677 | 678 | 1. **Training Recipe** ([`recipe.ipynb`](https://github.com/huggingface/smol2operator/blob/main/recipe.ipynb)): Complete training pipeline for both Phase 1 and Phase 2, including dataset mixture configurations and training orchestration. We leverage the [TRL](https://huggingface.co/docs/trl/en/index) library to train our models. 679 | 2. **Datasets** ([`smolagents/aguvis-stage-1`](https://huggingface.co/datasets/smolagents/aguvis-stage-1), [`smolagents/aguvis-stage-2`](https://huggingface.co/datasets/smolagents/aguvis-stage-2)): all datasets used are open-source. 680 | 3. **Model** ([`smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI`](https://huggingface.co/smolagents/SmolVLM2-2.2B-Instruct-Agentic-GUI)): the model produced by applying the training recipe described above. 681 | 4. **Preprocessing Tools:** 682 | - **Function Parser** ([`utils/function_parser.py`](https://github.com/huggingface/smol2operator/blob/main/utils/function_parser.py)): Utilities for parsing, normalizing, and reconstructing function calls from diverse dataset formats. Supports complex parameter structures, positional arguments, and multiple function call extraction. 683 | - **Action Conversion System** ([`preprocessing/action_conversion.py`](https://github.com/huggingface/smol2operator/blob/main/preprocessing/action_conversion.py)): Core unification engine transforming mobile and PyAutoGUI desktop actions into a standardized API format. Features smart coordinate handling, direction detection for scroll actions, and comprehensive parameter normalization. 684 | - **Action Space Converter** ([`utils/action_space_converter.py`](https://github.com/huggingface/smol2operator/blob/main/utils/action_space_converter.py)): Flexible tool for adapting the unified action space to custom vocabularies and naming conventions. Enables domain-specific customization through configurable parameter mappings. 685 | 686 |
693 | 💡 We’ve also opened a Space to experiment with the POC model’s agentic grounding capabilities: A-Mahla/Smol2Operator 694 |
695 | 696 | ## 5. Conclusion 697 | 698 | Our experiments demonstrate that high-quality, reasoning-oriented data can substantially improve GUI grounding, even for small VLMs, using only supervised fine-tuning (SFT). Beyond raw performance gains, these results show that the capabilities of a "GUI model" are largely determined by the quality of the data. High-quality, task-specific data for GUI interactions is a critical prerequisite for advancing agentic models. Carefully curated datasets teach models the structure and semantics of user interfaces, providing the grounding needed for accurate action prediction. 699 | 700 | While SFT excels at supervised tasks, emerging methods such as Reinforcement Learning (RL) or Direct Preference Optimization (DPO) support deeper reasoning and enable dynamic, real-time adaptation. These advances point toward a new generation of GUI agents that learn and improve through interaction rather than relying solely on static datasets. 701 | 702 | To support the development of GUI agents, we’re open-sourcing everything: our complete pipeline, datasets, and trained model. You can reproduce our results, experiment with different models and architectures, or adapt our approach to new domains. The future of agentic AI depends on researchers like you pushing these boundaries further 🤗 703 | --------------------------------------------------------------------------------