.
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 2D Positional Embeddings for Webpage Structural Understanding 🦙👀
6 |
7 |
8 | 🐦 Twitter
9 |
10 |
11 | # llama2d
12 | How can we get LLM-based agents to understand the *visual structure* of a webpage? We fine-tune Llama on OCR'd screenshots of webpages but with 2D positional embeddings, enabling it to "see" the structure of a webpage rather than just a sequence of tokens.
13 |
14 | To construct the dataset, we:
15 | - took each MHTML provided by Mind2Web
16 | - rendered it in Playwright
17 | - tagged interactable elements
18 | - ran OCR to get (x, y) coordinates of words on the page
19 |
20 | We then calculate 2D positional embeddings for each word and fine-tune Llama!
21 |
22 | Note: this repo is still a bit disorganized and a work in progress, but we encourage community contributions & forks to explore this direction in LLM web interaction
23 |
24 | ## Setup
25 |
26 | ```bash
27 | git clone https://github.com/llama2d/llama2d.git --recursive
28 | cd transformers && pip install -e . && cd ..
29 | pip install -r requirements.txt
30 | playwright install
31 | pre-commit install
32 | ```
33 |
34 | ## Secrets
35 |
36 | 1. Create a Google Cloud Vision credential file and put it at `secrets/gcp-vision.json`.
37 |
38 | 2. Run the Modal login command in the Slack channel. It looks like this: `modal token set --token-id --token-secret `
39 |
40 | ## Datasets
41 |
42 | Datasets are defined in the `src/llama2d/datasets/` directory.
43 |
44 | Every row of a dataset is defined by a prompt, a 2D "screen", and an output.
45 |
46 | However, a row is converted into pure tokens before being fed into Llama - see [this dataset]() for an example.
47 |
48 | You can visualize a dataset on Huggingface by copying all the numbers in a row and pasting it into [this webpage]().
49 |
50 | ### Synthetic datasets
51 |
52 | We will have lots of synthetic datasets--i.e. the Zoo Compass dataset defined in `src/llama2d/datasets/synthetic/zoo_compass.py`.
53 |
54 | These datasets are simple. They each spit out a bunch of rows with `prompt: str`, `screen: Llama2dScreen`, and `output: str`.
55 |
56 | It is easy to create a `Llama2dScreen`:
57 |
58 | ```py
59 | from llama2d.vision import Llama2dScreen
60 |
61 | screen = Llama2dScreen()
62 |
63 | screen.push_word(word="north",xy=(0.5,0))
64 | screen.push_word(word="south",xy=(0.5,1))
65 | screen.push_word(word="east",xy=(1,0.5))
66 | screen.push_word(word="west",xy=(0,0.5))
67 | ```
68 |
69 | To create this dataset, look at it in your console, and publish it to Huggingface, run the following:
70 |
71 | ```bash
72 | python -m llama2d.datasets.synthetic.zoo_compass
73 | ```
74 |
75 | I recommend reading the Zoo Compass dataset code for reference.
76 |
77 | ### Pretraining dataset
78 |
79 | This dataset contains over 600 retail websites. The task is next-token prediction.
80 |
81 | Here, the prompt and output are empty. The website text is all in the screen.
82 |
83 | The model is trained to predict the next token of the website text. It is NOT trained to predict the position of the next token.
84 |
85 | This dataset is implemented in [`src/llama2d/datasets/pretraining.py`](https://github.com/Llama2D/llama2d/blob/main/src/llama2d/datasets/pretraining.py).
86 |
87 | To collect this dataset and upload it to Huggingface, run the file:
88 |
89 | ```bash
90 | python -m src.llama2d.datasets.pretraining
91 | ```
92 |
93 | ### Mind2Web dataset
94 |
95 | This dataset contains ~1000 tasks from the Mind2Web dataset.
96 |
97 | The task is to take an intention, a screenshot of a webpage, and choose the correct action to take.
98 |
99 | To download this dataset, first download the Mind2Web `mhtml` files generated by Andrew Stelmach.
100 |
101 | The zip with the files is [here](https://drive.google.com/file/d/1RGNcNTlQrZhF1KuGBcGenkON1u74_IYx/view). Download it and unzip it into `src/data/mind2web-mhtml`. Your `src/data/mind2web-mhtml` directory should look like this:
102 |
103 | ```
104 | src/data/mind2web-mhtml
105 | ├── 0004f2a7-90d6-4f96-902a-b1d25d39a93d_before.mhtml
106 | ├── 00068a1e-b6a3-4c53-a60c-3ed777d4b05d_before.mhtml
107 | ├── 00146964-4b74-4e28-8292-5810a604639a_before.mhtml
108 | ├── 0018120a-8da1-4a36-a1c4-b4642c97211b_before.mhtml
109 | ```
110 |
111 | To process and cache the Mind2Web dataset, run the following:
112 |
113 | ```bash
114 | python -m llama2d.datasets.mind2web
115 | ```
116 |
117 | ## Modal training
118 |
119 | To train a model with Modal, change your directory to `src/llama2d/modal/` and run i.e.
120 |
121 | ```bash
122 | modal run train.py --dataset hf_dataset.py --repo src/llama2d/llama2d-mind2web --no-peft --num-epochs 4
123 | ```
124 |
125 | `peft` is a synonym for LoRA. `hf_dataset` means we are using a dataset uploaded to Huggingface (thanks Matthew!). [`src/llama2d/llama2d-mind2web`](https://huggingface.co/datasets/llama2d/llama2d-mind2web/viewer/default/train?row=0) is the Huggingface repo containing the dataset.
126 |
127 | ## In the Repo
128 |
129 | To add a requirement, add it to `requirements.in`, run `pip-compile`, and run `pip-sync`.
130 |
131 | Run `black . --exclude '/transformers/|/venv/'` to format the code.
132 |
133 | Pre-commit hooks are used to maintain code quality.
134 |
135 | ## Citations
136 |
137 | ```
138 | bibtex
139 | @misc{llama2d2024,
140 | title = {Llama2D: Two Dimensional Positional Embeddings for Webpage Structural Understanding},
141 | author = {Houjun Liu and Andrew Healey and Andrew Stelmach and Christopher Settles and Sarma Tangirala and Rohan Pandey},
142 | year = {2024},
143 | howpublished = {GitHub},
144 | url = {https://github.com/llama2d/llama2d}
145 | }
146 | ```
147 |
--------------------------------------------------------------------------------
/docs/index.js:
--------------------------------------------------------------------------------
1 | document.body.onload = ()=>{
2 |
3 | console.log("hey")
4 |
5 | // listen for updates to the textarea
6 | // when it updates, extract [input_ids,coords,labels,attention_mask] from the textarea
7 | const textarea = document.querySelector('textarea');
8 |
9 | textarea.addEventListener('input', function () {
10 | render();
11 | });
12 |
13 | const canvas = document.getElementById('rendered-output');
14 | const ctx = canvas.getContext('2d');
15 |
16 | window.render = ()=>{
17 | const text = textarea.value;
18 | // split text into newlines, parse each as JSON
19 | const lines = text.split('\n').filter(line=>line.trim().length>0);
20 | const [tokenIds, coords, labels, attentionMask] = lines.map(JSON.parse);
21 |
22 | const lastIdx = tokenIds.findLastIndex(i=>i>0)
23 | const firstIdxLastChunk = tokenIds.slice(0,lastIdx).findLastIndex(i=>i<=0)+1
24 |
25 | // console.log(lastIdx,firstIdxLastChunk)
26 | // console.log(labelIds.slice(firstIdxLastChunk,lastIdx+1))
27 | // console.log(llamaTokenizer.decode([0,...tokenIds.slice(0,firstIdxLastChunk)]))
28 |
29 | const prompt = llamaTokenizer.decode([0,...tokenIds.slice(0,firstIdxLastChunk)])
30 |
31 | const completion = llamaTokenizer.decode([0,...tokenIds.slice(firstIdxLastChunk,lastIdx+1)])
32 |
33 | const coordTokens = coords.map(([x,y],i)=>[x,y,tokenIds[i]]).filter(([x,y,tokenid])=>x>=0);
34 |
35 | /*
36 | python impl:
37 | # graph tokens with coords in a matplotlib figure
38 | # print the tokens without coords
39 |
40 | # every word has a few tokens with the same coord.
41 | # we should generate the word, turn it into a string, then plot it at the coord
42 |
43 | without_coords = [input_ids[i] for i in range(len(input_ids)) if coords[i][0] == -1 and attention_mask[i] == 1]
44 |
45 | with_coords = [(input_ids[i],coords[i]) for i in range(len(input_ids)) if coords[i][0] != -1 and attention_mask[i] == 1]
46 | # split with_coords into words - where a word is a list of tokens with the same coord
47 | words = []
48 | current_word = []
49 | current_coord = None
50 | for token in with_coords:
51 | if current_coord is None or (token[1] != current_coord).any():
52 | if len(current_word) > 0:
53 | words.append(current_word)
54 | current_word = []
55 | current_coord = token[1]
56 | current_word.append(token)
57 | words.append(current_word)
58 |
59 |
60 | # plot with_coords as text on a matplotlib figure
61 |
62 | fig = plt.figure()
63 | # make fig very big
64 | fig.set_size_inches(20,20)
65 |
66 | ax = fig.add_subplot(111)
67 | ax.set_xlim([0,1])
68 | ax.set_ylim([0,1])
69 | ax.set_aspect('equal')
70 |
71 | for word in words:
72 | word_str = "".join(tokenizer.convert_ids_to_tokens([i[0] for i in word]))
73 | word_coord = word[0][1]
74 | # very small text
75 | ax.text(word_coord[0],-word_coord[1],word_str,fontsize=10)
76 |
77 | # save the figure
78 | fig.savefig("tokens_with_coords.png")
79 |
80 | */
81 |
82 | const words = coordTokens.reduce((acc,[x,y,tokenid])=>{
83 | if(acc.length === 0 || acc[acc.length-1].length === 0 || acc[acc.length-1][0][0] !== x || acc[acc.length-1][0][1] !== y){
84 | acc.push([])
85 | }
86 | acc[acc.length-1].push([x,y,tokenid])
87 | return acc
88 | },[])
89 |
90 | const wordStrings = words.map(word=>llamaTokenizer.decode([0,...word.map(([x,y,tokenid])=>tokenid)]))
91 |
92 | const wordCoords = words.map(word=>word[0].slice(0,2))
93 |
94 | // clear canvas, map onto canvas
95 | ctx.clearRect(0, 0, canvas.width, canvas.height);
96 | ctx.textAlign = "center";
97 | ctx.font = '10px monospace';
98 |
99 | const canvasCoords = wordCoords.map(([x,y])=>[x*canvas.width,(1-y)*canvas.height])
100 | wordCoords.forEach(([x,y],i)=>{
101 | const wordString = wordStrings[i];
102 | ctx.fillStyle = wordString.match(/^\[\d+\]/) ? 'red' : 'black';
103 | ctx.fillText(wordStrings[i],canvasCoords[i][0],canvasCoords[i][1])
104 | })
105 |
106 | // paste non-coord tokens into the pre
107 | // the first line is the prompt
108 | // the second line is the completion
109 | // find prompt vs. completion using firstIdxLastChunk
110 |
111 | const promptTokens = tokenIds.map((tokenId,i)=>[tokenId,coords[i][0],ix<0).filter(([_,__,b])=>b).map(([tokenId,x])=>tokenId)
112 | const completionTokens = [0,...tokenIds.slice(firstIdxLastChunk,lastIdx+1).filter(i=>i>0)];
113 |
114 | const promptString = llamaTokenizer.decode(promptTokens);
115 | const completionString = llamaTokenizer.decode(completionTokens);
116 |
117 | const output = document.getElementById('output');
118 | output.innerText = promptString + '\n' + completionString;
119 |
120 | console.log(llamaTokenizer.decode(tokenIds))
121 | }
122 |
123 | setTimeout(render, 500);
124 | }
--------------------------------------------------------------------------------
/llama2d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/llama2d.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 88
3 | target-version = ['py311']
4 | excludes = [
5 | ".venv",
6 | "venv",
7 | ".git",
8 | "build",
9 | "alembic",
10 | "transformers",
11 | "transformers/*",
12 | ]
13 |
14 | [build-system]
15 | requires = ["black"]
16 | build-backend = "setuptools.build_meta"
17 |
18 | [tool.isort]
19 | profile = "black"
20 | multi_line_output = 3
21 | include_trailing_comma = true
22 |
--------------------------------------------------------------------------------
/requirements.in:
--------------------------------------------------------------------------------
1 | huggingface_hub[cli,torch]
2 | datasets
3 | huggingface_hub
4 | langchain
5 |
6 | wandb
7 | matplotlib
8 | playwright
9 | selenium
10 |
11 | google-cloud-vision
12 | Pillow
13 | modal
14 |
15 | faiss-cpu
16 | sentencepiece
17 |
18 | torch
19 | nest-asyncio
20 | gdown
21 | peft
22 | pre-commit
23 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #
2 | # This file is autogenerated by pip-compile with Python 3.11
3 | # by the following command:
4 | #
5 | # pip-compile
6 | #
7 | accelerate==0.23.0
8 | # via peft
9 | aiohttp==3.8.5
10 | # via
11 | # datasets
12 | # fsspec
13 | # langchain
14 | # modal
15 | aiosignal==1.3.1
16 | # via aiohttp
17 | aiostream==0.5.0
18 | # via modal
19 | annotated-types==0.5.0
20 | # via pydantic
21 | anyio==3.7.1
22 | # via
23 | # fastapi
24 | # langchain
25 | # starlette
26 | # watchfiles
27 | appdirs==1.4.4
28 | # via wandb
29 | asgiref==3.7.2
30 | # via modal
31 | async-timeout==4.0.3
32 | # via aiohttp
33 | attrs==23.1.0
34 | # via
35 | # aiohttp
36 | # outcome
37 | # sigtools
38 | # trio
39 | beautifulsoup4==4.12.2
40 | # via gdown
41 | cachetools==5.3.1
42 | # via google-auth
43 | certifi==2023.7.22
44 | # via
45 | # modal
46 | # requests
47 | # selenium
48 | # sentry-sdk
49 | cfgv==3.4.0
50 | # via pre-commit
51 | charset-normalizer==3.2.0
52 | # via
53 | # aiohttp
54 | # requests
55 | click==8.1.7
56 | # via
57 | # modal
58 | # typer
59 | # wandb
60 | cloudpickle==2.2.1
61 | # via modal
62 | contourpy==1.1.1
63 | # via matplotlib
64 | cycler==0.11.0
65 | # via matplotlib
66 | dataclasses-json==0.6.1
67 | # via langchain
68 | datasets==2.14.5
69 | # via -r requirements.in
70 | dill==0.3.7
71 | # via
72 | # datasets
73 | # multiprocess
74 | distlib==0.3.7
75 | # via virtualenv
76 | docker-pycreds==0.4.0
77 | # via wandb
78 | faiss-cpu==1.7.4
79 | # via -r requirements.in
80 | fastapi==0.103.1
81 | # via modal
82 | filelock==3.12.4
83 | # via
84 | # gdown
85 | # huggingface-hub
86 | # torch
87 | # transformers
88 | # virtualenv
89 | fonttools==4.42.1
90 | # via matplotlib
91 | frozenlist==1.4.0
92 | # via
93 | # aiohttp
94 | # aiosignal
95 | fsspec[http]==2023.6.0
96 | # via
97 | # datasets
98 | # huggingface-hub
99 | gdown==4.7.1
100 | # via -r requirements.in
101 | gitdb==4.0.10
102 | # via gitpython
103 | gitpython==3.1.37
104 | # via wandb
105 | google-api-core[grpc]==2.12.0
106 | # via google-cloud-vision
107 | google-auth==2.23.1
108 | # via google-api-core
109 | google-cloud-vision==3.4.4
110 | # via -r requirements.in
111 | googleapis-common-protos==1.60.0
112 | # via
113 | # google-api-core
114 | # grpcio-status
115 | greenlet==2.0.2
116 | # via playwright
117 | grpcio==1.58.0
118 | # via
119 | # google-api-core
120 | # grpcio-status
121 | grpcio-status==1.58.0
122 | # via google-api-core
123 | grpclib==0.4.3
124 | # via modal
125 | h11==0.14.0
126 | # via wsproto
127 | h2==4.1.0
128 | # via grpclib
129 | hpack==4.0.0
130 | # via h2
131 | huggingface-hub[cli,torch]==0.17.3
132 | # via
133 | # -r requirements.in
134 | # accelerate
135 | # datasets
136 | # transformers
137 | hyperframe==6.0.1
138 | # via h2
139 | identify==2.5.29
140 | # via pre-commit
141 | idna==3.4
142 | # via
143 | # anyio
144 | # requests
145 | # trio
146 | # yarl
147 | importlib-metadata==6.8.0
148 | # via modal
149 | inquirerpy==0.3.4
150 | # via huggingface-hub
151 | jinja2==3.1.2
152 | # via torch
153 | jsonpatch==1.33
154 | # via langchain
155 | jsonpointer==2.4
156 | # via jsonpatch
157 | kiwisolver==1.4.5
158 | # via matplotlib
159 | langchain==0.0.304
160 | # via -r requirements.in
161 | langsmith==0.0.41
162 | # via langchain
163 | markdown-it-py==3.0.0
164 | # via rich
165 | markupsafe==2.1.3
166 | # via jinja2
167 | marshmallow==3.20.1
168 | # via dataclasses-json
169 | matplotlib==3.8.0
170 | # via -r requirements.in
171 | mdurl==0.1.2
172 | # via markdown-it-py
173 | modal==0.53.3665
174 | # via -r requirements.in
175 | mpmath==1.3.0
176 | # via sympy
177 | multidict==6.0.4
178 | # via
179 | # aiohttp
180 | # grpclib
181 | # yarl
182 | multiprocess==0.70.15
183 | # via datasets
184 | mypy-extensions==1.0.0
185 | # via typing-inspect
186 | nest-asyncio==1.5.8
187 | # via -r requirements.in
188 | networkx==3.1
189 | # via torch
190 | nodeenv==1.8.0
191 | # via pre-commit
192 | numexpr==2.8.7
193 | # via langchain
194 | numpy==1.26.0
195 | # via
196 | # accelerate
197 | # contourpy
198 | # datasets
199 | # langchain
200 | # matplotlib
201 | # numexpr
202 | # pandas
203 | # peft
204 | # pyarrow
205 | # transformers
206 | outcome==1.2.0
207 | # via trio
208 | packaging==23.1
209 | # via
210 | # accelerate
211 | # datasets
212 | # huggingface-hub
213 | # marshmallow
214 | # matplotlib
215 | # peft
216 | # transformers
217 | pandas==2.1.1
218 | # via datasets
219 | pathtools==0.1.2
220 | # via wandb
221 | peft==0.5.0
222 | # via -r requirements.in
223 | pfzy==0.3.4
224 | # via inquirerpy
225 | pillow==10.0.1
226 | # via
227 | # -r requirements.in
228 | # matplotlib
229 | platformdirs==3.10.0
230 | # via virtualenv
231 | playwright==1.38.0
232 | # via -r requirements.in
233 | pre-commit==3.4.0
234 | # via -r requirements.in
235 | prompt-toolkit==3.0.39
236 | # via inquirerpy
237 | proto-plus==1.22.3
238 | # via google-cloud-vision
239 | protobuf==4.24.3
240 | # via
241 | # google-api-core
242 | # google-cloud-vision
243 | # googleapis-common-protos
244 | # grpcio-status
245 | # modal
246 | # proto-plus
247 | # wandb
248 | psutil==5.9.5
249 | # via
250 | # accelerate
251 | # peft
252 | # wandb
253 | pyarrow==13.0.0
254 | # via datasets
255 | pyasn1==0.5.0
256 | # via
257 | # pyasn1-modules
258 | # rsa
259 | pyasn1-modules==0.3.0
260 | # via google-auth
261 | pydantic==2.4.1
262 | # via
263 | # fastapi
264 | # langchain
265 | # langsmith
266 | pydantic-core==2.10.1
267 | # via pydantic
268 | pyee==9.0.4
269 | # via playwright
270 | pygments==2.16.1
271 | # via rich
272 | pyparsing==3.1.1
273 | # via matplotlib
274 | pysocks==1.7.1
275 | # via
276 | # requests
277 | # urllib3
278 | python-dateutil==2.8.2
279 | # via
280 | # matplotlib
281 | # pandas
282 | pytz==2023.3.post1
283 | # via pandas
284 | pyyaml==6.0.1
285 | # via
286 | # accelerate
287 | # datasets
288 | # huggingface-hub
289 | # langchain
290 | # peft
291 | # pre-commit
292 | # transformers
293 | # wandb
294 | regex==2023.8.8
295 | # via transformers
296 | requests[socks]==2.31.0
297 | # via
298 | # datasets
299 | # fsspec
300 | # gdown
301 | # google-api-core
302 | # huggingface-hub
303 | # langchain
304 | # langsmith
305 | # transformers
306 | # wandb
307 | rich==13.5.3
308 | # via modal
309 | rsa==4.9
310 | # via google-auth
311 | safetensors==0.3.3
312 | # via
313 | # peft
314 | # transformers
315 | selenium==4.13.0
316 | # via -r requirements.in
317 | sentencepiece==0.1.99
318 | # via -r requirements.in
319 | sentry-sdk==1.31.0
320 | # via wandb
321 | setproctitle==1.3.2
322 | # via wandb
323 | sigtools==4.0.1
324 | # via synchronicity
325 | six==1.16.0
326 | # via
327 | # docker-pycreds
328 | # gdown
329 | # python-dateutil
330 | smmap==5.0.1
331 | # via gitdb
332 | sniffio==1.3.0
333 | # via
334 | # anyio
335 | # trio
336 | sortedcontainers==2.4.0
337 | # via trio
338 | soupsieve==2.5
339 | # via beautifulsoup4
340 | sqlalchemy==2.0.21
341 | # via langchain
342 | starlette==0.27.0
343 | # via fastapi
344 | sympy==1.12
345 | # via torch
346 | synchronicity==0.5.3
347 | # via modal
348 | tblib==2.0.0
349 | # via modal
350 | tenacity==8.2.3
351 | # via langchain
352 | tokenizers==0.13.3
353 | # via transformers
354 | toml==0.10.2
355 | # via modal
356 | torch==2.0.1
357 | # via
358 | # -r requirements.in
359 | # accelerate
360 | # huggingface-hub
361 | # peft
362 | tqdm==4.66.1
363 | # via
364 | # datasets
365 | # gdown
366 | # huggingface-hub
367 | # peft
368 | # transformers
369 | transformers==4.33.3
370 | # via peft
371 | trio==0.22.2
372 | # via
373 | # selenium
374 | # trio-websocket
375 | trio-websocket==0.11.1
376 | # via selenium
377 | typer==0.9.0
378 | # via modal
379 | types-certifi==2021.10.8.3
380 | # via modal
381 | types-toml==0.10.8.7
382 | # via modal
383 | typing-extensions==4.8.0
384 | # via
385 | # aiostream
386 | # fastapi
387 | # huggingface-hub
388 | # modal
389 | # pydantic
390 | # pydantic-core
391 | # pyee
392 | # sqlalchemy
393 | # torch
394 | # typer
395 | # typing-inspect
396 | typing-inspect==0.9.0
397 | # via dataclasses-json
398 | tzdata==2023.3
399 | # via pandas
400 | urllib3[socks]==2.0.5
401 | # via
402 | # google-auth
403 | # requests
404 | # selenium
405 | # sentry-sdk
406 | virtualenv==20.24.5
407 | # via pre-commit
408 | wandb==0.15.11
409 | # via -r requirements.in
410 | watchfiles==0.20.0
411 | # via modal
412 | wcwidth==0.2.6
413 | # via prompt-toolkit
414 | wsproto==1.2.0
415 | # via trio-websocket
416 | xxhash==3.3.0
417 | # via datasets
418 | yarl==1.9.2
419 | # via aiohttp
420 | zipp==3.17.0
421 | # via importlib-metadata
422 |
423 | # The following packages are considered to be unsafe in a requirements file:
424 | # setuptools
425 |
--------------------------------------------------------------------------------
/screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/screenshot.png
--------------------------------------------------------------------------------
/script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if ImageMagick is installed
4 | if ! command -v convert &> /dev/null; then
5 | echo "ImageMagick is not installed. Please install it before running this script."
6 | exit 1
7 | fi
8 |
9 | # Directory containing the images
10 | input_dir="."
11 |
12 | # Output GIF file name
13 | output_gif="output.gif"
14 |
15 | # Check if the input directory exists
16 | if [ ! -d "$input_dir" ]; then
17 | echo "Input directory not found: $input_dir"
18 | exit 1
19 | fi
20 |
21 | # Change to the input directory
22 | cd "$input_dir" || exit
23 |
24 | # Create the GIF from images 0.png through 8.png
25 | convert -delay 100 -loop 0 {0..8}.png "$output_gif"
26 |
27 | # Verify if the GIF creation was successful
28 | if [ $? -eq 0 ]; then
29 | echo "GIF file created successfully: $output_gif"
30 | else
31 | echo "Failed to create the GIF."
32 | fi
33 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | ROOT_DIR = Path(__file__).parent.parent.resolve()
4 | SRC_DIR = ROOT_DIR / "src"
5 |
--------------------------------------------------------------------------------
/src/data/.gitignore:
--------------------------------------------------------------------------------
1 | # hide all files in subdirectories
2 | */**/*
3 |
4 | # allow everything in the root directory
5 | !./*
--------------------------------------------------------------------------------
/src/data/mind2web_example.json:
--------------------------------------------------------------------------------
1 | {
2 | "pos_candidates": [
3 | {
4 | "attributes": "{\"backend_node_id\": \"136\", \"bounding_box_rect\": \"110,607.390625,264,78\", \"class\": \"MuiSelect-root MuiSelect-select jss31 MuiSelect-filled jss32 MuiInputBase-input MuiFilledInput-input jss22 MuiInputBase-inputAdornedStart MuiFilledInput-inputAdornedStart\", \"id\": \"reservations-city-search-type\", \"name\": \"type\", \"data_pw_testid_buckeye_candidate\": \"1\"}",
5 | "backend_node_id": "136",
6 | "is_original_target": true,
7 | "is_top_level_target": true,
8 | "tag": "select"
9 | }
10 | ]
11 | }
--------------------------------------------------------------------------------
/src/data/pretraining-cache/.gitignore:
--------------------------------------------------------------------------------
1 | **/*
2 | !.gitignore
--------------------------------------------------------------------------------
/src/data/pretraining_urls.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | current_dir = Path(__file__).parent
4 |
5 | with open(current_dir / "urls.txt", "r") as f:
6 | urls = f.read().splitlines()
7 |
--------------------------------------------------------------------------------
/src/data/urls.txt:
--------------------------------------------------------------------------------
1 | http://dannysellstampabay.com
2 | http://floridahomeswithcarrie.com
3 | https://www.tarynsellshouses.com/
4 | https://robin.homes/signup/alex.colley
5 | https://www.demirealestates.com/
6 | https://johngarcia.lpthomesearch.com/
7 | https://keepingfloridamoving.com/
8 | https://yeseniaalicea.lpthomesearch.com/
9 | https://diannephillips.lpthomesearch.com/
10 | https://angelwilkinson.lpthomesearch.com/
11 | https://loreanarealestate.com/
12 | https://www.judy-cortez.net/
13 | https://victorialegrow.lpthomesearch.com/
14 | https://tanyaveitch.lpthomesearch.com/
15 | http://therealgatorrealty.com
16 | https://neldagregory.lpthomesearch.com/
17 | http://jennysuncoast.com
18 | https://www.cendonrealtor.net/
19 | https://appletreerealty.com
20 | https://makersolutionsinc.com/
21 | http://www.thealsbrooksteam.com/
22 | https://www.jandjrealtygroup.com/
23 | https://www.beachpropertymanagerandsales.com/
24 | http://tampabayhomessearch.com
25 | http://www.sarasotasandy.com
26 | http://www.getrichhome.com
27 | https://robin.homes/signup/amy.tejeda
28 | https://robin.homes/signup/lisa.spencer2
29 | https://robin.homes/signup/azalia.vasquez
30 | https://www.hoppersellshomes.com/
31 | https://estradahomesales.com
32 | https://pier21realty.myrealestateplatform.com/
33 | http://wyserhomespreferred.com
34 | https://florencezimmerman.com/
35 | http://tampabayhomessearch.com
36 | http://therealrachelgoldman.com/
37 | https://robin.homes/signup/william.vergara
38 | https://clermontrealestate.rezora.com/listing/demomls/155497461/6
39 | https://www.lauraschulerrealtor.com/
40 | http://suewoodsellsleesburg.com
41 | https://joshuniversityrealty.com
42 | https://chelseacooper.lpthomesearch.com/
43 | https://larissafloridarealtor.com
44 | https://robin.homes/signup/stavrula.crafa
45 | https://www.themillsgroupkw.com/
46 | http://viewverobeachhomes.com
47 | http://www.sarasotasandy.com
48 | http://lauralynchristianhomes.com
49 | http://agnesrosehomes.com
50 | http://joycesellspalmcoast.com/
51 | http://venicefloridahomes.com
52 | http://jcsunrisehomes.com
53 | http://livinparadiserealty.com
54 | https://ritapellens.com/
55 | http://therealbrianwalsh.com
56 | http://www.everydayintampabay.com
57 | https://stefaniewargo.lpthomesearch.com/
58 | https://mariatroncoso.lpthomesearch.com
59 | https://axelrodriguez.lpthomesearch.com/
60 | https://buyandsellrealestateinflorida.com/
61 | https://angelacardona.lpthomesearch.com/
62 | https://trystanfoglia.lpthomesearch.com/
63 | https://www.mariacastrellonrealtor.net
64 | https://yaecolon.tampabayhome.com/
65 | https://zidniaayala.lpthomesearch.com/
66 | https://mattsellsbrevard.com/
67 | https://mirianriera1.lpthomesearch.com/
68 | https://lilianaoviedo.lpthomesearch.com/
69 | https://sandravargas.lpthomesearch.com/
70 | https://josephbarnes.lpthomesearch.com/
71 | https://loriwilson.lpthomesearch.com/
72 | https://sierrarealtyfl.com/
73 | http://katiuskaquinterorealtor.com/
74 | https://yourkeytoflorida.com/
75 | https://andreabishop.lpthomesearch.com/
76 | https://davianmedina.unrealty.com/
77 | https://closewithkhalid.net/
78 | https://edeliosanchez.expchime.com/
79 | https://matthewlester.lpthomesearch.com/
80 | https://swflhomesales.com/
81 | https://homesbyyessy.com/
82 | http://www.cevhomes.com/
83 | https://robin.homes/signup/daniel.hut
84 | https://robin.homes/signup/irene.guy
85 | http://brendaefffect.com
86 | https://robin.homes/signup/rachael.corry
87 | http://www.eddierealty.com/
88 | https://destination.myrealtyonegroup.com/
89 | http://www.nazarbasrealtor.com
90 | http://www.theresilienthomegroup.com
91 | https://robin.homes/signup/jeevan.hanuman
92 | http://www.thepensalagroup.com/
93 | http://www.gulfcoastintegritygroup.com/
94 | https://robin.homes/signup/alexis.willims
95 | https://robin.homes/signup/valentino.sanchez
96 | http://pgpcrealty.com
97 | http://buyandsellpolkhomes.com
98 | https://omarandreasen.com/
99 | https://rebeccareadusre.com
100 | http://www.nazarbasrealtor.com
101 | https://robin.homes/signup/chodry.andre1
102 | http://realty.com/siesta-key-fl
103 | http://goldenclassrealty.com
104 | http://www.homeasap.com/1564211
105 | https://robin.homes/signup/valentina.cappetta
106 | https://robin.homes/signup/justin.owens
107 | https://robin.homes/signup/tyler.beasley
108 | http://buymultifamilyinvestments.com
109 | https://robin.homes/signup/karla.joneswilson2
110 | https://robin.homes/signup/sara.schneider
111 | http://www.homeasap.com/354522
112 | https://www.jandjrealtygroup.com
113 | http://thadismyrealtor.com
114 | https://karladeleon.lpthomesearch.com/
115 | https://gomezhomegroup.com/
116 | https://robin.homes/signup/krystal.crichlow
117 | https://robin.homes/signup/ashley.vanpelt
118 | https://blakeesekie.lpthomesearch.com
119 | http://www.homeasap.com/1630578
120 | http://www.homeasap.com/94340
121 | https://robin.homes/signup/justin.owens
122 | https://www.jesvalrealestate.com/
123 | https://www.kaizen-realty.com/
124 | https://www.yourrealtyspecialist.com/
125 | https://www.livelocalre.com
126 | http://www.tjcosgrove.com/
127 | http://danjoproperties.com
128 | https://robin.homes/signup/anabely.delatorre1
129 | http://www.homeasap.com/161009
130 | https://robin.homes/signup/elena.sherstikova
131 | https://taquishamccluster.lpthomesearch.com
132 | https://ahsguarantee.com
133 | http://www.homeasap.com/1647338
134 | https://argeliavidal.com/
135 | https://tamiamirealtyllc.com
136 | https://robin.homes/signup/amy.brocco
137 | https://robin.homes/signup/pharah.dutrevil
138 | https://taylor-smalley.com/
139 | http://mermaizinghomes.com
140 | http://properties4saleinflorida.com
141 | http://hillsboroughcountyhomes4sale.co
142 | http://tampabayareahomesforsales.com
143 | https://robin.homes/signup/suzanne.dickson2
144 | https://robin.homes/signup/david.ponte
145 | https://robin.homes/signup/emily.kirshaw
146 | https://robin.homes/signup/azalia.vasquez
147 | http://www.homeasap.com/1047411
148 | https://hirethepirate.com
149 | http://lsteuberrealty.com
150 | https://www.srqareahomefinder.com
151 | http://www.bluesunrealty.com/
152 | https://robin.homes/signup/david.brown12
153 | http://jessicalipprealty.com
154 | https://robin.homes/signup/lori.moses
155 | https://robin.homes/signup/stacy.bracewell
156 | https://robin.homes/signup/dana.lincolnpa
157 | https://robin.homes/signup/maurice.johnson3
158 | https://robin.homes/signup/gracemary.guastella
159 | https://robin.homes/signup/angelo.marcello
160 | http://www.homeasap.com/1645736
161 | https://robin.homes/signup/julie.sbrocco
162 | https://robin.homes/signup/missy.mcamis
163 | https://robin.homes/signup/lissette.sanchez
164 | https://robin.homes/signup/lisa.kelly2
165 | https://robin.homes/signup/bobbie.robinson
166 | https://robin.homes/signup/denise.becker
167 | https://robin.homes/signup/ashley.cooley
168 | https://robin.homes/signup/brandon.johnson
169 | https://robin.homes/signup/khalid.inshan
170 | https://robin.homes/signup/milton.figueroa
171 | https://robin.homes/signup/laura.rodrigueztello
172 | https://robin.homes/signup/ha.benacquisto
173 | https://robin.homes/signup/don.latimer
174 | https://robin.homes/signup/francesca.wilson
175 | https://robin.homes/signup/lsalma.abdelaal
176 | http://viewveniceflhomes.com
177 | https://www.yourwayhome.net
178 | https://www.lakebrantleyhomes.com/
179 | https://danesacolon.lpthomesearch.com
180 | http://viewbradentonflhomes.com
181 | https://issaygonzalez.lpthomesearch.com/
182 | https://robin.homes/signup/elizabethdesiree.morales1
183 | https://yanirasuarez.lpthomesearch.com/
184 | http://tampabayareahomesforsale.com
185 | http://www.homeasap.com/1331654
186 | http://topleesburgrealestate.com
187 | http://jillwillsell.com/
188 | http://jennifer-sims.elevatesite.com
189 | https://unrealty.com/
190 | https://robin.homes/signup/norma.gonsalves
191 | https://robin.homes/signup/ted.moseley
192 | https://robin.homes/signup/kellie.birmingham
193 | http://www.homeasap.com/881410
194 | https://cynthiaporpora.lpthomesearch.com
195 | https://binghamrealtyinc.com/
196 | http://www.themeadowsteam.com/
197 | https://sage-chaja-eedcbf.netlify.app/
198 | https://sage-chaja-eedcbf.netlify.app/
199 | https://www.bhhsfloridarealty.com/
200 | https://searchpalmharbor.com
201 | https://tkc-platinum-properties.findme.homes/
202 | http://projectmyhomeflorida.com
203 | http://www.earlsellstampa.com
204 | http://paigeboothrealty.com
205 | http://www.homesweettampabay.com
206 | https://ruthiearchie.myhomehq.biz
207 | http://mykwgb.com
208 | https://championsgate.realtor
209 | http://saintpetersburghomesfl.com
210 | http://www.homeasap.com/640516
211 | http://tampa-homesforsale.com
212 | http://genevievesproperties.com
213 | https://robin.homes/signup/abraham.mendez1
214 | https://seminoleheightsliving.com/
215 | http://www.homeasap.com/632713
216 | http://www.blakerealestate.com/
217 | http://www.homeasap.com/984715
218 | http://www.homeasap.com/619950
219 | http://lizcarvalho.propertyportalmarketing.com
220 | http://thebucketlistteam.com
221 | https://scottbryant.lpthomesearch.com
222 | http://flynnsellsflorida.com
223 | http://www.homeasap.com/1645802
224 | http://www.homeasap.com/1645492
225 | http://www.homeasap.com/1643639
226 | http://www.thetampapropertyfinder.com/
227 | https://robin.homes/signup/chodry.andre1
228 | https://robin.homes/signup/william.dibernardo1
229 | https://robin.homes/signup/lisa.eichenblatt
230 | https://robin.homes/signup/jeannette.mcintosh
231 | http://www.homeasap.com/1646297
232 | http://southtampasweethome.com
233 | http://kathycongdonhomessold.com
234 | https://bricksfolios.inbestments.com/
235 | https://www.inspiredpropertiessrq.com
236 | https://robin.homes/signup/liliana.lassalle
237 | http://stpetetropical.net
238 | https://peoplearemypassion.com/
239 | https://www.lpsantos.com
240 | http://www.briannacapuano.com/
241 | https://neighborhood-professionals-101944491.remax.com
242 | https://local-expert-101937672.remax.com
243 | https://capital-realty-100430055.remax.com
244 | https://legacy-100430027.remax.com
245 | https://tropical-sands-100429845.remax.com
246 | https://domenicaaraguache.lpthomesearch.com/
247 | http://kbhomesrealty1.com
248 | http://drivingfloridahome.net
249 | https://robin.homes/signup/anabely.delatorre1
250 | https://robin.homes/signup/kristal.saladin
251 | http://davidhgoodii.com
252 | http://robyncavallaro.com/
253 | https://www.ezhomesearch.com
254 | https://bursonhomeadvisors.com
255 | http://www.residethrivetampa.com/
256 | https://www.TurnerPropertyMgmt.com
257 | https://mensnyoreste1.lpthomesearch.com/
258 | https://victordeleon.rogtampabay.com
259 | http://www.valeriayafferealtor.com
260 | https://doreenlandi.info/
261 | http://AmandaAlligoodsellsfl.com
262 | http://homesbydonnawilliams.com
263 | http://realty.com/lakewood-ranch-fl
264 | https://www.joenewstreet.com/
265 | https://robin.homes/signup/cecilia.cabrales
266 | http://franciscoromerorealtor.com/
267 | https://robin.homes/signup/danielle.kielpikowski1
268 | http://www.susanbenante.com
269 | https://hidalisnunez.lpthomesearch.com/
270 | https://greaterlakelandhomes.com/
271 | https://orlandoandbeyond.com/
272 | https://www.sarasotadreamlifestyle.com/
273 | https://elizabethcolon.tampabayhome.com
274 | https://davidnpacheco.lpthomesearch.com/
275 | https://charliesantos.expchime.com
276 | https://ernestoperez1.lpthomesearch.com/ComplianceCheck/active/586
277 | https://robin.homes/signup/lily.aymat
278 | https://robin.homes/signup/jeevan.hanuman
279 | https://robin.homes/signup/cheryl.burcham
280 | https://chrislarue.lpthomesearch.com/
281 | https://danielpaz.lpthomesearch.com/
282 | http://darbiepfeifferrealestate.com
283 | http://johnkelleyflhomes.com
284 | http://merlybuysandsells.com
285 | http://justinbrandonhomes.com
286 | http://www.homeasap.com/659994
287 | http://remaxassured.comandzinnoteam.com
288 | http://www.isellbabcockranch.com
289 | https://floridarealtortony.com
290 | https://staugustine.evrealestate.com/
291 | https://robin.homes/signup/michael.bellamy1
292 | https://robin.homes/signup/morgan.porter1
293 | https://robin.homes/signup/maria.tilton
294 | https://robin.homes/signup/tennille.moore1
295 | https://robin.homes/signup/bianca.pineda
296 | https://p-33d82e42-351e-4188-b36b-11ae45b6ac8c.presencepreview.site/
297 | https://www.cathyrunningrealtor.com/
298 | https://valerusre.com/
299 | http://buyorsellsouthwestfloridahomes.com
300 | http://buyeorselleastoralndohomes.com
301 | https://www.mvprealty.com/
302 | http://tourtampabayhomes.com
303 | http://ltrhomes.com
304 | https://vanderleelie.com
305 | https://janellepruitt.realtor
306 | https://karuna.realestate
307 | http://helensfloridahomes.com
308 | http://www.fivestarflorida.com/
309 | http://DeannaBradley.com
310 | http://integrity1stgroup.com/
311 | https://dallascrider.lpthomesearch.com/
312 | http://monopolygre.com
313 | https://www.mysahomes.net
314 | http://vanderleelie.com
315 | https://robin.homes/signup/gemma.peterson
316 | https://robin.homes/signup/alex.estevez
317 | https://bursonhomeadvisors.com
318 | https://www.elliman.com
319 | http://williamroganrealtor.com/
320 | http://www.homeasap.com/14523
321 | https://valeriemcinerney.sarasotarealestatehub.com/
322 | https://pinpointrealtyfl.com
323 | https://www.mcsellsmanatee.com/
324 | https://robin.homes/signup/paul.mcdonald
325 | http://www.homeasap.com/1637509
326 | https://danrojas.lpthomesearch.com
327 | https://reganpappas.com
328 | https://priscillaarzivian.lpthomesearch.com/
329 | http://www.sunwestrealtyflorida.com/
330 | https://www.gibbsgrouptampa.com/
331 | http://www.homeasap.com/1401384
332 | https://robin.homes/signup/nan.robinson
333 | http://thelaygroup.com/
334 | http://www.ivanaldearealty.com
335 | https://usa.premmedia.com/better_homes_and_gardens_flagler_beach/
336 | https://www.baywestrealtygroup.com
337 | https://robin.homes/signup/christine.nargi
338 | https://www.garyberkson.com/
339 | http://BillandGinger.com
340 | https://robin.homes/signup/evan.devorace
341 | https://gabrielhoyos.lpthomesearch.com.lpthomesearch.com
342 | https://robin.homes/signup/mariaelena.martinez
343 | https://robin.homes/signup/nicole.musgrave1
344 | https://robin.homes/signup/janet.mansfield
345 | https://robin.homes/signup/myriah.schifley
346 | https://robin.homes/signup/muneera.mohamed
347 | https://brysonwalters.findhomesintampaflorida.com/
348 | http://jeanjannrealty.com
349 | https://johson.com
350 | http://amysellsbrevard.com
351 | http://seacowobsessedagent.com
352 | https://valriggshomes.com/
353 | https://www.rbfloridahomes.com
354 | https://searchswflhomesforsale.com
355 | https://robin.homes/signup/ann.osullivan
356 | https://robin.homes/signup/mary.blinkhorn
357 | https://robin.homes/signup/debbie.snowden
358 | https://robin.homes/signup/brent.canevari
359 | https://jillanayas.com/
360 | https://robin.homes/signup/william.dibernardo
361 | http://www.foxxteam.com/
362 | http://www.homeasap.com/768207
363 | http://soldbyjosh.net
364 | https://westerberggroup.com/
365 | http://myfloridarealestateforyou.com
366 | https://stephanieeisenbach.lpthomesearch.com/
367 | https://davidcardona.lpthomesearch.com/
368 | https://samdiasrealestate.com
369 | https://monicadiazquiroz.lpthomesearch.com
370 | https://anthonyrussell.lpthomesearch.com/
371 | https://michaelatate.myrealestateplatform.com/
372 | https://robin.homes/signup/jeanna.jackson1
373 | https://robin.homes/signup/alana.ohanlan1
374 | https://marlolaney.lpthomesearch.com
375 | https://www.hoppersellshomes.com/
376 | https://www.anchorrealtorgroup.com/
377 | https://lisbethbetizagatti.lpthomesearch.com/
378 | https://victoriatejeda1.lpthomesearch.com/
379 | http://www.teambelmonte.com
380 | https://robin.homes/signup/dinorat.querales
381 | http://buysellliveorlando.com
382 | https://mandjphamdevelopment.com/
383 | https://gabriellechavez.lpthomesearch.com/
384 | https://shannonhartrealtor.com/
385 | http://www.jbricksrealty.com
386 | http://wrarealestate.com
387 | https://www.livingdilife.com
388 | https://robin.homes/signup/olga.sexson
389 | https://robin.homes/signup/lucia.yang
390 | https://paulapalomino1.lpthomesearch.com/
391 | https://www.succesrealtyco.net/
392 | https://janicerodriguez.lpthomesearch.com/
393 | https://johnarroyo.lpthomesearch.com/
394 | https://teamchristie.lpthomesearch.com
395 | https://robin.homes/signup/fanny.horn
396 | https://robin.homes/signup/kayla.durias
397 | http://searchvenicehomesfl.com
398 | https://michaelatate.myrealestateplatform.com/
399 | https://www.yournaturecoasthomesearch.com/
400 | https://www.yourhomegirlbeth.net/
401 | http://www.homeasap.com/1638771
402 | https://unrealtyflorida.unrealty.com
403 | http://jazzysellsflorida.com
404 | https://amandatheiler.chime.me
405 | https://timeisoftheessencewithtosha.estate
406 | https://jessicaalopaeus.lpthomesearch.com/
407 | https://theyinglingteam.com
408 | https://www.paradisegrpfl.com
409 | https://adamrobinson.lpthomesearch.com
410 | https://sarahsmith.lpthomesearch.com/
411 | https://robertcruz.lpthomesearch.com
412 | https://zerahcruz.lpthomesearch.com
413 | http://www.iselldelandflorida.com/
414 | https://www.jbsellshomes.com/
415 | http://www.trusshomessarasota.com/
416 | http://www.sarasotarealtor.com/
417 | https://fcrg.backagent.net/
418 | https://johnlstrauss.lpthomesearch.com/
419 | http://matikglobalrealty.com
420 | https://usa.premmedia.com/better_homes_and_gardens_new_smyrna_beach/
421 | http://www.1percentlistfl.com
422 | https://www.terraexcelsior.com/
423 | https://robin.homes/signup/silvia.mozer
424 | http://www.homeasap.com/1636913
425 | https://robin.homes/signup/daisy.gonzalez
426 | http://www.homeasap.com/1636606
427 | http://veronicawhittingtonhomes.com
428 | http://www.livingdilife.com
429 | https://themarkrameygroup.com
430 | https://robin.homes/signup/jeevan.hanuman
431 | http://www.nelsoncruzteam.com/
432 | https://thewilchergroup.com
433 | https://homeasap.com/856425
434 | https://robin.homes/signup/nathan.jacoby
435 | https://sandranaumovski.lpthomesearch.com/
436 | https://yetseniamtorres.lpthomesearch.com/
437 | https://robin.homes/signup/mark.langley
438 | https://www.caseytranrealestate.com/
439 | https://robin.homes/signup/leidy.lara
440 | http://emeraldrealtycofl.com
441 | http://sagegainesville.com/
442 | https://helloreeve.com
443 | https://windermereintrealty.com
444 | http://listwithpeteandcheryl.com
445 | http://janicesellsorlando.com
446 | http://zoyasellsflorida.com
447 | http://livingabundantlygroup.com
448 |
--------------------------------------------------------------------------------
/src/llama2d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/src/llama2d/__init__.py
--------------------------------------------------------------------------------
/src/llama2d/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | ROOT_DIR = Path(__file__).parent.parent.parent.resolve()
4 |
5 | # 3 times the resolution of a 1080p monitor
6 | SCREEN_RESOLUTION = (1280, 1080 * 3)
7 |
8 | DATA_DIR = ROOT_DIR / "data"
9 |
10 | MIND2WEB_MHTML_DIR = DATA_DIR / "mind2web-mhtml"
11 | MIND2WEB_HHTML_DIR = DATA_DIR / "mind2web-hhtml"
12 |
13 | MIND2WEB_OUT_DIR = DATA_DIR / "mind2web-out"
14 | MIND2WEB_IN_DIR = DATA_DIR / "mind2web-in"
15 | MIND2WEB_VIZ_DIR = DATA_DIR / "mind2web-viz"
16 |
17 | MIND2WEB_CACHE_DIR = DATA_DIR / "mind2web-cache"
18 | PRETRAINING_CACHE_DIR = DATA_DIR / "pretraining-cache"
19 |
20 | # path to the Google Cloud credentials file
21 | SECRETS_FILE = ROOT_DIR / "secrets" / "gcp-vision.json"
22 |
23 | # max number of tokens allowed in a page screenshot
24 | # we will remove all page tokens after this number
25 | MAX_PAGE_LEN = 1000
26 |
27 | # max number of tokens inputted to Llama2d - between prompt, page, and completion
28 | # we will truncate big inputs to this number
29 | # we will also pad small inputs to this number
30 | MAX_SEQ_LEN = 300
31 |
32 | MAX_TAGS_LEN = 150
33 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/src/llama2d/datasets/__init__.py
--------------------------------------------------------------------------------
/src/llama2d/datasets/cached.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def save_dataset(dataset, save_dir: Path):
9 | # make the directory if it doesn't exist
10 | save_dir.mkdir(parents=True, exist_ok=True)
11 |
12 | for i in range(len(dataset)):
13 | torch.save(dataset[i], save_dir / f"{i}.pt")
14 |
15 |
16 | class CachedDataset(Dataset):
17 | def __init__(self, load_dir, use_2d=True, keep_fraction=1.0):
18 | self.load_dir = load_dir
19 | self.files = sorted(glob(f"{load_dir}/*.pt"))
20 | self.use_2d = use_2d
21 | self.keep_fraction = keep_fraction
22 |
23 | def __getitem__(self, i):
24 | ret = torch.load(self.files[i])
25 | # if not self.use_2d:
26 | # return {k: v for k, v in ret.items() if k != "coords"}
27 | return {k: v.to(torch.bfloat16) if k == "coords" else v for k, v in ret.items()}
28 |
29 | def __len__(self):
30 | return int(len(self.files) * self.keep_fraction)
31 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/huggingface.py:
--------------------------------------------------------------------------------
1 | import types
2 | from dataclasses import dataclass
3 | from time import time
4 |
5 | import numpy as np
6 | import torch
7 | from datasets import Dataset
8 | from torch.utils import data
9 |
10 | #
11 | from llama2d.datasets.cached import CachedDataset
12 |
13 |
14 | @dataclass
15 | class DatasetInfo:
16 | repo: str
17 | desc: str
18 |
19 |
20 | def dataset_dict_to_list(dataset_dict):
21 | """
22 | Converts a Torch dataset stored as a dictionary to a list of dictionaries.
23 |
24 | Args:
25 | dataset_dict (dict): The input dataset dictionary with keys 'input_ids', 'coords', 'labels', and 'attention_mask'.
26 |
27 | Returns:
28 | list: A list of dictionaries where each dictionary contains values for the keys at each index.
29 | """
30 | keys = dataset_dict.keys()
31 | num_samples = len(dataset_dict[list(keys)[0]])
32 | # Assuming all keys have the same length
33 | dataset_list = []
34 | for i in range(num_samples):
35 | sample_dict = dict.fromkeys(keys)
36 | for key in keys:
37 | sample_dict[key] = dataset_dict[key][i]
38 | dataset_list.append(sample_dict)
39 | return dataset_list
40 |
41 |
42 | def to(a, device: torch.device):
43 | if torch.is_tensor(a):
44 | return a.to(device)
45 | elif isinstance(a, dict):
46 | return {k: to(v, device) for k, v in a.items()}
47 | elif isinstance(a, (list, tuple)):
48 | return type(a)(to(v, device) for v in a)
49 | else:
50 | return a
51 |
52 |
53 | from tqdm import tqdm
54 |
55 |
56 | def pt2hf(torch_dataset: data.Dataset, convert_type: types = torch.float32):
57 | torch_dataset = [el for el in tqdm(torch_dataset) if el is not None]
58 | if convert_type is not None:
59 | torch_dataset = to(torch_dataset, convert_type)
60 | # import pdb; pdb.set_trace()
61 | try:
62 | dset_hf = Dataset.from_list(torch_dataset)
63 | except Exception as e:
64 | print(f"Exception while converting to hf dataset: {e}")
65 | import pdb
66 |
67 | pdb.set_trace()
68 | return dset_hf
69 |
70 |
71 | def publish_pt_dataset(ds_pt, dataset_info):
72 | try:
73 | ds = pt2hf(ds_pt) # may require setting: convert_type=np.float32
74 | print(f"Dataset type:{ds}")
75 | ds.info.description = dataset_info.desc
76 | ds.set_format(type="torch", columns=list(ds[0].keys()))
77 | ds.push_to_hub(dataset_info.repo)
78 | print(f"Push succeeded.")
79 | except Exception as e:
80 | print(f"Exception while publishing: {e}")
81 | raise e
82 |
83 |
84 | import torch
85 | from datasets import load_dataset
86 |
87 | dtypes = {
88 | "coords": torch.float16,
89 | "input_ids": torch.int64,
90 | "labels": torch.int64,
91 | "attention_mask": torch.int64,
92 | }
93 |
94 |
95 | class HuggingFaceDataset(torch.utils.data.Dataset):
96 | def __init__(
97 | self, repo: str, split: str, keep_fraction: float = 1.0, use_2d: bool = True
98 | ):
99 | print("Loading dataset...")
100 | start_time = time()
101 |
102 | hf_dataset = load_dataset(repo)
103 |
104 | print(f"Loaded dataset in {time()-start_time} seconds.")
105 | # dataset = [d for d in dataset if d is not None and sum([1 for i in d["labels"] if i>0])>0]
106 | df = hf_dataset["train"].to_pandas()
107 | df_filtered = df[df.labels.apply(lambda x: np.sum(np.array(x[::-1]) > 0) > 0)]
108 |
109 | dataset = Dataset.from_pandas(df_filtered)
110 |
111 | # split into train/val
112 | train_percent = 95
113 | train_size = int(len(dataset) * train_percent / 100)
114 | val_size = len(dataset) - train_size
115 | train_dataset, val_dataset = torch.utils.data.random_split(
116 | dataset, [train_size, val_size]
117 | )
118 |
119 | self.dataset = train_dataset if split == "train" else val_dataset
120 |
121 | # keep only a fraction of the dataset
122 | if keep_fraction < 1.0:
123 | self.dataset = torch.utils.data.Subset(
124 | self.dataset, range(int(len(self.dataset) * keep_fraction))
125 | )
126 |
127 | self.use_2d = use_2d
128 |
129 | def __getitem__(self, index):
130 | hf_dict = self.dataset[index]
131 |
132 | # convert to torch tensors
133 | ret = {k: torch.tensor(v, dtype=dtypes[k]) for k, v in hf_dict.items()}
134 |
135 | # if not self.use_2d:
136 | # del ret["coords"]
137 |
138 | return ret
139 |
140 | def __len__(self):
141 | return len(self.dataset)
142 |
143 |
144 | if __name__ == "__main__":
145 | import argparse
146 |
147 | from ..constants import PRETRAINING_CACHE_DIR
148 |
149 | parser = argparse.ArgumentParser(description="Description of your script")
150 | # Argument 1: First argument (e.g., input file)
151 | parser.add_argument(
152 | "-C",
153 | "--cache_dir",
154 | type=str,
155 | default=PRETRAINING_CACHE_DIR,
156 | help="Cache directory",
157 | )
158 | # Argument 2: Second argument (e.g., output file)
159 | parser.add_argument(
160 | "-R",
161 | "--repo",
162 | default="supermomo668/Llama2D-Pretrain",
163 | type=str,
164 | help="Name of Repo",
165 | )
166 | # Argument 2: Second argument (e.g., output file)
167 | parser.add_argument(
168 | "-D",
169 | "--desc",
170 | default="Llama2D is a project from AGI UI/UX Hackathon. Check our main Git Repo at : https://github.com/Llama2D/llama2d/tree/main",
171 | type=str,
172 | help="Name of Repo",
173 | )
174 |
175 | args = parser.parse_args()
176 | ds_pt = CachedDataset(args.cache_dir)
177 | publish_pt_dataset(ds_pt, args)
178 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/mhtml_to_hhtml.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | from llama2d.constants import MIND2WEB_HHTML_DIR, MIND2WEB_MHTML_DIR
4 |
5 | mhtml_files = [f for f in MIND2WEB_MHTML_DIR.iterdir() if f.suffix == ".mhtml"]
6 |
7 | for mhtml_filename in tqdm(mhtml_files):
8 | # print(mhtml_filename)
9 | mhtml_path = MIND2WEB_MHTML_DIR / mhtml_filename
10 | html_path = MIND2WEB_HHTML_DIR / mhtml_filename.with_suffix(".html")
11 |
12 | if html_path.exists():
13 | html_path.unlink()
14 |
15 | mhtml_content = open(mhtml_path, "r").read()
16 | hhtml_content = mhtml_content.replace(":hover", ".hvvvr")
17 |
18 | with open(html_path, "w") as f:
19 | f.write(hhtml_content)
20 |
21 | print("Done!")
22 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/mind2web.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from glob import glob
4 | from random import random
5 | from time import sleep
6 | from typing import Dict
7 |
8 | import torch
9 | from datasets import load_dataset
10 | from playwright.sync_api import sync_playwright
11 | from torch.utils.data import Dataset
12 |
13 | from llama2d.constants import MIND2WEB_MHTML_DIR, SCREEN_RESOLUTION
14 | from llama2d.datasets.huggingface import DatasetInfo, publish_pt_dataset
15 | from llama2d.tagging.add_tags_to_page import add_tags_to_webpage
16 | from llama2d.vision.take_screenshot import take_screenshot
17 | from llama2d.vision.url_to_llama_input import Llama2dWebsiteFeatureExtractor
18 | from llama2d.vision.viz_pt_input import debug_dataset
19 |
20 | should_debug = False
21 |
22 |
23 | class Mind2webDataset(Dataset):
24 | def __init__(
25 | self, model="decapoda-research/llama-7b-hf", playwright=None, headless=False,show_errors=False
26 | ):
27 | assert playwright is not None, "Please pass in playwright"
28 | self.__extractor = Llama2dWebsiteFeatureExtractor(mask_out_body=True)
29 |
30 | self.uid_to_mhtml = self.get_uid_to_mhtml_map()
31 |
32 | dataset = load_dataset("osunlp/Mind2Web")
33 | self.dataset = dataset["train"]
34 |
35 | self.actions = [
36 | (i, j)
37 | for i in range(len(self.dataset))
38 | for j in range(len(self.dataset[i]["actions"]))
39 | ]
40 |
41 | self.browser = playwright.chromium.launch(
42 | headless=headless, args=["--disable-web-security"]
43 | )
44 | self.page = self.browser.new_page()
45 |
46 | width, height = SCREEN_RESOLUTION
47 | self.page.set_viewport_size({"width": width, "height": height})
48 |
49 | self.page.set_extra_http_headers(
50 | {
51 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
52 | "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 "
53 | "Safari/537.36"
54 | }
55 | )
56 | self.page.set_default_navigation_timeout(1000 * 10)
57 | self.show_errors = show_errors
58 |
59 | def __len__(self):
60 | return len(self.actions)
61 |
62 | def __getitem__(self, index):
63 | screenshot_path = None
64 | try:
65 | task_idx, action_idx = self.actions[index]
66 | task = self.dataset[task_idx]
67 | action = task["actions"][action_idx]
68 |
69 | pos_candidates = action["pos_candidates"]
70 | if len(pos_candidates) == 0:
71 | raise Exception("No positive candidates in dataset!")
72 |
73 | uid = action["action_uid"]
74 | mhtml_file = self.uid_to_mhtml[uid]
75 |
76 | mhtml_file_name = mhtml_file.split("/")[-1]
77 | mhtml_file = "http://localhost:5002/" + mhtml_file_name
78 | self.page.goto(mhtml_file)
79 | sleep(1)
80 |
81 | gt_tag, tags_and_boxes = add_tags_to_webpage(self.page, action)
82 |
83 | rand_num = random()
84 | screenshot_path = f"screenshot_{rand_num}.png"
85 | take_screenshot(self.page, None, screenshot_path)
86 |
87 | self.page.evaluate("window.demo()")
88 | take_screenshot(self.page, None, "screenshot.png")
89 |
90 | intention = task["confirmed_task"]
91 |
92 | actions_str = "\n".join(task["action_reprs"])
93 | prompt = f"""
94 | You are a bot using a website. Your goal is: "{intention}"
95 | {"So far, you have done the following actions: "
96 | +actions_str if len(actions_str) > 0 else ""}
97 | The website looks like so:"""
98 |
99 | operation = action["operation"]
100 | op = operation["op"]
101 | value = operation["value"]
102 |
103 | completion = None
104 | if op == "CLICK":
105 | completion = f"CLICK [{gt_tag}]"
106 | elif op == "TYPE":
107 | completion = f"TYPE [{gt_tag}] {json.dumps(value)}"
108 | elif op == "SELECT":
109 | completion = f"SELECT [{gt_tag}]"
110 | else:
111 | raise NotImplementedError(f"Don't understand operation {op}")
112 |
113 | ret = self.__extractor.process(
114 | prompt, screenshot_path, completion, tags_and_boxes=tags_and_boxes
115 | )
116 |
117 | # delete the screenshot
118 | os.remove(screenshot_path)
119 |
120 | return ret
121 | except Exception as e:
122 | # raise e
123 | if self.show_errors:
124 | print("Error in dataset:", str(e)[:100] + "...")
125 |
126 | if "ImageAnnotation" in str(e):
127 | raise e
128 |
129 | if screenshot_path is not None:
130 | if os.path.exists(screenshot_path):
131 | os.remove(screenshot_path)
132 | return None
133 |
134 | def get_uid_to_mhtml_map(self) -> Dict[str, str]:
135 | all_mhtmls = glob(f"{MIND2WEB_MHTML_DIR}/*_before.mhtml")
136 | print("mhtml count:", len(all_mhtmls))
137 |
138 | # extract the uid from *_before.mhtml
139 | def get_uid(path):
140 | return path.split("/")[-1].split("_")[0]
141 |
142 | return {get_uid(path): path for path in all_mhtmls}
143 |
144 |
145 | mind2web_repo = "llama2d/llama2d-mind2web"
146 |
147 | if __name__ == "__main__":
148 | ds_info = DatasetInfo(
149 | repo=mind2web_repo,
150 | desc="Llama2d Mind2Web dataset - SFT dataset for"
151 | " tag interaction on diverse websites",
152 | )
153 |
154 | with sync_playwright() as playwright:
155 | dataset = Mind2webDataset(playwright=playwright, headless=True)
156 |
157 | # debug_dataset(dataset)
158 |
159 | # publish a subset
160 | num_samples = 2_000
161 |
162 | if num_samples is not None:
163 | dataset, _ = torch.utils.data.random_split(
164 | dataset, [num_samples, len(dataset) - num_samples]
165 | )
166 |
167 | publish_pt_dataset(dataset, ds_info)
168 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/mind2web_convert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | from glob import glob
5 |
6 | from ..constants import MIND2WEB_CACHE_DIR, MIND2WEB_OUT_DIR
7 |
8 | files = glob(f"{MIND2WEB_OUT_DIR}/*/input.pt")
9 |
10 | # copy /input.pd to MIND2WEB_CACHE_DIR/.pt
11 | # but only for input.pt files that are less than 1 day old
12 |
13 | for f in files:
14 | # get date modified
15 | date_modified = os.path.getmtime(f)
16 | # get current time
17 | current_time = time.time()
18 | # get difference
19 | diff = current_time - date_modified
20 | # if less than 1 day old
21 | if diff < 60 * 60 * 15:
22 | # get uid
23 | uid = f.split("/")[-2]
24 | # copy file
25 | shutil.copy(f, f"{MIND2WEB_CACHE_DIR}/{uid}.pt")
26 | print(f"Copied {f} to {MIND2WEB_CACHE_DIR}/{uid}.pt")
27 | else:
28 | print(f"Skipping {f} because it is {diff//(60*60)} hrs old")
29 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/pretraining.py:
--------------------------------------------------------------------------------
1 | from playwright.sync_api import sync_playwright
2 | from torch.utils.data import Dataset
3 |
4 | from llama2d.datasets.huggingface import DatasetInfo, publish_pt_dataset
5 | from llama2d.vision.url_to_llama_input import Llama2dWebsiteFeatureExtractor
6 | from src.data.pretraining_urls import urls
7 |
8 |
9 | class Llama2dPretrainingDataset(Dataset):
10 | def __init__(
11 | self, model="decapoda-research/llama-7b-hf", urls=[], include_coords=True
12 | ):
13 | self.__extractor = Llama2dWebsiteFeatureExtractor(model, mask_out_body=False)
14 | self.__urls = urls
15 |
16 | self.__include_coords = include_coords
17 |
18 | with sync_playwright() as p:
19 | # Using the Chromium browser but you can also use 'firefox' or 'webkit'
20 | browser = p.chromium.launch()
21 | page = browser.new_page()
22 |
23 | page.set_extra_http_headers(
24 | {
25 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
26 | "AppleWebKit/537.36 (KHTML, like Gecko)"
27 | " Chrome/116.0.0.0 Safari/537.36"
28 | }
29 | )
30 | # exceptional() is a function calling helper that returns
31 | # None if the method errors.
32 | # we call all the functions
33 | self.extractions = [
34 | exceptional(self.__extractor.create_inference_data, args=(page, "", i))
35 | for i in self.__urls
36 | ]
37 | # or otherwise return None
38 | self.extractions = [i for i in self.extractions if i]
39 |
40 | def __getitem__(self, index):
41 | ret = self.extractions[index]
42 | if not self.__include_coords:
43 | return {k: v for k, v in ret.items() if k != "coords"}
44 | return ret
45 |
46 | def __len__(self):
47 | return len(self.extractions)
48 |
49 |
50 | def exceptional(call, args):
51 | """Wrapper function to return None for a function if it errors.
52 |
53 | Parameters
54 | ----------
55 | call : callable
56 | The function to call
57 | args : List[Any]
58 | The arguments to call it with
59 |
60 | Returns
61 | -------
62 | Any
63 | The output of the funciton.
64 | """
65 |
66 | try:
67 | return call(*args)
68 | except Exception as e:
69 | print("your call to", call, "errored! Returning None")
70 | print(e)
71 |
72 | return None
73 |
74 |
75 | pretraining_repo = "llama2d/llama2d-pretraining"
76 |
77 | if __name__ == "__main__":
78 | print("Downloading pretraining dataset with Playwright...")
79 |
80 | ds_info = DatasetInfo(
81 | repo=pretraining_repo,
82 | desc="Llama2d pretraining dataset - next-token prediction "
83 | "on real estate websites",
84 | )
85 |
86 | dataset = Llama2dPretrainingDataset(
87 | model="decapoda-research/llama-7b-hf", urls=urls, include_coords=True
88 | )
89 |
90 | publish_pt_dataset(dataset, ds_info)
91 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/synthetic/top_or_bottom.py:
--------------------------------------------------------------------------------
1 | from math import inf
2 | from random import choice, random
3 |
4 | from torch.utils.data import Dataset
5 |
6 | from llama2d.datasets.huggingface import DatasetInfo, publish_pt_dataset
7 | from llama2d.vision import Llama2dScreen, Llama2dTokenizer, debug_dataset
8 |
9 | directions = {
10 | "t": (0.5, 0), # in -y direction
11 | "b": (0.5, 1), # in +y direction
12 | }
13 |
14 | rand_words = "bob,jane,alice,carol,ted,lisa,barry,frank,george,harold,henry,ian,john,james,kevin,mark,neil,oliver,peter,quinn,robert,steve,thomas,william".split(
15 | ","
16 | )
17 |
18 |
19 | class TopBottomDataset(Dataset):
20 | def __init__(self, num_screens: int, tokenizer: Llama2dTokenizer = None):
21 | self.num_screens = num_screens
22 |
23 | if tokenizer is None:
24 | tokenizer = Llama2dTokenizer()
25 | self.tokenizer = tokenizer
26 |
27 | self.screens = []
28 | for i in range(num_screens):
29 | screen = Llama2dScreen()
30 | direction, vector = choice(list(directions.items()))
31 |
32 | screen.push_word(word=choice(rand_words), xy=vector)
33 |
34 | prompt = f"Top or bottom? (t/b)"
35 | output = direction
36 |
37 | self.screens.append(self.tokenizer.process(prompt, screen, output))
38 |
39 | def __len__(self):
40 | return self.num_screens
41 |
42 | def __getitem__(self, i: int):
43 | return self.screens[i]
44 |
45 |
46 | if __name__ == "__main__":
47 | dataset = TopBottomDataset(num_screens=500)
48 |
49 | debug_dataset(dataset)
50 |
51 | info = DatasetInfo(
52 | repo="llama2d/llama2d-top-or-bottom", desc="Identify if a person is up or down."
53 | )
54 | publish_pt_dataset(dataset, info)
55 |
--------------------------------------------------------------------------------
/src/llama2d/datasets/synthetic/unscramble_words.py:
--------------------------------------------------------------------------------
1 |
2 | from llama2d.vision import debug_dataset,Llama2dTokenizer,Llama2dScreen
3 | from llama2d.datasets.huggingface import DatasetInfo, publish_pt_dataset
4 | from torch.utils.data import Dataset
5 |
6 | from random import choice,random
7 | rand_words = "bob,jane,alice,carol,ted,lisa,barry,frank,george,harold,henry,ian,john,james,kevin,mark,neil,oliver,peter,quinn,robert,steve,thomas,william".split(",")
8 |
9 | class UnscrambleDataset(Dataset):
10 | def __init__(
11 | self,
12 | num_screens:int,
13 | words_per_screen:int,
14 | words_per_line:int=20,
15 | lines_per_screen:int=5,
16 | tokenizer:Llama2dTokenizer=None
17 | ):
18 | self.num_screens = num_screens
19 | self.words_per_screen = words_per_screen
20 |
21 | if tokenizer is None:
22 | tokenizer = Llama2dTokenizer()
23 | self.tokenizer = tokenizer
24 |
25 | self.screens = []
26 | for i in range(num_screens):
27 | screen = Llama2dScreen()
28 |
29 | words = [choice(rand_words) for _ in range(words_per_screen)]
30 |
31 | # render in a grid of lines
32 | for k,word in enumerate(words):
33 | i,j = k%words_per_line,k//words_per_line
34 | # convert i,j to x,y, where x is horizontal and y is vertical
35 | # x is in [0,1] and y is in [0,1]
36 |
37 | x = (i+0.5)/words_per_line
38 | y = (j+0.5)/lines_per_screen
39 |
40 | assert y<1,"Too many words for the screen"
41 |
42 | screen.push_word(word=word,xy=(x,y))
43 |
44 | from random import shuffle
45 | shuffle(screen.words)
46 |
47 | prompt = "Read out the words in the order they appear."
48 | response = " ".join(words)
49 |
50 | self.screens.append(self.tokenizer.process(prompt,screen,response))
51 |
52 | def __len__(self):
53 | return self.num_screens
54 | def __getitem__(self,i:int):
55 | return self.screens[i]
56 |
57 | if __name__ == "__main__":
58 |
59 | dataset = UnscrambleDataset(
60 | num_screens=5000,
61 | words_per_screen=50,
62 | words_per_line=15,
63 | lines_per_screen=5
64 | )
65 |
66 | debug_dataset(dataset)
67 |
68 | info = DatasetInfo(repo="llama2d/llama2d-unscramble",desc="Unscramble the words displayed on the screen.")
69 | publish_pt_dataset(dataset,info)
--------------------------------------------------------------------------------
/src/llama2d/datasets/synthetic/zoo_compass.py:
--------------------------------------------------------------------------------
1 | from math import inf
2 | from random import choice, random
3 |
4 | from torch.utils.data import Dataset
5 |
6 | from llama2d.datasets.huggingface import DatasetInfo, publish_pt_dataset
7 | from llama2d.vision import Llama2dScreen, Llama2dTokenizer, debug_dataset
8 |
9 | animals = "frog,cat,bear,big lion,eagle,elephant,tiger,baboon,archerfish,gorilla,gerbil,ant colony".split(
10 | ","
11 | )
12 | directions = {
13 | "northernmost": (0, -1), # in -y direction
14 | "farthest west": (-1, 0), # in -x direction
15 | "southernmost": (0, 1), # in +y direction
16 | "farthest east": (1, 0), # in +x direction
17 | }
18 |
19 |
20 | class Llama2dZooCompassDataset(Dataset):
21 | def __init__(
22 | self,
23 | num_screens: int,
24 | words_per_screen: int,
25 | tokenizer: Llama2dTokenizer = None,
26 | ):
27 | self.num_screens = num_screens
28 |
29 | if tokenizer is None:
30 | tokenizer = Llama2dTokenizer()
31 | self.tokenizer = tokenizer
32 |
33 | self.screens = []
34 | for i in range(num_screens):
35 | screen = Llama2dScreen()
36 | direction, vector = choice(list(directions.items()))
37 |
38 | farthest_animal = None
39 | farthest_distance = -inf
40 | for j in range(words_per_screen):
41 | animal = choice(animals)
42 | coords = (random(), random())
43 | screen.push_word(word=animal, xy=coords)
44 |
45 | distance = coords[0] * vector[0] + coords[1] * vector[1]
46 | if distance > farthest_distance:
47 | farthest_animal = animal
48 | farthest_distance = distance
49 |
50 | assert farthest_animal is not None, "No animal is farthest"
51 |
52 | prompt = (
53 | f"Here is a map of the zoo. Find the {direction} animal in the zoo."
54 | )
55 | output = farthest_animal
56 |
57 | self.screens.append(self.tokenizer.process(prompt, screen, output))
58 |
59 | def __len__(self):
60 | return self.num_screens
61 |
62 | def __getitem__(self, i: int):
63 | return self.screens[i]
64 |
65 |
66 | if __name__ == "__main__":
67 | tokenizer = Llama2dTokenizer()
68 | dataset = Llama2dZooCompassDataset(
69 | tokenizer=tokenizer, num_screens=10_000, words_per_screen=20
70 | )
71 |
72 | debug_dataset(dataset)
73 |
74 | info = DatasetInfo(
75 | repo="llama2d/llama2d-zoo-compass",
76 | desc="Identify the animal farthest north/west/east/south in the zoo.",
77 | )
78 | publish_pt_dataset(dataset, info)
79 |
--------------------------------------------------------------------------------
/src/llama2d/find_pos_given_attr/download_mind2web.py:
--------------------------------------------------------------------------------
1 | from pprint import pprint
2 |
3 | from datasets import load_dataset
4 |
5 | # Load the Mind2Web dataset
6 | dataset = load_dataset("osunlp/Mind2Web")
7 |
8 | # Print the first sample for verification
9 |
10 |
11 | example = dataset["train"][0]
12 |
13 | pprint(example)
14 | # breakpoint()
15 |
16 | actions = example["actions"]
17 |
18 | print(actions[0].keys())
19 |
20 | print(example["action_reprs"])
21 |
--------------------------------------------------------------------------------
/src/llama2d/find_pos_given_attr/find_pos_given_attr.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import requests
4 | from bs4 import BeautifulSoup
5 | from datasets import load_dataset
6 |
7 | # Load the Mind2Web dataset
8 | dataset = load_dataset("osunlp/Mind2Web")
9 | example = dataset["train"][0]
10 |
11 |
12 | attrs = example["actions"][0]["pos_candidates"][0]["attributes"]
13 |
14 |
15 | # URL of the webpage you want to scrape
16 | url = "http://example.com"
17 |
18 | print(example["domain"])
19 | print(example["subdomain"])
20 |
21 | # We might be able to assume website we can append .com to it
22 | print(example["website"])
23 | print(len(dataset["train"]))
24 |
25 | print("Attemping to find all tags that contains that contain the attrs:")
26 | print(type(attrs))
27 | print(attrs)
28 |
29 | attributes = json.loads(attrs)
30 |
31 |
32 | # Send a GET request to the webpage
33 | response = requests.get(url)
34 | soup = BeautifulSoup(response.content, "html.parser")
35 |
36 | # Find all tags that match the attributes
37 | matching_tags = soup.find_all(attrs=attributes)
38 |
39 | # Check if there are matching tags
40 | if matching_tags:
41 | print(f"Found {len(matching_tags)} matching tag(s)!")
42 | for tag in matching_tags:
43 | print(tag)
44 | else:
45 | print("No matching tags found!")
46 |
--------------------------------------------------------------------------------
/src/llama2d/modal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/src/llama2d/modal/__init__.py
--------------------------------------------------------------------------------
/src/llama2d/modal/common.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from modal import Image, Secret, Stub, Volume
3 |
4 | N_GPUS = 2
5 | GPU_MEM = 80
6 | BASE_MODELS = {
7 | "base7": "meta-llama/Llama-2-7b-hf",
8 | "chat7": "meta-llama/Llama-2-7b-chat-hf",
9 | "chat13": "meta-llama/Llama-2-13b-chat-hf",
10 | "code7": "codellama/CodeLlama-7b-hf",
11 | "code34": "codellama/CodeLlama-34b-hf",
12 | "instruct7": "codellama/CodeLlama-7b-Instruct-hf",
13 | "instruct13": "codellama/CodeLlama-13b-Instruct-hf",
14 | "instruct34": "codellama/CodeLlama-34b-Instruct-hf",
15 | # Training 70B requires experimental flag fsdp_peft_cpu_offload_for_save.
16 | "chat70": "meta-llama/Llama-2-70b-chat-hf",
17 | }
18 |
19 | import os
20 | import random
21 |
22 | own_dir = os.path.dirname(os.path.realpath(__file__))
23 | root_dir = f"{own_dir}/../../.."
24 |
25 | secrets_dir = f"{root_dir}/secrets/"
26 | data_dir = f"{root_dir}/data/"
27 | dataset_dir = f"{own_dir}/datasets/"
28 |
29 | transformers_dir = f"{root_dir}/transformers"
30 | llama_recipes_dir = f"{root_dir}/llama-recipes"
31 |
32 | if os.path.exists(transformers_dir) and os.path.exists(llama_recipes_dir):
33 | import os
34 |
35 | transformers_commit = (
36 | os.popen(f"cd {transformers_dir} && git rev-parse HEAD").read().strip()
37 | )
38 | llama_recipes_commit = (
39 | os.popen(f"cd {llama_recipes_dir} && git rev-parse HEAD").read().strip()
40 | )
41 |
42 | assert transformers_commit != "", "Could not get transformers commit."
43 | assert llama_recipes_commit != "", "Could not get llama-recipes commit."
44 | else:
45 | transformers_commit = "overwriting-llama"
46 | llama_recipes_commit = "andrew-dev"
47 |
48 | print(
49 | f"Transformers commit: {transformers_commit}, llama-recipes commit: {llama_recipes_commit}"
50 | )
51 |
52 | import random
53 |
54 | image = (
55 | Image.micromamba()
56 | .micromamba_install(
57 | "cudatoolkit=11.8",
58 | "cudnn=8.1.0",
59 | "cuda-nvcc",
60 | channels=["conda-forge", "nvidia"],
61 | )
62 | .apt_install("git", "unzip")
63 | .pip_install(
64 | "huggingface_hub==0.17.1",
65 | "hf-transfer==0.1.3",
66 | "scipy",
67 | "gdown",
68 | "google-cloud-vision",
69 | "sentencepiece",
70 | "playwright",
71 | "wandb",
72 | "transformers",
73 | "matplotlib",
74 | )
75 | .pip_install(
76 | f"llama-recipes @ git+https://github.com/modal-labs/llama-recipes.git",
77 | extra_index_url="https://download.pytorch.org/whl/nightly/cu118",
78 | pre=True,
79 | )
80 | .run_commands(
81 | f"pip install 'llama-recipes @ git+https://github.com/llama2d/llama-recipes.git@{llama_recipes_commit}' git+https://github.com/llama2d/transformers.git@{transformers_commit} --no-deps"
82 | )
83 | .env(dict(HUGGINGFACE_HUB_CACHE="/pretrained", HF_HUB_ENABLE_HF_TRANSFER="1"))
84 | .copy_local_dir(secrets_dir, "/root/secrets")
85 | .copy_local_file(
86 | f"{os.path.dirname(os.path.realpath(__file__))}/finetuning.py",
87 | "/root/finetuning.py",
88 | )
89 | )
90 |
91 | stub = Stub(
92 | "llama-finetuning",
93 | image=image,
94 | secrets=[Secret.from_name("huggingface"), Secret.from_name("wandb")],
95 | )
96 |
97 | stub.hf_cache_volume = Volume.persisted("hf-cache")
98 |
99 | # Download pre-trained models into this volume.
100 | stub.pretrained_volume = Volume.persisted("example-pretrained-vol")
101 |
102 | # Save trained models into this volume.
103 | stub.results_volume = Volume.persisted("example-results-vol")
104 |
105 | VOLUME_CONFIG = {
106 | "/pretrained": stub.pretrained_volume,
107 | "/results": stub.results_volume,
108 | "/hf_cache": stub.hf_cache_volume,
109 | }
110 |
--------------------------------------------------------------------------------
/src/llama2d/modal/datasets/cached_dataset.py:
--------------------------------------------------------------------------------
1 | import gdown
2 | import torch
3 |
4 | from llama2d.datasets.cached import CachedDataset
5 |
6 |
7 | def get_custom_dataset(dataset_config, tokenizer, split):
8 | dataset_folder = dataset_config.dataset_folder
9 | print(f"Using dataset folder {dataset_folder}")
10 |
11 | use_2d = dataset_config.use_2d
12 |
13 | gdown.download(id="1bgbnuVQjhRku60gCLrFfqfM66bp0Z4sI")
14 | gdown.download(id="1LBT_gMNntS0mj-S8oTEWQE8pcJOIAXLA")
15 | # unzip the dataset
16 | import os
17 |
18 | os.system("unzip -qo cached-pretrain.zip")
19 | os.system("unzip -qo mind2web-cache.zip")
20 |
21 | train_percent = 80
22 |
23 | full_dataset = CachedDataset(
24 | dataset_folder, use_2d=use_2d, keep_fraction=dataset_config.keep_fraction
25 | )
26 |
27 | train_size = int(len(full_dataset) * train_percent / 100)
28 | val_size = len(full_dataset) - train_size
29 |
30 | train_dataset, val_dataset = torch.utils.data.random_split(
31 | full_dataset, [train_size, val_size]
32 | )
33 |
34 | return train_dataset if split == "train" else val_dataset
35 |
--------------------------------------------------------------------------------
/src/llama2d/modal/datasets/hf_dataset.py:
--------------------------------------------------------------------------------
1 | from llama2d.datasets.huggingface import HuggingFaceDataset
2 |
3 |
4 | def get_custom_dataset(dataset_config, tokenizer, split):
5 | repo = dataset_config.repo
6 | use_2d = dataset_config.use_2d
7 | print("get_custom_dataset, use_2d:", use_2d)
8 | return HuggingFaceDataset(
9 | repo, split, keep_fraction=dataset_config.keep_fraction, use_2d=use_2d
10 | )
11 |
--------------------------------------------------------------------------------
/src/llama2d/modal/datasets/new_dataset.py:
--------------------------------------------------------------------------------
1 | from llama2d.datasets.pretraining import Llama2dPretrainingDataset
2 |
3 |
4 | def format_text(row, tokenizer):
5 | return tokenizer(row)
6 |
7 |
8 | def get_custom_dataset():
9 | urls = [
10 | "https://github.com/OSU-NLP-Group/Mind2Web",
11 | "https://stackoverflow.com/questions/60352003/how-to-download-webpage-as-mhtml",
12 | ]
13 | dataset = Llama2dPretrainingDataset(
14 | model="decapoda-research/llama-7b-hf", urls=urls
15 | )
16 |
17 | return dataset
18 |
--------------------------------------------------------------------------------
/src/llama2d/modal/datasets/sql_dataset.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | from llama_recipes.datasets.utils import Concatenator
3 |
4 | B_INST, E_INST = "[INST] ", " [/INST]"
5 | B_SYS, E_SYS = "<>\n", "\n<>\n\n"
6 |
7 |
8 | def format_text(row, tokenizer):
9 | text = (
10 | B_INST
11 | + B_SYS
12 | + "You are an advanced SQL assistant that uses this SQL table schema "
13 | "to generate"
14 | " a SQL query which answers the user question.\n"
15 | + row["context"]
16 | + E_SYS
17 | + row["question"]
18 | + E_INST
19 | + "\n[SQL]\n"
20 | + row["answer"]
21 | + "\n[/SQL]"
22 | + ""
23 | )
24 |
25 | return tokenizer(text)
26 |
27 |
28 | def get_custom_dataset(dataset_config, tokenizer, split):
29 | full_dataset = datasets.load_dataset("b-mc2/sql-create-context", split="train")
30 |
31 | # Since the dataset has no train/test split, we create one and select it
32 | dataset = full_dataset.train_test_split(
33 | train_size=10000,
34 | test_size=200,
35 | seed=42,
36 | )["train" if split == dataset_config.train_split else "test"]
37 |
38 | dataset = dataset.map(
39 | lambda x: format_text(x, tokenizer), remove_columns=list(dataset.features)
40 | )
41 |
42 | dataset = dataset.map(Concatenator(), batched=True, batch_size=None)
43 |
44 | return dataset
45 |
--------------------------------------------------------------------------------
/src/llama2d/modal/datasets/zoo_dataset.py:
--------------------------------------------------------------------------------
1 | from llama2d.datasets.synthetic.zoo_compass import Llama2dZooCompassDataset
2 |
3 | dataset_registry = {}
4 |
5 |
6 | def get_custom_dataset(dataset_config, tokenizer, split):
7 | keep_fraction = dataset_config.keep_fraction
8 | train_size = int(5000 * keep_fraction)
9 | val_size = int(200 * keep_fraction) # make val_size very small - we're short on GPU time
10 | return Llama2dZooCompassDataset(
11 | num_screens=train_size if split == "train" else val_size,
12 | words_per_screen=20,
13 | )
14 |
--------------------------------------------------------------------------------
/src/llama2d/modal/finetuning.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the
3 | # terms of the Llama 2 Community License Agreement.
4 |
5 | import os
6 |
7 | import fire
8 | import torch
9 | import torch.distributed as dist
10 | import torch.optim as optim
11 | from llama_recipes.configs import fsdp_config, train_config
12 | from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
13 | from llama_recipes.utils import fsdp_auto_wrap_policy
14 | from llama_recipes.utils.config_utils import (
15 | generate_dataset_config,
16 | generate_peft_config,
17 | update_config,
18 | )
19 | from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
20 | from llama_recipes.utils.train_utils import (
21 | clear_gpu_cache,
22 | freeze_transformer_layers,
23 | get_policies,
24 | print_model_size,
25 | setup,
26 | setup_environ_flags,
27 | train,
28 | )
29 | from peft import get_peft_model, prepare_model_for_int8_training
30 | from pkg_resources import packaging
31 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32 | from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
33 | from torch.optim.lr_scheduler import StepLR
34 | from torch.utils.data import DistributedSampler
35 |
36 | from transformers import AutoTokenizer, default_data_collator
37 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer
38 | from transformers.models.llama.sam_embed import PositionEmbeddingRandom
39 |
40 | # dataclass serialization
41 | import dataclasses, json
42 |
43 | class EnhancedJSONEncoder(json.JSONEncoder):
44 | def default(self, o):
45 | if dataclasses.is_dataclass(o):
46 | return dataclasses.asdict(o)
47 | return super().default(o)
48 | def json_dumps(obj, *args,**kwargs):
49 | return json.dumps(obj,*args, cls=EnhancedJSONEncoder, **kwargs)
50 |
51 | def main(Llama, LlamaCfg, **kwargs):
52 | # Update the configuration for the training and sharding process
53 | update_config((train_config, fsdp_config), **kwargs)
54 |
55 | print(f"Full config: {train_config=},{kwargs=}")
56 | dataset_config = generate_dataset_config(train_config, kwargs)
57 | print(f"Dataset config: {dataset_config=}")
58 |
59 | use_2d = train_config.use_2d
60 | # Set the seeds for reproducibility
61 | torch.cuda.manual_seed(train_config.seed)
62 | torch.manual_seed(train_config.seed)
63 | import random
64 | random.seed(train_config.seed)
65 | import numpy as np
66 | np.random.seed(train_config.seed)
67 |
68 | if train_config.enable_fsdp:
69 | setup()
70 | # torchrun specific
71 | local_rank = int(os.environ["LOCAL_RANK"])
72 | rank = int(os.environ["RANK"])
73 | # world_size = int(os.environ["WORLD_SIZE"])
74 |
75 | if torch.distributed.is_initialized():
76 | torch.cuda.set_device(local_rank)
77 | clear_gpu_cache(local_rank)
78 | setup_environ_flags(rank)
79 |
80 | # Load the tokenizer and add special tokens
81 | tokenizer = AutoTokenizer.from_pretrained(train_config.model_name)
82 | tokenizer.add_special_tokens(
83 | {
84 | "pad_token": "",
85 | }
86 | )
87 |
88 | # Load and preprocess the dataset for training and validation
89 | dataset_train = get_preprocessed_dataset(
90 | tokenizer,
91 | dataset_config,
92 | split="train",
93 | )
94 |
95 | if not train_config.enable_fsdp or rank == 0:
96 | print(f"--> Training Set Length = {len(dataset_train)}")
97 |
98 | dataset_val = get_preprocessed_dataset(
99 | tokenizer,
100 | dataset_config,
101 | split="test",
102 | )
103 | if not train_config.enable_fsdp or rank == 0:
104 | print(f"--> Validation Set Length = {len(dataset_val)}")
105 |
106 | kwargs = {
107 | "use_2d": use_2d,
108 | "lbd_start_value": train_config.lbd_start_value,
109 | "use_point_embed": train_config.use_point_embed,
110 | "separate_point_embed": train_config.separate_point_embed,
111 | }
112 |
113 | # Load the pre-trained model and setup its configuration
114 | use_cache = False if train_config.enable_fsdp else None
115 | if train_config.enable_fsdp and train_config.low_cpu_fsdp:
116 | """
117 | for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
118 | this avoids cpu oom when loading large models like llama 70B, in which case
119 | model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
120 | overhead and currently requires latest nightly.
121 | """
122 | v = packaging.version.parse(torch.__version__)
123 | verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
124 | if not verify_latest_nightly:
125 | raise Exception(
126 | "latest pytorch nightly build is required to "
127 | "run with low_cpu_fsdp config, "
128 | "please install latest nightly."
129 | )
130 | if rank == 0:
131 | model = Llama.from_pretrained(
132 | train_config.model_name,
133 | load_in_8bit=True if train_config.quantization else None,
134 | device_map="auto" if train_config.quantization else None,
135 | use_cache=use_cache,
136 | **kwargs,
137 | )
138 | else:
139 | llama_config = LlamaCfg.from_pretrained(train_config.model_name)
140 | llama_config.use_cache = use_cache
141 |
142 | llama_config.use_2d = use_2d
143 | llama_config.lbd_start_value = train_config.lbd_start_value
144 | llama_config.use_point_embed = train_config.use_point_embed
145 | llama_config.separate_point_embed = train_config.separate_point_embed
146 |
147 | with torch.device("meta"):
148 | model = Llama(llama_config)
149 |
150 | else:
151 | model = Llama.from_pretrained(
152 | train_config.model_name,
153 | load_in_8bit=True if train_config.quantization else None,
154 | device_map="auto" if train_config.quantization else None,
155 | use_cache=use_cache,
156 | **kwargs,
157 | )
158 |
159 | print(f"Using model type: {type(model)}")
160 |
161 | if train_config.enable_fsdp and train_config.use_fast_kernels:
162 | """
163 | For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
164 | using of Flash Attention or Xformer memory-efficient kernels
165 | based on the hardware being used. This would speed up fine-tuning.
166 | """
167 | try:
168 | from optimum.bettertransformer import BetterTransformer
169 |
170 | model = BetterTransformer.transform(model)
171 | except ImportError:
172 | print(
173 | "Module 'optimum' not found."
174 | " Please install 'optimum' it before proceeding."
175 | )
176 | print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
177 |
178 | # Prepare the model for int8 training if quantization is enabled
179 | if train_config.quantization:
180 | model = prepare_model_for_int8_training(model)
181 |
182 | # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
183 | if train_config.enable_fsdp and fsdp_config.pure_bf16:
184 | print("Converting to bfloat16")
185 | model.to(torch.bfloat16)
186 |
187 | if train_config.use_peft:
188 | peft_config = generate_peft_config(train_config, kwargs)
189 | print(f"PEFT config: {peft_config=}")
190 | model = get_peft_model(model, peft_config)
191 |
192 | # Llama2D weight initialization code
193 |
194 | trainable_params_before, _ = model.get_nb_trainable_parameters()
195 |
196 | print("--------IGNORE POS EMBEDS IS FALSE--------")
197 | for k, v in model.named_parameters():
198 | if k.endswith(".lbd"):
199 | v.requires_grad = True
200 | print(k, "requires_grad=", v.requires_grad, v)
201 |
202 | trainable_params_after, _ = model.get_nb_trainable_parameters()
203 | assert trainable_params_after > trainable_params_before, (
204 | "Looks like lambda gating parameter isn't marked as trainable."
205 | f" Before: {trainable_params_before}, after: {trainable_params_after}"
206 | )
207 |
208 | model.print_trainable_parameters()
209 | else:
210 | for k, v in model.named_parameters():
211 | if k.endswith(".lbd"):
212 | v.requires_grad = v.data.requires_grad = True
213 | print(k, "requires_grad=", v.requires_grad, v.data)
214 |
215 | # setting up FSDP if enable_fsdp is enabled
216 | if train_config.enable_fsdp:
217 | if not train_config.use_peft and train_config.freeze_layers:
218 | freeze_transformer_layers(train_config.num_freeze_layers)
219 |
220 | mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
221 | my_auto_wrapping_policy = fsdp_auto_wrap_policy(
222 | model, LlamaDecoderLayer, PositionEmbeddingRandom
223 | )
224 |
225 | model = FSDP(
226 | model,
227 | auto_wrap_policy=my_auto_wrapping_policy
228 | if train_config.use_peft
229 | else wrapping_policy,
230 | cpu_offload=CPUOffload(offload_params=True)
231 | if fsdp_config.fsdp_cpu_offload
232 | else None,
233 | mixed_precision=mixed_precision_policy
234 | if not fsdp_config.pure_bf16
235 | else None,
236 | sharding_strategy=fsdp_config.sharding_strategy,
237 | device_id=torch.cuda.current_device(),
238 | limit_all_gathers=True,
239 | sync_module_states=train_config.low_cpu_fsdp,
240 | param_init_fn=lambda module: module.to_empty(
241 | device=torch.device("cuda"), recurse=False
242 | )
243 | if train_config.low_cpu_fsdp and rank != 0
244 | else None,
245 | )
246 | if fsdp_config.fsdp_activation_checkpointing:
247 | apply_fsdp_checkpointing(model)
248 | elif not train_config.quantization and not train_config.enable_fsdp:
249 | model.to("cuda")
250 |
251 | train_sampler = None
252 | val_sampler = None
253 | if train_config.enable_fsdp:
254 | train_sampler = DistributedSampler(
255 | dataset_train,
256 | rank=dist.get_rank(),
257 | num_replicas=dist.get_world_size(),
258 | shuffle=True,
259 | )
260 | if train_config.run_validation:
261 | val_sampler = DistributedSampler(
262 | dataset_val,
263 | rank=dist.get_rank(),
264 | num_replicas=dist.get_world_size(),
265 | )
266 |
267 | # Create DataLoaders for the training and validation dataset
268 | train_dataloader = torch.utils.data.DataLoader(
269 | dataset_train,
270 | batch_size=train_config.batch_size_training,
271 | num_workers=train_config.num_workers_dataloader,
272 | pin_memory=True,
273 | sampler=train_sampler if train_sampler else None,
274 | drop_last=True,
275 | collate_fn=default_data_collator,
276 | )
277 |
278 | eval_dataloader = None
279 | if train_config.run_validation:
280 | eval_dataloader = torch.utils.data.DataLoader(
281 | dataset_val,
282 | batch_size=train_config.val_batch_size,
283 | num_workers=train_config.num_workers_dataloader,
284 | pin_memory=True,
285 | sampler=val_sampler if val_sampler else None,
286 | drop_last=True,
287 | collate_fn=default_data_collator,
288 | )
289 |
290 | # Initialize the optimizer and learning rate scheduler
291 |
292 | # make custom param groups
293 | group_substrs = {
294 | "lambda":[train_config.lambda_lr,"lbd"],
295 | "point_embed":[train_config.point_embed_lr,"is_a_point_embed"],
296 | }
297 | param_groups = []
298 | for n,p in model.named_parameters():
299 | for group_name,(lr,substr) in group_substrs.items():
300 | if substr in n:
301 | param_groups.append({"params":[p],"lr":lr})
302 | break
303 | else:
304 | param_groups.append({"params":[p],"lr":train_config.lr})
305 |
306 |
307 | if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
308 | optimizer = AnyPrecisionAdamW(
309 | param_groups,
310 | momentum_dtype=torch.bfloat16,
311 | variance_dtype=torch.bfloat16,
312 | use_kahan_summation=False,
313 | weight_decay=train_config.weight_decay,
314 | )
315 | else:
316 | optimizer = optim.AdamW(
317 | param_groups,
318 | weight_decay=train_config.weight_decay,
319 | )
320 | scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
321 |
322 | if train_config.num_epochs > 0:
323 | # Start the training process
324 | results = train(
325 | model,
326 | train_dataloader,
327 | eval_dataloader,
328 | tokenizer,
329 | optimizer,
330 | scheduler,
331 | train_config.gradient_accumulation_steps,
332 | train_config,
333 | fsdp_config if train_config.enable_fsdp else None,
334 | local_rank if train_config.enable_fsdp else None,
335 | rank if train_config.enable_fsdp else None,
336 | kwargs,
337 | )
338 | if not train_config.enable_fsdp or rank == 0:
339 | [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
340 | else:
341 | print("Skipping training")
342 |
343 | # print lambda values
344 | print("-----Lambda gating values-------")
345 | with FSDP.summon_full_params(
346 | model, rank0_only=True, writeback=False, with_grads=False
347 | ):
348 | print("-----full-params Lambda gating values-------")
349 | for k, v in model.named_parameters():
350 | if k.endswith(".lbd"):
351 | print(k, v.data)
352 | print("--------------------------------")
353 |
354 |
355 | if __name__ == "__main__":
356 | fire.Fire(main)
357 |
--------------------------------------------------------------------------------
/src/llama2d/modal/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 |
4 | from common import BASE_MODELS, VOLUME_CONFIG, stub
5 | from modal import Image, gpu, method
6 |
7 | tgi_image = (
8 | Image.from_registry("ghcr.io/huggingface/text-generation-inference:1.0.3")
9 | .dockerfile_commands("ENTRYPOINT []")
10 | .pip_install("text-generation", "transformers>=4.33.0")
11 | # .pip_install("git+https://github.com/Llama2D/transformers")
12 | .env(dict(HUGGINGFACE_HUB_CACHE="/pretrained"))
13 | )
14 |
15 |
16 | @stub.function(image=tgi_image, volumes=VOLUME_CONFIG, timeout=60 * 20)
17 | def merge(run_id: str, commit: bool = False):
18 | from text_generation_server.utils.peft import download_and_unload_peft
19 |
20 | os.mkdir(f"/results/{run_id}/merged")
21 | subprocess.call(f"cp /results/{run_id}/*.* /results/{run_id}/merged", shell=True)
22 |
23 | print(f"Merging weights for fine-tuned {run_id=}.")
24 | download_and_unload_peft(f"/results/{run_id}/merged", None, False)
25 |
26 | if commit:
27 | print("Committing merged model permanently (can take a few minutes).")
28 | stub.results_volume.commit()
29 |
30 |
31 | @stub.cls(
32 | image=tgi_image,
33 | gpu=gpu.A100(count=1, memory=40),
34 | allow_concurrent_inputs=100,
35 | volumes=VOLUME_CONFIG,
36 | )
37 | class Model:
38 | def __init__(self, base: str = "", run_id: str = ""):
39 | import socket
40 | import time
41 |
42 | from text_generation import AsyncClient
43 |
44 | model = f"/results/{run_id}/merged" if run_id else BASE_MODELS[base]
45 |
46 | if run_id and not os.path.isdir(model):
47 | merge.local(run_id) # local = run in the same container
48 |
49 | print(f"Loading {model} into GPU ... ")
50 | launch_cmd = ["text-generation-launcher", "--model-id", model, "--port", "8000"]
51 | self.launcher = subprocess.Popen(launch_cmd, stdout=subprocess.DEVNULL)
52 |
53 | self.client = None
54 | while not self.client and self.launcher.returncode is None:
55 | try:
56 | socket.create_connection(("127.0.0.1", 8000), timeout=1).close()
57 | self.client = AsyncClient("http://127.0.0.1:8000", timeout=60)
58 | except (socket.timeout, ConnectionRefusedError):
59 | time.sleep(1.0)
60 |
61 | assert self.launcher.returncode is None
62 |
63 | def __exit__(self, _exc_type, _exc_value, _traceback):
64 | self.launcher.terminate()
65 |
66 | @method()
67 | async def generate(self, prompt: str):
68 | result = await self.client.generate(prompt, max_new_tokens=512)
69 |
70 | return result.generated_text
71 |
72 |
73 | @stub.local_entrypoint()
74 | def main(prompt: str, base: str, run_id: str = "", batch: int = 1):
75 | print(f"Running completion for prompt:\n{prompt}")
76 |
77 | print("=" * 20 + "Generating without adapter" + "=" * 20)
78 | for output in Model(base).generate.map([prompt] * batch):
79 | print(output)
80 |
81 | if run_id:
82 | print("=" * 20 + "Generating with adapter" + "=" * 20)
83 | for output in Model(base, run_id).generate.map([prompt] * batch):
84 | print(output)
85 |
--------------------------------------------------------------------------------
/src/llama2d/modal/repro.py:
--------------------------------------------------------------------------------
1 | from common import transformers_dir,llama_recipes_dir,root_dir
2 | import os
3 | import sys
4 |
5 | def check_all_code_committed(dir):
6 |
7 | old_dir = os.getcwd()
8 | os.chdir(dir)
9 |
10 | # assert that all code in current directory is committed
11 | git_diff = os.popen(f"git diff").read()
12 | git_diff_cached = os.popen("git diff --cached").read()
13 |
14 | dir_name = os.path.basename(dir)
15 | assert (
16 | git_diff == "" and git_diff_cached == ""
17 | ), f"Please commit all code in {dir_name} before running this script."
18 |
19 | git_commit_hash = os.popen(f"git rev-parse HEAD").read().strip()
20 |
21 | # assert that all code in transformers is committed
22 | os.chdir(old_dir)
23 |
24 | return git_commit_hash
25 |
26 | def check_llama2d_code():
27 | llama2d = check_all_code_committed(root_dir)
28 | transformers = check_all_code_committed(transformers_dir)
29 | llama_recipes = check_all_code_committed(llama_recipes_dir)
30 |
31 | return {
32 | "llama2d": llama2d,
33 | "transformers": transformers,
34 | "llama_recipes": llama_recipes,
35 | }
36 |
37 | def make_repro_command():
38 | commits = check_llama2d_code()
39 |
40 | # get full command line command
41 | command = " ".join(sys.argv)
42 |
43 | # TODO: fill in HF dataset name if it's not there
44 |
45 | return f"""
46 | # run in llama2d
47 | git checkout {commits["llama2d"]}
48 | cd transformers && git checkout {commits["transformers"]}
49 | cd ../llama-recipes && git checkout {commits["llama_recipes"]}
50 | cd src/llama2d/modal
51 | {command}
52 | """
--------------------------------------------------------------------------------
/src/llama2d/modal/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.5
2 | aiosignal==1.3.1
3 | aiostream==0.4.5
4 | annotated-types==0.5.0
5 | anyio==3.7.1
6 | asgiref==3.7.2
7 | async-timeout==4.0.3
8 | attrs==23.1.0
9 | certifi==2023.7.22
10 | charset-normalizer==3.2.0
11 | click==8.1.7
12 | cloudpickle==2.0.0
13 | datasets==2.14.5
14 | dill==0.3.7
15 | exceptiongroup==1.1.3
16 | fastapi==0.103.1
17 | filelock==3.12.4
18 | frozenlist==1.4.0
19 | fsspec==2023.6.0
20 | grpclib==0.4.3
21 | h2==4.1.0
22 | hpack==4.0.0
23 | huggingface-hub==0.17.1
24 | hyperframe==6.0.1
25 | idna==3.4
26 | importlib-metadata==6.8.0
27 | markdown-it-py==3.0.0
28 | mdurl==0.1.2
29 | modal==0.52.3439
30 | multidict==6.0.4
31 | multiprocess==0.70.15
32 | numpy==1.24.4
33 | packaging==23.1
34 | pandas==2.0.3
35 | protobuf==4.24.3
36 | pyarrow==13.0.0
37 | pydantic==2.3.0
38 | pydantic_core==2.6.3
39 | Pygments==2.16.1
40 | python-dateutil==2.8.2
41 | pytz==2023.3.post1
42 | PyYAML==6.0.1
43 | regex==2023.8.8
44 | requests==2.31.0
45 | rich==13.5.2
46 | safetensors==0.3.3
47 | sigtools==4.0.1
48 | six==1.16.0
49 | sniffio==1.3.0
50 | starlette==0.27.0
51 | synchronicity==0.5.3
52 | tblib==2.0.0
53 | tokenizers==0.13.3
54 | toml==0.10.2
55 | tqdm==4.66.1
56 | transformers @ git+https://github.com/Llama2D/transformers@cdffed967e6941bf72f333b33b599da601cb21d8
57 | typer==0.9.0
58 | types-certifi==2021.10.8.3
59 | types-toml==0.10.8.7
60 | typing_extensions==4.7.1
61 | tzdata==2023.3
62 | urllib3==2.0.4
63 | watchfiles==0.20.0
64 | xxhash==3.3.0
65 | yarl==1.9.2
66 | zipp==3.16.2
67 |
--------------------------------------------------------------------------------
/src/llama2d/modal/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from repro import make_repro_command
5 | from common import BASE_MODELS, GPU_MEM, N_GPUS, VOLUME_CONFIG, stub
6 | from modal import Mount, Secret, gpu
7 |
8 | # add llama2d to path
9 | sys.path.append(f"{os.path.dirname(os.path.realpath(__file__))}/../../.")
10 | import llama2d
11 |
12 |
13 | @stub.function(
14 | volumes=VOLUME_CONFIG,
15 | memory=1024 * 100,
16 | timeout=3600 * 4,
17 | secrets=[Secret.from_name("huggingface")],
18 | )
19 | def download(model_name: str):
20 | assert (
21 | "HUGGINGFACE_TOKEN" in os.environ
22 | ), "Please set the HUGGINGFACE_TOKEN environment variable."
23 | from huggingface_hub.hf_api import HfFolder
24 |
25 | HfFolder.save_token(os.environ["HUGGINGFACE_TOKEN"])
26 |
27 | from huggingface_hub import snapshot_download
28 |
29 | from transformers.utils import move_cache
30 |
31 | try:
32 | snapshot_download(model_name, local_files_only=True)
33 | print(f"Volume contains {model_name}.")
34 | except FileNotFoundError:
35 | print(f"Downloading {model_name} (no progress bar) ...")
36 | snapshot_download(model_name)
37 | move_cache()
38 |
39 | print("Committing /pretrained directory (no progress bar) ...")
40 | stub.pretrained_volume.commit()
41 |
42 |
43 | def library_entrypoint(config):
44 | import os
45 |
46 | print(os.getcwd(), os.listdir())
47 | assert (
48 | "HUGGINGFACE_TOKEN" in os.environ
49 | ), "Please set the HUGGINGFACE_TOKEN environment variable."
50 | from huggingface_hub.hf_api import HfFolder
51 |
52 | HfFolder.save_token(os.environ["HUGGINGFACE_TOKEN"])
53 |
54 | print(config)
55 | from finetuning import main
56 |
57 | from transformers import LlamaConfig, LlamaForCausalLM
58 |
59 | # from llama2d.model.modeling_llama import Llama2DForCausalLM
60 | # from llama2d.model.configuration_llama import Llama2DConfig
61 | # from llama2d.model.modeling_llama_old import LlamaForCausalLM
62 | # from llama2d.model.configuration_llama_old import LlamaConfig
63 |
64 | Llama = LlamaForCausalLM
65 | # LlamaConfig = Llama2DConfig
66 |
67 | main(Llama, LlamaConfig, **config)
68 |
69 |
70 | @stub.function(
71 | volumes=VOLUME_CONFIG,
72 | mounts=[
73 | Mount.from_local_dir("./datasets", remote_path="/root"),
74 | ],
75 | gpu=gpu.A100(count=N_GPUS, memory=GPU_MEM),
76 | timeout=3600 * 12,
77 | )
78 | def train(train_kwargs):
79 | from torch.distributed.run import config_from_args, elastic_launch, parse_args
80 |
81 | torch_args = parse_args(["--nnodes", "1", "--nproc_per_node", str(N_GPUS), ""])
82 | print(f"{torch_args=}\n{train_kwargs=}")
83 |
84 | elastic_launch(
85 | config=config_from_args(torch_args)[0],
86 | entrypoint=library_entrypoint,
87 | )(train_kwargs)
88 |
89 | print("Committing results volume (no progress bar) ...")
90 | stub.results_volume.commit()
91 |
92 | @stub.local_entrypoint() # Runs locally to kick off remote training job.
93 | def main(
94 | dataset: str,
95 | base: str = "base7",
96 | run_id: str = "",
97 | num_epochs: int = 1,
98 | batch_size: int = 16,
99 | repo: str = "llama2d/llama2d-mind2web",
100 | keep_fraction: float = 1.0,
101 | seed: int = 0,
102 |
103 | peft: bool = False,
104 | use_2d: bool = True,
105 | use_point_embed: bool = True,
106 | lbd_start_value: float = 0.0,
107 | lr: float = 3e-5,
108 | lambda_lr: float = 3e-2,
109 | point_embed_lr: float = 3e-5,
110 | separate_point_embed: bool = False,
111 |
112 | # wandb args
113 | group: str = None,
114 | name: str = None,
115 | ):
116 | print("Welcome to Modal Llama fine-tuning.")
117 | print(f"Dataset is {dataset}.")
118 |
119 | model_name = BASE_MODELS[base]
120 | print(f"Syncing base model {model_name} to volume.")
121 | download.remote(model_name)
122 |
123 | cmd = make_repro_command()
124 | print(cmd)
125 | raise Exception("Done")
126 |
127 | if not run_id:
128 | import secrets
129 |
130 | run_id = f"{base}-{secrets.token_hex(3)}"
131 | elif not run_id.startswith(base):
132 | run_id = f"{base}-{run_id}"
133 |
134 | print(f"Beginning run {run_id=}.")
135 | train.remote(
136 | {
137 | "model_name": BASE_MODELS[base],
138 | "output_dir": f"/results/{run_id}",
139 | "batch_size_training": batch_size,
140 | "lr": lr,
141 | "lambda_lr": lambda_lr,
142 | "num_epochs": num_epochs,
143 | "val_batch_size": 1,
144 | # --- Dataset options ---
145 | "dataset": "custom_dataset",
146 | "custom_dataset.file": dataset,
147 | # --- FSDP options ---
148 | "enable_fsdp": True,
149 | "low_cpu_fsdp": True, # Optimization for FSDP model loading (RAM won't scale with num GPUs) # noqa
150 | "fsdp_config.use_fast_kernels": True, # Only works when FSDP is on
151 | "fsdp_config.fsdp_activation_checkpointing": True, # Activation checkpointing for fsdp # noqa
152 | "pure_bf16": True,
153 | # --- Required for 70B ---
154 | "fsdp_config.fsdp_cpu_offload": True,
155 | "fsdp_peft_cpu_offload_for_save": True, # Experimental
156 | # --- PEFT options ---
157 | "use_peft": peft,
158 | "peft_method": "lora",
159 | "lora_config.r": 8,
160 | "lora_config.lora_alpha": 16,
161 | # --- Llama2D options ---
162 | "label_names": ["coords"],
163 | "dataset_folder": "mind2web-cache",
164 | "use_2d": use_2d,
165 | "keep_fraction": keep_fraction,
166 | "repo": repo,
167 | "lbd_start_value": lbd_start_value,
168 | "seed": seed,
169 | "use_point_embed": use_point_embed,
170 | "point_embed_lr": point_embed_lr,
171 | "separate_point_embed": separate_point_embed,
172 |
173 | "group": group,
174 | "name": name,
175 | }
176 | )
177 |
178 | print(f"Training completed {run_id=}.")
179 | print(
180 | f"Test: `modal run compare.py --base {base} --run-id {run_id} --prompt '...'`."
181 | )
182 |
--------------------------------------------------------------------------------
/src/llama2d/modal/validate_dataset.py:
--------------------------------------------------------------------------------
1 | from common import BASE_MODELS, stub
2 | from llama_recipes.configs.datasets import custom_dataset
3 | from llama_recipes.utils.config_utils import update_config
4 | from llama_recipes.utils.dataset_utils import get_custom_dataset
5 | from modal import Mount
6 |
7 |
8 | @stub.function(
9 | volumes={
10 | "/pretrained": stub.pretrained_volume,
11 | "/results": stub.results_volume,
12 | },
13 | mounts=[
14 | Mount.from_local_dir("./datasets", remote_path="/root"),
15 | ],
16 | )
17 | def dataset(base: str = "chat7", dataset: str = "local_dataset.py"):
18 | from transformers import AutoTokenizer
19 |
20 | tokenizer = AutoTokenizer.from_pretrained(BASE_MODELS[base])
21 | tokenizer.add_special_tokens({"pad_token": ""})
22 |
23 | config = custom_dataset()
24 | update_config(config, file=dataset)
25 |
26 | BLOCK = "=" * 20
27 |
28 | for split in [config.train_split, config.test_split]:
29 | dataset = get_custom_dataset(config, tokenizer, split)
30 | print(f"{split}: {len(dataset)} sequences")
31 |
32 | sample = tokenizer.decode(dataset[0]["input_ids"])[:500]
33 | print(f"{BLOCK} Sample {BLOCK}\n{sample} ...")
34 | print(f"{BLOCK} Tokens {BLOCK}\n{dataset[0]['input_ids'][:25]} ...\n")
35 |
--------------------------------------------------------------------------------
/src/llama2d/tagging/add_tags_to_page.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from dataclasses import dataclass
4 | from typing import List, Tuple
5 |
6 |
7 | @dataclass
8 | class TagAndBox:
9 | word: str
10 | coords: Tuple[int, int]
11 |
12 |
13 | def add_tags_to_webpage(page, mind2web_action) -> Tuple[int, List[TagAndBox]]:
14 | """
15 | Add visual tags to a webpage, and find the tag # of the desired Mind2Web action.
16 | A visual tag looks like [12] and is superimposed on buttons, textboxes, links, etc.
17 | """
18 |
19 | attrss = [
20 | json.loads(pos_candidate["attributes"])
21 | for pos_candidate in mind2web_action["pos_candidates"]
22 | ]
23 |
24 | els = []
25 | for attrs in attrss:
26 | cls = attrs.get("class", None)
27 | tag_id = attrs.get("id", None)
28 | bbox_rect = [float(i) for i in attrs["bounding_box_rect"].split(",")]
29 | els.append({"cls": cls, "tag_id": tag_id, "bbox_rect": bbox_rect})
30 |
31 | raw_html = mind2web_action["raw_html"]
32 |
33 | # print(f"Looking for element with class {cls}
34 | # and id {tag_id} and bbox {bbox_rect}")
35 |
36 | curr_dir = os.path.dirname(os.path.realpath(__file__))
37 | with open(f"{curr_dir}/tagUtils.js", "r") as f:
38 | page.evaluate(f.read())
39 |
40 | try:
41 | to_eval = f"tagifyWebpage({json.dumps(els)},true,{json.dumps(raw_html)})"
42 | gt_tag_id, el_tags = page.evaluate(to_eval)
43 | except Exception as e:
44 | raise e
45 | raise Exception(f"Error evaluating:\n{to_eval}\n{e}")
46 |
47 | assert isinstance(gt_tag_id, int), f"gt_tag_id is {json.dumps(gt_tag_id)}!"
48 |
49 | return gt_tag_id, [TagAndBox(**i) for i in el_tags]
50 |
51 |
52 | if __name__ == "__main__":
53 | from playwright.sync_api import sync_playwright
54 |
55 | with sync_playwright() as p:
56 | browser = p.chromium.launch(headless=False)
57 | page = browser.new_page()
58 |
59 | # get path to current file
60 | curr_dir = os.path.dirname(os.path.realpath(__file__))
61 | example_mhtml_path = f"{curr_dir}/../../data/mind2web_example.mhtml"
62 | example_json_path = f"{curr_dir}/../../data/mind2web_example.json"
63 | page.goto(f"file://{example_mhtml_path}")
64 | with open(example_json_path, "r") as f:
65 | dummy_action = json.load(f)
66 |
67 | try:
68 | print(add_tags_to_webpage(page, dummy_action))
69 | except Exception as e:
70 | print(e)
71 |
72 | input("Press enter to stop the program")
73 |
--------------------------------------------------------------------------------
/src/llama2d/tagging/tagUtils.js:
--------------------------------------------------------------------------------
1 | const assert = (condition, message) => {
2 | if(!condition) throw new Error(message)
3 | }
4 |
5 | const elIsClean = (el) => {
6 | if(el.style && el.style.display === 'none') return false
7 | if(el.hidden) return false
8 | if(el.disabled) return false
9 |
10 | const rect = el.getBoundingClientRect()
11 | if(rect.width === 0 || rect.height === 0) return false
12 |
13 | if(el.tagName === 'SCRIPT') return false
14 | if(el.tagName === 'STYLE') return false
15 |
16 | return true;
17 | }
18 |
19 | const isNotCovered = (el) => {
20 | const rect = el.getBoundingClientRect()
21 | const elCenter = [rect.left + rect.width/2, rect.top + rect.height/2];
22 |
23 | const elAtPoint = document.elementFromPoint(...elCenter)
24 |
25 | return el.contains(elAtPoint)
26 | }
27 | const isInteractiveCursor = (el) => ["pointer","text"].includes(el.computedStyleMap().get("cursor"))
28 |
29 | const inputs = ['a', 'button', 'textarea', 'select', 'details', 'label']
30 | const _isInteractible = (el) => (inputs.includes(el.tagName.toLowerCase()) ||
31 | (el.tagName.toLowerCase() === 'input' && el.type !== 'hidden') ||
32 | el.role === 'button' ||
33 | isInteractiveCursor(el) && !(el.parentElement && isInteractiveCursor(el.parentElement))) && isNotCovered(el)
34 |
35 | const isInteractible = (el) => _isInteractible(el) || el.parentElement && isInteractible(el.parentElement);
36 |
37 | const emptyTagWhitelist = ["input","textarea","select","button","a"]
38 | const isEmpty = (el) => {
39 |
40 | const bbox = el.getBoundingClientRect()
41 | // check if center of element is offscreen
42 | const center = [bbox.left + bbox.width/2, bbox.top + bbox.height/2]
43 | if(center[0] < 0 || center[0] > window.innerWidth || center[1] < 0 || center[1] > window.innerHeight) return true
44 |
45 | const tagName = el.tagName.toLowerCase()
46 | if(emptyTagWhitelist.includes(tagName)) return false
47 | if("innerText" in el && el.innerText.trim().length === 0) {
48 | // look for svg or img in the element
49 | const svg = el.querySelector("svg")
50 | const img = el.querySelector("img")
51 |
52 | if(svg || img) return false
53 |
54 | return true
55 | }
56 |
57 | return false
58 | }
59 |
60 | window.tagifyWebpageOneEl = (gtCls, gtId, gtBbox) => tagifyWebpage([{
61 | cls: gtCls,
62 | tag_id: gtId,
63 | bbox_rect: gtBbox
64 | }])
65 |
66 | const convertHoverToCls = () => {
67 | [...document.styleSheets].forEach(sheet=>{
68 | try{
69 | [...sheet.cssRules].forEach(rule=>{
70 | if(rule.selectorText) rule.selectorText = rule.selectorText.replace(/:hover/g,".mind2web-hover")
71 | })
72 | } catch(err){
73 | if(!(err+"").includes("Cannot access rules")) throw err;
74 | }
75 | })
76 | }
77 |
78 | window.tagifyWebpage = (gtEls,useGt=true,rawHtml="") =>{
79 |
80 | // Populate mHTML input values with raw_html from action JSON
81 | if(rawHtml.length>0){
82 | // parse html
83 | const parser = new DOMParser();
84 | const htmlDoc = parser.parseFromString(rawHtml, 'text/html');
85 |
86 | [...htmlDoc.querySelectorAll("[input_value], [input_checked]")].forEach(el=>{
87 | if(el.attributes.bounding_box_rect.value==="-1,-1,-1,-1") return;
88 |
89 | // get the position of the input on the page
90 | const classNames = [...el.classList].map(cls=>"."+cls).join("");
91 |
92 | const id = [el.id].filter(e=>e).map(id=>"#"+id)
93 | console.log(el.id,el.attributes.id)
94 | const tag = el.tagName.toLowerCase();
95 |
96 | const selector = `${tag}${classNames}${id}`;
97 |
98 | const fragmentMatches = htmlDoc.querySelectorAll(selector)
99 | const numMatchesInFragment = fragmentMatches.length;
100 | const fragmentIdx = [...fragmentMatches].indexOf(el);
101 |
102 | if(fragmentIdx<0) throw new Error("Could not find element with its own selector");
103 |
104 | const docMatches = document.querySelectorAll(selector);
105 | if(docMatches.length != fragmentMatches.length) throw new Error(`Mismatched lengths: ${docMatches.length} vs. ${fragmentMatches.length}: ${selector}`);
106 | const docEl = docMatches[fragmentIdx];
107 |
108 | // if has input_value, set docEl.value
109 | if("input_value" in el.attributes) {
110 | docEl.value = el.attributes.input_value.value;
111 | }
112 | else if("input_checked" in el.attributes) docEl.checked = el.attributes.input_checked.value;
113 | else {
114 | throw new Error("didn't find things");
115 | }
116 |
117 | })
118 | }
119 |
120 | convertHoverToCls();
121 |
122 | let numTagsSoFar = 0;
123 |
124 | let gtCandidates = [];
125 |
126 | let elTags = [];
127 |
128 | const validEls = new Set();
129 | const hasValidParent = el => validEls.has(el) || (el.parentElement && hasValidParent(el.parentElement));
130 |
131 | for(let el of document.body.querySelectorAll("*")){
132 |
133 | const stringifiedClasses = el.classList.toString();
134 |
135 | const gtMatches = gtEls.filter(({cls,tag_id,bbox_rect})=>(cls===null || stringifiedClasses===cls) && (tag_id===null || el.id === tag_id));
136 | const isGt = gtMatches.length > 0;
137 |
138 | el.classList.add("mind2web-hover")
139 |
140 | const empty = isEmpty(el);
141 | const dirty = !elIsClean(el);
142 | const uninteractible = !isInteractible(el);
143 | const validParent = hasValidParent(el)
144 |
145 | el.classList.remove("mind2web-hover")
146 |
147 | if(logElements.includes(el)) {
148 | console.log(`Logging ${el.innerText}, ${empty},${dirty},${uninteractible},${validParent}`)
149 | }
150 |
151 | const isGood = !(empty || dirty || uninteractible) || validParent;
152 | if(isGood) validEls.add(el);
153 |
154 | if(!isGood){
155 | if(isGt) console.log("Skipping!", el,`empty: ${empty}, dirty: ${dirty}, uninteractible: ${uninteractible}, validParent: ${validParent}`);
156 | continue;
157 | }
158 |
159 | const elBbox = el.getBoundingClientRect();
160 | const elCenter = [elBbox.left + elBbox.width/2, elBbox.top + elBbox.height/2];
161 |
162 | // get closest el in elTags
163 | const [closestDist,closestEl] = elTags.map(({coords})=>coords).map(([x,y])=>Math.sqrt((x-elCenter[0])*(x-elCenter[0]) + (y-elCenter[1])*(y-elCenter[1]))).reduce((acc,cur,i)=>cur 5;
165 |
166 | if(isGt){
167 | const gtTagId = useNewTag ? numTagsSoFar : closestEl;
168 | console.log("Tagging GT!", el);
169 | gtCandidates.push({
170 | el,
171 | tagId: gtTagId,
172 | stats:{empty, dirty, uninteractible, validParent},
173 | gtEls: gtMatches
174 | });
175 | }
176 |
177 | if(useNewTag){
178 |
179 | const tagStr = `[${numTagsSoFar}] `
180 |
181 | elTags.push({
182 | word:tagStr,
183 | coords:elCenter,
184 | })
185 | validEls.add(el);
186 |
187 | numTagsSoFar++;
188 | }
189 | }
190 | console.log(validEls)
191 |
192 | if(!useGt) return [null, elTags];
193 |
194 |
195 |
196 | const validGtCandidates = gtCandidates.filter(({el, stats}) => {
197 | const {empty, dirty, uninteractible, validParent} = stats
198 | return !empty && !dirty && !uninteractible || validParent
199 | })
200 |
201 | if(validGtCandidates.length === 0){
202 | console.log("No GT found!")
203 | // show stats for all candidates
204 | console.log(gtCandidates.map(({stats})=>`empty: ${stats.empty}, dirty: ${stats.dirty}, uninteractible: ${stats.uninteractible}`).join("\n"));
205 | throw new Error(`No GT found!\n${gtCandidates.map(({el})=>el.innerText).join("\n")}`)
206 | }
207 |
208 | if(validGtCandidates.length > 1){
209 | console.log("Multiple GTs found!")
210 | }
211 |
212 | const elementDistancesDeep = validGtCandidates.map(({el,gtEls}) => gtEls.map(({bbox_rect})=>bbox_rect).map((gtBbox)=>{
213 | const rect = el.getBoundingClientRect()
214 | const [x,y,w,h] = gtBbox;
215 | const gtCenter = [x+w/2, y+h/2];
216 | const elCenter = [rect.left + rect.width/2, rect.top + rect.height/2];
217 |
218 | const dx = gtCenter[0] - elCenter[0];
219 | const dy = gtCenter[1] - elCenter[1];
220 | return Math.sqrt(dx*dx + dy*dy)
221 | }))
222 |
223 | const elementDistances = elementDistancesDeep.map((distances)=>Math.min(...distances));
224 |
225 | const closestDistance = Math.min(...elementDistances);
226 | const closestElement = validGtCandidates[elementDistances.indexOf(closestDistance)];
227 |
228 | if(closestDistance > 20) {
229 | throw new Error(`Closest element is ${closestDistance}px away! Bboxes are ${validGtCandidates.map(({el})=>el.getBoundingClientRect()).map(({left, top, width, height})=>[left, top, width, height])})}}`);
230 | }
231 |
232 |
233 | return [closestElement.tagId, elTags];
234 | }
235 | logElements=[]; // some elements where you can check your classification performance. useful for debugging.
236 |
237 | window.showTag = coords => {
238 | myBox = document.createElement("div")
239 | myBox.style.width = "10px";
240 | myBox.style.height = "10px";
241 | myBox.style.background = "red";
242 | myBox.style.position = "absolute";
243 | myBox.style.top = coords[1]-5+"px";
244 | myBox.style.left = coords[0]-5+"px";
245 | myBox.textContent = "";
246 | myBox.style.zIndex = 2000
247 | document.body.appendChild(myBox)
248 | }
249 |
250 | window.demo = () => tagifyWebpage([],false)[1].forEach(({coords})=>showTag(coords))
251 | 1;
252 |
253 |
--------------------------------------------------------------------------------
/src/llama2d/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from .ocr import Llama2dScreen
2 | from .url_to_llama_input import Llama2dTokenizer
3 | from .viz_pt_input import debug_dataset
4 |
--------------------------------------------------------------------------------
/src/llama2d/vision/learn_mlp_on_embeds.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from tqdm import tqdm
5 |
6 | from transformers.models.llama.sam_embed import PositionEmbeddingRandom
7 |
8 |
9 | class CoordMlp(nn.Module):
10 | def __init__(self, n: int, hidden: int):
11 | super().__init__()
12 | self.embed = PositionEmbeddingRandom(n, torch_dtype=torch.float32)
13 | self.a = nn.Linear(n * 2, hidden)
14 | self.b = nn.Linear(hidden, 1)
15 |
16 | self.n = n
17 | self.hidden = hidden
18 |
19 | def forward(self, x):
20 | b, c, d = x.shape
21 | assert d == 2, "Coords are not 2d"
22 |
23 | max_y_el = torch.argmax(x[:, :, 1], dim=1)
24 |
25 | pos_embeds = self.embed(x).squeeze(1)
26 | assert pos_embeds.shape == (
27 | b,
28 | c,
29 | self.n * 2,
30 | ), f"Pos_embeds are {pos_embeds.shape}. vs. {(b,c,self.n*2)}"
31 |
32 | logits = self.b(F.relu(self.a(pos_embeds)))
33 |
34 | preds = logits.squeeze(dim=2)
35 | loss = F.cross_entropy(preds, F.one_hot(max_y_el).to(torch.float32))
36 |
37 | return loss
38 |
39 |
40 | def learn_mlp_for_top_point():
41 | rand_points = torch.rand((100, 50, 2))
42 |
43 | model = CoordMlp(100, 100)
44 | params = model.parameters()
45 | lr = 3e-2
46 | optimizer = torch.optim.SGD(params, lr=lr)
47 |
48 | epochs = 500
49 | for epoch in tqdm(range(epochs)):
50 | loss = model(rand_points)
51 |
52 | print(loss.item(), loss.shape)
53 |
54 | optimizer.zero_grad()
55 | loss.backward()
56 | optimizer.step()
57 |
58 |
59 | if __name__ == "__main__":
60 | learn_mlp_for_top_point()
61 |
--------------------------------------------------------------------------------
/src/llama2d/vision/ocr.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field, replace
2 | from typing import List, Optional, Tuple
3 |
4 | from google.cloud import vision
5 |
6 | from llama2d.constants import SCREEN_RESOLUTION, SECRETS_FILE
7 |
8 |
9 | @dataclass
10 | class ImageAnnotation:
11 | text: str # the word
12 | midpoint: Tuple[float, float] # the UNNORMALIZED midpoint of the word, (X,Y)
13 | midpoint_normalized: Tuple[
14 | float, float
15 | ] # the normalized midpoint between 0 - 1 (X,Y)
16 |
17 |
18 | @dataclass
19 | class Llama2dScreen:
20 | full_text: str = "" # full text
21 | orig_text_dims: Tuple[float, float] = (
22 | 1.0,
23 | 1.0,
24 | ) # the dimension of the *TEXT PORTION* of the image
25 |
26 | words: List[ImageAnnotation] = field(
27 | default_factory=list
28 | ) # a list of words and their midpoints
29 |
30 | def __add__(self, other):
31 | assert self.orig_text_dims == other.orig_text_dims
32 | return replace(self, words=self.words + other.words)
33 |
34 | def push_word(
35 | self,
36 | word: str,
37 | # must use exactly one
38 | # all 4 corners
39 | xyxy: Optional[Tuple[float, float, float, float]] = None,
40 | # midpoint
41 | xy: Optional[Tuple[float, float, float, float]] = None,
42 | ):
43 | new = self.concat_word(word=word, xyxy=xyxy, xy=xy)
44 |
45 | self.words = new.words
46 | self.full_text = new.full_text
47 |
48 | def concat_word(
49 | self,
50 | word: str,
51 | # must use exactly one
52 | # all 4 corners
53 | xyxy: Optional[Tuple[float, float, float, float]] = None,
54 | # midpoint
55 | xy: Optional[Tuple[float, float, float, float]] = None,
56 | ):
57 | full_text = self.full_text
58 | words = self.words
59 |
60 | if len(words) > 0:
61 | full_text += " "
62 | full_text += word
63 |
64 | assert (xyxy is None) != (
65 | xy is None
66 | ), "You should specify xy (midpoint) xor xyxy (corners)."
67 | if xy is None:
68 | x = (xyxy[0] + xyxy[2]) / 2
69 | y = (xyxy[1] + xyxy[3]) / 2
70 | xy = (x, y)
71 |
72 | x, y = xy
73 | w, h = self.orig_text_dims
74 | xy_norm = (x / w, y / h)
75 |
76 | new_ann = ImageAnnotation(text=word, midpoint=xy, midpoint_normalized=xy_norm)
77 | words = words + [new_ann]
78 |
79 | return replace(self, words=words, full_text=full_text)
80 |
81 | def __getitem__(self, key: slice):
82 | assert type(key) == slice, "__getitem__ only supports slice right now"
83 | words = self.words[key]
84 |
85 | full_text = " ".join([word.text for word in words])
86 |
87 | return replace(self, words=words, full_text=full_text)
88 |
89 |
90 | width, height = SCREEN_RESOLUTION
91 |
92 |
93 | class ImageAnnotator:
94 | def __init__(self, credentials=SECRETS_FILE):
95 | if not credentials.exists():
96 | raise ValueError(
97 | f"Place the Google Cloud credentials file in {credentials}"
98 | )
99 |
100 | self.client = vision.ImageAnnotatorClient.from_service_account_file(credentials)
101 | self.__features = [vision.Feature(type_=vision.Feature.Type.TEXT_DETECTION)]
102 |
103 | def __call__(self, path):
104 | with open(path, "rb") as image_file:
105 | content = image_file.read()
106 |
107 | image = vision.Image(content=content)
108 | request = vision.AnnotateImageRequest(image=image, features=self.__features)
109 | res = self.client.annotate_image(request)
110 |
111 | full_text = res.full_text_annotation.text
112 |
113 | annotations = res.text_annotations
114 |
115 | annotations_normed = Llama2dScreen(
116 | full_text=full_text,
117 | orig_text_dims=SCREEN_RESOLUTION,
118 | )
119 | for text in annotations[1:]:
120 | xs = [vertex.x for vertex in text.bounding_poly.vertices]
121 | ys = [vertex.y for vertex in text.bounding_poly.vertices]
122 |
123 | prev_len = len(annotations_normed.words)
124 | annotations_normed.push_word(
125 | word=text.description, xyxy=[min(xs), min(ys), max(xs), max(ys)]
126 | )
127 | assert len(annotations_normed.words) == prev_len + 1
128 |
129 | # optionally, sort the words by midpoint
130 | annotations_normed.words = list(
131 | sorted(
132 | annotations_normed.words,
133 | key=lambda x: (x.midpoint_normalized[1], x.midpoint_normalized[0]),
134 | )
135 | )
136 |
137 | return annotations_normed
138 |
--------------------------------------------------------------------------------
/src/llama2d/vision/render_dataset.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/src/llama2d/vision/render_dataset.py
--------------------------------------------------------------------------------
/src/llama2d/vision/take_screenshot.py:
--------------------------------------------------------------------------------
1 | from urllib.parse import urlparse
2 |
3 | from playwright.sync_api import sync_playwright
4 |
5 | from llama2d.constants import SCREEN_RESOLUTION
6 |
7 | width, height = SCREEN_RESOLUTION
8 |
9 |
10 | def take_screenshot(page, url, save_path="image_of_website.png"):
11 | if page is None:
12 | with sync_playwright() as p:
13 | # Using the Chromium browser but you can also use 'firefox' or 'webkit'
14 | browser = p.chromium.launch()
15 | page = browser.new_page()
16 |
17 | page.set_extra_http_headers(
18 | {
19 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)"
20 | " AppleWebKit/537.36 (KHTML, like Gecko)"
21 | " Chrome/116.0.0.0 Safari/537.36"
22 | }
23 | )
24 |
25 | return take_screenshot(page, url, save_path)
26 |
27 | if url is not None:
28 | print("going to " + url)
29 | page.goto(url)
30 |
31 | # Set the viewport height to be the height of the content
32 | content_height = page.evaluate("document.documentElement.scrollHeight")
33 | thresholded_height = min(content_height, height)
34 |
35 | page.set_viewport_size({"width": width, "height": thresholded_height})
36 |
37 | page.screenshot(path=save_path)
38 |
39 |
40 | def extract_domain(url):
41 | parsed_uri = urlparse(url)
42 | domain = "{uri.netloc}".format(uri=parsed_uri)
43 | domain = domain.replace(".", "_")
44 | return domain
45 |
46 |
47 | if __name__ == "__main__":
48 | target_url = "https://www.mytampahomeagent.com/"
49 | # target_url = "https://www.reddit.com"
50 | path = "./extracted/" + extract_domain(target_url) + ".png"
51 | print(path)
52 |
53 | take_screenshot(url=target_url, save_path=path)
54 |
--------------------------------------------------------------------------------
/src/llama2d/vision/url_to_llama_input.py:
--------------------------------------------------------------------------------
1 | """
2 | feature_extraction.py
3 | Extract features using the tokenizer, including text and image
4 | """
5 |
6 | import tempfile
7 | from pathlib import Path
8 | from typing import Dict, List, Optional
9 |
10 | import torch
11 |
12 | from llama2d.constants import MAX_PAGE_LEN, MAX_SEQ_LEN, MAX_TAGS_LEN
13 | from llama2d.tagging.add_tags_to_page import TagAndBox
14 | from llama2d.vision.ocr import ImageAnnotator, Llama2dScreen
15 | from llama2d.vision.take_screenshot import extract_domain, take_screenshot
16 | from transformers import LlamaTokenizer
17 |
18 |
19 | class Llama2dTokenizer(object):
20 | def __init__(
21 | self,
22 | model_path: str = "decapoda-research/llama-7b-hf",
23 | separator_id=None,
24 | label_mask_id=-100,
25 | mask_out_body=True,
26 | ):
27 | self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
28 |
29 | if not separator_id:
30 | self.__separator_id = (
31 | self.tokenizer.unk_token_id
32 | ) # this should be kept at 0 for most uses, as it is a special token
33 | else:
34 | self.__separator_id = separator_id
35 |
36 | self.__label_mask_id = label_mask_id
37 | self.__mask_out_body = mask_out_body
38 |
39 | def process(
40 | self, prompt: str, screen: Llama2dScreen, output: str
41 | ) -> Dict[str, torch.Tensor]:
42 | # output tokens
43 | output_tokens = self.tokenizer.tokenize(output)
44 | # and use (-1,-1) for the 2d embeddings for the prompt
45 | output_tokens_locs = [(-1, -1) for _ in range(len(output_tokens))]
46 |
47 | # extract tokens
48 | image_tokens = [self.tokenizer.tokenize(i.text) for i in screen.words]
49 | # and, correspondingly, get their midpoints. If a word is broken up into
50 | # multiple pieces by the BPE, we return multiple of the word's location
51 | image_token_locs = [
52 | [annot.midpoint_normalized for j in range(len(i))]
53 | for i, annot in zip(image_tokens, screen.words)
54 | ]
55 |
56 | # extract tokens from the prompt
57 | prompt_tokens = self.tokenizer.tokenize(prompt)
58 | # and use (-1,-1) for the 2d embeddings for the prompt
59 | prompt_tokens_locs = [(-1, -1) for _ in range(len(prompt_tokens))]
60 |
61 | # and now we stich it together
62 | input_ids = (
63 | [self.tokenizer.bos_token_id]
64 | + self.tokenizer.convert_tokens_to_ids(prompt_tokens) # bos token
65 | + [self.__separator_id]
66 | + self.tokenizer.convert_tokens_to_ids( # seperating prompt with context
67 | [j for i in image_tokens for j in i]
68 | )
69 | + [self.__separator_id]
70 | + self.tokenizer.convert_tokens_to_ids( # seperating context with answer
71 | output_tokens
72 | )
73 | )
74 |
75 | # mask out the prompt
76 | label_ids = (
77 | [self.tokenizer.bos_token_id]
78 | + [-100 for _ in range(len(prompt_tokens))] # bos token
79 | + [-100] # we don not want to predict the prompt
80 | + [ # seperating prompt with context
81 | -100 if self.__mask_out_body else k
82 | for k in self.tokenizer.convert_tokens_to_ids(
83 | [j for i in image_tokens for j in i]
84 | )
85 | ]
86 | + [-100]
87 | + self.tokenizer.convert_tokens_to_ids( # seperating context with answer
88 | output_tokens
89 | )
90 | )
91 |
92 | # and we switch together the image locs
93 | input_coords = (
94 | [(-1, -1)]
95 | + prompt_tokens_locs # bos token
96 | + [(-1, -1)]
97 | + [j for i in image_token_locs for j in i] # for the separator
98 | + [(-1, -1)]
99 | + output_tokens_locs # for the separator
100 | )
101 | input_coords = torch.tensor(input_coords)
102 | input_ids = torch.tensor(input_ids)
103 | label_ids = torch.tensor(label_ids)
104 |
105 | attention_mask = torch.ones_like(input_ids)
106 |
107 | assert (
108 | len(input_ids) == len(label_ids) == len(input_coords) == len(attention_mask)
109 | ), (
110 | f"len(input_ids) = {len(input_ids)}, len(label_ids) = {len(label_ids)},"
111 | f" len(input_coords) = {len(input_coords)},"
112 | f" len(attention_mask) = {len(attention_mask)}"
113 | )
114 |
115 | # pad or truncate
116 | if len(input_ids) > MAX_SEQ_LEN:
117 | input_ids = input_ids[:MAX_SEQ_LEN]
118 | label_ids = label_ids[:MAX_SEQ_LEN]
119 | input_coords = input_coords[:MAX_SEQ_LEN]
120 | attention_mask = attention_mask[:MAX_SEQ_LEN]
121 | elif len(input_ids) < MAX_SEQ_LEN:
122 | # right-pad label_ids with -100,
123 | # input_coords with (-1,-1), and input_ids with 0
124 | input_ids = torch.cat(
125 | [input_ids, torch.zeros(MAX_SEQ_LEN - len(input_ids), dtype=torch.long)]
126 | )
127 | label_ids = torch.cat(
128 | [
129 | label_ids,
130 | torch.ones(MAX_SEQ_LEN - len(label_ids), dtype=torch.long)
131 | * self.__label_mask_id,
132 | ]
133 | )
134 | input_coords = torch.cat(
135 | [input_coords, torch.ones(MAX_SEQ_LEN - len(input_coords), 2) * -1]
136 | ).to(torch.float16)
137 | attention_mask = torch.cat(
138 | [
139 | attention_mask,
140 | torch.zeros(MAX_SEQ_LEN - len(attention_mask), dtype=torch.long),
141 | ]
142 | )
143 |
144 | # assert all tensors are the desired length
145 | assert len(input_ids) == MAX_SEQ_LEN, f"len(input_ids) = {len(input_ids)}"
146 | assert len(label_ids) == MAX_SEQ_LEN, f"len(label_ids) = {len(label_ids)}"
147 | assert (
148 | len(input_coords) == MAX_SEQ_LEN
149 | ), f"len(input_coords) = {len(input_coords)}"
150 | assert (
151 | len(attention_mask) == MAX_SEQ_LEN
152 | ), f"len(attention_mask) = {len(attention_mask)}"
153 |
154 | # return output
155 | return {
156 | "input_ids": input_ids.to(torch.long),
157 | "coords": input_coords.to(torch.float16),
158 | "labels": label_ids.to(torch.long),
159 | "attention_mask": attention_mask.to(torch.long),
160 | }
161 |
162 |
163 | class Llama2dWebsiteFeatureExtractor(object):
164 | def __init__(
165 | self,
166 | **kwargs,
167 | ): # -100 is default
168 | self.tokenizer = Llama2dTokenizer(**kwargs)
169 | self.__annotator = ImageAnnotator()
170 |
171 | def process(
172 | self, prompt, page, output, tags_and_boxes: Optional[List[TagAndBox]] = None
173 | ):
174 | # run OCR
175 | annotations = self.__annotator(page)
176 | annotations = annotations[:MAX_PAGE_LEN]
177 |
178 | if tags_and_boxes is not None:
179 | for tag in tags_and_boxes[:MAX_TAGS_LEN]:
180 | annotations = annotations.concat_word(word=tag.word, xy=tag.coords)
181 |
182 | return self.tokenizer.process(prompt, annotations, output)
183 |
184 | def create_inference_data(self, page, prompt, uri):
185 | with tempfile.TemporaryDirectory() as tmpdir:
186 | path = Path(tmpdir) / extract_domain(uri) + ".png"
187 | # html = os.path.join(tmpdir, extract_domain(uri)+".mhtml")
188 |
189 | # driver = webdriver.Chrome()
190 | # driver.get(uri)
191 |
192 | # # Execute Chrome dev tool command to obtain the mhtml file
193 | # res = driver.execute_cdp_cmd('Page.captureSnapshot', {})
194 |
195 | take_screenshot(page=page, url=uri, save_path=path)
196 | return self.__process(prompt, path, "")
197 |
198 | def from_training_data(self, page, html, uri):
199 | with tempfile.TemporaryDirectory() as tmpdir:
200 | path = Path(tmpdir) / extract_domain(uri) + ".png"
201 | prompt, label = take_screenshot(page=page, url=html, save_path=path)
202 | return self.__process(prompt, path, label)
203 |
--------------------------------------------------------------------------------
/src/llama2d/vision/viz_pt_input.py:
--------------------------------------------------------------------------------
1 | # use headless
2 | import matplotlib
3 | from matplotlib import pyplot as plt
4 | from playwright.sync_api import sync_playwright
5 |
6 | from transformers import LlamaTokenizer
7 |
8 | matplotlib.use("Agg")
9 |
10 | # noqa
11 | """
12 | pytorch input is a dictionary of the form
13 | {
14 | "input_ids": [ids of the tokens, from 0 to vocab_size-1],
15 | "attention_mask": [0 for padding, 1 for non-padding],
16 | "coords": [x,y] for each token - normalized to [0,1] for tokens with coords, and (-1,-1) for tokens without coords
17 | "labels": [ids of the tokens, from 0 to vocab_size-1] - same as input_ids, but with -100 for tokens that should not be predicted # noqa
18 | }
19 | """
20 |
21 |
22 | model_path = "decapoda-research/llama-7b-hf"
23 | tokenizer = LlamaTokenizer.from_pretrained(model_path)
24 |
25 | # print(tokenizer.convert_ids_to_tokens([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]))
26 |
27 |
28 | def viz_pt_input(pt_input):
29 | input_ids = pt_input["input_ids"]
30 | attention_mask = pt_input["attention_mask"]
31 | coords = pt_input["coords"]
32 | # labels = pt_input["labels"]
33 |
34 | # graph tokens with coords in a matplotlib figure
35 | # print the tokens without coords
36 |
37 | # every word has a few tokens with the same coord.
38 | # we should generate the word, turn it into a string, then plot it at the coord
39 |
40 | without_coords = [
41 | input_ids[i]
42 | for i in range(len(input_ids))
43 | if coords[i][0] == -1 and attention_mask[i] == 1
44 | ]
45 |
46 | with_coords = [
47 | (input_ids[i], coords[i])
48 | for i in range(len(input_ids))
49 | if coords[i][0] != -1 and attention_mask[i] == 1
50 | ]
51 | # split with_coords into words -
52 | # where a word is a list of tokens with the same coord
53 | words = []
54 | current_word = []
55 | current_coord = None
56 | for token in with_coords:
57 | if current_coord is None or (token[1] != current_coord).any():
58 | if len(current_word) > 0:
59 | words.append(current_word)
60 | current_word = []
61 | current_coord = token[1]
62 | current_word.append(token)
63 | if len(current_word) > 0:
64 | words.append(current_word)
65 |
66 | # plot with_coords as text on a matplotlib figure
67 |
68 | fig = plt.figure()
69 | # make fig very big
70 | fig.set_size_inches(20, 20)
71 |
72 | ax = fig.add_subplot(111)
73 | ax.set_xlim([0, 1])
74 | ax.set_ylim([0, 1])
75 | ax.set_aspect("equal")
76 |
77 | for word in words:
78 | word_str = "".join(tokenizer.convert_ids_to_tokens([i[0] for i in word]))
79 | word_coord = word[0][1]
80 | # very small text
81 | ax.text(
82 | word_coord[0],
83 | 1 - word_coord[1],
84 | word_str,
85 | fontsize=10,
86 | horizontalalignment="center",
87 | verticalalignment="center",
88 | )
89 |
90 | # save the figure
91 | fig.savefig("tokens_with_coords.png")
92 |
93 | normal_str = "".join(tokenizer.convert_ids_to_tokens(input_ids))
94 | print(normal_str)
95 | print()
96 |
97 | # as a str:
98 | without_coords_str = "".join(tokenizer.convert_ids_to_tokens(without_coords))
99 | print(without_coords_str)
100 |
101 | print("")
102 |
103 |
104 | from torch.utils.data import Dataset
105 |
106 |
107 | def debug_dataset(dataset: Dataset):
108 | pt_input = None
109 |
110 | action = None
111 | i = 0
112 | while i < len(dataset):
113 | pt_input = dataset[i]
114 | if pt_input is not None:
115 | viz_pt_input(pt_input)
116 | action = input("Continue? [y/n/debug/]")
117 | if action == "n":
118 | break
119 | if action.startswith("d"):
120 | import pdb
121 |
122 | pdb.set_trace()
123 | # check if action is an integer - then skip that many
124 | if action.isdigit():
125 | print(f"Skipping {action}...")
126 | i += int(action)
127 | continue
128 | i += 1
129 |
130 | assert pt_input is not None, "Didn't find any valid dataset entries!"
131 | if action != "n":
132 | input("Dataset has ended. Press enter to continue program.")
133 |
134 |
135 | if __name__ == "__main__":
136 | from llama2d.datasets.mind2web import Mind2webDataset
137 |
138 | with sync_playwright() as playwright:
139 | dataset = HuggingFaceDataset("llama2d/llama2d-mind2web", split="train")
140 | for entry in dataset:
141 | assert (
142 | entry["labels"] > 0
143 | ).any(), f"No labels in entry! {entry['labels'].tolist()}"
144 |
145 | dataset = Mind2webDataset(playwright=playwright, headless=False)
146 |
147 | debug_dataset(dataset)
148 |
--------------------------------------------------------------------------------
/src/llama2d/vision/webutils/chromedriver:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/src/llama2d/vision/webutils/chromedriver
--------------------------------------------------------------------------------
/src/llama2d/vision/webutils/playwright_browser.py:
--------------------------------------------------------------------------------
1 | import nest_asyncio
2 | from langchain.agents import AgentType, initialize_agent
3 | from langchain.agents.agent_toolkits import PlayWrightBrowserToolkit
4 | from langchain.chat_models import ChatAnthropic
5 | from langchain.tools.playwright.utils import create_async_playwright_browser
6 |
7 | nest_asyncio.apply()
8 | DEFAULT_STARTER_URL = {
9 | "url": "https://web.archive.org/web/20230428131116/https://www.cnn.com/world"
10 | }
11 |
12 |
13 | async def init_agent_chain(starter_url, llm):
14 | # tools
15 | toolkit = PlayWrightBrowserToolkit.from_browser(async_browser=async_browser)
16 | tools = toolkit.get_tools()
17 | tools_by_name = {tool.name: tool for tool in tools}
18 | navigate_tool = tools_by_name["navigate_browser"]
19 | get_elements_tool = tools_by_name["get_elements"] #
20 |
21 | await navigate_tool.arun(starter_url)
22 | # action
23 | # The browser is shared across tools, so the agent can interact in a stateful manner
24 | await get_elements_tool.arun(
25 | {"selector": ".container__headline", "attributes": ["innerText"]}
26 | )
27 |
28 | agent_chain = initialize_agent(
29 | tools,
30 | llm,
31 | agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
32 | verbose=True,
33 | )
34 | return agent_chain
35 |
36 |
37 | async def run(agent_chain, prompt):
38 | result = await agent_chain.arun(prompt)
39 | return result
40 |
41 |
42 | if __name__ == "__main__":
43 | async_browser = create_async_playwright_browser()
44 | llm = ChatAnthropic(temperature=0) # or any other LLM, e.g., ChatOpenAI(), OpenAI()
45 | init_agent_chain(async_browser, llm)
46 |
--------------------------------------------------------------------------------
/src/llama2d/vision/webutils/selenium_action_chain.py:
--------------------------------------------------------------------------------
1 | # import webdriver
2 | from selenium import webdriver
3 |
4 | # import Action chains
5 | from selenium.webdriver.common.action_chains import ActionChains
6 |
7 |
8 | def run(driver):
9 | menu = driver.find_element_by_css_selector(".nav")
10 | hidden_submenu = driver.find_element_by_css_selector(".nav # submenu1")
11 |
12 | ActionChains(driver).move_to_element(menu).click(hidden_submenu).perform()
13 | # Or actions can be queued up one by one, then performed.:
14 |
15 | menu = driver.find_element_by_css_selector(".nav")
16 | hidden_submenu = driver.find_element_by_css_selector(".nav # submenu1")
17 |
18 | actions = ActionChains(driver)
19 | actions.move_to_element(menu)
20 | actions.click(hidden_submenu)
21 | actions.perform()
22 |
23 |
24 | # Project Example –
25 | # create webdriver object
26 | # get geeksforgeeks.org
27 | driver = webdriver.Chrome()
28 | driver.get("https://www.geeksforgeeks.org/")
29 | # get element
30 | element = driver.find_element_by_link_text("Courses")
31 | # create action chain object
32 | action = ActionChains(driver)
33 |
34 | # click the item
35 | action.click(on_element=element)
36 |
37 | # perform the operation
38 | action.perform()
39 |
40 |
41 | if "__main__" == __name__:
42 | # create webdriver object
43 | driver = webdriver.Firefox()
44 | # create action chain object
45 | action = ActionChains(driver)
46 | run(driver, action)
47 |
--------------------------------------------------------------------------------
/src/llama2d/vision/webutils/stacked_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Llama2D/llama2d/e28b97255d396c717fe183b96b802ff39ffd7e6d/src/llama2d/vision/webutils/stacked_image.png
--------------------------------------------------------------------------------
/src/llama2d/vision/webutils/stitch_webpage.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | import numpy as np
4 | from PIL import Image
5 | from selenium import webdriver
6 | from selenium.webdriver.common.by import By
7 |
8 | user_agent = (
9 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
10 | " (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
11 | )
12 |
13 | options = webdriver.ChromeOptions()
14 | options.add_argument(f"user-agent={user_agent}")
15 | options.add_argument("--disable-blink-features=AutomationControlled")
16 | options.add_argument("--disable-extensions")
17 | options.add_experimental_option("useAutomationExtension", False)
18 | options.add_experimental_option("excludeSwitches", ["enable-automation"])
19 |
20 | options.add_argument("--headless") # Optional: Run the browser in headless mode
21 | # DEFAULT_CHROMEDRIVER_PATH = 'chromedriver'
22 |
23 |
24 | def stitch(images):
25 | stacked_array = np.vstack(images)
26 | # Convert the NumPy array to a Pillow image
27 | image = Image.fromarray(stacked_array)
28 |
29 | # Save the image to a file
30 | image.save("stacked_image.png")
31 |
32 |
33 | def scrape_scroll(url):
34 | driver = webdriver.Chrome(options=options) # Make sure the path to
35 | # driver = uc.Chrome(headless=True, use_subprocess=False, option)
36 |
37 | driver.get(url)
38 | # Replace with the URL of the webpage you want to screenshot
39 | # Set the initial scroll height
40 | screenshots = []
41 | scroll_height = 0
42 | try:
43 | while True:
44 | total_height = driver.execute_script("return document.body.scrollHeight")
45 |
46 | driver.set_window_size(
47 | 1920, total_height
48 | ) # Adjust the window size to your liking
49 | screenshot = driver.find_element(By.TAG_NAME, "body").screenshot_as_png
50 |
51 | # print(type(screenshot))
52 | image = np.array(Image.open(io.BytesIO(screenshot)))
53 | print(image.shape)
54 | # with open('screenshot.png', 'wb') as f:
55 | # f.write(screenshot)
56 | screenshots.append(image)
57 | # Scroll down to the bottom of the page
58 | # Increment the scroll height
59 | scroll_height += 1
60 | driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
61 | # determine if this is end of page
62 | # Break the loop if we have reached the end of the page
63 | if scroll_height > 10: # You can adjust the number of scrolls as needed
64 | break
65 | except Exception:
66 | pass
67 |
68 | finally:
69 | print(f"Length of screenshots:{len(screenshots)}")
70 | stitch(screenshots)
71 | # Close the WebDriver
72 | driver.quit()
73 |
74 |
75 | if __name__ == "__main__":
76 | scrape_scroll("https://www.mytampahomeagent.com/")
77 |
--------------------------------------------------------------------------------
/src/llama2d/vision/webutils/web_to_action.py:
--------------------------------------------------------------------------------
1 | import faiss
2 | from langchain.agents import Tool
3 | from langchain.docstore import InMemoryDocstore
4 | from langchain.embeddings import OpenAIEmbeddings
5 | from langchain.tools.file_management.read import ReadFileTool
6 | from langchain.tools.file_management.write import WriteFileTool
7 | from langchain.utilities import SerpAPIWrapper
8 |
9 | # setup memory
10 | from langchain.vectorstores import FAISS
11 |
12 | # search agent
13 | search = SerpAPIWrapper()
14 | tools = [
15 | Tool(
16 | name="search",
17 | func=search.run,
18 | description="useful for when you need to answer questions about current events."
19 | " You should ask targeted questions",
20 | ),
21 | WriteFileTool(),
22 | ReadFileTool(),
23 | ]
24 |
25 |
26 | # Define your embedding model
27 | embeddings_model = OpenAIEmbeddings()
28 | # Initialize the vectorstore as empty
29 |
30 | embedding_size = 1536
31 | index = faiss.IndexFlatL2(embedding_size)
32 | vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
33 |
--------------------------------------------------------------------------------
/src/mhtml/download.js:
--------------------------------------------------------------------------------
1 | // playwright
2 |
3 | const playwright = require('playwright');
4 | const { setTimeout } = require('timers/promises');
5 |
6 | (async () => {
7 | const browser = await playwright.chromium.launch({headless:false});
8 | const context = await browser.newContext();
9 | const page = await context.newPage();
10 |
11 | await page.goto("https://google.com/");
12 | await setTimeout(20_000);
13 |
14 | const session = await page.context().newCDPSession(page)
15 | const doc = await session.send('Page.captureSnapshot', { format: 'mhtml' });
16 | console.log(doc.data);
17 |
18 | // save
19 | const {writeFileSync} = require('fs');
20 | writeFileSync('./finance.mhtml', doc.data);
21 |
22 | })();
--------------------------------------------------------------------------------
/src/mhtml/index.js:
--------------------------------------------------------------------------------
1 | const { Parser } = require("fast-mhtml");
2 | const p = new Parser({
3 | rewriteFn: (url)=>{
4 | console.log(url)
5 | return url
6 | // set base url to localhost:8080
7 | }, // default, urls are rewritten with this function
8 | });
9 |
10 |
11 | const {readFileSync,writeFileSync} = require('fs');
12 |
13 | const mhtmlFileContents = readFileSync('./finance.mhtml'); // read file
14 | const files = p.parse(mhtmlFileContents) // parse file
15 | .rewrite() // rewrite all links
16 | .spit(); // return all content
17 |
18 | console.log(files)
19 |
20 | writeFileSync('./finance.json', JSON.stringify(result,null,2)); // write file
21 |
22 |
23 | // mkdir -p ./finance
24 | // const {join} = require('path');
25 | // const {mkdirSync} = require('fs');
26 | // mkdirSync('./finance',{recursive:true});
27 |
28 | // files.forEach(({filename,content})=>{
--------------------------------------------------------------------------------
/src/mhtml/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mhtml",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "index.js",
6 | "scripts": {
7 | "test": "echo \"Error: no test specified\" && exit 1"
8 | },
9 | "keywords": [],
10 | "author": "",
11 | "license": "ISC",
12 | "dependencies": {
13 | "fast-mhtml": "^2.1.0",
14 | "playwright": "^1.38.1"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/src/mhtml/serve.js:
--------------------------------------------------------------------------------
1 | const {Processor} = require('fast-mhtml');
2 |
3 | Processor.serve(8080)
--------------------------------------------------------------------------------
/src/mhtml/serve_local_data.js:
--------------------------------------------------------------------------------
1 | const express = require('express');
2 | const { Parser } = require("fast-mhtml");
3 | const fs = require('fs');
4 |
5 | const filenamify = require('filenamify');
6 |
7 | const { join } = require('path');
8 | const mhtmlDir = join(__dirname, '../../data/mind2web-mhtml');
9 | // const mhtmlDir = join(__dirname, 'demos');
10 |
11 | const sentinel = 'mind2web_local_serve:'
12 |
13 | const app = express();
14 | const fileCache = new Map();
15 | app.get('/:path', (req, res) => {
16 |
17 | const file = req.params.path;
18 |
19 | if (file.endsWith('mhtml')) { // main file
20 | fileCache.clear(); // empty cache
21 |
22 | let base = null;
23 |
24 | const parser = new Parser({
25 | rewriteFn: (url) => {
26 | if(new URL(url,`http://localhost:${port}/`).protocol.startsWith('http')) {
27 | return url;
28 | }
29 | return sentinel+filenamify(url);
30 | }
31 | });
32 | // const fp = promised(fs.readFile, `${mhtmlDir}/${file}`);
33 | const fp = fs.promises.readFile(`${mhtmlDir}/${file}`);
34 | fp.then((data) => parser.parse(data).rewrite().spit()).then((spitFiles) => {
35 | for (const result of spitFiles) {
36 | fileCache.set(result.filename.replace(/#.*/, ''), result); // remove hash and set in cache
37 | }
38 | res.setHeader('Content-Type', spitFiles[0].type);
39 | res.send(spitFiles[0].content);
40 | res.end();
41 | }).catch((err) => {
42 | res.status(500);
43 | res.send(`Error: ${err}
${err.stack.replace(/\n/, '
')}`);
44 | res.end();
45 | });
46 | return;
47 | }
48 |
49 | // redirect to URL in path
50 | if(!file.startsWith(sentinel) && (file.includes(".css") || file.includes(".js"))){
51 | return res.redirect(file);
52 | }
53 |
54 | const result = fileCache.get(file);
55 | if (!result) {
56 | res.status(404);
57 | res.send(`MISS ${file} FROM${JSON.stringify(fileCache.keys())}`);
58 | res.end();
59 | return;
60 | }
61 | res.setHeader('Content-Type', result.type);
62 | res.send(result.content);
63 | res.end();
64 | });
65 |
66 | const port = 5002;
67 | app.listen(port,() => console.log('Listening on port '+port));
--------------------------------------------------------------------------------
/src/models/.gitignore:
--------------------------------------------------------------------------------
1 | **/*
2 |
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/src/secrets/.gitignore:
--------------------------------------------------------------------------------
1 | **/*
2 | !.gitignore
--------------------------------------------------------------------------------
/tests/testing.py:
--------------------------------------------------------------------------------
1 | from playwright.sync_api import Playwright, sync_playwright
2 |
3 | with sync_playwright() as p:
4 | browser = p.chromium.launch()
5 | page = browser.new_page()
6 | page.goto(
7 | "file:///Users/andrewstelmach/Desktop/llama2d/data/mind2web-mhtml/961c3a5e-f8ce-4c71-a917-aa546dcea7fb_before.mhtml"
8 | )
9 | # do something with the page...
10 | browser.close()
11 |
--------------------------------------------------------------------------------