├── .gitignore
├── README.md
├── dataset_files
├── README.md
├── abstractive
│ ├── ag_news.json
│ ├── antonym.json
│ ├── capitalize.json
│ ├── capitalize_first_letter.json
│ ├── capitalize_last_letter.json
│ ├── capitalize_second_letter.json
│ ├── commonsense_qa.json
│ ├── country-capital.json
│ ├── country-currency.json
│ ├── english-french.json
│ ├── english-german.json
│ ├── english-spanish.json
│ ├── landmark-country.json
│ ├── lowercase_first_letter.json
│ ├── lowercase_last_letter.json
│ ├── national_parks.json
│ ├── next_capital_letter.json
│ ├── next_item.json
│ ├── park-country.json
│ ├── person-instrument.json
│ ├── person-occupation.json
│ ├── person-sport.json
│ ├── present-past.json
│ ├── prev_item.json
│ ├── product-company.json
│ ├── sentiment.json
│ ├── singular-plural.json
│ ├── synonym.json
│ └── word_length.json
├── extractive
│ ├── adjective_v_verb_3.json
│ ├── adjective_v_verb_5.json
│ ├── alphabetically_first_3.json
│ ├── alphabetically_first_5.json
│ ├── alphabetically_last_3.json
│ ├── alphabetically_last_5.json
│ ├── animal_v_object_3.json
│ ├── animal_v_object_5.json
│ ├── choose_first_of_3.json
│ ├── choose_first_of_5.json
│ ├── choose_last_of_3.json
│ ├── choose_last_of_5.json
│ ├── choose_middle_of_3.json
│ ├── choose_middle_of_5.json
│ ├── color_v_animal_3.json
│ ├── color_v_animal_5.json
│ ├── concept_v_object_3.json
│ ├── concept_v_object_5.json
│ ├── conll2003_location.json
│ ├── conll2003_organization.json
│ ├── conll2003_person.json
│ ├── fruit_v_animal_3.json
│ ├── fruit_v_animal_5.json
│ ├── object_v_concept_3.json
│ ├── object_v_concept_5.json
│ ├── squad_val.json
│ ├── verb_v_adjective_3.json
│ └── verb_v_adjective_5.json
└── generate
│ ├── categories.json
│ ├── create_antonym_synonym_datasets.py
│ ├── create_translation_datasets.py
│ ├── task_data_generation.ipynb
│ └── translation
│ ├── en-de.0-5000.txt
│ ├── en-de.5000-6500.txt
│ ├── en-es.0-5000.txt
│ ├── en-es.5000-6500.txt
│ ├── en-fr.0-5000.txt
│ └── en-fr.5000-6500.txt
├── fv_environment.yml
├── fv_overview.png
├── notebooks
└── fv_demo.ipynb
└── src
├── __init__.py
├── compute_average_activations.py
├── compute_avg_hidden_state.py
├── compute_indirect_effect.py
├── eval_scripts
├── eval_avg_hs.sh
├── eval_fv.sh
├── eval_numheads.sh
├── eval_template_portability.sh
├── fv_eval_sweep.py
└── template.sh
├── evaluate_function_vector.py
├── natural_text_eval.py
├── portability_eval.py
├── test_numheads.py
├── utils
├── __init__.py
├── eval_utils.py
├── extract_utils.py
├── intervention_utils.py
├── model_utils.py
└── prompt_utils.py
└── vocab_reconstruction.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Cache and Checkpoints
2 | .ipynb_checkpoints
3 | *.ipynb_checkpoints
4 | __pycache__
5 | *__pycache__
6 |
7 | # Results and Figures
8 | *results
9 | results/*
10 | *figures
11 |
12 | # External Data/Repositories
13 | */generate/AntSynNET
14 |
15 | # Weights
16 | *.pth
17 | *.npz
18 | *.npy
19 | *.pt
20 |
21 | # Environments
22 | .env
23 | .venv
24 | .vscode/
25 | env/
26 | venv/
27 | ENV/
28 | env.bak/
29 | venv.bak/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Function Vectors in Large Language Models
2 | ### [Project Website](https://functions.baulab.info) | [Arxiv Preprint](https://arxiv.org/abs/2310.15213) | [OpenReview](https://openreview.net/forum?id=AwyxtyMwaG)
3 |
4 | This repository contains data and code for the paper: [Function Vectors in Large Language Models](https://arxiv.org/pdf/2310.15213).
5 |
6 |
7 |
8 |
9 |
10 | ## Setup
11 |
12 | We recommend using conda as a package manager.
13 | The environment used for this project can be found in the `fv_environment.yml` file.
14 | To install, you can run:
15 | ```
16 | conda env create -f fv_environment.yml
17 | conda activate fv
18 | ```
19 |
20 | ## Demo Notebook
21 | Checkout `notebooks/fv_demo.ipynb` for a jupyter notebook with a demo of how to create a function vector and use it in different contexts.
22 |
23 | ## Data
24 | The datasets used in our project can be found in the `dataset_files` folder.
25 |
26 | ## Code
27 | Our main evaluation scripts are contained in the `src` directory with sample script wrappers in `src/eval_scripts`.
28 |
29 | Other main code is split into various util files:
30 | - `eval_utils.py` contains code for evaluating function vectors in a variety of contexts
31 | - `extract_utils.py` contains functions for extracting function vectors and other relevant model activations.
32 | - `intervention_utils.py` contains main functionality for intervening with function vectors during inference
33 | - `model_utils.py` contains helpful functions for loading models & tokenizers from huggingface
34 | - `prompt_utils.py` contains data loading and prompt creation functionality
35 |
36 | ## Citing our work
37 | This work appeared at ICLR 2024. The paper can be cited as follows:
38 |
39 | ```bibtex
40 | @inproceedings{todd2024function,
41 | title={Function Vectors in Large Language Models},
42 | author={Eric Todd and Millicent L. Li and Arnab Sen Sharma and Aaron Mueller and Byron C. Wallace and David Bau},
43 | booktitle={Proceedings of the 2024 International Conference on Learning Representations},
44 | url={https://openreview.net/forum?id=AwyxtyMwaG},
45 | note={arXiv:2310.15213},
46 | year={2024},
47 | }
48 |
--------------------------------------------------------------------------------
/dataset_files/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | This directory contains two main directories of task datasets all in `.json` format.
4 | * (1) The `abstractive` directory contains tasks which require information that is not present in the prompt to answer.
5 | * (2) The `extractive` directory contains tasks where the answer is present somewhere in the prompt, and the task of the model
6 | is to retrieve it.
7 |
8 | The `generate` directory contains scripts we used to filter existing datasets, as well as a notebook we used to create new datasets, in addition to cleaning and filter additional pre-existing datasets.
--------------------------------------------------------------------------------
/dataset_files/abstractive/country-capital.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "output": "Kabul",
4 | "input": "Afghanistan"
5 | },
6 | {
7 | "output": "Tirana",
8 | "input": "Albania"
9 | },
10 | {
11 | "output": "Algiers",
12 | "input": "Algeria"
13 | },
14 | {
15 | "output": "Andorra la Vella",
16 | "input": "Andorra"
17 | },
18 | {
19 | "output": "Luanda",
20 | "input": "Angola"
21 | },
22 | {
23 | "output": "St. John's",
24 | "input": "Antigua and Barbuda"
25 | },
26 | {
27 | "output": "Buenos Aires",
28 | "input": "Argentina"
29 | },
30 | {
31 | "output": "Yerevan",
32 | "input": "Armenia"
33 | },
34 | {
35 | "output": "Canberra",
36 | "input": "Australia"
37 | },
38 | {
39 | "output": "Vienna",
40 | "input": "Austria"
41 | },
42 | {
43 | "output": "Baku",
44 | "input": "Azerbaijan"
45 | },
46 | {
47 | "output": "Nassau",
48 | "input": "Bahamas"
49 | },
50 | {
51 | "output": "Manama",
52 | "input": "Bahrain"
53 | },
54 | {
55 | "output": "Dhaka",
56 | "input": "Bangladesh"
57 | },
58 | {
59 | "output": "Bridgetown",
60 | "input": "Barbados"
61 | },
62 | {
63 | "output": "Minsk",
64 | "input": "Belarus"
65 | },
66 | {
67 | "output": "Brussels",
68 | "input": "Belgium"
69 | },
70 | {
71 | "output": "Belmopan",
72 | "input": "Belize"
73 | },
74 | {
75 | "output": "Porto-Novo",
76 | "input": "Benin"
77 | },
78 | {
79 | "output": "Thimphu",
80 | "input": "Bhutan"
81 | },
82 | {
83 | "output": "La Paz",
84 | "input": "Bolivia"
85 | },
86 | {
87 | "output": "Sarajevo",
88 | "input": "Bosnia and Herzegovina"
89 | },
90 | {
91 | "output": "Gaborone",
92 | "input": "Botswana"
93 | },
94 | {
95 | "output": "Brasilia",
96 | "input": "Brazil"
97 | },
98 | {
99 | "output": "Bandar Seri Begawan",
100 | "input": "Brunei"
101 | },
102 | {
103 | "output": "Sofia",
104 | "input": "Bulgaria"
105 | },
106 | {
107 | "output": "Ouagadougou",
108 | "input": "Burkina Faso"
109 | },
110 | {
111 | "output": "Bujumbura",
112 | "input": "Burundi"
113 | },
114 | {
115 | "output": "Praia",
116 | "input": "Cabo Verde"
117 | },
118 | {
119 | "output": "Phnom Penh",
120 | "input": "Cambodia"
121 | },
122 | {
123 | "output": "Yaounde",
124 | "input": "Cameroon"
125 | },
126 | {
127 | "output": "Ottawa",
128 | "input": "Canada"
129 | },
130 | {
131 | "output": "Bangui",
132 | "input": "Central African Republic"
133 | },
134 | {
135 | "output": "N'Djamena",
136 | "input": "Chad"
137 | },
138 | {
139 | "output": "Santiago",
140 | "input": "Chile"
141 | },
142 | {
143 | "output": "Beijing",
144 | "input": "China"
145 | },
146 | {
147 | "output": "Bogotá",
148 | "input": "Colombia"
149 | },
150 | {
151 | "output": "Moroni",
152 | "input": "Comoros"
153 | },
154 | {
155 | "output": "Kinshasa",
156 | "input": "Congo"
157 | },
158 | {
159 | "output": "San José",
160 | "input": "Costa Rica"
161 | },
162 | {
163 | "output": "Yamoussoukro",
164 | "input": "Cote d'Ivoire"
165 | },
166 | {
167 | "output": "Zagreb",
168 | "input": "Croatia"
169 | },
170 | {
171 | "output": "Havana",
172 | "input": "Cuba"
173 | },
174 | {
175 | "output": "Nicosia",
176 | "input": "Cyprus"
177 | },
178 | {
179 | "output": "Prague",
180 | "input": "Czech Republic"
181 | },
182 | {
183 | "output": "Kinshasa",
184 | "input": "Democratic Republic of the Congo"
185 | },
186 | {
187 | "output": "Copenhagen",
188 | "input": "Denmark"
189 | },
190 | {
191 | "output": "Djibouti City",
192 | "input": "Djibouti"
193 | },
194 | {
195 | "output": "Roseau",
196 | "input": "Dominica"
197 | },
198 | {
199 | "output": "Santo Domingo",
200 | "input": "Dominican Republic"
201 | },
202 | {
203 | "output": "Quito",
204 | "input": "Ecuador"
205 | },
206 | {
207 | "output": "Cairo",
208 | "input": "Egypt"
209 | },
210 | {
211 | "output": "San Salvador",
212 | "input": "El Salvador"
213 | },
214 | {
215 | "output": "Malabo",
216 | "input": "Equatorial Guinea"
217 | },
218 | {
219 | "output": "Asmara",
220 | "input": "Eritrea"
221 | },
222 | {
223 | "output": "Tallinn",
224 | "input": "Estonia"
225 | },
226 | {
227 | "output": "Mbabane",
228 | "input": "Eswatini"
229 | },
230 | {
231 | "output": "Addis Ababa",
232 | "input": "Ethiopia"
233 | },
234 | {
235 | "output": "Suva",
236 | "input": "Fiji"
237 | },
238 | {
239 | "output": "Helsinki",
240 | "input": "Finland"
241 | },
242 | {
243 | "output": "Paris",
244 | "input": "France"
245 | },
246 | {
247 | "output": "Libreville",
248 | "input": "Gabon"
249 | },
250 | {
251 | "output": "Banjul",
252 | "input": "Gambia"
253 | },
254 | {
255 | "output": "Tbilisi",
256 | "input": "Georgia"
257 | },
258 | {
259 | "output": "Berlin",
260 | "input": "Germany"
261 | },
262 | {
263 | "output": "Accra",
264 | "input": "Ghana"
265 | },
266 | {
267 | "output": "Athens",
268 | "input": "Greece"
269 | },
270 | {
271 | "output": "St. George's",
272 | "input": "Grenada"
273 | },
274 | {
275 | "output": "Guatemala City",
276 | "input": "Guatemala"
277 | },
278 | {
279 | "output": "Conakry",
280 | "input": "Guinea"
281 | },
282 | {
283 | "output": "Bissau",
284 | "input": "Guinea-Bissau"
285 | },
286 | {
287 | "output": "Georgetown",
288 | "input": "Guyana"
289 | },
290 | {
291 | "output": "Port-au-Prince",
292 | "input": "Haiti"
293 | },
294 | {
295 | "output": "Tegucigalpa",
296 | "input": "Honduras"
297 | },
298 | {
299 | "output": "Budapest",
300 | "input": "Hungary"
301 | },
302 | {
303 | "output": "Reykjavik",
304 | "input": "Iceland"
305 | },
306 | {
307 | "output": "New Delhi",
308 | "input": "India"
309 | },
310 | {
311 | "output": "Jakarta",
312 | "input": "Indonesia"
313 | },
314 | {
315 | "output": "Tehran",
316 | "input": "Iran"
317 | },
318 | {
319 | "output": "Baghdad",
320 | "input": "Iraq"
321 | },
322 | {
323 | "output": "Dublin",
324 | "input": "Ireland"
325 | },
326 | {
327 | "output": "Jerusalem",
328 | "input": "Israel"
329 | },
330 | {
331 | "output": "Rome",
332 | "input": "Italy"
333 | },
334 | {
335 | "output": "Kingston",
336 | "input": "Jamaica"
337 | },
338 | {
339 | "output": "Tokyo",
340 | "input": "Japan"
341 | },
342 | {
343 | "output": "Amman",
344 | "input": "Jordan"
345 | },
346 | {
347 | "output": "Astana",
348 | "input": "Kazakhstan"
349 | },
350 | {
351 | "output": "Nairobi",
352 | "input": "Kenya"
353 | },
354 | {
355 | "output": "South Tarawa",
356 | "input": "Kiribati"
357 | },
358 | {
359 | "output": "Pristina",
360 | "input": "Kosovo"
361 | },
362 | {
363 | "output": "Kuwait City",
364 | "input": "Kuwait"
365 | },
366 | {
367 | "output": "Bishkek",
368 | "input": "Kyrgyzstan"
369 | },
370 | {
371 | "output": "Vientiane",
372 | "input": "Laos"
373 | },
374 | {
375 | "output": "Riga",
376 | "input": "Latvia"
377 | },
378 | {
379 | "output": "Beirut",
380 | "input": "Lebanon"
381 | },
382 | {
383 | "output": "Maseru",
384 | "input": "Lesotho"
385 | },
386 | {
387 | "output": "Monrovia",
388 | "input": "Liberia"
389 | },
390 | {
391 | "output": "Tripoli",
392 | "input": "Libya"
393 | },
394 | {
395 | "output": "Vaduz",
396 | "input": "Liechtenstein"
397 | },
398 | {
399 | "output": "Vilnius",
400 | "input": "Lithuania"
401 | },
402 | {
403 | "output": "Luxembourg City",
404 | "input": "Luxembourg"
405 | },
406 | {
407 | "output": "Antananarivo",
408 | "input": "Madagascar"
409 | },
410 | {
411 | "output": "Lilongwe",
412 | "input": "Malawi"
413 | },
414 | {
415 | "output": "Kuala Lumpur",
416 | "input": "Malaysia"
417 | },
418 | {
419 | "output": "Malé",
420 | "input": "Maldives"
421 | },
422 | {
423 | "output": "Bamako",
424 | "input": "Mali"
425 | },
426 | {
427 | "output": "Valletta",
428 | "input": "Malta"
429 | },
430 | {
431 | "output": "Majuro",
432 | "input": "Marshall Islands"
433 | },
434 | {
435 | "output": "Nouakchott",
436 | "input": "Mauritania"
437 | },
438 | {
439 | "output": "Port Louis",
440 | "input": "Mauritius"
441 | },
442 | {
443 | "output": "Mexico City",
444 | "input": "Mexico"
445 | },
446 | {
447 | "output": "Palikir",
448 | "input": "Micronesia"
449 | },
450 | {
451 | "output": "Chisinau",
452 | "input": "Moldova"
453 | },
454 | {
455 | "output": "Monaco-Ville",
456 | "input": "Monaco"
457 | },
458 | {
459 | "output": "Ulaanbaatar",
460 | "input": "Mongolia"
461 | },
462 | {
463 | "output": "Podgorica",
464 | "input": "Montenegro"
465 | },
466 | {
467 | "output": "Rabat",
468 | "input": "Morocco"
469 | },
470 | {
471 | "output": "Maputo",
472 | "input": "Mozambique"
473 | },
474 | {
475 | "output": "Naypyidaw",
476 | "input": "Myanmar"
477 | },
478 | {
479 | "output": "Windhoek",
480 | "input": "Namibia"
481 | },
482 | {
483 | "output": "Yaren District",
484 | "input": "Nauru"
485 | },
486 | {
487 | "output": "Kathmandu",
488 | "input": "Nepal"
489 | },
490 | {
491 | "output": "Amsterdam",
492 | "input": "Netherlands"
493 | },
494 | {
495 | "output": "Wellington",
496 | "input": "New Zealand"
497 | },
498 | {
499 | "output": "Managua",
500 | "input": "Nicaragua"
501 | },
502 | {
503 | "output": "Niamey",
504 | "input": "Niger"
505 | },
506 | {
507 | "output": "Abuja",
508 | "input": "Nigeria"
509 | },
510 | {
511 | "output": "Pyongyang",
512 | "input": "North Korea"
513 | },
514 | {
515 | "output": "Skopje",
516 | "input": "North Macedonia"
517 | },
518 | {
519 | "output": "Oslo",
520 | "input": "Norway"
521 | },
522 | {
523 | "output": "Muscat",
524 | "input": "Oman"
525 | },
526 | {
527 | "output": "Islamabad",
528 | "input": "Pakistan"
529 | },
530 | {
531 | "output": "Ngerulmud",
532 | "input": "Palau"
533 | },
534 | {
535 | "output": "Ramallah",
536 | "input": "Palestine"
537 | },
538 | {
539 | "output": "Panama City",
540 | "input": "Panama"
541 | },
542 | {
543 | "output": "Port Moresby",
544 | "input": "Papua New Guinea"
545 | },
546 | {
547 | "output": "Asunción",
548 | "input": "Paraguay"
549 | },
550 | {
551 | "output": "Lima",
552 | "input": "Peru"
553 | },
554 | {
555 | "output": "Manila",
556 | "input": "Philippines"
557 | },
558 | {
559 | "output": "Warsaw",
560 | "input": "Poland"
561 | },
562 | {
563 | "output": "Lisbon",
564 | "input": "Portugal"
565 | },
566 | {
567 | "output": "Doha",
568 | "input": "Qatar"
569 | },
570 | {
571 | "output": "Bucharest",
572 | "input": "Romania"
573 | },
574 | {
575 | "output": "Moscow",
576 | "input": "Russia"
577 | },
578 | {
579 | "output": "Kigali",
580 | "input": "Rwanda"
581 | },
582 | {
583 | "output": "Basseterre",
584 | "input": "Saint Kitts and Nevis"
585 | },
586 | {
587 | "output": "Castries",
588 | "input": "Saint Lucia"
589 | },
590 | {
591 | "output": "Kingstown",
592 | "input": "Saint Vincent and the Grenadines"
593 | },
594 | {
595 | "output": "Apia",
596 | "input": "Samoa"
597 | },
598 | {
599 | "output": "San Marino",
600 | "input": "San Marino"
601 | },
602 | {
603 | "output": "Sao Tome",
604 | "input": "Sao Tome and Principe"
605 | },
606 | {
607 | "output": "Riyadh",
608 | "input": "Saudi Arabia"
609 | },
610 | {
611 | "output": "Dakar",
612 | "input": "Senegal"
613 | },
614 | {
615 | "output": "Belgrade",
616 | "input": "Serbia"
617 | },
618 | {
619 | "output": "Victoria",
620 | "input": "Seychelles"
621 | },
622 | {
623 | "output": "Freetown",
624 | "input": "Sierra Leone"
625 | },
626 | {
627 | "output": "Singapore",
628 | "input": "Singapore"
629 | },
630 | {
631 | "output": "Bratislava",
632 | "input": "Slovakia"
633 | },
634 | {
635 | "output": "Ljubljana",
636 | "input": "Slovenia"
637 | },
638 | {
639 | "output": "Honiara",
640 | "input": "Solomon Islands"
641 | },
642 | {
643 | "output": "Mogadishu",
644 | "input": "Somalia"
645 | },
646 | {
647 | "output": "Pretoria",
648 | "input": "South Africa"
649 | },
650 | {
651 | "output": "Seoul",
652 | "input": "South Korea"
653 | },
654 | {
655 | "output": "Juba",
656 | "input": "South Sudan"
657 | },
658 | {
659 | "output": "Madrid",
660 | "input": "Spain"
661 | },
662 | {
663 | "output": "Colombo",
664 | "input": "Sri Lanka"
665 | },
666 | {
667 | "output": "Khartoum",
668 | "input": "Sudan"
669 | },
670 | {
671 | "output": "Paramaribo",
672 | "input": "Suriname"
673 | },
674 | {
675 | "output": "Stockholm",
676 | "input": "Sweden"
677 | },
678 | {
679 | "output": "Bern",
680 | "input": "Switzerland"
681 | },
682 | {
683 | "output": "Damascus",
684 | "input": "Syria"
685 | },
686 | {
687 | "output": "Taipei",
688 | "input": "Taiwan"
689 | },
690 | {
691 | "output": "Dushanbe",
692 | "input": "Tajikistan"
693 | },
694 | {
695 | "output": "Dodoma",
696 | "input": "Tanzania"
697 | },
698 | {
699 | "output": "Bangkok",
700 | "input": "Thailand"
701 | },
702 | {
703 | "output": "Dili",
704 | "input": "Timor-Leste"
705 | },
706 | {
707 | "output": "Lome",
708 | "input": "Togo"
709 | },
710 | {
711 | "output": "Nukuʻalofa",
712 | "input": "Tonga"
713 | },
714 | {
715 | "output": "Port of Spain",
716 | "input": "Trinidad and Tobago"
717 | },
718 | {
719 | "output": "Tunis",
720 | "input": "Tunisia"
721 | },
722 | {
723 | "output": "Ankara",
724 | "input": "Turkey"
725 | },
726 | {
727 | "output": "Ashgabat",
728 | "input": "Turkmenistan"
729 | },
730 | {
731 | "output": "Funafuti",
732 | "input": "Tuvalu"
733 | },
734 | {
735 | "output": "Kampala",
736 | "input": "Uganda"
737 | },
738 | {
739 | "output": "Kiev",
740 | "input": "Ukraine"
741 | },
742 | {
743 | "output": "Abu Dhabi",
744 | "input": "United Arab Emirates"
745 | },
746 | {
747 | "output": "London",
748 | "input": "United Kingdom"
749 | },
750 | {
751 | "output": "Washington, D.C.",
752 | "input": "United States of America"
753 | },
754 | {
755 | "output": "Montevideo",
756 | "input": "Uruguay"
757 | },
758 | {
759 | "output": "Tashkent",
760 | "input": "Uzbekistan"
761 | },
762 | {
763 | "output": "Port Vila",
764 | "input": "Vanuatu"
765 | },
766 | {
767 | "output": "Vatican City",
768 | "input": "Vatican City"
769 | },
770 | {
771 | "output": "Caracas",
772 | "input": "Venezuela"
773 | },
774 | {
775 | "output": "Hanoi",
776 | "input": "Vietnam"
777 | },
778 | {
779 | "output": "Sana'a",
780 | "input": "Yemen"
781 | },
782 | {
783 | "output": "Lusaka",
784 | "input": "Zambia"
785 | },
786 | {
787 | "output": "Harare",
788 | "input": "Zimbabwe"
789 | }
790 | ]
--------------------------------------------------------------------------------
/dataset_files/abstractive/country-currency.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "output": "Afghani (AFN)",
4 | "input": "Afghanistan"
5 | },
6 | {
7 | "output": "Albanian Lek (ALL)",
8 | "input": "Albania"
9 | },
10 | {
11 | "output": "Algerian Dinar",
12 | "input": "Algeria"
13 | },
14 | {
15 | "output": "Euro (EUR)",
16 | "input": "Andorra"
17 | },
18 | {
19 | "output": "Kwanza (AOA)",
20 | "input": "Angola"
21 | },
22 | {
23 | "output": "East Caribbean Dollar (XCD)",
24 | "input": "Antigua and Barbuda"
25 | },
26 | {
27 | "output": "Argentine Peso",
28 | "input": "Argentina"
29 | },
30 | {
31 | "output": "Dram (AMD)",
32 | "input": "Armenia"
33 | },
34 | {
35 | "output": "Australian Dollar (AUD)",
36 | "input": "Australia"
37 | },
38 | {
39 | "output": "Euro (EUR)",
40 | "input": "Austria"
41 | },
42 | {
43 | "output": "Manat",
44 | "input": "Azerbaijan"
45 | },
46 | {
47 | "output": "Bahamian Dollar",
48 | "input": "Bahamas"
49 | },
50 | {
51 | "output": "Bahraini Dinar (BHD)",
52 | "input": "Bahrain"
53 | },
54 | {
55 | "output": "Taka",
56 | "input": "Bangladesh"
57 | },
58 | {
59 | "output": "Barbadian Dollar (BBD)",
60 | "input": "Barbados"
61 | },
62 | {
63 | "output": "Belarusian Ruble (BYN)",
64 | "input": "Belarus"
65 | },
66 | {
67 | "output": "Euro (EUR)",
68 | "input": "Belgium"
69 | },
70 | {
71 | "output": "Belize Dollar (BZD)",
72 | "input": "Belize"
73 | },
74 | {
75 | "output": "CFA Franc (XOF)",
76 | "input": "Benin"
77 | },
78 | {
79 | "output": "Ngultrum (BTN)",
80 | "input": "Bhutan"
81 | },
82 | {
83 | "output": "Bolivian Boliviano (BOB)",
84 | "input": "Bolivia"
85 | },
86 | {
87 | "output": "Convertible Mark (KM or BAM)",
88 | "input": "Bosnia and Herzegovina"
89 | },
90 | {
91 | "output": "Pula",
92 | "input": "Botswana"
93 | },
94 | {
95 | "output": "Brazilian Real (BRL)",
96 | "input": "Brazil"
97 | },
98 | {
99 | "output": "Brunei Dollar (BND)",
100 | "input": "Brunei"
101 | },
102 | {
103 | "output": "Bulgarian Lev (BGN)",
104 | "input": "Bulgaria"
105 | },
106 | {
107 | "output": "West African CFA franc",
108 | "input": "Burkina Faso"
109 | },
110 | {
111 | "output": "Burundian Franc (BIF)",
112 | "input": "Burundi"
113 | },
114 | {
115 | "output": "Escudo (CVE)",
116 | "input": "Cabo Verde"
117 | },
118 | {
119 | "output": "Riel",
120 | "input": "Cambodia"
121 | },
122 | {
123 | "output": "Central African CFA franc",
124 | "input": "Cameroon"
125 | },
126 | {
127 | "output": "Canadian Dollar (CAD)",
128 | "input": "Canada"
129 | },
130 | {
131 | "output": "Central African CFA franc",
132 | "input": "Central African Republic"
133 | },
134 | {
135 | "output": "Central African CFA franc",
136 | "input": "Chad"
137 | },
138 | {
139 | "output": "Chilean Peso",
140 | "input": "Chile"
141 | },
142 | {
143 | "output": "Renminbi (RMB)",
144 | "input": "China"
145 | },
146 | {
147 | "output": "Colombian Peso",
148 | "input": "Colombia"
149 | },
150 | {
151 | "output": "Comorian Franc",
152 | "input": "Comoros"
153 | },
154 | {
155 | "output": "Congolese Franc (CDF)",
156 | "input": "Congo"
157 | },
158 | {
159 | "output": "Colón (CRC)",
160 | "input": "Costa Rica"
161 | },
162 | {
163 | "output": "CFA Franc (XOF)",
164 | "input": "Cote d'Ivoire"
165 | },
166 | {
167 | "output": "Kuna (HRK)",
168 | "input": "Croatia"
169 | },
170 | {
171 | "output": "Cuban Peso (CUP)",
172 | "input": "Cuba"
173 | },
174 | {
175 | "output": "Euro (EUR)",
176 | "input": "Cyprus"
177 | },
178 | {
179 | "output": "Czech Koruna (CZK)",
180 | "input": "Czech Republic"
181 | },
182 | {
183 | "output": "Congolese Franc (CDF)",
184 | "input": "Democratic Republic of the Congo"
185 | },
186 | {
187 | "output": "Danish Krone",
188 | "input": "Denmark"
189 | },
190 | {
191 | "output": "Djiboutian Franc (DJF)",
192 | "input": "Djibouti"
193 | },
194 | {
195 | "output": "East Caribbean Dollar (XCD)",
196 | "input": "Dominica"
197 | },
198 | {
199 | "output": "Dominican Peso",
200 | "input": "Dominican Republic"
201 | },
202 | {
203 | "output": "US Dollar (USD)",
204 | "input": "Ecuador"
205 | },
206 | {
207 | "output": "Egyptian Pound (EGP)",
208 | "input": "Egypt"
209 | },
210 | {
211 | "output": "US Dollar (USD)",
212 | "input": "El Salvador"
213 | },
214 | {
215 | "output": "Central African CFA franc",
216 | "input": "Equatorial Guinea"
217 | },
218 | {
219 | "output": "Nakfa",
220 | "input": "Eritrea"
221 | },
222 | {
223 | "output": "Euro (EUR)",
224 | "input": "Estonia"
225 | },
226 | {
227 | "output": "Lilangeni",
228 | "input": "Eswatini"
229 | },
230 | {
231 | "output": "Ethiopian Birr",
232 | "input": "Ethiopia"
233 | },
234 | {
235 | "output": "Fijian Dollar (FJD)",
236 | "input": "Fiji"
237 | },
238 | {
239 | "output": "Euro (EUR)",
240 | "input": "Finland"
241 | },
242 | {
243 | "output": "Euro (EUR)",
244 | "input": "France"
245 | },
246 | {
247 | "output": "Central African CFA franc",
248 | "input": "Gabon"
249 | },
250 | {
251 | "output": "Dalasi (GMD)",
252 | "input": "Gambia"
253 | },
254 | {
255 | "output": "Lari",
256 | "input": "Georgia"
257 | },
258 | {
259 | "output": "Euro (EUR)",
260 | "input": "Germany"
261 | },
262 | {
263 | "output": "Ghana Cedi (GHS)",
264 | "input": "Ghana"
265 | },
266 | {
267 | "output": "Euro (EUR)",
268 | "input": "Greece"
269 | },
270 | {
271 | "output": "East Caribbean Dollar (XCD)",
272 | "input": "Grenada"
273 | },
274 | {
275 | "output": "Quetzal",
276 | "input": "Guatemala"
277 | },
278 | {
279 | "output": "Guinean Franc",
280 | "input": "Guinea"
281 | },
282 | {
283 | "output": "West African CFA franc (XOF)",
284 | "input": "Guinea-Bissau"
285 | },
286 | {
287 | "output": "Guyanese Dollar",
288 | "input": "Guyana"
289 | },
290 | {
291 | "output": "Gourde",
292 | "input": "Haiti"
293 | },
294 | {
295 | "output": "Lempira",
296 | "input": "Honduras"
297 | },
298 | {
299 | "output": "Forint (HUF)",
300 | "input": "Hungary"
301 | },
302 | {
303 | "output": "Icelandic Króna (ISK)",
304 | "input": "Iceland"
305 | },
306 | {
307 | "output": "Indian Rupee",
308 | "input": "India"
309 | },
310 | {
311 | "output": "Indonesian Rupiah",
312 | "input": "Indonesia"
313 | },
314 | {
315 | "output": "Iranian Rial",
316 | "input": "Iran"
317 | },
318 | {
319 | "output": "Iraqi Dinar (IQD)",
320 | "input": "Iraq"
321 | },
322 | {
323 | "output": "Euro (EUR)",
324 | "input": "Ireland"
325 | },
326 | {
327 | "output": "Israeli Shekel",
328 | "input": "Israel"
329 | },
330 | {
331 | "output": "Euro (EUR)",
332 | "input": "Italy"
333 | },
334 | {
335 | "output": "Jamaican Dollar (JMD)",
336 | "input": "Jamaica"
337 | },
338 | {
339 | "output": "Japanese Yen",
340 | "input": "Japan"
341 | },
342 | {
343 | "output": "Jordanian Dinar (JOD)",
344 | "input": "Jordan"
345 | },
346 | {
347 | "output": "Tenge",
348 | "input": "Kazakhstan"
349 | },
350 | {
351 | "output": "Kenyan Shilling (KES)",
352 | "input": "Kenya"
353 | },
354 | {
355 | "output": "Australian Dollar (AUD)",
356 | "input": "Kiribati"
357 | },
358 | {
359 | "output": "Euro (EUR)",
360 | "input": "Kosovo"
361 | },
362 | {
363 | "output": "Kuwaiti Dinar (KWD)",
364 | "input": "Kuwait"
365 | },
366 | {
367 | "output": "Som (KGS)",
368 | "input": "Kyrgyzstan"
369 | },
370 | {
371 | "output": "Lao Kip (LAK)",
372 | "input": "Laos"
373 | },
374 | {
375 | "output": "Latvian Lats (LVL)",
376 | "input": "Latvia"
377 | },
378 | {
379 | "output": "Lebanese Pound (LBP)",
380 | "input": "Lebanon"
381 | },
382 | {
383 | "output": "Loti (LSL)",
384 | "input": "Lesotho"
385 | },
386 | {
387 | "output": "Liberian Dollar",
388 | "input": "Liberia"
389 | },
390 | {
391 | "output": "Libyan Dinar (LYD)",
392 | "input": "Libya"
393 | },
394 | {
395 | "output": "Swiss Franc (CHF)",
396 | "input": "Liechtenstein"
397 | },
398 | {
399 | "output": "Lithuanian Litas (LTL)",
400 | "input": "Lithuania"
401 | },
402 | {
403 | "output": "Euro (EUR)",
404 | "input": "Luxembourg"
405 | },
406 | {
407 | "output": "Ariary",
408 | "input": "Madagascar"
409 | },
410 | {
411 | "output": "Malawian Kwacha (MWK)",
412 | "input": "Malawi"
413 | },
414 | {
415 | "output": "Malaysian Ringgit (MYR)",
416 | "input": "Malaysia"
417 | },
418 | {
419 | "output": "Maldivian Rufiyaa (MVR)",
420 | "input": "Maldives"
421 | },
422 | {
423 | "output": "CFA Franc (XOF)",
424 | "input": "Mali"
425 | },
426 | {
427 | "output": "Euro (EUR)",
428 | "input": "Malta"
429 | },
430 | {
431 | "output": "US Dollar (USD)",
432 | "input": "Marshall Islands"
433 | },
434 | {
435 | "output": "Ouguiya (MRO)",
436 | "input": "Mauritania"
437 | },
438 | {
439 | "output": "Mauritian Rupee",
440 | "input": "Mauritius"
441 | },
442 | {
443 | "output": "Mexican Peso",
444 | "input": "Mexico"
445 | },
446 | {
447 | "output": "US Dollar (USD)",
448 | "input": "Micronesia"
449 | },
450 | {
451 | "output": "Moldovan Leu (MDL)",
452 | "input": "Moldova"
453 | },
454 | {
455 | "output": "Euro (EUR)",
456 | "input": "Monaco"
457 | },
458 | {
459 | "output": "Tugrik (MNT)",
460 | "input": "Mongolia"
461 | },
462 | {
463 | "output": "Euro (EUR)",
464 | "input": "Montenegro"
465 | },
466 | {
467 | "output": "Moroccan Dirham (MAD)",
468 | "input": "Morocco"
469 | },
470 | {
471 | "output": "Metical (MZN)",
472 | "input": "Mozambique"
473 | },
474 | {
475 | "output": "Kyat",
476 | "input": "Myanmar"
477 | },
478 | {
479 | "output": "Namibian Dollar (NAD)",
480 | "input": "Namibia"
481 | },
482 | {
483 | "output": "Australian Dollar (AUD)",
484 | "input": "Nauru"
485 | },
486 | {
487 | "output": "Nepalese Rupee",
488 | "input": "Nepal"
489 | },
490 | {
491 | "output": "Euro (EUR)",
492 | "input": "Netherlands"
493 | },
494 | {
495 | "output": "New Zealand Dollar (NZD)",
496 | "input": "New Zealand"
497 | },
498 | {
499 | "output": "Córdoba (NIO)",
500 | "input": "Nicaragua"
501 | },
502 | {
503 | "output": "Naira",
504 | "input": "Niger"
505 | },
506 | {
507 | "output": "Naira",
508 | "input": "Nigeria"
509 | },
510 | {
511 | "output": "North Korean Won (KPW)",
512 | "input": "North Korea"
513 | },
514 | {
515 | "output": "Macedonian Denar (MKD)",
516 | "input": "North Macedonia"
517 | },
518 | {
519 | "output": "Norwegian Krone (NOK)",
520 | "input": "Norway"
521 | },
522 | {
523 | "output": "Omani Rial",
524 | "input": "Oman"
525 | },
526 | {
527 | "output": "Pakistani Rupee",
528 | "input": "Pakistan"
529 | },
530 | {
531 | "output": "US Dollar (USD)",
532 | "input": "Palau"
533 | },
534 | {
535 | "output": "Israeli New Shekel (ILS)",
536 | "input": "Palestine"
537 | },
538 | {
539 | "output": "Balboa (PAB)",
540 | "input": "Panama"
541 | },
542 | {
543 | "output": "Kina (PGK)",
544 | "input": "Papua New Guinea"
545 | },
546 | {
547 | "output": "Guarani (PYG)",
548 | "input": "Paraguay"
549 | },
550 | {
551 | "output": "Sol (PEN)",
552 | "input": "Peru"
553 | },
554 | {
555 | "output": "Philippine Peso",
556 | "input": "Philippines"
557 | },
558 | {
559 | "output": "Polish Zloty (PLN)",
560 | "input": "Poland"
561 | },
562 | {
563 | "output": "Euro (EUR)",
564 | "input": "Portugal"
565 | },
566 | {
567 | "output": "Qatari Riyal",
568 | "input": "Qatar"
569 | },
570 | {
571 | "output": "Romanian Leu (RON)",
572 | "input": "Romania"
573 | },
574 | {
575 | "output": "Russian Ruble",
576 | "input": "Russia"
577 | },
578 | {
579 | "output": "Rwandan Franc (RWF)",
580 | "input": "Rwanda"
581 | },
582 | {
583 | "output": "East Caribbean Dollar (XCD)",
584 | "input": "Saint Kitts and Nevis"
585 | },
586 | {
587 | "output": "East Caribbean Dollar (XCD)",
588 | "input": "Saint Lucia"
589 | },
590 | {
591 | "output": "East Caribbean Dollar (XCD)",
592 | "input": "Saint Vincent and the Grenadines"
593 | },
594 | {
595 | "output": "Tala",
596 | "input": "Samoa"
597 | },
598 | {
599 | "output": "Euro (EUR)",
600 | "input": "San Marino"
601 | },
602 | {
603 | "output": "Dobra (STD)",
604 | "input": "Sao Tome and Principe"
605 | },
606 | {
607 | "output": "Saudi Riyal",
608 | "input": "Saudi Arabia"
609 | },
610 | {
611 | "output": "West African CFA franc",
612 | "input": "Senegal"
613 | },
614 | {
615 | "output": "Serbian Dinar (RSD)",
616 | "input": "Serbia"
617 | },
618 | {
619 | "output": "Seychellois Rupee (SCR)",
620 | "input": "Seychelles"
621 | },
622 | {
623 | "output": "Leone (SLL)",
624 | "input": "Sierra Leone"
625 | },
626 | {
627 | "output": "Singapore Dollar (SGD)",
628 | "input": "Singapore"
629 | },
630 | {
631 | "output": "Euro (EUR)",
632 | "input": "Slovakia"
633 | },
634 | {
635 | "output": "Euro (EUR)",
636 | "input": "Slovenia"
637 | },
638 | {
639 | "output": "Solomon Islands Dollar (SBD)",
640 | "input": "Solomon Islands"
641 | },
642 | {
643 | "output": "Somali Shilling (SOS)",
644 | "input": "Somalia"
645 | },
646 | {
647 | "output": "South African Rand (ZAR)",
648 | "input": "South Africa"
649 | },
650 | {
651 | "output": "South Korean Won (KRW)",
652 | "input": "South Korea"
653 | },
654 | {
655 | "output": "South Sudanese Pound (SSP)",
656 | "input": "South Sudan"
657 | },
658 | {
659 | "output": "Euro (EUR)",
660 | "input": "Spain"
661 | },
662 | {
663 | "output": "Sri Lankan Rupee",
664 | "input": "Sri Lanka"
665 | },
666 | {
667 | "output": "Sudanese Pound (SDG)",
668 | "input": "Sudan"
669 | },
670 | {
671 | "output": "Surinamese Dollar (SRD)",
672 | "input": "Suriname"
673 | },
674 | {
675 | "output": "Swedish Krona (SEK)",
676 | "input": "Sweden"
677 | },
678 | {
679 | "output": "Swiss Franc (CHF)",
680 | "input": "Switzerland"
681 | },
682 | {
683 | "output": "Syrian Pound",
684 | "input": "Syria"
685 | },
686 | {
687 | "output": "New Taiwan Dollar (TWD)",
688 | "input": "Taiwan"
689 | },
690 | {
691 | "output": "Tajikistani Somoni",
692 | "input": "Tajikistan"
693 | },
694 | {
695 | "output": "Tanzanian Shilling",
696 | "input": "Tanzania"
697 | },
698 | {
699 | "output": "Thai Baht",
700 | "input": "Thailand"
701 | },
702 | {
703 | "output": "US Dollar (USD)",
704 | "input": "Timor-Leste"
705 | },
706 | {
707 | "output": "CFA Franc",
708 | "input": "Togo"
709 | },
710 | {
711 | "output": "Pa'anga (TOP)",
712 | "input": "Tonga"
713 | },
714 | {
715 | "output": "Trinidad and Tobago Dollar (TTD)",
716 | "input": "Trinidad and Tobago"
717 | },
718 | {
719 | "output": "Tunisian Dinar (TND)",
720 | "input": "Tunisia"
721 | },
722 | {
723 | "output": "Turkish Lira (TRY)",
724 | "input": "Turkey"
725 | },
726 | {
727 | "output": "Turkmenistani Manat",
728 | "input": "Turkmenistan"
729 | },
730 | {
731 | "output": "Tuvaluan Dollar (TVD)",
732 | "input": "Tuvalu"
733 | },
734 | {
735 | "output": "Ugandan Shilling",
736 | "input": "Uganda"
737 | },
738 | {
739 | "output": "Ukrainian Hryvnia (UAH)",
740 | "input": "Ukraine"
741 | },
742 | {
743 | "output": "UAE Dirham",
744 | "input": "United Arab Emirates"
745 | },
746 | {
747 | "output": "British Pound (GBP)",
748 | "input": "United Kingdom"
749 | },
750 | {
751 | "output": "US Dollar (USD)",
752 | "input": "United States of America"
753 | },
754 | {
755 | "output": "Uruguayan Peso",
756 | "input": "Uruguay"
757 | },
758 | {
759 | "output": "Uzbekistani Som (UZS)",
760 | "input": "Uzbekistan"
761 | },
762 | {
763 | "output": "Vatu (VUV)",
764 | "input": "Vanuatu"
765 | },
766 | {
767 | "output": "Euro (EUR)",
768 | "input": "Vatican City"
769 | },
770 | {
771 | "output": "Bolívar Soberano (VES)",
772 | "input": "Venezuela"
773 | },
774 | {
775 | "output": "Vietnamese Dong (VND)",
776 | "input": "Vietnam"
777 | },
778 | {
779 | "output": "Yemeni Rial",
780 | "input": "Yemen"
781 | },
782 | {
783 | "output": "Kwacha",
784 | "input": "Zambia"
785 | },
786 | {
787 | "output": "Zimbabwean Dollar",
788 | "input": "Zimbabwe"
789 | }
790 | ]
--------------------------------------------------------------------------------
/dataset_files/abstractive/next_item.json:
--------------------------------------------------------------------------------
1 | [{"input": "zero", "output": "one"}, {"input": "one", "output": "two"}, {"input": "two", "output": "three"}, {"input": "three", "output": "four"}, {"input": "four", "output": "five"}, {"input": "five", "output": "six"}, {"input": "six", "output": "seven"}, {"input": "seven", "output": "eight"}, {"input": "eight", "output": "nine"}, {"input": "nine", "output": "ten"}, {"input": "ten", "output": "eleven"}, {"input": "eleven", "output": "twelve"}, {"input": "twelve", "output": "thirteen"}, {"input": "thirteen", "output": "fourteen"}, {"input": "fourteen", "output": "fifteen"}, {"input": "fifteen", "output": "sixteen"}, {"input": "sixteen", "output": "seventeen"}, {"input": "seventeen", "output": "eighteen"}, {"input": "eighteen", "output": "nineteen"}, {"input": "nineteen", "output": "twenty"}, {"input": "0", "output": "1"}, {"input": "1", "output": "2"}, {"input": "2", "output": "3"}, {"input": "3", "output": "4"}, {"input": "4", "output": "5"}, {"input": "5", "output": "6"}, {"input": "6", "output": "7"}, {"input": "7", "output": "8"}, {"input": "8", "output": "9"}, {"input": "9", "output": "10"}, {"input": "10", "output": "11"}, {"input": "11", "output": "12"}, {"input": "12", "output": "13"}, {"input": "13", "output": "14"}, {"input": "14", "output": "15"}, {"input": "15", "output": "16"}, {"input": "16", "output": "17"}, {"input": "17", "output": "18"}, {"input": "18", "output": "19"}, {"input": "19", "output": "20"}, {"input": "20", "output": "21"}, {"input": "21", "output": "22"}, {"input": "22", "output": "23"}, {"input": "23", "output": "24"}, {"input": "24", "output": "25"}, {"input": "25", "output": "26"}, {"input": "26", "output": "27"}, {"input": "27", "output": "28"}, {"input": "28", "output": "29"}, {"input": "a", "output": "b"}, {"input": "b", "output": "c"}, {"input": "c", "output": "d"}, {"input": "d", "output": "e"}, {"input": "e", "output": "f"}, {"input": "f", "output": "g"}, {"input": "g", "output": "h"}, {"input": "h", "output": "i"}, {"input": "i", "output": "j"}, {"input": "j", "output": "k"}, {"input": "k", "output": "l"}, {"input": "l", "output": "m"}, {"input": "m", "output": "n"}, {"input": "n", "output": "o"}, {"input": "o", "output": "p"}, {"input": "p", "output": "q"}, {"input": "q", "output": "r"}, {"input": "r", "output": "s"}, {"input": "s", "output": "t"}, {"input": "t", "output": "u"}, {"input": "u", "output": "v"}, {"input": "v", "output": "w"}, {"input": "w", "output": "x"}, {"input": "x", "output": "y"}, {"input": "y", "output": "z"}, {"input": "A", "output": "B"}, {"input": "B", "output": "C"}, {"input": "C", "output": "D"}, {"input": "D", "output": "E"}, {"input": "E", "output": "F"}, {"input": "F", "output": "G"}, {"input": "G", "output": "H"}, {"input": "H", "output": "I"}, {"input": "I", "output": "J"}, {"input": "J", "output": "K"}, {"input": "K", "output": "L"}, {"input": "L", "output": "M"}, {"input": "M", "output": "N"}, {"input": "N", "output": "O"}, {"input": "O", "output": "P"}, {"input": "P", "output": "Q"}, {"input": "Q", "output": "R"}, {"input": "R", "output": "S"}, {"input": "S", "output": "T"}, {"input": "T", "output": "U"}, {"input": "U", "output": "V"}, {"input": "V", "output": "W"}, {"input": "W", "output": "X"}, {"input": "X", "output": "Y"}, {"input": "Y", "output": "Z"}, {"input": "AA", "output": "BB"}, {"input": "BB", "output": "CC"}, {"input": "CC", "output": "DD"}, {"input": "DD", "output": "EE"}, {"input": "EE", "output": "FF"}, {"input": "FF", "output": "GG"}, {"input": "GG", "output": "HH"}, {"input": "HH", "output": "II"}, {"input": "II", "output": "JJ"}, {"input": "JJ", "output": "KK"}, {"input": "KK", "output": "LL"}, {"input": "LL", "output": "MM"}, {"input": "MM", "output": "NN"}, {"input": "NN", "output": "OO"}, {"input": "OO", "output": "PP"}, {"input": "PP", "output": "QQ"}, {"input": "QQ", "output": "RR"}, {"input": "RR", "output": "SS"}, {"input": "SS", "output": "TT"}, {"input": "TT", "output": "UU"}, {"input": "UU", "output": "VV"}, {"input": "VV", "output": "WW"}, {"input": "WW", "output": "XX"}, {"input": "XX", "output": "YY"}, {"input": "YY", "output": "ZZ"}, {"input": "aa", "output": "bb"}, {"input": "bb", "output": "cc"}, {"input": "cc", "output": "dd"}, {"input": "dd", "output": "ee"}, {"input": "ee", "output": "ff"}, {"input": "ff", "output": "gg"}, {"input": "gg", "output": "hh"}, {"input": "hh", "output": "ii"}, {"input": "ii", "output": "jj"}, {"input": "jj", "output": "kk"}, {"input": "kk", "output": "ll"}, {"input": "ll", "output": "mm"}, {"input": "mm", "output": "nn"}, {"input": "nn", "output": "oo"}, {"input": "oo", "output": "pp"}, {"input": "pp", "output": "qq"}, {"input": "qq", "output": "rr"}, {"input": "rr", "output": "ss"}, {"input": "ss", "output": "tt"}, {"input": "tt", "output": "uu"}, {"input": "uu", "output": "vv"}, {"input": "vv", "output": "ww"}, {"input": "ww", "output": "xx"}, {"input": "xx", "output": "yy"}, {"input": "yy", "output": "zz"}, {"input": "I", "output": "II"}, {"input": "II", "output": "III"}, {"input": "III", "output": "IV"}, {"input": "IV", "output": "V"}, {"input": "V", "output": "VI"}, {"input": "VI", "output": "VII"}, {"input": "VII", "output": "VIII"}, {"input": "VIII", "output": "IX"}, {"input": "IX", "output": "X"}, {"input": "X", "output": "XI"}, {"input": "XI", "output": "XII"}, {"input": "XII", "output": "XIII"}, {"input": "XIII", "output": "XIV"}, {"input": "XIV", "output": "XV"}, {"input": "XV", "output": "XVI"}, {"input": "XVI", "output": "XVII"}, {"input": "XVII", "output": "XVIII"}, {"input": "XVIII", "output": "XIX"}, {"input": "XIX", "output": "XX"}, {"input": "i", "output": "ii"}, {"input": "ii", "output": "iii"}, {"input": "iii", "output": "iv"}, {"input": "iv", "output": "v"}, {"input": "v", "output": "vi"}, {"input": "vi", "output": "vii"}, {"input": "vii", "output": "viii"}, {"input": "viii", "output": "ix"}, {"input": "ix", "output": "x"}, {"input": "x", "output": "xi"}, {"input": "xi", "output": "xii"}, {"input": "xii", "output": "xiii"}, {"input": "xiii", "output": "xiv"}, {"input": "xiv", "output": "xv"}, {"input": "xv", "output": "xvi"}, {"input": "xvi", "output": "xvii"}, {"input": "xvii", "output": "xviii"}, {"input": "xviii", "output": "xix"}, {"input": "xix", "output": "xx"}, {"input": "monday", "output": "tuesday"}, {"input": "tuesday", "output": "wednesday"}, {"input": "wednesday", "output": "thursday"}, {"input": "thursday", "output": "friday"}, {"input": "friday", "output": "saturday"}, {"input": "saturday", "output": "sunday"}, {"input": "january", "output": "february"}, {"input": "february", "output": "march"}, {"input": "march", "output": "april"}, {"input": "april", "output": "may"}, {"input": "may", "output": "june"}, {"input": "june", "output": "july"}, {"input": "july", "output": "august"}, {"input": "august", "output": "september"}, {"input": "september", "output": "october"}, {"input": "october", "output": "november"}, {"input": "november", "output": "december"}, {"input": "Monday", "output": "Tuesday"}, {"input": "Tuesday", "output": "Wednesday"}, {"input": "Wednesday", "output": "Thursday"}, {"input": "Thursday", "output": "Friday"}, {"input": "Friday", "output": "Saturday"}, {"input": "Saturday", "output": "Sunday"}, {"input": "January", "output": "February"}, {"input": "February", "output": "March"}, {"input": "March", "output": "April"}, {"input": "April", "output": "May"}, {"input": "May", "output": "June"}, {"input": "June", "output": "July"}, {"input": "July", "output": "August"}, {"input": "August", "output": "September"}, {"input": "September", "output": "October"}, {"input": "October", "output": "November"}, {"input": "November", "output": "December"}, {"input": "sunday", "output": "monday"}, {"input": "december", "output": "january"}, {"input": "Sunday", "output": "Monday"}, {"input": "December", "output": "January"}]
--------------------------------------------------------------------------------
/dataset_files/abstractive/prev_item.json:
--------------------------------------------------------------------------------
1 | [{"input": "one", "output": "zero"}, {"input": "two", "output": "one"}, {"input": "three", "output": "two"}, {"input": "four", "output": "three"}, {"input": "five", "output": "four"}, {"input": "six", "output": "five"}, {"input": "seven", "output": "six"}, {"input": "eight", "output": "seven"}, {"input": "nine", "output": "eight"}, {"input": "ten", "output": "nine"}, {"input": "eleven", "output": "ten"}, {"input": "twelve", "output": "eleven"}, {"input": "thirteen", "output": "twelve"}, {"input": "fourteen", "output": "thirteen"}, {"input": "fifteen", "output": "fourteen"}, {"input": "sixteen", "output": "fifteen"}, {"input": "seventeen", "output": "sixteen"}, {"input": "eighteen", "output": "seventeen"}, {"input": "nineteen", "output": "eighteen"}, {"input": "twenty", "output": "nineteen"}, {"input": "1", "output": "0"}, {"input": "2", "output": "1"}, {"input": "3", "output": "2"}, {"input": "4", "output": "3"}, {"input": "5", "output": "4"}, {"input": "6", "output": "5"}, {"input": "7", "output": "6"}, {"input": "8", "output": "7"}, {"input": "9", "output": "8"}, {"input": "10", "output": "9"}, {"input": "11", "output": "10"}, {"input": "12", "output": "11"}, {"input": "13", "output": "12"}, {"input": "14", "output": "13"}, {"input": "15", "output": "14"}, {"input": "16", "output": "15"}, {"input": "17", "output": "16"}, {"input": "18", "output": "17"}, {"input": "19", "output": "18"}, {"input": "20", "output": "19"}, {"input": "21", "output": "20"}, {"input": "22", "output": "21"}, {"input": "23", "output": "22"}, {"input": "24", "output": "23"}, {"input": "25", "output": "24"}, {"input": "26", "output": "25"}, {"input": "27", "output": "26"}, {"input": "28", "output": "27"}, {"input": "29", "output": "28"}, {"input": "b", "output": "a"}, {"input": "c", "output": "b"}, {"input": "d", "output": "c"}, {"input": "e", "output": "d"}, {"input": "f", "output": "e"}, {"input": "g", "output": "f"}, {"input": "h", "output": "g"}, {"input": "i", "output": "h"}, {"input": "j", "output": "i"}, {"input": "k", "output": "j"}, {"input": "l", "output": "k"}, {"input": "m", "output": "l"}, {"input": "n", "output": "m"}, {"input": "o", "output": "n"}, {"input": "p", "output": "o"}, {"input": "q", "output": "p"}, {"input": "r", "output": "q"}, {"input": "s", "output": "r"}, {"input": "t", "output": "s"}, {"input": "u", "output": "t"}, {"input": "v", "output": "u"}, {"input": "w", "output": "v"}, {"input": "x", "output": "w"}, {"input": "y", "output": "x"}, {"input": "z", "output": "y"}, {"input": "B", "output": "A"}, {"input": "C", "output": "B"}, {"input": "D", "output": "C"}, {"input": "E", "output": "D"}, {"input": "F", "output": "E"}, {"input": "G", "output": "F"}, {"input": "H", "output": "G"}, {"input": "I", "output": "H"}, {"input": "J", "output": "I"}, {"input": "K", "output": "J"}, {"input": "L", "output": "K"}, {"input": "M", "output": "L"}, {"input": "N", "output": "M"}, {"input": "O", "output": "N"}, {"input": "P", "output": "O"}, {"input": "Q", "output": "P"}, {"input": "R", "output": "Q"}, {"input": "S", "output": "R"}, {"input": "T", "output": "S"}, {"input": "U", "output": "T"}, {"input": "V", "output": "U"}, {"input": "W", "output": "V"}, {"input": "X", "output": "W"}, {"input": "Y", "output": "X"}, {"input": "Z", "output": "Y"}, {"input": "BB", "output": "AA"}, {"input": "CC", "output": "BB"}, {"input": "DD", "output": "CC"}, {"input": "EE", "output": "DD"}, {"input": "FF", "output": "EE"}, {"input": "GG", "output": "FF"}, {"input": "HH", "output": "GG"}, {"input": "II", "output": "HH"}, {"input": "JJ", "output": "II"}, {"input": "KK", "output": "JJ"}, {"input": "LL", "output": "KK"}, {"input": "MM", "output": "LL"}, {"input": "NN", "output": "MM"}, {"input": "OO", "output": "NN"}, {"input": "PP", "output": "OO"}, {"input": "QQ", "output": "PP"}, {"input": "RR", "output": "QQ"}, {"input": "SS", "output": "RR"}, {"input": "TT", "output": "SS"}, {"input": "UU", "output": "TT"}, {"input": "VV", "output": "UU"}, {"input": "WW", "output": "VV"}, {"input": "XX", "output": "WW"}, {"input": "YY", "output": "XX"}, {"input": "ZZ", "output": "YY"}, {"input": "bb", "output": "aa"}, {"input": "cc", "output": "bb"}, {"input": "dd", "output": "cc"}, {"input": "ee", "output": "dd"}, {"input": "ff", "output": "ee"}, {"input": "gg", "output": "ff"}, {"input": "hh", "output": "gg"}, {"input": "ii", "output": "hh"}, {"input": "jj", "output": "ii"}, {"input": "kk", "output": "jj"}, {"input": "ll", "output": "kk"}, {"input": "mm", "output": "ll"}, {"input": "nn", "output": "mm"}, {"input": "oo", "output": "nn"}, {"input": "pp", "output": "oo"}, {"input": "qq", "output": "pp"}, {"input": "rr", "output": "qq"}, {"input": "ss", "output": "rr"}, {"input": "tt", "output": "ss"}, {"input": "uu", "output": "tt"}, {"input": "vv", "output": "uu"}, {"input": "ww", "output": "vv"}, {"input": "xx", "output": "ww"}, {"input": "yy", "output": "xx"}, {"input": "zz", "output": "yy"}, {"input": "II", "output": "I"}, {"input": "III", "output": "II"}, {"input": "IV", "output": "III"}, {"input": "V", "output": "IV"}, {"input": "VI", "output": "V"}, {"input": "VII", "output": "VI"}, {"input": "VIII", "output": "VII"}, {"input": "IX", "output": "VIII"}, {"input": "X", "output": "IX"}, {"input": "XI", "output": "X"}, {"input": "XII", "output": "XI"}, {"input": "XIII", "output": "XII"}, {"input": "XIV", "output": "XIII"}, {"input": "XV", "output": "XIV"}, {"input": "XVI", "output": "XV"}, {"input": "XVII", "output": "XVI"}, {"input": "XVIII", "output": "XVII"}, {"input": "XIX", "output": "XVIII"}, {"input": "XX", "output": "XIX"}, {"input": "ii", "output": "i"}, {"input": "iii", "output": "ii"}, {"input": "iv", "output": "iii"}, {"input": "v", "output": "iv"}, {"input": "vi", "output": "v"}, {"input": "vii", "output": "vi"}, {"input": "viii", "output": "vii"}, {"input": "ix", "output": "viii"}, {"input": "x", "output": "ix"}, {"input": "xi", "output": "x"}, {"input": "xii", "output": "xi"}, {"input": "xiii", "output": "xii"}, {"input": "xiv", "output": "xiii"}, {"input": "xv", "output": "xiv"}, {"input": "xvi", "output": "xv"}, {"input": "xvii", "output": "xvi"}, {"input": "xviii", "output": "xvii"}, {"input": "xix", "output": "xviii"}, {"input": "xx", "output": "xix"}, {"input": "tuesday", "output": "monday"}, {"input": "wednesday", "output": "tuesday"}, {"input": "thursday", "output": "wednesday"}, {"input": "friday", "output": "thursday"}, {"input": "saturday", "output": "friday"}, {"input": "sunday", "output": "saturday"}, {"input": "february", "output": "january"}, {"input": "march", "output": "february"}, {"input": "april", "output": "march"}, {"input": "may", "output": "april"}, {"input": "june", "output": "may"}, {"input": "july", "output": "june"}, {"input": "august", "output": "july"}, {"input": "september", "output": "august"}, {"input": "october", "output": "september"}, {"input": "november", "output": "october"}, {"input": "december", "output": "november"}, {"input": "Tuesday", "output": "Monday"}, {"input": "Wednesday", "output": "Tuesday"}, {"input": "Thursday", "output": "Wednesday"}, {"input": "Friday", "output": "Thursday"}, {"input": "Saturday", "output": "Friday"}, {"input": "Sunday", "output": "Saturday"}, {"input": "February", "output": "January"}, {"input": "March", "output": "February"}, {"input": "April", "output": "March"}, {"input": "May", "output": "April"}, {"input": "June", "output": "May"}, {"input": "July", "output": "June"}, {"input": "August", "output": "July"}, {"input": "September", "output": "August"}, {"input": "October", "output": "September"}, {"input": "November", "output": "October"}, {"input": "December", "output": "November"}, {"input": "monday", "output": "sunday"}, {"input": "january", "output": "december"}, {"input": "Monday", "output": "Sunday"}, {"input": "January", "output": "December"}]
--------------------------------------------------------------------------------
/dataset_files/abstractive/singular-plural.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "input": "wallet",
4 | "output": "wallets"
5 | },
6 | {
7 | "input": "keychain",
8 | "output": "keychains"
9 | },
10 | {
11 | "input": "mountain",
12 | "output": "mountains"
13 | },
14 | {
15 | "input": "comb",
16 | "output": "combs"
17 | },
18 | {
19 | "input": "monitor",
20 | "output": "monitors"
21 | },
22 | {
23 | "input": "island",
24 | "output": "islands"
25 | },
26 | {
27 | "input": "rake",
28 | "output": "rakes"
29 | },
30 | {
31 | "input": "needle",
32 | "output": "needles"
33 | },
34 | {
35 | "input": "lighter",
36 | "output": "lighters"
37 | },
38 | {
39 | "input": "slipper",
40 | "output": "slippers"
41 | },
42 | {
43 | "input": "fireplace",
44 | "output": "fireplaces"
45 | },
46 | {
47 | "input": "ladder",
48 | "output": "ladders"
49 | },
50 | {
51 | "input": "jacket",
52 | "output": "jackets"
53 | },
54 | {
55 | "input": "helicopter",
56 | "output": "helicopters"
57 | },
58 | {
59 | "input": "paintbrush",
60 | "output": "paintbrushes"
61 | },
62 | {
63 | "input": "dustpan",
64 | "output": "dustpans"
65 | },
66 | {
67 | "input": "wrench",
68 | "output": "wrenches"
69 | },
70 | {
71 | "input": "tablet",
72 | "output": "tablets"
73 | },
74 | {
75 | "input": "hoe",
76 | "output": "hoes"
77 | },
78 | {
79 | "input": "tie",
80 | "output": "ties"
81 | },
82 | {
83 | "input": "toy",
84 | "output": "toys"
85 | },
86 | {
87 | "input": "glass",
88 | "output": "glasses"
89 | },
90 | {
91 | "input": "hairdryer",
92 | "output": "hairdryers"
93 | },
94 | {
95 | "input": "axe",
96 | "output": "axes"
97 | },
98 | {
99 | "input": "vacuum",
100 | "output": "vacuums"
101 | },
102 | {
103 | "input": "blush",
104 | "output": "blushes"
105 | },
106 | {
107 | "input": "stove",
108 | "output": "stoves"
109 | },
110 | {
111 | "input": "ladle",
112 | "output": "ladles"
113 | },
114 | {
115 | "input": "poster",
116 | "output": "posters"
117 | },
118 | {
119 | "input": "hat",
120 | "output": "hats"
121 | },
122 | {
123 | "input": "lake",
124 | "output": "lakes"
125 | },
126 | {
127 | "input": "razor",
128 | "output": "razors"
129 | },
130 | {
131 | "input": "bottle",
132 | "output": "bottles"
133 | },
134 | {
135 | "input": "glove",
136 | "output": "gloves"
137 | },
138 | {
139 | "input": "grater",
140 | "output": "graters"
141 | },
142 | {
143 | "input": "dishwasher",
144 | "output": "dishwashers"
145 | },
146 | {
147 | "input": "sofa",
148 | "output": "sofas"
149 | },
150 | {
151 | "input": "bag",
152 | "output": "bags"
153 | },
154 | {
155 | "input": "keyboard",
156 | "output": "keyboards"
157 | },
158 | {
159 | "input": "clock",
160 | "output": "clocks"
161 | },
162 | {
163 | "input": "book",
164 | "output": "books"
165 | },
166 | {
167 | "input": "scarf",
168 | "output": "scarves"
169 | },
170 | {
171 | "input": "pants",
172 | "output": "pants"
173 | },
174 | {
175 | "input": "window",
176 | "output": "windows"
177 | },
178 | {
179 | "input": "house",
180 | "output": "houses"
181 | },
182 | {
183 | "input": "freezer",
184 | "output": "freezers"
185 | },
186 | {
187 | "input": "rag",
188 | "output": "rags"
189 | },
190 | {
191 | "input": "racquet",
192 | "output": "racquets"
193 | },
194 | {
195 | "input": "hair gel",
196 | "output": "hair gels"
197 | },
198 | {
199 | "input": "door",
200 | "output": "doors"
201 | },
202 | {
203 | "input": "pillow",
204 | "output": "pillows"
205 | },
206 | {
207 | "input": "ruler",
208 | "output": "rulers"
209 | },
210 | {
211 | "input": "washer",
212 | "output": "washers"
213 | },
214 | {
215 | "input": "ocean",
216 | "output": "oceans"
217 | },
218 | {
219 | "input": "plate",
220 | "output": "plates"
221 | },
222 | {
223 | "input": "eyeshadow",
224 | "output": "eyeshadows"
225 | },
226 | {
227 | "input": "zipper",
228 | "output": "zippers"
229 | },
230 | {
231 | "input": "radio",
232 | "output": "radios"
233 | },
234 | {
235 | "input": "flower",
236 | "output": "flowers"
237 | },
238 | {
239 | "input": "laptop",
240 | "output": "laptops"
241 | },
242 | {
243 | "input": "eraser",
244 | "output": "erasers"
245 | },
246 | {
247 | "input": "corkscrew",
248 | "output": "corkscrews"
249 | },
250 | {
251 | "input": "eyeliner",
252 | "output": "eyeliners"
253 | },
254 | {
255 | "input": "desk",
256 | "output": "desks"
257 | },
258 | {
259 | "input": "knife",
260 | "output": "knives"
261 | },
262 | {
263 | "input": "helmet",
264 | "output": "helmets"
265 | },
266 | {
267 | "input": "mixer",
268 | "output": "mixers"
269 | },
270 | {
271 | "input": "microwave",
272 | "output": "microwaves"
273 | },
274 | {
275 | "input": "button",
276 | "output": "buttons"
277 | },
278 | {
279 | "input": "jar",
280 | "output": "jars"
281 | },
282 | {
283 | "input": "pan",
284 | "output": "pans"
285 | },
286 | {
287 | "input": "key",
288 | "output": "keys"
289 | },
290 | {
291 | "input": "perfume",
292 | "output": "perfumes"
293 | },
294 | {
295 | "input": "tape",
296 | "output": "tapes"
297 | },
298 | {
299 | "input": "shoes",
300 | "output": "shoes"
301 | },
302 | {
303 | "input": "shirt",
304 | "output": "shirts"
305 | },
306 | {
307 | "input": "candle",
308 | "output": "candles"
309 | },
310 | {
311 | "input": "juicer",
312 | "output": "juicers"
313 | },
314 | {
315 | "input": "peeler",
316 | "output": "peelers"
317 | },
318 | {
319 | "input": "mirror",
320 | "output": "mirrors"
321 | },
322 | {
323 | "input": "mascara",
324 | "output": "mascaras"
325 | },
326 | {
327 | "input": "whisk",
328 | "output": "whisks"
329 | },
330 | {
331 | "input": "shovel",
332 | "output": "shovels"
333 | },
334 | {
335 | "input": "marker",
336 | "output": "markers"
337 | },
338 | {
339 | "input": "lotion",
340 | "output": "lotions"
341 | },
342 | {
343 | "input": "matches",
344 | "output": "matches"
345 | },
346 | {
347 | "input": "moon",
348 | "output": "moons"
349 | },
350 | {
351 | "input": "pot",
352 | "output": "pots"
353 | },
354 | {
355 | "input": "mop",
356 | "output": "mops"
357 | },
358 | {
359 | "input": "sprinkler",
360 | "output": "sprinklers"
361 | },
362 | {
363 | "input": "can",
364 | "output": "cans"
365 | },
366 | {
367 | "input": "notebook",
368 | "output": "notebooks"
369 | },
370 | {
371 | "input": "airplane",
372 | "output": "airplanes"
373 | },
374 | {
375 | "input": "tongs",
376 | "output": "tongs"
377 | },
378 | {
379 | "input": "phone",
380 | "output": "phones"
381 | },
382 | {
383 | "input": "paint",
384 | "output": "paints"
385 | },
386 | {
387 | "input": "conditioner",
388 | "output": "conditioners"
389 | },
390 | {
391 | "input": "purse",
392 | "output": "purses"
393 | },
394 | {
395 | "input": "broom",
396 | "output": "brooms"
397 | },
398 | {
399 | "input": "rug",
400 | "output": "rugs"
401 | },
402 | {
403 | "input": "toaster",
404 | "output": "toasters"
405 | },
406 | {
407 | "input": "kettle",
408 | "output": "kettles"
409 | },
410 | {
411 | "input": "blender",
412 | "output": "blenders"
413 | },
414 | {
415 | "input": "colander",
416 | "output": "colanders"
417 | },
418 | {
419 | "input": "toothpaste",
420 | "output": "toothpastes"
421 | },
422 | {
423 | "input": "bed",
424 | "output": "beds"
425 | },
426 | {
427 | "input": "refrigerator",
428 | "output": "refrigerators"
429 | },
430 | {
431 | "input": "stapler",
432 | "output": "staplers"
433 | },
434 | {
435 | "input": "backpack",
436 | "output": "backpacks"
437 | },
438 | {
439 | "input": "fabric",
440 | "output": "fabrics"
441 | },
442 | {
443 | "input": "nut",
444 | "output": "nuts"
445 | },
446 | {
447 | "input": "bowl",
448 | "output": "bowls"
449 | },
450 | {
451 | "input": "bolt",
452 | "output": "bolts"
453 | },
454 | {
455 | "input": "chopsticks",
456 | "output": "chopsticks"
457 | },
458 | {
459 | "input": "computer",
460 | "output": "computers"
461 | },
462 | {
463 | "input": "valley",
464 | "output": "valleys"
465 | },
466 | {
467 | "input": "lantern",
468 | "output": "lanterns"
469 | },
470 | {
471 | "input": "boot",
472 | "output": "boots"
473 | },
474 | {
475 | "input": "bucket",
476 | "output": "buckets"
477 | },
478 | {
479 | "input": "sandal",
480 | "output": "sandals"
481 | },
482 | {
483 | "input": "lock",
484 | "output": "locks"
485 | },
486 | {
487 | "input": "drill",
488 | "output": "drills"
489 | },
490 | {
491 | "input": "toothbrush",
492 | "output": "toothbrushes"
493 | },
494 | {
495 | "input": "lamp",
496 | "output": "lamps"
497 | },
498 | {
499 | "input": "star",
500 | "output": "stars"
501 | },
502 | {
503 | "input": "bandana",
504 | "output": "bandanas"
505 | },
506 | {
507 | "input": "stereo",
508 | "output": "stereos"
509 | },
510 | {
511 | "input": "teapot",
512 | "output": "teapots"
513 | },
514 | {
515 | "input": "thread",
516 | "output": "threads"
517 | },
518 | {
519 | "input": "lawnmower",
520 | "output": "lawnmowers"
521 | },
522 | {
523 | "input": "mug",
524 | "output": "mugs"
525 | },
526 | {
527 | "input": "game",
528 | "output": "games"
529 | },
530 | {
531 | "input": "shoe",
532 | "output": "shoes"
533 | },
534 | {
535 | "input": "mattress",
536 | "output": "mattresses"
537 | },
538 | {
539 | "input": "sunglasses",
540 | "output": "sunglasses"
541 | },
542 | {
543 | "input": "river",
544 | "output": "rivers"
545 | },
546 | {
547 | "input": "beach",
548 | "output": "beaches"
549 | },
550 | {
551 | "input": "sponge",
552 | "output": "sponges"
553 | },
554 | {
555 | "input": "blanket",
556 | "output": "blankets"
557 | },
558 | {
559 | "input": "headphones",
560 | "output": "headphones"
561 | },
562 | {
563 | "input": "mouse",
564 | "output": "mice"
565 | },
566 | {
567 | "input": "pen",
568 | "output": "pens"
569 | },
570 | {
571 | "input": "tree",
572 | "output": "trees"
573 | },
574 | {
575 | "input": "car",
576 | "output": "cars"
577 | },
578 | {
579 | "input": "belt",
580 | "output": "belts"
581 | },
582 | {
583 | "input": "sky",
584 | "output": "skies"
585 | },
586 | {
587 | "input": "socks",
588 | "output": "sockss"
589 | },
590 | {
591 | "input": "briefcase",
592 | "output": "briefcases"
593 | },
594 | {
595 | "input": "crayon",
596 | "output": "crayons"
597 | },
598 | {
599 | "input": "screw",
600 | "output": "screws"
601 | },
602 | {
603 | "input": "bicycle",
604 | "output": "bicycles"
605 | },
606 | {
607 | "input": "grill",
608 | "output": "grills"
609 | },
610 | {
611 | "input": "sun",
612 | "output": "suns"
613 | },
614 | {
615 | "input": "coffee maker",
616 | "output": "coffee makers"
617 | },
618 | {
619 | "input": "forest",
620 | "output": "forests"
621 | },
622 | {
623 | "input": "spatula",
624 | "output": "spatulas"
625 | },
626 | {
627 | "input": "brush",
628 | "output": "brushes"
629 | },
630 | {
631 | "input": "cup",
632 | "output": "cups"
633 | },
634 | {
635 | "input": "picture",
636 | "output": "pictures"
637 | },
638 | {
639 | "input": "jewelry",
640 | "output": "jewelries"
641 | },
642 | {
643 | "input": "shampoo",
644 | "output": "shampoos"
645 | },
646 | {
647 | "input": "speaker",
648 | "output": "speakers"
649 | },
650 | {
651 | "input": "boat",
652 | "output": "boats"
653 | },
654 | {
655 | "input": "bus",
656 | "output": "buses"
657 | },
658 | {
659 | "input": "sock",
660 | "output": "socks"
661 | },
662 | {
663 | "input": "pliers",
664 | "output": "pliers"
665 | },
666 | {
667 | "input": "suitcase",
668 | "output": "suitcases"
669 | },
670 | {
671 | "input": "saw",
672 | "output": "saws"
673 | },
674 | {
675 | "input": "fork",
676 | "output": "forks"
677 | },
678 | {
679 | "input": "calendar",
680 | "output": "calendars"
681 | },
682 | {
683 | "input": "umbrella",
684 | "output": "umbrellas"
685 | },
686 | {
687 | "input": "printer",
688 | "output": "printers"
689 | },
690 | {
691 | "input": "hose",
692 | "output": "hoses"
693 | },
694 | {
695 | "input": "paper",
696 | "output": "papers"
697 | },
698 | {
699 | "input": "television",
700 | "output": "televisions"
701 | },
702 | {
703 | "input": "newspaper",
704 | "output": "newspapers"
705 | },
706 | {
707 | "input": "hammer",
708 | "output": "hammers"
709 | },
710 | {
711 | "input": "chair",
712 | "output": "chairs"
713 | },
714 | {
715 | "input": "ball",
716 | "output": "balls"
717 | },
718 | {
719 | "input": "watch",
720 | "output": "watches"
721 | },
722 | {
723 | "input": "screwdriver",
724 | "output": "screwdrivers"
725 | },
726 | {
727 | "input": "cloud",
728 | "output": "clouds"
729 | },
730 | {
731 | "input": "dryer",
732 | "output": "dryers"
733 | },
734 | {
735 | "input": "train",
736 | "output": "trains"
737 | },
738 | {
739 | "input": "lipstick",
740 | "output": "lipsticks"
741 | },
742 | {
743 | "input": "hairbrush",
744 | "output": "hairbrushes"
745 | },
746 | {
747 | "input": "glue",
748 | "output": "glues"
749 | },
750 | {
751 | "input": "bat",
752 | "output": "bats"
753 | },
754 | {
755 | "input": "roller",
756 | "output": "rollers"
757 | },
758 | {
759 | "input": "ship",
760 | "output": "ships"
761 | },
762 | {
763 | "input": "spoon",
764 | "output": "spoons"
765 | },
766 | {
767 | "input": "motorcycle",
768 | "output": "motorcycles"
769 | },
770 | {
771 | "input": "oven",
772 | "output": "ovens"
773 | },
774 | {
775 | "input": "soap",
776 | "output": "soaps"
777 | },
778 | {
779 | "input": "table",
780 | "output": "tables"
781 | },
782 | {
783 | "input": "doll",
784 | "output": "dolls"
785 | },
786 | {
787 | "input": "flashlight",
788 | "output": "flashlights"
789 | },
790 | {
791 | "input": "scissors",
792 | "output": "scissors"
793 | },
794 | {
795 | "input": "pencil",
796 | "output": "pencils"
797 | },
798 | {
799 | "input": "magazine",
800 | "output": "magazines"
801 | },
802 | {
803 | "input": "iron",
804 | "output": "irons"
805 | },
806 | {
807 | "input": "camera",
808 | "output": "cameras"
809 | },
810 | {
811 | "input": "trash can",
812 | "output": "trash cans"
813 | },
814 | {
815 | "input": "nail",
816 | "output": "nails"
817 | },
818 | {
819 | "input": "box",
820 | "output": "boxes"
821 | }
822 | ]
--------------------------------------------------------------------------------
/dataset_files/generate/categories.json:
--------------------------------------------------------------------------------
1 | {
2 | "animal": [
3 | "alpaca",
4 | "ant",
5 | "anteater",
6 | "bat",
7 | "bear",
8 | "bee",
9 | "beaver",
10 | "bird",
11 | "buffalo",
12 | "bunny",
13 | "butterfly",
14 | "camel",
15 | "cat",
16 | "caterpillar",
17 | "chicken",
18 | "cheetah",
19 | "cow",
20 | "coyote",
21 | "dog",
22 | "dolphin",
23 | "donkey",
24 | "dove",
25 | "duck",
26 | "eagle",
27 | "eel",
28 | "elephant",
29 | "ferret",
30 | "finch",
31 | "fish",
32 | "fox",
33 | "frog",
34 | "goat",
35 | "goose",
36 | "gorilla",
37 | "hamster",
38 | "hedgehog",
39 | "hippopotamus",
40 | "horse",
41 | "jaguar",
42 | "kangaroo",
43 | "koala",
44 | "llama",
45 | "lion",
46 | "lizard",
47 | "monkey",
48 | "moth",
49 | "mouse",
50 | "octopus",
51 | "otter",
52 | "owl",
53 | "pig",
54 | "pigeon",
55 | "rabbit",
56 | "rat",
57 | "seal",
58 | "scorpion",
59 | "shark",
60 | "sheep",
61 | "skunk",
62 | "snail",
63 | "snake",
64 | "spider",
65 | "squirrel",
66 | "swan",
67 | "tiger",
68 | "turkey",
69 | "turtle",
70 | "vulture",
71 | "weasel",
72 | "whale",
73 | "wolf",
74 | "wombat",
75 | "worm",
76 | "wolverine",
77 | "zebra",
78 | "aardvark",
79 | "alligator",
80 | "armadillo",
81 | "baboon",
82 | "badger",
83 | "barracuda",
84 | "bison",
85 | "boar",
86 | "capybara",
87 | "caribou",
88 | "chimpanzee",
89 | "chinchilla",
90 | "chipmunk",
91 | "cobra",
92 | "cockroach",
93 | "cougar",
94 | "crab",
95 | "crane",
96 | "crocodile",
97 | "crow",
98 | "deer",
99 | "dingo",
100 | "dragonfly",
101 | "elk",
102 | "emu",
103 | "falcon",
104 | "firefly",
105 | "flamingo",
106 | "fossa",
107 | "gazelle",
108 | "gecko",
109 | "giraffe",
110 | "gnu",
111 | "grizzly",
112 | "hornet",
113 | "hyena",
114 | "ibex",
115 | "iguana",
116 | "impala",
117 | "jackal",
118 | "jellyfish",
119 | "komodo",
120 | "lemur",
121 | "leopard",
122 | "lobster",
123 | "lynx",
124 | "macaw",
125 | "magpie",
126 | "mandrill",
127 | "manatee",
128 | "mantis",
129 | "meerkat",
130 | "moose",
131 | "moray",
132 | "narwhal",
133 | "newt",
134 | "ocelot",
135 | "okapi",
136 | "opossum",
137 | "orangutan",
138 | "oryx",
139 | "ostrich",
140 | "panda",
141 | "panther",
142 | "parrot",
143 | "peacock",
144 | "pelican",
145 | "penguin",
146 | "puma",
147 | "python",
148 | "quokka",
149 | "raccoon",
150 | "reindeer",
151 | "rhinoceros",
152 | "salamander",
153 | "seahorse",
154 | "sloth",
155 | "toucan",
156 | "walrus",
157 | "woodpecker",
158 | "yak"
159 | ],
160 | "object": [
161 | "accordion",
162 | "airplane",
163 | "alarm",
164 | "anchor",
165 | "apron",
166 | "bag",
167 | "ball",
168 | "basket",
169 | "basketball",
170 | "beef",
171 | "bicycle",
172 | "blanket",
173 | "boat",
174 | "book",
175 | "boomerang",
176 | "bottle",
177 | "bowl",
178 | "cactus",
179 | "cake",
180 | "camera",
181 | "candle",
182 | "candlestick",
183 | "candy",
184 | "car",
185 | "carrot",
186 | "chair",
187 | "chocolate",
188 | "clock",
189 | "computer",
190 | "cookie",
191 | "cream",
192 | "cube",
193 | "cup",
194 | "curtain",
195 | "desk",
196 | "dice",
197 | "donut",
198 | "door",
199 | "dress",
200 | "drum",
201 | "dumbbell",
202 | "duster",
203 | "earmuffs",
204 | "earring",
205 | "easel",
206 | "egg",
207 | "envelope",
208 | "eraser",
209 | "fan",
210 | "feather",
211 | "fishing pole",
212 | "flower",
213 | "fork",
214 | "fountain",
215 | "garlic",
216 | "glass",
217 | "glasses",
218 | "globe",
219 | "gloves",
220 | "guitar",
221 | "gumball",
222 | "hairbrush",
223 | "hammer",
224 | "hammock",
225 | "hat",
226 | "hoop",
227 | "house",
228 | "ice",
229 | "igloo",
230 | "incense",
231 | "ink",
232 | "jacket",
233 | "jar",
234 | "jeans",
235 | "jigsaw",
236 | "juice",
237 | "kayak",
238 | "kettle",
239 | "key",
240 | "kite",
241 | "knife",
242 | "ladder",
243 | "lamp",
244 | "lantern",
245 | "laptop",
246 | "lettuce",
247 | "map",
248 | "maracas",
249 | "marker",
250 | "match",
251 | "microphone",
252 | "mirror",
253 | "motorcycle",
254 | "necklace",
255 | "net",
256 | "newspaper",
257 | "notebook",
258 | "olive",
259 | "onion",
260 | "oven",
261 | "paintbrush",
262 | "painting",
263 | "paper",
264 | "pasta",
265 | "pen",
266 | "pencil",
267 | "pepper",
268 | "phone",
269 | "piano",
270 | "picture",
271 | "pillow",
272 | "pizza",
273 | "plant",
274 | "plate",
275 | "pork",
276 | "potato",
277 | "puzzle",
278 | "quill",
279 | "quilt",
280 | "radio",
281 | "rake",
282 | "remote",
283 | "rice",
284 | "rifle",
285 | "robot",
286 | "rock",
287 | "rug",
288 | "ruler",
289 | "scissors",
290 | "sculpture",
291 | "shirt",
292 | "shoe",
293 | "skates",
294 | "snorkel",
295 | "socks",
296 | "soda",
297 | "sofa",
298 | "spoon",
299 | "stapler",
300 | "table",
301 | "tambourine",
302 | "tape",
303 | "teapot",
304 | "television",
305 | "tennis racket",
306 | "toilet",
307 | "tomato",
308 | "towel",
309 | "ukulele",
310 | "umbrella",
311 | "vacuum",
312 | "vase",
313 | "violin",
314 | "volleyball",
315 | "wallet",
316 | "watermelon",
317 | "whistle",
318 | "window",
319 | "wristwatch",
320 | "x-ray",
321 | "xylophone",
322 | "yacht",
323 | "yarn",
324 | "yo-yo",
325 | "yogurt",
326 | "zeppelin",
327 | "zipper",
328 | "zucchini"
329 | ],
330 | "verb": [
331 | "achieve",
332 | "analyze",
333 | "approve",
334 | "argue",
335 | "arrive",
336 | "attack",
337 | "believe",
338 | "breathe",
339 | "build",
340 | "calculate",
341 | "celebrate",
342 | "change",
343 | "choose",
344 | "climb",
345 | "collect",
346 | "compete",
347 | "complete",
348 | "consider",
349 | "consult",
350 | "copy",
351 | "create",
352 | "cry",
353 | "dance",
354 | "decide",
355 | "define",
356 | "deliver",
357 | "design",
358 | "destroy",
359 | "develop",
360 | "discuss",
361 | "discover",
362 | "dislike",
363 | "divide",
364 | "doubt",
365 | "enjoy",
366 | "examine",
367 | "exchange",
368 | "exist",
369 | "explore",
370 | "fear",
371 | "fight",
372 | "finish",
373 | "focus",
374 | "forgive",
375 | "gather",
376 | "give",
377 | "grow",
378 | "handle",
379 | "hate",
380 | "hear",
381 | "help",
382 | "jump",
383 | "juggle",
384 | "jog",
385 | "join",
386 | "judge",
387 | "jolt",
388 | "justify",
389 | "kick",
390 | "keep",
391 | "kill",
392 | "kindle",
393 | "kiss",
394 | "knit",
395 | "knock",
396 | "knot",
397 | "know",
398 | "kneel",
399 | "label",
400 | "land",
401 | "laugh",
402 | "launch",
403 | "learn",
404 | "lecture",
405 | "lift",
406 | "like",
407 | "listen",
408 | "live",
409 | "make",
410 | "manage",
411 | "manipulate",
412 | "mark",
413 | "master",
414 | "maximize",
415 | "measure",
416 | "memorize",
417 | "merge",
418 | "minimize",
419 | "navigate",
420 | "need",
421 | "negotiate",
422 | "notice",
423 | "nourish",
424 | "nurture",
425 | "observe",
426 | "obtain",
427 | "open",
428 | "operate",
429 | "organize",
430 | "overcome",
431 | "oversee",
432 | "paint",
433 | "participate",
434 | "perform",
435 | "persuade",
436 | "plan",
437 | "play",
438 | "practice",
439 | "predict",
440 | "prepare",
441 | "produce",
442 | "qualify",
443 | "question",
444 | "query",
445 | "quiet",
446 | "race",
447 | "reach",
448 | "read",
449 | "realize",
450 | "recruit",
451 | "reflect",
452 | "release",
453 | "relax",
454 | "remember",
455 | "remove",
456 | "sail",
457 | "sample",
458 | "save",
459 | "schedule",
460 | "search",
461 | "select",
462 | "serve",
463 | "solve",
464 | "speak",
465 | "study",
466 | "talk",
467 | "target",
468 | "teach",
469 | "test",
470 | "think",
471 | "train",
472 | "transform",
473 | "travel",
474 | "treat",
475 | "try",
476 | "uncover",
477 | "understand",
478 | "unite",
479 | "update",
480 | "use",
481 | "validate",
482 | "value",
483 | "verify",
484 | "view",
485 | "visit",
486 | "visualize",
487 | "volunteer",
488 | "walk",
489 | "watch",
490 | "win",
491 | "work",
492 | "write",
493 | "xerox",
494 | "yearn",
495 | "yield",
496 | "zoom",
497 | "zap"
498 | ],
499 | "color": [
500 | "red",
501 | "blue",
502 | "green",
503 | "yellow",
504 | "orange",
505 | "purple",
506 | "violet",
507 | "pink",
508 | "brown",
509 | "black",
510 | "white",
511 | "gray",
512 | "silver",
513 | "gold",
514 | "coral",
515 | "cream",
516 | "olive",
517 | "salmon",
518 | "navy",
519 | "mint",
520 | "mustard",
521 | "indigo"
522 | ],
523 | "fruit": [
524 | "apple",
525 | "apricot",
526 | "avocado",
527 | "banana",
528 | "blackberry",
529 | "cherry",
530 | "clementine",
531 | "coconut",
532 | "cranberry",
533 | "date",
534 | "dragonfruit",
535 | "durian",
536 | "fig",
537 | "gooseberry",
538 | "guava",
539 | "grape",
540 | "grapefruit",
541 | "huckleberry",
542 | "jackfruit",
543 | "kiwifruit",
544 | "kumquat",
545 | "lemon",
546 | "lime",
547 | "mango",
548 | "mandarine",
549 | "nectarine",
550 | "orange",
551 | "papaya",
552 | "passionfruit",
553 | "peach",
554 | "pear",
555 | "persimmon",
556 | "pineapple",
557 | "plantain",
558 | "plum",
559 | "pomegranate",
560 | "prune",
561 | "raspberry",
562 | "strawberry",
563 | "tangerine"
564 | ],
565 | "adjective": [
566 | "agile",
567 | "adorable",
568 | "adoring",
569 | "adventurous",
570 | "affable",
571 | "affectionate",
572 | "agreeable",
573 | "altruistic",
574 | "amazing",
575 | "amiable",
576 | "bad",
577 | "benevolent",
578 | "big",
579 | "bitter",
580 | "blissful",
581 | "blithe",
582 | "bold",
583 | "bountiful",
584 | "brave",
585 | "bright",
586 | "calm",
587 | "carefree",
588 | "caring",
589 | "charismatic",
590 | "charming",
591 | "cheap",
592 | "cheerful",
593 | "clean",
594 | "clever",
595 | "cold",
596 | "courageous",
597 | "cowardly",
598 | "daring",
599 | "dark",
600 | "dazzling",
601 | "delightful",
602 | "determined",
603 | "devoted",
604 | "diligent",
605 | "dirty",
606 | "dry",
607 | "dynamic",
608 | "eager",
609 | "ecstatic",
610 | "eloquent",
611 | "enchanting",
612 | "energetic",
613 | "enthusiastic",
614 | "expensive",
615 | "exquisite",
616 | "exuberant",
617 | "faithful",
618 | "fascinating",
619 | "fast",
620 | "fearless",
621 | "fierce",
622 | "fresh",
623 | "friendly",
624 | "funny",
625 | "generous",
626 | "gentle",
627 | "genuine",
628 | "good",
629 | "graceful",
630 | "gracious",
631 | "grateful",
632 | "happy",
633 | "hard",
634 | "harmonious",
635 | "heavy",
636 | "honest",
637 | "hopeful",
638 | "hot",
639 | "humble",
640 | "hungry",
641 | "idealistic",
642 | "innocent",
643 | "inquisitive",
644 | "insightful",
645 | "intelligent",
646 | "intrepid",
647 | "intuitive",
648 | "inventive",
649 | "jolly",
650 | "jovial",
651 | "joyful",
652 | "joyous",
653 | "jubilant",
654 | "keen",
655 | "kind",
656 | "kind-hearted",
657 | "kindhearted",
658 | "kindred",
659 | "knowledgeable",
660 | "laughing",
661 | "light",
662 | "lively",
663 | "long",
664 | "loud",
665 | "lovable",
666 | "lovely",
667 | "loving",
668 | "lucky",
669 | "luminous",
670 | "magnificent",
671 | "mellow",
672 | "mild",
673 | "mirthful",
674 | "modern",
675 | "modest",
676 | "naive",
677 | "natural",
678 | "naughty",
679 | "new",
680 | "noble",
681 | "nurturing",
682 | "observant",
683 | "old",
684 | "optimistic",
685 | "passionate",
686 | "patient",
687 | "peaceful",
688 | "pensive",
689 | "playful",
690 | "quick",
691 | "quick-witted",
692 | "quiet",
693 | "quirky",
694 | "quizzical",
695 | "radiant",
696 | "reliable",
697 | "resilient",
698 | "resolute",
699 | "resourceful",
700 | "rotten",
701 | "sad",
702 | "salty",
703 | "sensible",
704 | "sensitive",
705 | "serene",
706 | "serious",
707 | "short",
708 | "silly",
709 | "sincere",
710 | "slow",
711 | "small",
712 | "smart",
713 | "soft",
714 | "sour",
715 | "spicy",
716 | "strong",
717 | "stupid",
718 | "sweet",
719 | "talented",
720 | "tall",
721 | "tenacious",
722 | "tender",
723 | "thick",
724 | "thin",
725 | "thirsty",
726 | "thoughtful",
727 | "tranquil",
728 | "trustworthy",
729 | "unique",
730 | "unselfish",
731 | "unwavering",
732 | "upbeat",
733 | "uplifting",
734 | "versatile",
735 | "vibrant",
736 | "vivacious",
737 | "warm",
738 | "warmhearted",
739 | "weak",
740 | "wet",
741 | "whimsical",
742 | "wise",
743 | "witty",
744 | "wonderful",
745 | "young",
746 | "youthful",
747 | "zany",
748 | "zealous",
749 | "zesty"
750 | ],
751 | "pronoun": [
752 | "I",
753 | "you",
754 | "he",
755 | "she",
756 | "it",
757 | "we",
758 | "they",
759 | "me",
760 | "him",
761 | "her",
762 | "us",
763 | "them",
764 | "myself",
765 | "yourself",
766 | "himself",
767 | "herself",
768 | "itself",
769 | "ourselves",
770 | "themselves",
771 | "who",
772 | "whom",
773 | "whose",
774 | "whoever",
775 | "which",
776 | "that",
777 | "these",
778 | "those"
779 | ],
780 | "preposition": [
781 | "about",
782 | "above",
783 | "across",
784 | "after",
785 | "against",
786 | "along",
787 | "among",
788 | "around",
789 | "as",
790 | "at",
791 | "before",
792 | "behind",
793 | "below",
794 | "beneath",
795 | "beside",
796 | "between",
797 | "beyond",
798 | "but",
799 | "by",
800 | "concerning",
801 | "considering",
802 | "despite",
803 | "down",
804 | "during",
805 | "except",
806 | "for",
807 | "from",
808 | "in",
809 | "inside",
810 | "into",
811 | "like",
812 | "near",
813 | "of",
814 | "off",
815 | "on",
816 | "onto",
817 | "out",
818 | "outside",
819 | "over",
820 | "past",
821 | "regarding",
822 | "round",
823 | "since",
824 | "through",
825 | "throughout",
826 | "to",
827 | "toward",
828 | "under",
829 | "underneath",
830 | "until",
831 | "up",
832 | "upon",
833 | "with",
834 | "within",
835 | "without"
836 | ]
837 | }
--------------------------------------------------------------------------------
/dataset_files/generate/create_antonym_synonym_datasets.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from typing import *
3 | from typing import TextIO
4 |
5 | import json
6 | import os
7 | import random
8 |
9 | def verify_word_length(word: str, tokenizer: object) -> bool:
10 | """
11 | Verifies whether a word can be tokenized into a single token
12 |
13 | Parameters:
14 | word: the word that we're checking
15 | tokenizer: the tokenizer we use to tokenize the word
16 |
17 | Return: a boolean denoting whether the word is within the required 1 or not
18 | """
19 |
20 | return len(tokenizer(word)['input_ids']) == 1
21 |
22 | def parse_file(
23 | f_in: TextIO,
24 | ant_list: List[Dict],
25 | syn_list: List[Dict],
26 | seen: Set,
27 | tokenizer: object
28 | ):
29 | """
30 | Parses the input file into synonym and antonym categories
31 |
32 | Parameters:
33 | f_in: the input file from where the data will be taken from
34 | ant_list: the list of antonyms
35 | syn_list: the list of synonyms
36 | seen: the seen set of tuples to check against for duplicates
37 | tokenizer: the tokenizer we use to tokenize the word
38 | """
39 | for line in f_in:
40 | word1, word2, t = line.split()
41 | t = int(t)
42 |
43 | word1_bool = verify_word_length(" " + word1, tokenizer)
44 | word2_bool = verify_word_length(" " + word2, tokenizer)
45 |
46 | if word1_bool and word2_bool:
47 | d = {"input": word1, "output": word2}
48 | words = (word1, word2)
49 | if words not in seen:
50 | seen.add(words)
51 | else:
52 | continue
53 | # Synonym
54 | if t == 0:
55 | syn_list.append(d)
56 | # Antonym
57 | else:
58 | ant_list.append(d)
59 | else:
60 | continue
61 |
62 |
63 | if __name__ == "__main__":
64 | # Seed for dataset generation
65 | random.seed(42)
66 | model_name = r"EleutherAI/gpt-j-6B"
67 |
68 | # Load Tokenizer
69 | tokenizer = AutoTokenizer.from_pretrained(model_name)
70 | tokenizer.pad_token = tokenizer.eos_token
71 |
72 | assert os.path.exists('./AntSynNET/dataset'), "Original dataset missing! Please first clone https://github.com/nguyenkh/AntSynNET into this folder in order to re-generate antonym and synonym datasets."
73 |
74 | out_dir = "../abstractive"
75 |
76 | if not os.path.exists(out_dir):
77 | os.makedirs(out_dir)
78 |
79 | input_data_dir = "./AntSynNET/dataset"
80 | splits = ["train", "val", "test"]
81 | types = ["adjective", "noun", "verb"]
82 |
83 | ant_list = []
84 | syn_list = []
85 | filename_ant = "antonym.json"
86 | filename_syn = "synonym.json"
87 | seen = set()
88 |
89 | ant_path = os.path.join(out_dir, filename_ant)
90 | syn_path = os.path.join(out_dir, filename_syn)
91 | f_ant = open(ant_path, "w")
92 | f_syn = open(syn_path, "w")
93 | for s in splits:
94 | for t in types:
95 | path = t + "-pairs." + s
96 | full_path = os.path.join(input_data_dir, path)
97 | input_file = open(full_path, "r")
98 | parse_file(input_file, ant_list, syn_list, seen, tokenizer)
99 |
100 | json.dump(ant_list, f_ant)
101 | json.dump(syn_list, f_syn)
102 | f_ant.close()
103 | f_syn.close()
104 |
--------------------------------------------------------------------------------
/dataset_files/generate/create_translation_datasets.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from typing import *
3 | from typing import TextIO
4 |
5 | import json
6 | import os
7 |
8 | def verify_word_length(word: str, tokenizer: object) -> bool:
9 | """
10 | Verifies whether a word can be tokenized into a single word or not
11 |
12 | Parameters:
13 | word: the word that we're checking
14 | tokenizer: the tokenizer we use to tokenize the word
15 |
16 | Return: a boolean denoting whether the word is within the required 1 or not
17 | """
18 |
19 | return len(tokenizer(word)['input_ids']) == 1
20 |
21 | if __name__ == "__main__":
22 |
23 | d_names = {'en-de':"english-german", 'en-es':"english-spanish",'en-fr':"english-french"}
24 | path_exists = [os.path.exists(f'./translation/{lang_id}.0-5000.txt') for lang_id in d_names.keys()] + [os.path.exists(f'./translation/{lang_id}.5000-6500.txt') for lang_id in d_names.keys()]
25 |
26 | assert all(path_exists), "Original data missing! Please download corresponding 'train' and 'test' files from https://github.com/facebookresearch/MUSE#ground-truth-bilingual-dictionaries in order to re-generate translation_datasets."
27 |
28 | model_name = r"EleutherAI/gpt-j-6b"
29 |
30 | # Load Tokenizer
31 | tokenizer = AutoTokenizer.from_pretrained(model_name)
32 | tokenizer.pad_token = tokenizer.eos_token
33 |
34 | out_dir = '../abstractive'
35 |
36 | if not os.path.exists(out_dir):
37 | os.makedirs(out_dir)
38 |
39 | for lang_id in d_names.keys():
40 | valid = []
41 | for d_base in [f'./translation/{lang_id}.0-5000.txt', f'./translation/{lang_id}.5000-6500.txt']:
42 | with open(d_base, 'r', encoding="utf-8") as f:
43 | lines = f.read()
44 |
45 | word_pairs = list(set([tuple(x.split()) for x in lines.splitlines()]))
46 | word_pairs = [{'input':w1, 'output':w2} for (w1,w2) in word_pairs]
47 |
48 | for i, x in enumerate(word_pairs):
49 | if (x['input'] != x['output']): # Filter pairs that are exact copies
50 | valid.append(word_pairs[i])
51 |
52 | json.dump(valid, open(os.path.join(out_dir, f'{d_names[lang_id]}.json'), 'w'))
53 |
54 |
55 |
--------------------------------------------------------------------------------
/fv_environment.yml:
--------------------------------------------------------------------------------
1 | # all packages used:
2 | name: fv
3 | channels:
4 | - pytorch
5 | - huggingface
6 | - nvidia
7 | - defaults
8 | dependencies:
9 | - python=3.10
10 | - cudatoolkit=11.7.0
11 | - datasets=2.14.3
12 | - jupyter=1.0.0
13 | - matplotlib=3.7.1
14 | - numpy=1.25.0
15 | - pandas=1.5.3
16 | - plotly=5.9.0
17 | - pytorch=1.13.0
18 | - pip=23.2.1
19 | - scikit-learn=1.3.0
20 | - seaborn=0.12.2
21 | - sentencepiece=0.1.99
22 | - transformers=4.49.0
23 | - tqdm=4.65.0
24 | - pip:
25 | - git+https://github.com/davidbau/baukit@main#egg=baukit
26 | - bitsandbytes==0.45.3
27 | - huggingface-hub==0.29.3
28 | - accelerate==0.21.0
--------------------------------------------------------------------------------
/fv_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ericwtodd/function_vectors/751e2219d304eba471cffcacc9efd89a4f8ef3c4/fv_overview.png
--------------------------------------------------------------------------------
/notebooks/fv_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import os, re, json\n",
20 | "import torch, numpy as np\n",
21 | "\n",
22 | "import sys\n",
23 | "sys.path.append('..')\n",
24 | "torch.set_grad_enabled(False)\n",
25 | "\n",
26 | "from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector\n",
27 | "from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention\n",
28 | "from src.utils.model_utils import load_gpt_model_and_tokenizer\n",
29 | "from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt\n",
30 | "from src.utils.eval_utils import decode_to_vocab, sentence_eval"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "## Load model & tokenizer"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "model_name = 'EleutherAI/gpt-j-6b'\n",
47 | "model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)\n",
48 | "EDIT_LAYER = 9"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Load dataset and Compute task-conditioned mean activations"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "dataset = load_dataset('antonym', seed=0)\n",
65 | "mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer)"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "## Compute function vector (FV)"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "# Sample ICL example pairs, and a test word\n",
98 | "dataset = load_dataset('antonym')\n",
99 | "word_pairs = dataset['train'][:5]\n",
100 | "test_pair = dataset['test'][21]\n",
101 | "\n",
102 | "prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)\n",
103 | "sentence = create_prompt(prompt_data)\n",
104 | "print(\"ICL prompt:\\n\", repr(sentence), '\\n\\n')\n",
105 | "\n",
106 | "shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)\n",
107 | "shuffled_sentence = create_prompt(shuffled_prompt_data)\n",
108 | "print(\"Shuffled ICL Prompt:\\n\", repr(shuffled_sentence), '\\n\\n')\n",
109 | "\n",
110 | "zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)\n",
111 | "zeroshot_sentence = create_prompt(zeroshot_prompt_data)\n",
112 | "print(\"Zero-Shot Prompt:\\n\", repr(zeroshot_sentence))"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "## Evaluation"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "### Clean ICL Prompt"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "# Check model's ICL answer\n",
136 | "clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)\n",
137 | "\n",
138 | "print(\"Input Sentence:\", repr(sentence), '\\n')\n",
139 | "print(f\"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\\n\")\n",
140 | "print(\"ICL Prompt Top K Vocab Probs:\\n\", decode_to_vocab(clean_logits, tokenizer, k=5), '\\n')"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "### Corrupted ICL Prompt"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "# Perform an intervention on the shuffled setting\n",
157 | "clean_logits, interv_logits = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)\n",
158 | "\n",
159 | "print(\"Input Sentence:\", repr(shuffled_sentence), '\\n')\n",
160 | "print(f\"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\\n\")\n",
161 | "print(\"Few-Shot-Shuffled Prompt Top K Vocab Probs:\\n\", decode_to_vocab(clean_logits, tokenizer, k=5), '\\n')\n",
162 | "print(\"Shuffled Prompt+FV Top K Vocab Probs:\\n\", decode_to_vocab(interv_logits, tokenizer, k=5))"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "metadata": {},
168 | "source": [
169 | "### Zero-Shot Prompt"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": null,
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "# Intervention on the zero-shot prompt\n",
179 | "clean_logits, interv_logits = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)\n",
180 | "\n",
181 | "print(\"Input Sentence:\", repr(zeroshot_sentence), '\\n')\n",
182 | "print(f\"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\\n\")\n",
183 | "print(\"Zero-Shot Top K Vocab Probs:\\n\", decode_to_vocab(clean_logits, tokenizer, k=5), '\\n')\n",
184 | "print(\"Zero-Shot+FV Vocab Top K Vocab Probs:\\n\", decode_to_vocab(interv_logits, tokenizer, k=5))"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "### Natural Text Prompt"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "sentence = f\"The word \\\"{test_pair['input']}\\\" means\"\n",
201 | "co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)\n",
202 | "\n",
203 | "\n",
204 | "print(\"Input Sentence: \", repr(sentence))\n",
205 | "print(\"GPT-J:\" , repr(tokenizer.decode(co.squeeze())))\n",
206 | "print(\"GPT-J+FV:\", repr(tokenizer.decode(io.squeeze())), '\\n')"
207 | ]
208 | }
209 | ],
210 | "metadata": {
211 | "kernelspec": {
212 | "display_name": "Python 3 (ipykernel)",
213 | "language": "python",
214 | "name": "python3"
215 | },
216 | "language_info": {
217 | "codemirror_mode": {
218 | "name": "ipython",
219 | "version": 3
220 | },
221 | "file_extension": ".py",
222 | "mimetype": "text/x-python",
223 | "name": "python",
224 | "nbconvert_exporter": "python",
225 | "pygments_lexer": "ipython3",
226 | "version": "3.10.12"
227 | }
228 | },
229 | "nbformat": 4,
230 | "nbformat_minor": 2
231 | }
232 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ericwtodd/function_vectors/751e2219d304eba471cffcacc9efd89a4f8ef3c4/src/__init__.py
--------------------------------------------------------------------------------
/src/compute_average_activations.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import torch, numpy as np
3 | import argparse
4 |
5 | # Include prompt creation helper functions
6 | from utils.prompt_utils import *
7 | from utils.intervention_utils import *
8 | from utils.model_utils import *
9 | from utils.extract_utils import *
10 |
11 |
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
17 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
18 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
19 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results')
20 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42)
21 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", required=False, default=10)
22 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over", required=False, default=100)
23 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)
24 | parser.add_argument('--device', help='Device to run on', required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
25 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
26 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})
27 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None)
28 |
29 |
30 | args = parser.parse_args()
31 |
32 | dataset_name = args.dataset_name
33 | model_name = args.model_name
34 | root_data_dir = args.root_data_dir
35 | save_path_root = f"{args.save_path_root}/{dataset_name}"
36 | seed = args.seed
37 | n_shots = args.n_shots
38 | n_trials = args.n_trials
39 | test_split = args.test_split
40 | device = args.device
41 | prefixes = args.prefixes
42 | separators = args.separators
43 |
44 |
45 | # Load Model & Tokenizer
46 | torch.set_grad_enabled(False)
47 | print("Loading Model")
48 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision)
49 |
50 | set_seed(seed)
51 |
52 | # Load the dataset
53 | print("Loading Dataset")
54 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)
55 |
56 | print("Computing Mean Activations")
57 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer,
58 | n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)
59 |
60 | if not os.path.exists(save_path_root):
61 | os.makedirs(save_path_root)
62 |
63 | # Write args to file
64 | args.save_path_root = save_path_root # update for logging
65 | with open(f'{save_path_root}/mean_head_activation_args.txt', 'w') as arg_file:
66 | json.dump(args.__dict__, arg_file, indent=2)
67 |
68 | torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt')
69 |
70 |
--------------------------------------------------------------------------------
/src/compute_avg_hidden_state.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import torch, numpy as np
3 | import argparse
4 |
5 | # Include prompt creation helper functions
6 | from utils.prompt_utils import *
7 | from utils.intervention_utils import *
8 | from utils.model_utils import *
9 | from utils.eval_utils import *
10 | from utils.extract_utils import *
11 |
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
17 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
18 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
19 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results')
20 | parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5)
21 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", required=False, default=10)
22 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over", required=False, default=100)
23 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)
24 | parser.add_argument('--device', help='Device to run on', required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
25 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
26 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})
27 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None)
28 |
29 |
30 | args = parser.parse_args()
31 |
32 | dataset_name = args.dataset_name
33 | model_name = args.model_name
34 | root_data_dir = args.root_data_dir
35 | save_path_root = f"{args.save_path_root}/{dataset_name}"
36 | n_seeds = args.n_seeds
37 | n_shots = args.n_shots
38 | n_trials = args.n_trials
39 | test_split = args.test_split
40 | device = args.device
41 | prefixes = args.prefixes
42 | separators = args.separators
43 |
44 |
45 | # Load Model & Tokenizer
46 | torch.set_grad_enabled(False)
47 | print("Loading Model")
48 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision)
49 |
50 | seeds = np.random.choice(100000, size=n_seeds)
51 |
52 | for seed in seeds:
53 | set_seed(seed)
54 |
55 | # Load the dataset
56 | print("Loading Dataset")
57 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)
58 |
59 | print("Computing Mean Activations")
60 | dataset = load_dataset(dataset_name, seed=seed)
61 | mean_activations = get_mean_layer_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer,
62 | n_icl_examples=n_shots, N_TRIALS=n_trials)
63 |
64 |
65 | print("Saving mean layer activations")
66 | if not os.path.exists(save_path_root):
67 | os.makedirs(save_path_root)
68 |
69 | # Write args to file
70 | args.save_path_root = save_path_root # update for logging
71 | with open(f'{save_path_root}/mean_layer_activation_args.txt', 'w') as arg_file:
72 | json.dump(args.__dict__, arg_file, indent=2)
73 |
74 | torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_layer_activations.pt')
75 |
76 | print("Evaluating Layer Avgs. Baseline")
77 | fs_results = n_shot_eval_no_intervention(dataset, n_shots, model, model_config, tokenizer)
78 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
79 |
80 | zs_res = {}
81 | fss_res = {}
82 | for i in range(model_config['n_layers']):
83 | zs_res[i] = n_shot_eval(dataset, mean_activations[i].unsqueeze(0), i, 0, model, model_config, tokenizer, filter_set=filter_set)
84 | fss_res[i] = n_shot_eval(dataset, mean_activations[i].unsqueeze(0), i, 10, model, model_config, tokenizer, filter_set=filter_set, shuffle_labels=True)
85 |
86 | with open(f'{save_path_root}/mean_layer_intervention_zs_results_sweep_{seed}.json', 'w') as interv_zsres_file:
87 | json.dump(zs_res, interv_zsres_file, indent=2)
88 | with open(f'{save_path_root}/mean_layer_intervention_fss_results_sweep_{seed}.json', 'w') as interv_fssres_file:
89 | json.dump(fss_res, interv_fssres_file, indent=2)
90 |
--------------------------------------------------------------------------------
/src/compute_indirect_effect.py:
--------------------------------------------------------------------------------
1 | import os, re, json
2 | from tqdm import tqdm
3 | import torch, numpy as np
4 | import argparse
5 | from baukit import TraceDict
6 |
7 | # Include prompt creation helper functions
8 | from utils.prompt_utils import *
9 | from utils.intervention_utils import *
10 | from utils.model_utils import *
11 | from utils.extract_utils import *
12 |
13 |
14 | def activation_replacement_per_class_intervention(prompt_data, avg_activations, dummy_labels, model, model_config, tokenizer, last_token_only=True):
15 | """
16 | Experiment to determine top intervention locations through avg activation replacement.
17 | Performs a systematic sweep over attention heads (layer, head) to track their causal influence on probs of key tokens.
18 |
19 | Parameters:
20 | prompt_data: dict containing ICL prompt examples, and template information
21 | avg_activations: avg activation of each attention head in the model taken across n_trials ICL prompts
22 | dummy_labels: labels and indices for a baseline prompt with the same number of example pairs
23 | model: huggingface model
24 | model_config: contains model config information (n layers, n heads, etc.)
25 | tokenizer: huggingface tokenizer
26 | last_token_only: If True, only computes indirect effect for heads at the final token position. If False, computes indirect_effect for heads for all token classes
27 |
28 | Returns:
29 | indirect_effect_storage: torch tensor containing the indirect_effect of each head for each token class.
30 | """
31 | device = model.device
32 |
33 | # Get sentence and token labels
34 | query_target_pair = prompt_data['query_target']
35 |
36 | query = query_target_pair['input']
37 | token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query=query, prepend_bos=model_config['prepend_bos'])
38 |
39 | idx_map, idx_avg = compute_duplicated_labels(token_labels, dummy_labels)
40 | idx_map = update_idx_map(idx_map, idx_avg)
41 |
42 | sentences = [prompt_string]# * model.config.n_head # batch things by head
43 |
44 | # Figure out tokens of interest
45 | target = [query_target_pair['output']]
46 | token_id_of_interest = get_answer_id(sentences[0], target[0], tokenizer)
47 | if isinstance(token_id_of_interest, list):
48 | token_id_of_interest = token_id_of_interest[:1]
49 |
50 | inputs = tokenizer(sentences, return_tensors='pt').to(device)
51 |
52 | # Speed up computation by only computing causal effect at last token
53 | if last_token_only:
54 | token_classes = ['query_predictive']
55 | token_classes_regex = ['query_predictive_token']
56 | # Compute causal effect for all token classes (instead of just last token)
57 | else:
58 | token_classes = ['demonstration', 'label', 'separator', 'predictive', 'structural','end_of_example',
59 | 'query_demonstration', 'query_structural', 'query_separator', 'query_predictive']
60 | token_classes_regex = ['demonstration_[\d]{1,}_token', 'demonstration_[\d]{1,}_label_token', 'separator_token', 'predictive_token', 'structural_token','end_of_example_token',
61 | 'query_demonstration_token', 'query_structural_token', 'query_separator_token', 'query_predictive_token']
62 |
63 |
64 | indirect_effect_storage = torch.zeros(model_config['n_layers'], model_config['n_heads'],len(token_classes))
65 |
66 | # Clean Run of Baseline:
67 | clean_output = model(**inputs).logits[:,-1,:]
68 | clean_probs = torch.softmax(clean_output[0], dim=-1)
69 |
70 | # For every layer, head, token combination perform the replacement & track the change in meaningful tokens
71 | for layer in range(model_config['n_layers']):
72 | head_hook_layer = [model_config['attn_hook_names'][layer]]
73 |
74 | for head_n in range(model_config['n_heads']):
75 | for i,(token_class, class_regex) in enumerate(zip(token_classes, token_classes_regex)):
76 | reg_class_match = re.compile(f"^{class_regex}$")
77 | class_token_inds = [x[0] for x in token_labels if reg_class_match.match(x[2])]
78 |
79 | intervention_locations = [(layer, head_n, token_n) for token_n in class_token_inds]
80 | intervention_fn = replace_activation_w_avg(layer_head_token_pairs=intervention_locations, avg_activations=avg_activations,
81 | model=model, model_config=model_config,
82 | batched_input=False, idx_map=idx_map, last_token_only=last_token_only)
83 | with TraceDict(model, layers=head_hook_layer, edit_output=intervention_fn) as td:
84 | output = model(**inputs).logits[:,-1,:] # batch_size x n_tokens x vocab_size, only want last token prediction
85 |
86 | # TRACK probs of tokens of interest
87 | intervention_probs = torch.softmax(output, dim=-1) # convert to probability distribution
88 | indirect_effect_storage[layer,head_n,i] = (intervention_probs-clean_probs).index_select(1, torch.LongTensor(token_id_of_interest).to(device).squeeze()).squeeze()
89 |
90 | return indirect_effect_storage
91 |
92 |
93 | def compute_indirect_effect(dataset, mean_activations, model, model_config, tokenizer, n_shots=10, n_trials=25, last_token_only=True, prefixes=None, separators=None, filter_set=None):
94 | """
95 | Computes Indirect Effect of each head in the model
96 |
97 | Parameters:
98 | dataset: ICL dataset
99 | mean_activations:
100 | model: huggingface model
101 | model_config: contains model config information (n layers, n heads, etc.)
102 | tokenizer: huggingface tokenizer
103 | n_shots: Number of shots in each in-context prompt
104 | n_trials: Number of in-context prompts to average over
105 | last_token_only: If True, only computes Indirect Effect for heads at the final token position. If False, computes Indirect Effect for heads for all token classes
106 |
107 |
108 | Returns:
109 | indirect_effect: torch tensor of the indirect effect for each attention head in the model, size n_trials * n_layers * n_heads
110 | """
111 | n_test_examples = 1
112 |
113 | if prefixes is not None and separators is not None:
114 | dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer, prefixes=prefixes, separators=separators, model_config=model_config)
115 | else:
116 | dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer, model_config=model_config)
117 |
118 | # If the model already prepends a bos token by default, we don't want to add one
119 | prepend_bos = False if model_config['prepend_bos'] else True
120 |
121 | if last_token_only:
122 | indirect_effect = torch.zeros(n_trials,model_config['n_layers'], model_config['n_heads'])
123 | else:
124 | indirect_effect = torch.zeros(n_trials,model_config['n_layers'], model_config['n_heads'],10) # have 10 classes of tokens
125 |
126 | if filter_set is None:
127 | filter_set = np.arange(len(dataset['valid']))
128 |
129 | for i in tqdm(range(n_trials), total=n_trials):
130 | word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_shots, replace=False)]
131 | word_pairs_test = dataset['valid'][np.random.choice(filter_set,n_test_examples, replace=False)]
132 | if prefixes is not None and separators is not None:
133 | prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, shuffle_labels=True,
134 | prepend_bos_token=prepend_bos, prefixes=prefixes, separators=separators)
135 | else:
136 | prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test,
137 | shuffle_labels=True, prepend_bos_token=prepend_bos)
138 |
139 | ind_effects = activation_replacement_per_class_intervention(prompt_data=prompt_data_random,
140 | avg_activations = mean_activations,
141 | dummy_labels=dummy_gt_labels,
142 | model=model, model_config=model_config, tokenizer=tokenizer,
143 | last_token_only=last_token_only)
144 | indirect_effect[i] = ind_effects.squeeze()
145 |
146 | return indirect_effect
147 |
148 |
149 | if __name__ == "__main__":
150 |
151 | parser = argparse.ArgumentParser()
152 |
153 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
154 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
155 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
156 | parser.add_argument('--save_path_root', help='File path to save indirect effect to', type=str, required=False, default='../results')
157 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42)
158 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type =int, required=False, default=10)
159 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over", type=int, required=False, default=25)
160 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)
161 | parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
162 | parser.add_argument('--mean_activations_path', help='Path to mean activations file used for intervention', required=False, type=str, default=None)
163 | parser.add_argument('--last_token_only', help='Whether to compute indirect effect for heads at only the final token position, or for all token classes', required=False, type=bool, default=True)
164 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
165 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})
166 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None)
167 |
168 | args = parser.parse_args()
169 |
170 | dataset_name = args.dataset_name
171 | model_name = args.model_name
172 | root_data_dir = args.root_data_dir
173 | save_path_root = f"{args.save_path_root}/{dataset_name}"
174 | seed = args.seed
175 | n_shots = args.n_shots
176 | n_trials = args.n_trials
177 | test_split = args.test_split
178 | device = args.device
179 | mean_activations_path = args.mean_activations_path
180 | last_token_only = args.last_token_only
181 | prefixes = args.prefixes
182 | separators = args.separators
183 |
184 |
185 | # Load Model & Tokenizer
186 | torch.set_grad_enabled(False)
187 | print("Loading Model")
188 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision)
189 |
190 | set_seed(seed)
191 |
192 | # Load the dataset
193 | print("Loading Dataset")
194 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)
195 |
196 |
197 | if not os.path.exists(save_path_root):
198 | os.makedirs(save_path_root)
199 |
200 | # Load or Re-Compute Mean Activations
201 | if mean_activations_path is not None and os.path.exists(mean_activations_path):
202 | mean_activations = torch.load(mean_activations_path)
203 | elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'):
204 | mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
205 | mean_activations = torch.load(mean_activations_path)
206 | else:
207 | print("Computing Mean Activations")
208 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer,
209 | n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)
210 | torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt')
211 |
212 | print("Computing Indirect Effect")
213 | indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer,
214 | n_shots=n_shots, n_trials=n_trials, last_token_only=last_token_only, prefixes=prefixes, separators=separators)
215 |
216 | # Write args to file
217 | args.save_path_root = save_path_root
218 | args.mean_activations_path = mean_activations_path
219 | with open(f'{save_path_root}/indirect_effect_args.txt', 'w') as arg_file:
220 | json.dump(args.__dict__, arg_file, indent=2)
221 |
222 | torch.save(indirect_effect, f'{save_path_root}/{dataset_name}_indirect_effect.pt')
223 |
224 |
--------------------------------------------------------------------------------
/src/eval_scripts/eval_avg_hs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural')
3 | cd ../
4 |
5 | for d_name in "${datasets[@]}"
6 | do
7 | echo "Running Script for: ${d_name}"
8 | python compute_avg_hidden_state.py --dataset_name="${d_name}" --save_path_root="results/gptj_avg_hs" --model_name='EleutherAI/gpt-j-6b'
9 | done
10 |
--------------------------------------------------------------------------------
/src/eval_scripts/eval_fv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | datasets=('antonym')
3 | # datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural')
4 | cd ../
5 |
6 | for d_name in "${datasets[@]}"
7 | do
8 | echo "Running Script for: ${d_name}"
9 | python evaluate_function_vector.py --dataset_name="${d_name}" --save_path_root="results/gptj" --model_name='EleutherAI/gpt-j-6b'
10 | done
--------------------------------------------------------------------------------
/src/eval_scripts/eval_numheads.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural')
3 | cd ../
4 |
5 | for d_name in "${datasets[@]}"
6 | do
7 | echo "Running Script for: ${d_name}"
8 | python test_numheads.py --dataset_name="${d_name}" --model_name='EleutherAI/gpt-j-6b'
9 | done
--------------------------------------------------------------------------------
/src/eval_scripts/eval_template_portability.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural')
3 | cd ../
4 |
5 | for d_name in "${datasets[@]}"
6 | do
7 | echo "Running Script for: ${d_name}"
8 | python portability_eval.py --dataset_name="${d_name}" --save_path_root="results/gptj" --model_name='EleutherAI/gpt-j-6b'
9 | done
--------------------------------------------------------------------------------
/src/eval_scripts/fv_eval_sweep.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 |
5 | # Submit slurm jobs for many tasks
6 |
7 | dataset_names = ['antonym', 'capitalize', 'country-capital', 'english-french', 'present-past', 'singular-plural']
8 | MODEL_NAMES = ["EleutherAI/gpt-j-6b"]
9 | MODEL_NICKNAMES = ['gptj']
10 |
11 |
12 | job_path = str(time.ctime()).replace(" ", "_")
13 | print(job_path)
14 | os.makedirs(job_path, exist_ok=True)
15 |
16 | d_name_to_cmd = {}
17 |
18 | ## creating the jobs
19 | for model_name,model_nickname in zip(MODEL_NAMES, MODEL_NICKNAMES):
20 | current_seed = np.random.randint(1000000)
21 | for idx, d_name in enumerate(dataset_names):
22 | results_path = os.path.join('results', f'{model_nickname}')
23 | n_fv_heads = 10
24 |
25 | cmd = f"python evaluate_function_vector.py --dataset_name='{d_name}' --save_path_root='{results_path}' --model_name='{model_name}' --n_top_heads={n_fv_heads} --seed={current_seed}"
26 | if 'squad' in d_name:
27 | cmd += " --n_shots=5 --generate_str --metric='f1_score'"
28 | elif 'ag_news' in d_name:
29 | cmd += " --n_shots=10 --generate_str --metric='first_word_score'"
30 |
31 | key = model_nickname + '_' + d_name
32 | d_name_to_cmd[key] = cmd
33 |
34 |
35 | for key in d_name_to_cmd:
36 | with open("template.sh", "r") as f:
37 | bash_template = f.readlines()
38 | bash_template.append(d_name_to_cmd[key])
39 |
40 | with open(f"{job_path}/{key}.sh", "w") as f:
41 | f.writelines(bash_template)
42 |
43 |
44 | ## running the jobs
45 | for job in os.listdir(job_path):
46 | job_script = f"{job_path}/{job}"
47 | cmd = f"sbatch --gpus=1 --time=48:00:00 {job_script}"
48 | print("submitting job: ", job)
49 | print(cmd)
50 | os.system(cmd)
51 | print("\n\n")
52 |
53 | print("------------------------------------------------------------------")
54 | print(f"submitted {len(os.listdir(job_path))} jobs!")
55 |
--------------------------------------------------------------------------------
/src/eval_scripts/template.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | source ~/.bashrc
3 | cd ../
4 | conda activate fv
5 |
--------------------------------------------------------------------------------
/src/evaluate_function_vector.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import torch, numpy as np
3 | import argparse
4 |
5 | # Include prompt creation helper functions
6 | from utils.prompt_utils import *
7 | from utils.intervention_utils import *
8 | from utils.model_utils import *
9 | from utils.eval_utils import *
10 | from utils.extract_utils import *
11 | from compute_indirect_effect import compute_indirect_effect
12 |
13 | if __name__ == "__main__":
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
18 | parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function vector', required=False, type=int, default=10)
19 | parser.add_argument('--edit_layer', help='Layer for intervention. If -1, sweep over all layers', type=int, required=False, default=-1) #
20 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
21 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
22 | parser.add_argument('--save_path_root', help='File path to save to', type=str, required=False, default='../results')
23 | parser.add_argument('--ie_path_root', help='File path to load indirect effects from', type=str, required=False, default=None)
24 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42)
25 | parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
26 | parser.add_argument('--mean_activations_path', help='Path to file containing mean_head_activations for the specified task', required=False, type=str, default=None)
27 | parser.add_argument('--indirect_effect_path', help='Path to file containing indirect_effect scores for the specified task', required=False, type=str, default=None)
28 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)
29 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type=int, required=False, default=10)
30 | parser.add_argument('--n_mean_activations_trials', help="Number of in-context prompts to average over for mean_activations", type=int, required=False, default=100)
31 | parser.add_argument('--n_indirect_effect_trials', help="Number of in-context prompts to average over for indirect_effect", type=int, required=False, default=25)
32 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
33 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})
34 | parser.add_argument('--compute_baseline', help='Whether to compute the model baseline 0-shot -> n-shot performance', type=bool, required=False, default=True)
35 | parser.add_argument('--generate_str', help='Whether to generate long-form completions for the task', action='store_true', required=False)
36 | parser.add_argument("--metric", help="Metric to use when evaluating generated strings", type=str, required=False, default="f1_score")
37 | parser.add_argument("--universal_set", help="Flag for whether to evaluate using the univeral set of heads", action="store_true", required=False)
38 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None)
39 |
40 | args = parser.parse_args()
41 |
42 | dataset_name = args.dataset_name
43 | model_name = args.model_name
44 | root_data_dir = args.root_data_dir
45 | save_path_root = f"{args.save_path_root}/{dataset_name}"
46 | ie_path_root = f"{args.ie_path_root}/{dataset_name}" if args.ie_path_root else save_path_root
47 | seed = args.seed
48 | device = args.device
49 | mean_activations_path = args.mean_activations_path
50 | indirect_effect_path = args.indirect_effect_path
51 | n_top_heads = args.n_top_heads
52 | eval_edit_layer = args.edit_layer
53 |
54 | test_split = float(args.test_split)
55 | n_shots = args.n_shots
56 | n_mean_activations_trials = args.n_mean_activations_trials
57 | n_indirect_effect_trials = args.n_indirect_effect_trials
58 |
59 | prefixes = args.prefixes
60 | separators = args.separators
61 | compute_baseline = args.compute_baseline
62 |
63 | generate_str = args.generate_str
64 | metric = args.metric
65 | universal_set = args.universal_set
66 |
67 | print(args)
68 |
69 | # Load Model & Tokenizer
70 | torch.set_grad_enabled(False)
71 | print("Loading Model")
72 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision)
73 |
74 | if args.edit_layer == -1: # sweep over all layers if edit_layer=-1
75 | eval_edit_layer = [0, model_config['n_layers']]
76 |
77 | # Load the dataset
78 | print("Loading Dataset")
79 | set_seed(seed)
80 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)
81 |
82 | if not os.path.exists(save_path_root):
83 | os.makedirs(save_path_root)
84 |
85 | print(f"Filtering Dataset via {n_shots}-shot Eval")
86 | # 1. Compute Model 10-shot Baseline & 2. Filter test set to cases where model gets it correct
87 |
88 | fs_results_file_name = f'{save_path_root}/fs_results_layer_sweep.json'
89 | print(fs_results_file_name)
90 | if os.path.exists(fs_results_file_name):
91 | with open(fs_results_file_name, 'r') as indata:
92 | fs_results = json.load(indata)
93 | key = 'score' if generate_str else 'clean_rank_list'
94 | target_val = 1 if generate_str else 0
95 | filter_set = np.where(np.array(fs_results[key]) == target_val)[0]
96 | filter_set_validation = None
97 | elif generate_str:
98 | set_seed(seed+42)
99 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False,
100 | generate_str=True, metric=metric, test_split='valid', prefixes=prefixes, separators=separators)
101 | filter_set_validation = np.where(np.array(fs_results_validation['score']) == 1)[0]
102 | set_seed(seed)
103 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False,
104 | generate_str=True, metric=metric, prefixes=prefixes, separators=separators)
105 | filter_set = np.where(np.array(fs_results['score']) == 1)[0]
106 | else:
107 | set_seed(seed+42)
108 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True, test_split='valid', prefixes=prefixes, separators=separators)
109 | filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0]
110 | set_seed(seed)
111 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True, prefixes=prefixes, separators=separators)
112 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
113 |
114 | args.fs_results_file_name = fs_results_file_name
115 | with open(fs_results_file_name, 'w') as results_file:
116 | json.dump(fs_results, results_file, indent=2)
117 |
118 | set_seed(seed)
119 | # Load or Re-Compute mean_head_activations
120 | if mean_activations_path is not None and os.path.exists(mean_activations_path):
121 | mean_activations = torch.load(mean_activations_path)
122 | elif mean_activations_path is None and os.path.exists(f'{ie_path_root}/{dataset_name}_mean_head_activations.pt'):
123 | mean_activations_path = f'{ie_path_root}/{dataset_name}_mean_head_activations.pt'
124 | mean_activations = torch.load(mean_activations_path)
125 | else:
126 | print("Computing Mean Activations")
127 | set_seed(seed)
128 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots,
129 | N_TRIALS=n_mean_activations_trials, prefixes=prefixes, separators=separators, filter_set=filter_set_validation)
130 | args.mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
131 | torch.save(mean_activations, args.mean_activations_path)
132 |
133 | # Load or Re-Compute indirect_effect values
134 | if indirect_effect_path is not None and os.path.exists(indirect_effect_path):
135 | indirect_effect = torch.load(indirect_effect_path)
136 | elif indirect_effect_path is None and os.path.exists(f'{ie_path_root}/{dataset_name}_indirect_effect.pt'):
137 | indirect_effect_path = f'{ie_path_root}/{dataset_name}_indirect_effect.pt'
138 | indirect_effect = torch.load(indirect_effect_path)
139 | elif not universal_set: # Only compute indirect effects if we need to
140 | print("Computing Indirect Effects")
141 | set_seed(seed)
142 | indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer, n_shots=n_shots,
143 | n_trials=n_indirect_effect_trials, last_token_only=True, prefixes=prefixes, separators=separators, filter_set=filter_set_validation)
144 | args.indirect_effect_path = f'{save_path_root}/{dataset_name}_indirect_effect.pt'
145 | torch.save(indirect_effect, args.indirect_effect_path)
146 |
147 | # Compute Function Vector
148 | if universal_set:
149 | fv, top_heads = compute_universal_function_vector(mean_activations, model, model_config=model_config, n_top_heads=n_top_heads)
150 | else:
151 | fv, top_heads = compute_function_vector(mean_activations, indirect_effect, model, model_config=model_config, n_top_heads=n_top_heads)
152 |
153 | # Run Evaluation
154 | if isinstance(eval_edit_layer, int):
155 | print(f"Running ZS Eval with edit_layer={eval_edit_layer}")
156 | set_seed(seed)
157 | if generate_str:
158 | pred_filepath = f"{save_path_root}/preds/{model_config['name_or_path'].replace('/', '_')}_ZS_intervention_layer{eval_edit_layer}.txt"
159 | zs_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=0,
160 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set,
161 | generate_str=generate_str, metric=metric, pred_filepath=pred_filepath, prefixes=prefixes, separators=separators)
162 | else:
163 | zs_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=0,
164 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, prefixes=prefixes, separators=separators)
165 | zs_results_file_suffix = f'_editlayer_{eval_edit_layer}.json'
166 |
167 |
168 | print(f"Running {n_shots}-Shot Shuffled Eval")
169 | set_seed(seed)
170 | if generate_str:
171 | pred_filepath = f"{save_path_root}/preds/{model_config['name_or_path'].replace('/', '_')}_{n_shots}shots_shuffled_intervention_layer{eval_edit_layer}.txt"
172 | fs_shuffled_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=n_shots,
173 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, shuffle_labels=True,
174 | generate_str=generate_str, metric=metric, pred_filepath=pred_filepath, prefixes=prefixes, separators=separators)
175 | else:
176 | fs_shuffled_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=n_shots,
177 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, shuffle_labels=True, prefixes=prefixes, separators=separators)
178 | fs_shuffled_results_file_suffix = f'_editlayer_{eval_edit_layer}.json'
179 |
180 | else:
181 | print(f"Running sweep over layers {eval_edit_layer}")
182 | zs_results = {}
183 | fs_shuffled_results = {}
184 | for edit_layer in range(eval_edit_layer[0], eval_edit_layer[1]):
185 | set_seed(seed)
186 | if generate_str:
187 | zs_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=0,
188 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set,
189 | generate_str=generate_str, metric=metric, prefixes=prefixes, separators=separators)
190 | else:
191 | zs_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=0, prefixes=prefixes, separators=separators,
192 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set)
193 | set_seed(seed)
194 | if generate_str:
195 | fs_shuffled_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=n_shots,
196 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set = filter_set,
197 | generate_str=generate_str, metric=metric, shuffle_labels=True, prefixes=prefixes, separators=separators)
198 | else:
199 | fs_shuffled_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=n_shots,
200 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set = filter_set, shuffle_labels=True, prefixes=prefixes, separators=separators)
201 | zs_results_file_suffix = '_layer_sweep.json'
202 | fs_shuffled_results_file_suffix = '_layer_sweep.json'
203 |
204 |
205 | # Save results to files
206 | zs_results_file_name = make_valid_path_name(f'{save_path_root}/zs_results' + zs_results_file_suffix)
207 | args.zs_results_file_name = zs_results_file_name
208 | with open(zs_results_file_name, 'w') as results_file:
209 | json.dump(zs_results, results_file, indent=2)
210 |
211 | fs_shuffled_results_file_name = make_valid_path_name(f'{save_path_root}/fs_shuffled_results' + fs_shuffled_results_file_suffix)
212 | args.fs_shuffled_results_file_name = fs_shuffled_results_file_name
213 | with open(fs_shuffled_results_file_name, 'w') as results_file:
214 | json.dump(fs_shuffled_results, results_file, indent=2)
215 |
216 | if compute_baseline:
217 | print(f"Computing model baseline results for {n_shots}-shots")
218 | baseline_results = compute_dataset_baseline(dataset, model, model_config, tokenizer, n_shots=n_shots, seed=seed, prefixes=prefixes, separators=separators)
219 |
220 | baseline_file_name = make_valid_path_name(f'{save_path_root}/model_baseline.json')
221 | args.baseline_file_name = baseline_file_name
222 | with open(baseline_file_name, 'w') as results_file:
223 | json.dump(baseline_results, results_file, indent=2)
224 |
225 | # Write args to file
226 | args_file_name = make_valid_path_name(f'{save_path_root}/fv_eval_args.txt')
227 | with open(args_file_name, 'w') as arg_file:
228 | json.dump(args.__dict__, arg_file, indent=2)
229 |
--------------------------------------------------------------------------------
/src/natural_text_eval.py:
--------------------------------------------------------------------------------
1 | import os, re, json
2 | import torch, numpy as np
3 | from tqdm import tqdm
4 | import argparse
5 |
6 | # Include prompt creation helper functions
7 | from utils.eval_utils import n_shot_eval_no_intervention
8 | from utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
9 | from utils.intervention_utils import fv_intervention_natural_text
10 | from utils.model_utils import load_gpt_model_and_tokenizer
11 | from utils.prompt_utils import load_dataset
12 |
13 |
14 | def natural_text_eval(dataset, fv_vector, model, model_config, tokenizer, filter_set, edit_layer=9, MNT=5, verbose=False, loc_data=False):
15 | """
16 | Evaluates the causal effects of a function vector on natural text templates for the given dataset.
17 |
18 | Parameters:
19 | dataset: ICL dataset with pairs of words
20 | fv_vector: function vector to use for intervention
21 | model: huggingface model
22 | model_config: contains model config information (n layers, n heads, etc.)
23 | tokenizer: huggingface tokenizer
24 | filter_set: list of samples to filter to, used to include samples the model gets correct via ICL
25 | edit_layer: the layer to add the function vector to
26 | MNT: max number of tokens to generate
27 | verbose: whether to print outputs of clean & +FV generations.
28 | loc_data: whether the dataset is locations (e.g. country-capital, national parks, etc.)
29 |
30 | Returns:
31 | all_scores: scores for model when adding the FV during generation
32 | all_clean_scores: scores for base model (no intervention)
33 | sentences: sentence templates used during eval
34 | """
35 | all_scores = {}
36 | all_clean_scores = {}
37 |
38 | if loc_data: # country-capital & similar datasets
39 | sentences = ["A couple years ago I visited {X}, and",
40 | "If you ever travel to {X}, you have to visit",
41 | "When you think of {X},"]
42 | else:
43 | sentences = ["The word \"{X}\", means",
44 | "When I think of the word \"{X}\", it usually means",
45 | "When I think of \"{X}\", I usually",
46 | "While reading a book, I came across the word \"{X}\". I looked it up in a dictionary and it turns out that it means",
47 | "The word \"{X}\" can be understood as a synonym for"]
48 |
49 | for j in range(len(sentences)):
50 | scores = []
51 | clean_scores = []
52 | for i in tqdm(range(len(filter_set)), total=len(filter_set)):
53 | ind = int(filter_set[i])
54 | q_pair = dataset['test'][ind]
55 | if isinstance(q_pair['input'], list):
56 | q_pair['input'] = q_pair['input'][0]
57 | if isinstance(q_pair['output'], list):
58 | q_pair['output'] = q_pair['output'][0]
59 |
60 | sentence = sentences[j]
61 | sentence = sentence.replace('{X}', f"{q_pair['input']}")
62 |
63 | clean_output, fv_output = fv_intervention_natural_text(sentence, edit_layer, fv_vector, model, model_config, tokenizer, max_new_tokens=MNT)
64 | clean_out_str = repr(tokenizer.decode(clean_output.squeeze()[-MNT:]))
65 | fv_out_str = repr(tokenizer.decode(fv_output.squeeze()[-MNT:]))
66 |
67 | if verbose:
68 | print("\nQuery/Target: ", q_pair)
69 | print("Prompt: ", repr(sentence))
70 | print("clean completion:" , clean_out_str)
71 | print("+FV completion:", fv_out_str, '\n')
72 |
73 | scores.append(int(q_pair['output'] in fv_out_str))
74 | clean_scores.append(int(q_pair['output'] in clean_out_str))
75 |
76 | all_scores[j] = scores
77 | all_clean_scores[j] = clean_scores
78 |
79 | return all_scores, all_clean_scores, sentences
80 |
81 | def nattext_main(datasets, model, model_config, tokenizer, root_data_dir='../dataset_files', edit_layer=9, n_shots=10, n_trials=100, n_seeds=5):
82 | """
83 | Main function that evaluates causal effects of function vectors on natural text templates.
84 |
85 | Parameters:
86 | datasets: list of dataset names to evaluate
87 | model: huggingface model
88 | model_config: contains model config information (n layers, n heads, etc.)
89 | tokenizer: huggingface tokenizer
90 | root_data_dir: directory data is contained in
91 | edit_layer: layer to add the function vector to during intervention
92 | n_shots: number of shots for prompts used when computing task-conditioned mean head activations
93 | n_trials: number of prompts to include when computing task-conditioned mean head activations
94 | n_seeds: number of seeds to average results over
95 |
96 | Returns:
97 | clean_results_dict: dict containing results for base model (no intervention)
98 | interv_results_dict: results for model when adding the function vector at edit_layer during generation
99 | seeds_dict: dict containing the seeds used during evaluation
100 | """
101 | interv_results_dict = {k:[] for k in datasets}
102 | clean_results_dict = {k:[] for k in datasets}
103 | seeds_dict = {k:[] for k in datasets}
104 |
105 | # Test Loop:
106 | for dataset_name in datasets:
107 | if dataset_name == 'country-capital':
108 | loc_data = True
109 | max_new_tokens = 10
110 | else:
111 | loc_data = False
112 | max_new_tokens = 5
113 |
114 | for _ in range(n_seeds):
115 | seed = np.random.randint(100000)
116 | seeds_dict[dataset_name].append(seed)
117 | dataset = load_dataset(dataset_name, seed=seed, root_data_dir=root_data_dir)
118 |
119 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='valid')
120 | filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0]
121 |
122 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots,
123 | N_TRIALS=n_trials, filter_set=filter_set_validation)
124 | fv, _ = compute_universal_function_vector(mean_activations, model, model_config)
125 |
126 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='test')
127 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
128 |
129 | results, clean_results, _ = natural_text_eval(dataset, fv, model, model_config, tokenizer, filter_set, MNT=max_new_tokens, edit_layer=edit_layer, verbose=False, loc_data=loc_data)
130 |
131 | clean_results_dict[dataset_name].append([np.mean(clean_results[i]) for i in clean_results.keys()])
132 | interv_results_dict[dataset_name].append([np.mean(results[i]) for i in results.keys()])
133 |
134 | return clean_results_dict, interv_results_dict, seeds_dict
135 |
136 |
137 | if __name__ == "__main__":
138 |
139 | parser = argparse.ArgumentParser()
140 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
141 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
142 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results')
143 | parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5)
144 | parser.add_argument('--n_trials', help='Number of trials to use for computing task-conditioned mean head activations', type=int, required=False, default=100)
145 | parser.add_argument('--n_shots', help='Number of shots to use for prompts when computing task-conditioned mean head activations', type=int, required=False, default=10)
146 | parser.add_argument('--edit_layer', help='Layer to add function vector to', type=int, required=False, default=9)
147 |
148 | args = parser.parse_args()
149 |
150 | # Gather inputs
151 | model_name = args.model_name
152 | root_data_dir = args.root_data_dir
153 | save_path_root = args.save_path_root
154 | n_seeds = args.n_seeds
155 | n_trials = args.n_trials
156 | n_shots = args.n_shots
157 | edit_layer = args.edit_layer
158 |
159 |
160 | # Load Model & Tokenizer
161 | torch.set_grad_enabled(False)
162 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
163 |
164 | datasets = ['antonym', 'capitalize', 'country-capital', 'english-french', 'present-past', 'singular-plural']
165 | args.datasets = datasets
166 |
167 | # Run Natural Text Eval
168 | clean_results_dict, interv_results_dict, seeds_dict = nattext_main(datasets, model, model_config, tokenizer,
169 | root_data_dir=root_data_dir, edit_layer=edit_layer,
170 | n_shots=n_shots, n_trials=n_trials, n_seeds=n_seeds)
171 |
172 | # Extract Summary Results:
173 | os.makedirs(os.path.join(save_path_root), exist_ok=True)
174 | with open(os.path.join(save_path_root, 'nattext_eval_results.txt'), 'w') as out_file:
175 | for d in datasets:
176 | print(f"{d.title()}:", file=out_file)
177 | clean_acc = np.array(clean_results_dict[d]).mean(axis=0)
178 | clean_std = np.array(clean_results_dict[d]).std(axis=0)
179 | fv_acc = np.array(interv_results_dict[d]).mean(axis=0)
180 | fv_std = np.array(interv_results_dict[d]).std(axis=0)
181 |
182 | print("clean results:", clean_acc.round(3)*100, '% +/-', clean_std.round(3)*100, file=out_file)
183 | print("fv results:", fv_acc.round(3)*100, '% +/-', fv_std.round(3)*100, file=out_file)
184 |
185 | # Write args to a file
186 | args.seeds_dict = seeds_dict
187 | with open(os.path.join(save_path_root, 'nattext_eval_args.txt'), 'w') as arg_file:
188 | print(args.__dict__, file=arg_file)
189 |
190 |
191 |
--------------------------------------------------------------------------------
/src/portability_eval.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import torch, numpy as np
3 | import argparse
4 |
5 | # Include prompt creation helper functions
6 | from utils.prompt_utils import *
7 | from utils.intervention_utils import *
8 | from utils.model_utils import *
9 | from utils.eval_utils import *
10 | from utils.extract_utils import *
11 |
12 |
13 | if __name__ == "__main__":
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True)
18 | parser.add_argument('--n_eval_templates', help='Number of templates to evaluate with', required=True, type=int, default=15)
19 | parser.add_argument('--edit_layer', help='Layer for intervention. If -1, sweep over all layers', type=int, required=False, default=9) #
20 |
21 | parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function vector', required=False, type=int, default=10)
22 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
23 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
24 | parser.add_argument('--save_path_root', help='File path to save to', type=str, required=False, default='../results')
25 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=5678)
26 | parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu')
27 | parser.add_argument('--mean_activations_path', help='Path to file containing mean_head_activations for the specified task', required=False, type=str, default=None)
28 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3)
29 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type=int, required=False, default=10)
30 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over for indirect_effect", type=int, required=False, default=25)
31 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""})
32 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""})
33 |
34 | args = parser.parse_args()
35 |
36 | dataset_name = args.dataset_name
37 | model_name = args.model_name
38 | root_data_dir = args.root_data_dir
39 | save_path_root = f"{args.save_path_root}/{dataset_name}"
40 | seed = args.seed
41 | device = args.device
42 | mean_activations_path = args.mean_activations_path
43 | n_top_heads = args.n_top_heads
44 | eval_edit_layer = args.edit_layer
45 |
46 | test_split = args.test_split
47 | n_shots = args.n_shots
48 | n_trials = args.n_trials
49 |
50 | prefixes = args.prefixes
51 | separators = args.separators
52 |
53 | n_eval_templates = args.n_eval_templates
54 |
55 | print(args)
56 |
57 | # Load Model & Tokenizer
58 | torch.set_grad_enabled(False)
59 | print("Loading Model")
60 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device)
61 |
62 | if args.edit_layer == -1: # sweep over all layers if edit_layer=-1
63 | eval_edit_layer = [0, model_config['n_layers']]
64 |
65 | # Load the dataset
66 | print("Loading Dataset")
67 | set_seed(seed)
68 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed)
69 |
70 | if not os.path.exists(save_path_root):
71 | os.makedirs(save_path_root)
72 |
73 | # Load or Re-Compute mean_head_activations
74 | if mean_activations_path is not None and os.path.exists(mean_activations_path):
75 | mean_activations = torch.load(mean_activations_path)
76 | elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'):
77 | mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
78 | mean_activations = torch.load(mean_activations_path)
79 | else:
80 | print("Computing Mean Activations")
81 | set_seed(seed)
82 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer,
83 | n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators)
84 | args.mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt'
85 | torch.save(mean_activations, args.mean_activations_path)
86 |
87 | # Compute Function Vector
88 | fv, top_heads = compute_universal_function_vector(mean_activations, model, model_config=model_config, n_top_heads=n_top_heads)
89 |
90 | print("Computing Portability")
91 | fs_res_dict, zs_res_dict,fs_shuffled_res_dict, templates = portability_eval(dataset, fv, eval_edit_layer, model, model_config, tokenizer, n_eval_templates=n_eval_templates)
92 |
93 | args.templates = templates
94 |
95 | save_path_root = f"{args.save_path_root}_port/{dataset_name}"
96 | if not os.path.exists(save_path_root):
97 | os.makedirs(save_path_root)
98 |
99 | fs_results_file_name = make_valid_path_name(f'{save_path_root}/fs_port_eval.json')
100 | args.fs_results_file_name = fs_results_file_name
101 | with open(fs_results_file_name,'w') as fs_results_file:
102 | json.dump(fs_res_dict, fs_results_file,indent=2)
103 |
104 | fs_shuffled_results_file_name = make_valid_path_name(f'{save_path_root}/fs_shuffled_port_eval.json')
105 | args.fs_shuffled_results_file_name = fs_shuffled_results_file_name
106 | with open(fs_shuffled_results_file_name,'w') as fs_shuffled_results_file:
107 | json.dump(fs_shuffled_res_dict, fs_shuffled_results_file,indent=2)
108 |
109 | zs_results_file_name = make_valid_path_name(f'{save_path_root}/zs_port_eval.json')
110 | args.zs_results_file_name = zs_results_file_name
111 | with open(zs_results_file_name,'w') as zs_results_file:
112 | json.dump(zs_res_dict, zs_results_file,indent=2)
113 |
114 | args_file_name = make_valid_path_name(f'{save_path_root}/port_eval_args.txt')
115 | with open(args_file_name, 'w') as arg_file:
116 | json.dump(args.__dict__, arg_file, indent=2)
--------------------------------------------------------------------------------
/src/test_numheads.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import numpy as np
5 | import torch
6 |
7 | from src.utils.eval_utils import n_shot_eval, n_shot_eval_no_intervention
8 | from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed
9 | from src.utils.prompt_utils import load_dataset
10 | from src.evaluate_function_vector import compute_universal_function_vector
11 |
12 | # Evaluates how performance changes as the number of heads used to create a Function Vector increases
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 |
16 | parser.add_argument('--dataset_name', help="dataset to be evaluated", type=str, required=True)
17 | parser.add_argument('--mean_act_root', help="root path to mean activations", type=str, required=False, default='IE_template_QA/gptj')
18 | parser.add_argument('--model_name', type=str, required=True, default='EleutherAI/gpt-j-6b')
19 | parser.add_argument('--model_nickname', type=str, required=False, default='gptj')
20 | parser.add_argument('--n_heads', type=int, help="upper bound of the number of heads to create the FV", required=True, default=40)
21 | parser.add_argument('--edit_layer', type=int, help="layer at which to add the function vector", required=True, default=9)
22 | parser.add_argument('--seed', required=False, type=int, default=42)
23 | parser.add_argument('--save_path_root', required=True, type=str, default='../results')
24 |
25 |
26 | args = parser.parse_args()
27 | mean_act_root = args.mean_act_root
28 | model_name = args.model_name
29 | model_nickname = args.model_nickname
30 | dataset_name = args.dataset_name
31 | n_heads = args.n_heads
32 | edit_layer = args.edit_layer
33 | seed = args.seed
34 | save_path_root = args.save_path_root
35 |
36 |
37 | # Load Model & Tokenizer, doing inference so don't need gradients
38 | torch.set_grad_enabled(False)
39 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
40 | dataset = load_dataset(dataset_name)
41 | mean_activations = torch.load(f'{save_path_root}/{mean_act_root}/{dataset_name}/{dataset_name}_mean_head_activations.pt')
42 |
43 |
44 | set_seed(seed)
45 | fs_results = n_shot_eval_no_intervention(dataset, n_shots=10, model=model, model_config=model_config, tokenizer=tokenizer)
46 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
47 | print("Sanity Check, cleantopk: ", fs_results['clean_topk'])
48 | zs_results = {}
49 |
50 | for i in range(n_heads+1):
51 | fv, _ = compute_universal_function_vector(mean_activations, model, model_config, i)
52 | zs_results[i] = n_shot_eval(dataset, fv, edit_layer, 0, model, model_config, tokenizer, filter_set=filter_set)
53 |
54 |
55 | os.makedirs(f'{save_path_root}/{model_nickname}_test_numheads', exist_ok=True)
56 | json.dump(zs_results, open(f'{save_path_root}/{model_nickname}_test_numheads/{dataset_name}_perf_v_heads.json', 'w'))
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ericwtodd/function_vectors/751e2219d304eba471cffcacc9efd89a4f8ef3c4/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/intervention_utils.py:
--------------------------------------------------------------------------------
1 | from baukit import TraceDict, get_module
2 | import torch
3 | import re
4 | import bitsandbytes as bnb
5 |
6 | def get_module(model, name):
7 | """
8 | Finds the named module within the given model.
9 | """
10 | for n, m in model.named_modules():
11 | if n == name:
12 | return m
13 | raise LookupError(name)
14 |
15 |
16 | def replace_activation_w_avg(layer_head_token_pairs, avg_activations, model, model_config, idx_map, batched_input=False, last_token_only=False):
17 | """
18 | An intervention function for replacing activations with a computed average value.
19 | This function replaces the output of one (or several) attention head(s) with a pre-computed average value
20 | (usually taken from another set of runs with a particular property).
21 | The batched_input flag is used for systematic interventions where we are sweeping over all attention heads for a given (layer,token)
22 | The last_token_only flag is used for interventions where we only intervene on the last token (such as zero-shot or concept-naming)
23 |
24 | Parameters:
25 | layer_head_token_pairs: list of tuple triplets each containing a layer index, head index, and token index [(L,H,T), ...]
26 | avg_activations: torch tensor of the average activations (across ICL prompts) for each attention head of the model.
27 | model: huggingface model
28 | model_config: contains model config information (n layers, n heads, etc.)
29 | idx_map: dict mapping prompt label indices to ground truth label indices
30 | batched_input: whether or not to batch the intervention across all heads
31 | last_token_only: whether our intervention is only at the last token
32 |
33 | Returns:
34 | rep_act: A function that specifies how to replace activations with an average when given a hooked pytorch module.
35 | """
36 | edit_layers = [x[0] for x in layer_head_token_pairs]
37 |
38 | def rep_act(output, layer_name, inputs):
39 | current_layer = int(layer_name.split('.')[2])
40 | if current_layer in edit_layers:
41 | if isinstance(inputs, tuple):
42 | inputs = inputs[0]
43 |
44 | # Determine shapes for intervention
45 | original_shape = inputs.shape
46 | new_shape = inputs.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads)
47 | inputs = inputs.view(*new_shape) # inputs shape: (batch_size , tokens (n), heads, hidden_dim)
48 |
49 | # Perform Intervention:
50 | if batched_input:
51 | # Patch activations from avg activations into baseline sentences (i.e. n_head baseline sentences being modified in this case)
52 | for i in range(model_config['n_heads']):
53 | layer, head_n, token_n = layer_head_token_pairs[i]
54 | inputs[i, token_n, head_n] = avg_activations[layer, head_n, idx_map[token_n]]
55 | elif last_token_only:
56 | # Patch activations only at the last token for interventions like
57 | for (layer,head_n,token_n) in layer_head_token_pairs:
58 | if layer == current_layer:
59 | inputs[-1,-1,head_n] = avg_activations[layer,head_n,idx_map[token_n]]
60 | else:
61 | # Patch activations into baseline sentence found at index, -1 of the batch (targeted & multi-token patching)
62 | for (layer, head_n, token_n) in layer_head_token_pairs:
63 | if layer == current_layer:
64 | inputs[-1, token_n, head_n] = avg_activations[layer,head_n,idx_map[token_n]]
65 |
66 | inputs = inputs.view(*original_shape)
67 | proj_module = get_module(model, layer_name)
68 | out_proj = proj_module.weight
69 |
70 | if 'gpt2-xl' in model_config['name_or_path']: # GPT2-XL uses Conv1D (not nn.Linear) & has a bias term, GPTJ does not
71 | out_proj_bias = proj_module.bias
72 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj)
73 |
74 | elif 'gpt-j' in model_config['name_or_path'] or 'gemma' in model_config['name_or_path']:
75 | new_output = torch.matmul(inputs, out_proj.T)
76 |
77 | elif 'gpt-neox' in model_config['name_or_path'] or 'pythia' in model_config['name_or_path']:
78 | out_proj_bias = proj_module.bias
79 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj.T)
80 |
81 | elif 'llama' in model_config['name_or_path']:
82 | if '70b' in model_config['name_or_path']:
83 | # need to dequantize weights
84 | out_proj_dequant = bnb.functional.dequantize_4bit(out_proj.data, out_proj.quant_state)
85 | new_output = torch.matmul(inputs, out_proj_dequant.T)
86 | else:
87 | new_output = torch.matmul(inputs, out_proj.T)
88 |
89 | elif 'olmo' in model_config['name_or_path'].lower():
90 | new_output = torch.matmul(inputs, out_proj.T)
91 |
92 | return new_output
93 | else:
94 | return output
95 |
96 | return rep_act
97 |
98 | def add_function_vector(edit_layer, fv_vector, device, idx=-1):
99 | """
100 | Adds a vector to the output of a specified layer in the model
101 |
102 | Parameters:
103 | edit_layer: the layer to perform the FV intervention
104 | fv_vector: the function vector to add as an intervention
105 | device: device of the model (cuda gpu or cpu)
106 | idx: the token index to add the function vector at
107 |
108 | Returns:
109 | add_act: a fuction specifying how to add a function vector to a layer's output hidden state
110 | """
111 | def add_act(output, layer_name):
112 | current_layer = int(layer_name.split(".")[2])
113 | if current_layer == edit_layer:
114 | if isinstance(output, tuple):
115 | output[0][:, idx] += fv_vector.to(device)
116 | return output
117 | else:
118 | return output
119 | else:
120 | return output
121 |
122 | return add_act
123 |
124 | def function_vector_intervention(sentence, target, edit_layer, function_vector, model, model_config, tokenizer, compute_nll=False,
125 | generate_str=False):
126 | """
127 | Runs the model on the sentence and adds the function_vector to the output of edit_layer as a model intervention, predicting a single token.
128 | Returns the output of the model with and without intervention.
129 |
130 | Parameters:
131 | sentence: the sentence to be run through the model
132 | target: expected response of the model (str, or [str])
133 | edit_layer: layer at which to add the function vector
134 | function_vector: torch vector that triggers execution of a task
135 | model: huggingface model
136 | model_config: contains model config information (n layers, n heads, etc.)
137 | tokenizer: huggingface tokenizer
138 | compute_nll: whether to compute the negative log likelihood of a teacher-forced completion (used to compute perplexity (PPL))
139 | generate_str: whether to generate a string of tokens or predict a single token
140 |
141 | Returns:
142 | fvi_output: a tuple containing output results of a clean run and intervened run of the model
143 | """
144 | # Clean Run, No Intervention:
145 | device = model.device
146 | inputs = tokenizer(sentence, return_tensors='pt').to(device)
147 | original_pred_idx = len(inputs.input_ids.squeeze()) - 1
148 |
149 | if compute_nll:
150 | target_completion = "".join(sentence + target)
151 | nll_inputs = tokenizer(target_completion, return_tensors='pt').to(device)
152 | nll_targets = nll_inputs.input_ids.clone()
153 | target_len = len(nll_targets.squeeze()) - len(inputs.input_ids.squeeze())
154 | nll_targets[:,:-target_len] = -100 # This is the accepted value to skip indices when computing loss (see nn.CrossEntropyLoss default)
155 | output = model(**nll_inputs, labels=nll_targets)
156 | clean_nll = output.loss.item()
157 | clean_output = output.logits[:,original_pred_idx,:]
158 | intervention_idx = -1 - target_len
159 | elif generate_str:
160 | MAX_NEW_TOKENS = 16
161 | output = model.generate(inputs.input_ids, top_p=0.9, temperature=0.1,
162 | max_new_tokens=MAX_NEW_TOKENS)
163 | clean_output = tokenizer.decode(output.squeeze()[-MAX_NEW_TOKENS:])
164 | intervention_idx = -1
165 | else:
166 | clean_output = model(**inputs).logits[:,-1,:]
167 | intervention_idx = -1
168 |
169 | # Perform Intervention
170 | intervention_fn = add_function_vector(edit_layer, function_vector.reshape(1, model_config['resid_dim']), model.device, idx=intervention_idx)
171 | with TraceDict(model, layers=model_config['layer_hook_names'], edit_output=intervention_fn):
172 | if compute_nll:
173 | output = model(**nll_inputs, labels=nll_targets)
174 | intervention_nll = output.loss.item()
175 | intervention_output = output.logits[:,original_pred_idx,:]
176 | elif generate_str:
177 | output = model.generate(inputs.input_ids, top_p=0.9, temperature=0.1,
178 | max_new_tokens=MAX_NEW_TOKENS)
179 | intervention_output = tokenizer.decode(output.squeeze()[-MAX_NEW_TOKENS:])
180 | else:
181 | intervention_output = model(**inputs).logits[:,-1,:] # batch_size x n_tokens x vocab_size, only want last token prediction
182 |
183 | fvi_output = (clean_output, intervention_output)
184 | if compute_nll:
185 | fvi_output += (clean_nll, intervention_nll)
186 |
187 | return fvi_output
188 |
189 |
190 | def fv_intervention_natural_text(sentence, edit_layer, function_vector, model, model_config, tokenizer, max_new_tokens=16, num_interv_tokens=None, do_sample=False):
191 | """
192 | Allows for intervention in natural text where we generate and intervene on several tokens in a row.
193 |
194 | Parameters:
195 | sentence: sentence to intervene on with the FV
196 | edit_layer: layer at which to add the function vector
197 | function_vector: vector to add to the model that triggers execution of a task
198 | model: huggingface model
199 | model_config: dict with model config parameters (n_layers, n_heads, etc.)
200 | tokenizer: huggingface tokenizer
201 | max_new_tokens: number of tokens to generate
202 | num_interv_tokens: number of tokens to apply the intervention for (defaults to all subsequent generations)
203 | do_sample: whether to sample from top p tokens (True) or have deterministic greedy decoding (False)
204 |
205 | Returns:
206 | clean_output: tokens of clean output
207 | intervention_output: tokens of intervention output
208 |
209 | """
210 | # Clean Run, No Intervention:
211 | device = model.device
212 | inputs = tokenizer(sentence, return_tensors='pt').to(device)
213 | clean_output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id)
214 |
215 | # Perform Intervention
216 | intervention_fn = add_function_vector(edit_layer, function_vector, model.device)
217 |
218 | if num_interv_tokens is not None and num_interv_tokens < max_new_tokens: # Intervene only for a certain number of tokens
219 | num_extra_tokens = max_new_tokens - num_interv_tokens
220 | with TraceDict(model, layers=model_config['layer_hook_names'], edit_output=intervention_fn):
221 | intervention_output = model.generate(**inputs, max_new_tokens = num_interv_tokens, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id)
222 | intervention_output = model.generate(intervention_output, max_new_tokens=num_extra_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=do_sample)
223 | else:
224 | with TraceDict(model, layers=model_config['layer_hook_names'], edit_output=intervention_fn):
225 | intervention_output = model.generate(**inputs, max_new_tokens = max_new_tokens, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id)
226 |
227 | return clean_output, intervention_output
228 |
229 |
230 | def add_avg_to_activation(layer_head_token_pairs, avg_activations, model, model_config, batched_input=False, last_token_only=False):
231 | """
232 | An intervention function for adding a computed average value to activations.
233 | This function adds a pre-computed average value to the output of one (or several) attention head(s)
234 | (usually taken from another set of runs with a particular property).
235 | The batched_input flag is used for systematic interventions where we are sweeping over all attention heads for a given (layer,token)
236 | The last_token_only flag is used for interventions where we only intervene on the last token (such as zero-shot or concept-naming)
237 |
238 | Parameters:
239 | layer_head_token_pairs: list of tuple triplets each containing a layer index, head index, and token index [(L,H,T), ...]
240 | avg_activations: torch tensor of the average activations (across ICL prompts) for each attention head of the model.
241 | model: huggingface model
242 | model_config: contains model config information (n layers, n heads, etc.)
243 | batched_input: whether or not to batch the intervention across all heads
244 | last_token_only: whether our intervention is only at the last token
245 |
246 | Returns:
247 | add_act: A function that specifies how to replace activations with an average when given a hooked pytorch module.
248 | """
249 | edit_layers = [x[0] for x in layer_head_token_pairs]
250 | device = model.device
251 |
252 | def add_act(output, layer_name, inputs):
253 | current_layer = int(layer_name.split('.')[2])
254 | if current_layer in edit_layers:
255 | if isinstance(inputs, tuple):
256 | inputs = inputs[0]
257 |
258 | # Determine shapes for intervention
259 | original_shape = inputs.shape
260 | new_shape = inputs.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads)
261 | inputs = inputs.view(*new_shape) # inputs shape: (batch_size , tokens (n), heads, hidden_dim)
262 |
263 | # Perform Intervention:
264 | if batched_input:
265 | # Patch activations from avg activations into baseline sentences (i.e. n_head baseline sentences being modified in this case)
266 | for i in range(model_config['n_heads']):
267 | layer, head_n, token_n = layer_head_token_pairs[i]
268 | inputs[i, token_n, head_n] += avg_activations[layer, head_n, token_n].to(device)
269 | elif last_token_only:
270 | # Patch activations only at the last token for interventions like: (zero-shot, concept-naming, etc.)
271 | for (layer,head_n,token_n) in layer_head_token_pairs:
272 | if layer == current_layer:
273 | inputs[-1,-1,head_n] += avg_activations[layer,head_n,token_n].to(device)
274 | else:
275 | # Patch activations into baseline sentence found at index, -1 of the batch (targeted & multi-token patching)
276 | for (layer, head_n, token_n) in layer_head_token_pairs:
277 | if layer == current_layer:
278 | inputs[-1, token_n, head_n] += avg_activations[layer,head_n,token_n].to(device)
279 |
280 | inputs = inputs.view(*original_shape)
281 | proj_module = get_module(model, layer_name)
282 | out_proj = proj_module.weight
283 |
284 | if 'gpt2-xl' in model_config['name_or_path']: # GPT2-XL uses Conv1D (not nn.Linear) & has a bias term, GPTJ does not
285 | out_proj_bias = proj_module.bias
286 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj)
287 |
288 | elif 'gpt-j' in model_config['name_or_path'] or 'gemma' in model_config['name_or_path']:
289 | new_output = torch.matmul(inputs, out_proj.T)
290 |
291 | elif 'gpt-neox' in model_config['name_or_path'] or 'pythia' in model_config['name_or_path']:
292 | out_proj_bias = proj_module.bias
293 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj.T)
294 |
295 | elif 'llama' in model_config['name_or_path']:
296 | if '70b' in model_config['name_or_path']:
297 | # need to dequantize weights
298 | out_proj_dequant = bnb.functional.dequantize_4bit(out_proj.data, out_proj.quant_state)
299 | new_output = torch.matmul(inputs, out_proj_dequant.T)
300 | else:
301 | new_output = torch.matmul(inputs, out_proj.T)
302 |
303 | elif 'olmo' in model_config['name_or_path'].lower():
304 | new_output = torch.matmul(inputs, out_proj.T)
305 |
306 | return new_output
307 | else:
308 | return output
309 |
310 | return add_act
--------------------------------------------------------------------------------
/src/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
4 | import os
5 | import random
6 | from typing import *
7 |
8 |
9 | def load_gpt_model_and_tokenizer(model_name:str, device='cuda', revision=None):
10 | """
11 | Loads a huggingface model and its tokenizer
12 |
13 | Parameters:
14 | model_name: huggingface name of the model to load (e.g. GPTJ: "EleutherAI/gpt-j-6B", or "EleutherAI/gpt-j-6b")
15 | device: 'cuda' or 'cpu'
16 |
17 | Returns:
18 | model: huggingface model
19 | tokenizer: huggingface tokenizer
20 | MODEL_CONFIG: config variables w/ standardized names
21 |
22 | """
23 | assert model_name is not None
24 |
25 | print("Loading: ", model_name)
26 |
27 | if model_name == 'gpt2-xl':
28 | tokenizer = AutoTokenizer.from_pretrained(model_name)
29 | tokenizer.pad_token = tokenizer.eos_token
30 | model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
31 |
32 | MODEL_CONFIG={"n_heads":model.config.n_head,
33 | "n_layers":model.config.n_layer,
34 | "resid_dim":model.config.n_embd,
35 | "name_or_path":model.config.name_or_path,
36 | "attn_hook_names":[f'transformer.h.{layer}.attn.c_proj' for layer in range(model.config.n_layer)],
37 | "layer_hook_names":[f'transformer.h.{layer}' for layer in range(model.config.n_layer)],
38 | "prepend_bos":False}
39 |
40 | elif 'gpt-j' in model_name.lower():
41 | tokenizer = AutoTokenizer.from_pretrained(model_name)
42 | tokenizer.pad_token = tokenizer.eos_token
43 | model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True).to(device)
44 |
45 | MODEL_CONFIG={"n_heads":model.config.n_head,
46 | "n_layers":model.config.n_layer,
47 | "resid_dim":model.config.n_embd,
48 | "name_or_path":model.config.name_or_path,
49 | "attn_hook_names":[f'transformer.h.{layer}.attn.out_proj' for layer in range(model.config.n_layer)],
50 | "layer_hook_names":[f'transformer.h.{layer}' for layer in range(model.config.n_layer)],
51 | "prepend_bos":False}
52 |
53 | elif 'gpt-neox' in model_name.lower() or 'pythia' in model_name.lower():
54 | tokenizer = AutoTokenizer.from_pretrained(model_name)
55 | tokenizer.pad_token = tokenizer.eos_token
56 | if revision is not None and 'pythia' in model_name.lower():
57 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, revision=revision).to(device)
58 | else:
59 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
60 |
61 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads,
62 | "n_layers":model.config.num_hidden_layers,
63 | "resid_dim": model.config.hidden_size,
64 | "name_or_path":model.config.name_or_path,
65 | "attn_hook_names":[f'gpt_neox.layers.{layer}.attention.dense' for layer in range(model.config.num_hidden_layers)],
66 | "layer_hook_names":[f'gpt_neox.layers.{layer}' for layer in range(model.config.num_hidden_layers)],
67 | "prepend_bos":False}
68 |
69 | elif 'gemma' in model_name.lower():
70 | tokenizer = AutoTokenizer.from_pretrained(model_name)
71 | tokenizer.pad_token = tokenizer.eos_token
72 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
73 |
74 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads,
75 | "n_layers":model.config.num_hidden_layers,
76 | "resid_dim":model.config.hidden_size,
77 | "name_or_path":model.config._name_or_path,
78 | "attn_hook_names":[f'model.layers.{layer}.self_attn.o_proj' for layer in range(model.config.num_hidden_layers)],
79 | "layer_hook_names":[f'model.layers.{layer}' for layer in range(model.config.num_hidden_layers)],
80 | "prepend_bos":True}
81 |
82 | elif 'llama' in model_name.lower():
83 | if '70b' in model_name.lower():
84 | # use quantization. requires `bitsandbytes` library
85 | from transformers import BitsAndBytesConfig
86 | bnb_config = BitsAndBytesConfig(
87 | load_in_4bit=True,
88 | bnb_4bit_quant_type='nf4',
89 | bnb_4bit_use_double_quant=True,
90 | bnb_4bit_compute_dtype=torch.float16
91 | )
92 | tokenizer = LlamaTokenizer.from_pretrained(model_name)
93 | model = LlamaForCausalLM.from_pretrained(
94 | model_name,
95 | trust_remote_code=True,
96 | quantization_config=bnb_config
97 | )
98 | else:
99 | if '7b' in model_name.lower() or '8b' in model_name.lower():
100 | model_dtype = torch.float32
101 | else: #half precision for bigger llama models
102 | model_dtype = torch.float16
103 |
104 | # If transformers version is < 4.31 use LlamaLoaders
105 | # tokenizer = LlamaTokenizer.from_pretrained(model_name)
106 | # model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(device)
107 |
108 | # If transformers version is >= 4.31, use AutoLoaders
109 | tokenizer = AutoTokenizer.from_pretrained(model_name)
110 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(device)
111 |
112 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads,
113 | "n_layers":model.config.num_hidden_layers,
114 | "resid_dim":model.config.hidden_size,
115 | "name_or_path":model.config._name_or_path,
116 | "attn_hook_names":[f'model.layers.{layer}.self_attn.o_proj' for layer in range(model.config.num_hidden_layers)],
117 | "layer_hook_names":[f'model.layers.{layer}' for layer in range(model.config.num_hidden_layers)],
118 | "prepend_bos":True}
119 | elif "olmo" in model_name.lower():
120 |
121 | model_dtype = torch.float32
122 | tokenizer = AutoTokenizer.from_pretrained(model_name)
123 | if revision is not None:
124 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, revision=revision).to(device)
125 | else:
126 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(device)
127 |
128 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads,
129 | "n_layers":model.config.num_hidden_layers,
130 | "resid_dim":model.config.hidden_size,
131 | "name_or_path":model.config._name_or_path,
132 | "attn_hook_names":[f'model.layers.{layer}.self_attn.o_proj' for layer in range(model.config.num_hidden_layers)],
133 | "layer_hook_names":[f'model.layers.{layer}' for layer in range(model.config.num_hidden_layers)],
134 | "prepend_bos":False}
135 | else:
136 | raise NotImplementedError("Still working to get this model available!")
137 |
138 |
139 | return model, tokenizer, MODEL_CONFIG
140 |
141 | def set_seed(seed: int) -> None:
142 | """
143 | Sets the seed to make everything deterministic, for reproducibility of experiments
144 |
145 | Parameters:
146 | seed: the number to set the seed to
147 |
148 | Return: None
149 | """
150 |
151 | # Random seed
152 | random.seed(seed)
153 |
154 | # Numpy seed
155 | np.random.seed(seed)
156 |
157 | # Torch seed
158 | torch.manual_seed(seed)
159 | torch.cuda.manual_seed(seed)
160 | torch.backends.cudnn.deterministic = True
161 | torch.backends.cudnn.benchmark = True
162 |
163 | # os seed
164 | os.environ['PYTHONHASHSEED'] = str(seed)
--------------------------------------------------------------------------------
/src/utils/prompt_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from pathlib import Path
4 | import os
5 | from typing import *
6 | from sklearn.model_selection import train_test_split
7 |
8 |
9 |
10 | def create_fewshot_primer(prompt_data) -> str:
11 | """Creates the primer string for GPT in-context learning
12 |
13 | Parameters:
14 | prompt_data: dict containing ICL prompt examples, and template information
15 |
16 | Returns:
17 | prompt: the constructed ICL prompt primer as a string
18 | """
19 | prompt = ''
20 | prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions']
21 |
22 | for example in prompt_data['examples']:
23 |
24 | prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input']
25 | prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output']
26 |
27 | return prompt
28 |
29 | def create_prompt(prompt_data, sentence=None) -> str:
30 | """Creates a prompt using the specified sentence for GPT in-context learning
31 |
32 | Parameters:
33 | prompt_data: dict containing ICL prompt examples, and template information
34 | sentence: a query string (sentence/word) to include in the ICL prompt
35 |
36 | Returns:
37 | prompt: the constructed ICL prompt as a string
38 | """
39 | if sentence is None and prompt_data['query_target'] is not None:
40 | sentence = prompt_data['query_target']['input']
41 |
42 | if isinstance(sentence, list):
43 | sentence = sentence[0]
44 |
45 | prompt_init = create_fewshot_primer(prompt_data)
46 | prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input']
47 | prompt += prompt_data['prefixes']['output']
48 |
49 | return prompt
50 |
51 | # Partial primer & prompt functions
52 | def create_partial_fewshot_primer(prompt_data, include = np.arange(8)) -> str:
53 | """Creates the primer string for GPT in-context learning, filtering to include a subset of specified priming strings
54 |
55 | Parameters:
56 | prompt_data: dict containing ICL prompt examples, and template information
57 | include: an iterable of ints indicating which examples to include in the ICL prompt
58 |
59 | Returns:
60 | prompt: the constructed ICL prompt primer as a string
61 | """
62 | prompt = ''
63 | prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions']
64 |
65 | # Grab each priming example in the specified order.
66 | for i in include:
67 | example = prompt_data['examples'][i]
68 | prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input']
69 | prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output']
70 |
71 | return prompt
72 |
73 | def create_partial_prompt(prompt_data, sentence=None, include=np.arange(8)) -> str:
74 | """Creates a prompt using the specified sentence and partial list of in-context primer sentences
75 |
76 | Parameters:
77 | prompt_data: dict containing ICL prompt examples, and template information
78 | sentence: a query string (sentence /word) to include in the ICl prompt
79 | include: an iterable of ints indicating which examples to include in the ICL prompt
80 |
81 | Returns:
82 | prompt: the prompt as a string
83 | """
84 | if sentence is None and prompt_data['query_target'] is not None:
85 | sentence = prompt_data['query_target']['input']
86 | if isinstance(sentence, list):
87 | sentence = sentence[0]
88 |
89 | prompt_init = create_partial_fewshot_primer(prompt_data, include)
90 |
91 | prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input']
92 | prompt += prompt_data['prefixes']['output']
93 |
94 | return prompt
95 |
96 |
97 | # UTILS FOR GENERATING PROMPT META LABELS
98 | def get_prompt_parts_and_labels(prompt_data, query_sentence=None):
99 | """
100 | Generates high-level labels for ICL prompts according to its ICL role, such as demonstration, label, separator, structural, etc.
101 | The JSON prompt format should include 'instructions', 'examples' with ('input', 'output') pairs,
102 | 'prefixes', and 'separators' for 'input', 'output', and 'instructions'.
103 | Used in conjunction with tokenize_labels
104 |
105 | Parameters:
106 | prompt_data: dict containing ICL prompt examples, and template information
107 | query_sentence: optional (if contained in prompt_data) str containing a query for an ICL prompt
108 |
109 | Returns:
110 | prompt_parts: structured list of words to be flattened and tokenized
111 | prompt_part_labels: structured list of labels to be flattened & extended over tokenization
112 | """
113 | if query_sentence is None and prompt_data['query_target'] is not None:
114 | query_sentence = prompt_data['query_target']['input']
115 | if isinstance(query_sentence, list):
116 | query_sentence = query_sentence[0]
117 | n_examples = len(prompt_data['examples'])
118 | assemble_icl_example = lambda example, prompt_data: [prompt_data['prefixes']['input'], example['input'], prompt_data['separators']['input'], prompt_data['prefixes']['output'], example['output'], prompt_data['separators']['output']]
119 | assemble_icl_query = lambda query, prompt_data: [prompt_data['prefixes']['input'], query, prompt_data['separators']['input'], prompt_data['prefixes']['output']]
120 |
121 | prompt_instructions = [prompt_data['prefixes']['instructions'], prompt_data['instructions'], prompt_data['separators']['instructions']]
122 | prompt_icl_examples = [assemble_icl_example(prompt_data['examples'][i], prompt_data) for i in range(n_examples)]
123 | prompt_icl_query = [assemble_icl_query(query_sentence, prompt_data)]
124 |
125 | prompt_instructions_labels = ['bos_token', 'instructions_token', 'separator_token']
126 | prompt_icl_examples_labels = [['structural_token', f'demonstration_{i+1}_token', 'separator_token', 'structural_token', f'demonstration_{i+1}_label_token', 'separator_token'] for i in range(n_examples)]
127 | prompt_icl_query_labels = [['query_structural_token', 'query_demonstration_token', 'query_separator_token', 'query_structural_token']]
128 |
129 | prompt_parts = prompt_instructions + prompt_icl_examples + prompt_icl_query
130 |
131 | prompt_part_labels = prompt_instructions_labels + prompt_icl_examples_labels + prompt_icl_query_labels
132 |
133 | return prompt_parts, prompt_part_labels
134 |
135 | def extend_labels(sentence_parts, text_labels, tokenizer, label_init=[]):
136 | """
137 | Extends ICL component labels across words that are tokenized into multiple tokens
138 |
139 | Parameters:
140 | sentence_parts: list, where each element is either a token (str), phrase (str), or list of tokens/phrases
141 | text_labels: list with the same structure as 'sentence_parts', with a corresponding label for that level of the input sentence.
142 | tokenizer: huggingface tokenizer
143 |
144 | Returns:
145 | final_labels: flattened/extended list of token labels for an ICL prompt (split into parts, contained in sentence_parts and text_labels)
146 | """
147 | zipped_up = [list(zip(x,y)) if isinstance(x, list) else [(x,y)] for x,y in list(zip(sentence_parts,text_labels)) ]
148 |
149 | prompt_builder = ''
150 | final_labels = label_init
151 | for element in zipped_up:
152 |
153 | for j, (word,label) in enumerate(element):
154 | if len(word) == 0:
155 | continue
156 | pre = len(tokenizer.tokenize(prompt_builder))
157 | prompt_builder += word
158 | post = len(tokenizer.tokenize(prompt_builder))
159 |
160 | actual_tokens = post-pre
161 |
162 | if actual_tokens == 0:
163 | # if tokenization gobbles up a previous label, then we overwrite the last previous label w/ label that should've been added
164 | final_labels[-1] = label
165 |
166 | final_labels.extend([label] * (actual_tokens))
167 |
168 | if j==3 or j==2 and len(element[3])==0:
169 | final_labels[-1] = final_labels[-1].replace('structural', 'predictive').replace('separator', 'predictive')
170 | if j==5:
171 | final_labels[-actual_tokens] = final_labels[-actual_tokens].replace('separator', 'end_of_example')
172 |
173 | return final_labels
174 |
175 | def tokenize_labels(sentence_parts, text_labels, tokenizer, prepend_bos=False):
176 | """
177 | Extends phrase-level labels across tokenization for in-context learning prompts. Tested with GPT-2's tokenizer from huggingface.
178 | Parameters:
179 | sentence_parts: list, where each element is either a token (str), phrase (str), or list of tokens/phrases
180 | text_labels: list with the same structure as 'sentence_parts', with a corresponding label for that level of the input sentence.
181 | tokenizer: huggingface tokenizer
182 |
183 | Returns:
184 | labels: flattened/extended list of token labels for an ICL prompt (split into parts, contained in sentence_parts and text_labels)
185 |
186 | based on the tokenize_and_preserve_labels function from:
187 | https://www.depends-on-the-definition.com/named-entity-recognition-with-bert/
188 | """
189 |
190 | # If the model typically prepends a bos, we add a bos label to label init
191 | if prepend_bos:
192 | labels = extend_labels(sentence_parts, text_labels, tokenizer, label_init=['bos_token'])
193 | else:
194 | labels = extend_labels(sentence_parts, text_labels, tokenizer, label_init=[])
195 |
196 | return labels
197 |
198 | def get_token_meta_labels(prompt_data, tokenizer, query=None, prepend_bos=False):
199 | """
200 | Computes the ICL meta-labels for every token in a prompt.
201 |
202 | Parameters:
203 | prompt_data: dict containing ICL prompt examples, and template information
204 | tokenizer: huggingface tokenizer
205 | query: str of the query input
206 |
207 | Return:
208 | token_labels: list of tuples (prompt token index, token, label)
209 | prompt_string: full prompt as a string
210 | """
211 | if query is None and prompt_data['query_target'] is not None:
212 | query = prompt_data['query_target']['input']
213 | if isinstance(query, list):
214 | query = query[0]
215 |
216 | prompt_parts, prompt_part_labels = get_prompt_parts_and_labels(prompt_data, query_sentence=query)
217 | token_meta_labels = tokenize_labels(prompt_parts, prompt_part_labels, tokenizer, prepend_bos)
218 | prompt_string = create_prompt(prompt_data=prompt_data, sentence=query)
219 | tokens = [tokenizer.decode(x) for x in tokenizer(prompt_string).input_ids]
220 | token_labels = list(zip(np.arange(len(tokens)), tokens, token_meta_labels))
221 |
222 | return token_labels, prompt_string
223 |
224 | def get_dummy_token_labels(n_icl_examples, tokenizer, model_config, prefixes=None, separators=None):
225 | """
226 | Computes the ground-truth meta labels & indices for an ICL prompt with the specified number of example pairs
227 | These GT labels assume each word gets a single token
228 |
229 | Parameters:
230 | n_icl_examples: number of ICL example pairs
231 | tokenizer: huggingface tokenizer
232 | prefixes: ICL template prefixes
233 | separators: ICL template separators
234 |
235 | Return:
236 | final_token_labels: list of tuples containing a token's index and label name [(int, str), ... ]
237 | """
238 | # If the model already prepends a bos token by default, we don't want to add one to our prompts
239 | prepend_bos = False if model_config['prepend_bos'] else True
240 |
241 | if prefixes is not None and separators is not None:
242 | dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples},
243 | query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos,
244 | prefixes=prefixes, separators=separators)
245 | else:
246 | dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples},
247 | query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos)
248 | final_token_labels, _ = get_token_meta_labels(dummy_prompt_data,tokenizer, prepend_bos=model_config['prepend_bos'])
249 | final_token_labels = [(x[0],x[-1]) for x in final_token_labels]
250 | return final_token_labels
251 |
252 | def compute_duplicated_labels(token_labels, gt_labels):
253 | """
254 | Computes a map between duplicated labels and ground truth label positions for localized averaging
255 |
256 | Parameters:
257 | token_labels: token labels of actual prompt being used
258 | gt_labels: token labels for a "ground truth" prompt that assumes each input & output is a single token
259 |
260 | Returns:
261 | index_map: a dict mapping prompt label indices to ground truth label indices
262 | dup_label_ranges: indices where labels should be duplicated
263 | """
264 | check_inds = list(filter(lambda x: 'demo' in x[2], token_labels))
265 | dup_ranges = pd.DataFrame(check_inds).groupby(2)[0].aggregate(lambda x: (x.min(), x.max()))
266 | dup_labels = [v for v,x in dup_ranges.items() if (x[1] - x[0]) > 0]
267 |
268 | dup_label_ranges = dup_ranges[dup_labels].to_dict()
269 | dup_inds = pd.DataFrame(check_inds)[pd.DataFrame(check_inds)[2].duplicated()][0].values
270 |
271 | index_map = {k:v[0] for (k,v) in zip([x[0] for x in token_labels if x[0] not in dup_inds], gt_labels)}
272 |
273 | return index_map, dup_label_ranges
274 |
275 | def update_idx_map(idx_map, idx_avg) -> dict:
276 | """
277 | Updates the idx_map to map duplicate tokens to its gt token position
278 | """
279 | update_map = {}
280 | for (i,j) in idx_avg.values():
281 | for k in range(i,j+1):
282 | if k not in idx_map.keys():
283 | update_map[k] = idx_map[i]
284 |
285 | update_map = {**idx_map, **update_map}
286 | return update_map
287 |
288 |
289 | def word_pairs_to_prompt_data(word_pairs : dict,
290 | instructions: str = "",
291 | prefixes: dict = {"input":"Q:", "output":"A:","instructions":""},
292 | separators: dict = {"input":"\n", "output":"\n\n", "instructions":""},
293 | query_target_pair: dict = None, prepend_bos_token=False,
294 | shuffle_labels=False, prepend_space=True) -> dict:
295 | """Takes a dataset of word pairs, and constructs a prompt_data dict with additional information to construct an ICL prompt.
296 | Parameters:
297 | word_pairs: dict of the form {'word1':['a', 'b', ...], 'word2':['c', 'd', ...]}
298 | instructions: prefix instructions for an ICL prompt
299 | prefixes: dict of ICL prefixes that are prepended to inputs, outputs and instructions
300 | separators: dict of ICL separators that are appended to inputs, outputs and instructions
301 | query_target_pair: dict with a single input-output pair acting as the query for the prompt
302 | prepend_bos_token: whether or not to prepend a BOS token to the prompt
303 | shuffle_labels: whether to shuffle the ICL labels
304 | prepend_space: whether to prepend a space to every input and output token
305 |
306 | Returns:
307 | prompt_data: dict containing ICL prompt examples, and template information
308 | """
309 | prompt_data = {}
310 | prompt_data['instructions'] = instructions
311 | prompt_data['separators'] = separators
312 | if prepend_bos_token:
313 | prefixes = {k:(v if k !='instructions' else '<|endoftext|>' + v) for (k,v) in prefixes.items()}
314 | prompt_data['prefixes'] = prefixes
315 |
316 | if query_target_pair is not None:
317 | query_target_pair = {k:(v[0] if isinstance(v, list) else v) for k,v in query_target_pair.items()}
318 | prompt_data['query_target'] = query_target_pair
319 |
320 | if shuffle_labels:
321 | randomized_pairs = [np.random.permutation(x).tolist() if i==1 else x for (i,x) in enumerate(list(word_pairs.values()))] # shuffle labels only
322 | if prepend_space:
323 | prompt_data['examples'] = [{'input':' ' + str(w1), 'output':' ' + str(w2)} for (w1,w2) in list(zip(*randomized_pairs))]
324 | prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None
325 | else:
326 | prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*randomized_pairs))]
327 | else:
328 | if prepend_space:
329 | prompt_data['examples'] = [{'input':' ' + str(w1), 'output':' ' + str(w2)} for (w1,w2) in list(zip(*word_pairs.values()))]
330 | prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None
331 | else:
332 | prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*word_pairs.values()))]
333 |
334 | return prompt_data
335 |
336 |
337 | # DATASET UTILS
338 | class ICLDataset:
339 | """
340 | A simple dataset class containing input-output pairs, used for ICL prompt construction.
341 | """
342 | def __init__(self, dataset):
343 | if isinstance(dataset, str):
344 | self.raw_data = pd.read_json(dataset)
345 | elif isinstance(dataset, dict):
346 | self.raw_data = pd.DataFrame(dataset)
347 | self.raw_data = self.raw_data[['input', 'output']]
348 |
349 | def __getitem__(self,i):
350 | if isinstance(i, int):
351 | return self.raw_data.iloc[i].to_dict()
352 | elif isinstance(i, slice):
353 | return self.raw_data.iloc[i].to_dict(orient='list')
354 | elif isinstance(i, list) or isinstance(i, np.ndarray):
355 | return self.raw_data.iloc[i].to_dict(orient='list')
356 | elif isinstance(i, str):
357 | if i not in self.raw_data.columns:
358 | raise KeyError(f"Column '{i}' not in the dataset. Current columns in the dataset: {self.raw_data.columns.to_list()}")
359 | else:
360 | return self.raw_data[i].to_list()
361 | else:
362 | raise ValueError(f"{i} is not a valid index type. Expected one of: [int, list, np.ndarray, slice, str]")
363 |
364 | def __len__(self):
365 | return len(self.raw_data)
366 |
367 | def __repr__(self):
368 | s = "ICLDataset" + "({\n\tfeatures: " + f"{self.raw_data.columns.to_list()},\n\tnum_rows: {self.__len__()}" + "\n})"
369 | return s
370 |
371 | def split_icl_dataset(dataset, train_size=None, test_size=0.3, seed=42) -> Dict[str,ICLDataset]:
372 | """
373 | Uses scikit-learn's train_test split to create train, valid, test dataset from provided dataset.
374 |
375 | Parameters:
376 | dataset: ICL dataset
377 | train_size: percentage of data (float between 0 and 1) to put in the training data split
378 | test_size: percentage of data (float between 0 and 1) to put into the test data split
379 | seed: seed used for splitting the data
380 |
381 | Returns:
382 | dict containing train, valid, test ICL datasets
383 | """
384 | if train_size is None and test_size is None:
385 | train_size = 0.7
386 | test_size = 0.3
387 |
388 | elif train_size is not None and test_size is None:
389 | test_size = 1-train_size
390 |
391 | elif train_size is None and test_size is not None:
392 | train_size = 1-test_size
393 |
394 | elif train_size is not None and test_size is not None:
395 | assert train_size + test_size == 1
396 |
397 | train, valid = train_test_split(dataset.raw_data, test_size=test_size, random_state=seed)
398 | test, valid = train_test_split(valid, test_size=test_size, random_state=seed)
399 |
400 | train = ICLDataset(train.to_dict(orient='list'))
401 | valid = ICLDataset(valid.to_dict(orient='list'))
402 | test = ICLDataset(test.to_dict(orient='list'))
403 |
404 | return {'train':train, 'valid':valid, 'test':test}
405 |
406 |
407 | def load_dataset(task_name: str,
408 | root_data_dir: str = '../dataset_files',
409 | test_size = 0.3,
410 | seed=32
411 | ) -> Dict[str,ICLDataset]:
412 | """
413 | Loads a dataset with input/output pairs
414 |
415 | Parameters:
416 | task_name: the name of the task dataset
417 | root_data_dir: the root directory where the data comes from
418 | test_size: fraction used in train/test split
419 |
420 | Return:
421 | dataset: the dict contain the train/valid/test dataset splits
422 | """
423 |
424 | data_folders = ['abstractive', 'extractive']
425 | assert test_size <= 1.0
426 |
427 | path = Path(root_data_dir)
428 | d_group_map = [(dataset_type, os.path.exists(os.path.join(root_data_dir, dataset_type, task_name+'.json'))) for dataset_type in data_folders]
429 |
430 | d_group = list(filter(lambda x: x[1], d_group_map))
431 |
432 | assert len(d_group) !=0 and len(d_group) == 1, f"Error! 'task_name'={task_name}.json must be uniquely contained in one of these directories:{data_folders}. Please check the root_data_dir"
433 | dataset_folder = d_group[0][0]
434 |
435 | d_path = os.path.join(path, dataset_folder, f'{task_name}.json')
436 |
437 | dataset = ICLDataset(d_path)
438 | dataset = split_icl_dataset(dataset, test_size=test_size, seed=seed)
439 |
440 | return dataset
--------------------------------------------------------------------------------
/src/vocab_reconstruction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch, numpy as np
3 | import argparse
4 |
5 | # Include prompt creation helper functions
6 | from utils.prompt_utils import load_dataset
7 | from utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
8 | from utils.eval_utils import n_shot_eval_no_intervention, n_shot_eval
9 | from utils.model_utils import load_gpt_model_and_tokenizer, set_seed
10 |
11 | def optim_loop(v_n, target, decoder, loss_fn, optimizer, n_steps:int=1000, verbose:bool=False, restrict_vocab:int=50400):
12 | if target.shape[-1] != restrict_vocab:
13 | inds = torch.topk(target, restrict_vocab).indices[0]
14 | Z = torch.zeros(target.size()).cuda()
15 | Z[:,inds] = target[:,inds]
16 | else:
17 | Z = target
18 |
19 | for i in range(n_steps):
20 | loss = loss_fn(decoder(v_n),Z)
21 | loss.backward()
22 | if verbose:
23 | print(f"Loss:{loss.item()}, iter:{i}")
24 | optimizer.step()
25 | optimizer.zero_grad()
26 | return v_n
27 |
28 | def vocab_reconstruction(datasets, n_steps:int=1000, lr:float=0.5, n_seeds:int=5, n_trials:int=100, n_shots:int=10, restrict_vocab_list=[100,50400], return_vecs:bool=False):
29 | """
30 | Computes and evaluates a function vector reconstruction which matches its output vocabulary distribution.
31 |
32 | Parameters:
33 | n_steps: number of optimization steps
34 | lr: adam learning rate
35 | n_seeds: number of seeds to run
36 | n_trials: number of prompts to compute task-conditioned mean head activations over
37 | n_shots: number of shots for task-conditioned mean prompts
38 | restrict_vocab_list: list of ints determining how many vocab words to match. Defaults to 100 & full-vocab (which is 50400 for GPT-J)
39 | return_vecs: whether to return the function vectors and their corresponding vocab-optimized reconstruction vectors
40 |
41 | Returns:
42 | orig_results: FV results
43 | zs_results:
44 | kl_divs: kl divergences between the distribution of the FV and its reconstruction
45 | fvs: (optional) the function vectors used
46 | vns: (optional) the vocab-optimized reconstruction vectors
47 | """
48 |
49 | seeds = {k:[] for k in datasets}
50 | orig_results = {k:[] for k in datasets}
51 | fvs = {k:[] for k in datasets}
52 | vns = {k:{j:[] for j in range(len(restrict_vocab_list))} for k in datasets}
53 | zs_results = {k:{j:[] for j in range(len(restrict_vocab_list))} for k in datasets}
54 | kl_divs = {k:{j:[] for j in range(len(restrict_vocab_list))} for k in datasets}
55 |
56 |
57 | for dataset_name in datasets:
58 | print(f"Dataset: {dataset_name}")
59 |
60 | for i in range(n_seeds):
61 | seed = np.random.randint(100000)
62 | print(f"seed:{seed}")
63 | seeds[dataset_name].append(seed)
64 | set_seed(seed)
65 |
66 | # Disable gradients when extracting activations & computing FV
67 | torch.set_grad_enabled(False)
68 |
69 | dataset = load_dataset(dataset_name, seed=seed, root_data_dir=root_data_dir)
70 |
71 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='valid')
72 | filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0]
73 |
74 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots,
75 | N_TRIALS=n_trials, filter_set=filter_set_validation)
76 |
77 | fv, _ = compute_universal_function_vector(mean_activations, model, model_config)
78 |
79 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='test')
80 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0]
81 |
82 | fv_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=9, n_shots=0,
83 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set)
84 |
85 | orig_results[dataset_name].append(fv_results)
86 | fvs[dataset_name].append(fv)
87 |
88 | for j, vocab_size in enumerate(restrict_vocab_list):
89 | # Enable Gradients for Optimization
90 | torch.set_grad_enabled(True)
91 | v_n = torch.randn(fv.size()).cuda()
92 | v_n.requires_grad=True
93 |
94 | # Optim setup
95 | loss_fn = torch.nn.CrossEntropyLoss()
96 | optimizer = torch.optim.Adam([v_n], lr=lr)
97 | decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head).to(model.device)
98 |
99 | decoder.requires_grad=True
100 | for p in decoder.parameters():
101 | p.requires_grad = True
102 |
103 | target = torch.nn.functional.softmax(decoder(fv), dim=-1).detach()
104 |
105 | computed_vn = optim_loop(v_n, target, decoder, loss_fn, optimizer, verbose=False, n_steps=n_steps, restrict_vocab=vocab_size)
106 |
107 | scaled_vn = computed_vn / torch.linalg.norm(computed_vn) * torch.linalg.norm(fv)
108 |
109 | zs_reconstruction_results = n_shot_eval(dataset=dataset, fv_vector=scaled_vn, edit_layer=9, n_shots=0,
110 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set)
111 |
112 | zs_results[dataset_name][j].append(zs_reconstruction_results)
113 | vns[dataset_name][j].append(scaled_vn.detach())
114 |
115 | # Compute kl divergence between two distributions
116 | if vocab_size != 50400:
117 | tp = torch.softmax(decoder(fvs[dataset_name][i]), dim=-1)
118 | inds = torch.topk(tp, vocab_size).indices[0]
119 | vn_ps = torch.softmax(decoder(vns[dataset_name][j][i]), dim=-1)[:,inds]
120 |
121 | log_probs = torch.log(vn_ps / vn_ps.sum())
122 | target_probs = tp[:,inds] / tp[:,inds].sum()
123 | else:
124 | log_probs = torch.log(torch.softmax(decoder(vns[dataset_name][j][i]), dim=-1))
125 | target_probs = torch.softmax(decoder(fvs[dataset_name][i]), dim=-1)
126 |
127 | kl_divs[dataset_name][j].append(torch.nn.functional.kl_div(log_probs, target_probs, reduction='batchmean').item())
128 |
129 | if return_vecs:
130 | return orig_results, zs_results, kl_divs, fvs, vns
131 | else:
132 | return orig_results, zs_results, kl_divs
133 |
134 |
135 | if __name__ == "__main__":
136 |
137 | parser = argparse.ArgumentParser()
138 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
139 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
140 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results')
141 | parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5)
142 | parser.add_argument('--n_trials', help='Number of trials to use for computing task-conditioned mean head activations', type=int, required=False, default=100)
143 | parser.add_argument('--n_shots', help='Number of shots to use for prompts when computing task-conditioned mean head activations', type=int, required=False, default=10)
144 | parser.add_argument('--lr', help="Learning Rate for Adam Optimizer", type=int, required=False, default=0.5)
145 | parser.add_argument('--n_steps', help="Learning Rate for Adam Optimizer", type=int, required=False, default=1000)
146 |
147 | args = parser.parse_args()
148 |
149 | # Gather inputs
150 | model_name = args.model_name
151 | root_data_dir = args.root_data_dir
152 | save_path_root = args.save_path_root
153 | n_seeds = args.n_seeds
154 | n_trials = args.n_trials
155 | n_shots = args.n_shots
156 | lr = args.lr
157 | n_steps = args.n_steps
158 |
159 |
160 | # Load Model & Tokenizer
161 | torch.set_grad_enabled(False)
162 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
163 |
164 | datasets = ['antonym', 'english-french', 'capitalize', 'present-past', 'singular-plural', 'country-capital']
165 | args.datasets = datasets
166 |
167 | # Test Loop:
168 | orig_results, zs_results, kl_divs = vocab_reconstruction(datasets, n_steps=n_steps, lr=lr, n_seeds=n_seeds, n_trials=n_trials, n_shots=n_shots, restrict_vocab_list=[100,50400])
169 |
170 |
171 | # Extract Summary Results:
172 | os.makedirs(os.path.join(save_path_root), exist_ok=True)
173 | with open(os.path.join(save_path_root, 'reconstruction_results.txt'), 'w') as out_file:
174 |
175 | for dataset_name in datasets:
176 | print(f"{dataset_name.title()}:", file=out_file)
177 | fv_acc = [orig_results[dataset_name][i]['intervention_topk'][0][1] for i in range(n_seeds)]
178 | v100_acc = [zs_results[dataset_name][0][i]['intervention_topk'][0][1] for i in range(n_seeds)]
179 | kl100_val = kl_divs[dataset_name][0]
180 |
181 | vfull_acc = [zs_results[dataset_name][1][i]['intervention_topk'][0][1] for i in range(n_seeds)]
182 | klfull_val = kl_divs[dataset_name][1]
183 |
184 | print("fv results:", np.mean(fv_acc).round(3)*100, '% +/-', np.std(fv_acc).round(3)*100, file=out_file)
185 | print("v_100 results:", np.mean(v100_acc).round(3)*100, '% +/-', np.std(v100_acc).round(3)*100, file=out_file)
186 | print("KL100:", np.mean(kl100_val).round(5), '+/-', np.std(kl100_val).round(5), file=out_file)
187 | print("v_full results:", np.mean(vfull_acc).round(3)*100, '% +/-', np.std(vfull_acc).round(3)*100, file=out_file)
188 | print("KLFull:", np.mean(klfull_val).round(5), '+/-', np.std(klfull_val).round(5), '\n', file=out_file)
189 |
190 | with open(os.path.join(save_path_root, 'reconstruction_args.txt'), 'w') as arg_file:
191 | print(args.__dict__, file=arg_file)
--------------------------------------------------------------------------------