├── .github
├── scripts
│ └── python
│ │ └── update_version.py
└── workflows
│ ├── publish-python.yaml
│ └── run-tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── app
├── Results.py
├── data
│ └── latex
│ │ ├── column_name_map.json
│ │ ├── custom
│ │ └── appendix
│ │ │ ├── column_name_map.json
│ │ │ └── index_name_map.json
│ │ ├── hide_list.json
│ │ ├── index_name_map.json
│ │ ├── project_name_map.json
│ │ └── shortcut_maps.json
└── utils.py
├── assets
├── LlamaAndGPT.png
├── LlamaAndGPTAndMindAct.png
├── WebLINXTestSplits.png
├── WebLlamaLogo.png
└── llama-3.jpg
├── docs
├── CONTRIBUTING.md
└── README.md
├── examples
├── README.md
├── browsergym
│ ├── agent.py
│ └── run_bg.py
├── complete
│ └── run_all.py
└── web_api
│ ├── run_client.py
│ └── run_http.py
├── modeling
├── README.md
├── dmr
│ ├── conf
│ │ └── config.yaml
│ ├── eval.py
│ ├── processing.py
│ └── train.py
├── llama
│ ├── accelerate
│ │ ├── fsdp_2gpus.yaml
│ │ ├── fsdp_4gpus.yaml
│ │ ├── fsdp_6gpus.yaml
│ │ └── fsdp_8gpus.yaml
│ ├── conf
│ │ └── config.yaml
│ ├── eval.py
│ ├── processing.py
│ └── train.py
└── requirements.txt
├── requirements-basic.txt
├── requirements-extra.txt
├── setup.py
├── tests
├── requirements.txt
└── test_web_turn_processor.py
└── webllama
├── __init__.py
├── experimental
├── __init__.py
├── classes.py
├── formatting.py
├── functions.py
├── integrations
│ ├── __init__.py
│ └── browsergym
│ │ ├── __init__.py
│ │ └── functions.py
├── processing.py
├── templates
│ ├── __init__.py
│ └── weblinx.py
└── web
│ ├── __init__.py
│ ├── client.py
│ └── server.py
└── version.py
/.github/scripts/python/update_version.py:
--------------------------------------------------------------------------------
1 | """
2 | This CLI script is used to update the version of the package. It is used by the
3 | CI/CD pipeline to update the version of the package when a new release is made.
4 |
5 | It uses argparse to parse the command line arguments, which are the new version
6 | and the path to the package's __init__.py file.
7 | """
8 |
9 | import argparse
10 | from pathlib import Path
11 |
12 | def main():
13 | parser = argparse.ArgumentParser(
14 | description="Update the version of the package."
15 | )
16 | parser.add_argument(
17 | "--version",
18 | type=str,
19 | help="The new version of the package.",
20 | required=True,
21 | )
22 | parser.add_argument(
23 | "--path",
24 | type=Path,
25 | help="The path to the package's version file.",
26 | )
27 | args = parser.parse_args()
28 |
29 | with open(args.path, "w") as f:
30 | f.write(f"__version__ = \"{args.version}\"")
31 |
32 |
33 | if __name__ == "__main__":
34 | main()
--------------------------------------------------------------------------------
/.github/workflows/publish-python.yaml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Publish Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | bump-version-and-publish:
12 | name: Bump version and upload release to PyPI
13 |
14 | runs-on: ubuntu-latest
15 | permissions:
16 | # IMPORTANT: this permission is mandatory for trusted publishing
17 | id-token: write
18 |
19 | environment:
20 | name: pypi
21 | url: https://pypi.org/p/webllama
22 |
23 | steps:
24 | - uses: actions/checkout@v2
25 | - name: Set up Python
26 | uses: actions/setup-python@v2
27 | with:
28 | python-version: '3.10'
29 |
30 | - name: Update version.py with release tag
31 | env:
32 | RELEASE_TAG: ${{ github.event.release.tag_name }}
33 | run: |
34 | python .github/scripts/python/update_version.py --version $RELEASE_TAG --path "webllama/version.py"
35 |
36 | - name: Install dependencies
37 | run: |
38 | python -m pip install --upgrade pip
39 | pip install setuptools wheel twine
40 |
41 | - name: Build package
42 | run: |
43 | python setup.py sdist bdist_wheel
44 |
45 | - name: Publish package distributions to PyPI
46 | uses: pypa/gh-action-pypi-publish@release/v1
--------------------------------------------------------------------------------
/.github/workflows/run-tests.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: Check out repository
15 | uses: actions/checkout@v2
16 |
17 | - name: Set up Python
18 | uses: actions/setup-python@v3
19 | with:
20 | python-version: '3.9' # Specify your required Python version
21 |
22 | - name: Cache Python dependencies
23 | uses: actions/cache@v2
24 | with:
25 | path: ~/.cache/pip
26 | key: ${{ runner.os }}-pip-${{ hashFiles('tests/requirements.txt') }}
27 | restore-keys: |
28 | ${{ runner.os }}-pip-
29 |
30 | - name: Install dependencies
31 | run: |
32 | python -m pip install --upgrade pip
33 | pip install -r tests/requirements.txt # Assumes you have a requirements.txt file
34 |
35 | - name: Cache test assets
36 | uses: actions/cache@v2
37 | with:
38 | path: tests/demonstrations
39 | key: assets-${{ github.sha }}
40 | restore-keys: |
41 | assets-
42 |
43 | - name: Download test demos from release URL into `tests/demonstrations`
44 | run: |
45 | mkdir -p tests/demonstrations
46 | curl -L -o tests/demonstrations/aaabtsd.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aaabtsd.zip
47 | unzip -u tests/demonstrations/aaabtsd.zip -d tests/demonstrations
48 | curl -L -o tests/demonstrations/aajfwoq.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aajfwoq.zip
49 | unzip -u tests/demonstrations/aajfwoq.zip -d tests/demonstrations
50 |
51 | - name: Run tests
52 | run: |
53 | python -m unittest discover -s tests
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # CUSTOM
163 | modeling/checkpoints
164 | modeling/results/
165 | modeling/results/**/hydra_path.txt
166 | modeling/results/**/hashes.json
167 | modeling/results/**/scores-fta-1.csv
168 | modeling/results/**/results.json
169 | modeling/results/**/eval_scores.csv
170 | modeling/results/dmr/**/scores.jsonl
171 | modeling/wl_data
172 | app/data/inputs.json
173 | modeling/logs/
174 | venv*/
175 | .python-version
176 |
177 | # TESTS
178 | tests/demonstrations
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 McGill NLP
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
🖥️ WebLlama🦙
4 |
5 |
Building agents that can browse the web by following instructions and talking to you
6 |
7 | | 💻 [**GitHub**](https://github.com/McGill-NLP/webllama) | 🏠 [**Homepage**](https://webllama.github.io) | 🤗 [**`Llama-3-8B-Web`**](https://huggingface.co/McGill-NLP/Llama-3-8B-Web) |
8 | | :--: | :--: | :--: |
9 |
10 |
11 |

12 |
13 |
14 |
15 |
16 | | `WebLlama` helps you build powerful agents, powered by Meta Llama 3, for browsing the web on your behalf | Our first model, [`Llama-3-8B-Web`](https://huggingface.co/McGill-NLP/Llama-3-8B-Web), surpasses GPT-4V (`*`zero-shot) by 18% on [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) |
17 | |:---: | :---: |
18 | |  |  |
19 |
20 | ## About the project
21 |
22 | | `WebLlama` | The goal of our project is to build effective human-centric agents for browsing the web. We don't want to replace users, but equip them with powerful assistants. |
23 | |:---: | :---|
24 | | Modeling | We are build on top of cutting edge libraries for training Llama agents on web navigation tasks. We will provide training scripts, optimized configs, and instructions for training cutting-edge Llamas. |
25 | | Evaluation | Benchmarks for testing Llama models on real-world web browsing. This include *human-centric* browsing through dialogue ([`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/)), and we will soon add more benchmarks for automatic web navigation (e.g. Mind2Web). |
26 | | Data | Our first model is finetuned on over 24K instances of web interactions, including `click`, `textinput`, `submit`, and dialogue acts. We want to continuously curate, compile and release datasets for training better agents. |
27 | | Deployment | We want to make it easy to integrate Llama models with existing deployment platforms, including Playwright, Selenium, and BrowserGym. We are currently focusing on making this a reality. |
28 |
29 |
30 |
31 | Click to show citation
32 |
33 | If you use `WebLlama` in your research, you can cite the ICML 2024 paper upon which the training and evaluation are originally based on, by adding the following to your bibtex file:
34 |
35 | ```
36 | @misc{lu_2024_weblinx,
37 | title={WebLINX: Real-World Website Navigation with Multi-Turn Dialogue},
38 | author={Xing Han Lù and Zdeněk Kasner and Siva Reddy},
39 | year={2024},
40 | eprint={2402.05930},
41 | archivePrefix={arXiv},
42 | primaryClass={cs.CL}
43 | }
44 | ```
45 |
46 | Example usage (in latex):
47 |
48 | ```
49 | We use the WebLlama library, which builds on top of WebLINX \citep{lu_2024_weblinx}.
50 | ```
51 |
52 | ```
53 | We use Llama-3-8B-Web, a model finetuned on WebLINX demonstrations \citep{lu_2024_weblinx}.
54 | ```
55 |
56 |
57 |
58 | ## Modeling
59 |
60 | > [!NOTE]
61 | > The model is available on the 🤗 Hugging Face Model Hub as [`McGill-NLP/Llama-3-8B-Web`](https://huggingface.co/McGill-NLP/Llama-3-8B-Web). The training and evaluation data is available on [Hugging Face Hub as `McGill-NLP/WebLINX`](https://huggingface.co/datasets/McGill-NLP/WebLINX).
62 |
63 | Our first agent is a finetuned [`Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model, which was recently released by Meta GenAI team. We have finetuned this model on the [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) dataset, which contains over 100K instances of web navigation and dialogue, each collected and verified by expert annotators. We use a 24K curated subset for training the data.
64 |
65 | 
66 |
67 | **It surpasses GPT-4V (zero-shot `*`) by over 18% on the [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) benchmark**, achieving an overall score of 28.8% on the out-of-domain test splits (compared to 10.5% for GPT-4V). It chooses more useful links (34.1% vs 18.9% *seg-F1*), clicks on more relevant elements (27.1% vs 13.6% *IoU*) and formulates more aligned responses (37.5% vs 3.1% *chr-F1*).
68 |
69 | It's extremely straightforward to use the model via Hugging Face's `transformers`, `datasets` and `hub` libraries:
70 |
71 | ```python
72 | from datasets import load_dataset
73 | from huggingface_hub import snapshot_download
74 | from transformers import pipeline
75 |
76 | # We use validation data, but you can use your own data here
77 | valid = load_dataset("McGill-NLP/WebLINX", split="validation")
78 | snapshot_download("McGill-NLP/WebLINX", repo_type="dataset", allow_patterns="templates/*")
79 | template = open('templates/llama.txt').read()
80 |
81 | # Run the agent on a single state (text representation) and get the action
82 | state = template.format(**valid[0])
83 | agent = pipeline("McGill-NLP/Llama-3-8b-Web")
84 | out = agent(state, return_full_text=False)[0]
85 | print("Action:", out['generated_text'])
86 |
87 | # Here, you can use the predictions on platforms like playwright or browsergym
88 | action = process_pred(out['generated_text']) # implement based on your platform
89 | env.step(action) # execute the action in your environment
90 | ```
91 |
92 | ## Evaluation
93 |
94 | We believe short demo videos showing how well an agent performs is NOT enough to judge an agent. Simply put, **we do not know if we have a good agent if we do not have good benchmarks.** We need to systematically evaluate agents on wide range of tasks, spanning from simple instruction-following web navigation to complex dialogue-guided browsing.
95 |
96 |
97 |
98 | This is why we chose [`WebLINX`](https://mcgill-nlp.github.io/weblinx/) as our first benchmark. In addition to the training split, the benchmark has 4 real-world splits, with the goal of testing multiple dimensions of generalization: new websites, new domains, unseen geographic locations, and scenarios where the *user cannot see the screen and relies on dialogue*. It also covers 150 websites, including booking, shopping, writing, knowledge lookup, and even complex tasks like manipulating spreadsheets. Evaluating on this benchmark is very straightforward:
99 |
100 | ```bash
101 | cd modeling/
102 |
103 | # After installing dependencies, downloading the dataset, and training/evaluating your model, you can evaluate:
104 | python -m weblinx.eval # automatically find all `results.jsonl` and generate an `aggregated_results.json` file
105 |
106 | # Visualize your results with our app:
107 | cd ..
108 | streamlit run app/Results.py
109 | ```
110 |
111 | > 👷♀️ **Next steps**\
112 | > We are planning to evaluate our models on more benchmarks, including Mind2Web, a benchmark for automatic web navigation. We believe that a good agent should be able to navigate the web both through dialogue and autonomously, and potentially attain even broader ranges of capabilities useful for real-world web browsing.
113 |
114 |
115 | ## Data
116 |
117 | Although the 24K training examples from [`WebLINX 1.0`](https://mcgill-nlp.github.io/weblinx/) provide a good starting point for training a capable agent, we believe that more data is needed to train agents that can generalize to a wide range of web navigation tasks. Although it has been trained and evaluated on 150 websites, there are millions of websites that has never been seen by the model, with new ones being created every day.
118 |
119 | **This motivates us to continuously curate, compile and release datasets for training better agents.** As an immediate next step, we will be incorporating `Mind2Web`'s training data into the equation, which also covers over 100 websites.
120 |
121 | > [!NOTE]
122 | > WebLINX is now available as a benchmark through [BrowserGym](https://github.com/ServiceNow/BrowserGym), allowing you to access demonstration steps in the same way you would access a web agent environment like [WebArena](https://webarena.dev/) or [MiniWoB](https://miniwob.farama.org/index.html). This also allows you to run agents from the [Agentlab](https://github.com/ServiceNow/AgentLab) library, including agents that achieve SOTA performance through Claude-3.5-Sonnet. To enable this integration, we are releasing the `weblinx-browsergym` extension for BrowserGym on PyPi, as well as a [new dataset, WebLINX 1.1, derived from WebLINX on Huggingface](https://huggingface.co/datasets/McGill-NLP/weblinx-browsergym). In WebLINX 1.1, a small number of demonstrations were removed after processing, but no new demonstration was added. There are substantial changes to the steps being evaluated, with the inclusion of tab actions. Please report your results as "WebLINX-1.1", "WebLINX-BrowserGym" or "WebLINX-BG" in your work, to differentiate from the [initial release of weblinx (1.0)](https://huggingface.co/datasets/McGill-NLP/WebLINX/tree/v1.0).
123 |
124 |
125 | ## Deployment
126 |
127 | We are working hard to make it easy for you to deploy Llama web agents to the web. We want to integrate `WebLlama` with existing deployment platforms, including Microsoft's Playwright, ServiceNow Research's BrowserGym, and other partners.
128 |
129 | At the moment, we offer the following integrations:
130 | * `Browsergym`: Please find more information in [`examples/README.md`](examples/README.md) and [`docs/README.md`](docs/README.md).
131 |
132 | ## Code
133 |
134 | The code for finetuning the model and evaluating it on the [`WebLINX` 1.0](https://mcgill-nlp.github.io/weblinx/) benchmark is available now.
135 | * **Modeling**: You can find the detailed instructions in [modeling](modeling/README.md) for training `Llama-3-8B-Web` on the `WebLINX` 1.0 dataset.
136 | * **Examples**: We provide a few example for using the `webllama` API and models, including web API, end-to-end, and BrowserGym integration. You can find them in [examples](examples/README.md).
137 | * **App**: We provide a simple Streamlit app for visualizing the results of your model on the `WebLINX` 1.0 benchmark. You can find the code in [app](app/Results.py).
138 | * **Docs**: We provide detailed documentation for the code in [docs](docs/README.md).
139 |
140 |
141 | > 👷♀️ **Next steps**\
142 | > We are actively working on new data and evaluation at the moment! If you want to help, please create an issue describing what you would like to contribute, and we will be happy to help you get started.
143 |
144 |
145 | ## License
146 |
147 | The code in this repository is licensed under the MIT license, unless otherwise specified in the header of the file. Other materials (models, data, images) have their own licenses, which are specified in the original pages.
148 |
149 | ## FAQ
150 |
151 | ### How can I contribute to the project?
152 |
153 | We are actively looking for collaborators to help us build the best Llama-3 web agents! To get started, open an issue about what you would like to contribute, and once it has been discussed, you can submit a pull request.
154 |
155 |
156 | ## Citation
157 |
158 | If you use `WebLlama` in your research, you can cite the ICML 2024 paper upon which the training and evaluation are originally based on, by adding the following to your bibtex file:
159 |
160 | ```
161 | @misc{lu_2024_weblinx,
162 | title={WebLINX: Real-World Website Navigation with Multi-Turn Dialogue},
163 | author={Xing Han Lù and Zdeněk Kasner and Siva Reddy},
164 | year={2024},
165 | eprint={2402.05930},
166 | archivePrefix={arXiv},
167 | primaryClass={cs.CL}
168 | }
169 | ```
170 |
171 | Example usage (in latex):
172 |
173 | ```
174 | We use the WebLlama library, which builds on top of WebLINX \citep{lu_2024_weblinx}.
175 | ```
176 |
177 | ```
178 | We use Llama-3-8B-Web, a model finetuned on WebLINX demonstrations \citep{lu_2024_weblinx}.
179 | ```
180 |
--------------------------------------------------------------------------------
/app/Results.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from collections import defaultdict
3 | import os
4 | import time
5 | from datetime import datetime
6 | import json
7 | import random
8 | import string
9 | import shutil
10 | import traceback
11 | import sys
12 | from pathlib import Path
13 | import textwrap as tw
14 |
15 | import streamlit as st
16 | from PIL import Image, ImageDraw
17 | import pandas as pd
18 |
19 | import weblinx as wt
20 |
21 |
22 | sys.path.append(str(Path(__file__).resolve().parent.parent))
23 |
24 | from app.utils import show_overlay
25 |
26 | def remove_latex(x):
27 | if isinstance(x, tuple):
28 | return tuple(map(remove_latex, x))
29 |
30 | if not isinstance(x, str):
31 | return x
32 |
33 | if x.startswith("Unnamed"):
34 | return ""
35 |
36 | if "}" not in x:
37 | return x
38 |
39 | return x.rpartition("}")[0].partition("{")[2]
40 |
41 |
42 | def load_and_clean_df(path):
43 | df = pd.read_csv(path, index_col=0, header=[0, 1])
44 |
45 | df.index.name = "model"
46 | df.columns = df.columns.map(remove_latex)
47 | df.index = df.index.map(remove_latex)
48 |
49 | df = df.reset_index().set_index(["model", "intent"])
50 |
51 | return df
52 |
53 |
54 | def build_cond_df(base_dir="analysis/data/tables/"):
55 | base_dir = Path(base_dir).resolve()
56 | tables_dir = base_dir / "results"
57 |
58 | cond_df = pd.concat(
59 | axis=1,
60 | objs=[
61 | load_and_clean_df(tables_dir / "results_general_grouped_intents.csv"),
62 | load_and_clean_df(tables_dir / "results_text_grouped_intents.csv"),
63 | load_and_clean_df(tables_dir / "results_elem_grouped_intents.csv"),
64 | ],
65 | )
66 |
67 | return cond_df
68 |
69 |
70 | def build_uncond_df(base_dir="analysis/data/tables/"):
71 | base_dir = Path(base_dir).resolve()
72 | tables_dir = base_dir / "results_unconditional"
73 |
74 | uncond_df = pd.concat(
75 | axis=1,
76 | objs=[
77 | load_and_clean_df(
78 | base_dir / "results" / "results_general_grouped_intents.csv"
79 | ),
80 | load_and_clean_df(tables_dir / "results_text_grouped_intents.csv"),
81 | load_and_clean_df(tables_dir / "results_elem_grouped_intents.csv"),
82 | ],
83 | )
84 |
85 | return uncond_df
86 |
87 |
88 | @st.cache_data(ttl=60 * 2)
89 | def build_dataframe(score_path, choice):
90 | with open(score_path) as f:
91 | scores = json.load(f)
92 |
93 | for score in scores:
94 | replacements = [
95 | ("website", "test-web"),
96 | ("blind", "test-vis"),
97 | ("subcategory", "test-cat"),
98 | ("geography", "test-geo"),
99 | ("dev", "dev-deprecated"),
100 | ("indomain", "test-indomain"),
101 | ]
102 | for original, new in replacements:
103 | score["split"] = score["split"].replace(original, new)
104 | df = pd.DataFrame(scores)
105 |
106 | if choice != "Conditional":
107 | df["score"] = df["unconditional_score"]
108 | df.pop("unconditional_score")
109 |
110 | dff = df.pivot(
111 | index=["intent", "project_name", "model_name"],
112 | columns=["split", "metric"],
113 | values="score",
114 | )
115 |
116 | return dff
117 |
118 |
119 | def add_test_avg_inplace(df, fillna_with=0):
120 | splits = df.columns.get_level_values(0).unique().tolist()
121 | test_splits = [x for x in splits if x.startswith("test") and not x.endswith("iid")]
122 | metrics = df.columns.get_level_values(1).unique().tolist()
123 | # We need to take the mean of the test splits for each metric in metrics, we call this test-avg
124 | for metric in metrics:
125 | test_scores = [df[(split, metric)] for split in test_splits]
126 | test_df = pd.concat(test_scores, axis=1)
127 | if fillna_with is not None:
128 | test_df = test_df.fillna(fillna_with)
129 |
130 | df[("test-avg", metric)] = test_df.mean(axis=1)
131 |
132 |
133 | def preset_to_values():
134 | return {
135 | "All approximate": {
136 | "intent": [
137 | "overall",
138 | # "change",
139 | "click",
140 | "load",
141 | "say",
142 | # "scroll",
143 | "submit",
144 | "textinput",
145 | ],
146 | "metric": ["overall", "iou", "chrf", "urlf"],
147 | },
148 | "Group Approximate": {
149 | "intent": [
150 | "overall",
151 | "text-group",
152 | "element-group",
153 | ],
154 | "metric": ["overall", "intent-match", "iou", "chrf-urlf"],
155 | },
156 | "All intent-match": {
157 | "intent": [
158 | "change",
159 | "click",
160 | "load",
161 | "say",
162 | "scroll",
163 | "submit",
164 | "textinput",
165 | ],
166 | "metric": ["intent-match"],
167 | },
168 | "change": {
169 | "intent": ["change"],
170 | "metric": ["intent-match", "iou"],
171 | },
172 | "click": {
173 | "intent": ["click"],
174 | "metric": ["intent-match", "iou"],
175 | },
176 | "say": {
177 | "intent": ["say"],
178 | "metric": ["intent-match", "chrf"],
179 | },
180 | "textinput": {
181 | "intent": ["textinput"],
182 | "metric": ["intent-match", "iou", "chrf"],
183 | },
184 | "load": {
185 | "intent": ["load"],
186 | "metric": ["intent-match", "urlf"],
187 | },
188 | "submit": {
189 | "intent": ["submit"],
190 | "metric": ["intent-match", "iou"],
191 | },
192 |
193 | }
194 |
195 | def latex_sort_func(name):
196 | if not (name.endswith('B') or name.endswith("M")):
197 | return name, 0
198 |
199 | left, sep, right = name.rpartition("-")
200 |
201 | num = float(right[:-1])
202 |
203 | if right.endswith("B"):
204 | rest = 1e9
205 | elif right.endswith("M"):
206 | rest = 1e6
207 | else:
208 | rest = 1
209 |
210 | num = num * rest
211 |
212 | if left.startswith("MindAct"):
213 | left = 0
214 | elif left.startswith("Flan"):
215 | left = 1
216 | elif left.startswith("Pix2Struct"):
217 | left = 2
218 | elif left.startswith("Fuyu"):
219 | left = 3
220 | elif left.startswith("Sheared"):
221 | left = 4
222 | elif left.startswith("Llama"):
223 | left = 5
224 | elif left.startswith("GPT"):
225 | left = 6
226 |
227 | return left, num
228 |
229 |
230 |
231 |
232 | @st.cache_data(ttl=60 * 2)
233 | def filter_models_by_project(projects, df):
234 | # reset all indices except project_name
235 | df = df.copy().reset_index().set_index("project_name")
236 | # filter by project
237 | rem_models = df.loc[projects]["model_name"].unique().tolist()
238 | return rem_models
239 |
240 |
241 | def run(score_path="modeling/results/aggregated_scores.json"):
242 | st.title("Results Table Viewer")
243 |
244 | presets = preset_to_values()
245 |
246 | # Either choose cond or uncond
247 | with st.sidebar:
248 | use_two_cols = st.checkbox(
249 | "Use two columns", value=True, help="Use two columns for the dropdowns"
250 | )
251 |
252 | pivot_intent_index = st.checkbox(
253 | "Show intent as column",
254 | value=True,
255 | help="Whether to show intent as column, or keep it as index",
256 | )
257 |
258 | choice = st.radio(
259 | "Results wrt matched intent",
260 | ["Conditional", "Unconditional"],
261 | help=(
262 | "Conditional: only count samples where the predicted intent matches the reference "
263 | "intent (when there is no match, the sample is discarded)"
264 | "Unconditional: counts all samples (when there is no match, the score is set to 0)"
265 | ),
266 | index=1,
267 | )
268 |
269 | preset_choice = st.selectbox("Metric/Intent Preset", list(presets.keys()), index=0)
270 |
271 | remove_na = st.checkbox(
272 | "Drop cols with only NaN", value=True
273 | )
274 |
275 | remove_zero = st.checkbox(
276 | "Drop cols with only 0", value=True
277 | )
278 |
279 | if use_two_cols:
280 | col1, col2 = st.columns(2)
281 | else:
282 | col1 = col2 = st.columns(1)[0]
283 |
284 |
285 | df = build_dataframe(score_path, choice)
286 |
287 | add_test_avg_inplace(df)
288 |
289 | splits = df.columns.get_level_values("split").unique().tolist().copy()
290 | metrics = df.columns.get_level_values("metric").unique().tolist()
291 | models = df.index.get_level_values("model_name").unique().tolist()
292 | intents = df.index.get_level_values("intent").unique().tolist()
293 | projects = df.index.get_level_values("project_name").unique().tolist()
294 |
295 | default_splits = ["valid"]
296 | default_intents = presets[preset_choice]["intent"]
297 | default_metrics = presets[preset_choice]["metric"]
298 |
299 | default_projects = ["llama_ft"]
300 |
301 | splits = col1.multiselect("Split", splits, default=default_splits)
302 | metrics = col1.multiselect("Metric", metrics, default=default_metrics)
303 | intents = col2.multiselect("Intent", intents, default=default_intents)
304 | sort_by_container = col2.container()
305 | projects = col1.multiselect("Project", projects, default=default_projects)
306 |
307 | remaining_models = filter_models_by_project(projects=projects, df=df)
308 | models = col2.multiselect("Model", remaining_models, default=remaining_models)
309 |
310 | if len(projects) == 0:
311 | st.error("Please select at least one project")
312 | st.stop()
313 |
314 | if len(models) == 0:
315 | st.error("Please select at least one model")
316 | st.stop()
317 |
318 | cols = pd.MultiIndex.from_product([splits, metrics], names=["split", "metric"])
319 | # remove all cols not in dff
320 | cols = cols.intersection(df.columns)
321 |
322 | idx = pd.MultiIndex.from_product(
323 | [intents, projects, models], names=["intent", "project_name", "model_name"]
324 | )
325 | # remove all idx not in dff
326 | idx = idx.intersection(df.index)
327 |
328 | dff = df.loc[idx, cols]
329 |
330 | if pivot_intent_index:
331 | dff = dff.reset_index("intent").pivot(columns="intent")
332 |
333 | if remove_na:
334 | dff = dff.dropna(axis=1, how="all")
335 | if remove_zero:
336 | dff = dff.loc[:, (dff != 0).any(axis=0)]
337 |
338 | with sort_by_container:
339 | sort_by = st.selectbox("Sort by", dff.columns.tolist())
340 |
341 | # Sort by
342 | if sort_by:
343 | dff = dff.sort_values(sort_by, ascending=False)
344 |
345 | # swap order of column indices so that we have, in order, split, intent, metric
346 | dff = dff.swaplevel(1,2, axis=1)
347 |
348 |
349 | with st.expander("Latex Table"):
350 | dropped_col_indices = st.multiselect(
351 | "Drop columns levels", dff.columns.names, default=['split']
352 | )
353 | use_shorthand = st.checkbox("Use shorthand", value=False)
354 | use_custom_sorting = st.checkbox("Use custom sorting", value=True)
355 | remove_column_names = st.checkbox("Remove column names", value=True)
356 | merge_index = st.checkbox("Merge index", value=False)
357 |
358 | # dropdown
359 | custom_index_names = st.selectbox(
360 | "Custom index names", ["None", "Appendix"], index=0
361 | )
362 |
363 | custom_column_names = st.selectbox(
364 | "Custom column names", ["None", "Appendix"], index=0
365 | )
366 |
367 | # Rename metrics to symbols for latex
368 | with open("app/data/latex/column_name_map.json") as f:
369 | column_name_map = json.load(f)
370 |
371 | with open("app/data/latex/index_name_map.json") as f:
372 | index_name_map = json.load(f)
373 |
374 | with open("app/data/latex/hide_list.json") as f:
375 | hide_list = json.load(f)
376 | with open("app/data/latex/shortcut_maps.json") as f:
377 | shortcut_maps = json.load(f)
378 |
379 | if custom_index_names == "Appendix":
380 | with open("app/data/latex/custom/appendix/index_name_map.json") as f:
381 | custom_index_name_map = json.load(f)
382 | # update index name map with custom names
383 | index_name_map.update(custom_index_name_map)
384 |
385 | if custom_column_names == "Appendix":
386 | with open("app/data/latex/custom/appendix/column_name_map.json") as f:
387 | custom_column_name_map = json.load(f)
388 | # update column name map with custom names
389 | column_name_map.update(custom_column_name_map)
390 |
391 | # Remove rows from dff_latex if the project_name index and model_name index are in hide_list
392 | dff_latex = dff.copy()
393 | dff_latex = dff_latex.reset_index()
394 | for project_name, model_name in hide_list:
395 | dff_latex = dff_latex[~((dff_latex["project_name"] == project_name) & (dff_latex["model_name"] == model_name))]
396 | dff_latex = dff_latex.set_index(["project_name", "model_name"])
397 | dff_latex = dff_latex.rename(columns=column_name_map)
398 | dff_latex = dff_latex.rename(index=index_name_map)
399 |
400 |
401 | for i in dropped_col_indices:
402 | dff_latex.columns = dff_latex.columns.droplevel(i)
403 |
404 | # Convert all column index names to Capitalized
405 | dff_latex.columns.names = [x.capitalize() for x in dff_latex.columns.names]
406 |
407 | # Sort by index level 0
408 | if use_custom_sorting:
409 | dff_latex = dff_latex.sort_index(
410 | key=lambda index: index.map(latex_sort_func), ascending=True,
411 | )
412 | else:
413 | dff_latex = dff_latex.sort_index(ascending=True)
414 |
415 | if merge_index:
416 | # Join the multiindex into a single index separated by -
417 | dff_latex.index = dff_latex.index.map(lambda x: " - ".join(x))
418 |
419 | if remove_column_names:
420 | dff_latex.columns.names = [None] * len(dff_latex.columns.names)
421 |
422 | # should be at 4 decimal places
423 | # Multiply by 100 to get percentage
424 | dff_latex = dff_latex * 100
425 | dff_latex = dff_latex.to_latex(float_format="{:0.2f}".format)
426 |
427 | if use_shorthand:
428 | for full_value, shortcut in shortcut_maps.items():
429 | dff_latex = dff_latex.replace(full_value, shortcut)
430 |
431 | st.code(dff_latex, language="latex")
432 |
433 | with st.expander("Markdown Table"):
434 | st.code(dff.round(4).to_markdown(), language="markdown")
435 |
436 | with st.expander("Results Table", expanded=True):
437 | st.table(dff)
438 |
439 | # Best models
440 | if not pivot_intent_index:
441 | # We need to pivot the table to get the best model per intent
442 | dff = dff.reset_index("intent").pivot(columns="intent")
443 |
444 | best_df = pd.concat([dff.idxmax(), dff.max()], axis=1)
445 | # Set name of best_df.multiindex indexes
446 | best_df.columns = ["best_model", "best_score"]
447 |
448 | best_df["project_name"] = best_df["best_model"].apply(lambda x: x[0] if isinstance(x, tuple) else x)
449 | best_df["best_model"] = best_df["best_model"].apply(lambda x: x[1] if isinstance(x, tuple) else x)
450 | # Reorder columns
451 | best_df = best_df[["project_name", "best_model", "best_score"]]
452 |
453 | with st.expander("Best Models", expanded=True):
454 | st.table(best_df)
455 |
456 |
457 | if __name__ == "__main__":
458 | try:
459 | st.set_page_config(layout="wide")
460 | except:
461 | pass
462 | # run = protect_with_authentication(run)
463 | run()
464 |
--------------------------------------------------------------------------------
/app/data/latex/column_name_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "chrf": "chrF",
3 | "intent-match": "IM",
4 | "click": "\\texttt{click}",
5 | "submit": "\\texttt{submit}",
6 | "change": "\\texttt{change}",
7 | "textinput": "\\texttt{textinput}",
8 | "load": "\\texttt{load}",
9 | "say": "\\texttt{say}",
10 | "chrf-urlf": "SeqF",
11 | "urlf": "URLF",
12 | "iou": "IoU",
13 | "element-group": "Element Group",
14 | "text-group": "Text Group",
15 | "overall": "Overall"
16 | }
--------------------------------------------------------------------------------
/app/data/latex/custom/appendix/column_name_map.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/app/data/latex/custom/appendix/index_name_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "flan_m2w": "SFT",
3 | "flan_mht_v2": "SFT"
4 | }
--------------------------------------------------------------------------------
/app/data/latex/hide_list.json:
--------------------------------------------------------------------------------
1 | [
2 | ["openai", "HuggingFaceH4/zephyr-7b-beta"],
3 | ["llama_fft_mht", "mistralai/Mistral-7B-Instruct-v0.1"],
4 | ["flan_m2w", "google/flan-t5-large"],
5 | ["flan_m2w", "google/flan-t5-base"],
6 | ["flan_mht_v2", "osunlp/MindAct_ActionPrediction_flan-t5-large"],
7 | ["flan_mht_v2", "osunlp/MindAct_ActionPrediction_flan-t5-base"]
8 | ]
--------------------------------------------------------------------------------
/app/data/latex/index_name_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "flan_m2w": "M2W",
3 | "flan_mht_v2": "OTR",
4 | "openai": "0S",
5 | "llama_fft_mht": "SFT",
6 | "google/flan-t5-base": "Flan-T5-250M",
7 | "google/flan-t5-large": "Flan-T5-780M",
8 | "google/flan-t5-xl": "Flan-T5-3B",
9 | "osunlp/MindAct_ActionPrediction_flan-t5-base": "MindAct-T5-250M",
10 | "osunlp/MindAct_ActionPrediction_flan-t5-large": "MindAct-T5-780M",
11 | "osunlp/MindAct_ActionPrediction_flan-t5-xl": "MindAct-T5-3B",
12 | "princeton-nlp/Sheared-LLaMA-1.3B": "Sheared-LLaMA-1.3B",
13 | "princeton-nlp/Sheared-LLaMA-2.7B": "Sheared-LLaMA-2.7B",
14 | "meta-llama/Llama-2-7b-chat-hf": "Llama-2-7B",
15 | "meta-llama/Llama-2-13b-chat-hf": "Llama-2-13B",
16 | "gpt-3.5-turbo-1106": "GPT-3.5T",
17 | "gpt-4-1106-preview": "GPT-4T",
18 | "gpt-4-vision-preview": "GPT-4V",
19 | "google/pix2struct-base": "Pix2Struct-282M",
20 | "google/pix2struct-large": "Pix2Struct-1.3B",
21 | "adept/fuyu-8b": "Fuyu-8B",
22 | "fuyu": "SFT",
23 | "pix2struct": "SFT",
24 | "ft:gpt-3.5-turbo-1106:mcgill-nlp:webtasks-mht:8XWKFM3a": "GPT-3.5F"
25 | }
--------------------------------------------------------------------------------
/app/data/latex/project_name_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "flan_m2w": "MindAct"
3 | }
--------------------------------------------------------------------------------
/app/data/latex/shortcut_maps.json:
--------------------------------------------------------------------------------
1 | {
2 | "Sheared-LLaMA": "S-LLaMA",
3 | "MindAct-T5": "MindAct",
4 | "submit": "sbmt",
5 | "textinput": "input",
6 | "Overall": "All",
7 | "URLF": "urlF",
8 | "\\multicolumn{2}{r}{All}": "All & All",
9 | "Text Group": "TG",
10 | "Element Group": "EG"
11 | }
--------------------------------------------------------------------------------
/app/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from datetime import datetime
4 | import os
5 | import json
6 | from pathlib import Path
7 | import sys
8 | import shutil
9 | import time
10 | import traceback
11 |
12 | import pandas as pd
13 | import streamlit as st
14 |
15 | import json
16 | from PIL import Image, ImageDraw
17 |
18 |
19 | CACHE_TTL = 60 * 60 * 24 * 14
20 |
21 | """
22 | Streamlit app utilities
23 | """
24 |
25 |
26 | @st.cache_data(ttl=CACHE_TTL)
27 | def load_json(basedir, name):
28 | if not os.path.exists(f"{basedir}/{name}.json"):
29 | return None
30 |
31 | with open(f"{basedir}/{name}.json", "r") as f:
32 | j = json.load(f)
33 |
34 | return j
35 |
36 |
37 | def load_json_no_cache(basedir, name):
38 | if not os.path.exists(f"{basedir}/{name}.json"):
39 | return None
40 |
41 | with open(f"{basedir}/{name}.json", "r") as f:
42 | j = json.load(f)
43 |
44 | return j
45 |
46 |
47 | def save_json(basedir, name, data):
48 | with open(f"{basedir}/{name}.json", "w") as f:
49 | json.dump(data, f, indent=4)
50 |
51 |
52 | @st.cache_data
53 | def load_image(image_file):
54 | img = Image.open(image_file)
55 | return img
56 |
57 |
58 | @st.cache_resource
59 | def load_page(page_path):
60 | return open(page_path, "rb")
61 |
62 |
63 | def shorten(s):
64 | # shorten to 100 characters
65 | if len(s) > 100:
66 | s = s[:100] + "..."
67 |
68 | return s
69 |
70 |
71 | @st.cache_data
72 | def parse_arguments(action):
73 | s = []
74 | event_type = action["intent"]
75 | args = action["arguments"]
76 |
77 | if event_type == "textInput":
78 | txt = args["text"]
79 |
80 | txt = txt.strip()
81 |
82 | # escape markdown characters
83 | txt = txt.replace("_", "\\_")
84 | txt = txt.replace("*", "\\*")
85 | txt = txt.replace("`", "\\`")
86 | txt = txt.replace("$", "\\$")
87 |
88 | txt = shorten(txt)
89 |
90 | s.append(f'"{txt}"')
91 | elif event_type == "change":
92 | s.append(f'{args["value"]}')
93 | elif event_type == "load":
94 | url = args["properties"].get("url") or args.get("url")
95 | short_url = shorten(url)
96 | s.append(f'"[{short_url}]({url})"')
97 |
98 | if args["properties"].get("transitionType"):
99 | s.append(f'*{args["properties"]["transitionType"]}*')
100 | s.append(f'*{" ".join(args["properties"]["transitionQualifiers"])}*')
101 | elif event_type == "scroll":
102 | s.append(f'{args["scrollX"]}, {args["scrollY"]}')
103 | elif event_type == "say":
104 | s.append(f'"{args["text"]}"')
105 | elif event_type == "copy":
106 | selected = shorten(args["selected"])
107 | s.append(f'"{selected}"')
108 | elif event_type == "paste":
109 | pasted = shorten(args["pasted"])
110 | s.append(f'"{pasted}"')
111 | elif event_type == "tabcreate":
112 | s.append(f'{args["properties"]["tabId"]}')
113 | elif event_type == "tabremove":
114 | s.append(f'{args["properties"]["tabId"]}')
115 | elif event_type == "tabswitch":
116 | s.append(
117 | f'{args["properties"]["tabIdOrigin"]} -> {args["properties"]["tabId"]}'
118 | )
119 |
120 | if args.get("element"):
121 |
122 | if event_type == 'click':
123 | x = round(args['metadata']['mouseX'], 1)
124 | y = round(args['metadata']['mouseY'], 1)
125 | uid = args.get('element', {}).get('attributes', {}).get("data-webtasks-id")
126 | s.append(f"*x =* {x}, *y =* {y}, *uid =* {uid}")
127 | else:
128 | top = round(args["element"]["bbox"]["top"], 1)
129 | left = round(args["element"]["bbox"]["left"], 1)
130 | right = round(args["element"]["bbox"]["right"], 1)
131 | bottom = round(args["element"]["bbox"]["bottom"], 1)
132 |
133 | s.append(f"*top =* {top}, *left =* {left}, *right =* {right}, *bottom =* {bottom}")
134 |
135 | return ", ".join(s)
136 |
137 |
138 | @st.cache_resource(max_entries=50_000, ttl=CACHE_TTL)
139 | def create_visualization(_img, event_type, bbox, x, y, screenshot_path):
140 | # screenshot_path is not used, but we need it for caching since we can't cache
141 | # PIL images (hence the leading underscore in the variable name to indicate
142 | # that it's not hashed)
143 | _img = _img.convert("RGBA")
144 | draw = ImageDraw.Draw(_img)
145 |
146 | # draw a bounding box around the element
147 | color = {
148 | "click": "red",
149 | "hover": "orange",
150 | "textInput": "blue",
151 | "change": "green",
152 | "submit": "purple",
153 | }[event_type]
154 |
155 | left = bbox["left"]
156 | top = bbox["top"]
157 | w = bbox["width"]
158 | h = bbox["height"]
159 | draw.rectangle((left, top, left + w, top + h), outline=color, width=2)
160 |
161 | if event_type in ["click", "hover"]:
162 | r = 15
163 | for i in range(1, 5):
164 | rx = r * i
165 | draw.ellipse((x - rx, y - rx, x + rx, y + rx), outline=color, width=3)
166 | draw.ellipse((x - r, y - r, x + r, y + r), fill=color)
167 |
168 | return _img
169 |
170 |
171 | @st.cache_data(max_entries=50_000, ttl=CACHE_TTL)
172 | def get_screenshot_minimal(screenshot_path, event_type, bbox, x, y, new_width=None, overlay=True):
173 | img = load_image(screenshot_path)
174 | # vis = None
175 |
176 | if event_type in ["click", "textInput", "change", "hover", "submit"] and overlay:
177 | img = create_visualization(img, event_type, bbox, x, y, screenshot_path)
178 |
179 | if new_width is not None:
180 | # Resize to 800px wide
181 | w, h = img.size
182 | new_w = new_width
183 | new_h = int(new_w * h / w)
184 | img = img.resize((new_w, new_h))
185 | print(f"Resized '{screenshot_path}' to", new_w, new_h)
186 |
187 | return img
188 |
189 |
190 | def get_event_info(d):
191 | event_type = d["action"]["intent"]
192 |
193 | try:
194 | bbox = d["action"]["arguments"]["element"]["bbox"]
195 | except KeyError:
196 | bbox = None
197 |
198 | try:
199 | x = d["action"]["arguments"]["properties"]["x"]
200 | y = d["action"]["arguments"]["properties"]["y"]
201 | except KeyError:
202 | x = None
203 | y = None
204 |
205 | return event_type, bbox, x, y
206 |
207 |
208 | def get_screenshot(d, basedir, new_width=None, overlay=True):
209 | screenshot_filename = d["state"]["screenshot"]
210 |
211 | if not screenshot_filename:
212 | return None
213 |
214 | event_type, bbox, x, y = get_event_info(d)
215 | screenshot_path = f"{basedir}/screenshots/{screenshot_filename}"
216 |
217 | return get_screenshot_minimal(
218 | screenshot_path, event_type, bbox, x, y, new_width=new_width, overlay=overlay
219 | )
220 |
221 |
222 | def text_bubble(text, color):
223 | text = text.replace("\n", "
").replace("\t", " " * 8)
224 | return f'{text}
'
225 |
226 |
227 | def gather_chat_history(data, example_index):
228 | chat = []
229 | for i, d in enumerate(data):
230 | if d["type"] == "chat":
231 | if i >= example_index:
232 | break
233 | chat.append(d)
234 |
235 | # # leave out just 5 last messages
236 | # if len(chat) > 5:
237 | # chat = chat[-5:]
238 |
239 | return reversed(chat)
240 |
241 |
242 | def format_chat_message(d):
243 | if d["speaker"] == "instructor":
244 | return text_bubble("🧑 " + d["utterance"], "rgba(63, 111, 255, 0.35)")
245 | else:
246 | return text_bubble("🤖 " + d["utterance"], "rgba(185,185,185,0.35)")
247 |
248 |
249 | def find_screenshot(data, example_index, basedir, overlay=True):
250 | # keep looking at previous screenshots until we find one
251 | # if there is none, return None
252 |
253 | for i in range(example_index, -1, -1):
254 | d = data[i]
255 | if d["type"] == "chat":
256 | continue
257 |
258 | screenshot = get_screenshot(d, basedir, overlay=overlay)
259 | if screenshot:
260 | return screenshot
261 |
262 | return None
263 |
264 |
265 | def create_visualization_2(_img, bbox, color, width, x, y):
266 | _img = _img.convert("RGBA")
267 | draw = ImageDraw.Draw(_img)
268 |
269 | if bbox:
270 | left = bbox["left"]
271 | top = bbox["top"]
272 | w = bbox["width"]
273 | h = bbox["height"]
274 | draw.rectangle((left, top, left + w, top + h), outline=color, width=width)
275 |
276 | if x and y:
277 | r = 8
278 | for i in range(1, 4):
279 | rx = r * i
280 | draw.ellipse((x - rx, y - rx, x + rx, y + rx), outline=color, width=2)
281 | draw.ellipse((x - r, y - r, x + r, y + r), fill=color)
282 |
283 | return _img
284 |
285 |
286 | def rescale_bbox(bbox, scaling_factor):
287 | return {
288 | k: bbox[k] * scaling_factor
289 | for k in ["top", "left", "width", "height", "right", "bottom"]
290 | if k in bbox
291 | }
292 |
293 |
294 | def show_overlay(
295 | _img,
296 | pred,
297 | ref,
298 | turn_args,
299 | turn_metadata,
300 | scale_pred=True,
301 | show=("pred_coords", "ref", "pred_elem"),
302 | ):
303 | scaling_factor = turn_metadata.get("zoomLevel", 1.0)
304 |
305 | if "pred_elem" in show:
306 | # First, draw red box around predicted element
307 | if pred.get("element") and pred["element"].get("bbox"):
308 | # rescale the bbox by scaling_factor
309 | bbox = rescale_bbox(pred["element"]["bbox"], scaling_factor)
310 | _img = create_visualization_2(
311 | _img, bbox, color="red", width=9, x=None, y=None
312 | )
313 |
314 | if "ref" in show:
315 | # Finally, draw a blue box around the reference element (if it exists)
316 | if ref.get("element") and ref["element"].get("bbox"):
317 | # rescale the bbox
318 | bbox = rescale_bbox(ref["element"]["bbox"], scaling_factor)
319 | x = turn_args.get("properties", {}).get("x")
320 | y = turn_args.get("properties", {}).get("y")
321 | _img = create_visualization_2(_img, bbox, color="blue", width=6, x=x, y=y)
322 |
323 | if "pred_coords" in show:
324 | # Second draw a green box and x/y coordinate based on predicted coordinates
325 | # The predicted coordinates are the raw output of the model,
326 | # Whereas the predicted element is the inferred element from the predicted coordinates
327 | if pred["args"].get("x") and pred["args"].get("y"):
328 | x = pred["args"]["x"]
329 | y = pred["args"]["y"]
330 |
331 | if scale_pred:
332 | x = x * scaling_factor
333 | y = y * scaling_factor
334 | else:
335 | x = None
336 | y = None
337 |
338 | # If the predicted element is a bounding box, draw a green box around it
339 | if all(c in pred["args"] for c in ["top", "left", "right", "bottom"]):
340 | bbox = {
341 | "top": pred["args"]["top"],
342 | "left": pred["args"]["left"],
343 | "width": (pred["args"]["right"] - pred["args"]["left"]),
344 | "height": (pred["args"]["bottom"] - pred["args"]["top"]),
345 | "right": pred["args"]["right"],
346 | "bottom": pred["args"]["bottom"],
347 | }
348 |
349 | if scale_pred:
350 | bbox = rescale_bbox(bbox, scaling_factor)
351 | else:
352 | # Otherwise, do nothing
353 | bbox = None
354 |
355 | _img = create_visualization_2(_img, bbox=bbox, color="green", width=3, x=x, y=y)
356 |
357 | return _img
358 |
359 |
360 |
361 | def get_zoom_level(d):
362 | """
363 | Get the zoom level of the page
364 | """
365 |
366 | # If it's type chat, we just set the zoom level to 1 and ignore
367 | if d["type"] == "chat":
368 | return 100
369 |
370 | # the zoom level is in the state of the turn
371 | # d is the turn
372 | # the zoom level is in d['state']['zoom']
373 | # if it is not present, return 100
374 | option1 = (
375 | d.get("action", {})
376 | .get("arguments", {})
377 | .get("properties", {})
378 | .get("zoomLevel")
379 | )
380 | option2 = (
381 | d.get("action", {})
382 | .get("arguments", {})
383 | .get('metadata', {})
384 | .get("zoomLevel")
385 | )
386 |
387 | if option1 is not None:
388 | return option1
389 | elif option2 is not None:
390 | return option2
391 | else:
392 | raise ValueError("Zoom level not found in the turn.")
--------------------------------------------------------------------------------
/assets/LlamaAndGPT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/LlamaAndGPT.png
--------------------------------------------------------------------------------
/assets/LlamaAndGPTAndMindAct.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/LlamaAndGPTAndMindAct.png
--------------------------------------------------------------------------------
/assets/WebLINXTestSplits.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/WebLINXTestSplits.png
--------------------------------------------------------------------------------
/assets/WebLlamaLogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/WebLlamaLogo.png
--------------------------------------------------------------------------------
/assets/llama-3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/McGill-NLP/webllama/696a7c3664fe6610b411a16d27010055874e2714/assets/llama-3.jpg
--------------------------------------------------------------------------------
/docs/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Instructions
2 |
3 | ## Running tests
4 |
5 | To run the unit tests, run:
6 |
7 | ```bash
8 | python -m unittest discover -s tests
9 | ```
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # `webllama.experimental` API
2 |
3 | `webllama.experimental` is the new experimental API for working with webllama models. It will eventually be moved to `webllama` directly (once the API is deemed stable).
4 |
5 |
6 | ## Setup
7 |
8 | ```bash
9 | # Please choose the proper version to ensure you do not break the code
10 | # if there are breaking changes in the future.
11 | # e.g. 0.1.0
12 | pip install webllama==""
13 | ```
14 |
15 | You will need to download test demonstrations if you want to run the subsequent scripts that use existing weblinx demonstrations.
16 |
17 | ```bash
18 | mkdir -p tests/demonstrations
19 | curl -L -o tests/demonstrations/aaabtsd.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aaabtsd.zip
20 | unzip -u tests/demonstrations/aaabtsd.zip -d tests/demonstrations
21 | curl -L -o tests/demonstrations/aajfwoq.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aajfwoq.zip
22 | unzip -u tests/demonstrations/aajfwoq.zip -d tests/demonstrations
23 | ```
24 |
25 | ## Quickstart with `webllama.experimental.processing`
26 |
27 | To install:
28 | ```bash
29 | pip install webllama
30 | # if you want to install transformers, pytorch and sentence-transformers, run:
31 | pip install webllama[modeling]
32 | ```
33 |
34 | First, you will need to construct your own `action_history` and `state` using `webllama.experimental.classes`:
35 | ```python
36 | import webllama.experimental as wa
37 |
38 | # Create your action history and state!
39 | action_history = [
40 | wa.classes.Action(...), # ...
41 | ]
42 | state = wa.classes.State(...)
43 | ```
44 |
45 | You will also need to load your `dmr` and `act_model` models. For example, you can use `transformers` and `sentence-transformers` to load them:
46 | ```python
47 | from sentence_transformers import SentenceTransformer
48 | from transformers import AutoTokenizer, pipeline
49 |
50 | # You can choose your own DMR model, and action model
51 | act_model = pipeline(model=action_model_name, device=0, torch_dtype="auto")
52 | dmr = SentenceTransformer(dmr_name, device="cuda")
53 | ```
54 |
55 | Now, inside a Python script, you can use the `webllama.experimental.processing` to seamlessly use `Action` and `State` with action model and DMR, and also process the output:
56 |
57 | ```python
58 | import webllama.experimental as wa
59 |
60 | # We will initialize our processor, which helps us prepare the input for action model
61 | proc = wa.processing.WebTurnProcessor(tokenizer=act_model.tokenizer)
62 |
63 | # Step 1: prepare query, run DMR and prepare retrieved candidates
64 | query_dmr = proc.prepare_dmr_query(action_history, state)
65 | elems = proc.prepare_dmr_elements(state=state)
66 | scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems)
67 | top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k)
68 | cands_str = proc.prepare_candidates(top_cands)
69 |
70 | # Step 2: format candidates, utterances, state, and previous actions
71 | html = proc.prepare_state_html(state.html, cands_uids=cands_uids)
72 | utterances = proc.prepare_instructor_chat(action_history, state)
73 | prev_actions = proc.prepare_prev_actions(action_history, state)
74 |
75 | # Let's use the default system prompt template, but you can also use your own
76 | sys_prompt_template: str = proc.default_system_prompt_template
77 | sys_prompt = sys_prompt_template.format(
78 | html=html,
79 | utterances=utterances,
80 | candidates=cands_str,
81 | # ...
82 | )
83 | input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions)
84 |
85 | # Use your tokenizer to convert the input to string and pass it to the action model
86 | input_str = act_model.tokenizer.apply_chat_template(input_chat, tokenize=False)
87 | output = act_model(input_str, ...)
88 | pred_action = proc.process_action_model_output(output, state.index, elems)
89 | a = wa.classes.Action.from_dict(pred_action)
90 | ```
91 |
92 |
93 | ## End-to-end example
94 |
95 | Here's a full, self-contained example of how to use `webllama.experimental` to interact with a web page using a DMR model and an action model:
96 |
97 | ```python
98 | from functools import partial
99 | import time
100 | import logging
101 |
102 | from sentence_transformers import SentenceTransformer
103 | from transformers import AutoTokenizer, pipeline
104 | import weblinx as wl
105 | import webllama.experimental as wa
106 |
107 | logging.getLogger("urllib3").setLevel(logging.WARNING)
108 |
109 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects
110 | # To get that, we will use an example from weblinx, but it's easy to do manually (see below).
111 |
112 | demos = wl.list_demonstrations("tests/demonstrations")
113 | replay = wl.Replay.from_demonstration(demos[0])
114 | turn = replay[26]
115 |
116 | format_intent_am = partial(
117 | wa.formatting.build_formatters_action_model(), return_as=dict
118 | )
119 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr()
120 | action_history = wa.functions.create_action_history_from_replay(
121 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
122 | )
123 | state = wa.classes.State(
124 | index=turn.index,
125 | html=turn.html,
126 | bboxes=turn.bboxes,
127 | viewport_height=turn.viewport_height,
128 | viewport_width=turn.viewport_width,
129 | type=turn.type,
130 | )
131 |
132 | # Now, we can start!
133 | # First, load the DMR model we will use to select candidate elements
134 | dmr_name = "McGill-NLP/MiniLM-L6-dmr"
135 | action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx"
136 | tokenizer_chat_name = "McGill-NLP/Llama-2-7b-chat-weblinx"
137 |
138 | tokenizer_chat = AutoTokenizer.from_pretrained(tokenizer_chat_name)
139 | act_model = pipeline(model=action_model_name, device=0, torch_dtype="auto")
140 | dmr = SentenceTransformer(dmr_name, device="cuda")
141 |
142 | # We will initialize our processor, which helps us prepare the input for action model
143 | proc = wa.processing.WebTurnProcessor(tokenizer=act_model.tokenizer, start_time=time.time())
144 |
145 | # Step 1: prepare query, run DMR and prepare retrieved candidates
146 | query_dmr = proc.prepare_dmr_query(action_history, state)
147 | elems = proc.prepare_dmr_elements(state=state)
148 | scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems)
149 | top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k)
150 | cands_str = proc.prepare_candidates(top_cands)
151 |
152 | # Step 2: format candidates, utterances, state, and previous actions
153 | html = proc.prepare_state_html(state.html, cands_uids=cands_uids)
154 | utterances = proc.prepare_instructor_chat(action_history, state)
155 | prev_actions = proc.prepare_prev_actions(action_history, state)
156 |
157 | # Let's use the default system prompt template, but you can also use your own
158 | sys_prompt_template: str = proc.default_system_prompt_template
159 | sys_prompt = sys_prompt_template.format(
160 | html=html,
161 | num_utterances=proc.num_utterances - 1,
162 | utterances=utterances,
163 | height=state.viewport_height,
164 | width=state.viewport_width,
165 | num_prev_actions=proc.num_prev_actions,
166 | candidates=cands_str,
167 | )
168 | input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions)
169 |
170 | # We can now use the tokenizer's apply_chat_template method to convert it to a format
171 | # that can be used by the action model
172 | input_str = tokenizer_chat.apply_chat_template(input_chat, tokenize=False)
173 |
174 | # Let's now pass our input to the action model
175 | output = act_model(
176 | input_str,
177 | max_new_tokens=256,
178 | return_full_text=False,
179 | batch_size=1,
180 | pad_token_id=tokenizer.eos_token_id,
181 | )
182 | pred_action = proc.process_action_model_output(
183 | output=output, index=state.index, elems=elems
184 | )
185 | # optional: For certain platforms you may need to postprocess the action
186 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action)
187 | print(pred_action)
188 | # You can now convert this an Action object and add it to the action history
189 | a = wa.classes.Action.from_dict(pred_action)
190 | action_history.append(a)
191 | ```
192 |
193 | ## Tests
194 |
195 | To run the tests:
196 |
197 | ```bash
198 | python -m unittest discover -s tests
199 | ```
200 |
201 | ## Web API
202 |
203 | ### Running Server
204 |
205 | To launch the default server:
206 | ```bash
207 | # If you do not want to save logs, omit `--save_logs`
208 | python -m webllama.experimental.web.server --save_logs
209 | ```
210 |
211 | To create your own server, simply inherit:
212 | ```python
213 | from webllama.experimental.web.server import Server
214 |
215 | from ..classes import Action, State
216 |
217 | # Assuming the classes Action, State, and other necessary imports are already defined
218 | # as provided in your initial setup.
219 |
220 | # Initialize logging
221 | logging.basicConfig(level=logging.INFO)
222 |
223 | class Server(Server):
224 | # override initialize and run
225 | def initialize(self, dmr_name, action_model_name, device, dmr_device, am_device, torch_dtype):
226 | # initialize your model here
227 |
228 | def run(self, action_history_json, state_json):
229 | # ...
230 | pred_action = {
231 | # ...
232 | }
233 | return json.dumps(pred_action)
234 | ```
235 |
236 | ### Connecting via SSH
237 |
238 | To connect to the server via SSH, you can use the following command:
239 | ```bash
240 | ssh -N -L 8450:localhost:8450 user@server
241 |
242 | # Example:
243 | ssh -N -L 8450:localhost:8450 nlp-gpu-2
244 | ```
245 |
246 | ### Using API
247 |
248 | You can directly send http request to the web server, or use the client.
249 |
250 | Example of HTTP request in python:
251 |
252 | ```python
253 | from functools import partial
254 | import http.client
255 | import json
256 |
257 | from functools import partial
258 | import webllama.experimental as wa
259 | import weblinx as wl
260 |
261 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects
262 | demos = wl.list_demonstrations("tests/demonstrations")
263 | replay = wl.Replay.from_demonstration(demos[0])
264 | turn = replay[26]
265 |
266 | format_intent_am = partial(
267 | wa.formatting.build_formatters_action_model(), return_as=dict
268 | )
269 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr()
270 | action_history = wa.functions.create_action_history_from_replay(
271 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
272 | )
273 | state = wa.classes.State(
274 | index=turn.index,
275 | html=turn.html,
276 | bboxes=turn.bboxes,
277 | viewport_height=turn.viewport_height,
278 | viewport_width=turn.viewport_width,
279 | type=turn.type,
280 | )
281 |
282 | # Create a connection to the localhost on the port where your server is running
283 | conn = http.client.HTTPConnection('localhost', 8450)
284 |
285 | # Prepare the POST request data
286 | post_data = json.dumps({
287 | 'action_history': action_history_dict,
288 | 'state': state_dict
289 | })
290 | headers = {'Content-Type': 'application/json'}
291 |
292 | # Send a POST request with JSON data
293 | conn.request("POST", "/", body=post_data, headers=headers)
294 | response = conn.getresponse()
295 | print(f"Status: {response.status}")
296 | print(f"Reason: {response.reason}")
297 | print(f"Body: {response.read().decode()}")
298 | response.close()
299 |
300 | # Close the connection
301 | conn.close()
302 | ```
303 |
304 | ### Client
305 |
306 | A high level client is provided in `webllama.experimental.web.client`. You can use it as follows:
307 |
308 | ```python
309 | from functools import partial
310 | import webllama.experimental as wa
311 | import weblinx as wl
312 |
313 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects
314 | demos = wl.list_demonstrations("tests/demonstrations")
315 | replay = wl.Replay.from_demonstration(demos[0])
316 | turn = replay[26]
317 |
318 | format_intent_am = partial(
319 | wa.formatting.build_formatters_action_model(), return_as=dict
320 | )
321 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr()
322 | action_history = wa.functions.create_action_history_from_replay(
323 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
324 | )
325 | state = wa.classes.State(
326 | index=turn.index,
327 | html=turn.html,
328 | bboxes=turn.bboxes,
329 | viewport_height=turn.viewport_height,
330 | viewport_width=turn.viewport_width,
331 | type=turn.type,
332 | )
333 |
334 | # Now, we can start!
335 | pred_action = wa.web.client.get_prediction(
336 | action_history, state, address="localhost", port=8450, max_new_tokens=128
337 | )
338 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action)
339 | print(pred_action)
340 | a = wa.classes.Action.from_dict(pred_action)
341 | print(a)
342 | ```
343 |
344 | ## Building objects
345 |
346 | > Note: This section is a work in progress.
347 |
348 | ### Build `webllama.experimental.classes.Action`
349 |
350 | #### `say` action
351 |
352 | ```python
353 | utterance_instructor = wa.classes.Action(
354 | type="chat",
355 | intent="say",
356 | index=2,
357 | args=dict(
358 | speaker="instructor", utterance="Open independent ie Website.", x=None, y=None
359 | ),
360 | timestamp=13.234,
361 | tag=None,
362 | attrs=None,
363 | )
364 | ```
365 |
366 | #### `click` action
367 |
368 | To be added.
369 |
370 | #### `load` action
371 |
372 | To be added.
373 |
374 | #### `textinput` action
375 |
376 | To be added.
377 |
378 | #### `submit` action
379 |
380 | To be added.
381 |
382 | ### Build `webllama.experimental.classes.Bbox`
383 |
384 | To be added.
385 |
386 | ### Build `webllama.experimental.classes.State`
387 |
388 | To be added.
389 |
390 | ## Contributing
391 |
392 | For more information on contributing, please check out the [contributing docs](docs/CONTRIBUTING.md).
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | ### Web API and client
4 |
5 | You can find examples of how to use the server directly with `http.client.HTTPConnection` and through our client in [`examples/web_api/`](/examples/web_api/), respectively with `run_http.py` and `run_client.py`. You should let the server stay up for both examples. For more information, please read the section above about the Web API.
6 |
7 | ### End-to-end
8 |
9 | You can find an end-to-end example of using `webllama.experimental` in [`examples/complete/run_all.py`](/examples/complete):
10 |
11 | ```bash
12 | python examples/complete/run_all.py
13 | ```
14 |
15 |
16 | ### BrowserGym integration
17 |
18 | We provide directly integration to BrowserGym and examples to use it. You can find an example at [`examples/browsergym/run_bg.py`](/examples/browsergym).
19 |
20 |
21 | On remote server (with GPU and hosting the webllama model), run:
22 | ```bash
23 | # transformers, sentence-transformers, pytorch, etc.
24 | pip install -e .[modeling]
25 | ```
26 |
27 | First, remotely, run:
28 |
29 | ```bash
30 | # change if needed:
31 | export CUDA_VISIBLE_DEVICES=0
32 |
33 | python -m webllama.experimental.web.server --save_logs
34 | ```
35 |
36 | Then, connect to your remote server via SSH:
37 |
38 | ```bash
39 | # 8450 is the default port for our server
40 | ssh -N -L 8450:localhost:8450 "@"
41 | ```
42 |
43 | Now, on your local machine, run:
44 |
45 | ```bash
46 | pip install -e .
47 | # browsergym integration
48 | pip install "browsergym==0.3.*"
49 | # install playwright
50 | playwright install
51 | ```
52 |
53 | ```bash
54 | python examples/browsergym/run_bg.py
55 | ```
56 |
--------------------------------------------------------------------------------
/examples/browsergym/agent.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from copy import deepcopy
3 | from functools import partial
4 | import time
5 |
6 |
7 | from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str
8 | from browsergym.core.action.highlevel import HighLevelActionSet
9 | import weblinx as wl
10 |
11 | import webllama.experimental as wa
12 |
13 | from webllama.experimental.integrations.browsergym.functions import (
14 | say,
15 | click,
16 | textinput,
17 | load,
18 | scroll,
19 | wait,
20 | )
21 | from webllama.experimental.integrations.browsergym import replace_bid_with_wl_uid, reverse_dict, postprocess_for_browsergym
22 |
23 | def remap_bboxes(bboxes, attrs_map):
24 | """
25 | Cleans the bboxes dictionary by replacing the keys with the new unique ids.
26 | """
27 | return {attrs_map[k]: v for k, v in bboxes.items()}
28 |
29 | class AgentBase(ABC):
30 | """
31 | A template class that defines the required signature of an agent interacting with a browsergym environment.
32 | """
33 |
34 | @abstractmethod
35 | def reset(self, seed=None) -> None:
36 | """
37 | Resets the agent.
38 |
39 | """
40 | pass
41 |
42 | @abstractmethod
43 | def get_action(self, obs: dict) -> str:
44 | """
45 | Updates the agent with the current observation, and returns its next action (plus an info dict, optional).
46 |
47 | Parameters:
48 | -----------
49 | obs: dict
50 | The current observation of the environment.
51 | """
52 | pass
53 |
54 | def preprocess_obs(self, obs: dict) -> dict:
55 | """Default preprocessing of the observation."""
56 | pass
57 |
58 | def get_action_mapping(self) -> callable:
59 | """
60 | Returns a callable that can be used to map the agent actions to executable python code.
61 | """
62 | return None
63 |
64 |
65 | class WebLinxAgent(AgentBase):
66 | action_history = None
67 |
68 | def reset(self, seed=None) -> None:
69 | self.action_history = []
70 | self.messages = []
71 | self.start_time = time.time()
72 | self.has_user_message = False
73 |
74 | @property
75 | def num_messages(self):
76 | return len(self.messages)
77 |
78 | @staticmethod
79 | def get_bboxes(xprops):
80 | bboxes = {}
81 | for k in xprops:
82 | if xprops[k]["visibility"] == 1.0:
83 | bbox = dict(zip(["x", "y", "width", "height"], xprops[k]["bbox"]))
84 | # add top, left, bottom, right
85 | bbox["top"] = bbox["y"]
86 | bbox["left"] = bbox["x"]
87 | bbox["bottom"] = bbox["y"] + bbox["height"]
88 | bbox["right"] = bbox["x"] + bbox["width"]
89 | bboxes[k] = bbox
90 |
91 | return bboxes
92 |
93 | @staticmethod
94 | def infer_viewport_from_bboxes(bboxes):
95 | """
96 | DO NOT USE THIS, THIS FUNCTION IS NOT WORKING PROPERLY
97 | """
98 | if not bboxes:
99 | return 0, 0
100 |
101 | x = [bboxes[k]["right"] for k in bboxes]
102 | y = [bboxes[k]["bottom"] for k in bboxes]
103 |
104 | return max(x), max(y)
105 |
106 | def infer_from_screenshot(self, screenshot):
107 | h, w, _ = screenshot.shape
108 | return w, h
109 |
110 | @staticmethod
111 | def get_visible(xprops):
112 | return {k: xprops[k]["visibility"] == 1.0 for k in xprops}
113 |
114 | @staticmethod
115 | def rename_uid_attributes(dom_str, new_name="data-webtasks-id", old_name="bid"):
116 | return dom_str.replace(f"{old_name}=", f"{new_name}=")
117 |
118 | def get_action(self, obs: dict) -> str:
119 | # preprocessing
120 | obs["dom_str"] = flatten_dom_to_str(obs["dom_object"])
121 | obs["bboxes"] = self.get_bboxes(obs["extra_element_properties"])
122 | # obs["axtree_txt"] = flatten_axtree_to_str(obs["axtree_object"])
123 | # obs["visible"] = self.get_visible(obs["extra_element_properties"])
124 |
125 | vw, vh = self.infer_from_screenshot(obs["screenshot"])
126 | obs['html_str_orig'] = self.rename_uid_attributes(obs['dom_str'])
127 |
128 | obs["html_str"], attrs_map = replace_bid_with_wl_uid(obs["dom_str"], return_mapping=True)
129 | obs["remapped_bboxes"] = remap_bboxes(obs["bboxes"], attrs_map=attrs_map)
130 | reverse_attrs_map = reverse_dict(attrs_map)
131 |
132 | # check if we have new messages in the chat (+1 will skip first default message)
133 | new_messages = obs["chat_messages"][self.num_messages + 1 :]
134 | self.messages.extend(new_messages)
135 |
136 | # update action history with new messages
137 | for message in new_messages:
138 | role = "instructor" if message["role"] == "user" else "navigator"
139 | if role == "instructor":
140 | self.has_user_message = True
141 |
142 | self.action_history.append(
143 | wa.classes.Action(
144 | type="chat",
145 | index=len(self.action_history),
146 | intent="say",
147 | args={"utterance": message["message"], "speaker": role},
148 | timestamp=time.time() - self.start_time,
149 | tag=None,
150 | attrs=None,
151 | )
152 | )
153 | print(f"New message by '{role}': {message['message']}")
154 |
155 | if not self.has_user_message:
156 | # sleep and do nothing if no user message has been received
157 | return "wait(2)"
158 |
159 | state = wa.classes.State(
160 | index=len(self.action_history),
161 | html=obs["html_str"],
162 | bboxes=obs["remapped_bboxes"],
163 | viewport_height=vh,
164 | viewport_width=vw,
165 | type="browser",
166 | )
167 | pred_action = wa.web.client.get_prediction(
168 | self.action_history,
169 | state,
170 | address="localhost",
171 | port=8450,
172 | max_new_tokens=128,
173 | )
174 | # breakpoint()
175 | pred_action = postprocess_for_browsergym(pred_action, uid_map=reverse_attrs_map)
176 | # pred_action = postprocess_for_browsergym(pred_action)
177 |
178 | a = wa.classes.Action.from_dict(pred_action)
179 |
180 | # add action to action history
181 | self.action_history.append(a)
182 |
183 | action_str = a.to_str()
184 | print("Action String:", action_str)
185 |
186 | return action_str
187 |
188 | def get_action_mapping(self) -> callable:
189 | """
190 | Returns a callable that can be used to map the agent actions to executable python code.
191 | """
192 | action_set = HighLevelActionSet(
193 | subsets="custom",
194 | custom_actions=[say, click, textinput, load, scroll, wait],
195 | multiaction=False,
196 | strict=True,
197 | )
198 | return action_set.to_python_code
199 |
--------------------------------------------------------------------------------
/examples/browsergym/run_bg.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | import browsergym.core # register the openended task as a gym environment
3 | from examples.browsergym.agent import WebLinxAgent
4 |
5 | agent = WebLinxAgent()
6 |
7 | env = gym.make(
8 | "browsergym/openended",
9 | headless=False,
10 | wait_for_user_message=False,
11 | action_mapping=agent.get_action_mapping(),
12 | task_kwargs={"start_url": "chrome://newtab"},
13 | # task_kwargs={"start_url": "https://en.wikipedia.org"},
14 | )
15 |
16 | agent.reset()
17 | obs, info = env.reset()
18 |
19 | done = False
20 | while not done:
21 | action = agent.get_action(obs)
22 | obs, reward, terminated, truncated, info = env.step(action)
23 |
--------------------------------------------------------------------------------
/examples/complete/run_all.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import time
3 | import logging
4 |
5 | from sentence_transformers import SentenceTransformer
6 | from transformers import AutoTokenizer, pipeline
7 | import weblinx as wl
8 | import webllama.experimental as wa
9 |
10 | logging.getLogger("urllib3").setLevel(logging.WARNING)
11 |
12 | # Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects
13 | # To get that, we will use an example from weblinx, but it's easy to do manually (see below).
14 |
15 | demos = wl.list_demonstrations("tests/demonstrations")
16 | replay = wl.Replay.from_demonstration(demos[0])
17 | turn = replay[26]
18 |
19 | format_intent_am = partial(
20 | wa.formatting.build_formatters_action_model(), return_as=dict
21 | )
22 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr()
23 | action_history = wa.functions.create_action_history_from_replay(
24 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
25 | )
26 |
27 | state = wa.classes.State(
28 | index=turn.index,
29 | html=turn.html,
30 | bboxes=turn.bboxes,
31 | viewport_height=turn.viewport_height,
32 | viewport_width=turn.viewport_width,
33 | type=turn.type,
34 | )
35 |
36 |
37 | # Verifying that the to_dict and from_dict methods work as expected
38 | act = action_history[0]
39 | d = act.to_dict()
40 | act2 = wa.classes.Action.from_dict(d)
41 | assert act == act2
42 |
43 | d = state.to_dict()
44 | state2 = wa.classes.State.from_dict(d)
45 | assert state == state2
46 |
47 |
48 | # Now, we can start!
49 | # First, load the DMR model we will use to select candidate elements
50 | dmr_name = "McGill-NLP/MiniLM-L6-dmr"
51 | action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx"
52 | tokenizer_chat_name = "McGill-NLP/Llama-2-7b-chat-weblinx"
53 |
54 | tokenizer = AutoTokenizer.from_pretrained(action_model_name)
55 | tokenizer_chat = AutoTokenizer.from_pretrained(tokenizer_chat_name)
56 | dmr = SentenceTransformer(dmr_name, device="cuda")
57 | action_model = pipeline(model=action_model_name, device=0, torch_dtype="auto")
58 |
59 | # We will initialize our processor, which helps us prepare the input for action model
60 | proc = wa.processing.WebTurnProcessor(tokenizer=tokenizer, start_time=time.time())
61 |
62 | # Step 1: prepare query, run DMR and prepare retrieved candidates
63 | query_dmr = proc.prepare_dmr_query(action_history, state)
64 | elems = proc.prepare_dmr_elements(state=state)
65 | scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems)
66 | top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k)
67 | cands_str = proc.prepare_candidates(top_cands)
68 |
69 | # Step 2: format candidates, utterances, state, and previous actions
70 | html = proc.prepare_state_html(state.html, cands_uids=cands_uids)
71 | utterances = proc.prepare_instructor_chat(action_history, state)
72 | prev_actions = proc.prepare_prev_actions(action_history, state)
73 |
74 | # Let's use the default system prompt template, but you can also use your own
75 | sys_prompt_template: str = proc.default_system_prompt_template
76 | sys_prompt = sys_prompt_template.format(
77 | html=html,
78 | num_utterances=proc.num_utterances - 1,
79 | utterances=utterances,
80 | height=state.viewport_height,
81 | width=state.viewport_width,
82 | num_prev_actions=proc.num_prev_actions,
83 | candidates=cands_str,
84 | )
85 | input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions)
86 |
87 | # We can now use the tokenizer's apply_chat_template method to convert it to a format
88 | # that can be used by the action model
89 | input_str = tokenizer_chat.apply_chat_template(input_chat, tokenize=False)
90 |
91 | # Let's now pass our input to the action model
92 | output = action_model(
93 | input_str,
94 | max_new_tokens=256,
95 | return_full_text=False,
96 | batch_size=1,
97 | pad_token_id=tokenizer.eos_token_id,
98 | )
99 | pred_action = proc.process_action_model_output(
100 | output=output, index=state.index, elems=elems
101 | )
102 | # optional: For certain platforms you may need to postprocess the action
103 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action)
104 | print(pred_action)
105 | # You can now convert this an Action object and add it to the action history
106 | a = wa.classes.Action.from_dict(pred_action)
107 | action_history.append(a)
108 |
--------------------------------------------------------------------------------
/examples/web_api/run_client.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import webllama.experimental as wa
3 | import weblinx as wl
4 |
5 | demos = wl.list_demonstrations("tests/demonstrations")
6 | replay = wl.Replay.from_demonstration(demos[0])
7 | turn = replay[26]
8 |
9 | format_intent_am = partial(
10 | wa.formatting.build_formatters_action_model(), return_as=dict
11 | )
12 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr()
13 | action_history = wa.functions.create_action_history_from_replay(
14 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
15 | )
16 | state = wa.classes.State(
17 | index=turn.index,
18 | html=turn.html,
19 | bboxes=turn.bboxes,
20 | viewport_height=turn.viewport_height,
21 | viewport_width=turn.viewport_width,
22 | type=turn.type,
23 | )
24 |
25 | pred_action = wa.web.client.get_prediction(
26 | action_history, state, address="localhost", port=8450, max_new_tokens=128
27 | )
28 | pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action)
29 | print(pred_action)
30 | a = wa.classes.Action.from_dict(pred_action)
31 | print(a)
32 |
--------------------------------------------------------------------------------
/examples/web_api/run_http.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import http.client
3 | import json
4 |
5 | import weblinx as wl
6 | import webllama.experimental as wa
7 |
8 | def run_http():
9 | demos = wl.list_demonstrations("tests/demonstrations")
10 | replay = wl.Replay.from_demonstration(demos[0])
11 | turn = replay[26]
12 |
13 | format_intent_am = partial(
14 | wa.formatting.build_formatters_action_model(), return_as=dict
15 | )
16 | format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr()
17 | action_history = wa.functions.create_action_history_from_replay(
18 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
19 | )
20 | state = wa.classes.State(
21 | index=turn.index,
22 | html=turn.html,
23 | bboxes=turn.bboxes,
24 | viewport_height=turn.viewport_height,
25 | viewport_width=turn.viewport_width,
26 | type=turn.type,
27 | )
28 | action_history_dict = [action.to_dict() for action in action_history]
29 | state_dict = state.to_dict()
30 |
31 | # Create a connection to the localhost on the port where your server is running
32 | conn = http.client.HTTPConnection('localhost', 8450)
33 |
34 | # Send a request without parameters to test server response
35 | conn.request("POST", "/", body=json.dumps({}), headers={'Content-Type': 'application/json'})
36 | response = conn.getresponse()
37 | print("Test 1 - Server Initialization Check:")
38 | print(f"Status: {response.status}")
39 | print(f"Reason: {response.reason}")
40 | print(f"Body: {response.read().decode()}\n")
41 | response.close()
42 |
43 | # Prepare the POST request data
44 | post_data = json.dumps({
45 | 'action_history': action_history_dict,
46 | 'state': state_dict
47 | })
48 | headers = {'Content-Type': 'application/json'}
49 |
50 | # Send a POST request with JSON data
51 | conn.request("POST", "/", body=post_data, headers=headers)
52 | response = conn.getresponse()
53 | print("Test 2 - Functionality Check:")
54 | print(f"Status: {response.status}")
55 | print(f"Reason: {response.reason}")
56 | print(f"Body: {response.read().decode()}")
57 | response.close()
58 |
59 | # Close the connection
60 | conn.close()
61 |
62 | if __name__ == "__main__":
63 | run_http()
64 |
--------------------------------------------------------------------------------
/modeling/README.md:
--------------------------------------------------------------------------------
1 | ## Training
2 |
3 | First, you need to be in the `modeling` directory:
4 |
5 | ```bash
6 | cd modeling
7 | ```
8 |
9 | ### Download Data
10 |
11 | ownload the full dataset (warning: this will take a while):
12 |
13 | ```python
14 | from huggingface_hub import snapshot_download
15 |
16 | snapshot_download(repo_id="McGill-NLP/WebLINX-full", repo_type="dataset", local_dir="./wl_data/")
17 | ```
18 |
19 | The default configs (`llama/conf/config.yml`) assume that the `train.jsonl` is located at `./wl_data/candidates/train.jsonl`. If you want to change the path, you need to modify the `config.yml` accordingly.
20 |
21 | #### Optional: Symbolic linking to `WebLINX-full`
22 |
23 | If you downloaded `WebLINX-full` data in a different location (e.g. different disk) from your `weblinx/modeling` directory, you might consider using symbolic link to avoid having to change the `config.yml` files. You should do something like:
24 |
25 | ```bash
26 | ln -s /location/of/your/full/data /location/of/project/weblinx/modeling/wl_data
27 | ```
28 |
29 | For example, if your data is located at `/mnt/research/scratch/users/jdoe/WebLINX-full` but your cloned `weblinx` repository is at `~/dev/weblinx`, then you'd run:
30 |
31 | ```bash
32 | ln -s /mnt/research/scratch/users/jdoe/WebLINX-full ~/dev/weblinx/modeling/wl_data
33 | ```
34 |
35 | Which corresponds to the `data.base_dir` specified in `config.yml`, which is `"${project_dir}/wl_data/demonstrations/"`.
36 |
37 | ### Set `WEBLLAMA_PROJECT_DIR`
38 |
39 | You need to set the `WEBLLAMA_PROJECT_DIR` environment variable to the root directory of the WebLINX project. For example, if you have the following directory structure:
40 |
41 | ```bash
42 | export WEBLLAMA_PROJECT_DIR=/path/to/the/modeling/directory/
43 |
44 | # For example, if you are in the modeling directory, you can run:
45 | export WEBLLAMA_PROJECT_DIR=$(pwd)
46 | ```
47 |
48 | ### Install Dependencies
49 |
50 | You need to install the dependencies by running the following command:
51 |
52 | ```bash
53 | pip install -e .[extra]
54 | pip install -r modeling/requirements.txt
55 | ```
56 |
57 | However, due to `flash-attention` requiring `torch` to be pre-installed, it has to be install right after everything else has been installed:
58 | ```bash
59 | # Regular install
60 | pip install "flash-attn>=2.3.0"
61 | # IF you have limited RAM, you can try this:
62 | MAX_JOBS=4 pip install "flash-attn>=2.3.0" --no-build-isolation
63 | # If you have issues with nvcc, try this:
64 | FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install "flash-attn>=2.3.0" --no-build-isolation
65 | ```
66 |
67 | ### Action Model
68 |
69 | #### Train LLaMA
70 |
71 | You can train the model by running the following command (it will automatically use the hydra config from `conf/`):
72 |
73 | ```bash
74 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
75 |
76 | # Train Llama-3-8B-Instruct on WebLINX
77 | accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_4gpus.yaml -m llama.train
78 |
79 | # Fancy a different model? You can create your own variant (e.g. llama/conf/variant/8b_base.yaml)
80 | accelerate launch --use_fsdp --config_file llama/accelerate/fsdp_4gpus.yaml -m llama.train +variant="8b_base"
81 | ```
82 |
83 | Results will be saved in `./results` and checkpoints in `./checkpoints`.
84 |
85 | #### Run LLaMA on Evaluation Splits
86 |
87 | You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command:
88 |
89 | ```bash
90 | export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use
91 |
92 | # Evaluating llama-3-8b-instruct on a split
93 | python -m llama.eval -m eval.split=valid
94 |
95 | # Or other datasets (using multiple splits)
96 | python -m llama.eval -m eval.split=test_iid,test_web,test_geo,test_cat,test_vis
97 | ```
98 |
99 | #### Optional: running with screen
100 |
101 | You can run this (inside `modeling` dir):
102 | ```bash
103 | # Choose the variant you want to evaluate
104 | var="8b"
105 |
106 | # Launch the screen in detaqched mode
107 | iid="CUDA_VISIBLE_DEVICES=0 ../venv/bin/python -m llama.eval -m +variant="$var" eval.split=test_iid"
108 | screen -dmS eval-llama-$var-iid bash -c "$iid; exec bash"
109 | # ...
110 | vis="CUDA_VISIBLE_DEVICES=4 ../venv/bin/python -m llama.eval -m +variant="$var" eval.split=test_vis"
111 | screen -dmS eval-llama-$var-vis bash -c "$vis; exec bash"
112 | ```
113 |
114 | ### Evaluation
115 |
116 | To run the evaluation metrics, you can use the following command (from `modeling/`):
117 |
118 | ```bash
119 | python -m weblinx.eval -d ./results -b ./wl_data/demonstrations
120 | ```
121 |
122 | In this case, `-b` is the base directory for the demonstrations, and `-d` is the directory containing the results (generated above by the `llama.eval` script). This will automatically run the evaluation metrics and save the results in the `results/aggregated_scores.json` directory. If you are only interested in the overall score for a split (e.g. `valid`), you can find look for the following entry in the aggregated score file (as an example):
123 |
124 | ```json
125 | // ...
126 | {
127 | "split": "valid",
128 | "intent": "overall",
129 | "metric": "overall",
130 | "model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
131 | "project_name": "llama_ft",
132 | "score": 0.21667765869744438,
133 | "unconditional_score": 0.15307513104251605
134 | },
135 | // ...
136 | ```
137 |
138 | Behind the scene, this will use the `weblinx.eval.auto_eval_and_save` function to run the evaluation metrics. If you want more control, you can also use that `weblinx.eval.auto_eval_and_save` function directly if you prefer; for an example, check out `weblinx/eval/__main__.py`.
139 |
140 | Note that it might be slow the first time you run, because it reads a lot of demonstrations and load millions of files. However, a demo-level cache is automatically created (see `./.cache/demonstrations`), so the next time you run it, it should be much faster.
141 |
142 | ### Dense Markup Ranking (DMR)
143 |
144 | #### Train DMR
145 |
146 | You can train the model by running the following command (it will automatically use the hydra config from `conf/`):
147 |
148 | ```bash
149 | export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use
150 |
151 | # Finetune MiniLM-L6-DMR (Default)
152 | python -m dmr.train
153 | ```
154 |
155 | Results will be saved in `./results` and checkpoints in `./checkpoints`.
156 |
157 | #### Inference for DMR
158 |
159 | You need to specify which `eval.split` you want to evaluate on. For example, to evaluate on the `iid` split, you can run the following command:
160 |
161 | ```bash
162 | export CUDA_VISIBLE_DEVICES="0" # Set the GPU device you want to use
163 |
164 | # On just one
165 | python -m dmr.eval eval.split=valid
166 |
167 | # On multiple splits (e.g. test_iid, test_vis)
168 | python -m dmr.eval eval.split=test_iid,test_web,test_geo,test_cat,test_vis
169 | ```
170 |
171 | #### Moving generated DMR results to `wl_data/candidates`
172 |
173 | The `scores.jsonl` and `results.json` files will be saved at the `cfg.eval.result_dir` variable in `modeling/dmr/conf/config.yml`, which is by default `${project_dir}/results/${project_name}/${model.name}/${eval.split}`, which should by default resolve to `/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2/train` for the `train` split, `.../valid` for the valid split, etc. However, since the next steps assumes you have a directory like `wl_data/candidates/.json`, you need to manually move it. For example, you could run:
174 |
175 | ```bash
176 | # Change the following paths to match your setup
177 | orig_dir="/path/to/weblinx/modeling/results/dmr/sentence-transformers/all-MiniLM-L6-v2"
178 | # This is the directory where the candidates are stored
179 | new_dir="/path/to/wl_data/candidates"
180 |
181 | # You need to move the train split if you plan to use it for training the action model
182 | mv $orig_dir/train/scores.jsonl $new_dir/train.jsonl
183 | # You can move valid and test IID splits as well
184 | mv $orig_dir/valid/scores.jsonl $new_dir/valid.jsonl
185 | mv $orig_dir/test_iid/scores.jsonl $new_dir/test_iid.jsonl
186 | mv $orig_dir/test_web/scores.jsonl $new_dir/test_web.jsonl
187 | mv $orig_dir/test_geo/scores.jsonl $new_dir/test_geo.jsonl
188 | mv $orig_dir/test_cat/scores.jsonl $new_dir/test_cat.jsonl
189 | mv $orig_dir/test_vis/scores.jsonl $new_dir/test_vis.jsonl
190 | ```
191 |
192 | Alternatively, you can also update `config.yml` to save the results in the correct directory, by overriding `candidates`:
193 | ```yaml
194 | # ...
195 | candidates:
196 | # ...
197 | model: "sentence-transformers/all-MiniLM-L6-v2"
198 | path: ${project_dir}/results/${project_name}/${model.name}/${eval.split}
199 | ```
200 |
201 |
--------------------------------------------------------------------------------
/modeling/dmr/conf/config.yaml:
--------------------------------------------------------------------------------
1 | project_dir: ${oc.env:WEBLINX_PROJECT_DIR}
2 | seed: 123
3 | project_name: dmr
4 |
5 | data:
6 | split_path: ${project_dir}/wl_data/splits.json
7 | base_dir: ${project_dir}/wl_data/demonstrations
8 |
9 | model:
10 | name: sentence-transformers/all-MiniLM-L6-v2
11 | max_seq_length: 512
12 | use_bf16: True
13 | similarity: cos_sim
14 | save_dir: ${project_dir}/checkpoints/${project_name}/${model.name}
15 |
16 | train:
17 | split: train
18 | num_epochs: 10
19 | max_neg_per_turn: 9
20 | batch_size_per_device: 64
21 | dataloader_num_workers: 8
22 | optim: adamw
23 | gradient_checkpointing: True
24 | learning_rate: 0.00003
25 | warmup_steps: 500
26 | # Available schedulers:
27 | # constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
28 | scheduler: warmuplinear
29 |
30 | eval:
31 | split: dev
32 | mrr_k: 50
33 | batch_size_per_device: 64
34 | result_dir: ${project_dir}/results/${project_name}/${model.name}/${eval.split}
35 |
36 | hydra:
37 | run:
38 | dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
39 | # Use the same for sweep's subdir
40 | sweep:
41 | dir: ${hydra.run.dir}
42 | job:
43 | chdir: False
44 | verbose: INFO
--------------------------------------------------------------------------------
/modeling/dmr/eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from pathlib import Path
4 | from typing import List, Dict, Any
5 |
6 | import hydra
7 | import numpy as np
8 | import torch
9 | from tqdm import tqdm
10 | from sentence_transformers import SentenceTransformer
11 | from sentence_transformers.util import cos_sim, dot_score
12 | import weblinx as wl
13 | from weblinx.processing import group_record_to_dict
14 | from weblinx.utils.recs import ungroup_dict_to_records
15 | from weblinx.utils.hydra import save_path_to_hydra_logs
16 |
17 | from .processing import build_records_for_single_demo, build_formatters
18 |
19 |
20 | def recall_at_k(input_records, k, label_key="label", rank_key="rank"):
21 | num_correct = 0
22 | num_total = 0
23 |
24 | for r in input_records:
25 | if r[label_key] == 1:
26 | num_total += 1
27 | if r[rank_key] <= k:
28 | num_correct += 1
29 |
30 | score = num_correct / num_total
31 | return score
32 |
33 |
34 | def mean_reciprocal_rank(input_records, label_key="label", rank_key="rank", k=None):
35 | if k is None or len(input_records) < k or k < 1:
36 | k = len(input_records)
37 |
38 | mrr = 0
39 | num_total = 0
40 |
41 | for r in input_records:
42 | if r[label_key] == 1:
43 | if r[rank_key] <= k:
44 | mrr += 1 / r[rank_key]
45 | num_total += 1
46 |
47 | mrr /= num_total
48 |
49 | return mrr
50 |
51 |
52 | def verify_queries_are_all_the_same(grouped_records: dict) -> bool:
53 | """
54 | Given a dictionary of grouped records, this function verifies that all
55 | queries are the same within each group.
56 | """
57 | for k, v in grouped_records.items():
58 | first_query = v[0]["query"]
59 | if not all(r["query"] == first_query for r in v):
60 | return False
61 | return True
62 |
63 |
64 | def run_model_and_update_groups(
65 | model, input_grouped: Dict[Any, List[dict]], batch_size, sim_method="cos_sim"
66 | ):
67 | if sim_method == "cos_sim":
68 | sim_func = cos_sim
69 | elif sim_method == "dot_product":
70 | sim_func = dot_score
71 | else:
72 | raise ValueError(f"Unknown similarity function: {sim_method}")
73 |
74 | for k, group in tqdm(input_grouped.items(), desc="Computing scores"):
75 | group = input_grouped[k]
76 | query = group[0]["query"]
77 | docs = [r["doc"] for r in group]
78 |
79 | encoded = model.encode(
80 | [query] + docs, batch_size=batch_size, show_progress_bar=False
81 | )
82 | query_vector, doc_vectors = encoded[0], encoded[1:]
83 | scores = sim_func(query_vector, doc_vectors).cpu().squeeze().tolist()
84 | if isinstance(scores, float):
85 | scores = [scores]
86 |
87 | for i, r in enumerate(group):
88 | r["score"] = scores[i]
89 |
90 |
91 | def build_target_uids_dict(demos, uid_key="data-webtasks-id"):
92 | """
93 | Given a list of demonstrations, build a dictionary mapping
94 | `(demo_name, turn_index) -> uid`. This is used to determine the
95 | target element for a given demo turn, which labels the element
96 | as positive or negative.
97 | """
98 | target_uids_dict = {}
99 | for demo in tqdm(demos, desc="Creating dict of target uids"):
100 | for turn in wl.Replay.from_demonstration(demo):
101 | if turn.element is None or "attributes" not in turn.element:
102 | continue
103 | if uid_key not in turn.element["attributes"]:
104 | continue
105 |
106 | uid = turn.element["attributes"][uid_key]
107 | target_uids_dict[(demo.name, turn.index)] = uid
108 |
109 | return target_uids_dict
110 |
111 |
112 | def get_ranks_from_scores(scores: Dict[Any, float], starts_at=1) -> Dict[Any, int]:
113 | """
114 | Given a dictionary of key -> scores, return a dictionary of key -> ranks.
115 | """
116 | # Get sorted keys
117 | keys = sorted(scores.keys(), key=lambda k: scores[k], reverse=True)
118 | ranks = {k: i + starts_at for i, k in enumerate(keys)}
119 |
120 | return ranks
121 |
122 |
123 | @hydra.main(version_base=None, config_path="conf", config_name="config")
124 | def main(cfg):
125 | torch.manual_seed(cfg.seed)
126 |
127 | use_bf16 = cfg.model.use_bf16
128 | split = cfg.eval.split
129 | bsize = cfg.eval.batch_size_per_device
130 |
131 | split_path = Path(cfg.data.split_path).expanduser()
132 | model_save_dir = Path(cfg.model.save_dir).expanduser()
133 | result_dir = Path(cfg.eval.result_dir).expanduser()
134 |
135 | result_dir.mkdir(parents=True, exist_ok=True)
136 |
137 | if use_bf16:
138 | torch_dtype = torch.bfloat16
139 | use_amp = False
140 | else:
141 | torch_dtype = torch.float32
142 | use_amp = True
143 |
144 | # Data loading
145 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=split)
146 | demos = [wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names]
147 |
148 | format_intent_input, _ = build_formatters()
149 | input_records: List[dict] = []
150 | logging.info(f"Number of demos: {len(demos)}. Starting building records.")
151 | for demo in tqdm(demos, desc="Building input records"):
152 | demo_records = build_records_for_single_demo(
153 | demo=demo,
154 | format_intent_input=format_intent_input,
155 | max_neg_per_turn=None,
156 | # For eval, we want to include all elements in the demo
157 | # not just the ones with valid uids
158 | only_allow_valid_uid=False,
159 | )
160 | input_records.extend(demo_records)
161 | logging.info(f"Completed. Number of input records: {len(input_records)}")
162 |
163 | # Group records by (demo_name, turn_index) pairs
164 | input_grouped = group_record_to_dict(
165 | input_records, keys=["demo_name", "turn_index"], remove_keys=False
166 | )
167 |
168 | # Verify that queries are all the same within each group
169 | error_msg = "Queries are not all the same within each group"
170 | assert verify_queries_are_all_the_same(input_grouped), error_msg
171 |
172 | # Run the model and update the scores and ranks in place
173 | logging.info("Running model and computing scores")
174 |
175 | # Run the model
176 | model = SentenceTransformer(str(model_save_dir))
177 | sim_method = cfg.model.get("similarity", "cos_sim")
178 |
179 | logging.info(f"Using the following similarity method: {sim_method}")
180 |
181 | with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch_dtype):
182 | run_model_and_update_groups(
183 | model, input_grouped=input_grouped, batch_size=bsize, sim_method=sim_method
184 | )
185 | logging.info("Completed")
186 |
187 | for group in input_grouped.values():
188 | scores = {r["uid"]: r["score"] for r in group}
189 | ranks = get_ranks_from_scores(scores)
190 | for r in group:
191 | r["rank"] = ranks[r["uid"]]
192 |
193 | # Revert back to original records
194 | input_records = ungroup_dict_to_records(input_grouped)
195 |
196 | # Metrics
197 | lengths = np.array([len(v) for v in input_grouped.values()])
198 | results = {
199 | "split": split,
200 | "num_turns": len(input_grouped),
201 | "num_demos": len(demos),
202 | "avg_elements_per_turn": lengths.mean(),
203 | "std_elements_per_turn": lengths.std(),
204 | "mrr": mean_reciprocal_rank(input_records, k=cfg.eval.mrr_k),
205 | }
206 |
207 | for k in [1, 5, 10, 20, 50, 100, 200]:
208 | results[f"recall@{k}"] = recall_at_k(input_records, k=k)
209 |
210 | for k, v in results.items():
211 | print(f"{k}: {v}")
212 |
213 | # Save results
214 | with open(result_dir.joinpath("results.json"), "w") as f:
215 | json.dump(results, f, indent=2)
216 |
217 | # Save records and scores
218 | with open(result_dir.joinpath("scores.jsonl"), "w") as f:
219 | for r in input_records:
220 | f.write(json.dumps(r) + "\n")
221 |
222 | save_path_to_hydra_logs(save_dir=model_save_dir)
223 |
224 |
225 | if __name__ == "__main__":
226 | main()
227 |
--------------------------------------------------------------------------------
/modeling/dmr/processing.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from copy import deepcopy
3 | import random
4 | from typing import Any, Dict, List
5 | from functools import partial
6 |
7 | import lxml.html
8 | import weblinx as wl
9 | import weblinx.utils.html as wh
10 | import weblinx.utils.format as wlf
11 | from weblinx.processing.prompt import (
12 | format_prev_turns,
13 | find_turns_with_instructor_chat,
14 | format_utterances,
15 | )
16 |
17 |
18 | def format_turn_for_input(
19 | replay,
20 | turn,
21 | format_intent,
22 | turn_sep=" ; ",
23 | num_prev_turns=5,
24 | num_utterances=5,
25 | return_str=True,
26 | ):
27 | """
28 | This function formats a turn for input to the model. It does so by combining the following:
29 | 1. The first and last `num_utterances-1` utterances from the instructor
30 | 2. The previous turns (up to `num_prev_turns` turns)
31 |
32 | If return_str is True, then the output is a string. Otherwise, it returns two strings: the utterance context and the previous turns.
33 | """
34 | prev_turns_text = format_prev_turns(
35 | replay=replay,
36 | turn=turn,
37 | format_intent=format_intent,
38 | turn_sep=turn_sep,
39 | num_prev_turns=num_prev_turns,
40 | )
41 | instructor_chat_turns = find_turns_with_instructor_chat(
42 | replay, turn, num_prev_turns=num_prev_turns
43 | )
44 | utterance_context = format_utterances(
45 | instructor_chat_turns, num_utterances=num_utterances
46 | )
47 |
48 | if not return_str:
49 | return utterance_context, prev_turns_text
50 |
51 | # Now, let's combine the text from the previous turns with the utterance context
52 | # and the current turn's utterance
53 | text = (
54 | f"Viewport(height={turn.viewport_height}, width={turn.viewport_width}) ---- "
55 | f"Instructor Utterances: {utterance_context} ---- "
56 | f"Previous Turns:{prev_turns_text}"
57 | )
58 |
59 | return text
60 |
61 |
62 | def build_formatters():
63 | format_element_input = partial(
64 | wlf.format_element,
65 | include_text=False,
66 | include_attrs=("class", "title", "href", "aria-label", "d", "src"),
67 | )
68 | format_click_input = partial(
69 | wlf.format_click,
70 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp),
71 | )
72 | format_change_input = partial(
73 | wlf.format_change,
74 | formatters=(
75 | partial(wlf.format_arg_item, name="value"),
76 | format_element_input,
77 | wlf.format_timestamp,
78 | ),
79 | )
80 | format_hover_input = partial(
81 | wlf.format_hover,
82 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp),
83 | )
84 |
85 | format_submit_input = partial(
86 | wlf.format_submit, formatters=(format_element_input, wlf.format_timestamp)
87 | )
88 |
89 | format_text_input_input = partial(
90 | wlf.format_text_input,
91 | formatters=(
92 | partial(wlf.format_arg_item, name="text"),
93 | partial(format_element_input),
94 | wlf.format_timestamp,
95 | ),
96 | )
97 |
98 | format_intent_input = partial(
99 | wlf.format_intent_automatically,
100 | format_click=format_click_input,
101 | format_change=format_change_input,
102 | format_hover=format_hover_input,
103 | format_submit=format_submit_input,
104 | format_text_input=format_text_input_input,
105 | return_as=str,
106 | )
107 |
108 | # second, for the output (prediction text)
109 | format_element_out = partial(
110 | wlf.format_element,
111 | # Only want the tag
112 | include_text=False,
113 | include_attrs=False,
114 | )
115 |
116 | format_click_out = partial(wlf.format_click, formatters=(wlf.format_mouse_xy,))
117 | format_text_input_out = partial(
118 | wlf.format_text_input,
119 | formatters=(
120 | partial(wlf.format_arg_item, name="text", max_length=200),
121 | format_element_out,
122 | wlf.format_target_bbox,
123 | ),
124 | )
125 | format_change_out = partial(
126 | wlf.format_change,
127 | formatters=(
128 | partial(wlf.format_arg_item, name="value", max_length=200),
129 | format_element_out,
130 | wlf.format_target_bbox,
131 | ),
132 | )
133 | format_submit_out = partial(
134 | wlf.format_submit, formatters=(format_element_out, wlf.format_target_bbox)
135 | )
136 | format_load_out = partial(
137 | wlf.format_load,
138 | include_transition=False,
139 | include_timestamp=False,
140 | max_length=200,
141 | )
142 | format_scroll_out = partial(wlf.format_scroll, include_timestamp=False)
143 |
144 | format_say_out = partial(wlf.format_say, include_timestamp=False)
145 |
146 | format_intent_out = partial(
147 | wlf.format_intent_automatically,
148 | format_change=format_change_out,
149 | format_click=format_click_out,
150 | format_load=format_load_out,
151 | format_say=format_say_out,
152 | format_scroll=format_scroll_out,
153 | format_submit=format_submit_out,
154 | format_text_input=format_text_input_out,
155 | )
156 |
157 | return format_intent_input, format_intent_out
158 |
159 |
160 | def turn_has_valid_uid(turn, paths, uid_key="data-webtasks-id"):
161 | """
162 | Given a turn an lxml tree, return True if the turn's uid is in the tree.
163 | """
164 | uids = [p.attrib[uid_key] for p in paths]
165 | if turn.element is None or uid_key not in turn.element["attributes"]:
166 | return False
167 |
168 | if turn.element["attributes"][uid_key] not in uids:
169 | return False
170 |
171 | return True
172 |
173 |
174 | def format_attrs(attrs):
175 | return " ".join([f"{k!s}={v!r}" for k, v in attrs.items()])
176 |
177 |
178 | def shorten(s, max_length=100, side="center", ellipsis="..."):
179 | if max_length is None:
180 | return s
181 |
182 | if len(s) <= max_length:
183 | return s
184 |
185 | max_length = max_length - len(ellipsis)
186 |
187 | if side == "right":
188 | s = s[:max_length] + ellipsis
189 | elif side == "left":
190 | s = ellipsis + s[-max_length:]
191 | elif side == "center":
192 | s = s[: max_length // 2] + ellipsis + s[-max_length // 2 :]
193 | else:
194 | raise ValueError(f"Invalid side: {side}")
195 |
196 | return s
197 |
198 |
199 | def format_children(parent, depth=1):
200 | """
201 | Use the concise parentheses notation to format the children of an element.
202 | For example, for depth 1, we only have: (child1 child2 child3)
203 | For depth 2, we have: (child1 (grandchild1 grandchild2) child2 child3)
204 | """
205 | children = parent.getchildren()
206 | if len(children) == 0:
207 | return ""
208 |
209 | if depth == 1:
210 | return " ".join([c.tag for c in children])
211 |
212 | out_str = ""
213 | for c in children:
214 | out_str += f"{c.tag}"
215 | children_str = format_children(c, depth=depth - 1)
216 | if children_str != "":
217 | out_str += f" ( {children_str} )"
218 | out_str += " "
219 |
220 | return out_str.strip()
221 |
222 |
223 | def represent_element_as_dict(
224 | element,
225 | bbox,
226 | root_tree,
227 | max_text_length=200,
228 | max_attr_length=100,
229 | max_child_depth=2,
230 | ):
231 | """
232 | Format an lxml element into a dictionary of strings. The keys are:
233 | - tag: the tag name of the element
234 | - xpath: the xpath of the element
235 | - text: the text of the element, truncated to `max_text_length`
236 | - bbox: the bounding box of the element
237 | - attributes: the attributes of the element, truncated to `max_attr_length`
238 | - children: the children of the element, truncated to `max_attr_length`
239 | """
240 | # Get the tag name
241 | tag = element.tag
242 | xpath = root_tree.getpath(element)
243 | children = element.getchildren()
244 | text = element.text if element.text is not None else ""
245 |
246 | # Shorten the text and attributes
247 | text = shorten(text, max_text_length)
248 | attrs = {k: shorten(v, max_attr_length) for k, v in element.attrib.items()}
249 |
250 | # Sort the attributes by length
251 | attrs = dict(sorted(attrs.items(), key=lambda x: len(x[1])))
252 |
253 | # Truncate the children
254 | children = children[:max_child_depth]
255 |
256 | # Format the children
257 | children_str = " ".join([c.tag for c in children if isinstance(c.tag, str)])
258 | children_str = shorten(children_str, max_attr_length)
259 |
260 | # Format the attributes
261 | attrs_str = format_attrs(attrs)
262 |
263 | # Format the bounding box
264 | bbox_str = " ".join(
265 | [f"{k}={round(bbox[k], 1)}" for k in ["x", "y", "width", "height"]]
266 | )
267 |
268 | # format as a dict
269 | element_dict = {
270 | "tag": tag,
271 | "xpath": xpath,
272 | "text": text,
273 | "bbox": bbox_str,
274 | "attributes": attrs_str,
275 | "children": children_str,
276 | }
277 |
278 | return element_dict
279 |
280 |
281 | def convert_elem_dict_to_str_legacy(elem_dict: dict):
282 | """
283 | Convert an element dictionary to a string.
284 | """
285 | elem_dict = deepcopy(elem_dict)
286 |
287 | element_str = f"[[tag]] {elem_dict.pop('tag')}\n"
288 | element_str += f"[[xpath]] {elem_dict.pop('xpath')}\n"
289 | element_str += f"[[text]] {elem_dict.pop('text')}\n"
290 | element_str += f"[[bbox]] {elem_dict.pop('bbox')}\n"
291 | element_str += f"[[attributes]] {elem_dict.pop('attributes')}\n"
292 | element_str += f"[[children]] {elem_dict.pop('children')}"
293 |
294 | # for other keys, we just add them to the end
295 |
296 | for k, v in elem_dict.items():
297 | element_str += f"\n[[{k}]] {v}"
298 |
299 | return element_str
300 |
301 |
302 | def build_records_for_single_turn(
303 | turn, replay, format_intent_input, uid_key, max_neg=None, only_allow_valid_uid=True
304 | ) -> List[dict]:
305 | """
306 | This function will build a list of dictionaries, each of which is a record
307 | for a single turn. Each record has the following keys:
308 | - query: the dialogue history, used as a query for training the model
309 | - doc: concise representation of HTML element used as doc for training
310 | - label: either 0 or 1, indicating whether the document is the target element
311 | - uid: the unique identifier for an element, must be in the element attributes
312 | - turn_index: the index of the turn in the replay
313 | - demo_name: the name of the demonstration
314 |
315 | If `only_allow_valid_uid` is True, then only turns that have a valid uid
316 | will be included in the output. Otherwise, all turns will be included.
317 | """
318 | bboxes_filt = wh.filter_bboxes(
319 | turn.bboxes,
320 | viewport_height=turn.viewport_height,
321 | viewport_width=turn.viewport_width,
322 | )
323 | root = lxml.html.fromstring(turn.html)
324 | root_tree = root.getroottree()
325 | elements = root.xpath(f"//*[@{uid_key}]")
326 | elements_filt = [p for p in elements if p.attrib[uid_key] in bboxes_filt]
327 |
328 | has_valid_uid = turn_has_valid_uid(turn, paths=elements, uid_key=uid_key)
329 | if only_allow_valid_uid and not has_valid_uid:
330 | return []
331 |
332 | # Now, we can format each of the elements in paths_filt into string
333 | # and use them as negative samples
334 | query = format_turn_for_input(replay, turn, format_intent=format_intent_input)
335 | target_uid = turn.element["attributes"][uid_key] if has_valid_uid else -1
336 |
337 | records_positive = []
338 | records_negative = []
339 |
340 | for elem in elements_filt:
341 | bbox = turn.bboxes[elem.attrib[uid_key]]
342 | elem_dict = represent_element_as_dict(elem, bbox, root_tree)
343 | elem_str = convert_elem_dict_to_str_legacy(elem_dict)
344 |
345 | record = {
346 | "query": query,
347 | "doc": elem_str,
348 | "uid": elem.attrib[uid_key],
349 | "demo_name": turn.demo_name,
350 | "turn_index": turn.index,
351 | "elem_dict": elem_dict,
352 | }
353 |
354 | if elem.attrib[uid_key] == target_uid:
355 | record["label"] = 1
356 | records_positive.append(record)
357 | else:
358 | record["label"] = 0
359 | records_negative.append(record)
360 |
361 | if max_neg is not None and 0 < max_neg < len(records_negative):
362 | records_negative = random.sample(records_negative, max_neg)
363 |
364 | return records_positive + records_negative
365 |
366 |
367 | def build_records_for_single_demo(
368 | demo,
369 | format_intent_input,
370 | max_neg_per_turn=None,
371 | random_state=None,
372 | uid_key="data-webtasks-id",
373 | only_allow_valid_uid=True,
374 | group_by_turn=False,
375 | ) -> List[dict]:
376 | """
377 | This runs `build_records_for_single_turn` for each turn in the demonstration.
378 | First, the demonstration is converted into a replay, and then we filter the
379 | turns to only those that have HTML and bounding boxes, and that are of the
380 | following intents:
381 | - click
382 | - change
383 | - textInput
384 | - scroll
385 | - load
386 | - submit
387 |
388 | Any turn that does not have a valid uid is discarded.
389 | """
390 | if random_state is not None:
391 | random.seed(random_state)
392 |
393 | replay = wl.Replay.from_demonstration(demo)
394 | turns = replay.filter_by_intents(
395 | "click", "change", "textInput", "scroll", "load", "submit"
396 | )
397 | turns = wl.filter_turns(turns, lambda t: t.has_html() and t.has_bboxes())
398 |
399 | records_for_demo = []
400 | for turn in turns:
401 | recs = build_records_for_single_turn(
402 | turn=turn,
403 | replay=replay,
404 | format_intent_input=format_intent_input,
405 | uid_key=uid_key,
406 | max_neg=max_neg_per_turn,
407 | only_allow_valid_uid=only_allow_valid_uid,
408 | )
409 | if group_by_turn:
410 | records_for_demo.append(recs)
411 | else:
412 | records_for_demo.extend(recs)
413 |
414 | return records_for_demo
415 |
--------------------------------------------------------------------------------
/modeling/dmr/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | import hydra
5 | from tqdm import tqdm
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from sentence_transformers import SentenceTransformer, InputExample
9 | import sentence_transformers.models as st_models
10 | from sentence_transformers.losses import CosineSimilarityLoss
11 | import transformers
12 | from weblinx.utils.hydra import save_path_to_hydra_logs
13 | import weblinx as wl
14 |
15 | from .processing import build_records_for_single_demo, build_formatters
16 |
17 |
18 | def infer_optimizer(name):
19 | name = name.lower()
20 |
21 | if name == "adamw":
22 | return torch.optim.AdamW
23 | elif name == "adam":
24 | return torch.optim.Adam
25 | elif name == "adafactor":
26 | return transformers.Adafactor
27 | elif name == "sgd":
28 | return torch.optim.SGD
29 | else:
30 | raise ValueError(f"Unknown optimizer name: {name}")
31 |
32 |
33 | @hydra.main(version_base=None, config_path="conf", config_name="config")
34 | def main(cfg):
35 | torch.manual_seed(cfg.seed)
36 |
37 | model_name = cfg.model.name
38 | use_bf16 = cfg.model.use_bf16
39 | max_seq_length = cfg.model.max_seq_length
40 | optim = cfg.train.optim
41 | split = cfg.train.split
42 | learning_rate = cfg.train.learning_rate
43 | warmup_steps = cfg.train.warmup_steps
44 | batch_size = cfg.train.batch_size_per_device
45 | num_epochs = cfg.train.num_epochs
46 | scheduler = cfg.train.scheduler
47 |
48 | split_path = split_path = Path(cfg.data.split_path).expanduser()
49 | model_save_dir = Path(cfg.model.save_dir).expanduser()
50 |
51 | if use_bf16:
52 | torch_dtype = torch.bfloat16
53 | use_amp = False
54 | else:
55 | torch_dtype = torch.float32
56 | use_amp = True
57 |
58 | # Data loading
59 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=split)
60 | demos = [wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names]
61 |
62 | if cfg.project_name.endswith("testing"):
63 | demos = demos[:10]
64 |
65 | format_intent_input, _ = build_formatters()
66 | input_records = []
67 | logging.info(f"Number of demos: {len(demos)}. Starting building records.")
68 | for demo in tqdm(demos, desc="Building input records"):
69 | input_records.extend(
70 | build_records_for_single_demo(
71 | demo=demo,
72 | format_intent_input=format_intent_input,
73 | max_neg_per_turn=cfg.train.max_neg_per_turn,
74 | random_state=cfg.seed,
75 | # For training, we only want to include elements with valid uids
76 | # otherwise, we will be training on a lot of negative examples
77 | only_allow_valid_uid=True,
78 | )
79 | )
80 |
81 | logging.info(f"Number of input records: {len(input_records)}")
82 |
83 | train_examples = [
84 | InputExample(texts=[r["query"], r["doc"]], label=float(r["label"]))
85 | for r in tqdm(
86 | input_records, desc="Converting records to sentence-transformers input"
87 | )
88 | ]
89 |
90 | train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
91 |
92 | # Model loading
93 | word_embedding_model = st_models.Transformer(
94 | model_name, max_seq_length=max_seq_length
95 | )
96 | if cfg.train.gradient_checkpointing and hasattr(
97 | word_embedding_model.auto_model, "gradient_checkpointing_enable"
98 | ):
99 | word_embedding_model.auto_model.gradient_checkpointing_enable()
100 |
101 | pooling_model = st_models.Pooling(
102 | word_embedding_model.get_word_embedding_dimension()
103 | )
104 | model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
105 | train_loss = CosineSimilarityLoss(model=model)
106 |
107 | logging.info(f"Starting training for {num_epochs} epochs.")
108 | with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch_dtype):
109 | model.fit(
110 | train_objectives=[(train_dataloader, train_loss)],
111 | epochs=num_epochs,
112 | optimizer_class=infer_optimizer(optim),
113 | warmup_steps=warmup_steps,
114 | output_path=str(model_save_dir),
115 | weight_decay=0.0,
116 | scheduler=scheduler,
117 | optimizer_params={"lr": learning_rate},
118 | )
119 | logging.info("Training complete.")
120 |
121 | save_path_to_hydra_logs(save_dir=model_save_dir)
122 |
123 | return model_save_dir
124 |
125 |
126 | if __name__ == "__main__":
127 | main()
128 |
--------------------------------------------------------------------------------
/modeling/llama/accelerate/fsdp_2gpus.yaml:
--------------------------------------------------------------------------------
1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | fsdp_config:
7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
8 | fsdp_backward_prefetch_policy: BACKWARD_PRE
9 | fsdp_cpu_ram_efficient_loading: true
10 | fsdp_forward_prefetch: false
11 | fsdp_offload_params: false
12 | fsdp_sharding_strategy: 1
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_sync_module_states: true
15 | # Set fsdp_use_orig_params=true if using peft:
16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
17 | fsdp_use_orig_params: false
18 | machine_rank: 0
19 | main_training_function: main
20 | mixed_precision: bf16
21 | num_machines: 1
22 | num_processes: 2
23 | rdzv_backend: static
24 | same_network: true
25 | tpu_env: []
26 | tpu_use_cluster: false
27 | tpu_use_sudo: false
28 | use_cpu: false
--------------------------------------------------------------------------------
/modeling/llama/accelerate/fsdp_4gpus.yaml:
--------------------------------------------------------------------------------
1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | fsdp_config:
7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
8 | fsdp_backward_prefetch_policy: BACKWARD_PRE
9 | fsdp_cpu_ram_efficient_loading: true
10 | fsdp_forward_prefetch: false
11 | fsdp_offload_params: false
12 | fsdp_sharding_strategy: 1
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_sync_module_states: true
15 | # Set fsdp_use_orig_params=true if using peft:
16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
17 | fsdp_use_orig_params: false
18 | machine_rank: 0
19 | main_training_function: main
20 | mixed_precision: bf16
21 | num_machines: 1
22 | num_processes: 4
23 | rdzv_backend: static
24 | same_network: true
25 | tpu_env: []
26 | tpu_use_cluster: false
27 | tpu_use_sudo: false
28 | use_cpu: false
--------------------------------------------------------------------------------
/modeling/llama/accelerate/fsdp_6gpus.yaml:
--------------------------------------------------------------------------------
1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | fsdp_config:
7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
8 | fsdp_backward_prefetch_policy: BACKWARD_PRE
9 | fsdp_cpu_ram_efficient_loading: true
10 | fsdp_forward_prefetch: false
11 | fsdp_offload_params: false
12 | fsdp_sharding_strategy: 1
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_sync_module_states: true
15 | # Set fsdp_use_orig_params=true if using peft:
16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
17 | fsdp_use_orig_params: false
18 | machine_rank: 0
19 | main_training_function: main
20 | mixed_precision: bf16
21 | num_machines: 1
22 | num_processes: 6
23 | rdzv_backend: static
24 | same_network: true
25 | tpu_env: []
26 | tpu_use_cluster: false
27 | tpu_use_sudo: false
28 | use_cpu: false
--------------------------------------------------------------------------------
/modeling/llama/accelerate/fsdp_8gpus.yaml:
--------------------------------------------------------------------------------
1 | # Useful: https://huggingface.co/docs/accelerate/main/en/usage_guides/fsdp
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | fsdp_config:
7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
8 | fsdp_backward_prefetch_policy: BACKWARD_PRE
9 | fsdp_cpu_ram_efficient_loading: true
10 | fsdp_forward_prefetch: false
11 | fsdp_offload_params: false
12 | fsdp_sharding_strategy: 1
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_sync_module_states: true
15 | # Set fsdp_use_orig_params=true if using peft:
16 | # https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
17 | fsdp_use_orig_params: false
18 | machine_rank: 0
19 | main_training_function: main
20 | mixed_precision: bf16
21 | num_machines: 1
22 | num_processes: 8
23 | rdzv_backend: static
24 | same_network: true
25 | tpu_env: []
26 | tpu_use_cluster: false
27 | tpu_use_sudo: false
28 | use_cpu: false
--------------------------------------------------------------------------------
/modeling/llama/conf/config.yaml:
--------------------------------------------------------------------------------
1 | project_dir: ${oc.env:WEBLLAMA_PROJECT_DIR}
2 | seed: 123
3 | project_name: llama_ft
4 |
5 | data:
6 | num_proc: 8
7 | split_path: ${project_dir}/wl_data/splits.json
8 | base_dir: ${project_dir}/wl_data/demonstrations/
9 |
10 | train:
11 | split: train
12 | num_epochs: 3
13 | learning_rate: 3e-5
14 | batch_size_per_device: 4
15 | gradient_accumulation_steps: 1
16 | dataloader_num_workers: 8
17 | gradient_checkpointing: True
18 | use_accelerator_device_map: True # Set to true if using `accelerate`
19 | use_auto_device_map: False # Set to false if using `accelerate`
20 | warmup_ratio: 0
21 | scheduler: linear
22 | optim: adamw_torch
23 |
24 | eval:
25 | split: valid
26 | batch_size_per_device: 8
27 | result_dir: ${project_dir}/results/${project_name}/${eval.split}/${model.name}
28 | load_from_save_dir: True
29 | test_run: False
30 |
31 | model:
32 | name: meta-llama/Meta-Llama-3-8B-Instruct
33 | use_flash_attention_2: True
34 | tokenizer: ${model.name}
35 | template_tokenizer: ${model.tokenizer}
36 | max_inp_len: null
37 | max_out_len: 256
38 | use_rope: True
39 | save_dir: ${project_dir}/checkpoints/${project_name}/${model.name}
40 |
41 | candidates:
42 | k: 10
43 | model: McGill-NLP/MiniLM-L6-dmr # unused but potentially useful
44 | project_name: dmr # unused but potentially useful
45 | split: ${eval.split}
46 | train_path: ${project_dir}/wl_data/candidates/train.jsonl
47 | path: ${project_dir}/wl_data/candidates/${candidates.split}.jsonl
48 |
49 | hydra:
50 | run:
51 | dir: ${project_dir}/logs/${project_name}/${hydra.job.name}/${now:%Y-%m-%d-%H:%M:%S}
52 | # Use the same for sweep's subdir
53 | sweep:
54 | dir: ${hydra.run.dir}
55 | job:
56 | chdir: False
57 | verbose: INFO
--------------------------------------------------------------------------------
/modeling/llama/eval.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import logging
3 | import json
4 | from pathlib import Path
5 |
6 | import hydra
7 | from hydra.core.hydra_config import HydraConfig
8 | from omegaconf import OmegaConf
9 | import torch
10 | from tqdm import tqdm
11 | from transformers import (
12 | AutoTokenizer,
13 | AutoModelForCausalLM,
14 | pipeline,
15 | )
16 | from transformers.pipelines.pt_utils import KeyDataset
17 |
18 | import weblinx as wl
19 | from weblinx.processing import load_candidate_elements
20 | from weblinx.processing.prompt import build_input_records_from_selected_turns, select_turns_and_candidates_for_prompts
21 | from weblinx.utils.hydra import save_path_to_hydra_logs
22 |
23 | from .processing import (
24 | build_prompt_records_for_llama_truncated,
25 | build_formatter_for_multichoice,
26 | insert_formatted_chat_into_records
27 | )
28 |
29 |
30 | @hydra.main(version_base=None, config_path="conf", config_name="config")
31 | def main(cfg):
32 | logger = logging.getLogger(__name__)
33 |
34 | split_path = Path(cfg.data.split_path).expanduser()
35 | result_dir = Path(cfg.eval.result_dir).expanduser()
36 | model_save_dir = Path(cfg.model.save_dir).expanduser()
37 |
38 | max_out_len = cfg.model.max_out_len
39 | split = cfg.eval.split
40 |
41 | result_dir.mkdir(parents=True, exist_ok=True)
42 |
43 | logger.info(OmegaConf.to_yaml(cfg))
44 |
45 | candidates = load_candidate_elements(path=cfg.candidates.path)
46 |
47 | tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer, padding_side="left")
48 | tokenizer.pad_token = tokenizer.eos_token
49 |
50 | # Data loading
51 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=split)
52 | demos = [wl.Demonstration(name, base_dir=cfg.data.base_dir) for name in demo_names]
53 |
54 | format_intent = build_formatter_for_multichoice()
55 | build_prompt_records_fn = partial(
56 | build_prompt_records_for_llama_truncated,
57 | format_intent=format_intent,
58 | tokenizer=tokenizer,
59 | )
60 |
61 | selected_turns = select_turns_and_candidates_for_prompts(
62 | demos=demos,
63 | candidates=candidates,
64 | num_candidates=cfg.candidates.k,
65 | )
66 |
67 | input_records = build_input_records_from_selected_turns(
68 | selected_turns=selected_turns,
69 | format_intent=format_intent,
70 | build_prompt_records_fn=build_prompt_records_fn,
71 | format_prompt_records_fn=None,
72 | )
73 |
74 | template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer)
75 | insert_formatted_chat_into_records(
76 | records=input_records,
77 | tokenizer=template_tokenizer,
78 | include_output_target=False,
79 | )
80 |
81 | model_kwargs = dict(device_map="auto", torch_dtype=torch.bfloat16)
82 | model_kwargs['trust_remote_code'] = cfg.model.get('trust_remote_code', False)
83 |
84 | if cfg.model.use_rope:
85 | model_kwargs["rope_scaling"] = {"type": "dynamic", "factor": 2.0}
86 |
87 | if cfg.model.use_flash_attention_2:
88 | model_kwargs["use_flash_attention_2"] = True
89 |
90 | if cfg.eval.get("load_from_save_dir", False) is True:
91 | model_load_name = str(model_save_dir)
92 | else:
93 | model_load_name = cfg.model.name
94 |
95 | model = AutoModelForCausalLM.from_pretrained(model_load_name, **model_kwargs)
96 |
97 | dset = KeyDataset(input_records, key="text")
98 | pipe = pipeline(
99 | "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16
100 | )
101 | pipe_kwargs = dict(
102 | max_new_tokens=max_out_len,
103 | return_full_text=False,
104 | batch_size=cfg.eval.batch_size_per_device,
105 | pad_token_id=tokenizer.eos_token_id,
106 | )
107 |
108 | results = []
109 |
110 | with torch.cuda.amp.autocast(dtype=torch.bfloat16):
111 | pbar = tqdm(
112 | pipe(dset, **pipe_kwargs), desc="Generating outputs", total=len(dset)
113 | )
114 | for i, out in enumerate(pbar):
115 | rec = input_records[i]
116 | generated_text = out[0]["generated_text"]
117 | result = {
118 | "demo_name": rec["demo_name"],
119 | "turn_index": rec["turn_index"],
120 | "prompt": rec["prompt"],
121 | "text": rec["text"],
122 | "output_predicted": generated_text,
123 | "output_target": rec["output_target"],
124 | "output_target_dict": rec["output_target_dict"],
125 | }
126 |
127 | results.append(result)
128 |
129 | # Save results
130 | with open(result_dir / "results.json", "w") as f:
131 | json.dump(results, f, indent=2)
132 |
133 | # Save the path to hydra_path into the model directory
134 | save_path_to_hydra_logs(save_dir=result_dir)
135 |
136 |
137 | if __name__ == "__main__":
138 | main()
139 |
--------------------------------------------------------------------------------
/modeling/llama/processing.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Callable
3 |
4 | import lxml.html
5 |
6 | import weblinx.utils.format as wlf
7 | from weblinx.processing.dom import clean_and_prune_tree
8 | from weblinx.processing.prompt import (
9 | find_turns_with_instructor_chat,
10 | format_candidates,
11 | format_utterances,
12 | format_utterances_truncated,
13 | get_speaker,
14 | multi_attempt_format_prev_turns_truncated,
15 | )
16 | from weblinx.processing.truncation import (
17 | multi_attempt_truncate_cands_turn,
18 | multi_attempt_truncate_dom_tree,
19 | )
20 |
21 |
22 | def build_formatter_for_multichoice():
23 | format_click = partial(wlf.format_click, formatters=(wlf.format_uid,))
24 | format_text_input = partial(
25 | wlf.format_text_input,
26 | formatters=(
27 | partial(wlf.format_arg_item, name="text", max_length=200),
28 | wlf.format_uid,
29 | ),
30 | )
31 | format_change = partial(
32 | wlf.format_change,
33 | formatters=(
34 | partial(wlf.format_arg_item, name="value", max_length=200),
35 | wlf.format_uid,
36 | ),
37 | )
38 | format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,))
39 | format_load = partial(
40 | wlf.format_load,
41 | include_transition=False,
42 | include_timestamp=False,
43 | max_length=200,
44 | )
45 | format_scroll = partial(wlf.format_scroll, include_timestamp=False)
46 |
47 | format_say = partial(wlf.format_say, include_timestamp=False)
48 |
49 | format_intent_auto = partial(
50 | wlf.format_intent_automatically,
51 | format_change=format_change,
52 | format_click=format_click,
53 | format_load=format_load,
54 | format_say=format_say,
55 | format_scroll=format_scroll,
56 | format_submit=format_submit,
57 | format_text_input=format_text_input,
58 | )
59 |
60 | return format_intent_auto
61 |
62 |
63 | def get_system_prompt_template_for_llama_mc_concise():
64 | sys_prompt_template = (
65 | "You are an AI assistant with a deep understanding of HTML "
66 | "and you must predict actions based on a user request, which will be executed. "
67 | "Use one of the following, replacing [] with an appropriate value: "
68 | "change(value=[str], uid=[str]) ; "
69 | "click(uid=[str]) ; "
70 | "load(url=[str]) ; "
71 | 'say(speaker="navigator", utterance=[str]) ; '
72 | "scroll(x=[int], y=[int]) ; "
73 | "submit(uid=[str]) ;"
74 | "text_input(text=[str], uid=[str]) ;\n"
75 | "The user's first and last {num_utterances} utterances are: "
76 | "{utterance_context} ;\n"
77 | "Viewport size: {height}h x {width}w ;\n"
78 | "Only the last {num_prev_turns} turns are provided."
79 | )
80 |
81 | return sys_prompt_template
82 |
83 |
84 | def get_candidate_prompt_template_for_llama():
85 | return "Here are the top candidates for this turn: {candidate_str}\n"
86 |
87 |
88 | def get_final_user_message():
89 | return "Please select the best action using the correct format, do not provide any other information or explanation."
90 |
91 |
92 | def merge_prev_turns(prev_turns_text_list, final_user_message):
93 | prev_turns_merged = []
94 |
95 | # Merge turns from the same role
96 | for i, turn_text in enumerate(prev_turns_text_list):
97 | role = get_speaker(
98 | turn_text,
99 | instructor_name="user",
100 | navigator_name="assistant",
101 | default_name="unknown",
102 | )
103 |
104 | if i > 0 and prev_turns_merged[-1]["role"] == role:
105 | prev_turns_merged[-1]["content"] += " " + turn_text
106 | else:
107 | prev_turns_merged.append({"role": role, "content": turn_text})
108 |
109 | if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user":
110 | prev_turns_merged[-1]["content"] += " " + final_user_message
111 | else:
112 | prev_turns_merged.append({"role": "user", "content": final_user_message})
113 |
114 | return prev_turns_merged
115 |
116 |
117 | def build_prompt_records_for_llama_truncated(
118 | replay,
119 | turn,
120 | format_intent,
121 | tokenizer,
122 | cands_turn=None,
123 | num_utterances=5,
124 | num_prev_turns=5,
125 | system_prompt_template=None,
126 | candidate_prompt_template=None,
127 | final_user_message=None,
128 | include_html=True,
129 | format_candidates_fn=partial(
130 | format_candidates, max_char_len=None, use_uid_as_rank=True
131 | ),
132 | merge_prev_turns_fn=merge_prev_turns,
133 | format_output_dict_fn: Callable = partial(
134 | wlf.format_output_dictionary, function_key="intent"
135 | ),
136 | max_html_tokens=700,
137 | max_utterance_tokens=40 * 5,
138 | max_prev_turns_tokens=50 * 5,
139 | max_candidates_tokens=65 * 10,
140 | add_unused_len_to_cands=True,
141 | allow_iterative_reduction=False,
142 | parser=None,
143 | ):
144 | """
145 | Parameters
146 | ----------
147 | ...
148 | allow_iterative_reduction : bool
149 | This arg is only relevant when truncate_at_center is used behind the scene (e.g. for
150 | multi_attempt_format_prev_turns_truncated or multi_attempt_truncate_dom_tree). If True,
151 | then we will allow the iterative reduction to continue until the max_tokens is reached.
152 | This is useful when the tokenizer output does not necessarily decrease when we remove
153 | tokens from the input. For example, if we remove a token that is part of a word, but
154 | the updated text is retokenized to the same number of tokens, then we will continue
155 | to remove tokens until we reach the max_tokens limit.
156 | """
157 | if system_prompt_template is None:
158 | system_prompt_template = get_system_prompt_template_for_llama_mc_concise()
159 |
160 | if candidate_prompt_template is None:
161 | candidate_prompt_template = get_candidate_prompt_template_for_llama()
162 |
163 | if final_user_message is None:
164 | final_user_message = get_final_user_message()
165 |
166 | instructor_chat_turns = find_turns_with_instructor_chat(
167 | replay, turn, num_prev_turns=num_prev_turns
168 | )
169 | utterance_context = format_utterances_truncated(
170 | instructor_chat_turns,
171 | tokenizer=tokenizer,
172 | max_tokens=max_utterance_tokens,
173 | num_utterances=num_utterances,
174 | format_utterances_fn=format_utterances,
175 | allow_iterative_reduction=allow_iterative_reduction,
176 | )
177 |
178 | prev_turns_text_list = multi_attempt_format_prev_turns_truncated(
179 | replay=replay,
180 | turn=turn,
181 | format_intent=partial(format_intent, return_as=dict),
182 | tokenizer=tokenizer,
183 | num_prev_turns=num_prev_turns,
184 | turn_sep=None, # output list
185 | max_tokens=max_prev_turns_tokens,
186 | max_attempts=5,
187 | format_output_dict_fn=format_output_dict_fn,
188 | warn_after_attempts=False,
189 | allow_iterative_reduction=allow_iterative_reduction,
190 | )
191 |
192 | prev_turns_merged = merge_prev_turns_fn(
193 | prev_turns_text_list=prev_turns_text_list, final_user_message=final_user_message
194 | )
195 |
196 | sys_prompt = system_prompt_template.format(
197 | num_utterances=num_utterances - 1, # 1 less since we add the first utterance
198 | utterance_context=utterance_context,
199 | height=turn.viewport_height,
200 | width=turn.viewport_width,
201 | num_prev_turns=num_prev_turns,
202 | )
203 |
204 | if include_html and turn.html not in ["", None] and cands_turn is not None:
205 | dom_tree_raw = lxml.html.fromstring(turn.html, parser=parser)
206 | dom_tree_pruned = clean_and_prune_tree(dom_tree_raw, cands_turn=cands_turn)
207 | trunc = multi_attempt_truncate_dom_tree(
208 | dom_tree=dom_tree_pruned,
209 | tokenizer=tokenizer,
210 | max_tokens=max_html_tokens,
211 | warn_after_attempts=False,
212 | allow_iterative_reduction=allow_iterative_reduction,
213 | )
214 | html = trunc["tree_repr"]
215 | sys_prompt = html + sys_prompt
216 | else:
217 | html = ""
218 |
219 | if cands_turn is not None:
220 | if add_unused_len_to_cands:
221 | # Add the unused length to the candidates
222 | num_html_tokens = len(tokenizer.tokenize(html))
223 | num_utter_tokens = len(tokenizer.tokenize(utterance_context))
224 | num_prev_turns_tokens = len(
225 | tokenizer.tokenize(" ".join(prev_turns_text_list))
226 | )
227 | remain_html_tokens = max_html_tokens - num_html_tokens
228 | remain_utter_tokens = max_utterance_tokens - num_utter_tokens
229 | remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens
230 | remain_tokens = (
231 | remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens
232 | )
233 | # Add the unused length to the max_candidates_tokens
234 | max_candidates_tokens += remain_tokens
235 |
236 | cands_turn_trunc = multi_attempt_truncate_cands_turn(
237 | cands_turn=cands_turn,
238 | tokenizer=tokenizer,
239 | max_tokens=max_candidates_tokens,
240 | format_candidates_fn=format_candidates_fn,
241 | warn_after_attempts=False,
242 | allow_iterative_reduction=allow_iterative_reduction,
243 | )
244 | cand_str = format_candidates_fn(cands_turn_trunc, max_char_len=None)
245 | cand_prompt = candidate_prompt_template.format(candidate_str=cand_str)
246 | sys_prompt += "\n" + cand_prompt
247 |
248 | return [{"role": "system", "content": sys_prompt}, *prev_turns_merged]
249 |
250 |
251 | def format_prompt_llama(prompt_records):
252 | """
253 | DEPRECATED: Use `insert_formatted_chat_into_records` instead
254 | """
255 | for i, rec in enumerate(prompt_records):
256 | if i != 0 and rec["role"] == "system":
257 | raise ValueError(
258 | f"System prompt should be the first record. Found it at index {i}."
259 | )
260 | if i == 0 and rec["role"] != "system":
261 | raise ValueError(
262 | f"System prompt should be the first record. Found a {rec['role']} prompt at index {i}."
263 | )
264 |
265 | sys_prompt = prompt_records[0]["content"]
266 | remain_turns = prompt_records[1:]
267 |
268 | prompt = f"[INST] <>\n{sys_prompt}\n<>\n\n"
269 |
270 | for i, turn in enumerate(remain_turns):
271 | # If there's 1 turn remaining and it is not the user, then there was an issue
272 | if i == len(remain_turns) - 1 and turn["role"] != "user":
273 | raise ValueError(
274 | f"Last turn should be the user. Found a {turn['role']} turn at index {i}."
275 | )
276 |
277 | if turn["role"] == "user":
278 | # If the previous turn was system, we do not add the [INST] tag
279 | if i == 0:
280 | text = f"{turn['content']}"
281 | else:
282 | text = f"[INST] {turn['content'].strip()}"
283 |
284 | prompt += text
285 |
286 | elif turn["role"] == "assistant":
287 | prompt += f"[/INST] {turn['content'].strip()}"
288 |
289 | else:
290 | raise ValueError(
291 | f"Unknown role {turn['role']} at index {i}. Should be either 'user' or 'assistant'."
292 | )
293 |
294 | # Add [/INST] tag if the last turn was the user
295 | if remain_turns[-1]["role"] == "user":
296 | prompt += "[/INST]"
297 |
298 | return prompt
299 |
300 |
301 | def __insert_empty_user_content_at_first(prompt: list):
302 | """
303 | Given a list of dictionary representing the input prompt, insert an empty user content at the first position
304 | after system content, only if it is not already a user content. This is done in place.
305 | """
306 | if prompt[0]["role"] != "system":
307 | raise ValueError(
308 | f"First prompt must be a system prompt. Got {prompt[0]['role']} instead."
309 | )
310 |
311 | if prompt[1]["role"] != "user":
312 | prompt.insert(1, {"role": "user", "content": ""})
313 |
314 |
315 | def insert_formatted_chat_into_records(
316 | records,
317 | tokenizer,
318 | include_output_target=True,
319 | origin_key="prompt",
320 | text_key="text",
321 | ):
322 | """
323 | Given a list of records, insert the formatted chat into the records. This is done in place.
324 | Note that we need a tokenizer's `apply_chat_template` method to be available.
325 | """
326 | for i, record in enumerate(records):
327 | __insert_empty_user_content_at_first(record[origin_key])
328 |
329 | if include_output_target:
330 | target = [{"role": "assistant", "content": record["output_target"]}]
331 | combined = record[origin_key] + target
332 | else:
333 | combined = record[origin_key]
334 |
335 | text = tokenizer.apply_chat_template(
336 | combined, tokenize=False, add_generation_prompt=False
337 | )
338 | records[i][text_key] = text
339 |
--------------------------------------------------------------------------------
/modeling/llama/train.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import json
3 | import logging
4 | from pathlib import Path
5 |
6 | from accelerate import Accelerator
7 | import datasets
8 | from omegaconf import OmegaConf
9 | import hydra
10 | import torch
11 | from transformers import (
12 | AutoTokenizer,
13 | TrainingArguments,
14 | AutoModelForCausalLM,
15 | TrainingArguments,
16 | )
17 | from trl import SFTTrainer
18 |
19 | import weblinx as wl
20 | from weblinx.processing import load_candidate_elements
21 | from weblinx.processing.prompt import (
22 | build_input_records_from_selected_turns,
23 | select_turns_and_candidates_for_prompts,
24 | )
25 | from weblinx.utils.hydra import save_path_to_hydra_logs
26 | from weblinx.utils import set_seed
27 |
28 | from .processing import (
29 | build_formatter_for_multichoice,
30 | build_prompt_records_for_llama_truncated,
31 | insert_formatted_chat_into_records,
32 | )
33 |
34 |
35 | @hydra.main(config_path="conf", config_name="config", version_base=None)
36 | def main(cfg):
37 | set_seed(cfg.seed)
38 | split_path = Path(cfg.data.split_path).expanduser()
39 | model_save_dir = Path(cfg.model.save_dir).expanduser()
40 | model_save_dir.mkdir(exist_ok=True, parents=True)
41 | logging.info(OmegaConf.to_yaml(cfg))
42 |
43 | demo_names = wl.utils.load_demo_names_in_split(split_path, split=cfg.train.split)
44 | demos = [wl.Demonstration(demo_name, base_dir=cfg.data.base_dir) for demo_name in demo_names]
45 | candidates = load_candidate_elements(path=cfg.candidates.train_path)
46 |
47 | tokenizer = AutoTokenizer.from_pretrained(cfg.model.tokenizer, padding_side="right")
48 | tokenizer.pad_token = tokenizer.eos_token
49 |
50 | model_kwargs = dict(torch_dtype=torch.bfloat16)
51 | model_kwargs['trust_remote_code'] = cfg.model.get('trust_remote_code', False)
52 |
53 | if cfg.train.use_accelerator_device_map:
54 | accelerator = Accelerator()
55 | model_kwargs["device_map"] = {"": accelerator.process_index}
56 |
57 | elif cfg.train.use_auto_device_map:
58 | model_kwargs["device_map"] = "auto"
59 |
60 | if cfg.model.use_flash_attention_2:
61 | model_kwargs["use_flash_attention_2"] = True
62 |
63 | model = AutoModelForCausalLM.from_pretrained(cfg.model.name, **model_kwargs)
64 |
65 | format_intent = build_formatter_for_multichoice()
66 | input_records_fname = "input_records_trunc.json"
67 | build_prompt_records_fn = partial(
68 | build_prompt_records_for_llama_truncated,
69 | format_intent=format_intent,
70 | tokenizer=tokenizer,
71 | )
72 |
73 | selected_turns = select_turns_and_candidates_for_prompts(
74 | demos=demos,
75 | candidates=candidates,
76 | num_candidates=cfg.candidates.k,
77 | )
78 |
79 | input_records = build_input_records_from_selected_turns(
80 | selected_turns=selected_turns,
81 | format_intent=format_intent,
82 | build_prompt_records_fn=build_prompt_records_fn,
83 | format_prompt_records_fn=None,
84 | )
85 |
86 | template_tokenizer = AutoTokenizer.from_pretrained(cfg.model.template_tokenizer)
87 | insert_formatted_chat_into_records(
88 | input_records, template_tokenizer, include_output_target=True
89 | )
90 |
91 | with open(model_save_dir.joinpath(input_records_fname), "w") as f:
92 | json.dump(input_records, f, indent=2)
93 |
94 | input_records_texts = [{"text": record["text"]} for record in input_records]
95 |
96 | training_args = TrainingArguments(
97 | output_dir=model_save_dir,
98 | optim=cfg.train.optim,
99 | learning_rate=cfg.train.learning_rate,
100 | num_train_epochs=cfg.train.num_epochs,
101 | per_device_train_batch_size=cfg.train.batch_size_per_device,
102 | gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
103 | gradient_checkpointing=cfg.train.gradient_checkpointing,
104 | warmup_ratio=cfg.train.warmup_ratio,
105 | lr_scheduler_type=cfg.train.scheduler,
106 | save_strategy="no",
107 | evaluation_strategy="no",
108 | logging_strategy="epoch",
109 | logging_first_step=True,
110 | prediction_loss_only=True,
111 | bf16=True,
112 | bf16_full_eval=True,
113 | )
114 |
115 | trainer = SFTTrainer(
116 | model=model,
117 | tokenizer=tokenizer,
118 | args=training_args,
119 | train_dataset=datasets.Dataset.from_list(input_records_texts),
120 | max_seq_length=model.config.max_position_embeddings,
121 | dataset_text_field="text",
122 | )
123 |
124 | trainer.train()
125 |
126 | # Save model, tokenizer, trainer state, and path to hydra logs
127 | trainer.save_model(model_save_dir)
128 | tokenizer.save_pretrained(model_save_dir)
129 | trainer.state.save_to_json(model_save_dir / "trainer_state.json")
130 | save_path_to_hydra_logs(save_dir=model_save_dir)
131 |
132 | # if the model is saved as pytorch_model_fsdp.bin, rename it to pytorch_model.bin
133 | fsdp_model_path = model_save_dir / "pytorch_model_fsdp.bin"
134 | if fsdp_model_path.exists():
135 | fsdp_model_path.rename(model_save_dir / "pytorch_model.bin")
136 |
137 |
138 | if __name__ == "__main__":
139 | main()
140 |
--------------------------------------------------------------------------------
/modeling/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.35.0 # Future version may break the code, upgrade with caution
2 | lxml
3 | numpy
4 | datasets
5 | torch
6 | sentence-transformers
7 | peft
8 | backoff
9 | tensorboardX
10 | hydra-core
11 | peft
12 | accelerate
13 | optimum
14 | openai
15 | tiktoken
16 | trl
17 | bitsandbytes
18 | coloredlogs
19 | sacrebleu
20 | bert-score
21 | packaging
22 | ninja
23 | wheel
--------------------------------------------------------------------------------
/requirements-basic.txt:
--------------------------------------------------------------------------------
1 | weblinx>=0.3.0rc1
2 | lxml
3 | numpy
--------------------------------------------------------------------------------
/requirements-extra.txt:
--------------------------------------------------------------------------------
1 | weblinx[eval]>=0.3.0.rc1
2 | streamlit
3 | sentence-transformers
4 | transformers
5 | playwright
6 | browsergym
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | package_name = "webllama"
4 | version = {}
5 | with open(f"{package_name}/version.py") as fp:
6 | exec(fp.read(), version)
7 |
8 | with open("README.md") as fp:
9 | long_description = fp.read()
10 |
11 | with open('requirements-extra.txt') as f:
12 | extras = f.read().splitlines()
13 |
14 | with open('requirements-basic.txt') as f:
15 | install_requires = f.read().splitlines()
16 |
17 | extras_require = {
18 | "dev": ["black"],
19 | "extra": extras,
20 | }
21 | # Dynamically create the 'all' extra by combining all other extras
22 | extras_require["all"] = sum(extras_require.values(), [])
23 |
24 | setup(
25 | name=package_name,
26 | version=version["__version__"],
27 | author="Xing Han Lù",
28 | author_email=f"{package_name}@googlegroups.com",
29 | url=f"https://github.com/McGill-NLP/{package_name}",
30 | description="Llama-powered agents for automatic web browsing",
31 | long_description=long_description,
32 | packages=find_packages(include=[f"{package_name}*"]),
33 | package_data={},
34 | install_requires=install_requires,
35 | extras_require=extras_require,
36 | classifiers=[
37 | "Programming Language :: Python :: 3",
38 | "License :: OSI Approved :: MIT License",
39 | "Operating System :: OS Independent",
40 | ],
41 | python_requires=">=3.8",
42 | # Cast long description to markdown
43 | long_description_content_type="text/markdown",
44 | )
--------------------------------------------------------------------------------
/tests/requirements.txt:
--------------------------------------------------------------------------------
1 | -e .[all]
--------------------------------------------------------------------------------
/tests/test_web_turn_processor.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import time
3 | import logging
4 | import unittest
5 |
6 | from sentence_transformers import SentenceTransformer
7 | from transformers import AutoTokenizer, pipeline
8 | import weblinx as wl
9 | import webllama.experimental as wa
10 |
11 | logging.getLogger("urllib3").setLevel(logging.WARNING)
12 |
13 |
14 | class TestWebTurnProcessor(unittest.TestCase):
15 | def setUp(self):
16 |
17 | demos = wl.list_demonstrations("tests/demonstrations")
18 | replay = wl.Replay.from_demonstration(demos[0])
19 | turn = replay[26]
20 |
21 | self.turn = turn
22 | self.replay = replay
23 | self.action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx"
24 |
25 | self.tokenizer = AutoTokenizer.from_pretrained(self.action_model_name)
26 |
27 | format_intent_input_dmr, format_intent_out_dmr = (
28 | wa.formatting.build_formatters_dmr()
29 | )
30 | format_intent_am = partial(
31 | wa.formatting.build_formatters_action_model(), return_as=dict
32 | )
33 | self.action_history = wa.functions.create_action_history_from_replay(
34 | replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index
35 | )
36 | self.state = wa.classes.State(
37 | index=turn.index,
38 | html=turn.html,
39 | bboxes=turn.bboxes,
40 | viewport_height=turn.viewport_height,
41 | viewport_width=turn.viewport_width,
42 | type=turn.type,
43 | )
44 |
45 | def test_prepare_dmr_query(self):
46 | # We will initialize our processor, which helps us prepare the input for action model
47 | proc = wa.processing.WebTurnProcessor(tokenizer=self.tokenizer)
48 |
49 | # Step 1: prepare query, run DMR and prepare retrieved candidates
50 | query_dmr = proc.prepare_dmr_query(self.action_history, self.state)
51 |
52 | CORRECT_RESULT = 'Viewport(height=746, width=1536) ---- Instructor Utterances: [00:07] Hello [00:13] Open independent ie Website. [01:30] Go to life and send me some life related news [04:00] Open second one and Summarize the first three paragraphs in a few words ---- Previous Turns:tabswitch(origin=102465633, target=102465635, timestamp="04:19") ; load(url="https://search.yahoo.com/search?fr=mcafee&type=E211US714G0&p=chatgpt", timestamp="04:23") ; click(x=268, y=201, tag="a", attrs={}, timestamp="04:24") ; tabcreate(target=102465636, timestamp="04:25") ; tabswitch(origin=102465635, target=102465636, timestamp="04:25")'
53 | self.assertIsInstance(query_dmr, str)
54 | self.assertEqual(query_dmr, CORRECT_RESULT)
55 |
--------------------------------------------------------------------------------
/webllama/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 | from . import experimental
--------------------------------------------------------------------------------
/webllama/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | from . import classes, functions, integrations, formatting, processing, templates, web
--------------------------------------------------------------------------------
/webllama/experimental/classes.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from dataclasses import dataclass
3 | from typing import Callable, Dict, List, Tuple, TypedDict
4 | import typing
5 |
6 | from weblinx.utils.format import format_output_dictionary
7 |
8 | # Custom types
9 | UID = typing.NewType("UID", str)
10 | AttrsCore = TypedDict(
11 | "AttrsCore",
12 | {"class": str, "title": str, "href": str, "aria-label": str, "d": str, "src": str},
13 | )
14 |
15 |
16 | class BBox(TypedDict):
17 | """
18 | A class to represent the bounding box of an element.
19 |
20 | Attributes
21 | ----------
22 | x : int
23 | The x-coordinate of the bounding box.
24 | y : int
25 | The y-coordinate of the bounding box.
26 | width : float
27 | The width of the bounding box.
28 | height : float
29 | The height of the bounding box.
30 | top : float, optional
31 | The top position of the bounding box, calculated from `y` if not provided.
32 | bottom : float, optional
33 | The bottom position of the bounding box, calculated from `y` and `height` if not provided.
34 | left : float, optional
35 | The left position of the bounding box, calculated from `x` if not provided.
36 | right : float, optional
37 | The right position of the bounding box, calculated from `x` and `width` if not provided.
38 | """
39 | x: int
40 | y: int
41 | width: float
42 | height: float
43 | top: float = None
44 | bottom: float = None
45 | left: float = None
46 | right: float = None
47 |
48 | def __post_init__(self):
49 | """
50 | Ensures required attributes are provided and calculates optional attributes if not given.
51 | For example, if `top` is not provided, it is calculated from `y`.
52 | """
53 | if any(x is None for x in [self.x, self.y, self.width, self.height]):
54 | raise ValueError("x, y, width, and height must be provided.")
55 |
56 | if self.top is None:
57 | self.top = self.y
58 |
59 | if self.bottom is None:
60 | self.bottom = self.y + self.height
61 |
62 | if self.left is None:
63 | self.left = self.x
64 |
65 | if self.right is None:
66 | self.right = self.x + self.width
67 |
68 |
69 | @dataclass
70 | class State:
71 | """
72 | A class to represent the state during navigation.
73 |
74 | Attributes
75 | ----------
76 | index : int
77 | The index of the state in the sequence of states.
78 | html : str
79 | The DOM tree represented using HTML.
80 | bboxes : Dict[UID, BBox]
81 | A dictionary mapping unique IDs to bounding boxes.
82 | viewport_height : int
83 | The height of the viewport of the browser.
84 | viewport_width : int
85 | The width of the viewport of the browser.
86 | type : str
87 | The type of the state, either "browser" or "chat".
88 |
89 | Methods
90 | -------
91 | from_dict(cls, dictionary):
92 | Creates a `State` instance from a dictionary.
93 | to_dict():
94 | Converts the `State` instance to a dictionary.
95 | """
96 | index: int
97 | html: str
98 | bboxes: Dict[UID, BBox]
99 | viewport_height: int
100 | viewport_width: int
101 | type: str # either "browser" or "chat"
102 |
103 | # check type
104 | def __post_init__(self):
105 | if self.type not in ["browser", "chat"]:
106 | raise ValueError("type must be either 'browser' or 'chat'.")
107 |
108 | @classmethod
109 | def from_dict(cls, dictionary):
110 | """
111 | Creates a `State` instance from a dictionary.
112 |
113 | Parameters
114 | ----------
115 | dictionary : dict
116 | The dictionary to create the `State` instance from.
117 |
118 | Returns
119 | -------
120 | State
121 | The created `State` instance.
122 | """
123 | return cls(
124 | index=dictionary["index"],
125 | html=dictionary["html"],
126 | bboxes=dictionary["bboxes"],
127 | viewport_height=dictionary["viewport_height"],
128 | viewport_width=dictionary["viewport_width"],
129 | type=dictionary["type"],
130 | )
131 |
132 | def to_dict(self):
133 | """
134 | Converts the `State` instance to a dictionary.
135 |
136 | Returns
137 | -------
138 | dict
139 | A dictionary representation of the `State` instance.
140 | """
141 | return {
142 | "index": self.index,
143 | "html": self.html,
144 | "bboxes": self.bboxes,
145 | "viewport_height": self.viewport_height,
146 | "viewport_width": self.viewport_width,
147 | "type": self.type,
148 | }
149 |
150 | @dataclass
151 | class Action:
152 | """
153 | A class to represent an action taken by the user.
154 |
155 | Attributes
156 | ----------
157 | type : str
158 | The type of the action, either "chat" or "browser".
159 | index : int
160 | The index of the action in the sequence of state/actions.
161 | intent : str
162 | The intent of the action (e.g., "click", "type", "scroll", "say").
163 | args : Dict[str, str]
164 | A dictionary of arguments associated with the action, such as the unique
165 | ID of the element clicked, the text typed, or the message said.
166 | timestamp : float
167 | The timestamp of the action in seconds, relative to the start time.
168 | tag : str, optional
169 | The HTML tag associated with the action (e.g., "button", "input").
170 | attrs : AttrsCore, optional
171 | The attributes associated with the action (e.g., "class", "title", "href", "aria-label", "d", "src").
172 | """
173 | type: str
174 | index: int
175 | intent: str
176 | args: Dict[str, str]
177 | timestamp: float
178 | tag: str = None
179 | attrs: AttrsCore = None
180 |
181 | def get(self, key):
182 | """
183 | Retrieves the value of the specified argument key.
184 |
185 | Parameters
186 | ----------
187 | key : str
188 | The key of the argument to retrieve.
189 |
190 | Returns
191 | -------
192 | str
193 | The value of the specified argument key.
194 | """
195 | return self.args.get(key, None)
196 |
197 | @classmethod
198 | def from_dict(
199 | cls,
200 | dictionary: Dict,
201 | included_attrs: Tuple[str] = ("class", "title", "href", "aria-label", "d", "src"),
202 | ) -> "Action":
203 | """
204 | Creates an `Action` instance from a dictionary.
205 |
206 | Parameters
207 | ----------
208 | dictionary : dict
209 | The dictionary to create the `Action` instance from. It should have the following
210 | keys: "intent", "index", "timestamp", "attrs" (optional), "tag" (optional), and
211 | any other keys as arguments. Moreover, the type of the action is inferred from
212 | the "intent" key.
213 | included_attrs : tuple of str, optional
214 | A tuple of attribute keys to include in the `attrs` dictionary.
215 |
216 | Returns
217 | -------
218 | Action
219 | The created `Action` instance.
220 | """
221 | di = deepcopy(dictionary)
222 | intent = di.pop("intent")
223 | index = di.pop("index")
224 | timestamp = di.pop("timestamp")
225 | attrs = di.pop("attrs", None)
226 | if attrs is not None:
227 | attrs = {k: v for k, v in attrs.items() if k in included_attrs}
228 |
229 | args = di
230 | type_ = "chat" if intent == "say" else "browser"
231 | tag = di.pop("tag") if "tag" in di else None
232 |
233 | return cls(
234 | index=index,
235 | intent=intent,
236 | args=args,
237 | type=type_,
238 | timestamp=timestamp,
239 | attrs=attrs,
240 | tag=tag,
241 | )
242 |
243 | def to_dict(
244 | self,
245 | include_timestamp=True,
246 | include_attrs=True,
247 | include_tag=True,
248 | include_index=True,
249 | drop_none_coords=False,
250 | format_timestamp_fn=None,
251 | ignore_args=None,
252 | ):
253 | """
254 | Convert the action to a dictionary, given specific options.
255 |
256 | Parameters
257 | ----------
258 | include_timestamp: bool
259 | Whether to include the timestamp in the output dictionary, as "timestamp"
260 | include_attrs: bool
261 | Whether to include the attributes in the output dictionary, as "attrs"
262 | include_tag: bool
263 | Whether to include the tag in the output dictionary, as "tag"
264 | include_index: bool
265 | Whether to include the index in the output dictionary, as "index"
266 | ignore_args: list
267 | A list of keys to ignore in the args dictionary, if None, then all keys are included
268 | format_timestamp_fn: callable
269 | A function to format the timestamp, if None, then the raw timestamp is used
270 | start_time: float
271 | The start time of the action, used to calculate the timestamp
272 |
273 | Returns
274 | -------
275 | dict
276 | A dictionary representation of the action.
277 | """
278 | if ignore_args is not None:
279 | args = {k: v for k, v in self.args.items() if k not in ignore_args}
280 | else:
281 | args = self.args
282 |
283 | out = {"intent": self.intent, **args}
284 |
285 | if include_tag and self.tag is not None:
286 | out["tag"] = self.tag
287 |
288 | if include_attrs and self.attrs is not None:
289 | out["attrs"] = self.attrs
290 |
291 | if include_timestamp:
292 | if format_timestamp_fn is not None:
293 | out["timestamp"] = format_timestamp_fn(self)["timestamp"]
294 | else:
295 | out["timestamp"] = self.timestamp
296 |
297 | if include_index:
298 | out["index"] = self.index
299 |
300 | if drop_none_coords:
301 | if "x" in out and out["x"] is None:
302 | del out["x"]
303 | if "y" in out and out["y"] is None:
304 | del out["y"]
305 |
306 | return out
307 |
308 | def to_str(self, **kwargs):
309 | """
310 | Converts the `Action` instance to a formatted string.
311 |
312 | Parameters
313 | ----------
314 | kwargs : dict
315 | Keyword arguments to pass to the `to_dict` method.
316 |
317 | Returns
318 | -------
319 | str
320 | A formatted string representation of the action.
321 |
322 | Notes
323 | -----
324 |
325 | This runs the `to_dict` method and then formats the output dictionary as a string, using
326 | `weblinx.utils.format.format_output_dictionary` with the intent as the "function" key.
327 | """
328 | di = self.to_dict(**kwargs)
329 | return format_output_dictionary(di, function_key="intent", return_as=str)
330 |
331 | def items(self):
332 | """
333 | Mimics `weblinx.Turn.items()` to retrieve dictionary items of the action.
334 |
335 | Returns
336 | -------
337 | ItemsView
338 | A view object that displays a list of a dictionary's key-value tuple pairs.
339 |
340 | Notes
341 | -----
342 |
343 | This method is aimed to mimic `weblinx.Turn.items()`
344 | """
345 | di = self.to_dict(
346 | include_timestamp=True,
347 | include_attrs=False,
348 | include_tag=False,
349 | include_index=False,
350 | drop_none_coords=True,
351 | )
352 |
353 | return di.items()
354 |
--------------------------------------------------------------------------------
/webllama/experimental/formatting.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import weblinx.utils.format as wlf
3 |
4 | def build_formatters_action_model() -> callable:
5 | """
6 | Builds and returns a dictionary of formatters for action model events.
7 |
8 | This function uses partial functions from the `weblinx.utils.format` module to create
9 | formatters for various user actions, such as clicks, text inputs, changes, etc. These
10 | formatters are then combined into a single formatter function for automatically formatting
11 | intents based on user actions.
12 |
13 | Returns
14 | -------
15 | function
16 | A function that formats intents automatically using the defined formatters.
17 |
18 | Notes
19 | -----
20 |
21 | Slightly improved over original implementation from weblinx:
22 | https://github.com/McGill-NLP/weblinx/blob/7f151eaf819a9665b9b0b2232a99db6d4c4d2738/modeling/llama/processing.py#L23
23 | """
24 | format_click = partial(wlf.format_click, formatters=(wlf.format_uid,))
25 | format_text_input = partial(
26 | wlf.format_text_input,
27 | formatters=(
28 | partial(wlf.format_arg_item, name="text", max_length=200),
29 | wlf.format_uid,
30 | ),
31 | )
32 | format_change = partial(
33 | wlf.format_change,
34 | formatters=(
35 | partial(wlf.format_arg_item, name="value", max_length=200),
36 | wlf.format_uid,
37 | ),
38 | )
39 | format_copy = partial(wlf.format_copy, include_timestamp=False)
40 | format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,))
41 | format_load = partial(
42 | wlf.format_load,
43 | include_transition=False,
44 | include_timestamp=False,
45 | max_length=200,
46 | )
47 | format_hover = partial(wlf.format_hover, formatters=(wlf.format_uid,))
48 | format_paste = partial(wlf.format_paste, include_timestamp=False)
49 | format_scroll = partial(wlf.format_scroll, include_timestamp=False)
50 | format_say = partial(wlf.format_say, include_timestamp=False)
51 | format_tab = wlf.format_tab
52 |
53 | format_intent_auto = partial(
54 | wlf.format_intent_automatically,
55 | format_change=format_change,
56 | format_click=format_click,
57 | format_copy=format_copy,
58 | format_hover=format_hover,
59 | format_load=format_load,
60 | format_paste=format_paste,
61 | format_say=format_say,
62 | format_scroll=format_scroll,
63 | format_submit=format_submit,
64 | format_tab=format_tab,
65 | format_text_input=format_text_input,
66 | )
67 |
68 | return format_intent_auto
69 |
70 |
71 | def build_formatters_dmr():
72 | """
73 | Builds and returns two dictionaries of formatters for DMR (Document Model Retrieval) events.
74 |
75 | This function creates formatters for both input and output events using partial functions
76 | from the `weblinx.utils.format` module. For inputs, it formats elements, clicks, changes,
77 | hovers, submits, and text inputs. For outputs, it formats elements, clicks, changes, loads,
78 | scrolls, and text inputs.
79 |
80 | Returns
81 | -------
82 | tuple of functions
83 | A tuple containing two functions: one for formatting input intents and one for formatting
84 | output intents.
85 |
86 |
87 | Examples
88 | -----
89 |
90 | ```python
91 | format_intent_input, format_intent_out = build_formatters_dmr()
92 | ```
93 |
94 | """
95 | format_element_input = partial(
96 | wlf.format_element,
97 | include_text=False,
98 | include_attrs=("class", "title", "href", "aria-label", "d", "src"),
99 | )
100 | format_click_input = partial(
101 | wlf.format_click,
102 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp),
103 | )
104 | format_change_input = partial(
105 | wlf.format_change,
106 | formatters=(
107 | partial(wlf.format_arg_item, name="value"),
108 | format_element_input,
109 | wlf.format_timestamp,
110 | ),
111 | )
112 | format_hover_input = partial(
113 | wlf.format_hover,
114 | formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp),
115 | )
116 |
117 | format_submit_input = partial(
118 | wlf.format_submit, formatters=(format_element_input, wlf.format_timestamp)
119 | )
120 |
121 | format_text_input_input = partial(
122 | wlf.format_text_input,
123 | formatters=(
124 | partial(wlf.format_arg_item, name="text"),
125 | partial(format_element_input),
126 | wlf.format_timestamp,
127 | ),
128 | )
129 |
130 | format_intent_input = partial(
131 | wlf.format_intent_automatically,
132 | format_click=format_click_input,
133 | format_change=format_change_input,
134 | format_hover=format_hover_input,
135 | format_submit=format_submit_input,
136 | format_text_input=format_text_input_input,
137 | format_tab=wlf.format_tab,
138 | return_as=str,
139 | )
140 |
141 | # second, for the output (prediction text)
142 | format_element_out = partial(
143 | wlf.format_element,
144 | # Only want the tag
145 | include_text=False,
146 | include_attrs=False,
147 | )
148 |
149 | format_click_out = partial(wlf.format_click, formatters=(wlf.format_mouse_xy,))
150 | format_text_input_out = partial(
151 | wlf.format_text_input,
152 | formatters=(
153 | partial(wlf.format_arg_item, name="text", max_length=200),
154 | format_element_out,
155 | wlf.format_target_bbox,
156 | ),
157 | )
158 | format_change_out = partial(
159 | wlf.format_change,
160 | formatters=(
161 | partial(wlf.format_arg_item, name="value", max_length=200),
162 | format_element_out,
163 | wlf.format_target_bbox,
164 | ),
165 | )
166 | format_submit_out = partial(
167 | wlf.format_submit, formatters=(format_element_out, wlf.format_target_bbox)
168 | )
169 | format_load_out = partial(
170 | wlf.format_load,
171 | include_transition=False,
172 | include_timestamp=False,
173 | max_length=200,
174 | )
175 | format_scroll_out = partial(wlf.format_scroll, include_timestamp=False)
176 |
177 | format_say_out = partial(wlf.format_say, include_timestamp=False)
178 |
179 | format_intent_out = partial(
180 | wlf.format_intent_automatically,
181 | format_change=format_change_out,
182 | format_click=format_click_out,
183 | format_load=format_load_out,
184 | format_say=format_say_out,
185 | format_scroll=format_scroll_out,
186 | format_submit=format_submit_out,
187 | format_text_input=format_text_input_out,
188 | )
189 |
190 | return format_intent_input, format_intent_out
191 |
192 |
--------------------------------------------------------------------------------
/webllama/experimental/integrations/__init__.py:
--------------------------------------------------------------------------------
1 | from . import browsergym
--------------------------------------------------------------------------------
/webllama/experimental/integrations/browsergym/__init__.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import random
3 | import lxml.html
4 |
5 |
6 | def postprocess_for_browsergym(action, uid_map=None):
7 | # if uid is a int, we need to convert it to a string
8 | uid_map = {} if uid_map is None else uid_map
9 |
10 | if "uid" in action:
11 | action["uid"] = str(action["uid"])
12 | if action["uid"] in uid_map:
13 | action["uid"] = uid_map[action["uid"]]
14 |
15 | action = deepcopy(action)
16 | if action["intent"] == "scroll":
17 | if not "x" in action:
18 | action["x"] = 0
19 | if not "y" in action:
20 | action["y"] = 0
21 |
22 | return action
23 |
24 |
25 | def generate_uuid(old_attr_name):
26 | # We do not use old_attr_name here, but it is required by the signature of the function.
27 | def replace_char(c):
28 | r = random.randint(0, 15)
29 | v = r if c == "x" else (r & 0x3 | 0x8)
30 | return format(v, "x")
31 |
32 | uuid_template = "xxxxxxxx-xxxx-4xxx"
33 | return "".join(replace_char(c) if c in "xy" else c for c in uuid_template)
34 |
35 |
36 | def reverse_dict(mapping):
37 | return {v: k for k, v in mapping.items()}
38 |
39 | def replace_bid_with_wl_uid(
40 | dom_str,
41 | new_attr_name="data-webtasks-id",
42 | old_attr_name="bid",
43 | generate_fn=generate_uuid,
44 | return_mapping=False,
45 | ):
46 | """
47 | Replaces the bid attributes in the dom string with a new attribute name and a new unique id.
48 |
49 | generate_fn must be a function that takes the old attribute name and returns a new unique id.
50 | """
51 | html_parsed = lxml.html.fromstring(dom_str)
52 |
53 | new_attr_mapping = {
54 | str(elem.get(old_attr_name)): generate_fn(old_attr_name)
55 | for elem in html_parsed.xpath(f"//*[@{old_attr_name}]")
56 | if elem.get(old_attr_name) is not None
57 | }
58 |
59 | # remap the attributes from bid="key" to data-webtasks-id="value"
60 | for elem in html_parsed.xpath("//*[@bid]"):
61 | elem.set(new_attr_name, new_attr_mapping[elem.get(old_attr_name)])
62 | elem.attrib.pop(old_attr_name)
63 |
64 | html_processed_str = lxml.html.tostring(html_parsed).decode("utf-8")
65 |
66 | if return_mapping:
67 | return html_processed_str, new_attr_mapping
68 | else:
69 | return html_processed_str
--------------------------------------------------------------------------------
/webllama/experimental/integrations/browsergym/functions.py:
--------------------------------------------------------------------------------
1 | from browsergym.core.action.utils import get_elem_by_bid
2 | import playwright.sync_api
3 |
4 | page: playwright.sync_api.Page = None
5 | send_message_to_user: callable = None
6 |
7 | # Define your actions here
8 |
9 | def say(utterance: str, *args, **kwargs):
10 | """
11 | Sends a message to the user.
12 |
13 | Examples:
14 | say("Based on the results of my search, the city was built in 1751.")
15 | """
16 | send_message_to_user(utterance)
17 |
18 |
19 | def click(uid: str, *args,**kwargs):
20 | """
21 | Click an element.
22 |
23 | Examples:
24 | click('51')
25 | """
26 | elem = get_elem_by_bid(page, uid)
27 | elem.click()
28 |
29 | def textinput(uid: str, value: str, *args,**kwargs):
30 | """
31 | Fill out a form field. It focuses the element and triggers an input event with the entered text.
32 | It works for ,