├── .github ├── scripts │ └── python │ │ └── update_version.py └── workflows │ ├── publish-python.yaml │ └── run-tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── app ├── Results.py ├── data │ └── latex │ │ ├── column_name_map.json │ │ ├── custom │ │ └── appendix │ │ │ ├── column_name_map.json │ │ │ └── index_name_map.json │ │ ├── hide_list.json │ │ ├── index_name_map.json │ │ ├── project_name_map.json │ │ └── shortcut_maps.json └── utils.py ├── assets ├── LlamaAndGPT.png ├── LlamaAndGPTAndMindAct.png ├── WebLINXTestSplits.png ├── WebLlamaLogo.png └── llama-3.jpg ├── docs ├── CONTRIBUTING.md └── README.md ├── examples ├── README.md ├── browsergym │ ├── agent.py │ └── run_bg.py ├── complete │ └── run_all.py └── web_api │ ├── run_client.py │ └── run_http.py ├── modeling ├── README.md ├── dmr │ ├── conf │ │ └── config.yaml │ ├── eval.py │ ├── processing.py │ └── train.py ├── llama │ ├── accelerate │ │ ├── fsdp_2gpus.yaml │ │ ├── fsdp_4gpus.yaml │ │ ├── fsdp_6gpus.yaml │ │ └── fsdp_8gpus.yaml │ ├── conf │ │ └── config.yaml │ ├── eval.py │ ├── processing.py │ └── train.py └── requirements.txt ├── requirements-basic.txt ├── requirements-extra.txt ├── setup.py ├── tests ├── requirements.txt └── test_web_turn_processor.py └── webllama ├── __init__.py ├── experimental ├── __init__.py ├── classes.py ├── formatting.py ├── functions.py ├── integrations │ ├── __init__.py │ └── browsergym │ │ ├── __init__.py │ │ └── functions.py ├── processing.py ├── templates │ ├── __init__.py │ └── weblinx.py └── web │ ├── __init__.py │ ├── client.py │ └── server.py └── version.py /.github/scripts/python/update_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | This CLI script is used to update the version of the package. It is used by the 3 | CI/CD pipeline to update the version of the package when a new release is made. 4 | 5 | It uses argparse to parse the command line arguments, which are the new version 6 | and the path to the package's __init__.py file. 7 | """ 8 | 9 | import argparse 10 | from pathlib import Path 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser( 14 | description="Update the version of the package." 15 | ) 16 | parser.add_argument( 17 | "--version", 18 | type=str, 19 | help="The new version of the package.", 20 | required=True, 21 | ) 22 | parser.add_argument( 23 | "--path", 24 | type=Path, 25 | help="The path to the package's version file.", 26 | ) 27 | args = parser.parse_args() 28 | 29 | with open(args.path, "w") as f: 30 | f.write(f"__version__ = \"{args.version}\"") 31 | 32 | 33 | if __name__ == "__main__": 34 | main() -------------------------------------------------------------------------------- /.github/workflows/publish-python.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Publish Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | bump-version-and-publish: 12 | name: Bump version and upload release to PyPI 13 | 14 | runs-on: ubuntu-latest 15 | permissions: 16 | # IMPORTANT: this permission is mandatory for trusted publishing 17 | id-token: write 18 | 19 | environment: 20 | name: pypi 21 | url: https://pypi.org/p/webllama 22 | 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: '3.10' 29 | 30 | - name: Update version.py with release tag 31 | env: 32 | RELEASE_TAG: ${{ github.event.release.tag_name }} 33 | run: | 34 | python .github/scripts/python/update_version.py --version $RELEASE_TAG --path "webllama/version.py" 35 | 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install setuptools wheel twine 40 | 41 | - name: Build package 42 | run: | 43 | python setup.py sdist bdist_wheel 44 | 45 | - name: Publish package distributions to PyPI 46 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Check out repository 15 | uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: '3.9' # Specify your required Python version 21 | 22 | - name: Cache Python dependencies 23 | uses: actions/cache@v2 24 | with: 25 | path: ~/.cache/pip 26 | key: ${{ runner.os }}-pip-${{ hashFiles('tests/requirements.txt') }} 27 | restore-keys: | 28 | ${{ runner.os }}-pip- 29 | 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r tests/requirements.txt # Assumes you have a requirements.txt file 34 | 35 | - name: Cache test assets 36 | uses: actions/cache@v2 37 | with: 38 | path: tests/demonstrations 39 | key: assets-${{ github.sha }} 40 | restore-keys: | 41 | assets- 42 | 43 | - name: Download test demos from release URL into `tests/demonstrations` 44 | run: | 45 | mkdir -p tests/demonstrations 46 | curl -L -o tests/demonstrations/aaabtsd.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aaabtsd.zip 47 | unzip -u tests/demonstrations/aaabtsd.zip -d tests/demonstrations 48 | curl -L -o tests/demonstrations/aajfwoq.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aajfwoq.zip 49 | unzip -u tests/demonstrations/aajfwoq.zip -d tests/demonstrations 50 | 51 | - name: Run tests 52 | run: | 53 | python -m unittest discover -s tests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # CUSTOM 163 | modeling/checkpoints 164 | modeling/results/ 165 | modeling/results/**/hydra_path.txt 166 | modeling/results/**/hashes.json 167 | modeling/results/**/scores-fta-1.csv 168 | modeling/results/**/results.json 169 | modeling/results/**/eval_scores.csv 170 | modeling/results/dmr/**/scores.jsonl 171 | modeling/wl_data 172 | app/data/inputs.json 173 | modeling/logs/ 174 | venv*/ 175 | .python-version 176 | 177 | # TESTS 178 | tests/demonstrations -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 McGill NLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

🖥️ WebLlama🦙

4 | 5 | Building agents that can browse the web by following instructions and talking to you 6 | 7 | | 💻 [**GitHub**](https://github.com/McGill-NLP/webllama) | 🏠 [**Homepage**](https://webllama.github.io) | 🤗 [**`Llama-3-8B-Web`**](https://huggingface.co/McGill-NLP/Llama-3-8B-Web) | 8 | | :--: | :--: | :--: | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | | `WebLlama` helps you build powerful agents, powered by Meta Llama 3, for browsing the web on your behalf | Our first model, [`Llama-3-8B-Web`](https://huggingface.co/McGill-NLP/Llama-3-8B-Web), surpasses GPT-4V (`*`zero-shot) by 18% on [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) | 17 | |:---: | :---: | 18 | | ![Built with Meta Llama 3](assets/llama-3.jpg) | ![Comparison with GPT-4V](assets/LlamaAndGPT.png) | 19 | 20 | ## About the project 21 | 22 | | `WebLlama` | The goal of our project is to build effective human-centric agents for browsing the web. We don't want to replace users, but equip them with powerful assistants. | 23 | |:---: | :---| 24 | | Modeling | We are build on top of cutting edge libraries for training Llama agents on web navigation tasks. We will provide training scripts, optimized configs, and instructions for training cutting-edge Llamas. | 25 | | Evaluation | Benchmarks for testing Llama models on real-world web browsing. This include *human-centric* browsing through dialogue ([`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/)), and we will soon add more benchmarks for automatic web navigation (e.g. Mind2Web). | 26 | | Data | Our first model is finetuned on over 24K instances of web interactions, including `click`, `textinput`, `submit`, and dialogue acts. We want to continuously curate, compile and release datasets for training better agents. | 27 | | Deployment | We want to make it easy to integrate Llama models with existing deployment platforms, including Playwright, Selenium, and BrowserGym. We are currently focusing on making this a reality. | 28 | 29 | 30 |
31 | Click to show citation
32 | 33 | If you use `WebLlama` in your research, you can cite the ICML 2024 paper upon which the training and evaluation are originally based on, by adding the following to your bibtex file: 34 | 35 | ``` 36 | @misc{lu_2024_weblinx, 37 | title={WebLINX: Real-World Website Navigation with Multi-Turn Dialogue}, 38 | author={Xing Han Lù and Zdeněk Kasner and Siva Reddy}, 39 | year={2024}, 40 | eprint={2402.05930}, 41 | archivePrefix={arXiv}, 42 | primaryClass={cs.CL} 43 | } 44 | ``` 45 | 46 | Example usage (in latex): 47 | 48 | ``` 49 | We use the WebLlama library, which builds on top of WebLINX \citep{lu_2024_weblinx}. 50 | ``` 51 | 52 | ``` 53 | We use Llama-3-8B-Web, a model finetuned on WebLINX demonstrations \citep{lu_2024_weblinx}. 54 | ``` 55 | 56 |
57 | 58 | ## Modeling 59 | 60 | > [!NOTE] 61 | > The model is available on the 🤗 Hugging Face Model Hub as [`McGill-NLP/Llama-3-8B-Web`](https://huggingface.co/McGill-NLP/Llama-3-8B-Web). The training and evaluation data is available on [Hugging Face Hub as `McGill-NLP/WebLINX`](https://huggingface.co/datasets/McGill-NLP/WebLINX). 62 | 63 | Our first agent is a finetuned [`Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model, which was recently released by Meta GenAI team. We have finetuned this model on the [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) dataset, which contains over 100K instances of web navigation and dialogue, each collected and verified by expert annotators. We use a 24K curated subset for training the data. 64 | 65 | ![Comparison of Llama-3-Web, GPT-4V, GPT-3.5 and MindAct](assets/LlamaAndGPTAndMindAct.png) 66 | 67 | **It surpasses GPT-4V (zero-shot `*`) by over 18% on the [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) benchmark**, achieving an overall score of 28.8% on the out-of-domain test splits (compared to 10.5% for GPT-4V). It chooses more useful links (34.1% vs 18.9% *seg-F1*), clicks on more relevant elements (27.1% vs 13.6% *IoU*) and formulates more aligned responses (37.5% vs 3.1% *chr-F1*). 68 | 69 | It's extremely straightforward to use the model via Hugging Face's `transformers`, `datasets` and `hub` libraries: 70 | 71 | ```python 72 | from datasets import load_dataset 73 | from huggingface_hub import snapshot_download 74 | from transformers import pipeline 75 | 76 | # We use validation data, but you can use your own data here 77 | valid = load_dataset("McGill-NLP/WebLINX", split="validation") 78 | snapshot_download("McGill-NLP/WebLINX", repo_type="dataset", allow_patterns="templates/*") 79 | template = open('templates/llama.txt').read() 80 | 81 | # Run the agent on a single state (text representation) and get the action 82 | state = template.format(**valid[0]) 83 | agent = pipeline("McGill-NLP/Llama-3-8b-Web") 84 | out = agent(state, return_full_text=False)[0] 85 | print("Action:", out['generated_text']) 86 | 87 | # Here, you can use the predictions on platforms like playwright or browsergym 88 | action = process_pred(out['generated_text']) # implement based on your platform 89 | env.step(action) # execute the action in your environment 90 | ``` 91 | 92 | ## Evaluation 93 | 94 | We believe short demo videos showing how well an agent performs is NOT enough to judge an agent. Simply put, **we do not know if we have a good agent if we do not have good benchmarks.** We need to systematically evaluate agents on wide range of tasks, spanning from simple instruction-following web navigation to complex dialogue-guided browsing. 95 | 96 | 97 | 98 | This is why we chose [`WebLINX`](https://mcgill-nlp.github.io/weblinx/) as our first benchmark. In addition to the training split, the benchmark has 4 real-world splits, with the goal of testing multiple dimensions of generalization: new websites, new domains, unseen geographic locations, and scenarios where the *user cannot see the screen and relies on dialogue*. It also covers 150 websites, including booking, shopping, writing, knowledge lookup, and even complex tasks like manipulating spreadsheets. Evaluating on this benchmark is very straightforward: 99 | 100 | ```bash 101 | cd modeling/ 102 | 103 | # After installing dependencies, downloading the dataset, and training/evaluating your model, you can evaluate: 104 | python -m weblinx.eval # automatically find all `results.jsonl` and generate an `aggregated_results.json` file 105 | 106 | # Visualize your results with our app: 107 | cd .. 108 | streamlit run app/Results.py 109 | ``` 110 | 111 | > 👷‍♀️ **Next steps**\ 112 | > We are planning to evaluate our models on more benchmarks, including Mind2Web, a benchmark for automatic web navigation. We believe that a good agent should be able to navigate the web both through dialogue and autonomously, and potentially attain even broader ranges of capabilities useful for real-world web browsing. 113 | 114 | 115 | ## Data 116 | 117 | Although the 24K training examples from [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) provide a good starting point for training a capable agent, we believe that more data is needed to train agents that can generalize to a wide range of web navigation tasks. Although it has been trained and evaluated on 150 websites, there are millions of websites that has never been seen by the model, with new ones being created every day. 118 | 119 | **This motivates us to continuously curate, compile and release datasets for training better agents.** As an immediate next step, we will be incorporating `Mind2Web`'s training data into the equation, which also covers over 100 websites. 120 | 121 | > [!NOTE] 122 | > WebLINX is now available as a benchmark through [BrowserGym](https://github.com/ServiceNow/BrowserGym), allowing you to access demonstration steps in the same way you would access a web agent environment like [WebArena](https://webarena.dev/) or [MiniWoB](https://miniwob.farama.org/index.html). This also allows you to run agents from the [Agentlab](https://github.com/ServiceNow/AgentLab) library, including agents that achieve SOTA performance through Claude-3.5-Sonnet. To enable this integration, we are releasing the `weblinx-browsergym` extension for BrowserGym on PyPi, as well as a [new dataset, WebLINX 1.1, derived from WebLINX on Huggingface](https://huggingface.co/datasets/McGill-NLP/weblinx-browsergym). In WebLINX 1.1, a small number of demonstrations were removed after processing, but no new demonstration was added. There are substantial changes to the steps being evaluated, with the inclusion of tab actions. Please report your results as "WebLINX-1.1", "WebLINX-BrowserGym" or "WebLINX-BG" in your work, to differentiate from the [initial release of weblinx (1.0)](https://huggingface.co/datasets/McGill-NLP/WebLINX/tree/v1.0). 123 | 124 | 125 | ## Deployment 126 | 127 | We are working hard to make it easy for you to deploy Llama web agents to the web. We want to integrate `WebLlama` with existing deployment platforms, including Microsoft's Playwright, ServiceNow Research's BrowserGym, and other partners. 128 | 129 | At the moment, we offer the following integrations: 130 | * `Browsergym`: Please find more information in [`examples/README.md`](examples/README.md) and [`docs/README.md`](docs/README.md). 131 | 132 | ## Code 133 | 134 | The code for finetuning the model and evaluating it on the [`WebLINX` 1.0](https://mcgill-nlp.github.io/weblinx/) benchmark is available now. 135 | * **Modeling**: You can find the detailed instructions in [modeling](modeling/README.md) for training `Llama-3-8B-Web` on the `WebLINX` 1.0 dataset. 136 | * **Examples**: We provide a few example for using the `webllama` API and models, including web API, end-to-end, and BrowserGym integration. You can find them in [examples](examples/README.md). 137 | * **App**: We provide a simple Streamlit app for visualizing the results of your model on the `WebLINX` 1.0 benchmark. You can find the code in [app](app/Results.py). 138 | * **Docs**: We provide detailed documentation for the code in [docs](docs/README.md). 139 | 140 | 141 | > 👷‍♀️ **Next steps**\ 142 | > We are actively working on new data and evaluation at the moment! If you want to help, please create an issue describing what you would like to contribute, and we will be happy to help you get started. 143 | 144 | 145 | ## License 146 | 147 | The code in this repository is licensed under the MIT license, unless otherwise specified in the header of the file. Other materials (models, data, images) have their own licenses, which are specified in the original pages. 148 | 149 | ## FAQ 150 | 151 | ### How can I contribute to the project? 152 | 153 | We are actively looking for collaborators to help us build the best Llama-3 web agents! To get started, open an issue about what you would like to contribute, and once it has been discussed, you can submit a pull request. 154 | 155 | 156 | ## Citation 157 | 158 | If you use `WebLlama` in your research, you can cite the ICML 2024 paper upon which the training and evaluation are originally based on, by adding the following to your bibtex file: 159 | 160 | ``` 161 | @misc{lu_2024_weblinx, 162 | title={WebLINX: Real-World Website Navigation with Multi-Turn Dialogue}, 163 | author={Xing Han Lù and Zdeněk Kasner and Siva Reddy}, 164 | year={2024}, 165 | eprint={2402.05930}, 166 | archivePrefix={arXiv}, 167 | primaryClass={cs.CL} 168 | } 169 | ``` 170 | 171 | Example usage (in latex): 172 | 173 | ``` 174 | We use the WebLlama library, which builds on top of WebLINX \citep{lu_2024_weblinx}. 175 | ``` 176 | 177 | ``` 178 | We use Llama-3-8B-Web, a model finetuned on WebLINX demonstrations \citep{lu_2024_weblinx}. 179 | ``` 180 | -------------------------------------------------------------------------------- /app/Results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from collections import defaultdict 3 | import os 4 | import time 5 | from datetime import datetime 6 | import json 7 | import random 8 | import string 9 | import shutil 10 | import traceback 11 | import sys 12 | from pathlib import Path 13 | import textwrap as tw 14 | 15 | import streamlit as st 16 | from PIL import Image, ImageDraw 17 | import pandas as pd 18 | 19 | import weblinx as wt 20 | 21 | 22 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 23 | 24 | from app.utils import show_overlay 25 | 26 | def remove_latex(x): 27 | if isinstance(x, tuple): 28 | return tuple(map(remove_latex, x)) 29 | 30 | if not isinstance(x, str): 31 | return x 32 | 33 | if x.startswith("Unnamed"): 34 | return "" 35 | 36 | if "}" not in x: 37 | return x 38 | 39 | return x.rpartition("}")[0].partition("{")[2] 40 | 41 | 42 | def load_and_clean_df(path): 43 | df = pd.read_csv(path, index_col=0, header=[0, 1]) 44 | 45 | df.index.name = "model" 46 | df.columns = df.columns.map(remove_latex) 47 | df.index = df.index.map(remove_latex) 48 | 49 | df = df.reset_index().set_index(["model", "intent"]) 50 | 51 | return df 52 | 53 | 54 | def build_cond_df(base_dir="analysis/data/tables/"): 55 | base_dir = Path(base_dir).resolve() 56 | tables_dir = base_dir / "results" 57 | 58 | cond_df = pd.concat( 59 | axis=1, 60 | objs=[ 61 | load_and_clean_df(tables_dir / "results_general_grouped_intents.csv"), 62 | load_and_clean_df(tables_dir / "results_text_grouped_intents.csv"), 63 | load_and_clean_df(tables_dir / "results_elem_grouped_intents.csv"), 64 | ], 65 | ) 66 | 67 | return cond_df 68 | 69 | 70 | def build_uncond_df(base_dir="analysis/data/tables/"): 71 | base_dir = Path(base_dir).resolve() 72 | tables_dir = base_dir / "results_unconditional" 73 | 74 | uncond_df = pd.concat( 75 | axis=1, 76 | objs=[ 77 | load_and_clean_df( 78 | base_dir / "results" / "results_general_grouped_intents.csv" 79 | ), 80 | load_and_clean_df(tables_dir / "results_text_grouped_intents.csv"), 81 | load_and_clean_df(tables_dir / "results_elem_grouped_intents.csv"), 82 | ], 83 | ) 84 | 85 | return uncond_df 86 | 87 | 88 | @st.cache_data(ttl=60 * 2) 89 | def build_dataframe(score_path, choice): 90 | with open(score_path) as f: 91 | scores = json.load(f) 92 | 93 | for score in scores: 94 | replacements = [ 95 | ("website", "test-web"), 96 | ("blind", "test-vis"), 97 | ("subcategory", "test-cat"), 98 | ("geography", "test-geo"), 99 | ("dev", "dev-deprecated"), 100 | ("indomain", "test-indomain"), 101 | ] 102 | for original, new in replacements: 103 | score["split"] = score["split"].replace(original, new) 104 | df = pd.DataFrame(scores) 105 | 106 | if choice != "Conditional": 107 | df["score"] = df["unconditional_score"] 108 | df.pop("unconditional_score") 109 | 110 | dff = df.pivot( 111 | index=["intent", "project_name", "model_name"], 112 | columns=["split", "metric"], 113 | values="score", 114 | ) 115 | 116 | return dff 117 | 118 | 119 | def add_test_avg_inplace(df, fillna_with=0): 120 | splits = df.columns.get_level_values(0).unique().tolist() 121 | test_splits = [x for x in splits if x.startswith("test") and not x.endswith("iid")] 122 | metrics = df.columns.get_level_values(1).unique().tolist() 123 | # We need to take the mean of the test splits for each metric in metrics, we call this test-avg 124 | for metric in metrics: 125 | test_scores = [df[(split, metric)] for split in test_splits] 126 | test_df = pd.concat(test_scores, axis=1) 127 | if fillna_with is not None: 128 | test_df = test_df.fillna(fillna_with) 129 | 130 | df[("test-avg", metric)] = test_df.mean(axis=1) 131 | 132 | 133 | def preset_to_values(): 134 | return { 135 | "All approximate": { 136 | "intent": [ 137 | "overall", 138 | # "change", 139 | "click", 140 | "load", 141 | "say", 142 | # "scroll", 143 | "submit", 144 | "textinput", 145 | ], 146 | "metric": ["overall", "iou", "chrf", "urlf"], 147 | }, 148 | "Group Approximate": { 149 | "intent": [ 150 | "overall", 151 | "text-group", 152 | "element-group", 153 | ], 154 | "metric": ["overall", "intent-match", "iou", "chrf-urlf"], 155 | }, 156 | "All intent-match": { 157 | "intent": [ 158 | "change", 159 | "click", 160 | "load", 161 | "say", 162 | "scroll", 163 | "submit", 164 | "textinput", 165 | ], 166 | "metric": ["intent-match"], 167 | }, 168 | "change": { 169 | "intent": ["change"], 170 | "metric": ["intent-match", "iou"], 171 | }, 172 | "click": { 173 | "intent": ["click"], 174 | "metric": ["intent-match", "iou"], 175 | }, 176 | "say": { 177 | "intent": ["say"], 178 | "metric": ["intent-match", "chrf"], 179 | }, 180 | "textinput": { 181 | "intent": ["textinput"], 182 | "metric": ["intent-match", "iou", "chrf"], 183 | }, 184 | "load": { 185 | "intent": ["load"], 186 | "metric": ["intent-match", "urlf"], 187 | }, 188 | "submit": { 189 | "intent": ["submit"], 190 | "metric": ["intent-match", "iou"], 191 | }, 192 | 193 | } 194 | 195 | def latex_sort_func(name): 196 | if not (name.endswith('B') or name.endswith("M")): 197 | return name, 0 198 | 199 | left, sep, right = name.rpartition("-") 200 | 201 | num = float(right[:-1]) 202 | 203 | if right.endswith("B"): 204 | rest = 1e9 205 | elif right.endswith("M"): 206 | rest = 1e6 207 | else: 208 | rest = 1 209 | 210 | num = num * rest 211 | 212 | if left.startswith("MindAct"): 213 | left = 0 214 | elif left.startswith("Flan"): 215 | left = 1 216 | elif left.startswith("Pix2Struct"): 217 | left = 2 218 | elif left.startswith("Fuyu"): 219 | left = 3 220 | elif left.startswith("Sheared"): 221 | left = 4 222 | elif left.startswith("Llama"): 223 | left = 5 224 | elif left.startswith("GPT"): 225 | left = 6 226 | 227 | return left, num 228 | 229 | 230 | 231 | 232 | @st.cache_data(ttl=60 * 2) 233 | def filter_models_by_project(projects, df): 234 | # reset all indices except project_name 235 | df = df.copy().reset_index().set_index("project_name") 236 | # filter by project 237 | rem_models = df.loc[projects]["model_name"].unique().tolist() 238 | return rem_models 239 | 240 | 241 | def run(score_path="modeling/results/aggregated_scores.json"): 242 | st.title("Results Table Viewer") 243 | 244 | presets = preset_to_values() 245 | 246 | # Either choose cond or uncond 247 | with st.sidebar: 248 | use_two_cols = st.checkbox( 249 | "Use two columns", value=True, help="Use two columns for the dropdowns" 250 | ) 251 | 252 | pivot_intent_index = st.checkbox( 253 | "Show intent as column", 254 | value=True, 255 | help="Whether to show intent as column, or keep it as index", 256 | ) 257 | 258 | choice = st.radio( 259 | "Results wrt matched intent", 260 | ["Conditional", "Unconditional"], 261 | help=( 262 | "Conditional: only count samples where the predicted intent matches the reference " 263 | "intent (when there is no match, the sample is discarded)" 264 | "Unconditional: counts all samples (when there is no match, the score is set to 0)" 265 | ), 266 | index=1, 267 | ) 268 | 269 | preset_choice = st.selectbox("Metric/Intent Preset", list(presets.keys()), index=0) 270 | 271 | remove_na = st.checkbox( 272 | "Drop cols with only NaN", value=True 273 | ) 274 | 275 | remove_zero = st.checkbox( 276 | "Drop cols with only 0", value=True 277 | ) 278 | 279 | if use_two_cols: 280 | col1, col2 = st.columns(2) 281 | else: 282 | col1 = col2 = st.columns(1)[0] 283 | 284 | 285 | df = build_dataframe(score_path, choice) 286 | 287 | add_test_avg_inplace(df) 288 | 289 | splits = df.columns.get_level_values("split").unique().tolist().copy() 290 | metrics = df.columns.get_level_values("metric").unique().tolist() 291 | models = df.index.get_level_values("model_name").unique().tolist() 292 | intents = df.index.get_level_values("intent").unique().tolist() 293 | projects = df.index.get_level_values("project_name").unique().tolist() 294 | 295 | default_splits = ["valid"] 296 | default_intents = presets[preset_choice]["intent"] 297 | default_metrics = presets[preset_choice]["metric"] 298 | 299 | default_projects = ["llama_ft"] 300 | 301 | splits = col1.multiselect("Split", splits, default=default_splits) 302 | metrics = col1.multiselect("Metric", metrics, default=default_metrics) 303 | intents = col2.multiselect("Intent", intents, default=default_intents) 304 | sort_by_container = col2.container() 305 | projects = col1.multiselect("Project", projects, default=default_projects) 306 | 307 | remaining_models = filter_models_by_project(projects=projects, df=df) 308 | models = col2.multiselect("Model", remaining_models, default=remaining_models) 309 | 310 | if len(projects) == 0: 311 | st.error("Please select at least one project") 312 | st.stop() 313 | 314 | if len(models) == 0: 315 | st.error("Please select at least one model") 316 | st.stop() 317 | 318 | cols = pd.MultiIndex.from_product([splits, metrics], names=["split", "metric"]) 319 | # remove all cols not in dff 320 | cols = cols.intersection(df.columns) 321 | 322 | idx = pd.MultiIndex.from_product( 323 | [intents, projects, models], names=["intent", "project_name", "model_name"] 324 | ) 325 | # remove all idx not in dff 326 | idx = idx.intersection(df.index) 327 | 328 | dff = df.loc[idx, cols] 329 | 330 | if pivot_intent_index: 331 | dff = dff.reset_index("intent").pivot(columns="intent") 332 | 333 | if remove_na: 334 | dff = dff.dropna(axis=1, how="all") 335 | if remove_zero: 336 | dff = dff.loc[:, (dff != 0).any(axis=0)] 337 | 338 | with sort_by_container: 339 | sort_by = st.selectbox("Sort by", dff.columns.tolist()) 340 | 341 | # Sort by 342 | if sort_by: 343 | dff = dff.sort_values(sort_by, ascending=False) 344 | 345 | # swap order of column indices so that we have, in order, split, intent, metric 346 | dff = dff.swaplevel(1,2, axis=1) 347 | 348 | 349 | with st.expander("Latex Table"): 350 | dropped_col_indices = st.multiselect( 351 | "Drop columns levels", dff.columns.names, default=['split'] 352 | ) 353 | use_shorthand = st.checkbox("Use shorthand", value=False) 354 | use_custom_sorting = st.checkbox("Use custom sorting", value=True) 355 | remove_column_names = st.checkbox("Remove column names", value=True) 356 | merge_index = st.checkbox("Merge index", value=False) 357 | 358 | # dropdown 359 | custom_index_names = st.selectbox( 360 | "Custom index names", ["None", "Appendix"], index=0 361 | ) 362 | 363 | custom_column_names = st.selectbox( 364 | "Custom column names", ["None", "Appendix"], index=0 365 | ) 366 | 367 | # Rename metrics to symbols for latex 368 | with open("app/data/latex/column_name_map.json") as f: 369 | column_name_map = json.load(f) 370 | 371 | with open("app/data/latex/index_name_map.json") as f: 372 | index_name_map = json.load(f) 373 | 374 | with open("app/data/latex/hide_list.json") as f: 375 | hide_list = json.load(f) 376 | with open("app/data/latex/shortcut_maps.json") as f: 377 | shortcut_maps = json.load(f) 378 | 379 | if custom_index_names == "Appendix": 380 | with open("app/data/latex/custom/appendix/index_name_map.json") as f: 381 | custom_index_name_map = json.load(f) 382 | # update index name map with custom names 383 | index_name_map.update(custom_index_name_map) 384 | 385 | if custom_column_names == "Appendix": 386 | with open("app/data/latex/custom/appendix/column_name_map.json") as f: 387 | custom_column_name_map = json.load(f) 388 | # update column name map with custom names 389 | column_name_map.update(custom_column_name_map) 390 | 391 | # Remove rows from dff_latex if the project_name index and model_name index are in hide_list 392 | dff_latex = dff.copy() 393 | dff_latex = dff_latex.reset_index() 394 | for project_name, model_name in hide_list: 395 | dff_latex = dff_latex[~((dff_latex["project_name"] == project_name) & (dff_latex["model_name"] == model_name))] 396 | dff_latex = dff_latex.set_index(["project_name", "model_name"]) 397 | dff_latex = dff_latex.rename(columns=column_name_map) 398 | dff_latex = dff_latex.rename(index=index_name_map) 399 | 400 | 401 | for i in dropped_col_indices: 402 | dff_latex.columns = dff_latex.columns.droplevel(i) 403 | 404 | # Convert all column index names to Capitalized 405 | dff_latex.columns.names = [x.capitalize() for x in dff_latex.columns.names] 406 | 407 | # Sort by index level 0 408 | if use_custom_sorting: 409 | dff_latex = dff_latex.sort_index( 410 | key=lambda index: index.map(latex_sort_func), ascending=True, 411 | ) 412 | else: 413 | dff_latex = dff_latex.sort_index(ascending=True) 414 | 415 | if merge_index: 416 | # Join the multiindex into a single index separated by - 417 | dff_latex.index = dff_latex.index.map(lambda x: " - ".join(x)) 418 | 419 | if remove_column_names: 420 | dff_latex.columns.names = [None] * len(dff_latex.columns.names) 421 | 422 | # should be at 4 decimal places 423 | # Multiply by 100 to get percentage 424 | dff_latex = dff_latex * 100 425 | dff_latex = dff_latex.to_latex(float_format="{:0.2f}".format) 426 | 427 | if use_shorthand: 428 | for full_value, shortcut in shortcut_maps.items(): 429 | dff_latex = dff_latex.replace(full_value, shortcut) 430 | 431 | st.code(dff_latex, language="latex") 432 | 433 | with st.expander("Markdown Table"): 434 | st.code(dff.round(4).to_markdown(), language="markdown") 435 | 436 | with st.expander("Results Table", expanded=True): 437 | st.table(dff) 438 | 439 | # Best models 440 | if not pivot_intent_index: 441 | # We need to pivot the table to get the best model per intent 442 | dff = dff.reset_index("intent").pivot(columns="intent") 443 | 444 | best_df = pd.concat([dff.idxmax(), dff.max()], axis=1) 445 | # Set name of best_df.multiindex indexes 446 | best_df.columns = ["best_model", "best_score"] 447 | 448 | best_df["project_name"] = best_df["best_model"].apply(lambda x: x[0] if isinstance(x, tuple) else x) 449 | best_df["best_model"] = best_df["best_model"].apply(lambda x: x[1] if isinstance(x, tuple) else x) 450 | # Reorder columns 451 | best_df = best_df[["project_name", "best_model", "best_score"]] 452 | 453 | with st.expander("Best Models", expanded=True): 454 | st.table(best_df) 455 | 456 | 457 | if __name__ == "__main__": 458 | try: 459 | st.set_page_config(layout="wide") 460 | except: 461 | pass 462 | # run = protect_with_authentication(run) 463 | run() 464 | -------------------------------------------------------------------------------- /app/data/latex/column_name_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "chrf": "chrF", 3 | "intent-match": "IM", 4 | "click": "\\texttt{click}", 5 | "submit": "\\texttt{submit}", 6 | "change": "\\texttt{change}", 7 | "textinput": "\\texttt{textinput}", 8 | "load": "\\texttt{load}", 9 | "say": "\\texttt{say}", 10 | "chrf-urlf": "SeqF", 11 | "urlf": "URLF", 12 | "iou": "IoU", 13 | "element-group": "Element Group", 14 | "text-group": "Text Group", 15 | "overall": "Overall" 16 | } -------------------------------------------------------------------------------- /app/data/latex/custom/appendix/column_name_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /app/data/latex/custom/appendix/index_name_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "flan_m2w": "SFT", 3 | "flan_mht_v2": "SFT" 4 | } -------------------------------------------------------------------------------- /app/data/latex/hide_list.json: -------------------------------------------------------------------------------- 1 | [ 2 | ["openai", "HuggingFaceH4/zephyr-7b-beta"], 3 | ["llama_fft_mht", "mistralai/Mistral-7B-Instruct-v0.1"], 4 | ["flan_m2w", "google/flan-t5-large"], 5 | ["flan_m2w", "google/flan-t5-base"], 6 | ["flan_mht_v2", "osunlp/MindAct_ActionPrediction_flan-t5-large"], 7 | ["flan_mht_v2", "osunlp/MindAct_ActionPrediction_flan-t5-base"] 8 | ] -------------------------------------------------------------------------------- /app/data/latex/index_name_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "flan_m2w": "M2W", 3 | "flan_mht_v2": "OTR", 4 | "openai": "0S", 5 | "llama_fft_mht": "SFT", 6 | "google/flan-t5-base": "Flan-T5-250M", 7 | "google/flan-t5-large": "Flan-T5-780M", 8 | "google/flan-t5-xl": "Flan-T5-3B", 9 | "osunlp/MindAct_ActionPrediction_flan-t5-base": "MindAct-T5-250M", 10 | "osunlp/MindAct_ActionPrediction_flan-t5-large": "MindAct-T5-780M", 11 | "osunlp/MindAct_ActionPrediction_flan-t5-xl": "MindAct-T5-3B", 12 | "princeton-nlp/Sheared-LLaMA-1.3B": "Sheared-LLaMA-1.3B", 13 | "princeton-nlp/Sheared-LLaMA-2.7B": "Sheared-LLaMA-2.7B", 14 | "meta-llama/Llama-2-7b-chat-hf": "Llama-2-7B", 15 | "meta-llama/Llama-2-13b-chat-hf": "Llama-2-13B", 16 | "gpt-3.5-turbo-1106": "GPT-3.5T", 17 | "gpt-4-1106-preview": "GPT-4T", 18 | "gpt-4-vision-preview": "GPT-4V", 19 | "google/pix2struct-base": "Pix2Struct-282M", 20 | "google/pix2struct-large": "Pix2Struct-1.3B", 21 | "adept/fuyu-8b": "Fuyu-8B", 22 | "fuyu": "SFT", 23 | "pix2struct": "SFT", 24 | "ft:gpt-3.5-turbo-1106:mcgill-nlp:webtasks-mht:8XWKFM3a": "GPT-3.5F" 25 | } -------------------------------------------------------------------------------- /app/data/latex/project_name_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "flan_m2w": "MindAct" 3 | } -------------------------------------------------------------------------------- /app/data/latex/shortcut_maps.json: -------------------------------------------------------------------------------- 1 | { 2 | "Sheared-LLaMA": "S-LLaMA", 3 | "MindAct-T5": "MindAct", 4 | "submit": "sbmt", 5 | "textinput": "input", 6 | "Overall": "All", 7 | "URLF": "urlF", 8 | "\\multicolumn{2}{r}{All}": "All & All", 9 | "Text Group": "TG", 10 | "Element Group": "EG" 11 | } -------------------------------------------------------------------------------- /app/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from datetime import datetime 4 | import os 5 | import json 6 | from pathlib import Path 7 | import sys 8 | import shutil 9 | import time 10 | import traceback 11 | 12 | import pandas as pd 13 | import streamlit as st 14 | 15 | import json 16 | from PIL import Image, ImageDraw 17 | 18 | 19 | CACHE_TTL = 60 * 60 * 24 * 14 20 | 21 | """ 22 | Streamlit app utilities 23 | """ 24 | 25 | 26 | @st.cache_data(ttl=CACHE_TTL) 27 | def load_json(basedir, name): 28 | if not os.path.exists(f"{basedir}/{name}.json"): 29 | return None 30 | 31 | with open(f"{basedir}/{name}.json", "r") as f: 32 | j = json.load(f) 33 | 34 | return j 35 | 36 | 37 | def load_json_no_cache(basedir, name): 38 | if not os.path.exists(f"{basedir}/{name}.json"): 39 | return None 40 | 41 | with open(f"{basedir}/{name}.json", "r") as f: 42 | j = json.load(f) 43 | 44 | return j 45 | 46 | 47 | def save_json(basedir, name, data): 48 | with open(f"{basedir}/{name}.json", "w") as f: 49 | json.dump(data, f, indent=4) 50 | 51 | 52 | @st.cache_data 53 | def load_image(image_file): 54 | img = Image.open(image_file) 55 | return img 56 | 57 | 58 | @st.cache_resource 59 | def load_page(page_path): 60 | return open(page_path, "rb") 61 | 62 | 63 | def shorten(s): 64 | # shorten to 100 characters 65 | if len(s) > 100: 66 | s = s[:100] + "..." 67 | 68 | return s 69 | 70 | 71 | @st.cache_data 72 | def parse_arguments(action): 73 | s = [] 74 | event_type = action["intent"] 75 | args = action["arguments"] 76 | 77 | if event_type == "textInput": 78 | txt = args["text"] 79 | 80 | txt = txt.strip() 81 | 82 | # escape markdown characters 83 | txt = txt.replace("_", "\\_") 84 | txt = txt.replace("*", "\\*") 85 | txt = txt.replace("`", "\\`") 86 | txt = txt.replace("$", "\\$") 87 | 88 | txt = shorten(txt) 89 | 90 | s.append(f'"{txt}"') 91 | elif event_type == "change": 92 | s.append(f'{args["value"]}') 93 | elif event_type == "load": 94 | url = args["properties"].get("url") or args.get("url") 95 | short_url = shorten(url) 96 | s.append(f'"[{short_url}]({url})"') 97 | 98 | if args["properties"].get("transitionType"): 99 | s.append(f'*{args["properties"]["transitionType"]}*') 100 | s.append(f'*{" ".join(args["properties"]["transitionQualifiers"])}*') 101 | elif event_type == "scroll": 102 | s.append(f'{args["scrollX"]}, {args["scrollY"]}') 103 | elif event_type == "say": 104 | s.append(f'"{args["text"]}"') 105 | elif event_type == "copy": 106 | selected = shorten(args["selected"]) 107 | s.append(f'"{selected}"') 108 | elif event_type == "paste": 109 | pasted = shorten(args["pasted"]) 110 | s.append(f'"{pasted}"') 111 | elif event_type == "tabcreate": 112 | s.append(f'{args["properties"]["tabId"]}') 113 | elif event_type == "tabremove": 114 | s.append(f'{args["properties"]["tabId"]}') 115 | elif event_type == "tabswitch": 116 | s.append( 117 | f'{args["properties"]["tabIdOrigin"]} -> {args["properties"]["tabId"]}' 118 | ) 119 | 120 | if args.get("element"): 121 | 122 | if event_type == 'click': 123 | x = round(args['metadata']['mouseX'], 1) 124 | y = round(args['metadata']['mouseY'], 1) 125 | uid = args.get('element', {}).get('attributes', {}).get("data-webtasks-id") 126 | s.append(f"*x =* {x}, *y =* {y}, *uid =* {uid}") 127 | else: 128 | top = round(args["element"]["bbox"]["top"], 1) 129 | left = round(args["element"]["bbox"]["left"], 1) 130 | right = round(args["element"]["bbox"]["right"], 1) 131 | bottom = round(args["element"]["bbox"]["bottom"], 1) 132 | 133 | s.append(f"*top =* {top}, *left =* {left}, *right =* {right}, *bottom =* {bottom}") 134 | 135 | return ", ".join(s) 136 | 137 | 138 | @st.cache_resource(max_entries=50_000, ttl=CACHE_TTL) 139 | def create_visualization(_img, event_type, bbox, x, y, screenshot_path): 140 | # screenshot_path is not used, but we need it for caching since we can't cache 141 | # PIL images (hence the leading underscore in the variable name to indicate 142 | # that it's not hashed) 143 | _img = _img.convert("RGBA") 144 | draw = ImageDraw.Draw(_img) 145 | 146 | # draw a bounding box around the element 147 | color = { 148 | "click": "red", 149 | "hover": "orange", 150 | "textInput": "blue", 151 | "change": "green", 152 | "submit": "purple", 153 | }[event_type] 154 | 155 | left = bbox["left"] 156 | top = bbox["top"] 157 | w = bbox["width"] 158 | h = bbox["height"] 159 | draw.rectangle((left, top, left + w, top + h), outline=color, width=2) 160 | 161 | if event_type in ["click", "hover"]: 162 | r = 15 163 | for i in range(1, 5): 164 | rx = r * i 165 | draw.ellipse((x - rx, y - rx, x + rx, y + rx), outline=color, width=3) 166 | draw.ellipse((x - r, y - r, x + r, y + r), fill=color) 167 | 168 | return _img 169 | 170 | 171 | @st.cache_data(max_entries=50_000, ttl=CACHE_TTL) 172 | def get_screenshot_minimal(screenshot_path, event_type, bbox, x, y, new_width=None, overlay=True): 173 | img = load_image(screenshot_path) 174 | # vis = None 175 | 176 | if event_type in ["click", "textInput", "change", "hover", "submit"] and overlay: 177 | img = create_visualization(img, event_type, bbox, x, y, screenshot_path) 178 | 179 | if new_width is not None: 180 | # Resize to 800px wide 181 | w, h = img.size 182 | new_w = new_width 183 | new_h = int(new_w * h / w) 184 | img = img.resize((new_w, new_h)) 185 | print(f"Resized '{screenshot_path}' to", new_w, new_h) 186 | 187 | return img 188 | 189 | 190 | def get_event_info(d): 191 | event_type = d["action"]["intent"] 192 | 193 | try: 194 | bbox = d["action"]["arguments"]["element"]["bbox"] 195 | except KeyError: 196 | bbox = None 197 | 198 | try: 199 | x = d["action"]["arguments"]["properties"]["x"] 200 | y = d["action"]["arguments"]["properties"]["y"] 201 | except KeyError: 202 | x = None 203 | y = None 204 | 205 | return event_type, bbox, x, y 206 | 207 | 208 | def get_screenshot(d, basedir, new_width=None, overlay=True): 209 | screenshot_filename = d["state"]["screenshot"] 210 | 211 | if not screenshot_filename: 212 | return None 213 | 214 | event_type, bbox, x, y = get_event_info(d) 215 | screenshot_path = f"{basedir}/screenshots/{screenshot_filename}" 216 | 217 | return get_screenshot_minimal( 218 | screenshot_path, event_type, bbox, x, y, new_width=new_width, overlay=overlay 219 | ) 220 | 221 | 222 | def text_bubble(text, color): 223 | text = text.replace("\n", "
").replace("\t", " " * 8) 224 | return f'
{text}
' 225 | 226 | 227 | def gather_chat_history(data, example_index): 228 | chat = [] 229 | for i, d in enumerate(data): 230 | if d["type"] == "chat": 231 | if i >= example_index: 232 | break 233 | chat.append(d) 234 | 235 | # # leave out just 5 last messages 236 | # if len(chat) > 5: 237 | # chat = chat[-5:] 238 | 239 | return reversed(chat) 240 | 241 | 242 | def format_chat_message(d): 243 | if d["speaker"] == "instructor": 244 | return text_bubble("🧑 " + d["utterance"], "rgba(63, 111, 255, 0.35)") 245 | else: 246 | return text_bubble("🤖 " + d["utterance"], "rgba(185,185,185,0.35)") 247 | 248 | 249 | def find_screenshot(data, example_index, basedir, overlay=True): 250 | # keep looking at previous screenshots until we find one 251 | # if there is none, return None 252 | 253 | for i in range(example_index, -1, -1): 254 | d = data[i] 255 | if d["type"] == "chat": 256 | continue 257 | 258 | screenshot = get_screenshot(d, basedir, overlay=overlay) 259 | if screenshot: 260 | return screenshot 261 | 262 | return None 263 | 264 | 265 | def create_visualization_2(_img, bbox, color, width, x, y): 266 | _img = _img.convert("RGBA") 267 | draw = ImageDraw.Draw(_img) 268 | 269 | if bbox: 270 | left = bbox["left"] 271 | top = bbox["top"] 272 | w = bbox["width"] 273 | h = bbox["height"] 274 | draw.rectangle((left, top, left + w, top + h), outline=color, width=width) 275 | 276 | if x and y: 277 | r = 8 278 | for i in range(1, 4): 279 | rx = r * i 280 | draw.ellipse((x - rx, y - rx, x + rx, y + rx), outline=color, width=2) 281 | draw.ellipse((x - r, y - r, x + r, y + r), fill=color) 282 | 283 | return _img 284 | 285 | 286 | def rescale_bbox(bbox, scaling_factor): 287 | return { 288 | k: bbox[k] * scaling_factor 289 | for k in ["top", "left", "width", "height", "right", "bottom"] 290 | if k in bbox 291 | } 292 | 293 | 294 | def show_overlay( 295 | _img, 296 | pred, 297 | ref, 298 | turn_args, 299 | turn_metadata, 300 | scale_pred=True, 301 | show=("pred_coords", "ref", "pred_elem"), 302 | ): 303 | scaling_factor = turn_metadata.get("zoomLevel", 1.0) 304 | 305 | if "pred_elem" in show: 306 | # First, draw red box around predicted element 307 | if pred.get("element") and pred["element"].get("bbox"): 308 | # rescale the bbox by scaling_factor 309 | bbox = rescale_bbox(pred["element"]["bbox"], scaling_factor) 310 | _img = create_visualization_2( 311 | _img, bbox, color="red", width=9, x=None, y=None 312 | ) 313 | 314 | if "ref" in show: 315 | # Finally, draw a blue box around the reference element (if it exists) 316 | if ref.get("element") and ref["element"].get("bbox"): 317 | # rescale the bbox 318 | bbox = rescale_bbox(ref["element"]["bbox"], scaling_factor) 319 | x = turn_args.get("properties", {}).get("x") 320 | y = turn_args.get("properties", {}).get("y") 321 | _img = create_visualization_2(_img, bbox, color="blue", width=6, x=x, y=y) 322 | 323 | if "pred_coords" in show: 324 | # Second draw a green box and x/y coordinate based on predicted coordinates 325 | # The predicted coordinates are the raw output of the model, 326 | # Whereas the predicted element is the inferred element from the predicted coordinates 327 | if pred["args"].get("x") and pred["args"].get("y"): 328 | x = pred["args"]["x"] 329 | y = pred["args"]["y"] 330 | 331 | if scale_pred: 332 | x = x * scaling_factor 333 | y = y * scaling_factor 334 | else: 335 | x = None 336 | y = None 337 | 338 | # If the predicted element is a bounding box, draw a green box around it 339 | if all(c in pred["args"] for c in ["top", "left", "right", "bottom"]): 340 | bbox = { 341 | "top": pred["args"]["top"], 342 | "left": pred["args"]["left"], 343 | "width": (pred["args"]["right"] - pred["args"]["left"]), 344 | "height": (pred["args"]["bottom"] - pred["args"]["top"]), 345 | "right": pred["args"]["right"], 346 | "bottom": pred["args"]["bottom"], 347 | } 348 | 349 | if scale_pred: 350 | bbox = rescale_bbox(bbox, scaling_factor) 351 | else: 352 | # Otherwise, do nothing 353 | bbox = None 354 | 355 | _img = create_visualization_2(_img, bbox=bbox, color="green", width=3, x=x, y=y) 356 | 357 | return _img 358 | 359 | 360 | 361 | def get_zoom_level(d): 362 | """ 363 | Get the zoom level of the page 364 | """ 365 | 366 | # If it's type chat, we just set the zoom level to 1 and ignore 367 | if d["type"] == "chat": 368 | return 100 369 | 370 | # the zoom level is in the state of the turn 371 | # d is the turn 372 | # the zoom level is in d['state']['zoom'] 373 | # if it is not present, return 100 374 | option1 = ( 375 | d.get("action", {}) 376 | .get("arguments", {}) 377 | .get("properties", {}) 378 | .get("zoomLevel") 379 | ) 380 | option2 = ( 381 | d.get("action", {}) 382 | .get("arguments", {}) 383 | .get('metadata', {}) 384 | .get("zoomLevel") 385 | ) 386 | 387 | if option1 is not None: 388 | return option1 389 | elif option2 is not None: 390 | return option2 391 | else: 392 | raise ValueError("Zoom level not found in the turn.") -------------------------------------------------------------------------------- /assets/LlamaAndGPT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/LlamaAndGPT.png -------------------------------------------------------------------------------- /assets/LlamaAndGPTAndMindAct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/LlamaAndGPTAndMindAct.png -------------------------------------------------------------------------------- /assets/WebLINXTestSplits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/WebLINXTestSplits.png -------------------------------------------------------------------------------- /assets/WebLlamaLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/WebLlamaLogo.png -------------------------------------------------------------------------------- /assets/llama-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/llama-3.jpg -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Instructions 2 | 3 | ## Running tests 4 | 5 | To run the unit tests, run: 6 | 7 | ```bash 8 | python -m unittest discover -s tests 9 | ``` -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # `webllama.experimental` API 2 | 3 | `webllama.experimental` is the new experimental API for working with webllama models. It will eventually be moved to `webllama` directly (once the API is deemed stable). 4 | 5 | 6 | ## Setup 7 | 8 | ```bash 9 | # Please choose the proper version to ensure you do not break the code 10 | # if there are breaking changes in the future. 11 | # e.g. 0.1.0 12 | pip install webllama=="" 13 | ``` 14 | 15 | You will need to download test demonstrations if you want to run the subsequent scripts that use existing weblinx demonstrations. 16 | 17 | ```bash 18 | mkdir -p tests/demonstrations 19 | curl -L -o tests/demonstrations/aaabtsd.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aaabtsd.zip 20 | unzip -u tests/demonstrations/aaabtsd.zip -d tests/demonstrations 21 | curl -L -o tests/demonstrations/aajfwoq.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aajfwoq.zip 22 | unzip -u tests/demonstrations/aajfwoq.zip -d tests/demonstrations 23 | ``` 24 | 25 | ## Quickstart with `webllama.experimental.processing` 26 | 27 | To install: 28 | ```bash 29 | pip install webllama 30 | # if you want to install transformers, pytorch and sentence-transformers, run: 31 | pip install webllama[modeling] 32 | ``` 33 | 34 | First, you will need to construct your own `action_history` and `state` using `webllama.experimental.classes`: 35 | ```python 36 | import webllama.experimental as wa 37 | 38 | # Create your action history and state! 39 | action_history = [ 40 | wa.classes.Action(...), # ... 41 | ] 42 | state = wa.classes.State(...) 43 | ``` 44 | 45 | You will also need to load your `dmr` and `act_model` models. For example, you can use `transformers` and `sentence-transformers` to load them: 46 | ```python 47 | from sentence_transformers import SentenceTransformer 48 | from transformers import AutoTokenizer, pipeline 49 | 50 | # You can choose your own DMR model, and action model 51 | act_model = pipeline(model=action_model_name, device=0, torch_dtype="auto") 52 | dmr = SentenceTransformer(dmr_name, device="cuda") 53 | ``` 54 | 55 | Now, inside a Python script, you can use the `webllama.experimental.processing` to seamlessly use `Action` and `State` with action model and DMR, and also process the output: 56 | 57 | ```python 58 | import webllama.experimental as wa 59 | 60 | # We will initialize our processor, which helps us prepare the input for action model 61 | proc = wa.processing.WebTurnProcessor(tokenizer=act_model.tokenizer) 62 | 63 | # Step 1: prepare query, run DMR and prepare retrieved candidates 64 | query_dmr = proc.prepare_dmr_query(action_history, state) 65 | elems = proc.prepare_dmr_elements(state=state) 66 | scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) 67 | top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k) 68 | cands_str = proc.prepare_candidates(top_cands) 69 | 70 | # Step 2: format candidates, utterances, state, and previous actions 71 | html = proc.prepare_state_html(state.html, cands_uids=cands_uids) 72 | utterances = proc.prepare_instructor_chat(action_history, state) 73 | prev_actions = proc.prepare_prev_actions(action_history, state) 74 | 75 | # Let's use the default system prompt template, but you can also use your own 76 | sys_prompt_template: str = proc.default_system_prompt_template 77 | sys_prompt = sys_prompt_template.format( 78 | html=html, 79 | utterances=utterances, 80 | candidates=cands_str, 81 | # ... 82 | ) 83 | input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions) 84 | 85 | # Use your tokenizer to convert the input to string and pass it to the action model 86 | input_str = act_model.tokenizer.apply_chat_template(input_chat, tokenize=False) 87 | output = act_model(input_str, ...) 88 | pred_action = proc.process_action_model_output(output, state.index, elems) 89 | a = wa.classes.Action.from_dict(pred_action) 90 | ``` 91 | 92 | 93 | ## End-to-end example 94 | 95 | Here's a full, self-contained example of how to use `webllama.experimental` to interact with a web page using a DMR model and an action model: 96 | 97 | ```python 98 | from functools import partial 99 | import time 100 | import logging 101 | 102 | from sentence_transformers import SentenceTransformer 103 | from transformers import AutoTokenizer, pipeline 104 | import weblinx as wl 105 | import webllama.experimental as wa 106 | 107 | logging.getLogger("urllib3").setLevel(logging.WARNING) 108 | 109 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects 110 | # To get that, we will use an example from weblinx, but it's easy to do manually (see below). 111 | 112 | demos = wl.list_demonstrations("tests/demonstrations") 113 | replay = wl.Replay.from_demonstration(demos[0]) 114 | turn = replay[26] 115 | 116 | format_intent_am = partial( 117 | wa.formatting.build_formatters_action_model(), return_as=dict 118 | ) 119 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() 120 | action_history = wa.functions.create_action_history_from_replay( 121 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 122 | ) 123 | state = wa.classes.State( 124 | index=turn.index, 125 | html=turn.html, 126 | bboxes=turn.bboxes, 127 | viewport_height=turn.viewport_height, 128 | viewport_width=turn.viewport_width, 129 | type=turn.type, 130 | ) 131 | 132 | # Now, we can start! 133 | # First, load the DMR model we will use to select candidate elements 134 | dmr_name = "McGill-NLP/MiniLM-L6-dmr" 135 | action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx" 136 | tokenizer_chat_name = "McGill-NLP/Llama-2-7b-chat-weblinx" 137 | 138 | tokenizer_chat = AutoTokenizer.from_pretrained(tokenizer_chat_name) 139 | act_model = pipeline(model=action_model_name, device=0, torch_dtype="auto") 140 | dmr = SentenceTransformer(dmr_name, device="cuda") 141 | 142 | # We will initialize our processor, which helps us prepare the input for action model 143 | proc = wa.processing.WebTurnProcessor(tokenizer=act_model.tokenizer, start_time=time.time()) 144 | 145 | # Step 1: prepare query, run DMR and prepare retrieved candidates 146 | query_dmr = proc.prepare_dmr_query(action_history, state) 147 | elems = proc.prepare_dmr_elements(state=state) 148 | scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) 149 | top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k) 150 | cands_str = proc.prepare_candidates(top_cands) 151 | 152 | # Step 2: format candidates, utterances, state, and previous actions 153 | html = proc.prepare_state_html(state.html, cands_uids=cands_uids) 154 | utterances = proc.prepare_instructor_chat(action_history, state) 155 | prev_actions = proc.prepare_prev_actions(action_history, state) 156 | 157 | # Let's use the default system prompt template, but you can also use your own 158 | sys_prompt_template: str = proc.default_system_prompt_template 159 | sys_prompt = sys_prompt_template.format( 160 | html=html, 161 | num_utterances=proc.num_utterances - 1, 162 | utterances=utterances, 163 | height=state.viewport_height, 164 | width=state.viewport_width, 165 | num_prev_actions=proc.num_prev_actions, 166 | candidates=cands_str, 167 | ) 168 | input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions) 169 | 170 | # We can now use the tokenizer's apply_chat_template method to convert it to a format 171 | # that can be used by the action model 172 | input_str = tokenizer_chat.apply_chat_template(input_chat, tokenize=False) 173 | 174 | # Let's now pass our input to the action model 175 | output = act_model( 176 | input_str, 177 | max_new_tokens=256, 178 | return_full_text=False, 179 | batch_size=1, 180 | pad_token_id=tokenizer.eos_token_id, 181 | ) 182 | pred_action = proc.process_action_model_output( 183 | output=output, index=state.index, elems=elems 184 | ) 185 | # optional: For certain platforms you may need to postprocess the action 186 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) 187 | print(pred_action) 188 | # You can now convert this an Action object and add it to the action history 189 | a = wa.classes.Action.from_dict(pred_action) 190 | action_history.append(a) 191 | ``` 192 | 193 | ## Tests 194 | 195 | To run the tests: 196 | 197 | ```bash 198 | python -m unittest discover -s tests 199 | ``` 200 | 201 | ## Web API 202 | 203 | ### Running Server 204 | 205 | To launch the default server: 206 | ```bash 207 | # If you do not want to save logs, omit `--save_logs` 208 | python -m webllama.experimental.web.server --save_logs 209 | ``` 210 | 211 | To create your own server, simply inherit: 212 | ```python 213 | from webllama.experimental.web.server import Server 214 | 215 | from ..classes import Action, State 216 | 217 | # Assuming the classes Action, State, and other necessary imports are already defined 218 | # as provided in your initial setup. 219 | 220 | # Initialize logging 221 | logging.basicConfig(level=logging.INFO) 222 | 223 | class Server(Server): 224 | # override initialize and run 225 | def initialize(self, dmr_name, action_model_name, device, dmr_device, am_device, torch_dtype): 226 | # initialize your model here 227 | 228 | def run(self, action_history_json, state_json): 229 | # ... 230 | pred_action = { 231 | # ... 232 | } 233 | return json.dumps(pred_action) 234 | ``` 235 | 236 | ### Connecting via SSH 237 | 238 | To connect to the server via SSH, you can use the following command: 239 | ```bash 240 | ssh -N -L 8450:localhost:8450 user@server 241 | 242 | # Example: 243 | ssh -N -L 8450:localhost:8450 nlp-gpu-2 244 | ``` 245 | 246 | ### Using API 247 | 248 | You can directly send http request to the web server, or use the client. 249 | 250 | Example of HTTP request in python: 251 | 252 | ```python 253 | from functools import partial 254 | import http.client 255 | import json 256 | 257 | from functools import partial 258 | import webllama.experimental as wa 259 | import weblinx as wl 260 | 261 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects 262 | demos = wl.list_demonstrations("tests/demonstrations") 263 | replay = wl.Replay.from_demonstration(demos[0]) 264 | turn = replay[26] 265 | 266 | format_intent_am = partial( 267 | wa.formatting.build_formatters_action_model(), return_as=dict 268 | ) 269 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() 270 | action_history = wa.functions.create_action_history_from_replay( 271 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 272 | ) 273 | state = wa.classes.State( 274 | index=turn.index, 275 | html=turn.html, 276 | bboxes=turn.bboxes, 277 | viewport_height=turn.viewport_height, 278 | viewport_width=turn.viewport_width, 279 | type=turn.type, 280 | ) 281 | 282 | # Create a connection to the localhost on the port where your server is running 283 | conn = http.client.HTTPConnection('localhost', 8450) 284 | 285 | # Prepare the POST request data 286 | post_data = json.dumps({ 287 | 'action_history': action_history_dict, 288 | 'state': state_dict 289 | }) 290 | headers = {'Content-Type': 'application/json'} 291 | 292 | # Send a POST request with JSON data 293 | conn.request("POST", "/", body=post_data, headers=headers) 294 | response = conn.getresponse() 295 | print(f"Status: {response.status}") 296 | print(f"Reason: {response.reason}") 297 | print(f"Body: {response.read().decode()}") 298 | response.close() 299 | 300 | # Close the connection 301 | conn.close() 302 | ``` 303 | 304 | ### Client 305 | 306 | A high level client is provided in `webllama.experimental.web.client`. You can use it as follows: 307 | 308 | ```python 309 | from functools import partial 310 | import webllama.experimental as wa 311 | import weblinx as wl 312 | 313 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects 314 | demos = wl.list_demonstrations("tests/demonstrations") 315 | replay = wl.Replay.from_demonstration(demos[0]) 316 | turn = replay[26] 317 | 318 | format_intent_am = partial( 319 | wa.formatting.build_formatters_action_model(), return_as=dict 320 | ) 321 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() 322 | action_history = wa.functions.create_action_history_from_replay( 323 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 324 | ) 325 | state = wa.classes.State( 326 | index=turn.index, 327 | html=turn.html, 328 | bboxes=turn.bboxes, 329 | viewport_height=turn.viewport_height, 330 | viewport_width=turn.viewport_width, 331 | type=turn.type, 332 | ) 333 | 334 | # Now, we can start! 335 | pred_action = wa.web.client.get_prediction( 336 | action_history, state, address="localhost", port=8450, max_new_tokens=128 337 | ) 338 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) 339 | print(pred_action) 340 | a = wa.classes.Action.from_dict(pred_action) 341 | print(a) 342 | ``` 343 | 344 | ## Building objects 345 | 346 | > Note: This section is a work in progress. 347 | 348 | ### Build `webllama.experimental.classes.Action` 349 | 350 | #### `say` action 351 | 352 | ```python 353 | utterance_instructor = wa.classes.Action( 354 | type="chat", 355 | intent="say", 356 | index=2, 357 | args=dict( 358 | speaker="instructor", utterance="Open independent ie Website.", x=None, y=None 359 | ), 360 | timestamp=13.234, 361 | tag=None, 362 | attrs=None, 363 | ) 364 | ``` 365 | 366 | #### `click` action 367 | 368 | To be added. 369 | 370 | #### `load` action 371 | 372 | To be added. 373 | 374 | #### `textinput` action 375 | 376 | To be added. 377 | 378 | #### `submit` action 379 | 380 | To be added. 381 | 382 | ### Build `webllama.experimental.classes.Bbox` 383 | 384 | To be added. 385 | 386 | ### Build `webllama.experimental.classes.State` 387 | 388 | To be added. 389 | 390 | ## Contributing 391 | 392 | For more information on contributing, please check out the [contributing docs](docs/CONTRIBUTING.md). -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ### Web API and client 4 | 5 | You can find examples of how to use the server directly with `http.client.HTTPConnection` and through our client in [`examples/web_api/`](/examples/web_api/), respectively with `run_http.py` and `run_client.py`. You should let the server stay up for both examples. For more information, please read the section above about the Web API. 6 | 7 | ### End-to-end 8 | 9 | You can find an end-to-end example of using `webllama.experimental` in [`examples/complete/run_all.py`](/examples/complete): 10 | 11 | ```bash 12 | python examples/complete/run_all.py 13 | ``` 14 | 15 | 16 | ### BrowserGym integration 17 | 18 | We provide directly integration to BrowserGym and examples to use it. You can find an example at [`examples/browsergym/run_bg.py`](/examples/browsergym). 19 | 20 | 21 | On remote server (with GPU and hosting the webllama model), run: 22 | ```bash 23 | # transformers, sentence-transformers, pytorch, etc. 24 | pip install -e .[modeling] 25 | ``` 26 | 27 | First, remotely, run: 28 | 29 | ```bash 30 | # change if needed: 31 | export CUDA_VISIBLE_DEVICES=0 32 | 33 | python -m webllama.experimental.web.server --save_logs 34 | ``` 35 | 36 | Then, connect to your remote server via SSH: 37 | 38 | ```bash 39 | # 8450 is the default port for our server 40 | ssh -N -L 8450:localhost:8450 "@" 41 | ``` 42 | 43 | Now, on your local machine, run: 44 | 45 | ```bash 46 | pip install -e . 47 | # browsergym integration 48 | pip install "browsergym==0.3.*" 49 | # install playwright 50 | playwright install 51 | ``` 52 | 53 | ```bash 54 | python examples/browsergym/run_bg.py 55 | ``` 56 | -------------------------------------------------------------------------------- /examples/browsergym/agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from copy import deepcopy 3 | from functools import partial 4 | import time 5 | 6 | 7 | from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str 8 | from browsergym.core.action.highlevel import HighLevelActionSet 9 | import weblinx as wl 10 | 11 | import webllama.experimental as wa 12 | 13 | from webllama.experimental.integrations.browsergym.functions import ( 14 | say, 15 | click, 16 | textinput, 17 | load, 18 | scroll, 19 | wait, 20 | ) 21 | from webllama.experimental.integrations.browsergym import replace_bid_with_wl_uid, reverse_dict, postprocess_for_browsergym 22 | 23 | def remap_bboxes(bboxes, attrs_map): 24 | """ 25 | Cleans the bboxes dictionary by replacing the keys with the new unique ids. 26 | """ 27 | return {attrs_map[k]: v for k, v in bboxes.items()} 28 | 29 | class AgentBase(ABC): 30 | """ 31 | A template class that defines the required signature of an agent interacting with a browsergym environment. 32 | """ 33 | 34 | @abstractmethod 35 | def reset(self, seed=None) -> None: 36 | """ 37 | Resets the agent. 38 | 39 | """ 40 | pass 41 | 42 | @abstractmethod 43 | def get_action(self, obs: dict) -> str: 44 | """ 45 | Updates the agent with the current observation, and returns its next action (plus an info dict, optional). 46 | 47 | Parameters: 48 | ----------- 49 | obs: dict 50 | The current observation of the environment. 51 | """ 52 | pass 53 | 54 | def preprocess_obs(self, obs: dict) -> dict: 55 | """Default preprocessing of the observation.""" 56 | pass 57 | 58 | def get_action_mapping(self) -> callable: 59 | """ 60 | Returns a callable that can be used to map the agent actions to executable python code. 61 | """ 62 | return None 63 | 64 | 65 | class WebLinxAgent(AgentBase): 66 | action_history = None 67 | 68 | def reset(self, seed=None) -> None: 69 | self.action_history = [] 70 | self.messages = [] 71 | self.start_time = time.time() 72 | self.has_user_message = False 73 | 74 | @property 75 | def num_messages(self): 76 | return len(self.messages) 77 | 78 | @staticmethod 79 | def get_bboxes(xprops): 80 | bboxes = {} 81 | for k in xprops: 82 | if xprops[k]["visibility"] == 1.0: 83 | bbox = dict(zip(["x", "y", "width", "height"], xprops[k]["bbox"])) 84 | # add top, left, bottom, right 85 | bbox["top"] = bbox["y"] 86 | bbox["left"] = bbox["x"] 87 | bbox["bottom"] = bbox["y"] + bbox["height"] 88 | bbox["right"] = bbox["x"] + bbox["width"] 89 | bboxes[k] = bbox 90 | 91 | return bboxes 92 | 93 | @staticmethod 94 | def infer_viewport_from_bboxes(bboxes): 95 | """ 96 | DO NOT USE THIS, THIS FUNCTION IS NOT WORKING PROPERLY 97 | """ 98 | if not bboxes: 99 | return 0, 0 100 | 101 | x = [bboxes[k]["right"] for k in bboxes] 102 | y = [bboxes[k]["bottom"] for k in bboxes] 103 | 104 | return max(x), max(y) 105 | 106 | def infer_from_screenshot(self, screenshot): 107 | h, w, _ = screenshot.shape 108 | return w, h 109 | 110 | @staticmethod 111 | def get_visible(xprops): 112 | return {k: xprops[k]["visibility"] == 1.0 for k in xprops} 113 | 114 | @staticmethod 115 | def rename_uid_attributes(dom_str, new_name="data-webtasks-id", old_name="bid"): 116 | return dom_str.replace(f"{old_name}=", f"{new_name}=") 117 | 118 | def get_action(self, obs: dict) -> str: 119 | # preprocessing 120 | obs["dom_str"] = flatten_dom_to_str(obs["dom_object"]) 121 | obs["bboxes"] = self.get_bboxes(obs["extra_element_properties"]) 122 | # obs["axtree_txt"] = flatten_axtree_to_str(obs["axtree_object"]) 123 | # obs["visible"] = self.get_visible(obs["extra_element_properties"]) 124 | 125 | vw, vh = self.infer_from_screenshot(obs["screenshot"]) 126 | obs['html_str_orig'] = self.rename_uid_attributes(obs['dom_str']) 127 | 128 | obs["html_str"], attrs_map = replace_bid_with_wl_uid(obs["dom_str"], return_mapping=True) 129 | obs["remapped_bboxes"] = remap_bboxes(obs["bboxes"], attrs_map=attrs_map) 130 | reverse_attrs_map = reverse_dict(attrs_map) 131 | 132 | # check if we have new messages in the chat (+1 will skip first default message) 133 | new_messages = obs["chat_messages"][self.num_messages + 1 :] 134 | self.messages.extend(new_messages) 135 | 136 | # update action history with new messages 137 | for message in new_messages: 138 | role = "instructor" if message["role"] == "user" else "navigator" 139 | if role == "instructor": 140 | self.has_user_message = True 141 | 142 | self.action_history.append( 143 | wa.classes.Action( 144 | type="chat", 145 | index=len(self.action_history), 146 | intent="say", 147 | args={"utterance": message["message"], "speaker": role}, 148 | timestamp=time.time() - self.start_time, 149 | tag=None, 150 | attrs=None, 151 | ) 152 | ) 153 | print(f"New message by '{role}': {message['message']}") 154 | 155 | if not self.has_user_message: 156 | # sleep and do nothing if no user message has been received 157 | return "wait(2)" 158 | 159 | state = wa.classes.State( 160 | index=len(self.action_history), 161 | html=obs["html_str"], 162 | bboxes=obs["remapped_bboxes"], 163 | viewport_height=vh, 164 | viewport_width=vw, 165 | type="browser", 166 | ) 167 | pred_action = wa.web.client.get_prediction( 168 | self.action_history, 169 | state, 170 | address="localhost", 171 | port=8450, 172 | max_new_tokens=128, 173 | ) 174 | # breakpoint() 175 | pred_action = postprocess_for_browsergym(pred_action, uid_map=reverse_attrs_map) 176 | # pred_action = postprocess_for_browsergym(pred_action) 177 | 178 | a = wa.classes.Action.from_dict(pred_action) 179 | 180 | # add action to action history 181 | self.action_history.append(a) 182 | 183 | action_str = a.to_str() 184 | print("Action String:", action_str) 185 | 186 | return action_str 187 | 188 | def get_action_mapping(self) -> callable: 189 | """ 190 | Returns a callable that can be used to map the agent actions to executable python code. 191 | """ 192 | action_set = HighLevelActionSet( 193 | subsets="custom", 194 | custom_actions=[say, click, textinput, load, scroll, wait], 195 | multiaction=False, 196 | strict=True, 197 | ) 198 | return action_set.to_python_code 199 | -------------------------------------------------------------------------------- /examples/browsergym/run_bg.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import browsergym.core # register the openended task as a gym environment 3 | from examples.browsergym.agent import WebLinxAgent 4 | 5 | agent = WebLinxAgent() 6 | 7 | env = gym.make( 8 | "browsergym/openended", 9 | headless=False, 10 | wait_for_user_message=False, 11 | action_mapping=agent.get_action_mapping(), 12 | task_kwargs={"start_url": "chrome://newtab"}, 13 | # task_kwargs={"start_url": "https://en.wikipedia.org"}, 14 | ) 15 | 16 | agent.reset() 17 | obs, info = env.reset() 18 | 19 | done = False 20 | while not done: 21 | action = agent.get_action(obs) 22 | obs, reward, terminated, truncated, info = env.step(action) 23 | -------------------------------------------------------------------------------- /examples/complete/run_all.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import time 3 | import logging 4 | 5 | from sentence_transformers import SentenceTransformer 6 | from transformers import AutoTokenizer, pipeline 7 | import weblinx as wl 8 | import webllama.experimental as wa 9 | 10 | logging.getLogger("urllib3").setLevel(logging.WARNING) 11 | 12 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects 13 | # To get that, we will use an example from weblinx, but it's easy to do manually (see below). 14 | 15 | demos = wl.list_demonstrations("tests/demonstrations") 16 | replay = wl.Replay.from_demonstration(demos[0]) 17 | turn = replay[26] 18 | 19 | format_intent_am = partial( 20 | wa.formatting.build_formatters_action_model(), return_as=dict 21 | ) 22 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() 23 | action_history = wa.functions.create_action_history_from_replay( 24 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 25 | ) 26 | 27 | state = wa.classes.State( 28 | index=turn.index, 29 | html=turn.html, 30 | bboxes=turn.bboxes, 31 | viewport_height=turn.viewport_height, 32 | viewport_width=turn.viewport_width, 33 | type=turn.type, 34 | ) 35 | 36 | 37 | # Verifying that the to_dict and from_dict methods work as expected 38 | act = action_history[0] 39 | d = act.to_dict() 40 | act2 = wa.classes.Action.from_dict(d) 41 | assert act == act2 42 | 43 | d = state.to_dict() 44 | state2 = wa.classes.State.from_dict(d) 45 | assert state == state2 46 | 47 | 48 | # Now, we can start! 49 | # First, load the DMR model we will use to select candidate elements 50 | dmr_name = "McGill-NLP/MiniLM-L6-dmr" 51 | action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx" 52 | tokenizer_chat_name = "McGill-NLP/Llama-2-7b-chat-weblinx" 53 | 54 | tokenizer = AutoTokenizer.from_pretrained(action_model_name) 55 | tokenizer_chat = AutoTokenizer.from_pretrained(tokenizer_chat_name) 56 | dmr = SentenceTransformer(dmr_name, device="cuda") 57 | action_model = pipeline(model=action_model_name, device=0, torch_dtype="auto") 58 | 59 | # We will initialize our processor, which helps us prepare the input for action model 60 | proc = wa.processing.WebTurnProcessor(tokenizer=tokenizer, start_time=time.time()) 61 | 62 | # Step 1: prepare query, run DMR and prepare retrieved candidates 63 | query_dmr = proc.prepare_dmr_query(action_history, state) 64 | elems = proc.prepare_dmr_elements(state=state) 65 | scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) 66 | top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k) 67 | cands_str = proc.prepare_candidates(top_cands) 68 | 69 | # Step 2: format candidates, utterances, state, and previous actions 70 | html = proc.prepare_state_html(state.html, cands_uids=cands_uids) 71 | utterances = proc.prepare_instructor_chat(action_history, state) 72 | prev_actions = proc.prepare_prev_actions(action_history, state) 73 | 74 | # Let's use the default system prompt template, but you can also use your own 75 | sys_prompt_template: str = proc.default_system_prompt_template 76 | sys_prompt = sys_prompt_template.format( 77 | html=html, 78 | num_utterances=proc.num_utterances - 1, 79 | utterances=utterances, 80 | height=state.viewport_height, 81 | width=state.viewport_width, 82 | num_prev_actions=proc.num_prev_actions, 83 | candidates=cands_str, 84 | ) 85 | input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions) 86 | 87 | # We can now use the tokenizer's apply_chat_template method to convert it to a format 88 | # that can be used by the action model 89 | input_str = tokenizer_chat.apply_chat_template(input_chat, tokenize=False) 90 | 91 | # Let's now pass our input to the action model 92 | output = action_model( 93 | input_str, 94 | max_new_tokens=256, 95 | return_full_text=False, 96 | batch_size=1, 97 | pad_token_id=tokenizer.eos_token_id, 98 | ) 99 | pred_action = proc.process_action_model_output( 100 | output=output, index=state.index, elems=elems 101 | ) 102 | # optional: For certain platforms you may need to postprocess the action 103 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) 104 | print(pred_action) 105 | # You can now convert this an Action object and add it to the action history 106 | a = wa.classes.Action.from_dict(pred_action) 107 | action_history.append(a) 108 | -------------------------------------------------------------------------------- /examples/web_api/run_client.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import webllama.experimental as wa 3 | import weblinx as wl 4 | 5 | demos = wl.list_demonstrations("tests/demonstrations") 6 | replay = wl.Replay.from_demonstration(demos[0]) 7 | turn = replay[26] 8 | 9 | format_intent_am = partial( 10 | wa.formatting.build_formatters_action_model(), return_as=dict 11 | ) 12 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() 13 | action_history = wa.functions.create_action_history_from_replay( 14 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 15 | ) 16 | state = wa.classes.State( 17 | index=turn.index, 18 | html=turn.html, 19 | bboxes=turn.bboxes, 20 | viewport_height=turn.viewport_height, 21 | viewport_width=turn.viewport_width, 22 | type=turn.type, 23 | ) 24 | 25 | pred_action = wa.web.client.get_prediction( 26 | action_history, state, address="localhost", port=8450, max_new_tokens=128 27 | ) 28 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) 29 | print(pred_action) 30 | a = wa.classes.Action.from_dict(pred_action) 31 | print(a) 32 | -------------------------------------------------------------------------------- /examples/web_api/run_http.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import http.client 3 | import json 4 | 5 | import weblinx as wl 6 | import webllama.experimental as wa 7 | 8 | def run_http(): 9 | demos = wl.list_demonstrations("tests/demonstrations") 10 | replay = wl.Replay.from_demonstration(demos[0]) 11 | turn = replay[26] 12 | 13 | format_intent_am = partial( 14 | wa.formatting.build_formatters_action_model(), return_as=dict 15 | ) 16 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() 17 | action_history = wa.functions.create_action_history_from_replay( 18 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 19 | ) 20 | state = wa.classes.State( 21 | index=turn.index, 22 | html=turn.html, 23 | bboxes=turn.bboxes, 24 | viewport_height=turn.viewport_height, 25 | viewport_width=turn.viewport_width, 26 | type=turn.type, 27 | ) 28 | action_history_dict = [action.to_dict() for action in action_history] 29 | state_dict = state.to_dict() 30 | 31 | # Create a connection to the localhost on the port where your server is running 32 | conn = http.client.HTTPConnection('localhost', 8450) 33 | 34 | # Send a request without parameters to test server response 35 | conn.request("POST", "/", body=json.dumps({}), headers={'Content-Type': 'application/json'}) 36 | response = conn.getresponse() 37 | print("Test 1 - Server Initialization Check:") 38 | print(f"Status: {response.status}") 39 | print(f"Reason: {response.reason}") 40 | print(f"Body: {response.read().decode()}\n") 41 | response.close() 42 | 43 | # Prepare the POST request data 44 | post_data = json.dumps({ 45 | 'action_history': action_history_dict, 46 | 'state': state_dict 47 | }) 48 | headers = {'Content-Type': 'application/json'} 49 | 50 | # Send a POST request with JSON data 51 | conn.request("POST", "/", body=post_data, headers=headers) 52 | response = conn.getresponse() 53 | print("Test 2 - Functionality Check:") 54 | print(f"Status: {response.status}") 55 | print(f"Reason: {response.reason}") 56 | print(f"Body: {response.read().decode()}") 57 | response.close() 58 | 59 | # Close the connection 60 | conn.close() 61 | 62 | if __name__ == "__main__": 63 | run_http() 64 | -------------------------------------------------------------------------------- /modeling/README.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | 3 | First, you need to be in the `modeling` directory: 4 | 5 | ```bash 6 | cd modeling 7 | ``` 8 | 9 | ### Download Data 10 | 11 | ownload the full dataset (warning: this will take a while): 12 | 13 | ```python 14 | from huggingface_hub import snapshot_download 15 | 16 | snapshot_download(repo_id="McGill-NLP/WebLINX-full", repo_type="dataset", local_dir="./wl_data/") 17 | ``` 18 | 19 | The default configs (`llama/conf/config.yml`) assume that the `train.jsonl` is located at `./wl_data/candidates/train.jsonl`. If you want to change the path, you need to modify the `config.yml` accordingly. 20 | 21 | #### Optional: Symbolic linking to `WebLINX-full` 22 | 23 | If you downloaded `WebLINX-full` data in a different location (e.g. different disk) from your `weblinx/modeling` directory, you might consider using symbolic link to avoid having to change the `config.yml` files. You should do something like: 24 | 25 | ```bash 26 | ln -s /location/of/your/full/data /location/of/project/weblinx/modeling/wl_data 27 | ``` 28 | 29 | For example, if your data is located at `/mnt/research/scratch/users/jdoe/WebLINX-full` but your cloned `weblinx` repository is at `~/dev/weblinx`, then you'd run: 30 | 31 | ```bash 32 | ln -s /mnt/research/scratch/users/jdoe/WebLINX-full ~/dev/weblinx/modeling/wl_data 33 | ``` 34 | 35 | Which corresponds to the `data.base_dir` specified in `config.yml`, which is `"${project_dir}/wl_data/demonstrations/"`. 36 | 37 | ### Set `WEBLLAMA_PROJECT_DIR` 38 | 39 | You need to set the `WEBLLAMA_PROJECT_DIR` environment variable to the root directory of the WebLINX project. For example, if you have the following directory structure: 40 | 41 | ```bash 42 | export WEBLLAMA_PROJECT_DIR=/path/to/the/modeling/directory/ 43 | 44 | # For example, if you are in the modeling directory, you can run: 45 | export WEBLLAMA_PROJECT_DIR=$(pwd) 46 | ``` 47 | 48 | ### Install Dependencies 49 | 50 | You need to install the dependencies by running the following command: 51 | 52 | ```bash 53 | pip install -e .[extra] 54 | pip install -r modeling/requirements.txt 55 | ``` 56 | 57 | However, due to `flash-attention` requiring `torch` to be pre-installed, it has to be install right after everything else has been installed: 58 | ```bash 59 | # Regular install 60 | pip install "flash-attn>=2.3.0" 61 | # IF you have limited RAM, you can try this: 62 | MAX_JOBS=4 pip install "flash-attn>=2.3.0" --no-build-isolation 63 | # If you have issues with nvcc, try this: 64 | FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install "flash-attn>=2.3.0" --no-build-isolation 65 | ``` 66 | 67 | ### Action Model 68 | 69 | #### Train LLaMA 70 | 71 | You can train the model by running the following command (it will automatically use the hydra config from `conf/`): 72 | 73 | ```bash 74 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 75 | 76 | # Train Llama-3-8B-Instruct on WebLINX 77 | accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_4gpus.yaml -m llama.train 78 | 79 | # Fancy a different model? You can create your own variant (e.g. llama/conf/variant/8b_base.yaml) 80 | accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_4gpus.yaml -m llama.train +variant="8b_base" 81 | ``` 82 | 83 | Results will be saved in `./results` and checkpoints in `./checkpoints`. 84 | 85 | #### Run LLaMA on Evaluation Splits 86 | 87 | You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command: 88 | 89 | ```bash 90 | export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use 91 | 92 | # Evaluating llama-3-8b-instruct on a split 93 | python -m llama.eval -m eval.split=valid 94 | 95 | # Or other datasets (using multiple splits) 96 | python -m llama.eval -m eval.split=test_iid,test_web,test_geo,test_cat,test_vis 97 | ``` 98 | 99 | #### Optional: running with screen 100 | 101 | You can run this (inside `modeling` dir): 102 | ```bash 103 | # Choose the variant you want to evaluate 104 | var="8b" 105 | 106 | # Launch the screen in detaqched mode 107 | iid="CUDA_VISIBLE_DEVICES=0 ../venv/bin/python -m llama.eval -m +variant="$var" eval.split=test_iid" 108 | screen -dmS eval-llama-$var-iid bash -c "$iid; exec bash" 109 | # ... 110 | vis="CUDA_VISIBLE_DEVICES=4 ../venv/bin/python -m llama.eval -m +variant="$var" eval.split=test_vis" 111 | screen -dmS eval-llama-$var-vis bash -c "$vis; exec bash" 112 | ``` 113 | 114 | ### Evaluation 115 | 116 | To run the evaluation metrics, you can use the following command (from `modeling/`): 117 | 118 | ```bash 119 | python -m weblinx.eval -d ./results -b ./wl_data/demonstrations 120 | ``` 121 | 122 | In this case, `-b` is the base directory for the demonstrations, and `-d` is the directory containing the results (generated above by the `llama.eval` script). This will automatically run the evaluation metrics and save the results in the `results/aggregated_scores.json` directory. If you are only interested in the overall score for a split (e.g. `valid`), you can find look for the following entry in the aggregated score file (as an example): 123 | 124 | ```json 125 | // ... 126 | { 127 | "split": "valid", 128 | "intent": "overall", 129 | "metric": "overall", 130 | "model_name": "meta-llama/Meta-Llama-3-8B-Instruct", 131 | "project_name": "llama_ft", 132 | "score": 0.21667765869744438, 133 | "unconditional_score": 0.15307513104251605 134 | }, 135 | // ... 136 | ``` 137 | 138 | Behind the scene, this will use the `weblinx.eval.auto_eval_and_save` function to run the evaluation metrics. If you want more control, you can also use that `weblinx.eval.auto_eval_and_save` function directly if you prefer; for an example, check out `weblinx/eval/__main__.py`. 139 | 140 | Note that it might be slow the first time you run, because it reads a lot of demonstrations and load millions of files. However, a demo-level cache is automatically created (see `./.cache/demonstrations`), so the next time you run it, it should be much faster. 141 | 142 | ### Dense Markup Ranking (DMR) 143 | 144 | #### Train DMR 145 | 146 | You can train the model by running the following command (it will automatically use the hydra config from `conf/`): 147 | 148 | ```bash 149 | export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use 150 | 151 | # Finetune MiniLM-L6-DMR (Default) 152 | python -m dmr.train 153 | ``` 154 | 155 | Results will be saved in `./results` and checkpoints in `./checkpoints`. 156 | 157 | #### Inference for DMR 158 | 159 | You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command: 160 | 161 | ```bash 162 | export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use 163 | 164 | # On just one 165 | python -m dmr.eval eval.split=valid 166 | 167 | # On multiple splits (e.g. test_iid, test_vis) 168 | python -m dmr.eval eval.split=test_iid,test_web,test_geo,test_cat,test_vis 169 | ``` 170 | 171 | #### Moving generated DMR results to `wl_data/candidates` 172 | 173 | The `scores.jsonl` and `results.json` files will be saved at the `cfg.eval.result_dir` variable in `modeling/dmr/conf/config.yml`, which is by default `${project_dir}/results/${project_name}/${model.name}/${eval.split}`, which should by default resolve to `/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2/train` for the `train` split, `.../valid` for the valid split, etc. However, since the next steps assumes you have a directory like `wl_data/candidates/.json`, you need to manually move it. For example, you could run: 174 | 175 | ```bash 176 | # Change the following paths to match your setup 177 | orig_dir="/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2" 178 | # This is the directory where the candidates are stored 179 | new_dir="/path/to/wl_data/candidates" 180 | 181 | # You need to move the train split if you plan to use it for training the action model 182 | mv $orig_dir/train/scores.jsonl $new_dir/train.jsonl 183 | # You can move valid and test IID splits as well 184 | mv $orig_dir/valid/scores.jsonl $new_dir/valid.jsonl 185 | mv $orig_dir/test_iid/scores.jsonl $new_dir/test_iid.jsonl 186 | mv $orig_dir/test_web/scores.jsonl $new_dir/test_web.jsonl 187 | mv $orig_dir/test_geo/scores.jsonl $new_dir/test_geo.jsonl 188 | mv $orig_dir/test_cat/scores.jsonl $new_dir/test_cat.jsonl 189 | mv $orig_dir/test_vis/scores.jsonl $new_dir/test_vis.jsonl 190 | ``` 191 | 192 | Alternatively, you can also update `config.yml` to save the results in the correct directory, by overriding `candidates`: 193 | ```yaml 194 | # ... 195 | candidates: 196 | # ... 197 | model: "sentence-transformers/all-MiniLM-L6-v2" 198 | path: ${project_dir}/results/${project_name}/${model.name}/${eval.split} 199 | ``` 200 | 201 | -------------------------------------------------------------------------------- /modeling/dmr/conf/config.yaml: -------------------------------------------------------------------------------- 1 | project_dir: ${oc.env:WEBLINX_PROJECT_DIR} 2 | seed: 123 3 | project_name: dmr 4 | 5 | data: 6 | split_path: ${project_dir}/wl_data/splits.json 7 | base_dir: ${project_dir}/wl_data/demonstrations 8 | 9 | model: 10 | name: sentence-transformers/all-MiniLM-L6-v2 11 | max_seq_length: 512 12 | use_bf16: True 13 | similarity: cos_sim 14 | save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} 15 | 16 | train: 17 | split: train 18 | num_epochs: 10 19 | max_neg_per_turn: 9 20 | batch_size_per_device: 64 21 | dataloader_num_workers: 8 22 | optim: adamw 23 | gradient_checkpointing: True 24 | learning_rate: 0.00003 25 | warmup_steps: 500 26 | # Available schedulers: 27 | # constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts 28 | scheduler: warmuplinear 29 | 30 | eval: 31 | split: dev 32 | mrr_k: 50 33 | batch_size_per_device: 64 34 | result_dir: ${project_dir}/results/${project_name}/${model.name}/${eval.split} 35 | 36 | hydra: 37 | run: 38 | dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} 39 | # Use the same for sweep's subdir 40 | sweep: 41 | dir: ${hydra.run.dir} 42 | job: 43 | chdir: False 44 | verbose: INFO -------------------------------------------------------------------------------- /modeling/dmr/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from pathlib import Path 4 | from typing import List, Dict, Any 5 | 6 | import hydra 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | from sentence_transformers import SentenceTransformer 11 | from sentence_transformers.util import cos_sim, dot_score 12 | import weblinx as wl 13 | from weblinx.processing import group_record_to_dict 14 | from weblinx.utils.recs import ungroup_dict_to_records 15 | from weblinx.utils.hydra import save_path_to_hydra_logs 16 | 17 | from .processing import build_records_for_single_demo, build_formatters 18 | 19 | 20 | def recall_at_k(input_records, k, label_key="label", rank_key="rank"): 21 | num_correct = 0 22 | num_total = 0 23 | 24 | for r in input_records: 25 | if r[label_key] == 1: 26 | num_total += 1 27 | if r[rank_key] <= k: 28 | num_correct += 1 29 | 30 | score = num_correct / num_total 31 | return score 32 | 33 | 34 | def mean_reciprocal_rank(input_records, label_key="label", rank_key="rank", k=None): 35 | if k is None or len(input_records) < k or k < 1: 36 | k = len(input_records) 37 | 38 | mrr = 0 39 | num_total = 0 40 | 41 | for r in input_records: 42 | if r[label_key] == 1: 43 | if r[rank_key] <= k: 44 | mrr += 1 / r[rank_key] 45 | num_total += 1 46 | 47 | mrr /= num_total 48 | 49 | return mrr 50 | 51 | 52 | def verify_queries_are_all_the_same(grouped_records: dict) -> bool: 53 | """ 54 | Given a dictionary of grouped records, this function verifies that all 55 | queries are the same within each group. 56 | """ 57 | for k, v in grouped_records.items(): 58 | first_query = v[0]["query"] 59 | if not all(r["query"] == first_query for r in v): 60 | return False 61 | return True 62 | 63 | 64 | def run_model_and_update_groups( 65 | model, input_grouped: Dict[Any, List[dict]], batch_size, sim_method="cos_sim" 66 | ): 67 | if sim_method == "cos_sim": 68 | sim_func = cos_sim 69 | elif sim_method == "dot_product": 70 | sim_func = dot_score 71 | else: 72 | raise ValueError(f"Unknown similarity function: {sim_method}") 73 | 74 | for k, group in tqdm(input_grouped.items(), desc="Computing scores"): 75 | group = input_grouped[k] 76 | query = group[0]["query"] 77 | docs = [r["doc"] for r in group] 78 | 79 | encoded = model.encode( 80 | [query] + docs, batch_size=batch_size, show_progress_bar=False 81 | ) 82 | query_vector, doc_vectors = encoded[0], encoded[1:] 83 | scores = sim_func(query_vector, doc_vectors).cpu().squeeze().tolist() 84 | if isinstance(scores, float): 85 | scores = [scores] 86 | 87 | for i, r in enumerate(group): 88 | r["score"] = scores[i] 89 | 90 | 91 | def build_target_uids_dict(demos, uid_key="data-webtasks-id"): 92 | """ 93 | Given a list of demonstrations, build a dictionary mapping 94 | `(demo_name, turn_index) -> uid`. This is used to determine the 95 | target element for a given demo turn, which labels the element 96 | as positive or negative. 97 | """ 98 | target_uids_dict = {} 99 | for demo in tqdm(demos, desc="Creating dict of target uids"): 100 | for turn in wl.Replay.from_demonstration(demo): 101 | if turn.element is None or "attributes" not in turn.element: 102 | continue 103 | if uid_key not in turn.element["attributes"]: 104 | continue 105 | 106 | uid = turn.element["attributes"][uid_key] 107 | target_uids_dict[(demo.name, turn.index)] = uid 108 | 109 | return target_uids_dict 110 | 111 | 112 | def get_ranks_from_scores(scores: Dict[Any, float], starts_at=1) -> Dict[Any, int]: 113 | """ 114 | Given a dictionary of key -> scores, return a dictionary of key -> ranks. 115 | """ 116 | # Get sorted keys 117 | keys = sorted(scores.keys(), key=lambda k: scores[k], reverse=True) 118 | ranks = {k: i + starts_at for i, k in enumerate(keys)} 119 | 120 | return ranks 121 | 122 | 123 | @hydra.main(version_base=None, config_path="conf", config_name="config") 124 | def main(cfg): 125 | torch.manual_seed(cfg.seed) 126 | 127 | use_bf16 = cfg.model.use_bf16 128 | split = cfg.eval.split 129 | bsize = cfg.eval.batch_size_per_device 130 | 131 | split_path = Path(cfg.data.split_path).expanduser() 132 | model_save_dir = Path(cfg.model.save_dir).expanduser() 133 | result_dir = Path(cfg.eval.result_dir).expanduser() 134 | 135 | result_dir.mkdir(parents=True, exist_ok=True) 136 | 137 | if use_bf16: 138 | torch_dtype = torch.bfloat16 139 | use_amp = False 140 | else: 141 | torch_dtype = torch.float32 142 | use_amp = True 143 | 144 | # Data loading 145 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=split) 146 | demos = [wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names] 147 | 148 | format_intent_input, _ = build_formatters() 149 | input_records: List[dict] = [] 150 | logging.info(f"Number of demos: {len(demos)}. Starting building records.") 151 | for demo in tqdm(demos, desc="Building input records"): 152 | demo_records = build_records_for_single_demo( 153 | demo=demo, 154 | format_intent_input=format_intent_input, 155 | max_neg_per_turn=None, 156 | # For eval, we want to include all elements in the demo 157 | # not just the ones with valid uids 158 | only_allow_valid_uid=False, 159 | ) 160 | input_records.extend(demo_records) 161 | logging.info(f"Completed. Number of input records: {len(input_records)}") 162 | 163 | # Group records by (demo_name, turn_index) pairs 164 | input_grouped = group_record_to_dict( 165 | input_records, keys=["demo_name", "turn_index"], remove_keys=False 166 | ) 167 | 168 | # Verify that queries are all the same within each group 169 | error_msg = "Queries are not all the same within each group" 170 | assert verify_queries_are_all_the_same(input_grouped), error_msg 171 | 172 | # Run the model and update the scores and ranks in place 173 | logging.info("Running model and computing scores") 174 | 175 | # Run the model 176 | model = SentenceTransformer(str(model_save_dir)) 177 | sim_method = cfg.model.get("similarity", "cos_sim") 178 | 179 | logging.info(f"Using the following similarity method: {sim_method}") 180 | 181 | with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch_dtype): 182 | run_model_and_update_groups( 183 | model, input_grouped=input_grouped, batch_size=bsize, sim_method=sim_method 184 | ) 185 | logging.info("Completed") 186 | 187 | for group in input_grouped.values(): 188 | scores = {r["uid"]: r["score"] for r in group} 189 | ranks = get_ranks_from_scores(scores) 190 | for r in group: 191 | r["rank"] = ranks[r["uid"]] 192 | 193 | # Revert back to original records 194 | input_records = ungroup_dict_to_records(input_grouped) 195 | 196 | # Metrics 197 | lengths = np.array([len(v) for v in input_grouped.values()]) 198 | results = { 199 | "split": split, 200 | "num_turns": len(input_grouped), 201 | "num_demos": len(demos), 202 | "avg_elements_per_turn": lengths.mean(), 203 | "std_elements_per_turn": lengths.std(), 204 | "mrr": mean_reciprocal_rank(input_records, k=cfg.eval.mrr_k), 205 | } 206 | 207 | for k in [1, 5, 10, 20, 50, 100, 200]: 208 | results[f"recall@{k}"] = recall_at_k(input_records, k=k) 209 | 210 | for k, v in results.items(): 211 | print(f"{k}: {v}") 212 | 213 | # Save results 214 | with open(result_dir.joinpath("results.json"), "w") as f: 215 | json.dump(results, f, indent=2) 216 | 217 | # Save records and scores 218 | with open(result_dir.joinpath("scores.jsonl"), "w") as f: 219 | for r in input_records: 220 | f.write(json.dumps(r) + "\n") 221 | 222 | save_path_to_hydra_logs(save_dir=model_save_dir) 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /modeling/dmr/processing.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from copy import deepcopy 3 | import random 4 | from typing import Any, Dict, List 5 | from functools import partial 6 | 7 | import lxml.html 8 | import weblinx as wl 9 | import weblinx.utils.html as wh 10 | import weblinx.utils.format as wlf 11 | from weblinx.processing.prompt import ( 12 | format_prev_turns, 13 | find_turns_with_instructor_chat, 14 | format_utterances, 15 | ) 16 | 17 | 18 | def format_turn_for_input( 19 | replay, 20 | turn, 21 | format_intent, 22 | turn_sep=" ; ", 23 | num_prev_turns=5, 24 | num_utterances=5, 25 | return_str=True, 26 | ): 27 | """ 28 | This function formats a turn for input to the model. It does so by combining the following: 29 | 1. The first and last `num_utterances-1` utterances from the instructor 30 | 2. The previous turns (up to `num_prev_turns` turns) 31 | 32 | If return_str is True, then the output is a string. Otherwise, it returns two strings: the utterance context and the previous turns. 33 | """ 34 | prev_turns_text = format_prev_turns( 35 | replay=replay, 36 | turn=turn, 37 | format_intent=format_intent, 38 | turn_sep=turn_sep, 39 | num_prev_turns=num_prev_turns, 40 | ) 41 | instructor_chat_turns = find_turns_with_instructor_chat( 42 | replay, turn, num_prev_turns=num_prev_turns 43 | ) 44 | utterance_context = format_utterances( 45 | instructor_chat_turns, num_utterances=num_utterances 46 | ) 47 | 48 | if not return_str: 49 | return utterance_context, prev_turns_text 50 | 51 | # Now, let's combine the text from the previous turns with the utterance context 52 | # and the current turn's utterance 53 | text = ( 54 | f"Viewport(height={turn.viewport_height}, width={turn.viewport_width}) ---- " 55 | f"Instructor Utterances: {utterance_context} ---- " 56 | f"Previous Turns:{prev_turns_text}" 57 | ) 58 | 59 | return text 60 | 61 | 62 | def build_formatters(): 63 | format_element_input = partial( 64 | wlf.format_element, 65 | include_text=False, 66 | include_attrs=("class", "title", "href", "aria-label", "d", "src"), 67 | ) 68 | format_click_input = partial( 69 | wlf.format_click, 70 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp), 71 | ) 72 | format_change_input = partial( 73 | wlf.format_change, 74 | formatters=( 75 | partial(wlf.format_arg_item, name="value"), 76 | format_element_input, 77 | wlf.format_timestamp, 78 | ), 79 | ) 80 | format_hover_input = partial( 81 | wlf.format_hover, 82 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp), 83 | ) 84 | 85 | format_submit_input = partial( 86 | wlf.format_submit, formatters=(format_element_input, wlf.format_timestamp) 87 | ) 88 | 89 | format_text_input_input = partial( 90 | wlf.format_text_input, 91 | formatters=( 92 | partial(wlf.format_arg_item, name="text"), 93 | partial(format_element_input), 94 | wlf.format_timestamp, 95 | ), 96 | ) 97 | 98 | format_intent_input = partial( 99 | wlf.format_intent_automatically, 100 | format_click=format_click_input, 101 | format_change=format_change_input, 102 | format_hover=format_hover_input, 103 | format_submit=format_submit_input, 104 | format_text_input=format_text_input_input, 105 | return_as=str, 106 | ) 107 | 108 | # second, for the output (prediction text) 109 | format_element_out = partial( 110 | wlf.format_element, 111 | # Only want the tag 112 | include_text=False, 113 | include_attrs=False, 114 | ) 115 | 116 | format_click_out = partial(wlf.format_click, formatters=(wlf.format_mouse_xy,)) 117 | format_text_input_out = partial( 118 | wlf.format_text_input, 119 | formatters=( 120 | partial(wlf.format_arg_item, name="text", max_length=200), 121 | format_element_out, 122 | wlf.format_target_bbox, 123 | ), 124 | ) 125 | format_change_out = partial( 126 | wlf.format_change, 127 | formatters=( 128 | partial(wlf.format_arg_item, name="value", max_length=200), 129 | format_element_out, 130 | wlf.format_target_bbox, 131 | ), 132 | ) 133 | format_submit_out = partial( 134 | wlf.format_submit, formatters=(format_element_out, wlf.format_target_bbox) 135 | ) 136 | format_load_out = partial( 137 | wlf.format_load, 138 | include_transition=False, 139 | include_timestamp=False, 140 | max_length=200, 141 | ) 142 | format_scroll_out = partial(wlf.format_scroll, include_timestamp=False) 143 | 144 | format_say_out = partial(wlf.format_say, include_timestamp=False) 145 | 146 | format_intent_out = partial( 147 | wlf.format_intent_automatically, 148 | format_change=format_change_out, 149 | format_click=format_click_out, 150 | format_load=format_load_out, 151 | format_say=format_say_out, 152 | format_scroll=format_scroll_out, 153 | format_submit=format_submit_out, 154 | format_text_input=format_text_input_out, 155 | ) 156 | 157 | return format_intent_input, format_intent_out 158 | 159 | 160 | def turn_has_valid_uid(turn, paths, uid_key="data-webtasks-id"): 161 | """ 162 | Given a turn an lxml tree, return True if the turn's uid is in the tree. 163 | """ 164 | uids = [p.attrib[uid_key] for p in paths] 165 | if turn.element is None or uid_key not in turn.element["attributes"]: 166 | return False 167 | 168 | if turn.element["attributes"][uid_key] not in uids: 169 | return False 170 | 171 | return True 172 | 173 | 174 | def format_attrs(attrs): 175 | return " ".join([f"{k!s}={v!r}" for k, v in attrs.items()]) 176 | 177 | 178 | def shorten(s, max_length=100, side="center", ellipsis="..."): 179 | if max_length is None: 180 | return s 181 | 182 | if len(s) <= max_length: 183 | return s 184 | 185 | max_length = max_length - len(ellipsis) 186 | 187 | if side == "right": 188 | s = s[:max_length] + ellipsis 189 | elif side == "left": 190 | s = ellipsis + s[-max_length:] 191 | elif side == "center": 192 | s = s[: max_length // 2] + ellipsis + s[-max_length // 2 :] 193 | else: 194 | raise ValueError(f"Invalid side: {side}") 195 | 196 | return s 197 | 198 | 199 | def format_children(parent, depth=1): 200 | """ 201 | Use the concise parentheses notation to format the children of an element. 202 | For example, for depth 1, we only have: (child1 child2 child3) 203 | For depth 2, we have: (child1 (grandchild1 grandchild2) child2 child3) 204 | """ 205 | children = parent.getchildren() 206 | if len(children) == 0: 207 | return "" 208 | 209 | if depth == 1: 210 | return " ".join([c.tag for c in children]) 211 | 212 | out_str = "" 213 | for c in children: 214 | out_str += f"{c.tag}" 215 | children_str = format_children(c, depth=depth - 1) 216 | if children_str != "": 217 | out_str += f" ( {children_str} )" 218 | out_str += " " 219 | 220 | return out_str.strip() 221 | 222 | 223 | def represent_element_as_dict( 224 | element, 225 | bbox, 226 | root_tree, 227 | max_text_length=200, 228 | max_attr_length=100, 229 | max_child_depth=2, 230 | ): 231 | """ 232 | Format an lxml element into a dictionary of strings. The keys are: 233 | - tag: the tag name of the element 234 | - xpath: the xpath of the element 235 | - text: the text of the element, truncated to `max_text_length` 236 | - bbox: the bounding box of the element 237 | - attributes: the attributes of the element, truncated to `max_attr_length` 238 | - children: the children of the element, truncated to `max_attr_length` 239 | """ 240 | # Get the tag name 241 | tag = element.tag 242 | xpath = root_tree.getpath(element) 243 | children = element.getchildren() 244 | text = element.text if element.text is not None else "" 245 | 246 | # Shorten the text and attributes 247 | text = shorten(text, max_text_length) 248 | attrs = {k: shorten(v, max_attr_length) for k, v in element.attrib.items()} 249 | 250 | # Sort the attributes by length 251 | attrs = dict(sorted(attrs.items(), key=lambda x: len(x[1]))) 252 | 253 | # Truncate the children 254 | children = children[:max_child_depth] 255 | 256 | # Format the children 257 | children_str = " ".join([c.tag for c in children if isinstance(c.tag, str)]) 258 | children_str = shorten(children_str, max_attr_length) 259 | 260 | # Format the attributes 261 | attrs_str = format_attrs(attrs) 262 | 263 | # Format the bounding box 264 | bbox_str = " ".join( 265 | [f"{k}={round(bbox[k], 1)}" for k in ["x", "y", "width", "height"]] 266 | ) 267 | 268 | # format as a dict 269 | element_dict = { 270 | "tag": tag, 271 | "xpath": xpath, 272 | "text": text, 273 | "bbox": bbox_str, 274 | "attributes": attrs_str, 275 | "children": children_str, 276 | } 277 | 278 | return element_dict 279 | 280 | 281 | def convert_elem_dict_to_str_legacy(elem_dict: dict): 282 | """ 283 | Convert an element dictionary to a string. 284 | """ 285 | elem_dict = deepcopy(elem_dict) 286 | 287 | element_str = f"[[tag]] {elem_dict.pop('tag')}\n" 288 | element_str += f"[[xpath]] {elem_dict.pop('xpath')}\n" 289 | element_str += f"[[text]] {elem_dict.pop('text')}\n" 290 | element_str += f"[[bbox]] {elem_dict.pop('bbox')}\n" 291 | element_str += f"[[attributes]] {elem_dict.pop('attributes')}\n" 292 | element_str += f"[[children]] {elem_dict.pop('children')}" 293 | 294 | # for other keys, we just add them to the end 295 | 296 | for k, v in elem_dict.items(): 297 | element_str += f"\n[[{k}]] {v}" 298 | 299 | return element_str 300 | 301 | 302 | def build_records_for_single_turn( 303 | turn, replay, format_intent_input, uid_key, max_neg=None, only_allow_valid_uid=True 304 | ) -> List[dict]: 305 | """ 306 | This function will build a list of dictionaries, each of which is a record 307 | for a single turn. Each record has the following keys: 308 | - query: the dialogue history, used as a query for training the model 309 | - doc: concise representation of HTML element used as doc for training 310 | - label: either 0 or 1, indicating whether the document is the target element 311 | - uid: the unique identifier for an element, must be in the element attributes 312 | - turn_index: the index of the turn in the replay 313 | - demo_name: the name of the demonstration 314 | 315 | If `only_allow_valid_uid` is True, then only turns that have a valid uid 316 | will be included in the output. Otherwise, all turns will be included. 317 | """ 318 | bboxes_filt = wh.filter_bboxes( 319 | turn.bboxes, 320 | viewport_height=turn.viewport_height, 321 | viewport_width=turn.viewport_width, 322 | ) 323 | root = lxml.html.fromstring(turn.html) 324 | root_tree = root.getroottree() 325 | elements = root.xpath(f"//*[@{uid_key}]") 326 | elements_filt = [p for p in elements if p.attrib[uid_key] in bboxes_filt] 327 | 328 | has_valid_uid = turn_has_valid_uid(turn, paths=elements, uid_key=uid_key) 329 | if only_allow_valid_uid and not has_valid_uid: 330 | return [] 331 | 332 | # Now, we can format each of the elements in paths_filt into string 333 | # and use them as negative samples 334 | query = format_turn_for_input(replay, turn, format_intent=format_intent_input) 335 | target_uid = turn.element["attributes"][uid_key] if has_valid_uid else -1 336 | 337 | records_positive = [] 338 | records_negative = [] 339 | 340 | for elem in elements_filt: 341 | bbox = turn.bboxes[elem.attrib[uid_key]] 342 | elem_dict = represent_element_as_dict(elem, bbox, root_tree) 343 | elem_str = convert_elem_dict_to_str_legacy(elem_dict) 344 | 345 | record = { 346 | "query": query, 347 | "doc": elem_str, 348 | "uid": elem.attrib[uid_key], 349 | "demo_name": turn.demo_name, 350 | "turn_index": turn.index, 351 | "elem_dict": elem_dict, 352 | } 353 | 354 | if elem.attrib[uid_key] == target_uid: 355 | record["label"] = 1 356 | records_positive.append(record) 357 | else: 358 | record["label"] = 0 359 | records_negative.append(record) 360 | 361 | if max_neg is not None and 0 < max_neg < len(records_negative): 362 | records_negative = random.sample(records_negative, max_neg) 363 | 364 | return records_positive + records_negative 365 | 366 | 367 | def build_records_for_single_demo( 368 | demo, 369 | format_intent_input, 370 | max_neg_per_turn=None, 371 | random_state=None, 372 | uid_key="data-webtasks-id", 373 | only_allow_valid_uid=True, 374 | group_by_turn=False, 375 | ) -> List[dict]: 376 | """ 377 | This runs `build_records_for_single_turn` for each turn in the demonstration. 378 | First, the demonstration is converted into a replay, and then we filter the 379 | turns to only those that have HTML and bounding boxes, and that are of the 380 | following intents: 381 | - click 382 | - change 383 | - textInput 384 | - scroll 385 | - load 386 | - submit 387 | 388 | Any turn that does not have a valid uid is discarded. 389 | """ 390 | if random_state is not None: 391 | random.seed(random_state) 392 | 393 | replay = wl.Replay.from_demonstration(demo) 394 | turns = replay.filter_by_intents( 395 | "click", "change", "textInput", "scroll", "load", "submit" 396 | ) 397 | turns = wl.filter_turns(turns, lambda t: t.has_html() and t.has_bboxes()) 398 | 399 | records_for_demo = [] 400 | for turn in turns: 401 | recs = build_records_for_single_turn( 402 | turn=turn, 403 | replay=replay, 404 | format_intent_input=format_intent_input, 405 | uid_key=uid_key, 406 | max_neg=max_neg_per_turn, 407 | only_allow_valid_uid=only_allow_valid_uid, 408 | ) 409 | if group_by_turn: 410 | records_for_demo.append(recs) 411 | else: 412 | records_for_demo.extend(recs) 413 | 414 | return records_for_demo 415 | -------------------------------------------------------------------------------- /modeling/dmr/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import hydra 5 | from tqdm import tqdm 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from sentence_transformers import SentenceTransformer, InputExample 9 | import sentence_transformers.models as st_models 10 | from sentence_transformers.losses import CosineSimilarityLoss 11 | import transformers 12 | from weblinx.utils.hydra import save_path_to_hydra_logs 13 | import weblinx as wl 14 | 15 | from .processing import build_records_for_single_demo, build_formatters 16 | 17 | 18 | def infer_optimizer(name): 19 | name = name.lower() 20 | 21 | if name == "adamw": 22 | return torch.optim.AdamW 23 | elif name == "adam": 24 | return torch.optim.Adam 25 | elif name == "adafactor": 26 | return transformers.Adafactor 27 | elif name == "sgd": 28 | return torch.optim.SGD 29 | else: 30 | raise ValueError(f"Unknown optimizer name: {name}") 31 | 32 | 33 | @hydra.main(version_base=None, config_path="conf", config_name="config") 34 | def main(cfg): 35 | torch.manual_seed(cfg.seed) 36 | 37 | model_name = cfg.model.name 38 | use_bf16 = cfg.model.use_bf16 39 | max_seq_length = cfg.model.max_seq_length 40 | optim = cfg.train.optim 41 | split = cfg.train.split 42 | learning_rate = cfg.train.learning_rate 43 | warmup_steps = cfg.train.warmup_steps 44 | batch_size = cfg.train.batch_size_per_device 45 | num_epochs = cfg.train.num_epochs 46 | scheduler = cfg.train.scheduler 47 | 48 | split_path = split_path = Path(cfg.data.split_path).expanduser() 49 | model_save_dir = Path(cfg.model.save_dir).expanduser() 50 | 51 | if use_bf16: 52 | torch_dtype = torch.bfloat16 53 | use_amp = False 54 | else: 55 | torch_dtype = torch.float32 56 | use_amp = True 57 | 58 | # Data loading 59 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=split) 60 | demos = [wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names] 61 | 62 | if cfg.project_name.endswith("testing"): 63 | demos = demos[:10] 64 | 65 | format_intent_input, _ = build_formatters() 66 | input_records = [] 67 | logging.info(f"Number of demos: {len(demos)}. Starting building records.") 68 | for demo in tqdm(demos, desc="Building input records"): 69 | input_records.extend( 70 | build_records_for_single_demo( 71 | demo=demo, 72 | format_intent_input=format_intent_input, 73 | max_neg_per_turn=cfg.train.max_neg_per_turn, 74 | random_state=cfg.seed, 75 | # For training, we only want to include elements with valid uids 76 | # otherwise, we will be training on a lot of negative examples 77 | only_allow_valid_uid=True, 78 | ) 79 | ) 80 | 81 | logging.info(f"Number of input records: {len(input_records)}") 82 | 83 | train_examples = [ 84 | InputExample(texts=[r["query"], r["doc"]], label=float(r["label"])) 85 | for r in tqdm( 86 | input_records, desc="Converting records to sentence-transformers input" 87 | ) 88 | ] 89 | 90 | train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size) 91 | 92 | # Model loading 93 | word_embedding_model = st_models.Transformer( 94 | model_name, max_seq_length=max_seq_length 95 | ) 96 | if cfg.train.gradient_checkpointing and hasattr( 97 | word_embedding_model.auto_model, "gradient_checkpointing_enable" 98 | ): 99 | word_embedding_model.auto_model.gradient_checkpointing_enable() 100 | 101 | pooling_model = st_models.Pooling( 102 | word_embedding_model.get_word_embedding_dimension() 103 | ) 104 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) 105 | train_loss = CosineSimilarityLoss(model=model) 106 | 107 | logging.info(f"Starting training for {num_epochs} epochs.") 108 | with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch_dtype): 109 | model.fit( 110 | train_objectives=[(train_dataloader, train_loss)], 111 | epochs=num_epochs, 112 | optimizer_class=infer_optimizer(optim), 113 | warmup_steps=warmup_steps, 114 | output_path=str(model_save_dir), 115 | weight_decay=0.0, 116 | scheduler=scheduler, 117 | optimizer_params={"lr": learning_rate}, 118 | ) 119 | logging.info("Training complete.") 120 | 121 | save_path_to_hydra_logs(save_dir=model_save_dir) 122 | 123 | return model_save_dir 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /modeling/llama/accelerate/fsdp_2gpus.yaml: -------------------------------------------------------------------------------- 1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch_policy: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: false 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: 1 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_sync_module_states: true 15 | # Set fsdp_use_orig_params=true if using peft: 16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 17 | fsdp_use_orig_params: false 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 2 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false -------------------------------------------------------------------------------- /modeling/llama/accelerate/fsdp_4gpus.yaml: -------------------------------------------------------------------------------- 1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch_policy: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: false 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: 1 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_sync_module_states: true 15 | # Set fsdp_use_orig_params=true if using peft: 16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 17 | fsdp_use_orig_params: false 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 4 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false -------------------------------------------------------------------------------- /modeling/llama/accelerate/fsdp_6gpus.yaml: -------------------------------------------------------------------------------- 1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch_policy: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: false 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: 1 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_sync_module_states: true 15 | # Set fsdp_use_orig_params=true if using peft: 16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 17 | fsdp_use_orig_params: false 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 6 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false -------------------------------------------------------------------------------- /modeling/llama/accelerate/fsdp_8gpus.yaml: -------------------------------------------------------------------------------- 1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch_policy: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: false 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: 1 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_sync_module_states: true 15 | # Set fsdp_use_orig_params=true if using peft: 16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 17 | fsdp_use_orig_params: false 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 8 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false -------------------------------------------------------------------------------- /modeling/llama/conf/config.yaml: -------------------------------------------------------------------------------- 1 | project_dir: ${oc.env:WEBLLAMA_PROJECT_DIR} 2 | seed: 123 3 | project_name: llama_ft 4 | 5 | data: 6 | num_proc: 8 7 | split_path: ${project_dir}/wl_data/splits.json 8 | base_dir: ${project_dir}/wl_data/demonstrations/ 9 | 10 | train: 11 | split: train 12 | num_epochs: 3 13 | learning_rate: 3e-5 14 | batch_size_per_device: 4 15 | gradient_accumulation_steps: 1 16 | dataloader_num_workers: 8 17 | gradient_checkpointing: True 18 | use_accelerator_device_map: True # Set to true if using `accelerate` 19 | use_auto_device_map: False # Set to false if using `accelerate` 20 | warmup_ratio: 0 21 | scheduler: linear 22 | optim: adamw_torch 23 | 24 | eval: 25 | split: valid 26 | batch_size_per_device: 8 27 | result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name} 28 | load_from_save_dir: True 29 | test_run: False 30 | 31 | model: 32 | name: meta-llama/Meta-Llama-3-8B-Instruct 33 | use_flash_attention_2: True 34 | tokenizer: ${model.name} 35 | template_tokenizer: ${model.tokenizer} 36 | max_inp_len: null 37 | max_out_len: 256 38 | use_rope: True 39 | save_dir: ${project_dir}/checkpoints/${project_name}/${model.name} 40 | 41 | candidates: 42 | k: 10 43 | model: McGill-NLP/MiniLM-L6-dmr # unused but potentially useful 44 | project_name: dmr # unused but potentially useful 45 | split: ${eval.split} 46 | train_path: ${project_dir}/wl_data/candidates/train.jsonl 47 | path: ${project_dir}/wl_data/candidates/${candidates.split}.jsonl 48 | 49 | hydra: 50 | run: 51 | dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S} 52 | # Use the same for sweep's subdir 53 | sweep: 54 | dir: ${hydra.run.dir} 55 | job: 56 | chdir: False 57 | verbose: INFO -------------------------------------------------------------------------------- /modeling/llama/eval.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | import json 4 | from pathlib import Path 5 | 6 | import hydra 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import OmegaConf 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import ( 12 | AutoTokenizer, 13 | AutoModelForCausalLM, 14 | pipeline, 15 | ) 16 | from transformers.pipelines.pt_utils import KeyDataset 17 | 18 | import weblinx as wl 19 | from weblinx.processing import load_candidate_elements 20 | from weblinx.processing.prompt import build_input_records_from_selected_turns, select_turns_and_candidates_for_prompts 21 | from weblinx.utils.hydra import save_path_to_hydra_logs 22 | 23 | from .processing import ( 24 | build_prompt_records_for_llama_truncated, 25 | build_formatter_for_multichoice, 26 | insert_formatted_chat_into_records 27 | ) 28 | 29 | 30 | @hydra.main(version_base=None, config_path="conf", config_name="config") 31 | def main(cfg): 32 | logger = logging.getLogger(__name__) 33 | 34 | split_path = Path(cfg.data.split_path).expanduser() 35 | result_dir = Path(cfg.eval.result_dir).expanduser() 36 | model_save_dir = Path(cfg.model.save_dir).expanduser() 37 | 38 | max_out_len = cfg.model.max_out_len 39 | split = cfg.eval.split 40 | 41 | result_dir.mkdir(parents=True, exist_ok=True) 42 | 43 | logger.info(OmegaConf.to_yaml(cfg)) 44 | 45 | candidates = load_candidate_elements(path=cfg.candidates.path) 46 | 47 | tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer, padding_side="left") 48 | tokenizer.pad_token = tokenizer.eos_token 49 | 50 | # Data loading 51 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=split) 52 | demos = [wl.Demonstration(name, base_dir=cfg.data.base_dir) for name in demo_names] 53 | 54 | format_intent = build_formatter_for_multichoice() 55 | build_prompt_records_fn = partial( 56 | build_prompt_records_for_llama_truncated, 57 | format_intent=format_intent, 58 | tokenizer=tokenizer, 59 | ) 60 | 61 | selected_turns = select_turns_and_candidates_for_prompts( 62 | demos=demos, 63 | candidates=candidates, 64 | num_candidates=cfg.candidates.k, 65 | ) 66 | 67 | input_records = build_input_records_from_selected_turns( 68 | selected_turns=selected_turns, 69 | format_intent=format_intent, 70 | build_prompt_records_fn=build_prompt_records_fn, 71 | format_prompt_records_fn=None, 72 | ) 73 | 74 | template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer) 75 | insert_formatted_chat_into_records( 76 | records=input_records, 77 | tokenizer=template_tokenizer, 78 | include_output_target=False, 79 | ) 80 | 81 | model_kwargs = dict(device_map="auto", torch_dtype=torch.bfloat16) 82 | model_kwargs['trust_remote_code'] = cfg.model.get('trust_remote_code', False) 83 | 84 | if cfg.model.use_rope: 85 | model_kwargs["rope_scaling"] = {"type": "dynamic", "factor": 2.0} 86 | 87 | if cfg.model.use_flash_attention_2: 88 | model_kwargs["use_flash_attention_2"] = True 89 | 90 | if cfg.eval.get("load_from_save_dir", False) is True: 91 | model_load_name = str(model_save_dir) 92 | else: 93 | model_load_name = cfg.model.name 94 | 95 | model = AutoModelForCausalLM.from_pretrained(model_load_name, **model_kwargs) 96 | 97 | dset = KeyDataset(input_records, key="text") 98 | pipe = pipeline( 99 | "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16 100 | ) 101 | pipe_kwargs = dict( 102 | max_new_tokens=max_out_len, 103 | return_full_text=False, 104 | batch_size=cfg.eval.batch_size_per_device, 105 | pad_token_id=tokenizer.eos_token_id, 106 | ) 107 | 108 | results = [] 109 | 110 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 111 | pbar = tqdm( 112 | pipe(dset, **pipe_kwargs), desc="Generating outputs", total=len(dset) 113 | ) 114 | for i, out in enumerate(pbar): 115 | rec = input_records[i] 116 | generated_text = out[0]["generated_text"] 117 | result = { 118 | "demo_name": rec["demo_name"], 119 | "turn_index": rec["turn_index"], 120 | "prompt": rec["prompt"], 121 | "text": rec["text"], 122 | "output_predicted": generated_text, 123 | "output_target": rec["output_target"], 124 | "output_target_dict": rec["output_target_dict"], 125 | } 126 | 127 | results.append(result) 128 | 129 | # Save results 130 | with open(result_dir / "results.json", "w") as f: 131 | json.dump(results, f, indent=2) 132 | 133 | # Save the path to hydra_path into the model directory 134 | save_path_to_hydra_logs(save_dir=result_dir) 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /modeling/llama/processing.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable 3 | 4 | import lxml.html 5 | 6 | import weblinx.utils.format as wlf 7 | from weblinx.processing.dom import clean_and_prune_tree 8 | from weblinx.processing.prompt import ( 9 | find_turns_with_instructor_chat, 10 | format_candidates, 11 | format_utterances, 12 | format_utterances_truncated, 13 | get_speaker, 14 | multi_attempt_format_prev_turns_truncated, 15 | ) 16 | from weblinx.processing.truncation import ( 17 | multi_attempt_truncate_cands_turn, 18 | multi_attempt_truncate_dom_tree, 19 | ) 20 | 21 | 22 | def build_formatter_for_multichoice(): 23 | format_click = partial(wlf.format_click, formatters=(wlf.format_uid,)) 24 | format_text_input = partial( 25 | wlf.format_text_input, 26 | formatters=( 27 | partial(wlf.format_arg_item, name="text", max_length=200), 28 | wlf.format_uid, 29 | ), 30 | ) 31 | format_change = partial( 32 | wlf.format_change, 33 | formatters=( 34 | partial(wlf.format_arg_item, name="value", max_length=200), 35 | wlf.format_uid, 36 | ), 37 | ) 38 | format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,)) 39 | format_load = partial( 40 | wlf.format_load, 41 | include_transition=False, 42 | include_timestamp=False, 43 | max_length=200, 44 | ) 45 | format_scroll = partial(wlf.format_scroll, include_timestamp=False) 46 | 47 | format_say = partial(wlf.format_say, include_timestamp=False) 48 | 49 | format_intent_auto = partial( 50 | wlf.format_intent_automatically, 51 | format_change=format_change, 52 | format_click=format_click, 53 | format_load=format_load, 54 | format_say=format_say, 55 | format_scroll=format_scroll, 56 | format_submit=format_submit, 57 | format_text_input=format_text_input, 58 | ) 59 | 60 | return format_intent_auto 61 | 62 | 63 | def get_system_prompt_template_for_llama_mc_concise(): 64 | sys_prompt_template = ( 65 | "You are an AI assistant with a deep understanding of HTML " 66 | "and you must predict actions based on a user request, which will be executed. " 67 | "Use one of the following, replacing [] with an appropriate value: " 68 | "change(value=[str], uid=[str]) ; " 69 | "click(uid=[str]) ; " 70 | "load(url=[str]) ; " 71 | 'say(speaker="navigator", utterance=[str]) ; ' 72 | "scroll(x=[int], y=[int]) ; " 73 | "submit(uid=[str]) ;" 74 | "text_input(text=[str], uid=[str]) ;\n" 75 | "The user's first and last {num_utterances} utterances are: " 76 | "{utterance_context} ;\n" 77 | "Viewport size: {height}h x {width}w ;\n" 78 | "Only the last {num_prev_turns} turns are provided." 79 | ) 80 | 81 | return sys_prompt_template 82 | 83 | 84 | def get_candidate_prompt_template_for_llama(): 85 | return "Here are the top candidates for this turn: {candidate_str}\n" 86 | 87 | 88 | def get_final_user_message(): 89 | return "Please select the best action using the correct format, do not provide any other information or explanation." 90 | 91 | 92 | def merge_prev_turns(prev_turns_text_list, final_user_message): 93 | prev_turns_merged = [] 94 | 95 | # Merge turns from the same role 96 | for i, turn_text in enumerate(prev_turns_text_list): 97 | role = get_speaker( 98 | turn_text, 99 | instructor_name="user", 100 | navigator_name="assistant", 101 | default_name="unknown", 102 | ) 103 | 104 | if i > 0 and prev_turns_merged[-1]["role"] == role: 105 | prev_turns_merged[-1]["content"] += " " + turn_text 106 | else: 107 | prev_turns_merged.append({"role": role, "content": turn_text}) 108 | 109 | if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user": 110 | prev_turns_merged[-1]["content"] += " " + final_user_message 111 | else: 112 | prev_turns_merged.append({"role": "user", "content": final_user_message}) 113 | 114 | return prev_turns_merged 115 | 116 | 117 | def build_prompt_records_for_llama_truncated( 118 | replay, 119 | turn, 120 | format_intent, 121 | tokenizer, 122 | cands_turn=None, 123 | num_utterances=5, 124 | num_prev_turns=5, 125 | system_prompt_template=None, 126 | candidate_prompt_template=None, 127 | final_user_message=None, 128 | include_html=True, 129 | format_candidates_fn=partial( 130 | format_candidates, max_char_len=None, use_uid_as_rank=True 131 | ), 132 | merge_prev_turns_fn=merge_prev_turns, 133 | format_output_dict_fn: Callable = partial( 134 | wlf.format_output_dictionary, function_key="intent" 135 | ), 136 | max_html_tokens=700, 137 | max_utterance_tokens=40 * 5, 138 | max_prev_turns_tokens=50 * 5, 139 | max_candidates_tokens=65 * 10, 140 | add_unused_len_to_cands=True, 141 | allow_iterative_reduction=False, 142 | parser=None, 143 | ): 144 | """ 145 | Parameters 146 | ---------- 147 | ... 148 | allow_iterative_reduction : bool 149 | This arg is only relevant when truncate_at_center is used behind the scene (e.g. for 150 | multi_attempt_format_prev_turns_truncated or multi_attempt_truncate_dom_tree). If True, 151 | then we will allow the iterative reduction to continue until the max_tokens is reached. 152 | This is useful when the tokenizer output does not necessarily decrease when we remove 153 | tokens from the input. For example, if we remove a token that is part of a word, but 154 | the updated text is retokenized to the same number of tokens, then we will continue 155 | to remove tokens until we reach the max_tokens limit. 156 | """ 157 | if system_prompt_template is None: 158 | system_prompt_template = get_system_prompt_template_for_llama_mc_concise() 159 | 160 | if candidate_prompt_template is None: 161 | candidate_prompt_template = get_candidate_prompt_template_for_llama() 162 | 163 | if final_user_message is None: 164 | final_user_message = get_final_user_message() 165 | 166 | instructor_chat_turns = find_turns_with_instructor_chat( 167 | replay, turn, num_prev_turns=num_prev_turns 168 | ) 169 | utterance_context = format_utterances_truncated( 170 | instructor_chat_turns, 171 | tokenizer=tokenizer, 172 | max_tokens=max_utterance_tokens, 173 | num_utterances=num_utterances, 174 | format_utterances_fn=format_utterances, 175 | allow_iterative_reduction=allow_iterative_reduction, 176 | ) 177 | 178 | prev_turns_text_list = multi_attempt_format_prev_turns_truncated( 179 | replay=replay, 180 | turn=turn, 181 | format_intent=partial(format_intent, return_as=dict), 182 | tokenizer=tokenizer, 183 | num_prev_turns=num_prev_turns, 184 | turn_sep=None, # output list 185 | max_tokens=max_prev_turns_tokens, 186 | max_attempts=5, 187 | format_output_dict_fn=format_output_dict_fn, 188 | warn_after_attempts=False, 189 | allow_iterative_reduction=allow_iterative_reduction, 190 | ) 191 | 192 | prev_turns_merged = merge_prev_turns_fn( 193 | prev_turns_text_list=prev_turns_text_list, final_user_message=final_user_message 194 | ) 195 | 196 | sys_prompt = system_prompt_template.format( 197 | num_utterances=num_utterances - 1, # 1 less since we add the first utterance 198 | utterance_context=utterance_context, 199 | height=turn.viewport_height, 200 | width=turn.viewport_width, 201 | num_prev_turns=num_prev_turns, 202 | ) 203 | 204 | if include_html and turn.html not in ["", None] and cands_turn is not None: 205 | dom_tree_raw = lxml.html.fromstring(turn.html, parser=parser) 206 | dom_tree_pruned = clean_and_prune_tree(dom_tree_raw, cands_turn=cands_turn) 207 | trunc = multi_attempt_truncate_dom_tree( 208 | dom_tree=dom_tree_pruned, 209 | tokenizer=tokenizer, 210 | max_tokens=max_html_tokens, 211 | warn_after_attempts=False, 212 | allow_iterative_reduction=allow_iterative_reduction, 213 | ) 214 | html = trunc["tree_repr"] 215 | sys_prompt = html + sys_prompt 216 | else: 217 | html = "" 218 | 219 | if cands_turn is not None: 220 | if add_unused_len_to_cands: 221 | # Add the unused length to the candidates 222 | num_html_tokens = len(tokenizer.tokenize(html)) 223 | num_utter_tokens = len(tokenizer.tokenize(utterance_context)) 224 | num_prev_turns_tokens = len( 225 | tokenizer.tokenize(" ".join(prev_turns_text_list)) 226 | ) 227 | remain_html_tokens = max_html_tokens - num_html_tokens 228 | remain_utter_tokens = max_utterance_tokens - num_utter_tokens 229 | remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens 230 | remain_tokens = ( 231 | remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens 232 | ) 233 | # Add the unused length to the max_candidates_tokens 234 | max_candidates_tokens += remain_tokens 235 | 236 | cands_turn_trunc = multi_attempt_truncate_cands_turn( 237 | cands_turn=cands_turn, 238 | tokenizer=tokenizer, 239 | max_tokens=max_candidates_tokens, 240 | format_candidates_fn=format_candidates_fn, 241 | warn_after_attempts=False, 242 | allow_iterative_reduction=allow_iterative_reduction, 243 | ) 244 | cand_str = format_candidates_fn(cands_turn_trunc, max_char_len=None) 245 | cand_prompt = candidate_prompt_template.format(candidate_str=cand_str) 246 | sys_prompt += "\n" + cand_prompt 247 | 248 | return [{"role": "system", "content": sys_prompt}, *prev_turns_merged] 249 | 250 | 251 | def format_prompt_llama(prompt_records): 252 | """ 253 | DEPRECATED: Use `insert_formatted_chat_into_records` instead 254 | """ 255 | for i, rec in enumerate(prompt_records): 256 | if i != 0 and rec["role"] == "system": 257 | raise ValueError( 258 | f"System prompt should be the first record. Found it at index {i}." 259 | ) 260 | if i == 0 and rec["role"] != "system": 261 | raise ValueError( 262 | f"System prompt should be the first record. Found a {rec['role']} prompt at index {i}." 263 | ) 264 | 265 | sys_prompt = prompt_records[0]["content"] 266 | remain_turns = prompt_records[1:] 267 | 268 | prompt = f"[INST] <>\n{sys_prompt}\n<>\n\n" 269 | 270 | for i, turn in enumerate(remain_turns): 271 | # If there's 1 turn remaining and it is not the user, then there was an issue 272 | if i == len(remain_turns) - 1 and turn["role"] != "user": 273 | raise ValueError( 274 | f"Last turn should be the user. Found a {turn['role']} turn at index {i}." 275 | ) 276 | 277 | if turn["role"] == "user": 278 | # If the previous turn was system, we do not add the [INST] tag 279 | if i == 0: 280 | text = f"{turn['content']}" 281 | else: 282 | text = f"[INST] {turn['content'].strip()}" 283 | 284 | prompt += text 285 | 286 | elif turn["role"] == "assistant": 287 | prompt += f"[/INST] {turn['content'].strip()}" 288 | 289 | else: 290 | raise ValueError( 291 | f"Unknown role {turn['role']} at index {i}. Should be either 'user' or 'assistant'." 292 | ) 293 | 294 | # Add [/INST] tag if the last turn was the user 295 | if remain_turns[-1]["role"] == "user": 296 | prompt += "[/INST]" 297 | 298 | return prompt 299 | 300 | 301 | def __insert_empty_user_content_at_first(prompt: list): 302 | """ 303 | Given a list of dictionary representing the input prompt, insert an empty user content at the first position 304 | after system content, only if it is not already a user content. This is done in place. 305 | """ 306 | if prompt[0]["role"] != "system": 307 | raise ValueError( 308 | f"First prompt must be a system prompt. Got {prompt[0]['role']} instead." 309 | ) 310 | 311 | if prompt[1]["role"] != "user": 312 | prompt.insert(1, {"role": "user", "content": ""}) 313 | 314 | 315 | def insert_formatted_chat_into_records( 316 | records, 317 | tokenizer, 318 | include_output_target=True, 319 | origin_key="prompt", 320 | text_key="text", 321 | ): 322 | """ 323 | Given a list of records, insert the formatted chat into the records. This is done in place. 324 | Note that we need a tokenizer's `apply_chat_template` method to be available. 325 | """ 326 | for i, record in enumerate(records): 327 | __insert_empty_user_content_at_first(record[origin_key]) 328 | 329 | if include_output_target: 330 | target = [{"role": "assistant", "content": record["output_target"]}] 331 | combined = record[origin_key] + target 332 | else: 333 | combined = record[origin_key] 334 | 335 | text = tokenizer.apply_chat_template( 336 | combined, tokenize=False, add_generation_prompt=False 337 | ) 338 | records[i][text_key] = text 339 | -------------------------------------------------------------------------------- /modeling/llama/train.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import json 3 | import logging 4 | from pathlib import Path 5 | 6 | from accelerate import Accelerator 7 | import datasets 8 | from omegaconf import OmegaConf 9 | import hydra 10 | import torch 11 | from transformers import ( 12 | AutoTokenizer, 13 | TrainingArguments, 14 | AutoModelForCausalLM, 15 | TrainingArguments, 16 | ) 17 | from trl import SFTTrainer 18 | 19 | import weblinx as wl 20 | from weblinx.processing import load_candidate_elements 21 | from weblinx.processing.prompt import ( 22 | build_input_records_from_selected_turns, 23 | select_turns_and_candidates_for_prompts, 24 | ) 25 | from weblinx.utils.hydra import save_path_to_hydra_logs 26 | from weblinx.utils import set_seed 27 | 28 | from .processing import ( 29 | build_formatter_for_multichoice, 30 | build_prompt_records_for_llama_truncated, 31 | insert_formatted_chat_into_records, 32 | ) 33 | 34 | 35 | @hydra.main(config_path="conf", config_name="config", version_base=None) 36 | def main(cfg): 37 | set_seed(cfg.seed) 38 | split_path = Path(cfg.data.split_path).expanduser() 39 | model_save_dir = Path(cfg.model.save_dir).expanduser() 40 | model_save_dir.mkdir(exist_ok=True, parents=True) 41 | logging.info(OmegaConf.to_yaml(cfg)) 42 | 43 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=cfg.train.split) 44 | demos = [wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names] 45 | candidates = load_candidate_elements(path=cfg.candidates.train_path) 46 | 47 | tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer, padding_side="right") 48 | tokenizer.pad_token = tokenizer.eos_token 49 | 50 | model_kwargs = dict(torch_dtype=torch.bfloat16) 51 | model_kwargs['trust_remote_code'] = cfg.model.get('trust_remote_code', False) 52 | 53 | if cfg.train.use_accelerator_device_map: 54 | accelerator = Accelerator() 55 | model_kwargs["device_map"] = {"": accelerator.process_index} 56 | 57 | elif cfg.train.use_auto_device_map: 58 | model_kwargs["device_map"] = "auto" 59 | 60 | if cfg.model.use_flash_attention_2: 61 | model_kwargs["use_flash_attention_2"] = True 62 | 63 | model = AutoModelForCausalLM.from_pretrained(cfg.model.name, **model_kwargs) 64 | 65 | format_intent = build_formatter_for_multichoice() 66 | input_records_fname = "input_records_trunc.json" 67 | build_prompt_records_fn = partial( 68 | build_prompt_records_for_llama_truncated, 69 | format_intent=format_intent, 70 | tokenizer=tokenizer, 71 | ) 72 | 73 | selected_turns = select_turns_and_candidates_for_prompts( 74 | demos=demos, 75 | candidates=candidates, 76 | num_candidates=cfg.candidates.k, 77 | ) 78 | 79 | input_records = build_input_records_from_selected_turns( 80 | selected_turns=selected_turns, 81 | format_intent=format_intent, 82 | build_prompt_records_fn=build_prompt_records_fn, 83 | format_prompt_records_fn=None, 84 | ) 85 | 86 | template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer) 87 | insert_formatted_chat_into_records( 88 | input_records, template_tokenizer, include_output_target=True 89 | ) 90 | 91 | with open(model_save_dir.joinpath(input_records_fname), "w") as f: 92 | json.dump(input_records, f, indent=2) 93 | 94 | input_records_texts = [{"text": record["text"]} for record in input_records] 95 | 96 | training_args = TrainingArguments( 97 | output_dir=model_save_dir, 98 | optim=cfg.train.optim, 99 | learning_rate=cfg.train.learning_rate, 100 | num_train_epochs=cfg.train.num_epochs, 101 | per_device_train_batch_size=cfg.train.batch_size_per_device, 102 | gradient_accumulation_steps=cfg.train.gradient_accumulation_steps, 103 | gradient_checkpointing=cfg.train.gradient_checkpointing, 104 | warmup_ratio=cfg.train.warmup_ratio, 105 | lr_scheduler_type=cfg.train.scheduler, 106 | save_strategy="no", 107 | evaluation_strategy="no", 108 | logging_strategy="epoch", 109 | logging_first_step=True, 110 | prediction_loss_only=True, 111 | bf16=True, 112 | bf16_full_eval=True, 113 | ) 114 | 115 | trainer = SFTTrainer( 116 | model=model, 117 | tokenizer=tokenizer, 118 | args=training_args, 119 | train_dataset=datasets.Dataset.from_list(input_records_texts), 120 | max_seq_length=model.config.max_position_embeddings, 121 | dataset_text_field="text", 122 | ) 123 | 124 | trainer.train() 125 | 126 | # Save model, tokenizer, trainer state, and path to hydra logs 127 | trainer.save_model(model_save_dir) 128 | tokenizer.save_pretrained(model_save_dir) 129 | trainer.state.save_to_json(model_save_dir / "trainer_state.json") 130 | save_path_to_hydra_logs(save_dir=model_save_dir) 131 | 132 | # if the model is saved as pytorch_model_fsdp.bin, rename it to pytorch_model.bin 133 | fsdp_model_path = model_save_dir / "pytorch_model_fsdp.bin" 134 | if fsdp_model_path.exists(): 135 | fsdp_model_path.rename(model_save_dir / "pytorch_model.bin") 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /modeling/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.35.0 # Future version may break the code, upgrade with caution 2 | lxml 3 | numpy 4 | datasets 5 | torch 6 | sentence-transformers 7 | peft 8 | backoff 9 | tensorboardX 10 | hydra-core 11 | peft 12 | accelerate 13 | optimum 14 | openai 15 | tiktoken 16 | trl 17 | bitsandbytes 18 | coloredlogs 19 | sacrebleu 20 | bert-score 21 | packaging 22 | ninja 23 | wheel -------------------------------------------------------------------------------- /requirements-basic.txt: -------------------------------------------------------------------------------- 1 | weblinx>=0.3.0rc1 2 | lxml 3 | numpy -------------------------------------------------------------------------------- /requirements-extra.txt: -------------------------------------------------------------------------------- 1 | weblinx[eval]>=0.3.0.rc1 2 | streamlit 3 | sentence-transformers 4 | transformers 5 | playwright 6 | browsergym 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | package_name = "webllama" 4 | version = {} 5 | with open(f"{package_name}/version.py") as fp: 6 | exec(fp.read(), version) 7 | 8 | with open("README.md") as fp: 9 | long_description = fp.read() 10 | 11 | with open('requirements-extra.txt') as f: 12 | extras = f.read().splitlines() 13 | 14 | with open('requirements-basic.txt') as f: 15 | install_requires = f.read().splitlines() 16 | 17 | extras_require = { 18 | "dev": ["black"], 19 | "extra": extras, 20 | } 21 | # Dynamically create the 'all' extra by combining all other extras 22 | extras_require["all"] = sum(extras_require.values(), []) 23 | 24 | setup( 25 | name=package_name, 26 | version=version["__version__"], 27 | author="Xing Han Lù", 28 | author_email=f"{package_name}@googlegroups.com", 29 | url=f"https://github.com/McGill-NLP/{package_name}", 30 | description="Llama-powered agents for automatic web browsing", 31 | long_description=long_description, 32 | packages=find_packages(include=[f"{package_name}*"]), 33 | package_data={}, 34 | install_requires=install_requires, 35 | extras_require=extras_require, 36 | classifiers=[ 37 | "Programming Language :: Python :: 3", 38 | "License :: OSI Approved :: MIT License", 39 | "Operating System :: OS Independent", 40 | ], 41 | python_requires=">=3.8", 42 | # Cast long description to markdown 43 | long_description_content_type="text/markdown", 44 | ) -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | -e .[all] -------------------------------------------------------------------------------- /tests/test_web_turn_processor.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import time 3 | import logging 4 | import unittest 5 | 6 | from sentence_transformers import SentenceTransformer 7 | from transformers import AutoTokenizer, pipeline 8 | import weblinx as wl 9 | import webllama.experimental as wa 10 | 11 | logging.getLogger("urllib3").setLevel(logging.WARNING) 12 | 13 | 14 | class TestWebTurnProcessor(unittest.TestCase): 15 | def setUp(self): 16 | 17 | demos = wl.list_demonstrations("tests/demonstrations") 18 | replay = wl.Replay.from_demonstration(demos[0]) 19 | turn = replay[26] 20 | 21 | self.turn = turn 22 | self.replay = replay 23 | self.action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx" 24 | 25 | self.tokenizer = AutoTokenizer.from_pretrained(self.action_model_name) 26 | 27 | format_intent_input_dmr, format_intent_out_dmr = ( 28 | wa.formatting.build_formatters_dmr() 29 | ) 30 | format_intent_am = partial( 31 | wa.formatting.build_formatters_action_model(), return_as=dict 32 | ) 33 | self.action_history = wa.functions.create_action_history_from_replay( 34 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index 35 | ) 36 | self.state = wa.classes.State( 37 | index=turn.index, 38 | html=turn.html, 39 | bboxes=turn.bboxes, 40 | viewport_height=turn.viewport_height, 41 | viewport_width=turn.viewport_width, 42 | type=turn.type, 43 | ) 44 | 45 | def test_prepare_dmr_query(self): 46 | # We will initialize our processor, which helps us prepare the input for action model 47 | proc = wa.processing.WebTurnProcessor(tokenizer=self.tokenizer) 48 | 49 | # Step 1: prepare query, run DMR and prepare retrieved candidates 50 | query_dmr = proc.prepare_dmr_query(self.action_history, self.state) 51 | 52 | CORRECT_RESULT = 'Viewport(height=746, width=1536) ---- Instructor Utterances: [00:07] Hello [00:13] Open independent ie Website. [01:30] Go to life and send me some life related news [04:00] Open second one and Summarize the first three paragraphs in a few words ---- Previous Turns:tabswitch(origin=102465633, target=102465635, timestamp="04:19") ; load(url="https://search.yahoo.com/search?fr=mcafee&type=E211US714G0&p=chatgpt", timestamp="04:23") ; click(x=268, y=201, tag="a", attrs={}, timestamp="04:24") ; tabcreate(target=102465636, timestamp="04:25") ; tabswitch(origin=102465635, target=102465636, timestamp="04:25")' 53 | self.assertIsInstance(query_dmr, str) 54 | self.assertEqual(query_dmr, CORRECT_RESULT) 55 | -------------------------------------------------------------------------------- /webllama/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from . import experimental -------------------------------------------------------------------------------- /webllama/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from . import classes, functions, integrations, formatting, processing, templates, web -------------------------------------------------------------------------------- /webllama/experimental/classes.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass 3 | from typing import Callable, Dict, List, Tuple, TypedDict 4 | import typing 5 | 6 | from weblinx.utils.format import format_output_dictionary 7 | 8 | # Custom types 9 | UID = typing.NewType("UID", str) 10 | AttrsCore = TypedDict( 11 | "AttrsCore", 12 | {"class": str, "title": str, "href": str, "aria-label": str, "d": str, "src": str}, 13 | ) 14 | 15 | 16 | class BBox(TypedDict): 17 | """ 18 | A class to represent the bounding box of an element. 19 | 20 | Attributes 21 | ---------- 22 | x : int 23 | The x-coordinate of the bounding box. 24 | y : int 25 | The y-coordinate of the bounding box. 26 | width : float 27 | The width of the bounding box. 28 | height : float 29 | The height of the bounding box. 30 | top : float, optional 31 | The top position of the bounding box, calculated from `y` if not provided. 32 | bottom : float, optional 33 | The bottom position of the bounding box, calculated from `y` and `height` if not provided. 34 | left : float, optional 35 | The left position of the bounding box, calculated from `x` if not provided. 36 | right : float, optional 37 | The right position of the bounding box, calculated from `x` and `width` if not provided. 38 | """ 39 | x: int 40 | y: int 41 | width: float 42 | height: float 43 | top: float = None 44 | bottom: float = None 45 | left: float = None 46 | right: float = None 47 | 48 | def __post_init__(self): 49 | """ 50 | Ensures required attributes are provided and calculates optional attributes if not given. 51 | For example, if `top` is not provided, it is calculated from `y`. 52 | """ 53 | if any(x is None for x in [self.x, self.y, self.width, self.height]): 54 | raise ValueError("x, y, width, and height must be provided.") 55 | 56 | if self.top is None: 57 | self.top = self.y 58 | 59 | if self.bottom is None: 60 | self.bottom = self.y + self.height 61 | 62 | if self.left is None: 63 | self.left = self.x 64 | 65 | if self.right is None: 66 | self.right = self.x + self.width 67 | 68 | 69 | @dataclass 70 | class State: 71 | """ 72 | A class to represent the state during navigation. 73 | 74 | Attributes 75 | ---------- 76 | index : int 77 | The index of the state in the sequence of states. 78 | html : str 79 | The DOM tree represented using HTML. 80 | bboxes : Dict[UID, BBox] 81 | A dictionary mapping unique IDs to bounding boxes. 82 | viewport_height : int 83 | The height of the viewport of the browser. 84 | viewport_width : int 85 | The width of the viewport of the browser. 86 | type : str 87 | The type of the state, either "browser" or "chat". 88 | 89 | Methods 90 | ------- 91 | from_dict(cls, dictionary): 92 | Creates a `State` instance from a dictionary. 93 | to_dict(): 94 | Converts the `State` instance to a dictionary. 95 | """ 96 | index: int 97 | html: str 98 | bboxes: Dict[UID, BBox] 99 | viewport_height: int 100 | viewport_width: int 101 | type: str # either "browser" or "chat" 102 | 103 | # check type 104 | def __post_init__(self): 105 | if self.type not in ["browser", "chat"]: 106 | raise ValueError("type must be either 'browser' or 'chat'.") 107 | 108 | @classmethod 109 | def from_dict(cls, dictionary): 110 | """ 111 | Creates a `State` instance from a dictionary. 112 | 113 | Parameters 114 | ---------- 115 | dictionary : dict 116 | The dictionary to create the `State` instance from. 117 | 118 | Returns 119 | ------- 120 | State 121 | The created `State` instance. 122 | """ 123 | return cls( 124 | index=dictionary["index"], 125 | html=dictionary["html"], 126 | bboxes=dictionary["bboxes"], 127 | viewport_height=dictionary["viewport_height"], 128 | viewport_width=dictionary["viewport_width"], 129 | type=dictionary["type"], 130 | ) 131 | 132 | def to_dict(self): 133 | """ 134 | Converts the `State` instance to a dictionary. 135 | 136 | Returns 137 | ------- 138 | dict 139 | A dictionary representation of the `State` instance. 140 | """ 141 | return { 142 | "index": self.index, 143 | "html": self.html, 144 | "bboxes": self.bboxes, 145 | "viewport_height": self.viewport_height, 146 | "viewport_width": self.viewport_width, 147 | "type": self.type, 148 | } 149 | 150 | @dataclass 151 | class Action: 152 | """ 153 | A class to represent an action taken by the user. 154 | 155 | Attributes 156 | ---------- 157 | type : str 158 | The type of the action, either "chat" or "browser". 159 | index : int 160 | The index of the action in the sequence of state/actions. 161 | intent : str 162 | The intent of the action (e.g., "click", "type", "scroll", "say"). 163 | args : Dict[str, str] 164 | A dictionary of arguments associated with the action, such as the unique 165 | ID of the element clicked, the text typed, or the message said. 166 | timestamp : float 167 | The timestamp of the action in seconds, relative to the start time. 168 | tag : str, optional 169 | The HTML tag associated with the action (e.g., "button", "input"). 170 | attrs : AttrsCore, optional 171 | The attributes associated with the action (e.g., "class", "title", "href", "aria-label", "d", "src"). 172 | """ 173 | type: str 174 | index: int 175 | intent: str 176 | args: Dict[str, str] 177 | timestamp: float 178 | tag: str = None 179 | attrs: AttrsCore = None 180 | 181 | def get(self, key): 182 | """ 183 | Retrieves the value of the specified argument key. 184 | 185 | Parameters 186 | ---------- 187 | key : str 188 | The key of the argument to retrieve. 189 | 190 | Returns 191 | ------- 192 | str 193 | The value of the specified argument key. 194 | """ 195 | return self.args.get(key, None) 196 | 197 | @classmethod 198 | def from_dict( 199 | cls, 200 | dictionary: Dict, 201 | included_attrs: Tuple[str] = ("class", "title", "href", "aria-label", "d", "src"), 202 | ) -> "Action": 203 | """ 204 | Creates an `Action` instance from a dictionary. 205 | 206 | Parameters 207 | ---------- 208 | dictionary : dict 209 | The dictionary to create the `Action` instance from. It should have the following 210 | keys: "intent", "index", "timestamp", "attrs" (optional), "tag" (optional), and 211 | any other keys as arguments. Moreover, the type of the action is inferred from 212 | the "intent" key. 213 | included_attrs : tuple of str, optional 214 | A tuple of attribute keys to include in the `attrs` dictionary. 215 | 216 | Returns 217 | ------- 218 | Action 219 | The created `Action` instance. 220 | """ 221 | di = deepcopy(dictionary) 222 | intent = di.pop("intent") 223 | index = di.pop("index") 224 | timestamp = di.pop("timestamp") 225 | attrs = di.pop("attrs", None) 226 | if attrs is not None: 227 | attrs = {k: v for k, v in attrs.items() if k in included_attrs} 228 | 229 | args = di 230 | type_ = "chat" if intent == "say" else "browser" 231 | tag = di.pop("tag") if "tag" in di else None 232 | 233 | return cls( 234 | index=index, 235 | intent=intent, 236 | args=args, 237 | type=type_, 238 | timestamp=timestamp, 239 | attrs=attrs, 240 | tag=tag, 241 | ) 242 | 243 | def to_dict( 244 | self, 245 | include_timestamp=True, 246 | include_attrs=True, 247 | include_tag=True, 248 | include_index=True, 249 | drop_none_coords=False, 250 | format_timestamp_fn=None, 251 | ignore_args=None, 252 | ): 253 | """ 254 | Convert the action to a dictionary, given specific options. 255 | 256 | Parameters 257 | ---------- 258 | include_timestamp: bool 259 | Whether to include the timestamp in the output dictionary, as "timestamp" 260 | include_attrs: bool 261 | Whether to include the attributes in the output dictionary, as "attrs" 262 | include_tag: bool 263 | Whether to include the tag in the output dictionary, as "tag" 264 | include_index: bool 265 | Whether to include the index in the output dictionary, as "index" 266 | ignore_args: list 267 | A list of keys to ignore in the args dictionary, if None, then all keys are included 268 | format_timestamp_fn: callable 269 | A function to format the timestamp, if None, then the raw timestamp is used 270 | start_time: float 271 | The start time of the action, used to calculate the timestamp 272 | 273 | Returns 274 | ------- 275 | dict 276 | A dictionary representation of the action. 277 | """ 278 | if ignore_args is not None: 279 | args = {k: v for k, v in self.args.items() if k not in ignore_args} 280 | else: 281 | args = self.args 282 | 283 | out = {"intent": self.intent, **args} 284 | 285 | if include_tag and self.tag is not None: 286 | out["tag"] = self.tag 287 | 288 | if include_attrs and self.attrs is not None: 289 | out["attrs"] = self.attrs 290 | 291 | if include_timestamp: 292 | if format_timestamp_fn is not None: 293 | out["timestamp"] = format_timestamp_fn(self)["timestamp"] 294 | else: 295 | out["timestamp"] = self.timestamp 296 | 297 | if include_index: 298 | out["index"] = self.index 299 | 300 | if drop_none_coords: 301 | if "x" in out and out["x"] is None: 302 | del out["x"] 303 | if "y" in out and out["y"] is None: 304 | del out["y"] 305 | 306 | return out 307 | 308 | def to_str(self, **kwargs): 309 | """ 310 | Converts the `Action` instance to a formatted string. 311 | 312 | Parameters 313 | ---------- 314 | kwargs : dict 315 | Keyword arguments to pass to the `to_dict` method. 316 | 317 | Returns 318 | ------- 319 | str 320 | A formatted string representation of the action. 321 | 322 | Notes 323 | ----- 324 | 325 | This runs the `to_dict` method and then formats the output dictionary as a string, using 326 | `weblinx.utils.format.format_output_dictionary` with the intent as the "function" key. 327 | """ 328 | di = self.to_dict(**kwargs) 329 | return format_output_dictionary(di, function_key="intent", return_as=str) 330 | 331 | def items(self): 332 | """ 333 | Mimics `weblinx.Turn.items()` to retrieve dictionary items of the action. 334 | 335 | Returns 336 | ------- 337 | ItemsView 338 | A view object that displays a list of a dictionary's key-value tuple pairs. 339 | 340 | Notes 341 | ----- 342 | 343 | This method is aimed to mimic `weblinx.Turn.items()` 344 | """ 345 | di = self.to_dict( 346 | include_timestamp=True, 347 | include_attrs=False, 348 | include_tag=False, 349 | include_index=False, 350 | drop_none_coords=True, 351 | ) 352 | 353 | return di.items() 354 | -------------------------------------------------------------------------------- /webllama/experimental/formatting.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import weblinx.utils.format as wlf 3 | 4 | def build_formatters_action_model() -> callable: 5 | """ 6 | Builds and returns a dictionary of formatters for action model events. 7 | 8 | This function uses partial functions from the `weblinx.utils.format` module to create 9 | formatters for various user actions, such as clicks, text inputs, changes, etc. These 10 | formatters are then combined into a single formatter function for automatically formatting 11 | intents based on user actions. 12 | 13 | Returns 14 | ------- 15 | function 16 | A function that formats intents automatically using the defined formatters. 17 | 18 | Notes 19 | ----- 20 | 21 | Slightly improved over original implementation from weblinx: 22 | https://github.com/McGill-NLP/weblinx/blob/7f151eaf819a9665b9b0b2232a99db6d4c4d2738/modeling/llama/processing.py#L23 23 | """ 24 | format_click = partial(wlf.format_click, formatters=(wlf.format_uid,)) 25 | format_text_input = partial( 26 | wlf.format_text_input, 27 | formatters=( 28 | partial(wlf.format_arg_item, name="text", max_length=200), 29 | wlf.format_uid, 30 | ), 31 | ) 32 | format_change = partial( 33 | wlf.format_change, 34 | formatters=( 35 | partial(wlf.format_arg_item, name="value", max_length=200), 36 | wlf.format_uid, 37 | ), 38 | ) 39 | format_copy = partial(wlf.format_copy, include_timestamp=False) 40 | format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,)) 41 | format_load = partial( 42 | wlf.format_load, 43 | include_transition=False, 44 | include_timestamp=False, 45 | max_length=200, 46 | ) 47 | format_hover = partial(wlf.format_hover, formatters=(wlf.format_uid,)) 48 | format_paste = partial(wlf.format_paste, include_timestamp=False) 49 | format_scroll = partial(wlf.format_scroll, include_timestamp=False) 50 | format_say = partial(wlf.format_say, include_timestamp=False) 51 | format_tab = wlf.format_tab 52 | 53 | format_intent_auto = partial( 54 | wlf.format_intent_automatically, 55 | format_change=format_change, 56 | format_click=format_click, 57 | format_copy=format_copy, 58 | format_hover=format_hover, 59 | format_load=format_load, 60 | format_paste=format_paste, 61 | format_say=format_say, 62 | format_scroll=format_scroll, 63 | format_submit=format_submit, 64 | format_tab=format_tab, 65 | format_text_input=format_text_input, 66 | ) 67 | 68 | return format_intent_auto 69 | 70 | 71 | def build_formatters_dmr(): 72 | """ 73 | Builds and returns two dictionaries of formatters for DMR (Document Model Retrieval) events. 74 | 75 | This function creates formatters for both input and output events using partial functions 76 | from the `weblinx.utils.format` module. For inputs, it formats elements, clicks, changes, 77 | hovers, submits, and text inputs. For outputs, it formats elements, clicks, changes, loads, 78 | scrolls, and text inputs. 79 | 80 | Returns 81 | ------- 82 | tuple of functions 83 | A tuple containing two functions: one for formatting input intents and one for formatting 84 | output intents. 85 | 86 | 87 | Examples 88 | ----- 89 | 90 | ```python 91 | format_intent_input, format_intent_out = build_formatters_dmr() 92 | ``` 93 | 94 | """ 95 | format_element_input = partial( 96 | wlf.format_element, 97 | include_text=False, 98 | include_attrs=("class", "title", "href", "aria-label", "d", "src"), 99 | ) 100 | format_click_input = partial( 101 | wlf.format_click, 102 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp), 103 | ) 104 | format_change_input = partial( 105 | wlf.format_change, 106 | formatters=( 107 | partial(wlf.format_arg_item, name="value"), 108 | format_element_input, 109 | wlf.format_timestamp, 110 | ), 111 | ) 112 | format_hover_input = partial( 113 | wlf.format_hover, 114 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp), 115 | ) 116 | 117 | format_submit_input = partial( 118 | wlf.format_submit, formatters=(format_element_input, wlf.format_timestamp) 119 | ) 120 | 121 | format_text_input_input = partial( 122 | wlf.format_text_input, 123 | formatters=( 124 | partial(wlf.format_arg_item, name="text"), 125 | partial(format_element_input), 126 | wlf.format_timestamp, 127 | ), 128 | ) 129 | 130 | format_intent_input = partial( 131 | wlf.format_intent_automatically, 132 | format_click=format_click_input, 133 | format_change=format_change_input, 134 | format_hover=format_hover_input, 135 | format_submit=format_submit_input, 136 | format_text_input=format_text_input_input, 137 | format_tab=wlf.format_tab, 138 | return_as=str, 139 | ) 140 | 141 | # second, for the output (prediction text) 142 | format_element_out = partial( 143 | wlf.format_element, 144 | # Only want the tag 145 | include_text=False, 146 | include_attrs=False, 147 | ) 148 | 149 | format_click_out = partial(wlf.format_click, formatters=(wlf.format_mouse_xy,)) 150 | format_text_input_out = partial( 151 | wlf.format_text_input, 152 | formatters=( 153 | partial(wlf.format_arg_item, name="text", max_length=200), 154 | format_element_out, 155 | wlf.format_target_bbox, 156 | ), 157 | ) 158 | format_change_out = partial( 159 | wlf.format_change, 160 | formatters=( 161 | partial(wlf.format_arg_item, name="value", max_length=200), 162 | format_element_out, 163 | wlf.format_target_bbox, 164 | ), 165 | ) 166 | format_submit_out = partial( 167 | wlf.format_submit, formatters=(format_element_out, wlf.format_target_bbox) 168 | ) 169 | format_load_out = partial( 170 | wlf.format_load, 171 | include_transition=False, 172 | include_timestamp=False, 173 | max_length=200, 174 | ) 175 | format_scroll_out = partial(wlf.format_scroll, include_timestamp=False) 176 | 177 | format_say_out = partial(wlf.format_say, include_timestamp=False) 178 | 179 | format_intent_out = partial( 180 | wlf.format_intent_automatically, 181 | format_change=format_change_out, 182 | format_click=format_click_out, 183 | format_load=format_load_out, 184 | format_say=format_say_out, 185 | format_scroll=format_scroll_out, 186 | format_submit=format_submit_out, 187 | format_text_input=format_text_input_out, 188 | ) 189 | 190 | return format_intent_input, format_intent_out 191 | 192 | -------------------------------------------------------------------------------- /webllama/experimental/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | from . import browsergym -------------------------------------------------------------------------------- /webllama/experimental/integrations/browsergym/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import random 3 | import lxml.html 4 | 5 | 6 | def postprocess_for_browsergym(action, uid_map=None): 7 | # if uid is a int, we need to convert it to a string 8 | uid_map = {} if uid_map is None else uid_map 9 | 10 | if "uid" in action: 11 | action["uid"] = str(action["uid"]) 12 | if action["uid"] in uid_map: 13 | action["uid"] = uid_map[action["uid"]] 14 | 15 | action = deepcopy(action) 16 | if action["intent"] == "scroll": 17 | if not "x" in action: 18 | action["x"] = 0 19 | if not "y" in action: 20 | action["y"] = 0 21 | 22 | return action 23 | 24 | 25 | def generate_uuid(old_attr_name): 26 | # We do not use old_attr_name here, but it is required by the signature of the function. 27 | def replace_char(c): 28 | r = random.randint(0, 15) 29 | v = r if c == "x" else (r & 0x3 | 0x8) 30 | return format(v, "x") 31 | 32 | uuid_template = "xxxxxxxx-xxxx-4xxx" 33 | return "".join(replace_char(c) if c in "xy" else c for c in uuid_template) 34 | 35 | 36 | def reverse_dict(mapping): 37 | return {v: k for k, v in mapping.items()} 38 | 39 | def replace_bid_with_wl_uid( 40 | dom_str, 41 | new_attr_name="data-webtasks-id", 42 | old_attr_name="bid", 43 | generate_fn=generate_uuid, 44 | return_mapping=False, 45 | ): 46 | """ 47 | Replaces the bid attributes in the dom string with a new attribute name and a new unique id. 48 | 49 | generate_fn must be a function that takes the old attribute name and returns a new unique id. 50 | """ 51 | html_parsed = lxml.html.fromstring(dom_str) 52 | 53 | new_attr_mapping = { 54 | str(elem.get(old_attr_name)): generate_fn(old_attr_name) 55 | for elem in html_parsed.xpath(f"//*[@{old_attr_name}]") 56 | if elem.get(old_attr_name) is not None 57 | } 58 | 59 | # remap the attributes from bid="key" to data-webtasks-id="value" 60 | for elem in html_parsed.xpath("//*[@bid]"): 61 | elem.set(new_attr_name, new_attr_mapping[elem.get(old_attr_name)]) 62 | elem.attrib.pop(old_attr_name) 63 | 64 | html_processed_str = lxml.html.tostring(html_parsed).decode("utf-8") 65 | 66 | if return_mapping: 67 | return html_processed_str, new_attr_mapping 68 | else: 69 | return html_processed_str -------------------------------------------------------------------------------- /webllama/experimental/integrations/browsergym/functions.py: -------------------------------------------------------------------------------- 1 | from browsergym.core.action.utils import get_elem_by_bid 2 | import playwright.sync_api 3 | 4 | page: playwright.sync_api.Page = None 5 | send_message_to_user: callable = None 6 | 7 | # Define your actions here 8 | 9 | def say(utterance: str, *args, **kwargs): 10 | """ 11 | Sends a message to the user. 12 | 13 | Examples: 14 | say("Based on the results of my search, the city was built in 1751.") 15 | """ 16 | send_message_to_user(utterance) 17 | 18 | 19 | def click(uid: str, *args,**kwargs): 20 | """ 21 | Click an element. 22 | 23 | Examples: 24 | click('51') 25 | """ 26 | elem = get_elem_by_bid(page, uid) 27 | elem.click() 28 | 29 | def textinput(uid: str, value: str, *args,**kwargs): 30 | """ 31 | Fill out a form field. It focuses the element and triggers an input event with the entered text. 32 | It works for ,