├── workflow ├── pre.sh ├── aitw │ ├── prompts.py │ ├── aitw_test.py │ └── action_matching.py ├── mind2web │ ├── prompts.py │ ├── mind2web_test.py │ └── action_matching.py └── odyssey │ ├── prompts.py │ ├── odyssey_test.py │ └── action_matching.py ├── LEGAL.md ├── document_construction ├── pre.sh ├── odyssey_document │ ├── prompts.py │ └── main.py ├── aitw_document │ ├── prompts.py │ └── main.py └── mind2web_document │ ├── prompts.py │ └── main.py ├── LICENSE └── README.md /workflow/pre.sh: -------------------------------------------------------------------------------- 1 | pip install langchain_community 2 | pip install langchain_huggingface 3 | pip install jax 4 | pip install jaxlib 5 | pip install faiss-gpu 6 | pip install sentence-transformers 7 | 8 | python -m vllm.entrypoints.openai.api_server --served-model-name Qwen2.5-VL-72B-Instruct --model Qwen2.5-VL-72B-Instruct -tp 4 --limit_mm_per_prompt image=2 9 | -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /document_construction/pre.sh: -------------------------------------------------------------------------------- 1 | 2 | pip install langchain 3 | conda install -c pytorch faiss-gpu 4 | pip install -U langchain-community 5 | pip install sentence-transformers 6 | pip install numpy==1.23.2 7 | pip install -U langchain-huggingface 8 | pip install jax 9 | pip install jaxlib 10 | pip install --upgrade vllm 11 | 12 | python -m vllm.entrypoints.openai.api_server --served-model-name Qwen2.5-VL-72B-Instruct --model Qwen2.5-VL-72B-Instruct -tp 4 --limit_mm_per_prompt image=2 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 chenweizhi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /document_construction/odyssey_document/prompts.py: -------------------------------------------------------------------------------- 1 | 2 | page_summary = 'Please describe this screen containing following content with one full sentence, including \ 3 | the type of page, the function of page and the key components of the screen.' 4 | 5 | click_action_summary = 'The user clicks the item at coordinates {bbox}. You are required to summarize this operation with a verb phrase that begins with \"click\". Do not mention original coordinates.' 6 | 7 | 8 | redirection_judge = 'You will receive the images of screens before and after operation \'{action}\'. \ 9 | You need to determine whether this operation leads to a new page, or it is just an in-page operation. \ 10 | You are required to output with the following format:\n\ 11 | ### Thought: \n\ 12 | ### Conclusion: <\'Yes\' or \'No\'>\n\ 13 | Do not output anything else.' 14 | 15 | 16 | 17 | 18 | check_repeat = 'You are a professional GUI agent. You will be given a screen and some descriptions. \ 19 | Your task is to find one description that best fits the current page.\n\ 20 | Here are the descriptions:\n\ 21 | {old_description}\ 22 | You should answer with the following format:\n\ 23 | ### Thought: \n\ 24 | ### Index: \n\ 25 | Do not output anything else.' 26 | 27 | check_repeat_2 = 'Are these two screens similar? You should consider the type, layout, and content of the pages comprehensively.\n\ 28 | You are required to output with the following format:\n\ 29 | ### Thought: \n\ 30 | ### Conclusion: <\'Yes\' or \'No\'>\n\ 31 | Do not output anything else.' 32 | 33 | -------------------------------------------------------------------------------- /document_construction/aitw_document/prompts.py: -------------------------------------------------------------------------------- 1 | 2 | page_summary = 'Please describe this screen containing following content with one full sentence, including \ 3 | the type of page, the function of page and the key components of the screen.' 4 | 5 | action_summary = 'An operation has now been performed on the screen. \ 6 | Here is the type of the operation and relevant parameters:\n\ 7 | {action_description}\n\ 8 | You are required to summarize this operation with a verb phrase that begins with the given operation type.' 9 | 10 | redirection_judge = 'You will receive the images of screens before and after operation \'{action}\'. \ 11 | You need to determine whether this operation leads to a new page, or it is just an in-page operation. \ 12 | You are required to output with the following format:\n\ 13 | ### Thought: \n\ 14 | ### Conclusion: <\'Yes\' or \'No\'>\n\ 15 | Do not output anything else.' 16 | 17 | 18 | check_repeat = 'You are a professional GUI agent. You will be given a screen and some descriptions. \ 19 | Your task is to find one description that best fits the current page.\n\ 20 | Here are the descriptions:\n\ 21 | {old_description}\ 22 | You should answer with the following format:\n\ 23 | ### Thought: \n\ 24 | ### Index: \n\ 25 | Do not output anything else.' 26 | 27 | check_repeat_2 = 'Are these two screens similar? You should consider the type, layout, and content of the pages comprehensively.\n\ 28 | You are required to output with the following format:\n\ 29 | ### Thought: \n\ 30 | ### Conclusion: <\'Yes\' or \'No\'>\n\ 31 | Do not output anything else.' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PG-Agent: An Agent Powered by Page Graph 2 | 3 | [![Paper](http://img.shields.io/badge/Paper-arxiv.2509.03536-99D4C8.svg)](https://arxiv.org/abs/2509.03536) 4 | 5 | This is the source code for **page graph construction** and **multi-agent workflow**. 6 | 7 | ### Data Preparation 8 | 9 | ------ 10 | 11 | The open-source datasets we use are from following repositories: 12 | 13 | - AITW & Mind2Web: [here](https://github.com/njucckevin/SeeClick/blob/main/agent_tasks/readme_agent.md) 14 | - GUI Odyssey: [here](https://github.com/OpenGVLab/GUI-Odyssey/blob/master/README.md) 15 | 16 | ### Page Graph Construction 17 | 18 | ------ 19 | 20 | You can run the following code to construct the corresponding page graph. 21 | 22 | ``` 23 | cd document_construction 24 | sh pre.sh 25 | ``` 26 | 27 | - AITW 28 | 29 | ``` 30 | python aitw_document/main.py 31 | ``` 32 | 33 | - Mind2Web 34 | 35 | ``` 36 | python mind2web_document/main.py 37 | ``` 38 | 39 | - GUI Odyssey 40 | 41 | ``` 42 | python odyssey_document/main.py 43 | ``` 44 | 45 | ### Multi-agent Workflow 46 | 47 | ------ 48 | 49 | You can run the following code to evaluate the agent in following benchmarks with corresponding page graphs . 50 | 51 | ``` 52 | cd workflow 53 | sh pre.sh 54 | ``` 55 | 56 | - AITW 57 | 58 | ``` 59 | python aitw/aitw_test.py 60 | ``` 61 | 62 | - Mind2Web 63 | 64 | ``` 65 | python mind2web/mind2web_test.py 66 | ``` 67 | 68 | - GUI Odyssey 69 | 70 | ``` 71 | python odyssey/odyssey_test.py 72 | ``` 73 | 74 | ### Citation 75 | 76 | ------ 77 | 78 | ``` 79 | @misc{chen2025pgagentagentpoweredpage, 80 | title={PG-Agent: An Agent Powered by Page Graph}, 81 | author={Weizhi Chen and Ziwei Wang and Leyang Yang and Sheng Zhou and Xiaoxuan Tang and Jiajun Bu and Yong Li and Wei Jiang}, 82 | year={2025}, 83 | eprint={2509.03536}, 84 | archivePrefix={arXiv}, 85 | primaryClass={cs.AI}, 86 | url={https://arxiv.org/abs/2509.03536}, 87 | } 88 | ``` -------------------------------------------------------------------------------- /document_construction/mind2web_document/prompts.py: -------------------------------------------------------------------------------- 1 | 2 | page_summary = 'Please describe this screen containing following content with one full sentence, including \ 3 | the type of page, the function of page and the key components of the screen.' 4 | 5 | click_action_summary = 'This is a page of website. The user clicks the item at coordinates {bbox}. You are required to summarize this operation beginning with \"click\". Do not mention original coordinates.' 6 | 7 | type_action_summary = 'This is a page of website. The user types the content \"{content}\" at coordinates {bbox}. You are required to summarize this operation beginning with \"type\". Do not mention original coordinates.' 8 | 9 | select_action_summary = 'This is a page of website. The user opens a \"Select Menu\" or \"Dropdown List\" at coordinates {bbox}, and select the option \"{content}\". You are required to summarize this operation beginning with \"select\". Do not mention original coordinates.' 10 | 11 | 12 | redirection_judge = 'You will receive the images of screens before and after operation \'{action}\'. \ 13 | You need to determine whether this operation leads to a new page, or it is just an in-page operation. \ 14 | You are required to output with the following format:\n\ 15 | ### Thought: \n\ 16 | ### Conclusion: <\'Yes\' or \'No\'>\n\ 17 | Do not output anything else.' 18 | 19 | 20 | 21 | 22 | check_repeat = 'You are a professional GUI agent. You will be given a webpage and some descriptions. \ 23 | Your task is to find one description that best fits the current webpage.\n\ 24 | Here are the descriptions:\n\ 25 | {old_description}\ 26 | You should answer with the following format:\n\ 27 | ### Thought: \n\ 28 | ### Index: \n\ 29 | Do not output anything else.' 30 | 31 | 32 | check_repeat_2 = 'Are these two screens similar? You should consider the type, layout, and content of the pages comprehensively.\n\ 33 | You are required to output with the following format:\n\ 34 | ### Thought: \n\ 35 | ### Conclusion: <\'Yes\' or \'No\'>\n\ 36 | Do not output anything else.' 37 | 38 | -------------------------------------------------------------------------------- /workflow/aitw/prompts.py: -------------------------------------------------------------------------------- 1 | AITW_ACTION_SPACE = ''' 2 | 1. Click(x, y): An action of click a coordinate point on the smartphone screen and x,y is the position of the coordinate point on the screen. 3 | Your click location should be a UI element or text on the screen. 4 | A simple use case could be Click(100,238), which means you click the UI element at (100,238) on the current screen. 5 | 6 | 2. Type("typed_text"): An action of typing a piece of text. 7 | A simple use case can be text("Hello, world!"), which inserts the string "Hello, world!" into the input area on the smartphone screen. 8 | 9 | 3. Scroll("direction"): This function is used to scroll the screen to a specific direction. 10 | "direction" is a string that represents one of the four directions: "up", "down", "left", "right". 11 | A simple use case could be Scroll("up"), which means you take a scroll up action on the current screen. 12 | 13 | 4. Back(): The action for returning to the previous step. 14 | 15 | 5. Home(): The action for returning to the homepage. 16 | 17 | 6. Enter(): The action of pressing the ENTER key to submit input content. 18 | 19 | 7. Complete: It means you think the task is complete. 20 | ''' 21 | 22 | 23 | AITW_OBSERVATION_PROMT = f""" 24 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions. 25 | You will be given user's ultimate purpose and the previous actions that you have taken. 26 | Your task is to carefully observe the screen, descripe it and conclude some useful clues in one sentence. 27 | 28 | Now you can start to observe: 29 | 30 | ### User's purpose ### 31 | 32 | 33 | ### History trajectory ### 34 | History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions. 35 | 36 | 37 | ### Observation ### 38 | """ 39 | 40 | AITW_PLANNING_PROMT = f""" 41 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions. 42 | Your task is to plan the next action to complete user's purpose with the help of references. 43 | 44 | I will give you several important information: 45 | ### User's purpose ### 46 | This is the user's global purpose, and your goal is to complete it: 47 | 48 | 49 | ### Observation ### 50 | This is the observation of the screen and some useful clues that help you plan: 51 | 52 | 53 | ### Global Plan ### 54 | This is the global plan for completing user's purpose: 55 | 56 | 57 | ### History trajectory ### 58 | History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions. 59 | 60 | 61 | ### Reference ### 62 | There are some reference actions that you can follow: 63 | 64 | 65 | Based on given information, you are required to output with following format: 66 | 1. 67 | 2. 68 | 3. 69 | """ 70 | 71 | AITW_EXECUTION_PROMT = f""" 72 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface. 73 | You will be given a smartphone screenshot and a plan that you decide to take. 74 | 75 | Before you start, I will explain the data format: 76 | ### Plan ### 77 | This is your plan: 78 | 79 | 80 | ### Action Space ### 81 | These are the functions to interact with the phone: 82 | {AITW_ACTION_SPACE} 83 | 84 | ### Reference ### 85 | There are some reference actions that you can follow: 86 | 87 | 88 | Now please choose one action in \"### Action Space ###\" for the current screen state based on \"### Plan ###\" and \"### Reference ###\". 89 | You should output with following format: 90 | 91 | ### Thought ### 92 | According to \"### Plan ###\", you should first determine weather the purpose has been completed. If not, think step-by-step and output the action that should be taken currently. 93 | 94 | ### Action ### 95 | The action you finally choose from \"### Action Space ###\". Do not output anything else. 96 | """ 97 | 98 | AITW_GLOBAL_PLANNING_PROMT = f''' 99 | You are an agent that is trained to complete certain tasks on a smartphone. You will be given a screenshot of a smartphone app. 100 | 101 | The global task you should complete is: 102 | \"\" 103 | 104 | Now, carefully analyze all the above content and provide your output in the following format: 105 | 106 | ### Global Plan ### 107 | Please break down the overall task into 2~3 simple sub-goals. 108 | Note that since you can’t see future phone screenshots, each sub-goal should be abstract, high-level, and not involve interacting with specific UI elements. 109 | ''' 110 | 111 | 112 | 113 | PAGE_SUMMARY_PROMPT = 'Please describe this screen containing following content with one full sentence: \ 114 | the type of page, the function of page and a few key components of the screen.' 115 | 116 | 117 | REFERENCE_FORMAT = '''{idx}. 118 | You can take following action: {actions}. 119 | This can help you achieve goals like: {goals}. 120 | ''' 121 | 122 | ACTION_SUMMARY_PROMPT = 'A click operation has now been performed at coordinates {coordinates}. \ 123 | You are required to summarize this operation with a verb phrase.' 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /workflow/mind2web/prompts.py: -------------------------------------------------------------------------------- 1 | 2 | MIND2WEB_ACTION_SPACE=''' 3 | 1. Click(x,y): An action of clicking a coordinate point on the web screen and x,y is the position of the coordinate point on the screen. 4 | Your click location should be a UI element or text on the screen. 5 | A simple use case could be Click(100,238), which means you click the UI element at (100,238) on the current screen. 6 | 7 | 2. Type(x,y,"typed_text"): An action of typing a piece of text at the positon with coordinates x and y. 8 | A simple use case could be Type(340,212,"Where was Obama born?"), which inputs the string "Where was Obama born?" into the input area at the cordinates (340,212) on the web screen. 9 | 10 | 3. Select(x,y,"option"): An action of opening a \"Select Menu\" or \"Dropdown List\" located at coordinates (x, y) and choose an option you specify. 11 | A simple use case could be Select(679,437,"female"), which opens the list at the coordinates (679,437) and select the option "female" from the list. 12 | ''' 13 | 14 | MIND2WEB_OBSERVATION_PROMT = f""" 15 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions. 16 | You will be given user's ultimate purpose and the previous actions that you have taken. 17 | Your task is to carefully observe the screen, descripe it and conclude some useful clues in one sentence. 18 | 19 | Now you can start to observe: 20 | 21 | ### User's purpose ### 22 | 23 | 24 | ### History trajectory ### 25 | History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions. 26 | 27 | 28 | ### Observation ### 29 | """ 30 | 31 | MIND2WEB_GLOBAL_PLANNING_PROMT = f''' 32 | You are an agent that is trained to complete certain tasks on the webpage. You will be given a screenshot of a website. 33 | 34 | The global task you should complete is: 35 | \"\" 36 | 37 | Now, carefully analyze all the above content and provide your output in the following format: 38 | 39 | ### Global Plan ### 40 | Please break down the overall task into 2~3 simple sub-goals. 41 | Note that since you can’t see future webpages, each sub-goal should be abstract, high-level, and not involve interacting with specific UI elements. 42 | ''' 43 | 44 | MIND2WEB_PLANNING_PROMT = f""" 45 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions. 46 | Your task is to plan the next action to complete user's purpose with the help of references. 47 | 48 | I will give you several important information: 49 | ### User's purpose ### 50 | This is the user's global purpose, and your goal is to complete it: 51 | 52 | 53 | ### Observation ### 54 | This is the observation of the screen and some useful clues that help you plan: 55 | 56 | 57 | ### Global Plan ### 58 | This is the global plan for completing user's purpose: 59 | 60 | 61 | ### History trajectory ### 62 | History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions. 63 | 64 | 65 | ### Reference ### 66 | There are some reference actions that you can follow: 67 | 68 | 69 | Based on given information, you are required to output with following format: 70 | 1. 71 | 2. 72 | """ 73 | 74 | MIND2WEB_EXECUTION_PROMT = f""" 75 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface. 76 | You will be given a screenshot of a website and a plan that you decide to take. 77 | 78 | Before you start, I will explain the data format: 79 | ### Plan ### 80 | This is your plan: 81 | 82 | 83 | ### Reference ### 84 | There are some reference actions that you can follow: 85 | 86 | 87 | ### Action Space ### 88 | These are the functions to interact with the webpage: 89 | {MIND2WEB_ACTION_SPACE} 90 | 91 | Now please choose one action in \"### Action Space ###\" for the current webpage based on \"### Plan ###\" and \"### Reference ###\". 92 | You should output with following format: 93 | 94 | ### Thought ### 95 | Think step-by-step and output the action that should be taken currently. 96 | 97 | ### Action ### 98 | Output only one action you finally choose from \"### Action Space ###\". Do not output anything else. 99 | """ 100 | 101 | 102 | ACTION_SUMMARY_PROMPT = { 103 | 'click_action_summary' : 'This is a page of website. The user clicks the item at coordinates {bbox}. You are required to summarize this operation beginning with \"click\". Do not mention original coordinates.', 104 | 'type_action_summary' : 'This is a page of website. The user types the content \"{content}\" at coordinates {bbox}. You are required to summarize this operation beginning with \"type\". Do not mention original coordinates.', 105 | 'select_action_summary' : 'This is a page of website. The user opens a \"Select Menu\" or \"Dropdown List\" at coordinates {bbox}, and select the option \"{content}\". You are required to summarize this operation beginning with \"select\". Do not mention original coordinates.' 106 | } 107 | 108 | PAGE_SUMMARY_PROMPT = 'Please describe this screen containing following content with one full sentence, including \ 109 | the type of page, the function of page and the key components of the screen.' 110 | 111 | REFERENCE_FORMAT = '''{idx}. 112 | You can take following action: {actions}. 113 | This can help you achieve goals like: {goals}. 114 | ''' -------------------------------------------------------------------------------- /workflow/odyssey/prompts.py: -------------------------------------------------------------------------------- 1 | 2 | ODYSSEY_ACTION_SPACE = ''' 3 | 1. 'CLICK: (x,y)': An action of clicking a coordinate point on the smartphone screen and x,y is the position of the coordinate point on the screen. 4 | Your click location should be a UI element or text on the screen. 5 | A simple use case could be 'CLICK: (100,238)', which means you click the UI element at (100,238) on the current screen. 6 | 7 | 2. 'TYPE: typed_text': An action of typing a piece of text. 8 | A simple use case can be 'TYPE: Hello, world!', which inserts the string "Hello, world!" into the input area on the smartphone screen. 9 | 10 | 3. 'SCROLL: direction': This function is used to scroll an UI element shown on the smartphone screen, usually a scroll view or a slide bar. 11 | "direction" is a string that represents one of the four directions: UP, DOWN, LEFT, RIGHT. 12 | A simple use case could be 'SCROLL: UP', which means you take a scroll up action on the current screen. 13 | 14 | 4. 'PRESS_BACK': The action for returning to the previous screen. 15 | 16 | 5. 'PRESS_HOME': The action for returning to the homepage. 17 | 18 | 6. 'PRESS_RECENT': The action to go to the previous App. 19 | 20 | 7. 'COMPLETE': It means you think the task has been completed based on current screen. 21 | 22 | 8. 'IMPOSSIBLE': It means you think the task cannot be completed based on current screen. 23 | 24 | 9. 'LONG_PRESS: (x,y)': An action of pressing a coordinate point on the smartphone screen for a long time to copy texts or download images, where x and y is the position of the coordinate point on the screen. 25 | ''' 26 | 27 | ODYSSEY_OBSERVATION_PROMT = f""" 28 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions. 29 | You will be given user's ultimate purpose and the previous actions that you have taken. 30 | Your task is to carefully observe the screen, descripe it and conclude some useful clues in one sentence. 31 | 32 | Now you can start to observe: 33 | 34 | ### User's purpose ### 35 | 36 | 37 | ### History trajectory ### 38 | History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions. 39 | 40 | 41 | ### Observation ### 42 | """ 43 | 44 | ODYSSEY_GLOBAL_PLANNING_PROMT = f''' 45 | You are an agent that is trained to complete certain tasks on a smartphone. You will be given a screenshot of a smartphone app. 46 | 47 | The global task you should complete is: 48 | \"\" 49 | 50 | Now, carefully analyze all the above content and provide your output in the following format: 51 | 52 | ### Global Plan ### 53 | Please break down the overall task into 2~3 simple sub-goals. 54 | Note that since you can’t see future phone screenshots, each sub-goal should be abstract, high-level, and not involve interacting with specific UI elements. 55 | ''' 56 | 57 | ODYSSEY_PLANNING_PROMT = f""" 58 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions. 59 | Your task is to plan the next action to complete user's purpose with the help of references. 60 | 61 | I will give you several important information: 62 | ### User's purpose ### 63 | This is the user's global purpose, and your goal is to complete it: 64 | 65 | 66 | ### Observation ### 67 | This is the observation of the screen and some useful clues that help you plan: 68 | 69 | 70 | ### Global Plan ### 71 | This is the global plan for completing user's purpose: 72 | 73 | 74 | ### History trajectory ### 75 | History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions. 76 | 77 | 78 | ### Reference ### 79 | There are some reference actions that you can follow: 80 | 81 | 82 | Based on given information, you are required to output with following format: 83 | 1. 84 | 2. 85 | 3. 86 | """ 87 | 88 | ODYSSEY_EXECUTION_PROMT = f""" 89 | You are a smart GUI agent, capable of comprehensively understanding the GUI interface. 90 | You will be given a smartphone screenshot and a plan that you decide to take. 91 | 92 | Before you start, I will explain the data format: 93 | ### Plan ### 94 | This is your plan: 95 | 96 | 97 | ### Action Space ### 98 | These are the functions to interact with the phone: 99 | {ODYSSEY_ACTION_SPACE} 100 | 101 | ### Reference ### 102 | There are some reference actions that you can follow: 103 | 104 | 105 | Now please choose one action in \"### Action Space ###\" for the current screen state based on \"### Plan ###\" and \"### Reference ###\". 106 | You should output with following format: 107 | 108 | ### Thought ### 109 | According to \"### Plan ###\", you should first determine weather the purpose has been completed. If not, think step-by-step and output the action that should be taken currently. 110 | 111 | ### Action ### 112 | The action you finally choose from \"### Action Space ###\". Do not output anything else. 113 | """ 114 | 115 | REFERENCE_FORMAT = '''{idx}. 116 | You can take following action: {actions}. 117 | This can help you achieve goals like: {goals}. 118 | ''' 119 | 120 | PAGE_SUMMARY_PROMPT = 'Please describe this screen containing following content with one full sentence, including \ 121 | the type of page, the function of page and the key components of the screen.' 122 | 123 | 124 | 125 | ACTION_SUMMARY_PROMPT = 'The user clicks the item at coordinates {bbox}. You are required to summarize this operation with a verb phrase that begins with \"click\". Do not mention original coordinates.' 126 | -------------------------------------------------------------------------------- /document_construction/aitw_document/main.py: -------------------------------------------------------------------------------- 1 | # from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 2 | # from qwen_vl_utils import process_vision_info 3 | import random 4 | 5 | # import cv2 6 | import copy 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | import requests 10 | from urllib.parse import quote 11 | import json 12 | from tqdm import tqdm 13 | from langchain_community.vectorstores import FAISS 14 | from langchain.schema import Document 15 | from langchain_huggingface import HuggingFaceEmbeddings 16 | from PIL import Image 17 | 18 | import prompts 19 | 20 | 21 | url = "http://localhost:8000/v1/chat/completions" 22 | 23 | headers = { 24 | "Content-Type": "application/json" 25 | } 26 | 27 | def chat(img_url_list: str = '', query: str = '') -> dict: 28 | 29 | content = [] 30 | for img_url in img_url_list: 31 | img_url = quote(img_url, safe='/:') 32 | content.append({"type": "image_url", "image_url": {"url": img_url}}) 33 | content.append({"type": "text", "text": query}) 34 | data = { 35 | "model": "Qwen2.5-VL-72B-Instruct", 36 | "messages": [ 37 | {"role": "system", "content": "You are a helpful assistant."}, 38 | {"role": "user", "content": content} 39 | ], 40 | 'temperature':0 41 | } 42 | 43 | response = requests.post(url, headers=headers, data=json.dumps(data)) 44 | response = response.json() 45 | response = response['choices'][0]['message']['content'] 46 | return response 47 | 48 | def action2description(step): 49 | 50 | w, h = Image.open('AITW_simplified/aitw_images/' + step['img_filename']+'.png' ).size 51 | 52 | intention = step['goal'] 53 | action = step['action_type_text'] 54 | coord1_x, coord1_y = step['touch'][0]*w, step['touch'][1]*h 55 | coord2_x, coord2_y = step['lift'][0]*w, step['lift'][1]*h 56 | text = step['type_text'] 57 | 58 | 59 | if action == 'click': 60 | descrpition = f'### Action type: {action}\n\ 61 | ### Coordinates: ({coord1_x},{coord1_y})' 62 | elif action == 'type': 63 | descrpition = f'### Action type: {action}\n\ 64 | ### Content: {text}' 65 | else: 66 | descrpition = action 67 | 68 | return descrpition 69 | 70 | def check_repeat_item(img_path, page_summary, search_document, embedding_model): 71 | if len(search_document) == 0: 72 | return None, None 73 | vectorstore = FAISS.from_documents(search_document, embedding_model) 74 | search_res = vectorstore.similarity_search(page_summary) 75 | 76 | old_description = "" 77 | for i, res in enumerate(search_res): 78 | old_description += f'{i+1}. ' + res.page_content + '\n' 79 | 80 | check_repeat_prompt = prompts.check_repeat.format(old_description=old_description) 81 | check_repeat_res = chat([img_path], check_repeat_prompt) 82 | sample_index = check_repeat_res.split('### Index: ')[1] 83 | 84 | if sample_index == 'None': 85 | return None, None 86 | else: 87 | sample_index = int(sample_index) - 1 88 | old_img_path = search_res[sample_index].metadata['img_path'] 89 | double_check_res = chat([old_img_path, img_path], prompts.check_repeat_2) 90 | double_check_res = double_check_res.split('### Conclusion: ')[1].strip() 91 | assert double_check_res in ['Yes','No'] 92 | if double_check_res == 'No': 93 | return None, None 94 | repeat_index = search_res[sample_index].metadata['index'] 95 | new_summary = search_res[sample_index].page_content#check_repeat_res.split('### New Summary: ')[1] 96 | return new_summary, repeat_index 97 | 98 | 99 | 100 | def create_new_item(img_path, knowledge_library, search_document, embedding_model): 101 | page_summary = chat([img_path], prompts.page_summary) 102 | new_summary, repeat_index = check_repeat_item(img_path, page_summary, search_document, embedding_model) 103 | if repeat_index is None: 104 | knowledge_item = {} 105 | knowledge_item['index'] = len(knowledge_library) 106 | knowledge_item['page_summary'] = page_summary#.split('### Page Summary: ')[1] 107 | knowledge_item['original_image'] = [] 108 | knowledge_item['next_page_list'] = [{'actions':[],'page_index':None}] 109 | knowledge_library[knowledge_item['index']] = knowledge_item 110 | search_document.append(Document(page_content = page_summary, metadata = {"index": knowledge_item['index'], "img_path": img_path})) 111 | else: 112 | knowledge_library[repeat_index]['page_summary'] = new_summary 113 | search_document[repeat_index].page_content = new_summary 114 | knowledge_item = knowledge_library[repeat_index] 115 | 116 | return knowledge_item 117 | 118 | def get_item(img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model): 119 | if last_page_idx is None: 120 | knowledge_item = create_new_item(img_path, knowledge_library, search_document, embedding_model) 121 | redirection_flag = True 122 | else: 123 | redirection_res = chat([last_img_path, img_path], prompts.redirection_judge.format(action=last_action_summary)) 124 | redirection_res = redirection_res.split('### Conclusion: ')[1].strip() 125 | assert redirection_res in ['Yes','No'] 126 | if redirection_res == 'Yes': 127 | knowledge_item = create_new_item(img_path, knowledge_library, search_document, embedding_model) 128 | redirection_flag = True 129 | elif redirection_res == 'No': 130 | knowledge_item = knowledge_library[last_page_idx] 131 | redirection_flag = False 132 | knowledge_item['original_image'].append(img_path.split('http://localhost:6666/aitw_images/')[1]) 133 | return knowledge_item, redirection_flag 134 | 135 | 136 | 137 | aitw_train_data = json.load(open('aitw_annots/aitw_data_train.json','r')) 138 | aitw_data_type_list = [ 'install','googleapps','general','single','webshopping'] 139 | embedding_model_name = "bge-m3" 140 | embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'}) 141 | 142 | 143 | for aitw_data_type in aitw_data_type_list: 144 | knowledge_library = {} 145 | search_document = [] 146 | selected_episode = random.sample(aitw_train_data[aitw_data_type], len(aitw_train_data[aitw_data_type]) // 10) 147 | for episode in tqdm(selected_episode): 148 | last_page_idx = None 149 | last_img_path = None 150 | last_action_summary = None 151 | for i in range(len(episode)): 152 | img_path = 'http://localhost:6666/aitw_images/'+episode[i]['img_filename']+'.png' 153 | 154 | if last_page_idx is not None: 155 | action_description = action2description(episode[i-1]) 156 | if action_description[:10] == '### Action': 157 | last_action_summary = chat([last_img_path], prompts.action_summary.format(action_description=action_description)) 158 | else: 159 | last_action_summary = action_description 160 | 161 | knowledge_item, redirection_flag = get_item(img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model) 162 | 163 | if last_page_idx is not None: 164 | knowledge_library[last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary) 165 | knowledge_library[last_page_idx]['next_page_list'][-1]['goal'] = episode[i]['goal'] 166 | if redirection_flag: 167 | knowledge_library[last_page_idx]['next_page_list'][-1]['page_index'] = knowledge_item['index'] 168 | knowledge_library[last_page_idx]['next_page_list'].append({'actions':[],'page_index':None}) 169 | 170 | last_page_idx = knowledge_item['index'] 171 | last_img_path = img_path 172 | 173 | f_json = open(f'{aitw_data_type}_library.json', 'w') 174 | json.dump(knowledge_library, f_json, ensure_ascii=False, indent=4) 175 | f_json.close() 176 | 177 | -------------------------------------------------------------------------------- /document_construction/odyssey_document/main.py: -------------------------------------------------------------------------------- 1 | # from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 2 | # from qwen_vl_utils import process_vision_info 3 | import random 4 | 5 | # import cv2 6 | import copy 7 | import os 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | import requests 11 | from urllib.parse import quote 12 | import json 13 | from tqdm import tqdm 14 | from langchain_community.vectorstores import FAISS 15 | from langchain.schema import Document 16 | from langchain_huggingface import HuggingFaceEmbeddings 17 | from PIL import Image 18 | import numpy as np 19 | 20 | import prompts 21 | 22 | 23 | url = "http://localhost:8000/v1/chat/completions" 24 | 25 | headers = { 26 | "Content-Type": "application/json" 27 | } 28 | 29 | def chat(img_url_list: str = '', query: str = '') -> dict: 30 | 31 | content = [] 32 | for img_url in img_url_list: 33 | img_url = quote(img_url, safe='/:') 34 | content.append({"type": "image_url", "image_url": {"url": img_url}}) 35 | content.append({"type": "text", "text": query}) 36 | data = { 37 | "model": "Qwen2.5-VL-72B-Instruct", 38 | "messages": [ 39 | {"role": "system", "content": "You are a helpful assistant."}, 40 | {"role": "user", "content": content} 41 | ], 42 | 'temperature':0 43 | } 44 | 45 | response = requests.post(url, headers=headers, data=json.dumps(data)) 46 | response = response.json() 47 | response = response['choices'][0]['message']['content'] 48 | 49 | return response 50 | 51 | 52 | 53 | def get_action_summary(img_path, step): 54 | action = step['action'] 55 | info = step['info'] 56 | assert action in ['CLICK', 'TEXT', 'SCROLL', 'LONG_PRESS'] 57 | 58 | if action == 'CLICK' or action == "LONG_PRESS": 59 | if info == 'KEY_HOME': 60 | gt = 'press home to go to the home screen' 61 | elif info == 'KEY_BACK': 62 | gt = 'press back to go to the previous screen' 63 | elif info == 'KEY_APPSELECT': 64 | gt = 'go to the previous App' 65 | elif type(info) == list: 66 | 67 | w, h = Image.open('GUI-Odyssey-master/data/screenshots/' + step['screenshot']).size 68 | bbox_str = f'[{int(info[0][0]/1000*w)}, {int(info[0][1]/1000*h)}]' 69 | query = prompts.click_action_summary.format(bbox=bbox_str) 70 | gt = chat([img_path], query) 71 | if gt[-1] == '.': 72 | gt = gt[:-1] 73 | 74 | else: 75 | raise ValueError(f'Unknown click action {info}') 76 | 77 | elif action == 'SCROLL': 78 | start = np.array(info[0]) 79 | end = np.array(info[1]) 80 | delta = end - start 81 | delta_abs = np.abs(delta) 82 | lr = 'left' if delta[0] < 0 else 'right' 83 | ud = 'up' if delta[1] < 0 else 'down' 84 | if delta_abs[0] > delta_abs[1]: 85 | gt = f"scroll {lr}" 86 | else: 87 | gt = f"scroll {ud}" 88 | 89 | elif action == 'TEXT': 90 | gt = f'type {info}' 91 | 92 | return gt 93 | 94 | def check_repeat_item(domain, img_path, page_summary, search_document, embedding_model): 95 | if len(search_document[domain]) == 0: 96 | return None, None 97 | vectorstore = FAISS.from_documents(search_document[domain], embedding_model) 98 | search_res = vectorstore.similarity_search(page_summary) 99 | 100 | old_description = "" 101 | for i, res in enumerate(search_res): 102 | old_description += f'{i+1}. ' + res.page_content + '\n' 103 | 104 | check_repeat_prompt = prompts.check_repeat.format(old_description=old_description) 105 | check_repeat_res = chat([img_path], check_repeat_prompt) 106 | sample_index = check_repeat_res.split('### Index: ')[1].strip()#.split('\n')[0] 107 | 108 | if sample_index == 'None': 109 | return None, None 110 | else: 111 | sample_index = int(sample_index) - 1 112 | old_img_path = search_res[sample_index].metadata['img_path'] 113 | double_check_res = chat([old_img_path, img_path], prompts.check_repeat_2) 114 | double_check_res = double_check_res.split('### Conclusion: ')[1].strip() 115 | assert double_check_res in ['Yes','No'] 116 | if double_check_res == 'No': 117 | return None, None 118 | repeat_index = search_res[sample_index].metadata['index'] 119 | new_summary = search_res[sample_index].page_content#check_repeat_res.split('### New Summary: ')[1] 120 | return new_summary, repeat_index 121 | 122 | 123 | 124 | def create_new_item(domain, img_path, knowledge_library, search_document, embedding_model): 125 | page_summary = chat([img_path], prompts.page_summary) 126 | new_summary, repeat_index = check_repeat_item(domain, img_path, page_summary, search_document, embedding_model) 127 | if repeat_index is None: 128 | knowledge_item = {} 129 | knowledge_item['index'] = len(knowledge_library[domain]) 130 | knowledge_item['page_summary'] = page_summary#.split('### Page Summary: ')[1] 131 | knowledge_item['original_image'] = [] 132 | knowledge_item['next_page_list'] = [{'actions':[],'page_index':None}] 133 | knowledge_library[domain][knowledge_item['index']] = knowledge_item 134 | search_document[domain].append(Document(page_content = page_summary, metadata = {"index": knowledge_item['index'], "img_path": img_path})) 135 | else: 136 | knowledge_library[domain][repeat_index]['page_summary'] = new_summary 137 | search_document[domain][repeat_index].page_content = new_summary 138 | knowledge_item = knowledge_library[domain][repeat_index] 139 | 140 | return knowledge_item 141 | 142 | def get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model): 143 | if last_page_idx is None: 144 | knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model) 145 | redirection_flag = True 146 | else: 147 | redirection_res = chat([last_img_path, img_path], prompts.redirection_judge.format(action=last_action_summary)) 148 | redirection_res = redirection_res.split('### Conclusion: ')[1].strip() 149 | assert redirection_res in ['Yes','No'] 150 | if redirection_res == 'Yes': 151 | knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model) 152 | redirection_flag = True 153 | elif redirection_res == 'No': 154 | knowledge_item = knowledge_library[domain][last_page_idx] 155 | redirection_flag = False 156 | knowledge_item['original_image'].append(img_path.split('http://localhost:6668/')[1]) 157 | return knowledge_item, redirection_flag 158 | 159 | 160 | odyssey_data = json.load(open('data/splits/splits_random_split.json','r')) 161 | annotations_path = 'data/annotations/' 162 | imgs_path = 'data/screenshots/' 163 | embedding_model_name = "bge-m3" 164 | embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'}) 165 | 166 | knowledge_library = {} 167 | search_document = {} 168 | selected_episode_idx = random.sample(odyssey_data['train'], len(odyssey_data['train']) // 50) 169 | for train_idx in tqdm(selected_episode_idx): 170 | episode = json.load(open(annotations_path + train_idx,'r')) 171 | last_page_idx = None 172 | last_img_path = None 173 | last_action_summary = None 174 | domain = episode['task_info']['category'] 175 | if domain not in list(knowledge_library.keys()): 176 | knowledge_library[domain] = {} 177 | search_document[domain] = [] 178 | goal = episode['task_info']['instruction'] 179 | action_list = episode['steps'] 180 | for i in range(len(action_list)): 181 | img_path = 'http://localhost:6668/'+action_list[i]['screenshot'] 182 | 183 | if last_page_idx is not None: 184 | last_action_summary = get_action_summary(last_img_path, action_list[i-1]) 185 | 186 | knowledge_item, redirection_flag = get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model) 187 | 188 | if last_page_idx is not None: 189 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary) 190 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['goal'] = goal 191 | if redirection_flag: 192 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['page_index'] = knowledge_item['index'] 193 | knowledge_library[domain][last_page_idx]['next_page_list'].append({'actions':[],'page_index':None}) 194 | 195 | last_page_idx = knowledge_item['index'] 196 | last_img_path = img_path 197 | 198 | 199 | f_json = open(f'odyssey_library.json', 'w') 200 | json.dump(knowledge_library, f_json, ensure_ascii=False, indent=4) 201 | f_json.close() 202 | 203 | -------------------------------------------------------------------------------- /document_construction/mind2web_document/main.py: -------------------------------------------------------------------------------- 1 | # from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 2 | # from qwen_vl_utils import process_vision_info 3 | import random 4 | 5 | # import cv2 6 | import copy 7 | import os 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | import requests 11 | from urllib.parse import quote 12 | import json 13 | from tqdm import tqdm 14 | from langchain_community.vectorstores import FAISS 15 | from langchain.schema import Document 16 | from langchain_huggingface import HuggingFaceEmbeddings 17 | 18 | import prompts 19 | 20 | 21 | url = "http://localhost:8000/v1/chat/completions" 22 | 23 | headers = { 24 | "Content-Type": "application/json" 25 | } 26 | 27 | def chat(img_url_list: str = '', query: str = '') -> dict: 28 | 29 | content = [] 30 | for img_url in img_url_list: 31 | img_url = quote(img_url, safe='/:') 32 | content.append({"type": "image_url", "image_url": {"url": img_url}}) 33 | content.append({"type": "text", "text": query}) 34 | data = { 35 | "model": "Qwen2.5-VL-72B-Instruct", 36 | "messages": [ 37 | {"role": "system", "content": "You are a helpful assistant."}, 38 | {"role": "user", "content": content} 39 | ], 40 | 'temperature':0 41 | } 42 | 43 | response = requests.post(url, headers=headers, data=json.dumps(data)) 44 | response = response.json() 45 | response = response['choices'][0]['message']['content'] 46 | 47 | return response 48 | 49 | def get_action_summary(img_path, action): 50 | action_type = action['operation']['op'] 51 | assert action_type in ['CLICK', 'TYPE', 'SELECT'] 52 | 53 | bbox = [int(action["bbox"]["x"]), int(action["bbox"]["y"]), int(action["bbox"]["x"] + action["bbox"]["width"]), 54 | int(action["bbox"]["y"] + action["bbox"]["height"])] 55 | bbox_str = f'[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]' 56 | if action_type == 'CLICK': 57 | query = prompts.click_action_summary.format(bbox=bbox_str) 58 | elif action_type == 'TYPE': 59 | query = prompts.type_action_summary.format(content=action['operation']['value'],bbox=bbox_str) 60 | elif action_type == 'SELECT': 61 | query = prompts.select_action_summary.format(content=action['operation']['value'],bbox=bbox_str) 62 | 63 | action_summary = chat([img_path], query) 64 | if action_summary[-1] == '.': 65 | action_summary = action_summary[:-1] 66 | # if len(action['pos_candidates']) > 0: 67 | # print(action['pos_candidates'][0]['choice']) 68 | return action_summary 69 | 70 | def check_repeat_item(domain, img_path, page_summary, search_document, embedding_model): 71 | if len(search_document[domain]) == 0: 72 | return None, None 73 | vectorstore = FAISS.from_documents(search_document[domain], embedding_model) 74 | search_res = vectorstore.similarity_search(page_summary) 75 | 76 | old_description = "" 77 | for i, res in enumerate(search_res): 78 | old_description += f'{i+1}. ' + res.page_content + '\n' 79 | 80 | check_repeat_prompt = prompts.check_repeat.format(old_description=old_description) 81 | check_repeat_res = chat([img_path], check_repeat_prompt) 82 | sample_index = check_repeat_res.split('### Index: ')[1].strip()#.split('\n')[0] 83 | 84 | if sample_index == 'None': 85 | return None, None 86 | else: 87 | sample_index = int(sample_index) - 1 88 | old_img_path = search_res[sample_index].metadata['img_path'] 89 | double_check_res = chat([old_img_path, img_path], prompts.check_repeat_2) 90 | double_check_res = double_check_res.split('### Conclusion: ')[1].strip() 91 | assert double_check_res in ['Yes','No'] 92 | if double_check_res == 'No': 93 | return None, None 94 | repeat_index = search_res[sample_index].metadata['index'] 95 | new_summary = search_res[sample_index].page_content#check_repeat_res.split('### New Summary: ')[1] 96 | return new_summary, repeat_index 97 | 98 | 99 | 100 | def create_new_item(domain, img_path, knowledge_library, search_document, embedding_model): 101 | page_summary = chat([img_path], prompts.page_summary) 102 | new_summary, repeat_index = check_repeat_item(domain, img_path, page_summary, search_document, embedding_model) 103 | if repeat_index is None: 104 | knowledge_item = {} 105 | knowledge_item['index'] = len(knowledge_library[domain]) 106 | knowledge_item['page_summary'] = page_summary#.split('### Page Summary: ')[1] 107 | knowledge_item['original_image'] = [] 108 | knowledge_item['next_page_list'] = [{'actions':[],'page_index':None}] 109 | knowledge_library[domain][knowledge_item['index']] = knowledge_item 110 | search_document[domain].append(Document(page_content = page_summary, metadata = {"index": knowledge_item['index'], "img_path": img_path})) 111 | else: 112 | knowledge_library[domain][repeat_index]['page_summary'] = new_summary 113 | search_document[domain][repeat_index].page_content = new_summary 114 | knowledge_item = knowledge_library[domain][repeat_index] 115 | 116 | return knowledge_item 117 | 118 | def get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model): 119 | if last_page_idx is None: 120 | knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model) 121 | redirection_flag = True 122 | else: 123 | redirection_res = chat([last_img_path, img_path], prompts.redirection_judge.format(action=last_action_summary)) 124 | redirection_res = redirection_res.split('### Conclusion: ')[1].strip() 125 | assert redirection_res in ['Yes','No'] 126 | if redirection_res == 'Yes': 127 | knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model) 128 | redirection_flag = True 129 | elif redirection_res == 'No': 130 | knowledge_item = knowledge_library[domain][last_page_idx] 131 | redirection_flag = False 132 | knowledge_item['original_image'].append(img_path.split('http://localhost:6667/mind2web_images/')[1]) 133 | return knowledge_item, redirection_flag 134 | 135 | 136 | 137 | mind2web_train_data = json.load(open('mind2web_annots/mind2web_data_train.json','r')) 138 | embedding_model_name = "bge-m3" 139 | embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'}) 140 | 141 | knowledge_library = {} 142 | search_document = {} 143 | selected_episode = random.sample(mind2web_train_data, len(mind2web_train_data) // 10) 144 | for episode in tqdm(selected_episode): 145 | last_page_idx = None 146 | last_img_path = None 147 | last_action_summary = None 148 | domain = episode['domain'] 149 | if domain not in list(knowledge_library.keys()): 150 | knowledge_library[domain] = {} 151 | search_document[domain] = [] 152 | goal = episode['confirmed_task'] 153 | episode_id = episode['annotation_id'] 154 | action_list = episode['actions'] 155 | terminate_flag = False 156 | for i in range(len(action_list)): 157 | img_path = 'http://localhost:6667/mind2web_images/'+episode_id+'-'+action_list[i]['action_uid']+'.jpg' 158 | 159 | if not os.path.exists('mind2web_images/'+episode_id+'-'+action_list[i]['action_uid']+'.jpg'): 160 | terminate_flag = True 161 | print('IMAGE NOT FOUND') 162 | print(episode_id+'-'+action_list[i]['action_uid']) 163 | break 164 | 165 | if last_page_idx is not None: 166 | last_action_summary = get_action_summary(last_img_path, action_list[i-1]) 167 | 168 | knowledge_item, redirection_flag = get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model) 169 | 170 | if last_page_idx is not None: 171 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary) 172 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['goal'] = goal 173 | if redirection_flag: 174 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['page_index'] = knowledge_item['index'] 175 | knowledge_library[domain][last_page_idx]['next_page_list'].append({'actions':[],'page_index':None}) 176 | 177 | last_page_idx = knowledge_item['index'] 178 | last_img_path = img_path 179 | 180 | if terminate_flag: 181 | continue 182 | 183 | if len(action_list) > 1: 184 | last_action_summary = get_action_summary(last_img_path, action_list[-1]) 185 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary) 186 | knowledge_library[domain][last_page_idx]['next_page_list'][-1]['goal'] = goal 187 | 188 | f_json = open(f'mind2web_library.json', 'w') 189 | json.dump(knowledge_library, f_json, ensure_ascii=False, indent=4) 190 | f_json.close() 191 | -------------------------------------------------------------------------------- /workflow/odyssey/odyssey_test.py: -------------------------------------------------------------------------------- 1 | # evaluation on odyssey 2 | import os 3 | import random 4 | import torch 5 | import json 6 | from tqdm import tqdm 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | from peft import AutoPeftModelForCausalLM 9 | from transformers.generation import GenerationConfig 10 | import re 11 | import logging 12 | import ast 13 | import argparse 14 | from PIL import Image 15 | import numpy as np 16 | from langchain_community.vectorstores import FAISS 17 | from langchain.schema import Document 18 | from langchain_huggingface import HuggingFaceEmbeddings 19 | from collections import deque 20 | import requests 21 | 22 | from prompts import ODYSSEY_GLOBAL_PLANNING_PROMT, ODYSSEY_OBSERVATION_PROMT, ODYSSEY_PLANNING_PROMT, ODYSSEY_EXECUTION_PROMT, PAGE_SUMMARY_PROMPT, REFERENCE_FORMAT, ACTION_SUMMARY_PROMPT 23 | import action_matching 24 | 25 | torch.manual_seed(0) 26 | torch.cuda.manual_seed_all(0) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | random.seed(0) 30 | torch.manual_seed(0) 31 | torch.cuda.manual_seed_all(0) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | random.seed(0) 35 | 36 | url = "http://localhost:8000/v1/chat/completions" 37 | headers = { 38 | "Content-Type": "application/json" 39 | } 40 | 41 | 42 | def get_global_plan(img_path, goal): 43 | global_plan_prompt = ODYSSEY_GLOBAL_PLANNING_PROMT.replace('',goal) 44 | global_plan = chat([img_path], global_plan_prompt) 45 | return global_plan.split('### Global Plan ###')[-1].strip() 46 | 47 | def get_execution(img_path, action_plan, reference_actions): 48 | exec_prompt = ODYSSEY_EXECUTION_PROMT.replace('',action_plan).replace('',reference_actions) 49 | response = chat([img_path], exec_prompt) 50 | execution = response.split('### Action ###')[-1].strip() 51 | thought = response.split('### Thought ###')[-1].split('### Action ###')[0].strip() 52 | return thought, execution 53 | 54 | def get_observation(img_path, goal, previous_step): 55 | if previous_step == '': 56 | previous_step = '' 57 | obs_prompt = ODYSSEY_OBSERVATION_PROMT.replace('',goal).replace('',previous_step) 58 | observation = chat([img_path], obs_prompt) 59 | return observation 60 | 61 | def get_plan_action(img_path, goal, observations, global_plan, reference_actions, previous_step): 62 | if previous_step == '': 63 | previous_step = '' 64 | plan_prompt = ODYSSEY_PLANNING_PROMT.replace('',goal) 65 | plan_prompt = plan_prompt.replace('',observations) 66 | plan_prompt = plan_prompt.replace('',global_plan) 67 | plan_prompt = plan_prompt.replace('',reference_actions) 68 | plan_prompt = plan_prompt.replace('',previous_step) 69 | plan_action = chat([img_path], plan_prompt) 70 | return plan_action 71 | 72 | def bfs_goals(goal_list, idx, search_document): 73 | if idx is None: return 74 | queue = deque([(idx, 0)]) 75 | visited = set() 76 | visited.add(idx) 77 | 78 | while queue: 79 | cur_node, cur_depth = queue.popleft() 80 | if cur_depth >= 3: continue 81 | 82 | nxt_node_list = search_document[cur_node].metadata['next_page_list'] 83 | for nxt_node in nxt_node_list: 84 | if nxt_node['actions'] == []: continue 85 | if nxt_node['goal'] not in goal_list: 86 | goal_list.append(nxt_node['goal']) 87 | node_idx = nxt_node['page_index'] 88 | if node_idx is not None and node_idx not in visited: 89 | visited.add(node_idx) 90 | queue.append((node_idx, cur_depth+1)) 91 | 92 | 93 | def get_reference_actions(img_path, search_key, goal, search_document, embedding_model): 94 | reference_actions = '' 95 | 96 | page_summary = chat([img_path], PAGE_SUMMARY_PROMPT) 97 | 98 | count = 0 99 | max_count = 10 100 | 101 | vectorstore = FAISS.from_documents(search_document[search_key], embedding_model) 102 | search_res = vectorstore.similarity_search(page_summary) 103 | 104 | for res in search_res: 105 | for actions_chain in res.metadata['next_page_list']: 106 | if len(actions_chain['actions'])==0: continue 107 | count=count+1 108 | 109 | action_string = '' 110 | for one_action in actions_chain['actions']: 111 | action_string += ', ' + one_action 112 | action_string = action_string[2:] 113 | 114 | goal_list = [actions_chain['goal']] 115 | bfs_goals(goal_list, actions_chain['page_index'], search_document[search_key]) 116 | goals_string = '' 117 | for one_goal in goal_list: 118 | if one_goal[-1] == '.': 119 | one_goal = one_goal[:-1] 120 | goals_string += '; ' + one_goal 121 | goals_string = goals_string[2:] 122 | 123 | one_reference = REFERENCE_FORMAT.format(idx = count, actions = action_string, goals = goals_string) 124 | reference_actions += one_reference 125 | if count == max_count: break 126 | if count == max_count: break 127 | 128 | return reference_actions 129 | 130 | 131 | def get_action_summary(step, img_path): 132 | action = step['action'] 133 | info = step['info'] 134 | assert action in ['CLICK', 'TEXT', 'SCROLL', 'LONG_PRESS', 'COMPLETE', 'INCOMPLETE'] 135 | 136 | if action == 'CLICK' or action == "LONG_PRESS": 137 | if info == 'KEY_HOME': 138 | gt = 'press home to go to the home screen' 139 | elif info == 'KEY_BACK': 140 | gt = 'press back to go to the previous screen' 141 | elif info == 'KEY_APPSELECT': 142 | gt = 'go to the previous App' 143 | elif type(info) == list: 144 | 145 | w, h = Image.open('data/screenshots/' + step['screenshot']).size 146 | bbox_str = f'[{int(info[0][0]/1000*w)}, {int(info[0][1]/1000*h)}]' 147 | query = ACTION_SUMMARY_PROMPT.format(bbox=bbox_str) 148 | gt = chat([img_path], query) 149 | if gt[-1] == '.': 150 | gt = gt[:-1] 151 | 152 | else: 153 | raise ValueError(f'Unknown click action {info}') 154 | 155 | elif action == 'SCROLL': 156 | start = np.array(info[0]) 157 | end = np.array(info[1]) 158 | delta = end - start 159 | delta_abs = np.abs(delta) 160 | lr = 'left' if delta[0] < 0 else 'right' 161 | ud = 'up' if delta[1] < 0 else 'down' 162 | if delta_abs[0] > delta_abs[1]: 163 | gt = f"scroll {lr}" 164 | else: 165 | gt = f"scroll {ud}" 166 | 167 | elif action == 'TEXT': 168 | gt = f'type {info}' 169 | elif action == 'COMPLETE': 170 | gt = action 171 | elif action == 'INCOMPLETE': 172 | gt = 'IMPOSSIBLE' 173 | else: 174 | raise ValueError(f'Unknown action {action}') 175 | 176 | return gt 177 | 178 | def decode_action(action, info): 179 | if action == 'CLICK' or action == "LONG_PRESS": 180 | if info == 'KEY_HOME': 181 | gt = 'PRESS_HOME' 182 | elif info == 'KEY_BACK': 183 | gt = 'PRESS_BACK' 184 | elif info == 'KEY_APPSELECT': 185 | gt = 'PRESS_RECENT' 186 | elif type(info) == list: 187 | gt = f"{action}: {tuple(info[0])}" 188 | else: 189 | raise ValueError(f'Unknown click action {info}') 190 | 191 | elif action == 'SCROLL': 192 | start = np.array(info[0]) 193 | end = np.array(info[1]) 194 | delta = end - start 195 | delta_abs = np.abs(delta) 196 | lr = 'LEFT' if delta[0] < 0 else 'RIGHT' 197 | ud = 'UP' if delta[1] < 0 else 'DOWN' 198 | if delta_abs[0] > delta_abs[1]: 199 | gt = f"SCROLL: {lr}" 200 | else: 201 | gt = f"SCROLL: {ud}" 202 | 203 | elif action == 'TEXT': 204 | gt = f'TYPE: {info}' 205 | elif action == 'COMPLETE': 206 | gt = action 207 | elif action == 'INCOMPLETE': 208 | gt = 'IMPOSSIBLE' 209 | else: 210 | raise ValueError(f'Unknown action {action}') 211 | return gt 212 | 213 | def document_transform(raw_document): 214 | search_document = {} 215 | for type_name, pages in raw_document.items(): 216 | document = [] 217 | for idx in pages: 218 | item = pages[idx] 219 | document.append(Document(page_content = item['page_summary'], metadata = item)) 220 | search_document[type_name] = document 221 | return search_document 222 | 223 | def chat(img_url_list: str = '', query: str = '') -> dict: 224 | 225 | content = [] 226 | for img_url in img_url_list: 227 | content.append({"type": "image_url", "image_url": {"url": img_url}}) 228 | content.append({"type": "text", "text": query}) 229 | data = { 230 | "model": "Qwen2.5-VL-72B-Instruct", 231 | "messages": [ 232 | {"role": "system", "content": "You are a powerful agent that is trained to perform some basic tasks on the web page."}, 233 | {"role": "user", "content": content} 234 | ], 235 | "temperature":0} 236 | response = requests.post(url, headers=headers, data=json.dumps(data)) 237 | response = response.json() 238 | response = response['choices'][0]['message']['content'] 239 | 240 | return response 241 | 242 | if __name__ == '__main__': 243 | odyssey_data = json.load(open('data/splits/splits_random_split.json','r')) 244 | annotations_path = 'data/annotations/' 245 | imgs_path = 'data/screenshots/' 246 | embedding_model_name = "bge-m3" 247 | embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'}) 248 | raw_document = json.load(open('odyssey_library.json', 'r')) 249 | search_document = document_transform(raw_document) 250 | 251 | outputs = [] 252 | for test_idx in tqdm(odyssey_data['test']): 253 | episode = json.load(open(annotations_path + test_idx,'r')) 254 | domain = episode['task_info']['category'] 255 | goal = episode['task_info']['instruction'] 256 | previous_actions = [] 257 | global_plan = '' 258 | flag = 0 259 | 260 | for step in episode["steps"]: 261 | img_path = 'http://localhost:6668/'+step['screenshot'] 262 | gt = decode_action(step['action'],step['info']) 263 | 264 | previous_step = "" 265 | for i, action in enumerate(previous_actions[-4:]): 266 | previous_step += 'Step' + str(i+1) + ': ' + action + ". \n" 267 | 268 | action_step = get_action_summary(step, img_path) 269 | previous_actions.append(action_step) 270 | 271 | observations = get_observation(img_path, goal, previous_step) 272 | reference_actions = get_reference_actions(img_path, domain, goal, search_document, embedding_model) 273 | if flag == 0: 274 | global_plan = get_global_plan(img_path, goal) 275 | flag = 1 276 | plan_action = get_plan_action(img_path, goal, observations, global_plan, reference_actions, previous_step) 277 | thought, pred = get_execution(img_path, plan_action, reference_actions) 278 | 279 | more_info = {'category': domain, 'step_length': episode['step_length']} 280 | outputs.append({ 281 | 'question': goal, 282 | 'pred': str(pred), 283 | 'gt': gt, 284 | 'more_info': more_info, 285 | 'img': img_path.split('/')[-1] 286 | }) 287 | print('-------step:{}----------'.format(step['step'])) 288 | print('Goal: ', goal) 289 | print('Img: ', img_path) 290 | print('History: ', previous_step) 291 | print('gt: ', gt) 292 | print('Observation: ', observations) 293 | print('Global Planning: \n', global_plan) 294 | print('References: \n',reference_actions) 295 | print('Loacl Planning: \n', plan_action) 296 | print('Thought: ',thought) 297 | print('Decision: \n', pred) 298 | print('---------------------------------------------') 299 | 300 | savefile = 'odyssey_record.json' 301 | json.dump(outputs, open(savefile, 'w'), indent=4, ensure_ascii=False) 302 | 303 | 304 | print(f"Saving predict result ...") 305 | 306 | savefile = 'odyssey_record.json' 307 | json.dump(outputs, open(savefile, 'w'), indent=4, ensure_ascii=False) 308 | 309 | print(f"Evaluating ...") 310 | metrics = action_matching.odyssey_action_matching_evaluation(outputs, metric='micro') 311 | 312 | savefile2 = 'odyssey_eval.json' 313 | json.dump(metrics, open(savefile2, 'w'), indent=4, ensure_ascii=False) 314 | 315 | -------------------------------------------------------------------------------- /workflow/aitw/aitw_test.py: -------------------------------------------------------------------------------- 1 | # evaluation on aitw 2 | import requests 3 | import os 4 | import random 5 | import torch 6 | import json 7 | from collections import deque 8 | from tqdm import tqdm 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel 10 | from peft import AutoPeftModelForCausalLM 11 | from transformers.generation import GenerationConfig 12 | import re 13 | import logging 14 | import ast 15 | import argparse 16 | from PIL import Image 17 | import numpy as np 18 | from langchain_community.vectorstores import FAISS 19 | from langchain.schema import Document 20 | from langchain_huggingface import HuggingFaceEmbeddings 21 | import time 22 | 23 | from prompts import AITW_GLOBAL_PLANNING_PROMT, AITW_OBSERVATION_PROMT, AITW_PLANNING_PROMT, AITW_EXECUTION_PROMT, PAGE_SUMMARY_PROMPT, REFERENCE_FORMAT, ACTION_SUMMARY_PROMPT 24 | import action_matching 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | def get_global_plan(img_path, global_plan, previous_actions, goal): 29 | if global_plan == '': global_plan = 'No old global plan.' 30 | if len(previous_actions) == 1: 31 | last_action = 'No action taken before.' 32 | else: 33 | last_action = previous_actions[-2] 34 | global_plan_prompt = AITW_GLOBAL_PLANNING_PROMT.replace('',goal) 35 | global_plan = chat([img_path], global_plan_prompt) 36 | return global_plan.split('### Global Plan ###')[-1].strip() 37 | 38 | def get_execution(img_path, action_plan, reference_actions): 39 | exec_prompt = AITW_EXECUTION_PROMT.replace('',action_plan) 40 | 41 | exec_prompt = exec_prompt.replace('',"") 42 | execution = chat([img_path], exec_prompt) 43 | 44 | execution = execution.split('### Action ###')[-1].strip() 45 | return execution 46 | 47 | def get_observation(img_path, goal, previous_step): 48 | if previous_step == '': 49 | previous_step = 'No previous step has been taken.' 50 | obs_prompt = AITW_OBSERVATION_PROMT.replace('',goal).replace('',previous_step) 51 | observation = chat([img_path], obs_prompt) 52 | return observation 53 | 54 | def get_plan_action(img_path, goal, observations, global_plan, reference_actions,previous_step): 55 | plan_prompt = AITW_PLANNING_PROMT.replace('',goal) 56 | plan_prompt = plan_prompt.replace('',observations) 57 | plan_prompt = plan_prompt.replace('',global_plan) 58 | plan_prompt = plan_prompt.replace('',reference_actions) 59 | plan_prompt = plan_prompt.replace('',previous_step) 60 | plan_action = chat([img_path], plan_prompt) 61 | return plan_action 62 | 63 | def bfs_goals(goal_list, idx, search_document): 64 | if idx is None: return 65 | queue = deque([(idx, 0)]) 66 | visited = set() 67 | visited.add(idx) 68 | 69 | while queue: 70 | cur_node, cur_depth = queue.popleft() 71 | if cur_depth >= 3: continue 72 | 73 | nxt_node_list = search_document[cur_node].metadata['next_page_list'] 74 | for nxt_node in nxt_node_list: 75 | if nxt_node['actions'] == []: continue 76 | if nxt_node['goal'] not in goal_list: 77 | goal_list.append(nxt_node['goal']) 78 | node_idx = nxt_node['page_index'] 79 | if node_idx is not None and node_idx not in visited: 80 | visited.add(node_idx) 81 | queue.append((node_idx, cur_depth+1)) 82 | 83 | 84 | def get_reference_actions(img_path, goal, search_document, embedding_model): 85 | reference_actions = '' 86 | 87 | page_summary = chat([img_path], PAGE_SUMMARY_PROMPT) 88 | 89 | 90 | vectorstore = FAISS.from_documents(search_document, embedding_model) 91 | search_res = vectorstore.similarity_search(page_summary) 92 | 93 | max_count = 10 94 | count = 0 95 | for res in search_res: 96 | for actions_chain in res.metadata['next_page_list']: 97 | if len(actions_chain['actions'])==0: continue 98 | count=count+1 99 | 100 | action_string = '' 101 | for one_action in actions_chain['actions']: 102 | action_string += ', ' + one_action 103 | action_string = action_string[2:] 104 | 105 | goal_list = [actions_chain['goal']] 106 | bfs_goals(goal_list, actions_chain['page_index'], search_document) 107 | goals_string = '' 108 | for one_goal in goal_list: 109 | goals_string += ', ' + one_goal 110 | goals_string = goals_string[2:] 111 | 112 | one_reference = REFERENCE_FORMAT.format(idx = count, actions = action_string, goals = goals_string) 113 | reference_actions += one_reference 114 | if count == max_count: break 115 | if count == max_count: break 116 | 117 | return reference_actions 118 | 119 | def document_transform(raw_document): 120 | search_document = [] 121 | for idx in raw_document: 122 | item = raw_document[idx] 123 | search_document.append(Document(page_content = item['page_summary'], metadata = item)) 124 | return search_document 125 | 126 | 127 | def action2step_for_Qwen(step_data, img_path, img_filename): 128 | action_type = step_data["action_type_id"] 129 | if action_type == 4: 130 | if step_data["action_type_text"] == 'click': # for click action, we calculate midpoint of touch and lift as the click point 131 | touch_point = step_data["touch"] 132 | lift_point = step_data["lift"] 133 | action_type_new = 4 134 | click_point = [(touch_point[0] + lift_point[0]) / 2, (touch_point[1] + lift_point[1]) / 2] 135 | click_point = [f"{item:.2f}" for item in click_point] 136 | w, h = Image.open('aitw_images/' + img_filename).size 137 | click_point = "({},{})".format(int(float(click_point[0])*w), int(float(click_point[1])*h)) 138 | action_des = ACTION_SUMMARY_PROMPT.format(coordinates = click_point) 139 | action = chat([img_path], action_des) 140 | else: 141 | action = step_data["action_type_text"] 142 | elif action_type == 3: 143 | typed_text = step_data["type_text"] 144 | action = f'type {typed_text}' 145 | else: 146 | action = step_data["action_type_text"] 147 | return action 148 | 149 | torch.manual_seed(0) 150 | torch.cuda.manual_seed_all(0) 151 | torch.backends.cudnn.deterministic = True 152 | torch.backends.cudnn.benchmark = False 153 | random.seed(0) 154 | 155 | url = "http://localhost:8000/v1/chat/completions" 156 | headers = { 157 | "Content-Type": "application/json" 158 | } 159 | 160 | def chat(img_url_list: str = '', query: str = '') -> dict: 161 | content = [] 162 | for img_url in img_url_list: 163 | content.append({"type": "image_url", "image_url": {"url": img_url}}) 164 | content.append({"type": "text", "text": query}) 165 | data = { 166 | "model": "Qwen2.5-VL-72B-Instruct", 167 | "messages": [ 168 | {"role": "system", "content": "You are a powerful agent that is trained to perform some basic tasks on a smartphone."}, 169 | {"role": "user", "content": content} 170 | ], 171 | "temperature":0} 172 | 173 | response = requests.post(url, headers=headers, data=json.dumps(data)) 174 | response = response.json() 175 | response = response['choices'][0]['message']['content'] 176 | 177 | return response 178 | 179 | embedding_model_name = "bge-m3" 180 | embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'}) 181 | aitw_imgs_dir = "aitw_images" 182 | aitw_test = json.load(open('aitw_annots/aitw_data_test.json', 'r')) 183 | score_average = 0 184 | 185 | test_record = {} 186 | 187 | time_count = 0 188 | time_total = 0 189 | for task, episodes in aitw_test.items(): 190 | 191 | print("Task: " + task) 192 | testing_type = task 193 | raw_document = json.load(open(f'{testing_type}_library.json', 'r')) 194 | search_document = document_transform(raw_document) 195 | 196 | test_record[task] = [] 197 | corr_action = 0 198 | corr_type = 0 199 | num_text = 0 200 | corr_text = 0 201 | num_scroll = 0 202 | corr_scroll = 0 203 | num_click = 0 204 | corr_click = 0 205 | num_both_click = 0 206 | corr_both_click = 0 207 | num_wrong_format = 0 208 | num = 0 209 | for j, episode in enumerate(episodes): 210 | test_record[task].append([]) 211 | previous_actions = [] 212 | global_plan = '' 213 | tag = 0 214 | for st, step in enumerate(episode): 215 | one_step_record = {} 216 | one_step_record['ep_id'] = step['ep_id'] 217 | one_step_record['step'] = step['step'] 218 | one_step_record['img_filename'] = step['img_filename'] 219 | one_step_record['goal'] = step['goal'] 220 | 221 | img_filename = step["img_filename"] + '.png' 222 | img_path = os.path.join(aitw_imgs_dir, img_filename) 223 | img_path = 'http://localhost:6666/aitw_images/' + img_filename 224 | 225 | goal = step["goal"] 226 | 227 | previous_step = "" 228 | for i, action in enumerate(previous_actions[-4:]): 229 | previous_step += 'Step' + str(i+1) + ': ' + action + ". \n" 230 | 231 | action_step = action2step_for_Qwen(step, img_path, img_filename) 232 | previous_actions.append(action_step) 233 | 234 | action_ref = action_matching.action_2_format(step) 235 | 236 | t_start = time.time() 237 | observations = get_observation(img_path, goal, previous_step) 238 | reference_actions = get_reference_actions(img_path, goal, search_document, embedding_model) 239 | 240 | if tag == 0: 241 | global_plan = get_global_plan(img_path, global_plan, previous_actions, goal) 242 | tag = 1 243 | try: 244 | plan_action = get_plan_action(img_path, goal, observations, global_plan, reference_actions,previous_step) 245 | response_ = get_execution(img_path, plan_action, reference_actions) 246 | except: 247 | print('==== ERROR ====') 248 | continue 249 | response = action_matching.convert_qwen_format(response_) 250 | time_total += time.time() - t_start 251 | time_count += 1 252 | print('average inference time: ',time_total,time_count,time_total / time_count) 253 | 254 | num += 1 255 | 256 | try: 257 | action_pred = action_matching.pred_2_format_4_mpgui(response,img_filename) 258 | annot_position = np.array( 259 | [step["annot_position"][i:i + 4] for i in range(0, len(step["annot_position"]), 4)]) 260 | check_match = action_matching.check_actions_match(action_pred["touch_point"], action_pred["lift_point"], 261 | action_pred["action_type"], action_ref["touch_point"], 262 | action_ref["lift_point"], action_ref["action_type"], 263 | annot_position) 264 | print(f'-------eposide:{j+1}/step:{st+1}----------') 265 | print('Goal: ', goal) 266 | print('Correct: ', check_match) 267 | print('Img: ', img_filename) 268 | print('History: ', previous_step) 269 | print('gt: ', action_ref,step['action_addition']) 270 | print('Observation: ', observations) 271 | print('pred: ', action_pred) 272 | print('Global Planning: \n', global_plan) 273 | print('Loacl Planning: \n', plan_action) 274 | print('Decision: \n', response_) 275 | print('---------------------------------------------') 276 | # step accuracy 277 | if check_match == True: 278 | corr_action += 1 279 | match_label = 1 280 | # logging.info("Step: " + str(j) + " right") 281 | else: 282 | match_label = 0 283 | # logging.info("Step: " + str(j) + " wrong") 284 | 285 | # type accuracy 286 | if action_pred["action_type"] == action_ref["action_type"]: 287 | corr_type += 1 288 | 289 | # text accuracy 290 | if action_ref["action_type"] == 3: 291 | num_text += 1 292 | if (action_pred["typed_text"] == action_ref["typed_text"]) or ( 293 | action_pred["typed_text"] in action_ref["typed_text"]) or ( 294 | action_ref["typed_text"] in action_pred["typed_text"]): 295 | corr_text += 1 296 | 297 | if action_ref["action_type"] == 4: 298 | # click accuracy 299 | if action_matching.is_tap_action(action_ref["touch_point"], action_ref["lift_point"]): 300 | num_click += 1 301 | if match_label: 302 | corr_click += 1 303 | # scroll accuracy 304 | else: 305 | num_scroll += 1 306 | if match_label: 307 | corr_scroll += 1 308 | if (action_pred["action_type"] == 4) and action_matching.is_tap_action(action_ref["touch_point"], 309 | action_ref[ 310 | "lift_point"]) and action_matching.is_tap_action( 311 | action_pred["touch_point"], action_pred["lift_point"]): 312 | num_both_click += 1 313 | if match_label: 314 | corr_both_click += 1 315 | one_step_record['action_label'] = action_ref 316 | one_step_record['action_predict'] = action_pred 317 | one_step_record['is_match'] = match_label 318 | test_record[task][-1].append(one_step_record) 319 | 320 | f_json = open(f'aitw_record.json', 'w') 321 | json.dump(test_record, f_json, ensure_ascii=False, indent=4) 322 | f_json.close() 323 | except: 324 | num_wrong_format += 1 325 | print("Step: " + str(j) + " wrong format") 326 | 327 | score_average += corr_action / num 328 | 329 | print("Action Acc: " + str(corr_action / num)) 330 | print("Type Acc: " + str(corr_type / num)) 331 | print("Text Acc: " + str(corr_text / num_text)) 332 | print("Click Acc: " + str(corr_click / num_click)) 333 | print("Scroll Acc: " + str(corr_scroll / num_scroll)) 334 | print("Both Click Acc: " + str(corr_both_click / num_both_click)) 335 | print("Num Both Click: " + str(num_both_click)) 336 | print("Num wrong format: " + str(num_wrong_format)) 337 | 338 | print("Average score: " + str(score_average / 5)) 339 | 340 | 341 | -------------------------------------------------------------------------------- /workflow/mind2web/mind2web_test.py: -------------------------------------------------------------------------------- 1 | # evaluation on mind2web 2 | import os 3 | import random 4 | import torch 5 | import json 6 | from tqdm import tqdm 7 | from transformers import AutoModelForCausalLM, AutoTokenizer 8 | from peft import AutoPeftModelForCausalLM 9 | from transformers.generation import GenerationConfig 10 | import re 11 | import logging 12 | import ast 13 | import argparse 14 | from PIL import Image 15 | import numpy as np 16 | from langchain_community.vectorstores import FAISS 17 | from langchain.schema import Document 18 | from langchain_huggingface import HuggingFaceEmbeddings 19 | from collections import deque 20 | 21 | from prompts import MIND2WEB_GLOBAL_PLANNING_PROMT, MIND2WEB_OBSERVATION_PROMT, MIND2WEB_PLANNING_PROMT, MIND2WEB_EXECUTION_PROMT, PAGE_SUMMARY_PROMPT, REFERENCE_FORMAT, ACTION_SUMMARY_PROMPT 22 | import action_matching 23 | import requests 24 | # logging.basicConfig(level=print) 25 | 26 | def get_global_plan(img_path, goal): 27 | global_plan_prompt = MIND2WEB_GLOBAL_PLANNING_PROMT.replace('',goal) 28 | global_plan = chat([img_path], global_plan_prompt) 29 | return global_plan.split('### Global Plan ###')[-1].strip() 30 | 31 | def get_execution(img_path, action_plan, reference_actions): 32 | exec_prompt = MIND2WEB_EXECUTION_PROMT.replace('',action_plan).replace('',reference_actions) 33 | response = chat([img_path], exec_prompt) 34 | execution = response.split('### Action ###')[-1].strip() 35 | thought = response.split('### Thought ###')[-1].split('### Action ###')[0].strip() 36 | return thought, execution 37 | 38 | def get_observation(img_path, goal, previous_step): 39 | if previous_step == '': 40 | previous_step = '' 41 | obs_prompt = MIND2WEB_OBSERVATION_PROMT.replace('',goal).replace('',previous_step) 42 | observation = chat([img_path], obs_prompt) 43 | return observation 44 | 45 | def get_plan_action(img_path, goal, observations, global_plan, reference_actions, previous_step): 46 | if previous_step == '': 47 | previous_step = '' 48 | plan_prompt = MIND2WEB_PLANNING_PROMT.replace('',goal) 49 | plan_prompt = plan_prompt.replace('',observations) 50 | plan_prompt = plan_prompt.replace('',global_plan) 51 | plan_prompt = plan_prompt.replace('',reference_actions) 52 | plan_prompt = plan_prompt.replace('',previous_step) 53 | plan_action = chat([img_path], plan_prompt) 54 | return plan_action 55 | 56 | def bfs_goals(goal_list, idx, search_document): 57 | if idx is None: return 58 | queue = deque([(idx, 0)]) 59 | visited = set() 60 | visited.add(idx) 61 | 62 | while queue: 63 | cur_node, cur_depth = queue.popleft() 64 | if cur_depth >= 3: continue 65 | 66 | nxt_node_list = search_document[cur_node].metadata['next_page_list'] 67 | for nxt_node in nxt_node_list: 68 | if nxt_node['actions'] == []: continue 69 | if nxt_node['goal'] not in goal_list: 70 | goal_list.append(nxt_node['goal']) 71 | node_idx = nxt_node['page_index'] 72 | if node_idx is not None and node_idx not in visited: 73 | visited.add(node_idx) 74 | queue.append((node_idx, cur_depth+1)) 75 | 76 | 77 | def get_reference_actions(img_path, domain, goal, search_document, embedding_model): 78 | reference_actions = '' 79 | 80 | page_summary = chat([img_path], PAGE_SUMMARY_PROMPT) 81 | if domain in list(search_document.keys()): 82 | search_keys = [domain] 83 | else: 84 | search_keys = search_document.keys() 85 | 86 | max_count_final = 10 87 | count = 0 88 | max_count = 0 89 | for search_key in search_keys: 90 | max_count += int(max_count_final // len(search_keys)) 91 | 92 | 93 | vectorstore = FAISS.from_documents(search_document[search_key], embedding_model) 94 | search_res = vectorstore.similarity_search(page_summary) 95 | 96 | for res in search_res: 97 | for actions_chain in res.metadata['next_page_list']: 98 | if len(actions_chain['actions'])==0: continue 99 | count=count+1 100 | 101 | action_string = '' 102 | for one_action in actions_chain['actions']: 103 | if one_action[-1] == '.': 104 | one_action = one_action[:-1] 105 | action_string += ', ' + one_action 106 | action_string = action_string[2:] 107 | 108 | goal_list = [actions_chain['goal']] 109 | bfs_goals(goal_list, actions_chain['page_index'], search_document[search_key]) 110 | goals_string = '' 111 | for one_goal in goal_list:#random.sample(goal_list, min(10,len(goal_list))):# 112 | if one_goal[-1] == '.': 113 | one_goal = one_goal[:-1] 114 | goals_string += '; ' + one_goal 115 | goals_string = goals_string[2:] 116 | 117 | one_reference = REFERENCE_FORMAT.format(idx = count, actions = action_string, goals = goals_string) 118 | reference_actions += one_reference 119 | if count == max_count: break 120 | if count == max_count: break 121 | 122 | return reference_actions 123 | 124 | def document_transform(raw_document): 125 | search_document = {} 126 | for type_name, pages in raw_document.items(): 127 | document = [] 128 | for idx in pages: 129 | item = pages[idx] 130 | document.append(Document(page_content = item['page_summary'], metadata = item)) 131 | search_document[type_name] = document 132 | return search_document 133 | 134 | # convert action to prediction format (and return the groundtruth bbox) 135 | 136 | def action2description(action, img_path): 137 | action_type = action['operation']['op'] 138 | assert action_type in ['CLICK', 'TYPE', 'SELECT'] 139 | 140 | bbox = [int(action["bbox"]["x"]), int(action["bbox"]["y"]), int(action["bbox"]["x"] + action["bbox"]["width"]), 141 | int(action["bbox"]["y"] + action["bbox"]["height"])] 142 | bbox_str = f'[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]' 143 | if action_type == 'CLICK': 144 | query = ACTION_SUMMARY_PROMPT['click_action_summary'].format(bbox=bbox_str) 145 | elif action_type == 'TYPE': 146 | query = ACTION_SUMMARY_PROMPT['type_action_summary'].format(content=action['operation']['value'],bbox=bbox_str) 147 | elif action_type == 'SELECT': 148 | query = ACTION_SUMMARY_PROMPT['select_action_summary'].format(content=action['operation']['value'],bbox=bbox_str) 149 | 150 | action_summary = chat([img_path], query) 151 | if action_summary[-1] == '.': 152 | action_summary = action_summary[:-1] 153 | 154 | return action_summary 155 | 156 | def action2step(action, image_size, return_bbox=False): 157 | action_type = action["operation"]["original_op"] 158 | assert action_type in ['CLICK', 'TYPE', 'SELECT', 'HOVER', 'ENTER'] 159 | 160 | point_x = action["bbox"]["x"] + (action["bbox"]["width"] / 2) 161 | point_y = action["bbox"]["y"] + (action["bbox"]["height"] / 2) 162 | click_point = [point_x / image_size[0], point_y / image_size[1]] 163 | click_point = [round(item, 3) for item in click_point] 164 | click_point = [f"{item:.2f}" for item in click_point] 165 | 166 | click_point = "({},{})".format(int(float(click_point[0])*1000), int(float(click_point[1])*1000)) 167 | 168 | if return_bbox: 169 | bbox = [action["bbox"]["x"], action["bbox"]["y"], action["bbox"]["x"] + action["bbox"]["width"], 170 | action["bbox"]["y"] + action["bbox"]["height"]] 171 | bbox = [bbox[0] / image_size[0], bbox[1] / image_size[1], bbox[2] / image_size[0], bbox[3] / image_size[1]] 172 | bbox = [round(item, 3)*1000 for item in bbox] 173 | 174 | if action_type in ['CLICK', 'HOVER', 'ENTER']: 175 | action_step = "{{\"action_type\": {}, \"click_point\": {}}}".format(4, click_point) 176 | elif action_type == 'SELECT': 177 | select_value = action["operation"]["value"] 178 | action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(2, click_point, 179 | select_value) 180 | elif action_type == 'TYPE': 181 | typed_text = action["operation"]["value"] 182 | action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(3, click_point, 183 | typed_text) 184 | 185 | if return_bbox: 186 | return action_step, bbox 187 | else: 188 | return action_step 189 | 190 | 191 | # calculate action f1 following mind2web 192 | def calculate_f1(pred, label): 193 | pred = set(pred.strip().split()) 194 | label = set(label.strip().split()) 195 | if len(pred) == 0 and len(label) == 0: 196 | return 1 197 | if len(pred) == 0 or len(label) == 0: 198 | return 0 199 | 200 | tp = len(pred & label) 201 | fp = len(pred - label) 202 | fn = len(label - pred) 203 | precision = tp / (tp + fp) 204 | recall = tp / (tp + fn) 205 | if precision == 0 or recall == 0: 206 | return 0 207 | f1 = 2 * precision * recall / (precision + recall) 208 | return f1 209 | 210 | 211 | torch.manual_seed(0) 212 | torch.cuda.manual_seed_all(0) 213 | torch.backends.cudnn.deterministic = True 214 | torch.backends.cudnn.benchmark = False 215 | random.seed(0) 216 | torch.manual_seed(0) 217 | torch.cuda.manual_seed_all(0) 218 | torch.backends.cudnn.deterministic = True 219 | torch.backends.cudnn.benchmark = False 220 | random.seed(0) 221 | 222 | url = "http://localhost:8000/v1/chat/completions" 223 | headers = { 224 | "Content-Type": "application/json" 225 | } 226 | 227 | 228 | 229 | def chat(img_url_list: str = '', query: str = '') -> dict: 230 | 231 | content = [] 232 | for img_url in img_url_list: 233 | content.append({"type": "image_url", "image_url": {"url": img_url}}) 234 | content.append({"type": "text", "text": query}) 235 | data = { 236 | "model": "Qwen2.5-VL-72B-Instruct", 237 | "messages": [ 238 | {"role": "system", "content": "You are a powerful agent that is trained to perform some basic tasks on the web page."}, 239 | {"role": "user", "content": content} 240 | ], 241 | "temperature":0} 242 | response = requests.post(url, headers=headers, data=json.dumps(data)) 243 | response = response.json() 244 | response = response['choices'][0]['message']['content'] 245 | 246 | return response 247 | 248 | parser = argparse.ArgumentParser() 249 | parser.add_argument('--task', type=str, required=True) 250 | args = parser.parse_args() 251 | 252 | embedding_model_name = "bge-m3" 253 | embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'}) 254 | mind2web_imgs_dir = 'mind2web_images/' 255 | mind2web_test = json.load(open('mind2web_annots/mind2web_data_test_' + args.task + '.json', 'r')) 256 | 257 | raw_document = json.load(open('mind2web_library.json', 'r')) 258 | search_document = document_transform(raw_document) 259 | 260 | 261 | results = [] 262 | for episode in tqdm(mind2web_test): 263 | domain = episode['domain'] 264 | goal = episode["confirmed_task"] 265 | annot_id = episode["annotation_id"] 266 | previous_actions = [] 267 | results_actions = [] 268 | global_plan = '' 269 | flag = 0 270 | 271 | for j, step in enumerate(episode["actions"]): 272 | if "bbox" not in step: 273 | print("action not found") 274 | continue 275 | 276 | filename = annot_id + '-' + step["action_uid"] + '.jpg' 277 | img_path = os.path.join(mind2web_imgs_dir, filename) 278 | img_path_server = 'http://localhost:6667/mind2web_images/' + filename 279 | if not os.path.exists(img_path): 280 | print("img not found") 281 | continue 282 | image = Image.open(img_path) 283 | 284 | previous_step = "" 285 | for i, action in enumerate(previous_actions[-4:]): 286 | previous_step += 'Step' + str(i+1) + ': ' + action + ". \n" 287 | 288 | action_step = action2description(step, img_path_server) 289 | previous_actions.append(action_step) 290 | 291 | action_step_ref, bbox_ref = action2step(step, [1000,1000], return_bbox=True) 292 | try: 293 | action_step_ref = ast.literal_eval(action_step_ref) 294 | except: 295 | print('# error action_step_ref') 296 | continue 297 | 298 | observations = get_observation(img_path_server, goal, previous_step) 299 | reference_actions = get_reference_actions(img_path_server, domain, goal, search_document, embedding_model) 300 | if flag == 0: 301 | global_plan = get_global_plan(img_path_server, goal) 302 | flag = 1 303 | plan_action = get_plan_action(img_path_server, goal, observations, global_plan, reference_actions, previous_step) 304 | thought, response = get_execution(img_path_server, plan_action, reference_actions) 305 | 306 | step_result = {"annot_id": annot_id, "step" : j+1, "img_path": img_path, "instruction": goal, "sentence": response, 307 | "Op_match": False, "Ele_match": False, "Op_F1": [0, action_step_ref["action_type"]]} 308 | 309 | if 0 < 1: 310 | action_pred = action_matching.convert_qwen_format_mind2web(response) 311 | 312 | if action_pred["action_type"] == action_step_ref["action_type"]: 313 | step_result["Op_match"] = True 314 | 315 | click_point = action_pred["click_point"] 316 | 317 | if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]): 318 | step_result["Ele_match"] = True 319 | 320 | 321 | pred_str = str(action_pred["action_type"]) 322 | if action_pred["action_type"] == 3 or action_pred["action_type"] == 2: 323 | pred_str += ' ' 324 | pred_str += action_pred["value"].lower() 325 | ref_str = str(action_step_ref["action_type"]) 326 | if action_step_ref["action_type"] == 3 or action_step_ref["action_type"] == 2: 327 | ref_str += ' ' 328 | ref_str += action_step_ref["value"].lower() 329 | 330 | op_f1 = calculate_f1(pred_str, ref_str) 331 | step_result["Op_F1"][0] = op_f1 332 | 333 | print(f'-------step:{j+1}----------') 334 | print('Goal: ', goal) 335 | print('Img: ', img_path) 336 | print('History: ', previous_step) 337 | print('gt: ', step['operation']['op'],step['operation']['value'],bbox_ref) 338 | print('Observation: ', observations) 339 | print('pred: ', action_pred) 340 | print('Global Planning: \n', global_plan) 341 | print('References: \n',reference_actions) 342 | print('Loacl Planning: \n', plan_action) 343 | print('Thought: ',thought) 344 | print('Decision: \n', response) 345 | print('---------------------------------------------') 346 | 347 | 348 | results_actions.append(step_result) 349 | 350 | results.append(results_actions) 351 | 352 | f_json = open(f'mind2web_record.json', 'w') 353 | json.dump(results, f_json, ensure_ascii=False, indent=4) 354 | f_json.close() 355 | 356 | 357 | # calculate metrics 358 | num_step = 0 359 | num_episode = 0 360 | num_op = 0 361 | num_ele = 0 362 | op_f1 = {4: [], 2: [], 3: []} 363 | macro_ele_acc = {} 364 | macro_step_acc = {} 365 | macro_action_f1 = {} 366 | num_step_success = 0 367 | num_episode_success = 0 368 | for i, item in enumerate(results): 369 | macro_ele_acc[i] = [] 370 | macro_step_acc[i] = [] 371 | macro_action_f1[i] = [] 372 | num_episode += 1 373 | episode_success = True 374 | for step_result in item: 375 | num_step += 1 376 | 377 | if step_result["Op_match"]: 378 | num_op += 1 379 | 380 | if step_result["Ele_match"]: 381 | num_ele += 1 382 | macro_ele_acc[i].append(1) 383 | else: 384 | macro_ele_acc[i].append(0) 385 | 386 | if step_result["Op_F1"][1] in op_f1: 387 | op_f1[step_result["Op_F1"][1]].append(step_result["Op_F1"][0]) 388 | macro_action_f1[i].append(step_result["Op_F1"][0]) 389 | 390 | if step_result["Op_F1"][0] == 1.0 and step_result["Ele_match"]: 391 | num_step_success += 1 392 | macro_step_acc[i].append(1) 393 | else: 394 | macro_step_acc[i].append(0) 395 | episode_success = False 396 | 397 | if episode_success: 398 | num_episode_success += 1 399 | 400 | marco_op_f1 = np.mean([np.mean(x) for x in op_f1.values()]) 401 | 402 | print("Operation F1: " + str(marco_op_f1)) 403 | print("Element Acc: " + str(num_ele / num_step)) 404 | print("Step Success: " + str(num_step_success / num_step)) 405 | print("Episode Success: " + str(num_episode_success / num_episode)) 406 | print("Operation F1 cate: " + str([np.mean(x) for x in op_f1.values()])) 407 | 408 | macro_ele_acc = np.mean([np.mean(x) for x in macro_ele_acc.values()]) 409 | macro_step_acc = np.mean([np.mean(x) for x in macro_step_acc.values()]) 410 | macro_action_f1 = np.mean([np.mean(x) for x in macro_action_f1.values()]) 411 | print("Macro Ele Acc: " + str(macro_ele_acc)) 412 | print("Macro Op F1: " + str(macro_action_f1)) 413 | print("Macro Step SR: " + str(macro_step_acc)) 414 | 415 | -------------------------------------------------------------------------------- /workflow/aitw/action_matching.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild 3 | ''' 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import os 9 | 10 | import action_type as action_type_lib 11 | from PIL import Image 12 | 13 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen 14 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4 15 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4 16 | 17 | # Interval determining if an action is a tap or a swipe. 18 | _SWIPE_DISTANCE_THRESHOLD = 0.04 19 | 20 | 21 | def _yx_in_bounding_boxes( 22 | yx, bounding_boxes 23 | ): 24 | """Check if the (y,x) point is contained in each bounding box. 25 | 26 | Args: 27 | yx: The (y, x) coordinate in pixels of the point. 28 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row 29 | represents a bounding box: (y_top_left, x_top_left, box_height, 30 | box_width). Note: containment is inclusive of the bounding box edges. 31 | 32 | Returns: 33 | is_inside: A 1D bool array where each element specifies if the point is 34 | contained within the respective box. 35 | """ 36 | y, x = yx 37 | 38 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the 39 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension. 40 | top, left, height, width = [ 41 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1) 42 | ] 43 | 44 | # The y-axis is inverted for AndroidEnv, so bottom = top + height. 45 | bottom, right = top + height, left + width 46 | 47 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and( 48 | x >= left, x <= right) 49 | 50 | 51 | def _resize_annotation_bounding_boxes( 52 | annotation_positions, annotation_width_augment_fraction, 53 | annotation_height_augment_fraction): 54 | """Resize the bounding boxes by the given fractions. 55 | 56 | Args: 57 | annotation_positions: Array of shape (N, 4), where each row represents the 58 | (y, x, height, width) of the bounding boxes. 59 | annotation_width_augment_fraction: The fraction to augment the box widths, 60 | E.g., 1.4 == 240% total increase. 61 | annotation_height_augment_fraction: Same as described for width, but for box 62 | height. 63 | 64 | Returns: 65 | Resized bounding box. 66 | 67 | """ 68 | height_change = ( 69 | annotation_height_augment_fraction * annotation_positions[:, 2]) 70 | width_change = ( 71 | annotation_width_augment_fraction * annotation_positions[:, 3]) 72 | 73 | # Limit bounding box positions to the screen. 74 | resized_annotations = jnp.stack([ 75 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)), 76 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)), 77 | jnp.minimum(1, annotation_positions[:, 2] + height_change), 78 | jnp.minimum(1, annotation_positions[:, 3] + width_change), 79 | ], 80 | axis=1) 81 | return resized_annotations 82 | 83 | 84 | def is_tap_action(normalized_start_yx, 85 | normalized_end_yx): 86 | distance = jnp.linalg.norm( 87 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 88 | return distance <= _SWIPE_DISTANCE_THRESHOLD 89 | 90 | 91 | def _is_non_dual_point_action(action_type): 92 | return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT) 93 | 94 | 95 | def _check_tap_actions_match( 96 | tap_1_yx, 97 | tap_2_yx, 98 | annotation_positions, 99 | matching_tap_distance_threshold_screen_percentage, 100 | annotation_width_augment_fraction, 101 | annotation_height_augment_fraction, 102 | ): 103 | """Determines if two tap actions are the same.""" 104 | resized_annotation_positions = _resize_annotation_bounding_boxes( 105 | annotation_positions, 106 | annotation_width_augment_fraction, 107 | annotation_height_augment_fraction, 108 | ) 109 | 110 | # Check if the ground truth tap action falls in an annotation's bounding box. 111 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions) 112 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions) 113 | both_in_box = jnp.max(tap1_in_box & tap2_in_box) 114 | 115 | # If the ground-truth tap action falls outside any of the annotation 116 | # bounding boxes or one of the actions is inside a bounding box and the other 117 | # is outside bounding box or vice versa, compare the points using Euclidean 118 | # distance. 119 | within_threshold = ( 120 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx)) 121 | <= matching_tap_distance_threshold_screen_percentage 122 | ) 123 | return jnp.logical_or(both_in_box, within_threshold) 124 | 125 | 126 | def _check_drag_actions_match( 127 | drag_1_touch_yx, 128 | drag_1_lift_yx, 129 | drag_2_touch_yx, 130 | drag_2_lift_yx, 131 | ): 132 | """Determines if two drag actions are the same.""" 133 | # Store drag deltas (the change in the y and x coordinates from touch to 134 | # lift), magnitudes, and the index of the main axis, which is the axis with 135 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 136 | # ending at (0.3, 0.5) has a main axis index of 1). 137 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx 138 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 139 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 140 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx 141 | drag_2_magnitudes = jnp.abs(drag_2_deltas) 142 | drag_2_main_axis = np.argmax(drag_2_magnitudes) 143 | 144 | return jnp.equal(drag_1_main_axis, drag_2_main_axis) #只判断滑动的方向 145 | 146 | 147 | def check_actions_match( 148 | action_1_touch_yx, 149 | action_1_lift_yx, 150 | action_1_action_type, 151 | action_2_touch_yx, 152 | action_2_lift_yx, 153 | action_2_action_type, 154 | annotation_positions, 155 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD, 156 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION, 157 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION, 158 | ): 159 | """Determines if two actions are considered to be the same. 160 | 161 | Two actions being "the same" is defined here as two actions that would result 162 | in a similar screen state. 163 | 164 | Args: 165 | action_1_touch_yx: The (y, x) coordinates of the first action's touch. 166 | action_1_lift_yx: The (y, x) coordinates of the first action's lift. 167 | action_1_action_type: The action type of the first action. 168 | action_2_touch_yx: The (y, x) coordinates of the second action's touch. 169 | action_2_lift_yx: The (y, x) coordinates of the second action's lift. 170 | action_2_action_type: The action type of the second action. 171 | annotation_positions: The positions of the UI annotations for the screen. It 172 | is A 2D int array of shape (num_bboxes, 4), where each row represents a 173 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that 174 | containment is inclusive of the bounding box edges. 175 | tap_distance_threshold: The threshold that determines if two taps result in 176 | a matching screen state if they don't fall the same bounding boxes. 177 | annotation_width_augment_fraction: The fraction to increase the width of the 178 | bounding box by. 179 | annotation_height_augment_fraction: The fraction to increase the height of 180 | of the bounding box by. 181 | 182 | Returns: 183 | A boolean representing whether the two given actions are the same or not. 184 | """ 185 | action_1_touch_yx = jnp.asarray(action_1_touch_yx) 186 | action_1_lift_yx = jnp.asarray(action_1_lift_yx) 187 | action_2_touch_yx = jnp.asarray(action_2_touch_yx) 188 | action_2_lift_yx = jnp.asarray(action_2_lift_yx) 189 | 190 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT), 191 | # because if that is the case, only the actions' types need to be compared. 192 | has_non_dual_point_action = jnp.logical_or( 193 | _is_non_dual_point_action(action_1_action_type), 194 | _is_non_dual_point_action(action_2_action_type), 195 | ) 196 | #print("non dual point: "+str(has_non_dual_point_action)) 197 | 198 | different_dual_point_types = jnp.logical_xor( 199 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 200 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 201 | ) 202 | #print("different dual type: "+str(different_dual_point_types)) 203 | 204 | is_tap = jnp.logical_and( 205 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 206 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 207 | ) 208 | #print("is tap: "+str(is_tap)) 209 | 210 | taps_match = _check_tap_actions_match( 211 | action_1_touch_yx, 212 | action_2_touch_yx, 213 | annotation_positions, 214 | tap_distance_threshold, 215 | annotation_width_augment_fraction, 216 | annotation_height_augment_fraction, 217 | ) 218 | #print("tap match: "+str(taps_match)) 219 | 220 | taps_match = jnp.logical_and(is_tap, taps_match) 221 | #print("tap match: "+str(taps_match)) 222 | 223 | drags_match = _check_drag_actions_match( 224 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx 225 | ) 226 | drags_match = jnp.where(is_tap, False, drags_match) 227 | #print("drag match: "+str(drags_match)) 228 | 229 | return jnp.where( 230 | has_non_dual_point_action, 231 | jnp.equal(action_1_action_type, action_2_action_type), 232 | jnp.where( 233 | different_dual_point_types, 234 | False, 235 | jnp.logical_or(taps_match, drags_match), 236 | ), 237 | ) 238 | 239 | 240 | def action_2_format(step_data): 241 | # 把test数据集中的动作格式转换为计算matching score的格式 242 | action_type = step_data["action_type_id"] 243 | 244 | if action_type == 4: 245 | if step_data["action_type_text"] == 'click': # 点击 246 | touch_point = step_data["touch"] 247 | lift_point = step_data["lift"] 248 | else: # 上下左右滑动 249 | if step_data["action_type_text"] == 'scroll down': 250 | touch_point = [0.5, 0.8] 251 | lift_point = [0.5, 0.2] 252 | elif step_data["action_type_text"] == 'scroll up': 253 | touch_point = [0.5, 0.2] 254 | lift_point = [0.5, 0.8] 255 | elif step_data["action_type_text"] == 'scroll left': 256 | touch_point = [0.2, 0.5] 257 | lift_point = [0.8, 0.5] 258 | elif step_data["action_type_text"] == 'scroll right': 259 | touch_point = [0.8, 0.5] 260 | lift_point = [0.2, 0.5] 261 | else: 262 | touch_point = [-1.0, -1.0] 263 | lift_point = [-1.0, -1.0] 264 | 265 | if action_type == 3: 266 | typed_text = step_data["type_text"] 267 | else: 268 | typed_text = "" 269 | 270 | action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point, 271 | "typed_text": typed_text} 272 | 273 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 274 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 275 | action["typed_text"] = action["typed_text"].lower() 276 | 277 | return action 278 | 279 | 280 | def pred_2_format(step_data): 281 | # 把模型输出的内容转换为计算action_matching的格式 282 | action_type = step_data["action_type"] 283 | 284 | if action_type == 4: # 点击 285 | action_type_new = 4 286 | touch_point = step_data["click_point"] 287 | lift_point = step_data["click_point"] 288 | typed_text = "" 289 | elif action_type == 0: 290 | action_type_new = 4 291 | touch_point = [0.5, 0.8] 292 | lift_point = [0.5, 0.2] 293 | typed_text = "" 294 | elif action_type == 1: 295 | action_type_new = 4 296 | touch_point = [0.5, 0.2] 297 | lift_point = [0.5, 0.8] 298 | typed_text = "" 299 | elif action_type == 8: 300 | action_type_new = 4 301 | touch_point = [0.2, 0.5] 302 | lift_point = [0.8, 0.5] 303 | typed_text = "" 304 | elif action_type == 9: 305 | action_type_new = 4 306 | touch_point = [0.8, 0.5] 307 | lift_point = [0.2, 0.5] 308 | typed_text = "" 309 | else: 310 | action_type_new = action_type 311 | touch_point = [-1.0, -1.0] 312 | lift_point = [-1.0, -1.0] 313 | typed_text = "" 314 | if action_type_new == 3: 315 | typed_text = step_data["typed_text"] 316 | 317 | action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point, 318 | "typed_text": typed_text} 319 | 320 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 321 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 322 | action["typed_text"] = action["typed_text"].lower() 323 | 324 | return action 325 | 326 | def pred_2_format_4_mpgui(step_data,img_filename=''): 327 | # 把模型输出的内容转换为计算action_matching的格式 328 | action_type = step_data["action_type"] 329 | 330 | if img_filename != '': 331 | img_path = 'AITW_simplified/aitw_images/' + img_filename 332 | w, h = Image.open(img_path).size 333 | else: 334 | w, h = 1000, 1000 335 | 336 | 337 | if action_type == 4: # 点击 338 | action_type_new = 4 339 | 340 | if 'click_point' in step_data: 341 | touch_point = step_data["click_point"] 342 | lift_point = step_data["click_point"] 343 | # for MP-GUI 344 | if touch_point[0] > 1.: 345 | touch_point = [touch_point[0]/w, touch_point[1]/h] 346 | if lift_point[0] > 1: 347 | lift_point = [lift_point[0]/w, lift_point[1]/h] 348 | else: 349 | print(f'$$ error pred step: {step_data}') 350 | touch_point = [0., 0.] 351 | lift_point = [0., 0.] 352 | typed_text = "" 353 | elif action_type == 0: 354 | action_type_new = 4 355 | touch_point = [0.5, 0.8] 356 | lift_point = [0.5, 0.2] 357 | typed_text = "" 358 | elif action_type == 1: 359 | action_type_new = 4 360 | touch_point = [0.5, 0.2] 361 | lift_point = [0.5, 0.8] 362 | typed_text = "" 363 | elif action_type == 8: 364 | action_type_new = 4 365 | touch_point = [0.2, 0.5] 366 | lift_point = [0.8, 0.5] 367 | typed_text = "" 368 | elif action_type == 9: 369 | action_type_new = 4 370 | touch_point = [0.8, 0.5] 371 | lift_point = [0.2, 0.5] 372 | typed_text = "" 373 | else: 374 | action_type_new = action_type 375 | touch_point = [-1.0, -1.0] 376 | lift_point = [-1.0, -1.0] 377 | typed_text = "" 378 | if action_type_new == 3: 379 | typed_text = step_data["typed_text"] 380 | 381 | action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point, 382 | "typed_text": typed_text} 383 | 384 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 385 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 386 | action["typed_text"] = action["typed_text"].lower() 387 | 388 | return action 389 | 390 | def convert_qwen_format(response): 391 | pred_action = response 392 | # pred_action = response.split('### Action ###')[-1].strip() 393 | # print(pred_action) 394 | item = {} 395 | if 'Click' in pred_action: 396 | action_id = 4 397 | try: 398 | x, y = pred_action.split('(')[-1].split(')')[0].split(',') 399 | x, y = int(x), int(y) 400 | except: 401 | x,y = 0, 0 402 | item = { 403 | 'action_type': action_id, 404 | 'click_point': (x,y) 405 | } 406 | elif 'Scroll("up")' in pred_action: 407 | item = { 408 | 'action_type': 1 409 | } 410 | elif 'Scroll("down")' in pred_action: 411 | item = { 412 | 'action_type': 0 413 | } 414 | elif 'Scroll("left")' in pred_action: 415 | item = { 416 | 'action_type': 8 417 | } 418 | elif 'Scroll("right")' in pred_action: 419 | item = { 420 | 'action_type': 9 421 | } 422 | elif 'Type' in pred_action: 423 | text = pred_action.split('("')[-1].split('")')[0] 424 | item = { 425 | 'action_type': 3, 426 | 'typed_text': text 427 | } 428 | elif 'Complete' in pred_action: 429 | item ={ 430 | 'action_type': 10 431 | } 432 | elif 'Back' in pred_action: 433 | item ={ 434 | 'action_type': 5 435 | } 436 | elif 'Home' in pred_action: 437 | item ={ 438 | 'action_type': 6 439 | } 440 | elif 'Enter' in pred_action: 441 | item ={ 442 | 'action_type': 7 443 | } 444 | else: 445 | item ={ 446 | 'action_type': 2 #error 447 | } 448 | return item 449 | 450 | # def convert_qwen_format_mind2web(response): 451 | # pred_action = response#.split('### Action')[-1].strip() 452 | 453 | # item = {} 454 | 455 | # if 'Click' in pred_action: 456 | # try: 457 | # x, y = pred_action.split('(')[-1].split(')')[0].split(',') 458 | # x, y = int(x), int(y) 459 | # click_point = (x, y) 460 | # except: 461 | # x,y = 0, 0 462 | # click_point = (x, y) 463 | # item = {"action_type": 4, "click_point": click_point} 464 | 465 | # elif 'Type' in pred_action: 466 | # try: 467 | # # Type(x,y,"typed_text") 468 | # s = pred_action.split('(')[-1] 469 | # x, y, tp_txt = s.split(',') 470 | # x, y = int(x), int(y) 471 | # click_point = (x, y) 472 | # select_value = tp_txt.replace('"','').replace(')', '') 473 | # except: 474 | # click_point = (0,0) 475 | # select_value = '' 476 | # item = {"action_type": 3, "click_point": click_point, "value": select_value} 477 | 478 | # elif 'Select' in pred_action: 479 | # try: 480 | # s = pred_action.split('(')[-1] 481 | # x, y, tp_txt = s.split(',') 482 | # x, y = int(x), int(y) 483 | # click_point = (x, y) 484 | # select_value = tp_txt.replace('"','').replace(')', '') 485 | # except: 486 | # click_point = (0,0) 487 | # select_value = '' 488 | # item = {"action_type": 3, "click_point": click_point, "value": select_value} 489 | # else: 490 | # item = {"action_type": 0, "click_point": (0,0)} 491 | 492 | # return item 493 | def convert_qwen_format_mind2web(response): 494 | pred_action = response#.split('### Action')[-1].strip() 495 | 496 | item = {} 497 | 498 | if 'Click' in pred_action: 499 | try: 500 | # print(pred_action) 501 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 502 | x, y = s.split(',')[0], s.split(',')[1] 503 | x, y = int(x), int(y) 504 | click_point = (x, y) 505 | except: 506 | x,y = 0, 0 507 | click_point = (x, y) 508 | item = {"action_type": 4, "click_point": click_point} 509 | 510 | elif 'Type' in pred_action: 511 | try: 512 | # print(pred_action) 513 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 514 | x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:]) 515 | x, y = int(x), int(y) 516 | click_point = (x, y) 517 | typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 518 | # print(select_value) 519 | except: 520 | click_point = (0,0) 521 | typed_text = '' 522 | item = {"action_type": 3, "click_point": click_point, "value": typed_text} 523 | 524 | elif 'Select' in pred_action: 525 | try: 526 | # print(pred_action) 527 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 528 | x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:]) 529 | x, y = int(x), int(y) 530 | click_point = (x, y) 531 | select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 532 | # print(select_value) 533 | except: 534 | click_point = (0,0) 535 | select_value = '' 536 | item = {"action_type": 2, "click_point": click_point, "value": select_value} 537 | else: 538 | item = {"action_type": 0, "click_point": (0,0)} 539 | 540 | return item 541 | 542 | def convert_qwen_format_mind2web_InternVL(response): 543 | pred_action = response#.split('### Action')[-1].strip() 544 | 545 | item = {} 546 | 547 | if 'Click' in pred_action: 548 | try: 549 | # print(pred_action) 550 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 551 | [x1, y1, x2, y2] = s.split(',') 552 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 553 | click_point = (x, y) 554 | except: 555 | x,y = 0, 0 556 | click_point = (x, y) 557 | item = {"action_type": 4, "click_point": click_point} 558 | 559 | elif 'Type' in pred_action: 560 | try: 561 | # print(pred_action) 562 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 563 | x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:]) 564 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 565 | click_point = (x, y) 566 | typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 567 | # print(select_value) 568 | except: 569 | click_point = (0,0) 570 | typed_text = '' 571 | item = {"action_type": 3, "click_point": click_point, "value": typed_text} 572 | 573 | elif 'Select' in pred_action: 574 | try: 575 | # print(pred_action) 576 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 577 | x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:]) 578 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 579 | click_point = (x, y) 580 | select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 581 | # print(select_value) 582 | except: 583 | click_point = (0,0) 584 | select_value = '' 585 | item = {"action_type": 2, "click_point": click_point, "value": select_value} 586 | else: 587 | item = {"action_type": 0, "click_point": (0,0)} 588 | 589 | return item 590 | 591 | def simple_decode(gt, img_path=None): 592 | idx = gt.find(':') 593 | if idx == -1: 594 | action = gt 595 | info = "" 596 | else: 597 | action = gt[:idx].strip() 598 | info = gt[idx+1:].strip() 599 | if action in ['CLICK', "LONG_PRESS"]: 600 | info = eval(info) 601 | if img_path is not None: 602 | img_path = 'GUI-Odyssey-master/data/screenshots/' + img_path 603 | w, h = Image.open(img_path).size 604 | info = (info[0] / w * 1000, info[1] / h * 1000) 605 | 606 | return {"action": action, "info": info} 607 | 608 | 609 | TEXT_ANLS_THRESHOLD = 0.5 610 | CLICK_COORD_THRESHOLD = 0.14 611 | 612 | def levenshtein_distance(s1, s2): 613 | if len(s1) > len(s2): 614 | s1, s2 = s2, s1 615 | 616 | distances = range(len(s1) + 1) 617 | for i2, c2 in enumerate(s2): 618 | distances_ = [i2+1] 619 | for i1, c1 in enumerate(s1): 620 | if c1 == c2: 621 | distances_.append(distances[i1]) 622 | else: 623 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 624 | distances = distances_ 625 | return distances[-1] 626 | 627 | 628 | def text_matching(gt, pred): 629 | gt = gt.strip() 630 | pred = pred.strip() 631 | if gt in pred or pred in gt: 632 | return True 633 | 634 | dist = levenshtein_distance(gt, pred) 635 | length = max(len(gt), len(pred)) 636 | value = 0.0 if length == 0 else float(dist) / float(length) 637 | value = 1 - value 638 | return value >= TEXT_ANLS_THRESHOLD 639 | 640 | 641 | def click_matching(gt_info, pred_info): 642 | if type(pred_info) == str: 643 | pred_info = eval(pred_info) 644 | if type(gt_info) == str: 645 | gt_info = eval(gt_info) 646 | 647 | pred = np.asarray(pred_info) / 1000 648 | gt = np.asarray(gt_info) / 1000 649 | 650 | return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD 651 | 652 | 653 | 654 | def action_matching(pred_action, pred_info, gt_action, gt_info): 655 | pred_action = pred_action.strip() 656 | if type(pred_info) == str: 657 | pred_info = pred_info.strip() 658 | gt_action = gt_action.strip() 659 | if type(gt_info) == str: 660 | gt_info = gt_info.strip() 661 | 662 | if pred_action != gt_action: 663 | return {'is_correct': 'no', 'info': 'action_fail'} 664 | 665 | if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']: 666 | return {'is_correct': 'yes', 'info': 'action_correct'} 667 | 668 | elif gt_action == 'TYPE': 669 | text_flag = text_matching(gt_info, pred_info) 670 | 671 | if text_flag: 672 | return {'is_correct': 'yes', 'info': 'type_correct'} 673 | else: 674 | return {'is_correct': 'no', 'info': 'type_fail'} 675 | 676 | elif gt_action == 'SCROLL': 677 | if gt_info.lower() == pred_info.lower(): 678 | return {'is_correct': 'yes', 'info': 'scroll_correct'} 679 | else: 680 | return {'is_correct': 'no', 'info': 'scroll_fail'} 681 | 682 | elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS': 683 | click_flag = click_matching(gt_info, pred_info) 684 | 685 | if click_flag: 686 | return {'is_correct': 'yes', 'info': 'click_correct'} 687 | else: 688 | return {'is_correct': 'no', 'info': 'click_fail'} 689 | 690 | else: 691 | raise ValueError('Invalid action type') 692 | 693 | def stat_result(eval_dict, metric): 694 | text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct']) 695 | type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail']) 696 | text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')]) 697 | 698 | if metric == 'macro': 699 | action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes']) 700 | AMS = round(action_correct / len(eval_dict) * 100, 2) 701 | SR_cnt, SR_tot, SR = check_SR(eval_dict) 702 | elif metric == 'micro': 703 | task_cate_dict = {} 704 | acc_list = [] 705 | SR_list = [] 706 | # print(eval_dict) 707 | for sample in eval_dict: 708 | cat = sample['more_info']['category'] 709 | if cat not in task_cate_dict: 710 | task_cate_dict[cat] = [] 711 | task_cate_dict[cat].append(sample) 712 | 713 | 714 | # assert len(task_cate_dict) == 6 #总共6个类别的数据,跑部分数据可以注释掉 715 | for k, v in task_cate_dict.items(): 716 | SR_cnt, SR_tot, SR = check_SR(v) 717 | SR_list.append((SR)) 718 | acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2) 719 | acc_list.append(acc) 720 | print(f'category: {k}, AMS: {acc}, SR: {SR}') 721 | 722 | AMS = np.round(np.mean(acc_list), 2) 723 | SR = np.round(np.mean(SR_list), 2) 724 | 725 | else: 726 | raise ValueError(f'No metric {metric} found.') 727 | 728 | info = { 729 | 'AMS': AMS, 730 | 'SR': SR, 731 | 'total': len(eval_dict), 732 | 'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100), 733 | 'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100), 734 | } 735 | 736 | return info 737 | 738 | def check_SR(eval_dict): 739 | episode_dict = {} 740 | steps_map = {} 741 | for data in eval_dict: 742 | if 'img' in data: img = data['img'] 743 | elif 'image' in data: img = data['image'] 744 | else: img = data['question'].split('')[0].split('')[1] 745 | img = os.path.basename(img) 746 | tail = img.split('_')[-1] 747 | episode = img.replace(f'_{tail}', '') 748 | if episode not in episode_dict: 749 | episode_dict[episode] = [] 750 | else: 751 | assert steps_map[episode] == data['more_info']['step_length'] 752 | 753 | info = data['is_correct'] 754 | episode_dict[episode].append(info) 755 | steps_map[episode] = data['more_info']['step_length'] 756 | 757 | cnt, tot = 0, 0 758 | # print('=== ',episode_dict) 759 | for k, v in episode_dict.items(): 760 | if len(v) != steps_map[k]: 761 | print(f'step length of {k} does not match.') 762 | continue 763 | tot += 1 764 | v = list(set(v)) 765 | if len(v) == 1 and v[0] == 'yes': 766 | cnt += 1 767 | 768 | SR = round(cnt / tot * 100, 2) 769 | print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}') 770 | return cnt, tot, SR 771 | 772 | def odyssey_action_matching_evaluation(pred_output, metric='macro'): 773 | eval_dict = [] 774 | for idx, sample in enumerate(pred_output): 775 | 776 | question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info'] 777 | # sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info} 778 | sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info, 'img':sample['img']} 779 | 780 | gt_simple_info = simple_decode(gt) 781 | gt_action = gt_simple_info['action'] 782 | gt_info = gt_simple_info['info'] 783 | 784 | try: 785 | pred_simple_info = simple_decode(pred, sample['img']) 786 | # print('pred_simple_info:', pred_simple_info) 787 | pred_action = pred_simple_info['action'] 788 | pred_info = pred_simple_info['info'] 789 | except: 790 | # print('### eval err:', idx, pred) 791 | log_info = {'is_correct': 'no', 'info': 'decode invalid'} 792 | sample_eval_dict.update(log_info) 793 | eval_dict.append(sample_eval_dict) 794 | continue 795 | 796 | try: 797 | check_match = action_matching(pred_action, pred_info, gt_action, gt_info) 798 | except Exception as exc: 799 | print('$$$ eval err:', gt, pred, exc) 800 | check_match = {'is_correct': 'no', 'info': 'match invalid'} 801 | 802 | sample_eval_dict.update(check_match) 803 | eval_dict.append(sample_eval_dict) 804 | 805 | # print('===== ',eval_dict) 806 | info = stat_result(eval_dict, metric) 807 | metrics = {"info": info, "pred": eval_dict} 808 | return metrics -------------------------------------------------------------------------------- /workflow/mind2web/action_matching.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild 3 | ''' 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import os 9 | 10 | import action_type as action_type_lib 11 | from PIL import Image 12 | 13 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen 14 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4 15 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4 16 | 17 | # Interval determining if an action is a tap or a swipe. 18 | _SWIPE_DISTANCE_THRESHOLD = 0.04 19 | 20 | 21 | def _yx_in_bounding_boxes( 22 | yx, bounding_boxes 23 | ): 24 | """Check if the (y,x) point is contained in each bounding box. 25 | 26 | Args: 27 | yx: The (y, x) coordinate in pixels of the point. 28 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row 29 | represents a bounding box: (y_top_left, x_top_left, box_height, 30 | box_width). Note: containment is inclusive of the bounding box edges. 31 | 32 | Returns: 33 | is_inside: A 1D bool array where each element specifies if the point is 34 | contained within the respective box. 35 | """ 36 | y, x = yx 37 | 38 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the 39 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension. 40 | top, left, height, width = [ 41 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1) 42 | ] 43 | 44 | # The y-axis is inverted for AndroidEnv, so bottom = top + height. 45 | bottom, right = top + height, left + width 46 | 47 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and( 48 | x >= left, x <= right) 49 | 50 | 51 | def _resize_annotation_bounding_boxes( 52 | annotation_positions, annotation_width_augment_fraction, 53 | annotation_height_augment_fraction): 54 | """Resize the bounding boxes by the given fractions. 55 | 56 | Args: 57 | annotation_positions: Array of shape (N, 4), where each row represents the 58 | (y, x, height, width) of the bounding boxes. 59 | annotation_width_augment_fraction: The fraction to augment the box widths, 60 | E.g., 1.4 == 240% total increase. 61 | annotation_height_augment_fraction: Same as described for width, but for box 62 | height. 63 | 64 | Returns: 65 | Resized bounding box. 66 | 67 | """ 68 | height_change = ( 69 | annotation_height_augment_fraction * annotation_positions[:, 2]) 70 | width_change = ( 71 | annotation_width_augment_fraction * annotation_positions[:, 3]) 72 | 73 | # Limit bounding box positions to the screen. 74 | resized_annotations = jnp.stack([ 75 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)), 76 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)), 77 | jnp.minimum(1, annotation_positions[:, 2] + height_change), 78 | jnp.minimum(1, annotation_positions[:, 3] + width_change), 79 | ], 80 | axis=1) 81 | return resized_annotations 82 | 83 | 84 | def is_tap_action(normalized_start_yx, 85 | normalized_end_yx): 86 | distance = jnp.linalg.norm( 87 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 88 | return distance <= _SWIPE_DISTANCE_THRESHOLD 89 | 90 | 91 | def _is_non_dual_point_action(action_type): 92 | return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT) 93 | 94 | 95 | def _check_tap_actions_match( 96 | tap_1_yx, 97 | tap_2_yx, 98 | annotation_positions, 99 | matching_tap_distance_threshold_screen_percentage, 100 | annotation_width_augment_fraction, 101 | annotation_height_augment_fraction, 102 | ): 103 | """Determines if two tap actions are the same.""" 104 | resized_annotation_positions = _resize_annotation_bounding_boxes( 105 | annotation_positions, 106 | annotation_width_augment_fraction, 107 | annotation_height_augment_fraction, 108 | ) 109 | 110 | # Check if the ground truth tap action falls in an annotation's bounding box. 111 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions) 112 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions) 113 | both_in_box = jnp.max(tap1_in_box & tap2_in_box) 114 | 115 | # If the ground-truth tap action falls outside any of the annotation 116 | # bounding boxes or one of the actions is inside a bounding box and the other 117 | # is outside bounding box or vice versa, compare the points using Euclidean 118 | # distance. 119 | within_threshold = ( 120 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx)) 121 | <= matching_tap_distance_threshold_screen_percentage 122 | ) 123 | return jnp.logical_or(both_in_box, within_threshold) 124 | 125 | 126 | def _check_drag_actions_match( 127 | drag_1_touch_yx, 128 | drag_1_lift_yx, 129 | drag_2_touch_yx, 130 | drag_2_lift_yx, 131 | ): 132 | """Determines if two drag actions are the same.""" 133 | # Store drag deltas (the change in the y and x coordinates from touch to 134 | # lift), magnitudes, and the index of the main axis, which is the axis with 135 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 136 | # ending at (0.3, 0.5) has a main axis index of 1). 137 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx 138 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 139 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 140 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx 141 | drag_2_magnitudes = jnp.abs(drag_2_deltas) 142 | drag_2_main_axis = np.argmax(drag_2_magnitudes) 143 | 144 | return jnp.equal(drag_1_main_axis, drag_2_main_axis) #只判断滑动的方向 145 | 146 | 147 | def check_actions_match( 148 | action_1_touch_yx, 149 | action_1_lift_yx, 150 | action_1_action_type, 151 | action_2_touch_yx, 152 | action_2_lift_yx, 153 | action_2_action_type, 154 | annotation_positions, 155 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD, 156 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION, 157 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION, 158 | ): 159 | """Determines if two actions are considered to be the same. 160 | 161 | Two actions being "the same" is defined here as two actions that would result 162 | in a similar screen state. 163 | 164 | Args: 165 | action_1_touch_yx: The (y, x) coordinates of the first action's touch. 166 | action_1_lift_yx: The (y, x) coordinates of the first action's lift. 167 | action_1_action_type: The action type of the first action. 168 | action_2_touch_yx: The (y, x) coordinates of the second action's touch. 169 | action_2_lift_yx: The (y, x) coordinates of the second action's lift. 170 | action_2_action_type: The action type of the second action. 171 | annotation_positions: The positions of the UI annotations for the screen. It 172 | is A 2D int array of shape (num_bboxes, 4), where each row represents a 173 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that 174 | containment is inclusive of the bounding box edges. 175 | tap_distance_threshold: The threshold that determines if two taps result in 176 | a matching screen state if they don't fall the same bounding boxes. 177 | annotation_width_augment_fraction: The fraction to increase the width of the 178 | bounding box by. 179 | annotation_height_augment_fraction: The fraction to increase the height of 180 | of the bounding box by. 181 | 182 | Returns: 183 | A boolean representing whether the two given actions are the same or not. 184 | """ 185 | action_1_touch_yx = jnp.asarray(action_1_touch_yx) 186 | action_1_lift_yx = jnp.asarray(action_1_lift_yx) 187 | action_2_touch_yx = jnp.asarray(action_2_touch_yx) 188 | action_2_lift_yx = jnp.asarray(action_2_lift_yx) 189 | 190 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT), 191 | # because if that is the case, only the actions' types need to be compared. 192 | has_non_dual_point_action = jnp.logical_or( 193 | _is_non_dual_point_action(action_1_action_type), 194 | _is_non_dual_point_action(action_2_action_type), 195 | ) 196 | #print("non dual point: "+str(has_non_dual_point_action)) 197 | 198 | different_dual_point_types = jnp.logical_xor( 199 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 200 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 201 | ) 202 | #print("different dual type: "+str(different_dual_point_types)) 203 | 204 | is_tap = jnp.logical_and( 205 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 206 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 207 | ) 208 | #print("is tap: "+str(is_tap)) 209 | 210 | taps_match = _check_tap_actions_match( 211 | action_1_touch_yx, 212 | action_2_touch_yx, 213 | annotation_positions, 214 | tap_distance_threshold, 215 | annotation_width_augment_fraction, 216 | annotation_height_augment_fraction, 217 | ) 218 | #print("tap match: "+str(taps_match)) 219 | 220 | taps_match = jnp.logical_and(is_tap, taps_match) 221 | #print("tap match: "+str(taps_match)) 222 | 223 | drags_match = _check_drag_actions_match( 224 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx 225 | ) 226 | drags_match = jnp.where(is_tap, False, drags_match) 227 | #print("drag match: "+str(drags_match)) 228 | 229 | return jnp.where( 230 | has_non_dual_point_action, 231 | jnp.equal(action_1_action_type, action_2_action_type), 232 | jnp.where( 233 | different_dual_point_types, 234 | False, 235 | jnp.logical_or(taps_match, drags_match), 236 | ), 237 | ) 238 | 239 | 240 | def action_2_format(step_data): 241 | # 把test数据集中的动作格式转换为计算matching score的格式 242 | action_type = step_data["action_type_id"] 243 | 244 | if action_type == 4: 245 | if step_data["action_type_text"] == 'click': # 点击 246 | touch_point = step_data["touch"] 247 | lift_point = step_data["lift"] 248 | else: # 上下左右滑动 249 | if step_data["action_type_text"] == 'scroll down': 250 | touch_point = [0.5, 0.8] 251 | lift_point = [0.5, 0.2] 252 | elif step_data["action_type_text"] == 'scroll up': 253 | touch_point = [0.5, 0.2] 254 | lift_point = [0.5, 0.8] 255 | elif step_data["action_type_text"] == 'scroll left': 256 | touch_point = [0.2, 0.5] 257 | lift_point = [0.8, 0.5] 258 | elif step_data["action_type_text"] == 'scroll right': 259 | touch_point = [0.8, 0.5] 260 | lift_point = [0.2, 0.5] 261 | else: 262 | touch_point = [-1.0, -1.0] 263 | lift_point = [-1.0, -1.0] 264 | 265 | if action_type == 3: 266 | typed_text = step_data["type_text"] 267 | else: 268 | typed_text = "" 269 | 270 | action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point, 271 | "typed_text": typed_text} 272 | 273 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 274 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 275 | action["typed_text"] = action["typed_text"].lower() 276 | 277 | return action 278 | 279 | 280 | def pred_2_format(step_data): 281 | # 把模型输出的内容转换为计算action_matching的格式 282 | action_type = step_data["action_type"] 283 | 284 | if action_type == 4: # 点击 285 | action_type_new = 4 286 | touch_point = step_data["click_point"] 287 | lift_point = step_data["click_point"] 288 | typed_text = "" 289 | elif action_type == 0: 290 | action_type_new = 4 291 | touch_point = [0.5, 0.8] 292 | lift_point = [0.5, 0.2] 293 | typed_text = "" 294 | elif action_type == 1: 295 | action_type_new = 4 296 | touch_point = [0.5, 0.2] 297 | lift_point = [0.5, 0.8] 298 | typed_text = "" 299 | elif action_type == 8: 300 | action_type_new = 4 301 | touch_point = [0.2, 0.5] 302 | lift_point = [0.8, 0.5] 303 | typed_text = "" 304 | elif action_type == 9: 305 | action_type_new = 4 306 | touch_point = [0.8, 0.5] 307 | lift_point = [0.2, 0.5] 308 | typed_text = "" 309 | else: 310 | action_type_new = action_type 311 | touch_point = [-1.0, -1.0] 312 | lift_point = [-1.0, -1.0] 313 | typed_text = "" 314 | if action_type_new == 3: 315 | typed_text = step_data["typed_text"] 316 | 317 | action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point, 318 | "typed_text": typed_text} 319 | 320 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 321 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 322 | action["typed_text"] = action["typed_text"].lower() 323 | 324 | return action 325 | 326 | def pred_2_format_4_mpgui(step_data,img_filename=''): 327 | # 把模型输出的内容转换为计算action_matching的格式 328 | action_type = step_data["action_type"] 329 | 330 | if img_filename != '': 331 | img_path = 'AITW_simplified/aitw_images/' + img_filename 332 | w, h = Image.open(img_path).size 333 | else: 334 | w, h = 1000, 1000 335 | 336 | 337 | if action_type == 4: # 点击 338 | action_type_new = 4 339 | 340 | if 'click_point' in step_data: 341 | touch_point = step_data["click_point"] 342 | lift_point = step_data["click_point"] 343 | # for MP-GUI 344 | if touch_point[0] > 1.: 345 | touch_point = [touch_point[0]/w, touch_point[1]/h] 346 | if lift_point[0] > 1: 347 | lift_point = [lift_point[0]/w, lift_point[1]/h] 348 | else: 349 | print(f'$$ error pred step: {step_data}') 350 | touch_point = [0., 0.] 351 | lift_point = [0., 0.] 352 | typed_text = "" 353 | elif action_type == 0: 354 | action_type_new = 4 355 | touch_point = [0.5, 0.8] 356 | lift_point = [0.5, 0.2] 357 | typed_text = "" 358 | elif action_type == 1: 359 | action_type_new = 4 360 | touch_point = [0.5, 0.2] 361 | lift_point = [0.5, 0.8] 362 | typed_text = "" 363 | elif action_type == 8: 364 | action_type_new = 4 365 | touch_point = [0.2, 0.5] 366 | lift_point = [0.8, 0.5] 367 | typed_text = "" 368 | elif action_type == 9: 369 | action_type_new = 4 370 | touch_point = [0.8, 0.5] 371 | lift_point = [0.2, 0.5] 372 | typed_text = "" 373 | else: 374 | action_type_new = action_type 375 | touch_point = [-1.0, -1.0] 376 | lift_point = [-1.0, -1.0] 377 | typed_text = "" 378 | if action_type_new == 3: 379 | typed_text = step_data["typed_text"] 380 | 381 | action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point, 382 | "typed_text": typed_text} 383 | 384 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 385 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 386 | action["typed_text"] = action["typed_text"].lower() 387 | 388 | return action 389 | 390 | def convert_qwen_format(response): 391 | pred_action = response 392 | # pred_action = response.split('### Action ###')[-1].strip() 393 | # print(pred_action) 394 | item = {} 395 | if 'Click' in pred_action: 396 | action_id = 4 397 | try: 398 | x, y = pred_action.split('(')[-1].split(')')[0].split(',') 399 | x, y = int(x), int(y) 400 | except: 401 | x,y = 0, 0 402 | item = { 403 | 'action_type': action_id, 404 | 'click_point': (x,y) 405 | } 406 | elif 'Scroll("up")' in pred_action: 407 | item = { 408 | 'action_type': 1 409 | } 410 | elif 'Scroll("down")' in pred_action: 411 | item = { 412 | 'action_type': 0 413 | } 414 | elif 'Scroll("left")' in pred_action: 415 | item = { 416 | 'action_type': 8 417 | } 418 | elif 'Scroll("right")' in pred_action: 419 | item = { 420 | 'action_type': 9 421 | } 422 | elif 'Type' in pred_action: 423 | text = pred_action.split('("')[-1].split('")')[0] 424 | item = { 425 | 'action_type': 3, 426 | 'typed_text': text 427 | } 428 | elif 'Complete' in pred_action: 429 | item ={ 430 | 'action_type': 10 431 | } 432 | elif 'Back' in pred_action: 433 | item ={ 434 | 'action_type': 5 435 | } 436 | elif 'Home' in pred_action: 437 | item ={ 438 | 'action_type': 6 439 | } 440 | elif 'Enter' in pred_action: 441 | item ={ 442 | 'action_type': 7 443 | } 444 | else: 445 | item ={ 446 | 'action_type': 2 #error 447 | } 448 | return item 449 | 450 | # def convert_qwen_format_mind2web(response): 451 | # pred_action = response#.split('### Action')[-1].strip() 452 | 453 | # item = {} 454 | 455 | # if 'Click' in pred_action: 456 | # try: 457 | # x, y = pred_action.split('(')[-1].split(')')[0].split(',') 458 | # x, y = int(x), int(y) 459 | # click_point = (x, y) 460 | # except: 461 | # x,y = 0, 0 462 | # click_point = (x, y) 463 | # item = {"action_type": 4, "click_point": click_point} 464 | 465 | # elif 'Type' in pred_action: 466 | # try: 467 | # # Type(x,y,"typed_text") 468 | # s = pred_action.split('(')[-1] 469 | # x, y, tp_txt = s.split(',') 470 | # x, y = int(x), int(y) 471 | # click_point = (x, y) 472 | # select_value = tp_txt.replace('"','').replace(')', '') 473 | # except: 474 | # click_point = (0,0) 475 | # select_value = '' 476 | # item = {"action_type": 3, "click_point": click_point, "value": select_value} 477 | 478 | # elif 'Select' in pred_action: 479 | # try: 480 | # s = pred_action.split('(')[-1] 481 | # x, y, tp_txt = s.split(',') 482 | # x, y = int(x), int(y) 483 | # click_point = (x, y) 484 | # select_value = tp_txt.replace('"','').replace(')', '') 485 | # except: 486 | # click_point = (0,0) 487 | # select_value = '' 488 | # item = {"action_type": 3, "click_point": click_point, "value": select_value} 489 | # else: 490 | # item = {"action_type": 0, "click_point": (0,0)} 491 | 492 | # return item 493 | def convert_qwen_format_mind2web(response): 494 | pred_action = response#.split('### Action')[-1].strip() 495 | 496 | item = {} 497 | 498 | if 'Click' in pred_action: 499 | try: 500 | # print(pred_action) 501 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 502 | x, y = s.split(',')[0], s.split(',')[1] 503 | x, y = int(x), int(y) 504 | click_point = (x, y) 505 | except: 506 | x,y = 0, 0 507 | click_point = (x, y) 508 | item = {"action_type": 4, "click_point": click_point} 509 | 510 | elif 'Type' in pred_action: 511 | try: 512 | # print(pred_action) 513 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 514 | x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:]) 515 | x, y = int(x), int(y) 516 | click_point = (x, y) 517 | typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 518 | # print(select_value) 519 | except: 520 | click_point = (0,0) 521 | typed_text = '' 522 | item = {"action_type": 3, "click_point": click_point, "value": typed_text} 523 | 524 | elif 'Select' in pred_action: 525 | try: 526 | # print(pred_action) 527 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 528 | x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:]) 529 | x, y = int(x), int(y) 530 | click_point = (x, y) 531 | select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 532 | # print(select_value) 533 | except: 534 | click_point = (0,0) 535 | select_value = '' 536 | item = {"action_type": 2, "click_point": click_point, "value": select_value} 537 | else: 538 | item = {"action_type": 0, "click_point": (0,0)} 539 | 540 | return item 541 | 542 | def convert_qwen_format_mind2web_InternVL(response): 543 | pred_action = response#.split('### Action')[-1].strip() 544 | 545 | item = {} 546 | 547 | if 'Click' in pred_action: 548 | try: 549 | # print(pred_action) 550 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 551 | [x1, y1, x2, y2] = s.split(',') 552 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 553 | click_point = (x, y) 554 | except: 555 | x,y = 0, 0 556 | click_point = (x, y) 557 | item = {"action_type": 4, "click_point": click_point} 558 | 559 | elif 'Type' in pred_action: 560 | try: 561 | # print(pred_action) 562 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 563 | x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:]) 564 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 565 | click_point = (x, y) 566 | typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 567 | # print(select_value) 568 | except: 569 | click_point = (0,0) 570 | typed_text = '' 571 | item = {"action_type": 3, "click_point": click_point, "value": typed_text} 572 | 573 | elif 'Select' in pred_action: 574 | try: 575 | # print(pred_action) 576 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 577 | x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:]) 578 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 579 | click_point = (x, y) 580 | select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 581 | # print(select_value) 582 | except: 583 | click_point = (0,0) 584 | select_value = '' 585 | item = {"action_type": 2, "click_point": click_point, "value": select_value} 586 | else: 587 | item = {"action_type": 0, "click_point": (0,0)} 588 | 589 | return item 590 | 591 | def simple_decode(gt, img_path=None): 592 | idx = gt.find(':') 593 | if idx == -1: 594 | action = gt 595 | info = "" 596 | else: 597 | action = gt[:idx].strip() 598 | info = gt[idx+1:].strip() 599 | if action in ['CLICK', "LONG_PRESS"]: 600 | info = eval(info) 601 | if img_path is not None: 602 | img_path = 'GUI-Odyssey-master/data/screenshots/' + img_path 603 | w, h = Image.open(img_path).size 604 | info = (info[0] / w * 1000, info[1] / h * 1000) 605 | 606 | return {"action": action, "info": info} 607 | 608 | 609 | TEXT_ANLS_THRESHOLD = 0.5 610 | CLICK_COORD_THRESHOLD = 0.14 611 | 612 | def levenshtein_distance(s1, s2): 613 | if len(s1) > len(s2): 614 | s1, s2 = s2, s1 615 | 616 | distances = range(len(s1) + 1) 617 | for i2, c2 in enumerate(s2): 618 | distances_ = [i2+1] 619 | for i1, c1 in enumerate(s1): 620 | if c1 == c2: 621 | distances_.append(distances[i1]) 622 | else: 623 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 624 | distances = distances_ 625 | return distances[-1] 626 | 627 | 628 | def text_matching(gt, pred): 629 | gt = gt.strip() 630 | pred = pred.strip() 631 | if gt in pred or pred in gt: 632 | return True 633 | 634 | dist = levenshtein_distance(gt, pred) 635 | length = max(len(gt), len(pred)) 636 | value = 0.0 if length == 0 else float(dist) / float(length) 637 | value = 1 - value 638 | return value >= TEXT_ANLS_THRESHOLD 639 | 640 | 641 | def click_matching(gt_info, pred_info): 642 | if type(pred_info) == str: 643 | pred_info = eval(pred_info) 644 | if type(gt_info) == str: 645 | gt_info = eval(gt_info) 646 | 647 | pred = np.asarray(pred_info) / 1000 648 | gt = np.asarray(gt_info) / 1000 649 | 650 | return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD 651 | 652 | 653 | 654 | def action_matching(pred_action, pred_info, gt_action, gt_info): 655 | pred_action = pred_action.strip() 656 | if type(pred_info) == str: 657 | pred_info = pred_info.strip() 658 | gt_action = gt_action.strip() 659 | if type(gt_info) == str: 660 | gt_info = gt_info.strip() 661 | 662 | if pred_action != gt_action: 663 | return {'is_correct': 'no', 'info': 'action_fail'} 664 | 665 | if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']: 666 | return {'is_correct': 'yes', 'info': 'action_correct'} 667 | 668 | elif gt_action == 'TYPE': 669 | text_flag = text_matching(gt_info, pred_info) 670 | 671 | if text_flag: 672 | return {'is_correct': 'yes', 'info': 'type_correct'} 673 | else: 674 | return {'is_correct': 'no', 'info': 'type_fail'} 675 | 676 | elif gt_action == 'SCROLL': 677 | if gt_info.lower() == pred_info.lower(): 678 | return {'is_correct': 'yes', 'info': 'scroll_correct'} 679 | else: 680 | return {'is_correct': 'no', 'info': 'scroll_fail'} 681 | 682 | elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS': 683 | click_flag = click_matching(gt_info, pred_info) 684 | 685 | if click_flag: 686 | return {'is_correct': 'yes', 'info': 'click_correct'} 687 | else: 688 | return {'is_correct': 'no', 'info': 'click_fail'} 689 | 690 | else: 691 | raise ValueError('Invalid action type') 692 | 693 | def stat_result(eval_dict, metric): 694 | text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct']) 695 | type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail']) 696 | text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')]) 697 | 698 | if metric == 'macro': 699 | action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes']) 700 | AMS = round(action_correct / len(eval_dict) * 100, 2) 701 | SR_cnt, SR_tot, SR = check_SR(eval_dict) 702 | elif metric == 'micro': 703 | task_cate_dict = {} 704 | acc_list = [] 705 | SR_list = [] 706 | # print(eval_dict) 707 | for sample in eval_dict: 708 | cat = sample['more_info']['category'] 709 | if cat not in task_cate_dict: 710 | task_cate_dict[cat] = [] 711 | task_cate_dict[cat].append(sample) 712 | 713 | 714 | # assert len(task_cate_dict) == 6 #总共6个类别的数据,跑部分数据可以注释掉 715 | for k, v in task_cate_dict.items(): 716 | SR_cnt, SR_tot, SR = check_SR(v) 717 | SR_list.append((SR)) 718 | acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2) 719 | acc_list.append(acc) 720 | print(f'category: {k}, AMS: {acc}, SR: {SR}') 721 | 722 | AMS = np.round(np.mean(acc_list), 2) 723 | SR = np.round(np.mean(SR_list), 2) 724 | 725 | else: 726 | raise ValueError(f'No metric {metric} found.') 727 | 728 | info = { 729 | 'AMS': AMS, 730 | 'SR': SR, 731 | 'total': len(eval_dict), 732 | 'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100), 733 | 'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100), 734 | } 735 | 736 | return info 737 | 738 | def check_SR(eval_dict): 739 | episode_dict = {} 740 | steps_map = {} 741 | for data in eval_dict: 742 | if 'img' in data: img = data['img'] 743 | elif 'image' in data: img = data['image'] 744 | else: img = data['question'].split('')[0].split('')[1] 745 | img = os.path.basename(img) 746 | tail = img.split('_')[-1] 747 | episode = img.replace(f'_{tail}', '') 748 | if episode not in episode_dict: 749 | episode_dict[episode] = [] 750 | else: 751 | assert steps_map[episode] == data['more_info']['step_length'] 752 | 753 | info = data['is_correct'] 754 | episode_dict[episode].append(info) 755 | steps_map[episode] = data['more_info']['step_length'] 756 | 757 | cnt, tot = 0, 0 758 | # print('=== ',episode_dict) 759 | for k, v in episode_dict.items(): 760 | if len(v) != steps_map[k]: 761 | print(f'step length of {k} does not match.') 762 | continue 763 | tot += 1 764 | v = list(set(v)) 765 | if len(v) == 1 and v[0] == 'yes': 766 | cnt += 1 767 | 768 | SR = round(cnt / tot * 100, 2) 769 | print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}') 770 | return cnt, tot, SR 771 | 772 | def odyssey_action_matching_evaluation(pred_output, metric='macro'): 773 | eval_dict = [] 774 | for idx, sample in enumerate(pred_output): 775 | 776 | question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info'] 777 | # sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info} 778 | sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info, 'img':sample['img']} 779 | 780 | gt_simple_info = simple_decode(gt) 781 | gt_action = gt_simple_info['action'] 782 | gt_info = gt_simple_info['info'] 783 | 784 | try: 785 | pred_simple_info = simple_decode(pred, sample['img']) 786 | # print('pred_simple_info:', pred_simple_info) 787 | pred_action = pred_simple_info['action'] 788 | pred_info = pred_simple_info['info'] 789 | except: 790 | # print('### eval err:', idx, pred) 791 | log_info = {'is_correct': 'no', 'info': 'decode invalid'} 792 | sample_eval_dict.update(log_info) 793 | eval_dict.append(sample_eval_dict) 794 | continue 795 | 796 | try: 797 | check_match = action_matching(pred_action, pred_info, gt_action, gt_info) 798 | except Exception as exc: 799 | print('$$$ eval err:', gt, pred, exc) 800 | check_match = {'is_correct': 'no', 'info': 'match invalid'} 801 | 802 | sample_eval_dict.update(check_match) 803 | eval_dict.append(sample_eval_dict) 804 | 805 | # print('===== ',eval_dict) 806 | info = stat_result(eval_dict, metric) 807 | metrics = {"info": info, "pred": eval_dict} 808 | return metrics -------------------------------------------------------------------------------- /workflow/odyssey/action_matching.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild 3 | ''' 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import os 9 | 10 | import action_type as action_type_lib 11 | from PIL import Image 12 | 13 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen 14 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4 15 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4 16 | 17 | # Interval determining if an action is a tap or a swipe. 18 | _SWIPE_DISTANCE_THRESHOLD = 0.04 19 | 20 | 21 | def _yx_in_bounding_boxes( 22 | yx, bounding_boxes 23 | ): 24 | """Check if the (y,x) point is contained in each bounding box. 25 | 26 | Args: 27 | yx: The (y, x) coordinate in pixels of the point. 28 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row 29 | represents a bounding box: (y_top_left, x_top_left, box_height, 30 | box_width). Note: containment is inclusive of the bounding box edges. 31 | 32 | Returns: 33 | is_inside: A 1D bool array where each element specifies if the point is 34 | contained within the respective box. 35 | """ 36 | y, x = yx 37 | 38 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the 39 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension. 40 | top, left, height, width = [ 41 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1) 42 | ] 43 | 44 | # The y-axis is inverted for AndroidEnv, so bottom = top + height. 45 | bottom, right = top + height, left + width 46 | 47 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and( 48 | x >= left, x <= right) 49 | 50 | 51 | def _resize_annotation_bounding_boxes( 52 | annotation_positions, annotation_width_augment_fraction, 53 | annotation_height_augment_fraction): 54 | """Resize the bounding boxes by the given fractions. 55 | 56 | Args: 57 | annotation_positions: Array of shape (N, 4), where each row represents the 58 | (y, x, height, width) of the bounding boxes. 59 | annotation_width_augment_fraction: The fraction to augment the box widths, 60 | E.g., 1.4 == 240% total increase. 61 | annotation_height_augment_fraction: Same as described for width, but for box 62 | height. 63 | 64 | Returns: 65 | Resized bounding box. 66 | 67 | """ 68 | height_change = ( 69 | annotation_height_augment_fraction * annotation_positions[:, 2]) 70 | width_change = ( 71 | annotation_width_augment_fraction * annotation_positions[:, 3]) 72 | 73 | # Limit bounding box positions to the screen. 74 | resized_annotations = jnp.stack([ 75 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)), 76 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)), 77 | jnp.minimum(1, annotation_positions[:, 2] + height_change), 78 | jnp.minimum(1, annotation_positions[:, 3] + width_change), 79 | ], 80 | axis=1) 81 | return resized_annotations 82 | 83 | 84 | def is_tap_action(normalized_start_yx, 85 | normalized_end_yx): 86 | distance = jnp.linalg.norm( 87 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 88 | return distance <= _SWIPE_DISTANCE_THRESHOLD 89 | 90 | 91 | def _is_non_dual_point_action(action_type): 92 | return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT) 93 | 94 | 95 | def _check_tap_actions_match( 96 | tap_1_yx, 97 | tap_2_yx, 98 | annotation_positions, 99 | matching_tap_distance_threshold_screen_percentage, 100 | annotation_width_augment_fraction, 101 | annotation_height_augment_fraction, 102 | ): 103 | """Determines if two tap actions are the same.""" 104 | resized_annotation_positions = _resize_annotation_bounding_boxes( 105 | annotation_positions, 106 | annotation_width_augment_fraction, 107 | annotation_height_augment_fraction, 108 | ) 109 | 110 | # Check if the ground truth tap action falls in an annotation's bounding box. 111 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions) 112 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions) 113 | both_in_box = jnp.max(tap1_in_box & tap2_in_box) 114 | 115 | # If the ground-truth tap action falls outside any of the annotation 116 | # bounding boxes or one of the actions is inside a bounding box and the other 117 | # is outside bounding box or vice versa, compare the points using Euclidean 118 | # distance. 119 | within_threshold = ( 120 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx)) 121 | <= matching_tap_distance_threshold_screen_percentage 122 | ) 123 | return jnp.logical_or(both_in_box, within_threshold) 124 | 125 | 126 | def _check_drag_actions_match( 127 | drag_1_touch_yx, 128 | drag_1_lift_yx, 129 | drag_2_touch_yx, 130 | drag_2_lift_yx, 131 | ): 132 | """Determines if two drag actions are the same.""" 133 | # Store drag deltas (the change in the y and x coordinates from touch to 134 | # lift), magnitudes, and the index of the main axis, which is the axis with 135 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 136 | # ending at (0.3, 0.5) has a main axis index of 1). 137 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx 138 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 139 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 140 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx 141 | drag_2_magnitudes = jnp.abs(drag_2_deltas) 142 | drag_2_main_axis = np.argmax(drag_2_magnitudes) 143 | 144 | return jnp.equal(drag_1_main_axis, drag_2_main_axis) #只判断滑动的方向 145 | 146 | 147 | def check_actions_match( 148 | action_1_touch_yx, 149 | action_1_lift_yx, 150 | action_1_action_type, 151 | action_2_touch_yx, 152 | action_2_lift_yx, 153 | action_2_action_type, 154 | annotation_positions, 155 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD, 156 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION, 157 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION, 158 | ): 159 | """Determines if two actions are considered to be the same. 160 | 161 | Two actions being "the same" is defined here as two actions that would result 162 | in a similar screen state. 163 | 164 | Args: 165 | action_1_touch_yx: The (y, x) coordinates of the first action's touch. 166 | action_1_lift_yx: The (y, x) coordinates of the first action's lift. 167 | action_1_action_type: The action type of the first action. 168 | action_2_touch_yx: The (y, x) coordinates of the second action's touch. 169 | action_2_lift_yx: The (y, x) coordinates of the second action's lift. 170 | action_2_action_type: The action type of the second action. 171 | annotation_positions: The positions of the UI annotations for the screen. It 172 | is A 2D int array of shape (num_bboxes, 4), where each row represents a 173 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that 174 | containment is inclusive of the bounding box edges. 175 | tap_distance_threshold: The threshold that determines if two taps result in 176 | a matching screen state if they don't fall the same bounding boxes. 177 | annotation_width_augment_fraction: The fraction to increase the width of the 178 | bounding box by. 179 | annotation_height_augment_fraction: The fraction to increase the height of 180 | of the bounding box by. 181 | 182 | Returns: 183 | A boolean representing whether the two given actions are the same or not. 184 | """ 185 | action_1_touch_yx = jnp.asarray(action_1_touch_yx) 186 | action_1_lift_yx = jnp.asarray(action_1_lift_yx) 187 | action_2_touch_yx = jnp.asarray(action_2_touch_yx) 188 | action_2_lift_yx = jnp.asarray(action_2_lift_yx) 189 | 190 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT), 191 | # because if that is the case, only the actions' types need to be compared. 192 | has_non_dual_point_action = jnp.logical_or( 193 | _is_non_dual_point_action(action_1_action_type), 194 | _is_non_dual_point_action(action_2_action_type), 195 | ) 196 | #print("non dual point: "+str(has_non_dual_point_action)) 197 | 198 | different_dual_point_types = jnp.logical_xor( 199 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 200 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 201 | ) 202 | #print("different dual type: "+str(different_dual_point_types)) 203 | 204 | is_tap = jnp.logical_and( 205 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 206 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 207 | ) 208 | #print("is tap: "+str(is_tap)) 209 | 210 | taps_match = _check_tap_actions_match( 211 | action_1_touch_yx, 212 | action_2_touch_yx, 213 | annotation_positions, 214 | tap_distance_threshold, 215 | annotation_width_augment_fraction, 216 | annotation_height_augment_fraction, 217 | ) 218 | #print("tap match: "+str(taps_match)) 219 | 220 | taps_match = jnp.logical_and(is_tap, taps_match) 221 | #print("tap match: "+str(taps_match)) 222 | 223 | drags_match = _check_drag_actions_match( 224 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx 225 | ) 226 | drags_match = jnp.where(is_tap, False, drags_match) 227 | #print("drag match: "+str(drags_match)) 228 | 229 | return jnp.where( 230 | has_non_dual_point_action, 231 | jnp.equal(action_1_action_type, action_2_action_type), 232 | jnp.where( 233 | different_dual_point_types, 234 | False, 235 | jnp.logical_or(taps_match, drags_match), 236 | ), 237 | ) 238 | 239 | 240 | def action_2_format(step_data): 241 | # 把test数据集中的动作格式转换为计算matching score的格式 242 | action_type = step_data["action_type_id"] 243 | 244 | if action_type == 4: 245 | if step_data["action_type_text"] == 'click': # 点击 246 | touch_point = step_data["touch"] 247 | lift_point = step_data["lift"] 248 | else: # 上下左右滑动 249 | if step_data["action_type_text"] == 'scroll down': 250 | touch_point = [0.5, 0.8] 251 | lift_point = [0.5, 0.2] 252 | elif step_data["action_type_text"] == 'scroll up': 253 | touch_point = [0.5, 0.2] 254 | lift_point = [0.5, 0.8] 255 | elif step_data["action_type_text"] == 'scroll left': 256 | touch_point = [0.2, 0.5] 257 | lift_point = [0.8, 0.5] 258 | elif step_data["action_type_text"] == 'scroll right': 259 | touch_point = [0.8, 0.5] 260 | lift_point = [0.2, 0.5] 261 | else: 262 | touch_point = [-1.0, -1.0] 263 | lift_point = [-1.0, -1.0] 264 | 265 | if action_type == 3: 266 | typed_text = step_data["type_text"] 267 | else: 268 | typed_text = "" 269 | 270 | action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point, 271 | "typed_text": typed_text} 272 | 273 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 274 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 275 | action["typed_text"] = action["typed_text"].lower() 276 | 277 | return action 278 | 279 | 280 | def pred_2_format(step_data): 281 | # 把模型输出的内容转换为计算action_matching的格式 282 | action_type = step_data["action_type"] 283 | 284 | if action_type == 4: # 点击 285 | action_type_new = 4 286 | touch_point = step_data["click_point"] 287 | lift_point = step_data["click_point"] 288 | typed_text = "" 289 | elif action_type == 0: 290 | action_type_new = 4 291 | touch_point = [0.5, 0.8] 292 | lift_point = [0.5, 0.2] 293 | typed_text = "" 294 | elif action_type == 1: 295 | action_type_new = 4 296 | touch_point = [0.5, 0.2] 297 | lift_point = [0.5, 0.8] 298 | typed_text = "" 299 | elif action_type == 8: 300 | action_type_new = 4 301 | touch_point = [0.2, 0.5] 302 | lift_point = [0.8, 0.5] 303 | typed_text = "" 304 | elif action_type == 9: 305 | action_type_new = 4 306 | touch_point = [0.8, 0.5] 307 | lift_point = [0.2, 0.5] 308 | typed_text = "" 309 | else: 310 | action_type_new = action_type 311 | touch_point = [-1.0, -1.0] 312 | lift_point = [-1.0, -1.0] 313 | typed_text = "" 314 | if action_type_new == 3: 315 | typed_text = step_data["typed_text"] 316 | 317 | action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point, 318 | "typed_text": typed_text} 319 | 320 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 321 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 322 | action["typed_text"] = action["typed_text"].lower() 323 | 324 | return action 325 | 326 | def pred_2_format_4_mpgui(step_data,img_filename=''): 327 | # 把模型输出的内容转换为计算action_matching的格式 328 | action_type = step_data["action_type"] 329 | 330 | if img_filename != '': 331 | img_path = 'AITW_simplified/aitw_images/' + img_filename 332 | w, h = Image.open(img_path).size 333 | else: 334 | w, h = 1000, 1000 335 | 336 | 337 | if action_type == 4: # 点击 338 | action_type_new = 4 339 | 340 | if 'click_point' in step_data: 341 | touch_point = step_data["click_point"] 342 | lift_point = step_data["click_point"] 343 | # for MP-GUI 344 | if touch_point[0] > 1.: 345 | touch_point = [touch_point[0]/w, touch_point[1]/h] 346 | if lift_point[0] > 1: 347 | lift_point = [lift_point[0]/w, lift_point[1]/h] 348 | else: 349 | print(f'$$ error pred step: {step_data}') 350 | touch_point = [0., 0.] 351 | lift_point = [0., 0.] 352 | typed_text = "" 353 | elif action_type == 0: 354 | action_type_new = 4 355 | touch_point = [0.5, 0.8] 356 | lift_point = [0.5, 0.2] 357 | typed_text = "" 358 | elif action_type == 1: 359 | action_type_new = 4 360 | touch_point = [0.5, 0.2] 361 | lift_point = [0.5, 0.8] 362 | typed_text = "" 363 | elif action_type == 8: 364 | action_type_new = 4 365 | touch_point = [0.2, 0.5] 366 | lift_point = [0.8, 0.5] 367 | typed_text = "" 368 | elif action_type == 9: 369 | action_type_new = 4 370 | touch_point = [0.8, 0.5] 371 | lift_point = [0.2, 0.5] 372 | typed_text = "" 373 | else: 374 | action_type_new = action_type 375 | touch_point = [-1.0, -1.0] 376 | lift_point = [-1.0, -1.0] 377 | typed_text = "" 378 | if action_type_new == 3: 379 | typed_text = step_data["typed_text"] 380 | 381 | action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point, 382 | "typed_text": typed_text} 383 | 384 | action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]] 385 | action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]] 386 | action["typed_text"] = action["typed_text"].lower() 387 | 388 | return action 389 | 390 | def convert_qwen_format(response): 391 | pred_action = response 392 | # pred_action = response.split('### Action ###')[-1].strip() 393 | # print(pred_action) 394 | item = {} 395 | if 'Click' in pred_action: 396 | action_id = 4 397 | try: 398 | x, y = pred_action.split('(')[-1].split(')')[0].split(',') 399 | x, y = int(x), int(y) 400 | except: 401 | x,y = 0, 0 402 | item = { 403 | 'action_type': action_id, 404 | 'click_point': (x,y) 405 | } 406 | elif 'Scroll("up")' in pred_action: 407 | item = { 408 | 'action_type': 1 409 | } 410 | elif 'Scroll("down")' in pred_action: 411 | item = { 412 | 'action_type': 0 413 | } 414 | elif 'Scroll("left")' in pred_action: 415 | item = { 416 | 'action_type': 8 417 | } 418 | elif 'Scroll("right")' in pred_action: 419 | item = { 420 | 'action_type': 9 421 | } 422 | elif 'Type' in pred_action: 423 | text = pred_action.split('("')[-1].split('")')[0] 424 | item = { 425 | 'action_type': 3, 426 | 'typed_text': text 427 | } 428 | elif 'Complete' in pred_action: 429 | item ={ 430 | 'action_type': 10 431 | } 432 | elif 'Back' in pred_action: 433 | item ={ 434 | 'action_type': 5 435 | } 436 | elif 'Home' in pred_action: 437 | item ={ 438 | 'action_type': 6 439 | } 440 | elif 'Enter' in pred_action: 441 | item ={ 442 | 'action_type': 7 443 | } 444 | else: 445 | item ={ 446 | 'action_type': 2 #error 447 | } 448 | return item 449 | 450 | # def convert_qwen_format_mind2web(response): 451 | # pred_action = response#.split('### Action')[-1].strip() 452 | 453 | # item = {} 454 | 455 | # if 'Click' in pred_action: 456 | # try: 457 | # x, y = pred_action.split('(')[-1].split(')')[0].split(',') 458 | # x, y = int(x), int(y) 459 | # click_point = (x, y) 460 | # except: 461 | # x,y = 0, 0 462 | # click_point = (x, y) 463 | # item = {"action_type": 4, "click_point": click_point} 464 | 465 | # elif 'Type' in pred_action: 466 | # try: 467 | # # Type(x,y,"typed_text") 468 | # s = pred_action.split('(')[-1] 469 | # x, y, tp_txt = s.split(',') 470 | # x, y = int(x), int(y) 471 | # click_point = (x, y) 472 | # select_value = tp_txt.replace('"','').replace(')', '') 473 | # except: 474 | # click_point = (0,0) 475 | # select_value = '' 476 | # item = {"action_type": 3, "click_point": click_point, "value": select_value} 477 | 478 | # elif 'Select' in pred_action: 479 | # try: 480 | # s = pred_action.split('(')[-1] 481 | # x, y, tp_txt = s.split(',') 482 | # x, y = int(x), int(y) 483 | # click_point = (x, y) 484 | # select_value = tp_txt.replace('"','').replace(')', '') 485 | # except: 486 | # click_point = (0,0) 487 | # select_value = '' 488 | # item = {"action_type": 3, "click_point": click_point, "value": select_value} 489 | # else: 490 | # item = {"action_type": 0, "click_point": (0,0)} 491 | 492 | # return item 493 | def convert_qwen_format_mind2web(response): 494 | pred_action = response#.split('### Action')[-1].strip() 495 | 496 | item = {} 497 | 498 | if 'Click' in pred_action: 499 | try: 500 | # print(pred_action) 501 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 502 | x, y = s.split(',')[0], s.split(',')[1] 503 | x, y = int(x), int(y) 504 | click_point = (x, y) 505 | except: 506 | x,y = 0, 0 507 | click_point = (x, y) 508 | item = {"action_type": 4, "click_point": click_point} 509 | 510 | elif 'Type' in pred_action: 511 | try: 512 | # print(pred_action) 513 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 514 | x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:]) 515 | x, y = int(x), int(y) 516 | click_point = (x, y) 517 | typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 518 | # print(select_value) 519 | except: 520 | click_point = (0,0) 521 | typed_text = '' 522 | item = {"action_type": 3, "click_point": click_point, "value": typed_text} 523 | 524 | elif 'Select' in pred_action: 525 | try: 526 | # print(pred_action) 527 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 528 | x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:]) 529 | x, y = int(x), int(y) 530 | click_point = (x, y) 531 | select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 532 | # print(select_value) 533 | except: 534 | click_point = (0,0) 535 | select_value = '' 536 | item = {"action_type": 2, "click_point": click_point, "value": select_value} 537 | else: 538 | item = {"action_type": 0, "click_point": (0,0)} 539 | 540 | return item 541 | 542 | def convert_qwen_format_mind2web_InternVL(response): 543 | pred_action = response#.split('### Action')[-1].strip() 544 | 545 | item = {} 546 | 547 | if 'Click' in pred_action: 548 | try: 549 | # print(pred_action) 550 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 551 | [x1, y1, x2, y2] = s.split(',') 552 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 553 | click_point = (x, y) 554 | except: 555 | x,y = 0, 0 556 | click_point = (x, y) 557 | item = {"action_type": 4, "click_point": click_point} 558 | 559 | elif 'Type' in pred_action: 560 | try: 561 | # print(pred_action) 562 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 563 | x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:]) 564 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 565 | click_point = (x, y) 566 | typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 567 | # print(select_value) 568 | except: 569 | click_point = (0,0) 570 | typed_text = '' 571 | item = {"action_type": 3, "click_point": click_point, "value": typed_text} 572 | 573 | elif 'Select' in pred_action: 574 | try: 575 | # print(pred_action) 576 | s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')] 577 | x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:]) 578 | x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2 579 | click_point = (x, y) 580 | select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')] 581 | # print(select_value) 582 | except: 583 | click_point = (0,0) 584 | select_value = '' 585 | item = {"action_type": 2, "click_point": click_point, "value": select_value} 586 | else: 587 | item = {"action_type": 0, "click_point": (0,0)} 588 | 589 | return item 590 | 591 | def simple_decode(gt, img_path=None): 592 | idx = gt.find(':') 593 | if idx == -1: 594 | action = gt 595 | info = "" 596 | else: 597 | action = gt[:idx].strip() 598 | info = gt[idx+1:].strip() 599 | if action in ['CLICK', "LONG_PRESS"]: 600 | info = eval(info) 601 | if img_path is not None: 602 | img_path = 'GUI-Odyssey-master/data/screenshots/' + img_path 603 | w, h = Image.open(img_path).size 604 | info = (info[0] / w * 1000, info[1] / h * 1000) 605 | 606 | return {"action": action, "info": info} 607 | 608 | 609 | TEXT_ANLS_THRESHOLD = 0.5 610 | CLICK_COORD_THRESHOLD = 0.14 611 | 612 | def levenshtein_distance(s1, s2): 613 | if len(s1) > len(s2): 614 | s1, s2 = s2, s1 615 | 616 | distances = range(len(s1) + 1) 617 | for i2, c2 in enumerate(s2): 618 | distances_ = [i2+1] 619 | for i1, c1 in enumerate(s1): 620 | if c1 == c2: 621 | distances_.append(distances[i1]) 622 | else: 623 | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) 624 | distances = distances_ 625 | return distances[-1] 626 | 627 | 628 | def text_matching(gt, pred): 629 | gt = gt.strip() 630 | pred = pred.strip() 631 | if gt in pred or pred in gt: 632 | return True 633 | 634 | dist = levenshtein_distance(gt, pred) 635 | length = max(len(gt), len(pred)) 636 | value = 0.0 if length == 0 else float(dist) / float(length) 637 | value = 1 - value 638 | return value >= TEXT_ANLS_THRESHOLD 639 | 640 | 641 | def click_matching(gt_info, pred_info): 642 | if type(pred_info) == str: 643 | pred_info = eval(pred_info) 644 | if type(gt_info) == str: 645 | gt_info = eval(gt_info) 646 | 647 | pred = np.asarray(pred_info) / 1000 648 | gt = np.asarray(gt_info) / 1000 649 | 650 | return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD 651 | 652 | 653 | 654 | def action_matching(pred_action, pred_info, gt_action, gt_info): 655 | pred_action = pred_action.strip() 656 | if type(pred_info) == str: 657 | pred_info = pred_info.strip() 658 | gt_action = gt_action.strip() 659 | if type(gt_info) == str: 660 | gt_info = gt_info.strip() 661 | 662 | if pred_action != gt_action: 663 | return {'is_correct': 'no', 'info': 'action_fail'} 664 | 665 | if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']: 666 | return {'is_correct': 'yes', 'info': 'action_correct'} 667 | 668 | elif gt_action == 'TYPE': 669 | text_flag = text_matching(gt_info, pred_info) 670 | 671 | if text_flag: 672 | return {'is_correct': 'yes', 'info': 'type_correct'} 673 | else: 674 | return {'is_correct': 'no', 'info': 'type_fail'} 675 | 676 | elif gt_action == 'SCROLL': 677 | if gt_info.lower() == pred_info.lower(): 678 | return {'is_correct': 'yes', 'info': 'scroll_correct'} 679 | else: 680 | return {'is_correct': 'no', 'info': 'scroll_fail'} 681 | 682 | elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS': 683 | click_flag = click_matching(gt_info, pred_info) 684 | 685 | if click_flag: 686 | return {'is_correct': 'yes', 'info': 'click_correct'} 687 | else: 688 | return {'is_correct': 'no', 'info': 'click_fail'} 689 | 690 | else: 691 | raise ValueError('Invalid action type') 692 | 693 | def stat_result(eval_dict, metric): 694 | text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct']) 695 | type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail']) 696 | text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')]) 697 | 698 | if metric == 'macro': 699 | action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes']) 700 | AMS = round(action_correct / len(eval_dict) * 100, 2) 701 | SR_cnt, SR_tot, SR = check_SR(eval_dict) 702 | elif metric == 'micro': 703 | task_cate_dict = {} 704 | acc_list = [] 705 | SR_list = [] 706 | # print(eval_dict) 707 | for sample in eval_dict: 708 | cat = sample['more_info']['category'] 709 | if cat not in task_cate_dict: 710 | task_cate_dict[cat] = [] 711 | task_cate_dict[cat].append(sample) 712 | 713 | 714 | # assert len(task_cate_dict) == 6 #总共6个类别的数据,跑部分数据可以注释掉 715 | for k, v in task_cate_dict.items(): 716 | SR_cnt, SR_tot, SR = check_SR(v) 717 | SR_list.append((SR)) 718 | acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2) 719 | acc_list.append(acc) 720 | print(f'category: {k}, AMS: {acc}, SR: {SR}') 721 | 722 | AMS = np.round(np.mean(acc_list), 2) 723 | SR = np.round(np.mean(SR_list), 2) 724 | 725 | else: 726 | raise ValueError(f'No metric {metric} found.') 727 | 728 | info = { 729 | 'AMS': AMS, 730 | 'SR': SR, 731 | 'total': len(eval_dict), 732 | 'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100), 733 | 'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100), 734 | } 735 | 736 | return info 737 | 738 | def check_SR(eval_dict): 739 | episode_dict = {} 740 | steps_map = {} 741 | for data in eval_dict: 742 | if 'img' in data: img = data['img'] 743 | elif 'image' in data: img = data['image'] 744 | else: img = data['question'].split('')[0].split('')[1] 745 | img = os.path.basename(img) 746 | tail = img.split('_')[-1] 747 | episode = img.replace(f'_{tail}', '') 748 | if episode not in episode_dict: 749 | episode_dict[episode] = [] 750 | else: 751 | assert steps_map[episode] == data['more_info']['step_length'] 752 | 753 | info = data['is_correct'] 754 | episode_dict[episode].append(info) 755 | steps_map[episode] = data['more_info']['step_length'] 756 | 757 | cnt, tot = 0, 0 758 | # print('=== ',episode_dict) 759 | for k, v in episode_dict.items(): 760 | if len(v) != steps_map[k]: 761 | print(f'step length of {k} does not match.') 762 | continue 763 | tot += 1 764 | v = list(set(v)) 765 | if len(v) == 1 and v[0] == 'yes': 766 | cnt += 1 767 | 768 | SR = round(cnt / tot * 100, 2) 769 | print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}') 770 | return cnt, tot, SR 771 | 772 | def odyssey_action_matching_evaluation(pred_output, metric='macro'): 773 | eval_dict = [] 774 | for idx, sample in enumerate(pred_output): 775 | 776 | question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info'] 777 | # sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info} 778 | sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info, 'img':sample['img']} 779 | 780 | gt_simple_info = simple_decode(gt) 781 | gt_action = gt_simple_info['action'] 782 | gt_info = gt_simple_info['info'] 783 | 784 | try: 785 | pred_simple_info = simple_decode(pred, sample['img']) 786 | # print('pred_simple_info:', pred_simple_info) 787 | pred_action = pred_simple_info['action'] 788 | pred_info = pred_simple_info['info'] 789 | except: 790 | # print('### eval err:', idx, pred) 791 | log_info = {'is_correct': 'no', 'info': 'decode invalid'} 792 | sample_eval_dict.update(log_info) 793 | eval_dict.append(sample_eval_dict) 794 | continue 795 | 796 | try: 797 | check_match = action_matching(pred_action, pred_info, gt_action, gt_info) 798 | except Exception as exc: 799 | print('$$$ eval err:', gt, pred, exc) 800 | check_match = {'is_correct': 'no', 'info': 'match invalid'} 801 | 802 | sample_eval_dict.update(check_match) 803 | eval_dict.append(sample_eval_dict) 804 | 805 | # print('===== ',eval_dict) 806 | info = stat_result(eval_dict, metric) 807 | metrics = {"info": info, "pred": eval_dict} 808 | return metrics --------------------------------------------------------------------------------