to perform the action. it will be executed in a stateful environment.
25 |
26 | The following functions are exposed to the Python interpreter:
27 |
28 |
29 | # OS ACTIONS
30 |
31 |
32 | def click(x: Optional[float] = None, y: Optional[float] = None) -> str:
33 | """
34 | Performs a left-click at the specified normalized coordinates
35 | Args:
36 | x: The x coordinate (horizontal position)
37 | y: The y coordinate (vertical position)
38 | """
39 |
40 |
41 |
42 | The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.'''
43 |
44 | SCREENSPOT_V2_USER_PROMPT_PHASE_2 = """Please generate the next move according to the UI screenshot, instruction and previous actions.
45 |
46 | Instruction: {instruction}
47 |
48 | Previous actions:
49 | None"""
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # Temp folders
174 | data/
175 | wandb/
176 | logs/
177 | eval_results/
178 | results/
179 |
180 | .vscode/
181 | .python-version
--------------------------------------------------------------------------------
/utils/collator.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 |
4 | def create_collate_fn(processor, max_length: int):
5 | """Optimized collate function for VLM training that masks system prompt tokens."""
6 |
7 | def collate_fn(examples: list[dict[str, list | str | Image.Image]]):
8 | batch_messages: list[list[dict[str, list | str | Image.Image]]] = []
9 | assistant_messages: list[list[str]] = []
10 | all_image_inputs: list[list[Image.Image]] = []
11 | for example in examples:
12 | images: list[Image.Image] = example["images"]
13 | is_first_user = True
14 | sample: list[dict[str, list | str | Image.Image]] = []
15 | assistant: list[str] = []
16 | for text in example["texts"]:
17 | if "system" in text.keys():
18 | sample.append(
19 | {
20 | "role": "system",
21 | "content": [{"type": "text", "text": text["system"]}],
22 | }
23 | )
24 |
25 | if is_first_user:
26 | sample.append(
27 | {
28 | "role": "user",
29 | "content": [
30 | {"type": "image", "image": images[0]},
31 | {"type": "text", "text": text["user"]},
32 | ],
33 | }
34 | )
35 | is_first_user = False
36 | else:
37 | sample.append(
38 | {
39 | "role": "user",
40 | "content": [
41 | {"type": "text", "text": text["user"]},
42 | ],
43 | }
44 | )
45 |
46 | sample.append(
47 | {
48 | "role": "assistant",
49 | "content": [{"type": "text", "text": "\n" + text["assistant"]}],
50 | }
51 | )
52 | assistant.append(text["assistant"])
53 |
54 | batch_messages.append(sample)
55 | assistant_messages.append(assistant)
56 | all_image_inputs.append(images)
57 |
58 | texts = [
59 | processor.apply_chat_template(
60 | messages, tokenize=False, add_generation_prompt=False
61 | )
62 | for messages in batch_messages
63 | ]
64 |
65 | batch = processor(
66 | text=texts,
67 | images=all_image_inputs if all_image_inputs else None,
68 | max_length=max_length,
69 | truncation=True,
70 | padding=True,
71 | return_tensors="pt",
72 | )
73 |
74 | input_ids = batch["input_ids"]
75 | labels = input_ids.clone()
76 |
77 | assistant_encodings = [
78 | processor.tokenizer(
79 | [msg + " to perform the action. it will be executed in a stateful environment.
115 |
116 | The following functions are exposed to the Python interpreter:
117 |
118 | {OS_ACTIONS}
119 |
120 |
121 | The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
122 | """
123 |
124 | MOBILE_SYSTEM_PROMPT = f"""You are a helpful GUI agent. You’ll be given a task and a screenshot of the screen. Complete the task using Python function calls.
125 |
126 | For each step:
127 | • First, to perform the action. it will be executed in a stateful environment.
129 |
130 | The following functions are exposed to the Python interpreter:
131 |
132 |
133 | # OS ACTIONS
134 |
135 | {OS_ACTIONS}
136 |
137 | # MOBILE ACTIONS
138 |
139 | {MOBILE_ACTIONS}
140 |
141 |
142 | The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
143 | """
--------------------------------------------------------------------------------
/preprocessing/aguvis_models.py:
--------------------------------------------------------------------------------
1 | """
2 | Models and data structures for the AGUVIS dataset processing.
3 | Contains all Pydantic models, configuration classes, and data validation logic.
4 | """
5 |
6 | from typing import List, Optional, Literal, Any, Dict
7 | from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
8 | from PIL import Image
9 | import json
10 | from collections import OrderedDict
11 |
12 |
13 | class ChatMessage(BaseModel):
14 | """Represents a single chat message with role and content."""
15 | role: Literal["user", "assistant", "system"]
16 | content: str
17 |
18 | @staticmethod
19 | def from_conversation_list(data: list[dict[str, str]]) -> list["ChatMessage"]:
20 | """Convert conversation list to ChatMessage objects."""
21 | messages = []
22 | system_added = False
23 | for item in data:
24 | if item["from"] == "system":
25 | if not system_added:
26 | role: Literal["user", "assistant", "system"] = "system"
27 | messages.append(ChatMessage(role=role, content=item["value"]))
28 | system_added = True
29 | elif item["from"] == "human":
30 | role = "user"
31 | messages.append(ChatMessage(role=role, content=item["value"]))
32 | else:
33 | role = "assistant"
34 | messages.append(ChatMessage(role=role, content=item["value"]))
35 |
36 | return messages
37 |
38 |
39 | class ConversationEntry(BaseModel):
40 | """Represents a single entry in a conversation."""
41 | from_: Literal["system", "human", "gpt"] = Field(alias="from")
42 | value: str
43 | recipient: Optional[str] = None
44 | end_turn: Optional[bool] = None
45 |
46 | def to_chat_message(self) -> ChatMessage:
47 | """Convert conversation entry to ChatMessage."""
48 | if self.from_ == "system":
49 | role: Literal["user", "assistant", "system"] = "system"
50 | elif self.from_ == "human":
51 | role = "user"
52 | else:
53 | role = "assistant"
54 | return ChatMessage(role=role, content=self.value)
55 |
56 |
57 | class ConversationData(BaseModel):
58 | """Represents conversation data with associated image and conversation entries."""
59 | image: str
60 | conversations: List[ConversationEntry]
61 | recipient: Optional[str] = None
62 | end_turn: Optional[bool] = None
63 |
64 | @field_validator("image", mode="before")
65 | def validate_image(cls, v):
66 | """Validate and normalize image field."""
67 | if isinstance(v, list):
68 | if len(v) == 1:
69 | return v[0]
70 | elif len(v) == 2:
71 | return v[1]
72 | else:
73 | raise ValueError("Expected 1 or 2 images, got multiple")
74 | return v
75 |
76 | def to_chat_messages(self) -> list[ChatMessage]:
77 | """Convert all conversation entries to ChatMessage objects."""
78 | return [conversation.to_chat_message() for conversation in self.conversations]
79 |
80 |
81 | class ConversationDataList(RootModel[List[ConversationData]]):
82 | """Root model for a list of conversation data with validation and optional deduplication."""
83 |
84 | @classmethod
85 | def from_json_with_deduplication(cls, json_str: str, deduplicate: bool = True) -> "ConversationDataList":
86 | """Create instance from JSON with deduplication control."""
87 | if deduplicate:
88 | # Use normal validation with deduplication
89 | return cls.model_validate_json(json_str)
90 | else:
91 | data = json.loads(json_str)
92 | conversation_data_list = [ConversationData(**item) for item in data]
93 |
94 | # Create instance directly without triggering model validators
95 | instance = cls.__new__(cls)
96 | instance.__dict__.update({'root': conversation_data_list})
97 | instance.__pydantic_fields_set__ = {'root'}
98 | instance.__pydantic_extra__ = {}
99 |
100 | return instance
101 |
102 | @model_validator(mode="after")
103 | def validate_conversation(self):
104 | """Validate and deduplicate conversations."""
105 | new_conversations: dict[str, List[ConversationData]] = {}
106 |
107 | # merge image duplicates
108 | for conversation in self.root:
109 | if conversation.image not in new_conversations:
110 | new_conversations[conversation.image] = [conversation]
111 | else:
112 | new_conversations[conversation.image].append(conversation)
113 |
114 | # delete text duplicates
115 | duplicates = 0
116 | for data in new_conversations.values():
117 | if isinstance(data, list):
118 | index_to_pop = set()
119 | for i in range(len(data) - 1):
120 | for j in range(i + 1, len(data)):
121 | if [c1.model_dump() for c1 in data[i].conversations] == [c2.model_dump() for c2 in data[j].conversations]:
122 | if j not in index_to_pop:
123 | duplicates += 1
124 | index_to_pop.add(j)
125 | for index in sorted(index_to_pop, reverse=True):
126 | data.pop(index)
127 |
128 | # merge conversations for same images
129 | new_data = []
130 | for data in new_conversations.values():
131 | for i in range(len(data)):
132 | if i == 0:
133 | new_data.append(data[i])
134 | else:
135 | new_data[-1].conversations.extend(data[i].conversations)
136 |
137 | self.root = new_data
138 | return self
139 |
140 |
141 | class DataRow(BaseModel):
142 | """Represents a processed data row with images, texts, and source information."""
143 | images: list[Image.Image]
144 | texts: list[OrderedDict[str, str]]
145 | source: str
146 |
147 | class Config:
148 | arbitrary_types_allowed = True
149 |
150 | @classmethod
151 | def from_chat_messages(cls, messages: list[ChatMessage], image: Image.Image, source: str) -> "DataRow":
152 | """Create DataRow from chat messages and associated image."""
153 | system, user, assistant = None, None, None
154 | have_system = any(message.role == "system" for message in messages)
155 | texts: list[OrderedDict[str, str]] = []
156 | images = [image]
157 | chat_messages: OrderedDict[str, str] = OrderedDict()
158 |
159 | for message in messages:
160 | if message.role == "system":
161 | system = message.content
162 | elif message.role == "user":
163 | user = message.content
164 | elif message.role == "assistant":
165 | assistant = message.content
166 |
167 | if have_system and user is not None and assistant is not None and system is not None:
168 | chat_messages["system"] = system
169 | chat_messages["user"] = user
170 | chat_messages["assistant"] = assistant
171 | texts.append(chat_messages)
172 | chat_messages = OrderedDict()
173 | user, assistant = None, None
174 |
175 | elif not have_system and user is not None and assistant is not None:
176 | chat_messages["user"] = user
177 | chat_messages["assistant"] = assistant
178 | texts.append(chat_messages)
179 | chat_messages = OrderedDict()
180 | user, assistant = None, None
181 |
182 | return cls(images=images, texts=texts, source=source)
183 |
184 | def to_model_dump(self) -> dict:
185 | """Convert to dictionary representation."""
186 | return {
187 | "images": self.images,
188 | "texts": self.texts,
189 | "source": self.source,
190 | }
191 |
192 |
193 | class DatasetConfig(BaseModel):
194 | """Configuration for dataset processing."""
195 | huggingface_repo_id: str
196 | local_path: str
197 | config_dict: List[Dict[str, Any]]
198 | smolagents_repo_id: str
199 | reasoning: bool
200 | deduplicate: bool = False
201 |
202 |
203 | class ProcessingConfig(BaseModel):
204 | """Configuration for processing parameters."""
205 | subset_name: str
206 | json_path: str
207 | images_folder: str
208 |
209 | @classmethod
210 | def from_config_dict(cls, config: Dict[str, Any]) -> "ProcessingConfig":
211 | """Create ProcessingConfig from configuration dictionary."""
212 | subset_name = (
213 | config["json_path"]
214 | .replace(".json", "")
215 | )
216 |
217 | return cls(
218 | subset_name=subset_name,
219 | json_path=config["json_path"],
220 | images_folder=config["images_folder"]
221 | )
222 |
223 |
224 | # Mobile action space files configuration
225 | MOBILE_FILES = [
226 | "android_control.json",
227 | "aitw-l1.json",
228 | "aitw-l2.json",
229 | "aitw-l3.json",
230 | "coat.json",
231 | "amex-l1.json",
232 | "amex-l2.json",
233 | "amex-l3.json",
234 | "gui-odyssey-l1.json",
235 | "gui-odyssey-l2.json",
236 | "gui-odyssey-l3.json",
237 | ]
238 |
239 | # Stage 1 dataset configuration
240 | CONFIG_DICT_STAGE_1 = [
241 | {
242 | "json_path": "guienv.json",
243 | "images_folder": "guienvs/images/",
244 | },
245 | {
246 | "json_path": "omniact.json",
247 | "images_folder": "omniact/images/",
248 | },
249 | {
250 | "json_path": "ricoig16k.json",
251 | "images_folder": "ricoig16k/images/",
252 | },
253 | {
254 | "json_path": "ricosca.json",
255 | "images_folder": "ricosca/images/",
256 | },
257 | {
258 | "json_path": "seeclick.json",
259 | "images_folder": "seeclick/seeclick_web_imgs/",
260 | },
261 | {
262 | "json_path": "webui350k.json",
263 | "images_folder": "webui350k/images/",
264 | },
265 | {
266 | "json_path": "ui_refexp.json",
267 | "images_folder": "ui_refexp/images/",
268 | },
269 | {
270 | "json_path": "widget_captioning.json",
271 | "images_folder": "widget_captioning/images/",
272 | },
273 | ]
274 |
275 | # Stage 2 dataset configuration
276 | CONFIG_DICT_STAGE_2 = [
277 | {
278 | "json_path": "mind2web-l1.json",
279 | "images_folder": "mind2web/",
280 | },
281 | {
282 | "json_path": "mind2web-l2.json",
283 | "images_folder": "mind2web/",
284 | },
285 | {
286 | "json_path": "mind2web-l3.json",
287 | "images_folder": "mind2web/",
288 | },
289 | {
290 | "json_path": "guiact-web-single.json",
291 | "images_folder": "guiact-web-single/images/",
292 | },
293 | {
294 | "json_path": "guiact-web-multi-l1.json",
295 | "images_folder": "guiact-web-multi-v2/images",
296 | },
297 | {
298 | "json_path": "guiact-web-multi-l2.json",
299 | "images_folder": "guiact-web-multi-v2/images",
300 | },
301 | {
302 | "json_path": "guiact-web-multi-l3.json",
303 | "images_folder": "guiact-web-multi-v2/images",
304 | },
305 | {
306 | "json_path": "miniwob-l1.json",
307 | "images_folder": "images",
308 | },
309 | {
310 | "json_path": "miniwob-l2.json",
311 | "images_folder": "images",
312 | },
313 | {
314 | "json_path": "miniwob-l3.json",
315 | "images_folder": "images",
316 | },
317 | {
318 | "json_path": "coat.json",
319 | "images_folder": "coat/images/",
320 | },
321 | {
322 | "json_path": "android_control.json",
323 | "images_folder": "android_control/images/",
324 | },
325 | {
326 | "json_path": "gui-odyssey-l1.json",
327 | "images_folder": "gui-odyssey/images/",
328 | },
329 | {
330 | "json_path": "gui-odyssey-l2.json",
331 | "images_folder": "gui-odyssey/images/",
332 | },
333 | {
334 | "json_path": "gui-odyssey-l3.json",
335 | "images_folder": "gui-odyssey/images/",
336 | },
337 | {
338 | "json_path": "amex-l1.json",
339 | "images_folder": "amex/images/",
340 | },
341 | {
342 | "json_path": "amex-l2.json",
343 | "images_folder": "amex/images/",
344 | },
345 | {
346 | "json_path": "amex-l3.json",
347 | "images_folder": "amex/images/",
348 | },
349 | {
350 | "json_path": "aitw-l1.json",
351 | "images_folder": "aitw-v1/images/",
352 | },
353 | {
354 | "json_path": "aitw-l2.json",
355 | "images_folder": "aitw-v1/images/",
356 | },
357 | {
358 | "json_path": "aitw-l3.json",
359 | "images_folder": "aitw-v1/images/",
360 | },
361 | ]
362 |
--------------------------------------------------------------------------------
/preprocessing/action_conversion.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from utils.function_parser import FunctionCall
3 | from copy import deepcopy
4 |
5 | # Configure logger for this module
6 | logger = logging.getLogger(__name__)
7 |
8 | """
9 | ╔══════════════════════════════════════════════════════════════════════════════════════╗
10 | ║ 🔄 ACTION CONVERSION MAPPINGS 🔄 ║
11 | ║ Transform Aguvis & PyAutoGUI actions to unified API ║
12 | ╚══════════════════════════════════════════════════════════════════════════════════════╝
13 |
14 | 📱 MOBILE ACTIONS (Aguvis → Custom API):
15 | • mobile.home() → navigate_home()
16 | • mobile.open_app(app_name='drupe') → open_app(app_name: str)
17 | • mobile.swipe(from_coord=[x,y], to_coord=[x,y]) → swipe(from_coord: tuple, to_coord: tuple)
18 | • mobile.back() → navigate_back()
19 | • mobile.long_press(x=0.8, y=0.9) → long_press(x: float, y: float)
20 | • mobile.terminate(status='success') → final_answer(answer: str)
21 | • mobile.wait(seconds=3) → wait(seconds: int)
22 |
23 | 💻 DESKTOP ACTIONS (PyAutoGUI → Custom API):
24 | • pyautogui.click(x=0.8, y=0.9) → click(x: float, y: float)
25 | • pyautogui.doubleClick() → double_click()
26 | • pyautogui.rightClick() → right_click()
27 | • pyautogui.hotkey(keys=['ctrl', 'c']) → press(keys: str | list)
28 | • pyautogui.press(keys='enter') → press(keys: str | list)
29 | • pyautogui.moveTo(x=0.04, y=0.4) → move_mouse(x: float, y: float)
30 | • pyautogui.write(message='text') → type(text: str)
31 | • pyautogui.dragTo(from_coord=[x,y], to_coord=[x,y]) → drag(from_coord: tuple, to_coord: tuple)
32 |
33 | 🖱️ SCROLL ACTIONS (Smart Direction Detection):
34 | • pyautogui.scroll(page=-0.1) [negative] → scroll(direction="up", amount=10)
35 | • pyautogui.scroll(page=0.1) [positive] → scroll(direction="down", amount=10)
36 | • pyautogui.hscroll(page=-0.1) [negative] → scroll(direction="left", amount=10)
37 | • pyautogui.hscroll(page=0.1) [positive] → scroll(direction="right", amount=10)
38 |
39 | ✅ COMPLETION ACTIONS:
40 | • answer('text') → final_answer('text')
41 | """
42 |
43 |
44 | def convert_to_pixel_coordinates(action: FunctionCall, resolution: tuple[int, int]) -> None:
45 | """
46 | 🎯 Convert normalized coordinates (0.0-1.0) to absolute pixel coordinates.
47 |
48 | Transforms relative coordinates to screen pixels based on the given resolution.
49 | Handles both single coordinates (x, y) and coordinate pairs (from_coord, to_coord).
50 |
51 | Args:
52 | action: FunctionCall object containing coordinate parameters
53 | resolution: Screen resolution as (width, height) in pixels
54 |
55 | Note: Modifies the action object in-place by updating parameter names and values.
56 | """
57 | if "arg_0" in action.parameters:
58 | if isinstance(action.parameters["arg_0"], (list, tuple)):
59 | action.parameters["from_coord"] = (int(action.parameters["arg_0"][0] * resolution[0]), int(action.parameters["arg_0"][1] * resolution[1]))
60 | else:
61 | action.parameters["x"] = int(action.parameters["arg_0"] * resolution[0])
62 | del action.parameters["arg_0"]
63 | if "arg_1" in action.parameters:
64 | if isinstance(action.parameters["arg_1"], (list, tuple)):
65 | action.parameters["to_coord"] = (int(action.parameters["arg_1"][0] * resolution[0]), int(action.parameters["arg_1"][1] * resolution[1]))
66 | else:
67 | action.parameters["y"] = int(action.parameters["arg_1"] * resolution[1])
68 | del action.parameters["arg_1"]
69 |
70 | def change_argument_name(action: FunctionCall) -> None:
71 | """
72 | 🔄 Transform generic argument names to semantic parameter names.
73 |
74 | Converts arg_0, arg_1 to meaningful names like 'x', 'y', 'from_coord', 'to_coord'.
75 | Maintains coordinate values as floats for normalized coordinate system.
76 |
77 | Args:
78 | action: FunctionCall object with generic argument names
79 |
80 | Note: Modifies the action object in-place, preserving original coordinate values.
81 | """
82 | if "arg_0" in action.parameters:
83 | if isinstance(action.parameters["arg_0"], (list, tuple)):
84 | action.parameters["from_coord"] = (float(action.parameters["arg_0"][0]), float(action.parameters["arg_0"][1]))
85 | else:
86 | action.parameters["x"] = float(action.parameters["arg_0"])
87 | del action.parameters["arg_0"]
88 | if "arg_1" in action.parameters:
89 | if isinstance(action.parameters["arg_1"], (list, tuple)):
90 | action.parameters["to_coord"] = (float(action.parameters["arg_1"][0]), float(action.parameters["arg_1"][1]))
91 | else:
92 | action.parameters["y"] = float(action.parameters["arg_1"])
93 | del action.parameters["arg_1"]
94 |
95 |
96 | def rename_parameters(action: FunctionCall) -> None:
97 | """
98 | 🏷️ Standardize parameter names to arg_0, arg_1, arg_2 format.
99 |
100 | Converts named parameters to a generic indexed format while preserving
101 | the original parameter order. This creates a uniform interface for
102 | subsequent processing steps.
103 |
104 | Args:
105 | action: FunctionCall object to standardize parameter names for
106 |
107 | Example:
108 | Before: {"x": 0.5, "y": 0.8} → After: {"arg_0": 0.5, "arg_1": 0.8}
109 | """
110 | if not action.parameters:
111 | return
112 |
113 | for i, (key, value) in enumerate(deepcopy(action.parameters).items()):
114 | tmp = value
115 | del action.parameters[key]
116 | action.parameters[f"arg_{i}"] = tmp
117 |
118 |
119 |
120 | def action_conversion(
121 | actions: list[FunctionCall], resolution: tuple[int, int]
122 | ) -> list[FunctionCall]:
123 | """
124 | 🚀 Master conversion function: Transform diverse action formats into unified API.
125 |
126 | This is the main orchestrator that converts actions from different sources
127 | (Aguvis mobile actions, PyAutoGUI desktop actions) into a standardized
128 | action format for consistent processing.
129 |
130 | Args:
131 | actions: List of FunctionCall objects to convert
132 | resolution: Screen resolution (width, height) for coordinate conversion
133 |
134 | Returns:
135 | List of converted FunctionCall objects with unified naming and structure
136 |
137 | Features:
138 | • 📱 Mobile action normalization (Aguvis → Custom API)
139 | • 💻 Desktop action standardization (PyAutoGUI → Custom API)
140 | • 🎯 Smart coordinate handling (relative ↔ absolute)
141 | • 🖱️ Intelligent scroll direction detection
142 | • ✅ Consistent error handling and validation
143 | """
144 | for i, action in enumerate(actions):
145 | rename_parameters(action)
146 |
147 | # ═══════════════════════════════════════════════════════════════
148 | # 📱 MOBILE ACTIONS (Aguvis Framework)
149 | # ═══════════════════════════════════════════════════════════════
150 | if action.function_name == "mobile.home":
151 | actions[i].function_name = "navigate_home"
152 |
153 | elif action.function_name == "mobile.open_app":
154 | actions[i].function_name = "open_app"
155 |
156 | elif action.function_name == "mobile.swipe":
157 | actions[i].function_name = "swipe"
158 | change_argument_name(actions[i])
159 |
160 | elif action.function_name == "mobile.back":
161 | actions[i].function_name = "navigate_back"
162 |
163 | elif action.function_name == "mobile.long_press":
164 | actions[i].function_name = "long_press"
165 | change_argument_name(actions[i])
166 |
167 | elif action.function_name in ["mobile.terminate", "answer"]:
168 | actions[i].function_name = "final_answer"
169 |
170 | elif action.function_name == "mobile.wait":
171 | actions[i].function_name = "wait"
172 | if "arg_0" in actions[i].parameters:
173 | actions[i].parameters["seconds"] = int(actions[i].parameters["arg_0"])
174 | del actions[i].parameters["arg_0"]
175 |
176 | # ═══════════════════════════════════════════════════════════════
177 | # 💻 DESKTOP ACTIONS (PyAutoGUI Framework)
178 | # ═══════════════════════════════════════════════════════════════
179 | elif action.function_name == "pyautogui.click":
180 | actions[i].function_name = "click"
181 | change_argument_name(actions[i])
182 |
183 | elif action.function_name == "pyautogui.doubleClick":
184 | actions[i].function_name = "double_click"
185 | change_argument_name(actions[i])
186 |
187 | elif action.function_name == "pyautogui.rightClick":
188 | actions[i].function_name = "right_click"
189 | change_argument_name(actions[i])
190 |
191 | elif action.function_name in ["pyautogui.hotkey", "pyautogui.press"]:
192 | actions[i].function_name = "press"
193 | if "arg_0" in actions[i].parameters:
194 | actions[i].parameters["keys"] = actions[i].parameters["arg_0"]
195 | del actions[i].parameters["arg_0"]
196 |
197 | elif action.function_name == "pyautogui.moveTo":
198 | actions[i].function_name = "move_mouse"
199 | change_argument_name(actions[i])
200 |
201 | elif action.function_name == "pyautogui.write":
202 | actions[i].function_name = "type"
203 |
204 | # ──────────────────────────────────────────────────────────────
205 | # 🖱️ SCROLL ACTIONS (Direction Detection)
206 | # ──────────────────────────────────────────────────────────────
207 | elif action.function_name in ["pyautogui.scroll", "pyautogui.hscroll"]:
208 | arg_value = actions[i].parameters["arg_0"]
209 | if arg_value < 0:
210 | if action.function_name == "pyautogui.hscroll":
211 | actions[i].parameters["direction"] = "left"
212 | else:
213 | actions[i].parameters["direction"] = "up"
214 | else:
215 | if action.function_name == "pyautogui.hscroll":
216 | actions[i].parameters["direction"] = "right"
217 | else:
218 | actions[i].parameters["direction"] = "down"
219 | del actions[i].parameters["arg_0"]
220 | actions[i].function_name = "scroll"
221 | actions[i].parameters["amount"] = int(abs(arg_value * 100))
222 |
223 | elif action.function_name == "pyautogui.dragTo":
224 | actions[i].function_name = "drag"
225 | change_argument_name(actions[i])
226 |
227 | else:
228 | raise ValueError(f"🚫 Unsupported action: {action.function_name}")
229 |
230 | # 💾 Preserve original string representation for debugging
231 | actions[i].original_string = actions[i].to_string()
232 |
233 | return actions
234 |
235 | if __name__ == "__main__":
236 | from utils.function_parser import FunctionCall
237 |
238 | """
239 | ╔════════════════════════════════════════════════════════════════════════════════╗
240 | ║ 🧪 TESTING & DEMONSTRATION 🧪 ║
241 | ║ Comprehensive test suite for action conversion ║
242 | ╚════════════════════════════════════════════════════════════════════════════════╝
243 | """
244 |
245 | # 📋 Complete test dataset covering all supported action types
246 | actions = [
247 | # ═══════════════════════════════════════════════════════════════
248 | # 📱 MOBILE ACTIONS (Aguvis Framework)
249 | # ═══════════════════════════════════════════════════════════════
250 | FunctionCall("mobile.home", {}, "mobile.home()"),
251 | FunctionCall("mobile.open_app", {"app_name": "drupe"}, "mobile.open_app(app_name='drupe')"),
252 | FunctionCall("mobile.swipe", {"from_coord": [0.581, 0.898], "to_coord": [0.601, 0.518]}, "mobile.swipe(from_coord=[0.581,0.898],to_coord=[0.601,0.518])"),
253 | FunctionCall("mobile.back", {}, "mobile.back()"),
254 | FunctionCall("mobile.long_press", {"x": 0.799, "y": 0.911}, "mobile.long_press(x=0.799, y=0.911)"),
255 | FunctionCall("mobile.terminate", {"status": "success"}, "mobile.terminate(status='success')"),
256 | FunctionCall("answer", {"arg_0": "text"}, "answer('text')"),
257 | FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)"),
258 |
259 | # ═══════════════════════════════════════════════════════════════
260 | # 💻 DESKTOP ACTIONS (PyAutoGUI Framework)
261 | # ═══════════════════════════════════════════════════════════════
262 | FunctionCall("pyautogui.hscroll", {"page": -0.1}, "pyautogui.hscroll(page=-0.1)"),
263 | FunctionCall("pyautogui.scroll", {"page": 0.13}, "pyautogui.scroll(page=0.13)"),
264 | FunctionCall("pyautogui.click", {"x": 0.8102, "y": 0.9463}, "pyautogui.click(x=0.8102, y=0.9463)"),
265 | FunctionCall("pyautogui.doubleClick", {}, "pyautogui.doubleClick()"),
266 | FunctionCall("pyautogui.hotkey", {"keys": ["ctrl", "c"]}, "pyautogui.hotkey(keys=['ctrl','c'])"),
267 | FunctionCall("pyautogui.press", {"keys": "enter"}, "pyautogui.press(keys='enter')"),
268 | FunctionCall("pyautogui.moveTo", {"x": 0.04, "y": 0.405}, "pyautogui.moveTo(x=0.04, y=0.405)"),
269 | FunctionCall("pyautogui.write", {"message": "bread buns"}, "pyautogui.write(message='bread buns')"),
270 | FunctionCall("pyautogui.dragTo", {"from_coord": [0.87, 0.423], "to_coord": [0.8102, 0.9463]}, "pyautogui.dragTo(from_coord=[0.87, 0.423], to_coord=[0.8102, 0.9463])"),
271 | ]
272 |
273 | # 🖥️ Test resolution (Full HD Portrait - typical mobile orientation)
274 | resolution = (1080, 1920)
275 |
276 | logger.info("🔄 BEFORE CONVERSION:")
277 | logger.info("═" * 50)
278 | for action in actions:
279 | logger.info(f" 📋 {action}")
280 |
281 | logger.info(f"\n🚀 AFTER CONVERSION (Resolution: {resolution}):")
282 | logger.info("═" * 50)
283 | converted = action_conversion(actions, resolution)
284 | for action in converted:
285 | logger.info(f" ✅ {action}")
286 |
--------------------------------------------------------------------------------
/utils/action_space_converter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Flexible Action Space Converter
4 |
5 | A configurable system that allows users to define custom action space mappings
6 | for transforming unified API actions to their own custom action formats.
7 | This enables users to create domain-specific action spaces for training
8 | assistants with different action vocabularies.
9 | """
10 |
11 | from __future__ import annotations
12 | import logging
13 | from typing import List, Callable, Any, Optional
14 | from copy import deepcopy
15 | from utils.function_parser import FunctionCall
16 | from pydantic import BaseModel
17 |
18 | logging.basicConfig(level=logging.INFO)
19 |
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class ParameterMapping(BaseModel):
25 | """Defines how to map parameters from source to target function."""
26 | source_name: str
27 | target_name: str
28 | transform: Optional[Callable[[Any], Any]] = None
29 | default_value: Optional[Any] = None
30 |
31 |
32 | class ActionMapping(BaseModel):
33 | """Defines how to convert one action to another."""
34 | source_function: str
35 | target_function: str
36 | parameter_mappings: List[ParameterMapping]
37 | custom_transform: Optional[Callable[[FunctionCall], FunctionCall]] = None
38 | description: str = ""
39 |
40 |
41 | class ActionSpaceConverter:
42 | """
43 | Flexible Action Space Converter
44 |
45 | Allows users to define custom action space mappings to transform
46 | the unified API actions into their own custom action formats.
47 | """
48 |
49 | def __init__(self, mappings: List[ActionMapping]):
50 | """
51 | Initialize the converter with action mappings.
52 |
53 | Args:
54 | mappings: List of ActionMapping objects defining the conversions
55 | """
56 | self.mappings = {mapping.source_function: mapping for mapping in mappings}
57 | self._validate_mappings()
58 |
59 | def _validate_mappings(self) -> None:
60 | """Validate that all mappings are properly configured."""
61 | for source_func, mapping in self.mappings.items():
62 | if not mapping.target_function:
63 | raise ValueError(f"Target function not specified for '{source_func}'")
64 |
65 | # Check for duplicate parameter targets
66 | target_params = [pm.target_name for pm in mapping.parameter_mappings]
67 | if len(target_params) != len(set(target_params)):
68 | raise ValueError(f"Duplicate target parameter names in mapping for '{source_func}'")
69 |
70 | def convert_actions(self, actions: List[FunctionCall]) -> List[FunctionCall]:
71 | """
72 | Convert a list of actions using the defined mappings.
73 |
74 | Args:
75 | actions: List of FunctionCall objects to convert
76 |
77 | Returns:
78 | List of converted FunctionCall objects
79 |
80 | Raises:
81 | ValueError: If an unsupported action is encountered
82 | """
83 | converted_actions = []
84 |
85 | for action in actions:
86 | try:
87 | converted_action = self._convert_single_action(action)
88 | converted_actions.append(converted_action)
89 | except Exception as e:
90 | logger.error(f"Failed to convert action '{action.function_name}': {e}")
91 | raise
92 |
93 | return converted_actions
94 |
95 | def _convert_single_action(self, action: FunctionCall) -> FunctionCall:
96 | """
97 | Convert a single action using the mapping.
98 |
99 | Args:
100 | action: FunctionCall object to convert
101 |
102 | Returns:
103 | Converted FunctionCall object
104 | """
105 | if action.function_name not in self.mappings:
106 | raise ValueError(f"Unsupported action: {action.function_name}")
107 |
108 | mapping = self.mappings[action.function_name]
109 |
110 | if mapping.custom_transform:
111 | return mapping.custom_transform(deepcopy(action))
112 |
113 | new_parameters = {}
114 |
115 | for param_mapping in mapping.parameter_mappings:
116 | source_value = action.parameters.get(param_mapping.source_name, param_mapping.default_value)
117 |
118 | if source_value is None and param_mapping.default_value is None:
119 | # Skip missing optional parameters
120 | continue
121 |
122 | if param_mapping.transform:
123 | source_value = param_mapping.transform(source_value)
124 |
125 | new_parameters[param_mapping.target_name] = source_value
126 |
127 | # Create new action
128 | new_action = FunctionCall(
129 | function_name=mapping.target_function,
130 | parameters=new_parameters,
131 | original_string="",
132 | description=mapping.description
133 | )
134 |
135 | # Update the original string representation
136 | new_action.original_string = new_action.to_string()
137 |
138 | return new_action
139 |
140 | def add_mapping(self, mapping: ActionMapping) -> None:
141 | """Add a new action mapping to the converter."""
142 | self.mappings[mapping.source_function] = mapping
143 | self._validate_mappings()
144 |
145 | def remove_mapping(self, source_function: str) -> None:
146 | """Remove an action mapping from the converter."""
147 | if source_function in self.mappings:
148 | del self.mappings[source_function]
149 |
150 | def get_supported_actions(self) -> List[str]:
151 | """Get list of supported source actions."""
152 | return list(self.mappings.keys())
153 |
154 | def get_mapping_info(self, source_function: str) -> Optional[ActionMapping]:
155 | """Get mapping information for a specific source function."""
156 | return self.mappings.get(source_function)
157 |
158 |
159 | def create_default_unified_to_custom_converter() -> ActionSpaceConverter:
160 | """
161 | 🏭 Factory function to create a converter from unified API to a custom action space.
162 |
163 | This demonstrates how to create custom action space mappings.
164 | Users can modify this or create their own mapping configurations.
165 |
166 | Returns:
167 | ActionSpaceConverter configured with example custom mappings
168 | """
169 |
170 | # Example custom action space mappings
171 | mappings = [
172 | # Navigation actions
173 | ActionMapping(
174 | source_function="navigate_home",
175 | target_function="go_home",
176 | parameter_mappings=[],
177 | description="Navigate to home screen"
178 | ),
179 |
180 | ActionMapping(
181 | source_function="navigate_back",
182 | target_function="go_back",
183 | parameter_mappings=[],
184 | description="Navigate back"
185 | ),
186 |
187 | # App interaction
188 | ActionMapping(
189 | source_function="open_app",
190 | target_function="launch_application",
191 | parameter_mappings=[
192 | ParameterMapping(source_name="arg_0", target_name="application_name")
193 | ],
194 | description="Launch an application"
195 | ),
196 |
197 | # Touch interactions
198 | ActionMapping(
199 | source_function="click",
200 | target_function="touch",
201 | parameter_mappings=[
202 | ParameterMapping(source_name="x", target_name="x_coord"),
203 | ParameterMapping(source_name="y", target_name="y_coord")
204 | ],
205 | description="Touch screen at coordinates"
206 | ),
207 |
208 | ActionMapping(
209 | source_function="long_press",
210 | target_function="long_touch",
211 | parameter_mappings=[
212 | ParameterMapping(source_name="x", target_name="x_coord"),
213 | ParameterMapping(source_name="y", target_name="y_coord"),
214 | ParameterMapping(source_name="duration", target_name="hold_time", default_value=1.0)
215 | ],
216 | description="Long touch screen at coordinates"
217 | ),
218 |
219 | # Gesture actions
220 | ActionMapping(
221 | source_function="swipe",
222 | target_function="gesture_swipe",
223 | parameter_mappings=[
224 | ParameterMapping(source_name="from_coord", target_name="start_point"),
225 | ParameterMapping(source_name="to_coord", target_name="end_point"),
226 | ],
227 | description="Swipe gesture between two points"
228 | ),
229 |
230 | # Scroll actions with custom direction mapping
231 | ActionMapping(
232 | source_function="scroll",
233 | target_function="scroll_view",
234 | parameter_mappings=[
235 | ParameterMapping(
236 | source_name="direction",
237 | target_name="scroll_direction",
238 | transform=lambda x: {"up": "north", "down": "south", "left": "west", "right": "east"}.get(x, x)
239 | ),
240 | ParameterMapping(source_name="amount", target_name="scroll_distance")
241 | ],
242 | description="Scroll view in specified direction"
243 | ),
244 |
245 | # Input actions
246 | ActionMapping(
247 | source_function="type",
248 | target_function="input_text",
249 | parameter_mappings=[
250 | ParameterMapping(source_name="text", target_name="content"),
251 | ],
252 | description="Input text"
253 | ),
254 |
255 | ActionMapping(
256 | source_function="press",
257 | target_function="key_press",
258 | parameter_mappings=[
259 | ParameterMapping(source_name="keys", target_name="key_combination")
260 | ],
261 | description="Press key combination"
262 | ),
263 |
264 | # Mouse actions
265 | ActionMapping(
266 | source_function="move_mouse",
267 | target_function="cursor_move",
268 | parameter_mappings=[
269 | ParameterMapping(source_name="x", target_name="x_position"),
270 | ParameterMapping(source_name="y", target_name="y_position")
271 | ],
272 | description="Move cursor to position"
273 | ),
274 |
275 | ActionMapping(
276 | source_function="double_click",
277 | target_function="double_touch",
278 | parameter_mappings=[
279 | ParameterMapping(source_name="x", target_name="x_coord", default_value=0.5),
280 | ParameterMapping(source_name="y", target_name="y_coord", default_value=0.5)
281 | ],
282 | description="Double touch at coordinates"
283 | ),
284 |
285 | ActionMapping(
286 | source_function="right_click",
287 | target_function="context_menu",
288 | parameter_mappings=[
289 | ParameterMapping(source_name="x", target_name="x_coord", default_value=0.5),
290 | ParameterMapping(source_name="y", target_name="y_coord", default_value=0.5)
291 | ],
292 | description="Open context menu"
293 | ),
294 |
295 | ActionMapping(
296 | source_function="drag",
297 | target_function="drag_and_drop",
298 | parameter_mappings=[
299 | ParameterMapping(source_name="from_coord", target_name="start_position"),
300 | ParameterMapping(source_name="to_coord", target_name="end_position")
301 | ],
302 | description="Drag and drop operation"
303 | ),
304 |
305 | # Timing and completion
306 | ActionMapping(
307 | source_function="wait",
308 | target_function="pause",
309 | parameter_mappings=[
310 | ParameterMapping(source_name="seconds", target_name="duration")
311 | ],
312 | description="Pause execution for specified duration"
313 | ),
314 |
315 | ActionMapping(
316 | source_function="final_answer",
317 | target_function="complete_task",
318 | parameter_mappings=[
319 | ParameterMapping(source_name="arg_0", target_name="answer")
320 | ],
321 | description="Complete task with result"
322 | ),
323 | ]
324 |
325 | return ActionSpaceConverter(mappings)
326 |
327 | def convert_assistant(chat_message: dict, converter: ActionSpaceConverter) -> dict:
328 | """
329 | Convert function calls in assistant messages to sentence format.
330 |
331 | Args:
332 | chat_message: Dictionary with format {"user": "...", "assistant": "..."}
333 |
334 | Returns:
335 | Updated chat message with function calls converted to sentences
336 | """
337 | from utils.function_parser import parse_function_call
338 |
339 | if "assistant" not in chat_message:
340 | return chat_message
341 |
342 | assistant_message = chat_message["assistant"]
343 |
344 | # Parse function calls from the assistant message
345 | old_function_calls = parse_function_call(assistant_message)
346 | new_function_calls = converter.convert_actions(old_function_calls)
347 |
348 | # Replace each function call with its sentence format
349 | updated_message = assistant_message
350 | for new_function_call, old_function_call in zip(new_function_calls, old_function_calls):
351 | updated_message = updated_message.replace(old_function_call.to_string(), new_function_call.to_string())
352 |
353 | # Return updated chat message
354 | chat_message["assistant"] = updated_message
355 | return chat_message
356 |
357 |
358 | # Testing and demonstration
359 | if __name__ == "__main__":
360 |
361 | logger.info("🧪 Testing Action Space Converter")
362 | logger.info("=" * 50)
363 |
364 | # Test with default custom converter
365 | converter = create_default_unified_to_custom_converter()
366 |
367 | chat_history = [
368 | {"user": "Click on the home button", "assistant": "I'll click on the home button for you. navigate_home()"},
369 | {"user": "Type hello world", "assistant": "I'll type that text for you. type(text='hello world')"},
370 | {"user": "Click at coordinates 0.5, 0.8", "assistant": "I'll click at those coordinates. click(x=0.5, y=0.8)"},
371 | {"user": "Scroll up by 10 units", "assistant": "I'll scroll up for you. scroll(direction='up', amount=10)"}
372 | ]
373 |
374 | # Test the function with chat history
375 | logger.info("🧪 Testing Chat Message Function Call Conversion")
376 | logger.info("=" * 60)
377 |
378 | for i, chat_msg in enumerate(chat_history):
379 | logger.info(f"\n📩 Chat Message {i+1}:")
380 | logger.info(f" User: {chat_msg['user']}")
381 | logger.info(f" Original Assistant: {chat_msg['assistant']}")
382 |
383 | converted_msg = convert_assistant(chat_msg, converter)
384 | logger.info(f" Converted Assistant: {converted_msg['assistant']}")
--------------------------------------------------------------------------------
/utils/function_parser.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Function parser for extracting function names, parameter names, and values from string function calls.
4 | Supports both mobile and pyautogui function patterns.
5 | """
6 |
7 | import re
8 | from typing import Dict, List, Tuple, Any
9 | from collections import OrderedDict
10 | from pydantic import BaseModel
11 |
12 | class FunctionCall(BaseModel):
13 | """Represents a parsed function call with its parameters."""
14 | function_name: str
15 | parameters: Dict[str, Any]
16 | original_string: str
17 | description: str = ""
18 |
19 | def to_string(self) -> str:
20 | """
21 | Reconstruct the function call string from the parsed data.
22 |
23 | Returns:
24 | String representation of the function call
25 |
26 | Examples:
27 | >>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)")
28 | >>> call.to_string()
29 | "mobile.wait(seconds=3)"
30 |
31 | >>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)")
32 | >>> call.to_string()
33 | "function(1, 2, x=0.5)"
34 | """
35 | if not self.parameters:
36 | return f"{self.function_name}()"
37 |
38 | # Separate positional and named arguments
39 | positional_args = []
40 | named_args = []
41 |
42 | for name, value in self.parameters.items():
43 | if name.startswith("arg_"):
44 | # Positional argument
45 | positional_args.append((int(name.split("_")[1]), value))
46 | else:
47 | # kwargs
48 | named_args.append((name, value))
49 |
50 | # Sort positional arguments by index
51 | positional_args.sort(key=lambda x: x[0])
52 |
53 | # Build parameter string
54 | param_parts = []
55 |
56 | # Add positional arguments
57 | for _, value in positional_args:
58 | param_parts.append(self._value_to_string(value))
59 |
60 | # Add named arguments
61 | for name, value in named_args:
62 | param_parts.append(f"{name}={self._value_to_string(value)}")
63 |
64 | return f"{self.function_name}({', '.join(param_parts)})"
65 |
66 | def _value_to_string(self, value: Any) -> str:
67 | """
68 | Convert a value to its string representation for function calls.
69 |
70 | Args:
71 | value: The value to convert
72 |
73 | Returns:
74 | String representation of the value
75 | """
76 | if isinstance(value, str):
77 | # Quote strings
78 | return f"'{value}'"
79 | elif isinstance(value, (list, tuple)):
80 | # Convert lists/tuples to string representation
81 | items = [self._value_to_string(item) for item in value]
82 | return f"[{', '.join(items)}]"
83 | elif isinstance(value, dict):
84 | # Convert dictionaries to string representation
85 | items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()]
86 | return f"{{{', '.join(items)}}}"
87 | elif isinstance(value, bool):
88 | # Convert booleans to lowercase
89 | return str(value).lower()
90 | elif value is None:
91 | return "None"
92 | else:
93 | # Numbers and other types
94 | return str(value)
95 |
96 |
97 | def parse_function_call(function_string: str, pattern_to_match: list[str] = []) -> List[FunctionCall]:
98 | """
99 | Parse a function call string and extract all function calls found.
100 |
101 | Args:
102 | function_string: String representation of function calls
103 |
104 | Returns:
105 | List of FunctionCall objects with parsed information
106 |
107 | Examples:
108 | >>> parse_function_call("mobile.wait(seconds=3)")
109 | [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
110 |
111 | >>> parse_function_call("mobile. wait(seconds=3)")
112 | [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
113 |
114 | >>> parse_function_call("mobile.wait(seconds=3) mobile.home()")
115 | [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)]
116 | """
117 | # Remove any leading/trailing whitespace
118 | function_string = function_string.strip()
119 |
120 | # Pattern to match function calls with parameters
121 | # Matches: function_name(param1=value1, param2=value2, ...)
122 | # Can have any characters before the function call, extracts just the function name
123 | pattern = r'.*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)'
124 |
125 | matches = re.findall(pattern, function_string)
126 | if not matches:
127 | # No valid function calls found in: {function_string}
128 | return []
129 |
130 | results = []
131 | for match in matches:
132 | function_name = match[0]
133 | params_string = match[1]
134 |
135 | if pattern_to_match and all(pattern not in function_name for pattern in pattern_to_match):
136 | continue
137 |
138 | # Parse parameters
139 | parameters = parse_parameters(params_string)
140 |
141 | # Create the original string for this specific function call
142 | original_string = f"{function_name}({params_string})"
143 |
144 | results.append(FunctionCall(
145 | function_name=function_name,
146 | parameters=parameters,
147 | original_string=original_string
148 | ))
149 |
150 | return results
151 |
152 |
153 | def parse_parameters(params_string: str) -> Dict[str, Any]:
154 | """
155 | Parse parameter string and extract parameter names and values.
156 |
157 | Args:
158 | params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'")
159 |
160 | Returns:
161 | Dictionary mapping parameter names to their values
162 |
163 | Examples:
164 | >>> parse_parameters("x=0.5, y=0.6")
165 | {'x': 0.5, 'y': 0.6}
166 |
167 | >>> parse_parameters("app_name='drupe'")
168 | {'app_name': 'drupe'}
169 |
170 | >>> parse_parameters("'text'")
171 | {'arg_0': 'text'}
172 |
173 | >>> parse_parameters("1, 3, 4")
174 | {'arg_0': 1, 'arg_1': 3, 'arg_2': 4}
175 |
176 | >>> parse_parameters("arg1, arg2, x=0.5")
177 | {'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5}
178 | """
179 | if not params_string.strip():
180 | return {}
181 |
182 | parameters = OrderedDict()
183 |
184 | # Split by commas, but be careful with commas inside quotes or brackets
185 | param_parts = split_parameters(params_string)
186 |
187 | positional_index = 0
188 |
189 | for part in param_parts:
190 | part = part.strip()
191 | if not part:
192 | continue
193 |
194 | # Parse individual parameter
195 | name, value = parse_single_parameter(part)
196 |
197 | # For positional arguments, use index-based naming
198 | if name.startswith("arg_"):
199 | name = f"arg_{positional_index}"
200 | positional_index += 1
201 |
202 | parameters[name] = value
203 |
204 | return parameters
205 |
206 |
207 | def split_parameters(params_string: str) -> List[str]:
208 | """
209 | Split parameter string by commas, respecting quotes and brackets.
210 |
211 | Args:
212 | params_string: String containing parameters
213 |
214 | Returns:
215 | List of individual parameter strings
216 | """
217 | parts = []
218 | current_part = ""
219 | paren_count = 0
220 | bracket_count = 0
221 | brace_count = 0
222 | in_quotes = False
223 | quote_char = None
224 |
225 | for char in params_string:
226 | if char in ['"', "'"] and (not in_quotes or char == quote_char):
227 | if not in_quotes:
228 | in_quotes = True
229 | quote_char = char
230 | else:
231 | in_quotes = False
232 | quote_char = None
233 | elif not in_quotes:
234 | if char == '(':
235 | paren_count += 1
236 | elif char == ')':
237 | paren_count -= 1
238 | elif char == '[':
239 | bracket_count += 1
240 | elif char == ']':
241 | bracket_count -= 1
242 | elif char == '{':
243 | brace_count += 1
244 | elif char == '}':
245 | brace_count -= 1
246 | elif char == ',' and paren_count == 0 and bracket_count == 0 and brace_count == 0:
247 | parts.append(current_part.strip())
248 | current_part = ""
249 | continue
250 |
251 | current_part += char
252 |
253 | if current_part.strip():
254 | parts.append(current_part.strip())
255 |
256 | return parts
257 |
258 |
259 | def parse_single_parameter(param_string: str) -> Tuple[str, Any]:
260 | """
261 | Parse a single parameter string into name and value.
262 |
263 | Args:
264 | param_string: String like "x=0.5" or "app_name='drupe'" or just "value"
265 |
266 | Returns:
267 | Tuple of (parameter_name, parameter_value)
268 |
269 | Examples:
270 | >>> parse_single_parameter("x=0.5")
271 | ('x', 0.5)
272 |
273 | >>> parse_single_parameter("app_name='drupe'")
274 | ('app_name', 'drupe')
275 |
276 | >>> parse_single_parameter("'text'")
277 | ('arg_0', 'text')
278 |
279 | >>> parse_single_parameter("123")
280 | ('arg_0', 123)
281 |
282 | >>> parse_single_parameter("3")
283 | ('arg_0', 3)
284 | """
285 | # Pattern to match parameter name and value
286 | pattern = r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$'
287 |
288 | match = re.match(pattern, param_string)
289 | if match:
290 | # Named parameter
291 | param_name = match.group(1)
292 | param_value_str = match.group(2).strip()
293 | param_value = parse_value(param_value_str)
294 | return param_name, param_value
295 | else:
296 | # Positional parameter - treat as unnamed argument
297 | param_value = parse_value(param_string)
298 | return "arg_0", param_value
299 |
300 |
301 | def parse_value(value_string: str) -> Any:
302 | """
303 | Parse a value string into appropriate Python type.
304 |
305 | Args:
306 | value_string: String representation of a value
307 |
308 | Returns:
309 | Parsed value (int, float, str, list, etc.)
310 |
311 | Examples:
312 | >>> parse_value("3")
313 | 3
314 |
315 | >>> parse_value("3.14")
316 | 3.14
317 |
318 | >>> parse_value("'hello'")
319 | 'hello'
320 |
321 | >>> parse_value("[0.581, 0.898]")
322 | [0.581, 0.898]
323 | """
324 | value_string = value_string.strip()
325 |
326 | # String values (quoted)
327 | if (value_string.startswith("'") and value_string.endswith("'")) or \
328 | (value_string.startswith('"') and value_string.endswith('"')):
329 | return value_string[1:-1]
330 |
331 | # List values
332 | if value_string.startswith('[') and value_string.endswith(']'):
333 | return parse_list(value_string)
334 |
335 | # Dictionary values
336 | if value_string.startswith('{') and value_string.endswith('}'):
337 | return parse_dict(value_string)
338 |
339 | # Boolean values
340 | if value_string.lower() in ['true', 'false']:
341 | return value_string.lower() == 'true'
342 |
343 | # None value
344 | if value_string.lower() == 'none':
345 | return None
346 |
347 | # Numeric values
348 | try:
349 | # Try integer first
350 | if '.' not in value_string:
351 | return int(value_string)
352 | else:
353 | return float(value_string)
354 | except ValueError:
355 | # If it's not a number, return as string (remove quotes if present)
356 | if value_string.startswith("'") and value_string.endswith("'"):
357 | return value_string[1:-1]
358 | elif value_string.startswith('"') and value_string.endswith('"'):
359 | return value_string[1:-1]
360 | else:
361 | return value_string
362 |
363 |
364 | def parse_list(list_string: str) -> List[Any]:
365 | """
366 | Parse a list string into a Python list.
367 |
368 | Args:
369 | list_string: String like "[0.581, 0.898]"
370 |
371 | Returns:
372 | List of parsed values
373 |
374 | Examples:
375 | >>> parse_list("[0.581, 0.898]")
376 | [0.581, 0.898]
377 | """
378 | # Remove outer brackets
379 | content = list_string[1:-1].strip()
380 | if not content:
381 | return []
382 |
383 | # Split by commas, respecting nested structures
384 | parts = split_parameters(content)
385 |
386 | return [parse_value(part.strip()) for part in parts]
387 |
388 |
389 | def parse_dict(dict_string: str) -> Dict[str, Any]:
390 | """
391 | Parse a dictionary string into a Python dict.
392 |
393 | Args:
394 | dict_string: String like "{'key': 'value'}"
395 |
396 | Returns:
397 | Dictionary of parsed key-value pairs
398 | """
399 | # Remove outer braces
400 | content = dict_string[1:-1].strip()
401 | if not content:
402 | return {}
403 |
404 | # Split by commas, respecting nested structures
405 | parts = split_parameters(content)
406 |
407 | result = {}
408 | for part in parts:
409 | part = part.strip()
410 | if ':' in part:
411 | key_str, value_str = part.split(':', 1)
412 | key = parse_value(key_str.strip())
413 | value = parse_value(value_str.strip())
414 | result[key] = value
415 |
416 | return result
417 |
418 |
419 | def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]:
420 | """
421 | Parse multiple function call strings.
422 |
423 | Args:
424 | function_strings: List of function call strings
425 |
426 | Returns:
427 | List of FunctionCall objects
428 | """
429 | results = []
430 | for func_str in function_strings:
431 | try:
432 | result_list = parse_function_call(func_str)
433 | results.extend(result_list)
434 | except Exception as e:
435 | print(f"Warning: Could not parse function call '{func_str}': {e}")
436 | continue
437 |
438 | return results
439 |
440 |
441 | def extract_function_calls_from_text(text: str) -> List[FunctionCall]:
442 | """
443 | Extract and parse function calls from a text block.
444 |
445 | Args:
446 | text: Text containing function calls
447 |
448 | Returns:
449 | List of FunctionCall objects
450 | """
451 | # Pattern to find function calls in text
452 | # Matches: function_name(param1=value1, param2=value2)
453 | pattern = r'[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)'
454 |
455 | matches = re.findall(pattern, text)
456 | return parse_multiple_functions(matches)
457 |
458 |
459 | # Example usage and testing
460 | if __name__ == "__main__":
461 | test_cases = [
462 | "mobile.home()",
463 | "mobile.open_app(app_name='drupe')",
464 | "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
465 | "mobile.back()",
466 | "mobile.long_press(x=0.799, y=0.911)",
467 | "mobile.terminate(status='success')",
468 | "answer('text')",
469 | "pyautogui.hscroll(page=-0.1)",
470 | "pyautogui.scroll(page=-0.1)",
471 | "pyautogui.scroll(0.13)",
472 | "pyautogui.click(x=0.8102, y=0.9463)",
473 | "pyautogui.hotkey(keys=['ctrl', 'c'])",
474 | "pyautogui.doubleClick()",
475 | "pyautogui.press(keys='enter')",
476 | "pyautogui.press(keys=['enter'])",
477 | "pyautogui.moveTo(x=0.04, y=0.405)",
478 | "pyautogui.write(message='bread buns')",
479 | "pyautogui.dragTo(x=0.8102, y=0.9463)",
480 | "mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
481 | # Additional test cases for multiple positional arguments
482 | "function(arg1, arg2, arg3)",
483 | "function('hello', 123, x=0.5)",
484 | "function(arg1, arg2, named_param='value')",
485 | "function(1, 2, 3, 4, 5)",
486 | "function('a', 'b', 'c', x=1, y=2)",
487 | ]
488 |
489 | print("Testing function parser:")
490 | print("=" * 50)
491 |
492 | for test_case in test_cases:
493 | try:
494 | results = parse_function_call(test_case)
495 | print(f"✓ {test_case}")
496 | for result in results:
497 | print(f" Function: {result.function_name}")
498 | print(f" Parameters: {result.parameters}")
499 | print()
500 | except Exception as e:
501 | print(f"✗ {test_case}")
502 | print(f" Error: {e}")
503 | print()
504 |
505 | # Test extracting from text
506 | print("Testing text extraction:")
507 | print("=" * 50)
508 |
509 | sample_text = """
510 | mobile.wait(seconds=3)
511 | mobile.open_app(app_name='drupe')
512 | pyautogui.click(x=0.8102, y=0.9463)
513 | pyautogui.write(message='bread buns')
514 | """
515 |
516 | extracted = extract_function_calls_from_text(sample_text)
517 | for func_call in extracted:
518 | print(f"Found: {func_call.function_name} with params: {func_call.parameters}")
519 |
520 | # Test reconstruction
521 | print("\nTesting function call reconstruction:")
522 | print("=" * 50)
523 |
524 | reconstruction_tests = [
525 | "mobile.wait(seconds=3)",
526 | "mobile.home()",
527 | "mobile.open_app(app_name='drupe')",
528 | "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
529 | "answer('text')",
530 | "pyautogui.scroll(0.13)",
531 | "pyautogui.click(x=0.8102, y=0.9463)",
532 | "pyautogui.hotkey(keys=['ctrl', 'c'])",
533 | "function(1, 2, 3)",
534 | "function('hello', 123, x=0.5, y=0.8)",
535 | "function([1, 3], 'arg2', named_param='value')",
536 | ]
537 |
538 | for test_case in reconstruction_tests:
539 | parsed_list = parse_function_call(test_case)
540 | for parsed in parsed_list:
541 | reconstructed = parsed.to_string()
542 | print(f"Original: {test_case}")
543 | print(f"Reconstructed: {reconstructed}")
544 | print(f"Match: {test_case == reconstructed}")
545 | assert test_case == reconstructed
546 | print()
--------------------------------------------------------------------------------
/preprocessing/aguvis_processor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | AGUVIS Dataset Processor Module
4 |
5 | Downloading, processing, and uploading the aguvis-stage1/2 datasets.
6 | Downloads from huggingface.co/datasets/xlangai/aguvis-stage1/2 and uploads to smolagents/aguvis-stage-1/2
7 | """
8 |
9 | import logging
10 | import re
11 | import gc
12 | import os
13 | import zipfile
14 | import tarfile
15 | from copy import deepcopy
16 | from pathlib import Path
17 | from typing import Any, Dict, List, Generator, Optional
18 | from itertools import islice
19 | from multiprocessing import Pool
20 | from collections import defaultdict
21 |
22 | from tqdm import tqdm
23 | from datasets import Dataset, concatenate_datasets, get_dataset_config_names
24 | from dotenv import load_dotenv
25 | from huggingface_hub import login, snapshot_download
26 | from PIL import Image
27 |
28 | from utils.function_parser import parse_function_call
29 | from preprocessing.prompts import OS_SYSTEM_PROMPT, MOBILE_SYSTEM_PROMPT
30 | from preprocessing.action_conversion import action_conversion
31 | from preprocessing.aguvis_models import (
32 | ConversationDataList,
33 | ConversationData,
34 | ChatMessage,
35 | DataRow,
36 | DatasetConfig,
37 | ProcessingConfig,
38 | MOBILE_FILES,
39 | CONFIG_DICT_STAGE_1,
40 | CONFIG_DICT_STAGE_2
41 | )
42 |
43 | # Configure logger for this module
44 | logger = logging.getLogger(__name__)
45 |
46 | def huggingface_authenticate():
47 | """Authenticate with HuggingFace Hub using token."""
48 | hf_token = os.getenv("HF_TOKEN")
49 | if hf_token:
50 | logger.info("Authenticating with HuggingFace Hub using token...")
51 | login(token=hf_token)
52 | else:
53 | raise ValueError("HF_TOKEN environment variable not set.")
54 |
55 | class DatasetDownloader:
56 | """Handles dataset downloading and extraction operations."""
57 |
58 | def download_dataset(self, repo_id: str, local_dir: str) -> str:
59 | """Download the dataset using snapshot_download."""
60 | logger.info(f"Downloading dataset from {repo_id}...")
61 | local_path = snapshot_download(
62 | repo_id=repo_id, local_dir=local_dir, repo_type="dataset"
63 | )
64 | logger.info(f"Dataset downloaded to: {local_path}")
65 | return local_path
66 |
67 | def extract_zip_files(self, dataset_path: str):
68 | """Extract all zip files found in the dataset directory, but only if not already extracted."""
69 | logger.info("Extracting zip files...")
70 | dataset_dir = Path(dataset_path)
71 |
72 | for zip_file in dataset_dir.rglob("*.zip"):
73 | extract_dir = zip_file.parent / zip_file.stem
74 | if extract_dir.exists() and any(extract_dir.iterdir()):
75 | logger.info(
76 | f"Skipping extraction for {zip_file} (already extracted at {extract_dir})"
77 | )
78 | continue
79 |
80 | logger.info(f"Extracting: {zip_file}")
81 | with zipfile.ZipFile(zip_file, "r") as zip_ref:
82 | zip_ref.extractall(extract_dir)
83 | logger.info(f"Extracted to: {extract_dir}")
84 |
85 | def extract_tar_parts_grouped(self, dataset_path: str):
86 | """
87 | Finds all .tar.gz.part_* groups, merges them, and extracts them into directories
88 | named after their common prefix.
89 | """
90 | dataset_dir = Path(dataset_path)
91 | part_files = list(dataset_dir.glob("*.tar.gz.part_*"))
92 |
93 | if not part_files:
94 | logger.info("No split .tar.gz.part_* files found.")
95 | return
96 |
97 | # Group part files by prefix
98 | groups = defaultdict(list)
99 | for part in part_files:
100 | prefix = part.name.split(".tar.gz.part_")[0]
101 | groups[prefix].append(part)
102 |
103 | for prefix, parts in groups.items():
104 | parts = sorted(parts) # Ensure correct order
105 | merged_tar_path = dataset_dir / f"{prefix}.tar.gz"
106 | extract_dir = dataset_dir / prefix
107 |
108 | if extract_dir.exists() and any(extract_dir.iterdir()):
109 | logger.info(
110 | f"Skipping extraction for '{prefix}' (already extracted at {extract_dir})"
111 | )
112 | continue
113 |
114 | # Merge parts
115 | CHUNK_SIZE = 1024 * 1024
116 | logger.info(f"Merging parts for '{prefix}'...")
117 | with open(merged_tar_path, "wb") as outfile:
118 | for part in parts:
119 | logger.info(f" Adding: {part.name}")
120 | with open(part, "rb") as infile:
121 | while chunk := infile.read(CHUNK_SIZE):
122 | outfile.write(chunk)
123 |
124 | logger.info(f"Merged to: {merged_tar_path}")
125 |
126 | # Extract
127 | logger.info(f"Extracting to: {extract_dir}")
128 | with tarfile.open(merged_tar_path, "r:gz") as tar:
129 | tar.extractall(path=extract_dir)
130 | logger.info(f"Done extracting '{prefix}'\n")
131 |
132 | @staticmethod
133 | def discover_dataset_config(dataset_path: str, config_dict: List[Dict[str, Any]]) -> List[ProcessingConfig]:
134 | """Discover dataset configuration by scanning the data directory."""
135 | dataset_dir = Path(dataset_path)
136 | train_dir = dataset_dir
137 |
138 | if not train_dir.exists():
139 | raise FileNotFoundError(f"Train directory not found: {train_dir}")
140 |
141 | configs = []
142 | processed_splits = set()
143 |
144 | # Find all JSON files in the train directory
145 | for config in config_dict:
146 | processing_config = ProcessingConfig.from_config_dict(config)
147 |
148 | # Skip if we already processed this split
149 | if processing_config.subset_name in processed_splits:
150 | continue
151 |
152 | configs.append(processing_config)
153 | processed_splits.add(processing_config.subset_name)
154 | logger.info(
155 | f"Discovered config: {processing_config.subset_name} -> {processing_config.images_folder}"
156 | )
157 |
158 | return configs
159 |
160 |
161 |
162 |
163 | class SampleProcessor:
164 | """Processes and converts messages to different formats."""
165 |
166 | @staticmethod
167 | def load_image_from_folder(images_folder: Path, img_path: str) -> Image.Image:
168 | """Load images from the specified folder."""
169 | full_path = images_folder / img_path
170 | img = Image.open(full_path)
171 | new_img = img.copy()
172 | img.close()
173 | return new_img
174 |
175 | @staticmethod
176 | def convert_to_code_agent_format(messages: list[ChatMessage], json_path: str, reasoning: bool):
177 | """Convert messages to code agent format."""
178 | for i, message in enumerate(messages):
179 | content = message.content
180 |
181 | if message.role == "system":
182 | if json_path in MOBILE_FILES:
183 | content = MOBILE_SYSTEM_PROMPT
184 | else:
185 | content = OS_SYSTEM_PROMPT
186 |
187 | if message.role == "user":
188 | content = content.replace("\n" + content.strip() + "\n"
199 | )
200 | elif reasoning:
201 | # TODO: Check if there is always only 2 assistants
202 | content = (
203 | "\n(.*?)\n"
240 | assistant_msg = [(i, message) for i, message in enumerate(messages) if message.role == "assistant"]
241 |
242 | if assistant_msg:
243 | for index, msg in assistant_msg:
244 |
245 | if code_format:
246 | regex_match = re.search(regex, msg.content, re.DOTALL)
247 | else:
248 | regex_match = msg.content
249 |
250 | if regex_match is not None:
251 | function_calls = parse_function_call(
252 | regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match,
253 | pattern_to_match=["pyautogui", "mobile", "terminate", "answer"],
254 | )
255 |
256 | if len(function_calls) > 0:
257 |
258 | for i, function_call in enumerate(deepcopy(function_calls)):
259 |
260 | # pyautogui.dragTo have multiple signatures, we need to unify them before converting to new action space
261 | if function_call.function_name == "pyautogui.dragTo" and not isinstance(list(function_calls[i].parameters.values())[0], (list, tuple)):
262 | x1, y1 = islice(function_calls[i-1].parameters.values(), 2)
263 | x2, y2 = islice(function_calls[i].parameters.values(), 2)
264 | function_calls[i].parameters = {"from_coord": (x1, y1), "to_coord": (x2, y2)}
265 | function_calls[i].original_string = function_calls[i].to_string()
266 | function_calls.pop(i-1)
267 |
268 | function_calls = action_conversion(function_calls, resolution=resolution)
269 |
270 | new_action_string = "\n".join(
271 | [function_call.to_string() for function_call in function_calls]
272 | )
273 | messages[index].content = messages[index].content.replace(
274 | regex_match.group(1) if isinstance(regex_match, re.Match) else regex_match, new_action_string
275 | )
276 |
277 | return messages
278 |
279 |
280 | class DataProcessor:
281 | """Handles data processing and generation."""
282 |
283 | def __init__(self):
284 | self.sample_processor = SampleProcessor()
285 |
286 | def process_subset(
287 | self, config: ProcessingConfig, dataset_path: str, deduplicate: bool = True
288 | ) -> tuple[ConversationDataList, Path]:
289 | """Process a single dataset subset."""
290 | subset_name = config.subset_name
291 |
292 | logger.info(f"Processing split: {subset_name}")
293 |
294 | dataset_dir = Path(dataset_path)
295 | images_folder = dataset_dir / config.subset_name.replace("-l1", "").replace("-l2", "").replace("-l3", "") / config.images_folder
296 |
297 | if not images_folder.exists():
298 | logger.warning(f"Images folder not found: {images_folder}")
299 | else:
300 | logger.info(f"Images folder: {images_folder}")
301 |
302 | json_config_path = dataset_dir / config.json_path
303 | with open(json_config_path, "r") as f:
304 | # Create ConversationDataList with deduplication control
305 | data = ConversationDataList.from_json_with_deduplication(f.read(), deduplicate)
306 | logger.info(f"Added '{json_config_path}' (deduplication: {deduplicate})")
307 |
308 | return data, images_folder
309 |
310 | def row_generator(
311 | self, data: ConversationDataList, images_folder: Path, json_path: str, reasoning: bool
312 | ) -> Generator[Dict[str, Any], None, None]:
313 | """Generate processed data rows."""
314 | conversations: list[ConversationData] = data.root
315 | for item in tqdm(conversations):
316 | try:
317 | # Load images
318 | image = self.sample_processor.load_image_from_folder(images_folder, item.image)
319 | chat_message = self.sample_processor.convert_to_chat_format(item, json_path, reasoning)
320 | chat_message = self.sample_processor.convert_to_new_action_space(chat_message, image.size, code_format=reasoning)
321 | if len(chat_message) == 0:
322 | continue
323 |
324 | row = DataRow.from_chat_messages(chat_message, image, source=json_path.split("/")[-1].split(".")[0])
325 | yield row.model_dump(exclude_none=True)
326 | del image
327 | except Exception as e:
328 | import traceback
329 | traceback.print_exc()
330 | logger.error(f"Error processing item: {e}", item)
331 | continue
332 |
333 |
334 | class SingleConfigProcessor:
335 | """Processes a single configuration in isolation."""
336 |
337 | def __init__(self):
338 | self.data_processor = DataProcessor()
339 |
340 | @staticmethod
341 | def check_subset_exists(repo_id: str, subset_name: str) -> bool:
342 | """Check if a subset already exists in the remote dataset."""
343 | try:
344 | config_names = get_dataset_config_names(repo_id)
345 | return subset_name in config_names
346 | except Exception as e:
347 | logger.warning(f"Could not check if subset exists: {e}")
348 | return False
349 |
350 | def process_single_config(
351 | self, config: ProcessingConfig, dataset_path: str, smolagents_repo_id: str, reasoning: bool, deduplicate: bool = True
352 | ) -> bool:
353 | """Process a single config in a separate process."""
354 | try:
355 | # Authenticate in this process
356 | huggingface_authenticate()
357 |
358 | logger.info(f"\n{'=' * 50}")
359 | logger.info(f"Processing config: {config.subset_name}")
360 |
361 | # Check if the subset already exists in the remote dataset
362 | subset_name = config.subset_name
363 | if SingleConfigProcessor.check_subset_exists(smolagents_repo_id, subset_name):
364 | logger.info(
365 | f"Subset '{subset_name}' already exists in {smolagents_repo_id}, skipping processing."
366 | )
367 | return True
368 |
369 | json_path = config.json_path
370 | data, image_folder = self.data_processor.process_subset(config, dataset_path, deduplicate)
371 |
372 | # Collect all rows first
373 | rows = []
374 | datasets = []
375 | for row in self.data_processor.row_generator(data, image_folder, json_path, reasoning):
376 | rows.append(row)
377 | if len(rows) > 20000:
378 | logger.info("Creating batch dataset")
379 | dataset = Dataset.from_list(rows)
380 | datasets.append(dataset)
381 | rows = []
382 | gc.collect()
383 |
384 | if len(rows) > 0:
385 | # Create dataset from collected data
386 | dataset = Dataset.from_list(rows)
387 | datasets.append(dataset)
388 | rows = []
389 |
390 | dataset_to_push = concatenate_datasets(datasets)
391 |
392 | # Push to hub
393 | dataset_to_push.push_to_hub(
394 | smolagents_repo_id,
395 | config_name=subset_name,
396 | split="train",
397 | )
398 |
399 | logger.info(f"Processed and uploaded subset: {config.subset_name}")
400 |
401 | # Force garbage collection to manage memory
402 | gc.collect()
403 |
404 | return True
405 |
406 | except Exception as e:
407 | logger.error(f"Error processing config {config.subset_name}: {e}")
408 | import traceback
409 | traceback.print_exc()
410 | return False
411 |
412 |
413 | class AguvisDatasetProcessor:
414 | """Main class for orchestrating the entire AGUVIS dataset processing pipeline."""
415 |
416 | def __init__(self):
417 | self.downloader = DatasetDownloader()
418 | self.config_processor = SingleConfigProcessor()
419 |
420 | @staticmethod
421 | def authenticate():
422 | """Authenticate with HuggingFace Hub using token."""
423 | hf_token = os.getenv("HF_TOKEN")
424 | if hf_token:
425 | logger.info("Authenticating with HuggingFace Hub using token...")
426 | login(token=hf_token)
427 | else:
428 | raise ValueError("HF_TOKEN environment variable not set.")
429 |
430 | def make_dataset_from_original_data(
431 | self, dataset_config: DatasetConfig, max_processes: Optional[int] = None
432 | ):
433 | """Main function to orchestrate the entire process."""
434 | load_dotenv(override=True)
435 |
436 | logger.info(f"Starting {dataset_config.smolagents_repo_id} dataset processing...")
437 |
438 | self.authenticate()
439 |
440 | dataset_path = self.downloader.download_dataset(
441 | dataset_config.huggingface_repo_id, dataset_config.local_path
442 | )
443 |
444 | self.downloader.extract_zip_files(dataset_path)
445 | self.downloader.extract_tar_parts_grouped(dataset_path)
446 |
447 | dataset_configs = self.downloader.discover_dataset_config(
448 | dataset_path, dataset_config.config_dict
449 | )
450 | converted_repo_id = dataset_config.smolagents_repo_id
451 | reasoning = dataset_config.reasoning
452 | deduplicate = dataset_config.deduplicate
453 |
454 | if max_processes is None:
455 | max_processes = 1
456 | num_processes = min(max_processes, len(dataset_configs))
457 | logger.info(f"Using {num_processes} processes to process {len(dataset_configs)} configs")
458 |
459 | # Prepare arguments for multiprocessing
460 | process_args = [
461 | (config, dataset_path, converted_repo_id, reasoning, deduplicate)
462 | for config in dataset_configs
463 | ]
464 |
465 | # Process configs in parallel with progress tracking
466 | logger.info(f"Starting parallel processing of {len(process_args)} configs...")
467 | try:
468 | with Pool(processes=num_processes) as pool:
469 | results = []
470 | for i, result in enumerate(pool.starmap(self.config_processor.process_single_config, process_args)):
471 | results.append(result)
472 | logger.info(f"Completed {i+1}/{len(process_args)} configs")
473 | except Exception as e:
474 | logger.error(f"Multiprocessing failed: {e}")
475 | logger.info("Falling back to sequential processing...")
476 | results = []
477 | for i, args in enumerate(process_args):
478 | result = self.config_processor.process_single_config(*args)
479 | results.append(result)
480 | logger.info(f"Completed {i+1}/{len(process_args)} configs (sequential)")
481 |
482 | # Check results
483 | successful = sum(results)
484 | total = len(process_args)
485 | logger.info(f"\nProcessing complete: {successful}/{total} configs processed successfully")
486 |
487 | if successful < total:
488 | failed_count = total - successful
489 | logger.warning(f"Warning: {failed_count} configs failed to process. Check the logs above for details.")
490 | else:
491 | logger.info("All configs processed successfully!")
492 |
493 |
494 | def main():
495 | """Main entry point for the script."""
496 | # Create dataset configurations
497 | dataset_config_1 = DatasetConfig(
498 | huggingface_repo_id="xlangai/aguvis-stage1",
499 | local_path="./aguvis_raw_stage_1",
500 | config_dict=CONFIG_DICT_STAGE_1,
501 | smolagents_repo_id="smolagents/aguvis-stage-test",
502 | reasoning=False,
503 | deduplicate=True,
504 | )
505 |
506 | dataset_config_2 = DatasetConfig(
507 | huggingface_repo_id="xlangai/aguvis-stage2",
508 | local_path="./aguvis_raw_stage_2",
509 | config_dict=CONFIG_DICT_STAGE_2,
510 | smolagents_repo_id="smolagents/aguvis-stage-test",
511 | reasoning=True,
512 | deduplicate=False,
513 | )
514 |
515 | # Create processor and run
516 | processor = AguvisDatasetProcessor()
517 |
518 | # You can specify max_processes to limit the number of parallel processes
519 | processor.make_dataset_from_original_data(dataset_config_1, max_processes=4)
520 | processor.make_dataset_from_original_data(dataset_config_2, max_processes=4)
521 |
522 |
523 | if __name__ == "__main__":
524 | # python -m preprocessing.aguvis_processor
525 | main()
526 |
--------------------------------------------------------------------------------
/recipe.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Vision-Language Model Training Recipe\n",
8 | "\n",
9 | "Two-phase training pipeline for SmolVLM2 on GUI datasets.\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Setup & Imports\n"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from transformers import AutoModelForImageTextToText, AutoProcessor, PreTrainedModel\n",
26 | "from trl import SFTTrainer, SFTConfig\n",
27 | "from datasets import load_dataset, concatenate_datasets, DatasetDict\n",
28 | "from PIL import Image\n",
29 | "from typing import Any, Callable\n",
30 | "import torch\n",
31 | "import wandb\n",
32 | "import logging\n",
33 | "import os\n",
34 | "\n",
35 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
36 | "\n",
37 | "logger = logging.getLogger(__name__)"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "## Dataset Configuration\n",
45 | "\n",
46 | "Phase 1: Basic GUI understanding datasets\n"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "dataset_mixture_phase_1 = [\n",
56 | " {\n",
57 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
58 | " \"config\": \"guienv\",\n",
59 | " \"split\": \"train\",\n",
60 | " \"columns\": [\"images\", \"texts\"],\n",
61 | " \"weight\": 1.0,\n",
62 | " },\n",
63 | " {\n",
64 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
65 | " \"config\": \"omniact\",\n",
66 | " \"split\": \"train\",\n",
67 | " \"columns\": [\"images\", \"texts\"],\n",
68 | " \"weight\": 1.0,\n",
69 | " },\n",
70 | " {\n",
71 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
72 | " \"config\": \"ricoig16k\",\n",
73 | " \"split\": \"train\",\n",
74 | " \"columns\": [\"images\", \"texts\"],\n",
75 | " \"weight\": 1.0,\n",
76 | " },\n",
77 | " {\n",
78 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
79 | " \"config\": \"ricosca\",\n",
80 | " \"split\": \"train\",\n",
81 | " \"columns\": [\"images\", \"texts\"],\n",
82 | " \"weight\": 1.0,\n",
83 | " },\n",
84 | " {\n",
85 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
86 | " \"config\": \"seeclick\",\n",
87 | " \"split\": \"train\",\n",
88 | " \"columns\": [\"images\", \"texts\"],\n",
89 | " \"weight\": 1.0,\n",
90 | " },\n",
91 | " {\n",
92 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
93 | " \"config\": \"ui_refexp\",\n",
94 | " \"split\": \"train\",\n",
95 | " \"columns\": [\"images\", \"texts\"],\n",
96 | " \"weight\": 1.0,\n",
97 | " },\n",
98 | " {\n",
99 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
100 | " \"config\": \"webui350k\",\n",
101 | " \"split\": \"train\",\n",
102 | " \"columns\": [\"images\", \"texts\"],\n",
103 | " \"weight\": 1.0,\n",
104 | " },\n",
105 | " {\n",
106 | " \"id\": \"smolagents/aguvis-stage-1\",\n",
107 | " \"config\": \"widget_captioning\",\n",
108 | " \"split\": \"train\",\n",
109 | " \"columns\": [\"images\", \"texts\"],\n",
110 | " \"weight\": 1.0,\n",
111 | " },\n",
112 | "]\n"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "Phase 2: Advanced agentic behavior datasets\n"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "dataset_mixture_phase_2 = [\n",
129 | " {\n",
130 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
131 | " \"config\": \"mind2web-l1\",\n",
132 | " \"split\": \"train\",\n",
133 | " \"columns\": [\"images\", \"texts\"],\n",
134 | " \"weight\": 1.0,\n",
135 | " },\n",
136 | " {\n",
137 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
138 | " \"config\": \"mind2web-l2\",\n",
139 | " \"split\": \"train\",\n",
140 | " \"columns\": [\"images\", \"texts\"],\n",
141 | " \"weight\": 1.0,\n",
142 | " },\n",
143 | " {\n",
144 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
145 | " \"config\": \"guiact-web-single\",\n",
146 | " \"split\": \"train\",\n",
147 | " \"columns\": [\"images\", \"texts\"],\n",
148 | " \"weight\": 1.0,\n",
149 | " },\n",
150 | " {\n",
151 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
152 | " \"config\": \"guiact-web-multi-l1\",\n",
153 | " \"split\": \"train\",\n",
154 | " \"columns\": [\"images\", \"texts\"],\n",
155 | " \"weight\": 1.0,\n",
156 | " },\n",
157 | " {\n",
158 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
159 | " \"config\": \"guiact-web-multi-l2\",\n",
160 | " \"split\": \"train\",\n",
161 | " \"columns\": [\"images\", \"texts\"],\n",
162 | " \"weight\": 1.0,\n",
163 | " },\n",
164 | " {\n",
165 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
166 | " \"config\": \"miniwob-l1\",\n",
167 | " \"split\": \"train\",\n",
168 | " \"columns\": [\"images\", \"texts\"],\n",
169 | " \"weight\": 1.0,\n",
170 | " },\n",
171 | " {\n",
172 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
173 | " \"config\": \"miniwob-l2\",\n",
174 | " \"split\": \"train\",\n",
175 | " \"columns\": [\"images\", \"texts\"],\n",
176 | " \"weight\": 1.0,\n",
177 | " },\n",
178 | " {\n",
179 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
180 | " \"config\": \"coat\",\n",
181 | " \"split\": \"train\",\n",
182 | " \"columns\": [\"images\", \"texts\"],\n",
183 | " \"weight\": 1.0,\n",
184 | " },\n",
185 | " {\n",
186 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
187 | " \"config\": \"android_control\",\n",
188 | " \"split\": \"train\",\n",
189 | " \"columns\": [\"images\", \"texts\"],\n",
190 | " \"weight\": 1.0,\n",
191 | " },\n",
192 | " {\n",
193 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
194 | " \"config\": \"gui-odyssey-l1\",\n",
195 | " \"split\": \"train\",\n",
196 | " \"columns\": [\"images\", \"texts\"],\n",
197 | " \"weight\": 0.33,\n",
198 | " },\n",
199 | " {\n",
200 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
201 | " \"config\": \"gui-odyssey-l2\",\n",
202 | " \"split\": \"train\",\n",
203 | " \"columns\": [\"images\", \"texts\"],\n",
204 | " \"weight\": 0.33,\n",
205 | " },\n",
206 | " {\n",
207 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
208 | " \"config\": \"amex-l1\",\n",
209 | " \"split\": \"train\",\n",
210 | " \"columns\": [\"images\", \"texts\"],\n",
211 | " \"weight\": 0.33,\n",
212 | " },\n",
213 | " {\n",
214 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
215 | " \"config\": \"amex-l2\",\n",
216 | " \"split\": \"train\",\n",
217 | " \"columns\": [\"images\", \"texts\"],\n",
218 | " \"weight\": 0.33,\n",
219 | " },\n",
220 | " {\n",
221 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
222 | " \"config\": \"aitw-l1\",\n",
223 | " \"split\": \"train\",\n",
224 | " \"columns\": [\"images\", \"texts\"],\n",
225 | " \"weight\": 1.0,\n",
226 | " },\n",
227 | " {\n",
228 | " \"id\": \"smolagents/aguvis-stage-2\",\n",
229 | " \"config\": \"aitw-l2\",\n",
230 | " \"split\": \"train\",\n",
231 | " \"columns\": [\"images\", \"texts\"],\n",
232 | " \"weight\": 1.0,\n",
233 | " },\n",
234 | "]\n"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "metadata": {},
240 | "source": [
241 | "## Dataset Loading Utility\n",
242 | "\n",
243 | "Loads and combines multiple datasets with weights and splits into train/test.\n"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "def get_dataset(dataset_mixture: list[dict[str, Any]], test_split_size: float = 0.01) -> DatasetDict:\n",
253 | " \"\"\"Load a dataset or a mixture of datasets based on the configuration.\n",
254 | "\n",
255 | " Args:\n",
256 | " dataset_mixture (list[dict[str, Any]]): Dataset configuration.\n",
257 | "\n",
258 | " Returns:\n",
259 | " DatasetDict: The loaded datasets.\n",
260 | " \"\"\"\n",
261 | " logger.info(f\"Creating dataset mixture with {len(dataset_mixture)} datasets\")\n",
262 | " seed = 42\n",
263 | " datasets_list = []\n",
264 | "\n",
265 | " for dataset_config in dataset_mixture:\n",
266 | " logger.info(\n",
267 | " f\"Loading dataset for mixture: {dataset_config['id']} (config: {dataset_config['config']})\"\n",
268 | " )\n",
269 | " ds = load_dataset(\n",
270 | " dataset_config[\"id\"],\n",
271 | " dataset_config[\"config\"],\n",
272 | " split=dataset_config[\"split\"],\n",
273 | " )\n",
274 | " ds = ds.select_columns(dataset_config[\"columns\"])\n",
275 | " ds = ds.shuffle(seed=seed).select(\n",
276 | " range(int(len(ds) * dataset_config[\"weight\"]))\n",
277 | " )\n",
278 | " logger.info(\n",
279 | " f\"Subsampled dataset '{dataset_config['id']}' (config: {dataset_config['config']}) with weight={dataset_config['weight']} to {len(ds)} examples\"\n",
280 | " )\n",
281 | "\n",
282 | " datasets_list.append(ds)\n",
283 | "\n",
284 | " if datasets_list:\n",
285 | " combined_dataset = concatenate_datasets(datasets_list)\n",
286 | " combined_dataset = combined_dataset.shuffle(seed=seed)\n",
287 | " logger.info(f\"Created dataset mixture with {len(combined_dataset)} examples\")\n",
288 | "\n",
289 | " combined_dataset = combined_dataset.train_test_split(test_size=test_split_size, seed=seed)\n",
290 | " logger.info(\n",
291 | " f\"Split dataset into train and test sets with test size: {test_split_size}\"\n",
292 | " )\n",
293 | " return combined_dataset\n",
294 | " else:\n",
295 | " raise ValueError(\"No datasets were loaded from the mixture configuration\")\n"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {},
301 | "source": [
302 | "## Model & Training Parameters\n"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": null,
308 | "metadata": {},
309 | "outputs": [],
310 | "source": [
311 | "base_model_name = \"HuggingFaceTB/SmolVLM2-2.2B-Instruct\"\n",
312 | "phase_1_model_name = \"SmolVLM2-2.2B-Instruct-GUI\"\n",
313 | "phase_2_model_name = \"SmolVLM2-2.2B-Instruct-Agentic-GUI\"\n",
314 | "image_size = 1152\n",
315 | "max_length = 16384\n"
316 | ]
317 | },
318 | {
319 | "cell_type": "markdown",
320 | "metadata": {},
321 | "source": [
322 | "## Processor & Data Collator Setup\n",
323 | "\n",
324 | "Configure processor for image/text handling and create custom collator.\n"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "metadata": {},
331 | "outputs": [],
332 | "source": [
333 | "processor = AutoProcessor.from_pretrained(\n",
334 | " base_model_name,\n",
335 | " revision=\"main\",\n",
336 | " trust_remote_code=True,\n",
337 | ")\n",
338 | "processor.image_processor.size = {\"longest_edge\": image_size}\n",
339 | "processor.tokenizer.truncation_side = \"right\"\n",
340 | "processor.tokenizer.padding_side = \"right\"\n"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "def create_collate_fn(processor, max_length: int):\n",
350 | " \"\"\"Optimized collate function for VLM training that masks system prompt tokens.\"\"\"\n",
351 | "\n",
352 | " def collate_fn(examples: list[dict[str, list | str | Image.Image]]):\n",
353 | " batch_messages: list[list[dict[str, list | str | Image.Image]]] = []\n",
354 | " assistant_messages: list[list[str]] = []\n",
355 | " all_image_inputs: list[list[Image.Image]] = []\n",
356 | " for example in examples:\n",
357 | " images: list[Image.Image] = example[\"images\"]\n",
358 | " is_first_user = True\n",
359 | " sample: list[dict[str, list | str | Image.Image]] = []\n",
360 | " assistant: list[str] = []\n",
361 | " for text in example[\"texts\"]:\n",
362 | " if \"system\" in text.keys():\n",
363 | " sample.append(\n",
364 | " {\n",
365 | " \"role\": \"system\",\n",
366 | " \"content\": [{\"type\": \"text\", \"text\": text[\"system\"]}],\n",
367 | " }\n",
368 | " )\n",
369 | "\n",
370 | " if is_first_user:\n",
371 | " sample.append(\n",
372 | " {\n",
373 | " \"role\": \"user\",\n",
374 | " \"content\": [\n",
375 | " {\"type\": \"image\", \"image\": images[0]},\n",
376 | " {\"type\": \"text\", \"text\": text[\"user\"]},\n",
377 | " ],\n",
378 | " }\n",
379 | " )\n",
380 | " is_first_user = False\n",
381 | " else:\n",
382 | " sample.append(\n",
383 | " {\n",
384 | " \"role\": \"user\",\n",
385 | " \"content\": [\n",
386 | " {\"type\": \"text\", \"text\": text[\"user\"]},\n",
387 | " ],\n",
388 | " }\n",
389 | " )\n",
390 | "\n",
391 | " sample.append(\n",
392 | " {\n",
393 | " \"role\": \"assistant\",\n",
394 | " \"content\": [{\"type\": \"text\", \"text\": \"\\n\" + text[\"assistant\"]}],\n",
395 | " }\n",
396 | " )\n",
397 | " assistant.append(text[\"assistant\"])\n",
398 | "\n",
399 | " batch_messages.append(sample)\n",
400 | " assistant_messages.append(assistant)\n",
401 | " all_image_inputs.append(images)\n",
402 | "\n",
403 | " texts = [\n",
404 | " processor.apply_chat_template(\n",
405 | " messages, tokenize=False, add_generation_prompt=False\n",
406 | " )\n",
407 | " for messages in batch_messages\n",
408 | " ]\n",
409 | "\n",
410 | " batch = processor(\n",
411 | " text=texts,\n",
412 | " images=all_image_inputs if all_image_inputs else None,\n",
413 | " max_length=max_length,\n",
414 | " truncation=True,\n",
415 | " padding=True,\n",
416 | " return_tensors=\"pt\",\n",
417 | " )\n",
418 | "\n",
419 | " input_ids = batch[\"input_ids\"]\n",
420 | " labels = input_ids.clone()\n",
421 | "\n",
422 | " assistant_encodings = [\n",
423 | " processor.tokenizer(\n",
424 | " [msg + \"smolagents/aguvis-stage-1, smolagents/aguvis-stage-2
47 | GUI capabilities combine understanding of the interface and precise element localization. These abilities enable the model to translate high-level tasks into low-level GUI actions such as clicking, typing, …
68 | 69 |Evolution of ScreenSpot-v2 performance during the training phase of the base model **SmolVLM2-2.2B-Instruct**.
89 | 90 |click(x=302, y=63)) ties them to a single image size. Vision Language Models (VLMs) often resize images, causing pixel coordinates to break and require adjustment. Normalized coordinates (relative to image size) remain valid at any resolution and keep the dataset consistent.
168 | | Configuration (coords / image size) | 313 |Screenspot-v2 (%) | 321 |
|---|---|
| Normalized coordinates | 334 |340 | |
| Base / – | 349 |0.47 | 355 |
| 384 | 364 |31.28 | 370 |
| 764 | 379 |32.32 | 385 |
| 1152 | 394 |33.72 | 400 |
| Pixel coordinates | 411 |417 | |
| Base / – | 426 |0.55 | 432 |
| 384 | 441 |1.17 | 447 |
| 764 | 456 |2.67 | 462 |
| 1152 | 471 |4.32 | 477 |
Table 1: Baseline on HuggingFaceTB/SmolVLM2-2.2B-Instruct (400k samples, aguvis-stage-1). Higher is better.
488 || Configuration (coords / image size) | 527 |Screenspot-v2 (%) | 535 |
|---|---|
| Normalized coordinates / 1152 | 545 |41.27 | 551 |
Table 2: Baseline on HuggingFaceTB/SmolVLM2-2.2B-Instruct (2 epochs, aguvis-stage-1).
562 |\nclick(x=0.41, y=0.178)\n",
587 | }
588 | ```
589 |
590 | Each sample links a screenshot with a system/user/assistant turn. During fine-tuning, the data collator masks everything except the assistant’s answers when computing the loss.
591 |
592 |
598 |
599 | ### Phase 2 Results
600 |
601 | Starting from the Phase 1 checkpoint (1152 px resolution, normalized coordinates), we fine-tuned the model for two epochs on [smolagents/aguvis-stage-2](https://huggingface.co/datasets/smolagents/aguvis-stage-2). The accuracy on **ScreenSpot-v2 increased from 41% to 61%**, indicating that explicit reasoning improves GUI grounding performance.
602 |
603 |
604 |
605 | | Configuration (coords / image size) | 624 |Screenspot-v2 (%) | 632 |
|---|---|
| Normalized coordinates / 1152 | 642 |61.71 | 648 |
Table 2: Baseline on HuggingFaceTB/SmolVLM2-2.2B-Instruct after Phase 1 finetuning (2 epochs, aguvis-stage-1).
659 |