├── UCLA_VAT
├── UCLA_VAT.xlsx
├── UCLA_VAT_results.npz
├── UCLA_VAT_ind_subj_data.xlsx
├── README.md
├── eval_gpt_UCLA_VAT.py
└── analyze_UCLA_VAT.py
├── digit_mat
├── all_problems.npz
├── all_problems_1thru5.npz
├── gpt_matprob_results.npz
├── all_4_5_rule_problems.npz
├── N_unique_rules_2rule_prob.npz
├── N_unique_rules_3rule_prob.npz
├── exp1_GPT3_data
│ ├── all_prob.npz
│ ├── all_onerule_MC_acc.npz
│ ├── all_onerule_gen_acc.npz
│ ├── all_probcat_MC_acc.npz
│ ├── all_probcat_gen_acc.npz
│ ├── aligned_vs_permuted_MC_acc.npz
│ ├── aligned_vs_permuted_gen_acc.npz
│ ├── tworule_prob_N_unique_rules_MC.npz
│ ├── tworule_prog_vs_noprog_gen_acc.npz
│ ├── threerule_prob_N_unique_rules_MC.npz
│ ├── tworule_prob_N_unique_rules_gen.npz
│ └── threerule_prob_N_unique_rules_gen.npz
├── gpt_matprob_results_1thru5.npz
├── exp2_GPT3_data
│ ├── all_onerule_MC_acc.npz
│ ├── all_onerule_gen_acc.npz
│ ├── all_probcat_MC_acc.npz
│ └── all_probcat_gen_acc.npz
├── exp1_behavioral_data
│ ├── ind_subj_results.npz
│ ├── probcat_MC_acc_behavior.npz
│ ├── probcat_gen_acc_behavior.npz
│ ├── probcat_MC_acc_behavior_onerule.npz
│ ├── tworule_prob_N_unique_rules_MC.npz
│ ├── tworule_prob_N_unique_rules_gen.npz
│ ├── probcat_gen_acc_behavior_onerule.npz
│ ├── threerule_prob_N_unique_rules_MC.npz
│ ├── threerule_prob_N_unique_rules_gen.npz
│ ├── aligned_vs_permuted_MC_acc_behavior.npz
│ ├── aligned_vs_permuted_gen_acc_behavior.npz
│ └── probcat_gen_acc_behavior_prog_tworule.npz
├── exp2_behavioral_data
│ ├── ind_subj_results.npz
│ ├── probcat_MC_acc_behavior.npz
│ ├── probcat_gen_acc_behavior.npz
│ ├── probcat_MC_acc_behavior_onerule.npz
│ └── probcat_gen_acc_behavior_onerule.npz
├── exp1_vs_exp2_stats.r
├── exp1_corr_analysis.py
├── README.md
├── exp1_stats.r
├── exp2_plot_GPT3_vs_human.py
├── analyze_gpt3_exp2.py
├── combine_problems_1thru5.py
├── exp1_vs_exp2_create_stats_dset.py
├── exp1_create_stats_dset.py
├── eval_gpt_matprob.py
├── eval_gpt_matprob_prog_1thru5.py
├── analyze_gpt3_exp1.py
├── exp1_plot_GPT3_vs_human.py
└── gen_4_5_rule_problems.py
├── letter_string
├── all_prob.npz
├── GPT3_results
│ ├── onegen_acc.npz
│ ├── all_gen_acc.npz
│ ├── realworld_acc.npz
│ ├── zerogen_acc.npz
│ ├── ind_trial_results.npz
│ └── prob_subtype_acc.npz
├── gpt3_letterstring_results.npz
├── behavioral_results
│ ├── all_gen_acc.npz
│ ├── onegen_acc.npz
│ ├── zerogen_acc.npz
│ ├── realworld_acc.npz
│ ├── ind_subj_results.npz
│ └── prob_subtype_acc.npz
├── GPT3_results_noprompt
│ ├── all_gen_acc.npz
│ ├── onegen_acc.npz
│ ├── zerogen_acc.npz
│ ├── realworld_acc.npz
│ ├── prob_subtype_acc.npz
│ └── ind_trial_results.npz
├── GPT3_results_sentence
│ ├── all_gen_acc.npz
│ ├── onegen_acc.npz
│ ├── zerogen_acc.npz
│ ├── realworld_acc.npz
│ ├── prob_subtype_acc.npz
│ └── ind_trial_results.npz
├── gpt3_letterstring_results_noprompt.npz
├── gpt3_letterstring_results_sentence.npz
├── letterstring_analysis.R
├── README.md
├── corr_analysis.py
├── create_regression_dsets.py
├── eval_GPT3_letterstring_prob.py
├── compare_behavior_gpt3.py
├── gen_problems.py
└── analyze_gpt3_letterstring.py
├── story_analogies
├── human_vs_gpt3_analysis.R
├── README.md
├── gpt4_data.csv
├── analyze_GPT3_story_analogies.py
├── ind_cond_analyses.py
├── eval_GPT3_story_analogies.py
└── human_vs_gpt3_data.csv
├── LICENSE
└── README.md
/UCLA_VAT/UCLA_VAT.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/UCLA_VAT/UCLA_VAT.xlsx
--------------------------------------------------------------------------------
/digit_mat/all_problems.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/all_problems.npz
--------------------------------------------------------------------------------
/letter_string/all_prob.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/all_prob.npz
--------------------------------------------------------------------------------
/UCLA_VAT/UCLA_VAT_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/UCLA_VAT/UCLA_VAT_results.npz
--------------------------------------------------------------------------------
/digit_mat/all_problems_1thru5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/all_problems_1thru5.npz
--------------------------------------------------------------------------------
/digit_mat/gpt_matprob_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/gpt_matprob_results.npz
--------------------------------------------------------------------------------
/UCLA_VAT/UCLA_VAT_ind_subj_data.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/UCLA_VAT/UCLA_VAT_ind_subj_data.xlsx
--------------------------------------------------------------------------------
/digit_mat/all_4_5_rule_problems.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/all_4_5_rule_problems.npz
--------------------------------------------------------------------------------
/digit_mat/N_unique_rules_2rule_prob.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/N_unique_rules_2rule_prob.npz
--------------------------------------------------------------------------------
/digit_mat/N_unique_rules_3rule_prob.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/N_unique_rules_3rule_prob.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/all_prob.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_prob.npz
--------------------------------------------------------------------------------
/digit_mat/gpt_matprob_results_1thru5.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/gpt_matprob_results_1thru5.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results/onegen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/onegen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results/all_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/all_gen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results/realworld_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/realworld_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results/zerogen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/zerogen_acc.npz
--------------------------------------------------------------------------------
/letter_string/gpt3_letterstring_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/gpt3_letterstring_results.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/all_onerule_MC_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_onerule_MC_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/all_onerule_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_onerule_gen_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/all_probcat_MC_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_probcat_MC_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/all_probcat_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_probcat_gen_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_GPT3_data/all_onerule_MC_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_onerule_MC_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_GPT3_data/all_onerule_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_onerule_gen_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_GPT3_data/all_probcat_MC_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_probcat_MC_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_GPT3_data/all_probcat_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_probcat_gen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results/ind_trial_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/ind_trial_results.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results/prob_subtype_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/prob_subtype_acc.npz
--------------------------------------------------------------------------------
/letter_string/behavioral_results/all_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/all_gen_acc.npz
--------------------------------------------------------------------------------
/letter_string/behavioral_results/onegen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/onegen_acc.npz
--------------------------------------------------------------------------------
/letter_string/behavioral_results/zerogen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/zerogen_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/ind_subj_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/ind_subj_results.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_behavioral_data/ind_subj_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/ind_subj_results.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_noprompt/all_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/all_gen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_noprompt/onegen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/onegen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_noprompt/zerogen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/zerogen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_sentence/all_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/all_gen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_sentence/onegen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/onegen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_sentence/zerogen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/zerogen_acc.npz
--------------------------------------------------------------------------------
/letter_string/behavioral_results/realworld_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/realworld_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_noprompt/realworld_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/realworld_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_sentence/realworld_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/realworld_acc.npz
--------------------------------------------------------------------------------
/letter_string/behavioral_results/ind_subj_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/ind_subj_results.npz
--------------------------------------------------------------------------------
/letter_string/behavioral_results/prob_subtype_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/prob_subtype_acc.npz
--------------------------------------------------------------------------------
/letter_string/gpt3_letterstring_results_noprompt.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/gpt3_letterstring_results_noprompt.npz
--------------------------------------------------------------------------------
/letter_string/gpt3_letterstring_results_sentence.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/gpt3_letterstring_results_sentence.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/aligned_vs_permuted_MC_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/aligned_vs_permuted_MC_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/aligned_vs_permuted_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/aligned_vs_permuted_gen_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_noprompt/prob_subtype_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/prob_subtype_acc.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_sentence/prob_subtype_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/prob_subtype_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_MC.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_MC.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/tworule_prog_vs_noprog_gen_acc.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/tworule_prog_vs_noprog_gen_acc.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_noprompt/ind_trial_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/ind_trial_results.npz
--------------------------------------------------------------------------------
/letter_string/GPT3_results_sentence/ind_trial_results.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/ind_trial_results.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_MC.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_MC.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_gen.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_gen.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_gen.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_gen.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior_onerule.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior_onerule.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_MC.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_MC.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_gen.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_gen.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior_onerule.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior_onerule.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_onerule.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_onerule.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_MC.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_MC.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_gen.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_gen.npz
--------------------------------------------------------------------------------
/digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior_onerule.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior_onerule.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/aligned_vs_permuted_MC_acc_behavior.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/aligned_vs_permuted_MC_acc_behavior.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/aligned_vs_permuted_gen_acc_behavior.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/aligned_vs_permuted_gen_acc_behavior.npz
--------------------------------------------------------------------------------
/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_prog_tworule.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_prog_tworule.npz
--------------------------------------------------------------------------------
/story_analogies/human_vs_gpt3_analysis.R:
--------------------------------------------------------------------------------
1 | setwd("./")
2 |
3 | # Human vs. GPT-3
4 | data <-read.csv("./human_vs_gpt3_data.csv")
5 | human_vs_gpt3_model <- glm(correct_pred ~ human_vs_gpt, data=data, family="binomial")
6 | summary(human_vs_gpt3_model)
7 | human_vs_gpt3_OR <- exp(cbind(OR = coef(human_vs_gpt3_model), confint(human_vs_gpt3_model)))
8 | summary(human_vs_gpt3_OR)
9 |
10 |
11 |
--------------------------------------------------------------------------------
/UCLA_VAT/README.md:
--------------------------------------------------------------------------------
1 | ## UCLA Verbal Analogy Test
2 |
3 | To evaluate GPT-3 on UCLA VAT, run:
4 | ```
5 | python3 ./eval_gpt_UCLA_VAT.py
6 | ```
7 | Note that you will need to enter your OpenAI API key (line 8).
8 |
9 | To analyze GPT-3's responses and compare with human behavior, run:
10 | ```
11 | python3 ./analyze_UCLA_VAT.py
12 | ```
13 | Note that results for human participants and GPT-3 are already included in this repository.
14 |
--------------------------------------------------------------------------------
/digit_mat/exp1_vs_exp2_stats.r:
--------------------------------------------------------------------------------
1 | setwd("./")
2 | data <-read.csv("./exp1_vs_exp2_all_data.csv")
3 |
4 | # Generative task - problem type X experiment (progressive vs. random order) interaction
5 | # Human only
6 | human_gen <- glm(gen_correct_pred ~ prob_type + exp1_vs_exp2 + prob_type:exp1_vs_exp2, data=subset(data, human_vs_gpt==0), family="binomial")
7 | summary(human_gen)
8 | # GPT-3 only
9 | GPT3_gen <- glm(gen_correct_pred ~ prob_type + exp1_vs_exp2 + prob_type:exp1_vs_exp2, data=subset(data, human_vs_gpt==1), family="binomial")
10 | summary(GPT3_gen)
11 |
--------------------------------------------------------------------------------
/digit_mat/exp1_corr_analysis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats
3 |
4 | # Load human data
5 | human_prob_acc = np.load('./exp1_behavioral_data/ind_subj_results.npz')['all_subj_gen_correct_pred'].mean(0)
6 | gpt3_prob_acc = np.load('./exp1_GPT3_data/all_prob.npz')['all_gen'].reshape((-1,32)).astype(float).mean(0)
7 |
8 | # Correlation analysis
9 | corr_results = scipy.stats.pearsonr(gpt3_prob_acc, human_prob_acc)
10 | print('correlation analysis:')
11 | print('r = ' + str(np.around(corr_results[0],4)))
12 | print('p = ' + str(np.around(corr_results[1],4)))
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Taylor Webb licenses this file to You under the Apache License, Version 2.0 (the "License");
2 | you may not use this file except in compliance with the License. You may obtain a copy of the License at:
3 |
4 | http://www.apache.org/licenses/LICENSE-2.0
5 |
6 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
7 | on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 | See the License for the specific language governing permissions and limitations under the License.
9 |
--------------------------------------------------------------------------------
/story_analogies/README.md:
--------------------------------------------------------------------------------
1 | ## Story Analogies
2 |
3 | To perform individual analyses for text-davinci-003, GPT-4, and human participants, run:
4 | ```
5 | python3 ./ind_cond_analyses.py
6 | ```
7 | To perform analysis comparing human and GPT-3 performance, run the following R script:
8 | ```
9 | ./human_vs_gpt3_analysis.R
10 | ```
11 |
12 | Original story materials can be downloaded from [http://cvl.psych.ucla.edu/resources/AnalogyInventory.zip](http://cvl.psych.ucla.edu/resources/AnalogyInventory.zip) (file name: ```Cognitive Psychology.xlsx```, sheet name: ```Rattermann```).
13 |
--------------------------------------------------------------------------------
/letter_string/letterstring_analysis.R:
--------------------------------------------------------------------------------
1 | setwd("./")
2 | library(dplyr)
3 | library(Matrix)
4 | library(lme4)
5 | data <-read.csv("./letterstring_data.csv")
6 |
7 | # Omnibus model
8 | model <- glm(correct_pred ~ human_vs_gpt + N_gen + human_vs_gpt:N_gen, data=data, family="binomial")
9 | summary(model)
10 |
11 | # Zero-generalization problems
12 | zerogen_model <- glm(correct_pred ~ human_vs_gpt, data=subset(data, N_gen==0), family="binomial")
13 | summary(zerogen_model)
14 |
15 | # Effect of number of generalizations
16 | # Human
17 | model <- glm(correct_pred ~ N_gen, data=subset(data, human_vs_gpt==0), family="binomial")
18 | summary(model)
19 | # GPT-3
20 | model <- glm(correct_pred ~ N_gen, data=subset(data, human_vs_gpt==1), family="binomial")
21 | summary(model)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Emergent Analogical Reasoning in Large Language Models
2 |
3 | Code for the paper [Emergent Analogical Reasoning in Large Language Models](https://arxiv.org/abs/2212.09196).
4 |
5 | Code for each set of experiments, along with detailed instructions, are contained within separate directories.
6 |
7 | ## Prerequisites
8 |
9 | - Python 3
10 | - [OpenAI Python Library](https://github.com/openai/openai-python)
11 | - [NumPy](https://numpy.org/)
12 | - [SciPy](https://scipy.org/)
13 | - [statsmodels](https://www.statsmodels.org/stable/index.html)
14 | - [Matplotlib](https://matplotlib.org/)
15 | - [pandas](https://pandas.pydata.org/)
16 | - [R](https://www.r-project.org/)
17 |
18 |
19 | ## Authorship
20 |
21 | All code was written by [Taylor Webb](https://github.com/taylorwwebb).
22 |
--------------------------------------------------------------------------------
/story_analogies/gpt4_data.csv:
--------------------------------------------------------------------------------
1 | prob_ID,correct_pred,correct_AB,analogy_vs_similarity
2 | 1,1,0,0
3 | 2,1,0,0
4 | 3,1,0,0
5 | 4,1,0,0
6 | 5,1,0,0
7 | 6,1,0,0
8 | 7,1,0,0
9 | 8,0,0,0
10 | 9,1,0,0
11 | 10,0,0,0
12 | 11,1,0,0
13 | 12,1,0,0
14 | 13,1,0,0
15 | 14,1,0,0
16 | 15,1,0,0
17 | 16,1,0,0
18 | 17,1,0,0
19 | 18,1,0,0
20 | 1,0,1,0
21 | 2,1,1,0
22 | 3,0,1,0
23 | 4,1,1,0
24 | 5,1,1,0
25 | 6,1,1,0
26 | 7,0,1,0
27 | 8,0,1,0
28 | 9,1,1,0
29 | 10,0,1,0
30 | 11,1,1,0
31 | 12,1,1,0
32 | 13,0,1,0
33 | 14,1,1,0
34 | 15,0,1,0
35 | 16,1,1,0
36 | 17,1,1,0
37 | 18,1,1,0
38 | 1,1,0,1
39 | 2,1,0,1
40 | 3,1,0,1
41 | 4,1,0,1
42 | 5,1,0,1
43 | 6,1,0,1
44 | 7,1,0,1
45 | 8,1,0,1
46 | 9,1,0,1
47 | 10,1,0,1
48 | 11,1,0,1
49 | 12,1,0,1
50 | 13,1,0,1
51 | 14,1,0,1
52 | 15,1,0,1
53 | 16,1,0,1
54 | 17,1,0,1
55 | 18,1,0,1
56 | 1,1,1,1
57 | 2,1,1,1
58 | 3,1,1,1
59 | 4,1,1,1
60 | 5,1,1,1
61 | 6,1,1,1
62 | 7,1,1,1
63 | 8,0,1,1
64 | 9,1,1,1
65 | 10,1,1,1
66 | 11,1,1,1
67 | 12,1,1,1
68 | 13,1,1,1
69 | 14,1,1,1
70 | 15,1,1,1
71 | 16,1,1,1
72 | 17,1,1,1
73 | 18,1,1,1
--------------------------------------------------------------------------------
/letter_string/README.md:
--------------------------------------------------------------------------------
1 | ## Letter String Analogies
2 |
3 | To create new letter string problems, run:
4 | ```
5 | python3 ./gen_problems.py
6 | ```
7 | Problems are contained in:
8 | ```
9 | ./all_prob.npz
10 | ```
11 | To evaluate GPT-3 on letter string problems, run:
12 | ```
13 | python3 ./eval_GPT3_letterstring_prob.py
14 | ```
15 | Note that you will need to enter your OpenAI API key (line 15).
16 |
17 | To analyze GPT-3's responses, run:
18 | ```
19 | python3 ./analyze_gpt3_letterstring.py
20 | ```
21 | To plot figures comparing GPT-3 with human performance, run:
22 | ```
23 | python3 ./compare_behavior_gpt3.py
24 | ```
25 | To perform regression analyses, run:
26 | ```
27 | python3 ./create_regression_dsets.py
28 | ```
29 | and run the following R script:
30 | ```
31 | ./letterstring_analysis.R
32 | ```
33 | To perform correlation analyses, run:
34 | ```
35 | python3 ./corr_analysis.py
36 | ```
37 | To evaluate GPT-3 on letter string problems presented in alternative formats (without a prompt, or in the form of a setence), use the ```--noprompt``` or ```--sentence``` arguments for these scripts.
38 |
39 | Note that results for human participants and GPT-3 are already included in this repository.
40 |
--------------------------------------------------------------------------------
/letter_string/corr_analysis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import argparse
4 | import scipy.stats
5 |
6 | # Settings
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.")
9 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.")
10 | args = parser.parse_args()
11 |
12 | # Load data
13 | human_data = np.load('./behavioral_results/prob_subtype_acc.npz')
14 | human_subtype_acc = human_data['subtype_acc']
15 | human_subtype_counts = human_data['subtype_counts']
16 | if args.sentence:
17 | gpt3_data = np.load('./GPT3_results_sentence/prob_subtype_acc.npz')
18 | elif args.noprompt:
19 | gpt3_data = np.load('./GPT3_results_noprompt/prob_subtype_acc.npz')
20 | else:
21 | gpt3_data = np.load('./GPT3_results/prob_subtype_acc.npz')
22 | gpt3_subtype_acc = gpt3_data['subtype_acc']
23 | gpt3_subtype_counts = gpt3_data['subtype_counts']
24 |
25 | # Minimum number of trials per subtype
26 | min_trials = 5
27 | include = np.all(np.stack([(human_subtype_counts > min_trials),(gpt3_subtype_counts > min_trials)]),0)
28 | print('minimum number of trials per subtype = ' + str(min_trials))
29 | print(str(include.sum()) + ' out of ' + str(len(include)) + ' subtypes included')
30 |
31 | # Correlation analysis
32 | corr_results = scipy.stats.pearsonr(gpt3_subtype_acc[include], human_subtype_acc[include])
33 | print('correlation analysis:')
34 | print('r = ' + str(np.around(corr_results[0],4)))
35 | print('p = ' + str(np.around(corr_results[1],10)))
--------------------------------------------------------------------------------
/digit_mat/README.md:
--------------------------------------------------------------------------------
1 | ## Digit Matrices
2 |
3 | To create new digit matrix problems, run:
4 | ```
5 | python3 ./gen_problems.py
6 | python3 ./gen_4_5_rule_problems.py
7 | python3 ./combine_problems_1thru5.py
8 | ```
9 | Problems for experiment #1 (1-3 rule and logic problems, results in Figure 3 of paper) are contained in:
10 | ```
11 | ./all_problems.npz
12 | ```
13 | and problems for experiment #2 (1-5 rule problems, results in Supplementary Figure S4) are contained in:
14 | ```
15 | ./all_problems_1thru5.npz
16 | ```
17 |
18 | To evaluate GPT-3 on experiment #1, run:
19 | ```
20 | python3 ./eval_gpt_matprob.py
21 | ```
22 | Note that you will need to enter your OpenAI API key (line 14).
23 |
24 | To evaluate GPT-3 on experiment #2, run:
25 | ```
26 | python3 ./eval_gpt_matprob_prog_1thru5.py
27 | ```
28 | To analyze GPT-3's responses on these experiments, run:
29 | ```
30 | python3 ./analyze_gpt3_exp1.py
31 | python3 ./analyze_gpt3_exp2.py
32 | ```
33 | To plot figures comparing GPT-3 with human performance, run:
34 | ```
35 | python3 ./exp1_plot_GPT3_vs_human.py
36 | python3 ./exp2_plot_GPT3_vs_human.py
37 | ```
38 | To perform statistical analysis for experiment #1, run:
39 | ```
40 | python3 ./exp1_create_stats_dset.py
41 | ```
42 | and run the following R script:
43 | ```
44 | ./exp1_stats.r
45 | ```
46 | To perform analysis comparing experiments #1 and #2, run:
47 | ```
48 | python3 ./exp1_vs_exp2_create_stats_dset.py
49 | ```
50 | and run the following R script:
51 | ```
52 | ./exp1_vs_exp2_stats.r
53 | ```
54 | Note that results for human participants and GPT-3 are already included in this repository.
55 |
--------------------------------------------------------------------------------
/story_analogies/analyze_GPT3_story_analogies.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Load data
4 | results = np.load('./gpt3_rattermann_results.npz')
5 | all_analogy_results = results['all_analogy_results']
6 | all_analogy_correct = results['all_analogy_correct']
7 | all_similarity_results = results['all_similarity_results']
8 | all_similarity_correct = results['all_similarity_correct']
9 |
10 | # Score
11 | all_analogy_correct_pred = []
12 | for i in range(len(all_analogy_results)):
13 | print('response: ' + all_analogy_results[i])
14 | print('correct_answer: ' + all_analogy_correct[i])
15 | correct_pred = int(input("Correct? (correct=1, incorrect=0): "))
16 | print(' ')
17 | all_analogy_correct_pred.append(correct_pred)
18 | all_similarity_correct_pred = []
19 | for i in range(len(all_similarity_results)):
20 | print('response: ' + all_similarity_results[i])
21 | print('correct_answer: ' + all_similarity_correct[i])
22 | correct_pred = int(input("Correct? (correct=1, incorrect=0): "))
23 | print('')
24 | all_similarity_correct_pred.append(correct_pred)
25 |
26 | # Report accuracy
27 | analogy_acc = np.mean(all_analogy_correct_pred)
28 | print('analogy acc. = ' + str(np.around(analogy_acc,4)))
29 | similarity_acc = np.mean(all_similarity_correct_pred)
30 | print('similarity acc. = ' + str(np.around(similarity_acc,4)))
31 |
32 | # Save
33 | np.savez('./gpt3_rattermann_results.npz', all_analogy_results=all_analogy_results, all_analogy_correct=all_analogy_correct, all_analogy_correct_pred=all_analogy_correct_pred,
34 | all_similarity_results=all_similarity_results, all_similarity_correct=all_similarity_correct, all_similarity_correct_pred=all_similarity_correct_pred)
35 |
36 |
--------------------------------------------------------------------------------
/digit_mat/exp1_stats.r:
--------------------------------------------------------------------------------
1 | # Load data
2 | setwd("./")
3 | data <-read.csv("./exp1_all_data.csv")
4 |
5 | # Overall generative performance
6 | all_gen <- glm(gen_correct_pred ~ prob_type + human_vs_gpt + prob_type:human_vs_gpt, data=data, family="binomial")
7 | summary(all_gen)
8 | # Overall multiple-choice performance
9 | all_MC <- glm(MC_correct_pred ~ prob_type + human_vs_gpt + prob_type:human_vs_gpt, data=data, family="binomial")
10 | summary(all_MC)
11 |
12 | # Two-rule problems, generative performance, progression rule vs. no progression rule
13 | # Human only
14 | human_prog_vs_noprog <- glm(gen_correct_pred ~ tworule_prog_noprog, data=subset(subset(data, prob_type==1), human_vs_gpt==0), family="binomial")
15 | summary(human_prog_vs_noprog)
16 | # GPT-3 only
17 | GPT3_prog_vs_noprog <- glm(gen_correct_pred ~ tworule_prog_noprog, data=subset(subset(data, prob_type==1), human_vs_gpt==1), family="binomial")
18 | summary(GPT3_prog_vs_noprog)
19 |
20 | # Analysis of relational complexity
21 | # Human only
22 | human_relcompl <- glm(gen_correct_pred ~ N_unique_rules, data=subset(subset(data, prob_type==2), human_vs_gpt==0), family="binomial")
23 | summary(human_relcompl)
24 | # GPT-3 only
25 | GPT3_relcompl <- glm(gen_correct_pred ~ N_unique_rules, data=subset(subset(data, prob_type==2), human_vs_gpt==1), family="binomial")
26 | summary(GPT3_relcompl)
27 |
28 | # Aligned vs. permuted logic problems
29 | # Human only
30 | human_aligned_permute <- glm(gen_correct_pred ~ aligned_permuted, data=subset(subset(data, prob_type==3), human_vs_gpt==0), family="binomial")
31 | summary(human_aligned_permute)
32 | # GPT-3 only
33 | GPT3_aligned_permute <- glm(gen_correct_pred ~ aligned_permuted, data=subset(subset(data, prob_type==3), human_vs_gpt==1), family="binomial")
34 | summary(GPT3_aligned_permute)
35 |
36 |
37 |
--------------------------------------------------------------------------------
/letter_string/create_regression_dsets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import csv
3 | import builtins
4 |
5 | # Load human behavioral data
6 | human_data = np.load('./behavioral_results/ind_subj_results.npz')
7 | human_correct_pred = human_data['all_subj_correct_pred']
8 | human_prob_subtype = human_data['all_subj_prob_subtype']
9 | human_N_gen = human_data['all_subj_N_gen']
10 | human_realworld = human_data['all_subj_realworld']
11 |
12 | # Load GPT-3 data
13 | gpt3_data = np.load('./GPT3_results/ind_trial_results.npz')
14 | gpt3_correct_pred = gpt3_data['all_prob_type_correct_pred']
15 | gpt3_prob_subtype = gpt3_data['all_prob_type_subtype']
16 | gpt3_N_gen = gpt3_data['all_prob_N_gen']
17 | gpt3_realworld = gpt3_data['all_prob_realworld']
18 |
19 | # Create dataset
20 | subjID = []
21 | correct_pred = []
22 | human_vs_gpt = []
23 | prob_subtype = []
24 | N_gen = []
25 | realworld = []
26 |
27 | # Add human data
28 | N_subj = human_correct_pred.shape[0]
29 | N_prob = human_correct_pred.shape[1]
30 | for s in range(N_subj):
31 | for p in range(N_prob):
32 | subjID.append(s)
33 | correct_pred.append(int(human_correct_pred[s,p]))
34 | human_vs_gpt.append(0)
35 | N_gen.append(human_N_gen[s,p])
36 | realworld.append(human_realworld[s,p])
37 | if p < human_prob_subtype.shape[1]:
38 | prob_subtype.append(human_prob_subtype[s,p])
39 | else:
40 | prob_subtype.append(-1)
41 |
42 | # Add GPT-3 data
43 | N_trials_per_prob = gpt3_correct_pred.shape[1]
44 | for t in range(N_trials_per_prob):
45 | for p in range(N_prob):
46 | subjID.append(s+1)
47 | correct_pred.append(int(gpt3_correct_pred[p,t]))
48 | human_vs_gpt.append(1)
49 | N_gen.append(gpt3_N_gen[p,t])
50 | realworld.append(gpt3_realworld[p,t])
51 | if p < gpt3_prob_subtype.shape[0]:
52 | prob_subtype.append(gpt3_prob_subtype[p,t])
53 | else:
54 | prob_subtype.append(-1)
55 |
56 | # Write csv files
57 | # Create file
58 | f = open('./letterstring_data.csv', 'w')
59 | writer = csv.writer(f)
60 | # Header
61 | header = ['subjID', 'correct_pred', 'human_vs_gpt', 'prob_subtype', 'N_gen']
62 | writer.writerow(header)
63 | # Write data
64 | for i in range(len(subjID)):
65 | if realworld[i] == 0:
66 | data_row = [subjID[i], correct_pred[i], human_vs_gpt[i], prob_subtype[i], N_gen[i]]
67 | writer.writerow(data_row)
68 | # Close file
69 | f.close()
70 |
--------------------------------------------------------------------------------
/letter_string/eval_GPT3_letterstring_prob.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import numpy as np
3 | import builtins
4 | import argparse
5 | import os
6 | import time
7 |
8 | # Settings
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.")
11 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.")
12 | args = parser.parse_args()
13 |
14 | # GPT-3 settings
15 | openai.api_key = "FILL_IN_API_KEY_HERE"
16 | if args.sentence:
17 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":40, "echo":False, "logprobs":1, }
18 | else:
19 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":40, "stop":"\n", "echo":False, "logprobs":1, }
20 |
21 | # Load all problems
22 | all_prob = np.load('./all_prob.npz', allow_pickle=True)['all_prob']
23 | prob_types = builtins.list(all_prob.item().keys())
24 | N_prob_types = len(prob_types)
25 |
26 | # Evaluate
27 | N_trials_per_prob_type = 50
28 | all_prob_type_responses = []
29 | for p in range(N_prob_types):
30 | print('problem type' + str(p+1) + ' of ' + str(N_prob_types) + '...')
31 | prob_type_responses = []
32 | for t in range(N_trials_per_prob_type):
33 | print('trial ' + str(t+1) + ' of ' + str(N_trials_per_prob_type) + '...')
34 | # Generate prompt
35 | prob = all_prob.item()[prob_types[p]]['prob'][t]
36 | prompt = ''
37 | if not args.noprompt:
38 | prompt += "Let's try to complete the pattern:\n\n"
39 | if args.sentence:
40 | prompt += 'If '
41 | for i in range(len(prob[0][0])):
42 | prompt += str(prob[0][0][i])
43 | if i < len(prob[0][0]) - 1:
44 | prompt += ' '
45 | prompt += ' changes to '
46 | for i in range(len(prob[0][1])):
47 | prompt += str(prob[0][1][i])
48 | if i < len(prob[0][1]) - 1:
49 | prompt += ' '
50 | prompt += ', then '
51 | for i in range(len(prob[1][0])):
52 | prompt += str(prob[1][0][i])
53 | if i < len(prob[1][0]) - 1:
54 | prompt += ' '
55 | prompt += ' should change to '
56 | else:
57 | prompt += '['
58 | for i in range(len(prob[0][0])):
59 | prompt += str(prob[0][0][i])
60 | if i < len(prob[0][0]) - 1:
61 | prompt += ' '
62 | prompt += '] ['
63 | for i in range(len(prob[0][1])):
64 | prompt += str(prob[0][1][i])
65 | if i < len(prob[0][1]) - 1:
66 | prompt += ' '
67 | prompt += ']\n['
68 | for i in range(len(prob[1][0])):
69 | prompt += str(prob[1][0][i])
70 | if i < len(prob[1][0]) - 1:
71 | prompt += ' '
72 | prompt += '] ['
73 | # Get response
74 | response = []
75 | while len(response) == 0:
76 | try:
77 | response = openai.Completion.create(prompt=prompt, **kwargs)
78 | except:
79 | print('trying again...')
80 | time.sleep(5)
81 | prob_type_responses.append(response['choices'][0]['text'])
82 | all_prob_type_responses.append(prob_type_responses)
83 | # Save
84 | save_fname = './gpt3_letterstring_results'
85 | if args.sentence:
86 | save_fname += '_sentence'
87 | if args.noprompt:
88 | save_fname += '_noprompt'
89 | save_fname += '.npz'
90 | np.savez(save_fname, all_prob_type_responses=all_prob_type_responses, allow_pickle=True)
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/story_analogies/ind_cond_analyses.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import numpy as np
3 | from scipy.stats import binomtest, ttest_1samp
4 |
5 | # Human vs. GPT-3 data
6 | human_data_file = open('./human_vs_gpt3_data.csv')
7 | csvreader = csv.reader(human_data_file)
8 | header = []
9 | header = np.array(next(csvreader))
10 | rows = []
11 | for row in csvreader:
12 | rows.append(row)
13 | rows = np.array(rows).astype(int)
14 |
15 | # Human data analyses
16 | human_vs_gpt = np.where(header == 'human_vs_gpt')[0][0]
17 | human_data = rows[rows[:,human_vs_gpt]==0, :]
18 | human_acc = []
19 | subjID = np.where(header == 'subjID')[0][0]
20 | N_subj = np.max(np.unique(human_data[:,subjID]))
21 | for s in range(N_subj):
22 | subj_data = human_data[human_data[:,subjID] == s+1, :]
23 | correct_pred = np.where(header == 'correct_pred')[0][0]
24 | subj_acc = np.mean(subj_data[:,correct_pred])
25 | human_acc.append(subj_acc)
26 | print('\nHuman:')
27 | print('Combined:')
28 | print(ttest_1samp(human_acc, 0.5))
29 | # Near analogy
30 | analogy_vs_similarity = np.where(header == 'analogy_vs_similarity')[0][0]
31 | human_similarity_data = human_data[human_data[:,analogy_vs_similarity]==1, :]
32 | human_similarity_acc = []
33 | for s in range(N_subj):
34 | subj_data = human_similarity_data[human_similarity_data[:,subjID] == s+1, :]
35 | subj_acc = np.mean(subj_data[:,correct_pred])
36 | human_similarity_acc.append(subj_acc)
37 | print('Near analogy:')
38 | print(ttest_1samp(human_similarity_acc, 0.5))
39 | # Far analogy
40 | human_analogy_data = human_data[human_data[:,analogy_vs_similarity]==0, :]
41 | human_analogy_acc = []
42 | for s in range(N_subj):
43 | subj_data = human_analogy_data[human_analogy_data[:,subjID] == s+1, :]
44 | subj_acc = np.mean(subj_data[:,correct_pred])
45 | human_analogy_acc.append(subj_acc)
46 | print('Far analogy:')
47 | print(ttest_1samp(human_analogy_acc, 0.5))
48 |
49 | # GPT-3 data analyses
50 | gpt3_data = rows[rows[:,human_vs_gpt]==1, :]
51 | print('\nGPT-3:')
52 | print('Combined:')
53 | print(binomtest(gpt3_data[:,correct_pred].sum(), n=gpt3_data.shape[0], p=0.5))
54 | # Near analogy
55 | gpt3_similarity_data = gpt3_data[gpt3_data[:,analogy_vs_similarity]==1, :]
56 | print('Near analogy:')
57 | print(binomtest(gpt3_similarity_data[:,correct_pred].sum(), n=gpt3_similarity_data.shape[0], p=0.5))
58 | # Far analogy
59 | gpt3_analogy_data = gpt3_data[gpt3_data[:,analogy_vs_similarity]==0, :]
60 | print('Far analogy:')
61 | print(binomtest(gpt3_analogy_data[:,correct_pred].sum(), n=gpt3_analogy_data.shape[0], p=0.5))
62 |
63 | # GPT-4 data
64 | human_data_file = open('./gpt4_data.csv')
65 | csvreader = csv.reader(human_data_file)
66 | header = []
67 | header = np.array(next(csvreader))
68 | rows = []
69 | for row in csvreader:
70 | rows.append(row)
71 | rows = np.array(rows).astype(int)
72 |
73 | # GPT-4 data analyses
74 | gpt4_data = rows
75 | correct_pred = np.where(header == 'correct_pred')[0][0]
76 | print('\nGPT-4:')
77 | print('Combined:')
78 | print(binomtest(gpt4_data[:,correct_pred].sum(), n=gpt4_data.shape[0], p=0.5))
79 | # Near analogy
80 | analogy_vs_similarity = np.where(header == 'analogy_vs_similarity')[0][0]
81 | gpt4_similarity_data = gpt4_data[gpt4_data[:,analogy_vs_similarity]==1, :]
82 | print('Near analogy:')
83 | print(binomtest(gpt4_similarity_data[:,correct_pred].sum(), n=gpt4_similarity_data.shape[0], p=0.5))
84 | # Far analogy
85 | gpt4_analogy_data = gpt4_data[gpt4_data[:,analogy_vs_similarity]==0, :]
86 | print('Far analogy:')
87 | print(binomtest(gpt4_analogy_data[:,correct_pred].sum(), n=gpt4_analogy_data.shape[0], p=0.5))
88 | print(' ')
89 |
90 |
--------------------------------------------------------------------------------
/UCLA_VAT/eval_gpt_UCLA_VAT.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import numpy as np
3 | import pandas as pd
4 | import builtins
5 | import time
6 |
7 | # GPT-3 settings
8 | openai.api_key = "FILL_IN_API_KEY_HERE"
9 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":10, "stop":"\n", "echo":True, "logprobs":1, }
10 |
11 | # Load problems
12 | df = pd.read_excel (r'./UCLA_VAT.xlsx', sheet_name='UCLA_VAT')
13 | # Extract data
14 | A = builtins.list(df['A'])
15 | B = builtins.list(df['B'])
16 | C = builtins.list(df['C'])
17 | D = builtins.list(df['D'])
18 | D_prime = builtins.list(df["D'"])
19 |
20 | # Initialize storage for results
21 | all_synonym_correct_pred = []
22 | all_opposite_correct_pred = []
23 | all_function_correct_pred = []
24 | all_category_correct_pred = []
25 | results_fname = './UCLA_VAT_results.npz'
26 | # Evaluate
27 | N_prob = len(A)
28 | prob_order = np.arange(N_prob)
29 | np.random.shuffle(prob_order)
30 | for p in range(N_prob):
31 | print(str(p+1) + ' of ' + str(N_prob) + '...')
32 | if p == 0:
33 | prompt = A[prob_order[p]] + ' : ' + B[prob_order[p]] + ' :: ' + C[prob_order[p]] + ' : '
34 | else:
35 | prompt = context + '\n\n' + A[prob_order[p]] + ' : ' + B[prob_order[p]] + ' :: ' + C[prob_order[p]] + ' : '
36 | # Correct answer
37 | d_prompt = prompt + D[prob_order[p]]
38 | response = []
39 | while len(response) == 0:
40 | try:
41 | response = openai.Completion.create(prompt=d_prompt, **kwargs)
42 | except:
43 | print('trying again...')
44 | time.sleep(5)
45 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1]
46 | if len(d_prompt) > np.array(response['choices'][0]['logprobs']['text_offset'])[0]:
47 | d_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:])
48 | else:
49 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(d_prompt))[0][0]
50 | d_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind])
51 | # Foil
52 | d_prime_prompt = prompt + D_prime[prob_order[p]]
53 | response = []
54 | while len(response) == 0:
55 | try:
56 | response = openai.Completion.create(prompt=d_prime_prompt, **kwargs)
57 | except:
58 | print('trying again...')
59 | time.sleep(5)
60 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1]
61 | if len(d_prime_prompt) > np.array(response['choices'][0]['logprobs']['text_offset'])[0]:
62 | d_prime_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:])
63 | else:
64 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(d_prime_prompt))[0][0]
65 | d_prime_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind])
66 | # Correct
67 | correct_pred = d_avg_logprob > d_prime_avg_logprob
68 | if prob_order[p] < 20:
69 | all_synonym_correct_pred.append(correct_pred)
70 | elif prob_order[p] >= 20 and prob_order[p] < 40:
71 | all_opposite_correct_pred.append(correct_pred)
72 | elif prob_order[p] >= 40 and prob_order[p] < 60:
73 | all_function_correct_pred.append(correct_pred)
74 | elif prob_order[p] >= 60:
75 | all_category_correct_pred.append(correct_pred)
76 | # Add problem to context
77 | if correct_pred:
78 | context = d_prompt
79 | else:
80 | context = d_prime_prompt
81 | # Save results
82 | np.savez(results_fname, synonym=all_synonym_correct_pred, opposite=all_opposite_correct_pred, function=all_function_correct_pred, category=all_category_correct_pred, context=context, prob_order=prob_order, allow_pickle=True)
83 |
--------------------------------------------------------------------------------
/UCLA_VAT/analyze_UCLA_VAT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from statsmodels.stats.proportion import proportion_confint
4 | from scipy.stats import sem
5 | import pandas as pd
6 | import builtins
7 |
8 | def hide_top_right(ax):
9 | ax.spines['right'].set_visible(False)
10 | ax.spines['top'].set_visible(False)
11 | ax.yaxis.set_ticks_position('left')
12 | ax.xaxis.set_ticks_position('bottom')
13 |
14 | def plot_ind_data(ax, x_points, data, ind_bar_width):
15 | max_count = 30
16 | point_unit = ind_bar_width / max_count
17 | # Plot
18 | for i in range(len(x_points)):
19 | unique_vals = np.unique(data[i])
20 | for v in unique_vals:
21 | count = (data[i]==v).sum()
22 | span = count * point_unit
23 | x_min = x_points[i] - (span/2)
24 | x_max = x_points[i] + (span/2)
25 | x_vals = np.linspace(x_min,x_max,count)
26 | if count == 1:
27 | x_vals = np.mean([x_min,x_max])
28 | if v == 0:
29 | y_vals = np.ones(count) * 0.005
30 | else:
31 | y_vals = np.ones(count) * v
32 | plt.scatter(x_vals, y_vals, color='black', s=0.4, clip_on=False, marker='_')
33 | return ax
34 |
35 | # Load data
36 | all_data = np.load('./UCLA_VAT_results.npz', allow_pickle=True)
37 |
38 | # Calculate accuracy and confidence intervals
39 | conditions = ['category', 'function', 'opposite', 'synonym']
40 | N_conditions = len(conditions)
41 | all_acc = []
42 | all_correct_pred = []
43 | all_ci_lower = []
44 | all_ci_upper = []
45 | for c in range(N_conditions):
46 | correct_pred = all_data[conditions[c]]
47 | all_correct_pred.append(correct_pred)
48 | acc = correct_pred.mean()
49 | all_acc.append(acc)
50 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0])
51 | all_ci_lower.append(ci_lower)
52 | all_ci_upper.append(ci_upper)
53 | # Combine conditions
54 | all_acc = np.array(all_acc)
55 | all_ci_lower = np.array(all_ci_lower)
56 | all_ci_upper = np.array(all_ci_upper)
57 | all_lower_err = all_acc - all_ci_lower
58 | all_upper_err = all_ci_upper - all_acc
59 | all_err = np.array([all_lower_err, all_upper_err])
60 | # Overall accuracy
61 | all_correct_pred = np.concatenate(all_correct_pred)
62 | overall_gpt3_acc = all_correct_pred.astype(float).mean()
63 | overall_gpt3_ci_lower, overall_gpt3_ci_upper = proportion_confint(all_correct_pred.sum(), all_correct_pred.shape[0])
64 | overall_gpt3_err = [overall_gpt3_acc - overall_gpt3_ci_lower, overall_gpt3_ci_upper - overall_gpt3_acc]
65 |
66 | # Human data
67 | df = pd.read_excel (r'./UCLA_VAT_ind_subj_data.xlsx', sheet_name='ind_subj')
68 | category_ind_subj_acc = np.array(builtins.list(df['category'])[1:]) / 100.
69 | function_ind_subj_acc = np.array(builtins.list(df['function'])[1:]) / 100.
70 | opposite_ind_subj_acc = np.array(builtins.list(df['opposite'])[1:]) / 100.
71 | synonym_ind_subj_acc = np.array(builtins.list(df['synonym'])[1:]) / 100.
72 | human_ind_subj = np.array([category_ind_subj_acc, function_ind_subj_acc, opposite_ind_subj_acc, synonym_ind_subj_acc])
73 | human_acc = human_ind_subj.mean(1)
74 | human_err = sem(human_ind_subj,1)
75 |
76 | # Plot
77 | bar_width = 0.8
78 | ind_bar_width = bar_width / 2
79 | x_points = np.arange(N_conditions)
80 | gpt3_color = 'darkslateblue'
81 | human_color = 'powderblue'
82 | plot_fontsize = 14
83 | title_fontsize = 16
84 | ax = plt.subplot(111)
85 | plt.bar(x_points - (ind_bar_width/2), all_acc, yerr=all_err, color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
86 | plt.bar(x_points + (ind_bar_width/2), human_acc, yerr=human_err, color=human_color, edgecolor='black', width=ind_bar_width)
87 | plt.ylim([0,1])
88 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
89 | plt.ylabel('Accuracy', fontsize=plot_fontsize)
90 | plt.xticks(x_points, ['Categorical', 'Function', 'Antonym', 'Synonym'], fontsize=12)
91 | hide_top_right(ax)
92 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False, bbox_to_anchor=(0.9,1))
93 | plot_ind_data(ax, x_points + (ind_bar_width/2), human_ind_subj, ind_bar_width)
94 | current_xlim = ax.get_xlim()
95 | plt.plot([current_xlim[0], current_xlim[1]],[0.5,0.5],color='black',alpha=0.4)
96 | plt.xlim([current_xlim[0], current_xlim[1]])
97 | plt.title('UCLA VAT', fontsize=title_fontsize, pad=20)
98 | ax.set_aspect(4)
99 | plt.tight_layout()
100 | plt.savefig('./UCLA_VAT_results.pdf', dpi=300, bbox_inches="tight")
101 | plt.close()
102 |
--------------------------------------------------------------------------------
/story_analogies/eval_GPT3_story_analogies.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import numpy as np
3 | import builtins
4 | import time
5 |
6 | # GPT-3 settings
7 | openai.api_key = "FILL_IN_API_KEY_HERE"
8 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":256, "echo":False, "logprobs":1, }
9 |
10 | # Load problems
11 | df = pd.read_excel (r'./Rattermann.xlsx', sheet_name='Rattermann')
12 | source_story = builtins.list(df['Base'])[1:19]
13 | true_analogy = builtins.list(df['True Analogy Story'])[1:19]
14 | false_analogy = builtins.list(df['False Analogy Story'])[1:19]
15 | literal_similarity = builtins.list(df['Literally similar story'])[1:19]
16 | mere_appearance = builtins.list(df['Mere-Appearance Match'])[1:19]
17 |
18 | # Initialize results
19 | all_analogy_results = []
20 | all_analogy_correct = []
21 | all_similarity_results = []
22 | all_similarity_correct = []
23 | N_source_stories = 18
24 | for s in range(N_source_stories):
25 | print('Source story ' + str(s+1) + ' of ' + str(N_source_stories) + '...')
26 | # A. True analogy B. False analogy
27 | print('True analogy vs. false analogy')
28 | print(' ')
29 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + true_analogy[s] + '\n\nStory B: ' + false_analogy[s] + '\n\n'
30 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?'
31 | print(prompt)
32 | response = []
33 | while len(response) == 0:
34 | try:
35 | response = openai.Completion.create(prompt=prompt, **kwargs)
36 | except:
37 | print('trying again...')
38 | time.sleep(5)
39 | all_analogy_results.append(response['choices'][0]['text'])
40 | print('response: ' + response['choices'][0]['text'])
41 | all_analogy_correct.append('A')
42 | print('correct_answer: A')
43 | print(' ')
44 | # A. False analogy B. True analogy
45 | print('False analogy vs. true analogy')
46 | print(' ')
47 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + false_analogy[s] + '\n\nStory B: ' + true_analogy[s] + '\n\n'
48 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?'
49 | print(prompt)
50 | response = []
51 | while len(response) == 0:
52 | try:
53 | response = openai.Completion.create(prompt=prompt, **kwargs)
54 | except:
55 | print('trying again...')
56 | time.sleep(5)
57 | all_analogy_results.append(response['choices'][0]['text'])
58 | print('response: ' + response['choices'][0]['text'])
59 | all_analogy_correct.append('B')
60 | print('correct_answer: B')
61 | print(' ')
62 | # A. Literal similarity B. Mere appearance
63 | print('Literal similarity vs. mere appearance')
64 | print(' ')
65 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + literal_similarity[s] + '\n\nStory B: ' + mere_appearance[s] + '\n\n'
66 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?'
67 | print(prompt)
68 | response = []
69 | while len(response) == 0:
70 | try:
71 | response = openai.Completion.create(prompt=prompt, **kwargs)
72 | except:
73 | print('trying again...')
74 | time.sleep(5)
75 | all_similarity_results.append(response['choices'][0]['text'])
76 | print('response: ' + response['choices'][0]['text'])
77 | all_similarity_correct.append('A')
78 | print('correct_answer: A')
79 | print(' ')
80 | # A. Mere appearance B. Literal similarity
81 | print('Mere appearance vs. Literal similarity')
82 | print(' ')
83 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + mere_appearance[s] + '\n\nStory B: ' + literal_similarity[s] + '\n\n'
84 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?'
85 | print(prompt)
86 | response = []
87 | while len(response) == 0:
88 | try:
89 | response = openai.Completion.create(prompt=prompt, **kwargs)
90 | except:
91 | print('trying again...')
92 | time.sleep(5)
93 | all_similarity_results.append(response['choices'][0]['text'])
94 | print('response: ' + response['choices'][0]['text'])
95 | all_similarity_correct.append('B')
96 | print('correct_answer: B')
97 | print(' ')
98 | # Save results
99 | np.savez('./gpt3_rattermann_results.npz', all_analogy_results=all_analogy_results, all_analogy_correct=all_analogy_correct, all_similarity_results=all_similarity_results, all_similarity_correct=all_similarity_correct)
100 |
101 |
--------------------------------------------------------------------------------
/digit_mat/exp2_plot_GPT3_vs_human.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 |
5 | def check_path(path):
6 | if not os.path.exists(path):
7 | os.mkdir(path)
8 |
9 | def hide_top_right(ax):
10 | ax.spines['right'].set_visible(False)
11 | ax.spines['top'].set_visible(False)
12 | ax.yaxis.set_ticks_position('left')
13 | ax.xaxis.set_ticks_position('bottom')
14 |
15 | # Digit matrices
16 | # 1-rule, 2-rule, 3-rule, 4-rule, and 5-rule problems
17 | # GPT-3: progressive
18 | # Humans: progressive
19 |
20 | # Load GPT-3 data
21 | gpt3_gen_acc = np.load('./exp2_GPT3_data/all_probcat_gen_acc.npz')
22 | gpt3_MC_acc = np.load('./exp2_GPT3_data/all_probcat_MC_acc.npz')
23 | gpt3_gen_acc_mn = gpt3_gen_acc['acc']
24 | gpt3_gen_acc_err = gpt3_gen_acc['err']
25 | gpt3_MC_acc_mn = gpt3_MC_acc['acc']
26 | gpt3_MC_acc_err = gpt3_MC_acc['err']
27 |
28 | # Load human data
29 | human_gen_acc = np.load('./exp2_behavioral_data/probcat_gen_acc_behavior.npz')
30 | human_MC_acc = np.load('./exp2_behavioral_data/probcat_MC_acc_behavior.npz')
31 | human_gen_acc_mn = human_gen_acc['acc']
32 | human_gen_acc_err = human_gen_acc['err']
33 | human_MC_acc_mn = human_MC_acc['acc']
34 | human_MC_acc_err = human_MC_acc['err']
35 |
36 | # Plot settings
37 | N_conds = gpt3_MC_acc_mn.shape[0]
38 | total_bar_width = 0.8
39 | x_points = np.arange(N_conds)
40 | gpt3_color = 'darkslateblue'
41 | human_color = 'powderblue'
42 | plot_fontsize = 14
43 | title_fontsize = 16
44 |
45 | # Directory
46 | plot_dir = './exp2_GPT3_vs_human/'
47 | check_path(plot_dir)
48 |
49 | # Plot - generative
50 | ind_bar_width = total_bar_width / 2
51 | ax = plt.subplot(111)
52 | plt.bar(x_points - 0.2, gpt3_gen_acc_mn, yerr=gpt3_gen_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
53 | plt.bar(x_points + 0.2, human_gen_acc_mn, yerr=human_gen_acc_err, color=human_color, edgecolor='black', width=ind_bar_width)
54 | plt.ylim([0,1])
55 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
56 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
57 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', '4-rule', '5-rule'], fontsize=plot_fontsize)
58 | plt.xlabel('Problem type', fontsize=plot_fontsize)
59 | hide_top_right(ax)
60 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.95,1))
61 | results_fname = plot_dir + 'gen_gpt3_vs_human.png'
62 | ax.set_aspect(3)
63 | plt.tight_layout()
64 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
65 | plt.close()
66 |
67 | # Plot - multiple-choice
68 | ax = plt.subplot(111)
69 | plt.bar(x_points - 0.2, gpt3_MC_acc_mn, yerr=gpt3_MC_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
70 | plt.bar(x_points + 0.2, human_MC_acc_mn, yerr=human_MC_acc_err, color=human_color, edgecolor='black', width=ind_bar_width)
71 | plt.ylim([0,1])
72 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
73 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize)
74 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', '4-rule', '5-rule'], fontsize=plot_fontsize)
75 | plt.xlabel('Problem type', fontsize=plot_fontsize)
76 | hide_top_right(ax)
77 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.95,1))
78 | results_fname = plot_dir + 'MC_gpt3_vs_human.png'
79 | ax.set_aspect(3)
80 | plt.tight_layout()
81 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
82 | plt.close()
83 |
84 | ## Rule type analysis for one-rule problems
85 |
86 | # Load GPT-3 one-rule data
87 | gpt3_gen_acc_onerule = np.load('./exp2_GPT3_data/all_onerule_gen_acc.npz')
88 | gpt3_gen_acc_onerule_mn = gpt3_gen_acc_onerule['acc']
89 | gpt3_gen_acc_onerule_err = gpt3_gen_acc_onerule['err']
90 |
91 | # Load human one-rule data
92 | human_gen_acc_onerule = np.load('./exp2_behavioral_data/probcat_gen_acc_behavior_onerule.npz')
93 | human_gen_acc_onerule_mn = human_gen_acc_onerule['acc']
94 | human_gen_acc_onerule_err = human_gen_acc_onerule['err']
95 |
96 | # Plot settings
97 | N_conds = gpt3_gen_acc_onerule_mn.shape[0]
98 | x_points = np.arange(N_conds)
99 | ind_bar_width = total_bar_width / 2
100 |
101 | # Plot - generative
102 | ax = plt.subplot(111)
103 | plt.bar(x_points - 0.2, gpt3_gen_acc_onerule_mn, yerr=gpt3_gen_acc_onerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
104 | plt.bar(x_points + 0.2, human_gen_acc_onerule_mn, yerr=human_gen_acc_onerule_err, color=human_color, edgecolor='black', width=ind_bar_width)
105 | plt.ylim([0,1])
106 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
107 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
108 | plt.xticks(x_points, ['Constant', 'Distribution', 'Progression'], fontsize=plot_fontsize)
109 | plt.xlabel('Rule type', fontsize=plot_fontsize)
110 | hide_top_right(ax)
111 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.65,1))
112 | plt.title('One-rule problems', fontsize=title_fontsize)
113 | results_fname = plot_dir + 'onerule_gen_gpt3_vs_human.png'
114 | ax.set_aspect(3)
115 | plt.tight_layout()
116 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
117 | plt.close()
118 |
119 |
120 |
--------------------------------------------------------------------------------
/digit_mat/analyze_gpt3_exp2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import builtins
4 | from statsmodels.stats.proportion import proportion_confint
5 | import os
6 |
7 | def check_path(path):
8 | if not os.path.exists(path):
9 | os.mkdir(path)
10 |
11 | def hide_top_right(ax):
12 | ax.spines['right'].set_visible(False)
13 | ax.spines['top'].set_visible(False)
14 | ax.yaxis.set_ticks_position('left')
15 | ax.xaxis.set_ticks_position('bottom')
16 |
17 | # Load data
18 | all_data = np.load('./gpt_matprob_results_1thru5.npz', allow_pickle=True)
19 | MC_correct_pred = all_data['all_MC_correct_pred']
20 | gen_correct_pred = all_data['all_gen_correct_pred']
21 | all_prob_types = builtins.list(MC_correct_pred.item().keys())
22 |
23 | ## Analyze by major problem type
24 | correct_pred = {'combined_gen': [],
25 | 'combined_MC': [],
26 | 'one_rule_gen': [],
27 | 'one_rule_MC': [],
28 | 'two_rule_gen': [],
29 | 'two_rule_MC': [],
30 | 'three_rule_gen': [],
31 | 'three_rule_MC': [],
32 | 'four_rule_gen': [],
33 | 'four_rule_MC': [],
34 | 'five_rule_gen': [],
35 | 'five_rule_MC': []}
36 | for prob_type in all_prob_types:
37 | correct_pred['combined_gen'].append(gen_correct_pred.item()[prob_type])
38 | correct_pred['combined_MC'].append(MC_correct_pred.item()[prob_type])
39 | if 'constant' in prob_type or 'dist3' in prob_type or 'prog' in prob_type:
40 | correct_pred['one_rule_gen'].append(gen_correct_pred.item()[prob_type])
41 | correct_pred['one_rule_MC'].append(MC_correct_pred.item()[prob_type])
42 | elif 'two_rule' in prob_type:
43 | correct_pred['two_rule_gen'].append(gen_correct_pred.item()[prob_type])
44 | correct_pred['two_rule_MC'].append(MC_correct_pred.item()[prob_type])
45 | elif 'three_rule' in prob_type:
46 | correct_pred['three_rule_gen'].append(gen_correct_pred.item()[prob_type])
47 | correct_pred['three_rule_MC'].append(MC_correct_pred.item()[prob_type])
48 | elif 'four_rule' in prob_type:
49 | correct_pred['four_rule_gen'].append(gen_correct_pred.item()[prob_type])
50 | correct_pred['four_rule_MC'].append(MC_correct_pred.item()[prob_type])
51 | elif 'five_rule' in prob_type:
52 | correct_pred['five_rule_gen'].append(gen_correct_pred.item()[prob_type])
53 | correct_pred['five_rule_MC'].append(MC_correct_pred.item()[prob_type])
54 | # Convert to arrays
55 | for key in correct_pred.keys():
56 | correct_pred[key] = np.concatenate(correct_pred[key])
57 | # Calculate accuracy and confidence intervals
58 | all_acc = {}
59 | all_ci_lower = {}
60 | all_ci_upper = {}
61 | for key in correct_pred.keys():
62 | all_acc[key] = correct_pred[key].mean()
63 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
64 |
65 | # Directory for saving results
66 | results_dir = './exp2_GPT3_data/'
67 | check_path(results_dir)
68 |
69 | # Save results
70 | # Generative
71 | all_gen_acc = np.array([all_acc['one_rule_gen'], all_acc['two_rule_gen'], all_acc['three_rule_gen'], all_acc['four_rule_gen'], all_acc['five_rule_gen']])
72 | all_gen_lower_ci = np.array([all_ci_lower['one_rule_gen'], all_ci_lower['two_rule_gen'], all_ci_lower['three_rule_gen'], all_ci_lower['four_rule_gen'], all_ci_lower['five_rule_gen']])
73 | all_gen_upper_ci = np.array([all_ci_upper['one_rule_gen'], all_ci_upper['two_rule_gen'], all_ci_upper['three_rule_gen'], all_ci_upper['four_rule_gen'], all_ci_upper['five_rule_gen']])
74 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci
75 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc
76 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
77 | np.savez(results_dir + 'all_probcat_gen_acc.npz', acc=all_gen_acc, err=all_gen_err)
78 | # Multiple-choice
79 | all_MC_acc = np.array([all_acc['one_rule_MC'], all_acc['two_rule_MC'], all_acc['three_rule_MC'], all_acc['four_rule_MC'], all_acc['five_rule_MC']])
80 | all_MC_lower_ci = np.array([all_ci_lower['one_rule_MC'], all_ci_lower['two_rule_MC'], all_ci_lower['three_rule_MC'], all_ci_lower['four_rule_MC'], all_ci_lower['five_rule_MC']])
81 | all_MC_upper_ci = np.array([all_ci_upper['one_rule_MC'], all_ci_upper['two_rule_MC'], all_ci_upper['three_rule_MC'], all_ci_upper['four_rule_MC'], all_ci_upper['five_rule_MC']])
82 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci
83 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc
84 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
85 | np.savez(results_dir + 'all_probcat_MC_acc.npz', acc=all_MC_acc, err=all_MC_err)
86 |
87 | ## Three major one-rule problem types
88 | correct_pred = {'constant_gen': [],
89 | 'constant_MC': [],
90 | 'dist3_gen': [],
91 | 'dist3_MC': [],
92 | 'prog_gen': [],
93 | 'prog_MC': []}
94 | for prob_type in all_prob_types:
95 | if 'constant' in prob_type:
96 | correct_pred['constant_gen'].append(gen_correct_pred.item()[prob_type])
97 | correct_pred['constant_MC'].append(MC_correct_pred.item()[prob_type])
98 | elif 'dist3' in prob_type:
99 | correct_pred['dist3_gen'].append(gen_correct_pred.item()[prob_type])
100 | correct_pred['dist3_MC'].append(MC_correct_pred.item()[prob_type])
101 | elif 'prog' in prob_type:
102 | correct_pred['prog_gen'].append(gen_correct_pred.item()[prob_type])
103 | correct_pred['prog_MC'].append(MC_correct_pred.item()[prob_type])
104 | # Convert to arrays
105 | for key in correct_pred.keys():
106 | correct_pred[key] = np.concatenate(correct_pred[key])
107 | # Calculate accuracy and confidence intervals
108 | all_acc = {}
109 | all_ci_lower = {}
110 | all_ci_upper = {}
111 | for key in correct_pred.keys():
112 | all_acc[key] = correct_pred[key].mean()
113 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
114 |
115 | # Save results
116 | # Generative
117 | all_gen_acc = np.array([all_acc['constant_gen'], all_acc['dist3_gen'], all_acc['prog_gen']])
118 | all_gen_lower_ci = np.array([all_ci_lower['constant_gen'], all_ci_lower['dist3_gen'], all_ci_lower['prog_gen']])
119 | all_gen_upper_ci = np.array([all_ci_upper['constant_gen'], all_ci_upper['dist3_gen'], all_ci_upper['prog_gen']])
120 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci
121 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc
122 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
123 | np.savez(results_dir + 'all_onerule_gen_acc.npz', acc=all_gen_acc, err=all_gen_err)
124 | # Multiple-choice
125 | all_MC_acc = np.array([all_acc['constant_MC'], all_acc['dist3_MC'], all_acc['prog_MC']])
126 | all_MC_lower_ci = np.array([all_ci_lower['constant_MC'], all_ci_lower['dist3_MC'], all_ci_lower['prog_MC']])
127 | all_MC_upper_ci = np.array([all_ci_upper['constant_MC'], all_ci_upper['dist3_MC'], all_ci_upper['prog_MC']])
128 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci
129 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc
130 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
131 | np.savez(results_dir + 'all_onerule_MC_acc.npz', acc=all_MC_acc, err=all_MC_err)
132 |
133 |
--------------------------------------------------------------------------------
/digit_mat/combine_problems_1thru5.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 |
4 | # Save problems as json and numpy files
5 | def save_prob(all_prob, all_answer_choices, all_correct_ind, prob_type_name, all_problems_np, all_problems_js, perm_invariant=False):
6 | # Add problems to numpy dict
7 | all_problems_np[prob_type_name] = {'prob': all_prob, 'answer_choices': all_answer_choices, 'correct_ind': all_correct_ind, 'perm_invariant': perm_invariant}
8 | # Convert to strings and save as json
9 | all_data = []
10 | for p in range(all_prob.shape[0]):
11 | # Convert problem to string
12 | prompt = ''
13 | for r in range(3):
14 | for c in range(3):
15 | prompt += '['
16 | if r == 2 and c == 2:
17 | for i in range(len(all_prob[p][1][c])):
18 | prompt += '  '
19 | if i < (len(all_prob[p][1][c]) - 1):
20 | prompt += ' '
21 | else:
22 | for i in range(len(all_prob[p][r][c])):
23 | if all_prob[p][r][c][i] == -1:
24 | prompt += '  '
25 | else:
26 | prompt += str(all_prob[p][r][c][i])
27 |
28 | if i < (len(all_prob[p][r][c]) - 1):
29 | prompt += ' '
30 | prompt += ']'
31 | if c < 2:
32 | prompt += '   '
33 | if r < 2 and c == 2:
34 | prompt += '
'
35 | # Convert choices to strings
36 | options = []
37 | for a in range(8):
38 | option = '['
39 | for i in range(len(all_answer_choices[p][a])):
40 | option += str(all_answer_choices[p][a][i])
41 | if i < (len(all_answer_choices[p][a]) - 1):
42 | option += ' '
43 | if len(all_answer_choices[p][a]) == 0:
44 | option += '  '
45 | option += ']'
46 | options.append(option)
47 | # Add to dataset
48 | all_data.append({'prompt': prompt, 'options': options, 'correct': int(all_correct_ind[p]), 'prob_ind': p})
49 | # Add to javascript data
50 | all_problems_js[prob_type_name] = all_data
51 | return all_problems_np, all_problems_js
52 |
53 | # Number of problems per category
54 | max_N_prob_per_cat = 10
55 | N_instances_per_prob = 100
56 | # All categories
57 | all_cat = ['one_rule', 'two_rule', 'three_rule', 'four_rule', 'five_rule']
58 |
59 | # Load 1-3-rule problems
60 | one_3_rule_prob_fname = './all_problems.npz'
61 | one_3_rule_prob = np.load(one_3_rule_prob_fname, allow_pickle=True)['all_problems']
62 |
63 | # Load 4-5-rule problems
64 | four_5_rule_prob_fname = './all_4_5_rule_problems.npz'
65 | four_5_rule_prob = np.load(four_5_rule_prob_fname, allow_pickle=True)['all_problems']
66 |
67 | # Subsample problems and save as js script, also as numpy file
68 | all_problems_np = {}
69 | all_problems_js = {}
70 | # 1-rule problems
71 | one_rule_prob_names = ['row_constant', 'col_constant', 'dist3_diag1', 'dist3_diag2', 'prog_size1', 'prog_size2']
72 | for prob_name in one_rule_prob_names:
73 | all_problems_np, all_problems_js = save_prob(one_3_rule_prob.item()[prob_name]['prob'], one_3_rule_prob.item()[prob_name]['answer_choices'], one_3_rule_prob.item()[prob_name]['correct_ind'], prob_name, all_problems_np, all_problems_js)
74 | # 2-rule problems
75 | for prob_name in one_3_rule_prob.item().keys():
76 | if 'two_rule' in prob_name:
77 | all_problems_np, all_problems_js = save_prob(one_3_rule_prob.item()[prob_name]['prob'], one_3_rule_prob.item()[prob_name]['answer_choices'], one_3_rule_prob.item()[prob_name]['correct_ind'], prob_name, all_problems_np, all_problems_js)
78 | # 3-rule problems
79 | for prob_name in one_3_rule_prob.item().keys():
80 | if 'three_rule' in prob_name:
81 | all_problems_np, all_problems_js = save_prob(one_3_rule_prob.item()[prob_name]['prob'], one_3_rule_prob.item()[prob_name]['answer_choices'], one_3_rule_prob.item()[prob_name]['correct_ind'], prob_name, all_problems_np, all_problems_js)
82 | # 4-rule problems
83 | all_four_rule_prob_names = []
84 | for prob_name in four_5_rule_prob.item().keys():
85 | if 'four_rule' in prob_name:
86 | all_four_rule_prob_names.append(prob_name)
87 | all_sampled_probs = []
88 | for prob in range(max_N_prob_per_cat):
89 | # Sample problem instances
90 | prob_instances = []
91 | answer_choices = []
92 | correct_ind = []
93 | for i in range(N_instances_per_prob):
94 | duplicate = True
95 | while duplicate:
96 | prob_type = all_four_rule_prob_names[np.floor(np.random.rand() * len(all_four_rule_prob_names)).astype(int)]
97 | prob_ind = np.floor(np.random.rand() * N_instances_per_prob).astype(int)
98 | combined_name = prob_type + '_' + str(prob_ind)
99 | if np.logical_not(np.any(combined_name == np.array(all_sampled_probs))):
100 | duplicate = False
101 | all_sampled_probs.append(combined_name)
102 | prob_instances.append(four_5_rule_prob.item()[prob_type]['prob'][prob_ind])
103 | answer_choices.append(four_5_rule_prob.item()[prob_type]['answer_choices'][prob_ind])
104 | correct_ind.append(four_5_rule_prob.item()[prob_type]['correct_ind'][prob_ind])
105 | prob_instances = np.array(prob_instances)
106 | answer_choices = np.array(answer_choices)
107 | correct_ind = np.array(correct_ind)
108 | all_problems_np, all_problems_js = save_prob(prob_instances, answer_choices, correct_ind, 'four_rule_prob' + str(prob), all_problems_np, all_problems_js)
109 | # 5-rule problems
110 | all_five_rule_prob_names = []
111 | for prob_name in four_5_rule_prob.item().keys():
112 | if 'five_rule' in prob_name:
113 | all_five_rule_prob_names.append(prob_name)
114 | all_sampled_probs = []
115 | for prob in range(max_N_prob_per_cat):
116 | # Sample problem instances
117 | prob_instances = []
118 | answer_choices = []
119 | correct_ind = []
120 | for i in range(N_instances_per_prob):
121 | duplicate = True
122 | while duplicate:
123 | prob_type = all_five_rule_prob_names[np.floor(np.random.rand() * len(all_five_rule_prob_names)).astype(int)]
124 | prob_ind = np.floor(np.random.rand() * N_instances_per_prob).astype(int)
125 | combined_name = prob_type + '_' + str(prob_ind)
126 | if np.logical_not(np.any(combined_name == np.array(all_sampled_probs))):
127 | duplicate = False
128 | all_sampled_probs.append(combined_name)
129 | prob_instances.append(four_5_rule_prob.item()[prob_type]['prob'][prob_ind])
130 | answer_choices.append(four_5_rule_prob.item()[prob_type]['answer_choices'][prob_ind])
131 | correct_ind.append(four_5_rule_prob.item()[prob_type]['correct_ind'][prob_ind])
132 | prob_instances = np.array(prob_instances)
133 | answer_choices = np.array(answer_choices)
134 | correct_ind = np.array(correct_ind)
135 | all_problems_np, all_problems_js = save_prob(prob_instances, answer_choices, correct_ind, 'five_rule_prob' + str(prob), all_problems_np, all_problems_js)
136 | # Save numpy file
137 | np_fname = './all_problems_1thru5.npz'
138 | np.savez(np_fname, all_problems=all_problems_np)
139 | # Convert to json string
140 | json_string = json.dumps(all_problems_js)
141 | # Write to js script
142 | js_fname = './all_problems_1thru5.js'
143 | js_fid = open(js_fname, 'w')
144 | js_fid.write('var all_problems = ' + json_string)
145 | js_fid.close()
146 |
147 |
--------------------------------------------------------------------------------
/digit_mat/exp1_vs_exp2_create_stats_dset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import csv
3 | import builtins
4 |
5 | # Load human behavioral data for experiment 1
6 | human_data = np.load('./exp1_behavioral_data/ind_subj_results.npz')
7 | all_human_gen_correct_pred = human_data['all_subj_gen_correct_pred']
8 | all_human_MC_correct_pred = human_data['all_subj_MC_correct_pred']
9 | N_subj_exp1 = all_human_gen_correct_pred.shape[0]
10 | N_prob = all_human_gen_correct_pred.shape[1]
11 | # Create dataset
12 | subjID = []
13 | gen_correct_pred = []
14 | MC_correct_pred = []
15 | human_vs_gpt = []
16 | prob_type = []
17 | exp1_vs_exp2 = []
18 | onerule_subtype = []
19 | for s in range(N_subj_exp1):
20 | for p in range(N_prob):
21 | if p < 22:
22 | # Subject ID
23 | subjID.append(s)
24 | # Correct prediction
25 | gen_correct_pred.append(all_human_gen_correct_pred[s,p])
26 | MC_correct_pred.append(all_human_MC_correct_pred[s,p])
27 | # Human vs. GPT-3
28 | human_vs_gpt.append(0)
29 | # Experiment 1 vs. experiment 4
30 | exp1_vs_exp2.append(0)
31 | # Problem-type specific variables
32 | # One-rule problems
33 | if p < 6:
34 | # Problem type
35 | prob_type.append(0)
36 | # One-rule subtypes
37 | if p < 2:
38 | onerule_subtype.append(0)
39 | elif p == 2 or p == 3:
40 | onerule_subtype.append(1)
41 | elif p == 4 or p == 5:
42 | onerule_subtype.append(2)
43 | # Two-rule problems
44 | elif p >= 6 and p < 12:
45 | # Problem type
46 | prob_type.append(1)
47 | # Dummy code one-rule subtypes
48 | onerule_subtype.append(-1)
49 | # Three-rule problems
50 | elif p >= 12 and p < 22:
51 | # Problem type
52 | prob_type.append(2)
53 | # Dummy code one-rule subtypes
54 | onerule_subtype.append(-1)
55 | # Load GPT-3 data for experiment 1
56 | gpt3_data = np.load('./gpt_matprob_results.npz', allow_pickle=True)
57 | prob_type_names = builtins.list(gpt3_data['all_gen_correct_pred'].item().keys())
58 | # Add to dataset
59 | for p in range(N_prob):
60 | # Problem type name
61 | prob_type_name = prob_type_names[p]
62 | # Loop through all trials for this problem type
63 | N_trials = len(gpt3_data['all_gen_correct_pred'].item()[prob_type_name])
64 | for t in range(N_trials):
65 | if p < 22:
66 | # Subject ID
67 | subjID.append(N_subj_exp1)
68 | # Correct prediction
69 | gen_correct_pred.append(int(gpt3_data['all_gen_correct_pred'].item()[prob_type_name][t]))
70 | MC_correct_pred.append(int(gpt3_data['all_MC_correct_pred'].item()[prob_type_name][t]))
71 | # Human vs. GPT-3
72 | human_vs_gpt.append(1)
73 | # Experiment 1 vs. experiment 4
74 | exp1_vs_exp2.append(0)
75 | # Problem-type specific variables
76 | # One-rule problems
77 | if p < 6:
78 | # Problem type
79 | prob_type.append(0)
80 | # One-rule subtypes
81 | if p < 2:
82 | onerule_subtype.append(0)
83 | elif p == 2 or p == 3:
84 | onerule_subtype.append(1)
85 | elif p == 4 or p == 5:
86 | onerule_subtype.append(2)
87 | # Two-rule problems
88 | elif p >= 6 and p < 12:
89 | # Problem type
90 | prob_type.append(1)
91 | # Dummy code one-rule subtypes
92 | onerule_subtype.append(-1)
93 | # Three-rule problems
94 | elif p >= 12 and p < 22:
95 | # Problem type
96 | prob_type.append(2)
97 | # Dummy code one-rule subtypes
98 | onerule_subtype.append(-1)
99 |
100 | # Load human behavioral data for experiment 2
101 | human_data = np.load('./exp2_behavioral_data/ind_subj_results.npz')
102 | all_human_gen_correct_pred = human_data['all_subj_gen_correct_pred']
103 | all_human_MC_correct_pred = human_data['all_subj_MC_correct_pred']
104 | N_subj_exp2 = all_human_gen_correct_pred.shape[0]
105 | N_prob = all_human_gen_correct_pred.shape[1]
106 | for s in range(N_subj_exp2):
107 | for p in range(N_prob):
108 | if p < 22:
109 | # Subject ID
110 | subjID.append(N_subj_exp1 + 1 + s)
111 | # Correct prediction
112 | gen_correct_pred.append(all_human_gen_correct_pred[s,p])
113 | MC_correct_pred.append(all_human_MC_correct_pred[s,p])
114 | # Human vs. GPT-3
115 | human_vs_gpt.append(0)
116 | # Experiment 1 vs. experiment 4
117 | exp1_vs_exp2.append(1)
118 | # Problem-type specific variables
119 | # One-rule problems
120 | if p < 6:
121 | # Problem type
122 | prob_type.append(0)
123 | # One-rule subtypes
124 | if p < 2:
125 | onerule_subtype.append(0)
126 | elif p == 2 or p == 3:
127 | onerule_subtype.append(1)
128 | elif p == 4 or p == 5:
129 | onerule_subtype.append(2)
130 | # Two-rule problems
131 | elif p >= 6 and p < 12:
132 | # Problem type
133 | prob_type.append(1)
134 | # Dummy code one-rule subtypes
135 | onerule_subtype.append(-1)
136 | # Three-rule problems
137 | elif p >= 12 and p < 22:
138 | # Problem type
139 | prob_type.append(2)
140 | # Dummy code one-rule subtypes
141 | onerule_subtype.append(-1)
142 |
143 | # Load GPT-3 data
144 | gpt3_data = np.load('./gpt_matprob_results_1thru5.npz', allow_pickle=True)
145 | prob_type_names = builtins.list(gpt3_data['all_gen_correct_pred'].item().keys())
146 | # Add to dataset
147 | for p in range(N_prob):
148 | # Problem type name
149 | prob_type_name = prob_type_names[p]
150 | # Loop through all trials for this problem type
151 | N_trials = len(gpt3_data['all_gen_correct_pred'].item()[prob_type_name])
152 | for t in range(N_trials):
153 | if p < 22:
154 | # Subject ID
155 | subjID.append(N_subj_exp1)
156 | # Correct prediction
157 | gen_correct_pred.append(int(gpt3_data['all_gen_correct_pred'].item()[prob_type_name][t]))
158 | MC_correct_pred.append(int(gpt3_data['all_MC_correct_pred'].item()[prob_type_name][t]))
159 | # Human vs. GPT-3
160 | human_vs_gpt.append(1)
161 | # Experiment 1 vs. experiment 2
162 | exp1_vs_exp2.append(1)
163 | # Problem-type specific variables
164 | # One-rule problems
165 | if p < 6:
166 | # Problem type
167 | prob_type.append(0)
168 | # One-rule subtypes
169 | if p < 2:
170 | onerule_subtype.append(0)
171 | elif p == 2 or p == 3:
172 | onerule_subtype.append(1)
173 | elif p == 4 or p == 5:
174 | onerule_subtype.append(2)
175 | # Two-rule problems
176 | elif p >= 6 and p < 12:
177 | # Problem type
178 | prob_type.append(1)
179 | # Dummy code one-rule subtypes
180 | onerule_subtype.append(-1)
181 | # Three-rule problems
182 | elif p >= 12 and p < 22:
183 | # Problem type
184 | prob_type.append(2)
185 | # Dummy code one-rule subtypes
186 | onerule_subtype.append(-1)
187 |
188 | # Write csv file
189 | # Create file
190 | f = open('./exp1_vs_exp2_all_data.csv', 'w')
191 | writer = csv.writer(f)
192 | # Header
193 | header = ['subjID', 'gen_correct_pred', 'MC_correct_pred', 'human_vs_gpt', 'exp1_vs_exp2', 'prob_type', 'onerule_subtype']
194 | writer.writerow(header)
195 | # Write data
196 | for i in range(len(subjID)):
197 | data_row = [subjID[i], gen_correct_pred[i], MC_correct_pred[i], human_vs_gpt[i], exp1_vs_exp2[i], prob_type[i], onerule_subtype[i]]
198 | writer.writerow(data_row)
199 | # Close file
200 | f.close()
201 |
202 |
203 |
204 |
205 |
206 |
207 |
--------------------------------------------------------------------------------
/digit_mat/exp1_create_stats_dset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import csv
3 | import builtins
4 |
5 | # Load human behavioral data
6 | human_data = np.load('./exp1_behavioral_data/ind_subj_results.npz')
7 | all_human_gen_correct_pred = human_data['all_subj_gen_correct_pred']
8 | all_human_MC_correct_pred = human_data['all_subj_MC_correct_pred']
9 | N_subj = all_human_gen_correct_pred.shape[0]
10 | N_prob = all_human_gen_correct_pred.shape[1]
11 | # Load data about # of unique relations in each problem
12 | N_unique_rules_2rule_prob = np.load('./N_unique_rules_2rule_prob.npz')['N_unique_rules']
13 | N_unique_rules_3rule_prob = np.load('./N_unique_rules_3rule_prob.npz')['N_unique_rules']
14 | # Create dataset
15 | subjID = []
16 | gen_correct_pred = []
17 | MC_correct_pred = []
18 | human_vs_gpt = []
19 | N_unique_rules = []
20 | prob_type = []
21 | onerule_rule_type = []
22 | tworule_prog_noprog = []
23 | aligned_permuted = []
24 | for s in range(N_subj):
25 | for p in range(N_prob):
26 | # Subject ID
27 | subjID.append(s)
28 | # Correct prediction
29 | gen_correct_pred.append(all_human_gen_correct_pred[s,p])
30 | MC_correct_pred.append(all_human_MC_correct_pred[s,p])
31 | # Human vs. GPT-3
32 | human_vs_gpt.append(0)
33 | # Problem-type specific variables
34 | # One-rule problems
35 | if p < 6:
36 | # Problem type
37 | prob_type.append(0)
38 | # Number of unique rules
39 | N_unique_rules.append(0)
40 | # Rule type
41 | # Constant rule
42 | if p == 0 or p == 1:
43 | onerule_rule_type.append(0)
44 | # Distribution-of-3 rule
45 | elif p == 2 or p == 3:
46 | onerule_rule_type.append(1)
47 | # Progression rule
48 | elif p == 4 or p == 5:
49 | onerule_rule_type.append(2)
50 | # Dummy code for aligned vs. permuted (logic rules only)
51 | aligned_permuted.append(-1)
52 | # Dummy code for two-rule progression vs. no progression
53 | tworule_prog_noprog.append(-1)
54 | # Two-rule problems
55 | elif p >= 6 and p < 12:
56 | # Problem type
57 | prob_type.append(1)
58 | # Dummy code for one-rule rule type
59 | onerule_rule_type.append(-1)
60 | # Dummy code for aligned vs. permuted (logic rules only)
61 | aligned_permuted.append(-1)
62 | # Number of unique rules
63 | N_unique_rules.append(N_unique_rules_2rule_prob[p-6] - 1)
64 | # Progression rule present
65 | if p == 6 or p == 7 or p == 9:
66 | tworule_prog_noprog.append(0)
67 | elif p == 8 or p == 10 or p == 11:
68 | tworule_prog_noprog.append(1)
69 | # Three-rule problems
70 | elif p >= 12 and p < 22:
71 | # Problem type
72 | prob_type.append(2)
73 | # Dummy code for one-rule rule type
74 | onerule_rule_type.append(-1)
75 | # Dummy code for aligned vs. permuted (logic rules only)
76 | aligned_permuted.append(-1)
77 | # Number of unique rules
78 | N_unique_rules.append(N_unique_rules_3rule_prob[p-12] - 1)
79 | # Dummy code for two-rule progression vs. no progression
80 | tworule_prog_noprog.append(-1)
81 | # Logic problems
82 | elif p >= 22:
83 | # Problem type
84 | prob_type.append(3)
85 | # Dummy code for one-rule rule type
86 | onerule_rule_type.append(-1)
87 | # Dummy code for two-rule progression vs. no progression
88 | tworule_prog_noprog.append(-1)
89 | # Number of unique rules
90 | N_unique_rules.append(0)
91 | # Spatially aligned elements
92 | if p < 27:
93 | aligned_permuted.append(0)
94 | else:
95 | aligned_permuted.append(1)
96 | # Load GPT-3 data
97 | gpt3_data = np.load('./gpt_matprob_results.npz', allow_pickle=True)
98 | prob_type_names = builtins.list(gpt3_data['all_gen_correct_pred'].item().keys())
99 | # Add to dataset
100 | for p in range(N_prob):
101 | # Problem type name
102 | prob_type_name = prob_type_names[p]
103 | # Loop through all trials for this problem type
104 | N_trials = len(gpt3_data['all_gen_correct_pred'].item()[prob_type_name])
105 | for t in range(N_trials):
106 | # Subject ID
107 | subjID.append(N_subj)
108 | # Correct prediction
109 | gen_correct_pred.append(int(gpt3_data['all_gen_correct_pred'].item()[prob_type_name][t]))
110 | MC_correct_pred.append(int(gpt3_data['all_MC_correct_pred'].item()[prob_type_name][t]))
111 | # Human vs. GPT-3
112 | human_vs_gpt.append(1)
113 | # Problem-type specific variables
114 | # One-rule problems
115 | if p < 6:
116 | # Problem type
117 | prob_type.append(0)
118 | # Number of unique rules
119 | N_unique_rules.append(0)
120 | # Rule type
121 | # Constant rule
122 | if p == 0 or p == 1:
123 | onerule_rule_type.append(0)
124 | # Distribution-of-3 rule
125 | elif p == 2 or p == 3:
126 | onerule_rule_type.append(1)
127 | # Progression rule
128 | elif p == 4 or p == 5:
129 | onerule_rule_type.append(2)
130 | # Dummy code for aligned vs. permuted (logic rules only)
131 | aligned_permuted.append(-1)
132 | # Dummy code for two-rule progression vs. no progression
133 | tworule_prog_noprog.append(-1)
134 | # Two-rule problems
135 | elif p >= 6 and p < 12:
136 | # Problem type
137 | prob_type.append(1)
138 | # Dummy code for one-rule rule type
139 | onerule_rule_type.append(-1)
140 | # Number of unique rules
141 | N_unique_rules.append(N_unique_rules_2rule_prob[p-6] - 1)
142 | # Dummy code for aligned vs. permuted (logic rules only)
143 | aligned_permuted.append(-1)
144 | # Progression rule present
145 | if p == 6 or p == 7 or p == 9:
146 | tworule_prog_noprog.append(0)
147 | elif p == 8 or p == 10 or p == 11:
148 | tworule_prog_noprog.append(1)
149 | # Three-rule problems
150 | elif p >= 12 and p < 22:
151 | # Problem type
152 | prob_type.append(2)
153 | # Dummy code for one-rule rule type
154 | onerule_rule_type.append(-1)
155 | # Number of unique rules
156 | N_unique_rules.append(N_unique_rules_3rule_prob[p-12] - 1)
157 | # Dummy code for aligned vs. permuted (logic rules only)
158 | aligned_permuted.append(-1)
159 | # Dummy code for two-rule progression vs. no progression
160 | tworule_prog_noprog.append(-1)
161 | # Logic problems
162 | elif p >= 22:
163 | # Problem type
164 | prob_type.append(3)
165 | # Dummy code for one-rule rule type
166 | onerule_rule_type.append(-1)
167 | # Dummy code for two-rule progression vs. no progression
168 | tworule_prog_noprog.append(-1)
169 | # Number of unique rules
170 | N_unique_rules.append(0)
171 | # Spatially aligned elements
172 | if p < 27:
173 | aligned_permuted.append(0)
174 | else:
175 | aligned_permuted.append(1)
176 |
177 | # Write csv file
178 | # Create file
179 | f = open('./exp1_all_data.csv', 'w')
180 | writer = csv.writer(f)
181 | # Header
182 | header = ['subjID', 'gen_correct_pred', 'MC_correct_pred', 'human_vs_gpt', 'N_unique_rules', 'prob_type', 'onerule_rule_type', 'aligned_permuted', 'tworule_prog_noprog']
183 | writer.writerow(header)
184 | # Write data
185 | for i in range(len(subjID)):
186 | data_row = [subjID[i], gen_correct_pred[i], MC_correct_pred[i], human_vs_gpt[i], N_unique_rules[i], prob_type[i], onerule_rule_type[i], aligned_permuted[i], tworule_prog_noprog[i]]
187 | writer.writerow(data_row)
188 | # Close file
189 | f.close()
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
--------------------------------------------------------------------------------
/digit_mat/eval_gpt_matprob.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import numpy as np
3 | import builtins
4 | import os
5 |
6 | # Split word into characters
7 | def split(word):
8 | return [char for char in word]
9 |
10 | # Load all problems
11 | all_prob = np.load('./all_problems.npz', allow_pickle=True)
12 |
13 | # GPT-3 settings
14 | openai.api_key = "FILL_IN_API_KEY_HERE"
15 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":10, "stop":"\n", "echo":True, "logprobs":1, }
16 |
17 | # Loop through all problem types
18 | all_prob_types = builtins.list(all_prob['all_problems'].item().keys())
19 | # Load data if it already exists
20 | all_data_fname = './gpt_matprob_results.npz'
21 | if os.path.exists(all_data_fname):
22 | data_exists = True
23 | all_data = np.load('./gpt_matprob_results.npz', allow_pickle=True)
24 | else:
25 | data_exists = False
26 | # Create data structure for storing results
27 | all_gen_pred = {}
28 | all_gen_correct_pred = {}
29 | all_MC_pred = {}
30 | all_MC_correct_pred = {}
31 | all_alt_MC_correct_pred = {}
32 | for p in range(len(all_prob_types)):
33 | # Problem type
34 | prob_type = all_prob_types[p]
35 | # Load data
36 | if data_exists:
37 | all_gen_pred[prob_type] = all_data['all_gen_pred'].item()[prob_type]
38 | all_gen_correct_pred[prob_type] = all_data['all_gen_correct_pred'].item()[prob_type]
39 | all_MC_pred[prob_type] = all_data['all_MC_pred'].item()[prob_type]
40 | all_MC_correct_pred[prob_type] = all_data['all_MC_correct_pred'].item()[prob_type]
41 | all_alt_MC_correct_pred[prob_type] = all_data['all_alt_MC_correct_pred'].item()[prob_type]
42 | # Create data structure
43 | else:
44 | all_gen_pred[prob_type] = []
45 | all_gen_correct_pred[prob_type] = []
46 | all_MC_pred[prob_type] = []
47 | all_MC_correct_pred[prob_type] = []
48 | all_alt_MC_correct_pred[prob_type] = []
49 | # Loop over all problem indices
50 | N_prob = 40
51 | for prob_ind in range(N_prob):
52 | print(str(prob_ind + 1) + ' of ' + str(N_prob) + '...')
53 | # Loop over all problem types
54 | for p in range(len(all_prob_types)):
55 | # Problem type
56 | prob_type = all_prob_types[p]
57 | print('Problem type: ' + prob_type + '...')
58 | perm_invariant = all_prob['all_problems'].item()[prob_type]['perm_invariant']
59 | prob_type_N_prob = all_prob['all_problems'].item()[prob_type]['prob'].shape[0]
60 | if prob_ind < prob_type_N_prob and len(all_gen_correct_pred[prob_type]) <= prob_ind:
61 |
62 | # Problem
63 | prob = all_prob['all_problems'].item()[prob_type]['prob'][prob_ind]
64 | answer_choices = all_prob['all_problems'].item()[prob_type]['answer_choices'][prob_ind]
65 | correct_ind = all_prob['all_problems'].item()[prob_type]['correct_ind'][prob_ind]
66 | correct_answer = answer_choices[correct_ind]
67 |
68 | # Generate prompt
69 | prompt = ''
70 | for r in range(3):
71 | for c in range(3):
72 | prompt += '['
73 | if not (r == 2 and c == 2):
74 | for i in range(len(prob[r][c])):
75 | if prob[r][c][i] == -1:
76 | prompt += ' '
77 | else:
78 | prompt += str(prob[r][c][i])
79 | if i < len(prob[r][c]) - 1:
80 | prompt += ' '
81 | prompt += ']'
82 | if c < 2:
83 | prompt += ' '
84 | else:
85 | prompt += '\n'
86 |
87 | # Get response
88 | response = openai.Completion.create(prompt=prompt, **kwargs)
89 | response_text = response['choices'][0]['text']
90 | # Find portion of response corresponding to prediction
91 | prediction = response_text[len(prompt):]
92 | all_gen_pred[prob_type].append(prediction)
93 | # Get prediction set
94 | pred_set = []
95 | invalid_char = False
96 | closing_bracket = False
97 | for i in range(len(split(prediction))):
98 | if prediction[i] != ' ':
99 | if prediction[i].isdigit():
100 | pred_set.append(int(prediction[i]))
101 | elif prediction[i] == ']':
102 | break
103 | else:
104 | invalid_char = True
105 | break
106 | # Sort answer if problem type is permutation invariant
107 | if perm_invariant:
108 | correct_answer = np.sort(correct_answer)
109 | pred_set = np.sort(pred_set)
110 | # Determine whether prediction is correct
111 | correct_pred = False
112 | if not invalid_char and len(pred_set) == len(correct_answer):
113 | if np.all(pred_set == correct_answer):
114 | correct_pred = True
115 | all_gen_correct_pred[prob_type].append(correct_pred)
116 |
117 | # Get score for generated response
118 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1]
119 | response_complete = False
120 | token_ind = first_token_ind
121 | gen_completion = ''
122 | while not response_complete:
123 | token = response['choices'][0]['logprobs']['tokens'][token_ind]
124 | gen_completion += token
125 | contains_closed_bracket = False
126 | for i in range(len(token)):
127 | if token[i] == ']':
128 | contains_closed_bracket = True
129 | if contains_closed_bracket:
130 | response_complete = True
131 | if token == ']':
132 | last_token_ind = token_ind - 1
133 | else:
134 | last_token_ind = token_ind
135 | token_ind += 1
136 | gen_score = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind+1])
137 |
138 | # Evaluate answer choices
139 | all_choice_logprob = []
140 | for a in range(8):
141 | # Convert choice to string and remove ','
142 | choice_str = np.array(split(str(answer_choices[a])))
143 | choice_str = ''.join(builtins.list(choice_str[choice_str != ',']))
144 | # Add answer choice to prompt
145 | prompt_choice = prompt + choice_str[1:]
146 | # Get average log probability of response
147 | response = openai.Completion.create(prompt=prompt_choice, **kwargs)
148 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1]
149 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(prompt_choice))[0][0]
150 | choice_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind])
151 | all_choice_logprob.append(choice_avg_logprob)
152 | # Select answer
153 | model_choice = np.argmax(all_choice_logprob)
154 | all_MC_pred[prob_type].append(model_choice)
155 | # Determine whether multiple choice selection is correct
156 | MC_correct = model_choice == correct_ind
157 | all_MC_correct_pred[prob_type].append(MC_correct)
158 |
159 | # Alternative multiple-choice evaluation
160 | if correct_pred:
161 | alt_MC_correct = True
162 | else:
163 | if MC_correct:
164 | all_choice_logprob.append(gen_score)
165 | alt_model_choice = np.argmax(all_choice_logprob)
166 | alt_MC_correct = alt_model_choice == correct_ind
167 | else:
168 | alt_MC_correct = False
169 | all_alt_MC_correct_pred[prob_type].append(alt_MC_correct)
170 |
171 | # Save data
172 | eval_fname = './gpt_matprob_results.npz'
173 | np.savez(eval_fname,
174 | all_gen_pred=all_gen_pred, all_gen_correct_pred=all_gen_correct_pred, all_MC_pred=all_MC_pred, all_MC_correct_pred=all_MC_correct_pred, all_alt_MC_correct_pred=all_alt_MC_correct_pred,
175 | allow_pickle=True)
176 |
--------------------------------------------------------------------------------
/letter_string/compare_behavior_gpt3.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import argparse
4 | import os
5 |
6 | def check_path(path):
7 | if not os.path.exists(path):
8 | os.mkdir(path)
9 |
10 | def hide_top_right(ax):
11 | ax.spines['right'].set_visible(False)
12 | ax.spines['top'].set_visible(False)
13 | ax.yaxis.set_ticks_position('left')
14 | ax.xaxis.set_ticks_position('bottom')
15 |
16 | def plot_ind_data(ax, x_points, data, ind_bar_width):
17 | max_count = 24
18 | point_unit = ind_bar_width / max_count
19 | # Plot
20 | for i in range(len(x_points)):
21 | unique_vals = np.unique(data[i])
22 | for v in unique_vals:
23 | count = (data[i]==v).sum()
24 | span = count * point_unit
25 | x_min = x_points[i] - (span/2)
26 | x_max = x_points[i] + (span/2)
27 | x_vals = np.linspace(x_min,x_max,count)
28 | if v == 0:
29 | y_vals = np.ones(count) * 0.005
30 | else:
31 | y_vals = np.ones(count) * v
32 | plt.scatter(x_vals, y_vals, color='black', s=0.4, clip_on=False, marker='_')
33 | return ax
34 |
35 | # Settings
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.")
38 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.")
39 | args = parser.parse_args()
40 |
41 | # Results directories
42 | if args.sentence:
43 | results_dir = './human_vs_GPT3_sentence/'
44 | GPT3_results_dir = './GPT3_results_sentence/'
45 | elif args.noprompt:
46 | results_dir = './human_vs_GPT3_noprompt/'
47 | GPT3_results_dir = './GPT3_results_noprompt/'
48 | else:
49 | results_dir = './human_vs_GPT3/'
50 | GPT3_results_dir = './GPT3_results/'
51 | check_path(results_dir)
52 |
53 | # Plot settings
54 | gpt3_color = 'darkslateblue'
55 | human_color = 'powderblue'
56 | plot_fontsize = 10
57 | title_fontsize = 12
58 | axis_label_fontsize = 12
59 | bar_width = 0.8
60 | ind_bar_width = bar_width / 2
61 |
62 | ## Zero-generalization problems, grouped by transformation type
63 | # Load results
64 | human_zerogen_results = np.load('./behavioral_results/zerogen_acc.npz')
65 | human_zerogen_acc = human_zerogen_results['all_acc']
66 | human_zerogen_err = human_zerogen_results['all_err']
67 | GPT3_zerogen_results = np.load(GPT3_results_dir + 'zerogen_acc.npz')
68 | GPT3_zerogen_acc = GPT3_zerogen_results['all_acc']
69 | GPT3_zerogen_err = GPT3_zerogen_results['all_err']
70 | # Sort based on accuracy
71 | rank_order = np.flip(np.argsort(human_zerogen_acc))
72 | # Plot
73 | all_zerogen_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Remove\nredundant\nletter', 'Fix\nalphabetic\nsequence', 'Sort']
74 | x_points = np.arange(len(all_zerogen_prob_type_names))
75 | ax = plt.subplot(111)
76 | plt.bar(x_points - (ind_bar_width/2), GPT3_zerogen_acc[rank_order], yerr=GPT3_zerogen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
77 | plt.bar(x_points + (ind_bar_width/2), human_zerogen_acc[rank_order], yerr=human_zerogen_err[:,rank_order], color=human_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
78 | plt.ylim([0,1])
79 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
80 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
81 | plt.xticks(x_points, np.array(all_zerogen_prob_type_names)[rank_order], fontsize=plot_fontsize)
82 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize)
83 | plt.title('Zero-generalization problems')
84 | plt.legend(['GPT-3', 'Human'],fontsize=12,frameon=False)
85 | hide_top_right(ax)
86 | plt.tight_layout()
87 | plt.savefig(results_dir + 'zerogen_acc.pdf', dpi=300, bbox_inches="tight")
88 | plt.close()
89 |
90 | ## One-generalization problems, grouped by generalization type
91 | # Load results
92 | human_onegen_results = np.load('./behavioral_results/onegen_acc.npz')
93 | human_onegen_acc = human_onegen_results['all_acc']
94 | human_onegen_err = human_onegen_results['all_err']
95 | GPT3_onegen_results = np.load(GPT3_results_dir + 'onegen_acc.npz')
96 | GPT3_onegen_acc = GPT3_onegen_results['all_acc']
97 | GPT3_onegen_err = GPT3_onegen_results['all_err']
98 | # Sort based on accuracy
99 | rank_order = np.flip(np.argsort(human_onegen_acc))
100 | # Plot
101 | all_onegen_prob_type_names = ['Larger\ninterval', 'Longer\ntarget', 'Grouping', 'Interleaved\ndistractor', 'Letter-to-\nnumber', 'Reverse\norder']
102 | x_points = np.arange(len(all_onegen_prob_type_names))
103 | ax = plt.subplot(111)
104 | plt.bar(x_points - (ind_bar_width/2), GPT3_onegen_acc[rank_order], yerr=GPT3_onegen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
105 | plt.bar(x_points + (ind_bar_width/2), human_onegen_acc[rank_order], yerr=human_onegen_err[:,rank_order], color=human_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
106 | plt.ylim([0,1])
107 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
108 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
109 | plt.xticks(x_points, np.array(all_onegen_prob_type_names)[rank_order], fontsize=plot_fontsize)
110 | plt.xlabel('Generalization type', fontsize=axis_label_fontsize)
111 | plt.title('One-generalization problems')
112 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False)
113 | hide_top_right(ax)
114 | plt.tight_layout()
115 | plt.savefig(results_dir + 'onegen_acc.pdf', dpi=300, bbox_inches="tight")
116 | plt.close()
117 |
118 | ## All problems, grouped by number of generalizations
119 | # Load results
120 | human_all_gen_results = np.load('./behavioral_results/all_gen_acc.npz')
121 | human_all_gen_acc_ind_results = human_all_gen_results['all_ind_results']
122 | human_all_gen_acc = human_all_gen_results['all_acc']
123 | human_all_gen_err = human_all_gen_results['all_err']
124 | human_all_gen_std = human_all_gen_results['all_std']
125 | GPT3_all_gen_results = np.load(GPT3_results_dir + 'all_gen_acc.npz')
126 | GPT3_all_gen_acc = GPT3_all_gen_results['all_acc']
127 | GPT3_all_gen_err = GPT3_all_gen_results['all_err']
128 | # Sort based on accuracy
129 | rank_order = np.flip(np.argsort(human_all_gen_acc))
130 | # Plot
131 | all_gen_prob_type_names = np.arange(len(human_all_gen_acc)).astype(str)
132 | x_points = np.arange(len(all_gen_prob_type_names))
133 | ax = plt.subplot(111)
134 | plt.bar(x_points - (ind_bar_width/2), GPT3_all_gen_acc[rank_order], yerr=GPT3_all_gen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
135 | plt.bar(x_points + (ind_bar_width/2), human_all_gen_acc[rank_order], yerr=human_all_gen_err[rank_order], color=human_color, edgecolor='black', width=ind_bar_width)
136 | plt.ylim([0,1])
137 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
138 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
139 | plt.xticks(x_points, all_gen_prob_type_names, fontsize=plot_fontsize)
140 | plt.xlabel('Number of generalizations', fontsize=axis_label_fontsize)
141 | plt.title(' ')
142 | plot_ind_data(ax, x_points + (ind_bar_width/2), human_all_gen_acc_ind_results, ind_bar_width)
143 | hide_top_right(ax)
144 | ax.set_aspect(4)
145 | plt.tight_layout()
146 | plt.savefig(results_dir + 'all_gen_acc.pdf', dpi=300, bbox_inches="tight")
147 | plt.close()
148 |
149 | ## Generalization to real-world concepts
150 | # Load results
151 | human_realworld_results = np.load('./behavioral_results/realworld_acc.npz')
152 | human_realworld_acc = human_realworld_results['all_acc']
153 | human_realworld_err = human_realworld_results['all_err']
154 | GPT3_realworld_results = np.load(GPT3_results_dir + 'realworld_acc.npz')
155 | GPT3_realworld_acc = GPT3_realworld_results['all_acc']
156 | GPT3_realworld_err = GPT3_realworld_results['all_err']
157 | # Sort based on accuracy
158 | rank_order = np.flip(np.argsort(GPT3_realworld_acc))
159 | # Plot
160 | all_realworld_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Sort']
161 | x_points = np.arange(len(all_realworld_prob_type_names))
162 | ax = plt.subplot(111)
163 | plt.bar(x_points - (ind_bar_width/2), GPT3_realworld_acc[rank_order], yerr=GPT3_realworld_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
164 | plt.bar(x_points + (ind_bar_width/2), human_realworld_acc[rank_order], yerr=human_realworld_err[:,rank_order], color=human_color, edgecolor='black', width=ind_bar_width, ecolor='gray')
165 | plt.ylim([0,1])
166 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
167 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
168 | plt.xticks(x_points, np.array(all_realworld_prob_type_names)[rank_order], fontsize=plot_fontsize)
169 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize)
170 | plt.title('Real-world concept problems')
171 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False)
172 | hide_top_right(ax)
173 | ax.set_aspect(4)
174 | plt.tight_layout()
175 | plt.savefig(results_dir + 'realworld_acc.pdf', dpi=300, bbox_inches="tight")
176 | plt.close()
177 |
--------------------------------------------------------------------------------
/digit_mat/eval_gpt_matprob_prog_1thru5.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import numpy as np
3 | import builtins
4 | import os
5 |
6 | def check_path(path):
7 | if not os.path.exists(path):
8 | os.mkdir(path)
9 |
10 | # Split word into characters
11 | def split(word):
12 | return [char for char in word]
13 |
14 | # Load all problems
15 | all_prob = np.load('./all_problems_1thru5.npz', allow_pickle=True)
16 |
17 | # GPT-3 settings
18 | openai.api_key = "FILL_IN_API_KEY_HERE"
19 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":10, "stop":"\n", "echo":True, "logprobs":1, }
20 |
21 | # Loop through all problem types
22 | all_prob_types = builtins.list(all_prob['all_problems'].item().keys())
23 | # Load data if it already exists
24 | all_data_fname = './gpt_matprob_results_1thru5.npz'
25 | if os.path.exists(all_data_fname):
26 | data_exists = True
27 | all_data = np.load('./gpt_matprob_results_1thru5.npz', allow_pickle=True)
28 | else:
29 | data_exists = False
30 | # Create data structure for storing results
31 | all_gen_pred = {}
32 | all_gen_correct_pred = {}
33 | all_MC_pred = {}
34 | all_MC_correct_pred = {}
35 | all_alt_MC_correct_pred = {}
36 | for p in range(len(all_prob_types)):
37 | # Problem type
38 | prob_type = all_prob_types[p]
39 | # Load data
40 | if data_exists:
41 | all_gen_pred[prob_type] = all_data['all_gen_pred'].item()[prob_type]
42 | all_gen_correct_pred[prob_type] = all_data['all_gen_correct_pred'].item()[prob_type]
43 | all_MC_pred[prob_type] = all_data['all_MC_pred'].item()[prob_type]
44 | all_MC_correct_pred[prob_type] = all_data['all_MC_correct_pred'].item()[prob_type]
45 | all_alt_MC_correct_pred[prob_type] = all_data['all_alt_MC_correct_pred'].item()[prob_type]
46 | # Create data structure
47 | else:
48 | all_gen_pred[prob_type] = []
49 | all_gen_correct_pred[prob_type] = []
50 | all_MC_pred[prob_type] = []
51 | all_MC_correct_pred[prob_type] = []
52 | all_alt_MC_correct_pred[prob_type] = []
53 | # Loop over all problem indices
54 | N_runs = 20
55 | for run in range(N_runs):
56 | print(str(run + 1) + ' of ' + str(N_runs) + '...')
57 | # Initialize context with task instructions
58 | context = '[1] [1] [1]\n[2] [2] [2]\n[3] [3] [3]\n\n'
59 | # Loop over all problem types
60 | for p in range(len(all_prob_types)):
61 | # Problem type
62 | prob_type = all_prob_types[p]
63 | print('Problem type: ' + prob_type + '...')
64 | perm_invariant = all_prob['all_problems'].item()[prob_type]['perm_invariant']
65 | prob_type_N_prob = all_prob['all_problems'].item()[prob_type]['prob'].shape[0]
66 | if len(all_gen_correct_pred[prob_type]) <= run:
67 |
68 | # Sample problem index
69 | prob_ind = int(np.floor(np.random.rand() * prob_type_N_prob))
70 |
71 | # Problem
72 | prob = all_prob['all_problems'].item()[prob_type]['prob'][prob_ind]
73 | answer_choices = all_prob['all_problems'].item()[prob_type]['answer_choices'][prob_ind]
74 | correct_ind = all_prob['all_problems'].item()[prob_type]['correct_ind'][prob_ind]
75 | correct_answer = answer_choices[correct_ind]
76 |
77 | # Generate prompt
78 | prompt = ''
79 | for r in range(3):
80 | for c in range(3):
81 | prompt += '['
82 | if not (r == 2 and c == 2):
83 | for i in range(len(prob[r][c])):
84 | if prob[r][c][i] == -1:
85 | prompt += ' '
86 | else:
87 | prompt += str(prob[r][c][i])
88 | if i < len(prob[r][c]) - 1:
89 | prompt += ' '
90 | prompt += ']'
91 | if c < 2:
92 | prompt += ' '
93 | else:
94 | prompt += '\n'
95 | # Add context
96 | context_prompt = context + prompt
97 |
98 | # Get response
99 | fits_window = False
100 | response = []
101 | while not fits_window:
102 | try:
103 | response = openai.Completion.create(prompt=context_prompt, **kwargs)
104 | except:
105 | print('deleting problem from context...')
106 | context_prompt = context_prompt.split('\n\n')[1:]
107 | new_context_prompt = ''
108 | for i in range(len(context_prompt)):
109 | new_context_prompt += context_prompt[i]
110 | if i < (len(context_prompt) - 1):
111 | new_context_prompt += '\n\n'
112 | context_prompt = new_context_prompt
113 | if len(response) > 0:
114 | fits_window = True
115 | response_text = response['choices'][0]['text']
116 | # Find portion of response corresponding to prediction
117 | prediction = response_text[len(context_prompt):]
118 | all_gen_pred[prob_type].append(prediction)
119 | # Get prediction set
120 | pred_set = []
121 | invalid_char = False
122 | closing_bracket = False
123 | for i in range(len(split(prediction))):
124 | if prediction[i] != ' ':
125 | if prediction[i].isdigit():
126 | pred_set.append(int(prediction[i]))
127 | elif prediction[i] == ']':
128 | break
129 | else:
130 | invalid_char = True
131 | break
132 | # Sort answer if problem type is permutation invariant
133 | if perm_invariant:
134 | correct_answer = np.sort(correct_answer)
135 | pred_set = np.sort(pred_set)
136 | # Determine whether prediction is correct
137 | correct_pred = False
138 | if not invalid_char and len(pred_set) == len(correct_answer):
139 | if np.all(pred_set == correct_answer):
140 | correct_pred = True
141 | all_gen_correct_pred[prob_type].append(correct_pred)
142 |
143 | # Get score for generated response
144 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(context_prompt))[0][-1]
145 | response_complete = False
146 | token_ind = first_token_ind
147 | gen_completion = ''
148 | while not response_complete:
149 | token = response['choices'][0]['logprobs']['tokens'][token_ind]
150 | gen_completion += token
151 | contains_closed_bracket = False
152 | for i in range(len(token)):
153 | if token[i] == ']':
154 | contains_closed_bracket = True
155 | if contains_closed_bracket:
156 | response_complete = True
157 | if token == ']':
158 | last_token_ind = token_ind - 1
159 | else:
160 | last_token_ind = token_ind
161 | token_ind += 1
162 | gen_score = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind+1])
163 |
164 | # Evaluate answer choices
165 | all_choice_logprob = []
166 | for a in range(8):
167 | # Convert choice to string and remove ','
168 | choice_str = np.array(split(str(answer_choices[a])))
169 | choice_str = ''.join(builtins.list(choice_str[choice_str != ',']))
170 | # Add answer choice to prompt
171 | context_prompt_choice = context_prompt + choice_str[1:]
172 | # Get average log probability of response
173 | fits_window = False
174 | response = []
175 | while not fits_window:
176 | try:
177 | response = openai.Completion.create(prompt=context_prompt_choice, **kwargs)
178 | except:
179 | print('deleting problem from context...')
180 | context_prompt = context_prompt.split('\n\n')[1:]
181 | new_context_prompt = ''
182 | for i in range(len(context_prompt)):
183 | new_context_prompt += context_prompt[i]
184 | if i < (len(context_prompt) - 1):
185 | new_context_prompt += '\n\n'
186 | context_prompt = new_context_prompt
187 | context_prompt_choice = context_prompt + choice_str[1:]
188 | if len(response) > 0:
189 | fits_window = True
190 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(context_prompt))[0][-1]
191 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(context_prompt_choice))[0][0]
192 | choice_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind])
193 | all_choice_logprob.append(choice_avg_logprob)
194 | # Select answer
195 | model_choice = np.argmax(all_choice_logprob)
196 | all_MC_pred[prob_type].append(model_choice)
197 | # Determine whether multiple choice selection is correct
198 | MC_correct = model_choice == correct_ind
199 | all_MC_correct_pred[prob_type].append(MC_correct)
200 |
201 | # Alternative multiple-choice evaluation
202 | if correct_pred:
203 | alt_MC_correct = True
204 | else:
205 | if MC_correct:
206 | all_choice_logprob.append(gen_score)
207 | alt_model_choice = np.argmax(all_choice_logprob)
208 | alt_MC_correct = alt_model_choice == correct_ind
209 | else:
210 | alt_MC_correct = False
211 | all_alt_MC_correct_pred[prob_type].append(alt_MC_correct)
212 |
213 | # Add problem to context
214 | model_choice_str = np.array(split(str(answer_choices[model_choice])))
215 | model_choice_str = ''.join(builtins.list(model_choice_str[model_choice_str != ',']))
216 | completed_prob = context + prompt + model_choice_str[1:]
217 | completed_prob += '\n\n'
218 | context = completed_prob
219 |
220 | # Save data
221 | eval_fname = './gpt_matprob_results_1thru5.npz'
222 | np.savez(eval_fname,
223 | all_gen_pred=all_gen_pred, all_gen_correct_pred=all_gen_correct_pred, all_MC_pred=all_MC_pred, all_MC_correct_pred=all_MC_correct_pred, all_alt_MC_correct_pred=all_alt_MC_correct_pred,
224 | allow_pickle=True)
225 | # Raw output
226 | gen_data_dir = './gpt_matprob_results_1thru5/'
227 | check_path(gen_data_dir)
228 | gen_data_fname = gen_data_dir + str(run) + '.txt'
229 | gen_data_fid = open(gen_data_fname, 'w')
230 | gen_data_fid.write(context)
231 | gen_data_fid.close()
232 |
233 | else:
234 |
235 | # Load previously generated context
236 | gen_data_dir = './gpt_matprob_results_1thru5/'
237 | gen_data_fname = gen_data_dir + str(run) + '.txt'
238 | gen_data_fid = open(gen_data_fname, 'r')
239 | lines = gen_data_fid.readlines()
240 | context = ' '.join(lines)
241 | # Remove spaces
242 | context = context.split('\n')
243 | new_context = context[0]
244 | for c in range(1,len(context)):
245 | new_context += '\n'
246 | new_context += context[c][1:]
247 | context = new_context
248 |
--------------------------------------------------------------------------------
/letter_string/gen_problems.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import random
4 | import json
5 |
6 | ## Problems generated using one of the following 6 transformations:
7 | #
8 | # Successorship
9 | # [a b c d] [a b c e]
10 | # [i j k l] [i j k m]
11 | #
12 | # Predecessorship
13 | # [b c d e] [a c d e]
14 | # [j k l m] [i k l m]
15 | #
16 | # Add letter to sequence
17 | # [a b c d] [a b c d e]
18 | # [i j k l] [i j k l m]
19 | #
20 | # Remove redundant character
21 | # [a b b c d e] [a b c d e]
22 | # [i j k l l m] [i j k l m]
23 | #
24 | # Fix alphabetic sequence
25 | # [a w c d e] [a b c d e]
26 | # [i j k p m] [i j k l m]
27 | #
28 | # Sort characters
29 | # [a b e d c] [a b c d e]
30 | # [i k j l m] [i j k l m]
31 | #
32 | #
33 | ## and between 0 and 3 generalizations, out of the following 6:
34 | #
35 | # Larger interval
36 | # [a b c d] [a b c e]
37 | # [i k m o] [i k m q]
38 | #
39 | # Longer target
40 | # [a b c d] [a b c e]
41 | # [i j k l m n o p] [i j k l m n o q]
42 | #
43 | # Grouping
44 | # [a b c d] [a b c e]
45 | # [i i j j k k l l] [i i j j k k m m]
46 | #
47 | # Interleaved X's
48 | # [a b c d] [a b c e]
49 | # [i x j x k x l x] [i x j x k x m x]
50 | #
51 | # Letter-to-number
52 | # [a b c d] [a b c e]
53 | # [1 2 3 4] [1 2 3 5]
54 | #
55 | # Reversal
56 | # [a b c d] [a b c e]
57 | # [m l k j] [m l k i]
58 | #
59 | ##
60 |
61 | # Alphabet
62 | letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
63 | N_letters = len(letters)
64 | # Numbers
65 | numbers = np.arange(N_letters) + 1
66 | # Linearly ordered real-world concepts
67 | realworld_linear = [['cold', 'cool', 'warm', 'hot'],
68 | ['love', 'like', 'dislike', 'hate'],
69 | ['jack', 'queen', 'king', 'ace'],
70 | ['penny', 'nickel', 'dime', 'quarter'],
71 | ['second', 'minute', 'hour', 'day']]
72 |
73 | # Successor transformation
74 | def apply_succ(prob_letters):
75 | return [prob_letters[:-1], prob_letters[:-2] + [prob_letters[-1]]]
76 |
77 | # Predecessor transformation
78 | def apply_pred(prob_letters):
79 | return [prob_letters[1:], [prob_letters[0]] + prob_letters[2:]]
80 |
81 | # Add letter to sequence
82 | def apply_add_letter(prob_letters):
83 | return [prob_letters[:-1], prob_letters]
84 |
85 | # Remove redundant letter
86 | def apply_remove_redundant(prob_letters):
87 | redundant_loc = np.arange(len(prob_letters))
88 | np.random.shuffle(redundant_loc)
89 | redundant_loc = redundant_loc[0]
90 | prob_redundant = deepcopy(prob_letters)
91 | prob_redundant.insert(redundant_loc, prob_letters[redundant_loc])
92 | return [prob_redundant, prob_letters]
93 |
94 | # Remove out-of-place character
95 | def apply_fix_alphabet(prob_letters):
96 | remaining_letters = np.array(deepcopy(letters))
97 | remaining_letters = remaining_letters[np.all(np.expand_dims(np.array(remaining_letters),1) != np.expand_dims(np.array(prob_letters),0), 1)]
98 | np.random.shuffle(remaining_letters)
99 | insert_letter = remaining_letters[0]
100 | insert_loc = np.arange(len(prob_letters))
101 | np.random.shuffle(insert_loc)
102 | insert_loc = insert_loc[0]
103 | prob_letters_insert = deepcopy(prob_letters)
104 | prob_letters_insert[insert_loc] = insert_letter
105 | return [prob_letters_insert, prob_letters]
106 |
107 | # Sort letters
108 | def apply_sort(prob_letters):
109 | swap_loc = np.arange(len(prob_letters))
110 | np.random.shuffle(swap_loc)
111 | i_loc = swap_loc[0]
112 | j_loc = swap_loc[1]
113 | i_letter = prob_letters[i_loc]
114 | j_letter = prob_letters[j_loc]
115 | prob_swapped = deepcopy(prob_letters)
116 | prob_swapped[i_loc] = j_letter
117 | prob_swapped[j_loc] = i_letter
118 | return [prob_swapped, prob_letters]
119 |
120 | # Method for generating subset of problems
121 | def gen_prob_subset(N_generalize=0, N_prob=100, standard_len=5, longer_targ_len=9, larger_int_size=2,
122 | trans_allowed=['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort'],
123 | gen_allowed=['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse', 'realworld']):
124 | # Initialize storage for problems
125 | all_prob = []
126 | all_trans = []
127 | all_gen = []
128 | all_src_letters = []
129 | all_tgt_letters = []
130 | while len(all_prob) < N_prob:
131 | # Sample source letters
132 | src_start = np.floor(np.random.rand() * (len(letters)-(standard_len-1))).astype(int)
133 | src_letters = letters[src_start:src_start+standard_len]
134 | # Sample generalizations
135 | random.shuffle(gen_allowed)
136 | generalize = gen_allowed[:N_generalize]
137 | # Sample target letters
138 | if 'realworld' in generalize:
139 | random.shuffle(realworld_linear)
140 | tgt_letters = realworld_linear[0]
141 | else:
142 | if 'longer_targ' in generalize and 'larger_int' in generalize:
143 | tgt_span = (longer_targ_len * larger_int_size) - 1
144 | src_duplicate = True
145 | while src_duplicate:
146 | tgt_start = np.floor(np.random.rand() * (len(letters)-(tgt_span-1))).astype(int)
147 | if src_start != tgt_start:
148 | src_duplicate = False
149 | tgt_letters = letters[tgt_start:tgt_start+tgt_span][::2]
150 | elif 'longer_targ' in generalize and 'larger_int' not in generalize:
151 | src_duplicate = True
152 | while src_duplicate:
153 | tgt_start = np.floor(np.random.rand() * (len(letters)-(longer_targ_len-1))).astype(int)
154 | if src_start != tgt_start:
155 | src_duplicate = False
156 | tgt_letters = letters[tgt_start:tgt_start+longer_targ_len]
157 | elif 'longer_targ' not in generalize and 'larger_int' in generalize:
158 | tgt_span = (standard_len * larger_int_size) - 1
159 | src_duplicate = True
160 | while src_duplicate:
161 | tgt_start = np.floor(np.random.rand() * (len(letters)-(tgt_span-1))).astype(int)
162 | if src_start != tgt_start:
163 | src_duplicate = False
164 | tgt_letters = letters[tgt_start:tgt_start+tgt_span][::2]
165 | elif 'longer_targ' not in generalize and 'larger_int' not in generalize:
166 | src_duplicate = True
167 | while src_duplicate:
168 | tgt_start = np.floor(np.random.rand() * (len(letters)-(standard_len-1))).astype(int)
169 | if src_start != tgt_start:
170 | src_duplicate = False
171 | tgt_letters = letters[tgt_start:tgt_start+standard_len]
172 | # Reverse target letters
173 | if 'reverse' in generalize:
174 | tgt_letters.reverse()
175 | # Sample transformation
176 | random.shuffle(trans_allowed)
177 | trans = trans_allowed[0]
178 | # Apply transformation
179 | if trans == 'succ':
180 | src = apply_succ(src_letters)
181 | tgt = apply_succ(tgt_letters)
182 | elif trans == 'pred':
183 | src = apply_pred(src_letters)
184 | tgt = apply_pred(tgt_letters)
185 | elif trans == 'add_letter':
186 | src = apply_add_letter(src_letters)
187 | tgt = apply_add_letter(tgt_letters)
188 | elif trans == 'remove_redundant':
189 | src = apply_remove_redundant(src_letters)
190 | tgt = apply_remove_redundant(tgt_letters)
191 | elif trans == 'fix_alphabet':
192 | src = apply_fix_alphabet(src_letters)
193 | tgt = apply_fix_alphabet(tgt_letters)
194 | elif trans == 'sort':
195 | src = apply_sort(src_letters)
196 | tgt = apply_sort(tgt_letters)
197 | # Generalization from letters to numbers
198 | if 'letter2num' in generalize:
199 | new_tgt = []
200 | for i in range(len(tgt)):
201 | new_tgt_i = []
202 | for j in range(len(tgt[i])):
203 | new_tgt_i.append(numbers[np.where(np.array(letters) == tgt[i][j])[0][0]])
204 | new_tgt.append(new_tgt_i)
205 | tgt = new_tgt
206 | # Interleaved X's (or 0's, if target composed of numbers)
207 | if 'interleaved' in generalize:
208 | new_tgt = []
209 | for i in range(len(tgt)):
210 | new_tgt_i = []
211 | for j in range(len(tgt[i])):
212 | new_tgt_i.append(tgt[i][j])
213 | if 'letter2num' in generalize:
214 | new_tgt_i.append('0')
215 | else:
216 | new_tgt_i.append('x')
217 | new_tgt.append(new_tgt_i)
218 | tgt = new_tgt
219 | # Grouping
220 | if 'group' in generalize:
221 | new_tgt = []
222 | for i in range(len(tgt)):
223 | new_tgt_i = []
224 | for j in range(len(tgt[i])):
225 | new_tgt_i.append(tgt[i][j])
226 | new_tgt_i.append(tgt[i][j])
227 | new_tgt.append(new_tgt_i)
228 | tgt = new_tgt
229 | # Check that problem doesn't already exist
230 | prob = [src, tgt]
231 | duplicate = False
232 | for p_prev in range(len(all_prob)):
233 | if np.array(prob).shape == np.array(all_prob[p_prev]).shape:
234 | if np.all(np.array(prob) == np.array(all_prob[p_prev])):
235 | duplicate = True
236 | # Add to problem subset
237 | if not duplicate:
238 | all_prob.append(prob)
239 | all_trans.append(trans)
240 | all_gen.append(generalize)
241 | all_src_letters.append(src_letters)
242 | all_tgt_letters.append(tgt_letters)
243 | return {'prob': all_prob, 'trans': all_trans, 'gen': all_gen, 'src_letters': all_src_letters, 'tgt_letters': all_tgt_letters}
244 |
245 | # Add problems to json and numpy file
246 | def save_prob(all_prob, prob_type_name, all_prob_js):
247 | # Convert to strings and save as json
248 | all_data = []
249 | for p in range(len(all_prob['prob'])):
250 | # A
251 | prompt = '['
252 | for i in range(len(all_prob['prob'][p][0][0])):
253 | prompt += str(all_prob['prob'][p][0][0][i])
254 | if i < len(all_prob['prob'][p][0][0]) - 1:
255 | prompt += ' '
256 | prompt += ']   ['
257 | # B
258 | for i in range(len(all_prob['prob'][p][0][1])):
259 | prompt += str(all_prob['prob'][p][0][1][i])
260 | if i < len(all_prob['prob'][p][0][1]) - 1:
261 | prompt += ' '
262 | prompt += ']
['
263 | # C
264 | for i in range(len(all_prob['prob'][p][1][0])):
265 | prompt += str(all_prob['prob'][p][1][0][i])
266 | if i < len(all_prob['prob'][p][1][0]) - 1:
267 | prompt += ' '
268 | prompt += ']   [  ?  ]'
269 | # Add to dataset
270 | all_data.append({'prompt': prompt, 'prob_ind': p})
271 | # Add to javascript data
272 | all_prob_js[prob_type_name] = all_data
273 | return all_prob_js
274 |
275 | # Split subset
276 | def split_subset(all_prob, N_split):
277 | all_prob_split = []
278 | N_subset = int(len(all_prob['prob']) / N_split)
279 | for s in range(N_split):
280 | subset = {}
281 | for key in all_prob.keys():
282 | subset[key] = []
283 | for p in range(N_subset*s,N_subset*(s+1)):
284 | for key in all_prob.keys():
285 | subset[key].append(all_prob[key][p])
286 | all_prob_split.append(subset)
287 | return all_prob_split
288 |
289 | # Generate all basic analogies (zero generalizations)
290 | all_succ = gen_prob_subset(trans_allowed=['succ'])
291 | all_pred = gen_prob_subset(trans_allowed=['pred'])
292 | all_add_letter = gen_prob_subset(trans_allowed=['add_letter'])
293 | all_remove_redundant = gen_prob_subset(trans_allowed=['remove_redundant'])
294 | all_fix_alphabet = gen_prob_subset(trans_allowed=['fix_alphabet'])
295 | all_sort = gen_prob_subset(trans_allowed=['sort'])
296 |
297 | # Generate all problems with one generalization (randomly sample transformations)
298 | all_larger_int = gen_prob_subset(N_generalize=1, gen_allowed=['larger_int'])
299 | all_longer_targ = gen_prob_subset(N_generalize=1, gen_allowed=['longer_targ'])
300 | all_group = gen_prob_subset(N_generalize=1, gen_allowed=['group'])
301 | all_interleaved = gen_prob_subset(N_generalize=1, gen_allowed=['interleaved'])
302 | all_letter2num = gen_prob_subset(N_generalize=1, gen_allowed=['letter2num'])
303 | all_reverse = gen_prob_subset(N_generalize=1, gen_allowed=['reverse'])
304 |
305 | # Generate all problems with 2 and 3 generalizations
306 | all_2gen = gen_prob_subset(N_generalize=2, N_prob=600, gen_allowed=['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse'])
307 | all_2gen_split = split_subset(all_2gen, 6)
308 | all_3gen = gen_prob_subset(N_generalize=3, N_prob=600, gen_allowed=['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse'])
309 | all_3gen_split = split_subset(all_3gen, 6)
310 |
311 | # Generate problems involving generalization to real-world concepts
312 | all_realworld_succ = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['succ'], gen_allowed=['realworld'])
313 | all_realworld_pred = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['pred'], gen_allowed=['realworld'])
314 | all_realworld_add_letter = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['add_letter'], gen_allowed=['realworld'])
315 | all_realworld_sort = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['sort'], gen_allowed=['realworld'])
316 |
317 | # Combine problems
318 | all_prob_types = [all_succ, all_pred, all_add_letter, all_remove_redundant, all_fix_alphabet, all_sort,
319 | all_larger_int, all_longer_targ, all_group, all_interleaved, all_letter2num, all_reverse,
320 | all_2gen_split[0], all_2gen_split[1], all_2gen_split[2], all_2gen_split[3], all_2gen_split[4], all_2gen_split[5],
321 | all_3gen_split[0], all_3gen_split[1], all_3gen_split[2], all_3gen_split[3], all_3gen_split[4], all_3gen_split[5],
322 | all_realworld_succ, all_realworld_pred, all_realworld_add_letter, all_realworld_sort]
323 | all_prob_type_names = ['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort',
324 | 'larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse',
325 | '2gen_split1', '2gen_split2', '2gen_split3', '2gen_split4', '2gen_split5', '2gen_split6',
326 | '3gen_split1', '3gen_split2', '3gen_split3', '3gen_split4', '3gen_split5', '3gen_split6',
327 | 'realworld_succ', 'realworld_pred', 'realworld_add_letter', 'realworld_sort']
328 |
329 | # Create js variable for all_problems
330 | all_prob_js = {}
331 | all_prob_np = {}
332 | for p in range(len(all_prob_types)):
333 | all_prob_js = save_prob(all_prob_types[p], all_prob_type_names[p], all_prob_js)
334 | all_prob_np[all_prob_type_names[p]] = all_prob_types[p]
335 | # Write numpy file
336 | np.savez('./all_prob.npz', all_prob=all_prob_np)
337 | # Convert to json strings
338 | all_prob_json_string = json.dumps(all_prob_js)
339 | # Write to js script
340 | js_fname = './all_prob.js'
341 | js_fid = open(js_fname, 'w')
342 | js_fid.write('var all_problems = ' + all_prob_json_string)
343 | js_fid.close()
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
--------------------------------------------------------------------------------
/letter_string/analyze_gpt3_letterstring.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from statsmodels.stats.proportion import proportion_confint
4 | from itertools import combinations
5 | import builtins
6 | import argparse
7 | import os
8 |
9 | def check_path(path):
10 | if not os.path.exists(path):
11 | os.mkdir(path)
12 |
13 | def hide_top_right(ax):
14 | ax.spines['right'].set_visible(False)
15 | ax.spines['top'].set_visible(False)
16 | ax.yaxis.set_ticks_position('left')
17 | ax.xaxis.set_ticks_position('bottom')
18 |
19 | # Settings
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.")
22 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.")
23 | args = parser.parse_args()
24 |
25 | # Load data
26 | if args.sentence:
27 | all_responses = np.load('./gpt3_letterstring_results_sentence.npz')['all_prob_type_responses']
28 | elif args.noprompt:
29 | all_responses = np.load('./gpt3_letterstring_results_noprompt.npz')['all_prob_type_responses']
30 | else:
31 | all_responses = np.load('./gpt3_letterstring_results.npz')['all_prob_type_responses']
32 | N_prob_types = all_responses.shape[0]
33 | N_trials_per_prob_type = all_responses.shape[1]
34 | # Load problems
35 | all_prob = np.load('./all_prob.npz', allow_pickle=True)['all_prob']
36 | prob_types = builtins.list(all_prob.item().keys())
37 |
38 | # All possible combinations of transformations and generalizations
39 | trans = ['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort']
40 | gen = ['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse']
41 | all_trans = []
42 | all_gen = []
43 | for t in range(len(trans)):
44 | all_trans.append(trans[t])
45 | all_gen.append([])
46 | for t in range(len(trans)):
47 | for g in range(len(gen)):
48 | all_trans.append(trans[t])
49 | all_gen.append([gen[g]])
50 | all_2comb = builtins.list(combinations(np.arange(len(gen)),2))
51 | for t in range(len(trans)):
52 | for c in range(len(all_2comb)):
53 | all_trans.append(trans[t])
54 | all_gen.append([gen[all_2comb[c][0]], gen[all_2comb[c][1]]])
55 | all_3comb = builtins.list(combinations(np.arange(len(gen)),3))
56 | for t in range(len(trans)):
57 | for c in range(len(all_3comb)):
58 | all_trans.append(trans[t])
59 | all_gen.append([gen[all_3comb[c][0]], gen[all_3comb[c][1]], gen[all_3comb[c][2]]])
60 | N_prob_subtype = len(all_trans)
61 | subtype_counts = np.zeros(N_prob_subtype)
62 |
63 | # Calculate performance
64 | all_prob_type_correct_pred = []
65 | all_prob_type_subtype = []
66 | all_prob_N_gen = []
67 | all_prob_realworld = []
68 | for p in range(N_prob_types):
69 | all_correct_pred = []
70 | all_subtype = []
71 | all_N_gen = []
72 | all_realworld = []
73 | for t in range(N_trials_per_prob_type):
74 | response = all_responses[p][t]
75 | if args.sentence:
76 | if response[-1] == '.':
77 | response = response[:-1]
78 | if response[0] == ' ':
79 | response = response[1:]
80 | linebreak = True
81 | while linebreak:
82 | if response[:1] == '\n':
83 | response = response[1:]
84 | else:
85 | linebreak = False
86 | response_parsed = response.split(' ')
87 | correct_answer = all_prob.item()[prob_types[p]]['prob'][t][1][1]
88 | if np.array(correct_answer).astype(str).shape[0] == np.array(response_parsed).astype(str).shape[0]:
89 | correct_pred = np.all(np.array(correct_answer).astype(str) == np.array(response_parsed))
90 | else:
91 | correct_pred = False
92 | all_correct_pred.append(correct_pred)
93 | else:
94 | if response[-1] == ']':
95 | response = response[:-1]
96 | response_parsed = response.split(' ')
97 | correct_answer = all_prob.item()[prob_types[p]]['prob'][t][1][1]
98 | if np.array(correct_answer).astype(str).shape[0] == np.array(response_parsed).astype(str).shape[0]:
99 | correct_pred = np.all(np.array(correct_answer).astype(str) == np.array(response_parsed))
100 | else:
101 | correct_pred = False
102 | all_correct_pred.append(correct_pred)
103 | # Classify problem subtype
104 | for sp in range(N_prob_subtype):
105 | if all_prob.item()[prob_types[p]]['trans'][t] == all_trans[sp]:
106 | if len(all_prob.item()[prob_types[p]]['gen'][t]) == len(all_gen[sp]):
107 | if np.all(np.sort(all_prob.item()[prob_types[p]]['gen'][t]) == np.sort(np.array(all_gen[sp]))):
108 | all_subtype.append(sp)
109 | subtype_counts[sp] += 1
110 | # Number of generalizations
111 | N_gen = len(all_prob.item()[prob_types[p]]['gen'][t])
112 | all_N_gen.append(N_gen)
113 | # Real-world problems
114 | if 'realworld' in all_prob.item()[prob_types[p]]['gen'][t]:
115 | all_realworld.append(1)
116 | else:
117 | all_realworld.append(0)
118 | all_prob_type_correct_pred.append(all_correct_pred)
119 | all_prob_N_gen.append(all_N_gen)
120 | all_prob_realworld.append(all_realworld)
121 | if len(all_subtype) > 0:
122 | all_prob_type_subtype.append(all_subtype)
123 | # Convert to arrays
124 | all_prob_type_correct_pred = np.array(all_prob_type_correct_pred)
125 | all_prob_type_subtype = np.array(all_prob_type_subtype)
126 | all_prob_N_gen = np.array(all_prob_N_gen)
127 | all_prob_realworld = np.array(all_prob_realworld)
128 |
129 | # Create directory for results
130 | if args.sentence:
131 | results_dir = './GPT3_results_sentence/'
132 | elif args.noprompt:
133 | results_dir = './GPT3_results_noprompt/'
134 | else:
135 | results_dir = './GPT3_results/'
136 | check_path(results_dir)
137 |
138 | # Save individual trial results
139 | np.savez(results_dir + 'ind_trial_results.npz', all_prob_type_correct_pred=all_prob_type_correct_pred, all_prob_type_subtype=all_prob_type_subtype, all_prob_N_gen=all_prob_N_gen, all_prob_realworld=all_prob_realworld)
140 |
141 | # Correlation analysis
142 | all_subtype_acc = []
143 | for sp in range(N_prob_subtype):
144 | all_subtype_acc.append(all_prob_type_correct_pred[:all_prob_type_subtype.shape[0],:][all_prob_type_subtype == sp].mean())
145 | np.savez(results_dir + 'prob_subtype_acc.npz', subtype_acc=all_subtype_acc, subtype_counts=subtype_counts)
146 |
147 | # Plot settings
148 | gpt3_color = 'darkslateblue'
149 | plot_fontsize = 10
150 | title_fontsize = 12
151 | axis_label_fontsize = 12
152 | bar_width = 0.8
153 |
154 | # Calculate accuracy for all zero-generalization problems
155 | all_zerogen_prob_types = ['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort']
156 | all_zerogen_acc = []
157 | all_zerogen_ci_lower = []
158 | all_zerogen_ci_upper = []
159 | for p in range(len(all_zerogen_prob_types)):
160 | correct_pred = np.array(all_prob_type_correct_pred[np.where(np.array(prob_types)==all_zerogen_prob_types[p])[0][0]]).astype(float)
161 | all_zerogen_acc.append(correct_pred.mean())
162 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0])
163 | all_zerogen_ci_lower.append(ci_lower)
164 | all_zerogen_ci_upper.append(ci_upper)
165 | all_zerogen_acc = np.array(all_zerogen_acc)
166 | all_zerogen_ci_lower = np.array(all_zerogen_ci_lower)
167 | all_zerogen_ci_upper = np.array(all_zerogen_ci_upper)
168 | all_zerogen_lower_err = all_zerogen_acc - all_zerogen_ci_lower
169 | all_zerogen_upper_err = all_zerogen_ci_upper - all_zerogen_acc
170 | all_zerogen_err = np.array([all_zerogen_lower_err, all_zerogen_upper_err])
171 | # Sort based on accuracy
172 | rank_order = np.flip(np.argsort(all_zerogen_acc))
173 | # Plot
174 | all_zerogen_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Remove\nredundant\nletter', 'Fix\nalphabetic\nsequence', 'Sort']
175 | x_points = np.arange(len(all_zerogen_prob_types))
176 | ax = plt.subplot(111)
177 | plt.bar(x_points, all_zerogen_acc[rank_order], yerr=all_zerogen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=bar_width)
178 | plt.ylim([0,1])
179 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
180 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
181 | plt.xticks(x_points, np.array(all_zerogen_prob_type_names)[rank_order], fontsize=plot_fontsize)
182 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize)
183 | plt.title('Zero-generalization problems')
184 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False)
185 | hide_top_right(ax)
186 | plt.tight_layout()
187 | plt.savefig(results_dir + 'zerogen_acc.png', dpi=300, bbox_inches="tight")
188 | plt.close()
189 | # Save results
190 | np.savez(results_dir + 'zerogen_acc.npz', all_acc=all_zerogen_acc, all_err=all_zerogen_err)
191 |
192 | # Calculate all accuracy for one-generalization problems
193 | all_onegen_prob_types = ['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse']
194 | all_onegen_acc = []
195 | all_onegen_ci_lower = []
196 | all_onegen_ci_upper = []
197 | for p in range(len(all_onegen_prob_types)):
198 | correct_pred = np.array(all_prob_type_correct_pred[np.where(np.array(prob_types)==all_onegen_prob_types[p])[0][0]]).astype(float)
199 | all_onegen_acc.append(correct_pred.mean())
200 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0])
201 | all_onegen_ci_lower.append(ci_lower)
202 | all_onegen_ci_upper.append(ci_upper)
203 | all_onegen_acc = np.array(all_onegen_acc)
204 | all_onegen_ci_lower = np.array(all_onegen_ci_lower)
205 | all_onegen_ci_upper = np.array(all_onegen_ci_upper)
206 | all_onegen_lower_err = all_onegen_acc - all_onegen_ci_lower
207 | all_onegen_upper_err = all_onegen_ci_upper - all_onegen_acc
208 | all_onegen_err = np.array([all_onegen_lower_err, all_onegen_upper_err])
209 | # Sort based on accuracy
210 | rank_order = np.flip(np.argsort(all_onegen_acc))
211 | # Plot
212 | all_onegen_prob_type_names = ['Larger\ninterval', 'Longer\ntarget', 'Grouping', 'Interleaved\ndistractor', 'Letter-to-\nnumber', 'Reverse\norder']
213 | x_points = np.arange(len(all_onegen_prob_types))
214 | ax = plt.subplot(111)
215 | plt.bar(x_points, all_onegen_acc[rank_order], yerr=all_onegen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=bar_width)
216 | plt.ylim([0,1])
217 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
218 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
219 | plt.xticks(x_points, np.array(all_onegen_prob_type_names)[rank_order], fontsize=plot_fontsize)
220 | plt.xlabel('Generalization type', fontsize=axis_label_fontsize)
221 | plt.title('One-generalization problems')
222 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False)
223 | hide_top_right(ax)
224 | plt.tight_layout()
225 | plt.savefig(results_dir + 'onegen_acc.png', dpi=300, bbox_inches="tight")
226 | plt.close()
227 | # Save results
228 | np.savez(results_dir + 'onegen_acc.npz', all_acc=all_onegen_acc, all_err=all_onegen_err)
229 |
230 | # Calculate accuracy by number of generalizations
231 | gen_ind = [[0,6], [6,12], [12,18], [18,24]]
232 | all_gen_acc = []
233 | all_gen_ci_lower = []
234 | all_gen_ci_upper = []
235 | for p in range(len(gen_ind)):
236 | correct_pred = np.array(all_prob_type_correct_pred[gen_ind[p][0]:gen_ind[p][1]]).flatten().astype(float)
237 | acc = correct_pred.mean()
238 | all_gen_acc.append(acc)
239 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0])
240 | all_gen_ci_lower.append(ci_lower)
241 | all_gen_ci_upper.append(ci_upper)
242 | all_gen_ac = np.array(all_gen_acc)
243 | all_gen_ci_lower = np.array(all_gen_ci_lower)
244 | all_gen_ci_upper = np.array(all_gen_ci_upper)
245 | all_gen_lower_err = all_gen_acc - all_gen_ci_lower
246 | all_gen_upper_err = all_gen_ci_upper - all_gen_acc
247 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
248 | # Plot
249 | all_gen_prob_type_names = np.arange(len(gen_ind)).astype(str)
250 | x_points = np.arange(len(gen_ind))
251 | ax = plt.subplot(111)
252 | plt.bar(x_points, all_gen_acc, yerr=all_gen_err, color=gpt3_color, edgecolor='black', width=bar_width)
253 | plt.ylim([0,1])
254 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
255 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
256 | plt.xticks(x_points, all_gen_prob_type_names, fontsize=plot_fontsize)
257 | plt.xlabel('Number of generalizations', fontsize=axis_label_fontsize)
258 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False)
259 | hide_top_right(ax)
260 | plt.tight_layout()
261 | plt.savefig(results_dir + 'all_gen_acc.png', dpi=300, bbox_inches="tight")
262 | plt.close()
263 | # Save results
264 | np.savez(results_dir + 'all_gen_acc.npz', all_acc=all_gen_acc, all_err=all_gen_err)
265 |
266 | # Calculate accuracy for all real-world concept problems
267 | all_realworld_prob_types = ['realworld_succ', 'realworld_pred', 'realworld_add_letter', 'realworld_sort']
268 | all_realworld_acc = []
269 | all_realworld_ci_lower = []
270 | all_realworld_ci_upper = []
271 | for p in range(len(all_realworld_prob_types)):
272 | correct_pred = np.array(all_prob_type_correct_pred[np.where(np.array(prob_types)==all_realworld_prob_types[p])[0][0]]).astype(float)
273 | all_realworld_acc.append(correct_pred.mean())
274 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0])
275 | all_realworld_ci_lower.append(ci_lower)
276 | all_realworld_ci_upper.append(ci_upper)
277 | all_realworld_acc = np.array(all_realworld_acc)
278 | all_realworld_ci_lower = np.array(all_realworld_ci_lower)
279 | all_realworld_ci_upper = np.array(all_realworld_ci_upper)
280 | all_realworld_lower_err = all_realworld_acc - all_realworld_ci_lower
281 | all_realworld_upper_err = all_realworld_ci_upper - all_realworld_acc
282 | all_realworld_err = np.array([all_realworld_lower_err, all_realworld_upper_err])
283 | # Sort based on accuracy
284 | rank_order = np.flip(np.argsort(all_realworld_acc))
285 | # Plot
286 | all_realworld_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Sort']
287 | x_points = np.arange(len(all_realworld_prob_types))
288 | ax = plt.subplot(111)
289 | plt.bar(x_points, all_realworld_acc[rank_order], yerr=all_realworld_err[:,rank_order], color=gpt3_color, edgecolor='black', width=bar_width)
290 | plt.ylim([0,1])
291 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
292 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize)
293 | plt.xticks(x_points, np.array(all_realworld_prob_type_names)[rank_order], fontsize=plot_fontsize)
294 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize)
295 | plt.title('Real-world concept problems')
296 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False)
297 | hide_top_right(ax)
298 | plt.tight_layout()
299 | plt.savefig(results_dir + 'realworld_acc.png', dpi=300, bbox_inches="tight")
300 | plt.close()
301 | # Save results
302 | np.savez(results_dir + 'realworld_acc.npz', all_acc=all_realworld_acc, all_err=all_realworld_err)
303 |
--------------------------------------------------------------------------------
/digit_mat/analyze_gpt3_exp1.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import builtins
3 | from statsmodels.stats.proportion import proportion_confint
4 | import os
5 |
6 | def check_path(path):
7 | if not os.path.exists(path):
8 | os.mkdir(path)
9 |
10 | def hide_top_right(ax):
11 | ax.spines['right'].set_visible(False)
12 | ax.spines['top'].set_visible(False)
13 | ax.yaxis.set_ticks_position('left')
14 | ax.xaxis.set_ticks_position('bottom')
15 |
16 | # Load data
17 | all_data = np.load('./gpt_matprob_results.npz', allow_pickle=True)
18 | MC_correct_pred = all_data['all_MC_correct_pred']
19 | gen_correct_pred = all_data['all_gen_correct_pred']
20 | all_prob_types = builtins.list(MC_correct_pred.item().keys())
21 |
22 | ## Analyze by major problem type
23 | correct_pred = {'combined_gen': [],
24 | 'combined_MC': [],
25 | 'one_rule_gen': [],
26 | 'one_rule_MC': [],
27 | 'two_rule_gen': [],
28 | 'two_rule_MC': [],
29 | 'three_rule_gen': [],
30 | 'three_rule_MC': [],
31 | 'logic_rule_gen': [],
32 | 'logic_rule_MC': []}
33 | for prob_type in all_prob_types:
34 | correct_pred['combined_gen'].append(gen_correct_pred.item()[prob_type])
35 | correct_pred['combined_MC'].append(MC_correct_pred.item()[prob_type])
36 | if 'constant' in prob_type or 'dist3' in prob_type or 'prog' in prob_type:
37 | correct_pred['one_rule_gen'].append(gen_correct_pred.item()[prob_type])
38 | correct_pred['one_rule_MC'].append(MC_correct_pred.item()[prob_type])
39 | elif 'two_rule' in prob_type:
40 | correct_pred['two_rule_gen'].append(gen_correct_pred.item()[prob_type])
41 | correct_pred['two_rule_MC'].append(MC_correct_pred.item()[prob_type])
42 | elif 'three_rule' in prob_type:
43 | correct_pred['three_rule_gen'].append(gen_correct_pred.item()[prob_type])
44 | correct_pred['three_rule_MC'].append(MC_correct_pred.item()[prob_type])
45 | elif 'union' in prob_type or 'AND' in prob_type or 'XOR' in prob_type:
46 | correct_pred['logic_rule_gen'].append(gen_correct_pred.item()[prob_type])
47 | correct_pred['logic_rule_MC'].append(MC_correct_pred.item()[prob_type])
48 | # Convert to arrays
49 | for key in correct_pred.keys():
50 | correct_pred[key] = np.concatenate(correct_pred[key])
51 | # Calculate accuracy and confidence intervals
52 | all_acc = {}
53 | all_ci_lower = {}
54 | all_ci_upper = {}
55 | for key in correct_pred.keys():
56 | all_acc[key] = correct_pred[key].mean()
57 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
58 |
59 | # Directory for saving results
60 | results_dir = './exp1_GPT3_data/'
61 | check_path(results_dir)
62 |
63 | # All problems
64 | np.savez(results_dir + 'all_prob.npz', all_gen=correct_pred['combined_gen'], all_MC=correct_pred['combined_MC'])
65 |
66 | # Major problem types
67 | # Generative
68 | all_gen_acc = np.array([all_acc['one_rule_gen'], all_acc['two_rule_gen'], all_acc['three_rule_gen'], all_acc['logic_rule_gen']])
69 | all_gen_lower_ci = np.array([all_ci_lower['one_rule_gen'], all_ci_lower['two_rule_gen'], all_ci_lower['three_rule_gen'], all_ci_lower['logic_rule_gen']])
70 | all_gen_upper_ci = np.array([all_ci_upper['one_rule_gen'], all_ci_upper['two_rule_gen'], all_ci_upper['three_rule_gen'], all_ci_upper['logic_rule_gen']])
71 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci
72 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc
73 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
74 | np.savez(results_dir + 'all_probcat_gen_acc.npz', acc=all_gen_acc, err=all_gen_err)
75 | # Multiple-choice
76 | all_MC_acc = np.array([all_acc['one_rule_MC'], all_acc['two_rule_MC'], all_acc['three_rule_MC'], all_acc['logic_rule_MC']])
77 | all_MC_lower_ci = np.array([all_ci_lower['one_rule_MC'], all_ci_lower['two_rule_MC'], all_ci_lower['three_rule_MC'], all_ci_lower['logic_rule_MC']])
78 | all_MC_upper_ci = np.array([all_ci_upper['one_rule_MC'], all_ci_upper['two_rule_MC'], all_ci_upper['three_rule_MC'], all_ci_upper['logic_rule_MC']])
79 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci
80 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc
81 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
82 | np.savez(results_dir + 'all_probcat_MC_acc.npz', acc=all_MC_acc, err=all_MC_err)
83 |
84 | ## Relational complexity analysis (controlling for number of unique rules)
85 | N_unique_rules_2rule_prob = np.load('./N_unique_rules_2rule_prob.npz')['N_unique_rules']
86 | N_unique_rules_3rule_prob = np.load('./N_unique_rules_3rule_prob.npz')['N_unique_rules']
87 | correct_pred = {'gen_2rule_1unique': [],
88 | 'MC_2rule_1unique': [],
89 | 'gen_2rule_2unique': [],
90 | 'MC_2rule_2unique': [],
91 | 'gen_3rule_1unique': [],
92 | 'MC_3rule_1unique': [],
93 | 'gen_3rule_2unique': [],
94 | 'MC_3rule_2unique': [],
95 | 'gen_3rule_3unique': [],
96 | 'MC_3rule_3unique': []}
97 | for prob_type in all_prob_types:
98 | if 'two_rule' in prob_type:
99 | if N_unique_rules_2rule_prob[int(prob_type[-1])] == 1:
100 | correct_pred['gen_2rule_1unique'].append(gen_correct_pred.item()[prob_type])
101 | correct_pred['MC_2rule_1unique'].append(MC_correct_pred.item()[prob_type])
102 | elif N_unique_rules_2rule_prob[int(prob_type[-1])] == 2:
103 | correct_pred['gen_2rule_2unique'].append(gen_correct_pred.item()[prob_type])
104 | correct_pred['MC_2rule_2unique'].append(MC_correct_pred.item()[prob_type])
105 | elif 'three_rule' in prob_type:
106 | if N_unique_rules_3rule_prob[int(prob_type[-1])] == 1:
107 | correct_pred['gen_3rule_1unique'].append(gen_correct_pred.item()[prob_type])
108 | correct_pred['MC_3rule_1unique'].append(MC_correct_pred.item()[prob_type])
109 | elif N_unique_rules_3rule_prob[int(prob_type[-1])] == 2:
110 | correct_pred['gen_3rule_2unique'].append(gen_correct_pred.item()[prob_type])
111 | correct_pred['MC_3rule_2unique'].append(MC_correct_pred.item()[prob_type])
112 | elif N_unique_rules_3rule_prob[int(prob_type[-1])] == 3:
113 | correct_pred['gen_3rule_3unique'].append(gen_correct_pred.item()[prob_type])
114 | correct_pred['MC_3rule_3unique'].append(MC_correct_pred.item()[prob_type])
115 | # Convert to arrays
116 | for key in correct_pred.keys():
117 | correct_pred[key] = np.concatenate(correct_pred[key])
118 | # Calculate accuracy and confidence intervals
119 | all_acc = {}
120 | all_ci_lower = {}
121 | all_ci_upper = {}
122 | for key in correct_pred.keys():
123 | all_acc[key] = correct_pred[key].mean()
124 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
125 |
126 | # Save results
127 | # Two-rule problems
128 | # Generative
129 | all_gen_acc = np.array([all_acc['gen_2rule_1unique'], all_acc['gen_2rule_2unique']])
130 | all_gen_ci_lower = np.array([all_ci_lower['gen_2rule_1unique'], all_ci_lower['gen_2rule_2unique']])
131 | all_gen_ci_upper = np.array([all_ci_upper['gen_2rule_1unique'], all_ci_upper['gen_2rule_2unique']])
132 | all_gen_lower_err = all_gen_acc - all_gen_ci_lower
133 | all_gen_upper_err = all_gen_ci_upper - all_gen_acc
134 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
135 | np.savez(results_dir + 'tworule_prob_N_unique_rules_gen.npz', acc=all_gen_acc, err=all_gen_err)
136 | # Multiple-choice
137 | all_MC_acc = np.array([all_acc['MC_2rule_1unique'], all_acc['MC_2rule_2unique']])
138 | all_MC_ci_lower = np.array([all_ci_lower['MC_2rule_1unique'], all_ci_lower['MC_2rule_2unique']])
139 | all_MC_ci_upper = np.array([all_ci_upper['MC_2rule_1unique'], all_ci_upper['MC_2rule_2unique']])
140 | all_MC_lower_err = all_MC_acc - all_MC_ci_lower
141 | all_MC_upper_err = all_MC_ci_upper - all_MC_acc
142 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
143 | np.savez(results_dir + 'tworule_prob_N_unique_rules_MC.npz', acc=all_MC_acc, err=all_MC_err)
144 | # Three-rule problems
145 | # Generative
146 | all_gen_acc = np.array([all_acc['gen_3rule_1unique'], all_acc['gen_3rule_2unique'], all_acc['gen_3rule_3unique']])
147 | all_gen_ci_lower = np.array([all_ci_lower['gen_3rule_1unique'], all_ci_lower['gen_3rule_2unique'], all_ci_lower['gen_3rule_3unique']])
148 | all_gen_ci_upper = np.array([all_ci_upper['gen_3rule_1unique'], all_ci_upper['gen_3rule_2unique'], all_ci_upper['gen_3rule_3unique']])
149 | all_gen_lower_err = all_gen_acc - all_gen_ci_lower
150 | all_gen_upper_err = all_gen_ci_upper - all_gen_acc
151 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
152 | np.savez(results_dir + 'threerule_prob_N_unique_rules_gen.npz', acc=all_gen_acc, err=all_gen_err)
153 | # Multiple-choice
154 | all_MC_acc = np.array([all_acc['MC_3rule_1unique'], all_acc['MC_3rule_2unique'], all_acc['MC_3rule_3unique']])
155 | all_MC_ci_lower = np.array([all_ci_lower['MC_3rule_1unique'], all_ci_lower['MC_3rule_2unique'], all_ci_lower['MC_3rule_3unique']])
156 | all_MC_ci_upper = np.array([all_ci_upper['MC_3rule_1unique'], all_ci_upper['MC_3rule_2unique'], all_ci_upper['MC_3rule_3unique']])
157 | all_MC_lower_err = all_MC_acc - all_MC_ci_lower
158 | all_MC_upper_err = all_MC_ci_upper - all_MC_acc
159 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
160 | np.savez(results_dir + 'threerule_prob_N_unique_rules_MC.npz', acc=all_MC_acc, err=all_MC_err)
161 |
162 | ## Compare problems with vs. without progression rule (two-rule problems)
163 | correct_pred = {'tworule_prog_gen': [],
164 | 'tworule_noprog_gen': []}
165 | for prob_type in all_prob_types:
166 | if prob_type == 'two_rule_comb2' or prob_type == 'two_rule_comb4' or prob_type == 'two_rule_comb5':
167 | correct_pred['tworule_prog_gen'].append(gen_correct_pred.item()[prob_type])
168 | elif prob_type == 'two_rule_comb0' or prob_type == 'two_rule_comb1' or prob_type == 'two_rule_comb3':
169 | correct_pred['tworule_noprog_gen'].append(gen_correct_pred.item()[prob_type])
170 | # Convert to arrays
171 | for key in correct_pred.keys():
172 | correct_pred[key] = np.concatenate(correct_pred[key])
173 | # Calculate accuracy and confidence intervals
174 | all_acc = {}
175 | all_ci_lower = {}
176 | all_ci_upper = {}
177 | for key in correct_pred.keys():
178 | all_acc[key] = correct_pred[key].mean()
179 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
180 |
181 | # Save results
182 | all_gen_acc = np.array([all_acc['tworule_noprog_gen'], all_acc['tworule_prog_gen']])
183 | all_gen_lower_ci = np.array([all_ci_lower['tworule_noprog_gen'], all_ci_lower['tworule_prog_gen']])
184 | all_gen_upper_ci = np.array([all_ci_upper['tworule_noprog_gen'], all_ci_upper['tworule_prog_gen']])
185 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci
186 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc
187 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
188 | np.savez(results_dir + 'tworule_prog_vs_noprog_gen_acc.npz', acc=all_gen_acc, err=all_gen_err)
189 |
190 | # Three major one-rule problem types
191 | correct_pred = {'constant_gen': [],
192 | 'constant_MC': [],
193 | 'dist3_gen': [],
194 | 'dist3_MC': [],
195 | 'prog_gen': [],
196 | 'prog_MC': []}
197 | for prob_type in all_prob_types:
198 | if 'constant' in prob_type:
199 | correct_pred['constant_gen'].append(gen_correct_pred.item()[prob_type])
200 | correct_pred['constant_MC'].append(MC_correct_pred.item()[prob_type])
201 | elif 'dist3' in prob_type:
202 | correct_pred['dist3_gen'].append(gen_correct_pred.item()[prob_type])
203 | correct_pred['dist3_MC'].append(MC_correct_pred.item()[prob_type])
204 | elif 'prog' in prob_type:
205 | correct_pred['prog_gen'].append(gen_correct_pred.item()[prob_type])
206 | correct_pred['prog_MC'].append(MC_correct_pred.item()[prob_type])
207 | # Convert to arrays
208 | for key in correct_pred.keys():
209 | correct_pred[key] = np.concatenate(correct_pred[key])
210 | # Calculate accuracy and confidence intervals
211 | all_acc = {}
212 | all_ci_lower = {}
213 | all_ci_upper = {}
214 | for key in correct_pred.keys():
215 | all_acc[key] = correct_pred[key].mean()
216 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
217 |
218 | # Save results
219 | # Generative
220 | all_gen_acc = np.array([all_acc['constant_gen'], all_acc['dist3_gen'], all_acc['prog_gen']])
221 | all_gen_lower_ci = np.array([all_ci_lower['constant_gen'], all_ci_lower['dist3_gen'], all_ci_lower['prog_gen']])
222 | all_gen_upper_ci = np.array([all_ci_upper['constant_gen'], all_ci_upper['dist3_gen'], all_ci_upper['prog_gen']])
223 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci
224 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc
225 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
226 | np.savez(results_dir + 'all_onerule_gen_acc.npz', acc=all_gen_acc, err=all_gen_err)
227 | # Multiple-choice
228 | all_MC_acc = np.array([all_acc['constant_MC'], all_acc['dist3_MC'], all_acc['prog_MC']])
229 | all_MC_lower_ci = np.array([all_ci_lower['constant_MC'], all_ci_lower['dist3_MC'], all_ci_lower['prog_MC']])
230 | all_MC_upper_ci = np.array([all_ci_upper['constant_MC'], all_ci_upper['dist3_MC'], all_ci_upper['prog_MC']])
231 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci
232 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc
233 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
234 | np.savez(results_dir + 'all_onerule_MC_acc.npz', acc=all_MC_acc, err=all_MC_err)
235 |
236 | # Permuted vs. non-permuted logic problems
237 | correct_pred = {'aligned_gen': [],
238 | 'aligned_MC': [],
239 | 'permuted_gen': [],
240 | 'permuted_MC': []}
241 | for prob_type in all_prob_types:
242 | if 'union' in prob_type or 'AND' in prob_type or 'XOR' in prob_type:
243 | if 'permuted' in prob_type:
244 | correct_pred['permuted_gen'].append(gen_correct_pred.item()[prob_type])
245 | correct_pred['permuted_MC'].append(MC_correct_pred.item()[prob_type])
246 | else:
247 | correct_pred['aligned_gen'].append(gen_correct_pred.item()[prob_type])
248 | correct_pred['aligned_MC'].append(MC_correct_pred.item()[prob_type])
249 | # Convert to arrays
250 | for key in correct_pred.keys():
251 | correct_pred[key] = np.concatenate(correct_pred[key])
252 | # Calculate accuracy and confidence intervals
253 | all_acc = {}
254 | all_ci_lower = {}
255 | all_ci_upper = {}
256 | for key in correct_pred.keys():
257 | all_acc[key] = correct_pred[key].mean()
258 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0])
259 |
260 | # Save results
261 | # Generative
262 | all_gen_acc = np.array([all_acc['aligned_gen'], all_acc['permuted_gen']])
263 | all_gen_lower_ci = np.array([all_ci_lower['aligned_gen'], all_ci_lower['permuted_gen']])
264 | all_gen_upper_ci = np.array([all_ci_upper['aligned_gen'], all_ci_upper['permuted_gen']])
265 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci
266 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc
267 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err])
268 | np.savez(results_dir + 'aligned_vs_permuted_gen_acc.npz', acc=all_gen_acc, err=all_gen_err)
269 | # Multiple-choice
270 | all_MC_acc = np.array([all_acc['aligned_MC'], all_acc['permuted_MC']])
271 | all_MC_lower_ci = np.array([all_ci_lower['aligned_MC'], all_ci_lower['permuted_MC']])
272 | all_MC_upper_ci = np.array([all_ci_upper['aligned_MC'], all_ci_upper['permuted_MC']])
273 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci
274 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc
275 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err])
276 | np.savez(results_dir + 'aligned_vs_permuted_MC_acc.npz', acc=all_MC_acc, err=all_MC_err)
277 |
278 |
279 |
280 |
--------------------------------------------------------------------------------
/digit_mat/exp1_plot_GPT3_vs_human.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 |
5 | def check_path(path):
6 | if not os.path.exists(path):
7 | os.mkdir(path)
8 |
9 | def hide_top_right(ax):
10 | ax.spines['right'].set_visible(False)
11 | ax.spines['top'].set_visible(False)
12 | ax.yaxis.set_ticks_position('left')
13 | ax.xaxis.set_ticks_position('bottom')
14 |
15 | # Digit matrices
16 | # 1-rule, 2-rule, 3-rule, and logic-rule problems
17 | # GPT-3: zero-shot
18 | # Humans: shuffled
19 |
20 | ## Analysis of major problem types
21 |
22 | # Load GPT-3 data
23 | gpt3_gen_acc = np.load('./exp1_GPT3_data/all_probcat_gen_acc.npz')
24 | gpt3_MC_acc = np.load('./exp1_GPT3_data/all_probcat_MC_acc.npz')
25 | gpt3_gen_acc_mn = gpt3_gen_acc['acc']
26 | gpt3_gen_acc_err = gpt3_gen_acc['err']
27 | gpt3_MC_acc_mn = gpt3_MC_acc['acc']
28 | gpt3_MC_acc_err = gpt3_MC_acc['err']
29 |
30 | # Load human data
31 | human_gen_acc = np.load('./exp1_behavioral_data/probcat_gen_acc_behavior.npz')
32 | human_MC_acc = np.load('./exp1_behavioral_data/probcat_MC_acc_behavior.npz')
33 | human_gen_acc_mn = human_gen_acc['acc']
34 | human_gen_acc_err = human_gen_acc['err']
35 | human_MC_acc_mn = human_MC_acc['acc']
36 | human_MC_acc_err = human_MC_acc['err']
37 |
38 | # Plot settings
39 | N_conds = gpt3_MC_acc_mn.shape[0]
40 | total_bar_width = 0.8
41 | ind_bar_width = total_bar_width / 2
42 | x_points = np.arange(N_conds)
43 | gpt3_color = 'darkslateblue'
44 | human_color = 'powderblue'
45 | plot_fontsize = 14
46 | title_fontsize = 16
47 |
48 | # Directory
49 | plot_dir = './exp1_GPT3_vs_human/'
50 | check_path(plot_dir)
51 |
52 | # Plot - generative
53 | ax = plt.subplot(111)
54 | plt.bar(x_points - 0.2, gpt3_gen_acc_mn, yerr=gpt3_gen_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
55 | plt.bar(x_points + 0.2, human_gen_acc_mn, yerr=human_gen_acc_err, color=human_color, edgecolor='black', width=ind_bar_width)
56 | plt.ylim([0,1])
57 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
58 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
59 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', 'Logic'], fontsize=plot_fontsize)
60 | plt.xlabel('Problem type', fontsize=plot_fontsize)
61 | hide_top_right(ax)
62 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
63 | results_fname = plot_dir + 'gen_gpt3_vs_human.png'
64 | ax.set_aspect(2.5)
65 | plt.tight_layout()
66 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
67 | plt.close()
68 |
69 | # Plot - multiple-choice
70 | ax = plt.subplot(111)
71 | plt.bar(x_points - 0.2, gpt3_MC_acc_mn, yerr=gpt3_MC_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
72 | plt.bar(x_points + 0.2, human_MC_acc_mn, yerr=human_MC_acc_err, color=human_color, edgecolor='black', width=ind_bar_width)
73 | plt.ylim([0,1])
74 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
75 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize)
76 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', 'Logic'], fontsize=plot_fontsize)
77 | plt.xlabel('Problem type', fontsize=plot_fontsize)
78 | hide_top_right(ax)
79 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
80 | results_fname = plot_dir + 'MC_gpt3_vs_human.png'
81 | ax.set_aspect(2.5)
82 | plt.tight_layout()
83 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
84 | plt.close()
85 |
86 | ## Rule type analysis for one-rule problems
87 |
88 | # Load GPT-3 one-rule data
89 | gpt3_gen_acc_onerule = np.load('./exp1_GPT3_data/all_onerule_gen_acc.npz')
90 | gpt3_MC_acc_onerule = np.load('./exp1_GPT3_data/all_onerule_MC_acc.npz')
91 | gpt3_gen_acc_onerule_mn = gpt3_gen_acc_onerule['acc']
92 | gpt3_gen_acc_onerule_err = gpt3_gen_acc_onerule['err']
93 | gpt3_MC_acc_onerule_mn = gpt3_MC_acc_onerule['acc']
94 | gpt3_MC_acc_onerule_err = gpt3_MC_acc_onerule['err']
95 |
96 | # Load human one-rule data
97 | human_gen_acc_onerule = np.load('./exp1_behavioral_data/probcat_gen_acc_behavior_onerule.npz')
98 | human_MC_acc_onerule = np.load('./exp1_behavioral_data/probcat_MC_acc_behavior_onerule.npz')
99 | human_gen_acc_onerule_mn = human_gen_acc_onerule['acc']
100 | human_gen_acc_onerule_err = human_gen_acc_onerule['err']
101 | human_MC_acc_onerule_mn = human_MC_acc_onerule['acc']
102 | human_MC_acc_onerule_err = human_MC_acc_onerule['err']
103 |
104 | # Plot settings
105 | N_conds = gpt3_MC_acc_onerule_mn.shape[0]
106 | x_points = np.arange(N_conds)
107 |
108 | # Plot - generative
109 | ax = plt.subplot(111)
110 | plt.bar(x_points - 0.2, gpt3_gen_acc_onerule_mn, yerr=gpt3_gen_acc_onerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
111 | plt.bar(x_points + 0.2, human_gen_acc_onerule_mn, yerr=human_gen_acc_onerule_err, color=human_color, edgecolor='black', width=ind_bar_width)
112 | plt.ylim([0,1])
113 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
114 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
115 | plt.xticks(x_points, ['Constant', 'Distribution', 'Progression'], fontsize=plot_fontsize)
116 | plt.xlabel('Rule type', fontsize=plot_fontsize)
117 | hide_top_right(ax)
118 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
119 | plt.title('One-rule problems', fontsize=title_fontsize)
120 | results_fname = plot_dir + 'onerule_gen_gpt3_vs_human.png'
121 | ax.set_aspect(2.5)
122 | plt.tight_layout()
123 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
124 | plt.close()
125 |
126 | # Plot - multiple-choice
127 | ax = plt.subplot(111)
128 | plt.bar(x_points - 0.2, gpt3_MC_acc_onerule_mn, yerr=gpt3_MC_acc_onerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
129 | plt.bar(x_points + 0.2, human_MC_acc_onerule_mn, yerr=human_MC_acc_onerule_err, color=human_color, edgecolor='black', width=ind_bar_width)
130 | plt.ylim([0,1])
131 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
132 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize)
133 | plt.xticks(x_points, ['Constant', 'Distribution', 'Progression'], fontsize=plot_fontsize)
134 | plt.xlabel('Rule type', fontsize=plot_fontsize)
135 | hide_top_right(ax)
136 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
137 | plt.title('One-rule problems', fontsize=title_fontsize)
138 | results_fname = plot_dir + 'onerule_MC_gpt3_vs_human.png'
139 | ax.set_aspect(2.5)
140 | plt.tight_layout()
141 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
142 | plt.close()
143 |
144 | ## Progression vs. no progression two-rule problems
145 |
146 | # Load GPT-3 one-rule data
147 | gpt3_gen_acc_tworule = np.load('./exp1_GPT3_data/tworule_prog_vs_noprog_gen_acc.npz')
148 | gpt3_gen_acc_tworule_mn = gpt3_gen_acc_tworule['acc']
149 | gpt3_gen_acc_tworule_err = gpt3_gen_acc_tworule['err']
150 |
151 | # Load human one-rule data
152 | human_gen_acc_tworule = np.load('./exp1_behavioral_data/probcat_gen_acc_behavior_prog_tworule.npz')
153 | human_gen_acc_tworule_mn = human_gen_acc_tworule['acc']
154 | human_gen_acc_tworule_err = human_gen_acc_tworule['err']
155 |
156 | # Plot settings
157 | N_conds = gpt3_gen_acc_tworule_mn.shape[0]
158 | x_points = np.arange(N_conds)
159 |
160 | # Plot - generative
161 | ax = plt.subplot(111)
162 | plt.bar(x_points - 0.2, gpt3_gen_acc_tworule_mn, yerr=gpt3_gen_acc_tworule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
163 | plt.bar(x_points + 0.2, human_gen_acc_tworule_mn, yerr=human_gen_acc_tworule_err, color=human_color, edgecolor='black', width=ind_bar_width)
164 | plt.ylim([0,1])
165 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
166 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
167 | plt.xticks(x_points, ['No progression', 'Progression'], fontsize=plot_fontsize)
168 | plt.xlabel(' ', fontsize=plot_fontsize)
169 | hide_top_right(ax)
170 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.85,1))
171 | plt.title('Two-rule problems', fontsize=title_fontsize)
172 | results_fname = plot_dir + 'tworule_prog_vs_noprog_gen_gpt3_vs_human.png'
173 | ax.set_aspect(2.5)
174 | plt.tight_layout()
175 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
176 | plt.close()
177 |
178 | ## Aligned vs. permuted logic problems
179 |
180 | # Load GPT-3 one-rule data
181 | gpt3_gen_acc_logic = np.load('./exp1_GPT3_data/aligned_vs_permuted_gen_acc.npz')
182 | gpt3_MC_acc_logic = np.load('./exp1_GPT3_data/aligned_vs_permuted_MC_acc.npz')
183 | gpt3_gen_acc_logic_mn = gpt3_gen_acc_logic['acc']
184 | gpt3_gen_acc_logic_err = gpt3_gen_acc_logic['err']
185 | gpt3_MC_acc_logic_mn = gpt3_MC_acc_logic['acc']
186 | gpt3_MC_acc_logic_err = gpt3_MC_acc_logic['err']
187 |
188 | # Load human one-rule data
189 | human_gen_acc_logic = np.load('./exp1_behavioral_data/aligned_vs_permuted_gen_acc_behavior.npz')
190 | human_MC_acc_logic = np.load('./exp1_behavioral_data/aligned_vs_permuted_MC_acc_behavior.npz')
191 | human_gen_acc_logic_mn = human_gen_acc_logic['acc']
192 | human_gen_acc_logic_err = human_gen_acc_logic['err']
193 | human_MC_acc_logic_mn = human_MC_acc_logic['acc']
194 | human_MC_acc_logic_err = human_MC_acc_logic['err']
195 |
196 | # Plot settings
197 | N_conds = gpt3_MC_acc_logic_mn.shape[0]
198 | x_points = np.arange(N_conds)
199 |
200 | # Plot - generative
201 | ax = plt.subplot(111)
202 | plt.bar(x_points - 0.2, gpt3_gen_acc_logic_mn, yerr=gpt3_gen_acc_logic_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
203 | plt.bar(x_points + 0.2, human_gen_acc_logic_mn, yerr=human_gen_acc_logic_err, color=human_color, edgecolor='black', width=ind_bar_width)
204 | plt.ylim([0,1])
205 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
206 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
207 | plt.xticks(x_points, ['Aligned', 'Permuted'], fontsize=plot_fontsize)
208 | plt.xlabel(' ', fontsize=plot_fontsize)
209 | hide_top_right(ax)
210 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
211 | plt.title('Logic problems', fontsize=title_fontsize)
212 | results_fname = plot_dir + 'logic_aligned_vs_permuted_gen_gpt3_vs_human.png'
213 | ax.set_aspect(2.5)
214 | plt.tight_layout()
215 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
216 | plt.close()
217 |
218 | # Plot - multiple-choice
219 | ax = plt.subplot(111)
220 | plt.bar(x_points - 0.2, gpt3_MC_acc_logic_mn, yerr=gpt3_MC_acc_logic_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
221 | plt.bar(x_points + 0.2, human_MC_acc_logic_mn, yerr=human_MC_acc_logic_err, color=human_color, edgecolor='black', width=ind_bar_width)
222 | plt.ylim([0,1])
223 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
224 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize)
225 | plt.xticks(x_points, ['Aligned', 'Permuted'], fontsize=plot_fontsize)
226 | plt.xlabel(' ', fontsize=plot_fontsize)
227 | hide_top_right(ax)
228 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
229 | plt.title('Logic problems', fontsize=title_fontsize)
230 | results_fname = plot_dir + 'logic_aligned_vs_permuted_MC_gpt3_vs_human.png'
231 | ax.set_aspect(2.5)
232 | plt.tight_layout()
233 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
234 | plt.close()
235 |
236 | ### Relational complexity analysis
237 |
238 | ## Two-rule problems
239 |
240 | # Load GPT-3 one-rule data
241 | gpt3_gen_acc_tworule = np.load('./exp1_GPT3_data/tworule_prob_N_unique_rules_gen.npz')
242 | gpt3_MC_acc_tworule = np.load('./exp1_GPT3_data/tworule_prob_N_unique_rules_MC.npz')
243 | gpt3_gen_acc_tworule_mn = gpt3_gen_acc_tworule['acc']
244 | gpt3_gen_acc_tworule_err = gpt3_gen_acc_tworule['err']
245 | gpt3_MC_acc_tworule_mn = gpt3_MC_acc_tworule['acc']
246 | gpt3_MC_acc_tworule_err = gpt3_MC_acc_tworule['err']
247 |
248 | # Load human one-rule data
249 | human_gen_acc_tworule = np.load('./exp1_behavioral_data/tworule_prob_N_unique_rules_gen.npz')
250 | human_MC_acc_tworule = np.load('./exp1_behavioral_data/tworule_prob_N_unique_rules_MC.npz')
251 | human_gen_acc_tworule_mn = human_gen_acc_tworule['acc']
252 | human_gen_acc_tworule_err = human_gen_acc_tworule['err']
253 | human_MC_acc_tworule_mn = human_MC_acc_tworule['acc']
254 | human_MC_acc_tworule_err = human_MC_acc_tworule['err']
255 |
256 | # Plot settings
257 | N_conds = gpt3_MC_acc_tworule_mn.shape[0]
258 | x_points = np.arange(N_conds)
259 |
260 | # Plot - generative
261 | ax = plt.subplot(111)
262 | plt.bar(x_points - 0.2, gpt3_gen_acc_tworule_mn, yerr=gpt3_gen_acc_tworule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
263 | plt.bar(x_points + 0.2, human_gen_acc_tworule_mn, yerr=human_gen_acc_tworule_err, color=human_color, edgecolor='black', width=ind_bar_width)
264 | plt.ylim([0,1])
265 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
266 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
267 | plt.xticks(x_points, ['1','2'], fontsize=plot_fontsize)
268 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize)
269 | hide_top_right(ax)
270 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
271 | plt.title('Two-rule problems', fontsize=title_fontsize)
272 | results_fname = plot_dir + 'tworule_rel_complexity_gen_gpt3_vs_human.png'
273 | ax.set_aspect(2.5)
274 | plt.tight_layout()
275 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
276 | plt.close()
277 |
278 | # Plot - multiple-choice
279 | ax = plt.subplot(111)
280 | plt.bar(x_points - 0.2, gpt3_MC_acc_tworule_mn, yerr=gpt3_MC_acc_tworule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
281 | plt.bar(x_points + 0.2, human_MC_acc_tworule_mn, yerr=human_MC_acc_tworule_err, color=human_color, edgecolor='black', width=ind_bar_width)
282 | plt.ylim([0,1])
283 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
284 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize)
285 | plt.xticks(x_points, ['1','2'], fontsize=plot_fontsize)
286 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize)
287 | hide_top_right(ax)
288 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.75,1))
289 | plt.title('Two-rule problems', fontsize=title_fontsize)
290 | results_fname = plot_dir + 'tworule_rel_complexity_MC_gpt3_vs_human.png'
291 | ax.set_aspect(2.5)
292 | plt.tight_layout()
293 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
294 | plt.close()
295 |
296 | ## Three-rule problems
297 |
298 | # Load GPT-3 one-rule data
299 | gpt3_gen_acc_threerule = np.load('./exp1_GPT3_data/threerule_prob_N_unique_rules_gen.npz')
300 | gpt3_MC_acc_threerule = np.load('./exp1_GPT3_data/threerule_prob_N_unique_rules_MC.npz')
301 | gpt3_gen_acc_threerule_mn = gpt3_gen_acc_threerule['acc']
302 | gpt3_gen_acc_threerule_err = gpt3_gen_acc_threerule['err']
303 | gpt3_MC_acc_threerule_mn = gpt3_MC_acc_threerule['acc']
304 | gpt3_MC_acc_threerule_err = gpt3_MC_acc_threerule['err']
305 |
306 | # Load human one-rule data
307 | human_gen_acc_threerule = np.load('./exp1_behavioral_data/threerule_prob_N_unique_rules_gen.npz')
308 | human_MC_acc_threerule = np.load('./exp1_behavioral_data/threerule_prob_N_unique_rules_MC.npz')
309 | human_gen_acc_threerule_mn = human_gen_acc_threerule['acc']
310 | human_gen_acc_threerule_err = human_gen_acc_threerule['err']
311 | human_MC_acc_threerule_mn = human_MC_acc_threerule['acc']
312 | human_MC_acc_threerule_err = human_MC_acc_threerule['err']
313 |
314 | # Plot settings
315 | N_conds = gpt3_MC_acc_threerule_mn.shape[0]
316 | x_points = np.arange(N_conds)
317 |
318 | # Plot - generative
319 | ax = plt.subplot(111)
320 | plt.bar(x_points - 0.2, gpt3_gen_acc_threerule_mn, yerr=gpt3_gen_acc_threerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
321 | plt.bar(x_points + 0.2, human_gen_acc_threerule_mn, yerr=human_gen_acc_threerule_err, color=human_color, edgecolor='black', width=ind_bar_width)
322 | plt.ylim([0,1])
323 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
324 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize)
325 | plt.xticks(x_points, ['1','2','3'], fontsize=plot_fontsize)
326 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize)
327 | hide_top_right(ax)
328 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
329 | plt.title('Three-rule problems', fontsize=title_fontsize)
330 | results_fname = plot_dir + 'threerule_rel_complexity_gen_gpt3_vs_human.png'
331 | ax.set_aspect(2.5)
332 | plt.tight_layout()
333 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
334 | plt.close()
335 |
336 | # Plot - multiple-choice
337 | ax = plt.subplot(111)
338 | plt.bar(x_points - 0.2, gpt3_MC_acc_threerule_mn, yerr=gpt3_MC_acc_threerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width)
339 | plt.bar(x_points + 0.2, human_MC_acc_threerule_mn, yerr=human_MC_acc_threerule_err, color=human_color, edgecolor='black', width=ind_bar_width)
340 | plt.ylim([0,1])
341 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize)
342 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize)
343 | plt.xticks(x_points, ['1','2','3'], fontsize=plot_fontsize)
344 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize)
345 | hide_top_right(ax)
346 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1))
347 | plt.title('Three-rule problems', fontsize=title_fontsize)
348 | results_fname = plot_dir + 'threerule_rel_complexity_MC_gpt3_vs_human.png'
349 | ax.set_aspect(2.5)
350 | plt.tight_layout()
351 | plt.savefig(results_fname, dpi=300, bbox_inches="tight")
352 | plt.close()
353 |
354 |
--------------------------------------------------------------------------------
/story_analogies/human_vs_gpt3_data.csv:
--------------------------------------------------------------------------------
1 | probID,subjID,correct_pred,analogy_vs_similarity,human_vs_gpt
2 | 1,1,1,0,0
3 | 2,1,1,0,0
4 | 3,1,1,0,0
5 | 4,1,1,1,0
6 | 5,1,1,1,0
7 | 6,1,1,1,0
8 | 7,1,1,0,0
9 | 8,1,1,1,0
10 | 9,1,1,1,0
11 | 10,1,1,1,0
12 | 11,1,1,1,0
13 | 12,1,1,0,0
14 | 13,1,0,0,0
15 | 14,1,1,0,0
16 | 15,1,1,1,0
17 | 16,1,1,0,0
18 | 17,1,1,1,0
19 | 18,1,1,0,0
20 | 1,2,1,0,0
21 | 2,2,1,0,0
22 | 3,2,1,1,0
23 | 4,2,1,0,0
24 | 5,2,1,1,0
25 | 6,2,1,0,0
26 | 7,2,1,1,0
27 | 8,2,1,0,0
28 | 9,2,1,1,0
29 | 10,2,1,1,0
30 | 11,2,1,0,0
31 | 12,2,1,0,0
32 | 13,2,1,1,0
33 | 14,2,1,1,0
34 | 15,2,1,1,0
35 | 16,2,1,0,0
36 | 17,2,1,0,0
37 | 18,2,1,1,0
38 | 1,3,0,0,0
39 | 2,3,1,0,0
40 | 3,3,1,0,0
41 | 4,3,1,0,0
42 | 5,3,1,1,0
43 | 6,3,0,1,0
44 | 7,3,1,0,0
45 | 8,3,1,0,0
46 | 9,3,1,1,0
47 | 10,3,0,1,0
48 | 11,3,0,0,0
49 | 12,3,1,1,0
50 | 13,3,1,0,0
51 | 14,3,0,1,0
52 | 15,3,1,1,0
53 | 16,3,0,0,0
54 | 17,3,1,1,0
55 | 18,3,1,1,0
56 | 1,4,1,0,0
57 | 2,4,1,0,0
58 | 3,4,1,0,0
59 | 4,4,1,1,0
60 | 5,4,1,1,0
61 | 6,4,1,0,0
62 | 7,4,1,0,0
63 | 8,4,1,1,0
64 | 9,4,1,0,0
65 | 10,4,1,1,0
66 | 11,4,1,1,0
67 | 12,4,1,0,0
68 | 13,4,1,1,0
69 | 14,4,1,1,0
70 | 15,4,1,0,0
71 | 16,4,1,1,0
72 | 17,4,1,1,0
73 | 18,4,1,0,0
74 | 1,5,1,0,0
75 | 2,5,1,0,0
76 | 3,5,1,0,0
77 | 4,5,1,1,0
78 | 5,5,1,0,0
79 | 6,5,1,0,0
80 | 7,5,1,0,0
81 | 8,5,1,0,0
82 | 9,5,1,1,0
83 | 10,5,1,1,0
84 | 11,5,1,1,0
85 | 12,5,1,1,0
86 | 13,5,1,1,0
87 | 14,5,1,1,0
88 | 15,5,1,1,0
89 | 16,5,1,0,0
90 | 17,5,1,1,0
91 | 18,5,1,0,0
92 | 1,6,1,0,0
93 | 2,6,1,1,0
94 | 3,6,1,1,0
95 | 4,6,1,1,0
96 | 5,6,1,0,0
97 | 6,6,1,1,0
98 | 7,6,1,1,0
99 | 8,6,1,1,0
100 | 9,6,1,1,0
101 | 10,6,1,0,0
102 | 11,6,1,0,0
103 | 12,6,1,0,0
104 | 13,6,1,0,0
105 | 14,6,1,0,0
106 | 15,6,1,0,0
107 | 16,6,1,1,0
108 | 17,6,1,0,0
109 | 18,6,1,1,0
110 | 1,7,1,0,0
111 | 2,7,1,0,0
112 | 3,7,1,1,0
113 | 4,7,1,0,0
114 | 5,7,1,1,0
115 | 6,7,1,1,0
116 | 7,7,1,1,0
117 | 8,7,1,0,0
118 | 9,7,1,0,0
119 | 10,7,1,1,0
120 | 11,7,1,1,0
121 | 12,7,1,0,0
122 | 13,7,1,0,0
123 | 14,7,1,1,0
124 | 15,7,1,1,0
125 | 16,7,1,0,0
126 | 17,7,1,1,0
127 | 18,7,1,0,0
128 | 1,8,1,0,0
129 | 2,8,1,0,0
130 | 3,8,1,0,0
131 | 4,8,1,1,0
132 | 5,8,1,0,0
133 | 6,8,1,1,0
134 | 7,8,1,1,0
135 | 8,8,1,1,0
136 | 9,8,1,0,0
137 | 10,8,1,1,0
138 | 11,8,1,1,0
139 | 12,8,1,1,0
140 | 13,8,1,0,0
141 | 14,8,1,0,0
142 | 15,8,1,0,0
143 | 16,8,1,1,0
144 | 17,8,1,1,0
145 | 18,8,1,0,0
146 | 1,9,0,0,0
147 | 2,9,0,1,0
148 | 3,9,0,1,0
149 | 4,9,1,0,0
150 | 5,9,1,1,0
151 | 6,9,1,1,0
152 | 7,9,0,0,0
153 | 8,9,1,0,0
154 | 9,9,0,1,0
155 | 10,9,0,1,0
156 | 11,9,1,1,0
157 | 12,9,1,0,0
158 | 13,9,1,1,0
159 | 14,9,0,0,0
160 | 15,9,1,1,0
161 | 16,9,1,0,0
162 | 17,9,0,0,0
163 | 18,9,1,0,0
164 | 1,10,1,0,0
165 | 2,10,1,1,0
166 | 3,10,1,0,0
167 | 4,10,1,1,0
168 | 5,10,1,0,0
169 | 6,10,0,0,0
170 | 7,10,0,0,0
171 | 8,10,1,1,0
172 | 9,10,1,0,0
173 | 10,10,0,1,0
174 | 11,10,1,1,0
175 | 12,10,1,1,0
176 | 13,10,1,1,0
177 | 14,10,1,0,0
178 | 15,10,1,0,0
179 | 16,10,1,0,0
180 | 17,10,1,1,0
181 | 18,10,1,1,0
182 | 1,11,1,1,0
183 | 2,11,1,1,0
184 | 3,11,1,1,0
185 | 4,11,1,0,0
186 | 5,11,1,0,0
187 | 6,11,1,1,0
188 | 7,11,1,0,0
189 | 8,11,1,0,0
190 | 9,11,1,1,0
191 | 10,11,1,0,0
192 | 11,11,1,1,0
193 | 12,11,1,1,0
194 | 13,11,0,0,0
195 | 14,11,0,0,0
196 | 15,11,1,0,0
197 | 16,11,1,1,0
198 | 17,11,1,0,0
199 | 18,11,1,1,0
200 | 1,12,0,1,0
201 | 2,12,1,0,0
202 | 3,12,1,0,0
203 | 4,12,1,0,0
204 | 5,12,1,1,0
205 | 6,12,1,0,0
206 | 7,12,1,1,0
207 | 8,12,1,1,0
208 | 9,12,1,1,0
209 | 10,12,1,0,0
210 | 11,12,1,0,0
211 | 12,12,1,1,0
212 | 13,12,1,0,0
213 | 14,12,1,1,0
214 | 15,12,1,1,0
215 | 16,12,1,0,0
216 | 17,12,1,0,0
217 | 18,12,1,1,0
218 | 1,13,1,1,0
219 | 2,13,0,0,0
220 | 3,13,1,1,0
221 | 4,13,1,1,0
222 | 5,13,1,0,0
223 | 6,13,1,0,0
224 | 7,13,1,1,0
225 | 8,13,0,0,0
226 | 9,13,1,1,0
227 | 10,13,1,0,0
228 | 11,13,1,1,0
229 | 12,13,1,1,0
230 | 13,13,0,0,0
231 | 14,13,1,0,0
232 | 15,13,1,0,0
233 | 16,13,1,1,0
234 | 17,13,1,0,0
235 | 18,13,1,1,0
236 | 1,14,1,1,0
237 | 2,14,1,1,0
238 | 3,14,1,0,0
239 | 4,14,1,0,0
240 | 5,14,1,0,0
241 | 6,14,1,1,0
242 | 7,14,1,0,0
243 | 8,14,1,0,0
244 | 9,14,1,0,0
245 | 10,14,0,1,0
246 | 11,14,1,1,0
247 | 12,14,0,1,0
248 | 13,14,0,1,0
249 | 14,14,1,0,0
250 | 15,14,1,1,0
251 | 16,14,1,1,0
252 | 17,14,1,0,0
253 | 18,14,1,0,0
254 | 1,15,1,1,0
255 | 2,15,1,1,0
256 | 3,15,1,1,0
257 | 4,15,1,0,0
258 | 5,15,1,1,0
259 | 6,15,1,1,0
260 | 7,15,1,0,0
261 | 8,15,1,1,0
262 | 9,15,1,0,0
263 | 10,15,1,1,0
264 | 11,15,1,0,0
265 | 12,15,1,0,0
266 | 13,15,1,0,0
267 | 14,15,1,1,0
268 | 15,15,1,0,0
269 | 16,15,1,0,0
270 | 17,15,1,0,0
271 | 18,15,1,1,0
272 | 1,16,1,0,0
273 | 2,16,1,1,0
274 | 3,16,1,0,0
275 | 4,16,1,1,0
276 | 5,16,1,1,0
277 | 6,16,1,1,0
278 | 7,16,1,1,0
279 | 8,16,0,0,0
280 | 9,16,1,0,0
281 | 10,16,1,0,0
282 | 11,16,1,1,0
283 | 12,16,1,1,0
284 | 13,16,1,1,0
285 | 14,16,1,0,0
286 | 15,16,1,0,0
287 | 16,16,1,1,0
288 | 17,16,1,0,0
289 | 18,16,0,0,0
290 | 1,17,0,0,0
291 | 2,17,0,0,0
292 | 3,17,1,0,0
293 | 4,17,1,1,0
294 | 5,17,0,1,0
295 | 6,17,1,0,0
296 | 7,17,0,0,0
297 | 8,17,1,0,0
298 | 9,17,1,0,0
299 | 10,17,1,1,0
300 | 11,17,0,0,0
301 | 12,17,1,0,0
302 | 13,17,0,1,0
303 | 14,17,1,1,0
304 | 15,17,1,1,0
305 | 16,17,0,1,0
306 | 17,17,1,1,0
307 | 18,17,1,1,0
308 | 1,18,0,0,0
309 | 2,18,1,1,0
310 | 3,18,1,1,0
311 | 4,18,1,1,0
312 | 5,18,1,1,0
313 | 6,18,0,0,0
314 | 7,18,1,1,0
315 | 8,18,1,1,0
316 | 9,18,0,1,0
317 | 10,18,1,0,0
318 | 11,18,1,0,0
319 | 12,18,1,0,0
320 | 13,18,1,0,0
321 | 14,18,1,0,0
322 | 15,18,1,1,0
323 | 16,18,1,1,0
324 | 17,18,1,0,0
325 | 18,18,1,0,0
326 | 1,19,1,1,0
327 | 2,19,1,0,0
328 | 3,19,1,1,0
329 | 4,19,1,1,0
330 | 5,19,1,0,0
331 | 6,19,0,0,0
332 | 7,19,1,1,0
333 | 8,19,1,1,0
334 | 9,19,1,1,0
335 | 10,19,1,0,0
336 | 11,19,1,0,0
337 | 12,19,1,0,0
338 | 13,19,0,0,0
339 | 14,19,1,0,0
340 | 15,19,1,0,0
341 | 16,19,1,1,0
342 | 17,19,1,1,0
343 | 18,19,1,1,0
344 | 1,20,1,0,0
345 | 2,20,1,1,0
346 | 3,20,1,0,0
347 | 4,20,1,0,0
348 | 5,20,1,1,0
349 | 6,20,1,0,0
350 | 7,20,1,1,0
351 | 8,20,1,1,0
352 | 9,20,1,1,0
353 | 10,20,0,0,0
354 | 11,20,1,1,0
355 | 12,20,1,1,0
356 | 13,20,1,1,0
357 | 14,20,1,0,0
358 | 15,20,1,0,0
359 | 16,20,1,0,0
360 | 17,20,1,1,0
361 | 18,20,1,0,0
362 | 1,21,1,0,0
363 | 2,21,0,1,0
364 | 3,21,1,0,0
365 | 4,21,1,0,0
366 | 5,21,1,0,0
367 | 6,21,1,1,0
368 | 7,21,1,0,0
369 | 8,21,0,0,0
370 | 9,21,1,1,0
371 | 10,21,0,0,0
372 | 11,21,1,1,0
373 | 12,21,1,1,0
374 | 13,21,1,1,0
375 | 14,21,0,1,0
376 | 15,21,1,1,0
377 | 16,21,1,0,0
378 | 17,21,1,1,0
379 | 18,21,0,0,0
380 | 1,22,1,1,0
381 | 2,22,1,0,0
382 | 3,22,1,1,0
383 | 4,22,1,1,0
384 | 5,22,1,0,0
385 | 6,22,1,1,0
386 | 7,22,1,1,0
387 | 8,22,1,0,0
388 | 9,22,1,1,0
389 | 10,22,1,0,0
390 | 11,22,0,1,0
391 | 12,22,1,0,0
392 | 13,22,1,1,0
393 | 14,22,1,1,0
394 | 15,22,1,0,0
395 | 16,22,1,0,0
396 | 17,22,1,0,0
397 | 18,22,1,0,0
398 | 1,23,0,1,0
399 | 2,23,0,1,0
400 | 3,23,0,1,0
401 | 4,23,1,0,0
402 | 5,23,1,0,0
403 | 6,23,1,0,0
404 | 7,23,1,1,0
405 | 8,23,0,0,0
406 | 9,23,1,0,0
407 | 10,23,1,0,0
408 | 11,23,1,1,0
409 | 12,23,1,1,0
410 | 13,23,1,1,0
411 | 14,23,1,0,0
412 | 15,23,1,1,0
413 | 16,23,1,0,0
414 | 17,23,1,1,0
415 | 18,23,0,0,0
416 | 1,24,1,0,0
417 | 2,24,1,1,0
418 | 3,24,1,0,0
419 | 4,24,1,1,0
420 | 5,24,1,1,0
421 | 6,24,1,0,0
422 | 7,24,1,0,0
423 | 8,24,1,1,0
424 | 9,24,1,0,0
425 | 10,24,0,0,0
426 | 11,24,1,1,0
427 | 12,24,1,1,0
428 | 13,24,0,0,0
429 | 14,24,1,0,0
430 | 15,24,1,1,0
431 | 16,24,1,1,0
432 | 17,24,1,1,0
433 | 18,24,1,0,0
434 | 1,25,0,0,0
435 | 2,25,1,0,0
436 | 3,25,1,1,0
437 | 4,25,1,1,0
438 | 5,25,1,0,0
439 | 6,25,1,1,0
440 | 7,25,1,0,0
441 | 8,25,1,0,0
442 | 9,25,1,1,0
443 | 10,25,1,0,0
444 | 11,25,1,1,0
445 | 12,25,1,1,0
446 | 13,25,1,0,0
447 | 14,25,1,0,0
448 | 15,25,1,1,0
449 | 16,25,1,1,0
450 | 17,25,1,1,0
451 | 18,25,1,0,0
452 | 1,26,0,1,0
453 | 2,26,0,0,0
454 | 3,26,1,0,0
455 | 4,26,0,0,0
456 | 5,26,1,0,0
457 | 6,26,1,0,0
458 | 7,26,1,1,0
459 | 8,26,1,1,0
460 | 9,26,0,1,0
461 | 10,26,1,0,0
462 | 11,26,1,1,0
463 | 12,26,1,1,0
464 | 13,26,1,0,0
465 | 14,26,1,1,0
466 | 15,26,1,1,0
467 | 16,26,1,1,0
468 | 17,26,1,0,0
469 | 18,26,0,0,0
470 | 1,27,0,0,0
471 | 2,27,0,1,0
472 | 3,27,1,0,0
473 | 4,27,1,1,0
474 | 5,27,1,0,0
475 | 6,27,0,0,0
476 | 7,27,1,1,0
477 | 8,27,1,1,0
478 | 9,27,1,1,0
479 | 10,27,1,0,0
480 | 11,27,1,0,0
481 | 12,27,1,1,0
482 | 13,27,1,0,0
483 | 14,27,0,1,0
484 | 15,27,0,1,0
485 | 16,27,1,1,0
486 | 17,27,1,0,0
487 | 18,27,1,0,0
488 | 1,28,1,0,0
489 | 2,28,1,1,0
490 | 3,28,1,0,0
491 | 4,28,1,0,0
492 | 5,28,1,1,0
493 | 6,28,1,0,0
494 | 7,28,1,1,0
495 | 8,28,1,1,0
496 | 9,28,1,1,0
497 | 10,28,1,1,0
498 | 11,28,1,0,0
499 | 12,28,1,0,0
500 | 13,28,1,0,0
501 | 14,28,1,1,0
502 | 15,28,1,0,0
503 | 16,28,1,1,0
504 | 17,28,1,1,0
505 | 18,28,1,0,0
506 | 1,29,1,0,0
507 | 2,29,1,0,0
508 | 3,29,1,1,0
509 | 4,29,1,1,0
510 | 5,29,1,0,0
511 | 6,29,1,1,0
512 | 7,29,1,1,0
513 | 8,29,1,1,0
514 | 9,29,1,1,0
515 | 10,29,1,0,0
516 | 11,29,1,1,0
517 | 12,29,1,0,0
518 | 13,29,1,1,0
519 | 14,29,1,0,0
520 | 15,29,1,0,0
521 | 16,29,1,0,0
522 | 17,29,1,0,0
523 | 18,29,1,1,0
524 | 1,30,1,0,0
525 | 2,30,1,0,0
526 | 3,30,1,1,0
527 | 4,30,1,1,0
528 | 5,30,1,1,0
529 | 6,30,1,0,0
530 | 7,30,0,0,0
531 | 8,30,1,0,0
532 | 9,30,1,1,0
533 | 10,30,1,0,0
534 | 11,30,1,0,0
535 | 12,30,1,0,0
536 | 13,30,1,1,0
537 | 14,30,1,1,0
538 | 15,30,1,1,0
539 | 16,30,1,1,0
540 | 17,30,1,0,0
541 | 18,30,1,1,0
542 | 1,31,1,1,0
543 | 2,31,1,0,0
544 | 3,31,1,0,0
545 | 4,31,1,1,0
546 | 5,31,1,0,0
547 | 6,31,1,1,0
548 | 7,31,1,0,0
549 | 8,31,1,0,0
550 | 9,31,1,1,0
551 | 10,31,1,0,0
552 | 11,31,1,1,0
553 | 12,31,1,0,0
554 | 13,31,1,1,0
555 | 14,31,1,0,0
556 | 15,31,1,1,0
557 | 16,31,1,1,0
558 | 17,31,1,0,0
559 | 18,31,0,1,0
560 | 1,32,1,1,0
561 | 2,32,1,1,0
562 | 3,32,1,0,0
563 | 4,32,1,0,0
564 | 5,32,1,0,0
565 | 6,32,0,0,0
566 | 7,32,1,1,0
567 | 8,32,1,1,0
568 | 9,32,1,0,0
569 | 10,32,1,0,0
570 | 11,32,1,0,0
571 | 12,32,1,1,0
572 | 13,32,1,1,0
573 | 14,32,1,1,0
574 | 15,32,0,0,0
575 | 16,32,1,1,0
576 | 17,32,1,0,0
577 | 18,32,1,1,0
578 | 1,33,1,0,0
579 | 2,33,1,0,0
580 | 3,33,1,1,0
581 | 4,33,1,0,0
582 | 5,33,1,1,0
583 | 6,33,1,1,0
584 | 7,33,1,0,0
585 | 8,33,1,1,0
586 | 9,33,1,1,0
587 | 10,33,1,0,0
588 | 11,33,1,0,0
589 | 12,33,1,1,0
590 | 13,33,1,0,0
591 | 14,33,1,1,0
592 | 15,33,1,1,0
593 | 16,33,1,0,0
594 | 17,33,1,1,0
595 | 18,33,1,0,0
596 | 1,34,1,0,0
597 | 2,34,1,0,0
598 | 3,34,1,0,0
599 | 4,34,1,0,0
600 | 5,34,1,1,0
601 | 6,34,1,1,0
602 | 7,34,1,1,0
603 | 8,34,1,0,0
604 | 9,34,1,0,0
605 | 10,34,1,1,0
606 | 11,34,1,1,0
607 | 12,34,1,1,0
608 | 13,34,1,0,0
609 | 14,34,1,1,0
610 | 15,34,1,0,0
611 | 16,34,1,1,0
612 | 17,34,1,0,0
613 | 18,34,1,1,0
614 | 1,35,1,1,0
615 | 2,35,1,1,0
616 | 3,35,1,1,0
617 | 4,35,1,1,0
618 | 5,35,1,1,0
619 | 6,35,1,1,0
620 | 7,35,0,0,0
621 | 8,35,1,0,0
622 | 9,35,0,0,0
623 | 10,35,0,0,0
624 | 11,35,1,0,0
625 | 12,35,0,0,0
626 | 13,35,1,1,0
627 | 14,35,1,0,0
628 | 15,35,1,0,0
629 | 16,35,1,1,0
630 | 17,35,1,1,0
631 | 18,35,1,0,0
632 | 1,36,1,1,0
633 | 2,36,1,1,0
634 | 3,36,1,0,0
635 | 4,36,1,1,0
636 | 5,36,1,0,0
637 | 6,36,1,0,0
638 | 7,36,1,1,0
639 | 8,36,1,0,0
640 | 9,36,1,0,0
641 | 10,36,0,1,0
642 | 11,36,1,1,0
643 | 12,36,1,0,0
644 | 13,36,0,0,0
645 | 14,36,1,1,0
646 | 15,36,1,0,0
647 | 16,36,1,1,0
648 | 17,36,1,0,0
649 | 18,36,0,1,0
650 | 1,37,1,1,0
651 | 2,37,0,0,0
652 | 3,37,1,0,0
653 | 4,37,1,1,0
654 | 5,37,1,1,0
655 | 6,37,1,1,0
656 | 7,37,1,1,0
657 | 8,37,1,0,0
658 | 9,37,1,0,0
659 | 10,37,0,1,0
660 | 11,37,1,1,0
661 | 12,37,1,0,0
662 | 13,37,1,0,0
663 | 14,37,1,0,0
664 | 15,37,1,1,0
665 | 16,37,1,0,0
666 | 17,37,1,0,0
667 | 18,37,1,1,0
668 | 1,38,1,1,0
669 | 2,38,1,1,0
670 | 3,38,1,1,0
671 | 4,38,1,0,0
672 | 5,38,1,1,0
673 | 6,38,1,0,0
674 | 7,38,1,0,0
675 | 8,38,1,1,0
676 | 9,38,1,0,0
677 | 10,38,1,1,0
678 | 11,38,1,0,0
679 | 12,38,1,0,0
680 | 13,38,1,0,0
681 | 14,38,1,1,0
682 | 15,38,1,0,0
683 | 16,38,1,0,0
684 | 17,38,1,1,0
685 | 18,38,1,1,0
686 | 1,39,1,0,0
687 | 2,39,1,1,0
688 | 3,39,0,1,0
689 | 4,39,1,0,0
690 | 5,39,1,1,0
691 | 6,39,1,0,0
692 | 7,39,0,1,0
693 | 8,39,0,0,0
694 | 9,39,1,1,0
695 | 10,39,1,1,0
696 | 11,39,1,1,0
697 | 12,39,1,0,0
698 | 13,39,0,0,0
699 | 14,39,1,0,0
700 | 15,39,1,0,0
701 | 16,39,1,1,0
702 | 17,39,1,1,0
703 | 18,39,0,0,0
704 | 1,40,1,0,0
705 | 2,40,1,0,0
706 | 3,40,1,0,0
707 | 4,40,1,1,0
708 | 5,40,1,1,0
709 | 6,40,1,0,0
710 | 7,40,0,0,0
711 | 8,40,1,1,0
712 | 9,40,1,0,0
713 | 10,40,1,1,0
714 | 11,40,1,0,0
715 | 12,40,1,0,0
716 | 13,40,0,1,0
717 | 14,40,1,1,0
718 | 15,40,1,1,0
719 | 16,40,1,1,0
720 | 17,40,1,1,0
721 | 18,40,1,0,0
722 | 1,41,1,0,0
723 | 2,41,1,0,0
724 | 3,41,1,1,0
725 | 4,41,1,0,0
726 | 5,41,1,0,0
727 | 6,41,1,0,0
728 | 7,41,1,1,0
729 | 8,41,1,1,0
730 | 9,41,1,0,0
731 | 10,41,1,1,0
732 | 11,41,1,0,0
733 | 12,41,1,0,0
734 | 13,41,1,1,0
735 | 14,41,1,1,0
736 | 15,41,1,1,0
737 | 16,41,1,0,0
738 | 17,41,1,1,0
739 | 18,41,1,1,0
740 | 1,42,1,1,0
741 | 2,42,1,0,0
742 | 3,42,1,0,0
743 | 4,42,1,0,0
744 | 5,42,1,1,0
745 | 6,42,1,0,0
746 | 7,42,1,0,0
747 | 8,42,1,1,0
748 | 9,42,1,1,0
749 | 10,42,1,0,0
750 | 11,42,1,1,0
751 | 12,42,1,1,0
752 | 13,42,1,0,0
753 | 14,42,1,0,0
754 | 15,42,1,1,0
755 | 16,42,1,0,0
756 | 17,42,1,1,0
757 | 18,42,1,1,0
758 | 1,43,0,1,0
759 | 2,43,1,1,0
760 | 3,43,0,1,0
761 | 4,43,1,0,0
762 | 5,43,1,0,0
763 | 6,43,1,0,0
764 | 7,43,1,0,0
765 | 8,43,1,1,0
766 | 9,43,0,0,0
767 | 10,43,1,1,0
768 | 11,43,1,0,0
769 | 12,43,1,0,0
770 | 13,43,1,1,0
771 | 14,43,1,1,0
772 | 15,43,1,1,0
773 | 16,43,1,1,0
774 | 17,43,1,0,0
775 | 18,43,1,0,0
776 | 1,44,1,1,0
777 | 2,44,1,0,0
778 | 3,44,0,0,0
779 | 4,44,1,0,0
780 | 5,44,0,0,0
781 | 6,44,0,0,0
782 | 7,44,0,1,0
783 | 8,44,0,1,0
784 | 9,44,0,0,0
785 | 10,44,0,0,0
786 | 11,44,1,1,0
787 | 12,44,1,1,0
788 | 13,44,0,1,0
789 | 14,44,1,0,0
790 | 15,44,1,1,0
791 | 16,44,1,1,0
792 | 17,44,1,1,0
793 | 18,44,1,0,0
794 | 1,45,1,1,0
795 | 2,45,1,0,0
796 | 3,45,1,0,0
797 | 4,45,1,1,0
798 | 5,45,1,1,0
799 | 6,45,1,1,0
800 | 7,45,1,1,0
801 | 8,45,1,0,0
802 | 9,45,1,0,0
803 | 10,45,0,1,0
804 | 11,45,1,0,0
805 | 12,45,1,0,0
806 | 13,45,1,1,0
807 | 14,45,1,1,0
808 | 15,45,1,1,0
809 | 16,45,1,0,0
810 | 17,45,1,0,0
811 | 18,45,1,0,0
812 | 1,46,1,1,0
813 | 2,46,0,1,0
814 | 3,46,1,0,0
815 | 4,46,1,0,0
816 | 5,46,1,1,0
817 | 6,46,1,1,0
818 | 7,46,1,0,0
819 | 8,46,0,0,0
820 | 9,46,1,0,0
821 | 10,46,1,1,0
822 | 11,46,1,0,0
823 | 12,46,1,1,0
824 | 13,46,1,1,0
825 | 14,46,1,0,0
826 | 15,46,1,0,0
827 | 16,46,1,0,0
828 | 17,46,1,1,0
829 | 18,46,1,1,0
830 | 1,47,1,1,0
831 | 2,47,1,0,0
832 | 3,47,1,0,0
833 | 4,47,1,1,0
834 | 5,47,1,1,0
835 | 6,47,1,0,0
836 | 7,47,0,1,0
837 | 8,47,1,0,0
838 | 9,47,1,1,0
839 | 10,47,0,0,0
840 | 11,47,1,0,0
841 | 12,47,1,1,0
842 | 13,47,1,0,0
843 | 14,47,1,1,0
844 | 15,47,0,1,0
845 | 16,47,1,0,0
846 | 17,47,1,1,0
847 | 18,47,1,0,0
848 | 1,48,0,1,0
849 | 2,48,0,0,0
850 | 3,48,0,1,0
851 | 4,48,1,0,0
852 | 5,48,1,0,0
853 | 6,48,0,1,0
854 | 7,48,1,0,0
855 | 8,48,1,1,0
856 | 9,48,1,0,0
857 | 10,48,1,0,0
858 | 11,48,1,1,0
859 | 12,48,1,1,0
860 | 13,48,0,0,0
861 | 14,48,1,1,0
862 | 15,48,0,0,0
863 | 16,48,1,1,0
864 | 17,48,1,1,0
865 | 18,48,0,0,0
866 | 1,49,0,0,0
867 | 2,49,0,0,0
868 | 3,49,1,0,0
869 | 4,49,1,1,0
870 | 5,49,1,1,0
871 | 6,49,1,0,0
872 | 7,49,1,0,0
873 | 8,49,1,1,0
874 | 9,49,1,0,0
875 | 10,49,1,1,0
876 | 11,49,0,0,0
877 | 12,49,0,1,0
878 | 13,49,1,1,0
879 | 14,49,1,0,0
880 | 15,49,0,1,0
881 | 16,49,1,1,0
882 | 17,49,1,0,0
883 | 18,49,1,1,0
884 | 1,50,1,0,0
885 | 2,50,1,1,0
886 | 3,50,1,0,0
887 | 4,50,1,0,0
888 | 5,50,1,1,0
889 | 6,50,0,1,0
890 | 7,50,1,0,0
891 | 8,50,1,0,0
892 | 9,50,1,0,0
893 | 10,50,0,0,0
894 | 11,50,1,1,0
895 | 12,50,1,0,0
896 | 13,50,1,1,0
897 | 14,50,0,1,0
898 | 15,50,1,1,0
899 | 16,50,1,1,0
900 | 17,50,1,1,0
901 | 18,50,0,0,0
902 | 1,51,1,1,0
903 | 2,51,1,0,0
904 | 3,51,1,1,0
905 | 4,51,1,0,0
906 | 5,51,1,1,0
907 | 6,51,1,1,0
908 | 7,51,1,1,0
909 | 8,51,1,1,0
910 | 9,51,0,0,0
911 | 10,51,0,0,0
912 | 11,51,1,1,0
913 | 12,51,1,0,0
914 | 13,51,1,0,0
915 | 14,51,1,1,0
916 | 15,51,1,0,0
917 | 16,51,1,0,0
918 | 17,51,1,0,0
919 | 18,51,1,1,0
920 | 1,52,1,1,0
921 | 2,52,0,0,0
922 | 3,52,1,0,0
923 | 4,52,1,0,0
924 | 5,52,1,0,0
925 | 6,52,0,0,0
926 | 7,52,1,1,0
927 | 8,52,1,1,0
928 | 9,52,0,1,0
929 | 10,52,1,0,0
930 | 11,52,1,1,0
931 | 12,52,1,0,0
932 | 13,52,0,1,0
933 | 14,52,1,0,0
934 | 15,52,1,1,0
935 | 16,52,0,1,0
936 | 17,52,1,1,0
937 | 18,52,1,0,0
938 | 1,53,1,1,0
939 | 2,53,1,1,0
940 | 3,53,1,0,0
941 | 4,53,1,1,0
942 | 5,53,1,0,0
943 | 6,53,1,1,0
944 | 7,53,1,1,0
945 | 8,53,1,1,0
946 | 9,53,1,0,0
947 | 10,53,0,0,0
948 | 11,53,1,1,0
949 | 12,53,1,0,0
950 | 13,53,1,0,0
951 | 14,53,1,0,0
952 | 15,53,1,0,0
953 | 16,53,1,1,0
954 | 17,53,1,1,0
955 | 18,53,1,0,0
956 | 1,54,1,1,0
957 | 2,54,1,0,0
958 | 3,54,1,1,0
959 | 4,54,1,0,0
960 | 5,54,1,0,0
961 | 6,54,1,1,0
962 | 7,54,1,0,0
963 | 8,54,1,1,0
964 | 9,54,1,1,0
965 | 10,54,1,0,0
966 | 11,54,1,1,0
967 | 12,54,1,1,0
968 | 13,54,1,0,0
969 | 14,54,1,0,0
970 | 15,54,1,0,0
971 | 16,54,1,0,0
972 | 17,54,1,1,0
973 | 18,54,1,1,0
974 | 1,55,0,0,1
975 | 2,55,1,0,1
976 | 3,55,0,0,1
977 | 4,55,1,0,1
978 | 5,55,0,0,1
979 | 6,55,1,0,1
980 | 7,55,0,0,1
981 | 8,55,0,0,1
982 | 9,55,1,0,1
983 | 10,55,0,0,1
984 | 11,55,1,0,1
985 | 12,55,1,0,1
986 | 13,55,0,0,1
987 | 14,55,1,0,1
988 | 15,55,0,0,1
989 | 16,55,1,0,1
990 | 17,55,1,0,1
991 | 18,55,0,0,1
992 | 1,55,1,0,1
993 | 2,55,1,0,1
994 | 3,55,0,0,1
995 | 4,55,1,0,1
996 | 5,55,1,0,1
997 | 6,55,1,0,1
998 | 7,55,1,0,1
999 | 8,55,0,0,1
1000 | 9,55,1,0,1
1001 | 10,55,1,0,1
1002 | 11,55,1,0,1
1003 | 12,55,1,0,1
1004 | 13,55,0,0,1
1005 | 14,55,1,0,1
1006 | 15,55,1,0,1
1007 | 16,55,1,0,1
1008 | 17,55,1,0,1
1009 | 18,55,1,0,1
1010 | 1,55,0,1,1
1011 | 2,55,1,1,1
1012 | 3,55,1,1,1
1013 | 4,55,1,1,1
1014 | 5,55,0,1,1
1015 | 6,55,1,1,1
1016 | 7,55,1,1,1
1017 | 8,55,1,1,1
1018 | 9,55,1,1,1
1019 | 10,55,1,1,1
1020 | 11,55,1,1,1
1021 | 12,55,1,1,1
1022 | 13,55,1,1,1
1023 | 14,55,1,1,1
1024 | 15,55,0,1,1
1025 | 16,55,1,1,1
1026 | 17,55,1,1,1
1027 | 18,55,0,1,1
1028 | 1,55,1,1,1
1029 | 2,55,1,1,1
1030 | 3,55,0,1,1
1031 | 4,55,1,1,1
1032 | 5,55,0,1,1
1033 | 6,55,1,1,1
1034 | 7,55,1,1,1
1035 | 8,55,1,1,1
1036 | 9,55,1,1,1
1037 | 10,55,1,1,1
1038 | 11,55,1,1,1
1039 | 12,55,1,1,1
1040 | 13,55,0,1,1
1041 | 14,55,1,1,1
1042 | 15,55,0,1,1
1043 | 16,55,1,1,1
1044 | 17,55,0,1,1
1045 | 18,55,1,1,1
--------------------------------------------------------------------------------
/digit_mat/gen_4_5_rule_problems.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import permutations, combinations_with_replacement
3 | import builtins
4 | import random
5 | from copy import deepcopy
6 | import os
7 | import json
8 |
9 | # Method for generating distractors
10 | def gen_distractor(all_prob):
11 | # Methods and transformations
12 | distractor_methods = ['other_element',
13 | 'other_element_transformed',
14 | 'correct_answer_transformed',
15 | 'previous_distractor_transformed',
16 | 'random_new_element']
17 | transformations = np.array([1,2,-1,-2])
18 | # Additional methods for multi-rule problems
19 | if all_prob.shape[-1] > 1:
20 | distractor_methods.append('correct_answer_permuted')
21 | distractor_methods.append('other_element_permuted')
22 | distractor_methods.append('previous_distractor_permuted')
23 | distractor_methods.append('combination_other_elements')
24 | distractor_methods.append('combination_previous_distractors')
25 | # Loop through all problems
26 | all_probtype_answer_choices = []
27 | all_probtype_correct_ind = []
28 | for t in range(all_prob.shape[0]):
29 | all_answer_choices = []
30 | all_correct_ind = []
31 | for p in range(all_prob.shape[1]):
32 | # Problem
33 | prob = all_prob[t][p]
34 | # Extract correct answer
35 | correct_answer = prob[2,2,:]
36 | # Other elements (besides correct answer) in problem
37 | prob_flat = prob.reshape(prob.shape[0]*prob.shape[1],prob.shape[2])
38 | other_prob_elements = prob_flat[:-1,:]
39 | other_prob_elements = other_prob_elements[np.logical_not(np.all(other_prob_elements == np.expand_dims(correct_answer,0),1))]
40 | # Generate each distractor
41 | answer_choices = []
42 | for d in range(7):
43 | valid_distractor = False
44 | while not valid_distractor:
45 | # Sample method
46 | method = distractor_methods[int(np.floor(np.random.rand() * len(distractor_methods)))]
47 | # Other element from problem
48 | if method == 'other_element':
49 | np.random.shuffle(other_prob_elements)
50 | distractor = deepcopy(other_prob_elements[0])
51 | # Other element transformed
52 | elif method == 'other_element_transformed':
53 | np.random.shuffle(other_prob_elements)
54 | distractor = deepcopy(other_prob_elements[0])
55 | transform_dim = int(np.floor(np.random.rand() * all_prob.shape[-1]))
56 | np.random.shuffle(transformations)
57 | distractor[transform_dim] += transformations[0]
58 | # Correct answer transformed
59 | elif method == 'correct_answer_transformed':
60 | transform_dim = int(np.floor(np.random.rand() * all_prob.shape[-1]))
61 | np.random.shuffle(transformations)
62 | distractor = deepcopy(correct_answer)
63 | distractor[transform_dim] += transformations[0]
64 | # Previous distractor transformed
65 | elif method == 'previous_distractor_transformed':
66 | random.shuffle(answer_choices)
67 | if len(answer_choices) > 0:
68 | distractor = deepcopy(answer_choices[0])
69 | transform_dim = int(np.floor(np.random.rand() * all_prob.shape[-1]))
70 | np.random.shuffle(transformations)
71 | distractor[transform_dim] += transformations[0]
72 | else:
73 | distractor = np.ones(all_prob.shape[-1]).astype(int) * -1
74 | # Random new element
75 | elif method == 'random_new_element':
76 | distractor = np.floor(np.random.rand(all_prob.shape[-1]) * 10).astype(int)
77 | # Correct answer permuted
78 | elif method == 'correct_answer_permuted':
79 | valid_perm = builtins.list(permutations(np.arange(all_prob.shape[-1]),all_prob.shape[-1]))[1:]
80 | random.shuffle(valid_perm)
81 | distractor = deepcopy(correct_answer)[builtins.list(valid_perm[0])]
82 | # Other element permuted
83 | elif method == 'other_element_permuted':
84 | valid_perm = builtins.list(permutations(np.arange(all_prob.shape[-1]),all_prob.shape[-1]))[1:]
85 | random.shuffle(valid_perm)
86 | np.random.shuffle(other_prob_elements)
87 | other_element = other_prob_elements[0]
88 | distractor = deepcopy(other_element)[builtins.list(valid_perm[0])]
89 | # Previous distractor permuted
90 | elif method == 'previous_distractor_permuted':
91 | random.shuffle(answer_choices)
92 | if len(answer_choices) > 0:
93 | valid_perm = builtins.list(permutations(np.arange(all_prob.shape[-1]),all_prob.shape[-1]))[1:]
94 | random.shuffle(valid_perm)
95 | previous_distractor = answer_choices[0]
96 | distractor = deepcopy(previous_distractor)[builtins.list(valid_perm[0])]
97 | else:
98 | distractor = np.ones(all_prob.shape[-1]).astype(int) * -1
99 | # Combination of other elements
100 | elif method == 'combination_other_elements':
101 | distractor = []
102 | for dim in range(all_prob.shape[-1]):
103 | other_prob_elements_copy = deepcopy(other_prob_elements)
104 | np.random.shuffle(other_prob_elements_copy)
105 | distractor.append(other_prob_elements_copy[0,dim])
106 | distractor = np.array(distractor)
107 | # Combination of previous distractors
108 | elif method == 'combination_previous_distractors':
109 | random.shuffle(answer_choices)
110 | if len(answer_choices) > 1:
111 | distractor = []
112 | for dim in range(all_prob.shape[-1]):
113 | answer_choices_copy = deepcopy(answer_choices)
114 | np.random.shuffle(answer_choices_copy)
115 | distractor.append(answer_choices_copy[0][dim])
116 | distractor = np.array(distractor)
117 | else:
118 | distractor = np.ones(all_prob.shape[-1]).astype(int) * -1
119 | # Check to make sure distractor isn't invalid or doesn't already exist
120 | if np.all(distractor >= 0) and np.all(distractor <= 9) and not np.all(distractor == correct_answer):
121 | if len(answer_choices) == 0:
122 | answer_choices.append(distractor)
123 | valid_distractor = True
124 | else:
125 | if np.logical_not(np.any(np.all(np.array(answer_choices) == np.expand_dims(distractor,0),1))):
126 | answer_choices.append(distractor)
127 | valid_distractor = True
128 | # Add correct answer and shuffle
129 | answer_choices.append(correct_answer)
130 | answer_choices = np.array(answer_choices)
131 | shuffled_order = np.arange(8)
132 | np.random.shuffle(shuffled_order)
133 | correct_ind = np.where(shuffled_order == 7)[0][0]
134 | answer_choices = answer_choices[shuffled_order]
135 | # Add to list
136 | all_answer_choices.append(answer_choices)
137 | all_correct_ind.append(correct_ind)
138 | # Combine across problem types
139 | all_probtype_answer_choices.append(np.array(all_answer_choices))
140 | all_probtype_correct_ind.append(np.array(all_correct_ind))
141 | # Convert to arrays
142 | all_probtype_answer_choices = np.array(all_probtype_answer_choices)
143 | all_probtype_correct_ind = np.array(all_probtype_correct_ind)
144 | return all_probtype_answer_choices, all_probtype_correct_ind
145 |
146 | # Save problems as json and numpy files
147 | def save_prob(all_prob, all_answer_choices, all_correct_ind, prob_type_name, all_problems_np, all_problems_js, perm_invariant=False):
148 | # Add problems to numpy dict
149 | all_problems_np[prob_type_name] = {'prob': all_prob, 'answer_choices': all_answer_choices, 'correct_ind': all_correct_ind, 'perm_invariant': perm_invariant}
150 | # Convert to strings and save as json
151 | all_data = []
152 | for p in range(all_prob.shape[0]):
153 | # Convert problem to string
154 | prompt = ''
155 | for r in range(3):
156 | for c in range(3):
157 | prompt += '['
158 | if r == 2 and c == 2:
159 | for i in range(len(all_prob[p][1][c])):
160 | prompt += '  '
161 | if i < (len(all_prob[p][1][c]) - 1):
162 | prompt += ' '
163 | else:
164 | for i in range(len(all_prob[p][r][c])):
165 | if all_prob[p][r][c][i] == -1:
166 | prompt += '  '
167 | else:
168 | prompt += str(all_prob[p][r][c][i])
169 |
170 | if i < (len(all_prob[p][r][c]) - 1):
171 | prompt += ' '
172 | prompt += ']'
173 | if c < 2:
174 | prompt += '   '
175 | if r < 2 and c == 2:
176 | prompt += '
'
177 | # Convert choices to strings
178 | options = []
179 | for a in range(8):
180 | option = '['
181 | for i in range(len(all_answer_choices[p][a])):
182 | option += str(all_answer_choices[p][a][i])
183 | if i < (len(all_answer_choices[p][a]) - 1):
184 | option += ' '
185 | if len(all_answer_choices[p][a]) == 0:
186 | option += '  '
187 | option += ']'
188 | options.append(option)
189 | # Add to dataset
190 | all_data.append({'prompt': prompt, 'options': options, 'correct': int(all_correct_ind[p]), 'prob_ind': p})
191 | # Add to javascript data
192 | all_problems_js[prob_type_name] = all_data
193 | return all_problems_np, all_problems_js
194 |
195 | #### 4-rule and 5-rule problems:
196 | # - 1 for each of 15 4-rule combinations
197 | # - 1 for each of 21 5-rule combinations
198 |
199 | # Number of problems-per-category will either be N (below) or maximum number possible
200 | N_probs = 100
201 |
202 | # All 10choose3 permutations
203 | all_10c3_perm = np.array(builtins.list(permutations(np.arange(10),3)))
204 |
205 | # Constant
206 | all_constant = []
207 | all_row_constant = []
208 | all_col_constant = []
209 | for p in range(all_10c3_perm.shape[0]):
210 | row_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][0], all_10c3_perm[p][0]],
211 | [all_10c3_perm[p][1], all_10c3_perm[p][1], all_10c3_perm[p][1]],
212 | [all_10c3_perm[p][2], all_10c3_perm[p][2], all_10c3_perm[p][2]]])
213 | all_row_constant.append(row_prob)
214 | all_constant.append(row_prob)
215 | col_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]],
216 | [all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]],
217 | [all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]]])
218 | all_constant.append(col_prob)
219 | all_col_constant.append(col_prob)
220 | all_constant = np.array(all_constant)
221 |
222 | # Distribution-of-3
223 | all_dist3 = []
224 | all_dist3_diag1 = []
225 | all_dist3_diag2 = []
226 | for p in range(all_10c3_perm.shape[0]):
227 | diag1_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]],
228 | [all_10c3_perm[p][1], all_10c3_perm[p][2], all_10c3_perm[p][0]],
229 | [all_10c3_perm[p][2], all_10c3_perm[p][0], all_10c3_perm[p][1]]])
230 | all_dist3_diag1.append(diag1_prob)
231 | all_dist3.append(diag1_prob)
232 | diag2_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]],
233 | [all_10c3_perm[p][2], all_10c3_perm[p][0], all_10c3_perm[p][1]],
234 | [all_10c3_perm[p][1], all_10c3_perm[p][2], all_10c3_perm[p][0]]])
235 | all_dist3_diag2.append(diag2_prob)
236 | all_dist3.append(diag2_prob)
237 | all_dist3 = np.array(all_dist3)
238 | all_dist3_diag1 = np.array(all_dist3_diag1)
239 | all_dist3_diag2 = np.array(all_dist3_diag2)
240 | # Select subset of distribution-of-3 problems
241 | np.random.shuffle(all_dist3_diag1)
242 |
243 | # Progression
244 | prog_size1 = np.array([np.arange(0,5), np.arange(1,6), np.arange(2,7), np.arange(3,8), np.arange(4,9), np.arange(5,10)])
245 | prog_size1_reversed = np.fliplr(prog_size1)
246 | prog_size2 = np.array([np.arange(0,10,2), np.arange(1,11,2)])
247 | prog_size2_reversed = np.fliplr(prog_size2)
248 | all_prog_range = np.concatenate([prog_size1, prog_size1_reversed, prog_size2, prog_size2_reversed], 0)
249 | size1_size2 = np.concatenate([np.zeros(prog_size1.shape[0] + prog_size1_reversed.shape[0]), np.ones(prog_size2.shape[0] + prog_size2_reversed.shape[0])])
250 | all_prog = []
251 | all_prog_size1 = []
252 | all_prog_size2 = []
253 | for p in range(all_prog_range.shape[0]):
254 | prog_prob = np.array([[all_prog_range[p][0], all_prog_range[p][1], all_prog_range[p][2]],
255 | [all_prog_range[p][1], all_prog_range[p][2], all_prog_range[p][3]],
256 | [all_prog_range[p][2], all_prog_range[p][3], all_prog_range[p][4]]])
257 | all_prog.append(prog_prob)
258 | if size1_size2[p] == 0:
259 | all_prog_size1.append(prog_prob)
260 | elif size1_size2[p] == 1:
261 | all_prog_size2.append(prog_prob)
262 | all_prog = np.array(all_prog)
263 |
264 | # All 4-rule and 5-rule sets (combinations with replacement)
265 | all_4rule_comb = builtins.list(combinations_with_replacement(np.arange(3), 4))
266 | all_5rule_comb = builtins.list(combinations_with_replacement(np.arange(3), 5))
267 | # All 4-rule and 5-rule permutations (with replacement)
268 | # 4 rules
269 | all_4rule_perm = []
270 | for r1 in range(3):
271 | for r2 in range(3):
272 | for r3 in range(3):
273 | for r4 in range(3):
274 | all_4rule_perm.append([r1, r2, r3, r4])
275 | all_4rule_perm = np.array(all_4rule_perm)
276 | #5 rules
277 | all_5rule_perm = []
278 | for r1 in range(3):
279 | for r2 in range(3):
280 | for r3 in range(3):
281 | for r4 in range(3):
282 | for r5 in range(3):
283 | all_5rule_perm.append([r1, r2, r3, r4, r5])
284 | all_5rule_perm = np.array(all_5rule_perm)
285 | # Sort permutations by combination
286 | # 4 rules
287 | all_4rule_perm_sorted = []
288 | for c in range(len(all_4rule_comb)):
289 | all_4rule_perm_sorted.append(all_4rule_perm[np.all(np.expand_dims(np.array(all_4rule_comb[c]),0) == np.sort(all_4rule_perm,1), 1)])
290 | # 5 rules
291 | all_5rule_perm_sorted = []
292 | for c in range(len(all_5rule_comb)):
293 | all_5rule_perm_sorted.append(all_5rule_perm[np.all(np.expand_dims(np.array(all_5rule_comb[c]),0) == np.sort(all_5rule_perm,1), 1)])
294 |
295 | # Combine problem types
296 | prob_types = [all_constant, all_dist3, all_prog]
297 |
298 | # Generate 4-rule problems
299 | all_4rule_prob = []
300 | for c in range(len(all_4rule_comb)):
301 | all_comb_prob = []
302 | for p in range(N_probs):
303 | duplicate_prob = True
304 | while duplicate_prob:
305 | # Randomly sample permutation
306 | all_perm = all_4rule_perm_sorted[c]
307 | np.random.shuffle(all_perm)
308 | perm = all_perm[0]
309 | # Sample rule instances
310 | # Rule 1
311 | r1_ind = np.floor(np.random.rand() * prob_types[perm[0]].shape[0]).astype(int)
312 | r1 = prob_types[perm[0]][r1_ind]
313 | # Rule 2
314 | duplicate_rule = True
315 | while duplicate_rule:
316 | r2_ind = np.floor(np.random.rand() * prob_types[perm[1]].shape[0]).astype(int)
317 | r2 = prob_types[perm[1]][r2_ind]
318 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(2).astype(bool).flatten())])
319 | # Rule 3
320 | duplicate_rule = True
321 | while duplicate_rule:
322 | r3_ind = np.floor(np.random.rand() * prob_types[perm[2]].shape[0]).astype(int)
323 | r3 = prob_types[perm[2]][r3_ind]
324 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(3).astype(bool).flatten())])
325 | # Rule 4
326 | duplicate_rule = True
327 | while duplicate_rule:
328 | r4_ind = np.floor(np.random.rand() * prob_types[perm[3]].shape[0]).astype(int)
329 | r4 = prob_types[perm[3]][r4_ind]
330 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(4).astype(bool).flatten())])
331 | # Combine rules 1-3
332 | prob = np.stack([r1,r2,r3,r4],2)
333 | # Check if duplicate
334 | duplicate_detected = False
335 | for i in range(len(all_comb_prob)):
336 | if np.all(prob == all_comb_prob[i]):
337 | duplicate_detected = True
338 | if not duplicate_detected:
339 | duplicate_prob = False
340 | all_comb_prob.append(prob)
341 | all_comb_prob = np.array(all_comb_prob)
342 | all_4rule_prob.append(all_comb_prob)
343 | all_4rule_prob = np.array(all_4rule_prob)
344 | # Generate distractors
345 | all_4rule_answer_choices, all_4rule_correct_ind = gen_distractor(all_4rule_prob)
346 |
347 | # Generate 5-rule problems
348 | all_5rule_prob = []
349 | for c in range(len(all_5rule_comb)):
350 | all_comb_prob = []
351 | for p in range(N_probs):
352 | duplicate_prob = True
353 | while duplicate_prob:
354 | # Randomly sample permutation
355 | all_perm = all_5rule_perm_sorted[c]
356 | np.random.shuffle(all_perm)
357 | perm = all_perm[0]
358 | # Sample rule instances
359 | # Rule 1
360 | r1_ind = np.floor(np.random.rand() * prob_types[perm[0]].shape[0]).astype(int)
361 | r1 = prob_types[perm[0]][r1_ind]
362 | # Rule 2
363 | duplicate_rule = True
364 | while duplicate_rule:
365 | r2_ind = np.floor(np.random.rand() * prob_types[perm[1]].shape[0]).astype(int)
366 | r2 = prob_types[perm[1]][r2_ind]
367 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(2).astype(bool).flatten())])
368 | # Rule 3
369 | duplicate_rule = True
370 | while duplicate_rule:
371 | r3_ind = np.floor(np.random.rand() * prob_types[perm[2]].shape[0]).astype(int)
372 | r3 = prob_types[perm[2]][r3_ind]
373 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(3).astype(bool).flatten())])
374 | # Rule 4
375 | duplicate_rule = True
376 | while duplicate_rule:
377 | r4_ind = np.floor(np.random.rand() * prob_types[perm[3]].shape[0]).astype(int)
378 | r4 = prob_types[perm[3]][r4_ind]
379 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(4).astype(bool).flatten())])
380 | # Rule 5
381 | duplicate_rule = True
382 | while duplicate_rule:
383 | r5_ind = np.floor(np.random.rand() * prob_types[perm[4]].shape[0]).astype(int)
384 | r5 = prob_types[perm[4]][r5_ind]
385 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten(), r5.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten(), r5.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(5).astype(bool).flatten())])
386 | # Combine rules 1-3
387 | prob = np.stack([r1,r2,r3,r4,r5],2)
388 | # Check if duplicate
389 | duplicate_detected = False
390 | for i in range(len(all_comb_prob)):
391 | if np.all(prob == all_comb_prob[i]):
392 | duplicate_detected = True
393 | if not duplicate_detected:
394 | duplicate_prob = False
395 | all_comb_prob.append(prob)
396 | all_comb_prob = np.array(all_comb_prob)
397 | all_5rule_prob.append(all_comb_prob)
398 | all_5rule_prob = np.array(all_5rule_prob)
399 | # Generate distractors
400 | all_5rule_answer_choices, all_5rule_correct_ind = gen_distractor(all_5rule_prob)
401 |
402 | # Convert problems to strings and save as js script, also as numpy file
403 | all_problems_np = {}
404 | all_problems_js = {}
405 | # 4-rule problems
406 | for c in range(all_4rule_prob.shape[0]):
407 | all_problems_np, all_problems_js = save_prob(all_4rule_prob[c], all_4rule_answer_choices[c], all_4rule_correct_ind[c], 'four_rule_comb' + str(c), all_problems_np, all_problems_js)
408 | # 5-rule problems
409 | for c in range(all_5rule_prob.shape[0]):
410 | all_problems_np, all_problems_js = save_prob(all_5rule_prob[c], all_5rule_answer_choices[c], all_5rule_correct_ind[c], 'five_rule_comb' + str(c), all_problems_np, all_problems_js)
411 | # Save numpy file
412 | np_fname = './all_4_5_rule_problems.npz'
413 | np.savez(np_fname, all_problems=all_problems_np)
414 | # Convert to json string
415 | json_string = json.dumps(all_problems_js)
416 | # Write to js script
417 | js_fname = './all_4_5_rule_problems.js'
418 | js_fid = open(js_fname, 'w')
419 | js_fid.write('var all_problems = ' + json_string)
420 | js_fid.close()
421 |
--------------------------------------------------------------------------------