├── .gitignore ├── README.md ├── dataset_files ├── README.md ├── abstractive │ ├── ag_news.json │ ├── antonym.json │ ├── capitalize.json │ ├── capitalize_first_letter.json │ ├── capitalize_last_letter.json │ ├── capitalize_second_letter.json │ ├── commonsense_qa.json │ ├── country-capital.json │ ├── country-currency.json │ ├── english-french.json │ ├── english-german.json │ ├── english-spanish.json │ ├── landmark-country.json │ ├── lowercase_first_letter.json │ ├── lowercase_last_letter.json │ ├── national_parks.json │ ├── next_capital_letter.json │ ├── next_item.json │ ├── park-country.json │ ├── person-instrument.json │ ├── person-occupation.json │ ├── person-sport.json │ ├── present-past.json │ ├── prev_item.json │ ├── product-company.json │ ├── sentiment.json │ ├── singular-plural.json │ ├── synonym.json │ └── word_length.json ├── extractive │ ├── adjective_v_verb_3.json │ ├── adjective_v_verb_5.json │ ├── alphabetically_first_3.json │ ├── alphabetically_first_5.json │ ├── alphabetically_last_3.json │ ├── alphabetically_last_5.json │ ├── animal_v_object_3.json │ ├── animal_v_object_5.json │ ├── choose_first_of_3.json │ ├── choose_first_of_5.json │ ├── choose_last_of_3.json │ ├── choose_last_of_5.json │ ├── choose_middle_of_3.json │ ├── choose_middle_of_5.json │ ├── color_v_animal_3.json │ ├── color_v_animal_5.json │ ├── concept_v_object_3.json │ ├── concept_v_object_5.json │ ├── conll2003_location.json │ ├── conll2003_organization.json │ ├── conll2003_person.json │ ├── fruit_v_animal_3.json │ ├── fruit_v_animal_5.json │ ├── object_v_concept_3.json │ ├── object_v_concept_5.json │ ├── squad_val.json │ ├── verb_v_adjective_3.json │ └── verb_v_adjective_5.json └── generate │ ├── categories.json │ ├── create_antonym_synonym_datasets.py │ ├── create_translation_datasets.py │ ├── task_data_generation.ipynb │ └── translation │ ├── en-de.0-5000.txt │ ├── en-de.5000-6500.txt │ ├── en-es.0-5000.txt │ ├── en-es.5000-6500.txt │ ├── en-fr.0-5000.txt │ └── en-fr.5000-6500.txt ├── fv_environment.yml ├── fv_overview.png ├── notebooks └── fv_demo.ipynb └── src ├── __init__.py ├── compute_average_activations.py ├── compute_avg_hidden_state.py ├── compute_indirect_effect.py ├── eval_scripts ├── eval_avg_hs.sh ├── eval_fv.sh ├── eval_numheads.sh ├── eval_template_portability.sh ├── fv_eval_sweep.py └── template.sh ├── evaluate_function_vector.py ├── natural_text_eval.py ├── portability_eval.py ├── test_numheads.py ├── utils ├── __init__.py ├── eval_utils.py ├── extract_utils.py ├── intervention_utils.py ├── model_utils.py └── prompt_utils.py └── vocab_reconstruction.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Cache and Checkpoints 2 | .ipynb_checkpoints 3 | *.ipynb_checkpoints 4 | __pycache__ 5 | *__pycache__ 6 | 7 | # Results and Figures 8 | *results 9 | results/* 10 | *figures 11 | 12 | # External Data/Repositories 13 | */generate/AntSynNET 14 | 15 | # Weights 16 | *.pth 17 | *.npz 18 | *.npy 19 | *.pt 20 | 21 | # Environments 22 | .env 23 | .venv 24 | .vscode/ 25 | env/ 26 | venv/ 27 | ENV/ 28 | env.bak/ 29 | venv.bak/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Function Vectors in Large Language Models 2 | ### [Project Website](https://functions.baulab.info) | [Arxiv Preprint](https://arxiv.org/abs/2310.15213) | [OpenReview](https://openreview.net/forum?id=AwyxtyMwaG) 3 | 4 | This repository contains data and code for the paper: [Function Vectors in Large Language Models](https://arxiv.org/pdf/2310.15213). 5 | 6 |

7 | 8 |

9 | 10 | ## Setup 11 | 12 | We recommend using conda as a package manager. 13 | The environment used for this project can be found in the `fv_environment.yml` file. 14 | To install, you can run: 15 | ``` 16 | conda env create -f fv_environment.yml 17 | conda activate fv 18 | ``` 19 | 20 | ## Demo Notebook 21 | Checkout `notebooks/fv_demo.ipynb` for a jupyter notebook with a demo of how to create a function vector and use it in different contexts. 22 | 23 | ## Data 24 | The datasets used in our project can be found in the `dataset_files` folder. 25 | 26 | ## Code 27 | Our main evaluation scripts are contained in the `src` directory with sample script wrappers in `src/eval_scripts`. 28 | 29 | Other main code is split into various util files: 30 | - `eval_utils.py` contains code for evaluating function vectors in a variety of contexts 31 | - `extract_utils.py` contains functions for extracting function vectors and other relevant model activations. 32 | - `intervention_utils.py` contains main functionality for intervening with function vectors during inference 33 | - `model_utils.py` contains helpful functions for loading models & tokenizers from huggingface 34 | - `prompt_utils.py` contains data loading and prompt creation functionality 35 | 36 | ## Citing our work 37 | This work appeared at ICLR 2024. The paper can be cited as follows: 38 | 39 | ```bibtex 40 | @inproceedings{todd2024function, 41 | title={Function Vectors in Large Language Models}, 42 | author={Eric Todd and Millicent L. Li and Arnab Sen Sharma and Aaron Mueller and Byron C. Wallace and David Bau}, 43 | booktitle={Proceedings of the 2024 International Conference on Learning Representations}, 44 | url={https://openreview.net/forum?id=AwyxtyMwaG}, 45 | note={arXiv:2310.15213}, 46 | year={2024}, 47 | } 48 | -------------------------------------------------------------------------------- /dataset_files/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | This directory contains two main directories of task datasets all in `.json` format. 4 | * (1) The `abstractive` directory contains tasks which require information that is not present in the prompt to answer. 5 | * (2) The `extractive` directory contains tasks where the answer is present somewhere in the prompt, and the task of the model 6 | is to retrieve it. 7 | 8 | The `generate` directory contains scripts we used to filter existing datasets, as well as a notebook we used to create new datasets, in addition to cleaning and filter additional pre-existing datasets. -------------------------------------------------------------------------------- /dataset_files/abstractive/country-capital.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "output": "Kabul", 4 | "input": "Afghanistan" 5 | }, 6 | { 7 | "output": "Tirana", 8 | "input": "Albania" 9 | }, 10 | { 11 | "output": "Algiers", 12 | "input": "Algeria" 13 | }, 14 | { 15 | "output": "Andorra la Vella", 16 | "input": "Andorra" 17 | }, 18 | { 19 | "output": "Luanda", 20 | "input": "Angola" 21 | }, 22 | { 23 | "output": "St. John's", 24 | "input": "Antigua and Barbuda" 25 | }, 26 | { 27 | "output": "Buenos Aires", 28 | "input": "Argentina" 29 | }, 30 | { 31 | "output": "Yerevan", 32 | "input": "Armenia" 33 | }, 34 | { 35 | "output": "Canberra", 36 | "input": "Australia" 37 | }, 38 | { 39 | "output": "Vienna", 40 | "input": "Austria" 41 | }, 42 | { 43 | "output": "Baku", 44 | "input": "Azerbaijan" 45 | }, 46 | { 47 | "output": "Nassau", 48 | "input": "Bahamas" 49 | }, 50 | { 51 | "output": "Manama", 52 | "input": "Bahrain" 53 | }, 54 | { 55 | "output": "Dhaka", 56 | "input": "Bangladesh" 57 | }, 58 | { 59 | "output": "Bridgetown", 60 | "input": "Barbados" 61 | }, 62 | { 63 | "output": "Minsk", 64 | "input": "Belarus" 65 | }, 66 | { 67 | "output": "Brussels", 68 | "input": "Belgium" 69 | }, 70 | { 71 | "output": "Belmopan", 72 | "input": "Belize" 73 | }, 74 | { 75 | "output": "Porto-Novo", 76 | "input": "Benin" 77 | }, 78 | { 79 | "output": "Thimphu", 80 | "input": "Bhutan" 81 | }, 82 | { 83 | "output": "La Paz", 84 | "input": "Bolivia" 85 | }, 86 | { 87 | "output": "Sarajevo", 88 | "input": "Bosnia and Herzegovina" 89 | }, 90 | { 91 | "output": "Gaborone", 92 | "input": "Botswana" 93 | }, 94 | { 95 | "output": "Brasilia", 96 | "input": "Brazil" 97 | }, 98 | { 99 | "output": "Bandar Seri Begawan", 100 | "input": "Brunei" 101 | }, 102 | { 103 | "output": "Sofia", 104 | "input": "Bulgaria" 105 | }, 106 | { 107 | "output": "Ouagadougou", 108 | "input": "Burkina Faso" 109 | }, 110 | { 111 | "output": "Bujumbura", 112 | "input": "Burundi" 113 | }, 114 | { 115 | "output": "Praia", 116 | "input": "Cabo Verde" 117 | }, 118 | { 119 | "output": "Phnom Penh", 120 | "input": "Cambodia" 121 | }, 122 | { 123 | "output": "Yaounde", 124 | "input": "Cameroon" 125 | }, 126 | { 127 | "output": "Ottawa", 128 | "input": "Canada" 129 | }, 130 | { 131 | "output": "Bangui", 132 | "input": "Central African Republic" 133 | }, 134 | { 135 | "output": "N'Djamena", 136 | "input": "Chad" 137 | }, 138 | { 139 | "output": "Santiago", 140 | "input": "Chile" 141 | }, 142 | { 143 | "output": "Beijing", 144 | "input": "China" 145 | }, 146 | { 147 | "output": "Bogotá", 148 | "input": "Colombia" 149 | }, 150 | { 151 | "output": "Moroni", 152 | "input": "Comoros" 153 | }, 154 | { 155 | "output": "Kinshasa", 156 | "input": "Congo" 157 | }, 158 | { 159 | "output": "San José", 160 | "input": "Costa Rica" 161 | }, 162 | { 163 | "output": "Yamoussoukro", 164 | "input": "Cote d'Ivoire" 165 | }, 166 | { 167 | "output": "Zagreb", 168 | "input": "Croatia" 169 | }, 170 | { 171 | "output": "Havana", 172 | "input": "Cuba" 173 | }, 174 | { 175 | "output": "Nicosia", 176 | "input": "Cyprus" 177 | }, 178 | { 179 | "output": "Prague", 180 | "input": "Czech Republic" 181 | }, 182 | { 183 | "output": "Kinshasa", 184 | "input": "Democratic Republic of the Congo" 185 | }, 186 | { 187 | "output": "Copenhagen", 188 | "input": "Denmark" 189 | }, 190 | { 191 | "output": "Djibouti City", 192 | "input": "Djibouti" 193 | }, 194 | { 195 | "output": "Roseau", 196 | "input": "Dominica" 197 | }, 198 | { 199 | "output": "Santo Domingo", 200 | "input": "Dominican Republic" 201 | }, 202 | { 203 | "output": "Quito", 204 | "input": "Ecuador" 205 | }, 206 | { 207 | "output": "Cairo", 208 | "input": "Egypt" 209 | }, 210 | { 211 | "output": "San Salvador", 212 | "input": "El Salvador" 213 | }, 214 | { 215 | "output": "Malabo", 216 | "input": "Equatorial Guinea" 217 | }, 218 | { 219 | "output": "Asmara", 220 | "input": "Eritrea" 221 | }, 222 | { 223 | "output": "Tallinn", 224 | "input": "Estonia" 225 | }, 226 | { 227 | "output": "Mbabane", 228 | "input": "Eswatini" 229 | }, 230 | { 231 | "output": "Addis Ababa", 232 | "input": "Ethiopia" 233 | }, 234 | { 235 | "output": "Suva", 236 | "input": "Fiji" 237 | }, 238 | { 239 | "output": "Helsinki", 240 | "input": "Finland" 241 | }, 242 | { 243 | "output": "Paris", 244 | "input": "France" 245 | }, 246 | { 247 | "output": "Libreville", 248 | "input": "Gabon" 249 | }, 250 | { 251 | "output": "Banjul", 252 | "input": "Gambia" 253 | }, 254 | { 255 | "output": "Tbilisi", 256 | "input": "Georgia" 257 | }, 258 | { 259 | "output": "Berlin", 260 | "input": "Germany" 261 | }, 262 | { 263 | "output": "Accra", 264 | "input": "Ghana" 265 | }, 266 | { 267 | "output": "Athens", 268 | "input": "Greece" 269 | }, 270 | { 271 | "output": "St. George's", 272 | "input": "Grenada" 273 | }, 274 | { 275 | "output": "Guatemala City", 276 | "input": "Guatemala" 277 | }, 278 | { 279 | "output": "Conakry", 280 | "input": "Guinea" 281 | }, 282 | { 283 | "output": "Bissau", 284 | "input": "Guinea-Bissau" 285 | }, 286 | { 287 | "output": "Georgetown", 288 | "input": "Guyana" 289 | }, 290 | { 291 | "output": "Port-au-Prince", 292 | "input": "Haiti" 293 | }, 294 | { 295 | "output": "Tegucigalpa", 296 | "input": "Honduras" 297 | }, 298 | { 299 | "output": "Budapest", 300 | "input": "Hungary" 301 | }, 302 | { 303 | "output": "Reykjavik", 304 | "input": "Iceland" 305 | }, 306 | { 307 | "output": "New Delhi", 308 | "input": "India" 309 | }, 310 | { 311 | "output": "Jakarta", 312 | "input": "Indonesia" 313 | }, 314 | { 315 | "output": "Tehran", 316 | "input": "Iran" 317 | }, 318 | { 319 | "output": "Baghdad", 320 | "input": "Iraq" 321 | }, 322 | { 323 | "output": "Dublin", 324 | "input": "Ireland" 325 | }, 326 | { 327 | "output": "Jerusalem", 328 | "input": "Israel" 329 | }, 330 | { 331 | "output": "Rome", 332 | "input": "Italy" 333 | }, 334 | { 335 | "output": "Kingston", 336 | "input": "Jamaica" 337 | }, 338 | { 339 | "output": "Tokyo", 340 | "input": "Japan" 341 | }, 342 | { 343 | "output": "Amman", 344 | "input": "Jordan" 345 | }, 346 | { 347 | "output": "Astana", 348 | "input": "Kazakhstan" 349 | }, 350 | { 351 | "output": "Nairobi", 352 | "input": "Kenya" 353 | }, 354 | { 355 | "output": "South Tarawa", 356 | "input": "Kiribati" 357 | }, 358 | { 359 | "output": "Pristina", 360 | "input": "Kosovo" 361 | }, 362 | { 363 | "output": "Kuwait City", 364 | "input": "Kuwait" 365 | }, 366 | { 367 | "output": "Bishkek", 368 | "input": "Kyrgyzstan" 369 | }, 370 | { 371 | "output": "Vientiane", 372 | "input": "Laos" 373 | }, 374 | { 375 | "output": "Riga", 376 | "input": "Latvia" 377 | }, 378 | { 379 | "output": "Beirut", 380 | "input": "Lebanon" 381 | }, 382 | { 383 | "output": "Maseru", 384 | "input": "Lesotho" 385 | }, 386 | { 387 | "output": "Monrovia", 388 | "input": "Liberia" 389 | }, 390 | { 391 | "output": "Tripoli", 392 | "input": "Libya" 393 | }, 394 | { 395 | "output": "Vaduz", 396 | "input": "Liechtenstein" 397 | }, 398 | { 399 | "output": "Vilnius", 400 | "input": "Lithuania" 401 | }, 402 | { 403 | "output": "Luxembourg City", 404 | "input": "Luxembourg" 405 | }, 406 | { 407 | "output": "Antananarivo", 408 | "input": "Madagascar" 409 | }, 410 | { 411 | "output": "Lilongwe", 412 | "input": "Malawi" 413 | }, 414 | { 415 | "output": "Kuala Lumpur", 416 | "input": "Malaysia" 417 | }, 418 | { 419 | "output": "Malé", 420 | "input": "Maldives" 421 | }, 422 | { 423 | "output": "Bamako", 424 | "input": "Mali" 425 | }, 426 | { 427 | "output": "Valletta", 428 | "input": "Malta" 429 | }, 430 | { 431 | "output": "Majuro", 432 | "input": "Marshall Islands" 433 | }, 434 | { 435 | "output": "Nouakchott", 436 | "input": "Mauritania" 437 | }, 438 | { 439 | "output": "Port Louis", 440 | "input": "Mauritius" 441 | }, 442 | { 443 | "output": "Mexico City", 444 | "input": "Mexico" 445 | }, 446 | { 447 | "output": "Palikir", 448 | "input": "Micronesia" 449 | }, 450 | { 451 | "output": "Chisinau", 452 | "input": "Moldova" 453 | }, 454 | { 455 | "output": "Monaco-Ville", 456 | "input": "Monaco" 457 | }, 458 | { 459 | "output": "Ulaanbaatar", 460 | "input": "Mongolia" 461 | }, 462 | { 463 | "output": "Podgorica", 464 | "input": "Montenegro" 465 | }, 466 | { 467 | "output": "Rabat", 468 | "input": "Morocco" 469 | }, 470 | { 471 | "output": "Maputo", 472 | "input": "Mozambique" 473 | }, 474 | { 475 | "output": "Naypyidaw", 476 | "input": "Myanmar" 477 | }, 478 | { 479 | "output": "Windhoek", 480 | "input": "Namibia" 481 | }, 482 | { 483 | "output": "Yaren District", 484 | "input": "Nauru" 485 | }, 486 | { 487 | "output": "Kathmandu", 488 | "input": "Nepal" 489 | }, 490 | { 491 | "output": "Amsterdam", 492 | "input": "Netherlands" 493 | }, 494 | { 495 | "output": "Wellington", 496 | "input": "New Zealand" 497 | }, 498 | { 499 | "output": "Managua", 500 | "input": "Nicaragua" 501 | }, 502 | { 503 | "output": "Niamey", 504 | "input": "Niger" 505 | }, 506 | { 507 | "output": "Abuja", 508 | "input": "Nigeria" 509 | }, 510 | { 511 | "output": "Pyongyang", 512 | "input": "North Korea" 513 | }, 514 | { 515 | "output": "Skopje", 516 | "input": "North Macedonia" 517 | }, 518 | { 519 | "output": "Oslo", 520 | "input": "Norway" 521 | }, 522 | { 523 | "output": "Muscat", 524 | "input": "Oman" 525 | }, 526 | { 527 | "output": "Islamabad", 528 | "input": "Pakistan" 529 | }, 530 | { 531 | "output": "Ngerulmud", 532 | "input": "Palau" 533 | }, 534 | { 535 | "output": "Ramallah", 536 | "input": "Palestine" 537 | }, 538 | { 539 | "output": "Panama City", 540 | "input": "Panama" 541 | }, 542 | { 543 | "output": "Port Moresby", 544 | "input": "Papua New Guinea" 545 | }, 546 | { 547 | "output": "Asunción", 548 | "input": "Paraguay" 549 | }, 550 | { 551 | "output": "Lima", 552 | "input": "Peru" 553 | }, 554 | { 555 | "output": "Manila", 556 | "input": "Philippines" 557 | }, 558 | { 559 | "output": "Warsaw", 560 | "input": "Poland" 561 | }, 562 | { 563 | "output": "Lisbon", 564 | "input": "Portugal" 565 | }, 566 | { 567 | "output": "Doha", 568 | "input": "Qatar" 569 | }, 570 | { 571 | "output": "Bucharest", 572 | "input": "Romania" 573 | }, 574 | { 575 | "output": "Moscow", 576 | "input": "Russia" 577 | }, 578 | { 579 | "output": "Kigali", 580 | "input": "Rwanda" 581 | }, 582 | { 583 | "output": "Basseterre", 584 | "input": "Saint Kitts and Nevis" 585 | }, 586 | { 587 | "output": "Castries", 588 | "input": "Saint Lucia" 589 | }, 590 | { 591 | "output": "Kingstown", 592 | "input": "Saint Vincent and the Grenadines" 593 | }, 594 | { 595 | "output": "Apia", 596 | "input": "Samoa" 597 | }, 598 | { 599 | "output": "San Marino", 600 | "input": "San Marino" 601 | }, 602 | { 603 | "output": "Sao Tome", 604 | "input": "Sao Tome and Principe" 605 | }, 606 | { 607 | "output": "Riyadh", 608 | "input": "Saudi Arabia" 609 | }, 610 | { 611 | "output": "Dakar", 612 | "input": "Senegal" 613 | }, 614 | { 615 | "output": "Belgrade", 616 | "input": "Serbia" 617 | }, 618 | { 619 | "output": "Victoria", 620 | "input": "Seychelles" 621 | }, 622 | { 623 | "output": "Freetown", 624 | "input": "Sierra Leone" 625 | }, 626 | { 627 | "output": "Singapore", 628 | "input": "Singapore" 629 | }, 630 | { 631 | "output": "Bratislava", 632 | "input": "Slovakia" 633 | }, 634 | { 635 | "output": "Ljubljana", 636 | "input": "Slovenia" 637 | }, 638 | { 639 | "output": "Honiara", 640 | "input": "Solomon Islands" 641 | }, 642 | { 643 | "output": "Mogadishu", 644 | "input": "Somalia" 645 | }, 646 | { 647 | "output": "Pretoria", 648 | "input": "South Africa" 649 | }, 650 | { 651 | "output": "Seoul", 652 | "input": "South Korea" 653 | }, 654 | { 655 | "output": "Juba", 656 | "input": "South Sudan" 657 | }, 658 | { 659 | "output": "Madrid", 660 | "input": "Spain" 661 | }, 662 | { 663 | "output": "Colombo", 664 | "input": "Sri Lanka" 665 | }, 666 | { 667 | "output": "Khartoum", 668 | "input": "Sudan" 669 | }, 670 | { 671 | "output": "Paramaribo", 672 | "input": "Suriname" 673 | }, 674 | { 675 | "output": "Stockholm", 676 | "input": "Sweden" 677 | }, 678 | { 679 | "output": "Bern", 680 | "input": "Switzerland" 681 | }, 682 | { 683 | "output": "Damascus", 684 | "input": "Syria" 685 | }, 686 | { 687 | "output": "Taipei", 688 | "input": "Taiwan" 689 | }, 690 | { 691 | "output": "Dushanbe", 692 | "input": "Tajikistan" 693 | }, 694 | { 695 | "output": "Dodoma", 696 | "input": "Tanzania" 697 | }, 698 | { 699 | "output": "Bangkok", 700 | "input": "Thailand" 701 | }, 702 | { 703 | "output": "Dili", 704 | "input": "Timor-Leste" 705 | }, 706 | { 707 | "output": "Lome", 708 | "input": "Togo" 709 | }, 710 | { 711 | "output": "Nukuʻalofa", 712 | "input": "Tonga" 713 | }, 714 | { 715 | "output": "Port of Spain", 716 | "input": "Trinidad and Tobago" 717 | }, 718 | { 719 | "output": "Tunis", 720 | "input": "Tunisia" 721 | }, 722 | { 723 | "output": "Ankara", 724 | "input": "Turkey" 725 | }, 726 | { 727 | "output": "Ashgabat", 728 | "input": "Turkmenistan" 729 | }, 730 | { 731 | "output": "Funafuti", 732 | "input": "Tuvalu" 733 | }, 734 | { 735 | "output": "Kampala", 736 | "input": "Uganda" 737 | }, 738 | { 739 | "output": "Kiev", 740 | "input": "Ukraine" 741 | }, 742 | { 743 | "output": "Abu Dhabi", 744 | "input": "United Arab Emirates" 745 | }, 746 | { 747 | "output": "London", 748 | "input": "United Kingdom" 749 | }, 750 | { 751 | "output": "Washington, D.C.", 752 | "input": "United States of America" 753 | }, 754 | { 755 | "output": "Montevideo", 756 | "input": "Uruguay" 757 | }, 758 | { 759 | "output": "Tashkent", 760 | "input": "Uzbekistan" 761 | }, 762 | { 763 | "output": "Port Vila", 764 | "input": "Vanuatu" 765 | }, 766 | { 767 | "output": "Vatican City", 768 | "input": "Vatican City" 769 | }, 770 | { 771 | "output": "Caracas", 772 | "input": "Venezuela" 773 | }, 774 | { 775 | "output": "Hanoi", 776 | "input": "Vietnam" 777 | }, 778 | { 779 | "output": "Sana'a", 780 | "input": "Yemen" 781 | }, 782 | { 783 | "output": "Lusaka", 784 | "input": "Zambia" 785 | }, 786 | { 787 | "output": "Harare", 788 | "input": "Zimbabwe" 789 | } 790 | ] -------------------------------------------------------------------------------- /dataset_files/abstractive/country-currency.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "output": "Afghani (AFN)", 4 | "input": "Afghanistan" 5 | }, 6 | { 7 | "output": "Albanian Lek (ALL)", 8 | "input": "Albania" 9 | }, 10 | { 11 | "output": "Algerian Dinar", 12 | "input": "Algeria" 13 | }, 14 | { 15 | "output": "Euro (EUR)", 16 | "input": "Andorra" 17 | }, 18 | { 19 | "output": "Kwanza (AOA)", 20 | "input": "Angola" 21 | }, 22 | { 23 | "output": "East Caribbean Dollar (XCD)", 24 | "input": "Antigua and Barbuda" 25 | }, 26 | { 27 | "output": "Argentine Peso", 28 | "input": "Argentina" 29 | }, 30 | { 31 | "output": "Dram (AMD)", 32 | "input": "Armenia" 33 | }, 34 | { 35 | "output": "Australian Dollar (AUD)", 36 | "input": "Australia" 37 | }, 38 | { 39 | "output": "Euro (EUR)", 40 | "input": "Austria" 41 | }, 42 | { 43 | "output": "Manat", 44 | "input": "Azerbaijan" 45 | }, 46 | { 47 | "output": "Bahamian Dollar", 48 | "input": "Bahamas" 49 | }, 50 | { 51 | "output": "Bahraini Dinar (BHD)", 52 | "input": "Bahrain" 53 | }, 54 | { 55 | "output": "Taka", 56 | "input": "Bangladesh" 57 | }, 58 | { 59 | "output": "Barbadian Dollar (BBD)", 60 | "input": "Barbados" 61 | }, 62 | { 63 | "output": "Belarusian Ruble (BYN)", 64 | "input": "Belarus" 65 | }, 66 | { 67 | "output": "Euro (EUR)", 68 | "input": "Belgium" 69 | }, 70 | { 71 | "output": "Belize Dollar (BZD)", 72 | "input": "Belize" 73 | }, 74 | { 75 | "output": "CFA Franc (XOF)", 76 | "input": "Benin" 77 | }, 78 | { 79 | "output": "Ngultrum (BTN)", 80 | "input": "Bhutan" 81 | }, 82 | { 83 | "output": "Bolivian Boliviano (BOB)", 84 | "input": "Bolivia" 85 | }, 86 | { 87 | "output": "Convertible Mark (KM or BAM)", 88 | "input": "Bosnia and Herzegovina" 89 | }, 90 | { 91 | "output": "Pula", 92 | "input": "Botswana" 93 | }, 94 | { 95 | "output": "Brazilian Real (BRL)", 96 | "input": "Brazil" 97 | }, 98 | { 99 | "output": "Brunei Dollar (BND)", 100 | "input": "Brunei" 101 | }, 102 | { 103 | "output": "Bulgarian Lev (BGN)", 104 | "input": "Bulgaria" 105 | }, 106 | { 107 | "output": "West African CFA franc", 108 | "input": "Burkina Faso" 109 | }, 110 | { 111 | "output": "Burundian Franc (BIF)", 112 | "input": "Burundi" 113 | }, 114 | { 115 | "output": "Escudo (CVE)", 116 | "input": "Cabo Verde" 117 | }, 118 | { 119 | "output": "Riel", 120 | "input": "Cambodia" 121 | }, 122 | { 123 | "output": "Central African CFA franc", 124 | "input": "Cameroon" 125 | }, 126 | { 127 | "output": "Canadian Dollar (CAD)", 128 | "input": "Canada" 129 | }, 130 | { 131 | "output": "Central African CFA franc", 132 | "input": "Central African Republic" 133 | }, 134 | { 135 | "output": "Central African CFA franc", 136 | "input": "Chad" 137 | }, 138 | { 139 | "output": "Chilean Peso", 140 | "input": "Chile" 141 | }, 142 | { 143 | "output": "Renminbi (RMB)", 144 | "input": "China" 145 | }, 146 | { 147 | "output": "Colombian Peso", 148 | "input": "Colombia" 149 | }, 150 | { 151 | "output": "Comorian Franc", 152 | "input": "Comoros" 153 | }, 154 | { 155 | "output": "Congolese Franc (CDF)", 156 | "input": "Congo" 157 | }, 158 | { 159 | "output": "Colón (CRC)", 160 | "input": "Costa Rica" 161 | }, 162 | { 163 | "output": "CFA Franc (XOF)", 164 | "input": "Cote d'Ivoire" 165 | }, 166 | { 167 | "output": "Kuna (HRK)", 168 | "input": "Croatia" 169 | }, 170 | { 171 | "output": "Cuban Peso (CUP)", 172 | "input": "Cuba" 173 | }, 174 | { 175 | "output": "Euro (EUR)", 176 | "input": "Cyprus" 177 | }, 178 | { 179 | "output": "Czech Koruna (CZK)", 180 | "input": "Czech Republic" 181 | }, 182 | { 183 | "output": "Congolese Franc (CDF)", 184 | "input": "Democratic Republic of the Congo" 185 | }, 186 | { 187 | "output": "Danish Krone", 188 | "input": "Denmark" 189 | }, 190 | { 191 | "output": "Djiboutian Franc (DJF)", 192 | "input": "Djibouti" 193 | }, 194 | { 195 | "output": "East Caribbean Dollar (XCD)", 196 | "input": "Dominica" 197 | }, 198 | { 199 | "output": "Dominican Peso", 200 | "input": "Dominican Republic" 201 | }, 202 | { 203 | "output": "US Dollar (USD)", 204 | "input": "Ecuador" 205 | }, 206 | { 207 | "output": "Egyptian Pound (EGP)", 208 | "input": "Egypt" 209 | }, 210 | { 211 | "output": "US Dollar (USD)", 212 | "input": "El Salvador" 213 | }, 214 | { 215 | "output": "Central African CFA franc", 216 | "input": "Equatorial Guinea" 217 | }, 218 | { 219 | "output": "Nakfa", 220 | "input": "Eritrea" 221 | }, 222 | { 223 | "output": "Euro (EUR)", 224 | "input": "Estonia" 225 | }, 226 | { 227 | "output": "Lilangeni", 228 | "input": "Eswatini" 229 | }, 230 | { 231 | "output": "Ethiopian Birr", 232 | "input": "Ethiopia" 233 | }, 234 | { 235 | "output": "Fijian Dollar (FJD)", 236 | "input": "Fiji" 237 | }, 238 | { 239 | "output": "Euro (EUR)", 240 | "input": "Finland" 241 | }, 242 | { 243 | "output": "Euro (EUR)", 244 | "input": "France" 245 | }, 246 | { 247 | "output": "Central African CFA franc", 248 | "input": "Gabon" 249 | }, 250 | { 251 | "output": "Dalasi (GMD)", 252 | "input": "Gambia" 253 | }, 254 | { 255 | "output": "Lari", 256 | "input": "Georgia" 257 | }, 258 | { 259 | "output": "Euro (EUR)", 260 | "input": "Germany" 261 | }, 262 | { 263 | "output": "Ghana Cedi (GHS)", 264 | "input": "Ghana" 265 | }, 266 | { 267 | "output": "Euro (EUR)", 268 | "input": "Greece" 269 | }, 270 | { 271 | "output": "East Caribbean Dollar (XCD)", 272 | "input": "Grenada" 273 | }, 274 | { 275 | "output": "Quetzal", 276 | "input": "Guatemala" 277 | }, 278 | { 279 | "output": "Guinean Franc", 280 | "input": "Guinea" 281 | }, 282 | { 283 | "output": "West African CFA franc (XOF)", 284 | "input": "Guinea-Bissau" 285 | }, 286 | { 287 | "output": "Guyanese Dollar", 288 | "input": "Guyana" 289 | }, 290 | { 291 | "output": "Gourde", 292 | "input": "Haiti" 293 | }, 294 | { 295 | "output": "Lempira", 296 | "input": "Honduras" 297 | }, 298 | { 299 | "output": "Forint (HUF)", 300 | "input": "Hungary" 301 | }, 302 | { 303 | "output": "Icelandic Króna (ISK)", 304 | "input": "Iceland" 305 | }, 306 | { 307 | "output": "Indian Rupee", 308 | "input": "India" 309 | }, 310 | { 311 | "output": "Indonesian Rupiah", 312 | "input": "Indonesia" 313 | }, 314 | { 315 | "output": "Iranian Rial", 316 | "input": "Iran" 317 | }, 318 | { 319 | "output": "Iraqi Dinar (IQD)", 320 | "input": "Iraq" 321 | }, 322 | { 323 | "output": "Euro (EUR)", 324 | "input": "Ireland" 325 | }, 326 | { 327 | "output": "Israeli Shekel", 328 | "input": "Israel" 329 | }, 330 | { 331 | "output": "Euro (EUR)", 332 | "input": "Italy" 333 | }, 334 | { 335 | "output": "Jamaican Dollar (JMD)", 336 | "input": "Jamaica" 337 | }, 338 | { 339 | "output": "Japanese Yen", 340 | "input": "Japan" 341 | }, 342 | { 343 | "output": "Jordanian Dinar (JOD)", 344 | "input": "Jordan" 345 | }, 346 | { 347 | "output": "Tenge", 348 | "input": "Kazakhstan" 349 | }, 350 | { 351 | "output": "Kenyan Shilling (KES)", 352 | "input": "Kenya" 353 | }, 354 | { 355 | "output": "Australian Dollar (AUD)", 356 | "input": "Kiribati" 357 | }, 358 | { 359 | "output": "Euro (EUR)", 360 | "input": "Kosovo" 361 | }, 362 | { 363 | "output": "Kuwaiti Dinar (KWD)", 364 | "input": "Kuwait" 365 | }, 366 | { 367 | "output": "Som (KGS)", 368 | "input": "Kyrgyzstan" 369 | }, 370 | { 371 | "output": "Lao Kip (LAK)", 372 | "input": "Laos" 373 | }, 374 | { 375 | "output": "Latvian Lats (LVL)", 376 | "input": "Latvia" 377 | }, 378 | { 379 | "output": "Lebanese Pound (LBP)", 380 | "input": "Lebanon" 381 | }, 382 | { 383 | "output": "Loti (LSL)", 384 | "input": "Lesotho" 385 | }, 386 | { 387 | "output": "Liberian Dollar", 388 | "input": "Liberia" 389 | }, 390 | { 391 | "output": "Libyan Dinar (LYD)", 392 | "input": "Libya" 393 | }, 394 | { 395 | "output": "Swiss Franc (CHF)", 396 | "input": "Liechtenstein" 397 | }, 398 | { 399 | "output": "Lithuanian Litas (LTL)", 400 | "input": "Lithuania" 401 | }, 402 | { 403 | "output": "Euro (EUR)", 404 | "input": "Luxembourg" 405 | }, 406 | { 407 | "output": "Ariary", 408 | "input": "Madagascar" 409 | }, 410 | { 411 | "output": "Malawian Kwacha (MWK)", 412 | "input": "Malawi" 413 | }, 414 | { 415 | "output": "Malaysian Ringgit (MYR)", 416 | "input": "Malaysia" 417 | }, 418 | { 419 | "output": "Maldivian Rufiyaa (MVR)", 420 | "input": "Maldives" 421 | }, 422 | { 423 | "output": "CFA Franc (XOF)", 424 | "input": "Mali" 425 | }, 426 | { 427 | "output": "Euro (EUR)", 428 | "input": "Malta" 429 | }, 430 | { 431 | "output": "US Dollar (USD)", 432 | "input": "Marshall Islands" 433 | }, 434 | { 435 | "output": "Ouguiya (MRO)", 436 | "input": "Mauritania" 437 | }, 438 | { 439 | "output": "Mauritian Rupee", 440 | "input": "Mauritius" 441 | }, 442 | { 443 | "output": "Mexican Peso", 444 | "input": "Mexico" 445 | }, 446 | { 447 | "output": "US Dollar (USD)", 448 | "input": "Micronesia" 449 | }, 450 | { 451 | "output": "Moldovan Leu (MDL)", 452 | "input": "Moldova" 453 | }, 454 | { 455 | "output": "Euro (EUR)", 456 | "input": "Monaco" 457 | }, 458 | { 459 | "output": "Tugrik (MNT)", 460 | "input": "Mongolia" 461 | }, 462 | { 463 | "output": "Euro (EUR)", 464 | "input": "Montenegro" 465 | }, 466 | { 467 | "output": "Moroccan Dirham (MAD)", 468 | "input": "Morocco" 469 | }, 470 | { 471 | "output": "Metical (MZN)", 472 | "input": "Mozambique" 473 | }, 474 | { 475 | "output": "Kyat", 476 | "input": "Myanmar" 477 | }, 478 | { 479 | "output": "Namibian Dollar (NAD)", 480 | "input": "Namibia" 481 | }, 482 | { 483 | "output": "Australian Dollar (AUD)", 484 | "input": "Nauru" 485 | }, 486 | { 487 | "output": "Nepalese Rupee", 488 | "input": "Nepal" 489 | }, 490 | { 491 | "output": "Euro (EUR)", 492 | "input": "Netherlands" 493 | }, 494 | { 495 | "output": "New Zealand Dollar (NZD)", 496 | "input": "New Zealand" 497 | }, 498 | { 499 | "output": "Córdoba (NIO)", 500 | "input": "Nicaragua" 501 | }, 502 | { 503 | "output": "Naira", 504 | "input": "Niger" 505 | }, 506 | { 507 | "output": "Naira", 508 | "input": "Nigeria" 509 | }, 510 | { 511 | "output": "North Korean Won (KPW)", 512 | "input": "North Korea" 513 | }, 514 | { 515 | "output": "Macedonian Denar (MKD)", 516 | "input": "North Macedonia" 517 | }, 518 | { 519 | "output": "Norwegian Krone (NOK)", 520 | "input": "Norway" 521 | }, 522 | { 523 | "output": "Omani Rial", 524 | "input": "Oman" 525 | }, 526 | { 527 | "output": "Pakistani Rupee", 528 | "input": "Pakistan" 529 | }, 530 | { 531 | "output": "US Dollar (USD)", 532 | "input": "Palau" 533 | }, 534 | { 535 | "output": "Israeli New Shekel (ILS)", 536 | "input": "Palestine" 537 | }, 538 | { 539 | "output": "Balboa (PAB)", 540 | "input": "Panama" 541 | }, 542 | { 543 | "output": "Kina (PGK)", 544 | "input": "Papua New Guinea" 545 | }, 546 | { 547 | "output": "Guarani (PYG)", 548 | "input": "Paraguay" 549 | }, 550 | { 551 | "output": "Sol (PEN)", 552 | "input": "Peru" 553 | }, 554 | { 555 | "output": "Philippine Peso", 556 | "input": "Philippines" 557 | }, 558 | { 559 | "output": "Polish Zloty (PLN)", 560 | "input": "Poland" 561 | }, 562 | { 563 | "output": "Euro (EUR)", 564 | "input": "Portugal" 565 | }, 566 | { 567 | "output": "Qatari Riyal", 568 | "input": "Qatar" 569 | }, 570 | { 571 | "output": "Romanian Leu (RON)", 572 | "input": "Romania" 573 | }, 574 | { 575 | "output": "Russian Ruble", 576 | "input": "Russia" 577 | }, 578 | { 579 | "output": "Rwandan Franc (RWF)", 580 | "input": "Rwanda" 581 | }, 582 | { 583 | "output": "East Caribbean Dollar (XCD)", 584 | "input": "Saint Kitts and Nevis" 585 | }, 586 | { 587 | "output": "East Caribbean Dollar (XCD)", 588 | "input": "Saint Lucia" 589 | }, 590 | { 591 | "output": "East Caribbean Dollar (XCD)", 592 | "input": "Saint Vincent and the Grenadines" 593 | }, 594 | { 595 | "output": "Tala", 596 | "input": "Samoa" 597 | }, 598 | { 599 | "output": "Euro (EUR)", 600 | "input": "San Marino" 601 | }, 602 | { 603 | "output": "Dobra (STD)", 604 | "input": "Sao Tome and Principe" 605 | }, 606 | { 607 | "output": "Saudi Riyal", 608 | "input": "Saudi Arabia" 609 | }, 610 | { 611 | "output": "West African CFA franc", 612 | "input": "Senegal" 613 | }, 614 | { 615 | "output": "Serbian Dinar (RSD)", 616 | "input": "Serbia" 617 | }, 618 | { 619 | "output": "Seychellois Rupee (SCR)", 620 | "input": "Seychelles" 621 | }, 622 | { 623 | "output": "Leone (SLL)", 624 | "input": "Sierra Leone" 625 | }, 626 | { 627 | "output": "Singapore Dollar (SGD)", 628 | "input": "Singapore" 629 | }, 630 | { 631 | "output": "Euro (EUR)", 632 | "input": "Slovakia" 633 | }, 634 | { 635 | "output": "Euro (EUR)", 636 | "input": "Slovenia" 637 | }, 638 | { 639 | "output": "Solomon Islands Dollar (SBD)", 640 | "input": "Solomon Islands" 641 | }, 642 | { 643 | "output": "Somali Shilling (SOS)", 644 | "input": "Somalia" 645 | }, 646 | { 647 | "output": "South African Rand (ZAR)", 648 | "input": "South Africa" 649 | }, 650 | { 651 | "output": "South Korean Won (KRW)", 652 | "input": "South Korea" 653 | }, 654 | { 655 | "output": "South Sudanese Pound (SSP)", 656 | "input": "South Sudan" 657 | }, 658 | { 659 | "output": "Euro (EUR)", 660 | "input": "Spain" 661 | }, 662 | { 663 | "output": "Sri Lankan Rupee", 664 | "input": "Sri Lanka" 665 | }, 666 | { 667 | "output": "Sudanese Pound (SDG)", 668 | "input": "Sudan" 669 | }, 670 | { 671 | "output": "Surinamese Dollar (SRD)", 672 | "input": "Suriname" 673 | }, 674 | { 675 | "output": "Swedish Krona (SEK)", 676 | "input": "Sweden" 677 | }, 678 | { 679 | "output": "Swiss Franc (CHF)", 680 | "input": "Switzerland" 681 | }, 682 | { 683 | "output": "Syrian Pound", 684 | "input": "Syria" 685 | }, 686 | { 687 | "output": "New Taiwan Dollar (TWD)", 688 | "input": "Taiwan" 689 | }, 690 | { 691 | "output": "Tajikistani Somoni", 692 | "input": "Tajikistan" 693 | }, 694 | { 695 | "output": "Tanzanian Shilling", 696 | "input": "Tanzania" 697 | }, 698 | { 699 | "output": "Thai Baht", 700 | "input": "Thailand" 701 | }, 702 | { 703 | "output": "US Dollar (USD)", 704 | "input": "Timor-Leste" 705 | }, 706 | { 707 | "output": "CFA Franc", 708 | "input": "Togo" 709 | }, 710 | { 711 | "output": "Pa'anga (TOP)", 712 | "input": "Tonga" 713 | }, 714 | { 715 | "output": "Trinidad and Tobago Dollar (TTD)", 716 | "input": "Trinidad and Tobago" 717 | }, 718 | { 719 | "output": "Tunisian Dinar (TND)", 720 | "input": "Tunisia" 721 | }, 722 | { 723 | "output": "Turkish Lira (TRY)", 724 | "input": "Turkey" 725 | }, 726 | { 727 | "output": "Turkmenistani Manat", 728 | "input": "Turkmenistan" 729 | }, 730 | { 731 | "output": "Tuvaluan Dollar (TVD)", 732 | "input": "Tuvalu" 733 | }, 734 | { 735 | "output": "Ugandan Shilling", 736 | "input": "Uganda" 737 | }, 738 | { 739 | "output": "Ukrainian Hryvnia (UAH)", 740 | "input": "Ukraine" 741 | }, 742 | { 743 | "output": "UAE Dirham", 744 | "input": "United Arab Emirates" 745 | }, 746 | { 747 | "output": "British Pound (GBP)", 748 | "input": "United Kingdom" 749 | }, 750 | { 751 | "output": "US Dollar (USD)", 752 | "input": "United States of America" 753 | }, 754 | { 755 | "output": "Uruguayan Peso", 756 | "input": "Uruguay" 757 | }, 758 | { 759 | "output": "Uzbekistani Som (UZS)", 760 | "input": "Uzbekistan" 761 | }, 762 | { 763 | "output": "Vatu (VUV)", 764 | "input": "Vanuatu" 765 | }, 766 | { 767 | "output": "Euro (EUR)", 768 | "input": "Vatican City" 769 | }, 770 | { 771 | "output": "Bolívar Soberano (VES)", 772 | "input": "Venezuela" 773 | }, 774 | { 775 | "output": "Vietnamese Dong (VND)", 776 | "input": "Vietnam" 777 | }, 778 | { 779 | "output": "Yemeni Rial", 780 | "input": "Yemen" 781 | }, 782 | { 783 | "output": "Kwacha", 784 | "input": "Zambia" 785 | }, 786 | { 787 | "output": "Zimbabwean Dollar", 788 | "input": "Zimbabwe" 789 | } 790 | ] -------------------------------------------------------------------------------- /dataset_files/abstractive/next_item.json: -------------------------------------------------------------------------------- 1 | [{"input": "zero", "output": "one"}, {"input": "one", "output": "two"}, {"input": "two", "output": "three"}, {"input": "three", "output": "four"}, {"input": "four", "output": "five"}, {"input": "five", "output": "six"}, {"input": "six", "output": "seven"}, {"input": "seven", "output": "eight"}, {"input": "eight", "output": "nine"}, {"input": "nine", "output": "ten"}, {"input": "ten", "output": "eleven"}, {"input": "eleven", "output": "twelve"}, {"input": "twelve", "output": "thirteen"}, {"input": "thirteen", "output": "fourteen"}, {"input": "fourteen", "output": "fifteen"}, {"input": "fifteen", "output": "sixteen"}, {"input": "sixteen", "output": "seventeen"}, {"input": "seventeen", "output": "eighteen"}, {"input": "eighteen", "output": "nineteen"}, {"input": "nineteen", "output": "twenty"}, {"input": "0", "output": "1"}, {"input": "1", "output": "2"}, {"input": "2", "output": "3"}, {"input": "3", "output": "4"}, {"input": "4", "output": "5"}, {"input": "5", "output": "6"}, {"input": "6", "output": "7"}, {"input": "7", "output": "8"}, {"input": "8", "output": "9"}, {"input": "9", "output": "10"}, {"input": "10", "output": "11"}, {"input": "11", "output": "12"}, {"input": "12", "output": "13"}, {"input": "13", "output": "14"}, {"input": "14", "output": "15"}, {"input": "15", "output": "16"}, {"input": "16", "output": "17"}, {"input": "17", "output": "18"}, {"input": "18", "output": "19"}, {"input": "19", "output": "20"}, {"input": "20", "output": "21"}, {"input": "21", "output": "22"}, {"input": "22", "output": "23"}, {"input": "23", "output": "24"}, {"input": "24", "output": "25"}, {"input": "25", "output": "26"}, {"input": "26", "output": "27"}, {"input": "27", "output": "28"}, {"input": "28", "output": "29"}, {"input": "a", "output": "b"}, {"input": "b", "output": "c"}, {"input": "c", "output": "d"}, {"input": "d", "output": "e"}, {"input": "e", "output": "f"}, {"input": "f", "output": "g"}, {"input": "g", "output": "h"}, {"input": "h", "output": "i"}, {"input": "i", "output": "j"}, {"input": "j", "output": "k"}, {"input": "k", "output": "l"}, {"input": "l", "output": "m"}, {"input": "m", "output": "n"}, {"input": "n", "output": "o"}, {"input": "o", "output": "p"}, {"input": "p", "output": "q"}, {"input": "q", "output": "r"}, {"input": "r", "output": "s"}, {"input": "s", "output": "t"}, {"input": "t", "output": "u"}, {"input": "u", "output": "v"}, {"input": "v", "output": "w"}, {"input": "w", "output": "x"}, {"input": "x", "output": "y"}, {"input": "y", "output": "z"}, {"input": "A", "output": "B"}, {"input": "B", "output": "C"}, {"input": "C", "output": "D"}, {"input": "D", "output": "E"}, {"input": "E", "output": "F"}, {"input": "F", "output": "G"}, {"input": "G", "output": "H"}, {"input": "H", "output": "I"}, {"input": "I", "output": "J"}, {"input": "J", "output": "K"}, {"input": "K", "output": "L"}, {"input": "L", "output": "M"}, {"input": "M", "output": "N"}, {"input": "N", "output": "O"}, {"input": "O", "output": "P"}, {"input": "P", "output": "Q"}, {"input": "Q", "output": "R"}, {"input": "R", "output": "S"}, {"input": "S", "output": "T"}, {"input": "T", "output": "U"}, {"input": "U", "output": "V"}, {"input": "V", "output": "W"}, {"input": "W", "output": "X"}, {"input": "X", "output": "Y"}, {"input": "Y", "output": "Z"}, {"input": "AA", "output": "BB"}, {"input": "BB", "output": "CC"}, {"input": "CC", "output": "DD"}, {"input": "DD", "output": "EE"}, {"input": "EE", "output": "FF"}, {"input": "FF", "output": "GG"}, {"input": "GG", "output": "HH"}, {"input": "HH", "output": "II"}, {"input": "II", "output": "JJ"}, {"input": "JJ", "output": "KK"}, {"input": "KK", "output": "LL"}, {"input": "LL", "output": "MM"}, {"input": "MM", "output": "NN"}, {"input": "NN", "output": "OO"}, {"input": "OO", "output": "PP"}, {"input": "PP", "output": "QQ"}, {"input": "QQ", "output": "RR"}, {"input": "RR", "output": "SS"}, {"input": "SS", "output": "TT"}, {"input": "TT", "output": "UU"}, {"input": "UU", "output": "VV"}, {"input": "VV", "output": "WW"}, {"input": "WW", "output": "XX"}, {"input": "XX", "output": "YY"}, {"input": "YY", "output": "ZZ"}, {"input": "aa", "output": "bb"}, {"input": "bb", "output": "cc"}, {"input": "cc", "output": "dd"}, {"input": "dd", "output": "ee"}, {"input": "ee", "output": "ff"}, {"input": "ff", "output": "gg"}, {"input": "gg", "output": "hh"}, {"input": "hh", "output": "ii"}, {"input": "ii", "output": "jj"}, {"input": "jj", "output": "kk"}, {"input": "kk", "output": "ll"}, {"input": "ll", "output": "mm"}, {"input": "mm", "output": "nn"}, {"input": "nn", "output": "oo"}, {"input": "oo", "output": "pp"}, {"input": "pp", "output": "qq"}, {"input": "qq", "output": "rr"}, {"input": "rr", "output": "ss"}, {"input": "ss", "output": "tt"}, {"input": "tt", "output": "uu"}, {"input": "uu", "output": "vv"}, {"input": "vv", "output": "ww"}, {"input": "ww", "output": "xx"}, {"input": "xx", "output": "yy"}, {"input": "yy", "output": "zz"}, {"input": "I", "output": "II"}, {"input": "II", "output": "III"}, {"input": "III", "output": "IV"}, {"input": "IV", "output": "V"}, {"input": "V", "output": "VI"}, {"input": "VI", "output": "VII"}, {"input": "VII", "output": "VIII"}, {"input": "VIII", "output": "IX"}, {"input": "IX", "output": "X"}, {"input": "X", "output": "XI"}, {"input": "XI", "output": "XII"}, {"input": "XII", "output": "XIII"}, {"input": "XIII", "output": "XIV"}, {"input": "XIV", "output": "XV"}, {"input": "XV", "output": "XVI"}, {"input": "XVI", "output": "XVII"}, {"input": "XVII", "output": "XVIII"}, {"input": "XVIII", "output": "XIX"}, {"input": "XIX", "output": "XX"}, {"input": "i", "output": "ii"}, {"input": "ii", "output": "iii"}, {"input": "iii", "output": "iv"}, {"input": "iv", "output": "v"}, {"input": "v", "output": "vi"}, {"input": "vi", "output": "vii"}, {"input": "vii", "output": "viii"}, {"input": "viii", "output": "ix"}, {"input": "ix", "output": "x"}, {"input": "x", "output": "xi"}, {"input": "xi", "output": "xii"}, {"input": "xii", "output": "xiii"}, {"input": "xiii", "output": "xiv"}, {"input": "xiv", "output": "xv"}, {"input": "xv", "output": "xvi"}, {"input": "xvi", "output": "xvii"}, {"input": "xvii", "output": "xviii"}, {"input": "xviii", "output": "xix"}, {"input": "xix", "output": "xx"}, {"input": "monday", "output": "tuesday"}, {"input": "tuesday", "output": "wednesday"}, {"input": "wednesday", "output": "thursday"}, {"input": "thursday", "output": "friday"}, {"input": "friday", "output": "saturday"}, {"input": "saturday", "output": "sunday"}, {"input": "january", "output": "february"}, {"input": "february", "output": "march"}, {"input": "march", "output": "april"}, {"input": "april", "output": "may"}, {"input": "may", "output": "june"}, {"input": "june", "output": "july"}, {"input": "july", "output": "august"}, {"input": "august", "output": "september"}, {"input": "september", "output": "october"}, {"input": "october", "output": "november"}, {"input": "november", "output": "december"}, {"input": "Monday", "output": "Tuesday"}, {"input": "Tuesday", "output": "Wednesday"}, {"input": "Wednesday", "output": "Thursday"}, {"input": "Thursday", "output": "Friday"}, {"input": "Friday", "output": "Saturday"}, {"input": "Saturday", "output": "Sunday"}, {"input": "January", "output": "February"}, {"input": "February", "output": "March"}, {"input": "March", "output": "April"}, {"input": "April", "output": "May"}, {"input": "May", "output": "June"}, {"input": "June", "output": "July"}, {"input": "July", "output": "August"}, {"input": "August", "output": "September"}, {"input": "September", "output": "October"}, {"input": "October", "output": "November"}, {"input": "November", "output": "December"}, {"input": "sunday", "output": "monday"}, {"input": "december", "output": "january"}, {"input": "Sunday", "output": "Monday"}, {"input": "December", "output": "January"}] -------------------------------------------------------------------------------- /dataset_files/abstractive/prev_item.json: -------------------------------------------------------------------------------- 1 | [{"input": "one", "output": "zero"}, {"input": "two", "output": "one"}, {"input": "three", "output": "two"}, {"input": "four", "output": "three"}, {"input": "five", "output": "four"}, {"input": "six", "output": "five"}, {"input": "seven", "output": "six"}, {"input": "eight", "output": "seven"}, {"input": "nine", "output": "eight"}, {"input": "ten", "output": "nine"}, {"input": "eleven", "output": "ten"}, {"input": "twelve", "output": "eleven"}, {"input": "thirteen", "output": "twelve"}, {"input": "fourteen", "output": "thirteen"}, {"input": "fifteen", "output": "fourteen"}, {"input": "sixteen", "output": "fifteen"}, {"input": "seventeen", "output": "sixteen"}, {"input": "eighteen", "output": "seventeen"}, {"input": "nineteen", "output": "eighteen"}, {"input": "twenty", "output": "nineteen"}, {"input": "1", "output": "0"}, {"input": "2", "output": "1"}, {"input": "3", "output": "2"}, {"input": "4", "output": "3"}, {"input": "5", "output": "4"}, {"input": "6", "output": "5"}, {"input": "7", "output": "6"}, {"input": "8", "output": "7"}, {"input": "9", "output": "8"}, {"input": "10", "output": "9"}, {"input": "11", "output": "10"}, {"input": "12", "output": "11"}, {"input": "13", "output": "12"}, {"input": "14", "output": "13"}, {"input": "15", "output": "14"}, {"input": "16", "output": "15"}, {"input": "17", "output": "16"}, {"input": "18", "output": "17"}, {"input": "19", "output": "18"}, {"input": "20", "output": "19"}, {"input": "21", "output": "20"}, {"input": "22", "output": "21"}, {"input": "23", "output": "22"}, {"input": "24", "output": "23"}, {"input": "25", "output": "24"}, {"input": "26", "output": "25"}, {"input": "27", "output": "26"}, {"input": "28", "output": "27"}, {"input": "29", "output": "28"}, {"input": "b", "output": "a"}, {"input": "c", "output": "b"}, {"input": "d", "output": "c"}, {"input": "e", "output": "d"}, {"input": "f", "output": "e"}, {"input": "g", "output": "f"}, {"input": "h", "output": "g"}, {"input": "i", "output": "h"}, {"input": "j", "output": "i"}, {"input": "k", "output": "j"}, {"input": "l", "output": "k"}, {"input": "m", "output": "l"}, {"input": "n", "output": "m"}, {"input": "o", "output": "n"}, {"input": "p", "output": "o"}, {"input": "q", "output": "p"}, {"input": "r", "output": "q"}, {"input": "s", "output": "r"}, {"input": "t", "output": "s"}, {"input": "u", "output": "t"}, {"input": "v", "output": "u"}, {"input": "w", "output": "v"}, {"input": "x", "output": "w"}, {"input": "y", "output": "x"}, {"input": "z", "output": "y"}, {"input": "B", "output": "A"}, {"input": "C", "output": "B"}, {"input": "D", "output": "C"}, {"input": "E", "output": "D"}, {"input": "F", "output": "E"}, {"input": "G", "output": "F"}, {"input": "H", "output": "G"}, {"input": "I", "output": "H"}, {"input": "J", "output": "I"}, {"input": "K", "output": "J"}, {"input": "L", "output": "K"}, {"input": "M", "output": "L"}, {"input": "N", "output": "M"}, {"input": "O", "output": "N"}, {"input": "P", "output": "O"}, {"input": "Q", "output": "P"}, {"input": "R", "output": "Q"}, {"input": "S", "output": "R"}, {"input": "T", "output": "S"}, {"input": "U", "output": "T"}, {"input": "V", "output": "U"}, {"input": "W", "output": "V"}, {"input": "X", "output": "W"}, {"input": "Y", "output": "X"}, {"input": "Z", "output": "Y"}, {"input": "BB", "output": "AA"}, {"input": "CC", "output": "BB"}, {"input": "DD", "output": "CC"}, {"input": "EE", "output": "DD"}, {"input": "FF", "output": "EE"}, {"input": "GG", "output": "FF"}, {"input": "HH", "output": "GG"}, {"input": "II", "output": "HH"}, {"input": "JJ", "output": "II"}, {"input": "KK", "output": "JJ"}, {"input": "LL", "output": "KK"}, {"input": "MM", "output": "LL"}, {"input": "NN", "output": "MM"}, {"input": "OO", "output": "NN"}, {"input": "PP", "output": "OO"}, {"input": "QQ", "output": "PP"}, {"input": "RR", "output": "QQ"}, {"input": "SS", "output": "RR"}, {"input": "TT", "output": "SS"}, {"input": "UU", "output": "TT"}, {"input": "VV", "output": "UU"}, {"input": "WW", "output": "VV"}, {"input": "XX", "output": "WW"}, {"input": "YY", "output": "XX"}, {"input": "ZZ", "output": "YY"}, {"input": "bb", "output": "aa"}, {"input": "cc", "output": "bb"}, {"input": "dd", "output": "cc"}, {"input": "ee", "output": "dd"}, {"input": "ff", "output": "ee"}, {"input": "gg", "output": "ff"}, {"input": "hh", "output": "gg"}, {"input": "ii", "output": "hh"}, {"input": "jj", "output": "ii"}, {"input": "kk", "output": "jj"}, {"input": "ll", "output": "kk"}, {"input": "mm", "output": "ll"}, {"input": "nn", "output": "mm"}, {"input": "oo", "output": "nn"}, {"input": "pp", "output": "oo"}, {"input": "qq", "output": "pp"}, {"input": "rr", "output": "qq"}, {"input": "ss", "output": "rr"}, {"input": "tt", "output": "ss"}, {"input": "uu", "output": "tt"}, {"input": "vv", "output": "uu"}, {"input": "ww", "output": "vv"}, {"input": "xx", "output": "ww"}, {"input": "yy", "output": "xx"}, {"input": "zz", "output": "yy"}, {"input": "II", "output": "I"}, {"input": "III", "output": "II"}, {"input": "IV", "output": "III"}, {"input": "V", "output": "IV"}, {"input": "VI", "output": "V"}, {"input": "VII", "output": "VI"}, {"input": "VIII", "output": "VII"}, {"input": "IX", "output": "VIII"}, {"input": "X", "output": "IX"}, {"input": "XI", "output": "X"}, {"input": "XII", "output": "XI"}, {"input": "XIII", "output": "XII"}, {"input": "XIV", "output": "XIII"}, {"input": "XV", "output": "XIV"}, {"input": "XVI", "output": "XV"}, {"input": "XVII", "output": "XVI"}, {"input": "XVIII", "output": "XVII"}, {"input": "XIX", "output": "XVIII"}, {"input": "XX", "output": "XIX"}, {"input": "ii", "output": "i"}, {"input": "iii", "output": "ii"}, {"input": "iv", "output": "iii"}, {"input": "v", "output": "iv"}, {"input": "vi", "output": "v"}, {"input": "vii", "output": "vi"}, {"input": "viii", "output": "vii"}, {"input": "ix", "output": "viii"}, {"input": "x", "output": "ix"}, {"input": "xi", "output": "x"}, {"input": "xii", "output": "xi"}, {"input": "xiii", "output": "xii"}, {"input": "xiv", "output": "xiii"}, {"input": "xv", "output": "xiv"}, {"input": "xvi", "output": "xv"}, {"input": "xvii", "output": "xvi"}, {"input": "xviii", "output": "xvii"}, {"input": "xix", "output": "xviii"}, {"input": "xx", "output": "xix"}, {"input": "tuesday", "output": "monday"}, {"input": "wednesday", "output": "tuesday"}, {"input": "thursday", "output": "wednesday"}, {"input": "friday", "output": "thursday"}, {"input": "saturday", "output": "friday"}, {"input": "sunday", "output": "saturday"}, {"input": "february", "output": "january"}, {"input": "march", "output": "february"}, {"input": "april", "output": "march"}, {"input": "may", "output": "april"}, {"input": "june", "output": "may"}, {"input": "july", "output": "june"}, {"input": "august", "output": "july"}, {"input": "september", "output": "august"}, {"input": "october", "output": "september"}, {"input": "november", "output": "october"}, {"input": "december", "output": "november"}, {"input": "Tuesday", "output": "Monday"}, {"input": "Wednesday", "output": "Tuesday"}, {"input": "Thursday", "output": "Wednesday"}, {"input": "Friday", "output": "Thursday"}, {"input": "Saturday", "output": "Friday"}, {"input": "Sunday", "output": "Saturday"}, {"input": "February", "output": "January"}, {"input": "March", "output": "February"}, {"input": "April", "output": "March"}, {"input": "May", "output": "April"}, {"input": "June", "output": "May"}, {"input": "July", "output": "June"}, {"input": "August", "output": "July"}, {"input": "September", "output": "August"}, {"input": "October", "output": "September"}, {"input": "November", "output": "October"}, {"input": "December", "output": "November"}, {"input": "monday", "output": "sunday"}, {"input": "january", "output": "december"}, {"input": "Monday", "output": "Sunday"}, {"input": "January", "output": "December"}] -------------------------------------------------------------------------------- /dataset_files/abstractive/singular-plural.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input": "wallet", 4 | "output": "wallets" 5 | }, 6 | { 7 | "input": "keychain", 8 | "output": "keychains" 9 | }, 10 | { 11 | "input": "mountain", 12 | "output": "mountains" 13 | }, 14 | { 15 | "input": "comb", 16 | "output": "combs" 17 | }, 18 | { 19 | "input": "monitor", 20 | "output": "monitors" 21 | }, 22 | { 23 | "input": "island", 24 | "output": "islands" 25 | }, 26 | { 27 | "input": "rake", 28 | "output": "rakes" 29 | }, 30 | { 31 | "input": "needle", 32 | "output": "needles" 33 | }, 34 | { 35 | "input": "lighter", 36 | "output": "lighters" 37 | }, 38 | { 39 | "input": "slipper", 40 | "output": "slippers" 41 | }, 42 | { 43 | "input": "fireplace", 44 | "output": "fireplaces" 45 | }, 46 | { 47 | "input": "ladder", 48 | "output": "ladders" 49 | }, 50 | { 51 | "input": "jacket", 52 | "output": "jackets" 53 | }, 54 | { 55 | "input": "helicopter", 56 | "output": "helicopters" 57 | }, 58 | { 59 | "input": "paintbrush", 60 | "output": "paintbrushes" 61 | }, 62 | { 63 | "input": "dustpan", 64 | "output": "dustpans" 65 | }, 66 | { 67 | "input": "wrench", 68 | "output": "wrenches" 69 | }, 70 | { 71 | "input": "tablet", 72 | "output": "tablets" 73 | }, 74 | { 75 | "input": "hoe", 76 | "output": "hoes" 77 | }, 78 | { 79 | "input": "tie", 80 | "output": "ties" 81 | }, 82 | { 83 | "input": "toy", 84 | "output": "toys" 85 | }, 86 | { 87 | "input": "glass", 88 | "output": "glasses" 89 | }, 90 | { 91 | "input": "hairdryer", 92 | "output": "hairdryers" 93 | }, 94 | { 95 | "input": "axe", 96 | "output": "axes" 97 | }, 98 | { 99 | "input": "vacuum", 100 | "output": "vacuums" 101 | }, 102 | { 103 | "input": "blush", 104 | "output": "blushes" 105 | }, 106 | { 107 | "input": "stove", 108 | "output": "stoves" 109 | }, 110 | { 111 | "input": "ladle", 112 | "output": "ladles" 113 | }, 114 | { 115 | "input": "poster", 116 | "output": "posters" 117 | }, 118 | { 119 | "input": "hat", 120 | "output": "hats" 121 | }, 122 | { 123 | "input": "lake", 124 | "output": "lakes" 125 | }, 126 | { 127 | "input": "razor", 128 | "output": "razors" 129 | }, 130 | { 131 | "input": "bottle", 132 | "output": "bottles" 133 | }, 134 | { 135 | "input": "glove", 136 | "output": "gloves" 137 | }, 138 | { 139 | "input": "grater", 140 | "output": "graters" 141 | }, 142 | { 143 | "input": "dishwasher", 144 | "output": "dishwashers" 145 | }, 146 | { 147 | "input": "sofa", 148 | "output": "sofas" 149 | }, 150 | { 151 | "input": "bag", 152 | "output": "bags" 153 | }, 154 | { 155 | "input": "keyboard", 156 | "output": "keyboards" 157 | }, 158 | { 159 | "input": "clock", 160 | "output": "clocks" 161 | }, 162 | { 163 | "input": "book", 164 | "output": "books" 165 | }, 166 | { 167 | "input": "scarf", 168 | "output": "scarves" 169 | }, 170 | { 171 | "input": "pants", 172 | "output": "pants" 173 | }, 174 | { 175 | "input": "window", 176 | "output": "windows" 177 | }, 178 | { 179 | "input": "house", 180 | "output": "houses" 181 | }, 182 | { 183 | "input": "freezer", 184 | "output": "freezers" 185 | }, 186 | { 187 | "input": "rag", 188 | "output": "rags" 189 | }, 190 | { 191 | "input": "racquet", 192 | "output": "racquets" 193 | }, 194 | { 195 | "input": "hair gel", 196 | "output": "hair gels" 197 | }, 198 | { 199 | "input": "door", 200 | "output": "doors" 201 | }, 202 | { 203 | "input": "pillow", 204 | "output": "pillows" 205 | }, 206 | { 207 | "input": "ruler", 208 | "output": "rulers" 209 | }, 210 | { 211 | "input": "washer", 212 | "output": "washers" 213 | }, 214 | { 215 | "input": "ocean", 216 | "output": "oceans" 217 | }, 218 | { 219 | "input": "plate", 220 | "output": "plates" 221 | }, 222 | { 223 | "input": "eyeshadow", 224 | "output": "eyeshadows" 225 | }, 226 | { 227 | "input": "zipper", 228 | "output": "zippers" 229 | }, 230 | { 231 | "input": "radio", 232 | "output": "radios" 233 | }, 234 | { 235 | "input": "flower", 236 | "output": "flowers" 237 | }, 238 | { 239 | "input": "laptop", 240 | "output": "laptops" 241 | }, 242 | { 243 | "input": "eraser", 244 | "output": "erasers" 245 | }, 246 | { 247 | "input": "corkscrew", 248 | "output": "corkscrews" 249 | }, 250 | { 251 | "input": "eyeliner", 252 | "output": "eyeliners" 253 | }, 254 | { 255 | "input": "desk", 256 | "output": "desks" 257 | }, 258 | { 259 | "input": "knife", 260 | "output": "knives" 261 | }, 262 | { 263 | "input": "helmet", 264 | "output": "helmets" 265 | }, 266 | { 267 | "input": "mixer", 268 | "output": "mixers" 269 | }, 270 | { 271 | "input": "microwave", 272 | "output": "microwaves" 273 | }, 274 | { 275 | "input": "button", 276 | "output": "buttons" 277 | }, 278 | { 279 | "input": "jar", 280 | "output": "jars" 281 | }, 282 | { 283 | "input": "pan", 284 | "output": "pans" 285 | }, 286 | { 287 | "input": "key", 288 | "output": "keys" 289 | }, 290 | { 291 | "input": "perfume", 292 | "output": "perfumes" 293 | }, 294 | { 295 | "input": "tape", 296 | "output": "tapes" 297 | }, 298 | { 299 | "input": "shoes", 300 | "output": "shoes" 301 | }, 302 | { 303 | "input": "shirt", 304 | "output": "shirts" 305 | }, 306 | { 307 | "input": "candle", 308 | "output": "candles" 309 | }, 310 | { 311 | "input": "juicer", 312 | "output": "juicers" 313 | }, 314 | { 315 | "input": "peeler", 316 | "output": "peelers" 317 | }, 318 | { 319 | "input": "mirror", 320 | "output": "mirrors" 321 | }, 322 | { 323 | "input": "mascara", 324 | "output": "mascaras" 325 | }, 326 | { 327 | "input": "whisk", 328 | "output": "whisks" 329 | }, 330 | { 331 | "input": "shovel", 332 | "output": "shovels" 333 | }, 334 | { 335 | "input": "marker", 336 | "output": "markers" 337 | }, 338 | { 339 | "input": "lotion", 340 | "output": "lotions" 341 | }, 342 | { 343 | "input": "matches", 344 | "output": "matches" 345 | }, 346 | { 347 | "input": "moon", 348 | "output": "moons" 349 | }, 350 | { 351 | "input": "pot", 352 | "output": "pots" 353 | }, 354 | { 355 | "input": "mop", 356 | "output": "mops" 357 | }, 358 | { 359 | "input": "sprinkler", 360 | "output": "sprinklers" 361 | }, 362 | { 363 | "input": "can", 364 | "output": "cans" 365 | }, 366 | { 367 | "input": "notebook", 368 | "output": "notebooks" 369 | }, 370 | { 371 | "input": "airplane", 372 | "output": "airplanes" 373 | }, 374 | { 375 | "input": "tongs", 376 | "output": "tongs" 377 | }, 378 | { 379 | "input": "phone", 380 | "output": "phones" 381 | }, 382 | { 383 | "input": "paint", 384 | "output": "paints" 385 | }, 386 | { 387 | "input": "conditioner", 388 | "output": "conditioners" 389 | }, 390 | { 391 | "input": "purse", 392 | "output": "purses" 393 | }, 394 | { 395 | "input": "broom", 396 | "output": "brooms" 397 | }, 398 | { 399 | "input": "rug", 400 | "output": "rugs" 401 | }, 402 | { 403 | "input": "toaster", 404 | "output": "toasters" 405 | }, 406 | { 407 | "input": "kettle", 408 | "output": "kettles" 409 | }, 410 | { 411 | "input": "blender", 412 | "output": "blenders" 413 | }, 414 | { 415 | "input": "colander", 416 | "output": "colanders" 417 | }, 418 | { 419 | "input": "toothpaste", 420 | "output": "toothpastes" 421 | }, 422 | { 423 | "input": "bed", 424 | "output": "beds" 425 | }, 426 | { 427 | "input": "refrigerator", 428 | "output": "refrigerators" 429 | }, 430 | { 431 | "input": "stapler", 432 | "output": "staplers" 433 | }, 434 | { 435 | "input": "backpack", 436 | "output": "backpacks" 437 | }, 438 | { 439 | "input": "fabric", 440 | "output": "fabrics" 441 | }, 442 | { 443 | "input": "nut", 444 | "output": "nuts" 445 | }, 446 | { 447 | "input": "bowl", 448 | "output": "bowls" 449 | }, 450 | { 451 | "input": "bolt", 452 | "output": "bolts" 453 | }, 454 | { 455 | "input": "chopsticks", 456 | "output": "chopsticks" 457 | }, 458 | { 459 | "input": "computer", 460 | "output": "computers" 461 | }, 462 | { 463 | "input": "valley", 464 | "output": "valleys" 465 | }, 466 | { 467 | "input": "lantern", 468 | "output": "lanterns" 469 | }, 470 | { 471 | "input": "boot", 472 | "output": "boots" 473 | }, 474 | { 475 | "input": "bucket", 476 | "output": "buckets" 477 | }, 478 | { 479 | "input": "sandal", 480 | "output": "sandals" 481 | }, 482 | { 483 | "input": "lock", 484 | "output": "locks" 485 | }, 486 | { 487 | "input": "drill", 488 | "output": "drills" 489 | }, 490 | { 491 | "input": "toothbrush", 492 | "output": "toothbrushes" 493 | }, 494 | { 495 | "input": "lamp", 496 | "output": "lamps" 497 | }, 498 | { 499 | "input": "star", 500 | "output": "stars" 501 | }, 502 | { 503 | "input": "bandana", 504 | "output": "bandanas" 505 | }, 506 | { 507 | "input": "stereo", 508 | "output": "stereos" 509 | }, 510 | { 511 | "input": "teapot", 512 | "output": "teapots" 513 | }, 514 | { 515 | "input": "thread", 516 | "output": "threads" 517 | }, 518 | { 519 | "input": "lawnmower", 520 | "output": "lawnmowers" 521 | }, 522 | { 523 | "input": "mug", 524 | "output": "mugs" 525 | }, 526 | { 527 | "input": "game", 528 | "output": "games" 529 | }, 530 | { 531 | "input": "shoe", 532 | "output": "shoes" 533 | }, 534 | { 535 | "input": "mattress", 536 | "output": "mattresses" 537 | }, 538 | { 539 | "input": "sunglasses", 540 | "output": "sunglasses" 541 | }, 542 | { 543 | "input": "river", 544 | "output": "rivers" 545 | }, 546 | { 547 | "input": "beach", 548 | "output": "beaches" 549 | }, 550 | { 551 | "input": "sponge", 552 | "output": "sponges" 553 | }, 554 | { 555 | "input": "blanket", 556 | "output": "blankets" 557 | }, 558 | { 559 | "input": "headphones", 560 | "output": "headphones" 561 | }, 562 | { 563 | "input": "mouse", 564 | "output": "mice" 565 | }, 566 | { 567 | "input": "pen", 568 | "output": "pens" 569 | }, 570 | { 571 | "input": "tree", 572 | "output": "trees" 573 | }, 574 | { 575 | "input": "car", 576 | "output": "cars" 577 | }, 578 | { 579 | "input": "belt", 580 | "output": "belts" 581 | }, 582 | { 583 | "input": "sky", 584 | "output": "skies" 585 | }, 586 | { 587 | "input": "socks", 588 | "output": "sockss" 589 | }, 590 | { 591 | "input": "briefcase", 592 | "output": "briefcases" 593 | }, 594 | { 595 | "input": "crayon", 596 | "output": "crayons" 597 | }, 598 | { 599 | "input": "screw", 600 | "output": "screws" 601 | }, 602 | { 603 | "input": "bicycle", 604 | "output": "bicycles" 605 | }, 606 | { 607 | "input": "grill", 608 | "output": "grills" 609 | }, 610 | { 611 | "input": "sun", 612 | "output": "suns" 613 | }, 614 | { 615 | "input": "coffee maker", 616 | "output": "coffee makers" 617 | }, 618 | { 619 | "input": "forest", 620 | "output": "forests" 621 | }, 622 | { 623 | "input": "spatula", 624 | "output": "spatulas" 625 | }, 626 | { 627 | "input": "brush", 628 | "output": "brushes" 629 | }, 630 | { 631 | "input": "cup", 632 | "output": "cups" 633 | }, 634 | { 635 | "input": "picture", 636 | "output": "pictures" 637 | }, 638 | { 639 | "input": "jewelry", 640 | "output": "jewelries" 641 | }, 642 | { 643 | "input": "shampoo", 644 | "output": "shampoos" 645 | }, 646 | { 647 | "input": "speaker", 648 | "output": "speakers" 649 | }, 650 | { 651 | "input": "boat", 652 | "output": "boats" 653 | }, 654 | { 655 | "input": "bus", 656 | "output": "buses" 657 | }, 658 | { 659 | "input": "sock", 660 | "output": "socks" 661 | }, 662 | { 663 | "input": "pliers", 664 | "output": "pliers" 665 | }, 666 | { 667 | "input": "suitcase", 668 | "output": "suitcases" 669 | }, 670 | { 671 | "input": "saw", 672 | "output": "saws" 673 | }, 674 | { 675 | "input": "fork", 676 | "output": "forks" 677 | }, 678 | { 679 | "input": "calendar", 680 | "output": "calendars" 681 | }, 682 | { 683 | "input": "umbrella", 684 | "output": "umbrellas" 685 | }, 686 | { 687 | "input": "printer", 688 | "output": "printers" 689 | }, 690 | { 691 | "input": "hose", 692 | "output": "hoses" 693 | }, 694 | { 695 | "input": "paper", 696 | "output": "papers" 697 | }, 698 | { 699 | "input": "television", 700 | "output": "televisions" 701 | }, 702 | { 703 | "input": "newspaper", 704 | "output": "newspapers" 705 | }, 706 | { 707 | "input": "hammer", 708 | "output": "hammers" 709 | }, 710 | { 711 | "input": "chair", 712 | "output": "chairs" 713 | }, 714 | { 715 | "input": "ball", 716 | "output": "balls" 717 | }, 718 | { 719 | "input": "watch", 720 | "output": "watches" 721 | }, 722 | { 723 | "input": "screwdriver", 724 | "output": "screwdrivers" 725 | }, 726 | { 727 | "input": "cloud", 728 | "output": "clouds" 729 | }, 730 | { 731 | "input": "dryer", 732 | "output": "dryers" 733 | }, 734 | { 735 | "input": "train", 736 | "output": "trains" 737 | }, 738 | { 739 | "input": "lipstick", 740 | "output": "lipsticks" 741 | }, 742 | { 743 | "input": "hairbrush", 744 | "output": "hairbrushes" 745 | }, 746 | { 747 | "input": "glue", 748 | "output": "glues" 749 | }, 750 | { 751 | "input": "bat", 752 | "output": "bats" 753 | }, 754 | { 755 | "input": "roller", 756 | "output": "rollers" 757 | }, 758 | { 759 | "input": "ship", 760 | "output": "ships" 761 | }, 762 | { 763 | "input": "spoon", 764 | "output": "spoons" 765 | }, 766 | { 767 | "input": "motorcycle", 768 | "output": "motorcycles" 769 | }, 770 | { 771 | "input": "oven", 772 | "output": "ovens" 773 | }, 774 | { 775 | "input": "soap", 776 | "output": "soaps" 777 | }, 778 | { 779 | "input": "table", 780 | "output": "tables" 781 | }, 782 | { 783 | "input": "doll", 784 | "output": "dolls" 785 | }, 786 | { 787 | "input": "flashlight", 788 | "output": "flashlights" 789 | }, 790 | { 791 | "input": "scissors", 792 | "output": "scissors" 793 | }, 794 | { 795 | "input": "pencil", 796 | "output": "pencils" 797 | }, 798 | { 799 | "input": "magazine", 800 | "output": "magazines" 801 | }, 802 | { 803 | "input": "iron", 804 | "output": "irons" 805 | }, 806 | { 807 | "input": "camera", 808 | "output": "cameras" 809 | }, 810 | { 811 | "input": "trash can", 812 | "output": "trash cans" 813 | }, 814 | { 815 | "input": "nail", 816 | "output": "nails" 817 | }, 818 | { 819 | "input": "box", 820 | "output": "boxes" 821 | } 822 | ] -------------------------------------------------------------------------------- /dataset_files/generate/categories.json: -------------------------------------------------------------------------------- 1 | { 2 | "animal": [ 3 | "alpaca", 4 | "ant", 5 | "anteater", 6 | "bat", 7 | "bear", 8 | "bee", 9 | "beaver", 10 | "bird", 11 | "buffalo", 12 | "bunny", 13 | "butterfly", 14 | "camel", 15 | "cat", 16 | "caterpillar", 17 | "chicken", 18 | "cheetah", 19 | "cow", 20 | "coyote", 21 | "dog", 22 | "dolphin", 23 | "donkey", 24 | "dove", 25 | "duck", 26 | "eagle", 27 | "eel", 28 | "elephant", 29 | "ferret", 30 | "finch", 31 | "fish", 32 | "fox", 33 | "frog", 34 | "goat", 35 | "goose", 36 | "gorilla", 37 | "hamster", 38 | "hedgehog", 39 | "hippopotamus", 40 | "horse", 41 | "jaguar", 42 | "kangaroo", 43 | "koala", 44 | "llama", 45 | "lion", 46 | "lizard", 47 | "monkey", 48 | "moth", 49 | "mouse", 50 | "octopus", 51 | "otter", 52 | "owl", 53 | "pig", 54 | "pigeon", 55 | "rabbit", 56 | "rat", 57 | "seal", 58 | "scorpion", 59 | "shark", 60 | "sheep", 61 | "skunk", 62 | "snail", 63 | "snake", 64 | "spider", 65 | "squirrel", 66 | "swan", 67 | "tiger", 68 | "turkey", 69 | "turtle", 70 | "vulture", 71 | "weasel", 72 | "whale", 73 | "wolf", 74 | "wombat", 75 | "worm", 76 | "wolverine", 77 | "zebra", 78 | "aardvark", 79 | "alligator", 80 | "armadillo", 81 | "baboon", 82 | "badger", 83 | "barracuda", 84 | "bison", 85 | "boar", 86 | "capybara", 87 | "caribou", 88 | "chimpanzee", 89 | "chinchilla", 90 | "chipmunk", 91 | "cobra", 92 | "cockroach", 93 | "cougar", 94 | "crab", 95 | "crane", 96 | "crocodile", 97 | "crow", 98 | "deer", 99 | "dingo", 100 | "dragonfly", 101 | "elk", 102 | "emu", 103 | "falcon", 104 | "firefly", 105 | "flamingo", 106 | "fossa", 107 | "gazelle", 108 | "gecko", 109 | "giraffe", 110 | "gnu", 111 | "grizzly", 112 | "hornet", 113 | "hyena", 114 | "ibex", 115 | "iguana", 116 | "impala", 117 | "jackal", 118 | "jellyfish", 119 | "komodo", 120 | "lemur", 121 | "leopard", 122 | "lobster", 123 | "lynx", 124 | "macaw", 125 | "magpie", 126 | "mandrill", 127 | "manatee", 128 | "mantis", 129 | "meerkat", 130 | "moose", 131 | "moray", 132 | "narwhal", 133 | "newt", 134 | "ocelot", 135 | "okapi", 136 | "opossum", 137 | "orangutan", 138 | "oryx", 139 | "ostrich", 140 | "panda", 141 | "panther", 142 | "parrot", 143 | "peacock", 144 | "pelican", 145 | "penguin", 146 | "puma", 147 | "python", 148 | "quokka", 149 | "raccoon", 150 | "reindeer", 151 | "rhinoceros", 152 | "salamander", 153 | "seahorse", 154 | "sloth", 155 | "toucan", 156 | "walrus", 157 | "woodpecker", 158 | "yak" 159 | ], 160 | "object": [ 161 | "accordion", 162 | "airplane", 163 | "alarm", 164 | "anchor", 165 | "apron", 166 | "bag", 167 | "ball", 168 | "basket", 169 | "basketball", 170 | "beef", 171 | "bicycle", 172 | "blanket", 173 | "boat", 174 | "book", 175 | "boomerang", 176 | "bottle", 177 | "bowl", 178 | "cactus", 179 | "cake", 180 | "camera", 181 | "candle", 182 | "candlestick", 183 | "candy", 184 | "car", 185 | "carrot", 186 | "chair", 187 | "chocolate", 188 | "clock", 189 | "computer", 190 | "cookie", 191 | "cream", 192 | "cube", 193 | "cup", 194 | "curtain", 195 | "desk", 196 | "dice", 197 | "donut", 198 | "door", 199 | "dress", 200 | "drum", 201 | "dumbbell", 202 | "duster", 203 | "earmuffs", 204 | "earring", 205 | "easel", 206 | "egg", 207 | "envelope", 208 | "eraser", 209 | "fan", 210 | "feather", 211 | "fishing pole", 212 | "flower", 213 | "fork", 214 | "fountain", 215 | "garlic", 216 | "glass", 217 | "glasses", 218 | "globe", 219 | "gloves", 220 | "guitar", 221 | "gumball", 222 | "hairbrush", 223 | "hammer", 224 | "hammock", 225 | "hat", 226 | "hoop", 227 | "house", 228 | "ice", 229 | "igloo", 230 | "incense", 231 | "ink", 232 | "jacket", 233 | "jar", 234 | "jeans", 235 | "jigsaw", 236 | "juice", 237 | "kayak", 238 | "kettle", 239 | "key", 240 | "kite", 241 | "knife", 242 | "ladder", 243 | "lamp", 244 | "lantern", 245 | "laptop", 246 | "lettuce", 247 | "map", 248 | "maracas", 249 | "marker", 250 | "match", 251 | "microphone", 252 | "mirror", 253 | "motorcycle", 254 | "necklace", 255 | "net", 256 | "newspaper", 257 | "notebook", 258 | "olive", 259 | "onion", 260 | "oven", 261 | "paintbrush", 262 | "painting", 263 | "paper", 264 | "pasta", 265 | "pen", 266 | "pencil", 267 | "pepper", 268 | "phone", 269 | "piano", 270 | "picture", 271 | "pillow", 272 | "pizza", 273 | "plant", 274 | "plate", 275 | "pork", 276 | "potato", 277 | "puzzle", 278 | "quill", 279 | "quilt", 280 | "radio", 281 | "rake", 282 | "remote", 283 | "rice", 284 | "rifle", 285 | "robot", 286 | "rock", 287 | "rug", 288 | "ruler", 289 | "scissors", 290 | "sculpture", 291 | "shirt", 292 | "shoe", 293 | "skates", 294 | "snorkel", 295 | "socks", 296 | "soda", 297 | "sofa", 298 | "spoon", 299 | "stapler", 300 | "table", 301 | "tambourine", 302 | "tape", 303 | "teapot", 304 | "television", 305 | "tennis racket", 306 | "toilet", 307 | "tomato", 308 | "towel", 309 | "ukulele", 310 | "umbrella", 311 | "vacuum", 312 | "vase", 313 | "violin", 314 | "volleyball", 315 | "wallet", 316 | "watermelon", 317 | "whistle", 318 | "window", 319 | "wristwatch", 320 | "x-ray", 321 | "xylophone", 322 | "yacht", 323 | "yarn", 324 | "yo-yo", 325 | "yogurt", 326 | "zeppelin", 327 | "zipper", 328 | "zucchini" 329 | ], 330 | "verb": [ 331 | "achieve", 332 | "analyze", 333 | "approve", 334 | "argue", 335 | "arrive", 336 | "attack", 337 | "believe", 338 | "breathe", 339 | "build", 340 | "calculate", 341 | "celebrate", 342 | "change", 343 | "choose", 344 | "climb", 345 | "collect", 346 | "compete", 347 | "complete", 348 | "consider", 349 | "consult", 350 | "copy", 351 | "create", 352 | "cry", 353 | "dance", 354 | "decide", 355 | "define", 356 | "deliver", 357 | "design", 358 | "destroy", 359 | "develop", 360 | "discuss", 361 | "discover", 362 | "dislike", 363 | "divide", 364 | "doubt", 365 | "enjoy", 366 | "examine", 367 | "exchange", 368 | "exist", 369 | "explore", 370 | "fear", 371 | "fight", 372 | "finish", 373 | "focus", 374 | "forgive", 375 | "gather", 376 | "give", 377 | "grow", 378 | "handle", 379 | "hate", 380 | "hear", 381 | "help", 382 | "jump", 383 | "juggle", 384 | "jog", 385 | "join", 386 | "judge", 387 | "jolt", 388 | "justify", 389 | "kick", 390 | "keep", 391 | "kill", 392 | "kindle", 393 | "kiss", 394 | "knit", 395 | "knock", 396 | "knot", 397 | "know", 398 | "kneel", 399 | "label", 400 | "land", 401 | "laugh", 402 | "launch", 403 | "learn", 404 | "lecture", 405 | "lift", 406 | "like", 407 | "listen", 408 | "live", 409 | "make", 410 | "manage", 411 | "manipulate", 412 | "mark", 413 | "master", 414 | "maximize", 415 | "measure", 416 | "memorize", 417 | "merge", 418 | "minimize", 419 | "navigate", 420 | "need", 421 | "negotiate", 422 | "notice", 423 | "nourish", 424 | "nurture", 425 | "observe", 426 | "obtain", 427 | "open", 428 | "operate", 429 | "organize", 430 | "overcome", 431 | "oversee", 432 | "paint", 433 | "participate", 434 | "perform", 435 | "persuade", 436 | "plan", 437 | "play", 438 | "practice", 439 | "predict", 440 | "prepare", 441 | "produce", 442 | "qualify", 443 | "question", 444 | "query", 445 | "quiet", 446 | "race", 447 | "reach", 448 | "read", 449 | "realize", 450 | "recruit", 451 | "reflect", 452 | "release", 453 | "relax", 454 | "remember", 455 | "remove", 456 | "sail", 457 | "sample", 458 | "save", 459 | "schedule", 460 | "search", 461 | "select", 462 | "serve", 463 | "solve", 464 | "speak", 465 | "study", 466 | "talk", 467 | "target", 468 | "teach", 469 | "test", 470 | "think", 471 | "train", 472 | "transform", 473 | "travel", 474 | "treat", 475 | "try", 476 | "uncover", 477 | "understand", 478 | "unite", 479 | "update", 480 | "use", 481 | "validate", 482 | "value", 483 | "verify", 484 | "view", 485 | "visit", 486 | "visualize", 487 | "volunteer", 488 | "walk", 489 | "watch", 490 | "win", 491 | "work", 492 | "write", 493 | "xerox", 494 | "yearn", 495 | "yield", 496 | "zoom", 497 | "zap" 498 | ], 499 | "color": [ 500 | "red", 501 | "blue", 502 | "green", 503 | "yellow", 504 | "orange", 505 | "purple", 506 | "violet", 507 | "pink", 508 | "brown", 509 | "black", 510 | "white", 511 | "gray", 512 | "silver", 513 | "gold", 514 | "coral", 515 | "cream", 516 | "olive", 517 | "salmon", 518 | "navy", 519 | "mint", 520 | "mustard", 521 | "indigo" 522 | ], 523 | "fruit": [ 524 | "apple", 525 | "apricot", 526 | "avocado", 527 | "banana", 528 | "blackberry", 529 | "cherry", 530 | "clementine", 531 | "coconut", 532 | "cranberry", 533 | "date", 534 | "dragonfruit", 535 | "durian", 536 | "fig", 537 | "gooseberry", 538 | "guava", 539 | "grape", 540 | "grapefruit", 541 | "huckleberry", 542 | "jackfruit", 543 | "kiwifruit", 544 | "kumquat", 545 | "lemon", 546 | "lime", 547 | "mango", 548 | "mandarine", 549 | "nectarine", 550 | "orange", 551 | "papaya", 552 | "passionfruit", 553 | "peach", 554 | "pear", 555 | "persimmon", 556 | "pineapple", 557 | "plantain", 558 | "plum", 559 | "pomegranate", 560 | "prune", 561 | "raspberry", 562 | "strawberry", 563 | "tangerine" 564 | ], 565 | "adjective": [ 566 | "agile", 567 | "adorable", 568 | "adoring", 569 | "adventurous", 570 | "affable", 571 | "affectionate", 572 | "agreeable", 573 | "altruistic", 574 | "amazing", 575 | "amiable", 576 | "bad", 577 | "benevolent", 578 | "big", 579 | "bitter", 580 | "blissful", 581 | "blithe", 582 | "bold", 583 | "bountiful", 584 | "brave", 585 | "bright", 586 | "calm", 587 | "carefree", 588 | "caring", 589 | "charismatic", 590 | "charming", 591 | "cheap", 592 | "cheerful", 593 | "clean", 594 | "clever", 595 | "cold", 596 | "courageous", 597 | "cowardly", 598 | "daring", 599 | "dark", 600 | "dazzling", 601 | "delightful", 602 | "determined", 603 | "devoted", 604 | "diligent", 605 | "dirty", 606 | "dry", 607 | "dynamic", 608 | "eager", 609 | "ecstatic", 610 | "eloquent", 611 | "enchanting", 612 | "energetic", 613 | "enthusiastic", 614 | "expensive", 615 | "exquisite", 616 | "exuberant", 617 | "faithful", 618 | "fascinating", 619 | "fast", 620 | "fearless", 621 | "fierce", 622 | "fresh", 623 | "friendly", 624 | "funny", 625 | "generous", 626 | "gentle", 627 | "genuine", 628 | "good", 629 | "graceful", 630 | "gracious", 631 | "grateful", 632 | "happy", 633 | "hard", 634 | "harmonious", 635 | "heavy", 636 | "honest", 637 | "hopeful", 638 | "hot", 639 | "humble", 640 | "hungry", 641 | "idealistic", 642 | "innocent", 643 | "inquisitive", 644 | "insightful", 645 | "intelligent", 646 | "intrepid", 647 | "intuitive", 648 | "inventive", 649 | "jolly", 650 | "jovial", 651 | "joyful", 652 | "joyous", 653 | "jubilant", 654 | "keen", 655 | "kind", 656 | "kind-hearted", 657 | "kindhearted", 658 | "kindred", 659 | "knowledgeable", 660 | "laughing", 661 | "light", 662 | "lively", 663 | "long", 664 | "loud", 665 | "lovable", 666 | "lovely", 667 | "loving", 668 | "lucky", 669 | "luminous", 670 | "magnificent", 671 | "mellow", 672 | "mild", 673 | "mirthful", 674 | "modern", 675 | "modest", 676 | "naive", 677 | "natural", 678 | "naughty", 679 | "new", 680 | "noble", 681 | "nurturing", 682 | "observant", 683 | "old", 684 | "optimistic", 685 | "passionate", 686 | "patient", 687 | "peaceful", 688 | "pensive", 689 | "playful", 690 | "quick", 691 | "quick-witted", 692 | "quiet", 693 | "quirky", 694 | "quizzical", 695 | "radiant", 696 | "reliable", 697 | "resilient", 698 | "resolute", 699 | "resourceful", 700 | "rotten", 701 | "sad", 702 | "salty", 703 | "sensible", 704 | "sensitive", 705 | "serene", 706 | "serious", 707 | "short", 708 | "silly", 709 | "sincere", 710 | "slow", 711 | "small", 712 | "smart", 713 | "soft", 714 | "sour", 715 | "spicy", 716 | "strong", 717 | "stupid", 718 | "sweet", 719 | "talented", 720 | "tall", 721 | "tenacious", 722 | "tender", 723 | "thick", 724 | "thin", 725 | "thirsty", 726 | "thoughtful", 727 | "tranquil", 728 | "trustworthy", 729 | "unique", 730 | "unselfish", 731 | "unwavering", 732 | "upbeat", 733 | "uplifting", 734 | "versatile", 735 | "vibrant", 736 | "vivacious", 737 | "warm", 738 | "warmhearted", 739 | "weak", 740 | "wet", 741 | "whimsical", 742 | "wise", 743 | "witty", 744 | "wonderful", 745 | "young", 746 | "youthful", 747 | "zany", 748 | "zealous", 749 | "zesty" 750 | ], 751 | "pronoun": [ 752 | "I", 753 | "you", 754 | "he", 755 | "she", 756 | "it", 757 | "we", 758 | "they", 759 | "me", 760 | "him", 761 | "her", 762 | "us", 763 | "them", 764 | "myself", 765 | "yourself", 766 | "himself", 767 | "herself", 768 | "itself", 769 | "ourselves", 770 | "themselves", 771 | "who", 772 | "whom", 773 | "whose", 774 | "whoever", 775 | "which", 776 | "that", 777 | "these", 778 | "those" 779 | ], 780 | "preposition": [ 781 | "about", 782 | "above", 783 | "across", 784 | "after", 785 | "against", 786 | "along", 787 | "among", 788 | "around", 789 | "as", 790 | "at", 791 | "before", 792 | "behind", 793 | "below", 794 | "beneath", 795 | "beside", 796 | "between", 797 | "beyond", 798 | "but", 799 | "by", 800 | "concerning", 801 | "considering", 802 | "despite", 803 | "down", 804 | "during", 805 | "except", 806 | "for", 807 | "from", 808 | "in", 809 | "inside", 810 | "into", 811 | "like", 812 | "near", 813 | "of", 814 | "off", 815 | "on", 816 | "onto", 817 | "out", 818 | "outside", 819 | "over", 820 | "past", 821 | "regarding", 822 | "round", 823 | "since", 824 | "through", 825 | "throughout", 826 | "to", 827 | "toward", 828 | "under", 829 | "underneath", 830 | "until", 831 | "up", 832 | "upon", 833 | "with", 834 | "within", 835 | "without" 836 | ] 837 | } -------------------------------------------------------------------------------- /dataset_files/generate/create_antonym_synonym_datasets.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from typing import * 3 | from typing import TextIO 4 | 5 | import json 6 | import os 7 | import random 8 | 9 | def verify_word_length(word: str, tokenizer: object) -> bool: 10 | """ 11 | Verifies whether a word can be tokenized into a single token 12 | 13 | Parameters: 14 | word: the word that we're checking 15 | tokenizer: the tokenizer we use to tokenize the word 16 | 17 | Return: a boolean denoting whether the word is within the required 1 or not 18 | """ 19 | 20 | return len(tokenizer(word)['input_ids']) == 1 21 | 22 | def parse_file( 23 | f_in: TextIO, 24 | ant_list: List[Dict], 25 | syn_list: List[Dict], 26 | seen: Set, 27 | tokenizer: object 28 | ): 29 | """ 30 | Parses the input file into synonym and antonym categories 31 | 32 | Parameters: 33 | f_in: the input file from where the data will be taken from 34 | ant_list: the list of antonyms 35 | syn_list: the list of synonyms 36 | seen: the seen set of tuples to check against for duplicates 37 | tokenizer: the tokenizer we use to tokenize the word 38 | """ 39 | for line in f_in: 40 | word1, word2, t = line.split() 41 | t = int(t) 42 | 43 | word1_bool = verify_word_length(" " + word1, tokenizer) 44 | word2_bool = verify_word_length(" " + word2, tokenizer) 45 | 46 | if word1_bool and word2_bool: 47 | d = {"input": word1, "output": word2} 48 | words = (word1, word2) 49 | if words not in seen: 50 | seen.add(words) 51 | else: 52 | continue 53 | # Synonym 54 | if t == 0: 55 | syn_list.append(d) 56 | # Antonym 57 | else: 58 | ant_list.append(d) 59 | else: 60 | continue 61 | 62 | 63 | if __name__ == "__main__": 64 | # Seed for dataset generation 65 | random.seed(42) 66 | model_name = r"EleutherAI/gpt-j-6B" 67 | 68 | # Load Tokenizer 69 | tokenizer = AutoTokenizer.from_pretrained(model_name) 70 | tokenizer.pad_token = tokenizer.eos_token 71 | 72 | assert os.path.exists('./AntSynNET/dataset'), "Original dataset missing! Please first clone https://github.com/nguyenkh/AntSynNET into this folder in order to re-generate antonym and synonym datasets." 73 | 74 | out_dir = "../abstractive" 75 | 76 | if not os.path.exists(out_dir): 77 | os.makedirs(out_dir) 78 | 79 | input_data_dir = "./AntSynNET/dataset" 80 | splits = ["train", "val", "test"] 81 | types = ["adjective", "noun", "verb"] 82 | 83 | ant_list = [] 84 | syn_list = [] 85 | filename_ant = "antonym.json" 86 | filename_syn = "synonym.json" 87 | seen = set() 88 | 89 | ant_path = os.path.join(out_dir, filename_ant) 90 | syn_path = os.path.join(out_dir, filename_syn) 91 | f_ant = open(ant_path, "w") 92 | f_syn = open(syn_path, "w") 93 | for s in splits: 94 | for t in types: 95 | path = t + "-pairs." + s 96 | full_path = os.path.join(input_data_dir, path) 97 | input_file = open(full_path, "r") 98 | parse_file(input_file, ant_list, syn_list, seen, tokenizer) 99 | 100 | json.dump(ant_list, f_ant) 101 | json.dump(syn_list, f_syn) 102 | f_ant.close() 103 | f_syn.close() 104 | -------------------------------------------------------------------------------- /dataset_files/generate/create_translation_datasets.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from typing import * 3 | from typing import TextIO 4 | 5 | import json 6 | import os 7 | 8 | def verify_word_length(word: str, tokenizer: object) -> bool: 9 | """ 10 | Verifies whether a word can be tokenized into a single word or not 11 | 12 | Parameters: 13 | word: the word that we're checking 14 | tokenizer: the tokenizer we use to tokenize the word 15 | 16 | Return: a boolean denoting whether the word is within the required 1 or not 17 | """ 18 | 19 | return len(tokenizer(word)['input_ids']) == 1 20 | 21 | if __name__ == "__main__": 22 | 23 | d_names = {'en-de':"english-german", 'en-es':"english-spanish",'en-fr':"english-french"} 24 | path_exists = [os.path.exists(f'./translation/{lang_id}.0-5000.txt') for lang_id in d_names.keys()] + [os.path.exists(f'./translation/{lang_id}.5000-6500.txt') for lang_id in d_names.keys()] 25 | 26 | assert all(path_exists), "Original data missing! Please download corresponding 'train' and 'test' files from https://github.com/facebookresearch/MUSE#ground-truth-bilingual-dictionaries in order to re-generate translation_datasets." 27 | 28 | model_name = r"EleutherAI/gpt-j-6b" 29 | 30 | # Load Tokenizer 31 | tokenizer = AutoTokenizer.from_pretrained(model_name) 32 | tokenizer.pad_token = tokenizer.eos_token 33 | 34 | out_dir = '../abstractive' 35 | 36 | if not os.path.exists(out_dir): 37 | os.makedirs(out_dir) 38 | 39 | for lang_id in d_names.keys(): 40 | valid = [] 41 | for d_base in [f'./translation/{lang_id}.0-5000.txt', f'./translation/{lang_id}.5000-6500.txt']: 42 | with open(d_base, 'r', encoding="utf-8") as f: 43 | lines = f.read() 44 | 45 | word_pairs = list(set([tuple(x.split()) for x in lines.splitlines()])) 46 | word_pairs = [{'input':w1, 'output':w2} for (w1,w2) in word_pairs] 47 | 48 | for i, x in enumerate(word_pairs): 49 | if (x['input'] != x['output']): # Filter pairs that are exact copies 50 | valid.append(word_pairs[i]) 51 | 52 | json.dump(valid, open(os.path.join(out_dir, f'{d_names[lang_id]}.json'), 'w')) 53 | 54 | 55 | -------------------------------------------------------------------------------- /fv_environment.yml: -------------------------------------------------------------------------------- 1 | # all packages used: 2 | name: fv 3 | channels: 4 | - pytorch 5 | - huggingface 6 | - nvidia 7 | - defaults 8 | dependencies: 9 | - python=3.10 10 | - cudatoolkit=11.7.0 11 | - datasets=2.14.3 12 | - jupyter=1.0.0 13 | - matplotlib=3.7.1 14 | - numpy=1.25.0 15 | - pandas=1.5.3 16 | - plotly=5.9.0 17 | - pytorch=1.13.0 18 | - pip=23.2.1 19 | - scikit-learn=1.3.0 20 | - seaborn=0.12.2 21 | - sentencepiece=0.1.99 22 | - transformers=4.49.0 23 | - tqdm=4.65.0 24 | - pip: 25 | - git+https://github.com/davidbau/baukit@main#egg=baukit 26 | - bitsandbytes==0.45.3 27 | - huggingface-hub==0.29.3 28 | - accelerate==0.21.0 -------------------------------------------------------------------------------- /fv_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericwtodd/function_vectors/751e2219d304eba471cffcacc9efd89a4f8ef3c4/fv_overview.png -------------------------------------------------------------------------------- /notebooks/fv_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os, re, json\n", 20 | "import torch, numpy as np\n", 21 | "\n", 22 | "import sys\n", 23 | "sys.path.append('..')\n", 24 | "torch.set_grad_enabled(False)\n", 25 | "\n", 26 | "from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector\n", 27 | "from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention\n", 28 | "from src.utils.model_utils import load_gpt_model_and_tokenizer\n", 29 | "from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt\n", 30 | "from src.utils.eval_utils import decode_to_vocab, sentence_eval" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Load model & tokenizer" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "model_name = 'EleutherAI/gpt-j-6b'\n", 47 | "model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)\n", 48 | "EDIT_LAYER = 9" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Load dataset and Compute task-conditioned mean activations" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "dataset = load_dataset('antonym', seed=0)\n", 65 | "mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Compute function vector (FV)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# Sample ICL example pairs, and a test word\n", 98 | "dataset = load_dataset('antonym')\n", 99 | "word_pairs = dataset['train'][:5]\n", 100 | "test_pair = dataset['test'][21]\n", 101 | "\n", 102 | "prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)\n", 103 | "sentence = create_prompt(prompt_data)\n", 104 | "print(\"ICL prompt:\\n\", repr(sentence), '\\n\\n')\n", 105 | "\n", 106 | "shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)\n", 107 | "shuffled_sentence = create_prompt(shuffled_prompt_data)\n", 108 | "print(\"Shuffled ICL Prompt:\\n\", repr(shuffled_sentence), '\\n\\n')\n", 109 | "\n", 110 | "zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)\n", 111 | "zeroshot_sentence = create_prompt(zeroshot_prompt_data)\n", 112 | "print(\"Zero-Shot Prompt:\\n\", repr(zeroshot_sentence))" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Evaluation" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "### Clean ICL Prompt" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# Check model's ICL answer\n", 136 | "clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)\n", 137 | "\n", 138 | "print(\"Input Sentence:\", repr(sentence), '\\n')\n", 139 | "print(f\"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\\n\")\n", 140 | "print(\"ICL Prompt Top K Vocab Probs:\\n\", decode_to_vocab(clean_logits, tokenizer, k=5), '\\n')" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "### Corrupted ICL Prompt" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "# Perform an intervention on the shuffled setting\n", 157 | "clean_logits, interv_logits = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)\n", 158 | "\n", 159 | "print(\"Input Sentence:\", repr(shuffled_sentence), '\\n')\n", 160 | "print(f\"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\\n\")\n", 161 | "print(\"Few-Shot-Shuffled Prompt Top K Vocab Probs:\\n\", decode_to_vocab(clean_logits, tokenizer, k=5), '\\n')\n", 162 | "print(\"Shuffled Prompt+FV Top K Vocab Probs:\\n\", decode_to_vocab(interv_logits, tokenizer, k=5))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "### Zero-Shot Prompt" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "# Intervention on the zero-shot prompt\n", 179 | "clean_logits, interv_logits = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)\n", 180 | "\n", 181 | "print(\"Input Sentence:\", repr(zeroshot_sentence), '\\n')\n", 182 | "print(f\"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\\n\")\n", 183 | "print(\"Zero-Shot Top K Vocab Probs:\\n\", decode_to_vocab(clean_logits, tokenizer, k=5), '\\n')\n", 184 | "print(\"Zero-Shot+FV Vocab Top K Vocab Probs:\\n\", decode_to_vocab(interv_logits, tokenizer, k=5))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "### Natural Text Prompt" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "sentence = f\"The word \\\"{test_pair['input']}\\\" means\"\n", 201 | "co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)\n", 202 | "\n", 203 | "\n", 204 | "print(\"Input Sentence: \", repr(sentence))\n", 205 | "print(\"GPT-J:\" , repr(tokenizer.decode(co.squeeze())))\n", 206 | "print(\"GPT-J+FV:\", repr(tokenizer.decode(io.squeeze())), '\\n')" 207 | ] 208 | } 209 | ], 210 | "metadata": { 211 | "kernelspec": { 212 | "display_name": "Python 3 (ipykernel)", 213 | "language": "python", 214 | "name": "python3" 215 | }, 216 | "language_info": { 217 | "codemirror_mode": { 218 | "name": "ipython", 219 | "version": 3 220 | }, 221 | "file_extension": ".py", 222 | "mimetype": "text/x-python", 223 | "name": "python", 224 | "nbconvert_exporter": "python", 225 | "pygments_lexer": "ipython3", 226 | "version": "3.10.12" 227 | } 228 | }, 229 | "nbformat": 4, 230 | "nbformat_minor": 2 231 | } 232 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericwtodd/function_vectors/751e2219d304eba471cffcacc9efd89a4f8ef3c4/src/__init__.py -------------------------------------------------------------------------------- /src/compute_average_activations.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import torch, numpy as np 3 | import argparse 4 | 5 | # Include prompt creation helper functions 6 | from utils.prompt_utils import * 7 | from utils.intervention_utils import * 8 | from utils.model_utils import * 9 | from utils.extract_utils import * 10 | 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True) 17 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 18 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 19 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results') 20 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42) 21 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", required=False, default=10) 22 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over", required=False, default=100) 23 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3) 24 | parser.add_argument('--device', help='Device to run on', required=False, default='cuda' if torch.cuda.is_available() else 'cpu') 25 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""}) 26 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""}) 27 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None) 28 | 29 | 30 | args = parser.parse_args() 31 | 32 | dataset_name = args.dataset_name 33 | model_name = args.model_name 34 | root_data_dir = args.root_data_dir 35 | save_path_root = f"{args.save_path_root}/{dataset_name}" 36 | seed = args.seed 37 | n_shots = args.n_shots 38 | n_trials = args.n_trials 39 | test_split = args.test_split 40 | device = args.device 41 | prefixes = args.prefixes 42 | separators = args.separators 43 | 44 | 45 | # Load Model & Tokenizer 46 | torch.set_grad_enabled(False) 47 | print("Loading Model") 48 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision) 49 | 50 | set_seed(seed) 51 | 52 | # Load the dataset 53 | print("Loading Dataset") 54 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) 55 | 56 | print("Computing Mean Activations") 57 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 58 | n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators) 59 | 60 | if not os.path.exists(save_path_root): 61 | os.makedirs(save_path_root) 62 | 63 | # Write args to file 64 | args.save_path_root = save_path_root # update for logging 65 | with open(f'{save_path_root}/mean_head_activation_args.txt', 'w') as arg_file: 66 | json.dump(args.__dict__, arg_file, indent=2) 67 | 68 | torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt') 69 | 70 | -------------------------------------------------------------------------------- /src/compute_avg_hidden_state.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import torch, numpy as np 3 | import argparse 4 | 5 | # Include prompt creation helper functions 6 | from utils.prompt_utils import * 7 | from utils.intervention_utils import * 8 | from utils.model_utils import * 9 | from utils.eval_utils import * 10 | from utils.extract_utils import * 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True) 17 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 18 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 19 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results') 20 | parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5) 21 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", required=False, default=10) 22 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over", required=False, default=100) 23 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3) 24 | parser.add_argument('--device', help='Device to run on', required=False, default='cuda' if torch.cuda.is_available() else 'cpu') 25 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""}) 26 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""}) 27 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None) 28 | 29 | 30 | args = parser.parse_args() 31 | 32 | dataset_name = args.dataset_name 33 | model_name = args.model_name 34 | root_data_dir = args.root_data_dir 35 | save_path_root = f"{args.save_path_root}/{dataset_name}" 36 | n_seeds = args.n_seeds 37 | n_shots = args.n_shots 38 | n_trials = args.n_trials 39 | test_split = args.test_split 40 | device = args.device 41 | prefixes = args.prefixes 42 | separators = args.separators 43 | 44 | 45 | # Load Model & Tokenizer 46 | torch.set_grad_enabled(False) 47 | print("Loading Model") 48 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision) 49 | 50 | seeds = np.random.choice(100000, size=n_seeds) 51 | 52 | for seed in seeds: 53 | set_seed(seed) 54 | 55 | # Load the dataset 56 | print("Loading Dataset") 57 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) 58 | 59 | print("Computing Mean Activations") 60 | dataset = load_dataset(dataset_name, seed=seed) 61 | mean_activations = get_mean_layer_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 62 | n_icl_examples=n_shots, N_TRIALS=n_trials) 63 | 64 | 65 | print("Saving mean layer activations") 66 | if not os.path.exists(save_path_root): 67 | os.makedirs(save_path_root) 68 | 69 | # Write args to file 70 | args.save_path_root = save_path_root # update for logging 71 | with open(f'{save_path_root}/mean_layer_activation_args.txt', 'w') as arg_file: 72 | json.dump(args.__dict__, arg_file, indent=2) 73 | 74 | torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_layer_activations.pt') 75 | 76 | print("Evaluating Layer Avgs. Baseline") 77 | fs_results = n_shot_eval_no_intervention(dataset, n_shots, model, model_config, tokenizer) 78 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0] 79 | 80 | zs_res = {} 81 | fss_res = {} 82 | for i in range(model_config['n_layers']): 83 | zs_res[i] = n_shot_eval(dataset, mean_activations[i].unsqueeze(0), i, 0, model, model_config, tokenizer, filter_set=filter_set) 84 | fss_res[i] = n_shot_eval(dataset, mean_activations[i].unsqueeze(0), i, 10, model, model_config, tokenizer, filter_set=filter_set, shuffle_labels=True) 85 | 86 | with open(f'{save_path_root}/mean_layer_intervention_zs_results_sweep_{seed}.json', 'w') as interv_zsres_file: 87 | json.dump(zs_res, interv_zsres_file, indent=2) 88 | with open(f'{save_path_root}/mean_layer_intervention_fss_results_sweep_{seed}.json', 'w') as interv_fssres_file: 89 | json.dump(fss_res, interv_fssres_file, indent=2) 90 | -------------------------------------------------------------------------------- /src/compute_indirect_effect.py: -------------------------------------------------------------------------------- 1 | import os, re, json 2 | from tqdm import tqdm 3 | import torch, numpy as np 4 | import argparse 5 | from baukit import TraceDict 6 | 7 | # Include prompt creation helper functions 8 | from utils.prompt_utils import * 9 | from utils.intervention_utils import * 10 | from utils.model_utils import * 11 | from utils.extract_utils import * 12 | 13 | 14 | def activation_replacement_per_class_intervention(prompt_data, avg_activations, dummy_labels, model, model_config, tokenizer, last_token_only=True): 15 | """ 16 | Experiment to determine top intervention locations through avg activation replacement. 17 | Performs a systematic sweep over attention heads (layer, head) to track their causal influence on probs of key tokens. 18 | 19 | Parameters: 20 | prompt_data: dict containing ICL prompt examples, and template information 21 | avg_activations: avg activation of each attention head in the model taken across n_trials ICL prompts 22 | dummy_labels: labels and indices for a baseline prompt with the same number of example pairs 23 | model: huggingface model 24 | model_config: contains model config information (n layers, n heads, etc.) 25 | tokenizer: huggingface tokenizer 26 | last_token_only: If True, only computes indirect effect for heads at the final token position. If False, computes indirect_effect for heads for all token classes 27 | 28 | Returns: 29 | indirect_effect_storage: torch tensor containing the indirect_effect of each head for each token class. 30 | """ 31 | device = model.device 32 | 33 | # Get sentence and token labels 34 | query_target_pair = prompt_data['query_target'] 35 | 36 | query = query_target_pair['input'] 37 | token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query=query, prepend_bos=model_config['prepend_bos']) 38 | 39 | idx_map, idx_avg = compute_duplicated_labels(token_labels, dummy_labels) 40 | idx_map = update_idx_map(idx_map, idx_avg) 41 | 42 | sentences = [prompt_string]# * model.config.n_head # batch things by head 43 | 44 | # Figure out tokens of interest 45 | target = [query_target_pair['output']] 46 | token_id_of_interest = get_answer_id(sentences[0], target[0], tokenizer) 47 | if isinstance(token_id_of_interest, list): 48 | token_id_of_interest = token_id_of_interest[:1] 49 | 50 | inputs = tokenizer(sentences, return_tensors='pt').to(device) 51 | 52 | # Speed up computation by only computing causal effect at last token 53 | if last_token_only: 54 | token_classes = ['query_predictive'] 55 | token_classes_regex = ['query_predictive_token'] 56 | # Compute causal effect for all token classes (instead of just last token) 57 | else: 58 | token_classes = ['demonstration', 'label', 'separator', 'predictive', 'structural','end_of_example', 59 | 'query_demonstration', 'query_structural', 'query_separator', 'query_predictive'] 60 | token_classes_regex = ['demonstration_[\d]{1,}_token', 'demonstration_[\d]{1,}_label_token', 'separator_token', 'predictive_token', 'structural_token','end_of_example_token', 61 | 'query_demonstration_token', 'query_structural_token', 'query_separator_token', 'query_predictive_token'] 62 | 63 | 64 | indirect_effect_storage = torch.zeros(model_config['n_layers'], model_config['n_heads'],len(token_classes)) 65 | 66 | # Clean Run of Baseline: 67 | clean_output = model(**inputs).logits[:,-1,:] 68 | clean_probs = torch.softmax(clean_output[0], dim=-1) 69 | 70 | # For every layer, head, token combination perform the replacement & track the change in meaningful tokens 71 | for layer in range(model_config['n_layers']): 72 | head_hook_layer = [model_config['attn_hook_names'][layer]] 73 | 74 | for head_n in range(model_config['n_heads']): 75 | for i,(token_class, class_regex) in enumerate(zip(token_classes, token_classes_regex)): 76 | reg_class_match = re.compile(f"^{class_regex}$") 77 | class_token_inds = [x[0] for x in token_labels if reg_class_match.match(x[2])] 78 | 79 | intervention_locations = [(layer, head_n, token_n) for token_n in class_token_inds] 80 | intervention_fn = replace_activation_w_avg(layer_head_token_pairs=intervention_locations, avg_activations=avg_activations, 81 | model=model, model_config=model_config, 82 | batched_input=False, idx_map=idx_map, last_token_only=last_token_only) 83 | with TraceDict(model, layers=head_hook_layer, edit_output=intervention_fn) as td: 84 | output = model(**inputs).logits[:,-1,:] # batch_size x n_tokens x vocab_size, only want last token prediction 85 | 86 | # TRACK probs of tokens of interest 87 | intervention_probs = torch.softmax(output, dim=-1) # convert to probability distribution 88 | indirect_effect_storage[layer,head_n,i] = (intervention_probs-clean_probs).index_select(1, torch.LongTensor(token_id_of_interest).to(device).squeeze()).squeeze() 89 | 90 | return indirect_effect_storage 91 | 92 | 93 | def compute_indirect_effect(dataset, mean_activations, model, model_config, tokenizer, n_shots=10, n_trials=25, last_token_only=True, prefixes=None, separators=None, filter_set=None): 94 | """ 95 | Computes Indirect Effect of each head in the model 96 | 97 | Parameters: 98 | dataset: ICL dataset 99 | mean_activations: 100 | model: huggingface model 101 | model_config: contains model config information (n layers, n heads, etc.) 102 | tokenizer: huggingface tokenizer 103 | n_shots: Number of shots in each in-context prompt 104 | n_trials: Number of in-context prompts to average over 105 | last_token_only: If True, only computes Indirect Effect for heads at the final token position. If False, computes Indirect Effect for heads for all token classes 106 | 107 | 108 | Returns: 109 | indirect_effect: torch tensor of the indirect effect for each attention head in the model, size n_trials * n_layers * n_heads 110 | """ 111 | n_test_examples = 1 112 | 113 | if prefixes is not None and separators is not None: 114 | dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer, prefixes=prefixes, separators=separators, model_config=model_config) 115 | else: 116 | dummy_gt_labels = get_dummy_token_labels(n_shots, tokenizer=tokenizer, model_config=model_config) 117 | 118 | # If the model already prepends a bos token by default, we don't want to add one 119 | prepend_bos = False if model_config['prepend_bos'] else True 120 | 121 | if last_token_only: 122 | indirect_effect = torch.zeros(n_trials,model_config['n_layers'], model_config['n_heads']) 123 | else: 124 | indirect_effect = torch.zeros(n_trials,model_config['n_layers'], model_config['n_heads'],10) # have 10 classes of tokens 125 | 126 | if filter_set is None: 127 | filter_set = np.arange(len(dataset['valid'])) 128 | 129 | for i in tqdm(range(n_trials), total=n_trials): 130 | word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_shots, replace=False)] 131 | word_pairs_test = dataset['valid'][np.random.choice(filter_set,n_test_examples, replace=False)] 132 | if prefixes is not None and separators is not None: 133 | prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, shuffle_labels=True, 134 | prepend_bos_token=prepend_bos, prefixes=prefixes, separators=separators) 135 | else: 136 | prompt_data_random = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, 137 | shuffle_labels=True, prepend_bos_token=prepend_bos) 138 | 139 | ind_effects = activation_replacement_per_class_intervention(prompt_data=prompt_data_random, 140 | avg_activations = mean_activations, 141 | dummy_labels=dummy_gt_labels, 142 | model=model, model_config=model_config, tokenizer=tokenizer, 143 | last_token_only=last_token_only) 144 | indirect_effect[i] = ind_effects.squeeze() 145 | 146 | return indirect_effect 147 | 148 | 149 | if __name__ == "__main__": 150 | 151 | parser = argparse.ArgumentParser() 152 | 153 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True) 154 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 155 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 156 | parser.add_argument('--save_path_root', help='File path to save indirect effect to', type=str, required=False, default='../results') 157 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42) 158 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type =int, required=False, default=10) 159 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over", type=int, required=False, default=25) 160 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3) 161 | parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu') 162 | parser.add_argument('--mean_activations_path', help='Path to mean activations file used for intervention', required=False, type=str, default=None) 163 | parser.add_argument('--last_token_only', help='Whether to compute indirect effect for heads at only the final token position, or for all token classes', required=False, type=bool, default=True) 164 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""}) 165 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""}) 166 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None) 167 | 168 | args = parser.parse_args() 169 | 170 | dataset_name = args.dataset_name 171 | model_name = args.model_name 172 | root_data_dir = args.root_data_dir 173 | save_path_root = f"{args.save_path_root}/{dataset_name}" 174 | seed = args.seed 175 | n_shots = args.n_shots 176 | n_trials = args.n_trials 177 | test_split = args.test_split 178 | device = args.device 179 | mean_activations_path = args.mean_activations_path 180 | last_token_only = args.last_token_only 181 | prefixes = args.prefixes 182 | separators = args.separators 183 | 184 | 185 | # Load Model & Tokenizer 186 | torch.set_grad_enabled(False) 187 | print("Loading Model") 188 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision) 189 | 190 | set_seed(seed) 191 | 192 | # Load the dataset 193 | print("Loading Dataset") 194 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) 195 | 196 | 197 | if not os.path.exists(save_path_root): 198 | os.makedirs(save_path_root) 199 | 200 | # Load or Re-Compute Mean Activations 201 | if mean_activations_path is not None and os.path.exists(mean_activations_path): 202 | mean_activations = torch.load(mean_activations_path) 203 | elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'): 204 | mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt' 205 | mean_activations = torch.load(mean_activations_path) 206 | else: 207 | print("Computing Mean Activations") 208 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 209 | n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators) 210 | torch.save(mean_activations, f'{save_path_root}/{dataset_name}_mean_head_activations.pt') 211 | 212 | print("Computing Indirect Effect") 213 | indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer, 214 | n_shots=n_shots, n_trials=n_trials, last_token_only=last_token_only, prefixes=prefixes, separators=separators) 215 | 216 | # Write args to file 217 | args.save_path_root = save_path_root 218 | args.mean_activations_path = mean_activations_path 219 | with open(f'{save_path_root}/indirect_effect_args.txt', 'w') as arg_file: 220 | json.dump(args.__dict__, arg_file, indent=2) 221 | 222 | torch.save(indirect_effect, f'{save_path_root}/{dataset_name}_indirect_effect.pt') 223 | 224 | -------------------------------------------------------------------------------- /src/eval_scripts/eval_avg_hs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural') 3 | cd ../ 4 | 5 | for d_name in "${datasets[@]}" 6 | do 7 | echo "Running Script for: ${d_name}" 8 | python compute_avg_hidden_state.py --dataset_name="${d_name}" --save_path_root="results/gptj_avg_hs" --model_name='EleutherAI/gpt-j-6b' 9 | done 10 | -------------------------------------------------------------------------------- /src/eval_scripts/eval_fv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | datasets=('antonym') 3 | # datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural') 4 | cd ../ 5 | 6 | for d_name in "${datasets[@]}" 7 | do 8 | echo "Running Script for: ${d_name}" 9 | python evaluate_function_vector.py --dataset_name="${d_name}" --save_path_root="results/gptj" --model_name='EleutherAI/gpt-j-6b' 10 | done -------------------------------------------------------------------------------- /src/eval_scripts/eval_numheads.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural') 3 | cd ../ 4 | 5 | for d_name in "${datasets[@]}" 6 | do 7 | echo "Running Script for: ${d_name}" 8 | python test_numheads.py --dataset_name="${d_name}" --model_name='EleutherAI/gpt-j-6b' 9 | done -------------------------------------------------------------------------------- /src/eval_scripts/eval_template_portability.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | datasets=('antonym' 'capitalize' 'country-capital' 'english-french' 'present-past' 'singular-plural') 3 | cd ../ 4 | 5 | for d_name in "${datasets[@]}" 6 | do 7 | echo "Running Script for: ${d_name}" 8 | python portability_eval.py --dataset_name="${d_name}" --save_path_root="results/gptj" --model_name='EleutherAI/gpt-j-6b' 9 | done -------------------------------------------------------------------------------- /src/eval_scripts/fv_eval_sweep.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | 5 | # Submit slurm jobs for many tasks 6 | 7 | dataset_names = ['antonym', 'capitalize', 'country-capital', 'english-french', 'present-past', 'singular-plural'] 8 | MODEL_NAMES = ["EleutherAI/gpt-j-6b"] 9 | MODEL_NICKNAMES = ['gptj'] 10 | 11 | 12 | job_path = str(time.ctime()).replace(" ", "_") 13 | print(job_path) 14 | os.makedirs(job_path, exist_ok=True) 15 | 16 | d_name_to_cmd = {} 17 | 18 | ## creating the jobs 19 | for model_name,model_nickname in zip(MODEL_NAMES, MODEL_NICKNAMES): 20 | current_seed = np.random.randint(1000000) 21 | for idx, d_name in enumerate(dataset_names): 22 | results_path = os.path.join('results', f'{model_nickname}') 23 | n_fv_heads = 10 24 | 25 | cmd = f"python evaluate_function_vector.py --dataset_name='{d_name}' --save_path_root='{results_path}' --model_name='{model_name}' --n_top_heads={n_fv_heads} --seed={current_seed}" 26 | if 'squad' in d_name: 27 | cmd += " --n_shots=5 --generate_str --metric='f1_score'" 28 | elif 'ag_news' in d_name: 29 | cmd += " --n_shots=10 --generate_str --metric='first_word_score'" 30 | 31 | key = model_nickname + '_' + d_name 32 | d_name_to_cmd[key] = cmd 33 | 34 | 35 | for key in d_name_to_cmd: 36 | with open("template.sh", "r") as f: 37 | bash_template = f.readlines() 38 | bash_template.append(d_name_to_cmd[key]) 39 | 40 | with open(f"{job_path}/{key}.sh", "w") as f: 41 | f.writelines(bash_template) 42 | 43 | 44 | ## running the jobs 45 | for job in os.listdir(job_path): 46 | job_script = f"{job_path}/{job}" 47 | cmd = f"sbatch --gpus=1 --time=48:00:00 {job_script}" 48 | print("submitting job: ", job) 49 | print(cmd) 50 | os.system(cmd) 51 | print("\n\n") 52 | 53 | print("------------------------------------------------------------------") 54 | print(f"submitted {len(os.listdir(job_path))} jobs!") 55 | -------------------------------------------------------------------------------- /src/eval_scripts/template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/.bashrc 3 | cd ../ 4 | conda activate fv 5 | -------------------------------------------------------------------------------- /src/evaluate_function_vector.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import torch, numpy as np 3 | import argparse 4 | 5 | # Include prompt creation helper functions 6 | from utils.prompt_utils import * 7 | from utils.intervention_utils import * 8 | from utils.model_utils import * 9 | from utils.eval_utils import * 10 | from utils.extract_utils import * 11 | from compute_indirect_effect import compute_indirect_effect 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True) 18 | parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function vector', required=False, type=int, default=10) 19 | parser.add_argument('--edit_layer', help='Layer for intervention. If -1, sweep over all layers', type=int, required=False, default=-1) # 20 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 21 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 22 | parser.add_argument('--save_path_root', help='File path to save to', type=str, required=False, default='../results') 23 | parser.add_argument('--ie_path_root', help='File path to load indirect effects from', type=str, required=False, default=None) 24 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=42) 25 | parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu') 26 | parser.add_argument('--mean_activations_path', help='Path to file containing mean_head_activations for the specified task', required=False, type=str, default=None) 27 | parser.add_argument('--indirect_effect_path', help='Path to file containing indirect_effect scores for the specified task', required=False, type=str, default=None) 28 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3) 29 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type=int, required=False, default=10) 30 | parser.add_argument('--n_mean_activations_trials', help="Number of in-context prompts to average over for mean_activations", type=int, required=False, default=100) 31 | parser.add_argument('--n_indirect_effect_trials', help="Number of in-context prompts to average over for indirect_effect", type=int, required=False, default=25) 32 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""}) 33 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""}) 34 | parser.add_argument('--compute_baseline', help='Whether to compute the model baseline 0-shot -> n-shot performance', type=bool, required=False, default=True) 35 | parser.add_argument('--generate_str', help='Whether to generate long-form completions for the task', action='store_true', required=False) 36 | parser.add_argument("--metric", help="Metric to use when evaluating generated strings", type=str, required=False, default="f1_score") 37 | parser.add_argument("--universal_set", help="Flag for whether to evaluate using the univeral set of heads", action="store_true", required=False) 38 | parser.add_argument('--revision', help='Specify model checkpoints for pythia or olmo models', type=str, required=False, default=None) 39 | 40 | args = parser.parse_args() 41 | 42 | dataset_name = args.dataset_name 43 | model_name = args.model_name 44 | root_data_dir = args.root_data_dir 45 | save_path_root = f"{args.save_path_root}/{dataset_name}" 46 | ie_path_root = f"{args.ie_path_root}/{dataset_name}" if args.ie_path_root else save_path_root 47 | seed = args.seed 48 | device = args.device 49 | mean_activations_path = args.mean_activations_path 50 | indirect_effect_path = args.indirect_effect_path 51 | n_top_heads = args.n_top_heads 52 | eval_edit_layer = args.edit_layer 53 | 54 | test_split = float(args.test_split) 55 | n_shots = args.n_shots 56 | n_mean_activations_trials = args.n_mean_activations_trials 57 | n_indirect_effect_trials = args.n_indirect_effect_trials 58 | 59 | prefixes = args.prefixes 60 | separators = args.separators 61 | compute_baseline = args.compute_baseline 62 | 63 | generate_str = args.generate_str 64 | metric = args.metric 65 | universal_set = args.universal_set 66 | 67 | print(args) 68 | 69 | # Load Model & Tokenizer 70 | torch.set_grad_enabled(False) 71 | print("Loading Model") 72 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device, revision=args.revision) 73 | 74 | if args.edit_layer == -1: # sweep over all layers if edit_layer=-1 75 | eval_edit_layer = [0, model_config['n_layers']] 76 | 77 | # Load the dataset 78 | print("Loading Dataset") 79 | set_seed(seed) 80 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) 81 | 82 | if not os.path.exists(save_path_root): 83 | os.makedirs(save_path_root) 84 | 85 | print(f"Filtering Dataset via {n_shots}-shot Eval") 86 | # 1. Compute Model 10-shot Baseline & 2. Filter test set to cases where model gets it correct 87 | 88 | fs_results_file_name = f'{save_path_root}/fs_results_layer_sweep.json' 89 | print(fs_results_file_name) 90 | if os.path.exists(fs_results_file_name): 91 | with open(fs_results_file_name, 'r') as indata: 92 | fs_results = json.load(indata) 93 | key = 'score' if generate_str else 'clean_rank_list' 94 | target_val = 1 if generate_str else 0 95 | filter_set = np.where(np.array(fs_results[key]) == target_val)[0] 96 | filter_set_validation = None 97 | elif generate_str: 98 | set_seed(seed+42) 99 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, 100 | generate_str=True, metric=metric, test_split='valid', prefixes=prefixes, separators=separators) 101 | filter_set_validation = np.where(np.array(fs_results_validation['score']) == 1)[0] 102 | set_seed(seed) 103 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, 104 | generate_str=True, metric=metric, prefixes=prefixes, separators=separators) 105 | filter_set = np.where(np.array(fs_results['score']) == 1)[0] 106 | else: 107 | set_seed(seed+42) 108 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True, test_split='valid', prefixes=prefixes, separators=separators) 109 | filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0] 110 | set_seed(seed) 111 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True, prefixes=prefixes, separators=separators) 112 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0] 113 | 114 | args.fs_results_file_name = fs_results_file_name 115 | with open(fs_results_file_name, 'w') as results_file: 116 | json.dump(fs_results, results_file, indent=2) 117 | 118 | set_seed(seed) 119 | # Load or Re-Compute mean_head_activations 120 | if mean_activations_path is not None and os.path.exists(mean_activations_path): 121 | mean_activations = torch.load(mean_activations_path) 122 | elif mean_activations_path is None and os.path.exists(f'{ie_path_root}/{dataset_name}_mean_head_activations.pt'): 123 | mean_activations_path = f'{ie_path_root}/{dataset_name}_mean_head_activations.pt' 124 | mean_activations = torch.load(mean_activations_path) 125 | else: 126 | print("Computing Mean Activations") 127 | set_seed(seed) 128 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots, 129 | N_TRIALS=n_mean_activations_trials, prefixes=prefixes, separators=separators, filter_set=filter_set_validation) 130 | args.mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt' 131 | torch.save(mean_activations, args.mean_activations_path) 132 | 133 | # Load or Re-Compute indirect_effect values 134 | if indirect_effect_path is not None and os.path.exists(indirect_effect_path): 135 | indirect_effect = torch.load(indirect_effect_path) 136 | elif indirect_effect_path is None and os.path.exists(f'{ie_path_root}/{dataset_name}_indirect_effect.pt'): 137 | indirect_effect_path = f'{ie_path_root}/{dataset_name}_indirect_effect.pt' 138 | indirect_effect = torch.load(indirect_effect_path) 139 | elif not universal_set: # Only compute indirect effects if we need to 140 | print("Computing Indirect Effects") 141 | set_seed(seed) 142 | indirect_effect = compute_indirect_effect(dataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer, n_shots=n_shots, 143 | n_trials=n_indirect_effect_trials, last_token_only=True, prefixes=prefixes, separators=separators, filter_set=filter_set_validation) 144 | args.indirect_effect_path = f'{save_path_root}/{dataset_name}_indirect_effect.pt' 145 | torch.save(indirect_effect, args.indirect_effect_path) 146 | 147 | # Compute Function Vector 148 | if universal_set: 149 | fv, top_heads = compute_universal_function_vector(mean_activations, model, model_config=model_config, n_top_heads=n_top_heads) 150 | else: 151 | fv, top_heads = compute_function_vector(mean_activations, indirect_effect, model, model_config=model_config, n_top_heads=n_top_heads) 152 | 153 | # Run Evaluation 154 | if isinstance(eval_edit_layer, int): 155 | print(f"Running ZS Eval with edit_layer={eval_edit_layer}") 156 | set_seed(seed) 157 | if generate_str: 158 | pred_filepath = f"{save_path_root}/preds/{model_config['name_or_path'].replace('/', '_')}_ZS_intervention_layer{eval_edit_layer}.txt" 159 | zs_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=0, 160 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, 161 | generate_str=generate_str, metric=metric, pred_filepath=pred_filepath, prefixes=prefixes, separators=separators) 162 | else: 163 | zs_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=0, 164 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, prefixes=prefixes, separators=separators) 165 | zs_results_file_suffix = f'_editlayer_{eval_edit_layer}.json' 166 | 167 | 168 | print(f"Running {n_shots}-Shot Shuffled Eval") 169 | set_seed(seed) 170 | if generate_str: 171 | pred_filepath = f"{save_path_root}/preds/{model_config['name_or_path'].replace('/', '_')}_{n_shots}shots_shuffled_intervention_layer{eval_edit_layer}.txt" 172 | fs_shuffled_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=n_shots, 173 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, shuffle_labels=True, 174 | generate_str=generate_str, metric=metric, pred_filepath=pred_filepath, prefixes=prefixes, separators=separators) 175 | else: 176 | fs_shuffled_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=eval_edit_layer, n_shots=n_shots, 177 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, shuffle_labels=True, prefixes=prefixes, separators=separators) 178 | fs_shuffled_results_file_suffix = f'_editlayer_{eval_edit_layer}.json' 179 | 180 | else: 181 | print(f"Running sweep over layers {eval_edit_layer}") 182 | zs_results = {} 183 | fs_shuffled_results = {} 184 | for edit_layer in range(eval_edit_layer[0], eval_edit_layer[1]): 185 | set_seed(seed) 186 | if generate_str: 187 | zs_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=0, 188 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set, 189 | generate_str=generate_str, metric=metric, prefixes=prefixes, separators=separators) 190 | else: 191 | zs_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=0, prefixes=prefixes, separators=separators, 192 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set) 193 | set_seed(seed) 194 | if generate_str: 195 | fs_shuffled_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=n_shots, 196 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set = filter_set, 197 | generate_str=generate_str, metric=metric, shuffle_labels=True, prefixes=prefixes, separators=separators) 198 | else: 199 | fs_shuffled_results[edit_layer] = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=edit_layer, n_shots=n_shots, 200 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set = filter_set, shuffle_labels=True, prefixes=prefixes, separators=separators) 201 | zs_results_file_suffix = '_layer_sweep.json' 202 | fs_shuffled_results_file_suffix = '_layer_sweep.json' 203 | 204 | 205 | # Save results to files 206 | zs_results_file_name = make_valid_path_name(f'{save_path_root}/zs_results' + zs_results_file_suffix) 207 | args.zs_results_file_name = zs_results_file_name 208 | with open(zs_results_file_name, 'w') as results_file: 209 | json.dump(zs_results, results_file, indent=2) 210 | 211 | fs_shuffled_results_file_name = make_valid_path_name(f'{save_path_root}/fs_shuffled_results' + fs_shuffled_results_file_suffix) 212 | args.fs_shuffled_results_file_name = fs_shuffled_results_file_name 213 | with open(fs_shuffled_results_file_name, 'w') as results_file: 214 | json.dump(fs_shuffled_results, results_file, indent=2) 215 | 216 | if compute_baseline: 217 | print(f"Computing model baseline results for {n_shots}-shots") 218 | baseline_results = compute_dataset_baseline(dataset, model, model_config, tokenizer, n_shots=n_shots, seed=seed, prefixes=prefixes, separators=separators) 219 | 220 | baseline_file_name = make_valid_path_name(f'{save_path_root}/model_baseline.json') 221 | args.baseline_file_name = baseline_file_name 222 | with open(baseline_file_name, 'w') as results_file: 223 | json.dump(baseline_results, results_file, indent=2) 224 | 225 | # Write args to file 226 | args_file_name = make_valid_path_name(f'{save_path_root}/fv_eval_args.txt') 227 | with open(args_file_name, 'w') as arg_file: 228 | json.dump(args.__dict__, arg_file, indent=2) 229 | -------------------------------------------------------------------------------- /src/natural_text_eval.py: -------------------------------------------------------------------------------- 1 | import os, re, json 2 | import torch, numpy as np 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | # Include prompt creation helper functions 7 | from utils.eval_utils import n_shot_eval_no_intervention 8 | from utils.extract_utils import get_mean_head_activations, compute_universal_function_vector 9 | from utils.intervention_utils import fv_intervention_natural_text 10 | from utils.model_utils import load_gpt_model_and_tokenizer 11 | from utils.prompt_utils import load_dataset 12 | 13 | 14 | def natural_text_eval(dataset, fv_vector, model, model_config, tokenizer, filter_set, edit_layer=9, MNT=5, verbose=False, loc_data=False): 15 | """ 16 | Evaluates the causal effects of a function vector on natural text templates for the given dataset. 17 | 18 | Parameters: 19 | dataset: ICL dataset with pairs of words 20 | fv_vector: function vector to use for intervention 21 | model: huggingface model 22 | model_config: contains model config information (n layers, n heads, etc.) 23 | tokenizer: huggingface tokenizer 24 | filter_set: list of samples to filter to, used to include samples the model gets correct via ICL 25 | edit_layer: the layer to add the function vector to 26 | MNT: max number of tokens to generate 27 | verbose: whether to print outputs of clean & +FV generations. 28 | loc_data: whether the dataset is locations (e.g. country-capital, national parks, etc.) 29 | 30 | Returns: 31 | all_scores: scores for model when adding the FV during generation 32 | all_clean_scores: scores for base model (no intervention) 33 | sentences: sentence templates used during eval 34 | """ 35 | all_scores = {} 36 | all_clean_scores = {} 37 | 38 | if loc_data: # country-capital & similar datasets 39 | sentences = ["A couple years ago I visited {X}, and", 40 | "If you ever travel to {X}, you have to visit", 41 | "When you think of {X},"] 42 | else: 43 | sentences = ["The word \"{X}\", means", 44 | "When I think of the word \"{X}\", it usually means", 45 | "When I think of \"{X}\", I usually", 46 | "While reading a book, I came across the word \"{X}\". I looked it up in a dictionary and it turns out that it means", 47 | "The word \"{X}\" can be understood as a synonym for"] 48 | 49 | for j in range(len(sentences)): 50 | scores = [] 51 | clean_scores = [] 52 | for i in tqdm(range(len(filter_set)), total=len(filter_set)): 53 | ind = int(filter_set[i]) 54 | q_pair = dataset['test'][ind] 55 | if isinstance(q_pair['input'], list): 56 | q_pair['input'] = q_pair['input'][0] 57 | if isinstance(q_pair['output'], list): 58 | q_pair['output'] = q_pair['output'][0] 59 | 60 | sentence = sentences[j] 61 | sentence = sentence.replace('{X}', f"{q_pair['input']}") 62 | 63 | clean_output, fv_output = fv_intervention_natural_text(sentence, edit_layer, fv_vector, model, model_config, tokenizer, max_new_tokens=MNT) 64 | clean_out_str = repr(tokenizer.decode(clean_output.squeeze()[-MNT:])) 65 | fv_out_str = repr(tokenizer.decode(fv_output.squeeze()[-MNT:])) 66 | 67 | if verbose: 68 | print("\nQuery/Target: ", q_pair) 69 | print("Prompt: ", repr(sentence)) 70 | print("clean completion:" , clean_out_str) 71 | print("+FV completion:", fv_out_str, '\n') 72 | 73 | scores.append(int(q_pair['output'] in fv_out_str)) 74 | clean_scores.append(int(q_pair['output'] in clean_out_str)) 75 | 76 | all_scores[j] = scores 77 | all_clean_scores[j] = clean_scores 78 | 79 | return all_scores, all_clean_scores, sentences 80 | 81 | def nattext_main(datasets, model, model_config, tokenizer, root_data_dir='../dataset_files', edit_layer=9, n_shots=10, n_trials=100, n_seeds=5): 82 | """ 83 | Main function that evaluates causal effects of function vectors on natural text templates. 84 | 85 | Parameters: 86 | datasets: list of dataset names to evaluate 87 | model: huggingface model 88 | model_config: contains model config information (n layers, n heads, etc.) 89 | tokenizer: huggingface tokenizer 90 | root_data_dir: directory data is contained in 91 | edit_layer: layer to add the function vector to during intervention 92 | n_shots: number of shots for prompts used when computing task-conditioned mean head activations 93 | n_trials: number of prompts to include when computing task-conditioned mean head activations 94 | n_seeds: number of seeds to average results over 95 | 96 | Returns: 97 | clean_results_dict: dict containing results for base model (no intervention) 98 | interv_results_dict: results for model when adding the function vector at edit_layer during generation 99 | seeds_dict: dict containing the seeds used during evaluation 100 | """ 101 | interv_results_dict = {k:[] for k in datasets} 102 | clean_results_dict = {k:[] for k in datasets} 103 | seeds_dict = {k:[] for k in datasets} 104 | 105 | # Test Loop: 106 | for dataset_name in datasets: 107 | if dataset_name == 'country-capital': 108 | loc_data = True 109 | max_new_tokens = 10 110 | else: 111 | loc_data = False 112 | max_new_tokens = 5 113 | 114 | for _ in range(n_seeds): 115 | seed = np.random.randint(100000) 116 | seeds_dict[dataset_name].append(seed) 117 | dataset = load_dataset(dataset_name, seed=seed, root_data_dir=root_data_dir) 118 | 119 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='valid') 120 | filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0] 121 | 122 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots, 123 | N_TRIALS=n_trials, filter_set=filter_set_validation) 124 | fv, _ = compute_universal_function_vector(mean_activations, model, model_config) 125 | 126 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='test') 127 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0] 128 | 129 | results, clean_results, _ = natural_text_eval(dataset, fv, model, model_config, tokenizer, filter_set, MNT=max_new_tokens, edit_layer=edit_layer, verbose=False, loc_data=loc_data) 130 | 131 | clean_results_dict[dataset_name].append([np.mean(clean_results[i]) for i in clean_results.keys()]) 132 | interv_results_dict[dataset_name].append([np.mean(results[i]) for i in results.keys()]) 133 | 134 | return clean_results_dict, interv_results_dict, seeds_dict 135 | 136 | 137 | if __name__ == "__main__": 138 | 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 141 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 142 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results') 143 | parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5) 144 | parser.add_argument('--n_trials', help='Number of trials to use for computing task-conditioned mean head activations', type=int, required=False, default=100) 145 | parser.add_argument('--n_shots', help='Number of shots to use for prompts when computing task-conditioned mean head activations', type=int, required=False, default=10) 146 | parser.add_argument('--edit_layer', help='Layer to add function vector to', type=int, required=False, default=9) 147 | 148 | args = parser.parse_args() 149 | 150 | # Gather inputs 151 | model_name = args.model_name 152 | root_data_dir = args.root_data_dir 153 | save_path_root = args.save_path_root 154 | n_seeds = args.n_seeds 155 | n_trials = args.n_trials 156 | n_shots = args.n_shots 157 | edit_layer = args.edit_layer 158 | 159 | 160 | # Load Model & Tokenizer 161 | torch.set_grad_enabled(False) 162 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name) 163 | 164 | datasets = ['antonym', 'capitalize', 'country-capital', 'english-french', 'present-past', 'singular-plural'] 165 | args.datasets = datasets 166 | 167 | # Run Natural Text Eval 168 | clean_results_dict, interv_results_dict, seeds_dict = nattext_main(datasets, model, model_config, tokenizer, 169 | root_data_dir=root_data_dir, edit_layer=edit_layer, 170 | n_shots=n_shots, n_trials=n_trials, n_seeds=n_seeds) 171 | 172 | # Extract Summary Results: 173 | os.makedirs(os.path.join(save_path_root), exist_ok=True) 174 | with open(os.path.join(save_path_root, 'nattext_eval_results.txt'), 'w') as out_file: 175 | for d in datasets: 176 | print(f"{d.title()}:", file=out_file) 177 | clean_acc = np.array(clean_results_dict[d]).mean(axis=0) 178 | clean_std = np.array(clean_results_dict[d]).std(axis=0) 179 | fv_acc = np.array(interv_results_dict[d]).mean(axis=0) 180 | fv_std = np.array(interv_results_dict[d]).std(axis=0) 181 | 182 | print("clean results:", clean_acc.round(3)*100, '% +/-', clean_std.round(3)*100, file=out_file) 183 | print("fv results:", fv_acc.round(3)*100, '% +/-', fv_std.round(3)*100, file=out_file) 184 | 185 | # Write args to a file 186 | args.seeds_dict = seeds_dict 187 | with open(os.path.join(save_path_root, 'nattext_eval_args.txt'), 'w') as arg_file: 188 | print(args.__dict__, file=arg_file) 189 | 190 | 191 | -------------------------------------------------------------------------------- /src/portability_eval.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import torch, numpy as np 3 | import argparse 4 | 5 | # Include prompt creation helper functions 6 | from utils.prompt_utils import * 7 | from utils.intervention_utils import * 8 | from utils.model_utils import * 9 | from utils.eval_utils import * 10 | from utils.extract_utils import * 11 | 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--dataset_name', help='Name of the dataset to be loaded', type=str, required=True) 18 | parser.add_argument('--n_eval_templates', help='Number of templates to evaluate with', required=True, type=int, default=15) 19 | parser.add_argument('--edit_layer', help='Layer for intervention. If -1, sweep over all layers', type=int, required=False, default=9) # 20 | 21 | parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function vector', required=False, type=int, default=10) 22 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 23 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 24 | parser.add_argument('--save_path_root', help='File path to save to', type=str, required=False, default='../results') 25 | parser.add_argument('--seed', help='Randomized seed', type=int, required=False, default=5678) 26 | parser.add_argument('--device', help='Device to run on',type=str, required=False, default='cuda' if torch.cuda.is_available() else 'cpu') 27 | parser.add_argument('--mean_activations_path', help='Path to file containing mean_head_activations for the specified task', required=False, type=str, default=None) 28 | parser.add_argument('--test_split', help="Percentage corresponding to test set split size", required=False, default=0.3) 29 | parser.add_argument('--n_shots', help="Number of shots in each in-context prompt", type=int, required=False, default=10) 30 | parser.add_argument('--n_trials', help="Number of in-context prompts to average over for indirect_effect", type=int, required=False, default=25) 31 | parser.add_argument('--prefixes', help='Prompt template prefixes to be used', type=json.loads, required=False, default={"input":"Q:", "output":"A:", "instructions":""}) 32 | parser.add_argument('--separators', help='Prompt template separators to be used', type=json.loads, required=False, default={"input":"\n", "output":"\n\n", "instructions":""}) 33 | 34 | args = parser.parse_args() 35 | 36 | dataset_name = args.dataset_name 37 | model_name = args.model_name 38 | root_data_dir = args.root_data_dir 39 | save_path_root = f"{args.save_path_root}/{dataset_name}" 40 | seed = args.seed 41 | device = args.device 42 | mean_activations_path = args.mean_activations_path 43 | n_top_heads = args.n_top_heads 44 | eval_edit_layer = args.edit_layer 45 | 46 | test_split = args.test_split 47 | n_shots = args.n_shots 48 | n_trials = args.n_trials 49 | 50 | prefixes = args.prefixes 51 | separators = args.separators 52 | 53 | n_eval_templates = args.n_eval_templates 54 | 55 | print(args) 56 | 57 | # Load Model & Tokenizer 58 | torch.set_grad_enabled(False) 59 | print("Loading Model") 60 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name, device=device) 61 | 62 | if args.edit_layer == -1: # sweep over all layers if edit_layer=-1 63 | eval_edit_layer = [0, model_config['n_layers']] 64 | 65 | # Load the dataset 66 | print("Loading Dataset") 67 | set_seed(seed) 68 | dataset = load_dataset(dataset_name, root_data_dir=root_data_dir, test_size=test_split, seed=seed) 69 | 70 | if not os.path.exists(save_path_root): 71 | os.makedirs(save_path_root) 72 | 73 | # Load or Re-Compute mean_head_activations 74 | if mean_activations_path is not None and os.path.exists(mean_activations_path): 75 | mean_activations = torch.load(mean_activations_path) 76 | elif mean_activations_path is None and os.path.exists(f'{save_path_root}/{dataset_name}_mean_head_activations.pt'): 77 | mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt' 78 | mean_activations = torch.load(mean_activations_path) 79 | else: 80 | print("Computing Mean Activations") 81 | set_seed(seed) 82 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, 83 | n_icl_examples=n_shots, N_TRIALS=n_trials, prefixes=prefixes, separators=separators) 84 | args.mean_activations_path = f'{save_path_root}/{dataset_name}_mean_head_activations.pt' 85 | torch.save(mean_activations, args.mean_activations_path) 86 | 87 | # Compute Function Vector 88 | fv, top_heads = compute_universal_function_vector(mean_activations, model, model_config=model_config, n_top_heads=n_top_heads) 89 | 90 | print("Computing Portability") 91 | fs_res_dict, zs_res_dict,fs_shuffled_res_dict, templates = portability_eval(dataset, fv, eval_edit_layer, model, model_config, tokenizer, n_eval_templates=n_eval_templates) 92 | 93 | args.templates = templates 94 | 95 | save_path_root = f"{args.save_path_root}_port/{dataset_name}" 96 | if not os.path.exists(save_path_root): 97 | os.makedirs(save_path_root) 98 | 99 | fs_results_file_name = make_valid_path_name(f'{save_path_root}/fs_port_eval.json') 100 | args.fs_results_file_name = fs_results_file_name 101 | with open(fs_results_file_name,'w') as fs_results_file: 102 | json.dump(fs_res_dict, fs_results_file,indent=2) 103 | 104 | fs_shuffled_results_file_name = make_valid_path_name(f'{save_path_root}/fs_shuffled_port_eval.json') 105 | args.fs_shuffled_results_file_name = fs_shuffled_results_file_name 106 | with open(fs_shuffled_results_file_name,'w') as fs_shuffled_results_file: 107 | json.dump(fs_shuffled_res_dict, fs_shuffled_results_file,indent=2) 108 | 109 | zs_results_file_name = make_valid_path_name(f'{save_path_root}/zs_port_eval.json') 110 | args.zs_results_file_name = zs_results_file_name 111 | with open(zs_results_file_name,'w') as zs_results_file: 112 | json.dump(zs_res_dict, zs_results_file,indent=2) 113 | 114 | args_file_name = make_valid_path_name(f'{save_path_root}/port_eval_args.txt') 115 | with open(args_file_name, 'w') as arg_file: 116 | json.dump(args.__dict__, arg_file, indent=2) -------------------------------------------------------------------------------- /src/test_numheads.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import numpy as np 5 | import torch 6 | 7 | from src.utils.eval_utils import n_shot_eval, n_shot_eval_no_intervention 8 | from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed 9 | from src.utils.prompt_utils import load_dataset 10 | from src.evaluate_function_vector import compute_universal_function_vector 11 | 12 | # Evaluates how performance changes as the number of heads used to create a Function Vector increases 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--dataset_name', help="dataset to be evaluated", type=str, required=True) 17 | parser.add_argument('--mean_act_root', help="root path to mean activations", type=str, required=False, default='IE_template_QA/gptj') 18 | parser.add_argument('--model_name', type=str, required=True, default='EleutherAI/gpt-j-6b') 19 | parser.add_argument('--model_nickname', type=str, required=False, default='gptj') 20 | parser.add_argument('--n_heads', type=int, help="upper bound of the number of heads to create the FV", required=True, default=40) 21 | parser.add_argument('--edit_layer', type=int, help="layer at which to add the function vector", required=True, default=9) 22 | parser.add_argument('--seed', required=False, type=int, default=42) 23 | parser.add_argument('--save_path_root', required=True, type=str, default='../results') 24 | 25 | 26 | args = parser.parse_args() 27 | mean_act_root = args.mean_act_root 28 | model_name = args.model_name 29 | model_nickname = args.model_nickname 30 | dataset_name = args.dataset_name 31 | n_heads = args.n_heads 32 | edit_layer = args.edit_layer 33 | seed = args.seed 34 | save_path_root = args.save_path_root 35 | 36 | 37 | # Load Model & Tokenizer, doing inference so don't need gradients 38 | torch.set_grad_enabled(False) 39 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name) 40 | dataset = load_dataset(dataset_name) 41 | mean_activations = torch.load(f'{save_path_root}/{mean_act_root}/{dataset_name}/{dataset_name}_mean_head_activations.pt') 42 | 43 | 44 | set_seed(seed) 45 | fs_results = n_shot_eval_no_intervention(dataset, n_shots=10, model=model, model_config=model_config, tokenizer=tokenizer) 46 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0] 47 | print("Sanity Check, cleantopk: ", fs_results['clean_topk']) 48 | zs_results = {} 49 | 50 | for i in range(n_heads+1): 51 | fv, _ = compute_universal_function_vector(mean_activations, model, model_config, i) 52 | zs_results[i] = n_shot_eval(dataset, fv, edit_layer, 0, model, model_config, tokenizer, filter_set=filter_set) 53 | 54 | 55 | os.makedirs(f'{save_path_root}/{model_nickname}_test_numheads', exist_ok=True) 56 | json.dump(zs_results, open(f'{save_path_root}/{model_nickname}_test_numheads/{dataset_name}_perf_v_heads.json', 'w')) -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ericwtodd/function_vectors/751e2219d304eba471cffcacc9efd89a4f8ef3c4/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/intervention_utils.py: -------------------------------------------------------------------------------- 1 | from baukit import TraceDict, get_module 2 | import torch 3 | import re 4 | import bitsandbytes as bnb 5 | 6 | def get_module(model, name): 7 | """ 8 | Finds the named module within the given model. 9 | """ 10 | for n, m in model.named_modules(): 11 | if n == name: 12 | return m 13 | raise LookupError(name) 14 | 15 | 16 | def replace_activation_w_avg(layer_head_token_pairs, avg_activations, model, model_config, idx_map, batched_input=False, last_token_only=False): 17 | """ 18 | An intervention function for replacing activations with a computed average value. 19 | This function replaces the output of one (or several) attention head(s) with a pre-computed average value 20 | (usually taken from another set of runs with a particular property). 21 | The batched_input flag is used for systematic interventions where we are sweeping over all attention heads for a given (layer,token) 22 | The last_token_only flag is used for interventions where we only intervene on the last token (such as zero-shot or concept-naming) 23 | 24 | Parameters: 25 | layer_head_token_pairs: list of tuple triplets each containing a layer index, head index, and token index [(L,H,T), ...] 26 | avg_activations: torch tensor of the average activations (across ICL prompts) for each attention head of the model. 27 | model: huggingface model 28 | model_config: contains model config information (n layers, n heads, etc.) 29 | idx_map: dict mapping prompt label indices to ground truth label indices 30 | batched_input: whether or not to batch the intervention across all heads 31 | last_token_only: whether our intervention is only at the last token 32 | 33 | Returns: 34 | rep_act: A function that specifies how to replace activations with an average when given a hooked pytorch module. 35 | """ 36 | edit_layers = [x[0] for x in layer_head_token_pairs] 37 | 38 | def rep_act(output, layer_name, inputs): 39 | current_layer = int(layer_name.split('.')[2]) 40 | if current_layer in edit_layers: 41 | if isinstance(inputs, tuple): 42 | inputs = inputs[0] 43 | 44 | # Determine shapes for intervention 45 | original_shape = inputs.shape 46 | new_shape = inputs.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads) 47 | inputs = inputs.view(*new_shape) # inputs shape: (batch_size , tokens (n), heads, hidden_dim) 48 | 49 | # Perform Intervention: 50 | if batched_input: 51 | # Patch activations from avg activations into baseline sentences (i.e. n_head baseline sentences being modified in this case) 52 | for i in range(model_config['n_heads']): 53 | layer, head_n, token_n = layer_head_token_pairs[i] 54 | inputs[i, token_n, head_n] = avg_activations[layer, head_n, idx_map[token_n]] 55 | elif last_token_only: 56 | # Patch activations only at the last token for interventions like 57 | for (layer,head_n,token_n) in layer_head_token_pairs: 58 | if layer == current_layer: 59 | inputs[-1,-1,head_n] = avg_activations[layer,head_n,idx_map[token_n]] 60 | else: 61 | # Patch activations into baseline sentence found at index, -1 of the batch (targeted & multi-token patching) 62 | for (layer, head_n, token_n) in layer_head_token_pairs: 63 | if layer == current_layer: 64 | inputs[-1, token_n, head_n] = avg_activations[layer,head_n,idx_map[token_n]] 65 | 66 | inputs = inputs.view(*original_shape) 67 | proj_module = get_module(model, layer_name) 68 | out_proj = proj_module.weight 69 | 70 | if 'gpt2-xl' in model_config['name_or_path']: # GPT2-XL uses Conv1D (not nn.Linear) & has a bias term, GPTJ does not 71 | out_proj_bias = proj_module.bias 72 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj) 73 | 74 | elif 'gpt-j' in model_config['name_or_path'] or 'gemma' in model_config['name_or_path']: 75 | new_output = torch.matmul(inputs, out_proj.T) 76 | 77 | elif 'gpt-neox' in model_config['name_or_path'] or 'pythia' in model_config['name_or_path']: 78 | out_proj_bias = proj_module.bias 79 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj.T) 80 | 81 | elif 'llama' in model_config['name_or_path']: 82 | if '70b' in model_config['name_or_path']: 83 | # need to dequantize weights 84 | out_proj_dequant = bnb.functional.dequantize_4bit(out_proj.data, out_proj.quant_state) 85 | new_output = torch.matmul(inputs, out_proj_dequant.T) 86 | else: 87 | new_output = torch.matmul(inputs, out_proj.T) 88 | 89 | elif 'olmo' in model_config['name_or_path'].lower(): 90 | new_output = torch.matmul(inputs, out_proj.T) 91 | 92 | return new_output 93 | else: 94 | return output 95 | 96 | return rep_act 97 | 98 | def add_function_vector(edit_layer, fv_vector, device, idx=-1): 99 | """ 100 | Adds a vector to the output of a specified layer in the model 101 | 102 | Parameters: 103 | edit_layer: the layer to perform the FV intervention 104 | fv_vector: the function vector to add as an intervention 105 | device: device of the model (cuda gpu or cpu) 106 | idx: the token index to add the function vector at 107 | 108 | Returns: 109 | add_act: a fuction specifying how to add a function vector to a layer's output hidden state 110 | """ 111 | def add_act(output, layer_name): 112 | current_layer = int(layer_name.split(".")[2]) 113 | if current_layer == edit_layer: 114 | if isinstance(output, tuple): 115 | output[0][:, idx] += fv_vector.to(device) 116 | return output 117 | else: 118 | return output 119 | else: 120 | return output 121 | 122 | return add_act 123 | 124 | def function_vector_intervention(sentence, target, edit_layer, function_vector, model, model_config, tokenizer, compute_nll=False, 125 | generate_str=False): 126 | """ 127 | Runs the model on the sentence and adds the function_vector to the output of edit_layer as a model intervention, predicting a single token. 128 | Returns the output of the model with and without intervention. 129 | 130 | Parameters: 131 | sentence: the sentence to be run through the model 132 | target: expected response of the model (str, or [str]) 133 | edit_layer: layer at which to add the function vector 134 | function_vector: torch vector that triggers execution of a task 135 | model: huggingface model 136 | model_config: contains model config information (n layers, n heads, etc.) 137 | tokenizer: huggingface tokenizer 138 | compute_nll: whether to compute the negative log likelihood of a teacher-forced completion (used to compute perplexity (PPL)) 139 | generate_str: whether to generate a string of tokens or predict a single token 140 | 141 | Returns: 142 | fvi_output: a tuple containing output results of a clean run and intervened run of the model 143 | """ 144 | # Clean Run, No Intervention: 145 | device = model.device 146 | inputs = tokenizer(sentence, return_tensors='pt').to(device) 147 | original_pred_idx = len(inputs.input_ids.squeeze()) - 1 148 | 149 | if compute_nll: 150 | target_completion = "".join(sentence + target) 151 | nll_inputs = tokenizer(target_completion, return_tensors='pt').to(device) 152 | nll_targets = nll_inputs.input_ids.clone() 153 | target_len = len(nll_targets.squeeze()) - len(inputs.input_ids.squeeze()) 154 | nll_targets[:,:-target_len] = -100 # This is the accepted value to skip indices when computing loss (see nn.CrossEntropyLoss default) 155 | output = model(**nll_inputs, labels=nll_targets) 156 | clean_nll = output.loss.item() 157 | clean_output = output.logits[:,original_pred_idx,:] 158 | intervention_idx = -1 - target_len 159 | elif generate_str: 160 | MAX_NEW_TOKENS = 16 161 | output = model.generate(inputs.input_ids, top_p=0.9, temperature=0.1, 162 | max_new_tokens=MAX_NEW_TOKENS) 163 | clean_output = tokenizer.decode(output.squeeze()[-MAX_NEW_TOKENS:]) 164 | intervention_idx = -1 165 | else: 166 | clean_output = model(**inputs).logits[:,-1,:] 167 | intervention_idx = -1 168 | 169 | # Perform Intervention 170 | intervention_fn = add_function_vector(edit_layer, function_vector.reshape(1, model_config['resid_dim']), model.device, idx=intervention_idx) 171 | with TraceDict(model, layers=model_config['layer_hook_names'], edit_output=intervention_fn): 172 | if compute_nll: 173 | output = model(**nll_inputs, labels=nll_targets) 174 | intervention_nll = output.loss.item() 175 | intervention_output = output.logits[:,original_pred_idx,:] 176 | elif generate_str: 177 | output = model.generate(inputs.input_ids, top_p=0.9, temperature=0.1, 178 | max_new_tokens=MAX_NEW_TOKENS) 179 | intervention_output = tokenizer.decode(output.squeeze()[-MAX_NEW_TOKENS:]) 180 | else: 181 | intervention_output = model(**inputs).logits[:,-1,:] # batch_size x n_tokens x vocab_size, only want last token prediction 182 | 183 | fvi_output = (clean_output, intervention_output) 184 | if compute_nll: 185 | fvi_output += (clean_nll, intervention_nll) 186 | 187 | return fvi_output 188 | 189 | 190 | def fv_intervention_natural_text(sentence, edit_layer, function_vector, model, model_config, tokenizer, max_new_tokens=16, num_interv_tokens=None, do_sample=False): 191 | """ 192 | Allows for intervention in natural text where we generate and intervene on several tokens in a row. 193 | 194 | Parameters: 195 | sentence: sentence to intervene on with the FV 196 | edit_layer: layer at which to add the function vector 197 | function_vector: vector to add to the model that triggers execution of a task 198 | model: huggingface model 199 | model_config: dict with model config parameters (n_layers, n_heads, etc.) 200 | tokenizer: huggingface tokenizer 201 | max_new_tokens: number of tokens to generate 202 | num_interv_tokens: number of tokens to apply the intervention for (defaults to all subsequent generations) 203 | do_sample: whether to sample from top p tokens (True) or have deterministic greedy decoding (False) 204 | 205 | Returns: 206 | clean_output: tokens of clean output 207 | intervention_output: tokens of intervention output 208 | 209 | """ 210 | # Clean Run, No Intervention: 211 | device = model.device 212 | inputs = tokenizer(sentence, return_tensors='pt').to(device) 213 | clean_output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id) 214 | 215 | # Perform Intervention 216 | intervention_fn = add_function_vector(edit_layer, function_vector, model.device) 217 | 218 | if num_interv_tokens is not None and num_interv_tokens < max_new_tokens: # Intervene only for a certain number of tokens 219 | num_extra_tokens = max_new_tokens - num_interv_tokens 220 | with TraceDict(model, layers=model_config['layer_hook_names'], edit_output=intervention_fn): 221 | intervention_output = model.generate(**inputs, max_new_tokens = num_interv_tokens, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id) 222 | intervention_output = model.generate(intervention_output, max_new_tokens=num_extra_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=do_sample) 223 | else: 224 | with TraceDict(model, layers=model_config['layer_hook_names'], edit_output=intervention_fn): 225 | intervention_output = model.generate(**inputs, max_new_tokens = max_new_tokens, do_sample=do_sample, pad_token_id=tokenizer.eos_token_id) 226 | 227 | return clean_output, intervention_output 228 | 229 | 230 | def add_avg_to_activation(layer_head_token_pairs, avg_activations, model, model_config, batched_input=False, last_token_only=False): 231 | """ 232 | An intervention function for adding a computed average value to activations. 233 | This function adds a pre-computed average value to the output of one (or several) attention head(s) 234 | (usually taken from another set of runs with a particular property). 235 | The batched_input flag is used for systematic interventions where we are sweeping over all attention heads for a given (layer,token) 236 | The last_token_only flag is used for interventions where we only intervene on the last token (such as zero-shot or concept-naming) 237 | 238 | Parameters: 239 | layer_head_token_pairs: list of tuple triplets each containing a layer index, head index, and token index [(L,H,T), ...] 240 | avg_activations: torch tensor of the average activations (across ICL prompts) for each attention head of the model. 241 | model: huggingface model 242 | model_config: contains model config information (n layers, n heads, etc.) 243 | batched_input: whether or not to batch the intervention across all heads 244 | last_token_only: whether our intervention is only at the last token 245 | 246 | Returns: 247 | add_act: A function that specifies how to replace activations with an average when given a hooked pytorch module. 248 | """ 249 | edit_layers = [x[0] for x in layer_head_token_pairs] 250 | device = model.device 251 | 252 | def add_act(output, layer_name, inputs): 253 | current_layer = int(layer_name.split('.')[2]) 254 | if current_layer in edit_layers: 255 | if isinstance(inputs, tuple): 256 | inputs = inputs[0] 257 | 258 | # Determine shapes for intervention 259 | original_shape = inputs.shape 260 | new_shape = inputs.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads) 261 | inputs = inputs.view(*new_shape) # inputs shape: (batch_size , tokens (n), heads, hidden_dim) 262 | 263 | # Perform Intervention: 264 | if batched_input: 265 | # Patch activations from avg activations into baseline sentences (i.e. n_head baseline sentences being modified in this case) 266 | for i in range(model_config['n_heads']): 267 | layer, head_n, token_n = layer_head_token_pairs[i] 268 | inputs[i, token_n, head_n] += avg_activations[layer, head_n, token_n].to(device) 269 | elif last_token_only: 270 | # Patch activations only at the last token for interventions like: (zero-shot, concept-naming, etc.) 271 | for (layer,head_n,token_n) in layer_head_token_pairs: 272 | if layer == current_layer: 273 | inputs[-1,-1,head_n] += avg_activations[layer,head_n,token_n].to(device) 274 | else: 275 | # Patch activations into baseline sentence found at index, -1 of the batch (targeted & multi-token patching) 276 | for (layer, head_n, token_n) in layer_head_token_pairs: 277 | if layer == current_layer: 278 | inputs[-1, token_n, head_n] += avg_activations[layer,head_n,token_n].to(device) 279 | 280 | inputs = inputs.view(*original_shape) 281 | proj_module = get_module(model, layer_name) 282 | out_proj = proj_module.weight 283 | 284 | if 'gpt2-xl' in model_config['name_or_path']: # GPT2-XL uses Conv1D (not nn.Linear) & has a bias term, GPTJ does not 285 | out_proj_bias = proj_module.bias 286 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj) 287 | 288 | elif 'gpt-j' in model_config['name_or_path'] or 'gemma' in model_config['name_or_path']: 289 | new_output = torch.matmul(inputs, out_proj.T) 290 | 291 | elif 'gpt-neox' in model_config['name_or_path'] or 'pythia' in model_config['name_or_path']: 292 | out_proj_bias = proj_module.bias 293 | new_output = torch.addmm(out_proj_bias, inputs.squeeze(), out_proj.T) 294 | 295 | elif 'llama' in model_config['name_or_path']: 296 | if '70b' in model_config['name_or_path']: 297 | # need to dequantize weights 298 | out_proj_dequant = bnb.functional.dequantize_4bit(out_proj.data, out_proj.quant_state) 299 | new_output = torch.matmul(inputs, out_proj_dequant.T) 300 | else: 301 | new_output = torch.matmul(inputs, out_proj.T) 302 | 303 | elif 'olmo' in model_config['name_or_path'].lower(): 304 | new_output = torch.matmul(inputs, out_proj.T) 305 | 306 | return new_output 307 | else: 308 | return output 309 | 310 | return add_act -------------------------------------------------------------------------------- /src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM 4 | import os 5 | import random 6 | from typing import * 7 | 8 | 9 | def load_gpt_model_and_tokenizer(model_name:str, device='cuda', revision=None): 10 | """ 11 | Loads a huggingface model and its tokenizer 12 | 13 | Parameters: 14 | model_name: huggingface name of the model to load (e.g. GPTJ: "EleutherAI/gpt-j-6B", or "EleutherAI/gpt-j-6b") 15 | device: 'cuda' or 'cpu' 16 | 17 | Returns: 18 | model: huggingface model 19 | tokenizer: huggingface tokenizer 20 | MODEL_CONFIG: config variables w/ standardized names 21 | 22 | """ 23 | assert model_name is not None 24 | 25 | print("Loading: ", model_name) 26 | 27 | if model_name == 'gpt2-xl': 28 | tokenizer = AutoTokenizer.from_pretrained(model_name) 29 | tokenizer.pad_token = tokenizer.eos_token 30 | model = AutoModelForCausalLM.from_pretrained(model_name).to(device) 31 | 32 | MODEL_CONFIG={"n_heads":model.config.n_head, 33 | "n_layers":model.config.n_layer, 34 | "resid_dim":model.config.n_embd, 35 | "name_or_path":model.config.name_or_path, 36 | "attn_hook_names":[f'transformer.h.{layer}.attn.c_proj' for layer in range(model.config.n_layer)], 37 | "layer_hook_names":[f'transformer.h.{layer}' for layer in range(model.config.n_layer)], 38 | "prepend_bos":False} 39 | 40 | elif 'gpt-j' in model_name.lower(): 41 | tokenizer = AutoTokenizer.from_pretrained(model_name) 42 | tokenizer.pad_token = tokenizer.eos_token 43 | model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True).to(device) 44 | 45 | MODEL_CONFIG={"n_heads":model.config.n_head, 46 | "n_layers":model.config.n_layer, 47 | "resid_dim":model.config.n_embd, 48 | "name_or_path":model.config.name_or_path, 49 | "attn_hook_names":[f'transformer.h.{layer}.attn.out_proj' for layer in range(model.config.n_layer)], 50 | "layer_hook_names":[f'transformer.h.{layer}' for layer in range(model.config.n_layer)], 51 | "prepend_bos":False} 52 | 53 | elif 'gpt-neox' in model_name.lower() or 'pythia' in model_name.lower(): 54 | tokenizer = AutoTokenizer.from_pretrained(model_name) 55 | tokenizer.pad_token = tokenizer.eos_token 56 | if revision is not None and 'pythia' in model_name.lower(): 57 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, revision=revision).to(device) 58 | else: 59 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) 60 | 61 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads, 62 | "n_layers":model.config.num_hidden_layers, 63 | "resid_dim": model.config.hidden_size, 64 | "name_or_path":model.config.name_or_path, 65 | "attn_hook_names":[f'gpt_neox.layers.{layer}.attention.dense' for layer in range(model.config.num_hidden_layers)], 66 | "layer_hook_names":[f'gpt_neox.layers.{layer}' for layer in range(model.config.num_hidden_layers)], 67 | "prepend_bos":False} 68 | 69 | elif 'gemma' in model_name.lower(): 70 | tokenizer = AutoTokenizer.from_pretrained(model_name) 71 | tokenizer.pad_token = tokenizer.eos_token 72 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) 73 | 74 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads, 75 | "n_layers":model.config.num_hidden_layers, 76 | "resid_dim":model.config.hidden_size, 77 | "name_or_path":model.config._name_or_path, 78 | "attn_hook_names":[f'model.layers.{layer}.self_attn.o_proj' for layer in range(model.config.num_hidden_layers)], 79 | "layer_hook_names":[f'model.layers.{layer}' for layer in range(model.config.num_hidden_layers)], 80 | "prepend_bos":True} 81 | 82 | elif 'llama' in model_name.lower(): 83 | if '70b' in model_name.lower(): 84 | # use quantization. requires `bitsandbytes` library 85 | from transformers import BitsAndBytesConfig 86 | bnb_config = BitsAndBytesConfig( 87 | load_in_4bit=True, 88 | bnb_4bit_quant_type='nf4', 89 | bnb_4bit_use_double_quant=True, 90 | bnb_4bit_compute_dtype=torch.float16 91 | ) 92 | tokenizer = LlamaTokenizer.from_pretrained(model_name) 93 | model = LlamaForCausalLM.from_pretrained( 94 | model_name, 95 | trust_remote_code=True, 96 | quantization_config=bnb_config 97 | ) 98 | else: 99 | if '7b' in model_name.lower() or '8b' in model_name.lower(): 100 | model_dtype = torch.float32 101 | else: #half precision for bigger llama models 102 | model_dtype = torch.float16 103 | 104 | # If transformers version is < 4.31 use LlamaLoaders 105 | # tokenizer = LlamaTokenizer.from_pretrained(model_name) 106 | # model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(device) 107 | 108 | # If transformers version is >= 4.31, use AutoLoaders 109 | tokenizer = AutoTokenizer.from_pretrained(model_name) 110 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(device) 111 | 112 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads, 113 | "n_layers":model.config.num_hidden_layers, 114 | "resid_dim":model.config.hidden_size, 115 | "name_or_path":model.config._name_or_path, 116 | "attn_hook_names":[f'model.layers.{layer}.self_attn.o_proj' for layer in range(model.config.num_hidden_layers)], 117 | "layer_hook_names":[f'model.layers.{layer}' for layer in range(model.config.num_hidden_layers)], 118 | "prepend_bos":True} 119 | elif "olmo" in model_name.lower(): 120 | 121 | model_dtype = torch.float32 122 | tokenizer = AutoTokenizer.from_pretrained(model_name) 123 | if revision is not None: 124 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, revision=revision).to(device) 125 | else: 126 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(device) 127 | 128 | MODEL_CONFIG={"n_heads":model.config.num_attention_heads, 129 | "n_layers":model.config.num_hidden_layers, 130 | "resid_dim":model.config.hidden_size, 131 | "name_or_path":model.config._name_or_path, 132 | "attn_hook_names":[f'model.layers.{layer}.self_attn.o_proj' for layer in range(model.config.num_hidden_layers)], 133 | "layer_hook_names":[f'model.layers.{layer}' for layer in range(model.config.num_hidden_layers)], 134 | "prepend_bos":False} 135 | else: 136 | raise NotImplementedError("Still working to get this model available!") 137 | 138 | 139 | return model, tokenizer, MODEL_CONFIG 140 | 141 | def set_seed(seed: int) -> None: 142 | """ 143 | Sets the seed to make everything deterministic, for reproducibility of experiments 144 | 145 | Parameters: 146 | seed: the number to set the seed to 147 | 148 | Return: None 149 | """ 150 | 151 | # Random seed 152 | random.seed(seed) 153 | 154 | # Numpy seed 155 | np.random.seed(seed) 156 | 157 | # Torch seed 158 | torch.manual_seed(seed) 159 | torch.cuda.manual_seed(seed) 160 | torch.backends.cudnn.deterministic = True 161 | torch.backends.cudnn.benchmark = True 162 | 163 | # os seed 164 | os.environ['PYTHONHASHSEED'] = str(seed) -------------------------------------------------------------------------------- /src/utils/prompt_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pathlib import Path 4 | import os 5 | from typing import * 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | 10 | def create_fewshot_primer(prompt_data) -> str: 11 | """Creates the primer string for GPT in-context learning 12 | 13 | Parameters: 14 | prompt_data: dict containing ICL prompt examples, and template information 15 | 16 | Returns: 17 | prompt: the constructed ICL prompt primer as a string 18 | """ 19 | prompt = '' 20 | prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions'] 21 | 22 | for example in prompt_data['examples']: 23 | 24 | prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input'] 25 | prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output'] 26 | 27 | return prompt 28 | 29 | def create_prompt(prompt_data, sentence=None) -> str: 30 | """Creates a prompt using the specified sentence for GPT in-context learning 31 | 32 | Parameters: 33 | prompt_data: dict containing ICL prompt examples, and template information 34 | sentence: a query string (sentence/word) to include in the ICL prompt 35 | 36 | Returns: 37 | prompt: the constructed ICL prompt as a string 38 | """ 39 | if sentence is None and prompt_data['query_target'] is not None: 40 | sentence = prompt_data['query_target']['input'] 41 | 42 | if isinstance(sentence, list): 43 | sentence = sentence[0] 44 | 45 | prompt_init = create_fewshot_primer(prompt_data) 46 | prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input'] 47 | prompt += prompt_data['prefixes']['output'] 48 | 49 | return prompt 50 | 51 | # Partial primer & prompt functions 52 | def create_partial_fewshot_primer(prompt_data, include = np.arange(8)) -> str: 53 | """Creates the primer string for GPT in-context learning, filtering to include a subset of specified priming strings 54 | 55 | Parameters: 56 | prompt_data: dict containing ICL prompt examples, and template information 57 | include: an iterable of ints indicating which examples to include in the ICL prompt 58 | 59 | Returns: 60 | prompt: the constructed ICL prompt primer as a string 61 | """ 62 | prompt = '' 63 | prompt += prompt_data['prefixes']['instructions'] + prompt_data['instructions'] + prompt_data['separators']['instructions'] 64 | 65 | # Grab each priming example in the specified order. 66 | for i in include: 67 | example = prompt_data['examples'][i] 68 | prompt += prompt_data['prefixes']['input'] + example['input'] + prompt_data['separators']['input'] 69 | prompt += prompt_data['prefixes']['output'] + example['output'] + prompt_data['separators']['output'] 70 | 71 | return prompt 72 | 73 | def create_partial_prompt(prompt_data, sentence=None, include=np.arange(8)) -> str: 74 | """Creates a prompt using the specified sentence and partial list of in-context primer sentences 75 | 76 | Parameters: 77 | prompt_data: dict containing ICL prompt examples, and template information 78 | sentence: a query string (sentence /word) to include in the ICl prompt 79 | include: an iterable of ints indicating which examples to include in the ICL prompt 80 | 81 | Returns: 82 | prompt: the prompt as a string 83 | """ 84 | if sentence is None and prompt_data['query_target'] is not None: 85 | sentence = prompt_data['query_target']['input'] 86 | if isinstance(sentence, list): 87 | sentence = sentence[0] 88 | 89 | prompt_init = create_partial_fewshot_primer(prompt_data, include) 90 | 91 | prompt = prompt_init + prompt_data['prefixes']['input'] + sentence + prompt_data['separators']['input'] 92 | prompt += prompt_data['prefixes']['output'] 93 | 94 | return prompt 95 | 96 | 97 | # UTILS FOR GENERATING PROMPT META LABELS 98 | def get_prompt_parts_and_labels(prompt_data, query_sentence=None): 99 | """ 100 | Generates high-level labels for ICL prompts according to its ICL role, such as demonstration, label, separator, structural, etc. 101 | The JSON prompt format should include 'instructions', 'examples' with ('input', 'output') pairs, 102 | 'prefixes', and 'separators' for 'input', 'output', and 'instructions'. 103 | Used in conjunction with tokenize_labels 104 | 105 | Parameters: 106 | prompt_data: dict containing ICL prompt examples, and template information 107 | query_sentence: optional (if contained in prompt_data) str containing a query for an ICL prompt 108 | 109 | Returns: 110 | prompt_parts: structured list of words to be flattened and tokenized 111 | prompt_part_labels: structured list of labels to be flattened & extended over tokenization 112 | """ 113 | if query_sentence is None and prompt_data['query_target'] is not None: 114 | query_sentence = prompt_data['query_target']['input'] 115 | if isinstance(query_sentence, list): 116 | query_sentence = query_sentence[0] 117 | n_examples = len(prompt_data['examples']) 118 | assemble_icl_example = lambda example, prompt_data: [prompt_data['prefixes']['input'], example['input'], prompt_data['separators']['input'], prompt_data['prefixes']['output'], example['output'], prompt_data['separators']['output']] 119 | assemble_icl_query = lambda query, prompt_data: [prompt_data['prefixes']['input'], query, prompt_data['separators']['input'], prompt_data['prefixes']['output']] 120 | 121 | prompt_instructions = [prompt_data['prefixes']['instructions'], prompt_data['instructions'], prompt_data['separators']['instructions']] 122 | prompt_icl_examples = [assemble_icl_example(prompt_data['examples'][i], prompt_data) for i in range(n_examples)] 123 | prompt_icl_query = [assemble_icl_query(query_sentence, prompt_data)] 124 | 125 | prompt_instructions_labels = ['bos_token', 'instructions_token', 'separator_token'] 126 | prompt_icl_examples_labels = [['structural_token', f'demonstration_{i+1}_token', 'separator_token', 'structural_token', f'demonstration_{i+1}_label_token', 'separator_token'] for i in range(n_examples)] 127 | prompt_icl_query_labels = [['query_structural_token', 'query_demonstration_token', 'query_separator_token', 'query_structural_token']] 128 | 129 | prompt_parts = prompt_instructions + prompt_icl_examples + prompt_icl_query 130 | 131 | prompt_part_labels = prompt_instructions_labels + prompt_icl_examples_labels + prompt_icl_query_labels 132 | 133 | return prompt_parts, prompt_part_labels 134 | 135 | def extend_labels(sentence_parts, text_labels, tokenizer, label_init=[]): 136 | """ 137 | Extends ICL component labels across words that are tokenized into multiple tokens 138 | 139 | Parameters: 140 | sentence_parts: list, where each element is either a token (str), phrase (str), or list of tokens/phrases 141 | text_labels: list with the same structure as 'sentence_parts', with a corresponding label for that level of the input sentence. 142 | tokenizer: huggingface tokenizer 143 | 144 | Returns: 145 | final_labels: flattened/extended list of token labels for an ICL prompt (split into parts, contained in sentence_parts and text_labels) 146 | """ 147 | zipped_up = [list(zip(x,y)) if isinstance(x, list) else [(x,y)] for x,y in list(zip(sentence_parts,text_labels)) ] 148 | 149 | prompt_builder = '' 150 | final_labels = label_init 151 | for element in zipped_up: 152 | 153 | for j, (word,label) in enumerate(element): 154 | if len(word) == 0: 155 | continue 156 | pre = len(tokenizer.tokenize(prompt_builder)) 157 | prompt_builder += word 158 | post = len(tokenizer.tokenize(prompt_builder)) 159 | 160 | actual_tokens = post-pre 161 | 162 | if actual_tokens == 0: 163 | # if tokenization gobbles up a previous label, then we overwrite the last previous label w/ label that should've been added 164 | final_labels[-1] = label 165 | 166 | final_labels.extend([label] * (actual_tokens)) 167 | 168 | if j==3 or j==2 and len(element[3])==0: 169 | final_labels[-1] = final_labels[-1].replace('structural', 'predictive').replace('separator', 'predictive') 170 | if j==5: 171 | final_labels[-actual_tokens] = final_labels[-actual_tokens].replace('separator', 'end_of_example') 172 | 173 | return final_labels 174 | 175 | def tokenize_labels(sentence_parts, text_labels, tokenizer, prepend_bos=False): 176 | """ 177 | Extends phrase-level labels across tokenization for in-context learning prompts. Tested with GPT-2's tokenizer from huggingface. 178 | Parameters: 179 | sentence_parts: list, where each element is either a token (str), phrase (str), or list of tokens/phrases 180 | text_labels: list with the same structure as 'sentence_parts', with a corresponding label for that level of the input sentence. 181 | tokenizer: huggingface tokenizer 182 | 183 | Returns: 184 | labels: flattened/extended list of token labels for an ICL prompt (split into parts, contained in sentence_parts and text_labels) 185 | 186 | based on the tokenize_and_preserve_labels function from: 187 | https://www.depends-on-the-definition.com/named-entity-recognition-with-bert/ 188 | """ 189 | 190 | # If the model typically prepends a bos, we add a bos label to label init 191 | if prepend_bos: 192 | labels = extend_labels(sentence_parts, text_labels, tokenizer, label_init=['bos_token']) 193 | else: 194 | labels = extend_labels(sentence_parts, text_labels, tokenizer, label_init=[]) 195 | 196 | return labels 197 | 198 | def get_token_meta_labels(prompt_data, tokenizer, query=None, prepend_bos=False): 199 | """ 200 | Computes the ICL meta-labels for every token in a prompt. 201 | 202 | Parameters: 203 | prompt_data: dict containing ICL prompt examples, and template information 204 | tokenizer: huggingface tokenizer 205 | query: str of the query input 206 | 207 | Return: 208 | token_labels: list of tuples (prompt token index, token, label) 209 | prompt_string: full prompt as a string 210 | """ 211 | if query is None and prompt_data['query_target'] is not None: 212 | query = prompt_data['query_target']['input'] 213 | if isinstance(query, list): 214 | query = query[0] 215 | 216 | prompt_parts, prompt_part_labels = get_prompt_parts_and_labels(prompt_data, query_sentence=query) 217 | token_meta_labels = tokenize_labels(prompt_parts, prompt_part_labels, tokenizer, prepend_bos) 218 | prompt_string = create_prompt(prompt_data=prompt_data, sentence=query) 219 | tokens = [tokenizer.decode(x) for x in tokenizer(prompt_string).input_ids] 220 | token_labels = list(zip(np.arange(len(tokens)), tokens, token_meta_labels)) 221 | 222 | return token_labels, prompt_string 223 | 224 | def get_dummy_token_labels(n_icl_examples, tokenizer, model_config, prefixes=None, separators=None): 225 | """ 226 | Computes the ground-truth meta labels & indices for an ICL prompt with the specified number of example pairs 227 | These GT labels assume each word gets a single token 228 | 229 | Parameters: 230 | n_icl_examples: number of ICL example pairs 231 | tokenizer: huggingface tokenizer 232 | prefixes: ICL template prefixes 233 | separators: ICL template separators 234 | 235 | Return: 236 | final_token_labels: list of tuples containing a token's index and label name [(int, str), ... ] 237 | """ 238 | # If the model already prepends a bos token by default, we don't want to add one to our prompts 239 | prepend_bos = False if model_config['prepend_bos'] else True 240 | 241 | if prefixes is not None and separators is not None: 242 | dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples}, 243 | query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos, 244 | prefixes=prefixes, separators=separators) 245 | else: 246 | dummy_prompt_data = word_pairs_to_prompt_data({'input': ['a']*n_icl_examples, 'output':['a']*n_icl_examples}, 247 | query_target_pair={'input':['a'], 'output':['a']}, prepend_bos_token=prepend_bos) 248 | final_token_labels, _ = get_token_meta_labels(dummy_prompt_data,tokenizer, prepend_bos=model_config['prepend_bos']) 249 | final_token_labels = [(x[0],x[-1]) for x in final_token_labels] 250 | return final_token_labels 251 | 252 | def compute_duplicated_labels(token_labels, gt_labels): 253 | """ 254 | Computes a map between duplicated labels and ground truth label positions for localized averaging 255 | 256 | Parameters: 257 | token_labels: token labels of actual prompt being used 258 | gt_labels: token labels for a "ground truth" prompt that assumes each input & output is a single token 259 | 260 | Returns: 261 | index_map: a dict mapping prompt label indices to ground truth label indices 262 | dup_label_ranges: indices where labels should be duplicated 263 | """ 264 | check_inds = list(filter(lambda x: 'demo' in x[2], token_labels)) 265 | dup_ranges = pd.DataFrame(check_inds).groupby(2)[0].aggregate(lambda x: (x.min(), x.max())) 266 | dup_labels = [v for v,x in dup_ranges.items() if (x[1] - x[0]) > 0] 267 | 268 | dup_label_ranges = dup_ranges[dup_labels].to_dict() 269 | dup_inds = pd.DataFrame(check_inds)[pd.DataFrame(check_inds)[2].duplicated()][0].values 270 | 271 | index_map = {k:v[0] for (k,v) in zip([x[0] for x in token_labels if x[0] not in dup_inds], gt_labels)} 272 | 273 | return index_map, dup_label_ranges 274 | 275 | def update_idx_map(idx_map, idx_avg) -> dict: 276 | """ 277 | Updates the idx_map to map duplicate tokens to its gt token position 278 | """ 279 | update_map = {} 280 | for (i,j) in idx_avg.values(): 281 | for k in range(i,j+1): 282 | if k not in idx_map.keys(): 283 | update_map[k] = idx_map[i] 284 | 285 | update_map = {**idx_map, **update_map} 286 | return update_map 287 | 288 | 289 | def word_pairs_to_prompt_data(word_pairs : dict, 290 | instructions: str = "", 291 | prefixes: dict = {"input":"Q:", "output":"A:","instructions":""}, 292 | separators: dict = {"input":"\n", "output":"\n\n", "instructions":""}, 293 | query_target_pair: dict = None, prepend_bos_token=False, 294 | shuffle_labels=False, prepend_space=True) -> dict: 295 | """Takes a dataset of word pairs, and constructs a prompt_data dict with additional information to construct an ICL prompt. 296 | Parameters: 297 | word_pairs: dict of the form {'word1':['a', 'b', ...], 'word2':['c', 'd', ...]} 298 | instructions: prefix instructions for an ICL prompt 299 | prefixes: dict of ICL prefixes that are prepended to inputs, outputs and instructions 300 | separators: dict of ICL separators that are appended to inputs, outputs and instructions 301 | query_target_pair: dict with a single input-output pair acting as the query for the prompt 302 | prepend_bos_token: whether or not to prepend a BOS token to the prompt 303 | shuffle_labels: whether to shuffle the ICL labels 304 | prepend_space: whether to prepend a space to every input and output token 305 | 306 | Returns: 307 | prompt_data: dict containing ICL prompt examples, and template information 308 | """ 309 | prompt_data = {} 310 | prompt_data['instructions'] = instructions 311 | prompt_data['separators'] = separators 312 | if prepend_bos_token: 313 | prefixes = {k:(v if k !='instructions' else '<|endoftext|>' + v) for (k,v) in prefixes.items()} 314 | prompt_data['prefixes'] = prefixes 315 | 316 | if query_target_pair is not None: 317 | query_target_pair = {k:(v[0] if isinstance(v, list) else v) for k,v in query_target_pair.items()} 318 | prompt_data['query_target'] = query_target_pair 319 | 320 | if shuffle_labels: 321 | randomized_pairs = [np.random.permutation(x).tolist() if i==1 else x for (i,x) in enumerate(list(word_pairs.values()))] # shuffle labels only 322 | if prepend_space: 323 | prompt_data['examples'] = [{'input':' ' + str(w1), 'output':' ' + str(w2)} for (w1,w2) in list(zip(*randomized_pairs))] 324 | prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None 325 | else: 326 | prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*randomized_pairs))] 327 | else: 328 | if prepend_space: 329 | prompt_data['examples'] = [{'input':' ' + str(w1), 'output':' ' + str(w2)} for (w1,w2) in list(zip(*word_pairs.values()))] 330 | prompt_data['query_target'] = {k:' ' + str(v) for k,v in query_target_pair.items()} if query_target_pair is not None else None 331 | else: 332 | prompt_data['examples'] = [{'input':w1, 'output':w2} for (w1,w2) in list(zip(*word_pairs.values()))] 333 | 334 | return prompt_data 335 | 336 | 337 | # DATASET UTILS 338 | class ICLDataset: 339 | """ 340 | A simple dataset class containing input-output pairs, used for ICL prompt construction. 341 | """ 342 | def __init__(self, dataset): 343 | if isinstance(dataset, str): 344 | self.raw_data = pd.read_json(dataset) 345 | elif isinstance(dataset, dict): 346 | self.raw_data = pd.DataFrame(dataset) 347 | self.raw_data = self.raw_data[['input', 'output']] 348 | 349 | def __getitem__(self,i): 350 | if isinstance(i, int): 351 | return self.raw_data.iloc[i].to_dict() 352 | elif isinstance(i, slice): 353 | return self.raw_data.iloc[i].to_dict(orient='list') 354 | elif isinstance(i, list) or isinstance(i, np.ndarray): 355 | return self.raw_data.iloc[i].to_dict(orient='list') 356 | elif isinstance(i, str): 357 | if i not in self.raw_data.columns: 358 | raise KeyError(f"Column '{i}' not in the dataset. Current columns in the dataset: {self.raw_data.columns.to_list()}") 359 | else: 360 | return self.raw_data[i].to_list() 361 | else: 362 | raise ValueError(f"{i} is not a valid index type. Expected one of: [int, list, np.ndarray, slice, str]") 363 | 364 | def __len__(self): 365 | return len(self.raw_data) 366 | 367 | def __repr__(self): 368 | s = "ICLDataset" + "({\n\tfeatures: " + f"{self.raw_data.columns.to_list()},\n\tnum_rows: {self.__len__()}" + "\n})" 369 | return s 370 | 371 | def split_icl_dataset(dataset, train_size=None, test_size=0.3, seed=42) -> Dict[str,ICLDataset]: 372 | """ 373 | Uses scikit-learn's train_test split to create train, valid, test dataset from provided dataset. 374 | 375 | Parameters: 376 | dataset: ICL dataset 377 | train_size: percentage of data (float between 0 and 1) to put in the training data split 378 | test_size: percentage of data (float between 0 and 1) to put into the test data split 379 | seed: seed used for splitting the data 380 | 381 | Returns: 382 | dict containing train, valid, test ICL datasets 383 | """ 384 | if train_size is None and test_size is None: 385 | train_size = 0.7 386 | test_size = 0.3 387 | 388 | elif train_size is not None and test_size is None: 389 | test_size = 1-train_size 390 | 391 | elif train_size is None and test_size is not None: 392 | train_size = 1-test_size 393 | 394 | elif train_size is not None and test_size is not None: 395 | assert train_size + test_size == 1 396 | 397 | train, valid = train_test_split(dataset.raw_data, test_size=test_size, random_state=seed) 398 | test, valid = train_test_split(valid, test_size=test_size, random_state=seed) 399 | 400 | train = ICLDataset(train.to_dict(orient='list')) 401 | valid = ICLDataset(valid.to_dict(orient='list')) 402 | test = ICLDataset(test.to_dict(orient='list')) 403 | 404 | return {'train':train, 'valid':valid, 'test':test} 405 | 406 | 407 | def load_dataset(task_name: str, 408 | root_data_dir: str = '../dataset_files', 409 | test_size = 0.3, 410 | seed=32 411 | ) -> Dict[str,ICLDataset]: 412 | """ 413 | Loads a dataset with input/output pairs 414 | 415 | Parameters: 416 | task_name: the name of the task dataset 417 | root_data_dir: the root directory where the data comes from 418 | test_size: fraction used in train/test split 419 | 420 | Return: 421 | dataset: the dict contain the train/valid/test dataset splits 422 | """ 423 | 424 | data_folders = ['abstractive', 'extractive'] 425 | assert test_size <= 1.0 426 | 427 | path = Path(root_data_dir) 428 | d_group_map = [(dataset_type, os.path.exists(os.path.join(root_data_dir, dataset_type, task_name+'.json'))) for dataset_type in data_folders] 429 | 430 | d_group = list(filter(lambda x: x[1], d_group_map)) 431 | 432 | assert len(d_group) !=0 and len(d_group) == 1, f"Error! 'task_name'={task_name}.json must be uniquely contained in one of these directories:{data_folders}. Please check the root_data_dir" 433 | dataset_folder = d_group[0][0] 434 | 435 | d_path = os.path.join(path, dataset_folder, f'{task_name}.json') 436 | 437 | dataset = ICLDataset(d_path) 438 | dataset = split_icl_dataset(dataset, test_size=test_size, seed=seed) 439 | 440 | return dataset -------------------------------------------------------------------------------- /src/vocab_reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch, numpy as np 3 | import argparse 4 | 5 | # Include prompt creation helper functions 6 | from utils.prompt_utils import load_dataset 7 | from utils.extract_utils import get_mean_head_activations, compute_universal_function_vector 8 | from utils.eval_utils import n_shot_eval_no_intervention, n_shot_eval 9 | from utils.model_utils import load_gpt_model_and_tokenizer, set_seed 10 | 11 | def optim_loop(v_n, target, decoder, loss_fn, optimizer, n_steps:int=1000, verbose:bool=False, restrict_vocab:int=50400): 12 | if target.shape[-1] != restrict_vocab: 13 | inds = torch.topk(target, restrict_vocab).indices[0] 14 | Z = torch.zeros(target.size()).cuda() 15 | Z[:,inds] = target[:,inds] 16 | else: 17 | Z = target 18 | 19 | for i in range(n_steps): 20 | loss = loss_fn(decoder(v_n),Z) 21 | loss.backward() 22 | if verbose: 23 | print(f"Loss:{loss.item()}, iter:{i}") 24 | optimizer.step() 25 | optimizer.zero_grad() 26 | return v_n 27 | 28 | def vocab_reconstruction(datasets, n_steps:int=1000, lr:float=0.5, n_seeds:int=5, n_trials:int=100, n_shots:int=10, restrict_vocab_list=[100,50400], return_vecs:bool=False): 29 | """ 30 | Computes and evaluates a function vector reconstruction which matches its output vocabulary distribution. 31 | 32 | Parameters: 33 | n_steps: number of optimization steps 34 | lr: adam learning rate 35 | n_seeds: number of seeds to run 36 | n_trials: number of prompts to compute task-conditioned mean head activations over 37 | n_shots: number of shots for task-conditioned mean prompts 38 | restrict_vocab_list: list of ints determining how many vocab words to match. Defaults to 100 & full-vocab (which is 50400 for GPT-J) 39 | return_vecs: whether to return the function vectors and their corresponding vocab-optimized reconstruction vectors 40 | 41 | Returns: 42 | orig_results: FV results 43 | zs_results: 44 | kl_divs: kl divergences between the distribution of the FV and its reconstruction 45 | fvs: (optional) the function vectors used 46 | vns: (optional) the vocab-optimized reconstruction vectors 47 | """ 48 | 49 | seeds = {k:[] for k in datasets} 50 | orig_results = {k:[] for k in datasets} 51 | fvs = {k:[] for k in datasets} 52 | vns = {k:{j:[] for j in range(len(restrict_vocab_list))} for k in datasets} 53 | zs_results = {k:{j:[] for j in range(len(restrict_vocab_list))} for k in datasets} 54 | kl_divs = {k:{j:[] for j in range(len(restrict_vocab_list))} for k in datasets} 55 | 56 | 57 | for dataset_name in datasets: 58 | print(f"Dataset: {dataset_name}") 59 | 60 | for i in range(n_seeds): 61 | seed = np.random.randint(100000) 62 | print(f"seed:{seed}") 63 | seeds[dataset_name].append(seed) 64 | set_seed(seed) 65 | 66 | # Disable gradients when extracting activations & computing FV 67 | torch.set_grad_enabled(False) 68 | 69 | dataset = load_dataset(dataset_name, seed=seed, root_data_dir=root_data_dir) 70 | 71 | fs_results_validation = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='valid') 72 | filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0] 73 | 74 | mean_activations = get_mean_head_activations(dataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=n_shots, 75 | N_TRIALS=n_trials, filter_set=filter_set_validation) 76 | 77 | fv, _ = compute_universal_function_vector(mean_activations, model, model_config) 78 | 79 | fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False, test_split='test') 80 | filter_set = np.where(np.array(fs_results['clean_rank_list']) == 0)[0] 81 | 82 | fv_results = n_shot_eval(dataset=dataset, fv_vector=fv, edit_layer=9, n_shots=0, 83 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set) 84 | 85 | orig_results[dataset_name].append(fv_results) 86 | fvs[dataset_name].append(fv) 87 | 88 | for j, vocab_size in enumerate(restrict_vocab_list): 89 | # Enable Gradients for Optimization 90 | torch.set_grad_enabled(True) 91 | v_n = torch.randn(fv.size()).cuda() 92 | v_n.requires_grad=True 93 | 94 | # Optim setup 95 | loss_fn = torch.nn.CrossEntropyLoss() 96 | optimizer = torch.optim.Adam([v_n], lr=lr) 97 | decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head).to(model.device) 98 | 99 | decoder.requires_grad=True 100 | for p in decoder.parameters(): 101 | p.requires_grad = True 102 | 103 | target = torch.nn.functional.softmax(decoder(fv), dim=-1).detach() 104 | 105 | computed_vn = optim_loop(v_n, target, decoder, loss_fn, optimizer, verbose=False, n_steps=n_steps, restrict_vocab=vocab_size) 106 | 107 | scaled_vn = computed_vn / torch.linalg.norm(computed_vn) * torch.linalg.norm(fv) 108 | 109 | zs_reconstruction_results = n_shot_eval(dataset=dataset, fv_vector=scaled_vn, edit_layer=9, n_shots=0, 110 | model=model, model_config=model_config, tokenizer=tokenizer, filter_set=filter_set) 111 | 112 | zs_results[dataset_name][j].append(zs_reconstruction_results) 113 | vns[dataset_name][j].append(scaled_vn.detach()) 114 | 115 | # Compute kl divergence between two distributions 116 | if vocab_size != 50400: 117 | tp = torch.softmax(decoder(fvs[dataset_name][i]), dim=-1) 118 | inds = torch.topk(tp, vocab_size).indices[0] 119 | vn_ps = torch.softmax(decoder(vns[dataset_name][j][i]), dim=-1)[:,inds] 120 | 121 | log_probs = torch.log(vn_ps / vn_ps.sum()) 122 | target_probs = tp[:,inds] / tp[:,inds].sum() 123 | else: 124 | log_probs = torch.log(torch.softmax(decoder(vns[dataset_name][j][i]), dim=-1)) 125 | target_probs = torch.softmax(decoder(fvs[dataset_name][i]), dim=-1) 126 | 127 | kl_divs[dataset_name][j].append(torch.nn.functional.kl_div(log_probs, target_probs, reduction='batchmean').item()) 128 | 129 | if return_vecs: 130 | return orig_results, zs_results, kl_divs, fvs, vns 131 | else: 132 | return orig_results, zs_results, kl_divs 133 | 134 | 135 | if __name__ == "__main__": 136 | 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b') 139 | parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files') 140 | parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results') 141 | parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5) 142 | parser.add_argument('--n_trials', help='Number of trials to use for computing task-conditioned mean head activations', type=int, required=False, default=100) 143 | parser.add_argument('--n_shots', help='Number of shots to use for prompts when computing task-conditioned mean head activations', type=int, required=False, default=10) 144 | parser.add_argument('--lr', help="Learning Rate for Adam Optimizer", type=int, required=False, default=0.5) 145 | parser.add_argument('--n_steps', help="Learning Rate for Adam Optimizer", type=int, required=False, default=1000) 146 | 147 | args = parser.parse_args() 148 | 149 | # Gather inputs 150 | model_name = args.model_name 151 | root_data_dir = args.root_data_dir 152 | save_path_root = args.save_path_root 153 | n_seeds = args.n_seeds 154 | n_trials = args.n_trials 155 | n_shots = args.n_shots 156 | lr = args.lr 157 | n_steps = args.n_steps 158 | 159 | 160 | # Load Model & Tokenizer 161 | torch.set_grad_enabled(False) 162 | model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name) 163 | 164 | datasets = ['antonym', 'english-french', 'capitalize', 'present-past', 'singular-plural', 'country-capital'] 165 | args.datasets = datasets 166 | 167 | # Test Loop: 168 | orig_results, zs_results, kl_divs = vocab_reconstruction(datasets, n_steps=n_steps, lr=lr, n_seeds=n_seeds, n_trials=n_trials, n_shots=n_shots, restrict_vocab_list=[100,50400]) 169 | 170 | 171 | # Extract Summary Results: 172 | os.makedirs(os.path.join(save_path_root), exist_ok=True) 173 | with open(os.path.join(save_path_root, 'reconstruction_results.txt'), 'w') as out_file: 174 | 175 | for dataset_name in datasets: 176 | print(f"{dataset_name.title()}:", file=out_file) 177 | fv_acc = [orig_results[dataset_name][i]['intervention_topk'][0][1] for i in range(n_seeds)] 178 | v100_acc = [zs_results[dataset_name][0][i]['intervention_topk'][0][1] for i in range(n_seeds)] 179 | kl100_val = kl_divs[dataset_name][0] 180 | 181 | vfull_acc = [zs_results[dataset_name][1][i]['intervention_topk'][0][1] for i in range(n_seeds)] 182 | klfull_val = kl_divs[dataset_name][1] 183 | 184 | print("fv results:", np.mean(fv_acc).round(3)*100, '% +/-', np.std(fv_acc).round(3)*100, file=out_file) 185 | print("v_100 results:", np.mean(v100_acc).round(3)*100, '% +/-', np.std(v100_acc).round(3)*100, file=out_file) 186 | print("KL100:", np.mean(kl100_val).round(5), '+/-', np.std(kl100_val).round(5), file=out_file) 187 | print("v_full results:", np.mean(vfull_acc).round(3)*100, '% +/-', np.std(vfull_acc).round(3)*100, file=out_file) 188 | print("KLFull:", np.mean(klfull_val).round(5), '+/-', np.std(klfull_val).round(5), '\n', file=out_file) 189 | 190 | with open(os.path.join(save_path_root, 'reconstruction_args.txt'), 'w') as arg_file: 191 | print(args.__dict__, file=arg_file) --------------------------------------------------------------------------------