├── .gitignore ├── GPT2 ├── .ipynb_checkpoints │ ├── model-checkpoint.py │ ├── model_l2loss-checkpoint.py │ ├── train_adam_l2loss-checkpoint.py │ └── train_adam_l2loss-checkpoint.sh ├── abstractive │ ├── antonym.json │ ├── capitalize.json │ ├── country-capital.json │ ├── english-french.json │ ├── english-german.json │ ├── english-spanish.json │ ├── next_item.json │ ├── present-past.json │ ├── prev_item.json │ ├── singular-plural.json │ ├── synonym.json │ └── word_length.json ├── config │ ├── .ipynb_checkpoints │ │ └── train_gpt2_small_adam_l2loss-checkpoint.py │ └── train_gpt2_small_adam_l2loss.py ├── configurator.py ├── function_vectors.ipynb ├── function_vectors_parallelogram.ipynb ├── model.py ├── model_l2loss.py ├── prediction.ipynb ├── readme.md ├── train_adam_l2loss.py └── train_adam_l2loss.sh ├── README.md ├── environment.yaml ├── figures ├── circle_case_study.pdf ├── data_eff_plot.pdf ├── eq_harmax.png ├── ev_plot.pdf ├── grokking_plot.pdf ├── mnist_harmonic_weights.pdf ├── mnist_standard_weights.pdf ├── modadd_weights_evolution.gif ├── rep_plots.pdf ├── rep_plots_appendix.pdf └── weights_evolution.gif ├── notebooks ├── case_study_circle.ipynb ├── final_figures.ipynb ├── lattice.ipynb ├── mnist.ipynb ├── mnist_video.ipynb ├── modadd.ipynb ├── modadd_video.ipynb ├── n_and_loss_experiments.ipynb ├── permutation_group.ipynb └── plot_results.ipynb ├── perm_figs ├── final_figs │ ├── H_MLP_ev_0.5912.png │ ├── H_transformer_ev_0.4036.png │ ├── fig_eff.png │ ├── fig_fvu.png │ ├── fig_grok_mlp.png │ ├── fig_grok_transformer.png │ ├── standard_MLP_ev_0.3940.png │ └── standard_transformer_ev_0.2944.png ├── n=1 │ ├── 0_H_MLP_ev_33.4.png │ ├── 105_H_MLP_ev_38.7.png │ ├── 157_H_MLP_ev_32.7.png │ ├── 210_H_MLP_ev_43.2.png │ └── 52_H_MLP_ev_27.3.png ├── n=1_high_weight_decay │ ├── emb_H_MLP_ev_0.5169.png │ ├── emb_H_transformer_ev_0.3840.png │ ├── emb_standard_MLP_ev_0.3582.png │ ├── emb_standard_transformer_ev_0.3366.png │ ├── fig_eff.png │ ├── fig_fvu.png │ ├── fig_grok_mlp.png │ └── fig_grok_transformer.png ├── n=embd_dim │ ├── H_MLP_ev_34.4.png │ ├── H_transformer_ev_33.4.png │ ├── fig_eff.png │ ├── fig_fvu.png │ ├── fig_grok_mlp.png │ ├── fig_grok_transformer.png │ ├── standard_MLP_ev_39.0.png │ └── standard_transformer_ev_28.7.png └── n=sqrt_embd_dim │ ├── H_MLP_ev_59.1.png │ ├── H_transformer_ev_41.1.png │ ├── fig_eff.png │ ├── fig_fvu.png │ ├── fig_grok_mlp.png │ ├── fig_grok_transformer.png │ ├── standard_MLP_ev_39.1.png │ └── standard_transformer_ev_27.2.png ├── scripts ├── HM_circle.sh ├── HM_equiv.sh ├── HM_family.sh ├── HM_lattice.sh ├── HM_permutation.sh ├── HT_circle.sh ├── HT_equiv.sh ├── HT_family.sh ├── HT_lattice.sh ├── HT_permutation.sh ├── M_circle.sh ├── M_equiv.sh ├── M_family.sh ├── M_lattice.sh ├── M_permutation.sh ├── T_circle.sh ├── T_equiv.sh ├── T_family.sh ├── T_lattice.sh ├── T_permutation.sh ├── circle_run.sh ├── data_size_sweep.sh ├── equiv_run.sh ├── family_tree_run.sh ├── greater_run.sh ├── lattice.py ├── lattice.sh ├── lattice_run.sh ├── loss_exp.sh ├── modadd.py ├── modadd.sh ├── n_exp.sh ├── u_circle.sh ├── u_circle_new.sh ├── u_equiv.sh ├── u_family.sh ├── u_family_new.sh ├── u_greater.sh └── u_lattice.sh ├── src ├── README.md ├── run_exp.py ├── unit_exp.py └── utils │ ├── FamilyTreeGenerator.py │ ├── crystal_metric.py │ ├── dataset.py │ ├── driver.py │ ├── model.py │ └── visualization.py └── toy_points.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | results 3 | 4 | scratch*.ipynb 5 | */slurm*.out 6 | 7 | harmonic_archive.zip 8 | data 9 | dropbox* 10 | permutation_results -------------------------------------------------------------------------------- /GPT2/.ipynb_checkpoints/train_adam_l2loss-checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o train_adam_l2loss.log-%j 3 | # SBATCH --job-name=AdamTrain 4 | #SBATCH --nodes=4 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gres=gpu:volta:2 7 | #SBATCH --cpus-per-task=40 8 | # SBATCH --mem=250G 9 | # SBATCH --time=23:59:00 10 | 11 | source /etc/profile 12 | module load anaconda/2023a-pytorch 13 | module load cuda/11.4 14 | module load nccl/2.10.3-cuda11.4 15 | 16 | export NCCL_DEBUG=INFO 17 | export PYTHONFAULTHANDLER=1 18 | 19 | # Set up rendezvous parameters 20 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 21 | export MASTER_PORT=$(shuf -i 29000-29999 -n 1) 22 | 23 | echo "MASTER_ADDR: $MASTER_ADDR" 24 | echo "MASTER_PORT: $MASTER_PORT" 25 | echo "SLURM_JOB_ID: $SLURM_JOB_ID" 26 | echo "SLURM_NTASKS: $SLURM_NTASKS" 27 | echo "SLURM_NODELIST: $SLURM_NODELIST" 28 | 29 | # Use srun to ensure the job is distributed 30 | srun --nodes=$SLURM_JOB_NUM_NODES --ntasks=$SLURM_JOB_NUM_NODES \ 31 | torchrun \ 32 | --nnodes=$SLURM_JOB_NUM_NODES \ 33 | --nproc_per_node=2 \ 34 | --rdzv_id=$SLURM_JOB_ID \ 35 | --rdzv_backend=c10d \ 36 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 37 | train_adam_l2loss.py \ 38 | config/train_gpt2_small_adam_l2loss.py \ 39 | --batch_size=6 \ 40 | --gradient_accumulation_steps=10 41 | -------------------------------------------------------------------------------- /GPT2/abstractive/country-capital.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "output": "Kabul", 4 | "input": "Afghanistan" 5 | }, 6 | { 7 | "output": "Tirana", 8 | "input": "Albania" 9 | }, 10 | { 11 | "output": "Algiers", 12 | "input": "Algeria" 13 | }, 14 | { 15 | "output": "Andorra la Vella", 16 | "input": "Andorra" 17 | }, 18 | { 19 | "output": "Luanda", 20 | "input": "Angola" 21 | }, 22 | { 23 | "output": "St. John's", 24 | "input": "Antigua and Barbuda" 25 | }, 26 | { 27 | "output": "Buenos Aires", 28 | "input": "Argentina" 29 | }, 30 | { 31 | "output": "Yerevan", 32 | "input": "Armenia" 33 | }, 34 | { 35 | "output": "Canberra", 36 | "input": "Australia" 37 | }, 38 | { 39 | "output": "Vienna", 40 | "input": "Austria" 41 | }, 42 | { 43 | "output": "Baku", 44 | "input": "Azerbaijan" 45 | }, 46 | { 47 | "output": "Nassau", 48 | "input": "Bahamas" 49 | }, 50 | { 51 | "output": "Manama", 52 | "input": "Bahrain" 53 | }, 54 | { 55 | "output": "Dhaka", 56 | "input": "Bangladesh" 57 | }, 58 | { 59 | "output": "Bridgetown", 60 | "input": "Barbados" 61 | }, 62 | { 63 | "output": "Minsk", 64 | "input": "Belarus" 65 | }, 66 | { 67 | "output": "Brussels", 68 | "input": "Belgium" 69 | }, 70 | { 71 | "output": "Belmopan", 72 | "input": "Belize" 73 | }, 74 | { 75 | "output": "Porto-Novo", 76 | "input": "Benin" 77 | }, 78 | { 79 | "output": "Thimphu", 80 | "input": "Bhutan" 81 | }, 82 | { 83 | "output": "La Paz", 84 | "input": "Bolivia" 85 | }, 86 | { 87 | "output": "Sarajevo", 88 | "input": "Bosnia and Herzegovina" 89 | }, 90 | { 91 | "output": "Gaborone", 92 | "input": "Botswana" 93 | }, 94 | { 95 | "output": "Brasilia", 96 | "input": "Brazil" 97 | }, 98 | { 99 | "output": "Bandar Seri Begawan", 100 | "input": "Brunei" 101 | }, 102 | { 103 | "output": "Sofia", 104 | "input": "Bulgaria" 105 | }, 106 | { 107 | "output": "Ouagadougou", 108 | "input": "Burkina Faso" 109 | }, 110 | { 111 | "output": "Bujumbura", 112 | "input": "Burundi" 113 | }, 114 | { 115 | "output": "Praia", 116 | "input": "Cabo Verde" 117 | }, 118 | { 119 | "output": "Phnom Penh", 120 | "input": "Cambodia" 121 | }, 122 | { 123 | "output": "Yaounde", 124 | "input": "Cameroon" 125 | }, 126 | { 127 | "output": "Ottawa", 128 | "input": "Canada" 129 | }, 130 | { 131 | "output": "Bangui", 132 | "input": "Central African Republic" 133 | }, 134 | { 135 | "output": "N'Djamena", 136 | "input": "Chad" 137 | }, 138 | { 139 | "output": "Santiago", 140 | "input": "Chile" 141 | }, 142 | { 143 | "output": "Beijing", 144 | "input": "China" 145 | }, 146 | { 147 | "output": "Bogotá", 148 | "input": "Colombia" 149 | }, 150 | { 151 | "output": "Moroni", 152 | "input": "Comoros" 153 | }, 154 | { 155 | "output": "Kinshasa", 156 | "input": "Congo" 157 | }, 158 | { 159 | "output": "San José", 160 | "input": "Costa Rica" 161 | }, 162 | { 163 | "output": "Yamoussoukro", 164 | "input": "Cote d'Ivoire" 165 | }, 166 | { 167 | "output": "Zagreb", 168 | "input": "Croatia" 169 | }, 170 | { 171 | "output": "Havana", 172 | "input": "Cuba" 173 | }, 174 | { 175 | "output": "Nicosia", 176 | "input": "Cyprus" 177 | }, 178 | { 179 | "output": "Prague", 180 | "input": "Czech Republic" 181 | }, 182 | { 183 | "output": "Kinshasa", 184 | "input": "Democratic Republic of the Congo" 185 | }, 186 | { 187 | "output": "Copenhagen", 188 | "input": "Denmark" 189 | }, 190 | { 191 | "output": "Djibouti City", 192 | "input": "Djibouti" 193 | }, 194 | { 195 | "output": "Roseau", 196 | "input": "Dominica" 197 | }, 198 | { 199 | "output": "Santo Domingo", 200 | "input": "Dominican Republic" 201 | }, 202 | { 203 | "output": "Quito", 204 | "input": "Ecuador" 205 | }, 206 | { 207 | "output": "Cairo", 208 | "input": "Egypt" 209 | }, 210 | { 211 | "output": "San Salvador", 212 | "input": "El Salvador" 213 | }, 214 | { 215 | "output": "Malabo", 216 | "input": "Equatorial Guinea" 217 | }, 218 | { 219 | "output": "Asmara", 220 | "input": "Eritrea" 221 | }, 222 | { 223 | "output": "Tallinn", 224 | "input": "Estonia" 225 | }, 226 | { 227 | "output": "Mbabane", 228 | "input": "Eswatini" 229 | }, 230 | { 231 | "output": "Addis Ababa", 232 | "input": "Ethiopia" 233 | }, 234 | { 235 | "output": "Suva", 236 | "input": "Fiji" 237 | }, 238 | { 239 | "output": "Helsinki", 240 | "input": "Finland" 241 | }, 242 | { 243 | "output": "Paris", 244 | "input": "France" 245 | }, 246 | { 247 | "output": "Libreville", 248 | "input": "Gabon" 249 | }, 250 | { 251 | "output": "Banjul", 252 | "input": "Gambia" 253 | }, 254 | { 255 | "output": "Tbilisi", 256 | "input": "Georgia" 257 | }, 258 | { 259 | "output": "Berlin", 260 | "input": "Germany" 261 | }, 262 | { 263 | "output": "Accra", 264 | "input": "Ghana" 265 | }, 266 | { 267 | "output": "Athens", 268 | "input": "Greece" 269 | }, 270 | { 271 | "output": "St. George's", 272 | "input": "Grenada" 273 | }, 274 | { 275 | "output": "Guatemala City", 276 | "input": "Guatemala" 277 | }, 278 | { 279 | "output": "Conakry", 280 | "input": "Guinea" 281 | }, 282 | { 283 | "output": "Bissau", 284 | "input": "Guinea-Bissau" 285 | }, 286 | { 287 | "output": "Georgetown", 288 | "input": "Guyana" 289 | }, 290 | { 291 | "output": "Port-au-Prince", 292 | "input": "Haiti" 293 | }, 294 | { 295 | "output": "Tegucigalpa", 296 | "input": "Honduras" 297 | }, 298 | { 299 | "output": "Budapest", 300 | "input": "Hungary" 301 | }, 302 | { 303 | "output": "Reykjavik", 304 | "input": "Iceland" 305 | }, 306 | { 307 | "output": "New Delhi", 308 | "input": "India" 309 | }, 310 | { 311 | "output": "Jakarta", 312 | "input": "Indonesia" 313 | }, 314 | { 315 | "output": "Tehran", 316 | "input": "Iran" 317 | }, 318 | { 319 | "output": "Baghdad", 320 | "input": "Iraq" 321 | }, 322 | { 323 | "output": "Dublin", 324 | "input": "Ireland" 325 | }, 326 | { 327 | "output": "Jerusalem", 328 | "input": "Israel" 329 | }, 330 | { 331 | "output": "Rome", 332 | "input": "Italy" 333 | }, 334 | { 335 | "output": "Kingston", 336 | "input": "Jamaica" 337 | }, 338 | { 339 | "output": "Tokyo", 340 | "input": "Japan" 341 | }, 342 | { 343 | "output": "Amman", 344 | "input": "Jordan" 345 | }, 346 | { 347 | "output": "Astana", 348 | "input": "Kazakhstan" 349 | }, 350 | { 351 | "output": "Nairobi", 352 | "input": "Kenya" 353 | }, 354 | { 355 | "output": "South Tarawa", 356 | "input": "Kiribati" 357 | }, 358 | { 359 | "output": "Pristina", 360 | "input": "Kosovo" 361 | }, 362 | { 363 | "output": "Kuwait City", 364 | "input": "Kuwait" 365 | }, 366 | { 367 | "output": "Bishkek", 368 | "input": "Kyrgyzstan" 369 | }, 370 | { 371 | "output": "Vientiane", 372 | "input": "Laos" 373 | }, 374 | { 375 | "output": "Riga", 376 | "input": "Latvia" 377 | }, 378 | { 379 | "output": "Beirut", 380 | "input": "Lebanon" 381 | }, 382 | { 383 | "output": "Maseru", 384 | "input": "Lesotho" 385 | }, 386 | { 387 | "output": "Monrovia", 388 | "input": "Liberia" 389 | }, 390 | { 391 | "output": "Tripoli", 392 | "input": "Libya" 393 | }, 394 | { 395 | "output": "Vaduz", 396 | "input": "Liechtenstein" 397 | }, 398 | { 399 | "output": "Vilnius", 400 | "input": "Lithuania" 401 | }, 402 | { 403 | "output": "Luxembourg City", 404 | "input": "Luxembourg" 405 | }, 406 | { 407 | "output": "Antananarivo", 408 | "input": "Madagascar" 409 | }, 410 | { 411 | "output": "Lilongwe", 412 | "input": "Malawi" 413 | }, 414 | { 415 | "output": "Kuala Lumpur", 416 | "input": "Malaysia" 417 | }, 418 | { 419 | "output": "Malé", 420 | "input": "Maldives" 421 | }, 422 | { 423 | "output": "Bamako", 424 | "input": "Mali" 425 | }, 426 | { 427 | "output": "Valletta", 428 | "input": "Malta" 429 | }, 430 | { 431 | "output": "Majuro", 432 | "input": "Marshall Islands" 433 | }, 434 | { 435 | "output": "Nouakchott", 436 | "input": "Mauritania" 437 | }, 438 | { 439 | "output": "Port Louis", 440 | "input": "Mauritius" 441 | }, 442 | { 443 | "output": "Mexico City", 444 | "input": "Mexico" 445 | }, 446 | { 447 | "output": "Palikir", 448 | "input": "Micronesia" 449 | }, 450 | { 451 | "output": "Chisinau", 452 | "input": "Moldova" 453 | }, 454 | { 455 | "output": "Monaco-Ville", 456 | "input": "Monaco" 457 | }, 458 | { 459 | "output": "Ulaanbaatar", 460 | "input": "Mongolia" 461 | }, 462 | { 463 | "output": "Podgorica", 464 | "input": "Montenegro" 465 | }, 466 | { 467 | "output": "Rabat", 468 | "input": "Morocco" 469 | }, 470 | { 471 | "output": "Maputo", 472 | "input": "Mozambique" 473 | }, 474 | { 475 | "output": "Naypyidaw", 476 | "input": "Myanmar" 477 | }, 478 | { 479 | "output": "Windhoek", 480 | "input": "Namibia" 481 | }, 482 | { 483 | "output": "Yaren District", 484 | "input": "Nauru" 485 | }, 486 | { 487 | "output": "Kathmandu", 488 | "input": "Nepal" 489 | }, 490 | { 491 | "output": "Amsterdam", 492 | "input": "Netherlands" 493 | }, 494 | { 495 | "output": "Wellington", 496 | "input": "New Zealand" 497 | }, 498 | { 499 | "output": "Managua", 500 | "input": "Nicaragua" 501 | }, 502 | { 503 | "output": "Niamey", 504 | "input": "Niger" 505 | }, 506 | { 507 | "output": "Abuja", 508 | "input": "Nigeria" 509 | }, 510 | { 511 | "output": "Pyongyang", 512 | "input": "North Korea" 513 | }, 514 | { 515 | "output": "Skopje", 516 | "input": "North Macedonia" 517 | }, 518 | { 519 | "output": "Oslo", 520 | "input": "Norway" 521 | }, 522 | { 523 | "output": "Muscat", 524 | "input": "Oman" 525 | }, 526 | { 527 | "output": "Islamabad", 528 | "input": "Pakistan" 529 | }, 530 | { 531 | "output": "Ngerulmud", 532 | "input": "Palau" 533 | }, 534 | { 535 | "output": "Ramallah", 536 | "input": "Palestine" 537 | }, 538 | { 539 | "output": "Panama City", 540 | "input": "Panama" 541 | }, 542 | { 543 | "output": "Port Moresby", 544 | "input": "Papua New Guinea" 545 | }, 546 | { 547 | "output": "Asunción", 548 | "input": "Paraguay" 549 | }, 550 | { 551 | "output": "Lima", 552 | "input": "Peru" 553 | }, 554 | { 555 | "output": "Manila", 556 | "input": "Philippines" 557 | }, 558 | { 559 | "output": "Warsaw", 560 | "input": "Poland" 561 | }, 562 | { 563 | "output": "Lisbon", 564 | "input": "Portugal" 565 | }, 566 | { 567 | "output": "Doha", 568 | "input": "Qatar" 569 | }, 570 | { 571 | "output": "Bucharest", 572 | "input": "Romania" 573 | }, 574 | { 575 | "output": "Moscow", 576 | "input": "Russia" 577 | }, 578 | { 579 | "output": "Kigali", 580 | "input": "Rwanda" 581 | }, 582 | { 583 | "output": "Basseterre", 584 | "input": "Saint Kitts and Nevis" 585 | }, 586 | { 587 | "output": "Castries", 588 | "input": "Saint Lucia" 589 | }, 590 | { 591 | "output": "Kingstown", 592 | "input": "Saint Vincent and the Grenadines" 593 | }, 594 | { 595 | "output": "Apia", 596 | "input": "Samoa" 597 | }, 598 | { 599 | "output": "San Marino", 600 | "input": "San Marino" 601 | }, 602 | { 603 | "output": "Sao Tome", 604 | "input": "Sao Tome and Principe" 605 | }, 606 | { 607 | "output": "Riyadh", 608 | "input": "Saudi Arabia" 609 | }, 610 | { 611 | "output": "Dakar", 612 | "input": "Senegal" 613 | }, 614 | { 615 | "output": "Belgrade", 616 | "input": "Serbia" 617 | }, 618 | { 619 | "output": "Victoria", 620 | "input": "Seychelles" 621 | }, 622 | { 623 | "output": "Freetown", 624 | "input": "Sierra Leone" 625 | }, 626 | { 627 | "output": "Singapore", 628 | "input": "Singapore" 629 | }, 630 | { 631 | "output": "Bratislava", 632 | "input": "Slovakia" 633 | }, 634 | { 635 | "output": "Ljubljana", 636 | "input": "Slovenia" 637 | }, 638 | { 639 | "output": "Honiara", 640 | "input": "Solomon Islands" 641 | }, 642 | { 643 | "output": "Mogadishu", 644 | "input": "Somalia" 645 | }, 646 | { 647 | "output": "Pretoria", 648 | "input": "South Africa" 649 | }, 650 | { 651 | "output": "Seoul", 652 | "input": "South Korea" 653 | }, 654 | { 655 | "output": "Juba", 656 | "input": "South Sudan" 657 | }, 658 | { 659 | "output": "Madrid", 660 | "input": "Spain" 661 | }, 662 | { 663 | "output": "Colombo", 664 | "input": "Sri Lanka" 665 | }, 666 | { 667 | "output": "Khartoum", 668 | "input": "Sudan" 669 | }, 670 | { 671 | "output": "Paramaribo", 672 | "input": "Suriname" 673 | }, 674 | { 675 | "output": "Stockholm", 676 | "input": "Sweden" 677 | }, 678 | { 679 | "output": "Bern", 680 | "input": "Switzerland" 681 | }, 682 | { 683 | "output": "Damascus", 684 | "input": "Syria" 685 | }, 686 | { 687 | "output": "Taipei", 688 | "input": "Taiwan" 689 | }, 690 | { 691 | "output": "Dushanbe", 692 | "input": "Tajikistan" 693 | }, 694 | { 695 | "output": "Dodoma", 696 | "input": "Tanzania" 697 | }, 698 | { 699 | "output": "Bangkok", 700 | "input": "Thailand" 701 | }, 702 | { 703 | "output": "Dili", 704 | "input": "Timor-Leste" 705 | }, 706 | { 707 | "output": "Lome", 708 | "input": "Togo" 709 | }, 710 | { 711 | "output": "Nukuʻalofa", 712 | "input": "Tonga" 713 | }, 714 | { 715 | "output": "Port of Spain", 716 | "input": "Trinidad and Tobago" 717 | }, 718 | { 719 | "output": "Tunis", 720 | "input": "Tunisia" 721 | }, 722 | { 723 | "output": "Ankara", 724 | "input": "Turkey" 725 | }, 726 | { 727 | "output": "Ashgabat", 728 | "input": "Turkmenistan" 729 | }, 730 | { 731 | "output": "Funafuti", 732 | "input": "Tuvalu" 733 | }, 734 | { 735 | "output": "Kampala", 736 | "input": "Uganda" 737 | }, 738 | { 739 | "output": "Kiev", 740 | "input": "Ukraine" 741 | }, 742 | { 743 | "output": "Abu Dhabi", 744 | "input": "United Arab Emirates" 745 | }, 746 | { 747 | "output": "London", 748 | "input": "United Kingdom" 749 | }, 750 | { 751 | "output": "Washington, D.C.", 752 | "input": "United States of America" 753 | }, 754 | { 755 | "output": "Montevideo", 756 | "input": "Uruguay" 757 | }, 758 | { 759 | "output": "Tashkent", 760 | "input": "Uzbekistan" 761 | }, 762 | { 763 | "output": "Port Vila", 764 | "input": "Vanuatu" 765 | }, 766 | { 767 | "output": "Vatican City", 768 | "input": "Vatican City" 769 | }, 770 | { 771 | "output": "Caracas", 772 | "input": "Venezuela" 773 | }, 774 | { 775 | "output": "Hanoi", 776 | "input": "Vietnam" 777 | }, 778 | { 779 | "output": "Sana'a", 780 | "input": "Yemen" 781 | }, 782 | { 783 | "output": "Lusaka", 784 | "input": "Zambia" 785 | }, 786 | { 787 | "output": "Harare", 788 | "input": "Zimbabwe" 789 | } 790 | ] -------------------------------------------------------------------------------- /GPT2/abstractive/next_item.json: -------------------------------------------------------------------------------- 1 | [{"input": "zero", "output": "one"}, {"input": "one", "output": "two"}, {"input": "two", "output": "three"}, {"input": "three", "output": "four"}, {"input": "four", "output": "five"}, {"input": "five", "output": "six"}, {"input": "six", "output": "seven"}, {"input": "seven", "output": "eight"}, {"input": "eight", "output": "nine"}, {"input": "nine", "output": "ten"}, {"input": "ten", "output": "eleven"}, {"input": "eleven", "output": "twelve"}, {"input": "twelve", "output": "thirteen"}, {"input": "thirteen", "output": "fourteen"}, {"input": "fourteen", "output": "fifteen"}, {"input": "fifteen", "output": "sixteen"}, {"input": "sixteen", "output": "seventeen"}, {"input": "seventeen", "output": "eighteen"}, {"input": "eighteen", "output": "nineteen"}, {"input": "nineteen", "output": "twenty"}, {"input": "0", "output": "1"}, {"input": "1", "output": "2"}, {"input": "2", "output": "3"}, {"input": "3", "output": "4"}, {"input": "4", "output": "5"}, {"input": "5", "output": "6"}, {"input": "6", "output": "7"}, {"input": "7", "output": "8"}, {"input": "8", "output": "9"}, {"input": "9", "output": "10"}, {"input": "10", "output": "11"}, {"input": "11", "output": "12"}, {"input": "12", "output": "13"}, {"input": "13", "output": "14"}, {"input": "14", "output": "15"}, {"input": "15", "output": "16"}, {"input": "16", "output": "17"}, {"input": "17", "output": "18"}, {"input": "18", "output": "19"}, {"input": "19", "output": "20"}, {"input": "20", "output": "21"}, {"input": "21", "output": "22"}, {"input": "22", "output": "23"}, {"input": "23", "output": "24"}, {"input": "24", "output": "25"}, {"input": "25", "output": "26"}, {"input": "26", "output": "27"}, {"input": "27", "output": "28"}, {"input": "28", "output": "29"}, {"input": "a", "output": "b"}, {"input": "b", "output": "c"}, {"input": "c", "output": "d"}, {"input": "d", "output": "e"}, {"input": "e", "output": "f"}, {"input": "f", "output": "g"}, {"input": "g", "output": "h"}, {"input": "h", "output": "i"}, {"input": "i", "output": "j"}, {"input": "j", "output": "k"}, {"input": "k", "output": "l"}, {"input": "l", "output": "m"}, {"input": "m", "output": "n"}, {"input": "n", "output": "o"}, {"input": "o", "output": "p"}, {"input": "p", "output": "q"}, {"input": "q", "output": "r"}, {"input": "r", "output": "s"}, {"input": "s", "output": "t"}, {"input": "t", "output": "u"}, {"input": "u", "output": "v"}, {"input": "v", "output": "w"}, {"input": "w", "output": "x"}, {"input": "x", "output": "y"}, {"input": "y", "output": "z"}, {"input": "A", "output": "B"}, {"input": "B", "output": "C"}, {"input": "C", "output": "D"}, {"input": "D", "output": "E"}, {"input": "E", "output": "F"}, {"input": "F", "output": "G"}, {"input": "G", "output": "H"}, {"input": "H", "output": "I"}, {"input": "I", "output": "J"}, {"input": "J", "output": "K"}, {"input": "K", "output": "L"}, {"input": "L", "output": "M"}, {"input": "M", "output": "N"}, {"input": "N", "output": "O"}, {"input": "O", "output": "P"}, {"input": "P", "output": "Q"}, {"input": "Q", "output": "R"}, {"input": "R", "output": "S"}, {"input": "S", "output": "T"}, {"input": "T", "output": "U"}, {"input": "U", "output": "V"}, {"input": "V", "output": "W"}, {"input": "W", "output": "X"}, {"input": "X", "output": "Y"}, {"input": "Y", "output": "Z"}, {"input": "AA", "output": "BB"}, {"input": "BB", "output": "CC"}, {"input": "CC", "output": "DD"}, {"input": "DD", "output": "EE"}, {"input": "EE", "output": "FF"}, {"input": "FF", "output": "GG"}, {"input": "GG", "output": "HH"}, {"input": "HH", "output": "II"}, {"input": "II", "output": "JJ"}, {"input": "JJ", "output": "KK"}, {"input": "KK", "output": "LL"}, {"input": "LL", "output": "MM"}, {"input": "MM", "output": "NN"}, {"input": "NN", "output": "OO"}, {"input": "OO", "output": "PP"}, {"input": "PP", "output": "QQ"}, {"input": "QQ", "output": "RR"}, {"input": "RR", "output": "SS"}, {"input": "SS", "output": "TT"}, {"input": "TT", "output": "UU"}, {"input": "UU", "output": "VV"}, {"input": "VV", "output": "WW"}, {"input": "WW", "output": "XX"}, {"input": "XX", "output": "YY"}, {"input": "YY", "output": "ZZ"}, {"input": "aa", "output": "bb"}, {"input": "bb", "output": "cc"}, {"input": "cc", "output": "dd"}, {"input": "dd", "output": "ee"}, {"input": "ee", "output": "ff"}, {"input": "ff", "output": "gg"}, {"input": "gg", "output": "hh"}, {"input": "hh", "output": "ii"}, {"input": "ii", "output": "jj"}, {"input": "jj", "output": "kk"}, {"input": "kk", "output": "ll"}, {"input": "ll", "output": "mm"}, {"input": "mm", "output": "nn"}, {"input": "nn", "output": "oo"}, {"input": "oo", "output": "pp"}, {"input": "pp", "output": "qq"}, {"input": "qq", "output": "rr"}, {"input": "rr", "output": "ss"}, {"input": "ss", "output": "tt"}, {"input": "tt", "output": "uu"}, {"input": "uu", "output": "vv"}, {"input": "vv", "output": "ww"}, {"input": "ww", "output": "xx"}, {"input": "xx", "output": "yy"}, {"input": "yy", "output": "zz"}, {"input": "I", "output": "II"}, {"input": "II", "output": "III"}, {"input": "III", "output": "IV"}, {"input": "IV", "output": "V"}, {"input": "V", "output": "VI"}, {"input": "VI", "output": "VII"}, {"input": "VII", "output": "VIII"}, {"input": "VIII", "output": "IX"}, {"input": "IX", "output": "X"}, {"input": "X", "output": "XI"}, {"input": "XI", "output": "XII"}, {"input": "XII", "output": "XIII"}, {"input": "XIII", "output": "XIV"}, {"input": "XIV", "output": "XV"}, {"input": "XV", "output": "XVI"}, {"input": "XVI", "output": "XVII"}, {"input": "XVII", "output": "XVIII"}, {"input": "XVIII", "output": "XIX"}, {"input": "XIX", "output": "XX"}, {"input": "i", "output": "ii"}, {"input": "ii", "output": "iii"}, {"input": "iii", "output": "iv"}, {"input": "iv", "output": "v"}, {"input": "v", "output": "vi"}, {"input": "vi", "output": "vii"}, {"input": "vii", "output": "viii"}, {"input": "viii", "output": "ix"}, {"input": "ix", "output": "x"}, {"input": "x", "output": "xi"}, {"input": "xi", "output": "xii"}, {"input": "xii", "output": "xiii"}, {"input": "xiii", "output": "xiv"}, {"input": "xiv", "output": "xv"}, {"input": "xv", "output": "xvi"}, {"input": "xvi", "output": "xvii"}, {"input": "xvii", "output": "xviii"}, {"input": "xviii", "output": "xix"}, {"input": "xix", "output": "xx"}, {"input": "monday", "output": "tuesday"}, {"input": "tuesday", "output": "wednesday"}, {"input": "wednesday", "output": "thursday"}, {"input": "thursday", "output": "friday"}, {"input": "friday", "output": "saturday"}, {"input": "saturday", "output": "sunday"}, {"input": "january", "output": "february"}, {"input": "february", "output": "march"}, {"input": "march", "output": "april"}, {"input": "april", "output": "may"}, {"input": "may", "output": "june"}, {"input": "june", "output": "july"}, {"input": "july", "output": "august"}, {"input": "august", "output": "september"}, {"input": "september", "output": "october"}, {"input": "october", "output": "november"}, {"input": "november", "output": "december"}, {"input": "Monday", "output": "Tuesday"}, {"input": "Tuesday", "output": "Wednesday"}, {"input": "Wednesday", "output": "Thursday"}, {"input": "Thursday", "output": "Friday"}, {"input": "Friday", "output": "Saturday"}, {"input": "Saturday", "output": "Sunday"}, {"input": "January", "output": "February"}, {"input": "February", "output": "March"}, {"input": "March", "output": "April"}, {"input": "April", "output": "May"}, {"input": "May", "output": "June"}, {"input": "June", "output": "July"}, {"input": "July", "output": "August"}, {"input": "August", "output": "September"}, {"input": "September", "output": "October"}, {"input": "October", "output": "November"}, {"input": "November", "output": "December"}, {"input": "sunday", "output": "monday"}, {"input": "december", "output": "january"}, {"input": "Sunday", "output": "Monday"}, {"input": "December", "output": "January"}] -------------------------------------------------------------------------------- /GPT2/abstractive/prev_item.json: -------------------------------------------------------------------------------- 1 | [{"input": "one", "output": "zero"}, {"input": "two", "output": "one"}, {"input": "three", "output": "two"}, {"input": "four", "output": "three"}, {"input": "five", "output": "four"}, {"input": "six", "output": "five"}, {"input": "seven", "output": "six"}, {"input": "eight", "output": "seven"}, {"input": "nine", "output": "eight"}, {"input": "ten", "output": "nine"}, {"input": "eleven", "output": "ten"}, {"input": "twelve", "output": "eleven"}, {"input": "thirteen", "output": "twelve"}, {"input": "fourteen", "output": "thirteen"}, {"input": "fifteen", "output": "fourteen"}, {"input": "sixteen", "output": "fifteen"}, {"input": "seventeen", "output": "sixteen"}, {"input": "eighteen", "output": "seventeen"}, {"input": "nineteen", "output": "eighteen"}, {"input": "twenty", "output": "nineteen"}, {"input": "1", "output": "0"}, {"input": "2", "output": "1"}, {"input": "3", "output": "2"}, {"input": "4", "output": "3"}, {"input": "5", "output": "4"}, {"input": "6", "output": "5"}, {"input": "7", "output": "6"}, {"input": "8", "output": "7"}, {"input": "9", "output": "8"}, {"input": "10", "output": "9"}, {"input": "11", "output": "10"}, {"input": "12", "output": "11"}, {"input": "13", "output": "12"}, {"input": "14", "output": "13"}, {"input": "15", "output": "14"}, {"input": "16", "output": "15"}, {"input": "17", "output": "16"}, {"input": "18", "output": "17"}, {"input": "19", "output": "18"}, {"input": "20", "output": "19"}, {"input": "21", "output": "20"}, {"input": "22", "output": "21"}, {"input": "23", "output": "22"}, {"input": "24", "output": "23"}, {"input": "25", "output": "24"}, {"input": "26", "output": "25"}, {"input": "27", "output": "26"}, {"input": "28", "output": "27"}, {"input": "29", "output": "28"}, {"input": "b", "output": "a"}, {"input": "c", "output": "b"}, {"input": "d", "output": "c"}, {"input": "e", "output": "d"}, {"input": "f", "output": "e"}, {"input": "g", "output": "f"}, {"input": "h", "output": "g"}, {"input": "i", "output": "h"}, {"input": "j", "output": "i"}, {"input": "k", "output": "j"}, {"input": "l", "output": "k"}, {"input": "m", "output": "l"}, {"input": "n", "output": "m"}, {"input": "o", "output": "n"}, {"input": "p", "output": "o"}, {"input": "q", "output": "p"}, {"input": "r", "output": "q"}, {"input": "s", "output": "r"}, {"input": "t", "output": "s"}, {"input": "u", "output": "t"}, {"input": "v", "output": "u"}, {"input": "w", "output": "v"}, {"input": "x", "output": "w"}, {"input": "y", "output": "x"}, {"input": "z", "output": "y"}, {"input": "B", "output": "A"}, {"input": "C", "output": "B"}, {"input": "D", "output": "C"}, {"input": "E", "output": "D"}, {"input": "F", "output": "E"}, {"input": "G", "output": "F"}, {"input": "H", "output": "G"}, {"input": "I", "output": "H"}, {"input": "J", "output": "I"}, {"input": "K", "output": "J"}, {"input": "L", "output": "K"}, {"input": "M", "output": "L"}, {"input": "N", "output": "M"}, {"input": "O", "output": "N"}, {"input": "P", "output": "O"}, {"input": "Q", "output": "P"}, {"input": "R", "output": "Q"}, {"input": "S", "output": "R"}, {"input": "T", "output": "S"}, {"input": "U", "output": "T"}, {"input": "V", "output": "U"}, {"input": "W", "output": "V"}, {"input": "X", "output": "W"}, {"input": "Y", "output": "X"}, {"input": "Z", "output": "Y"}, {"input": "BB", "output": "AA"}, {"input": "CC", "output": "BB"}, {"input": "DD", "output": "CC"}, {"input": "EE", "output": "DD"}, {"input": "FF", "output": "EE"}, {"input": "GG", "output": "FF"}, {"input": "HH", "output": "GG"}, {"input": "II", "output": "HH"}, {"input": "JJ", "output": "II"}, {"input": "KK", "output": "JJ"}, {"input": "LL", "output": "KK"}, {"input": "MM", "output": "LL"}, {"input": "NN", "output": "MM"}, {"input": "OO", "output": "NN"}, {"input": "PP", "output": "OO"}, {"input": "QQ", "output": "PP"}, {"input": "RR", "output": "QQ"}, {"input": "SS", "output": "RR"}, {"input": "TT", "output": "SS"}, {"input": "UU", "output": "TT"}, {"input": "VV", "output": "UU"}, {"input": "WW", "output": "VV"}, {"input": "XX", "output": "WW"}, {"input": "YY", "output": "XX"}, {"input": "ZZ", "output": "YY"}, {"input": "bb", "output": "aa"}, {"input": "cc", "output": "bb"}, {"input": "dd", "output": "cc"}, {"input": "ee", "output": "dd"}, {"input": "ff", "output": "ee"}, {"input": "gg", "output": "ff"}, {"input": "hh", "output": "gg"}, {"input": "ii", "output": "hh"}, {"input": "jj", "output": "ii"}, {"input": "kk", "output": "jj"}, {"input": "ll", "output": "kk"}, {"input": "mm", "output": "ll"}, {"input": "nn", "output": "mm"}, {"input": "oo", "output": "nn"}, {"input": "pp", "output": "oo"}, {"input": "qq", "output": "pp"}, {"input": "rr", "output": "qq"}, {"input": "ss", "output": "rr"}, {"input": "tt", "output": "ss"}, {"input": "uu", "output": "tt"}, {"input": "vv", "output": "uu"}, {"input": "ww", "output": "vv"}, {"input": "xx", "output": "ww"}, {"input": "yy", "output": "xx"}, {"input": "zz", "output": "yy"}, {"input": "II", "output": "I"}, {"input": "III", "output": "II"}, {"input": "IV", "output": "III"}, {"input": "V", "output": "IV"}, {"input": "VI", "output": "V"}, {"input": "VII", "output": "VI"}, {"input": "VIII", "output": "VII"}, {"input": "IX", "output": "VIII"}, {"input": "X", "output": "IX"}, {"input": "XI", "output": "X"}, {"input": "XII", "output": "XI"}, {"input": "XIII", "output": "XII"}, {"input": "XIV", "output": "XIII"}, {"input": "XV", "output": "XIV"}, {"input": "XVI", "output": "XV"}, {"input": "XVII", "output": "XVI"}, {"input": "XVIII", "output": "XVII"}, {"input": "XIX", "output": "XVIII"}, {"input": "XX", "output": "XIX"}, {"input": "ii", "output": "i"}, {"input": "iii", "output": "ii"}, {"input": "iv", "output": "iii"}, {"input": "v", "output": "iv"}, {"input": "vi", "output": "v"}, {"input": "vii", "output": "vi"}, {"input": "viii", "output": "vii"}, {"input": "ix", "output": "viii"}, {"input": "x", "output": "ix"}, {"input": "xi", "output": "x"}, {"input": "xii", "output": "xi"}, {"input": "xiii", "output": "xii"}, {"input": "xiv", "output": "xiii"}, {"input": "xv", "output": "xiv"}, {"input": "xvi", "output": "xv"}, {"input": "xvii", "output": "xvi"}, {"input": "xviii", "output": "xvii"}, {"input": "xix", "output": "xviii"}, {"input": "xx", "output": "xix"}, {"input": "tuesday", "output": "monday"}, {"input": "wednesday", "output": "tuesday"}, {"input": "thursday", "output": "wednesday"}, {"input": "friday", "output": "thursday"}, {"input": "saturday", "output": "friday"}, {"input": "sunday", "output": "saturday"}, {"input": "february", "output": "january"}, {"input": "march", "output": "february"}, {"input": "april", "output": "march"}, {"input": "may", "output": "april"}, {"input": "june", "output": "may"}, {"input": "july", "output": "june"}, {"input": "august", "output": "july"}, {"input": "september", "output": "august"}, {"input": "october", "output": "september"}, {"input": "november", "output": "october"}, {"input": "december", "output": "november"}, {"input": "Tuesday", "output": "Monday"}, {"input": "Wednesday", "output": "Tuesday"}, {"input": "Thursday", "output": "Wednesday"}, {"input": "Friday", "output": "Thursday"}, {"input": "Saturday", "output": "Friday"}, {"input": "Sunday", "output": "Saturday"}, {"input": "February", "output": "January"}, {"input": "March", "output": "February"}, {"input": "April", "output": "March"}, {"input": "May", "output": "April"}, {"input": "June", "output": "May"}, {"input": "July", "output": "June"}, {"input": "August", "output": "July"}, {"input": "September", "output": "August"}, {"input": "October", "output": "September"}, {"input": "November", "output": "October"}, {"input": "December", "output": "November"}, {"input": "monday", "output": "sunday"}, {"input": "january", "output": "december"}, {"input": "Monday", "output": "Sunday"}, {"input": "January", "output": "December"}] -------------------------------------------------------------------------------- /GPT2/abstractive/singular-plural.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "input": "wallet", 4 | "output": "wallets" 5 | }, 6 | { 7 | "input": "keychain", 8 | "output": "keychains" 9 | }, 10 | { 11 | "input": "mountain", 12 | "output": "mountains" 13 | }, 14 | { 15 | "input": "comb", 16 | "output": "combs" 17 | }, 18 | { 19 | "input": "monitor", 20 | "output": "monitors" 21 | }, 22 | { 23 | "input": "island", 24 | "output": "islands" 25 | }, 26 | { 27 | "input": "rake", 28 | "output": "rakes" 29 | }, 30 | { 31 | "input": "needle", 32 | "output": "needles" 33 | }, 34 | { 35 | "input": "lighter", 36 | "output": "lighters" 37 | }, 38 | { 39 | "input": "slipper", 40 | "output": "slippers" 41 | }, 42 | { 43 | "input": "fireplace", 44 | "output": "fireplaces" 45 | }, 46 | { 47 | "input": "ladder", 48 | "output": "ladders" 49 | }, 50 | { 51 | "input": "jacket", 52 | "output": "jackets" 53 | }, 54 | { 55 | "input": "helicopter", 56 | "output": "helicopters" 57 | }, 58 | { 59 | "input": "paintbrush", 60 | "output": "paintbrushes" 61 | }, 62 | { 63 | "input": "dustpan", 64 | "output": "dustpans" 65 | }, 66 | { 67 | "input": "wrench", 68 | "output": "wrenches" 69 | }, 70 | { 71 | "input": "tablet", 72 | "output": "tablets" 73 | }, 74 | { 75 | "input": "hoe", 76 | "output": "hoes" 77 | }, 78 | { 79 | "input": "tie", 80 | "output": "ties" 81 | }, 82 | { 83 | "input": "toy", 84 | "output": "toys" 85 | }, 86 | { 87 | "input": "glass", 88 | "output": "glasses" 89 | }, 90 | { 91 | "input": "hairdryer", 92 | "output": "hairdryers" 93 | }, 94 | { 95 | "input": "axe", 96 | "output": "axes" 97 | }, 98 | { 99 | "input": "vacuum", 100 | "output": "vacuums" 101 | }, 102 | { 103 | "input": "blush", 104 | "output": "blushes" 105 | }, 106 | { 107 | "input": "stove", 108 | "output": "stoves" 109 | }, 110 | { 111 | "input": "ladle", 112 | "output": "ladles" 113 | }, 114 | { 115 | "input": "poster", 116 | "output": "posters" 117 | }, 118 | { 119 | "input": "hat", 120 | "output": "hats" 121 | }, 122 | { 123 | "input": "lake", 124 | "output": "lakes" 125 | }, 126 | { 127 | "input": "razor", 128 | "output": "razors" 129 | }, 130 | { 131 | "input": "bottle", 132 | "output": "bottles" 133 | }, 134 | { 135 | "input": "glove", 136 | "output": "gloves" 137 | }, 138 | { 139 | "input": "grater", 140 | "output": "graters" 141 | }, 142 | { 143 | "input": "dishwasher", 144 | "output": "dishwashers" 145 | }, 146 | { 147 | "input": "sofa", 148 | "output": "sofas" 149 | }, 150 | { 151 | "input": "bag", 152 | "output": "bags" 153 | }, 154 | { 155 | "input": "keyboard", 156 | "output": "keyboards" 157 | }, 158 | { 159 | "input": "clock", 160 | "output": "clocks" 161 | }, 162 | { 163 | "input": "book", 164 | "output": "books" 165 | }, 166 | { 167 | "input": "scarf", 168 | "output": "scarves" 169 | }, 170 | { 171 | "input": "pants", 172 | "output": "pants" 173 | }, 174 | { 175 | "input": "window", 176 | "output": "windows" 177 | }, 178 | { 179 | "input": "house", 180 | "output": "houses" 181 | }, 182 | { 183 | "input": "freezer", 184 | "output": "freezers" 185 | }, 186 | { 187 | "input": "rag", 188 | "output": "rags" 189 | }, 190 | { 191 | "input": "racquet", 192 | "output": "racquets" 193 | }, 194 | { 195 | "input": "hair gel", 196 | "output": "hair gels" 197 | }, 198 | { 199 | "input": "door", 200 | "output": "doors" 201 | }, 202 | { 203 | "input": "pillow", 204 | "output": "pillows" 205 | }, 206 | { 207 | "input": "ruler", 208 | "output": "rulers" 209 | }, 210 | { 211 | "input": "washer", 212 | "output": "washers" 213 | }, 214 | { 215 | "input": "ocean", 216 | "output": "oceans" 217 | }, 218 | { 219 | "input": "plate", 220 | "output": "plates" 221 | }, 222 | { 223 | "input": "eyeshadow", 224 | "output": "eyeshadows" 225 | }, 226 | { 227 | "input": "zipper", 228 | "output": "zippers" 229 | }, 230 | { 231 | "input": "radio", 232 | "output": "radios" 233 | }, 234 | { 235 | "input": "flower", 236 | "output": "flowers" 237 | }, 238 | { 239 | "input": "laptop", 240 | "output": "laptops" 241 | }, 242 | { 243 | "input": "eraser", 244 | "output": "erasers" 245 | }, 246 | { 247 | "input": "corkscrew", 248 | "output": "corkscrews" 249 | }, 250 | { 251 | "input": "eyeliner", 252 | "output": "eyeliners" 253 | }, 254 | { 255 | "input": "desk", 256 | "output": "desks" 257 | }, 258 | { 259 | "input": "knife", 260 | "output": "knives" 261 | }, 262 | { 263 | "input": "helmet", 264 | "output": "helmets" 265 | }, 266 | { 267 | "input": "mixer", 268 | "output": "mixers" 269 | }, 270 | { 271 | "input": "microwave", 272 | "output": "microwaves" 273 | }, 274 | { 275 | "input": "button", 276 | "output": "buttons" 277 | }, 278 | { 279 | "input": "jar", 280 | "output": "jars" 281 | }, 282 | { 283 | "input": "pan", 284 | "output": "pans" 285 | }, 286 | { 287 | "input": "key", 288 | "output": "keys" 289 | }, 290 | { 291 | "input": "perfume", 292 | "output": "perfumes" 293 | }, 294 | { 295 | "input": "tape", 296 | "output": "tapes" 297 | }, 298 | { 299 | "input": "shoes", 300 | "output": "shoes" 301 | }, 302 | { 303 | "input": "shirt", 304 | "output": "shirts" 305 | }, 306 | { 307 | "input": "candle", 308 | "output": "candles" 309 | }, 310 | { 311 | "input": "juicer", 312 | "output": "juicers" 313 | }, 314 | { 315 | "input": "peeler", 316 | "output": "peelers" 317 | }, 318 | { 319 | "input": "mirror", 320 | "output": "mirrors" 321 | }, 322 | { 323 | "input": "mascara", 324 | "output": "mascaras" 325 | }, 326 | { 327 | "input": "whisk", 328 | "output": "whisks" 329 | }, 330 | { 331 | "input": "shovel", 332 | "output": "shovels" 333 | }, 334 | { 335 | "input": "marker", 336 | "output": "markers" 337 | }, 338 | { 339 | "input": "lotion", 340 | "output": "lotions" 341 | }, 342 | { 343 | "input": "matches", 344 | "output": "matches" 345 | }, 346 | { 347 | "input": "moon", 348 | "output": "moons" 349 | }, 350 | { 351 | "input": "pot", 352 | "output": "pots" 353 | }, 354 | { 355 | "input": "mop", 356 | "output": "mops" 357 | }, 358 | { 359 | "input": "sprinkler", 360 | "output": "sprinklers" 361 | }, 362 | { 363 | "input": "can", 364 | "output": "cans" 365 | }, 366 | { 367 | "input": "notebook", 368 | "output": "notebooks" 369 | }, 370 | { 371 | "input": "airplane", 372 | "output": "airplanes" 373 | }, 374 | { 375 | "input": "tongs", 376 | "output": "tongs" 377 | }, 378 | { 379 | "input": "phone", 380 | "output": "phones" 381 | }, 382 | { 383 | "input": "paint", 384 | "output": "paints" 385 | }, 386 | { 387 | "input": "conditioner", 388 | "output": "conditioners" 389 | }, 390 | { 391 | "input": "purse", 392 | "output": "purses" 393 | }, 394 | { 395 | "input": "broom", 396 | "output": "brooms" 397 | }, 398 | { 399 | "input": "rug", 400 | "output": "rugs" 401 | }, 402 | { 403 | "input": "toaster", 404 | "output": "toasters" 405 | }, 406 | { 407 | "input": "kettle", 408 | "output": "kettles" 409 | }, 410 | { 411 | "input": "blender", 412 | "output": "blenders" 413 | }, 414 | { 415 | "input": "colander", 416 | "output": "colanders" 417 | }, 418 | { 419 | "input": "toothpaste", 420 | "output": "toothpastes" 421 | }, 422 | { 423 | "input": "bed", 424 | "output": "beds" 425 | }, 426 | { 427 | "input": "refrigerator", 428 | "output": "refrigerators" 429 | }, 430 | { 431 | "input": "stapler", 432 | "output": "staplers" 433 | }, 434 | { 435 | "input": "backpack", 436 | "output": "backpacks" 437 | }, 438 | { 439 | "input": "fabric", 440 | "output": "fabrics" 441 | }, 442 | { 443 | "input": "nut", 444 | "output": "nuts" 445 | }, 446 | { 447 | "input": "bowl", 448 | "output": "bowls" 449 | }, 450 | { 451 | "input": "bolt", 452 | "output": "bolts" 453 | }, 454 | { 455 | "input": "chopsticks", 456 | "output": "chopsticks" 457 | }, 458 | { 459 | "input": "computer", 460 | "output": "computers" 461 | }, 462 | { 463 | "input": "valley", 464 | "output": "valleys" 465 | }, 466 | { 467 | "input": "lantern", 468 | "output": "lanterns" 469 | }, 470 | { 471 | "input": "boot", 472 | "output": "boots" 473 | }, 474 | { 475 | "input": "bucket", 476 | "output": "buckets" 477 | }, 478 | { 479 | "input": "sandal", 480 | "output": "sandals" 481 | }, 482 | { 483 | "input": "lock", 484 | "output": "locks" 485 | }, 486 | { 487 | "input": "drill", 488 | "output": "drills" 489 | }, 490 | { 491 | "input": "toothbrush", 492 | "output": "toothbrushes" 493 | }, 494 | { 495 | "input": "lamp", 496 | "output": "lamps" 497 | }, 498 | { 499 | "input": "star", 500 | "output": "stars" 501 | }, 502 | { 503 | "input": "bandana", 504 | "output": "bandanas" 505 | }, 506 | { 507 | "input": "stereo", 508 | "output": "stereos" 509 | }, 510 | { 511 | "input": "teapot", 512 | "output": "teapots" 513 | }, 514 | { 515 | "input": "thread", 516 | "output": "threads" 517 | }, 518 | { 519 | "input": "lawnmower", 520 | "output": "lawnmowers" 521 | }, 522 | { 523 | "input": "mug", 524 | "output": "mugs" 525 | }, 526 | { 527 | "input": "game", 528 | "output": "games" 529 | }, 530 | { 531 | "input": "shoe", 532 | "output": "shoes" 533 | }, 534 | { 535 | "input": "mattress", 536 | "output": "mattresses" 537 | }, 538 | { 539 | "input": "sunglasses", 540 | "output": "sunglasses" 541 | }, 542 | { 543 | "input": "river", 544 | "output": "rivers" 545 | }, 546 | { 547 | "input": "beach", 548 | "output": "beaches" 549 | }, 550 | { 551 | "input": "sponge", 552 | "output": "sponges" 553 | }, 554 | { 555 | "input": "blanket", 556 | "output": "blankets" 557 | }, 558 | { 559 | "input": "headphones", 560 | "output": "headphones" 561 | }, 562 | { 563 | "input": "mouse", 564 | "output": "mice" 565 | }, 566 | { 567 | "input": "pen", 568 | "output": "pens" 569 | }, 570 | { 571 | "input": "tree", 572 | "output": "trees" 573 | }, 574 | { 575 | "input": "car", 576 | "output": "cars" 577 | }, 578 | { 579 | "input": "belt", 580 | "output": "belts" 581 | }, 582 | { 583 | "input": "sky", 584 | "output": "skies" 585 | }, 586 | { 587 | "input": "socks", 588 | "output": "sockss" 589 | }, 590 | { 591 | "input": "briefcase", 592 | "output": "briefcases" 593 | }, 594 | { 595 | "input": "crayon", 596 | "output": "crayons" 597 | }, 598 | { 599 | "input": "screw", 600 | "output": "screws" 601 | }, 602 | { 603 | "input": "bicycle", 604 | "output": "bicycles" 605 | }, 606 | { 607 | "input": "grill", 608 | "output": "grills" 609 | }, 610 | { 611 | "input": "sun", 612 | "output": "suns" 613 | }, 614 | { 615 | "input": "coffee maker", 616 | "output": "coffee makers" 617 | }, 618 | { 619 | "input": "forest", 620 | "output": "forests" 621 | }, 622 | { 623 | "input": "spatula", 624 | "output": "spatulas" 625 | }, 626 | { 627 | "input": "brush", 628 | "output": "brushes" 629 | }, 630 | { 631 | "input": "cup", 632 | "output": "cups" 633 | }, 634 | { 635 | "input": "picture", 636 | "output": "pictures" 637 | }, 638 | { 639 | "input": "jewelry", 640 | "output": "jewelries" 641 | }, 642 | { 643 | "input": "shampoo", 644 | "output": "shampoos" 645 | }, 646 | { 647 | "input": "speaker", 648 | "output": "speakers" 649 | }, 650 | { 651 | "input": "boat", 652 | "output": "boats" 653 | }, 654 | { 655 | "input": "bus", 656 | "output": "buses" 657 | }, 658 | { 659 | "input": "sock", 660 | "output": "socks" 661 | }, 662 | { 663 | "input": "pliers", 664 | "output": "pliers" 665 | }, 666 | { 667 | "input": "suitcase", 668 | "output": "suitcases" 669 | }, 670 | { 671 | "input": "saw", 672 | "output": "saws" 673 | }, 674 | { 675 | "input": "fork", 676 | "output": "forks" 677 | }, 678 | { 679 | "input": "calendar", 680 | "output": "calendars" 681 | }, 682 | { 683 | "input": "umbrella", 684 | "output": "umbrellas" 685 | }, 686 | { 687 | "input": "printer", 688 | "output": "printers" 689 | }, 690 | { 691 | "input": "hose", 692 | "output": "hoses" 693 | }, 694 | { 695 | "input": "paper", 696 | "output": "papers" 697 | }, 698 | { 699 | "input": "television", 700 | "output": "televisions" 701 | }, 702 | { 703 | "input": "newspaper", 704 | "output": "newspapers" 705 | }, 706 | { 707 | "input": "hammer", 708 | "output": "hammers" 709 | }, 710 | { 711 | "input": "chair", 712 | "output": "chairs" 713 | }, 714 | { 715 | "input": "ball", 716 | "output": "balls" 717 | }, 718 | { 719 | "input": "watch", 720 | "output": "watches" 721 | }, 722 | { 723 | "input": "screwdriver", 724 | "output": "screwdrivers" 725 | }, 726 | { 727 | "input": "cloud", 728 | "output": "clouds" 729 | }, 730 | { 731 | "input": "dryer", 732 | "output": "dryers" 733 | }, 734 | { 735 | "input": "train", 736 | "output": "trains" 737 | }, 738 | { 739 | "input": "lipstick", 740 | "output": "lipsticks" 741 | }, 742 | { 743 | "input": "hairbrush", 744 | "output": "hairbrushes" 745 | }, 746 | { 747 | "input": "glue", 748 | "output": "glues" 749 | }, 750 | { 751 | "input": "bat", 752 | "output": "bats" 753 | }, 754 | { 755 | "input": "roller", 756 | "output": "rollers" 757 | }, 758 | { 759 | "input": "ship", 760 | "output": "ships" 761 | }, 762 | { 763 | "input": "spoon", 764 | "output": "spoons" 765 | }, 766 | { 767 | "input": "motorcycle", 768 | "output": "motorcycles" 769 | }, 770 | { 771 | "input": "oven", 772 | "output": "ovens" 773 | }, 774 | { 775 | "input": "soap", 776 | "output": "soaps" 777 | }, 778 | { 779 | "input": "table", 780 | "output": "tables" 781 | }, 782 | { 783 | "input": "doll", 784 | "output": "dolls" 785 | }, 786 | { 787 | "input": "flashlight", 788 | "output": "flashlights" 789 | }, 790 | { 791 | "input": "scissors", 792 | "output": "scissors" 793 | }, 794 | { 795 | "input": "pencil", 796 | "output": "pencils" 797 | }, 798 | { 799 | "input": "magazine", 800 | "output": "magazines" 801 | }, 802 | { 803 | "input": "iron", 804 | "output": "irons" 805 | }, 806 | { 807 | "input": "camera", 808 | "output": "cameras" 809 | }, 810 | { 811 | "input": "trash can", 812 | "output": "trash cans" 813 | }, 814 | { 815 | "input": "nail", 816 | "output": "nails" 817 | }, 818 | { 819 | "input": "box", 820 | "output": "boxes" 821 | } 822 | ] -------------------------------------------------------------------------------- /GPT2/config/.ipynb_checkpoints/train_gpt2_small_adam_l2loss-checkpoint.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'l2loss' 3 | wandb_run_name = 'gpt2-small-adamw' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520 7 | batch_size = 8 8 | block_size = 1024 9 | gradient_accumulation_steps = 6 10 | 11 | n_layer = 12 12 | n_head = 12 13 | n_embd = 768 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | 17 | # this makes total number of tokens be 300B 18 | #max_iters = 100000 19 | max_iters = 10000 20 | lr_decay_iters = 10000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'adamw' 29 | learning_rate = 6e-3 # max learning rate 30 | weight_decay = 0. #1e-1 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | beta3 = 0. 34 | gamma = 1. 35 | lr_max = 6e-4 36 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 37 | # learning rate decay settings 38 | decay_lr = True # whether to decay the learning rate 39 | warmup_iters = 1000 # how many steps to warm up for 40 | #warmup_iters = 0 # how many steps to warm up for 41 | #min_lr = 3e-5 42 | min_lr = 3e-4 43 | 44 | compile = True 45 | 46 | #out_dir = 'out_small_adam_1k' 47 | out_dir = 'out_small_adam_l2loss' 48 | -------------------------------------------------------------------------------- /GPT2/config/train_gpt2_small_adam_l2loss.py: -------------------------------------------------------------------------------- 1 | wandb_log = True 2 | wandb_project = 'l2loss' 3 | wandb_run_name = 'gpt2-small-adamw' 4 | 5 | # these make the total batch size be ~0.5M 6 | # 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520 7 | batch_size = 8 8 | block_size = 1024 9 | gradient_accumulation_steps = 6 10 | 11 | n_layer = 12 12 | n_head = 12 13 | n_embd = 768 14 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 15 | bias = False 16 | 17 | # this makes total number of tokens be 300B 18 | #max_iters = 100000 19 | max_iters = 10000 20 | lr_decay_iters = 10000 21 | 22 | # eval stuff 23 | eval_interval = 1000 24 | eval_iters = 200 25 | log_interval = 10 26 | 27 | # optimizer 28 | optimizer_name = 'adamw' 29 | learning_rate = 6e-3 # max learning rate 30 | weight_decay = 0. #1e-1 31 | beta1 = 0.9 32 | beta2 = 0.95 33 | beta3 = 0. 34 | gamma = 1. 35 | lr_max = 6e-4 36 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 37 | # learning rate decay settings 38 | decay_lr = True # whether to decay the learning rate 39 | warmup_iters = 1000 # how many steps to warm up for 40 | #warmup_iters = 0 # how many steps to warm up for 41 | #min_lr = 3e-5 42 | min_lr = 3e-4 43 | 44 | compile = True 45 | 46 | #out_dir = 'out_small_adam_1k' 47 | out_dir = 'out_small_adam_l2loss' 48 | -------------------------------------------------------------------------------- /GPT2/configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /GPT2/prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "id": "e5d5ac79", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "10000\n", 14 | "Resuming training from out_small_adam\n" 15 | ] 16 | }, 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "/Users/ziming/opt/anaconda3/lib/python3.9/site-packages/torch/amp/autocast_mode.py:250: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n", 22 | " warnings.warn(\n" 23 | ] 24 | }, 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "number of parameters: 123.59M\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import os\n", 35 | "import time\n", 36 | "import math\n", 37 | "import pickle\n", 38 | "from contextlib import nullcontext\n", 39 | "import socket\n", 40 | "import numpy as np\n", 41 | "import torch\n", 42 | "from torch.nn.parallel import DistributedDataParallel as DDP\n", 43 | "from torch.distributed import init_process_group, destroy_process_group\n", 44 | "\n", 45 | "from model_tmax import GPTConfig, GPT\n", 46 | "\n", 47 | "import warnings\n", 48 | "from typing import Union, Iterable, List, Dict, Tuple, Optional\n", 49 | "\n", 50 | "import torch\n", 51 | "from torch import Tensor, inf\n", 52 | "from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support\n", 53 | "\n", 54 | "model_name = 'standard'\n", 55 | "#model_name = 'harmonic'\n", 56 | "\n", 57 | "ckpt_step = 10000\n", 58 | "ppp = []\n", 59 | "\n", 60 | "\n", 61 | "hostname = socket.gethostname()\n", 62 | "\n", 63 | "print(ckpt_step)\n", 64 | "# -----------------------------------------------------------------------------\n", 65 | "# default config values designed to train a gpt2 (124M) on OpenWebText\n", 66 | "# I/O\n", 67 | "out_dir = 'out'\n", 68 | "eval_interval = 2000\n", 69 | "log_interval = 1\n", 70 | "eval_iters = 200\n", 71 | "eval_only = False # if True, script exits right after the first eval\n", 72 | "always_save_checkpoint = True # if True, always save a checkpoint after each eval\n", 73 | "init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'\n", 74 | "# wandb logging\n", 75 | "wandb_log = False # disabled by default\n", 76 | "wandb_project = 'owt'\n", 77 | "wandb_run_name = 'gpt2' # 'run' + str(time.time())\n", 78 | "os.environ[\"WANDB_MODE\"]=\"offline\" # run wandb offline\n", 79 | "# data\n", 80 | "dataset = 'openwebtext'\n", 81 | "gradient_accumulation_steps = 5 # used to simulate larger batch sizes\n", 82 | "batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size\n", 83 | "block_size = 1024\n", 84 | "# model\n", 85 | "n_layer = 12\n", 86 | "n_head = 12\n", 87 | "n_embd = 768\n", 88 | "dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+\n", 89 | "bias = False # do we use bias inside LayerNorm and Linear layers?\n", 90 | "# optimizer\n", 91 | "optimizer_name = 'adamw' \n", 92 | "learning_rate = 6e-4 # max learning rate\n", 93 | "max_iters = 600000 # total number of training iterations\n", 94 | "weight_decay = 1e-1\n", 95 | "beta1 = 0.9\n", 96 | "beta2 = 0.95\n", 97 | "grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0\n", 98 | "rho = 0.1\n", 99 | "interval = 10\n", 100 | "variant = 4 \n", 101 | "# learning rate decay settings\n", 102 | "decay_lr = True # whether to decay the learning rate\n", 103 | "warmup_iters = 2000 # how many steps to warm up for\n", 104 | "lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla\n", 105 | "min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla\n", 106 | "# DDP settings\n", 107 | "backend = 'nccl' # 'nccl', 'gloo', etc.\n", 108 | "# system\n", 109 | "device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks\n", 110 | "dtype = 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler\n", 111 | "compile = True # use PyTorch 2.0 to compile the model to be faster\n", 112 | "scale_attn_by_inverse_layer_idx = True\n", 113 | "# -----------------------------------------------------------------------------\n", 114 | "config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]\n", 115 | "\n", 116 | "wandb_log = True\n", 117 | "wandb_project = 'anneal'\n", 118 | "wandb_run_name='gpt2-small-weightdecay-0-gradclipvalue-1'\n", 119 | "\n", 120 | "# these make the total batch size be ~0.5M\n", 121 | "# 8 batch size * 1024 block size * 6 gradaccum * 10 GPUs = 491,520\n", 122 | "batch_size = 8\n", 123 | "block_size = 1024\n", 124 | "gradient_accumulation_steps = 6\n", 125 | "\n", 126 | "n_layer = 12\n", 127 | "n_head = 12\n", 128 | "n_embd = 768\n", 129 | "dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+\n", 130 | "bias = False\n", 131 | "\n", 132 | "# this makes total number of tokens be 300B\n", 133 | "#max_iters = 100000\n", 134 | "max_iters = 10000\n", 135 | "lr_decay_iters = 10000\n", 136 | "\n", 137 | "# eval stuff\n", 138 | "eval_interval = 1000\n", 139 | "eval_iters = 200\n", 140 | "log_interval = 10\n", 141 | "\n", 142 | "# optimizer\n", 143 | "optimizer_name = 'adamw'\n", 144 | "learning_rate = 6e-4 # max learning rate\n", 145 | "weight_decay = 0 #1e-1\n", 146 | "beta1 = 0.9\n", 147 | "beta2 = 0.95\n", 148 | "beta3 = 0.\n", 149 | "gamma = 1.\n", 150 | "lr_max = 6e-4\n", 151 | "grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0\n", 152 | "# learning rate decay settings\n", 153 | "decay_lr = True # whether to decay the learning rate\n", 154 | "warmup_iters = 2000 # how many steps to warm up for\n", 155 | "#warmup_iters = 0 # how many steps to warm up for\n", 156 | "min_lr = 3e-5 \n", 157 | "\n", 158 | "compile = True\n", 159 | "\n", 160 | "if model_name == 'standard':\n", 161 | " out_dir = 'out_small_adam'\n", 162 | "else:\n", 163 | " out_dir = 'out_small_adam_hm'\n", 164 | "#out_dir = 'out_small_n_28_scale_28_6e-3_3e-4'\n", 165 | "device = 'cpu'\n", 166 | "\n", 167 | "config = {k: globals()[k] for k in config_keys} # will be useful for logging\n", 168 | "# -----------------------------------------------------------------------------\n", 169 | "\n", 170 | "# various inits, derived attributes, I/O setup\n", 171 | "ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?\n", 172 | "if ddp:\n", 173 | " init_process_group(backend=backend)\n", 174 | " ddp_rank = int(os.environ['RANK'])\n", 175 | " ddp_local_rank = int(os.environ['LOCAL_RANK'])\n", 176 | " world_size = int(os.environ[\"WORLD_SIZE\"])\n", 177 | " torch.cuda.set_device(ddp_local_rank)\n", 178 | " print(f\"Rank {ddp_rank}: world_size={world_size}, local_rank={ddp_local_rank}, hostname={hostname}\")\n", 179 | " master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.\n", 180 | " seed_offset = ddp_rank # each process gets a different seed\n", 181 | " device = torch.device(\"cuda\", ddp_local_rank)\n", 182 | "else:\n", 183 | " # if not ddp, we are running on a single gpu, and one process\n", 184 | " master_process = True\n", 185 | " seed_offset = 0\n", 186 | " gradient_accumulation_steps *= 8 # simulate 8 gpus\n", 187 | "\n", 188 | "if master_process:\n", 189 | " os.makedirs(out_dir, exist_ok=True)\n", 190 | "torch.manual_seed(5000 + seed_offset)\n", 191 | "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n", 192 | "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n", 193 | "#device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast\n", 194 | "device_type = 'cuda'\n", 195 | "# note: float16 data type will automatically use a GradScaler\n", 196 | "ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n", 197 | "ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)\n", 198 | "\n", 199 | "# poor man's data loader\n", 200 | "#data_dir = os.path.join('data', dataset)\n", 201 | "data_dir = os.path.join('./')\n", 202 | "#data_dir = './data'\n", 203 | "train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')\n", 204 | "val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')\n", 205 | "def get_batch(split):\n", 206 | " data = train_data if split == 'train' else val_data\n", 207 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 208 | " x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])\n", 209 | " y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])\n", 210 | " if device_type == 'cuda':\n", 211 | " # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)\n", 212 | " x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)\n", 213 | " else:\n", 214 | " x, y = x.to(device), y.to(device)\n", 215 | " return x, y\n", 216 | "\n", 217 | "# init these up here, can override if init_from='resume' (i.e. from a checkpoint)\n", 218 | "iter_num = 0\n", 219 | "best_val_loss = 1e9\n", 220 | "\n", 221 | "# attempt to derive vocab_size from the dataset\n", 222 | "meta_path = os.path.join(data_dir, 'meta.pkl')\n", 223 | "meta_vocab_size = None\n", 224 | "if os.path.exists(meta_path):\n", 225 | " with open(meta_path, 'rb') as f:\n", 226 | " meta = pickle.load(f)\n", 227 | " meta_vocab_size = meta['vocab_size']\n", 228 | " print(f\"found vocab_size = {meta_vocab_size} (inside {meta_path})\")\n", 229 | "\n", 230 | "# model init\n", 231 | "model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,\n", 232 | " bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line\n", 233 | "\n", 234 | "\n", 235 | "print(f\"Resuming training from {out_dir}\")\n", 236 | "# resume training from a checkpoint.\n", 237 | "ckpt_path = os.path.join(out_dir, f'ckpt_{ckpt_step}.pt')\n", 238 | "checkpoint = torch.load(ckpt_path, map_location=device)\n", 239 | "checkpoint_model_args = checkpoint['model_args']\n", 240 | "# force these config attributes to be equal otherwise we can't even resume training\n", 241 | "# the rest of the attributes (e.g. dropout) can stay as desired from command line\n", 242 | "for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:\n", 243 | " model_args[k] = checkpoint_model_args[k]\n", 244 | "# create the model\n", 245 | "gptconf = GPTConfig(**model_args)\n", 246 | "model = GPT(gptconf)\n", 247 | "state_dict = checkpoint['model']\n", 248 | "# fix the keys of the state dictionary :(\n", 249 | "# honestly no idea how checkpoints sometimes get this prefix, have to debug more\n", 250 | "unwanted_prefix = '_orig_mod.'\n", 251 | "for k,v in list(state_dict.items()):\n", 252 | " if k.startswith(unwanted_prefix):\n", 253 | " state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)\n", 254 | "model.load_state_dict(state_dict)\n", 255 | "iter_num = checkpoint['iter_num']\n", 256 | "best_val_loss = checkpoint['best_val_loss']\n" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 8, 262 | "id": "5e60eb06", 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "total tokens in training dataset 9035582489\n", 270 | "1 epoch = 18382.939634195962 steps\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "import os\n", 276 | "import numpy as np\n", 277 | "\n", 278 | "data_dir = os.path.join('./')\n", 279 | "train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')\n", 280 | "val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')\n", 281 | "print('total tokens in training dataset', train_data.shape[0])\n", 282 | "print('1 epoch =', train_data.shape[0]/(480*1024), 'steps')" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 9, 288 | "id": "8eeff0c1", 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "memmap([ 8585, 262, 1772, ..., 13815, 13, 50256], dtype=uint16)" 295 | ] 296 | }, 297 | "execution_count": 9, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "train_data" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 10, 309 | "id": "df34b167", 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "import tiktoken\n", 314 | "enc = tiktoken.get_encoding(\"gpt2\")\n", 315 | "prompt = enc.decode(train_data[10000:10010])\n", 316 | "idx = enc.encode(prompt, disallowed_special=())\n", 317 | "\n", 318 | "generate_idx = model.generate(torch.tensor([idx], dtype=torch.long), max_new_tokens=10)\n", 319 | "generate_content = enc.decode(list(generate_idx[0]))" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 11, 325 | "id": "7a97f959", 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "data": { 330 | "text/plain": [ 331 | "'chickpea’), which the P P Involved· away.42 P Trade Bee'" 332 | ] 333 | }, 334 | "execution_count": 11, 335 | "metadata": {}, 336 | "output_type": "execute_result" 337 | } 338 | ], 339 | "source": [ 340 | "generate_content" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "ffb766d3", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3 (ipykernel)", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.9.7" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 5 373 | } 374 | -------------------------------------------------------------------------------- /GPT2/readme.md: -------------------------------------------------------------------------------- 1 | # Start 2 | 3 | * Step 1: create a folder data/openwebtext, and put both binary files train.bin and val.bin into the folder. Binary files can be downloaded from [here](https://www.dropbox.com/scl/fo/v24k2eltevgiszdfvean6/AF0j1Pu9ladYpDZbqSVKHGI?rlkey=jwa73nxrwt5bj13a6c9q0z20w&st=090g6v8w&dl=0). 4 | 5 | * Step 2: in terminal type `sbatch train_adam_l2loss.sh`. That's it! This should immediately work on supercloud (except that perhaps you need to pip install wandb etc., I don't remember exactly). If you don't want to train the model by yourself, pre-trained models can be found [here](https://www.dropbox.com/scl/fo/v24k2eltevgiszdfvean6/AF0j1Pu9ladYpDZbqSVKHGI?rlkey=jwa73nxrwt5bj13a6c9q0z20w&st=090g6v8w&dl=0) in folders `out_small_adam` (standard) and `out_small_adam_hm` (harmonic). Place both folders in the current folder. 6 | 7 | # Notice 8 | * The code is based on [sophia repo](https://github.com/Liuhong99/Sophia/tree/main), which in turn is based on [nanogpt](https://github.com/karpathy/nanoGPT/). The training pipeline might be unnecessarily complicated for our purposes (a lot of parallelization etc.). 9 | * My major changes (relevant to harmonic losses) are in `model_l2loss.py` and highlighted with comments "Ziming's note". The standard transformer is in `model.py`. The line in `train_adam_l2loss.py`, which is `from model_l2loss import GPT, GPTConfig`, specifies that we're using GPT with harmonic similarity. To use standard GPT, change the line to `from model import GPT, GPTConfig`. 10 | * To change configurations, e.g., the size of the network, go to `config/train_gpt2_small_adam_l2loss.py`. Although there are some hyperparameters being set up at the beginning of `train_adam_l2loss.py`, these hyperparameters are later overwritten by `config/train_gpt2_small_adam_l2loss.py`. 11 | * Given the complexity of the training code, I suspect a faster way to kickstart is playing with the `GPT` model in `model_l2loss.py` and `model.py`, writing training loops by oneself without caring to read other files. 12 | 13 | # Parallelogram experiments 14 | * Function vector data are taken from [function vectors](https://github.com/ericwtodd/function_vectors/tree/main/dataset_files/abstractive) and are stored in ./abstractive. We only keep datasets that are consistent with our experimental setup. 15 | * function_vectors.ipynb: parallelogram loss 16 | * function_vectors_parallelogram.ipynb: parallelogram 17 | * How to generate text output: prediction.ipynb 18 | -------------------------------------------------------------------------------- /GPT2/train_adam_l2loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import pickle 5 | from contextlib import nullcontext 6 | import socket 7 | import numpy as np 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.distributed import init_process_group, destroy_process_group 11 | 12 | from model_l2loss import GPTConfig, GPT 13 | 14 | import warnings 15 | from typing import Union, Iterable, List, Dict, Tuple, Optional 16 | 17 | import torch 18 | from torch import Tensor, inf 19 | from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support 20 | 21 | hostname = socket.gethostname() 22 | 23 | 24 | # ----------------------------------------------------------------------------- 25 | # default config values designed to train a gpt2 (124M) on OpenWebText 26 | # I/O 27 | out_dir = 'out' 28 | eval_interval = 2000 29 | log_interval = 1 30 | eval_iters = 200 31 | eval_only = False # if True, script exits right after the first eval 32 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 33 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 34 | # wandb logging 35 | wandb_log = False # disabled by default 36 | wandb_project = 'owt' 37 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 38 | os.environ["WANDB_MODE"]="offline" # run wandb offline 39 | # data 40 | dataset = 'openwebtext' 41 | gradient_accumulation_steps = 5 # used to simulate larger batch sizes 42 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 43 | block_size = 1024 44 | # model 45 | n_layer = 12 46 | n_head = 12 47 | n_embd = 768 48 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 49 | bias = False # do we use bias inside LayerNorm and Linear layers? 50 | # optimizer 51 | optimizer_name = 'adamw' 52 | learning_rate = 6e-4 # max learning rate 53 | max_iters = 600000 # total number of training iterations 54 | weight_decay = 1e-1 55 | beta1 = 0.9 56 | beta2 = 0.95 57 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 58 | rho = 0.1 59 | interval = 10 60 | variant = 4 61 | # learning rate decay settings 62 | decay_lr = True # whether to decay the learning rate 63 | warmup_iters = 2000 # how many steps to warm up for 64 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 65 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 66 | # DDP settings 67 | backend = 'nccl' # 'nccl', 'gloo', etc. 68 | # system 69 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 70 | dtype = 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 71 | compile = True # use PyTorch 2.0 to compile the model to be faster 72 | scale_attn_by_inverse_layer_idx = True 73 | # ----------------------------------------------------------------------------- 74 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 75 | exec(open('configurator.py').read()) # overrides from command line or config file 76 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 77 | # ----------------------------------------------------------------------------- 78 | 79 | # various inits, derived attributes, I/O setup 80 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 81 | if ddp: 82 | init_process_group(backend=backend) 83 | ddp_rank = int(os.environ['RANK']) 84 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 85 | world_size = int(os.environ["WORLD_SIZE"]) 86 | torch.cuda.set_device(ddp_local_rank) 87 | print(f"Rank {ddp_rank}: world_size={world_size}, local_rank={ddp_local_rank}, hostname={hostname}") 88 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 89 | seed_offset = ddp_rank # each process gets a different seed 90 | device = torch.device("cuda", ddp_local_rank) 91 | else: 92 | # if not ddp, we are running on a single gpu, and one process 93 | master_process = True 94 | seed_offset = 0 95 | gradient_accumulation_steps *= 8 # simulate 8 gpus 96 | 97 | if master_process: 98 | os.makedirs(out_dir, exist_ok=True) 99 | torch.manual_seed(5000 + seed_offset) 100 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 101 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 102 | #device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 103 | device_type = 'cuda' 104 | # note: float16 data type will automatically use a GradScaler 105 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 106 | ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype) 107 | 108 | # poor man's data loader 109 | #data_dir = os.path.join('data', dataset) 110 | data_dir = os.path.join('./data', dataset) 111 | #data_dir = './data' 112 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 113 | val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 114 | def get_batch(split): 115 | data = train_data if split == 'train' else val_data 116 | ix = torch.randint(len(data) - block_size, (batch_size,)) 117 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 118 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 119 | if device_type == 'cuda': 120 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 121 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 122 | else: 123 | x, y = x.to(device), y.to(device) 124 | return x, y 125 | 126 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 127 | iter_num = 0 128 | best_val_loss = 1e9 129 | 130 | # attempt to derive vocab_size from the dataset 131 | meta_path = os.path.join(data_dir, 'meta.pkl') 132 | meta_vocab_size = None 133 | if os.path.exists(meta_path): 134 | with open(meta_path, 'rb') as f: 135 | meta = pickle.load(f) 136 | meta_vocab_size = meta['vocab_size'] 137 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 138 | 139 | # model init 140 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 141 | bias=bias, vocab_size=None, dropout=dropout, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) # start with model_args from command line 142 | if init_from == 'scratch': 143 | # init a new model from scratch 144 | print("Initializing a new model from scratch") 145 | # determine the vocab size we'll use for from-scratch training 146 | if meta_vocab_size is None: 147 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 148 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 149 | gptconf = GPTConfig(**model_args) 150 | model = GPT(gptconf) 151 | elif init_from == 'resume': 152 | print(f"Resuming training from {out_dir}") 153 | # resume training from a checkpoint. 154 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 155 | checkpoint = torch.load(ckpt_path, map_location=device) 156 | checkpoint_model_args = checkpoint['model_args'] 157 | # force these config attributes to be equal otherwise we can't even resume training 158 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 159 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 160 | model_args[k] = checkpoint_model_args[k] 161 | # create the model 162 | gptconf = GPTConfig(**model_args) 163 | model = GPT(gptconf) 164 | state_dict = checkpoint['model'] 165 | # fix the keys of the state dictionary :( 166 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 167 | unwanted_prefix = '_orig_mod.' 168 | for k,v in list(state_dict.items()): 169 | if k.startswith(unwanted_prefix): 170 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 171 | model.load_state_dict(state_dict) 172 | iter_num = checkpoint['iter_num'] 173 | best_val_loss = checkpoint['best_val_loss'] 174 | elif init_from.startswith('gpt2'): 175 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 176 | # initialize from OpenAI GPT-2 weights 177 | override_args = dict(dropout=dropout) 178 | model = GPT.from_pretrained(init_from, override_args) 179 | # read off the created config params, so we can store them into checkpoint correctly 180 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 181 | model_args[k] = getattr(model.config, k) 182 | # crop down the model block size if desired, using model surgery 183 | if block_size < model.config.block_size: 184 | model.crop_block_size(block_size) 185 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 186 | 187 | 188 | model.to(device) 189 | 190 | # initialize a GradScaler. If enabled=False scaler is a no-op 191 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) # defalt 2.0 may lead to overflow growth_factor=1.1, 192 | 193 | # optimizer 194 | optimizer = model.configure_optimizers(optimizer_name, weight_decay, learning_rate, (beta1, beta2), rho, gamma, lr_max, device_type) 195 | if init_from == 'resume': 196 | optimizer.load_state_dict(checkpoint['optimizer']) 197 | del state_dict 198 | del checkpoint 199 | # compile the model 200 | if compile: 201 | print("compiling the model... (takes a ~minute)") 202 | unoptimized_model = model 203 | model = torch.compile(model) # requires PyTorch 2.0 204 | 205 | # wrap model into DDP container 206 | if ddp: 207 | model = DDP(model, device_ids=[ddp_local_rank]) 208 | 209 | # helps estimate an arbitrarily accurate loss over either split using many batches 210 | @torch.no_grad() 211 | def estimate_loss(): 212 | out = {} 213 | model.eval() 214 | for split in ['train', 'val']: 215 | losses = torch.zeros(eval_iters) 216 | for k in range(eval_iters): 217 | X, Y = get_batch(split) 218 | with ctx: 219 | logits, loss, acc = model(X, Y) 220 | losses[k] = loss.item() 221 | out[split] = losses.mean() 222 | model.train() 223 | return out 224 | 225 | # learning rate decay scheduler (cosine with warmup) 226 | 227 | 228 | def get_lr(it): 229 | # 1) linear warmup for warmup_iters steps 230 | if it < warmup_iters: 231 | return learning_rate * it / warmup_iters 232 | # 2) if it > lr_decay_iters, return min learning rate 233 | if it > lr_decay_iters: 234 | return min_lr 235 | # 3) in between, use cosine decay down to min learning rate 236 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 237 | assert 0 <= decay_ratio <= 1 238 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 239 | return min_lr + coeff * (learning_rate - min_lr) 240 | 241 | 242 | 243 | _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] 244 | 245 | # logging 246 | if wandb_log and master_process: 247 | import wandb 248 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 249 | 250 | # training loop 251 | X, Y = get_batch('train') # fetch the very first batch 252 | t0 = time.time() 253 | local_iter_num = 0 # number of iterations in the lifetime of this process 254 | raw_model = model.module if ddp else model # unwrap DDP container if needed 255 | running_mfu = -1.0 256 | clip_time = 0 257 | 258 | 259 | def fillnan(x, nan_value=0.): 260 | x = torch.nan_to_num(x, nan=nan_value, posinf=nan_value, neginf=nan_value) 261 | return x 262 | 263 | while True: 264 | 265 | # determine and set the learning rate for this iteration 266 | lr = get_lr(iter_num) if decay_lr else learning_rate 267 | for param_group in optimizer.param_groups: 268 | param_group['lr'] = lr 269 | 270 | # evaluate the loss on train/val sets and write checkpoints 271 | if iter_num % eval_interval == 0 and master_process: 272 | losses = estimate_loss() 273 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 274 | if wandb_log: 275 | wandb.log({ 276 | "iter": iter_num, 277 | "train/loss": losses['train'], 278 | "val/loss": losses['val'], 279 | "lr": lr, 280 | "mfu": running_mfu*100, # convert to percentage 281 | }, step=iter_num) 282 | if losses['val'] < best_val_loss or always_save_checkpoint: 283 | best_val_loss = losses['val'] 284 | if iter_num > 0: 285 | checkpoint = { 286 | 'model': raw_model.state_dict(), 287 | 'optimizer': optimizer.state_dict(), 288 | 'model_args': model_args, 289 | 'iter_num': iter_num, 290 | 'best_val_loss': best_val_loss, 291 | 'config': config, 292 | } 293 | print(f"saving checkpoint to {out_dir}") 294 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 295 | 296 | if iter_num % (int(eval_interval)) == 0 and master_process: 297 | checkpoint = { 298 | 'model': raw_model.state_dict(), 299 | 'optimizer': optimizer.state_dict(), 300 | 'model_args': model_args, 301 | 'iter_num': iter_num, 302 | 'best_val_loss': best_val_loss, 303 | 'config': config, 304 | } 305 | print(f"saving checkpoint to {out_dir}") 306 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt_'+str(iter_num)+'.pt')) 307 | 308 | #torch.save(state_dict, path) 309 | 310 | if iter_num == 0 and eval_only: 311 | break 312 | 313 | # forward backward update, with optional gradient accumulation to simulate larger batch size 314 | # and using the GradScaler if data type is float16 315 | for micro_step in range(gradient_accumulation_steps): 316 | if ddp: 317 | # in DDP training we only need to sync gradients at the last micro step. 318 | # the official way to do this is with model.no_sync() context manager, but 319 | # I really dislike that this bloats the code and forces us to repeat code 320 | # looking at the source of that context manager, it just toggles this variable 321 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 322 | with ctx: 323 | logits, loss, acc = model(X, Y) 324 | 325 | loss = loss / gradient_accumulation_steps 326 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 327 | X, Y = get_batch('train') 328 | # backward pass, with gradient scaling if training in fp16 329 | scaler.scale(loss).backward() 330 | # clip the gradient 331 | if grad_clip != 0.0: 332 | scaler.unscale_(optimizer) 333 | 334 | total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 335 | if total_norm.item() > grad_clip: 336 | clip_time += 1 337 | # step the optimizer and scaler if training in fp16 338 | scaler.step(optimizer) 339 | scaler.update() 340 | 341 | # flush the gradients as soon as we can, no need for this memory anymore 342 | optimizer.zero_grad(set_to_none=True) 343 | 344 | # timing and logging 345 | t1 = time.time() 346 | dt = t1 - t0 347 | t0 = t1 348 | if iter_num % log_interval == 0 and master_process: 349 | lossf = loss.item() * gradient_accumulation_steps # loss as float. note: this is a CPU-GPU sync point 350 | accf = acc.item() 351 | if local_iter_num >= 5: # let the training loop settle a bit 352 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 353 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 354 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%, scale factor: {scaler.get_scale()}") 355 | params = [] 356 | for (name, p) in model.named_parameters(): 357 | params.append(p) 358 | total_param_norm = 0 359 | for p in params: 360 | param_norm = p.data.norm(2) 361 | total_param_norm += param_norm.item() ** 2 362 | total_param_norm = total_param_norm ** 0.5 363 | momentum_norm = 0. 364 | v_norm = 0. 365 | move_norm = 0. 366 | LL = len(optimizer.state_dict()['state']) 367 | for jj in range(LL): 368 | momentum_step = optimizer.state_dict()['state'][jj]['exp_avg'] 369 | v_step = optimizer.state_dict()['state'][jj]['exp_avg_sq'] 370 | move = momentum_step/(torch.sqrt(v_step) + 1e-8) 371 | 372 | momentum_norm += (momentum_step.detach().norm(2)) ** 2 373 | v_norm += (v_step.detach().norm(2)) ** 2 374 | move_norm += (move.detach().norm(2)) ** 2 375 | 376 | momentum_norm = torch.sqrt(momentum_norm).item() 377 | v_norm = torch.sqrt(v_norm).item() 378 | move_norm = torch.sqrt(move_norm).item() 379 | #alpha = get_alpha(iter_num) 380 | if wandb_log: 381 | wandb.log({ 382 | "iter": iter_num, 383 | "train/loss": lossf, 384 | "train/acc": accf, 385 | "lr": lr, 386 | #"alpha": alpha, 387 | "param_norm": total_param_norm, 388 | "momentum_norm" : momentum_norm, 389 | "v_norm" : v_norm, 390 | "move_norm" : move_norm, 391 | "train/clip_rate": clip_time / (iter_num + 1) 392 | }, step=iter_num) 393 | 394 | iter_num += 1 395 | local_iter_num += 1 396 | 397 | # termination conditions 398 | if iter_num > max_iters: 399 | break 400 | 401 | if ddp: 402 | destroy_process_group() 403 | -------------------------------------------------------------------------------- /GPT2/train_adam_l2loss.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o train_adam_l2loss.log-%j 3 | # SBATCH --job-name=AdamTrain 4 | #SBATCH --nodes=4 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gres=gpu:volta:2 7 | #SBATCH --cpus-per-task=40 8 | # SBATCH --mem=250G 9 | # SBATCH --time=23:59:00 10 | 11 | source /etc/profile 12 | module load anaconda/2023a-pytorch 13 | module load cuda/11.4 14 | module load nccl/2.10.3-cuda11.4 15 | 16 | export NCCL_DEBUG=INFO 17 | export PYTHONFAULTHANDLER=1 18 | 19 | # Set up rendezvous parameters 20 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 21 | export MASTER_PORT=$(shuf -i 29000-29999 -n 1) 22 | 23 | echo "MASTER_ADDR: $MASTER_ADDR" 24 | echo "MASTER_PORT: $MASTER_PORT" 25 | echo "SLURM_JOB_ID: $SLURM_JOB_ID" 26 | echo "SLURM_NTASKS: $SLURM_NTASKS" 27 | echo "SLURM_NODELIST: $SLURM_NODELIST" 28 | 29 | # Use srun to ensure the job is distributed 30 | srun --nodes=$SLURM_JOB_NUM_NODES --ntasks=$SLURM_JOB_NUM_NODES \ 31 | torchrun \ 32 | --nnodes=$SLURM_JOB_NUM_NODES \ 33 | --nproc_per_node=2 \ 34 | --rdzv_id=$SLURM_JOB_ID \ 35 | --rdzv_backend=c10d \ 36 | --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ 37 | train_adam_l2loss.py \ 38 | config/train_gpt2_small_adam_l2loss.py \ 39 | --batch_size=6 \ 40 | --gradient_accumulation_steps=10 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Harmonic Loss Trains Interpretable AI Models 2 | 3 | This is the GitHub repository for the paper "Harmonic Loss Trains Interpretable AI Models" [[arXiv]](https://arxiv.org/abs/2502.01628) [[Twitter]](https://x.com/dbaek__/status/1886781418115862544) [[Github]](https://github.com/KindXiaoming/grow-crystals). 4 | 5 | ![Harmonic Demo](./figures/weights_evolution.gif) 6 | 7 | ## What is Harmonic Loss? 8 | - Harmonic logit $d_i$ is defined as the $l_2$ distance between the weight vector $\mathbf{w}_i$ and the input (query) $\mathbf{x}$:  $d_i = \|\mathbf{w}_i - \mathbf{x}\|_2$. 9 | 10 | - The probability $p_i$ is computed using the harmonic max function: 11 | 12 | ![Harmonic Max](./figures/eq_harmax.png) 13 | 14 | 15 | where $n$ is the **harmonic exponent**—a hyperparameter that controls the heavy-tailedness of the probability distribution. 16 | 17 | - Harmonic Loss achieves (1) **nonlinear separability**, (2) **fast convergence**, (3) **scale invariance**, (4) **interpretability by design**, properties that are not available in cross-entropy loss. 18 | 19 | 20 | ## Reproducing results 21 | 22 | Download the results from the following link: [Link](https://www.dropbox.com/scl/fi/9kj9aw1ymgsw0qya7sh8h/harmonic-data.zip?rlkey=6oc804x2r3ocmx3jidow4uqcp&st=e7i81esq&dl=0) 23 | 24 | Figure 1: ``toy_points.ipynb`` 25 | 26 | Figure 2,3,7: ``notebooks/final_figures.ipynb`` 27 | 28 | Figure 4. ``notebooks/case_study_circle.ipynb`` 29 | 30 | Figure 5. ``notebooks/mnist.ipynb`` 31 | 32 | Figure 6. ``GPT2/function_vectors.ipynb`` 33 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: crystal 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - _libgcc_mutex=0.1=conda_forge 6 | - _openmp_mutex=4.5=2_gnu 7 | - asttokens=3.0.0=pyhd8ed1ab_1 8 | - bzip2=1.0.8=h4bc722e_7 9 | - ca-certificates=2024.12.14=hbcca054_0 10 | - comm=0.2.2=pyhd8ed1ab_1 11 | - debugpy=1.8.12=py312h2ec8cdc_0 12 | - decorator=5.1.1=pyhd8ed1ab_1 13 | - exceptiongroup=1.2.2=pyhd8ed1ab_1 14 | - executing=2.1.0=pyhd8ed1ab_1 15 | - importlib-metadata=8.6.1=pyha770c72_0 16 | - ipykernel=6.29.5=pyh3099207_0 17 | - ipython=8.31.0=pyh707e725_0 18 | - jedi=0.19.2=pyhd8ed1ab_1 19 | - jupyter_client=8.6.3=pyhd8ed1ab_1 20 | - jupyter_core=5.7.2=pyh31011fe_1 21 | - keyutils=1.6.1=h166bdaf_0 22 | - krb5=1.21.3=h659f571_0 23 | - ld_impl_linux-64=2.43=h712a8e2_2 24 | - libedit=3.1.20240808=pl5321h7949ede_0 25 | - libexpat=2.6.4=h5888daf_0 26 | - libffi=3.4.2=h7f98852_5 27 | - libgcc=14.2.0=h77fa898_1 28 | - libgcc-ng=14.2.0=h69a702a_1 29 | - libgomp=14.2.0=h77fa898_1 30 | - liblzma=5.6.3=hb9d3cd8_1 31 | - libnsl=2.0.1=hd590300_0 32 | - libsodium=1.0.20=h4ab18f5_0 33 | - libsqlite=3.48.0=hee588c1_0 34 | - libstdcxx=14.2.0=hc0a3c3a_1 35 | - libstdcxx-ng=14.2.0=h4852527_1 36 | - libuuid=2.38.1=h0b41bf4_0 37 | - libxcrypt=4.4.36=hd590300_1 38 | - libzlib=1.3.1=hb9d3cd8_2 39 | - matplotlib-inline=0.1.7=pyhd8ed1ab_1 40 | - ncurses=6.5=h2d0b736_2 41 | - nest-asyncio=1.6.0=pyhd8ed1ab_1 42 | - openssl=3.4.0=h7b32b05_1 43 | - packaging=24.2=pyhd8ed1ab_2 44 | - parso=0.8.4=pyhd8ed1ab_1 45 | - pexpect=4.9.0=pyhd8ed1ab_1 46 | - pickleshare=0.7.5=pyhd8ed1ab_1004 47 | - pip=24.3.1=pyh8b19718_2 48 | - platformdirs=4.3.6=pyhd8ed1ab_1 49 | - prompt-toolkit=3.0.50=pyha770c72_0 50 | - psutil=6.1.1=py312h66e93f0_0 51 | - ptyprocess=0.7.0=pyhd8ed1ab_1 52 | - pure_eval=0.2.3=pyhd8ed1ab_1 53 | - pygments=2.19.1=pyhd8ed1ab_0 54 | - python=3.12.8=h9e4cc4f_1_cpython 55 | - python-dateutil=2.9.0.post0=pyhff2d567_1 56 | - python_abi=3.12=5_cp312 57 | - pyzmq=26.2.0=py312hbf22597_3 58 | - readline=8.2=h8228510_1 59 | - setuptools=75.8.0=pyhff2d567_0 60 | - six=1.17.0=pyhd8ed1ab_0 61 | - stack_data=0.6.3=pyhd8ed1ab_1 62 | - tk=8.6.13=noxft_h4845f30_101 63 | - tornado=6.4.2=py312h66e93f0_0 64 | - traitlets=5.14.3=pyhd8ed1ab_1 65 | - typing_extensions=4.12.2=pyha770c72_1 66 | - wcwidth=0.2.13=pyhd8ed1ab_1 67 | - wheel=0.45.1=pyhd8ed1ab_1 68 | - zeromq=4.3.5=h3b0a872_7 69 | - zipp=3.21.0=pyhd8ed1ab_1 70 | - pip: 71 | - adjusttext==1.3.0 72 | - contourpy==1.3.1 73 | - cycler==0.12.1 74 | - filelock==3.16.1 75 | - fonttools==4.55.3 76 | - fsspec==2024.12.0 77 | - h5py==3.12.1 78 | - jinja2==3.1.5 79 | - joblib==1.4.2 80 | - kiwisolver==1.4.8 81 | - markupsafe==3.0.2 82 | - matplotlib==3.10.0 83 | - mpmath==1.3.0 84 | - networkx==3.4.2 85 | - numpy==2.2.2 86 | - nvidia-cublas-cu12==12.4.5.8 87 | - nvidia-cuda-cupti-cu12==12.4.127 88 | - nvidia-cuda-nvrtc-cu12==12.4.127 89 | - nvidia-cuda-runtime-cu12==12.4.127 90 | - nvidia-cudnn-cu12==9.1.0.70 91 | - nvidia-cufft-cu12==11.2.1.3 92 | - nvidia-curand-cu12==10.3.5.147 93 | - nvidia-cusolver-cu12==11.6.1.9 94 | - nvidia-cusparse-cu12==12.3.1.170 95 | - nvidia-nccl-cu12==2.21.5 96 | - nvidia-nvjitlink-cu12==12.4.127 97 | - nvidia-nvtx-cu12==12.4.127 98 | - pandas==2.2.3 99 | - pillow==11.1.0 100 | - pyparsing==3.2.1 101 | - pytz==2024.2 102 | - scikit-learn==1.6.1 103 | - scipy==1.15.1 104 | - sympy==1.13.1 105 | - threadpoolctl==3.5.0 106 | - torch==2.5.1 107 | - tqdm==4.67.1 108 | - triton==3.1.0 109 | - tzdata==2024.2 110 | prefix: /om/user/dbaek/.conda/envs/crystal 111 | -------------------------------------------------------------------------------- /figures/circle_case_study.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/circle_case_study.pdf -------------------------------------------------------------------------------- /figures/data_eff_plot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/data_eff_plot.pdf -------------------------------------------------------------------------------- /figures/eq_harmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/eq_harmax.png -------------------------------------------------------------------------------- /figures/ev_plot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/ev_plot.pdf -------------------------------------------------------------------------------- /figures/grokking_plot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/grokking_plot.pdf -------------------------------------------------------------------------------- /figures/mnist_harmonic_weights.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/mnist_harmonic_weights.pdf -------------------------------------------------------------------------------- /figures/mnist_standard_weights.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/mnist_standard_weights.pdf -------------------------------------------------------------------------------- /figures/modadd_weights_evolution.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/modadd_weights_evolution.gif -------------------------------------------------------------------------------- /figures/rep_plots.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/rep_plots.pdf -------------------------------------------------------------------------------- /figures/rep_plots_appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/rep_plots_appendix.pdf -------------------------------------------------------------------------------- /figures/weights_evolution.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/figures/weights_evolution.gif -------------------------------------------------------------------------------- /notebooks/mnist_video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append(\"../\") \n", 11 | "\n", 12 | "from src.utils.driver import set_seed\n", 13 | "\n", 14 | "set_seed(57)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### Model and Dataset" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "import torch.optim as optim\n", 33 | "from torchvision import datasets, transforms\n", 34 | "from torch.utils.data import DataLoader\n", 35 | "\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "import numpy as np\n", 38 | "import os\n", 39 | "import imageio\n", 40 | "\n", 41 | "from src.utils.model import DistLayer\n", 42 | "\n", 43 | "# Define the model class\n", 44 | "class SimpleNN(nn.Module):\n", 45 | " def __init__(self, harmonic=False):\n", 46 | " super(SimpleNN, self).__init__()\n", 47 | " self.harmonic = harmonic\n", 48 | " if harmonic:\n", 49 | " self.fc1 = DistLayer(28 * 28, 10, n=1.)\n", 50 | " else:\n", 51 | " self.fc1 = nn.Linear(28 * 28, 10)\n", 52 | " nn.init.normal_(self.fc1.weight, mean=0, std=1/28.)\n", 53 | "\n", 54 | " def forward(self, x):\n", 55 | " x = x.view(-1, 28 * 28) # Flatten the input\n", 56 | " x = self.fc1(x)\n", 57 | " if self.harmonic:\n", 58 | " prob = x/torch.sum(x, dim=1, keepdim=True)\n", 59 | " logits = (-1)*torch.log(prob)\n", 60 | " return logits\n", 61 | " return x\n", 62 | "\n", 63 | "# Set device\n", 64 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 65 | "\n", 66 | "# Hyperparameters\n", 67 | "batch_size = 64\n", 68 | "learning_rate = 0.001\n", 69 | "max_epochs = 100\n", 70 | "\n", 71 | "# Load MNIST dataset\n", 72 | "transform = transforms.Compose([\n", 73 | " transforms.ToTensor()\n", 74 | "])\n", 75 | "\n", 76 | "train_dataset = datasets.MNIST(root=\"./data\", train=True, transform=transform, download=True)\n", 77 | "test_dataset = datasets.MNIST(root=\"./data\", train=False, transform=transform, download=True)\n", 78 | "\n", 79 | "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", 80 | "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 3, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "import torch\n", 97 | "import torch.nn as nn\n", 98 | "import torch.optim as optim\n", 99 | "import matplotlib.pyplot as plt\n", 100 | "import numpy as np\n", 101 | "import os\n", 102 | "import imageio\n", 103 | "\n", 104 | "def save_weight_visualization_gif(model, train_loader, test_loader, \n", 105 | " max_epochs=100, \n", 106 | " learning_rate=0.001, \n", 107 | " device='cuda', \n", 108 | " output_dir='../results/mnist_vis',\n", 109 | " save_prefix='',\n", 110 | " selected_classes=[3, 5, 7, 9]):\n", 111 | " \"\"\"\n", 112 | " Train the model and save weight visualizations as a GIF\n", 113 | " \n", 114 | " Args:\n", 115 | " model (nn.Module): Neural network model\n", 116 | " train_loader (DataLoader): Training data loader\n", 117 | " test_loader (DataLoader): Test data loader\n", 118 | " max_epochs (int): Maximum number of training epochs\n", 119 | " learning_rate (float): Learning rate for optimizer\n", 120 | " device (str): Training device (cuda/cpu)\n", 121 | " output_dir (str): Directory to save visualizations\n", 122 | " selected_classes (list): Classes to visualize\n", 123 | " \"\"\"\n", 124 | " # Ensure output directory exists\n", 125 | " os.makedirs(output_dir, exist_ok=True)\n", 126 | " \n", 127 | " # Move model to device\n", 128 | " model = model.to(device)\n", 129 | " \n", 130 | " # Optimizer and loss\n", 131 | " optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n", 132 | " \n", 133 | " # Visualization storage\n", 134 | " weight_images = []\n", 135 | " \n", 136 | " # Training loop\n", 137 | " for epoch in [1]:\n", 138 | " # Training phase\n", 139 | " model.train()\n", 140 | " running_loss = 0.0\n", 141 | " \n", 142 | " for batch_idx, (data, targets) in enumerate(train_loader):\n", 143 | " data, targets = data.to(device), targets.to(device)\n", 144 | " \n", 145 | " # Forward pass\n", 146 | " outputs = model(data)\n", 147 | " loss = outputs[range(targets.size(0)), targets].mean()\n", 148 | " \n", 149 | " # Backward pass\n", 150 | " optimizer.zero_grad()\n", 151 | " loss.backward()\n", 152 | " optimizer.step()\n", 153 | " \n", 154 | " running_loss += loss.item()\n", 155 | " \n", 156 | " # Periodically save weight visualization\n", 157 | " if (batch_idx+1) % 15 == 0:\n", 158 | " \n", 159 | " plt.figure(figsize=(16, 12))\n", 160 | " plt.suptitle(f'Weight Visualization - Epoch {epoch}')\n", 161 | " \n", 162 | " for i, cls in enumerate(selected_classes, 1):\n", 163 | " # Extract weights for specific class\n", 164 | " weights = model.fc1.weight.detach().cpu().numpy()[cls].reshape(28, 28)\n", 165 | " weights = np.where(weights < 0.01, 1, 0)\n", 166 | " \n", 167 | " plt.subplot(2, 2, i)\n", 168 | " plt.title(f'Class {cls}')\n", 169 | " plt.imshow(weights, cmap='viridis')\n", 170 | "# plt.colorbar()\n", 171 | " plt.axis('off')\n", 172 | " \n", 173 | " plt.tight_layout()\n", 174 | " \n", 175 | " # Save plot to a temporary file\n", 176 | " temp_plot_path = os.path.join(output_dir, f'{save_prefix}_mnist_{(batch_idx+1)}.png')\n", 177 | " torch.save(model.state_dict(), temp_plot_path.replace('.png', '.pt'))\n", 178 | " plt.savefig(temp_plot_path)\n", 179 | " plt.close()\n", 180 | " \n", 181 | " # Read the image and append to list\n", 182 | " weight_images.append(imageio.imread(temp_plot_path))\n", 183 | " print(batch_idx)\n", 184 | " if batch_idx > 900:\n", 185 | " break\n", 186 | " \n", 187 | " # Save as GIF\n", 188 | "# imageio.mimsave('../figures/mnist_weights_evolution.gif', weight_images, duration=0.5)\n", 189 | " \n", 190 | " # Evaluation\n", 191 | " model.eval()\n", 192 | " correct = 0\n", 193 | " with torch.no_grad():\n", 194 | " for data, targets in test_loader:\n", 195 | " data, targets = data.to(device), targets.to(device)\n", 196 | " outputs = (-1)*model(data)\n", 197 | " _, predicted = torch.max(outputs, 1)\n", 198 | " correct += (predicted == targets).sum().item()\n", 199 | " \n", 200 | " accuracy = correct / len(test_loader.dataset) * 100\n", 201 | " print(f\"Test Accuracy: {accuracy:.2f}%\")\n", 202 | "\n" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 4, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stderr", 212 | "output_type": "stream", 213 | "text": [ 214 | "/tmp/ipykernel_777127/1809309761.py:87: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n", 215 | " weight_images.append(imageio.imread(temp_plot_path))\n" 216 | ] 217 | }, 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "14\n", 223 | "29\n", 224 | "44\n", 225 | "59\n", 226 | "74\n", 227 | "89\n", 228 | "104\n", 229 | "119\n", 230 | "134\n", 231 | "149\n", 232 | "164\n", 233 | "179\n", 234 | "194\n", 235 | "209\n", 236 | "224\n", 237 | "239\n", 238 | "254\n", 239 | "269\n", 240 | "284\n", 241 | "299\n", 242 | "314\n", 243 | "329\n", 244 | "344\n", 245 | "359\n", 246 | "374\n", 247 | "389\n", 248 | "404\n", 249 | "419\n", 250 | "434\n", 251 | "449\n", 252 | "464\n", 253 | "479\n", 254 | "494\n", 255 | "509\n", 256 | "524\n", 257 | "539\n", 258 | "554\n", 259 | "569\n", 260 | "584\n", 261 | "599\n", 262 | "614\n", 263 | "629\n", 264 | "644\n", 265 | "659\n", 266 | "674\n", 267 | "689\n", 268 | "704\n", 269 | "719\n", 270 | "734\n", 271 | "749\n", 272 | "764\n", 273 | "779\n", 274 | "794\n", 275 | "809\n", 276 | "824\n", 277 | "839\n", 278 | "854\n", 279 | "869\n", 280 | "884\n", 281 | "899\n", 282 | "914\n", 283 | "Test Accuracy: 75.66%\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "model = SimpleNN(harmonic=True).to(device)\n", 289 | "save_weight_visualization_gif(model, train_loader, test_loader, save_prefix='harmonic')" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 5, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "name": "stderr", 299 | "output_type": "stream", 300 | "text": [ 301 | "/tmp/ipykernel_777127/1809309761.py:87: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n", 302 | " weight_images.append(imageio.imread(temp_plot_path))\n" 303 | ] 304 | }, 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "14\n", 310 | "29\n", 311 | "44\n", 312 | "59\n", 313 | "74\n", 314 | "89\n", 315 | "104\n", 316 | "119\n", 317 | "134\n", 318 | "149\n", 319 | "164\n", 320 | "179\n", 321 | "194\n", 322 | "209\n", 323 | "224\n", 324 | "239\n", 325 | "254\n", 326 | "269\n", 327 | "284\n", 328 | "299\n", 329 | "314\n", 330 | "329\n", 331 | "344\n", 332 | "359\n", 333 | "374\n", 334 | "389\n", 335 | "404\n", 336 | "419\n", 337 | "434\n", 338 | "449\n", 339 | "464\n", 340 | "479\n", 341 | "494\n", 342 | "509\n", 343 | "524\n", 344 | "539\n", 345 | "554\n", 346 | "569\n", 347 | "584\n", 348 | "599\n", 349 | "614\n", 350 | "629\n", 351 | "644\n", 352 | "659\n", 353 | "674\n", 354 | "689\n", 355 | "704\n", 356 | "719\n", 357 | "734\n", 358 | "749\n", 359 | "764\n", 360 | "779\n", 361 | "794\n", 362 | "809\n", 363 | "824\n", 364 | "839\n", 365 | "854\n", 366 | "869\n", 367 | "884\n", 368 | "899\n", 369 | "914\n", 370 | "Test Accuracy: 72.01%\n" 371 | ] 372 | } 373 | ], 374 | "source": [ 375 | "model = SimpleNN(harmonic=False).to(device)\n", 376 | "save_weight_visualization_gif(model, train_loader, test_loader, save_prefix='standard')" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "crystal", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.12.8" 404 | } 405 | }, 406 | "nbformat": 4, 407 | "nbformat_minor": 2 408 | } 409 | -------------------------------------------------------------------------------- /perm_figs/final_figs/H_MLP_ev_0.5912.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/H_MLP_ev_0.5912.png -------------------------------------------------------------------------------- /perm_figs/final_figs/H_transformer_ev_0.4036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/H_transformer_ev_0.4036.png -------------------------------------------------------------------------------- /perm_figs/final_figs/fig_eff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/fig_eff.png -------------------------------------------------------------------------------- /perm_figs/final_figs/fig_fvu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/fig_fvu.png -------------------------------------------------------------------------------- /perm_figs/final_figs/fig_grok_mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/fig_grok_mlp.png -------------------------------------------------------------------------------- /perm_figs/final_figs/fig_grok_transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/fig_grok_transformer.png -------------------------------------------------------------------------------- /perm_figs/final_figs/standard_MLP_ev_0.3940.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/standard_MLP_ev_0.3940.png -------------------------------------------------------------------------------- /perm_figs/final_figs/standard_transformer_ev_0.2944.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/final_figs/standard_transformer_ev_0.2944.png -------------------------------------------------------------------------------- /perm_figs/n=1/0_H_MLP_ev_33.4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1/0_H_MLP_ev_33.4.png -------------------------------------------------------------------------------- /perm_figs/n=1/105_H_MLP_ev_38.7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1/105_H_MLP_ev_38.7.png -------------------------------------------------------------------------------- /perm_figs/n=1/157_H_MLP_ev_32.7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1/157_H_MLP_ev_32.7.png -------------------------------------------------------------------------------- /perm_figs/n=1/210_H_MLP_ev_43.2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1/210_H_MLP_ev_43.2.png -------------------------------------------------------------------------------- /perm_figs/n=1/52_H_MLP_ev_27.3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1/52_H_MLP_ev_27.3.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/emb_H_MLP_ev_0.5169.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/emb_H_MLP_ev_0.5169.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/emb_H_transformer_ev_0.3840.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/emb_H_transformer_ev_0.3840.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/emb_standard_MLP_ev_0.3582.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/emb_standard_MLP_ev_0.3582.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/emb_standard_transformer_ev_0.3366.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/emb_standard_transformer_ev_0.3366.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/fig_eff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/fig_eff.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/fig_fvu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/fig_fvu.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/fig_grok_mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/fig_grok_mlp.png -------------------------------------------------------------------------------- /perm_figs/n=1_high_weight_decay/fig_grok_transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=1_high_weight_decay/fig_grok_transformer.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/H_MLP_ev_34.4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/H_MLP_ev_34.4.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/H_transformer_ev_33.4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/H_transformer_ev_33.4.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/fig_eff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/fig_eff.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/fig_fvu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/fig_fvu.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/fig_grok_mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/fig_grok_mlp.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/fig_grok_transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/fig_grok_transformer.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/standard_MLP_ev_39.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/standard_MLP_ev_39.0.png -------------------------------------------------------------------------------- /perm_figs/n=embd_dim/standard_transformer_ev_28.7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=embd_dim/standard_transformer_ev_28.7.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/H_MLP_ev_59.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/H_MLP_ev_59.1.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/H_transformer_ev_41.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/H_transformer_ev_41.1.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/fig_eff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/fig_eff.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/fig_fvu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/fig_fvu.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/fig_grok_mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/fig_grok_mlp.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/fig_grok_transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/fig_grok_transformer.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/standard_MLP_ev_39.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/standard_MLP_ev_39.1.png -------------------------------------------------------------------------------- /perm_figs/n=sqrt_embd_dim/standard_transformer_ev_27.2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KindXiaoming/grow-crystals/0f2e6cd03d3211b5c93e325b6c2083afb35083c2/perm_figs/n=sqrt_embd_dim/standard_transformer_ev_27.2.png -------------------------------------------------------------------------------- /scripts/HM_circle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id circle --model_id H_MLP 6 | 7 | -------------------------------------------------------------------------------- /scripts/HM_equiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 23:59:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id equivalence --model_id H_MLP 6 | python ../src/run_exp.py --data_id circle --model_id H_MLP 7 | -------------------------------------------------------------------------------- /scripts/HM_family.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id family_tree --model_id H_MLP 6 | 7 | -------------------------------------------------------------------------------- /scripts/HM_lattice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 23:59:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id lattice --model_id H_MLP 6 | python ../src/run_exp.py --data_id family_tree --model_id H_MLP 7 | python ../src/run_exp.py --data_id equivalence --model_id H_MLP 8 | python ../src/run_exp.py --data_id circle --model_id H_MLP 9 | -------------------------------------------------------------------------------- /scripts/HM_permutation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 23:59:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id permutation --model_id H_MLP 6 | -------------------------------------------------------------------------------- /scripts/HT_circle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id circle --model_id H_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/HT_equiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id equivalence --model_id H_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/HT_family.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id family_tree --model_id H_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/HT_lattice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id lattice --model_id H_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/HT_permutation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id permutation --model_id H_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/M_circle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id circle --model_id standard_MLP 6 | 7 | -------------------------------------------------------------------------------- /scripts/M_equiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 23:59:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id equivalence --model_id standard_MLP 6 | python ../src/run_exp.py --data_id circle --model_id standard_MLP 7 | 8 | -------------------------------------------------------------------------------- /scripts/M_family.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id family_tree --model_id standard_MLP 6 | 7 | -------------------------------------------------------------------------------- /scripts/M_lattice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id lattice --model_id standard_MLP 6 | python ../src/run_exp.py --data_id family_tree --model_id standard_MLP 7 | python ../src/run_exp.py --data_id equivalence --model_id standard_MLP 8 | python ../src/run_exp.py --data_id circle --model_id standard_MLP 9 | 10 | -------------------------------------------------------------------------------- /scripts/M_permutation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id permutation --model_id standard_MLP -------------------------------------------------------------------------------- /scripts/T_circle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id circle --model_id standard_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/T_equiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id equivalence --model_id standard_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/T_family.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id family_tree --model_id standard_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/T_lattice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id lattice --model_id standard_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/T_permutation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | python ../src/run_exp.py --data_id permutation --model_id standard_transformer 6 | 7 | -------------------------------------------------------------------------------- /scripts/circle_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/run_exp.py --data_id circle --model_id standard_MLP 7 | python ../src/run_exp.py --data_id circle --model_id H_MLP 8 | python ../src/run_exp.py --data_id circle --model_id standard_transformer 9 | python ../src/run_exp.py --data_id circle --model_id H_transformer 10 | 11 | -------------------------------------------------------------------------------- /scripts/data_size_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 2:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 32 5 | 6 | sizes=$(python3 -c "import numpy as np; print(' '.join(map(str, np.logspace(1, 4, num=10, dtype=int))))") 7 | 8 | 9 | for size in $sizes 10 | do 11 | python3 ../sweep_transformers.py --data_size $size --use_harmonic 0 12 | done 13 | 14 | -------------------------------------------------------------------------------- /scripts/equiv_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/run_exp.py --data_id equivalence --model_id standard_MLP 7 | python ../src/run_exp.py --data_id equivalence --model_id H_MLP 8 | python ../src/run_exp.py --data_id equivalence --model_id standard_transformer 9 | python ../src/run_exp.py --data_id equivalence --model_id H_transformer 10 | 11 | -------------------------------------------------------------------------------- /scripts/family_tree_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/run_exp.py --data_id family_tree --model_id standard_MLP 7 | python ../src/run_exp.py --data_id family_tree --model_id H_MLP 8 | python ../src/run_exp.py --data_id family_tree --model_id standard_transformer 9 | python ../src/run_exp.py --data_id family_tree --model_id H_transformer 10 | 11 | -------------------------------------------------------------------------------- /scripts/greater_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/run_exp.py --data_id greater --model_id standard_MLP 7 | python ../src/run_exp.py --data_id greater --model_id H_MLP 8 | python ../src/run_exp.py --data_id greater --model_id standard_transformer 9 | python ../src/run_exp.py --data_id greater --model_id H_transformer 10 | 11 | -------------------------------------------------------------------------------- /scripts/lattice.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | 5 | sys.path.append('..') 6 | 7 | from src.utils.model import * 8 | from src.utils.dataset import * 9 | import numpy as np 10 | from sklearn.decomposition import PCA 11 | 12 | def run(): 13 | 14 | # Grab the arguments that are passed in 15 | my_task_id = int(sys.argv[1]) 16 | num_tasks = int(sys.argv[2]) 17 | 18 | model_modes = ["standard", "ip", "hs1", "hs2"] 19 | #embd_dims = [1,2,3,4,5,10,20,50,100] 20 | embd_dims = [1,2,3,4,5,10,20] 21 | data_nums = [10,20,50,100,200,500,1000,2000,5000,10000] 22 | lambs = [0.,0.1,1.,10.] 23 | #seeds = [0,1,2,3,4] 24 | seeds = [0] 25 | 26 | xx, yy, zz, uu, vv = np.meshgrid(model_modes, embd_dims, data_nums, lambs, seeds) 27 | params_ = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,), zz.reshape(-1,), uu.reshape(-1,), vv.reshape(-1,)])) 28 | 29 | indices = np.arange(params_.shape[0]) 30 | 31 | my_indices = indices[my_task_id:indices.shape[0]:num_tasks] 32 | 33 | for i in my_indices: 34 | 35 | steps = 10001 #4001 36 | 37 | model_mode = params_[i][0].astype('str') 38 | embd_dim = params_[i][1].astype('int') 39 | data_num = params_[i][2].astype('int') 40 | lamb = params_[i][3].astype('float') 41 | seed = params_[i][4].astype('int') 42 | 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | 46 | torch.set_default_tensor_type(torch.DoubleTensor) 47 | 48 | device = 'cpu' 49 | 50 | p = 10 51 | input_token = 3 52 | lattice_dim = 2 53 | vocab_size = p ** lattice_dim 54 | 55 | 56 | if model_mode == 'ip': 57 | # ip model 58 | unembd = True 59 | weight_tied = True 60 | hidden_size = 100 61 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 62 | model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed).to(device) 63 | elif model_mode == 'hs2': 64 | weight_tied = True 65 | hidden_size = 100 66 | shp = [input_token * embd_dim, embd_dim, vocab_size] 67 | model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed).to(device) 68 | elif model_mode == 'hs1': 69 | weight_tied = True 70 | hidden_size = 100 71 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 72 | model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed).to(device) 73 | elif model_mode == 'standard': 74 | unembd = False 75 | weight_tied = False 76 | hidden_size = 100 77 | shp = [input_token * embd_dim, hidden_size, vocab_size] 78 | model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed).to(device) 79 | else: 80 | print('model_mode not recognized!') 81 | 82 | 83 | # data 84 | dataset = parallelogram_dataset(p=p, dim=lattice_dim, num=data_num, seed=seed) 85 | dataset = repeat_dataset(dataset) 86 | 87 | ### train ### 88 | wd = 0.0 89 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=wd) 90 | log = 200 91 | 92 | train_losses = [] 93 | test_losses = [] 94 | train_accs = [] 95 | test_accs = [] 96 | 97 | embds = [] 98 | 99 | 100 | for step in range(steps): 101 | 102 | optimizer.zero_grad() 103 | 104 | logits = model.pred_logit(dataset['train_data_id']) 105 | loss = torch.nn.functional.cross_entropy(logits, dataset['train_label']) 106 | 107 | embd_reg = torch.mean(torch.sqrt(torch.mean(model.embedding**2, dim=0))) 108 | total_loss = loss + lamb * embd_reg 109 | 110 | acc = torch.mean((torch.argmax(logits, dim=1) == dataset['train_label']).float()) 111 | 112 | train_losses.append(loss.item()) 113 | train_accs.append(acc.item()) 114 | 115 | logits_test = model.pred_logit(dataset['test_data_id']) 116 | loss_test = torch.nn.functional.cross_entropy(logits_test, dataset['test_label']) 117 | 118 | acc_test = torch.mean((torch.argmax(logits_test, dim=1) == dataset['test_label']).float()) 119 | 120 | test_losses.append(loss_test.item()) 121 | test_accs.append(acc_test.item()) 122 | 123 | #total_loss = loss 124 | total_loss.backward() 125 | optimizer.step() 126 | 127 | if step % log == 0: 128 | print("step = %d | total loss: %.2e | train loss: %.2e | test loss %.2e | train acc: %.2e | test acc: %.2e "%(step, total_loss.cpu().detach().numpy(), loss.cpu().detach().numpy(), loss_test.cpu().detach().numpy(), acc.cpu().detach().numpy(), acc_test.cpu().detach().numpy())) 129 | 130 | if step % 100 == 0: 131 | embds.append(model.embedding.cpu().detach().numpy()) 132 | 133 | embd = model.embedding.cpu().detach().numpy() 134 | X = embd 135 | pca = PCA(n_components=embd_dim) 136 | pca.fit(X) 137 | embd_t = pca.fit_transform(X) 138 | 139 | active_pca_dim = np.sum(pca.explained_variance_ratio_ > 1e-4) 140 | active_embd_dim = torch.sum(torch.mean(model.embedding**2, dim=0) > 1e-4).item() 141 | 142 | inputs = embd[dataset['train_data_id']] 143 | output = (- inputs[:,0,:] + inputs[:,1,:] + inputs[:,2,:]) 144 | 145 | xx = np.linalg.norm(output, axis=1)[:,None]**2 146 | ww = np.linalg.norm(embd, axis=1)[None,:]**2 147 | wx = output @ embd.T 148 | distsq = ww + xx - 2 * wx 149 | parallelogram_acc = np.mean(np.argmin(distsq, axis=1) == dataset['train_label'].cpu().detach().numpy()) 150 | 151 | # save train_acc, test_acc, parallelogram_acc, active_pca_dim, active_embd_dim, 152 | np.savetxt('./results/lattice/model_%s_embddim_%d_data_%d_lamb_%.2f_seed_%d_p_10_performance.txt'%(model_mode, embd_dim, data_num, lamb, seed), [train_accs[-1], test_accs[-1], parallelogram_acc, active_pca_dim, active_embd_dim]) 153 | # save embd 154 | np.savetxt('./results/lattice/model_%s_embddim_%d_data_%d_lamb_%.2f_seed_%d_p_10_embedding.txt'%(model_mode, embd_dim, data_num, lamb, seed), embd) 155 | 156 | run() 157 | -------------------------------------------------------------------------------- /scripts/lattice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "My task ID:" $LLSUB_RANK 4 | echo "Number of Tasks:" $LLSUB_SIZE 5 | 6 | nohup python lattice.py $LLSUB_RANK $LLSUB_SIZE 7 | 8 | -------------------------------------------------------------------------------- /scripts/lattice_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/run_exp.py --data_id lattice --model_id standard_MLP 7 | python ../src/run_exp.py --data_id lattice --model_id H_MLP 8 | python ../src/run_exp.py --data_id lattice --model_id standard_transformer 9 | python ../src/run_exp.py --data_id lattice --model_id H_transformer 10 | 11 | -------------------------------------------------------------------------------- /scripts/loss_exp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | # Define arrays for parameters 7 | data_ids=("circle") # "circle" 8 | model_ids=("H_MLP") 9 | splits=(3) # Modify as needed 10 | ns=(1) 11 | 12 | # Iterate over all combinations 13 | for data_id in "${data_ids[@]}"; do 14 | for model_id in "${model_ids[@]}"; do 15 | for split in "${splits[@]}"; do 16 | for n in "${ns[@]}"; do 17 | PYTHONPATH=$(pwd) python src/run_exp.py --data_id "$data_id" --model_id "$model_id" --split "$split" --n "$n" > output_softnn_"$data_id"_"$model_id".txt 2>&1 18 | done 19 | done 20 | done 21 | done -------------------------------------------------------------------------------- /scripts/modadd.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | 5 | sys.path.append('..') 6 | 7 | from src.utils.model import * 8 | from src.utils.dataset import * 9 | import numpy as np 10 | from sklearn.decomposition import PCA 11 | import math 12 | 13 | def run(): 14 | 15 | # Grab the arguments that are passed in 16 | my_task_id = int(sys.argv[1]) 17 | num_tasks = int(sys.argv[2]) 18 | 19 | model_modes = ["standard", "ip", "hs1", "hs2", "hs3"] 20 | #model_modes = ["hs2"] 21 | embd_dims = [1,2,3,4,5,10,20,50] 22 | lambs = [0.,0.1,0.3,1.] 23 | seeds = [0] 24 | 25 | xx, yy, zz, uu= np.meshgrid(model_modes, embd_dims, lambs, seeds) 26 | params_ = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,), zz.reshape(-1,), uu.reshape(-1,)])) 27 | 28 | indices = np.arange(params_.shape[0]) 29 | 30 | my_indices = indices[my_task_id:indices.shape[0]:num_tasks] 31 | 32 | for i in my_indices: 33 | 34 | steps = 10001 35 | 36 | model_mode = params_[i][0].astype('str') 37 | print('model_mode:', model_mode) 38 | embd_dim = params_[i][1].astype('int') 39 | lamb = params_[i][2].astype('float') 40 | seed = params_[i][3].astype('int') 41 | 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | 45 | torch.set_default_tensor_type(torch.DoubleTensor) 46 | 47 | device = 'cpu' 48 | 49 | p = 59 50 | #embd_dim = 10 51 | input_token = 2 52 | vocab_size = p 53 | 54 | 55 | if model_mode == 'ip': 56 | # ip model 57 | unembd = True 58 | weight_tied = True 59 | hidden_size = 100 60 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 61 | model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed).to(device) 62 | elif model_mode == 'hs3': 63 | weight_tied = True 64 | hidden_size = 100 65 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 66 | model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=1).to(device) 67 | elif model_mode == 'hs2': 68 | weight_tied = True 69 | hidden_size = 100 70 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 71 | model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=math.sqrt(embd_dim)).to(device) 72 | elif model_mode == 'hs1': 73 | weight_tied = True 74 | hidden_size = 100 75 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 76 | model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=embd_dim).to(device) 77 | elif model_mode == 'standard': 78 | unembd = False 79 | weight_tied = False 80 | hidden_size = 100 81 | shp = [input_token * embd_dim, hidden_size, vocab_size] 82 | model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed).to(device) 83 | else: 84 | print('model_mode not recognized!') 85 | 86 | 87 | # data 88 | dataset = modular_addition_dataset(p=p) 89 | #dataset = repeat_dataset(dataset) 90 | dataset = split_dataset(dataset, train_ratio=0.8, seed=seed) 91 | 92 | ### train ### 93 | wd = 0.0 94 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=wd) 95 | steps = 4001 #4001 96 | log = 200 97 | 98 | train_losses = [] 99 | test_losses = [] 100 | train_accs = [] 101 | test_accs = [] 102 | 103 | embds = [] 104 | 105 | #lamb = 0.1 106 | 107 | for step in range(steps): 108 | 109 | optimizer.zero_grad() 110 | 111 | logits = model.pred_logit(dataset['train_data_id']) 112 | loss = torch.nn.functional.cross_entropy(logits, dataset['train_label']) 113 | 114 | embd_reg = torch.mean(torch.sqrt(torch.mean(model.embedding**2, dim=0))) 115 | total_loss = loss + lamb * embd_reg 116 | 117 | acc = torch.mean((torch.argmax(logits, dim=1) == dataset['train_label']).float()) 118 | 119 | train_losses.append(loss.item()) 120 | train_accs.append(acc.item()) 121 | 122 | logits_test = model.pred_logit(dataset['test_data_id']) 123 | loss_test = torch.nn.functional.cross_entropy(logits_test, dataset['test_label']) 124 | 125 | acc_test = torch.mean((torch.argmax(logits_test, dim=1) == dataset['test_label']).float()) 126 | 127 | test_losses.append(loss_test.item()) 128 | test_accs.append(acc_test.item()) 129 | 130 | #total_loss = loss 131 | total_loss.backward() 132 | optimizer.step() 133 | 134 | if step % log == 0: 135 | print("step = %d | total loss: %.2e | train loss: %.2e | test loss %.2e | train acc: %.2e | test acc: %.2e "%(step, total_loss.cpu().detach().numpy(), loss.cpu().detach().numpy(), loss_test.cpu().detach().numpy(), acc.cpu().detach().numpy(), acc_test.cpu().detach().numpy())) 136 | 137 | if step % 100 == 0: 138 | embds.append(model.embedding.cpu().detach().numpy()) 139 | 140 | embd = model.embedding.cpu().detach().numpy() 141 | X = embd 142 | pca = PCA(n_components=embd_dim) 143 | pca.fit(X) 144 | embd_t = pca.fit_transform(X) 145 | 146 | active_pca_dim = np.sum(pca.explained_variance_ratio_ > 1e-4) 147 | active_embd_dim = torch.sum(torch.mean(model.embedding**2, dim=0) > 1e-4).item() 148 | 149 | # save train_acc, test_acc, parallelogram_acc, active_pca_dim, active_embd_dim, 150 | np.savetxt('./results/modadd_largeinit/model_%s_embddim_%d_lamb_%.2f_seed_%d_p_59_performance.txt'%(model_mode, embd_dim, lamb, seed), [train_accs[-1], test_accs[-1], active_pca_dim, active_embd_dim]) 151 | # save embd 152 | np.savetxt('./results/modadd_largeinit/model_%s_embddim_%d_lamb_%.2f_seed_%d_p_59_embedding.txt'%(model_mode, embd_dim, lamb, seed), embd) 153 | 154 | run() 155 | -------------------------------------------------------------------------------- /scripts/modadd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "My task ID:" $LLSUB_RANK 4 | echo "Number of Tasks:" $LLSUB_SIZE 5 | 6 | nohup python modadd.py $LLSUB_RANK $LLSUB_SIZE 7 | 8 | -------------------------------------------------------------------------------- /scripts/n_exp.sh: -------------------------------------------------------------------------------- 1 | 2 | # Define arrays for parameters 3 | data_ids=("circle") # "circle" 4 | model_ids=("H_transformer") 5 | splits=(3) # Modify as needed 6 | ns=(2) # I have 3 - 10 left to do for circle 7 | 8 | # Iterate over all combinations 9 | for data_id in "${data_ids[@]}"; do 10 | for model_id in "${model_ids[@]}"; do 11 | for split in "${splits[@]}"; do 12 | for n in "${ns[@]}"; do 13 | PYTHONPATH=$(pwd) python src/run_exp.py --data_id "$data_id" --model_id "$model_id" --split "$split" --n "$n" > output_"$data_id"_"$n".txt 2>&1 14 | done 15 | done 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /scripts/u_circle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | #SBATCH -n 16 5 | 6 | for ARG in $(python -c "import numpy as np; print(' '.join(map(str, np.linspace(51, 100, 20, dtype=int))))"); do 7 | echo "Running with seed $ARG:" 8 | python ../src/unit_exp.py --data_id circle --model_id H_MLP --seed $ARG 9 | echo 10 | done -------------------------------------------------------------------------------- /scripts/u_circle_new.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH -p tegmark 4 | #SBATCH --gres=gpu:a100:1 5 | 6 | for ARG in $(python -c "import numpy as np; print(' '.join(map(str, np.linspace(0, 1000, 20, dtype=int))))"); do 7 | echo "Running with seed $ARG:" 8 | python ../src/unit_exp.py --data_id circle --model_id H_transformer --seed $ARG 9 | echo 10 | done -------------------------------------------------------------------------------- /scripts/u_equiv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/unit_exp.py --data_id equivalence --model_id standard_transformer 7 | python ../src/unit_exp.py --data_id equivalence --model_id H_transformer 8 | python ../src/unit_exp.py --data_id equivalence --model_id standard_MLP 9 | python ../src/unit_exp.py --data_id equivalence --model_id H_MLP 10 | 11 | -------------------------------------------------------------------------------- /scripts/u_family.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | #SBATCH -n 16 5 | 6 | for ARG in $(python -c "import numpy as np; print(' '.join(map(str, np.linspace(51, 100, 20, dtype=int))))"); do 7 | echo "Running with seed $ARG:" 8 | python ../src/unit_exp.py --data_id family_tree --model_id H_MLP --seed $ARG 9 | echo 10 | done -------------------------------------------------------------------------------- /scripts/u_family_new.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:a100:1 4 | 5 | for ARG in $(python -c "import numpy as np; print(' '.join(map(str, np.linspace(0, 1000, 20, dtype=int))))"); do 6 | echo "Running with seed $ARG:" 7 | python ../src/unit_exp.py --data_id family_tree --model_id H_transformer --seed $ARG 8 | echo 9 | done -------------------------------------------------------------------------------- /scripts/u_greater.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/unit_exp.py --data_id greater --model_id standard_transformer 7 | python ../src/unit_exp.py --data_id greater --model_id H_transformer 8 | python ../src/unit_exp.py --data_id greater --model_id standard_MLP 9 | python ../src/unit_exp.py --data_id greater --model_id H_MLP 10 | 11 | -------------------------------------------------------------------------------- /scripts/u_lattice.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -t 16:00:00 3 | #SBATCH --gres=gpu:1 4 | #SBATCH -n 16 5 | 6 | python ../src/unit_exp.py --data_id lattice --model_id standard_transformer 7 | python ../src/unit_exp.py --data_id lattice --model_id H_transformer 8 | python ../src/unit_exp.py --data_id lattice --model_id standard_MLP 9 | python ../src/unit_exp.py --data_id lattice --model_id H_MLP 10 | 11 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | ## How to add new dataset for experiments 2 | 3 | 1. Implement a function which returns the dataset dictionary in `utils/dataset.py`. 4 | 2. Choose a unique id for the new dataset. Implement a function which evaluates the quality of representation in `utils/crystal_metric.py`. Modify the function `crystal_metric` to support the new data_id. 5 | 3. Add the new data_id to the array `data_id_choices` in `run_exp.py`. 6 | 4. If any auxiliary information is required to evaluate the representations, add them to the dictionary `aux_info` in `run_exp.py`. Sometimes, these information may depend on the specific dataset; In such cases, make any necessary modifications within each of the three experiment for loops in `run_exp.py`. 7 | 5. Now, you're ready to test the new dataset! Command format is: 8 | `python run_exp.py --data_id new_data_id --model_id H_MLP`. -------------------------------------------------------------------------------- /src/run_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | # import random 6 | # import optuna 7 | # import joblib 8 | 9 | from tqdm import tqdm 10 | 11 | import sys 12 | sys.path.append("..") 13 | 14 | import argparse 15 | from src.utils.driver import train_single_model 16 | from src.utils.visualization import visualize_embedding 17 | from src.utils.crystal_metric import crystal_metric 18 | import json 19 | 20 | import os 21 | from datetime import datetime 22 | 23 | data_id_choices = ["lattice", "greater", "family_tree", "equivalence", "circle", "permutation"] 24 | model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"] 25 | split_choices = [1,2,3,4,5,6,7, 8] 26 | wd_choices = [0.0005, 0.001, 0.003, 0.005, 0.007, 0.01, 0.012, 0.015, 0.02, 0.03, 0.05, 0.07, 0.1] 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser(description='Experiment') 29 | parser.add_argument('--seed', type=int, default=66, help='random seed') 30 | parser.add_argument('--data_id', type=str, required=True, choices=data_id_choices, help='Data ID') 31 | parser.add_argument('--model_id', type=str, required=True, choices=model_id_choices, help='Model ID') 32 | parser.add_argument('--split', type=int, required=False, choices=split_choices, help='To split running experiments') 33 | parser.add_argument('--wd', type=float, required=False, choices=wd_choices, help='weight decay') 34 | parser.add_argument('--n', type=int, required=False, default=1, help='n exponent value') 35 | 36 | 37 | args = parser.parse_args() 38 | seed = args.seed 39 | data_id = args.data_id 40 | model_id = args.model_id 41 | split=args.split 42 | n_exp = args.n 43 | 44 | ## ------------------------ CONFIG -------------------------- ## 45 | 46 | data_size = 1000 47 | train_ratio = 0.8 48 | embd_dim = 16 49 | 50 | lr = 0.002 51 | weight_decay = 0.01 if "MLP" in model_id else 0.005 52 | 53 | param_dict = { 54 | 'seed': seed, 55 | 'data_id': data_id, 56 | 'data_size': data_size, 57 | 'train_ratio': train_ratio, 58 | 'model_id': model_id, 59 | 'device': torch.device('cuda:1' if torch.cuda.is_available() else 'cpu'), 60 | 'embd_dim': embd_dim, 61 | 'n_exp': n_exp, 62 | 'lr': lr, 63 | 'weight_decay':weight_decay, 64 | 'custom_loss': "softnn" 65 | } 66 | 67 | results_root = f"results_loss_exps/softnn" 68 | 69 | current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 70 | results_root = f"{results_root}/{current_datetime}-{seed}-{data_id}-{model_id}" 71 | os.makedirs(results_root, exist_ok=True) 72 | 73 | param_dict_json = {k: v for k, v in param_dict.items() if k != 'device'} # since torch.device is not JSON serializable 74 | 75 | 76 | with open(f"{results_root}/config.json", "w") as f: 77 | json.dump(param_dict_json, f, indent=4) 78 | 79 | aux_info = {} 80 | if data_id == "lattice": 81 | aux_info["lattice_size"] = 5 82 | elif data_id == "greater": 83 | aux_info["p"] = 30 84 | elif data_id == "equivalence": 85 | aux_info["mod"] = 5 86 | elif data_id == "circle": 87 | aux_info["p"] = 31 88 | elif data_id == "family_tree": 89 | aux_info["dict_level"] = 2 90 | elif data_id == "permutation": 91 | aux_info["p"] = 4 92 | else: 93 | raise ValueError(f"Unknown data_id: {data_id}") 94 | 95 | # # Optuna study for lr/wd 96 | # def loss_objective(trial): 97 | # weight_decay = trial.suggest_float('wd', 0, 0.01) 98 | # lr = trial.suggest_float('lr', 0.002, 0.005) 99 | 100 | # param_dict = { 101 | # 'seed': seed, 102 | # 'data_id': data_id, 103 | # 'data_size': data_size, 104 | # 'train_ratio': train_ratio, 105 | # 'model_id': model_id, 106 | # 'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), 107 | # 'embd_dim': embd_dim, 108 | # 'n_exp': n_exp, 109 | # 'lr': lr, 110 | # 'weight_decay':weight_decay 111 | # } 112 | 113 | # ret_dic = train_single_model(param_dict) 114 | 115 | # test_loss = np.mean(ret_dic["results"]["test_losses"][-10:]) 116 | 117 | # return test_loss 118 | 119 | # study = optuna.create_study() 120 | # study.optimize(loss_objective, n_trials = 15) 121 | # joblib.dump(study, "wd_lr_study.pkl") 122 | 123 | # print(study.best_params) 124 | 125 | # # Train the model 126 | # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}, weight decay {weight_decay}") 127 | # ret_dic = train_single_model(param_dict) 128 | 129 | # ## Exp1: Visualize Embeddings 130 | # print(f"Experiment 1: Visualize Embeddings") 131 | # model = ret_dic['model'] 132 | # dataset = ret_dic['dataset'] 133 | # torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt") 134 | 135 | # if hasattr(model.embedding, 'weight'): 136 | # visualize_embedding(model.embedding.weight.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}", save_path=f"{results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False) 137 | # else: 138 | # visualize_embedding(model.embedding.data.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}", save_path=f"{results_root}/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None, color_dict = False if data_id == "permutation" else True, adjust_overlapping_text = False) 139 | 140 | 141 | # ## Exp2: Metric vs Overall Dataset Size (fixed train-test split) 142 | # print(f"Experiment 2: Metric vs Overall Dataset Size (fixed train-test split)") 143 | # data_size_list = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] 144 | # for i in tqdm(range(len(data_size_list))): 145 | # data_size = data_size_list[i] 146 | # param_dict = { 147 | # 'seed': seed, 148 | # 'data_id': data_id, 149 | # 'data_size': data_size, 150 | # 'train_ratio': train_ratio, 151 | # 'model_id': model_id, 152 | # 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 153 | # 'embd_dim': embd_dim, 154 | # 'n_exp': n_exp, 155 | # 'lr': lr, 156 | # 'weight_decay':weight_decay 157 | # } 158 | 159 | # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}") 160 | # ret_dic = train_single_model(param_dict) 161 | # model = ret_dic['model'] 162 | # dataset = ret_dic['dataset'] 163 | 164 | # torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt") 165 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f: 166 | # json.dump(ret_dic["results"], f, indent=4) 167 | 168 | # if data_id == "family_tree": 169 | # aux_info["dict_level"] = dataset['dict_level'] 170 | 171 | # if hasattr(model.embedding, 'weight'): 172 | # metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info) 173 | # else: 174 | # metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info) 175 | 176 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.json", "w") as f: 177 | # json.dump(metric_dict, f, indent=4) 178 | 179 | # # ## Exp3: Metric vs Train Fraction (fixed dataset size) 180 | # print(f"Experiment 3: Metric vs Train Fraction (fixed dataset size)") 181 | # train_ratio_list = [] 182 | # if split == 1: 183 | # train_ratio_list = np.arange(1, 5) / 10 184 | # if split == 2: 185 | # train_ratio_list = np.arange(5,10) / 10 186 | 187 | # data_size = 1000 188 | # for i in tqdm(range(len(train_ratio_list))): 189 | # train_ratio = train_ratio_list[i] 190 | # param_dict = { 191 | # 'seed': seed, 192 | # 'data_id': data_id, 193 | # 'data_size': data_size, 194 | # 'train_ratio': train_ratio, 195 | # 'model_id': model_id, 196 | # 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 197 | # 'embd_dim': embd_dim, 198 | # 'n_exp': n_exp, 199 | # 'lr': lr, 200 | # 'weight_decay':weight_decay 201 | # } 202 | # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}") 203 | # ret_dic = train_single_model(param_dict) 204 | # model = ret_dic['model'] 205 | # dataset = ret_dic['dataset'] 206 | 207 | # torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt") 208 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f: 209 | # json.dump(ret_dic["results"], f, indent=4) 210 | # torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt") 211 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f: 212 | # json.dump(ret_dic["results"], f, indent=4) 213 | 214 | # if data_id == "family_tree": 215 | # aux_info["dict_level"] = dataset['dict_level'] 216 | # if data_id == "family_tree": 217 | # aux_info["dict_level"] = dataset['dict_level'] 218 | 219 | # if hasattr(model.embedding, 'weight'): 220 | # metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info) 221 | # else: 222 | # metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info) 223 | # if hasattr(model.embedding, 'weight'): 224 | # metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info) 225 | # else: 226 | # metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info) 227 | 228 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_metric.json", "w") as f: 229 | # json.dump(metric_dict, f, indent=4) 230 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_metric.json", "w") as f: 231 | # json.dump(metric_dict, f, indent=4) 232 | 233 | ## Exp4: Grokking plot: Run with different seeds 234 | print(f"Experiment 4: Train with different seeds") 235 | 236 | seed_list = [] 237 | if split == 3: 238 | seed_list = np.linspace(0, 1000, 20, dtype=int)[:4] 239 | if split == 4: 240 | seed_list = np.linspace(0, 1000, 20, dtype=int)[4:7] 241 | if split == 5: 242 | seed_list = np.linspace(0, 1000, 20, dtype=int)[7:10] 243 | if split == 6: 244 | seed_list = np.linspace(0, 1000, 20, dtype=int)[10:13] 245 | if split == 7: 246 | seed_list = np.linspace(0, 1000, 20, dtype=int)[13:17] 247 | if split == 8: 248 | seed_list = np.linspace(0, 1000, 20, dtype=int)[17:] 249 | 250 | 251 | 252 | for i in tqdm(range(len(seed_list))): 253 | seed = seed_list[i] 254 | data_size = 1000 255 | train_ratio = 0.8 256 | 257 | param_dict = { 258 | 'seed': int(seed), 259 | 'data_id': data_id, 260 | 'data_size': data_size, 261 | 'train_ratio': train_ratio, 262 | 'model_id': model_id, 263 | 'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), 264 | 'embd_dim': embd_dim, 265 | 'n_exp': n_exp, 266 | 'lr': lr, 267 | 'weight_decay':weight_decay, 268 | 'custom_loss': "softnn" 269 | } 270 | print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}") 271 | ret_dic = train_single_model(param_dict) 272 | model = ret_dic['model'] 273 | dataset = ret_dic['dataset'] 274 | torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt") 275 | with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f: 276 | json.dump(ret_dic["results"], f, indent=4) 277 | 278 | if data_id == "family_tree": 279 | aux_info["dict_level"] = dataset['dict_level'] 280 | 281 | if hasattr(model.embedding, 'weight'): 282 | metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info) 283 | else: 284 | metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info) 285 | 286 | with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.json", "w") as f: 287 | json.dump(metric_dict, f, indent=4) 288 | 289 | # #Exp5: N Exponent value plot: Run with different n values, plot test accuracy vs. and explained variance vs. 290 | 291 | # print(f"Experiment 5: Train with different exponent values") 292 | # n_list = np.arange(1, 17, dtype=int) 293 | 294 | # for i in tqdm(range(len(n_list))): 295 | # n_exp = n_list[i] 296 | # data_size = 1000 297 | # train_ratio = 0.8 298 | 299 | # param_dict = { 300 | # 'seed': seed, 301 | # 'data_id': data_id, 302 | # 'data_size': data_size, 303 | # 'train_ratio': train_ratio, 304 | # 'model_id': model_id, 305 | # 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 306 | # 'embd_dim': embd_dim, 307 | # 'n_exp': n_exp 308 | # } 309 | # print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}, n_exp {n_exp}, embd_dim {embd_dim}") 310 | 311 | # ret_dic = train_single_model(param_dict) 312 | # model = ret_dic['model'] 313 | # dataset = ret_dic['dataset'] 314 | # torch.save(model.state_dict(), f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.pt") 315 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}_train_results.json", "w") as f: 316 | # json.dump(ret_dic["results"], f, indent=4) 317 | 318 | # if data_id == "family_tree": 319 | # aux_info["dict_level"] = dataset['dict_level'] 320 | 321 | # if hasattr(model.embedding, 'weight'): 322 | # metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info) 323 | # else: 324 | # metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info) 325 | 326 | # with open(f"{results_root}/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_{n_exp}.json", "w") as f: 327 | # json.dump(metric_dict, f, indent=4) 328 | 329 | -------------------------------------------------------------------------------- /src/unit_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import random 6 | 7 | from tqdm import tqdm 8 | 9 | import sys 10 | sys.path.append("..") 11 | 12 | import argparse 13 | from src.utils.driver import train_single_model 14 | from src.utils.visualization import visualize_embedding 15 | from src.utils.crystal_metric import crystal_metric 16 | import json 17 | 18 | data_id_choices = ["lattice", "greater", "family_tree", "equivalence", "circle"] 19 | model_id_choices = ["H_MLP", "standard_MLP", "H_transformer", "standard_transformer"] 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(description='Experiment') 22 | parser.add_argument('--seed', type=int, default=29, help='random seed') 23 | parser.add_argument('--data_id', type=str, required=True, choices=data_id_choices, help='Data ID') 24 | parser.add_argument('--model_id', type=str, required=True, choices=model_id_choices, help='Model ID') 25 | 26 | 27 | args = parser.parse_args() 28 | seed = args.seed 29 | data_id = args.data_id 30 | model_id = args.model_id 31 | 32 | data_size = 1000 33 | train_ratio = 0.8 34 | 35 | param_dict = { 36 | 'seed': seed, 37 | 'data_id': data_id, 38 | 'data_size': data_size, 39 | 'train_ratio': train_ratio, 40 | 'model_id': model_id, 41 | 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 42 | 'embd_dim': 16, 43 | } 44 | 45 | 46 | # Train the model 47 | print(f"Training model with seed {seed}, data_id {data_id}, model_id {model_id}") 48 | ret_dic = train_single_model(param_dict) 49 | 50 | ## Exp1: Visualize Embeddings 51 | print(f"Experiment 1: Visualize Embeddings") 52 | model = ret_dic['model'] 53 | dataset = ret_dic['dataset'] 54 | torch.save(model.state_dict(), f"../results/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_d=sqrtembed_1.pt") 55 | 56 | if hasattr(model.embedding, 'weight'): 57 | visualize_embedding(model.embedding.weight.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../results/unit_tests/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_new.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None) 58 | else: 59 | visualize_embedding(model.embedding.data.cpu(), title=f"{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}", save_path=f"../results/unit_tests/emb_{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_new.png", dict_level = dataset['dict_level'] if 'dict_level' in dataset else None) 60 | 61 | with open(f"../results/unit_tests/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_train_results_new.json", "w") as f: 62 | json.dump(ret_dic["results"], f, indent=4) 63 | 64 | aux_info = {} 65 | if data_id == "lattice": 66 | aux_info["lattice_size"] = 5 67 | elif data_id == "greater": 68 | aux_info["p"] = 30 69 | elif data_id == "family_tree": 70 | aux_info["dict_level"] = dataset['dict_level'] 71 | elif data_id == "equivalence": 72 | aux_info["mod"] = 5 73 | elif data_id == "circle": 74 | aux_info["p"] = 17 75 | else: 76 | raise ValueError(f"Unknown data_id: {data_id}") 77 | 78 | if hasattr(model.embedding, 'weight'): 79 | metric_dict = crystal_metric(model.embedding.weight.cpu().detach(), data_id, aux_info) 80 | else: 81 | metric_dict = crystal_metric(model.embedding.data.cpu(), data_id, aux_info) 82 | 83 | with open(f"../results/unit_tests/{seed}_{data_id}_{model_id}_{data_size}_{train_ratio}_new.json", "w") as f: 84 | json.dump(metric_dict, f, indent=4) 85 | 86 | -------------------------------------------------------------------------------- /src/utils/crystal_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from itertools import combinations 5 | from sklearn.decomposition import PCA 6 | 7 | def crystal_metric(reps, data_id, aux_info): 8 | """ 9 | Compute the crystal metric for the given representations and data_id. 10 | """ 11 | if data_id == "lattice": 12 | return lattice_metric(reps, aux_info) 13 | elif data_id == "greater": 14 | return greater_metric(reps, aux_info) 15 | elif data_id == "family_tree": 16 | return family_tree_metric(reps, aux_info) 17 | elif data_id == "equivalence": 18 | return equivalence_metric(reps, aux_info) 19 | elif data_id == "circle": 20 | return circle_metric(reps, aux_info) 21 | elif data_id == "permutation": 22 | return permutation_metric(reps, aux_info) 23 | else: 24 | raise ValueError(f"Unknown data_id: {data_id}") 25 | 26 | def lattice_metric(reps, aux_info): 27 | lattice_size = aux_info['lattice_size'] 28 | deviation_arr = [] 29 | points = [(i, j) for i in range(lattice_size) for j in range(lattice_size)] 30 | 31 | def side_length_deviation(a, b, c, d): 32 | a, b, c, d = np.array(a), np.array(b), np.array(c), np.array(d) 33 | 34 | # Compute lengths of opposite sides 35 | length_ab = np.linalg.norm(b - a) 36 | length_cd = np.linalg.norm(d - c) 37 | length_ac = np.linalg.norm(c - a) 38 | length_bd = np.linalg.norm(b - d) 39 | length_bc = np.linalg.norm(c - b) 40 | length_ad = np.linalg.norm(d - a) 41 | 42 | # Calculate side length deviation 43 | side_deviation = np.sqrt((length_ab - length_cd)**2 + (length_ac - length_bd)**2) / np.sqrt((length_ab ** 2 + length_bc ** 2 + length_cd ** 2 + length_ad ** 2)/2) 44 | 45 | return side_deviation 46 | 47 | # Compute the deviation from a perfect parallelogram for all quadrilaterals 48 | for quad in combinations(points, 3): 49 | a, b, c = quad 50 | d = (c[0] + b[0] - a[0], c[1] + b[1] - a[1]) 51 | if d[0] < 0 or d[0] >= lattice_size or d[1] < 0 or d[1] >= lattice_size: 52 | continue 53 | 54 | if a[0] == b[0] and b[0] == c[0]: 55 | continue 56 | if a[1] == b[1] and b[1] == c[1]: 57 | continue 58 | 59 | a = lattice_size * a[0] + a[1] 60 | b = lattice_size * b[0] + b[1] 61 | c = lattice_size * c[0] + c[1] 62 | d = lattice_size * d[0] + d[1] 63 | 64 | a = reps[a] 65 | b = reps[b] 66 | c = reps[c] 67 | d = reps[d] 68 | deviation = side_length_deviation(a, b, c, d) 69 | deviation_arr.append(deviation) 70 | 71 | # Obtatin explained variance ratios 72 | pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) 73 | emb_pca = pca.fit_transform(reps) 74 | variances = pca.explained_variance_ratio_ 75 | 76 | metric_dict = { 77 | 'metric': float(np.mean(deviation_arr)), 78 | 'variances': variances.tolist(), 79 | } 80 | 81 | return metric_dict 82 | 83 | 84 | def greater_metric(reps, aux_info): 85 | diff_arr = [] 86 | 87 | # Compute the difference between consecutive representations 88 | # We expect the perfect representation to be equidistant 89 | for i in range(reps.shape[0]-1): 90 | diff_arr.append(np.linalg.norm(reps[i] - reps[i+1])) 91 | 92 | pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) 93 | emb_pca = pca.fit_transform(reps) 94 | variances = pca.explained_variance_ratio_ 95 | 96 | metric_dict = { 97 | 'metric': float(np.std(diff_arr) / np.mean(diff_arr)), 98 | 'variances': variances.tolist(), 99 | } 100 | return metric_dict 101 | 102 | def family_tree_metric(reps, aux_info): 103 | 104 | dict_level = aux_info['dict_level'] 105 | reps = reps[1:(max(dict_level.keys()) + 1)] 106 | 107 | pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) 108 | reps = pca.fit_transform(reps) 109 | reps = reps[:, :2] 110 | 111 | 112 | # Group embeddings by generation 113 | levels = {} 114 | for node, generation in dict_level.items(): 115 | if generation not in levels: 116 | levels[generation] = [] 117 | levels[generation].append(reps[node-1]) 118 | 119 | # Compute one-dimensionality for each generation 120 | level_scores = {} 121 | for generation, points in levels.items(): 122 | if len(points) < 5: 123 | continue 124 | 125 | points_array = np.stack(points) # Convert to NumPy array 126 | pca_sub = PCA(n_components=min(points_array.shape[0], points_array.shape[1])) 127 | pca_sub.fit(points_array) 128 | one_dimensionality = pca_sub.explained_variance_ratio_[0] # Ratio of variance explained by the first PC 129 | level_scores[generation] = one_dimensionality 130 | 131 | 132 | # pca.fit_transform(reps) 133 | variances = pca.explained_variance_ratio_ 134 | 135 | metric_dict = { 136 | 'metric': float(1 - np.mean(list(level_scores.values()))), 137 | 'variances': variances.tolist(), 138 | } 139 | return metric_dict 140 | 141 | def equivalence_metric(reps, aux_info): 142 | mod = aux_info['mod'] 143 | n = reps.shape[0] 144 | 145 | # Compute the difference between representations within the same equivalence class 146 | diff_arr = [] 147 | cross_diff_arr = [] 148 | for i in range(n): 149 | for j in range(n): 150 | if i % mod != j % mod: 151 | cross_diff_arr.append(np.linalg.norm(reps[i] - reps[j])) 152 | else: 153 | diff_arr.append(np.linalg.norm(reps[i] - reps[j])) 154 | 155 | # Filter Outliers 156 | diff_arr = np.array(diff_arr) 157 | diff_arr = diff_arr[diff_arr < np.mean(cross_diff_arr)] 158 | 159 | pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) 160 | emb_pca = pca.fit_transform(reps) 161 | variances = pca.explained_variance_ratio_ 162 | 163 | print(np.mean(diff_arr) , np.mean(cross_diff_arr)) 164 | metric_dict = { 165 | 'metric': float(np.mean(diff_arr) / np.mean(cross_diff_arr)), 166 | 'variances': variances.tolist(), 167 | } 168 | return metric_dict 169 | 170 | 171 | def circle_metric(reps, aux_info): 172 | 173 | pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) 174 | emb_pca = pca.fit_transform(reps) 175 | variances = pca.explained_variance_ratio_ 176 | 177 | points = emb_pca[:, :2] 178 | 179 | min_x, min_y = points.min(axis=0) 180 | max_x, max_y = points.max(axis=0) 181 | width = max_x - min_x 182 | height = max_y - min_y 183 | 184 | # Normalize points to [0, 1] in both dimensions 185 | normalized_points = (points - [min_x, min_y]) / [width, height] 186 | 187 | # Compute the centroid of the points 188 | centroid = np.mean(normalized_points, axis=0) 189 | 190 | # Compute distances of points from the centroid 191 | distances = np.linalg.norm(normalized_points - centroid, axis=1) 192 | 193 | # Mean and standard deviation of distances 194 | mean_distance = np.mean(distances) 195 | std_distance = np.std(distances) 196 | 197 | # Circularity score 198 | circularity_score = (std_distance / mean_distance) 199 | 200 | 201 | metric_dict = { 202 | 'metric': float(circularity_score), 203 | 'variances': variances.tolist(), 204 | } 205 | return metric_dict 206 | 207 | 208 | def permutation_metric(reps, aux_info): # to be the average distance between permutations in the same coset 209 | 210 | pca = PCA(n_components=min(reps.shape[0], reps.shape[1])) 211 | emb_pca = pca.fit_transform(reps) 212 | variances = pca.explained_variance_ratio_ 213 | 214 | points = emb_pca[:, :2] 215 | 216 | min_x, min_y = points.min(axis=0) 217 | max_x, max_y = points.max(axis=0) 218 | width = max_x - min_x 219 | height = max_y - min_y 220 | 221 | # Normalize points to [0, 1] in both dimensions 222 | normalized_points = (points - [min_x, min_y]) / [width, height] 223 | 224 | scatter = np.array(normalized_points) 225 | 226 | scatter -= scatter.mean(axis=0) 227 | 228 | angles = np.linspace(0, 2 * np.pi, aux_info['p'], endpoint=False) 229 | 230 | distances = [] 231 | for angle in angles[1:]: 232 | # Create rotation matrix 233 | rotation_matrix = np.array([ 234 | [np.cos(angle), -np.sin(angle)], 235 | [np.sin(angle), np.cos(angle)] 236 | ]) 237 | # rotate scatterplot 238 | rotated_scatter = scatter @ rotation_matrix.T 239 | 240 | # nearest-neighbor distances 241 | total_distance = 0 242 | for point in scatter: 243 | distances_to_rotated = np.linalg.norm(rotated_scatter - point, axis=1) 244 | total_distance += np.min(distances_to_rotated) 245 | distances.append(total_distance / len(scatter)) 246 | 247 | # Symmetry score (inverse of average distance) 248 | symmetry_score = 1 / (1 + np.mean(distances)) 249 | 250 | metric_dict = { 251 | 'metric': float(symmetry_score), 252 | 'variances': variances.tolist(), 253 | } 254 | return metric_dict 255 | -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import math 4 | import itertools 5 | 6 | import sys 7 | sys.path.append("..") 8 | from src.utils.FamilyTreeGenerator import GenerateFamilyTree 9 | 10 | def parallelogram_dataset(p, dim, num, seed=0, device='cpu'): 11 | 12 | torch.manual_seed(seed) 13 | np.random.seed(seed) 14 | 15 | N_sample = 5 * num 16 | x = np.random.choice(p, N_sample*dim*3).reshape(N_sample, 3, dim) 17 | target = -x[:,0,:] + x[:,1,:] + x[:,2,:] 18 | id_ = np.where(np.prod((target >= 0) * (target < p), axis=1)==1)[0][:num] 19 | target = target[id_] 20 | x = x[id_] 21 | 22 | data_id = 0 23 | for i in range(dim): 24 | data_id += x[:,:,i] * p ** (dim-i-1) 25 | 26 | labels = 0 27 | for i in range(dim): 28 | labels += target[:,i] * p ** (dim-i-1) 29 | 30 | data_id = torch.from_numpy(data_id).to(device) 31 | labels = torch.from_numpy(labels).to(device) 32 | 33 | vocab_size = p**dim 34 | 35 | dataset = {} 36 | dataset['data_id'] = data_id 37 | dataset['label'] = labels 38 | dataset['vocab_size'] = vocab_size 39 | 40 | return dataset 41 | 42 | 43 | def modular_addition_dataset(p, num, seed=0, device='cpu'): 44 | 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | 48 | x = np.arange(p) 49 | y = np.arange(p) 50 | XX, YY = np.meshgrid(x, y) 51 | data_id = np.transpose([XX.reshape(-1,), YY.reshape(-1,)]) 52 | 53 | sample_id = np.random.choice(len(data_id), size=num, replace=True) 54 | data_id = data_id[sample_id] 55 | labels = (data_id[:,0] + data_id[:,1]) % p 56 | labels = torch.tensor(labels, dtype=torch.long) 57 | 58 | 59 | vocab_size = p 60 | 61 | dataset = {} 62 | dataset['data_id'] = data_id 63 | dataset['label'] = labels 64 | dataset['vocab_size'] = vocab_size 65 | 66 | return dataset 67 | 68 | def permutation_group_dataset(p, num, seed=0, device='cpu'): 69 | torch.manual_seed(seed) 70 | np.random.seed(seed) 71 | 72 | perms = list(itertools.permutations(range(p))) 73 | num_perms = len(perms) 74 | 75 | perm_dict = dict(enumerate(perms)) 76 | swapped_dict = {v:k for k,v in perm_dict.items()} 77 | 78 | idx = torch.arange(num_perms) 79 | 80 | data_id = [[perms[int(i)], perms[int(j)]] for i, j in torch.cartesian_prod(idx, idx)] 81 | keyed_data_id = np.array([[swapped_dict[data_id[i][0]], swapped_dict[data_id[i][1]]] for i in range(len(data_id))]) 82 | 83 | labels = [tuple(np.array(perms[int(i)])[np.array(perms[int(j)])]) for i, j in torch.cartesian_prod(idx, idx)] 84 | keyed_labels = np.array([swapped_dict[labels[i]] for i in range(len(labels))]) 85 | labels = torch.tensor(labels, dtype=torch.long, device=device) 86 | 87 | perm_vals = ["".join(np.array(perm_dict[i]).astype(str)) for i in range(len(perm_dict))] 88 | new_perm_dict = dict(zip(perm_dict.keys(), perm_vals)) # are these indices correct? 89 | 90 | dataset = {} 91 | 92 | dataset['data_id'] = keyed_data_id 93 | dataset['label'] = keyed_labels 94 | dataset['vocab_size'] = num_perms 95 | dataset['dict_level'] = new_perm_dict 96 | 97 | return dataset 98 | 99 | 100 | def split_dataset(dataset, train_ratio, seed=0): 101 | 102 | torch.manual_seed(seed) 103 | np.random.seed(seed) 104 | 105 | dataset2 = {} 106 | 107 | num = dataset['data_id'].shape[0] 108 | 109 | train_num = int(num*train_ratio) 110 | test_num = num - train_num 111 | 112 | train_id = np.random.choice(num,train_num,replace=False) 113 | test_id = np.array(list(set(np.arange(num)) - set(train_id))) 114 | 115 | dataset2['train_data_id'] = dataset['data_id'][train_id] 116 | dataset2['test_data_id'] = dataset['data_id'][test_id] 117 | dataset2['train_label'] = dataset['label'][train_id] 118 | dataset2['test_label'] = dataset['label'][test_id] 119 | dataset2['vocab_size'] = dataset['vocab_size'] 120 | if 'dict_level' in dataset: 121 | dataset2['dict_level'] = dataset['dict_level'] 122 | return dataset2 123 | 124 | def repeat_dataset(dataset): 125 | 126 | dataset2 = {} 127 | 128 | dataset2['train_data_id'] = dataset['data_id'] 129 | dataset2['test_data_id'] = dataset['data_id'] 130 | dataset2['train_label'] = dataset['label'] 131 | dataset2['test_label'] = dataset['label'] 132 | dataset2['vocab_size'] = dataset['vocab_size'] 133 | 134 | if 'dict_level' in dataset: 135 | dataset2['dict_level'] = dataset['dict_level'] 136 | 137 | return dataset2 138 | 139 | 140 | def combine_dataset(train_dataset, test_dataset): 141 | 142 | dataset_c = {} 143 | 144 | dataset_c['train_data_id'] = train_dataset['data_id'] 145 | dataset_c['test_data_id'] = test_dataset['data_id'] 146 | dataset_c['train_label'] = train_dataset['label'] 147 | dataset_c['test_label'] = test_dataset['label'] 148 | 149 | assert train_dataset['vocab_size'] == test_dataset['vocab_size'] 150 | dataset_c['vocab_size'] = train_dataset['vocab_size'] 151 | 152 | return dataset_c 153 | 154 | 155 | # Dataset and DataLoader 156 | class ToyDataset(torch.utils.data.Dataset): 157 | def __init__(self, inputs, targets): 158 | self.inputs = inputs 159 | self.targets = targets 160 | 161 | def __len__(self): 162 | return len(self.inputs) 163 | 164 | def __getitem__(self, idx): 165 | return self.inputs[idx], self.targets[idx] 166 | 167 | def descendant_dataset(p, num, seed=0, device='cpu'): 168 | 169 | torch.manual_seed(seed) 170 | np.random.seed(seed) 171 | 172 | N_sample = num 173 | x = np.random.choice(range(2,p), N_sample*2).reshape(N_sample, 2) 174 | 175 | # Check if b is a descendant of a 176 | # In a complete binary tree where two children of x is 2x and 2x+1 177 | def is_desc(a, b): 178 | while b > 1: 179 | if b == a: 180 | return True 181 | b //= 2 # Move up to the parent node 182 | return b == a 183 | target = np.array([1 if is_desc(x[i,0]-1, x[i,1]-1) else 0 for i in range(N_sample)]) 184 | 185 | data_id = torch.from_numpy(x).to(device) 186 | labels = torch.from_numpy(target).to(device) 187 | 188 | vocab_size = p 189 | 190 | dataset = {} 191 | dataset['data_id'] = data_id 192 | dataset['label'] = labels 193 | dataset['vocab_size'] = vocab_size 194 | 195 | return dataset 196 | 197 | def descendant_dataset_2(p, num, seed=0, device='cpu'): 198 | 199 | torch.manual_seed(seed) 200 | np.random.seed(seed) 201 | 202 | N_sample = num*4 203 | x = np.random.choice(range(1,(p-1)//2), num*2).reshape(num, 2) 204 | 205 | data = np.zeros((N_sample, 4), dtype=np.int32) 206 | data[:num,0] = x[:,0] 207 | data[:num,1] = 2*x[:,0] 208 | data[:num,2] = x[:,1] 209 | data[:num,3] = 2*x[:,1] 210 | 211 | data[num:(2*num),0] = x[:,0] 212 | data[num:(2*num),1] = 2*x[:,0] + 1 213 | data[num:(2*num),2] = x[:,1] 214 | data[num:(2*num),3] = 2*x[:,1] + 1 215 | 216 | data[2*num:(3*num),0] = 2*x[:,0] + 1 217 | data[2*num:(3*num),1] = x[:,0] 218 | data[2*num:(3*num),2] = 2*x[:,1] + 1 219 | data[2*num:(3*num),3] = x[:,1] 220 | 221 | data[3*num:(4*num),0] = 2*x[:,0] + 1 222 | data[3*num:(4*num),1] = x[:,0] 223 | data[3*num:(4*num),2] = 2*x[:,1] + 1 224 | data[3*num:(4*num),3] = x[:,1] 225 | 226 | np.random.shuffle(data) 227 | 228 | data_id = torch.from_numpy(data[:, :3]).to(device) 229 | labels = torch.from_numpy(data[:, 3]).to(device) 230 | 231 | vocab_size = p+1 232 | 233 | dataset = {} 234 | dataset['data_id'] = data_id 235 | dataset['label'] = labels 236 | dataset['vocab_size'] = vocab_size 237 | 238 | return dataset 239 | 240 | 241 | def greater_than_dataset(p, num, seed=0, device='cpu'): 242 | 243 | torch.manual_seed(seed) 244 | np.random.seed(seed) 245 | 246 | N_sample = num 247 | x = np.random.choice(range(p), N_sample*2).reshape(N_sample, 2) 248 | 249 | target = np.array([p+1 if x[i,0] > x[i,1] else p for i in range(N_sample)]) 250 | 251 | data_id = torch.from_numpy(x).to(device) 252 | labels = torch.from_numpy(target).to(device) 253 | 254 | vocab_size = p+2 255 | 256 | dataset = {} 257 | dataset['data_id'] = data_id 258 | dataset['label'] = labels 259 | dataset['vocab_size'] = vocab_size 260 | 261 | return dataset 262 | 263 | 264 | def xor_dataset(p, num, seed=0, device='cpu'): 265 | 266 | torch.manual_seed(seed) 267 | np.random.seed(seed) 268 | 269 | N_sample = num 270 | x = np.random.choice(range(p), N_sample*2).reshape(N_sample, 2) 271 | 272 | target = np.array([x[i,0]^x[i,1] for i in range(N_sample)]) 273 | 274 | data_id = torch.from_numpy(x).to(device) 275 | labels = torch.from_numpy(target).to(device) 276 | 277 | vocab_size = p+2 278 | 279 | dataset = {} 280 | dataset['data_id'] = data_id 281 | dataset['label'] = labels 282 | dataset['vocab_size'] = vocab_size 283 | 284 | return dataset 285 | 286 | def multi_step_dataset(p, num, seed=0, device='cpu'): 287 | 288 | torch.manual_seed(seed) 289 | np.random.seed(seed) 290 | 291 | N_sample = num 292 | x = np.random.choice(range(p), N_sample*3).reshape(N_sample, 3) 293 | 294 | target = np.array([(x[i,0]*x[i,1]+x[i,2])%p for i in range(N_sample)]) 295 | 296 | data_id = torch.from_numpy(x).to(device) 297 | labels = torch.from_numpy(target).to(device) 298 | 299 | vocab_size = p 300 | 301 | dataset = {} 302 | dataset['data_id'] = data_id 303 | dataset['label'] = labels 304 | dataset['vocab_size'] = vocab_size 305 | 306 | return dataset 307 | 308 | 309 | def mod_classification_dataset(p, num, seed=0, device='cpu'): 310 | 311 | torch.manual_seed(seed) 312 | np.random.seed(seed) 313 | 314 | N_sample = num 315 | x = np.random.choice(range(p), N_sample).reshape(N_sample, 1) 316 | 317 | target = np.array([x[i,0]%5 for i in range(N_sample)]) 318 | 319 | data_id = torch.from_numpy(x).to(device) 320 | labels = torch.from_numpy(target).to(device) 321 | 322 | vocab_size = p 323 | 324 | dataset = {} 325 | dataset['data_id'] = data_id 326 | dataset['label'] = labels 327 | dataset['vocab_size'] = vocab_size 328 | 329 | return dataset 330 | 331 | 332 | def mod_equiv_dataset(p, num, seed=0, device='cpu'): 333 | 334 | torch.manual_seed(seed) 335 | np.random.seed(seed) 336 | 337 | N_sample = num 338 | x = np.random.choice(range(p), N_sample*2).reshape(N_sample, 2) 339 | 340 | target = np.array([p if (x[i,0]-x[i,1])%5 == 0 else p+1 for i in range(N_sample)]) 341 | 342 | data_id = torch.from_numpy(x).to(device) 343 | labels = torch.from_numpy(target).to(device) 344 | 345 | vocab_size = p+2 346 | 347 | dataset = {} 348 | dataset['data_id'] = data_id 349 | dataset['label'] = labels 350 | dataset['vocab_size'] = vocab_size 351 | 352 | return dataset 353 | 354 | def family_tree_dataset(p, num, seed=0, device='cpu'): 355 | 356 | torch.manual_seed(seed) 357 | np.random.seed(seed) 358 | 359 | N_sample = num 360 | ret_dic = GenerateFamilyTree(nodes_MAX=p, max_child_per_gen=3, seed=seed) 361 | 362 | unique_mapping = {element: index for index, element in enumerate(set(ret_dic["col_relation"]))} 363 | 364 | # Map the list to numbers 365 | mapped_list = [unique_mapping[element]+p+1 for element in ret_dic["col_relation"]] 366 | x = np.zeros((len(ret_dic["col_subject"]), 3), dtype=np.int32) 367 | x[:,0] = ret_dic["col_subject"] 368 | x[:,1] = ret_dic["col_object"] 369 | x[:,2] = mapped_list 370 | 371 | random_indices = np.random.choice(x.shape[0], size=N_sample) 372 | data = x[random_indices] 373 | 374 | data_id = torch.from_numpy(data[:,:2]).to(device) 375 | labels = torch.from_numpy(data[:,2]).to(device) 376 | 377 | vocab_size = max(mapped_list) + 1 378 | 379 | dataset = {} 380 | dataset['data_id'] = data_id 381 | dataset['label'] = labels 382 | dataset['vocab_size'] = vocab_size 383 | dataset['dict_level'] = ret_dic['dict_level'] 384 | 385 | return dataset 386 | 387 | def family_tree_dataset_2(p, num, seed=0, device='cpu'): 388 | 389 | torch.manual_seed(seed) 390 | np.random.seed(seed) 391 | 392 | N_sample = num 393 | 394 | ''' 395 | p : parent 396 | p+1 : grandparent 397 | p+2 : sibling 398 | ''' 399 | total_data = [] 400 | for i in range(1,p): 401 | if i >= 2: 402 | total_data.append([i, p, i // 2]) 403 | if i >= 4: 404 | total_data.append([i, p+1, i // 4]) 405 | if i > 1: 406 | if i % 2 == 0: 407 | total_data.append([i, p+2, i+1]) 408 | else: 409 | total_data.append([i, p+2, i-1]) 410 | 411 | data = np.array(total_data) 412 | 413 | data_id = np.random.choice(len(data), size=num, replace=True) 414 | x = data[data_id] 415 | 416 | 417 | dict_level = dict() 418 | dict_level[1] = 0 419 | for i in range(1,p): 420 | if i*2 < p: 421 | dict_level[i*2] = dict_level[i] + 1 422 | if i*2+1 < p: 423 | dict_level[i*2+1] = dict_level[i] + 1 424 | 425 | 426 | data_id = torch.from_numpy(x[:, :2]).to(device) 427 | labels = torch.from_numpy(x[:, 2]).to(device) 428 | 429 | vocab_size = p+3 430 | 431 | dataset = {} 432 | dataset['data_id'] = data_id 433 | dataset['label'] = labels 434 | dataset['vocab_size'] = vocab_size 435 | dataset['dict_level'] = dict_level 436 | 437 | return dataset -------------------------------------------------------------------------------- /src/utils/driver.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from src.utils.dataset import * 4 | from src.utils.model import * 5 | import os 6 | 7 | import numpy as np 8 | 9 | def set_seed(seed: int) -> None: 10 | """ 11 | Sets the seed to make everything deterministic, for reproducibility of experiments 12 | 13 | Parameters: 14 | seed: the number to set the seed to 15 | 16 | Return: None 17 | """ 18 | 19 | # Random seed 20 | random.seed(seed) 21 | 22 | # Numpy seed 23 | np.random.seed(seed) 24 | 25 | # Torch seed 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = True 30 | 31 | # os seed 32 | os.environ['PYTHONHASHSEED'] = str(seed) 33 | 34 | def train_single_model(param_dict: dict): 35 | 36 | if "seed" not in param_dict: 37 | raise ValueError("seed must be provided in param_dict") 38 | if "data_id" not in param_dict: 39 | raise ValueError("data_id must be provided in param_dict") 40 | if "data_size" not in param_dict: 41 | raise ValueError("data_size must be provided in param_dict") 42 | if "train_ratio" not in param_dict: 43 | raise ValueError("train_ratio must be provided in param_dict") 44 | if "model_id" not in param_dict: 45 | raise ValueError("model_id must be provided in param_dict") 46 | if "device" not in param_dict: 47 | raise ValueError("device must be provided in param_dict") 48 | if "embd_dim" not in param_dict: 49 | raise ValueError("embd_dim must be provided in param_dict") 50 | if "n_exp" not in param_dict: 51 | raise ValueError("n_exp must be provided in param_dict") 52 | 53 | 54 | seed = param_dict['seed'] 55 | data_id = param_dict['data_id'] 56 | data_size = param_dict['data_size'] 57 | train_ratio = param_dict['train_ratio'] 58 | model_id = param_dict['model_id'] 59 | device = param_dict['device'] 60 | embd_dim = param_dict['embd_dim'] 61 | n_exp = param_dict['n_exp'] 62 | 63 | video = False if 'video' not in param_dict else param_dict['video'] 64 | lr = 0.002 if 'lr' not in param_dict else param_dict['lr'] 65 | weight_decay = 0.01 if 'weight_decay' not in param_dict else param_dict['weight_decay'] 66 | verbose = False if 'verbose' not in param_dict else param_dict['verbose'] 67 | lamb_reg = 0.01 if 'lamb_reg' not in param_dict else param_dict['lamb_reg'] 68 | custom_loss = None if 'custom_loss' not in param_dict else param_dict['custom_loss'] 69 | use_custom_loss = False if custom_loss is None else True 70 | 71 | set_seed(seed) 72 | 73 | # define dataset 74 | input_token = 2 75 | num_epochs = 7000 if 'num_epochs' not in param_dict else param_dict['num_epochs'] 76 | if data_id == "lattice": 77 | dataset = parallelogram_dataset(p=5, dim=2, num=data_size, seed=seed, device=device) 78 | input_token = 3 79 | elif data_id == "greater": 80 | dataset = greater_than_dataset(p=30, num=data_size, seed=seed, device=device) 81 | elif data_id == "family_tree": 82 | dataset = family_tree_dataset_2(p=127, num=data_size, seed=seed, device=device) 83 | elif data_id == "equivalence": 84 | input_token = 2 85 | dataset = mod_equiv_dataset(p=40, num=data_size, seed=seed, device=device) 86 | elif data_id == "circle": 87 | dataset = modular_addition_dataset(p=31, num=data_size, seed=seed, device=device) 88 | elif data_id=="permutation": 89 | dataset = permutation_group_dataset(p=4, num=data_size, seed=seed, device=device) 90 | if model_id == "H_transformer" or model_id == "standard_transformer": 91 | num_epochs = 10000 # extra epochs to train fully 92 | else: 93 | raise ValueError(f"Unknown data_id: {data_id}") 94 | 95 | dataset = split_dataset(dataset, train_ratio=train_ratio, seed=seed) 96 | vocab_size = dataset['vocab_size'] 97 | 98 | # define model 99 | if model_id == "H_MLP": 100 | weight_tied = True 101 | hidden_size = 100 102 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 103 | if use_custom_loss: 104 | loss_type = custom_loss 105 | else: 106 | loss_type = 'harmonic' 107 | print(loss_type) 108 | model = MLP_HS(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, weight_tied=weight_tied, seed=seed, n=n_exp, init_scale=1, loss_type=loss_type).to(device) 109 | elif model_id == "standard_MLP": 110 | unembd = True 111 | weight_tied = True 112 | hidden_size = 100 113 | shp = [input_token * embd_dim, hidden_size, embd_dim, vocab_size] 114 | model = MLP(shp=shp, vocab_size=vocab_size, embd_dim=embd_dim, input_token=input_token, unembd=unembd, weight_tied=weight_tied, seed=seed, init_scale=1, loss_type = "cross_entropy").to(device) 115 | elif model_id == "H_transformer": 116 | if use_custom_loss: 117 | loss_type = custom_loss 118 | else: 119 | loss_type = 'harmonic' 120 | print(loss_type) 121 | model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, n_dist=n_exp,seq_len=input_token, seed=seed, use_dist_layer=True, init_scale=1, loss_type=loss_type).to(device) 122 | elif model_id == "standard_transformer": 123 | model = ToyTransformer(vocab_size=vocab_size, d_model=embd_dim, nhead=2, num_layers=2, seq_len=input_token, seed=seed, use_dist_layer=False, init_scale=1, loss_type = "cross_entropy").to(device) 124 | else: 125 | raise ValueError(f"Unknown model_id: {model_id}") 126 | 127 | # define dataloader 128 | batch_size = 32 129 | train_dataset = ToyDataset(dataset['train_data_id'], dataset['train_label']) 130 | test_dataset = ToyDataset(dataset['test_data_id'], dataset['test_label']) 131 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 132 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 133 | 134 | ret_dic = {} 135 | ret_dic["results"] = model.train(param_dict={'model_id':model_id,'num_epochs': num_epochs, 'learning_rate': lr, 'weight_decay':weight_decay, 'train_dataloader': train_dataloader, 'test_dataloader': test_dataloader, 'device': device, 'video': video, 'verbose': verbose, 'lambda': lamb_reg}) 136 | ret_dic["model"] = model 137 | ret_dic["dataset"] = dataset 138 | 139 | return ret_dic -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from sklearn.decomposition import PCA 2 | import matplotlib.pyplot as plt 3 | from adjustText import adjust_text 4 | import random 5 | import itertools 6 | import numpy as np 7 | from collections import defaultdict 8 | import itertools 9 | import numpy as np 10 | from collections import defaultdict 11 | 12 | colors = [ "#{:06x}".format(random.randint(0, 0xFFFFFF)) for i in range(1000)] 13 | 14 | def visualize_embedding(emb, title="", save_path=None, dict_level = None, color_dict=True, adjust_overlapping_text=False): 15 | # adjustText is a library that cleans up overlapping text in the figure, which is helpful for permutation. Feel free to comment it out. 16 | 17 | pca = PCA(n_components=2) 18 | emb_pca = pca.fit_transform(emb.detach().numpy()) 19 | print("Explained Variance Ratio", pca.explained_variance_ratio_) 20 | dim1 = 0 21 | dim2 = 1 22 | plt.rcParams.update({'font.size': 12}) 23 | plt.title(title) 24 | if adjust_overlapping_text: 25 | texts = [] 26 | x = [] 27 | y = [] 28 | for i in range(len(emb_pca)): 29 | if dict_level: 30 | if i in dict_level: 31 | plt.scatter(emb_pca[i, dim1], emb_pca[i, dim2], c=colors[dict_level[i]] if color_dict else 'k') 32 | if adjust_overlapping_text: 33 | texts.append(plt.text(emb_pca[i, dim1], emb_pca[i, dim2], str(dict_level[i]), fontsize=12)) 34 | else: 35 | plt.text(emb_pca[i, dim1], emb_pca[i, dim2], str(dict_level[i]), fontsize=12) 36 | else: 37 | plt.scatter(emb_pca[i, dim1], emb_pca[i, dim2], c='k') 38 | if adjust_overlapping_text: 39 | texts.append(plt.text(emb_pca[i, dim1], emb_pca[i, dim2], str(i), fontsize=12)) 40 | else: 41 | plt.text(emb_pca[i, dim1], emb_pca[i, dim2], str(i), fontsize=12) 42 | 43 | if adjust_overlapping_text: 44 | x.append(emb_pca[i,dim1]) 45 | y.append(emb_pca[i,dim2]) 46 | 47 | if adjust_overlapping_text: 48 | print("Adjusting text") 49 | adjust_text(texts, x=x, y=y, autoalign='xy', force_points=0.5, only_move = {'text':'xy'}) 50 | if save_path: 51 | plt.tight_layout() 52 | plt.savefig(save_path, bbox_inches='tight') 53 | #plt.show() 54 | #plt.close() 55 | 56 | def get_right_and_left_coset(): 57 | s_12 = [1234, 2143, 3412, 4321, 3124, 2314, 4132, 2431, 4213, 3241, 1423, 1342] 58 | s_8_1 = [1234, 4312, 3421, 2143, 4321, 3412, 2134, 1243] 59 | s_8_2 = [1234, 4123, 2341, 3412, 2143, 4321, 3214, 1432] 60 | s_8_3 = [1234, 2413, 3142, 4321, 2143, 4231, 1324] 61 | s_6_1 = [1234, 2134, 3214, 1324, 3124, 2314] 62 | s_6_2 = [1234, 3214, 4231, 1243, 4213, 3241] 63 | s_6_3 = [1234, 1324, 1432, 1243, 1423, 1342] 64 | s_6_4 = [1234, 2134, 4231, 1432, 4132, 2431] 65 | s_4_1 = [1234, 2134, 1243, 2143] 66 | s_4_2 = [1234, 3214, 1432, 3412] 67 | s_4_3 = [1234, 4231, 1324, 4321] 68 | s_4_4 = [1234, 2143, 3412, 4321] 69 | s_4_5 = [1234, 4312, 2143, 3421] 70 | s_4_6 = [1234, 4123, 3412, 2341] 71 | s_4_7 = [1234, 3142, 4321, 2413] 72 | s_3_1 = [1234, 3124, 2314] 73 | s_3_2 = [1234, 4132, 2431] 74 | s_3_3 = [1234, 4213, 3241] 75 | s_3_4 = [1234, 1423, 1342] 76 | s_2_1 = [1234, 2134] 77 | s_2_2 = [1234, 3214] 78 | s_2_3 = [1234, 4231] 79 | s_2_4 = [1234, 1324] 80 | s_2_5 = [1234, 1432] 81 | s_2_6 = [1234, 1243] 82 | s_2_7 = [1234, 2143] 83 | s_2_8 = [1234, 3412] 84 | s_2_9 = [1234, 4321] 85 | 86 | subgroup_list = [s_12, s_8_1, s_8_2, s_8_3, s_6_1, s_6_2, s_6_3, s_6_4, s_4_1, s_4_2, s_4_3, s_4_4, s_4_5, s_4_6, s_4_7, s_3_1, s_3_2, s_3_3, s_3_4, s_2_1, s_2_2, s_2_3, s_2_4, s_2_5, s_2_6, s_2_7, s_2_8, s_2_9] 87 | subgroup_list = [np.array(s) for s in subgroup_list] 88 | 89 | new_sub_list = [] 90 | for i in subgroup_list: 91 | new_subgroup = [] 92 | for ele in i: 93 | s = str(ele) 94 | new_arr = np.array([s[0], s[1], s[2], s[3]]).astype(int) - 1 95 | new_subgroup.append(new_arr) 96 | new_sub_list.append(new_subgroup) 97 | 98 | 99 | perms = list(itertools.permutations(range(4))) 100 | s_4_group = [np.array(perms[i]) for i in range(len(perms))] 101 | 102 | right_coset_list = [] 103 | left_coset_list = [] 104 | for subgroup in new_sub_list: 105 | new_right = set() 106 | new_left = set() 107 | for g in s_4_group: 108 | pot_right = set() 109 | pot_left = set() 110 | for s in subgroup: 111 | pot_right.add(tuple(g[s])) 112 | pot_left.add(tuple(s[g])) 113 | # print(tuple(sorted(pot_right))) 114 | new_right.add(tuple(sorted(pot_right))) 115 | new_left.add(tuple(sorted(pot_left))) 116 | new_right = np.array(list(new_right)) 117 | new_left = np.array(list(new_left)) 118 | right_coset_list.append(new_right) 119 | left_coset_list.append(new_left) 120 | return right_coset_list, left_coset_list 121 | 122 | def visualize_embedding_permutations(emb, right_coset_list, left_coset_list, title="", save_path=None, dict_level=None, adjust_overlapping_text=False, text=False): 123 | 124 | pca = PCA(n_components=2) 125 | emb_pca = pca.fit_transform(emb.detach().numpy()) 126 | print("Explained Variance Ratio", pca.explained_variance_ratio_) 127 | 128 | # generate all permutations of [0, 1, 2, 3] 129 | all_permutations = list(itertools.permutations(range(4))) 130 | perm_to_index = {perm: idx for idx, perm in enumerate(all_permutations)} 131 | 132 | fig, axs = plt.subplots(14, 4, figsize=(10, 28)) 133 | plt.subplots_adjust(hspace=0.2, wspace=0.2) 134 | 135 | texts = [] 136 | 137 | # plot right cosets with colors based on permutation arrays 138 | for idx, coset in enumerate(right_coset_list): 139 | ax = axs[idx // 4, idx % 4] 140 | 141 | ax.set_xticks([]) 142 | ax.set_yticks([]) 143 | 144 | ax.set_title(f"Right Coset {idx}") 145 | coset_cnt = len(coset) 146 | colors = plt.cm.viridis(np.linspace(0, 1, coset_cnt)) 147 | 148 | for pidx, perm_array in enumerate(coset): 149 | color = colors[pidx] 150 | for perm in perm_array: 151 | perm_tuple = tuple(perm) 152 | perm_index = perm_to_index[perm_tuple] 153 | ax.scatter(emb_pca[perm_index, 0], emb_pca[perm_index, 1], c=[color]) 154 | 155 | # Add text based on dict_level 156 | if dict_level and perm_index in dict_level: 157 | label = str(dict_level[perm_index]) 158 | if text: 159 | if adjust_overlapping_text: 160 | texts.append(ax.text(emb_pca[perm_index, 0], emb_pca[perm_index, 1], label, fontsize=12)) 161 | else: 162 | ax.text(emb_pca[perm_index, 0], emb_pca[perm_index, 1], label, fontsize=8) 163 | 164 | # plog left cosets 165 | for idx, coset in enumerate(left_coset_list): 166 | ax = axs[(idx + len(right_coset_list)) // 4, (idx + len(right_coset_list)) % 4] 167 | ax.set_xticks([]) 168 | ax.set_yticks([]) 169 | 170 | ax.set_title(f"Left Coset {idx}") 171 | coset_cnt = len(coset) 172 | colors = plt.cm.viridis(np.linspace(0, 1, coset_cnt)) 173 | 174 | for pidx, perm_array in enumerate(coset): 175 | color = colors[pidx] 176 | for perm in perm_array: 177 | perm_tuple = tuple(perm) 178 | perm_index = perm_to_index[perm_tuple] 179 | ax.scatter(emb_pca[perm_index, 0], emb_pca[perm_index, 1], c=[color]) 180 | 181 | if dict_level and perm_index in dict_level: 182 | label = str(dict_level[perm_index]) 183 | if text: 184 | if adjust_overlapping_text: 185 | texts.append(ax.text(emb_pca[perm_index, 0], emb_pca[perm_index, 1], label, fontsize=12)) 186 | else: 187 | ax.text(emb_pca[perm_index, 0], emb_pca[perm_index, 1], label, fontsize=8) 188 | 189 | if adjust_overlapping_text: 190 | adjust_text(texts, autoalign='xy', force_points=0.5, only_move={'text': 'xy'}) 191 | 192 | if save_path: 193 | plt.tight_layout() 194 | plt.savefig(save_path, bbox_inches='tight') 195 | 196 | def silhouette_score(points, labels, penalty_weight=0): 197 | points = np.array(points) 198 | labels = np.array(labels) 199 | 200 | clusters = defaultdict(list) 201 | for point, label in zip(points, labels): 202 | clusters[label].append(point) 203 | 204 | silhouette_scores = [] 205 | for _, (point, label) in enumerate(zip(points, labels)): 206 | same_cluster = np.array(clusters[label]) 207 | 208 | # compute a(i): mean distance within the same cluster (excluding the point itself) 209 | a = np.mean([np.linalg.norm(point - other) for other in same_cluster if not np.array_equal(point, other)]) 210 | 211 | # compute b(i): mean distance to the nearest different cluster 212 | b = float('inf') 213 | for other_label, other_cluster in clusters.items(): 214 | if other_label != label: 215 | other_cluster = np.array(other_cluster) 216 | mean_distance = np.mean([np.linalg.norm(point - other) for other in other_cluster]) 217 | b = min(b, mean_distance) 218 | 219 | # edge case when a cluster has only one point (in case, shouldn't be an issue) 220 | if len(same_cluster) == 1: 221 | silhouette_scores.append(0) 222 | else: 223 | silhouette_scores.append((b - a) / max(a, b)) 224 | 225 | # average 226 | avg_silhouette_score = np.mean(silhouette_scores) 227 | 228 | # a penalty for the number of unique labels, set to 0 229 | num_clusters = len(clusters) 230 | penalty = penalty_weight * num_clusters 231 | adjusted_score = avg_silhouette_score - penalty 232 | 233 | return adjusted_score 234 | 235 | def plot_single_coset(array_list, emb_pca, perm_to_index, title=None, save_path=None, dict_level=None, adjust_overlapping_text=False): 236 | # plt.rcParams.update({'font.size': 12}) 237 | 238 | if title: 239 | plt.title(title) 240 | 241 | if adjust_overlapping_text: 242 | texts = [] 243 | 244 | array_count = len(array_list) 245 | colors = plt.cm.viridis(np.linspace(0, 1, array_count)) # plt.cm.tab20(range(array_count%20)) 246 | 247 | for array_idx, perm_array in enumerate(array_list): 248 | color = colors[array_idx] 249 | for perm in perm_array: 250 | perm_tuple = tuple(perm) 251 | if perm_tuple in perm_to_index: 252 | perm_index = perm_to_index[perm_tuple] 253 | plt.scatter(emb_pca[perm_index, 0], emb_pca[perm_index, 1], c=[color]) 254 | 255 | if dict_level and perm_index in dict_level: 256 | label = str(dict_level[perm_index]) 257 | if adjust_overlapping_text: 258 | texts.append(plt.text(emb_pca[perm_index, 0], emb_pca[perm_index, 1], label, fontsize=8)) 259 | else: 260 | plt.text(emb_pca[perm_index, 0], emb_pca[perm_index, 1], label, fontsize=8) 261 | 262 | if adjust_overlapping_text: 263 | adjust_text(texts, autoalign='xy', force_points=0.5, only_move={'text': 'xy'}) 264 | 265 | if save_path: 266 | plt.savefig(save_path) 267 | 268 | # plt.show() 269 | # plt.close() 270 | 271 | def visualize_best_embedding(emb, right_coset_list, left_coset_list, title="", save_name=None, dict_level=None, adjust_overlapping_text=False, penalty_weight=0, input_best=None): 272 | pca = PCA(n_components=2) 273 | emb_pca = pca.fit_transform(emb.detach().numpy()) 274 | # print("Explained Variance Ratio", pca.explained_variance_ratio_) 275 | 276 | total_ev = np.sum(pca.explained_variance_ratio_) 277 | 278 | save_path = None 279 | if save_name: 280 | save_path = f"{save_name}_ev_{total_ev:.4f}.png" 281 | 282 | # generate all permutations of [0, 1, 2, 3] 283 | all_permutations = list(itertools.permutations(range(4))) 284 | perm_to_index = {perm: idx for idx, perm in enumerate(all_permutations)} 285 | 286 | full_coset_list = right_coset_list + left_coset_list 287 | 288 | coset_scores = [] 289 | for idx, coset in enumerate(full_coset_list): 290 | X_arr = [] 291 | label_arr = [] 292 | 293 | labels=np.arange(len(coset)) 294 | 295 | for pidx, perm_array in enumerate(coset): 296 | for perm in perm_array: 297 | perm_tuple = tuple(perm) 298 | perm_index = perm_to_index[perm_tuple] 299 | X_arr.append([emb_pca[perm_index, 0], emb_pca[perm_index, 1]]) 300 | label_arr.append(labels[pidx]) 301 | coset_scores.append(silhouette_score(X_arr, label_arr, penalty_weight)) 302 | 303 | coset_name = 'Right' 304 | best_coset = np.argmax(coset_scores) 305 | 306 | if input_best: 307 | best_coset = input_best 308 | 309 | if best_coset > len(right_coset_list) - 1: 310 | best_coset = best_coset - len(right_coset_list) 311 | coset_name = 'Left' 312 | coset = left_coset_list[best_coset] 313 | else: 314 | coset=right_coset_list[best_coset] 315 | 316 | # print(f"Best coset found, {coset_name} {best_coset}") 317 | 318 | plot_single_coset(coset, emb_pca, perm_to_index, title=title, save_path=save_path, dict_level=dict_level, adjust_overlapping_text=adjust_overlapping_text) 319 | 320 | def visualize_best_embedding(emb, right_coset_list, left_coset_list, title="", save_name=None, dict_level=None, adjust_overlapping_text=False, penalty_weight=0, input_best=None): 321 | pca = PCA(n_components=2) 322 | emb_pca = pca.fit_transform(emb.detach().numpy()) 323 | print("Explained Variance Ratio", pca.explained_variance_ratio_) 324 | 325 | total_ev = np.sum(pca.explained_variance_ratio_) 326 | 327 | save_path = None 328 | if save_name: 329 | save_path = f"{save_name}_ev_{total_ev:.4f}.png" 330 | 331 | # generate all permutations of [0, 1, 2, 3] 332 | all_permutations = list(itertools.permutations(range(4))) 333 | perm_to_index = {perm: idx for idx, perm in enumerate(all_permutations)} 334 | 335 | full_coset_list = right_coset_list + left_coset_list 336 | 337 | coset_scores = [] 338 | for idx, coset in enumerate(full_coset_list): 339 | X_arr = [] 340 | label_arr = [] 341 | 342 | labels=np.arange(len(coset)) 343 | 344 | for pidx, perm_array in enumerate(coset): 345 | for perm in perm_array: 346 | perm_tuple = tuple(perm) 347 | perm_index = perm_to_index[perm_tuple] 348 | X_arr.append([emb_pca[perm_index, 0], emb_pca[perm_index, 1]]) 349 | label_arr.append(labels[pidx]) 350 | coset_scores.append(silhouette_score(X_arr, label_arr, penalty_weight)) 351 | 352 | coset_name = 'Right' 353 | best_coset = np.argmax(coset_scores) 354 | 355 | if input_best: 356 | best_coset = input_best 357 | 358 | if best_coset > len(right_coset_list) - 1: 359 | best_coset = best_coset - len(right_coset_list) 360 | coset_name = 'Left' 361 | coset = left_coset_list[best_coset] 362 | else: 363 | coset=right_coset_list[best_coset] 364 | 365 | print(f"Best coset found, {coset_name} {best_coset}") 366 | 367 | plot_single_coset(coset, emb_pca, perm_to_index, title=title, save_path=save_path, dict_level=dict_level, adjust_overlapping_text=adjust_overlapping_text) 368 | 369 | 370 | def visualize_embedding_3d(emb, title="", save_path=None, dict_level = None, color_dict=True): 371 | pca = PCA(n_components=3) 372 | emb_pca = pca.fit_transform(emb.detach().numpy()) 373 | print("Explained Variance Ratio:", pca.explained_variance_ratio_) 374 | 375 | plt.rcParams.update({'font.size': 12}) 376 | fig = plt.figure() 377 | ax = fig.add_subplot(111, projection='3d') 378 | 379 | for i in range(len(emb_pca)): 380 | if dict_level: 381 | if i in dict_level: 382 | ax.scatter(emb_pca[i, 0], emb_pca[i, 1], emb_pca[i, 2], 383 | c=colors[dict_level[i]] if color_dict else 'k') 384 | ax.text(emb_pca[i, 0], emb_pca[i, 1], emb_pca[i, 2], 385 | str(dict_level[i]), fontsize=12) 386 | else: 387 | ax.scatter(emb_pca[i, 0], emb_pca[i, 1], emb_pca[i, 2], c='k') 388 | ax.text(emb_pca[i, 0], emb_pca[i, 1], emb_pca[i, 2], str(i), fontsize=12) 389 | ax.set_title(title) 390 | ax.set_xlabel('PC1') 391 | ax.set_ylabel('PC2') 392 | ax.set_zlabel('PC3') 393 | if save_path: 394 | plt.savefig(save_path) 395 | else: 396 | plt.show() --------------------------------------------------------------------------------