├── alignment └── __init__.py ├── neural ├── fits │ └── __init__.py ├── results │ ├── __init__.py │ └── parcels.csv ├── feher2023rethinking │ └── prompts.jsonl ├── combine.py └── extract.py ├── results ├── __init__.py ├── entropy_regression_coef.npy ├── all_data_random.csv ├── all_data_baseline.csv ├── tuckute2024driving │ ├── llama.pth │ ├── random.pth │ ├── centaur2000.pth │ ├── llama_sem.pth │ ├── random_sem.pth │ └── centaur2000_sem.pth ├── feher2023rethinking │ ├── tst_llama_alignment.csv │ ├── tst_centaur_alignment.csv │ ├── schaefer_tst_llama_alignment.csv │ ├── schaefer_tst_llama_alignment.pth │ ├── schaefer_tst_centaur_alignment.csv │ ├── schaefer_tst_centaur_alignment.pth │ ├── schaefer_tst_cognitive_alignment.csv │ ├── schaefer_tst_cognitive_alignment.pth │ ├── schaefer_tst_random_alignment.csv │ └── schaefer_tst_random_alignment.pth ├── all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv ├── all_data_unsloth-Meta-Llama-3.1-8B-bnb-4bit.csv ├── all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv ├── all_data_marcelbinz-Llama-3.1-Centaur-8B-adapter.csv ├── custom_metrics_full_log_likelihoods_baselines.pth ├── custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth ├── custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth ├── custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth ├── custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth ├── custom_metrics_full_log_likelihoods_unsloth-Hermes-3-Llama-3.1-70B-bnb-4bit.pth ├── custom_metrics_full_log_likelihoods_unsloth-Reflection-Llama-3.1-70B-bnb-4bit.pth ├── custom_metrics_full_log_likelihoods_unsloth-llama-3-70b-Instruct-bnb-4bit.pth ├── custom_metrics_full_log_likelihoods_unsloth-Llama-3.1-Nemotron-70B-Instruct-bnb-4bit.pth ├── metabench │ ├── centaur-2000-results.json │ └── base-llama-3_1-70B-results.json └── CogBench │ ├── performance.csv │ └── behaviour.csv ├── openloop ├── results │ ├── __init__.py │ ├── baselines_openloop_centaur_horizon1.csv │ ├── baselines_openloop_centaur_horizon2.csv │ ├── baselines_openloop_centaur_horizon3.csv │ ├── baselines_openloop_centaur_horizon4.csv │ ├── baselines_openloop_centaur_twostep1.csv │ ├── baselines_openloop_centaur_twostep2.csv │ ├── baselines_openloop_human_horizon1.csv │ ├── baselines_openloop_human_horizon2.csv │ ├── baselines_openloop_human_horizon3.csv │ ├── baselines_openloop_human_horizon4.csv │ ├── baselines_openloop_human_twostep1.csv │ └── baselines_openloop_human_twostep2.csv ├── kool2016when │ ├── exp2.csv │ ├── simulation.csv │ └── simulate.py ├── kool2017cost │ ├── exp2.csv │ └── simulation.csv ├── wilson2014humans │ ├── exp1.csv │ ├── exp2.csv │ ├── exp3.csv │ ├── exp4.csv │ ├── exp5.csv │ ├── simulation0.csv │ ├── simulation1.csv │ ├── simulation2.csv │ ├── simulation3.csv │ ├── simulation4.csv │ └── simulate.py ├── baar2021latent │ ├── gameDat.csv │ ├── simulation_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv │ ├── stats.py │ └── simulate.py ├── jansen2021dunningkruger │ ├── exp1.csv │ └── simulation.csv ├── trainers.py └── openloop.py ├── generalization ├── results │ ├── __init__.py │ ├── privileged.csv │ ├── generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth │ ├── generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth │ ├── generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth │ ├── generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth │ ├── additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth │ ├── additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth │ ├── additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth │ └── additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth ├── feher2020humans │ ├── exp1.csv │ └── prompts.jsonl ├── jansen2021logic │ ├── exp1.csv │ └── prompts.jsonl ├── dubois2022value │ ├── exp1.csv │ └── prompts.jsonl ├── additional_experiments │ ├── awad2018moral.jsonl │ ├── xu2021novelty.jsonl │ ├── akata2023repeatedgames.jsonl │ ├── singh2022representing.jsonl │ ├── demircan2024evaluatingcategory.jsonl │ └── demircan2024evaluatingreward.jsonl ├── trainers.py ├── privileged.py ├── generalization.py ├── generalization_custom_metrics.py └── additional_generalization.py ├── plots ├── test.png ├── figures │ ├── fig4.pdf │ ├── fig5.pdf │ ├── fig6.pdf │ ├── fig7.pdf │ ├── fig8.pdf │ ├── fig9.pdf │ ├── fig10.pdf │ ├── fig11.pdf │ ├── fig12.pdf │ ├── fig2_new.pdf │ ├── fig3_new.pdf │ ├── fig4_new.pdf │ ├── fig2_8b=True.pdf │ ├── fig3_8b=True.pdf │ ├── fig2_8b=False.pdf │ └── fig3_8b=False.pdf ├── fig9.py ├── tab1_new.py ├── fig6.py ├── fig3_new.py ├── fig12.py ├── fig11.py ├── tab1.py ├── fig10.py ├── fig5.py ├── fig3.py ├── fig14.py ├── fig7.py ├── fig8.py ├── fig4.py └── fig4_new.py ├── camera_ready ├── 1.jpg ├── 2.pdf ├── 2.png ├── 3.pdf ├── 3.png ├── bandit.png ├── logical.png ├── tstcover.png ├── figures │ ├── fig1.pdf │ ├── fig2.pdf │ ├── fig4.pdf │ ├── fig3_new.pdf │ └── fig5_new.pdf ├── data │ ├── cognitive_nlls.pth │ └── log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth ├── fig5.py ├── fig3.py └── fig4.py ├── extended_data ├── test.png ├── overview.pdf ├── overview.png ├── wordcloud.pdf ├── wordcloud.png ├── embeddings.npy ├── figures │ ├── fig1.jpg │ ├── fig2.jpg │ ├── fig3.jpg │ ├── fig4.jpg │ ├── fig5.jpg │ ├── fig6.jpg │ ├── fig7.png │ └── tab1.jpg ├── ed_fig1.py ├── ed_fig4.py ├── ed_fig3.py ├── ed_fig5.py └── ed_fig7.py ├── experiments.csv ├── ceiling ├── results │ ├── ceiling.csv │ ├── unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv │ └── marcelbinz-Llama-3.1-Centaur-70B-adapter.csv ├── peterson2021using │ ├── exp1.csv │ └── prompts_zeroshot.jsonl ├── ruggeri2022globalizability │ ├── exp1.csv │ └── prompts_zeroshot.jsonl ├── trainers.py ├── ceiling.py ├── no_history.py └── models.py ├── contamination ├── results │ ├── As.pth │ └── Bs.pth └── contamination.py ├── .gitattributes ├── scripts ├── cluster_privileged.sh ├── cluster_llama_8b.sh ├── cluster_llama_70b.sh ├── cluster_8b.sh ├── cluster_70b.sh └── cluster_train.sh ├── metabench └── metabench.py ├── merge.py ├── run_minimal.py ├── test_adapter_custom_metrics.py ├── test_adapter.py ├── test.py ├── test_adapter_full_log_likelihoods.py └── finetune.py /alignment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neural/fits/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /neural/results/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openloop/results/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generalization/results/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plots/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/test.png -------------------------------------------------------------------------------- /camera_ready/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/1.jpg -------------------------------------------------------------------------------- /camera_ready/2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/2.pdf -------------------------------------------------------------------------------- /camera_ready/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/2.png -------------------------------------------------------------------------------- /camera_ready/3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/3.pdf -------------------------------------------------------------------------------- /camera_ready/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/3.png -------------------------------------------------------------------------------- /extended_data/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/test.png -------------------------------------------------------------------------------- /plots/figures/fig4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig4.pdf -------------------------------------------------------------------------------- /plots/figures/fig5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig5.pdf -------------------------------------------------------------------------------- /plots/figures/fig6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig6.pdf -------------------------------------------------------------------------------- /plots/figures/fig7.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig7.pdf -------------------------------------------------------------------------------- /plots/figures/fig8.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig8.pdf -------------------------------------------------------------------------------- /plots/figures/fig9.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig9.pdf -------------------------------------------------------------------------------- /camera_ready/bandit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/bandit.png -------------------------------------------------------------------------------- /camera_ready/logical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/logical.png -------------------------------------------------------------------------------- /camera_ready/tstcover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/tstcover.png -------------------------------------------------------------------------------- /plots/figures/fig10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig10.pdf -------------------------------------------------------------------------------- /plots/figures/fig11.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig11.pdf -------------------------------------------------------------------------------- /plots/figures/fig12.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig12.pdf -------------------------------------------------------------------------------- /extended_data/overview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/overview.pdf -------------------------------------------------------------------------------- /extended_data/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/overview.png -------------------------------------------------------------------------------- /extended_data/wordcloud.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/wordcloud.pdf -------------------------------------------------------------------------------- /extended_data/wordcloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/wordcloud.png -------------------------------------------------------------------------------- /plots/figures/fig2_new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig2_new.pdf -------------------------------------------------------------------------------- /plots/figures/fig3_new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig3_new.pdf -------------------------------------------------------------------------------- /plots/figures/fig4_new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig4_new.pdf -------------------------------------------------------------------------------- /camera_ready/figures/fig1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/figures/fig1.pdf -------------------------------------------------------------------------------- /camera_ready/figures/fig2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/figures/fig2.pdf -------------------------------------------------------------------------------- /camera_ready/figures/fig4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/figures/fig4.pdf -------------------------------------------------------------------------------- /extended_data/embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/embeddings.npy -------------------------------------------------------------------------------- /extended_data/figures/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig1.jpg -------------------------------------------------------------------------------- /extended_data/figures/fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig2.jpg -------------------------------------------------------------------------------- /extended_data/figures/fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig3.jpg -------------------------------------------------------------------------------- /extended_data/figures/fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig4.jpg -------------------------------------------------------------------------------- /extended_data/figures/fig5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig5.jpg -------------------------------------------------------------------------------- /extended_data/figures/fig6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig6.jpg -------------------------------------------------------------------------------- /extended_data/figures/fig7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/fig7.png -------------------------------------------------------------------------------- /extended_data/figures/tab1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/extended_data/figures/tab1.jpg -------------------------------------------------------------------------------- /plots/figures/fig2_8b=True.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig2_8b=True.pdf -------------------------------------------------------------------------------- /plots/figures/fig3_8b=True.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig3_8b=True.pdf -------------------------------------------------------------------------------- /plots/figures/fig2_8b=False.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig2_8b=False.pdf -------------------------------------------------------------------------------- /plots/figures/fig3_8b=False.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/plots/figures/fig3_8b=False.pdf -------------------------------------------------------------------------------- /camera_ready/figures/fig3_new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/figures/fig3_new.pdf -------------------------------------------------------------------------------- /camera_ready/figures/fig5_new.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/camera_ready/figures/fig5_new.pdf -------------------------------------------------------------------------------- /results/entropy_regression_coef.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcelbinz/Llama-3.1-Centaur-70B/HEAD/results/entropy_regression_coef.npy -------------------------------------------------------------------------------- /experiments.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8b6a93f8a18ddcd4c73dbdba1a91dcc2d99d8d6ef877c1c48974f90b3f76bdb8 3 | size 7338 4 | -------------------------------------------------------------------------------- /ceiling/results/ceiling.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:581206ccc925135d172fcd3ac2b179db9e4cbf5a7a88bb735f6381999d45c258 3 | size 172 4 | -------------------------------------------------------------------------------- /results/all_data_random.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a8eca00f2a9bbf556b71df64122ea03367449c9dce11d6e93d930f1e76d4745a 3 | size 2654 4 | -------------------------------------------------------------------------------- /contamination/results/As.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:58a3ec3f4306dd9c54dae8e48e41be33f85e9612d7d7b9006e7ebe0c46980f1f 3 | size 1411 4 | -------------------------------------------------------------------------------- /contamination/results/Bs.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d1bf46321e0ed4dd5a67c354e135b9ca205d75c60c969cb47d6220ba264f3604 3 | size 1411 4 | -------------------------------------------------------------------------------- /neural/results/parcels.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:648b13b2d05f7faa1663a799891d0b97a1f13e606b0b04f0cf1bed525237727d 3 | size 56284344 4 | -------------------------------------------------------------------------------- /results/all_data_baseline.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:45b302a4e809812dc3a283a00c3417e56a54f642f3efda4e52a3640d5525a776 3 | size 2109 4 | -------------------------------------------------------------------------------- /generalization/results/privileged.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:428b345be470b9140353b0bbd5fba3a2317a97eae4bb81023e437158d272f37a 3 | size 151 4 | -------------------------------------------------------------------------------- /openloop/kool2016when/exp2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:09ec944770f91091e2a90588d204639c1272929a0ade66fd7fcf992734f23797 3 | size 5110462 4 | -------------------------------------------------------------------------------- /openloop/kool2017cost/exp2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c0207de3f0b5f149cbc886da38c2d166c2f8ae0b83b4a0f9fa7641f7e5f59d59 3 | size 3002825 4 | -------------------------------------------------------------------------------- /openloop/kool2017cost/simulation.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aea859d4ab0fa9d0073157836d7cfad4a997ed48d057ddd396875367919bbee9 3 | size 97622 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:adb1044c6b5e80932cb1987e93158eeb6e6fe4c7845f92c9095dda015ff09966 3 | size 3333522 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/exp2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6229e080e9d358f684f601ac2b1e4d921aeb406f9c695e94ee79a4d36e7ebaa9 3 | size 360624 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/exp3.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2677fbe4f967041d46215e6b3c16f1034b9cb0cf365b8a9decbbd20242c7039b 3 | size 4017966 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/exp4.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:65507c2c22cf5ab0275307fb3c3319265e9315f937247cc1c4c22f4c90cdb3db 3 | size 4895934 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/exp5.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:878a026ee036d9bb74e86d27e74f9c030e94bd83f471dd587559be688473e879 3 | size 2409238 4 | -------------------------------------------------------------------------------- /results/tuckute2024driving/llama.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ab4f7e7eb4154a02bf3cbbe4f9922da6f2a2e6d63fc86cb3919bf77f6915a8d 3 | size 1490 4 | -------------------------------------------------------------------------------- /results/tuckute2024driving/random.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a1311ae372f284f693f3665dd57b4ea391db9093b897794f4b9a1b534e4c5085 3 | size 1495 4 | -------------------------------------------------------------------------------- /camera_ready/data/cognitive_nlls.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ad1d45ec5c888ad69b21a2a5d329444af35097d9b27ec559970a88f0f92de8f 3 | size 304575 4 | -------------------------------------------------------------------------------- /ceiling/peterson2021using/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a1590ea486c9e974dfc1daa4d94df26afef4eaafe2a0747e869280770f37c19a 3 | size 432757246 4 | -------------------------------------------------------------------------------- /generalization/feher2020humans/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:06acc5c6554602af8dd9a5ea0d1a6654228793c54b09cc016b0b73e4bbb58f67 3 | size 1187935 4 | -------------------------------------------------------------------------------- /generalization/jansen2021logic/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7effc10d215015c875d701401dc7af5d90af1465c116225af46e13ef6c1299a7 3 | size 6228906 4 | -------------------------------------------------------------------------------- /openloop/baar2021latent/gameDat.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2606ea4a25734d545860c08fe4a25b3647aecab803b4c86166dfd57f7ce2ee26 3 | size 1381719 4 | -------------------------------------------------------------------------------- /openloop/kool2016when/simulation.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:606eb033d7cad6dfd8401f1150ffc241385309b2ff7d825e550bbc5d3e541bab 3 | size 102462 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/simulation0.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:91fd8a3c107516996031ccd7f3a4696fda7493b479217018f66632485fddc05a 3 | size 77372 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/simulation1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f6a61b32f1b9946ab0304e59ea2e766d015cb63efa5e20f3b81333476411968a 3 | size 69 4 | -------------------------------------------------------------------------------- /results/tuckute2024driving/centaur2000.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1ec1ea792cd08b5d5eee49a0e37ed8b4798067d618f1dc98eb4b7d4b27e5f611 3 | size 1520 4 | -------------------------------------------------------------------------------- /results/tuckute2024driving/llama_sem.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1670911a1c8a6627442a0f065ee45cb91ea3a0902e2abc12724c8e28fbdb8f8b 3 | size 1510 4 | -------------------------------------------------------------------------------- /results/tuckute2024driving/random_sem.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ec1a5e86de322e9e4dd86ac3093183cc72c264de78bf06fd5716f3bb5bea4cff 3 | size 1515 4 | -------------------------------------------------------------------------------- /generalization/dubois2022value/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:674c16a260aa079b6a778c3aedc464752e4d828af9adfa48541fb423292b6345 3 | size 77499443 4 | -------------------------------------------------------------------------------- /generalization/feher2020humans/prompts.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2e182b3bcd35160d5316e4e63f60b321bf08932e37b69791bda5281841a00931 3 | size 826426 4 | -------------------------------------------------------------------------------- /neural/feher2023rethinking/prompts.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e9f7e138e02745e9842ff8631a70e26f80ccd9be2646939464fb131a907f8a20 3 | size 2489314 4 | -------------------------------------------------------------------------------- /openloop/jansen2021dunningkruger/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:810bbaf12eb067dbca08febcafa1e48c4630fbe2c98ffa263c2d3e135042be24 3 | size 5988544 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/simulation2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cf00cc724e59d07edd7492e3036a74f5ce6c3065d8cec5ae2e5f8e407c13133d 3 | size 133627 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/simulation3.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6d4335df6c31d32812c1394fe3626cacaf79ec81b398ebb916d379513f3d464f 3 | size 214419 4 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/simulation4.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:52f681c50d001b8d38d686ee069509373241b35d111cc6b80c97d34bcec89e04 3 | size 117035 4 | -------------------------------------------------------------------------------- /results/tuckute2024driving/centaur2000_sem.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:53395340767c383790017efe90dadbbdf0d204688ff2283da39fb902b2c7f451 3 | size 1540 4 | -------------------------------------------------------------------------------- /ceiling/peterson2021using/prompts_zeroshot.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e9617a79f7637c29335f11c505a34ad2a67034d27d357386800ba14194694209 3 | size 97114055 4 | -------------------------------------------------------------------------------- /ceiling/ruggeri2022globalizability/exp1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9d24bf090b0805828195080360adc1352495c4c8d292608c42d37f3d5971758f 3 | size 112909611 4 | -------------------------------------------------------------------------------- /generalization/dubois2022value/prompts.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9b333c096e4836254fc2b2255b92c396f00f05c502281a576f9a843178838cd6 3 | size 142889957 4 | -------------------------------------------------------------------------------- /generalization/jansen2021logic/prompts.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c8dca3ad24c9ab575e6e08acacb19ab7b4a24c2f86f0b624174b065f694f6d7e 3 | size 49114620 4 | -------------------------------------------------------------------------------- /openloop/jansen2021dunningkruger/simulation.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:04b50e4a72c555947ff20188c1d75893ac98decef79cd4c69c3ef958ded3fd9f 3 | size 1505118 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/tst_llama_alignment.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eeda28b45a476904d298b4d91623a834b1eda04635aee7726821aef338e61259 3 | size 7596 4 | -------------------------------------------------------------------------------- /ceiling/results/unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ab8705cc91b38db88a9e57770781d9ea55d58a3a3e4f8a79b0398550889b224e 3 | size 174 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_centaur_horizon1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:faced0bc3777821f3d6c34da8e23ac9ab35a5825939faff3eed63c91778d47ed 3 | size 195 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_centaur_horizon2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e1a34c826926d0605424e3e80c6d9e9f62770f6744fd26f0779198c464ec37be 3 | size 323 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_centaur_horizon3.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dcffb353eeebaa5cf2a44b0623d8f1821ffcca02821c30062f0b7f27ff88f01a 3 | size 481 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_centaur_horizon4.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:48263ed596124600b991daa636e7806f26200e42280ae7802dd7f9854c449c01 3 | size 284 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_centaur_twostep1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c0ad83c5c8b1e45d26cdcbc0b8c1699341e9ccd31805cd28128e5b411d4b2e26 3 | size 636 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_centaur_twostep2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:96fc1e516879b98fe0bf81b4c4a06edc699ab757c92103e39d37796a2c51930d 3 | size 393 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_human_horizon1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7f86670b6978c859c71584f209f447ce98367f0545cd5df33b995027b7212f83 3 | size 195 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_human_horizon2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ad3170d1c5fa7981baec2b6cde94e1d253a6cba997ea5e4ee3f3fadcf417bf9 3 | size 324 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_human_horizon3.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2911b4db5ecb04da51ddaebc4263ef4cba159395c1b8a5b59c92c1dacb1d871d 3 | size 485 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_human_horizon4.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a2b3b5d4ffdad46be261e3a01b0876d9fe44d40c742a852d4b3161df66a972e0 3 | size 284 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_human_twostep1.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f5158c0cd840a39564b518c8b16267fc20484fe1bbea09e7e4840c67e69ffdf9 3 | size 634 4 | -------------------------------------------------------------------------------- /openloop/results/baselines_openloop_human_twostep2.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:03139fdacb71282842bf45b83c93d47cef7d8e73c80188faf15e997d2f5af341 3 | size 397 4 | -------------------------------------------------------------------------------- /results/all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:92e4f0d8da485d85a4477ec0ec91cceeeb9c0021e6d0b175cee2860a33969011 3 | size 2972 4 | -------------------------------------------------------------------------------- /results/all_data_unsloth-Meta-Llama-3.1-8B-bnb-4bit.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ed63bd233a153fa0f4d60349190927c876bf1cc9b63d2406d0e4466ace52c09 3 | size 2971 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/tst_centaur_alignment.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:66b0fca5d3e4a26cb86b30c613294b5be94f0806079e23a7230e63c29c4768d4 3 | size 7523 4 | -------------------------------------------------------------------------------- /ceiling/results/marcelbinz-Llama-3.1-Centaur-70B-adapter.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f749c52916928e5be545082b24cd4eec7c3962f517ce8716ae5a8da2b7200ddc 3 | size 180 4 | -------------------------------------------------------------------------------- /ceiling/ruggeri2022globalizability/prompts_zeroshot.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:50311d24914369a3bc0a8b57a301dd1ef84d50922f43446a92533c274730ce0d 3 | size 5869304 4 | -------------------------------------------------------------------------------- /generalization/additional_experiments/awad2018moral.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5188ee3a16db90ac91643189096080be4ccb4211af045b1fb7130942c28e8de7 3 | size 52407938 4 | -------------------------------------------------------------------------------- /generalization/additional_experiments/xu2021novelty.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:94d1039d82f5d2bb3865f74ee307f5846ec7ecb04aca2fafb2ef32f11aa994b3 3 | size 243078 4 | -------------------------------------------------------------------------------- /results/all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:305c94eeddca68bf278387ccf3cc1986777d0f9d2b69ed306a761f825c1f13da 3 | size 2980 4 | -------------------------------------------------------------------------------- /results/all_data_marcelbinz-Llama-3.1-Centaur-8B-adapter.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2b00e52ed24e2751351a4e71347fe84380c0136547d2a85aa45c01c8033b2ce6 3 | size 2979 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_baselines.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7c237c07331878adb00956917c815416b02741f1d700a42472f0b997352b63a2 3 | size 7331448 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_llama_alignment.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:309e7df1feb4913380b2d7590c091b0025e71a72b76aed3d39320cf01cbf2da4 3 | size 706383 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_llama_alignment.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:869ccc8aa6e2b46aca59ed84d87c67b1b5403be5285ac780c52dca26e6664a19 3 | size 790196 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_centaur_alignment.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2e4f3dd7d51a78f83946b1c8d003c44392ae3478bcabf91b57fdba041976cb26 3 | size 712586 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_centaur_alignment.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b9a6df6d5d4fb3446b86b1c3c3b92874ec1e60f1bacaa4e4e27d7cae010f4777 3 | size 789820 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_cognitive_alignment.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:61de1fe35c305694dc9b59d28ec4427baa2bc3701c15da413179a2ec74c923dc 3 | size 147658 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_cognitive_alignment.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1f0ab5b62076da3b8ec683e3befe1a77c7b285179ca7cae9165979cab533b959 3 | size 158852 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_random_alignment.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8d8ed5717fddb198b9910fe78767739f27cbc561e8d1b2964f8c7c533f4854bd 3 | size 710310 4 | -------------------------------------------------------------------------------- /results/feher2023rethinking/schaefer_tst_random_alignment.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:282bb577b6bbaa95e5a2bffd3cdd761c9b4b4d05cd4a7053f0c83f90c63492bc 3 | size 791736 4 | -------------------------------------------------------------------------------- /generalization/additional_experiments/akata2023repeatedgames.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b358101ef870cac6a6767872524b499a11d542fa2f5d723f47d048a891aa47ea 3 | size 885715 4 | -------------------------------------------------------------------------------- /generalization/additional_experiments/singh2022representing.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:33601eed9bade3ddb79d626d44d1491c8d7820b840b0e044550c0a7691700caa 3 | size 6637835 4 | -------------------------------------------------------------------------------- /generalization/additional_experiments/demircan2024evaluatingcategory.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1ad7521657abb33680c05d40fe1b543326d791c51769d6d96a4f646ac7e0f7be 3 | size 1209802 4 | -------------------------------------------------------------------------------- /generalization/additional_experiments/demircan2024evaluatingreward.jsonl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:35ec20699152307cce2aa088f7ffbcb6fd0d6fce43a583aad878ee9c66bbe3fe 3 | size 792361 4 | -------------------------------------------------------------------------------- /camera_ready/data/log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b3fd01fac583c640dae6247d3199ce4e562694328a2453295f3d4560c8636c0d 3 | size 40140 4 | -------------------------------------------------------------------------------- /openloop/baar2021latent/simulation_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b90a230e8c5980243d66d35024fec5c2d820b2fa5668342792a841bb64fb8209 3 | size 540010 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fae90f0538ebbdbf4334bdb9e8242bc6cf835f1f143ba9fd1abfeaaf6889c235 3 | size 6291488 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d235f02dc7cc21c6f72133e670d88b361d5df36eb21dbc8e78751d7874f7dd99 3 | size 6329820 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4173f43132c5fd543e402ab4213beece92f36860667761fdba877bffb7eee3fd 3 | size 6281204 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:270917620da8ca9482eeedcb97ef025b3f82ac8a240d891e3f885ab69bad3e7f 3 | size 6303600 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_unsloth-Hermes-3-Llama-3.1-70B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c2a10acb56d830db8ebece31d1e4f0c2a7a3a7d5c25304a41a06a5008f94f986 3 | size 6302128 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_unsloth-Reflection-Llama-3.1-70B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4b0d71370131c432e1bb1f605f81e53c5b8c66df60ba580cfc051f12f5a0a4eb 3 | size 6310008 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_unsloth-llama-3-70b-Instruct-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d47dfae847d4569031ca64b4cc83cf6ec301dcf4f265147a1b8b8b0f75f060ec 3 | size 6210024 4 | -------------------------------------------------------------------------------- /generalization/results/generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:743f2ddfccb6da8b7280f1a9f453a39bfc40b8b5fc0020f733c24dc3ec84408b 3 | size 3426208 4 | -------------------------------------------------------------------------------- /generalization/results/generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1841f96498d4754fa6f06d6a1285a0844b1fc2c9b1dac04586596e65098943a0 3 | size 3424924 4 | -------------------------------------------------------------------------------- /results/custom_metrics_full_log_likelihoods_unsloth-Llama-3.1-Nemotron-70B-Instruct-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a0b1fe348c4048946a9783b4458eda1d9b971650a1526563bec2161f751190ce 3 | size 6310676 4 | -------------------------------------------------------------------------------- /generalization/results/generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f6e7d625d2ce1104a5c7dbfc55630680a15490f74ac10d6843ce81a008ce303e 3 | size 3417204 4 | -------------------------------------------------------------------------------- /generalization/results/generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a185b393f0a449609a4a50b6dc392035e166a08b75c8454212f7a99fb04573e4 3 | size 3437360 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.safetensors filter=lfs diff=lfs merge=lfs -text 2 | *.pth filter=lfs diff=lfs merge=lfs -text 3 | *.pt filter=lfs diff=lfs merge=lfs -text 4 | *.csv filter=lfs diff=lfs merge=lfs -text 5 | *.jsonl filter=lfs diff=lfs merge=lfs -text 6 | -------------------------------------------------------------------------------- /generalization/results/additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0d192406b91a7af9794fb02f8e09db6396fd086ab76066c01e2d3c653fc8b82d 3 | size 1028876 4 | -------------------------------------------------------------------------------- /generalization/results/additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:91fa9ac9974b6e40c61c66ca566d64d111c87420a59330a38297ec33a740b942 3 | size 1046152 4 | -------------------------------------------------------------------------------- /generalization/results/additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:13c4a5a60dc83a37c6769c02b4ad22eebaeee33fd13ee146dc726cd58a49ab60 3 | size 1049760 4 | -------------------------------------------------------------------------------- /generalization/results/additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c07067a7b3cdd94d4c60a8f2edc943378b4984e8d0e6a249cc2bc9d80cc5104 3 | size 1034396 4 | -------------------------------------------------------------------------------- /scripts/cluster_privileged.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J privileged 4 | #SBATCH -p cpu_p 5 | #SBATCH --qos cpu_normal 6 | #SBATCH --mem=32G 7 | #SBATCH -t 48:00:00 8 | #SBATCH --nice=1000 9 | #SBATCH --cpus-per-task=32 10 | 11 | source activate new_python 12 | 13 | cd ../generalization/ 14 | python privileged.py 15 | -------------------------------------------------------------------------------- /scripts/cluster_llama_8b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J CENTaUR2 4 | #SBATCH -p gpu_p 5 | #SBATCH --qos gpu_normal 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --mem=80G 8 | #SBATCH -t 48:00:00 9 | #SBATCH --constraint=a100_80gb 10 | #SBATCH --nice=10000 11 | #SBATCH --cpus-per-task=20 12 | 13 | source activate unsloth_env2 14 | 15 | cd .. 16 | python test_adapter.py --model unsloth/Meta-Llama-3.1-8B-bnb-4bit 17 | python test_adapter_custom_metrics.py --model unsloth/Meta-Llama-3.1-8B-bnb-4bit 18 | 19 | cd generalization/ 20 | python generalization.py --model unsloth/Meta-Llama-3.1-8B-bnb-4bit 21 | python generalization_custom_metrics.py --model unsloth/Meta-Llama-3.1-8B-bnb-4bit 22 | 23 | cd .. 24 | python merge.py --model unsloth-Meta-Llama-3.1-8B-bnb-4bit 25 | -------------------------------------------------------------------------------- /scripts/cluster_llama_70b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J CENTaUR2 4 | #SBATCH -p gpu_p 5 | #SBATCH --qos gpu_normal 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --mem=80G 8 | #SBATCH -t 48:00:00 9 | #SBATCH --constraint=a100_80gb 10 | #SBATCH --nice=10000 11 | #SBATCH --cpus-per-task=20 12 | 13 | source activate unsloth_env2 14 | 15 | cd .. 16 | python test_adapter.py --model unsloth/Meta-Llama-3.1-70B-bnb-4bit 17 | python test_adapter_custom_metrics.py --model unsloth/Meta-Llama-3.1-70B-bnb-4bit 18 | 19 | cd generalization/ 20 | python generalization.py --model unsloth/Meta-Llama-3.1-70B-bnb-4bit 21 | python generalization_custom_metrics.py --model unsloth/Meta-Llama-3.1-70B-bnb-4bit 22 | 23 | cd .. 24 | python merge.py --model unsloth-Meta-Llama-3.1-70B-bnb-4bit 25 | -------------------------------------------------------------------------------- /metabench/metabench.py: -------------------------------------------------------------------------------- 1 | import json 2 | from lm_eval.evaluator import simple_evaluate 3 | 4 | if __name__ == "__main__": 5 | results = simple_evaluate(model="hf", model_args="pretrained=Centaur-3.1/1_finetuning/centaur2-final-llama/checkpoint-2000", tasks="metabench", num_fewshot=0, batch_size=8) 6 | 7 | with open("Centaur-3.1/4_benchmarks/metabench/centaur-2000-results.json", "w") as outfile: 8 | json.dump(results["results"], outfile) 9 | 10 | results = simple_evaluate(model="hf", model_args="pretrained=unsloth/Meta-Llama-3.1-70B-bnb-4bit", tasks="metabench", num_fewshot=0, batch_size=8) 11 | 12 | with open("Centaur-3.1/4_benchmarks/metabench/base-llama-3_1-70B-results.json", "w") as outfile: 13 | json.dump(results["results"], outfile) -------------------------------------------------------------------------------- /scripts/cluster_8b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J CENTaUR2 4 | #SBATCH -p gpu_p 5 | #SBATCH --qos gpu_normal 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --mem=80G 8 | #SBATCH -t 48:00:00 9 | #SBATCH --constraint=a100_80gb 10 | #SBATCH --nice=10000 11 | #SBATCH --cpus-per-task=20 12 | 13 | source activate unsloth_env2 14 | 15 | cd .. 16 | python test_adapter.py --model marcelbinz/Llama-3.1-Centaur-8B-adapter 17 | python test_adapter_custom_metrics.py --model marcelbinz/Llama-3.1-Centaur-8B-adapter 18 | 19 | cd generalization/ 20 | python generalization.py --model marcelbinz/Llama-3.1-Centaur-8B-adapter 21 | python generalization_custom_metrics.py --model marcelbinz/Llama-3.1-Centaur-8B-adapter 22 | 23 | cd .. 24 | python merge.py --model marcelbinz-Llama-3.1-Centaur-8B-adapter 25 | -------------------------------------------------------------------------------- /scripts/cluster_70b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J CENTaUR2 4 | #SBATCH -p gpu_p 5 | #SBATCH --qos gpu_normal 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --mem=80G 8 | #SBATCH -t 48:00:00 9 | #SBATCH --constraint=a100_80gb 10 | #SBATCH --nice=10000 11 | #SBATCH --cpus-per-task=20 12 | 13 | source activate unsloth_env2 14 | 15 | cd .. 16 | python test_adapter.py --model marcelbinz/Llama-3.1-Centaur-70B-adapter 17 | python test_adapter_custom_metrics.py --model marcelbinz/Llama-3.1-Centaur-70B-adapter 18 | 19 | cd generalization/ 20 | python generalization.py --model marcelbinz/Llama-3.1-Centaur-70B-adapter 21 | python generalization_custom_metrics.py --model marcelbinz/Llama-3.1-Centaur-70B-adapter 22 | 23 | cd .. 24 | python merge.py --model marcelbinz-Llama-3.1-Centaur-70B-adapter 25 | -------------------------------------------------------------------------------- /plots/fig9.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import scienceplots 4 | import seaborn as sns 5 | 6 | 7 | Bs = torch.load('../contamination/results/Bs.pth') 8 | 9 | log_Bs = torch.log(Bs) 10 | 11 | plt.figure(figsize=(7.08661, 1.9)) 12 | plt.style.use(['nature']) 13 | plt.scatter(torch.arange(len(log_Bs)), log_Bs, color='#69005f') 14 | plt.axhline(y=1, color='grey', linestyle='--', linewidth=1.0) 15 | 16 | plt.text(len(log_Bs), 1.1, 'potentially contaminated', fontsize=6, color='red', horizontalalignment='right') 17 | plt.text(len(log_Bs), 0.8, 'not contaminated', fontsize=6, color='green', horizontalalignment='right') 18 | plt.ylabel(r'$\log B$') 19 | plt.xlabel('Experiment') 20 | plt.ylim(-1.6, 1.1) 21 | plt.xlim(-0.5, len(log_Bs)+0.1) 22 | sns.despine() 23 | plt.tight_layout() 24 | plt.savefig('figures/fig9.pdf', bbox_inches='tight') 25 | plt.show() 26 | -------------------------------------------------------------------------------- /scripts/cluster_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -J CENTaUR2 4 | #SBATCH -p gpu_p 5 | #SBATCH --qos gpu_long 6 | #SBATCH --gres=gpu:1 7 | #SBATCH --mem=160G 8 | #SBATCH -t 96:00:00 9 | #SBATCH --constraint=a100_80gb 10 | #SBATCH --nice=10000 11 | #SBATCH --cpus-per-task=20 12 | 13 | source activate unsloth_env2 14 | 15 | python finetune.py \ 16 | --seed 100 \ 17 | --model_name_or_path "unsloth/Meta-Llama-3.1-70B-bnb-4bit" \ 18 | --max_seq_len 32768 \ 19 | --num_train_epochs 5 \ 20 | --log_level "info" \ 21 | --logging_strategy "steps" \ 22 | --logging_steps 1 \ 23 | --evaluation_strategy "steps" \ 24 | --eval_steps 999999 \ 25 | --save_strategy "steps" \ 26 | --save_steps 100 \ 27 | --learning_rate 5e-5 \ 28 | --optim "adamw_8bit" \ 29 | --lr_scheduler_type "cosine" \ 30 | --weight_decay 0.01 \ 31 | --warmup_steps 100 \ 32 | --output_dir "centaur2-final-llama" \ 33 | --per_device_train_batch_size 1 \ 34 | --per_device_eval_batch_size 1 \ 35 | --gradient_accumulation_steps 32 36 | -------------------------------------------------------------------------------- /results/metabench/centaur-2000-results.json: -------------------------------------------------------------------------------- 1 | {"metabench": {"acc,none": 0.6386027185594687, "acc_stderr,none": 0.01728429558982042, "alias": "metabench"}, "metabench_arc": {"alias": " - metabench_arc", "acc,none": 0.6827586206896552, "acc_stderr,none": 0.03878352372138622, "acc_norm,none": 0.7379310344827587, "acc_norm_stderr,none": 0.03664666337225258}, "metabench_gsm8k": {"alias": " - metabench_gsm8k", "exact_match,strict-match": 0.7383966244725738, "exact_match_stderr,strict-match": 0.028609516716994934, "exact_match,flexible-extract": 0.7426160337552743, "exact_match_stderr,flexible-extract": 0.02845882099146029}, "metabench_hellaswag": {"alias": " - metabench_hellaswag", "acc,none": 0.5161290322580645, "acc_stderr,none": 0.05210147439272567, "acc_norm,none": 0.7956989247311828, "acc_norm_stderr,none": 0.04203545939892303}, "metabench_mmlu": {"alias": " - metabench_mmlu", "acc,none": 0.84375, "acc_stderr,none": 0.03725247254245437}, "metabench_truthfulqa": {"alias": " - metabench_truthfulqa", "acc,none": 0.2857142857142857, "acc_stderr,none": 0.03652214232606523}, "metabench_winogrande": {"alias": " - metabench_winogrande", "acc,none": 0.8646616541353384, "acc_stderr,none": 0.029774643218898812}} -------------------------------------------------------------------------------- /results/metabench/base-llama-3_1-70B-results.json: -------------------------------------------------------------------------------- 1 | {"metabench": {"acc,none": 0.6224153850218629, "acc_stderr,none": 0.016368421562007154, "alias": "metabench"}, "metabench_arc": {"alias": " - metabench_arc", "acc,none": 0.6896551724137931, "acc_stderr,none": 0.03855289616378949, "acc_norm,none": 0.7448275862068966, "acc_norm_stderr,none": 0.03632984052707842}, "metabench_gsm8k": {"alias": " - metabench_gsm8k", "exact_match,strict-match": 0.759493670886076, "exact_match_stderr,strict-match": 0.027820781981149675, "exact_match,flexible-extract": 0.759493670886076, "exact_match_stderr,flexible-extract": 0.027820781981149675}, "metabench_hellaswag": {"alias": " - metabench_hellaswag", "acc,none": 0.5161290322580645, "acc_stderr,none": 0.05210147439272567, "acc_norm,none": 0.7849462365591398, "acc_norm_stderr,none": 0.042835078355547535}, "metabench_mmlu": {"alias": " - metabench_mmlu", "acc,none": 0.84375, "acc_stderr,none": 0.03725247254245437}, "metabench_truthfulqa": {"alias": " - metabench_truthfulqa", "acc,none": 0.17532467532467533, "acc_stderr,none": 0.030740951540481013}, "metabench_winogrande": {"alias": " - metabench_winogrande", "acc,none": 0.8872180451127819, "acc_stderr,none": 0.027532650801261535}} -------------------------------------------------------------------------------- /ceiling/trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import schedulefree 5 | from tqdm import tqdm 6 | 7 | class Trainer: 8 | def __init__(self, model, num_iter=1000): 9 | self.model = model 10 | self.num_iter = num_iter 11 | self.optimizer = schedulefree.AdamWScheduleFree(self.model.parameters(), lr=0.1) 12 | 13 | def fit_and_evaluate(self, train_df, eval_df): 14 | ### PREPROCESS DATA ### 15 | train_data, eval_data = self.model.preprocess_data(train_df, eval_df) 16 | 17 | ### FITTING ### 18 | self.model.train() 19 | self.optimizer.train() 20 | for _ in tqdm(range(self.num_iter)): 21 | self.optimizer.zero_grad() 22 | logits = self.model(train_data) 23 | loss = F.cross_entropy(logits.flatten(0, -2), train_data['choice'].flatten().long()) 24 | loss.backward() 25 | print(loss.item(), flush=True) 26 | self.optimizer.step() 27 | 28 | ### EVALUATION ### 29 | self.model.eval() 30 | self.optimizer.eval() 31 | logits = self.model(eval_data) 32 | return F.cross_entropy(logits.flatten(0, -2), eval_data['choice'].flatten().long()) 33 | -------------------------------------------------------------------------------- /openloop/trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import schedulefree 5 | from tqdm import tqdm 6 | 7 | class Trainer: 8 | def __init__(self, model, num_iter=1000): 9 | self.model = model 10 | self.num_iter = num_iter 11 | self.optimizer = schedulefree.AdamWScheduleFree(self.model.parameters(), lr=0.1) 12 | 13 | def fit_and_evaluate(self, train_df, eval_df): 14 | ### PREPROCESS DATA ### 15 | train_data, eval_data = self.model.preprocess_data(train_df, eval_df) 16 | 17 | ### FITTING ### 18 | self.model.train() 19 | self.optimizer.train() 20 | for _ in tqdm(range(self.num_iter)): 21 | self.optimizer.zero_grad() 22 | logits = self.model(train_data) 23 | loss = F.cross_entropy(logits.flatten(0, -2), train_data['choice'].flatten().long()) 24 | loss.backward() 25 | print(loss.item(), flush=True) 26 | self.optimizer.step() 27 | 28 | ### EVALUATION ### 29 | self.model.eval() 30 | self.optimizer.eval() 31 | logits = self.model(eval_data) 32 | return F.cross_entropy(logits.flatten(0, -2), eval_data['choice'].flatten().long()) 33 | -------------------------------------------------------------------------------- /generalization/trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | import schedulefree 5 | from tqdm import tqdm 6 | 7 | class Trainer: 8 | def __init__(self, model, num_iter=1000): 9 | self.model = model 10 | self.num_iter = num_iter 11 | self.optimizer = schedulefree.AdamWScheduleFree(self.model.parameters(), lr=0.1) 12 | 13 | def fit_and_evaluate(self, train_df, eval_df): 14 | ### PREPROCESS DATA ### 15 | train_data, eval_data = self.model.preprocess_data(train_df, eval_df) 16 | 17 | ### FITTING ### 18 | self.model.train() 19 | self.optimizer.train() 20 | for _ in tqdm(range(self.num_iter)): 21 | self.optimizer.zero_grad() 22 | logits = self.model(train_data) 23 | loss = F.cross_entropy(logits.flatten(0, -2), train_data['choice'].flatten().long()) 24 | loss.backward() 25 | print(loss.item(), flush=True) 26 | self.optimizer.step() 27 | 28 | ### EVALUATION ### 29 | self.model.eval() 30 | self.optimizer.eval() 31 | logits = self.model(eval_data) 32 | return F.cross_entropy(logits.flatten(0, -2), eval_data['choice'].flatten().long()) 33 | -------------------------------------------------------------------------------- /neural/combine.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import glob 4 | from natsort import natsorted 5 | import pandas as pd 6 | import numpy as np 7 | 8 | ds = [] 9 | for i in [0, 1, 2, 3, 4]: 10 | d = {} 11 | files = natsorted(glob.glob('fits/model=Llama-3.1-Centaur-70B_layer=' + str(i) + '_roi=*')) 12 | for file in files: 13 | r2_scores = torch.load(file) 14 | d[file.removeprefix('fits/model=Llama-3.1-Centaur-70B_layer=' + str(i) + '_roi=').removesuffix('.pth')] = r2_scores[:, :, :, 1].mean() 15 | print(len(d)) 16 | ds.append(d) 17 | df = pd.DataFrame(ds) 18 | df.to_csv('../results/feher2023rethinking/tst_centaur_alignment.csv', index=False) 19 | twostep_centaur = df.values.mean(1) 20 | print(twostep_centaur) 21 | 22 | ds = [] 23 | for i in [0, 1, 2, 3, 4]: 24 | d = {} 25 | files = natsorted(glob.glob('fits/model=Meta-Llama-3.1-70B_layer=' + str(i) + '_roi=*')) 26 | for file in files: 27 | r2_scores = torch.load(file) 28 | d[file.removeprefix('fits/model=Meta-Llama-3.1-70B_layer=' + str(i) + '_roi=').removesuffix('.pth')] = r2_scores[:, :, :, 1].mean() 29 | print(len(d)) 30 | ds.append(d) 31 | df = pd.DataFrame(ds) 32 | df.to_csv('../results/feher2023rethinking/tst_llama_alignment.csv', index=False) 33 | twostep_llama = df.values.mean(1) 34 | print(twostep_llama) 35 | -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import argparse 4 | import torch 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--model", type=str, required=True) 9 | args = parser.parse_args() 10 | 11 | df_test = pd.read_csv('results/' + args.model + '.csv', index_col=0) 12 | df_test['custom_metric'] = False 13 | df_test['unseen'] = 'participants' 14 | 15 | df_test_custom_metrics = pd.read_csv('results/custom_metrics_' + args.model + '.csv', index_col=0) 16 | df_test_custom_metrics['custom_metric'] = True 17 | df_test_custom_metrics['unseen'] = 'participants' 18 | 19 | df_generalization = pd.read_csv('generalization/results/' + args.model + '.csv', index_col=0) 20 | df_generalization['custom_metric'] = False 21 | df_generalization['unseen'] = 'experiments' 22 | 23 | df_generalization_custom_metrics = pd.read_csv('generalization/results/custom_metrics_' + args.model + '.csv', index_col=0) 24 | df_generalization_custom_metrics['custom_metric'] = True 25 | df_generalization_custom_metrics['unseen'] = 'experiments' 26 | 27 | df = pd.concat([df_test, df_test_custom_metrics, df_generalization, df_generalization_custom_metrics]) 28 | print(df) 29 | 30 | df.to_csv('results/all_data_' + args.model.replace('/', '-') + '.csv') 31 | -------------------------------------------------------------------------------- /generalization/privileged.py: -------------------------------------------------------------------------------- 1 | from models import DunningKruger, RescorlaWagnerModel, DualSystemsModel 2 | from trainers import Trainer 3 | import pandas as pd 4 | import torch 5 | import os 6 | import numpy as np 7 | 8 | experiments = [ 9 | {'path': 'feher2020humans/exp1.csv', 'model': DualSystemsModel}, 10 | {'path': 'dubois2022value/exp1.csv', 'model': RescorlaWagnerModel}, 11 | {'path': 'jansen2021logic/exp1.csv', 'model': DunningKruger}, 12 | ] 13 | 14 | data = [] 15 | for index in range(len(experiments)): 16 | exp_name = experiments[index]['path'] 17 | print(exp_name) 18 | 19 | df = pd.read_csv(exp_name) 20 | 21 | num_splits = 10 22 | splits = np.array_split(df['participant'].unique(), num_splits) 23 | 24 | predictive_nll = 0 25 | for split in splits: 26 | train_df = df[~df['participant'].isin(split.tolist())] 27 | eval_df = df[df['participant'].isin(split.tolist())] 28 | 29 | trainer = Trainer(experiments[index]['model']()) 30 | predictive_nll += trainer.fit_and_evaluate(train_df, eval_df).item() 31 | 32 | predictive_nll = predictive_nll / num_splits 33 | 34 | print(predictive_nll) 35 | 36 | x = exp_name.split("/") 37 | data.append([x[-2], x[-1], predictive_nll]) 38 | 39 | df = pd.DataFrame(data, columns=['task', 'exp', 'nll']) 40 | print(df) 41 | df.to_csv('results/privileged.csv') 42 | -------------------------------------------------------------------------------- /run_minimal.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import transformers 3 | 4 | model, tokenizer = FastLanguageModel.from_pretrained( 5 | model_name = "marcelbinz/Llama-3.1-Centaur-70B-adapter", 6 | max_seq_length = 32768, 7 | dtype = None, 8 | load_in_4bit = True, 9 | ) 10 | FastLanguageModel.for_inference(model) 11 | 12 | pipe = transformers.pipeline( 13 | "text-generation", 14 | model=model, 15 | tokenizer=tokenizer, 16 | trust_remote_code=True, 17 | pad_token_id=0, 18 | do_sample=True, 19 | temperature=1.0, 20 | max_new_tokens=1, 21 | ) 22 | 23 | prompt = "You will be presented with triplets of objects, which will be assigned to the keys H, Y, and E.\n" \ 24 | "In each trial, please indicate which object you think is the odd one out by pressing the corresponding key.\n" \ 25 | "In other words, please choose the object that is the least similar to the other two.\n\n" \ 26 | "H: plant, Y: chainsaw, and E: periscope. You press <>.\n" \ 27 | "H: tostada, Y: leaf, and E: sail. You press <>.\n" \ 28 | "H: clock, Y: crystal, and E: grate. You press <>.\n" \ 29 | "H: barbed wire, Y: kale, and E: sweater. You press <>.\n" \ 30 | "H: raccoon, Y: toothbrush, and E: ice. You press <<" 31 | 32 | print(prompt) 33 | 34 | choice = pipe(prompt)[0]['generated_text'][len(prompt):] 35 | print(choice) 36 | -------------------------------------------------------------------------------- /openloop/baar2021latent/stats.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | 5 | df = pd.read_csv('gameDat.csv') 6 | df['correct'] = df['CorrAns'] == df['GivenAns'] 7 | 8 | df_nat = df[df['Variant'] == 'nat'] 9 | df_inv = df[df['Variant'] == 'inv'] 10 | 11 | print(df_nat['correct'].mean()) 12 | print(df_inv['correct'].mean()) 13 | 14 | print(df_nat.groupby('subID')['correct'].mean().values) 15 | print(df_inv.groupby('subID')['correct'].mean().values) 16 | 17 | df = pd.read_csv('simulation_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv') 18 | df['correct'] = df['CorrAns'] == df['GivenAns'] 19 | df = df.groupby('subID').filter(lambda x: ~((x.GivenAns == 'coop').all())) 20 | df = df.groupby('subID').filter(lambda x: ~((x.GivenAns == 'def').all())) 21 | 22 | df_nat = df[df['Variant'] == 'nat'] 23 | df_inv = df[df['Variant'] == 'inv'] 24 | 25 | print(df_nat['correct'].mean()) 26 | print(df_inv['correct'].mean()) 27 | 28 | print(df_nat.groupby('subID')['correct'].mean().values) 29 | print(df_inv.groupby('subID')['correct'].mean().values) 30 | 31 | sns.swarmplot(data=[df_nat.groupby('subID')['correct'].mean().values * 100, df_inv.groupby('subID')['correct'].mean().values * 100]) 32 | ax = sns.violinplot(data=[df_nat.groupby('subID')['correct'].mean().values * 100, df_inv.groupby('subID')['correct'].mean().values * 100]) 33 | plt.setp(ax.collections, alpha=.5) 34 | plt.xticks([0, 1], ['Human\nstrategies', 'Artificial\nstrategies']) 35 | plt.axhline(y=50, color='grey', linestyle='--', linewidth=1.0) 36 | plt.ylabel('Accuracy (%)') 37 | sns.despine() 38 | plt.show() 39 | -------------------------------------------------------------------------------- /ceiling/ceiling.py: -------------------------------------------------------------------------------- 1 | from models import NoiseCeiling, DunningKruger 2 | from trainers import Trainer 3 | import pandas as pd 4 | import torch 5 | from datasets import load_dataset 6 | import os 7 | import numpy as np 8 | 9 | 10 | experiments = [ 11 | {'prefix': '', 'path': 'ruggeri2022globalizability/exp1.csv', 'model': NoiseCeiling(UID='question')}, 12 | {'prefix': '', 'path': 'peterson2021using/exp1.csv', 'model': NoiseCeiling(UID='uniqueID')}, 13 | {'prefix': '../openloop/', 'path': 'jansen2021dunningkruger/exp1.csv', 'model': DunningKruger()}, 14 | 15 | ] 16 | 17 | data = [] 18 | for index in range(len(experiments)): 19 | exp_name = experiments[index]['path'] 20 | prefix = experiments[index]['prefix'] 21 | print(prefix + exp_name) 22 | 23 | df = pd.read_csv(prefix + exp_name) 24 | 25 | train_dataset = load_dataset("marcelbinz/Psych-101")['train'].filter(lambda example: example['experiment'].startswith(exp_name)) 26 | eval_dataset = load_dataset("marcelbinz/Psych-101-test")['test'].filter(lambda example: example['experiment'].startswith(exp_name)) 27 | 28 | train_participants = list(map(int, train_dataset['participant'])) 29 | eval_participants = list(map(int, eval_dataset['participant'])) 30 | 31 | train_df = df[df['participant'].isin(train_participants)] 32 | eval_df = df[df['participant'].isin(eval_participants)] 33 | 34 | trainer = Trainer(experiments[index]['model']) 35 | predictive_nll = trainer.fit_and_evaluate(train_df, eval_df).item() 36 | 37 | print(predictive_nll) 38 | 39 | x = exp_name.split("/") 40 | data.append([x[-2], x[-1], predictive_nll]) 41 | 42 | df = pd.DataFrame(data, columns=['task', 'exp', 'nll']) 43 | print(df) 44 | df.to_csv('results/ceiling.csv') 45 | -------------------------------------------------------------------------------- /plots/tab1_new.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import math 3 | from functools import reduce 4 | import torch 5 | import numpy as np 6 | 7 | df_exp = pd.read_csv('../experiments.csv', sep=';') 8 | print(df_exp) 9 | centaur_70b = torch.load('../results/custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 10 | llama_70b = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 11 | baselines_full = torch.load('../results/custom_metrics_full_log_likelihoods_baselines.pth') 12 | 13 | papers = [] 14 | ll_centaur = {} 15 | ll_llama = {} 16 | ll_baseline = {} 17 | 18 | for key in centaur_70b.keys(): 19 | print(key) 20 | exp_name = df_exp[df_exp['path'] == key + '/']['task_name'].item() 21 | if exp_name in ll_centaur.keys(): 22 | ll_centaur[exp_name] = np.concatenate((ll_centaur[exp_name], centaur_70b[key])) 23 | else: 24 | ll_centaur[exp_name] = centaur_70b[key] 25 | 26 | if exp_name in ll_llama.keys(): 27 | ll_llama[exp_name] = np.concatenate((ll_llama[exp_name], llama_70b[key])) 28 | else: 29 | ll_llama[exp_name] = llama_70b[key] 30 | 31 | if key in baselines_full.keys(): 32 | if exp_name in ll_baseline.keys(): 33 | ll_llama[exp_name] = np.concatenate((ll_baseline[exp_name], baselines_full[key])) 34 | else: 35 | ll_baseline[exp_name] = baselines_full[key] 36 | 37 | print(papers) 38 | 39 | prompt = '\\begin{table}[]\n' 40 | prompt += '\\centering \n' 41 | prompt += '\\begin{tabular}{@{}lccc@{}} \n' 42 | prompt += '\\toprule \n' 43 | prompt += "\\textbf{Experiment} & \\textbf{Centaur} & \\textbf{Llama} & \\textbf{Cognitive model} \\\\ \n" 44 | prompt += '\\midrule \n' 45 | for key in ll_centaur.keys(): 46 | baseline_to_nan = ll_baseline[key].mean() if key in ll_baseline else np.nan 47 | prompt += str(key) + ' & ' + str(format(ll_centaur[key].mean(), '.4f')) + ' & ' + str(format(ll_llama[key].mean(), '.4f')) + ' & ' + str(format(baseline_to_nan, '.4f')) + ' \\\\ \n' 48 | prompt += '\\bottomrule \\\\ \n' 49 | prompt += '\\end{tabular} \n' 50 | prompt += '\\caption{Full negative log-likelihoods results on held-out participants.}\n' 51 | prompt += '\\label{tab:tab2} \n' 52 | prompt += '\\end{table}' 53 | print(prompt) 54 | -------------------------------------------------------------------------------- /neural/extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from unsloth import FastLanguageModel 3 | from datasets import load_dataset 4 | import sys 5 | import argparse 6 | import torch 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | layers = [0, 10, 20, 30, 40] # change as needed 14 | 15 | dataset = load_dataset( 16 | 'json', 17 | data_files="feher2023rethinking/prompts.jsonl" 18 | ) 19 | 20 | model, tokenizer = FastLanguageModel.from_pretrained( 21 | model_name = args.model, 22 | max_seq_length = 32768, 23 | dtype = None, 24 | load_in_4bit = True, 25 | ) 26 | FastLanguageModel.for_inference(model) 27 | 28 | left_token = " <<" 29 | l_id = tokenizer(left_token).input_ids[-1] 30 | print(l_id) 31 | 32 | zero_token = "0" 33 | zero_id = tokenizer(zero_token).input_ids[-1] 34 | print(zero_id) 35 | 36 | one_token = "1" 37 | one_id = tokenizer(one_token).input_ids[-1] 38 | print(one_id) 39 | 40 | # loop over prompts 41 | for i, prompt in enumerate(dataset['train']): 42 | with torch.no_grad(): 43 | tokenized_prompt = tokenizer(prompt['text'], return_tensors='pt') 44 | tokenized_prompt.input_ids = tokenized_prompt.input_ids[:, :32768] 45 | relevant_tokens_l = (tokenized_prompt.input_ids == l_id).squeeze() 46 | relevant_tokens_0 = (tokenized_prompt.input_ids == zero_id).squeeze() 47 | relevant_tokens_1 = (tokenized_prompt.input_ids == one_id).squeeze() 48 | 49 | relevant_tokens = torch.logical_or(relevant_tokens_l, torch.logical_or(relevant_tokens_0, relevant_tokens_1)) 50 | print(relevant_tokens.float().sum()) 51 | print(tokenized_prompt.input_ids.shape) 52 | result_model = model(tokenized_prompt.input_ids, output_hidden_states=True, return_dict=True, max_new_tokens=1) 53 | 54 | representations = [] 55 | for j, layer in enumerate(result_model['hidden_states']): 56 | if j in layers: 57 | representations.append(layer[0, relevant_tokens, :].to("cpu").half()) 58 | torch.save(representations, 'results/model=' + args.model.replace('/', '-') + '_participant=' + str(i) + '.pth') 59 | -------------------------------------------------------------------------------- /results/CogBench/performance.csv: -------------------------------------------------------------------------------- 1 | ,Agent,Probabilistic Reasoning,Horizon Task,Restless Bandit,Instrumental Learning,Two Step Task,BART,Agent_ci,Probabilistic Reasoning_ci,Horizon Task_ci,Restless Bandit_ci,Instrumental Learning_ci,Two Step Task_ci,BART_ci 2 | 0,GPT-4,0.9023843178003348,2.5616718729423607,1.5067004730172278,1.5798778527804593,1.5133905675103176,0.38500339097999314,GPT-4,0.1321855921175766,0.32601478190668737,0.1206790299903343,0.723622051657155,0.9762550781596039,0.1060708390872204 3 | 1,GPT-3.5,0.7486550712024966,1.9910866496392539,0.4809112512555779,0.7232401157184193,1.4910692384881277,0.014111693664084301,GPT-3.5,0.13626983634408002,0.3414222650768884,0.22915531981455028,0.7230613989438085,0.976321279257046,0.09402914156603344 4 | 2,GPT-3,0.7828931391829572,1.4893719913869494,0.12204043540423505,0.4130504660880744,1.100445980599848,0.07426246185147506,GPT-3,0.12787059500438888,0.19260381551125277,0.1400160081566184,1.2532543553807673,0.6909835373914166,0.0 5 | 3,Claude 1,0.8661789836861621,1.6367517006983248,0.9192817031169805,0.892633920599206,1.7366038577321938,0.022414377755171244,Claude 1,0.014800731479486887,0.10640961234045043,0.21571788662028693,0.36331575933748755,0.9755388385371356,0.05373127597250343 6 | 4,Claude 2,0.9837727962387371,1.7282884502769695,0.20324762694341997,1.1589216989409838,0.4862928178936293,0.04908443540183113,Claude 2,0.009612390842191017,0.10581158386752876,0.3014194551520603,0.5179519787171498,0.8238034253139583,0.009893450901144017 7 | 5,Claude 3,0.9380825048806724,2.078410140789585,1.2350136125713265,1.1025393764063014,0.3526814583565615,0.006188538487622924,Claude 3,0.09733502897075721,0.3282577343256385,0.152514091177587,0.7402724170130903,0.9783947792322534,0.016398794848788242 8 | 36,Llama,0.5909309834901845,0.6093284001136539,0.009119937333717767,0.22741883638701077,0.13838778013156613,0.012716174974567654,Llama,0.038107189777225596,0.0761897452617687,0.05108973723541939,0.14882898134451691,0.1786061039386288,0.009247705168291653 9 | 45,Centaur,0.7615643281005462,2.0909918585453404,0.5409607725515271,0.24509803921568601,0.4040205151075923,0.024398101051203783,Centaur,0.03610349179665707,0.10651043893190425,0.13297999521616086,0.23236684298167481,0.30931376526719956,0.012194726359166009 10 | 47,Human,1.0,1.0,1.0,1.0,1.0,1.0,Human,0.013262564207610785,0.11999477579135502,0.08992596737914589,0.2342977192238481,1.5461889379959568,0.19695222552040062 11 | -------------------------------------------------------------------------------- /contamination/contamination.py: -------------------------------------------------------------------------------- 1 | from unsloth import FastLanguageModel 2 | import transformers 3 | from datasets import load_dataset 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | from scipy.optimize import minimize 8 | 9 | # code from https://colab.research.google.com/drive/1GDbmEMmCVEOwhYk6-1AothdXeAlnqZ_j?usp=copy#scrollTo=L_mG7OJuumrZ 10 | def get_logp(model,tokenizer,text): 11 | #get the logp for each token in text 12 | logp = [] 13 | input_ids = tokenizer.encode(text) 14 | logits = model(torch.tensor(input_ids).unsqueeze(0)).logits 15 | logps = torch.nn.functional.log_softmax(logits[0,-len(input_ids):],dim=-1) 16 | for i in range(len(input_ids)-1): 17 | logp.append(logps[i,input_ids[i+1]].item()) 18 | return [0]+logp 19 | 20 | def fit_model(logp): 21 | # Fit an exponential model on a serie of logprobabilities (the function computes the cumulative probability) 22 | logp = np.array([0]+np.cumsum(np.array(logp)[1:]).tolist()) #Compute the cumulative logprobability 23 | def loss(logp, params): 24 | #Computes an MSE loss 25 | n = len(logp) 26 | A, B = params 27 | x = np.arange(len(logp))/n #Normalize x 28 | y = -A*(1-np.exp(-B*x)) 29 | l = ((logp/n-y)**2).mean() 30 | return l 31 | #Fit the model 32 | A, B = minimize( 33 | lambda params : loss(logp,params), 34 | np.array([1,1]), #Arbitrary initialization 35 | method='BFGS', 36 | tol=10**-20).x 37 | return A, B 38 | 39 | model, tokenizer = FastLanguageModel.from_pretrained( 40 | model_name = "unsloth/Meta-Llama-3.1-70B-bnb-4bit", 41 | max_seq_length = 32768, 42 | dtype = None, 43 | load_in_4bit = True, 44 | ) 45 | FastLanguageModel.for_inference(model) 46 | 47 | dataset = load_dataset("marcelbinz/Psych-101") 48 | unique_experiment_names = dataset.unique('experiment')['train'] 49 | 50 | As = [] 51 | Bs = [] 52 | with torch.no_grad(): 53 | for experiment_name in unique_experiment_names: 54 | print(experiment_name) 55 | subset = dataset.filter(lambda example: example["experiment"].startswith(experiment_name)) 56 | text = subset['train'][0]['text'].split('<<')[0] 57 | logp = get_logp(model,tokenizer,text) 58 | A, B = fit_model(logp) 59 | print('A:', A) 60 | print('B:', B) 61 | As.append(A) 62 | Bs.append(B) 63 | 64 | torch.save(torch.Tensor(As), 'results/As.pth') 65 | torch.save(torch.Tensor(Bs), 'results/Bs.pth') 66 | -------------------------------------------------------------------------------- /extended_data/ed_fig1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import scienceplots 4 | import seaborn as sns 5 | import matplotlib.gridspec as gridspec 6 | from datasets import load_dataset 7 | import numpy as np 8 | from transformers import pipeline 9 | from sklearn.manifold import MDS 10 | 11 | gs = gridspec.GridSpec(2, 2, width_ratios=[0.5, 0.5]) 12 | 13 | Bs = torch.load('../contamination/results/Bs.pth') 14 | 15 | log_Bs = torch.log(Bs) 16 | 17 | fig = plt.figure(figsize=(7.20472, 4)) 18 | plt.style.use(['nature']) 19 | 20 | ax = fig.add_subplot(gs[0, 0]) 21 | 22 | image = plt.imread('overview.png') 23 | cax = ax.imshow(image) 24 | ax.axis('off') 25 | 26 | ax = fig.add_subplot(gs[0, 1]) 27 | 28 | image = plt.imread('wordcloud.png') 29 | cax = ax.imshow(image) 30 | ax.axis('off') 31 | 32 | ax1 = fig.add_subplot(gs[1, 0]) 33 | 34 | ax1.scatter(torch.arange(len(log_Bs)), log_Bs, color='#69005f') 35 | ax1.axhline(y=1, color='grey', linestyle='--', linewidth=1.0) 36 | 37 | ax1.text(len(log_Bs), 1.07, 'potentially contaminated', fontsize=6, color='red', horizontalalignment='right') 38 | ax1.text(len(log_Bs), 0.83, 'not contaminated', fontsize=6, color='green', horizontalalignment='right') 39 | ax1.set_ylabel(r'$\log B$') 40 | ax1.set_xlabel('Experiment') 41 | ax1.set_ylim(-1.6, 1.1) 42 | ax1.set_xlim(-0.5, len(log_Bs)+0.1) 43 | 44 | ax2 = fig.add_subplot(gs[1, 1]) 45 | 46 | eval_experiments_names = [ 47 | 'Modified problem structure (Figure 3b)', 48 | 'Modified cover story (Figure 3a)', 49 | 'Entirely novel domain (Figure 3c)', 50 | 'Moral decision-making', 51 | 'Naturalistic category learning', 52 | 'Naturalistic reward learning', 53 | 'Economic games', 54 | 'Behavioral propensities', 55 | 'Deep sequential decision task', 56 | ] 57 | 58 | 59 | embeddings = np.load('embeddings.npy') 60 | colors = 76 * ['#69005f'] 61 | colors.extend(['C0', 'C1', 'C2', '#ff506e', '#ff506e', '#ff506e', '#ff506e', '#ff506e', '#ff506e']) 62 | 63 | ax2.scatter(embeddings[:, 0], embeddings[:, 1], c=colors, s=25, alpha=0.8) 64 | ax2.set_xlabel('Embedding dimension 1') 65 | ax2.set_ylabel('Embedding dimension 2') 66 | 67 | for i in range(1, len(eval_experiments_names) + 1): 68 | plt.annotate(eval_experiments_names[-i], (0.5 + embeddings[-i, 0], embeddings[-i, 1]-0.2), size=5) 69 | 70 | fig.text(0.015, 0.955, 'a', fontsize=8, weight='bold') 71 | fig.text(0.015, 0.52, 'c', fontsize=8, weight='bold') 72 | fig.text(0.478, 0.955, 'b', fontsize=8, weight='bold') 73 | fig.text(0.478, 0.52, 'd', fontsize=8, weight='bold') 74 | 75 | sns.despine() 76 | plt.tight_layout() 77 | plt.savefig('figures/fig1.jpg', bbox_inches='tight', dpi=300) 78 | plt.show() 79 | -------------------------------------------------------------------------------- /ceiling/no_history.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--model", type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | task_names = [ 15 | "peterson2021using/prompts_zeroshot.jsonl", 16 | "ruggeri2022globalizability/prompts_zeroshot.jsonl", 17 | ] 18 | 19 | model, tokenizer = FastLanguageModel.from_pretrained( 20 | model_name = args.model, 21 | max_seq_length = 32768, 22 | dtype = None, 23 | load_in_4bit = True, 24 | ) 25 | l_id = tokenizer(" <<").input_ids[1:] 26 | r_id = tokenizer(">>").input_ids[1:] 27 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 28 | is_quantized = model.is_quantized 29 | 30 | data = [] 31 | with torch.no_grad(): 32 | for i, task_name in enumerate(task_names): 33 | dataset = load_dataset( 34 | 'json', 35 | data_files={ 36 | 'test': [task_name], 37 | } 38 | ) 39 | 40 | model.is_quantized = False 41 | training_args = TrainingArguments( 42 | output_dir="eval_"+str(i), 43 | per_device_eval_batch_size=1, 44 | report_to="none" 45 | ) 46 | trainer = SFTTrainer( 47 | model=model, 48 | tokenizer=tokenizer, 49 | args=training_args, 50 | train_dataset=dataset['test'], 51 | eval_dataset=dataset['test'], 52 | dataset_text_field="text", 53 | max_seq_length=32768, 54 | data_collator=collator, 55 | ) 56 | model.is_quantized = is_quantized 57 | result = trainer.evaluate() 58 | 59 | print(task_name, flush=True) 60 | print(result, flush=True) 61 | data.append([task_name, result['eval_loss']]) 62 | df = pd.DataFrame(data, columns=['task', str(args.model)]) 63 | print(df, flush=True) 64 | df.to_csv('results/' + args.model.replace('/', '-') + '.csv') 65 | -------------------------------------------------------------------------------- /generalization/generalization.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--model", type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | task_names = [ 15 | "feher2020humans/prompts.jsonl", 16 | "dubois2022value/prompts.jsonl", 17 | ] 18 | 19 | model, tokenizer = FastLanguageModel.from_pretrained( 20 | model_name = args.model, 21 | max_seq_length = 32768, 22 | dtype = None, 23 | load_in_4bit = True, 24 | ) 25 | l_id = tokenizer(" <<").input_ids[1:] 26 | r_id = tokenizer(">>").input_ids[1:] 27 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 28 | is_quantized = model.is_quantized 29 | 30 | data = [] 31 | with torch.no_grad(): 32 | for i, task_name in enumerate(task_names): 33 | dataset = load_dataset( 34 | 'json', 35 | data_files={ 36 | 'test': [task_name], 37 | } 38 | ) 39 | 40 | model.is_quantized = False 41 | training_args = TrainingArguments( 42 | output_dir="eval_"+str(i), 43 | per_device_eval_batch_size=1, 44 | report_to="none" 45 | ) 46 | trainer = SFTTrainer( 47 | model=model, 48 | tokenizer=tokenizer, 49 | args=training_args, 50 | train_dataset=dataset['test'], 51 | eval_dataset=dataset['test'], 52 | dataset_text_field="text", 53 | max_seq_length=32768, 54 | data_collator=collator, 55 | ) 56 | model.is_quantized = is_quantized 57 | result = trainer.evaluate() 58 | 59 | print(task_name, flush=True) 60 | print(result, flush=True) 61 | data.append([task_name.removesuffix('/prompts.jsonl'), result['eval_loss']]) 62 | df = pd.DataFrame(data, columns=['task', str(args.model)]) 63 | print(df, flush=True) 64 | df.to_csv('results/' + args.model.replace('/', '-') + '.csv') 65 | -------------------------------------------------------------------------------- /plots/fig6.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import seaborn as sns 8 | from matplotlib.lines import Line2D 9 | 10 | def get_performance(file_name): 11 | with open(file_name) as f: 12 | d = json.load(f) 13 | k = d.keys() 14 | v = [d[key]['exact_match,strict-match'] if (key == 'metabench_gsm8k') else d[key]['acc,none'] for key in k] 15 | verr = [d[key]['exact_match_stderr,strict-match'] if (key == 'metabench_gsm8k') else d[key]['acc_stderr,none'] for key in k] 16 | return k, v, verr 17 | 18 | k, v_llama, verr_llama = get_performance('../results/metabench/base-llama-3_1-70B-results.json') 19 | _, v_centaur, verr_centaur = get_performance('../results/metabench/centaur-2000-results.json') 20 | 21 | df = pd.DataFrame( 22 | {'Task': k, 23 | 'Llama': v_llama, 24 | 'Centaur': v_centaur 25 | }).set_index('Task') 26 | 27 | 28 | df_ci = pd.DataFrame( 29 | {'Task': k, 30 | 'Llama': verr_llama, 31 | 'Centaur': verr_centaur 32 | }).set_index('Task') 33 | 34 | df.index = df.index.str.replace('metabench_arc', 'ARC') 35 | df.index = df.index.str.replace('metabench_gsm8k', 'GSM8K') 36 | df.index = df.index.str.replace('metabench_hellaswag', 'HellaSwag') 37 | df.index = df.index.str.replace('metabench_mmlu', 'MMLU') 38 | df.index = df.index.str.replace('metabench_truthfulqa', 'TruthfulQA') 39 | df.index = df.index.str.replace('metabench_winogrande', 'Winogrande') 40 | df.index = df.index.str.replace('metabench', 'Mean') 41 | df_ci.index = df_ci.index.str.replace('metabench_arc', 'ARC') 42 | df_ci.index = df_ci.index.str.replace('metabench_gsm8k', 'GSM8K') 43 | df_ci.index = df_ci.index.str.replace('metabench_hellaswag', 'HellaSwag') 44 | df_ci.index = df_ci.index.str.replace('metabench_mmlu', 'MMLU') 45 | df_ci.index = df_ci.index.str.replace('metabench_truthfulqa', 'TruthfulQA') 46 | df_ci.index = df_ci.index.str.replace('metabench_winogrande', 'Winogrande') 47 | df_ci.index = df_ci.index.str.replace('metabench', 'Mean') 48 | df = df[['Centaur', 'Llama']] 49 | df_ci = df_ci[['Centaur', 'Llama']] 50 | 51 | plt.style.use(['nature']) 52 | plt.rcParams["figure.figsize"] = (7.08661, 3) 53 | 54 | df.plot(kind='bar', yerr=df_ci, legend=False, color=['#69005f', '#ff506e'], alpha=0.8) 55 | color_1 = '#69005f' 56 | color_2 = '#ff506e' 57 | custom_lines_r2 = [Line2D([0], [0], color=color_1, alpha=0.8, marker="o", linestyle='None', markersize=5), Line2D([0], [0], color=color_2, alpha=0.8, marker="o", linestyle='None', markersize=5)] 58 | plt.legend(custom_lines_r2, ['Centaur', 'Llama'], frameon=False, ncols=3, bbox_to_anchor=(0.5, 1.3), loc='upper center') 59 | plt.ylabel('Performance') 60 | plt.xlabel('') 61 | plt.ylim(-0.0, 1.1) 62 | 63 | plt.tight_layout() 64 | sns.despine() 65 | plt.savefig('figures/fig6.pdf', bbox_inches='tight') 66 | plt.show() 67 | -------------------------------------------------------------------------------- /generalization/generalization_custom_metrics.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | def preprocess_logits_for_metrics(logits, labels): 10 | with torch.no_grad(): 11 | logits = logits.cpu() 12 | labels = labels.cpu() 13 | labels = torch.cat((labels[0, 1:], -100 * torch.ones(1).long()), 0) 14 | logits = logits[0] 15 | ce = torch.nn.functional.cross_entropy(logits, labels, reduction='none') 16 | total_loss = [] 17 | item_loss = 0 18 | item_counter = 0 19 | for i in range(ce.shape[0]): 20 | if labels[i] != -100: 21 | item_loss += ce[i] 22 | item_counter += 1 23 | else: 24 | if item_counter != 0: 25 | total_loss.append(item_loss) 26 | item_loss = 0 27 | item_counter = 0 28 | return torch.Tensor(total_loss) 29 | 30 | 31 | def compute_metrics(pred): 32 | print(pred.predictions.shape, flush=True) 33 | return {'custom_loss': pred.predictions.mean()} 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model", type=str, required=True) 38 | args = parser.parse_args() 39 | 40 | task_names = [ 41 | "jansen2021logic/prompts.jsonl", 42 | ] 43 | 44 | model, tokenizer = FastLanguageModel.from_pretrained( 45 | model_name = args.model, 46 | max_seq_length = 32768, 47 | dtype = None, 48 | load_in_4bit = True, 49 | ) 50 | l_id = tokenizer(" <<").input_ids[1:] 51 | r_id = tokenizer(">>").input_ids[1:] 52 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 53 | is_quantized = model.is_quantized 54 | 55 | data = [] 56 | with torch.no_grad(): 57 | for i, task_name in enumerate(task_names): 58 | dataset = load_dataset( 59 | 'json', 60 | data_files={ 61 | 'test': [task_name], 62 | } 63 | ) 64 | 65 | model.is_quantized = False 66 | training_args = TrainingArguments( 67 | output_dir="eval_"+str(i), 68 | per_device_eval_batch_size=1, 69 | eval_accumulation_steps=1, 70 | report_to="none" 71 | ) 72 | trainer = SFTTrainer( 73 | model=model, 74 | tokenizer=tokenizer, 75 | args=training_args, 76 | train_dataset=dataset['test'], 77 | eval_dataset=dataset['test'], 78 | dataset_text_field="text", 79 | max_seq_length=32768, 80 | data_collator=collator, 81 | compute_metrics=compute_metrics, 82 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 83 | ) 84 | model.is_quantized = is_quantized 85 | result = trainer.evaluate() 86 | print(task_name, flush=True) 87 | print(result, flush=True) 88 | data.append([task_name.removesuffix('/prompts.jsonl'), result['eval_custom_loss']]) 89 | df = pd.DataFrame(data, columns=['task', str(args.model)]) 90 | print(df, flush=True) 91 | df.to_csv('results/custom_metrics_' + args.model.replace('/', '-') + '.csv') 92 | -------------------------------------------------------------------------------- /test_adapter_custom_metrics.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | def preprocess_logits_for_metrics(logits, labels): 10 | with torch.no_grad(): 11 | logits = logits.cpu() 12 | labels = labels.cpu() 13 | labels = torch.cat((labels[0, 1:], -100 * torch.ones(1).long()), 0) 14 | logits = logits[0] 15 | ce = torch.nn.functional.cross_entropy(logits, labels, reduction='none') 16 | total_loss = [] 17 | item_loss = 0 18 | item_counter = 0 19 | for i in range(ce.shape[0]): 20 | if labels[i] != -100: 21 | item_loss += ce[i] 22 | item_counter += 1 23 | else: 24 | if item_counter != 0: 25 | total_loss.append(item_loss) 26 | item_loss = 0 27 | item_counter = 0 28 | return torch.Tensor(total_loss) 29 | 30 | 31 | def compute_metrics(pred): 32 | print(pred.predictions.shape, flush=True) 33 | return {'custom_loss': pred.predictions.mean()} 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--model", type=str, required=True) 38 | args = parser.parse_args() 39 | 40 | task_names = [ 41 | "collsiöö2023MCPL", 42 | "cox2017information", 43 | "garcia2023experiential", 44 | "jansen2021dunningkruger", 45 | "krueger2022identifying", 46 | "kumar2023disentangling", 47 | "popov2023intent", 48 | "wise2019acomputational", 49 | "wu2018generalisation", 50 | "zhu2020bayesian", 51 | ] 52 | 53 | model, tokenizer = FastLanguageModel.from_pretrained( 54 | model_name = args.model, 55 | max_seq_length = 32768, 56 | dtype = None, 57 | load_in_4bit = True, 58 | ) 59 | l_id = tokenizer(" <<").input_ids[1:] 60 | r_id = tokenizer(">>").input_ids[1:] 61 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 62 | dataset = load_dataset("marcelbinz/Psych-101-test") 63 | is_quantized = model.is_quantized 64 | 65 | data = [] 66 | with torch.no_grad(): 67 | for task_name in task_names: 68 | eval_dataset = dataset['test'].filter(lambda example: example['experiment'].startswith(task_name)) 69 | 70 | model.is_quantized = False 71 | training_args = TrainingArguments( 72 | output_dir="eval", 73 | per_device_eval_batch_size=1, 74 | eval_accumulation_steps=1 75 | ) 76 | trainer = SFTTrainer( 77 | model=model, 78 | tokenizer=tokenizer, 79 | args=training_args, 80 | train_dataset=eval_dataset, 81 | eval_dataset=eval_dataset, 82 | dataset_text_field="text", 83 | max_seq_length=32768, 84 | data_collator=collator, 85 | compute_metrics=compute_metrics, 86 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 87 | ) 88 | model.is_quantized = is_quantized 89 | result = trainer.evaluate() 90 | 91 | print(task_name, flush=True) 92 | print(result, flush=True) 93 | data.append([task_name, result['eval_custom_loss']]) 94 | df = pd.DataFrame(data, columns=['task', str(args.model)]) 95 | print(df, flush=True) 96 | df.to_csv('results/custom_metrics_' + args.model.replace('/', '-') + '.csv') 97 | -------------------------------------------------------------------------------- /plots/fig3_new.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | import math 6 | import scienceplots 7 | import matplotlib.gridspec as gridspec 8 | from functools import reduce 9 | import torch 10 | import math 11 | from scipy import stats 12 | 13 | centaur_70b = torch.load('../generalization/results/generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 14 | llama_70b = torch.load('../generalization/results/generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 15 | 16 | df_baseline = pd.read_csv('../results/all_data_baseline.csv') 17 | df_baseline = df_baseline[df_baseline['unseen'] == 'experiments'][['task', 'baseline']] 18 | 19 | df_random = pd.read_csv('../results/all_data_random.csv') 20 | df_random = df_random[df_random['unseen'] == 'experiments'][['task', 'random']] 21 | 22 | means = {} 23 | sems = {} 24 | for key in centaur_70b.keys(): 25 | print(key) 26 | print(centaur_70b[key].shape) 27 | baseline = df_baseline[df_baseline['task'] == key] 28 | random = df_random[df_random['task'] == key] 29 | means[key] = [] 30 | sems[key] = [] 31 | means[key].append(centaur_70b[key].mean()) 32 | means[key].append(llama_70b[key].mean()) 33 | sems[key].append(centaur_70b[key].std() / math.sqrt(len(centaur_70b[key]))) 34 | sems[key].append(llama_70b[key].std() / math.sqrt(len(llama_70b[key]))) 35 | 36 | print(stats.ttest_ind(centaur_70b[key], llama_70b[key], alternative='less')) 37 | 38 | 39 | if len(baseline) > 0: 40 | means[key].append(baseline.baseline.item()) 41 | print(stats.ttest_1samp(centaur_70b[key], baseline.baseline.item(), alternative='two-sided')) 42 | print(stats.ttest_1samp(llama_70b[key], baseline.baseline.item(), alternative='two-sided')) 43 | else: 44 | means[key].append(0) 45 | sems[key].append(0) 46 | means[key].append(random.random.item()) 47 | sems[key].append(0) 48 | print() 49 | 50 | 51 | 52 | 53 | #print(dfgdfgfd) 54 | gs = gridspec.GridSpec(1, 3, width_ratios=[0.3333, 0.3333, 0.3333]) 55 | offsets = [0.009, 0.026, 0.024] 56 | plt.style.use(['nature']) 57 | fig = plt.figure(figsize=(7.08661, 1.9)) 58 | for task_index, task in enumerate(means.keys()): 59 | print(task) 60 | ax = fig.add_subplot(gs[:, task_index]) 61 | ax.bar(np.arange(3), means[task][:-1], yerr=sems[task][:-1], color=['#69005f', '#ff506e', '#cbc9e2']) 62 | ax.set_xticks(np.arange(3), ['Centaur', 'Llama', 'Cognitive\nmodel']) 63 | ax.axhline(y=means[task][-1], color='grey', linestyle='--', linewidth=1.0) 64 | ax.text(2.5, means[task][-1] + offsets[task_index], 'Random guessing', fontsize=6, color='grey', horizontalalignment='right') 65 | 66 | if task_index == 2: 67 | ax.text(0.775, 0.125, 'N/A', transform=ax.transAxes, va='top') 68 | if task_index == 0: 69 | ax.set_ylabel('Negative log-likelihood') 70 | ax.containers[1][0].set_alpha(0.8) 71 | ax.containers[1][1].set_alpha(0.8) 72 | ax.containers[1][2].set_alpha(1) 73 | ax.set_ylim(0.9 * means[task][0], 1.1 * means[task][-1]) 74 | 75 | 76 | sns.despine() 77 | plt.tight_layout() 78 | plt.savefig('figures/fig3_new.pdf', bbox_inches='tight') 79 | plt.show() 80 | -------------------------------------------------------------------------------- /plots/fig12.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from datasets import load_dataset 3 | import numpy as np 4 | from transformers import pipeline 5 | import torch 6 | from sklearn.manifold import MDS 7 | import seaborn as sns 8 | import scienceplots 9 | 10 | 11 | embeddings = [] 12 | colors = [] 13 | 14 | feature_extractor = pipeline("feature-extraction", framework="pt", model="answerdotai/ModernBERT-base") 15 | 16 | # train data 17 | dataset = load_dataset("marcelbinz/Psych-101") 18 | unique_experiment_names = dataset.unique('experiment')['train'] 19 | for experiment_name in unique_experiment_names: 20 | print(experiment_name) 21 | subset = dataset.filter(lambda example: example["experiment"].startswith(experiment_name)) 22 | text = subset['train'][0]['text'].split('<<')[0] 23 | features = feature_extractor(text, return_tensors = "pt").mean(axis=(0, 1)) 24 | embeddings.append(features) 25 | colors.append('#69005f') 26 | 27 | # eval data 28 | eval_experiments = [ 29 | "../generalization/dubois2022value/prompts.jsonl", 30 | "../generalization/feher2020humans/prompts.jsonl", 31 | "../generalization/jansen2021logic/prompts.jsonl", 32 | "../generalization/additional_experiments/awad2018moral.jsonl", 33 | "../generalization/additional_experiments/demircan2024evaluatingcategory.jsonl", 34 | "../generalization/additional_experiments/demircan2024evaluatingreward.jsonl", 35 | "../generalization/additional_experiments/akata2023repeatedgames.jsonl", 36 | "../generalization/additional_experiments/singh2022representing.jsonl", 37 | "../generalization/additional_experiments/xu2021novelty.jsonl", 38 | ] 39 | eval_experiments_names = [ 40 | 'Modified problem structure (Figure 3b)', 41 | 'Modified cover story (Figure 3a)', 42 | 'Entirely novel domain (Figure 3c)', 43 | 'Moral decision-making', 44 | 'Naturalistic category learning', 45 | 'Naturalistic reward learning', 46 | 'Economic games', 47 | 'Behavioral propensities', 48 | 'Deep sequential decision task', 49 | ] 50 | 51 | colors.extend(['C0', 'C1', 'C2', '#ff506e', '#ff506e', '#ff506e', '#ff506e', '#ff506e', '#ff506e']) 52 | for eval_experiment_name in eval_experiments: 53 | subset = load_dataset('json', 54 | data_files={ 55 | 'train': [eval_experiment_name], 56 | } 57 | ) 58 | text = subset['train'][0]['text'].split('<<')[0] 59 | features = feature_extractor(text, return_tensors = "pt").mean(axis=(0, 1)) 60 | embeddings.append(features) 61 | 62 | embeddings = torch.stack(embeddings, dim=0).numpy() 63 | print(embeddings.shape) 64 | reducer = MDS(n_components=2) 65 | embeddings = reducer.fit_transform(embeddings) 66 | print(embeddings.shape) 67 | 68 | plt.style.use(['nature']) 69 | fig = plt.figure(figsize=(7.08661, 7.08661/2)) 70 | plt.scatter(embeddings[:, 0], embeddings[:, 1], c=colors, s=25, alpha=0.8) 71 | plt.xlabel('Embedding dimension 1') 72 | plt.ylabel('Embedding dimension 2') 73 | 74 | for i in range(1, len(eval_experiments_names) + 1): 75 | plt.annotate(eval_experiments_names[-i], (0.2 + embeddings[-i, 0], embeddings[-i, 1]-0.2)) 76 | 77 | sns.despine() 78 | plt.tight_layout() 79 | plt.savefig('figures/fig12.pdf', bbox_inches='tight') 80 | plt.show() 81 | -------------------------------------------------------------------------------- /plots/fig11.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import numpy as np 8 | 9 | hermes = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Hermes-3-Llama-3.1-70B-bnb-4bit.pth') 10 | nemotron = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Llama-3.1-Nemotron-70B-Instruct-bnb-4bit.pth') 11 | reflection = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Reflection-Llama-3.1-70B-bnb-4bit.pth') 12 | instruct = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-llama-3-70b-Instruct-bnb-4bit.pth') 13 | centaur = torch.load('../results/custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 14 | llama = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 15 | 16 | results_centaur = [] 17 | results_llama = [] 18 | results_instruct = [] 19 | results_nemotron = [] 20 | results_hermes = [] 21 | results_reflection = [] 22 | for key in centaur.keys(): 23 | print(key) 24 | results_centaur.append(centaur[key]) 25 | results_llama.append(llama[key]) 26 | results_instruct.append(instruct[key]) 27 | results_nemotron.append(nemotron[key]) 28 | results_hermes.append(hermes[key]) 29 | results_reflection.append(reflection[key]) 30 | 31 | results_centaur = np.concatenate(results_centaur) 32 | results_llama = np.concatenate(results_llama) 33 | results_instruct = np.concatenate(results_instruct) 34 | results_nemotron = np.concatenate(results_nemotron) 35 | results_hermes = np.concatenate(results_hermes) 36 | results_reflection = np.concatenate(results_reflection) 37 | 38 | gs = gridspec.GridSpec(1, 1, width_ratios=[1]) 39 | plt.style.use(['nature']) 40 | fig = plt.figure(figsize=(7.08661/2, 1.9)) 41 | 42 | 43 | ax = fig.add_subplot(gs[0, :]) 44 | means = [ 45 | np.array(results_centaur).mean(), 46 | np.array(results_llama).mean(), 47 | #np.array(results_instruct).mean(), 48 | np.array(results_nemotron).mean(), 49 | np.array(results_hermes).mean(), 50 | np.array(results_reflection).mean() 51 | ] 52 | 53 | sems = [ 54 | np.array(results_centaur).std() / math.sqrt(len(results_centaur)), 55 | np.array(results_llama).std() / math.sqrt(len(results_centaur)), 56 | #np.array(results_instruct).std() / math.sqrt(len(results_centaur)), 57 | np.array(results_nemotron).std() / math.sqrt(len(results_centaur)), 58 | np.array(results_hermes).std() / math.sqrt(len(results_centaur)), 59 | np.array(results_reflection).std() / math.sqrt(len(results_centaur)) 60 | ] 61 | ax.bar(np.arange(5), means, yerr=sems, color=['#69005f', '#ff506e', 'C0', 'C1', 'C2']) 62 | ax.set_xticks(np.arange(5), ['Centaur', 'Llama', 'Nemotron', 'Hermes', 'Reflection']) 63 | 64 | ax.set_ylabel('Negative log-likelihood') 65 | ax.containers[1][0].set_alpha(0.8) 66 | ax.containers[1][1].set_alpha(0.8) 67 | ax.containers[1][2].set_alpha(0.8) 68 | ax.containers[1][3].set_alpha(0.8) 69 | ax.containers[1][4].set_alpha(0.8) 70 | ax.set_ylim(0.9 * min(means), 1.1 * max(means)) 71 | 72 | sns.despine() 73 | plt.tight_layout() 74 | plt.savefig('figures/fig11.pdf', bbox_inches='tight') 75 | plt.show() 76 | -------------------------------------------------------------------------------- /plots/tab1.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import math 3 | from functools import reduce 4 | 5 | df_exp = pd.read_csv('../experiments.csv', sep=';') 6 | 7 | df_llama_70b = pd.read_csv('../results/all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv') 8 | df_llama_70b = df_llama_70b[df_llama_70b['unseen'] == 'participants'][['task', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit']] 9 | 10 | df_centaur_70b = pd.read_csv('../results/all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv') 11 | df_centaur_70b = df_centaur_70b[df_centaur_70b['unseen'] == 'participants'][['task', 'marcelbinz/Llama-3.1-Centaur-70B-adapter']] 12 | 13 | df_baseline = pd.read_csv('../results/all_data_baseline.csv') 14 | df_baseline = df_baseline[df_baseline['unseen'] == 'participants'][['task', 'baseline']] 15 | 16 | df_random = pd.read_csv('../results/all_data_random.csv') 17 | df_random = df_random[df_random['unseen'] == 'participants'][['task', 'random']] 18 | 19 | df = reduce(lambda left,right: pd.merge(left,right,on=['task'], how='outer'), [df_llama_70b, df_centaur_70b, df_baseline, df_random]) 20 | 21 | for index, row in df.iterrows(): 22 | task_name = df_exp[df_exp['path'] == df.iloc[index]['task'] + '/']['task_name'].item() 23 | df.loc[index, 'task'] = task_name 24 | df = df.groupby('task', as_index=False, sort=False).mean() 25 | 26 | df = df.rename(columns={"baseline": "Cognitive model"}) 27 | df = df.rename(columns={"marcelbinz/Llama-3.1-Centaur-70B-adapter": "Centaur"}) 28 | df = df.rename(columns={"unsloth/Meta-Llama-3.1-70B-bnb-4bit": "Llama"}) 29 | df = df.rename(columns={"random": "Random"}) 30 | df = df.rename(columns={"task": "task_name"}) 31 | 32 | print(df) 33 | 34 | prompt = '\\begin{table}[]\n' 35 | prompt += '\\centering \n' 36 | prompt += '\\begin{tabular}{@{}lccc@{}} \n' 37 | prompt += '\\toprule \n' 38 | prompt += "\\textbf{Experiment} & \\textbf{Centaur} & \\textbf{Llama} & \\textbf{Cognitive model} \\\\ \n" 39 | prompt += '\\midrule \n' 40 | for i, row in df.iterrows(): 41 | prompt += str(row['task_name']) + ' & ' + str(format(row['Centaur'], '.4f')) + ' & ' + str(format(row['Llama'], '.4f')) + ' & ' + str(format(row['Cognitive model'], '.4f')) + ' \\\\ \n' 42 | prompt += '\\bottomrule \\\\ \n' 43 | prompt += '\\end{tabular} \n' 44 | prompt += '\\caption{Full negative log-likelihoods results on held-out participants.}\n' 45 | prompt += '\\label{tab:tab2} \n' 46 | prompt += '\\end{table}' 47 | print(prompt) 48 | 49 | df['Centaur'] = 1 - (-df['Centaur']/-df['Random']) 50 | df['Llama'] = 1 - (-df['Llama']/-df['Random']) 51 | df['Cognitive model'] = 1 - (-df['Cognitive model']/-df['Random']) 52 | 53 | prompt = '\\begin{table}[]\n' 54 | prompt += '\\centering \n' 55 | prompt += '\\begin{tabular}{@{}lccc@{}} \n' 56 | prompt += '\\toprule \n' 57 | prompt += "\\textbf{Experiment} & \\textbf{Centaur} & \\textbf{Llama} & \\textbf{Cognitive model} \\\\ \n" 58 | prompt += '\\midrule \n' 59 | for i, row in df.iterrows(): 60 | prompt += str(row['task_name']) + ' & ' + str(format(row['Centaur'], '.4f')) + ' & ' + str(format(row['Llama'], '.4f')) + ' & ' + str(format(row['Cognitive model'], '.4f')) + ' \\\\ \n' 61 | prompt += '\\bottomrule \\\\ \n' 62 | prompt += '\\end{tabular} \n' 63 | prompt += '\\caption{Full pseudo-R$^2$ results on held-out participants.}\n' 64 | prompt += '\\label{tab:tab3} \n' 65 | prompt += '\\end{table}' 66 | 67 | print(prompt) 68 | -------------------------------------------------------------------------------- /extended_data/ed_fig4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import numpy as np 8 | 9 | hermes = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Hermes-3-Llama-3.1-70B-bnb-4bit.pth') 10 | nemotron = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Llama-3.1-Nemotron-70B-Instruct-bnb-4bit.pth') 11 | reflection = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Reflection-Llama-3.1-70B-bnb-4bit.pth') 12 | instruct = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-llama-3-70b-Instruct-bnb-4bit.pth') 13 | centaur = torch.load('../results/custom_metrics_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 14 | llama = torch.load('../results/custom_metrics_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 15 | 16 | results_centaur = [] 17 | results_llama = [] 18 | results_instruct = [] 19 | results_nemotron = [] 20 | results_hermes = [] 21 | results_reflection = [] 22 | for key in centaur.keys(): 23 | print(key) 24 | results_centaur.append(centaur[key]) 25 | results_llama.append(llama[key]) 26 | results_instruct.append(instruct[key]) 27 | results_nemotron.append(nemotron[key]) 28 | results_hermes.append(hermes[key]) 29 | results_reflection.append(reflection[key]) 30 | 31 | results_centaur = np.concatenate(results_centaur) 32 | results_llama = np.concatenate(results_llama) 33 | results_instruct = np.concatenate(results_instruct) 34 | results_nemotron = np.concatenate(results_nemotron) 35 | results_hermes = np.concatenate(results_hermes) 36 | results_reflection = np.concatenate(results_reflection) 37 | 38 | gs = gridspec.GridSpec(1, 1, width_ratios=[1]) 39 | plt.style.use(['nature']) 40 | fig = plt.figure(figsize=(7.20472/2, 1.9)) 41 | 42 | 43 | ax = fig.add_subplot(gs[0, :]) 44 | means = [ 45 | np.array(results_centaur).mean(), 46 | np.array(results_llama).mean(), 47 | #np.array(results_instruct).mean(), 48 | np.array(results_nemotron).mean(), 49 | np.array(results_hermes).mean(), 50 | np.array(results_reflection).mean() 51 | ] 52 | 53 | sems = [ 54 | np.array(results_centaur).std() / math.sqrt(len(results_centaur)), 55 | np.array(results_llama).std() / math.sqrt(len(results_centaur)), 56 | #np.array(results_instruct).std() / math.sqrt(len(results_centaur)), 57 | np.array(results_nemotron).std() / math.sqrt(len(results_centaur)), 58 | np.array(results_hermes).std() / math.sqrt(len(results_centaur)), 59 | np.array(results_reflection).std() / math.sqrt(len(results_centaur)) 60 | ] 61 | ax.bar(np.arange(5), means, yerr=sems, color=['#69005f', '#ff506e', 'C0', 'C1', 'C2']) 62 | ax.set_xticks(np.arange(5), ['Centaur', 'Llama', 'Nemotron', 'Hermes', 'Reflection']) 63 | 64 | ax.set_ylabel('Negative log-likelihood') 65 | ax.containers[1][0].set_alpha(0.8) 66 | ax.containers[1][1].set_alpha(0.8) 67 | ax.containers[1][2].set_alpha(0.8) 68 | ax.containers[1][3].set_alpha(0.8) 69 | ax.containers[1][4].set_alpha(0.8) 70 | ax.set_ylim(0.9 * min(means), 1.15 * max(means)) 71 | ax.set_yticks([0.5, 0.6, 0.7, 0.8]) 72 | 73 | sns.despine() 74 | plt.tight_layout() 75 | plt.savefig('figures/fig4.jpg', bbox_inches='tight', dpi=300) 76 | plt.show() 77 | -------------------------------------------------------------------------------- /generalization/additional_generalization.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | def full_log_likelihoods(logits, labels): 10 | with torch.no_grad(): 11 | logits = logits.float().cpu() 12 | labels = labels.cpu() 13 | labels = torch.cat((labels[0, 1:], -100 * torch.ones(1).long()), 0) 14 | logits = logits[0] 15 | ce = torch.nn.functional.cross_entropy(logits, labels, reduction='none') 16 | total_loss = [] 17 | item_loss = 0 18 | item_counter = 0 19 | for i in range(ce.shape[0]): 20 | if labels[i] != -100: 21 | item_loss += ce[i] 22 | item_counter += 1 23 | else: 24 | if item_counter != 0: 25 | total_loss.append(item_loss) 26 | item_loss = 0 27 | item_counter = 0 28 | return torch.Tensor(total_loss) 29 | 30 | 31 | def compute_metrics(pred): 32 | return {'custom_loss': pred.predictions} 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--model", type=str, required=True) 37 | args = parser.parse_args() 38 | 39 | task_names = [ 40 | "additional_experiments/awad2018moral.jsonl", 41 | "additional_experiments/demircan2024evaluatingcategory.jsonl", 42 | "additional_experiments/demircan2024evaluatingreward.jsonl", 43 | "additional_experiments/akata2023repeatedgames.jsonl", 44 | "additional_experiments/singh2022representing.jsonl", 45 | "additional_experiments/xu2021novelty.jsonl", 46 | ] 47 | 48 | model, tokenizer = FastLanguageModel.from_pretrained( 49 | model_name = args.model, 50 | max_seq_length = 32768, 51 | dtype = None, 52 | load_in_4bit = True, 53 | ) 54 | l_id = tokenizer(" <<").input_ids[1:] 55 | r_id = tokenizer(">>").input_ids[1:] 56 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 57 | is_quantized = model.is_quantized 58 | 59 | data = {} 60 | with torch.no_grad(): 61 | for i, task_name in enumerate(task_names): 62 | dataset = load_dataset('json', 63 | data_files={ 64 | 'test': [task_name], 65 | } 66 | ) 67 | 68 | model.is_quantized = False 69 | training_args = TrainingArguments( 70 | output_dir="eval_"+str(i), 71 | per_device_eval_batch_size=1, 72 | report_to="none", 73 | eval_accumulation_steps=1 74 | ) 75 | trainer = SFTTrainer( 76 | model=model, 77 | tokenizer=tokenizer, 78 | args=training_args, 79 | train_dataset=dataset['test'], 80 | eval_dataset=dataset['test'], 81 | dataset_text_field="text", 82 | max_seq_length=32768, 83 | data_collator=collator, 84 | compute_metrics=compute_metrics, 85 | preprocess_logits_for_metrics=full_log_likelihoods, 86 | ) 87 | model.is_quantized = is_quantized 88 | result = trainer.evaluate() 89 | 90 | print(task_name, flush=True) 91 | print(result, flush=True) 92 | data[task_name] = result['eval_custom_loss'] 93 | 94 | torch.save(data, 'results/additional_generalization_full_log_likelihoods_' + args.model.replace('/', '-') + '.pth') -------------------------------------------------------------------------------- /openloop/openloop.py: -------------------------------------------------------------------------------- 1 | from models import RescorlaWagnerModel, DualSystems 2 | from trainers import Trainer 3 | import pandas as pd 4 | import torch 5 | from datasets import load_dataset 6 | import os 7 | 8 | experiments = [ 9 | #{'name': 'horizon1', 'agent': 'centaur', 'path': 'wilson2014humans/simulation0.csv', 'model': RescorlaWagnerModel(num_options=2)}, 10 | #{'name': 'horizon1', 'agent': 'human', 'path': 'wilson2014humans/exp1.csv', 'model': RescorlaWagnerModel(num_options=2)}, 11 | #{'name': 'horizon2', 'agent': 'centaur', 'path': 'wilson2014humans/simulation2.csv', 'model': RescorlaWagnerModel(num_options=2)}, 12 | #{'name': 'horizon2', 'agent': 'human', 'path': 'wilson2014humans/exp3.csv', 'model': RescorlaWagnerModel(num_options=2)}, 13 | #{'name': 'horizon3', 'agent': 'centaur', 'path': 'wilson2014humans/simulation3.csv', 'model': RescorlaWagnerModel(num_options=2)}, 14 | #{'name': 'horizon3', 'agent': 'human', 'path': 'wilson2014humans/exp4.csv', 'model': RescorlaWagnerModel(num_options=2)}, 15 | #{'name': 'horizon4', 'agent': 'centaur', 'path': 'wilson2014humans/simulation4.csv', 'model': RescorlaWagnerModel(num_options=2)}, 16 | #{'name': 'horizon4', 'agent': 'human', 'path': 'wilson2014humans/exp5.csv', 'model': RescorlaWagnerModel(num_options=2)}, 17 | {'name': 'twostep1', 'agent': 'centaur', 'path': 'kool2016when/simulation.csv', 'model': DualSystems(variant='two_step')}, 18 | {'name': 'twostep1', 'agent': 'human', 'path': 'kool2016when/exp2.csv', 'model': DualSystems(variant='two_step')}, 19 | {'name': 'twostep2', 'agent': 'centaur', 'path': 'kool2017cost/simulation.csv', 'model': DualSystems(variant='two_step')}, 20 | {'name': 'twostep2', 'agent': 'human', 'path': 'kool2017cost/exp2.csv', 'model': DualSystems(variant='two_step')}, 21 | ] 22 | 23 | for index in range(len(experiments)): 24 | data = [] 25 | 26 | df = pd.read_csv(experiments[index]['path']) 27 | 28 | # select human participants 29 | if (('twostep' in experiments[index]['name']) and (experiments[index]['agent'] == 'human')) or (('horizon' in experiments[index]['name']) and (experiments[index]['agent'] == 'human')): 30 | dataset = load_dataset("marcelbinz/Psych-101-test") 31 | eval_dataset = dataset['test'].filter(lambda example: example['experiment'].startswith(experiments[index]['path'])) 32 | eval_participants = list(map(int, eval_dataset['participant'])) 33 | df = df[df['participant'].isin(eval_participants)] 34 | print(eval_participants) 35 | 36 | # match simulated data 37 | if ('horizon' in experiments[index]['name']): 38 | df = df[df['participant'] < 100] 39 | df = df[df['task'] < 100] 40 | 41 | for participant in df['participant'].unique(): 42 | df_participant = df[df['participant'] == participant] 43 | 44 | trainer = Trainer(experiments[index]['model']) 45 | nll = trainer.fit_and_evaluate(df_participant, df_participant).item() 46 | 47 | if ('horizon' in experiments[index]['name']): 48 | params = trainer.model.information_logits.beta.item() 49 | data.append([participant, params, df_participant[df_participant['forced'] == 0]['reward'].mean()]) 50 | elif ('twostep' in experiments[index]['name']): 51 | params = torch.sigmoid(trainer.model.tau).item() 52 | data.append([participant, params, df_participant['reward'].mean()]) 53 | 54 | df = pd.DataFrame(data, columns=['participant', 'param', 'reward']) 55 | print(df) 56 | df.to_csv('results/baselines_openloop_' + experiments[index]['agent'] + '_' + experiments[index]['name'] + '.csv') 57 | -------------------------------------------------------------------------------- /ceiling/models.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | def pd_to_pth(df, values, keys=['participant', 'task', 'trial']): 8 | column_names_list = [keys + [value] for value in values] 9 | wide_arrs = {} 10 | for column_names in column_names_list: 11 | arr = df[column_names].values 12 | dims = [np.unique(arr[:, i], return_inverse=True) for i in range(len(column_names)-1)] 13 | wide_arr = np.full([len(dims[i][0]) for i in range(len(column_names)-1)], np.nan) 14 | wide_arr[*[dims[i][1] for i in range(len(column_names)-1)]] = arr[:, -1] 15 | wide_arrs[column_names[-1]] = torch.from_numpy(wide_arr).reshape(-1, wide_arr.shape[-1]) 16 | return wide_arrs 17 | 18 | class DunningKruger(nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | self.param_tensor = nn.Parameter(torch.randn(28, 11)) 22 | 23 | def preprocess_data(self, train_df, eval_df): 24 | for i in range(4, 24): 25 | train_df.loc[train_df['trial'] == i, 'choice'] = train_df[train_df['trial'] == i]['choice'].astype('category').cat.codes 26 | eval_df.loc[eval_df['trial'] == i, 'choice'] = eval_df[eval_df['trial'] == i]['choice'].astype('category').cat.codes 27 | 28 | normalizer = torch.Tensor([2, 10, 1, 1, 29 | 1, 1, 1, 1, 1, 30 | 1, 1, 1, 1, 1, 31 | 1, 1, 1, 1, 1, 32 | 1, 1, 1, 1, 1, 33 | 2, 10, 1, 1]) 34 | 35 | train_data = {} 36 | num_train_participants = len(train_df.participant.unique()) 37 | train_data['choice'] = torch.from_numpy(train_df[(train_df['trial'] != 24)]['choice'].values.astype('float')) 38 | train_data['choice'] = (train_data['choice'] // normalizer.repeat(num_train_participants)).long() 39 | 40 | eval_data = {} 41 | num_eval_participants = len(eval_df.participant.unique()) 42 | eval_data['choice'] = torch.from_numpy(eval_df[(eval_df['trial'] != 24)]['choice'].values.astype('float')) 43 | eval_data['choice'] = (eval_data['choice'] // normalizer.repeat(num_eval_participants)).long() 44 | 45 | return train_data, eval_data 46 | 47 | def forward(self, data): 48 | num_participants = int(data['choice'].shape[0] / 28) 49 | params = self.param_tensor.repeat(num_participants, 1) 50 | 51 | return params 52 | 53 | class NoiseCeiling(nn.Module): 54 | def __init__(self, UID, num_questions=14568, num_options=2, ): 55 | super().__init__() 56 | self.param_tensor = nn.Parameter(torch.randn(num_questions, num_options)) 57 | self.UID = UID 58 | 59 | def preprocess_data(self, train_df, eval_df): 60 | train_data = {} 61 | 62 | mapping_dict = {k: v for k, v in zip(train_df[self.UID], train_df[self.UID].astype('category').cat.codes)} 63 | print(mapping_dict) 64 | 65 | train_data['choice'] = torch.from_numpy(train_df['choice'].values) 66 | train_data[self.UID] = torch.from_numpy(train_df[self.UID].map(mapping_dict).values).long() 67 | 68 | 69 | eval_data = {} 70 | num_eval_participants = len(eval_df.participant.unique()) 71 | eval_data['choice'] = torch.from_numpy(eval_df['choice'].values) 72 | eval_data[self.UID] = torch.from_numpy(eval_df[self.UID].map(mapping_dict).values).long() 73 | 74 | return train_data, eval_data 75 | 76 | def forward(self, data): 77 | params = self.param_tensor[data[self.UID]] 78 | return params 79 | -------------------------------------------------------------------------------- /results/CogBench/behaviour.csv: -------------------------------------------------------------------------------- 1 | ,Agent,Prior weighting,Likelihood weighting,Directed exploration,Random exploration,Meta-cognition,Learning rate,Optimism bias,Model-basedness,Temporal discounting,Risk taking,Agent_ci,Prior weighting_ci,Likelihood weighting_ci,Directed exploration_ci,Random exploration_ci,Meta-cognition_ci,Learning rate_ci,Optimism bias_ci,Model-basedness_ci,Temporal discounting_ci,Risk taking_ci 2 | 0,GPT-4,0.6918502348688533,0.29972814131681497,0.7261821167136815,0.11432670783542695,0.7063745877990604,0.1195298614767119,2.7199230583083627,4.530319709185224,0.7487437185929644,0.5220730823900763,GPT-4,0.11669288488327707,0.16898431412042575,1.3665838456105215,1.6980979186418086,0.321091650172654,0.6243566952498483,1.1814623930676011,0.5357208219046923,0.0,0.11227249227267833 3 | 1,GPT-3.5,0.45198493054745376,0.24874372612046447,0.006946083421069069,0.422308285276341,0.40561279599549815,0.2505084753946049,3.531964351247797,0.5143381654947679,2.6013400335008354,1.390178511417412,GPT-3.5,0.10189801192019035,0.16173115668228852,1.0995449189186144,1.288057305792493,0.3483722643046461,0.6332701635558806,0.5507643382094584,0.40220312468567637,0.0,0.30103829320513914 4 | 2,GPT-3,0.3983162533267074,0.4787687057791692,0.30179384903157386,0.6897716283587736,0.6718447898501536,0.4264339313045442,2.537114923647825,1.2552117395833804,2.6013400335008354,1.567208764509809,GPT-3,0.118152146750305,0.17679131284500033,0.747388222832115,1.061678554381139,0.6571367407261278,0.9423183821030914,1.3815856352698743,0.43097212979072025,0.0,0.09953044249431713 5 | 3,Claude 1,0.8446903242903696,0.22506455679414872,0.5731222111973278,0.07849776680314616,0.9675754852320094,0.003732420006771874,3.54681444737725,0.9494136165667084,1.7638190954773856,0.08155116489217369,Claude 1,0.013288920393635162,0.013220137214679697,0.38401799832944516,0.625197923381597,0.34224497500983114,0.36066002419973303,0.4112736557363793,0.3043706750491891,0.0,0.15756168747171415 6 | 4,Claude 2,0.8305908609729911,0.37388756433276327,0.2214048771437066,0.19040041721137077,1.2677097121981915,0.3784109401488164,2.731943987778644,1.525563700915961,5.951423785594636,0.0371535755804546,Claude 2,0.009668306539946029,0.00950888404428697,0.3054974065447095,0.5983806682078128,0.7458646023387326,0.5147002825845868,0.6194046552025605,0.29735694599207846,0.0,0.010648204156515783 7 | 5,Claude 3,0.8501087419025741,0.37081587486908346,0.8618762245820757,0.822754061867183,1.674040910876299,1.1998759625390634,0.2281936714618803,3.92476665297887,4.936348408710215,0.009223048410424413,Claude 3,0.10596655207450523,0.10681140040396382,0.8694966509714979,1.8389935655313085,0.3910876535156263,0.7484568966637302,1.5739354063202282,0.5532691746662123,0.0,0.011681488469090479 8 | 36,Llama,0.5207501596711092,0.06828420961234041,0.03422759568895414,0.259937573744107,0.14472819554209818,0.0007261683951087816,0.12114846954922624,0.09706040158495188,0.39698492462311546,0.012235164085623746,Llama,0.13248621778122796,0.1311837078422043,0.10828092113471542,0.1385443410556201,0.08262626173358908,0.35433742945783436,0.8006207002172121,0.12652249213517522,0.2657674712794893,0.0070699453853747835 9 | 45,Centaur,0.6786120888542513,0.14967184610871387,0.09467564777737215,0.6315928477750391,0.5094482296733448,0.15052676468438095,0.7312746454876076,1.3240095153343965,0.7487437185929644,0.022267251830295896,Centaur,0.05162566533880502,0.05182844886633703,0.1981409570620531,0.3340462211430205,0.3133483375111889,0.5040981168845569,1.3063751537208785,0.21595852708578847,0.0,0.009935916313346886 10 | 47,Human,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,Human,0.0006823050187188417,0.0006685383948290158,0.19552745429465865,0.3601984229417105,0.18277476400588627,0.32521819686741804,0.5477074196879459,0.9588584626649467,0.0,0.14458456777625245 11 | -------------------------------------------------------------------------------- /test_adapter.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--model", type=str, required=True) 12 | args = parser.parse_args() 13 | 14 | task_names = [ 15 | "badham2017deficits", 16 | "bahrami2020four", 17 | "enkavi2019adaptivenback", 18 | "enkavi2019digitspan", 19 | "enkavi2019gonogo", 20 | "enkavi2019recentprobes", 21 | "feng2021dynamics", 22 | "flesch2018comparing", 23 | "frey2017cct", 24 | "frey2017risk", 25 | "gershman2018deconstructing", 26 | "gershman2020reward", 27 | "hebart2023things", 28 | "hilbig2014generalized", 29 | "kool2016when", 30 | "kool2017cost", 31 | "lefebvre2017behavioural", 32 | "levering2020revisiting", 33 | "ludwig2023human", 34 | "peterson2021using", 35 | "plonsky2018when", 36 | "ruggeri2022globalizability", 37 | "sadeghiyeh2020temporal", 38 | "schulz2020finding", 39 | "somerville2017charting", 40 | "speekenbrink2008learning", 41 | "steingroever2015data", 42 | "tomov2020discovery", 43 | "tomov2021multitask", 44 | "waltz2020differential", 45 | "wilson2014humans", 46 | "wu2023chunking", 47 | "wulff2018description", 48 | "wulff2018sampling", 49 | "xiong2023neural", 50 | "zorowitz2023data", 51 | ] 52 | 53 | model, tokenizer = FastLanguageModel.from_pretrained( 54 | model_name = args.model, 55 | max_seq_length = 32768, 56 | dtype = None, 57 | load_in_4bit = True, 58 | ) 59 | l_id = tokenizer(" <<").input_ids[1:] 60 | r_id = tokenizer(">>").input_ids[1:] 61 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 62 | dataset = load_dataset("marcelbinz/Psych-101-test") 63 | is_quantized = model.is_quantized 64 | 65 | data = [] 66 | with torch.no_grad(): 67 | for task_name in task_names: 68 | eval_dataset = dataset['test'].filter(lambda example: example['experiment'].startswith(task_name)) 69 | 70 | model.is_quantized = False 71 | training_args = TrainingArguments( 72 | output_dir="eval", 73 | per_device_eval_batch_size=1 74 | ) 75 | trainer = SFTTrainer( 76 | model=model, 77 | tokenizer=tokenizer, 78 | args=training_args, 79 | train_dataset=eval_dataset, 80 | eval_dataset=eval_dataset, 81 | dataset_text_field="text", 82 | max_seq_length=32768, 83 | data_collator=collator, 84 | ) 85 | model.is_quantized = is_quantized 86 | result = trainer.evaluate() 87 | 88 | print(task_name, flush=True) 89 | print(result, flush=True) 90 | data.append([task_name, result['eval_loss']]) 91 | df = pd.DataFrame(data, columns=['task', str(args.model)]) 92 | print(df, flush=True) 93 | df.to_csv('results/' + args.model.replace('/', '-') + '.csv') 94 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from datasets import load_dataset 4 | import pandas as pd 5 | import argparse 6 | import torch 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model", type=str, required=True) 11 | args = parser.parse_args() 12 | 13 | task_names = [ 14 | "badham2017deficits", 15 | "bahrami2020four", 16 | "enkavi2019adaptivenback", 17 | "enkavi2019digitspan", 18 | "enkavi2019gonogo", 19 | "enkavi2019recentprobes", 20 | "feng2021dynamics", 21 | "flesch2018comparing", 22 | "frey2017cct", 23 | "frey2017risk", 24 | "gershman2018deconstructing", 25 | "gershman2020reward", 26 | "hebart2023things", 27 | "hilbig2014generalized", 28 | "kool2016when", 29 | "kool2017cost", 30 | "lefebvre2017behavioural", 31 | "levering2020revisiting", 32 | "ludwig2023human", 33 | "peterson2021using", 34 | "plonsky2018when", 35 | "ruggeri2022globalizability", 36 | "sadeghiyeh2020temporal", 37 | "schulz2020finding", 38 | "somerville2017charting", 39 | "speekenbrink2008learning", 40 | "steingroever2015data", 41 | "tomov2020discovery", 42 | "tomov2021multitask", 43 | "waltz2020differential", 44 | "wilson2014humans", 45 | "wu2023chunking", 46 | "wulff2018description", 47 | "wulff2018sampling", 48 | "xiong2023neural", 49 | "zorowitz2023data", 50 | ] 51 | 52 | model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, attn_implementation='flash_attention_2', device_map="auto") 53 | tokenizer = AutoTokenizer.from_pretrained(args.model) 54 | l_id = tokenizer(" <<").input_ids[1:] 55 | r_id = tokenizer(">>").input_ids[1:] 56 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 57 | dataset = load_dataset("marcelbinz/Psych-101-test") 58 | is_quantized = model.is_quantized 59 | 60 | data = [] 61 | with torch.no_grad(): 62 | for task_name in task_names: 63 | eval_dataset = dataset['test'].filter(lambda example: example['experiment'].startswith(task_name)) 64 | 65 | model.is_quantized = False 66 | training_args = TrainingArguments( 67 | output_dir="eval", 68 | per_device_eval_batch_size=1 69 | ) 70 | trainer = SFTTrainer( 71 | model=model, 72 | tokenizer=tokenizer, 73 | args=training_args, 74 | train_dataset=eval_dataset, 75 | eval_dataset=eval_dataset, 76 | dataset_text_field="text", 77 | max_seq_length=32768, 78 | data_collator=collator, 79 | ) 80 | model.is_quantized = is_quantized 81 | result = trainer.evaluate() 82 | 83 | print(task_name, flush=True) 84 | print(result, flush=True) 85 | data.append([task_name, result['eval_loss']]) 86 | df = pd.DataFrame(data, columns=['task', str(args.model)]) 87 | print(df, flush=True) 88 | df.to_csv('results/' + args.model.replace('/', '-') + '.csv') 89 | -------------------------------------------------------------------------------- /plots/fig10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import numpy as np 8 | 9 | centaur_70b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 10 | centaur_8b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth') 11 | llama_70b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 12 | llama_8b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth') 13 | 14 | nll_random = { 15 | 'additional_experiments/awad2018moral.jsonl': -math.log(1/2), 16 | 'additional_experiments/demircan2024evaluatingcategory.jsonl': -math.log(1/2), 17 | 'additional_experiments/demircan2024evaluatingreward.jsonl': -math.log(1/2), 18 | 'additional_experiments/akata2023repeatedgames.jsonl': -math.log(1/2), 19 | 'additional_experiments/singh2022representing.jsonl': -math.log(1/7), 20 | 'additional_experiments/xu2021novelty.jsonl': -math.log(1/3), 21 | } 22 | 23 | task_names = { 24 | 'additional_experiments/awad2018moral.jsonl': 'Moral decision-making', 25 | 'additional_experiments/demircan2024evaluatingcategory.jsonl': 'Naturalistic category learning', 26 | 'additional_experiments/demircan2024evaluatingreward.jsonl': 'Naturalistic reward learning', 27 | 'additional_experiments/akata2023repeatedgames.jsonl': 'Economic games', 28 | 'additional_experiments/singh2022representing.jsonl': 'Behavioral propensities', 29 | 'additional_experiments/xu2021novelty.jsonl': 'Deep sequential decision task', 30 | } 31 | 32 | gs = gridspec.GridSpec(2, 3, width_ratios=[0.3333, 0.3333, 0.3333]) 33 | plt.style.use(['nature']) 34 | fig = plt.figure(figsize=(7.08661, 3.8)) 35 | 36 | offsets = [0.01, 0.01, 0.005, 0.01, 0.02, 0.01] 37 | for i, key in enumerate(centaur_70b.keys()): 38 | print(key) 39 | centaur_70b_r2 = centaur_70b[key].mean().item() 40 | centaur_8b_r2 = centaur_8b[key].mean().item() 41 | llama_70b_r2 = llama_70b[key].mean().item() 42 | llama_8b_r2 = llama_8b[key].mean().item() 43 | centaur_70b_r2_se = centaur_70b[key].std().item() / math.sqrt(len(centaur_70b[key])) 44 | centaur_8b_r2_se = centaur_8b[key].std().item() / math.sqrt(len(centaur_8b[key])) 45 | llama_70b_r2_se = llama_70b[key].std().item() / math.sqrt(len(llama_70b[key])) 46 | llama_8b_r2_se = llama_8b[key].std().item() / math.sqrt(len(llama_8b[key])) 47 | 48 | ax = fig.add_subplot(gs[0 if i < 3 else 1, i % 3]) 49 | values = np.array([centaur_70b_r2, centaur_8b_r2, llama_70b_r2, llama_8b_r2]) 50 | ax.bar(np.arange(4), values, yerr=[centaur_70b_r2_se, centaur_8b_r2_se, llama_70b_r2_se, llama_8b_r2_se], color=['#69005f', '#69005f', '#ff506e', '#ff506e'])# 'C0', 'C1' 51 | ax.set_xticks(np.arange(4), ['Centaur', 'Minitaur', 'Llama\n(70B)', 'Llama\n(8B)']) 52 | ax.axhline(y=nll_random[key], color='grey', linestyle='--', linewidth=1.0) 53 | ax.text(3.5, nll_random[key] + offsets[i], 'Random guessing', fontsize=6, color='grey', horizontalalignment='right') 54 | ax.set_ylim(0.9 * min(nll_random[key], min(values)), 1.1 * max(max(values), nll_random[key])) 55 | 56 | if i == 0 or i == 3: 57 | ax.set_ylabel('Negative log-likelihood') 58 | ax.containers[1][0].set_alpha(0.8) 59 | ax.containers[1][1].set_alpha(0.5) 60 | ax.containers[1][2].set_alpha(0.8) 61 | ax.containers[1][3].set_alpha(0.5) 62 | ax.set_title(task_names[key], fontsize=8) 63 | sns.despine() 64 | plt.tight_layout() 65 | plt.savefig('figures/fig10.pdf', bbox_inches='tight') 66 | plt.show() 67 | -------------------------------------------------------------------------------- /plots/fig5.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import jsonlines 3 | from glob import glob 4 | import numpy as np 5 | import statsmodels.formula.api as sm 6 | import pandas as pd 7 | import jsonlines 8 | from glob import glob 9 | import numpy as np 10 | import statsmodels.formula.api as sm 11 | import matplotlib.pyplot as plt 12 | import scienceplots 13 | import matplotlib.gridspec as gridspec 14 | import seaborn as sns 15 | from functools import reduce 16 | 17 | df = pd.read_csv('../experiments.csv', sep=';') 18 | df['path'] = df['path'].str.replace('/','') 19 | df = df.rename(columns={"path": "task"}) 20 | 21 | df_llama_70b = pd.read_csv('../results/all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv') 22 | df_llama_70b = df_llama_70b[df_llama_70b['unseen'] == 'participants'][['task', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit']] 23 | 24 | df_centaur_70b = pd.read_csv('../results/all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv') 25 | df_centaur_70b = df_centaur_70b[df_centaur_70b['unseen'] == 'participants'][['task', 'marcelbinz/Llama-3.1-Centaur-70B-adapter']] 26 | 27 | df_random = pd.read_csv('../results/all_data_random.csv') 28 | df_random = df_random[df_random['unseen'] == 'participants'][['task', 'random']] 29 | 30 | df_llms = reduce(lambda left,right: pd.merge(left,right,on=['task'], how='outer'), [df_llama_70b, df_centaur_70b, df_random]) 31 | 32 | for index, row in df_llms.iterrows(): 33 | df_llms.loc[index, 'num_actions'] = df[df['task'] == row['task']]['num_actions'].item() 34 | df_llms.loc[index, 'num_participants'] = df[df['task'] == row['task']]['num_participants'].item() 35 | df_llms.loc[index, 'num_choices'] = df[df['task'] == row['task']]['num_choices'].item() 36 | df_llms.loc[index, 'num_characters'] = df[df['task'] == row['task']]['num_characters'].item() 37 | df_llms.loc[index, 'task_type'] = df[df['task'] == row['task']]['task_type'].item() 38 | df_llms.loc[index, 'split'] = df[df['task'] == row['task']]['split'].item() 39 | 40 | ll_centaur = -df_llms['marcelbinz/Llama-3.1-Centaur-70B-adapter'] 41 | ll_llama = -df_llms['unsloth/Meta-Llama-3.1-70B-bnb-4bit'] 42 | ll_random = -df_llms['random'] 43 | df_llms['r2_centaur'] = 1 - (ll_centaur/ll_random) 44 | df_llms['r2_llama'] = 1 - (ll_llama/ll_random) 45 | df_llms['r2_delta'] = df_llms['r2_centaur'] - df_llms['r2_llama'] 46 | 47 | print((df_llms['r2_delta'].values < 0).sum()) 48 | print(df_llms['r2_delta'].values.mean()) 49 | print(df_llms['r2_delta'].values.std()) 50 | 51 | result = sm.ols(formula="r2_delta ~ task_type + num_participants + num_choices + num_characters - 1", data=df_llms).fit() 52 | print(result.summary()) 53 | 54 | result.params = result.params.set_axis(['Decision-making', 'Markov decision processes', 'Memory', 'Miscellaneous', 'Multi-armed bandits', 'Supervised learning', 'Participants', 'Choices', 'Characters']) 55 | result.bse = result.bse.set_axis(['Decision-making', 'Markov decision processes', 'Memory', 'Miscellaneous', 'Multi-armed bandits', 'Supervised learning', 'Participants', 'Choices', 'Characters']) 56 | 57 | 58 | plt.style.use(['nature']) 59 | fig = plt.figure(figsize=(7.08661, 3)) 60 | 61 | gs = gridspec.GridSpec(1, 2, width_ratios=[0.5, 0.5]) 62 | 63 | ax = fig.add_subplot(gs[:, 0]) 64 | result.params[:6].plot(kind='bar', yerr=result.bse[:6], ax=ax,legend=False, color='grey', alpha=0.8) 65 | ax.text(-0.11, 1.11, 'a', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 66 | ax.set_ylim(0, 0.21) 67 | 68 | ax = fig.add_subplot(gs[:, 1]) 69 | result.params[6:].plot(kind='bar', yerr=result.bse[6:], ax=ax,legend=False, color='grey', alpha=0.8) 70 | ax.text(-0.11, 1.11, 'b', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 71 | ax.set_ylim(0, 3e-5) 72 | 73 | plt.tight_layout() 74 | sns.despine() 75 | plt.savefig('figures/fig5.pdf', bbox_inches='tight') 76 | plt.show() 77 | -------------------------------------------------------------------------------- /camera_ready/fig5.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | import math 6 | import scienceplots 7 | import matplotlib.gridspec as gridspec 8 | from functools import reduce 9 | import torch 10 | import math 11 | from scipy import stats 12 | from datasets import load_dataset 13 | import matplotlib as mpl 14 | 15 | test_data = load_dataset("marcelbinz/Psych-101-test")['test'] 16 | test_participants = test_data.filter(lambda example: example['experiment'] == "hilbig2014generalized/exp1.csv")['participant'] 17 | test_participants = [int(a) for a in test_participants] 18 | 19 | nll_centaur = torch.stack(torch.load('data/log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth', weights_only=True)) 20 | AIC_centaur_test = 2 * nll_centaur[test_participants].sum().item() 21 | print('Centaur AIC (test):', AIC_centaur_test) 22 | 23 | 24 | nll_cog = torch.load('data/cognitive_nlls.pth', weights_only=True) 25 | print(nll_cog.shape) 26 | AIC_cog = 2 * torch.Tensor([1, 1, 1, 1, 2]).unsqueeze(-1) + 2 * nll_cog[:5].sum(-1) 27 | print(AIC_cog[:, test_participants]) 28 | AIC_cog_test = AIC_cog[:, test_participants].sum(-1) 29 | print('Cognitive model AIC (test):', AIC_cog_test) 30 | 31 | #AIC_cog_test = torch.cat([AIC_cog_test[:3], torch.zeros(3), AIC_cog_test[:4], torch.zeros(3), AIC_cog_test]) 32 | labels = [ 33 | 'Weighted-additive strategy', 34 | 'Equal weighting', 35 | 'Take-the-best heuristic', 36 | 'DeepSeek-R1 discovered', 37 | 'Scientific regret minimization' 38 | ] 39 | print(AIC_cog_test) 40 | 41 | 42 | plt.style.use(['nature']) 43 | mpl.rcParams.update({ 44 | "pdf.fonttype": 42, 45 | "pdf.fonttype": 42, 46 | "text.usetex": False, 47 | }) 48 | gs = gridspec.GridSpec(1, 5, width_ratios=[0.175, 0.1, 0.25, 0.1, 0.325]) 49 | fig = plt.figure(figsize=(7.20472, 2.8)) 50 | 51 | ax = fig.add_subplot(gs[0, 0]) 52 | ax.bar(np.arange(AIC_cog_test[:3].shape[0]), AIC_cog_test[:3], color='#cbc9e2', width=0.75) 53 | ax.set_xticks(np.arange(AIC_cog_test[:3].shape[0]), labels[:3], rotation=90) 54 | ax.axhline(y=AIC_centaur_test, color='#69005f', linestyle='--', linewidth=1.0, alpha=0.8) 55 | ax.set_ylabel('AIC') 56 | ax.set_ylim(0, 420) 57 | 58 | 59 | ax = fig.add_subplot(gs[0, 2]) 60 | ax.bar(np.arange(AIC_cog_test[:4].shape[0]), AIC_cog_test[:4], color='#cbc9e2', width=0.75) 61 | ax.set_xticks(np.arange(AIC_cog_test[:4].shape[0]), labels[:4], rotation=90) 62 | ax.axhline(y=AIC_centaur_test, color='#69005f', linestyle='--', linewidth=1.0, alpha=0.8) 63 | ax.set_ylabel('AIC') 64 | ax.set_ylim(0, 420) 65 | 66 | ax = fig.add_subplot(gs[0, 4]) 67 | ax.bar(np.arange(AIC_cog_test[:5].shape[0]), AIC_cog_test[:5], color='#cbc9e2', width=0.75) 68 | ax.set_xticks(np.arange(AIC_cog_test[:5].shape[0]), labels[:5], rotation=90) 69 | ax.axhline(y=AIC_centaur_test, color='#69005f', linestyle='--', linewidth=1.0, alpha=0.8) 70 | ax.set_ylabel('AIC') 71 | ax.set_ylim(0, 420) 72 | 73 | fig.text(0.012, 0.955, 'a', fontsize=8, weight='bold') 74 | fig.text(0.338, 0.955, 'b', fontsize=8, weight='bold') 75 | fig.text(0.714, 0.955, 'c', fontsize=8, weight='bold') 76 | fig.text(0.988, 0.661, 'Centaur', fontsize=6, color='#69005f', horizontalalignment='right', alpha=0.8) 77 | 78 | ''' 79 | plt.ylabel('AIC') 80 | plt.axhline(y=AIC_centaur_test, xmin=0, xmax=0.168, color='#69005f', linestyle='--', linewidth=1.0, alpha=0.8) 81 | plt.axhline(y=AIC_centaur_test, xmin=0.332, xmax=0.555, color='#69005f', linestyle='--', linewidth=1.0, alpha=0.8) 82 | plt.axhline(y=AIC_centaur_test, xmin=0.72, xmax=1, color='#69005f', linestyle='--', linewidth=1.0, alpha=0.8) 83 | 84 | plt.xticks([0, 1, 2, 6, 7, 8, 9, 13, 14, 15, 16, 17], labels, rotation=90) 85 | plt.ylim(0) 86 | plt.xlim(-0.5, 17.5)''' 87 | 88 | 89 | sns.despine() 90 | plt.tight_layout() 91 | plt.savefig('figures/fig5_part1.pdf', bbox_inches='tight') 92 | plt.show() 93 | -------------------------------------------------------------------------------- /plots/fig3.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | import math 6 | import scienceplots 7 | import matplotlib.gridspec as gridspec 8 | from functools import reduce 9 | 10 | plot_8b = False 11 | 12 | df_llama_70b = pd.read_csv('../results/all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv') 13 | df_llama_70b = df_llama_70b[df_llama_70b['unseen'] == 'experiments'][['task', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit']] 14 | 15 | df_centaur_70b = pd.read_csv('../results/all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv') 16 | df_centaur_70b = df_centaur_70b[df_centaur_70b['unseen'] == 'experiments'][['task', 'marcelbinz/Llama-3.1-Centaur-70B-adapter']] 17 | 18 | df_baseline = pd.read_csv('../results/all_data_baseline.csv') 19 | df_baseline = df_baseline[df_baseline['unseen'] == 'experiments'][['task', 'baseline']] 20 | 21 | df_random = pd.read_csv('../results/all_data_random.csv') 22 | df_random = df_random[df_random['unseen'] == 'experiments'][['task', 'random']] 23 | 24 | if plot_8b: 25 | df_llama_8b = pd.read_csv('../results/all_data_unsloth-Meta-Llama-3.1-8B-bnb-4bit.csv') 26 | df_llama_8b = df_llama_8b[df_llama_8b['unseen'] == 'experiments'][['task', 'unsloth/Meta-Llama-3.1-8B-bnb-4bit']] 27 | 28 | df_centaur_8b = pd.read_csv('../results/all_data_marcelbinz-Llama-3.1-Centaur-8B-adapter.csv') 29 | df_centaur_8b = df_centaur_8b[df_centaur_8b['unseen'] == 'experiments'][['task', 'marcelbinz/Llama-3.1-Centaur-8B-adapter']] 30 | 31 | df = reduce(lambda left,right: pd.merge(left,right,on=['task'], how='outer'), [df_llama_70b, df_centaur_70b, df_baseline, df_llama_8b, df_centaur_8b, df_random]) 32 | else: 33 | df = reduce(lambda left,right: pd.merge(left,right,on=['task'], how='outer'), [df_llama_70b, df_centaur_70b, df_baseline, df_random]) 34 | 35 | df2 = pd.read_csv('../results/fig3_data.csv') 36 | print(df2) 37 | print(df) 38 | #print(dfgdfgfd) 39 | gs = gridspec.GridSpec(1, 3, width_ratios=[0.3333, 0.3333, 0.3333]) 40 | if plot_8b: 41 | model_names = ['random', 'marcelbinz/Llama-3.1-Centaur-70B-adapter', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit', 'baseline', 'marcelbinz/Llama-3.1-Centaur-8B-adapter', 'unsloth/Meta-Llama-3.1-8B-bnb-4bit'] 42 | else: 43 | model_names = ['random', 'marcelbinz/Llama-3.1-Centaur-70B-adapter', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit', 'baseline'] 44 | 45 | plt.style.use(['nature']) 46 | fig = plt.figure(figsize=(7.08661, 1.9)) 47 | for task_index, task in enumerate(df['task']): 48 | print(task) 49 | scale = 1 if task_index == 1 else 0.5 50 | df_task = df[df['task'] == task][model_names].values.flatten() 51 | ll_random = -df_task[0] 52 | df_task = 1 - (-df_task[1:]/ll_random) 53 | df_task[df_task != df_task] = 0 54 | print(df_task) 55 | ax = fig.add_subplot(gs[:, task_index]) 56 | 57 | if plot_8b: 58 | ax.bar(np.arange(5), df_task, color=['#69005f', '#ff506e', '#cbc9e2', 'C0', 'C1']) 59 | ax.set_xticks(np.arange(5), ['Centaur\n(70B)', 'Llama\n(70B)', 'Cognitive\nmodel', 'Centaur\n(8B)', 'Llama\n(8B)'], size=5) 60 | else: 61 | ax.bar(np.arange(3), df_task, color=['#69005f', '#ff506e', '#cbc9e2']) 62 | ax.set_xticks(np.arange(3), ['Centaur', 'Llama', 'Cognitive\nmodel']) 63 | 64 | if task_index == 2: 65 | if plot_8b: 66 | ax.text(0.45, 0.15, 'N/A', transform=ax.transAxes, va='top') 67 | else: 68 | ax.text(0.775, 0.15, 'N/A', transform=ax.transAxes, va='top') 69 | if task_index == 0: 70 | ax.set_ylabel(r'Pseudo-R$^2$') 71 | ax.containers[0][0].set_alpha(0.8) 72 | ax.containers[0][1].set_alpha(0.8) 73 | ax.containers[0][2].set_alpha(1) 74 | if plot_8b: 75 | ax.containers[0][3].set_alpha(0.8) 76 | ax.containers[0][4].set_alpha(0.8) 77 | 78 | sns.despine() 79 | plt.tight_layout() 80 | plt.savefig('figures/fig3_8b=' + str(plot_8b) + '.pdf', bbox_inches='tight') 81 | plt.show() 82 | -------------------------------------------------------------------------------- /plots/fig14.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import seaborn as sns 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import matplotlib.colors as colors 8 | import numpy as np 9 | import nilearn 10 | import nibabel as nib 11 | from nilearn import plotting, datasets, surface, image 12 | from nilearn.plotting import plot_stat_map 13 | from matplotlib.lines import Line2D 14 | import matplotlib.cm as cm 15 | import matplotlib.colors as mcolors 16 | 17 | gs = gridspec.GridSpec(1, 2, width_ratios=[0.33333, 0.66666]) 18 | plt.style.use(['nature']) 19 | fig = plt.figure(figsize=(7.08661, 7.08661/3)) 20 | 21 | # subplot 1 22 | ax = fig.add_subplot(gs[:, 0]) 23 | image = plt.imread('test.png') 24 | cax = ax.imshow(image) 25 | ax.axis('off') 26 | ax.set_title('Pearson correlation', y=1.19) 27 | 28 | # Define a colormap 29 | cmap = cmap=sns.color_palette("rocket_r", as_cmap=True) 30 | 31 | # Create a norm (normalization of values) 32 | norm = mcolors.Normalize(vmin=0, vmax=0.8) 33 | 34 | # Create a ScalarMappable with colormap and norm 35 | sm = cm.ScalarMappable(cmap=cmap, norm=norm) 36 | sm.set_array([]) # Required for the colorbar 37 | 38 | cbar = plt.colorbar(sm, ax=ax, orientation="vertical") 39 | ax.text(-0.13, 1.32, 'a', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') 40 | 41 | # subplot 2 42 | ax = fig.add_subplot(gs[:, 1]) 43 | 44 | layer = 2 45 | 46 | df_centaur_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_centaur_alignment.csv') 47 | df_llama_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_llama_alignment.csv') 48 | df_random_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_random_alignment.csv') 49 | df_cognitive_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_cognitive_alignment.csv') 50 | 51 | print(df_cognitive_tst) 52 | 53 | columns_to_keep = [ 54 | ['Left Accumbens', 'Right Accumbens'], 55 | ["b'7Networks_LH_Limbic_OFC_1'", "b'7Networks_LH_Default_pCunPCC_1'", "b'7Networks_RH_Limbic_OFC_1'", "b'7Networks_RH_Default_PFCdPFCm_1'"], 56 | ["b'7Networks_LH_SomMot_1'", "b'7Networks_LH_SomMot_2'", "b'7Networks_LH_SomMot_3'", "b'7Networks_LH_SomMot_4'", "b'7Networks_LH_SomMot_5'", "b'7Networks_LH_SomMot_6'"], 57 | ["b'7Networks_LH_Vis_1'", "b'7Networks_LH_Vis_2'", "b'7Networks_LH_Vis_3'", "b'7Networks_LH_Vis_4'", "b'7Networks_LH_Vis_5'", "b'7Networks_LH_Vis_6'", "b'7Networks_LH_Vis_7'", "b'7Networks_LH_Vis_8'", "b'7Networks_LH_Vis_9'", "b'7Networks_RH_Vis_1'", "b'7Networks_RH_Vis_2'", "b'7Networks_RH_Vis_3'", "b'7Networks_RH_Vis_4'", "b'7Networks_RH_Vis_5'", "b'7Networks_RH_Vis_6'", "b'7Networks_RH_Vis_7'", "b'7Networks_RH_Vis_8'"], 58 | ] 59 | 60 | df_centaur_tst = df_centaur_tst.iloc[layer] 61 | df_llama_tst = df_llama_tst.iloc[layer] 62 | df_random_tst = df_random_tst.iloc[layer] 63 | df_cognitive_tst = df_cognitive_tst.iloc[0] 64 | print(df_cognitive_tst) 65 | for i, column in enumerate(columns_to_keep): 66 | ax.bar(np.array([-1/5, 0, 1/5, 0.4]) + i, [df_centaur_tst[column].values.mean(), df_llama_tst[column].values.mean(), df_cognitive_tst[column].values.mean(), df_random_tst[column].values.mean()], alpha=0.8, color=['#69005f', '#ff506e', '#cbc9e2', 'grey'], width=1/5) 67 | ax.set_ylabel('Pearson correlation') 68 | ax.set_xticks([0.1, 1.1, 2.1, 3.1], ['Accumbens', 'Medial PFC', 'Motor Cortex', 'Visual Cortex']) 69 | 70 | custom_lines_r2 = [Line2D([0], [0], color='#69005f', alpha=0.8, marker="o", linestyle='None', markersize=5), 71 | Line2D([0], [0], color='#ff506e', alpha=0.8, marker="o", linestyle='None', markersize=5), 72 | Line2D([0], [0], color='#cbc9e2', alpha=0.8, marker="o", linestyle='None', markersize=5), 73 | Line2D([0], [0], color='grey', alpha=0.8, marker="o", linestyle='None', markersize=5),] 74 | ax.legend(custom_lines_r2, ['Centaur', 'Llama', 'Cognitive model', 'Control'], frameon=False, ncols=4, bbox_to_anchor=(0.5, 1.3), loc='upper center') 75 | ax.text(-0.06, 1.24, 'b', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') 76 | 77 | sns.despine() 78 | 79 | plt.tight_layout() 80 | plt.savefig('figures/fig14.pdf', bbox_inches='tight') 81 | 82 | plt.show() 83 | -------------------------------------------------------------------------------- /test_adapter_full_log_likelihoods.py: -------------------------------------------------------------------------------- 1 | from transformers import TrainingArguments 2 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 3 | from unsloth import FastLanguageModel 4 | from datasets import load_dataset 5 | import pandas as pd 6 | import argparse 7 | import torch 8 | 9 | def full_log_likelihoods(logits, labels): 10 | with torch.no_grad(): 11 | logits = logits.float().cpu() 12 | labels = labels.cpu() 13 | labels = torch.cat((labels[0, 1:], -100 * torch.ones(1).long()), 0) 14 | logits = logits[0] 15 | ce = torch.nn.functional.cross_entropy(logits, labels, reduction='none') 16 | total_loss = [] 17 | item_loss = 0 18 | item_counter = 0 19 | for i in range(ce.shape[0]): 20 | if labels[i] != -100: 21 | item_loss += ce[i] 22 | item_counter += 1 23 | else: 24 | if item_counter != 0: 25 | total_loss.append(item_loss) 26 | item_loss = 0 27 | item_counter = 0 28 | return torch.Tensor(total_loss) 29 | 30 | 31 | def compute_metrics(pred): 32 | return {'custom_loss': pred.predictions} 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--model", type=str, required=True) 37 | args = parser.parse_args() 38 | 39 | task_names = [ 40 | "badham2017deficits", 41 | "bahrami2020four", 42 | "enkavi2019adaptivenback", 43 | "enkavi2019digitspan", 44 | "enkavi2019gonogo", 45 | "enkavi2019recentprobes", 46 | "feng2021dynamics", 47 | "flesch2018comparing", 48 | "frey2017cct", 49 | "frey2017risk", 50 | "gershman2018deconstructing", 51 | "gershman2020reward", 52 | "hebart2023things", 53 | "hilbig2014generalized", 54 | "kool2016when", 55 | "kool2017cost", 56 | "lefebvre2017behavioural", 57 | "levering2020revisiting", 58 | "ludwig2023human", 59 | "peterson2021using", 60 | "plonsky2018when", 61 | "ruggeri2022globalizability", 62 | "sadeghiyeh2020temporal", 63 | "schulz2020finding", 64 | "somerville2017charting", 65 | "speekenbrink2008learning", 66 | "steingroever2015data", 67 | "tomov2020discovery", 68 | "tomov2021multitask", 69 | "waltz2020differential", 70 | "wilson2014humans", 71 | "wu2023chunking", 72 | "wulff2018description", 73 | "wulff2018sampling", 74 | "xiong2023neural", 75 | "zorowitz2023data", 76 | "collsiöö2023MCPL", 77 | "cox2017information", 78 | "garcia2023experiential", 79 | "jansen2021dunningkruger", 80 | "krueger2022identifying", 81 | "kumar2023disentangling", 82 | "popov2023intent", 83 | "wise2019acomputational", 84 | "wu2018generalisation", 85 | "zhu2020bayesian", 86 | ] 87 | 88 | model, tokenizer = FastLanguageModel.from_pretrained( 89 | model_name = args.model, 90 | max_seq_length = 32768, 91 | dtype = None, 92 | load_in_4bit = True, 93 | ) 94 | l_id = tokenizer(" <<").input_ids[1:] 95 | r_id = tokenizer(">>").input_ids[1:] 96 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 97 | dataset = load_dataset("marcelbinz/Psych-101-test") 98 | is_quantized = model.is_quantized 99 | 100 | data = {} 101 | with torch.no_grad(): 102 | for task_name in task_names: 103 | eval_dataset = dataset['test'].filter(lambda example: example['experiment'].startswith(task_name)) 104 | 105 | model.is_quantized = False 106 | training_args = TrainingArguments( 107 | output_dir="eval", 108 | per_device_eval_batch_size=1, 109 | eval_accumulation_steps=1 110 | ) 111 | trainer = SFTTrainer( 112 | model=model, 113 | tokenizer=tokenizer, 114 | args=training_args, 115 | train_dataset=eval_dataset, 116 | eval_dataset=eval_dataset, 117 | dataset_text_field="text", 118 | max_seq_length=32768, 119 | data_collator=collator, 120 | compute_metrics=compute_metrics, 121 | preprocess_logits_for_metrics=full_log_likelihoods, 122 | ) 123 | model.is_quantized = is_quantized 124 | result = trainer.evaluate() 125 | 126 | print(task_name, flush=True) 127 | print(result, flush=True) 128 | data[task_name] = result['eval_custom_loss'] 129 | 130 | torch.save(data, 'results/custom_metrics_full_log_likelihoods_' + args.model.replace('/', '-') + '.pth') -------------------------------------------------------------------------------- /plots/fig7.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import scienceplots 5 | import matplotlib.gridspec as gridspec 6 | import seaborn as sns 7 | from matplotlib.lines import Line2D 8 | 9 | gs = gridspec.GridSpec(1, 2, width_ratios=[0.5, 0.5]) 10 | 11 | agents = [ 12 | 'Centaur', 13 | 'Llama', 14 | 'Human', 15 | ] 16 | 17 | plt.style.use(['nature']) 18 | fig = plt.figure(figsize=(7.08661, 3)) 19 | 20 | color_1 = '#69005f' 21 | color_2 = '#ff506e' 22 | custom_lines_r2 = [Line2D([0], [0], color=color_1, alpha=0.8, marker="o", linestyle='None', markersize=5), Line2D([0], [0], color=color_2, alpha=0.8, marker="o", linestyle='None', markersize=5)] 23 | 24 | ax = fig.add_subplot(gs[:, 0]) 25 | df = pd.read_csv('../results/CogBench/performance.csv') 26 | df = df[df['Agent'].isin(agents)] 27 | ci = [col for col in df.columns if '_ci' in col] 28 | not_ci = [col for col in df.columns if not '_ci' in col] 29 | df_T = df[not_ci].drop(['Unnamed: 0'], axis=1).set_index('Agent').T 30 | df.columns = df.columns.str.rstrip('_x') 31 | df_T_ci = df[ci] 32 | df_T_ci.columns = df_T_ci.columns.str.rstrip('_ci') 33 | df_T_ci = df_T_ci.set_index('Agent').T 34 | df_T = df_T.drop(columns=['Human']) 35 | df_T.index = df_T.index.str.replace('Probabilistic Reasoning', 'Probabilistic reasoning') 36 | df_T.index = df_T.index.str.replace('Horizon Task', 'Horizon task') 37 | df_T.index = df_T.index.str.replace('Restless Bandit', 'Restless bandit') 38 | df_T.index = df_T.index.str.replace('Instrumental Learning', 'Instrumental learning') 39 | df_T.index = df_T.index.str.replace('BART', 'Balloon analog risk task') 40 | df_T.index = df_T.index.str.replace('Two Step Task', 'Two-step task') 41 | df_T_ci.index = df_T.index.str.replace('Probabilistic Reasoning', 'Probabilistic reasoning') 42 | df_T_ci.index = df_T.index.str.replace('Horizon Task', 'Horizon task') 43 | df_T_ci.index = df_T.index.str.replace('Restless Bandit', 'Restless bandit') 44 | df_T_ci.index = df_T.index.str.replace('Instrumental Learning', 'Instrumental learning') 45 | df_T_ci.index = df_T.index.str.replace('BART', 'Balloon analog risk task') 46 | df_T_ci.index = df_T.index.str.replace('Two Step Task', 'Two-step task') 47 | df_T = df_T[['Centaur', 'Llama']] 48 | df_T_ci = df_T_ci[['Centaur', 'Llama']] 49 | ax.text(-0.06, 1.2, 'a', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 50 | ax.text(5.5, 1.1, 'Humans', fontsize=6, color='grey', horizontalalignment='right') 51 | ax.text(5.5, -0.3, 'Random', fontsize=6, color='grey', horizontalalignment='right') 52 | df_T.plot(kind='bar', yerr=df_T_ci, ax=ax, legend=False, color=['#69005f', '#ff506e'], alpha=0.8) 53 | ax.legend(custom_lines_r2, ['Centaur', 'Llama'], frameon=False, ncols=3, bbox_to_anchor=(0.5, 1.3), loc='upper center') 54 | ax.set_ylim(-0.6, 2.2) 55 | ax.set_ylabel('Performance') 56 | ax.hlines(y=1, xmin=-1, xmax=20, color='grey', linestyle='--', linewidth=1.0) 57 | ax.hlines(y=0, xmin=-1, xmax=20, color='grey', linestyle='--', linewidth=1.0) 58 | 59 | ax = fig.add_subplot(gs[:, 1]) 60 | df = pd.read_csv('../results/CogBench/behaviour.csv') 61 | df = df[df['Agent'].isin(agents)] 62 | ci = [col for col in df.columns if '_ci' in col] 63 | not_ci = [col for col in df.columns if not '_ci' in col] 64 | df_T = df[not_ci].drop(['Unnamed: 0'], axis=1).set_index('Agent').T 65 | df.columns = df.columns.str.rstrip('_x') 66 | df_T_ci = df[ci] 67 | df_T_ci.columns = df_T_ci.columns.str.rstrip('_ci') 68 | df_T_ci = df_T_ci.set_index('Agent').T 69 | df_T = df_T.drop(columns=['Human']) 70 | df_T = df_T[['Centaur', 'Llama']] 71 | df_T_ci = df_T_ci[['Centaur', 'Llama']] 72 | ax.text(-0.06, 1.2, 'b', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 73 | ax.text(9.5, 1.1, 'Humans', fontsize=6, color='grey', horizontalalignment='right') 74 | ax.text(9.5, -0.3, 'Random', fontsize=6, color='grey', horizontalalignment='right') 75 | df_T.plot(kind='bar', yerr=df_T_ci, ax=ax,legend=False, color=['#69005f', '#ff506e'], alpha=0.8) 76 | ax.set_ylim(-0.6, 2.1) 77 | ax.legend(custom_lines_r2, ['Centaur', 'Llama'], frameon=False, ncols=3, bbox_to_anchor=(0.5, 1.3), loc='upper center') 78 | ax.set_ylabel('Parameter value') 79 | ax.hlines(y=1, xmin=-1, xmax=20, color='grey', linestyle='--', linewidth=1.0) 80 | ax.hlines(y=0, xmin=-1, xmax=20, color='grey', linestyle='--', linewidth=1.0) 81 | 82 | plt.tight_layout() 83 | sns.despine() 84 | plt.savefig('figures/fig7.pdf', bbox_inches='tight') 85 | plt.show() 86 | -------------------------------------------------------------------------------- /plots/fig8.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | import math 6 | import scienceplots 7 | import matplotlib.gridspec as gridspec 8 | from functools import reduce 9 | from brokenaxes import brokenaxes 10 | 11 | selected_tasks = ['peterson2021using', 'ruggeri2022globalizability'] 12 | 13 | df_random = pd.read_csv('../results/all_data_random.csv') 14 | df_random = df_random[df_random['task'].isin(selected_tasks)][['task', 'random']] 15 | 16 | df_llama_70b = pd.read_csv('../results/all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv') 17 | df_llama_70b = df_llama_70b[df_llama_70b['task'].isin(selected_tasks)][['task', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit']] 18 | 19 | df_centaur_70b = pd.read_csv('../results/all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv') 20 | df_centaur_70b = df_centaur_70b[df_centaur_70b['task'].isin(selected_tasks)][['task', 'marcelbinz/Llama-3.1-Centaur-70B-adapter']] 21 | 22 | df_baseline = pd.read_csv('../results/all_data_baseline.csv') 23 | df_baseline = df_baseline[df_baseline['task'].isin(selected_tasks)][['task', 'baseline']] 24 | 25 | df_centaur_70b_no_history = pd.read_csv('../ceiling/results/marcelbinz-Llama-3.1-Centaur-70B-adapter.csv', index_col=0) 26 | df_centaur_70b_no_history = df_centaur_70b_no_history.rename(columns={'marcelbinz/Llama-3.1-Centaur-70B-adapter': 'marcelbinz/Llama-3.1-Centaur-70B-adapter-no-history'}) 27 | df_centaur_70b_no_history['task'] = df_centaur_70b_no_history['task'].str.replace('/prompts_zeroshot.jsonl','') 28 | 29 | df_llama_70b_no_history = pd.read_csv('../ceiling/results/unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv', index_col=0) 30 | df_llama_70b_no_history = df_llama_70b_no_history.rename(columns={'unsloth/Meta-Llama-3.1-70B-bnb-4bit': 'unsloth/Meta-Llama-3.1-70B-bnb-4bit-no-history'}) 31 | df_llama_70b_no_history['task'] = df_llama_70b_no_history['task'].str.replace('/prompts_zeroshot.jsonl','') 32 | 33 | df_ceiling = pd.read_csv('../ceiling/results/ceiling.csv') 34 | df_ceiling = df_ceiling.rename(columns={'nll': 'ceiling'}) 35 | df_ceiling = df_ceiling[df_ceiling['task'].isin(selected_tasks)][['task', 'ceiling']] 36 | 37 | df = reduce(lambda left,right: pd.merge(left,right,on=['task'], how='outer'), [df_random, df_llama_70b, df_centaur_70b, df_centaur_70b_no_history, df_llama_70b_no_history, df_ceiling, df_baseline]) 38 | print(df) 39 | 40 | model_names = ['random', 'marcelbinz/Llama-3.1-Centaur-70B-adapter', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit', 'baseline', 'ceiling', 'marcelbinz/Llama-3.1-Centaur-70B-adapter-no-history', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit-no-history'] 41 | offsets = [0.01, 0.01] 42 | 43 | gs = gridspec.GridSpec(1, 2, width_ratios=[0.5, 0.5]) 44 | plt.style.use(['nature']) 45 | fig = plt.figure(figsize=(7.08661, 1.9)) 46 | for task_index, task in enumerate(df['task']): 47 | print(task) 48 | df_task = df[df['task'] == task][model_names].values.flatten() 49 | df_task = df_task[[1, 2, 3, 4, 5, 6, 0]] 50 | df_task[df_task != df_task] = 0 51 | print(df_task) 52 | cutoff = (2.67, 2.72) if task_index == 0 else (2.0, 2.05) 53 | ax = brokenaxes(ylims=((.4, .74), cutoff), subplot_spec=gs[task_index]) 54 | 55 | ax.bar(np.arange(6), df_task[:-1], color=['#69005f', '#ff506e', '#cbc9e2', 'white', '#69005f', '#ff506e']) 56 | ax.set_xticks(np.arange(6), ['Centaur', 'Llama', 'Cog.\nmodel', 'Noise\nceiling', 'Centaur\n(ind.)', 'Llama\n(ind.)',]) 57 | ax.axhline(y=df_task[-1], color='grey', linestyle='--', linewidth=1.0) 58 | ax.axs[-1].text(5.65, df_task[-1] + offsets[task_index], 'Random guessing', fontsize=6, color='grey', horizontalalignment='right') 59 | ax.set_title('choices13k' if task_index == 0 else 'Intertemporal choice', fontsize=8) 60 | fig.axes[-1].text(-0.13, 1.12, 'a' if task_index == 0 else 'b', fontsize=8, fontweight='bold', va='top') 61 | if task_index == 0: 62 | ax.set_ylabel('Negative log-likelihood') 63 | 64 | print(ax.containers) 65 | ax.axs[-1].containers[0][0].set_alpha(0.8) 66 | ax.axs[-1].containers[0][1].set_alpha(0.8) 67 | ax.axs[-1].containers[0][2].set_alpha(1) 68 | ax.axs[-1].containers[0][3].set_alpha(0.5) 69 | ax.axs[-1].containers[0][4].set_alpha(0.8) 70 | ax.axs[-1].containers[0][5].set_alpha(0.8) 71 | ax.axs[-1].containers[0][3].set_edgecolor('black') 72 | ax.axs[-1].containers[0][4].set_hatch('///') 73 | ax.axs[-1].containers[0][5].set_hatch('///') 74 | 75 | #sns.despine() 76 | #plt.tight_layout() 77 | plt.savefig('figures/fig8.pdf', bbox_inches='tight') 78 | plt.show() 79 | -------------------------------------------------------------------------------- /extended_data/ed_fig3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import numpy as np 8 | from scipy import stats 9 | 10 | centaur_70b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 11 | centaur_8b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-8B-adapter.pth') 12 | llama_70b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 13 | llama_8b = torch.load('../generalization/results/additional_generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-8B-bnb-4bit.pth') 14 | 15 | nll_random = { 16 | 'additional_experiments/awad2018moral.jsonl': -math.log(1/2), 17 | 'additional_experiments/demircan2024evaluatingcategory.jsonl': -math.log(1/2), 18 | 'additional_experiments/demircan2024evaluatingreward.jsonl': -math.log(1/2), 19 | 'additional_experiments/akata2023repeatedgames.jsonl': -math.log(1/2), 20 | 'additional_experiments/singh2022representing.jsonl': -math.log(1/7), 21 | 'additional_experiments/xu2021novelty.jsonl': -math.log(1/3), 22 | } 23 | 24 | task_names = { 25 | 'additional_experiments/awad2018moral.jsonl': 'Moral decision-making', 26 | 'additional_experiments/demircan2024evaluatingcategory.jsonl': 'Naturalistic category learning', 27 | 'additional_experiments/demircan2024evaluatingreward.jsonl': 'Naturalistic reward learning', 28 | 'additional_experiments/akata2023repeatedgames.jsonl': 'Economic games', 29 | 'additional_experiments/singh2022representing.jsonl': 'Behavioral propensities', 30 | 'additional_experiments/xu2021novelty.jsonl': 'Deep sequential decision task', 31 | } 32 | 33 | gs = gridspec.GridSpec(2, 3, width_ratios=[0.3333, 0.3333, 0.3333]) 34 | plt.style.use(['nature']) 35 | fig = plt.figure(figsize=(7.20472, 3.8)) 36 | 37 | offsets = [0.01, 0.01, 0.005, 0.01, 0.02, 0.01] 38 | for i, key in enumerate(centaur_70b.keys()): 39 | print(key) 40 | centaur_70b_r2 = centaur_70b[key].mean().item() 41 | centaur_8b_r2 = centaur_8b[key].mean().item() 42 | llama_70b_r2 = llama_70b[key].mean().item() 43 | llama_8b_r2 = llama_8b[key].mean().item() 44 | centaur_70b_r2_se = centaur_70b[key].std().item() / math.sqrt(len(centaur_70b[key])) 45 | centaur_8b_r2_se = centaur_8b[key].std().item() / math.sqrt(len(centaur_8b[key])) 46 | llama_70b_r2_se = llama_70b[key].std().item() / math.sqrt(len(llama_70b[key])) 47 | llama_8b_r2_se = llama_8b[key].std().item() / math.sqrt(len(llama_8b[key])) 48 | res = stats.ttest_ind(centaur_70b[key], llama_70b[key], alternative='less') 49 | 50 | print('t(' + str(int(res.df)) + ') = ' + str(np.round(res.statistic, 2)) + ', p = ' + str(np.round(res.pvalue, 6))) 51 | 52 | ax = fig.add_subplot(gs[0 if i < 3 else 1, i % 3]) 53 | values = np.array([centaur_70b_r2, centaur_8b_r2, llama_70b_r2, llama_8b_r2]) 54 | ax.bar(np.arange(4), values, yerr=[centaur_70b_r2_se, centaur_8b_r2_se, llama_70b_r2_se, llama_8b_r2_se], color=['#69005f', '#69005f', '#ff506e', '#ff506e'])# 'C0', 'C1' 55 | ax.set_xticks(np.arange(4), ['Centaur', 'Minitaur', 'Llama\n(70B)', 'Llama\n(8B)']) 56 | ax.axhline(y=nll_random[key], color='grey', linestyle='--', linewidth=1.0) 57 | ax.text(3.5, nll_random[key] + offsets[i], 'Random guessing', fontsize=6, color='grey', horizontalalignment='right') 58 | 59 | 60 | 61 | if i == 2: 62 | ax.set_ylim(0.58, 0.82) 63 | ax.set_yticks([0.6, 0.7, 0.8]) 64 | else: 65 | ax.set_ylim(0.9 * min(nll_random[key], min(values)), 1.1 * max(max(values), nll_random[key])) 66 | 67 | ax.set_ylabel('Negative log-likelihood') 68 | ax.containers[1][0].set_alpha(0.8) 69 | ax.containers[1][1].set_alpha(0.5) 70 | ax.containers[1][2].set_alpha(0.8) 71 | ax.containers[1][3].set_alpha(0.5) 72 | ax.set_title(task_names[key], fontsize=7) 73 | 74 | fig.text(0.012, 0.955, 'a', fontsize=8, weight='bold') 75 | fig.text(0.012, 0.465, 'b', fontsize=8, weight='bold') 76 | fig.text(0.344, 0.955, 'c', fontsize=8, weight='bold') 77 | fig.text(0.344, 0.465, 'd', fontsize=8, weight='bold') 78 | fig.text(0.67, 0.955, 'e', fontsize=8, weight='bold') 79 | fig.text(0.67, 0.465, 'f', fontsize=8, weight='bold') 80 | 81 | sns.despine() 82 | plt.tight_layout() 83 | plt.savefig('figures/fig3.jpg', bbox_inches='tight', dpi=300) 84 | plt.show() 85 | -------------------------------------------------------------------------------- /camera_ready/fig3.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | import math 6 | import scienceplots 7 | import matplotlib.gridspec as gridspec 8 | from functools import reduce 9 | import torch 10 | import math 11 | from scipy import stats 12 | import matplotlib as mpl 13 | 14 | def cohen_d(x,y): 15 | return (np.mean(x) -np. mean(y)) / math.sqrt((np.std(x, ddof=1) ** 2 + np.std(y, ddof=1) ** 2) / 2.0) 16 | 17 | centaur_70b = torch.load('../generalization/results/generalization_full_log_likelihoods_marcelbinz-Llama-3.1-Centaur-70B-adapter.pth') 18 | llama_70b = torch.load('../generalization/results/generalization_full_log_likelihoods_unsloth-Meta-Llama-3.1-70B-bnb-4bit.pth') 19 | 20 | df_baseline = pd.read_csv('../results/all_data_baseline.csv') 21 | df_baseline = df_baseline[df_baseline['unseen'] == 'experiments'][['task', 'baseline']] 22 | 23 | df_random = pd.read_csv('../results/all_data_random.csv') 24 | df_random = df_random[df_random['unseen'] == 'experiments'][['task', 'random']] 25 | 26 | means = {} 27 | sems = {} 28 | for key in centaur_70b.keys(): 29 | print(key) 30 | print(centaur_70b[key].shape) 31 | baseline = df_baseline[df_baseline['task'] == key] 32 | random = df_random[df_random['task'] == key] 33 | means[key] = [] 34 | sems[key] = [] 35 | means[key].append(centaur_70b[key].mean()) 36 | means[key].append(llama_70b[key].mean()) 37 | sems[key].append(centaur_70b[key].std() / math.sqrt(len(centaur_70b[key]))) 38 | sems[key].append(llama_70b[key].std() / math.sqrt(len(llama_70b[key]))) 39 | 40 | 41 | print(stats.ttest_ind(centaur_70b[key], llama_70b[key], alternative='less')) 42 | print("Cohen's d:", cohen_d(centaur_70b[key], llama_70b[key])) 43 | 44 | if len(baseline) > 0: 45 | means[key].append(baseline.baseline.item()) 46 | print(centaur_70b[key].shape) 47 | print(stats.ttest_1samp(centaur_70b[key], llama_70b[key].mean().item(), alternative='less')) 48 | print(stats.ttest_1samp(centaur_70b[key], baseline.baseline.item(), alternative='less')) 49 | print(stats.ttest_1samp(llama_70b[key], baseline.baseline.item(), alternative='less')) 50 | else: 51 | means[key].append(0) 52 | print(means) 53 | sems[key].append(0) 54 | means[key].append(random.random.item()) 55 | sems[key].append(0) 56 | print() 57 | 58 | plt.style.use(['nature']) 59 | mpl.rcParams.update({ 60 | "pdf.fonttype": 42, 61 | "pdf.fonttype": 42, 62 | "text.usetex": False, 63 | }) 64 | fig = plt.figure(figsize=(7.20472, 4.5)) 65 | gs = gridspec.GridSpec(2, 3, width_ratios=[0.3333, 0.3333, 0.3333], height_ratios=[1.25, 1]) 66 | 67 | ax = fig.add_subplot(gs[0, 0]) 68 | #image = plt.imread('tstcover.png') 69 | #cax = ax.imshow(image) 70 | ax.axis('off') 71 | ax.set_title('Modified cover story', fontsize=7) 72 | 73 | ax = fig.add_subplot(gs[0, 1]) 74 | #image = plt.imread('bandit.png') 75 | #cax = ax.imshow(image) 76 | ax.axis('off') 77 | ax.set_title('Modified problem structure', fontsize=7) 78 | 79 | ax = fig.add_subplot(gs[0, 2]) 80 | #image = plt.imread('logical2.png') 81 | #cax = ax.imshow(image) 82 | ax.axis('off') 83 | ax.set_title('Entirely novel domain', fontsize=7) 84 | 85 | 86 | #print(dfgdfgfd) 87 | 88 | offsets = [0.007, 0.021, 0.020] 89 | for task_index, task in enumerate(means.keys()): 90 | print(task) 91 | ax = fig.add_subplot(gs[1, task_index]) 92 | ax.bar(np.arange(3), means[task][:-1], yerr=sems[task][:-1], color=['#69005f', '#ff506e', '#cbc9e2']) 93 | ax.set_xticks(np.arange(3), ['Centaur', 'Llama', 'Cognitive\nmodel']) 94 | ax.axhline(y=means[task][-1], color='grey', linestyle='--', linewidth=1.0) 95 | ax.text(2.5, means[task][-1] + offsets[task_index], 'Random guessing', fontsize=6, color='grey', horizontalalignment='right') 96 | 97 | if task_index == 2: 98 | ax.text(0.775, 0.125, 'N/A', transform=ax.transAxes, va='top') 99 | 100 | ax.set_ylabel('Negative log-likelihood') 101 | ax.containers[1][0].set_alpha(0.8) 102 | ax.containers[1][1].set_alpha(0.8) 103 | ax.containers[1][2].set_alpha(1) 104 | ax.set_ylim(0.9 * means[task][0], 1.1 * means[task][-1]) 105 | 106 | fig.text(0.005, 0.961, 'a', fontsize=8, weight='bold') 107 | fig.text(0.34, 0.961, 'b', fontsize=8, weight='bold') 108 | fig.text(0.67, 0.961, 'c', fontsize=8, weight='bold') 109 | 110 | sns.despine() 111 | plt.tight_layout() 112 | plt.savefig('figures/fig3_part1.pdf', bbox_inches='tight') 113 | plt.show() 114 | -------------------------------------------------------------------------------- /extended_data/ed_fig5.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import numpy as np 5 | import math 6 | import scienceplots 7 | import matplotlib.gridspec as gridspec 8 | from functools import reduce 9 | from brokenaxes import brokenaxes 10 | 11 | selected_tasks = ['peterson2021using', 'ruggeri2022globalizability'] 12 | 13 | df_random = pd.read_csv('../results/all_data_random.csv') 14 | df_random = df_random[df_random['task'].isin(selected_tasks)][['task', 'random']] 15 | 16 | df_llama_70b = pd.read_csv('../results/all_data_unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv') 17 | df_llama_70b = df_llama_70b[df_llama_70b['task'].isin(selected_tasks)][['task', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit']] 18 | 19 | df_centaur_70b = pd.read_csv('../results/all_data_marcelbinz-Llama-3.1-Centaur-70B-adapter.csv') 20 | df_centaur_70b = df_centaur_70b[df_centaur_70b['task'].isin(selected_tasks)][['task', 'marcelbinz/Llama-3.1-Centaur-70B-adapter']] 21 | 22 | df_baseline = pd.read_csv('../results/all_data_baseline.csv') 23 | df_baseline = df_baseline[df_baseline['task'].isin(selected_tasks)][['task', 'baseline']] 24 | 25 | df_centaur_70b_no_history = pd.read_csv('../ceiling/results/marcelbinz-Llama-3.1-Centaur-70B-adapter.csv', index_col=0) 26 | df_centaur_70b_no_history = df_centaur_70b_no_history.rename(columns={'marcelbinz/Llama-3.1-Centaur-70B-adapter': 'marcelbinz/Llama-3.1-Centaur-70B-adapter-no-history'}) 27 | df_centaur_70b_no_history['task'] = df_centaur_70b_no_history['task'].str.replace('/prompts_zeroshot.jsonl','') 28 | 29 | df_llama_70b_no_history = pd.read_csv('../ceiling/results/unsloth-Meta-Llama-3.1-70B-bnb-4bit.csv', index_col=0) 30 | df_llama_70b_no_history = df_llama_70b_no_history.rename(columns={'unsloth/Meta-Llama-3.1-70B-bnb-4bit': 'unsloth/Meta-Llama-3.1-70B-bnb-4bit-no-history'}) 31 | df_llama_70b_no_history['task'] = df_llama_70b_no_history['task'].str.replace('/prompts_zeroshot.jsonl','') 32 | 33 | df_ceiling = pd.read_csv('../ceiling/results/ceiling.csv') 34 | df_ceiling = df_ceiling.rename(columns={'nll': 'ceiling'}) 35 | df_ceiling = df_ceiling[df_ceiling['task'].isin(selected_tasks)][['task', 'ceiling']] 36 | 37 | df = reduce(lambda left,right: pd.merge(left,right,on=['task'], how='outer'), [df_random, df_llama_70b, df_centaur_70b, df_centaur_70b_no_history, df_llama_70b_no_history, df_ceiling, df_baseline]) 38 | print(df) 39 | 40 | model_names = ['random', 'marcelbinz/Llama-3.1-Centaur-70B-adapter', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit', 'baseline', 'ceiling', 'marcelbinz/Llama-3.1-Centaur-70B-adapter-no-history', 'unsloth/Meta-Llama-3.1-70B-bnb-4bit-no-history'] 41 | offsets = [0.01, 0.01] 42 | 43 | gs = gridspec.GridSpec(1, 2, width_ratios=[0.5, 0.5]) 44 | plt.style.use(['nature']) 45 | fig = plt.figure(figsize=(7.20472, 1.9)) 46 | for task_index, task in enumerate(df['task']): 47 | print(task) 48 | df_task = df[df['task'] == task][model_names].values.flatten() 49 | df_task = df_task[[1, 2, 3, 4, 5, 6, 0]] 50 | df_task[df_task != df_task] = 0 51 | print(df_task) 52 | cutoff = (2.67, 2.72) if task_index == 0 else (2.0, 2.05) 53 | ax = brokenaxes(ylims=((.4, .74), cutoff), subplot_spec=gs[task_index]) 54 | 55 | ax.bar(np.arange(6), df_task[:-1], color=['#69005f', '#ff506e', '#cbc9e2', 'white', '#69005f', '#ff506e']) 56 | ax.set_xticks(np.arange(6), ['Centaur', 'Llama', 'Cog.\nmodel', 'Noise\nceiling', 'Centaur\n(ind.)', 'Llama\n(ind.)',]) 57 | ax.axhline(y=df_task[-1], color='grey', linestyle='--', linewidth=1.0) 58 | ax.axs[-1].text(5.65, df_task[-1] + offsets[task_index], 'Random guessing', fontsize=6, color='grey', horizontalalignment='right') 59 | ax.set_title('choices13k' if task_index == 0 else 'Intertemporal choice', fontsize=7) 60 | 61 | ax.set_ylabel('Negative log-likelihood', labelpad=20) 62 | 63 | print(ax.containers) 64 | ax.axs[-1].containers[0][0].set_alpha(0.8) 65 | ax.axs[-1].containers[0][1].set_alpha(0.8) 66 | ax.axs[-1].containers[0][2].set_alpha(1) 67 | ax.axs[-1].containers[0][3].set_alpha(0.5) 68 | ax.axs[-1].containers[0][4].set_alpha(0.8) 69 | ax.axs[-1].containers[0][5].set_alpha(0.8) 70 | ax.axs[-1].containers[0][3].set_edgecolor('black') 71 | ax.axs[-1].containers[0][4].set_hatch('///') 72 | ax.axs[-1].containers[0][5].set_hatch('///') 73 | ax.axs[-2].containers[0][4].set_alpha(0.8) 74 | ax.axs[-2].containers[0][5].set_alpha(0.8) 75 | ax.axs[-2].containers[0][4].set_hatch('///') 76 | ax.axs[-2].containers[0][5].set_hatch('///') 77 | 78 | fig.text(0.07, 0.925, 'a', fontsize=8, weight='bold') 79 | fig.text(0.493, 0.925, 'b', fontsize=8, weight='bold') 80 | 81 | #sns.despine() 82 | #plt.tight_layout() 83 | plt.savefig('figures/fig5.jpg', bbox_inches='tight', dpi=300) 84 | plt.show() 85 | -------------------------------------------------------------------------------- /plots/fig4.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import seaborn as sns 4 | import scienceplots 5 | import matplotlib.gridspec as gridspec 6 | import matplotlib.colors as colors 7 | import glob 8 | from natsort import natsorted 9 | import pandas as pd 10 | import numpy as np 11 | from sklearn.manifold import MDS 12 | 13 | gs = gridspec.GridSpec(1, 3, width_ratios=[0.33333, 0.33333, 0.33333]) 14 | plt.style.use(['nature']) 15 | fig = plt.figure(figsize=(7.08661, 7.08661/3)) 16 | 17 | # plot MDS 18 | metrics_df = pd.read_csv('../results/CogBench/behaviour.csv') 19 | metrics_df = metrics_df.loc[:, ~metrics_df.columns.str.contains('^Unnamed')] 20 | colors = ['black' for _ in metrics_df.Agent] 21 | colors = ['grey' if engine in ['Human'] else color for engine, color in zip(metrics_df.Agent, colors)] 22 | colors = ['#69005f' if engine == 'Centaur' else color for engine, color in zip(metrics_df.Agent, colors)] 23 | colors = ['#ff506e' if engine == 'Llama' else color for engine, color in zip(metrics_df.Agent, colors)] 24 | 25 | reducer = MDS(n_components=2, random_state=1) 26 | metrics_scores = metrics_df.iloc[:, 1:metrics_df.shape[1]//2].values 27 | agent_names = metrics_df.iloc[:, 0].values 28 | embedding = reducer.fit_transform(metrics_scores) 29 | 30 | ax = fig.add_subplot(gs[:, 0]) 31 | ax.scatter(embedding[:, 0], embedding[:, 1], c=colors, s=25, alpha=0.8) 32 | ax.set_xlabel('Embedding dimension 1') 33 | ax.set_ylabel('Embedding dimension 2') 34 | ax.set_xlim(-4, 6) 35 | ax.set_ylim(-4, 6) 36 | 37 | for i in range(embedding.shape[0]): 38 | if agent_names[i] == 'GPT-3.5': 39 | ax.annotate(agent_names[i], (-0.5 + embedding[i, 0], embedding[i, 1]+0.5)) 40 | else: 41 | ax.annotate(agent_names[i], (0.4 + embedding[i, 0], embedding[i, 1]-0.25)) 42 | 43 | red_point = embedding[[engine == 'Llama' for engine in metrics_df.Agent]] 44 | green_point = embedding[[engine == 'Centaur' for engine in metrics_df.Agent]] 45 | 46 | ax.text(-0.2, 1.09, 'a', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 47 | 48 | if red_point.size > 0 and green_point.size > 0: 49 | plt.arrow( 50 | red_point[0, 0], red_point[0, 1], 51 | green_point[0, 0] - red_point[0, 0], green_point[0, 1] - red_point[0, 1], 52 | head_width=0.4, head_length=0.4, overhang=0, fc='k', length_includes_head=True 53 | ) 54 | 55 | #plot feher da silva 56 | twostep_centaur = pd.read_csv('../results/feher2023rethinking/tst_centaur_alignment.csv').values.mean(1) 57 | twostep_llama = pd.read_csv('../results/feher2023rethinking/tst_llama_alignment.csv').values.mean(1) 58 | baseline_model = 0.023144136642747695 59 | print(twostep_llama) 60 | print(twostep_centaur) 61 | 62 | ax = fig.add_subplot(gs[:, 1]) 63 | ax.plot([0, 10, 20, 30, 40], twostep_centaur, color='#69005f', alpha=0.8, linewidth=1) 64 | ax.plot([0, 10, 20, 30, 40], twostep_llama, color='#ff506e', alpha=0.8, linewidth=1) 65 | ax.axhline(y=baseline_model, color='grey', linestyle='--', linewidth=1.0) 66 | ax.text(41, baseline_model - 0.018, 'Cognitive model', fontsize=6, color='grey', horizontalalignment='right') 67 | ax.text(-0.2, 1.09, 'b', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 68 | ax.set_ylabel(r'R$^2$') 69 | ax.set_xlabel('Layer') 70 | ax.set_xlim(1, 41) 71 | ax.set_ylim(-0.02, 0.24) 72 | ax.legend(['Centaur', 'Llama'], frameon=False, ncols=2, borderaxespad=0, handlelength=1, columnspacing=0.7, handletextpad=0.5, bbox_to_anchor=(0.51, 1.125), loc='upper center') 73 | 74 | # plot tuckute 75 | reading_llama = torch.load('../results/tuckute2024driving/llama.pth') 76 | reading_centaur = torch.load('../results/tuckute2024driving/centaur2000.pth') 77 | 78 | ax = fig.add_subplot(gs[:, 2]) 79 | ax.plot(torch.arange(1, reading_centaur.shape[0] + 1), reading_centaur, color='#69005f', alpha=0.8, linewidth=1) 80 | ax.plot(torch.arange(1, reading_llama.shape[0] + 1), reading_llama, color='#ff506e', alpha=0.8, linewidth=1) 81 | ax.axhline(y=0.38, color='grey', linestyle='--', linewidth=1.0) 82 | ax.axhline(y=0.56, color='black', linestyle='--', linewidth=1.0) 83 | ax.text(41, 0.357, 'Tuckute et al. (2024)', fontsize=6, color='grey', horizontalalignment='right') 84 | ax.text(41, 0.57, 'Noise ceiling', fontsize=6, color='black', horizontalalignment='right') 85 | ax.text(-0.2, 1.09, 'c', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 86 | ax.set_ylabel(r'R$^2$') 87 | ax.set_xlabel('Layer',) 88 | ax.set_xlim(1, 41) 89 | ax.set_ylim(0.3, 0.63) 90 | ax.legend(['Centaur', 'Llama'], frameon=False, ncols=2, borderaxespad=0, handlelength=1, columnspacing=0.7, handletextpad=0.5, bbox_to_anchor=(0.51, 1.125), loc='upper center') 91 | 92 | sns.despine() 93 | plt.tight_layout() 94 | plt.savefig('figures/fig4.pdf', bbox_inches='tight') 95 | 96 | plt.show() 97 | -------------------------------------------------------------------------------- /extended_data/ed_fig7.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import seaborn as sns 5 | import scienceplots 6 | import matplotlib.gridspec as gridspec 7 | import matplotlib.colors as colors 8 | import numpy as np 9 | import nilearn 10 | import nibabel as nib 11 | from nilearn import plotting, datasets, surface, image 12 | from nilearn.plotting import plot_stat_map 13 | from matplotlib.lines import Line2D 14 | import matplotlib.cm as cm 15 | import matplotlib.colors as mcolors 16 | import ast 17 | 18 | gs = gridspec.GridSpec(1, 2, width_ratios=[0.33333, 0.66666]) 19 | plt.style.use(['nature']) 20 | fig = plt.figure(figsize=(7.20472, 7.08661/3)) 21 | 22 | # subplot 1 23 | ax = fig.add_subplot(gs[:, 0]) 24 | image = plt.imread('test.png') 25 | cax = ax.imshow(image) 26 | ax.axis('off') 27 | ax.set_title('Pearson correlation', y=1.13, fontsize=7) 28 | 29 | # Define a colormap 30 | cmap = cmap=sns.color_palette("rocket_r", as_cmap=True) 31 | 32 | # Create a norm (normalization of values) 33 | norm = mcolors.Normalize(vmin=0, vmax=0.8) 34 | 35 | # Create a ScalarMappable with colormap and norm 36 | sm = cm.ScalarMappable(cmap=cmap, norm=norm) 37 | sm.set_array([]) # Required for the colorbar 38 | 39 | cbar = plt.colorbar(sm, ax=ax, orientation="vertical") 40 | 41 | # subplot 2 42 | ax = fig.add_subplot(gs[:, 1]) 43 | 44 | layer = 2 45 | 46 | def str_to_array(cell): 47 | cell = cell.strip('[]').split() 48 | return np.array(cell, dtype=float) 49 | 50 | df_centaur_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_centaur_alignment.csv') 51 | df_llama_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_llama_alignment.csv') 52 | df_random_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_random_alignment.csv') 53 | df_cognitive_tst = pd.read_csv('../results/feher2023rethinking/schaefer_tst_cognitive_alignment.csv') 54 | 55 | df_centaur_tst = df_centaur_tst.applymap(str_to_array) 56 | df_llama_tst = df_llama_tst.applymap(str_to_array) 57 | df_random_tst = df_random_tst.applymap(str_to_array) 58 | df_cognitive_tst = df_cognitive_tst.applymap(str_to_array) 59 | 60 | print(df_cognitive_tst) 61 | 62 | columns_to_keep = [ 63 | ['Left Accumbens', 'Right Accumbens'], 64 | ["b'7Networks_LH_Limbic_OFC_1'", "b'7Networks_LH_Default_pCunPCC_1'", "b'7Networks_RH_Limbic_OFC_1'", "b'7Networks_RH_Default_PFCdPFCm_1'"], 65 | ["b'7Networks_LH_SomMot_1'", "b'7Networks_LH_SomMot_2'", "b'7Networks_LH_SomMot_3'", "b'7Networks_LH_SomMot_4'", "b'7Networks_LH_SomMot_5'", "b'7Networks_LH_SomMot_6'"], 66 | ["b'7Networks_LH_Vis_1'", "b'7Networks_LH_Vis_2'", "b'7Networks_LH_Vis_3'", "b'7Networks_LH_Vis_4'", "b'7Networks_LH_Vis_5'", "b'7Networks_LH_Vis_6'", "b'7Networks_LH_Vis_7'", "b'7Networks_LH_Vis_8'", "b'7Networks_LH_Vis_9'", "b'7Networks_RH_Vis_1'", "b'7Networks_RH_Vis_2'", "b'7Networks_RH_Vis_3'", "b'7Networks_RH_Vis_4'", "b'7Networks_RH_Vis_5'", "b'7Networks_RH_Vis_6'", "b'7Networks_RH_Vis_7'", "b'7Networks_RH_Vis_8'"], 67 | ] 68 | 69 | df_centaur_tst = df_centaur_tst.iloc[layer] 70 | df_llama_tst = df_llama_tst.iloc[layer] 71 | df_random_tst = df_random_tst.iloc[layer] 72 | df_cognitive_tst = df_cognitive_tst.iloc[0] 73 | print(df_cognitive_tst) 74 | for i, column in enumerate(columns_to_keep): 75 | print('here') 76 | print(np.stack(df_centaur_tst[column].values).shape) 77 | centaur_mean = np.stack(df_centaur_tst[column].values).mean() 78 | llama_mean = np.stack(df_llama_tst[column].values).mean() 79 | cognitive_mean = np.stack(df_cognitive_tst[column].values).mean() 80 | random_mean = np.stack(df_random_tst[column].values).mean() 81 | 82 | ax.bar(np.array([-1/5, 0, 1/5, 0.4]) + i, [centaur_mean, llama_mean, cognitive_mean, random_mean], alpha=0.8, color=['#69005f', '#ff506e', '#cbc9e2', 'grey'], width=1/5) 83 | ax.set_ylabel('Pearson correlation') 84 | ax.set_xticks([0.1, 1.1, 2.1, 3.1], ['Accumbens', 'Medial PFC', 'Motor Cortex', 'Visual Cortex']) 85 | 86 | custom_lines_r2 = [Line2D([0], [0], color='#69005f', alpha=0.8, linewidth=5, markersize=3), 87 | Line2D([0], [0], color='#ff506e', alpha=0.8, linewidth=5, markersize=3), 88 | Line2D([0], [0], color='#cbc9e2', alpha=0.8, linewidth=5, markersize=3), 89 | Line2D([0], [0], color='grey', alpha=0.8, linewidth=5, markersize=3),] 90 | ax.legend(custom_lines_r2, ['Centaur', 'Llama', 'Cognitive model', 'Control'], frameon=False, ncols=4, bbox_to_anchor=(0.5, 1.3), loc='upper center') 91 | 92 | fig.text(0.012, 0.852, 'a', fontsize=8, weight='bold') 93 | fig.text(0.33, 0.852, 'b', fontsize=8, weight='bold') 94 | 95 | sns.despine() 96 | plt.tight_layout() 97 | plt.savefig('figures/fig7.png', bbox_inches='tight', dpi=300) 98 | 99 | plt.show() 100 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | from transformers import HfArgumentParser, TrainingArguments, set_seed 6 | from unsloth import is_bfloat16_supported 7 | from unsloth import UnslothTrainer, UnslothTrainingArguments 8 | from unsloth import FastLanguageModel 9 | from trl import DataCollatorForCompletionOnlyLM 10 | from datasets import load_dataset 11 | 12 | @dataclass 13 | class ModelArguments: 14 | model_name_or_path: str = field(metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}) 15 | lora_r: Optional[int] = field(default=8) 16 | lora_alpha: Optional[int] = field(default=8) 17 | lora_dropout: Optional[float] = field(default=0) 18 | 19 | @dataclass 20 | class DataTrainingArguments: 21 | dataset_text_field: str = field(default="text") 22 | max_seq_length: Optional[int] = field(default=32768) 23 | 24 | def main(model_args, data_args, training_args): 25 | # set seed for reproducibility 26 | set_seed(training_args.seed) 27 | 28 | # datasets 29 | train_dataset = load_dataset("marcelbinz/Psych-101")['train'].shuffle() 30 | eval_datasets = load_dataset("marcelbinz/Psych-101-test")['test'] 31 | 32 | # model 33 | model, tokenizer = FastLanguageModel.from_pretrained( 34 | model_name = model_args.model_name_or_path, 35 | max_seq_length = data_args.max_seq_length, 36 | dtype = None, 37 | load_in_4bit = True, 38 | ) 39 | 40 | model = FastLanguageModel.get_peft_model( 41 | model, 42 | r = model_args.lora_r, 43 | target_modules = [ 44 | "q_proj", "k_proj", "v_proj", "o_proj", 45 | "gate_proj", 46 | "up_proj", "down_proj", 47 | ], 48 | lora_alpha = model_args.lora_alpha, 49 | lora_dropout = model_args.lora_dropout, 50 | bias = "none", 51 | use_gradient_checkpointing = "unsloth", 52 | random_state = training_args.seed, 53 | use_rslora = True, 54 | loftq_config = None, 55 | ) 56 | 57 | tokenizer.pad_token_id = 0 58 | tokenizer.padding_side = "right" 59 | 60 | l_id = tokenizer(" <<").input_ids[1:] 61 | r_id = tokenizer(">>").input_ids[1:] 62 | collator = DataCollatorForCompletionOnlyLM(response_template=l_id, instruction_template=r_id, tokenizer=tokenizer) 63 | 64 | # trainer 65 | trainer = UnslothTrainer( 66 | model = model, 67 | tokenizer = tokenizer, 68 | train_dataset = train_dataset, 69 | eval_dataset=eval_datasets, 70 | dataset_text_field = data_args.dataset_text_field, 71 | max_seq_length = data_args.max_seq_length, 72 | dataset_num_proc = 8, 73 | data_collator=collator, 74 | args = UnslothTrainingArguments( 75 | per_device_train_batch_size = training_args.per_device_train_batch_size, 76 | per_device_eval_batch_size = training_args.per_device_eval_batch_size, 77 | gradient_accumulation_steps = training_args.gradient_accumulation_steps, 78 | warmup_steps = training_args.warmup_steps, 79 | num_train_epochs = training_args.num_train_epochs, 80 | learning_rate = training_args.learning_rate, 81 | embedding_learning_rate = training_args.learning_rate / 10, 82 | fp16 = not is_bfloat16_supported(), 83 | bf16 = is_bfloat16_supported(), 84 | log_level = training_args.log_level, 85 | logging_strategy = training_args.logging_strategy, 86 | logging_steps = training_args.logging_steps, 87 | evaluation_strategy = training_args.evaluation_strategy, 88 | eval_steps = training_args.eval_steps, 89 | save_strategy = training_args.save_strategy, 90 | save_steps = training_args.save_steps, 91 | optim = training_args.optim, 92 | weight_decay = training_args.weight_decay, 93 | lr_scheduler_type = training_args.lr_scheduler_type, 94 | seed = training_args.seed, 95 | output_dir = training_args.output_dir, 96 | ), 97 | ) 98 | 99 | trainer.accelerator.print(f"{trainer.model}") 100 | trainer.model.print_trainable_parameters() 101 | 102 | # train 103 | trainer.train(resume_from_checkpoint=None) 104 | 105 | # saving final model 106 | trainer.save_model() 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 111 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 112 | # If we pass only one argument to the script and it's the path to a json file, 113 | # let's parse it to get our arguments. 114 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 115 | else: 116 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 117 | main(model_args, data_args, training_args) 118 | -------------------------------------------------------------------------------- /openloop/baar2021latent/simulate.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import jsonlines 4 | from unsloth import FastLanguageModel 5 | import transformers 6 | import argparse 7 | 8 | def randomized_choice_options(num_choices): 9 | choice_options = list(map(chr, range(65, 91))) 10 | return np.random.choice(choice_options, num_choices, replace=False) 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--model", type=str, required=True) 15 | args = parser.parse_args() 16 | 17 | model, tokenizer = FastLanguageModel.from_pretrained( 18 | model_name = args.model, 19 | max_seq_length = 32768, 20 | dtype = None, 21 | load_in_4bit = True, 22 | ) 23 | FastLanguageModel.for_inference(model) 24 | 25 | pipe = transformers.pipeline( 26 | "text-generation", 27 | model=model, 28 | tokenizer=tokenizer, 29 | trust_remote_code=True, 30 | pad_token_id=0, 31 | do_sample=True, 32 | temperature=1.0, 33 | max_new_tokens=1, 34 | ) 35 | 36 | df = pd.read_csv('gameDat.csv') 37 | 38 | data = [] 39 | for subject in df.subID.unique(): 40 | choice_options = randomized_choice_options(2) 41 | 42 | prompt = "You will take part in a Social Prediction Game.\n"\ 43 | "You will observe a Player playing against an Opponent in different games.\n"\ 44 | "In each game, the Player and the Opponent simultaneously choose between option " + choice_options[0] + " and option " + choice_options[1] + ".\n"\ 45 | "The Player and the Opponent win points based on their choices.\n"\ 46 | "The rules change between games, and you will be informed about them before each game.\n"\ 47 | "The Player varies between blocks but is consistent across games within a block.\n"\ 48 | "The Opponent switches in each game.\n"\ 49 | "Your task is to predict the choices made by the Player and rate your confidence in this prediction on an 11-point scale from 0 to 100 (in increments of 10).\n"\ 50 | "You get feedback after each game on whether your prediction was correct or not.\n\n"\ 51 | 52 | df_sub = df[df['subID'] == subject] 53 | for block in range(4): 54 | prompt += 'Block ' + str(block + 1) + ' starts now.\n\n' 55 | df_block = df_sub[df_sub['Block'] == block] 56 | for trial in range(16): 57 | df_trial = df_block[df_block['Trial'] == trial] 58 | 59 | # 0 co-operate, 1 defect 60 | T = df_trial['T'].item() 61 | S = df_trial['S'].item() 62 | 63 | prompt += "The rules of the game are as follows:\n"\ 64 | "If Player chooses option " + choice_options[0] + " and Opponent chooses option " + choice_options[0] + ", then Player wins 10 points and Opponent wins 10 points.\n"\ 65 | "If Player chooses option " + choice_options[0] + " and Opponent chooses option " + choice_options[1] + ", then Player wins " + str(S) + " points and Opponent wins " + str(T) + " points.\n"\ 66 | "If Player chooses option " + choice_options[1] + " and Opponent chooses option " + choice_options[0] + ", then Player wins " + str(T) + " points and Opponent wins " + str(S) + " points.\n"\ 67 | "If Player chooses option " + choice_options[1] + " and Opponent chooses option " + choice_options[1] + ", then Player wins 5 points and Opponent wins 5 points.\n"\ 68 | "You predict that Player will choose option <<" 69 | 70 | # simulate choice 71 | choice = pipe(prompt)[0]['generated_text'][len(prompt):] 72 | print(choice) 73 | if choice == choice_options[0]: 74 | response = 'coop' 75 | elif choice == choice_options[1]: 76 | response = 'def' 77 | else: 78 | response = 'NaN' 79 | print('something went wrong') 80 | correct = 'correct' if response == df_trial['CorrAns'].item() else 'incorrect' 81 | prompt += str(choice) + ">>. You indicate a confidence of <<" 82 | 83 | # simulate confidence 84 | confidence = pipe(prompt)[0]['generated_text'][len(prompt):] 85 | prompt += str(confidence) + ">>. Your prediction was " + correct + ".\n\n" 86 | print(prompt) 87 | 88 | row = [subject, block, df_trial['Player'].item(), df_trial['Type'].item(), df_trial['Variant'].item(), df_trial['Type_Total'].item(), trial, S, T, df_trial['GameType'].item(), df_trial['CorrAns'].item(), response, confidence] 89 | data.append(row) 90 | 91 | df = pd.DataFrame(data, columns=['subID', 'Block', 'Player', 'Type', 'Variant', 'Type_Total', 'Trial', 'S', 'T', 'GameType', 'CorrAns', 'GivenAns', 'ConfidenceNum']) 92 | print(df) 93 | df.to_csv('simulation_' + args.model.replace('/', '-') + '.csv') 94 | -------------------------------------------------------------------------------- /openloop/wilson2014humans/simulate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm import tqdm 5 | import sys 6 | from unsloth import FastLanguageModel 7 | import transformers 8 | from datasets import load_dataset 9 | import math 10 | import argparse 11 | 12 | def randomized_choice_options(num_choices): 13 | choice_options = list(map(chr, range(65, 91))) 14 | return np.random.choice(choice_options, num_choices, replace=False) 15 | 16 | def generate_rewards(values): 17 | return np.round(np.clip(np.random.normal(values, 8.0), 1.0, 100.0)).astype('int') 18 | 19 | def generate_prompts_horizon(datasets, model): 20 | 21 | model, tokenizer = FastLanguageModel.from_pretrained( 22 | model_name = model, 23 | max_seq_length = 32768, 24 | dtype = None, 25 | load_in_4bit = True, 26 | ) 27 | FastLanguageModel.for_inference(model) 28 | 29 | pipe = transformers.pipeline( 30 | "text-generation", 31 | model=model, 32 | tokenizer=tokenizer, 33 | trust_remote_code=True, 34 | pad_token_id=0, 35 | do_sample=True, 36 | temperature=1.0, 37 | max_new_tokens=1, 38 | ) 39 | 40 | for dataset_idx, dataset in enumerate(datasets): 41 | print(dataset) 42 | df = pd.read_csv(dataset) 43 | prompts = load_dataset("marcelbinz/Psych-101-test") 44 | eval_prompts = prompts['test'].filter(lambda example: example['experiment'].startswith('wilson2014humans/' + dataset)) 45 | eval_participants = list(map(int, eval_prompts['participant'])) 46 | df = df[df['participant'].isin(eval_participants)] 47 | print(len(df.participant.unique()), flush=True) 48 | 49 | data = [] 50 | 51 | for participant in tqdm(df.participant.unique()): 52 | choice_options = randomized_choice_options(num_choices=2) 53 | 54 | prompt = \ 55 | "You are participating in multiple games involving two slot machines, labeled " + choice_options[0] + " and " + choice_options[1] + ".\n" \ 56 | "The two slot machines are different across different games.\nEach time you choose a slot machine, you get some points.\n" \ 57 | "You choose a slot machine by pressing the corresponding key.\n" \ 58 | "Each slot machine tends to pay out about the same amount of points on average.\n" \ 59 | "Your goal is to choose the slot machines that will give you the most points across the experiment.\n" \ 60 | "The first 4 trials in each game are instructed trials where you will be told which slot machine to choose.\n" \ 61 | "After these instructed trials, you will have the freedom to choose for either 1 or 6 trials.\n\n" 62 | 63 | df_participant = df[(df['participant'] == participant)] 64 | num_tasks = min(100, df_participant.task.max() + 1) # 65 | 66 | for task in range(num_tasks): 67 | df_task = df_participant[(df_participant['task'] == task)] 68 | num_trials = int(df_task.trial.max() + 1) 69 | prompt += f"Game {task + 1}. There are {num_trials} trials in this game.\n" 70 | 71 | rewards = generate_rewards(df_task[['slot1_value', 'slot2_value']].values) 72 | 73 | for trial in range(num_trials): 74 | df_trial = df_task[(df_task['trial'] == trial)] 75 | if trial < 4: 76 | c_idx = df_trial.choice.item() 77 | c = choice_options[c_idx].item() 78 | r = rewards[trial, c_idx] 79 | prompt += f"You are instructed to press {c} and get {r} points.\n" 80 | forced_choice_trial = 1 81 | else: 82 | prompt += f"You press <<" 83 | c = pipe(prompt)[0]['generated_text'][len(prompt):] 84 | if c not in [choice_options[0], choice_options[1]]: 85 | c = np.random.choice([choice_options[0], choice_options[1]]) 86 | print('should not happen!', flush=True) 87 | c_idx = list(choice_options).index(c) 88 | r = rewards[trial, c_idx] 89 | prompt += f"{c}>> and get {r} points.\n" 90 | forced_choice_trial = 0 91 | 92 | #print(prompt) 93 | #print() 94 | row = [participant, task, trial, c_idx, r, forced_choice_trial, df_trial['slot1_value'].item(), df_trial['slot2_value'].item()] 95 | data.append(row) 96 | 97 | prompt += '\n' 98 | 99 | df = pd.DataFrame(data, columns=['participant', 'task', 'trial', 'choice', 'reward', 'forced', 'slot1_value', 'slot2_value']) 100 | print(df) 101 | df.to_csv('simulation' + str(dataset_idx) + '.csv') 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--model", type=str, required=True) 106 | args = parser.parse_args() 107 | 108 | files = os.listdir(".") 109 | datasets = sorted([f for f in files if (f.startswith("exp") and f.endswith(".csv"))]) 110 | print(datasets) 111 | generate_prompts_horizon(datasets, args.model) 112 | -------------------------------------------------------------------------------- /plots/fig4_new.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import seaborn as sns 4 | import scienceplots 5 | import matplotlib.gridspec as gridspec 6 | import matplotlib.colors as colors 7 | import glob 8 | from natsort import natsorted 9 | import pandas as pd 10 | import numpy as np 11 | from sklearn.manifold import MDS 12 | import math 13 | from scipy import stats 14 | 15 | gs = gridspec.GridSpec(1, 3, width_ratios=[0.33333, 0.33333, 0.33333]) 16 | plt.style.use(['nature']) 17 | fig = plt.figure(figsize=(7.08661, 7.08661/3)) 18 | 19 | # plot MDS 20 | metrics_df = pd.read_csv('../results/CogBench/behaviour.csv') 21 | metrics_df = metrics_df.loc[:, ~metrics_df.columns.str.contains('^Unnamed')] 22 | colors = ['black' for _ in metrics_df.Agent] 23 | colors = ['grey' if engine in ['Human'] else color for engine, color in zip(metrics_df.Agent, colors)] 24 | colors = ['#69005f' if engine == 'Centaur' else color for engine, color in zip(metrics_df.Agent, colors)] 25 | colors = ['#ff506e' if engine == 'Llama' else color for engine, color in zip(metrics_df.Agent, colors)] 26 | 27 | reducer = MDS(n_components=2, random_state=1) 28 | metrics_scores = metrics_df.iloc[:, 1:metrics_df.shape[1]//2].values 29 | agent_names = metrics_df.iloc[:, 0].values 30 | embedding = reducer.fit_transform(metrics_scores) 31 | 32 | ax = fig.add_subplot(gs[0, 0]) 33 | ax.scatter(embedding[:, 0], embedding[:, 1], c=colors, s=25, alpha=0.8) 34 | ax.set_xlabel('Embedding dimension 1') 35 | ax.set_ylabel('Embedding dimension 2') 36 | ax.set_xlim(-4, 6) 37 | ax.set_ylim(-4, 6) 38 | 39 | for i in range(embedding.shape[0]): 40 | if agent_names[i] == 'GPT-3.5': 41 | ax.annotate(agent_names[i], (-0.6 + embedding[i, 0], embedding[i, 1]+0.5), size=6) 42 | else: 43 | ax.annotate(agent_names[i], (0.45 + embedding[i, 0], embedding[i, 1]-0.25), size=6) 44 | 45 | red_point = embedding[[engine == 'Llama' for engine in metrics_df.Agent]] 46 | green_point = embedding[[engine == 'Centaur' for engine in metrics_df.Agent]] 47 | 48 | ax.text(-0.22, 1.09, 'a', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 49 | 50 | if red_point.size > 0 and green_point.size > 0: 51 | plt.arrow( 52 | red_point[0, 0], red_point[0, 1], 53 | green_point[0, 0] - red_point[0, 0], green_point[0, 1] - red_point[0, 1], 54 | head_width=0.4, head_length=0.4, overhang=0, fc='k', length_includes_head=True 55 | ) 56 | 57 | #plot feher da silva 58 | df_centaur_tst = torch.load('../results/feher2023rethinking/schaefer_tst_centaur_alignment.pth') 59 | df_llama_tst = torch.load('../results/feher2023rethinking/schaefer_tst_llama_alignment.pth') 60 | df_random_tst = torch.load('../results/feher2023rethinking/schaefer_tst_random_alignment.pth') 61 | 62 | df_centaur_tst = np.array([list(layer.values()) for layer in df_centaur_tst]) 63 | df_llama_tst = np.array([list(layer.values()) for layer in df_llama_tst]) 64 | df_random_tst = np.array([list(layer.values()) for layer in df_random_tst]) 65 | 66 | twostep_centaur = df_centaur_tst.mean((1, 2)) 67 | twostep_llama = df_llama_tst.mean((1, 2)) 68 | twostep_random = df_random_tst.mean((1, 2)) 69 | twostep_centaur_se = df_centaur_tst.std((1, 2)) / math.sqrt(df_centaur_tst.shape[1] * df_centaur_tst.shape[2]) 70 | twostep_llama_se = df_llama_tst.std((1, 2)) / math.sqrt(df_centaur_tst.shape[1] * df_centaur_tst.shape[2]) 71 | twostep_random_se = df_random_tst.std((1, 2)) / math.sqrt(df_random_tst.shape[1] * df_random_tst.shape[2]) 72 | 73 | print('Two-step task:') 74 | for i in range(5): 75 | print(stats.ttest_ind(df_centaur_tst[i].flatten(), df_llama_tst[i].flatten(), alternative='greater')) 76 | 77 | baseline_model = 0.20065425519568694 78 | print(twostep_llama) 79 | print(twostep_centaur) 80 | 81 | ax = fig.add_subplot(gs[0, 1]) 82 | ax.errorbar([0, 10, 20, 30, 40], twostep_centaur, yerr=twostep_centaur_se, color='#69005f', alpha=0.8, linewidth=1) 83 | ax.errorbar([0, 10, 20, 30, 40], twostep_llama, yerr=twostep_llama_se, color='#ff506e', alpha=0.8, linewidth=1) 84 | ax.errorbar([0, 10, 20, 30, 40], twostep_random, yerr=twostep_random_se, color='grey', alpha=0.8, linewidth=1) 85 | ax.legend(['Centaur', 'Llama', 'Control'], frameon=False, ncols=3, borderaxespad=0, handlelength=1, columnspacing=0.7, handletextpad=0.5, bbox_to_anchor=(0.51, 1.125), loc='upper center') 86 | ax.axhline(y=baseline_model, color='#cbc9e2', linestyle='--', linewidth=1.0) 87 | ax.text(41, baseline_model - 0.0185, 'Cognitive model', fontsize=6, color='#aeabcc', horizontalalignment='right') 88 | ax.text(-0.2, 1.09, 'b', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 89 | ax.set_ylabel('Pearson correlation') 90 | ax.set_xlabel('Layer') 91 | ax.set_xlim(1, 41) 92 | ax.set_ylim(0.16, 0.44) 93 | 94 | 95 | # plot tuckute 96 | reading_llama = torch.load('../results/tuckute2024driving/llama.pth') 97 | reading_centaur = torch.load('../results/tuckute2024driving/centaur2000.pth') 98 | reading_random = torch.load('../results/tuckute2024driving/random.pth') 99 | reading_llama_sem = torch.load('../results/tuckute2024driving/llama_sem.pth') 100 | reading_centaur_sem = torch.load('../results/tuckute2024driving/centaur2000_sem.pth') 101 | reading_random_sem = torch.load('../results/tuckute2024driving/random_sem.pth') 102 | 103 | best_centaur = reading_centaur.argmax() 104 | best_llama = reading_llama.argmax() 105 | 106 | ax = fig.add_subplot(gs[0, 2]) 107 | ax.errorbar(torch.arange(1, reading_centaur.shape[0] + 1), reading_centaur, yerr=reading_centaur_sem, color='#69005f', alpha=0.8, linewidth=1) 108 | ax.errorbar(torch.arange(1, reading_llama.shape[0] + 1), reading_llama, yerr=reading_llama_sem, color='#ff506e', alpha=0.8, linewidth=1) 109 | ax.errorbar(torch.arange(1, reading_random.shape[0] + 1), reading_random, yerr=reading_random_sem, color='grey', alpha=0.8, linewidth=1) 110 | ax.legend(['Centaur', 'Llama', 'Control'], frameon=False, ncols=3, borderaxespad=0, handlelength=1, columnspacing=0.7, handletextpad=0.5, bbox_to_anchor=(0.51, 1.125), loc='upper center') 111 | ax.axhline(y=0.38, color='#cbc9e2', linestyle='--', linewidth=1.0) 112 | ax.axhline(y=0.56, color='black', linestyle='--', linewidth=1.0) 113 | ax.text(41, 0.34, 'Tuckute et al. (2024)', fontsize=6, color='#aeabcc', horizontalalignment='right') 114 | ax.text(41, 0.575, 'Noise ceiling', fontsize=6, color='black', horizontalalignment='right') 115 | ax.text(-0.2, 1.09, 'c', transform=ax.transAxes, fontsize=8, fontweight='bold', va='top') # Add label (b) 116 | ax.set_ylabel('Pearson correlation') 117 | ax.set_xlabel('Layer',) 118 | ax.set_xlim(1, 41) 119 | ax.set_ylim(0.08, 0.64) 120 | 121 | 122 | sns.despine() 123 | plt.tight_layout() 124 | plt.savefig('figures/fig4_new.pdf', bbox_inches='tight') 125 | 126 | plt.show() 127 | -------------------------------------------------------------------------------- /camera_ready/fig4.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import seaborn as sns 4 | import scienceplots 5 | import matplotlib.gridspec as gridspec 6 | import matplotlib.colors as colors 7 | import glob 8 | from natsort import natsorted 9 | import pandas as pd 10 | import numpy as np 11 | from sklearn.manifold import MDS 12 | import math 13 | from scipy import stats 14 | import matplotlib as mpl 15 | 16 | gs = gridspec.GridSpec(1, 3, width_ratios=[0.33333, 0.33333, 0.33333]) 17 | plt.style.use(['nature']) 18 | mpl.rcParams.update({ 19 | "pdf.fonttype": 42, 20 | "pdf.fonttype": 42, 21 | "text.usetex": False, 22 | }) 23 | fig = plt.figure(figsize=(7.20472, 7.20472/3)) 24 | 25 | # plot MDS 26 | metrics_df = pd.read_csv('../results/CogBench/behaviour.csv') 27 | metrics_df = metrics_df.loc[:, ~metrics_df.columns.str.contains('^Unnamed')] 28 | colors = ['black' for _ in metrics_df.Agent] 29 | colors = ['grey' if engine in ['Human'] else color for engine, color in zip(metrics_df.Agent, colors)] 30 | colors = ['#69005f' if engine == 'Centaur' else color for engine, color in zip(metrics_df.Agent, colors)] 31 | colors = ['#ff506e' if engine == 'Llama' else color for engine, color in zip(metrics_df.Agent, colors)] 32 | 33 | reducer = MDS(n_components=2, random_state=1) 34 | metrics_scores = metrics_df.iloc[:, 1:metrics_df.shape[1]//2].values 35 | agent_names = metrics_df.iloc[:, 0].values 36 | embedding = reducer.fit_transform(metrics_scores) 37 | 38 | ax = fig.add_subplot(gs[0, 0]) 39 | ax.scatter(embedding[:, 0], embedding[:, 1], c=colors, s=25, alpha=0.8) 40 | ax.set_xlabel('Embedding dimension 1') 41 | ax.set_ylabel('Embedding dimension 2') 42 | ax.set_xlim(-4, 6) 43 | ax.set_ylim(-4, 6) 44 | 45 | for i in range(embedding.shape[0]): 46 | if agent_names[i] == 'GPT-3.5': 47 | ax.annotate(agent_names[i], (-0.6 + embedding[i, 0], embedding[i, 1]+0.5), size=6) 48 | else: 49 | ax.annotate(agent_names[i], (0.45 + embedding[i, 0], embedding[i, 1]-0.25), size=6) 50 | 51 | red_point = embedding[[engine == 'Llama' for engine in metrics_df.Agent]] 52 | green_point = embedding[[engine == 'Centaur' for engine in metrics_df.Agent]] 53 | 54 | 55 | 56 | if red_point.size > 0 and green_point.size > 0: 57 | plt.arrow( 58 | red_point[0, 0], red_point[0, 1], 59 | green_point[0, 0] - red_point[0, 0], green_point[0, 1] - red_point[0, 1], 60 | head_width=0.4, head_length=0.4, overhang=0, fc='k', length_includes_head=True 61 | ) 62 | 63 | #plot feher da silva 64 | df_centaur_tst = torch.load('../results/feher2023rethinking/schaefer_tst_centaur_alignment.pth') 65 | df_llama_tst = torch.load('../results/feher2023rethinking/schaefer_tst_llama_alignment.pth') 66 | df_random_tst = torch.load('../results/feher2023rethinking/schaefer_tst_random_alignment.pth') 67 | 68 | df_centaur_tst = np.array([list(layer.values()) for layer in df_centaur_tst]) 69 | df_llama_tst = np.array([list(layer.values()) for layer in df_llama_tst]) 70 | df_random_tst = np.array([list(layer.values()) for layer in df_random_tst]) 71 | 72 | twostep_centaur = df_centaur_tst.mean((1, 2)) 73 | twostep_llama = df_llama_tst.mean((1, 2)) 74 | twostep_random = df_random_tst.mean((1, 2)) 75 | twostep_centaur_se = df_centaur_tst.std((1, 2)) / math.sqrt(df_centaur_tst.shape[1] * df_centaur_tst.shape[2]) 76 | twostep_llama_se = df_llama_tst.std((1, 2)) / math.sqrt(df_centaur_tst.shape[1] * df_centaur_tst.shape[2]) 77 | twostep_random_se = df_random_tst.std((1, 2)) / math.sqrt(df_random_tst.shape[1] * df_random_tst.shape[2]) 78 | 79 | print('Two-step task:') 80 | for i in range(5): 81 | print(df_centaur_tst[i].flatten().shape) 82 | print(stats.ttest_ind(df_centaur_tst[i].flatten(), df_llama_tst[i].flatten(), alternative='greater')) 83 | 84 | baseline_model = 0.20065425519568694 85 | print(twostep_llama) 86 | print(twostep_centaur) 87 | 88 | ax = fig.add_subplot(gs[0, 1]) 89 | ax.errorbar([0, 10, 20, 30, 40], twostep_centaur, yerr=twostep_centaur_se, color='#69005f', alpha=0.8, linewidth=1) 90 | ax.errorbar([0, 10, 20, 30, 40], twostep_llama, yerr=twostep_llama_se, color='#ff506e', alpha=0.8, linewidth=1) 91 | ax.errorbar([0, 10, 20, 30, 40], twostep_random, yerr=twostep_random_se, color='grey', alpha=0.8, linewidth=1) 92 | ax.legend(['Centaur', 'Llama', 'Control'], frameon=False, ncols=3, borderaxespad=0, handlelength=1, columnspacing=0.7, handletextpad=0.5, bbox_to_anchor=(0.51, 1.125), loc='upper center') 93 | ax.axhline(y=baseline_model, color='#cbc9e2', linestyle='--', linewidth=1.0) 94 | ax.text(41, baseline_model - 0.0185, 'Cognitive model', fontsize=6, color='#aeabcc', horizontalalignment='right') 95 | ax.set_ylabel('Pearson correlation') 96 | ax.set_xlabel('Layer') 97 | ax.set_xlim(1, 41) 98 | ax.set_ylim(0.16, 0.44) 99 | 100 | 101 | # plot tuckute 102 | reading_llama = torch.load('../results/tuckute2024driving/llama.pth') 103 | reading_centaur = torch.load('../results/tuckute2024driving/centaur2000.pth') 104 | reading_random = torch.load('../results/tuckute2024driving/random.pth') 105 | reading_llama_sem = torch.load('../results/tuckute2024driving/llama_sem.pth') 106 | reading_centaur_sem = torch.load('../results/tuckute2024driving/centaur2000_sem.pth') 107 | reading_random_sem = torch.load('../results/tuckute2024driving/random_sem.pth') 108 | 109 | best_centaur = reading_centaur.argmax() 110 | best_llama = reading_llama.argmax() 111 | 112 | ax = fig.add_subplot(gs[0, 2]) 113 | ax.errorbar(torch.arange(1, reading_centaur.shape[0] + 1), reading_centaur, yerr=reading_centaur_sem, color='#69005f', alpha=0.8, linewidth=1) 114 | ax.errorbar(torch.arange(1, reading_llama.shape[0] + 1), reading_llama, yerr=reading_llama_sem, color='#ff506e', alpha=0.8, linewidth=1) 115 | ax.errorbar(torch.arange(1, reading_random.shape[0] + 1), reading_random, yerr=reading_random_sem, color='grey', alpha=0.8, linewidth=1) 116 | ax.legend(['Centaur', 'Llama', 'Control'], frameon=False, ncols=3, borderaxespad=0, handlelength=1, columnspacing=0.7, handletextpad=0.5, bbox_to_anchor=(0.51, 1.125), loc='upper center') 117 | ax.axhline(y=0.38, color='#cbc9e2', linestyle='--', linewidth=1.0) 118 | ax.axhline(y=0.56, color='black', linestyle='--', linewidth=1.0) 119 | ax.text(41, 0.34, 'Tuckute et al. (2024)', fontsize=6, color='#aeabcc', horizontalalignment='right') 120 | ax.text(41, 0.575, 'Noise ceiling', fontsize=6, color='black', horizontalalignment='right') 121 | ax.set_ylabel('Pearson correlation') 122 | ax.set_xlabel('Layer',) 123 | ax.set_xlim(1, 41) 124 | ax.set_ylim(0.08, 0.64) 125 | 126 | fig.text(0.012, 0.9, 'a', fontsize=8, weight='bold') 127 | fig.text(0.335, 0.9, 'b', fontsize=8, weight='bold') 128 | fig.text(0.675, 0.9, 'c', fontsize=8, weight='bold') 129 | 130 | sns.despine() 131 | plt.tight_layout() 132 | plt.savefig('figures/fig4.pdf', bbox_inches='tight') 133 | 134 | plt.show() 135 | -------------------------------------------------------------------------------- /openloop/kool2016when/simulate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from unsloth import FastLanguageModel 6 | import transformers 7 | from datasets import load_dataset 8 | import os 9 | import argparse 10 | 11 | def randomized_choice_options(num_choices): 12 | choice_options = list(map(chr, range(65, 91))) 13 | return np.random.choice(choice_options, num_choices, replace=False) 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", type=str, required=True) 18 | args = parser.parse_args() 19 | 20 | df = pd.read_csv("exp2.csv") 21 | 22 | dataset = load_dataset("marcelbinz/Psych-101-test") 23 | 24 | kool_eval_dataset = dataset['test'].filter(lambda example: example['experiment'].startswith('kool2016when/exp2.csv')) 25 | kool_eval_participants = list(map(int, kool_eval_dataset['participant'])) 26 | df = df[df['participant'].isin(kool_eval_participants)] 27 | print(len(df.participant.unique())) 28 | 29 | 30 | model, tokenizer = FastLanguageModel.from_pretrained( 31 | model_name = args.model, 32 | max_seq_length = 32768, 33 | dtype = None, 34 | load_in_4bit = True, 35 | ) 36 | FastLanguageModel.for_inference(model) 37 | 38 | pipe = transformers.pipeline( 39 | "text-generation", 40 | model=model, 41 | tokenizer=tokenizer, 42 | trust_remote_code=True, 43 | pad_token_id=0, 44 | do_sample=True, 45 | temperature=1.0, 46 | max_new_tokens=1, 47 | ) 48 | 49 | instructions = ( 50 | "You will be taking one of the spaceships {spaceship_0} or {spaceship_1} to one of the planets {planet_0} or {planet_1}.\n" 51 | "The spaceships can fly to either planet, but one will mostly fly to planet {planet_0}, and the other will mostly fly to planet {planet_1}.\n" 52 | "The planet a spaceship goes to most won't change during the game.\n" 53 | "Planet {planet_0} has aliens {alien_0} and {alien_1}, and planet {planet_1} has aliens {alien_2} and {alien_3} on it.\n" 54 | "Each alien has its own space treasure mine.\n" 55 | "When you arrive at each planet, you will ask one of the aliens for space treasure from their mines.\n" 56 | "The treasure an alien can give will change slowly during the game.\n" 57 | "You can take a spaceship or ask an alien for space treasure by pressing the corresponding key.\n" 58 | "Your goal is to get as much treasure as possible over the next {n_trials} days.\n\n" 59 | ) 60 | 61 | data = [] 62 | 63 | for participant in tqdm(df.participant.unique()): 64 | ( 65 | spaceship_0, 66 | spaceship_1, 67 | planet_0, 68 | planet_1, 69 | alien_0, 70 | alien_1, 71 | alien_2, 72 | alien_3, 73 | ) = randomized_choice_options(8) 74 | 75 | par_text = instructions.format( 76 | spaceship_0=spaceship_0, 77 | spaceship_1=spaceship_1, 78 | planet_0=planet_0, 79 | planet_1=planet_1, 80 | alien_0=alien_0, 81 | alien_1=alien_1, 82 | alien_2=alien_2, 83 | alien_3=alien_3, 84 | n_trials=int( 85 | df[df.participant == participant].trial.nunique() / 2 86 | ), # because two step 87 | ) 88 | #print(par_text) 89 | 90 | # iterate every two trials 91 | par_df = df[df.participant == participant].reset_index(drop=True) 92 | 93 | for trial in range(0, par_df.trial.nunique(), 2): 94 | # select the current two trials 95 | # by row number 96 | first_step_df = par_df.iloc[trial, :] 97 | 98 | par_text += f"You are presented with spaceships {spaceship_0} and {spaceship_1}." 99 | par_text += " You press <<" 100 | 101 | choice_1 = pipe(par_text)[0]['generated_text'][len(par_text):] 102 | print(choice_1) 103 | if choice_1 not in [spaceship_0, spaceship_1]: 104 | choice_1 = np.random.choice([spaceship_0, spaceship_1]) 105 | print('should not happen!', flush=True) 106 | 107 | if choice_1 == spaceship_0: 108 | planet_landed = np.random.choice([planet_0, planet_1], p=[0.7, 0.3]) 109 | choice_1_idx = 0 110 | elif choice_1 == spaceship_1: 111 | planet_landed = np.random.choice([planet_1, planet_0], p=[0.7, 0.3]) 112 | choice_1_idx = 1 113 | 114 | if planet_landed == planet_0: 115 | par_text += ( 116 | f"{choice_1}>>." 117 | f" You end up on planet {planet_landed}." 118 | f" You see alien {alien_0} and alien {alien_1}." 119 | ) 120 | elif planet_landed == planet_1: 121 | par_text += ( 122 | f"{choice_1}>>." 123 | f" You end up on planet {planet_landed}." 124 | f" You see alien {alien_2} and alien {alien_3}." 125 | ) 126 | 127 | par_text += " You press <<" 128 | 129 | choice_2 = pipe(par_text)[0]['generated_text'][len(par_text):] 130 | print(choice_2) 131 | avail_options = [alien_0, alien_1] if planet_landed == planet_0 else [alien_2, alien_3] 132 | if choice_2 not in avail_options: 133 | choice_2 = np.random.choice(avail_options) 134 | print('should not happen!', flush=True) 135 | 136 | if choice_2 == alien_0: 137 | reward = np.random.choice([1, 0], p=[first_step_df['reward.0.0'].item(), 1 - first_step_df['reward.0.0'].item()]) 138 | state_idx = 0 139 | choice_2_idx = 0 140 | elif choice_2 == alien_1: 141 | reward = np.random.choice([1, 0], p=[first_step_df['reward.0.1'].item(), 1 - first_step_df['reward.0.1'].item()]) 142 | state_idx = 0 143 | choice_2_idx = 1 144 | elif choice_2 == alien_2: 145 | reward = np.random.choice([1, 0], p=[first_step_df['reward.1.0'].item(), 1 - first_step_df['reward.1.0'].item()]) 146 | state_idx = 1 147 | choice_2_idx = 0 148 | elif choice_2 == alien_3: 149 | reward = np.random.choice([1, 0], p=[first_step_df['reward.1.1'].item(), 1 - first_step_df['reward.1.1'].item()]) 150 | state_idx = 1 151 | choice_2_idx = 1 152 | 153 | par_text += ( 154 | f"{choice_2}>>." 155 | f" You find {int(reward)} pieces of space treasure.\n" 156 | ) 157 | 158 | row1 = [participant, 0, trial, 999, choice_1_idx, 0] 159 | row2 = [participant, 0, trial+1, state_idx, choice_2_idx, reward] 160 | 161 | data.append(row1) 162 | data.append(row2) 163 | #print(par_text) 164 | #print() 165 | df = pd.DataFrame(data, columns=['participant', 'task', 'trial', 'current_state', 'choice', 'reward']) 166 | print(df) 167 | df.to_csv('simulation.csv') 168 | --------------------------------------------------------------------------------