├── UCLA_VAT ├── UCLA_VAT.xlsx ├── UCLA_VAT_results.npz ├── UCLA_VAT_ind_subj_data.xlsx ├── README.md ├── eval_gpt_UCLA_VAT.py └── analyze_UCLA_VAT.py ├── digit_mat ├── all_problems.npz ├── all_problems_1thru5.npz ├── gpt_matprob_results.npz ├── all_4_5_rule_problems.npz ├── N_unique_rules_2rule_prob.npz ├── N_unique_rules_3rule_prob.npz ├── exp1_GPT3_data │ ├── all_prob.npz │ ├── all_onerule_MC_acc.npz │ ├── all_onerule_gen_acc.npz │ ├── all_probcat_MC_acc.npz │ ├── all_probcat_gen_acc.npz │ ├── aligned_vs_permuted_MC_acc.npz │ ├── aligned_vs_permuted_gen_acc.npz │ ├── tworule_prob_N_unique_rules_MC.npz │ ├── tworule_prog_vs_noprog_gen_acc.npz │ ├── threerule_prob_N_unique_rules_MC.npz │ ├── tworule_prob_N_unique_rules_gen.npz │ └── threerule_prob_N_unique_rules_gen.npz ├── gpt_matprob_results_1thru5.npz ├── exp2_GPT3_data │ ├── all_onerule_MC_acc.npz │ ├── all_onerule_gen_acc.npz │ ├── all_probcat_MC_acc.npz │ └── all_probcat_gen_acc.npz ├── exp1_behavioral_data │ ├── ind_subj_results.npz │ ├── probcat_MC_acc_behavior.npz │ ├── probcat_gen_acc_behavior.npz │ ├── probcat_MC_acc_behavior_onerule.npz │ ├── tworule_prob_N_unique_rules_MC.npz │ ├── tworule_prob_N_unique_rules_gen.npz │ ├── probcat_gen_acc_behavior_onerule.npz │ ├── threerule_prob_N_unique_rules_MC.npz │ ├── threerule_prob_N_unique_rules_gen.npz │ ├── aligned_vs_permuted_MC_acc_behavior.npz │ ├── aligned_vs_permuted_gen_acc_behavior.npz │ └── probcat_gen_acc_behavior_prog_tworule.npz ├── exp2_behavioral_data │ ├── ind_subj_results.npz │ ├── probcat_MC_acc_behavior.npz │ ├── probcat_gen_acc_behavior.npz │ ├── probcat_MC_acc_behavior_onerule.npz │ └── probcat_gen_acc_behavior_onerule.npz ├── exp1_vs_exp2_stats.r ├── exp1_corr_analysis.py ├── README.md ├── exp1_stats.r ├── exp2_plot_GPT3_vs_human.py ├── analyze_gpt3_exp2.py ├── combine_problems_1thru5.py ├── exp1_vs_exp2_create_stats_dset.py ├── exp1_create_stats_dset.py ├── eval_gpt_matprob.py ├── eval_gpt_matprob_prog_1thru5.py ├── analyze_gpt3_exp1.py ├── exp1_plot_GPT3_vs_human.py └── gen_4_5_rule_problems.py ├── letter_string ├── all_prob.npz ├── GPT3_results │ ├── onegen_acc.npz │ ├── all_gen_acc.npz │ ├── realworld_acc.npz │ ├── zerogen_acc.npz │ ├── ind_trial_results.npz │ └── prob_subtype_acc.npz ├── gpt3_letterstring_results.npz ├── behavioral_results │ ├── all_gen_acc.npz │ ├── onegen_acc.npz │ ├── zerogen_acc.npz │ ├── realworld_acc.npz │ ├── ind_subj_results.npz │ └── prob_subtype_acc.npz ├── GPT3_results_noprompt │ ├── all_gen_acc.npz │ ├── onegen_acc.npz │ ├── zerogen_acc.npz │ ├── realworld_acc.npz │ ├── prob_subtype_acc.npz │ └── ind_trial_results.npz ├── GPT3_results_sentence │ ├── all_gen_acc.npz │ ├── onegen_acc.npz │ ├── zerogen_acc.npz │ ├── realworld_acc.npz │ ├── prob_subtype_acc.npz │ └── ind_trial_results.npz ├── gpt3_letterstring_results_noprompt.npz ├── gpt3_letterstring_results_sentence.npz ├── letterstring_analysis.R ├── README.md ├── corr_analysis.py ├── create_regression_dsets.py ├── eval_GPT3_letterstring_prob.py ├── compare_behavior_gpt3.py ├── gen_problems.py └── analyze_gpt3_letterstring.py ├── story_analogies ├── human_vs_gpt3_analysis.R ├── README.md ├── gpt4_data.csv ├── analyze_GPT3_story_analogies.py ├── ind_cond_analyses.py ├── eval_GPT3_story_analogies.py └── human_vs_gpt3_data.csv ├── LICENSE └── README.md /UCLA_VAT/UCLA_VAT.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/UCLA_VAT/UCLA_VAT.xlsx -------------------------------------------------------------------------------- /digit_mat/all_problems.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/all_problems.npz -------------------------------------------------------------------------------- /letter_string/all_prob.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/all_prob.npz -------------------------------------------------------------------------------- /UCLA_VAT/UCLA_VAT_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/UCLA_VAT/UCLA_VAT_results.npz -------------------------------------------------------------------------------- /digit_mat/all_problems_1thru5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/all_problems_1thru5.npz -------------------------------------------------------------------------------- /digit_mat/gpt_matprob_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/gpt_matprob_results.npz -------------------------------------------------------------------------------- /UCLA_VAT/UCLA_VAT_ind_subj_data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/UCLA_VAT/UCLA_VAT_ind_subj_data.xlsx -------------------------------------------------------------------------------- /digit_mat/all_4_5_rule_problems.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/all_4_5_rule_problems.npz -------------------------------------------------------------------------------- /digit_mat/N_unique_rules_2rule_prob.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/N_unique_rules_2rule_prob.npz -------------------------------------------------------------------------------- /digit_mat/N_unique_rules_3rule_prob.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/N_unique_rules_3rule_prob.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/all_prob.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_prob.npz -------------------------------------------------------------------------------- /digit_mat/gpt_matprob_results_1thru5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/gpt_matprob_results_1thru5.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results/onegen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/onegen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results/all_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/all_gen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results/realworld_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/realworld_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results/zerogen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/zerogen_acc.npz -------------------------------------------------------------------------------- /letter_string/gpt3_letterstring_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/gpt3_letterstring_results.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/all_onerule_MC_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_onerule_MC_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/all_onerule_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_onerule_gen_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/all_probcat_MC_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_probcat_MC_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/all_probcat_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/all_probcat_gen_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp2_GPT3_data/all_onerule_MC_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_onerule_MC_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp2_GPT3_data/all_onerule_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_onerule_gen_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp2_GPT3_data/all_probcat_MC_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_probcat_MC_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp2_GPT3_data/all_probcat_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_GPT3_data/all_probcat_gen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results/ind_trial_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/ind_trial_results.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results/prob_subtype_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results/prob_subtype_acc.npz -------------------------------------------------------------------------------- /letter_string/behavioral_results/all_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/all_gen_acc.npz -------------------------------------------------------------------------------- /letter_string/behavioral_results/onegen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/onegen_acc.npz -------------------------------------------------------------------------------- /letter_string/behavioral_results/zerogen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/zerogen_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/ind_subj_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/ind_subj_results.npz -------------------------------------------------------------------------------- /digit_mat/exp2_behavioral_data/ind_subj_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/ind_subj_results.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_noprompt/all_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/all_gen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_noprompt/onegen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/onegen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_noprompt/zerogen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/zerogen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_sentence/all_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/all_gen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_sentence/onegen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/onegen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_sentence/zerogen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/zerogen_acc.npz -------------------------------------------------------------------------------- /letter_string/behavioral_results/realworld_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/realworld_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_noprompt/realworld_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/realworld_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_sentence/realworld_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/realworld_acc.npz -------------------------------------------------------------------------------- /letter_string/behavioral_results/ind_subj_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/ind_subj_results.npz -------------------------------------------------------------------------------- /letter_string/behavioral_results/prob_subtype_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/behavioral_results/prob_subtype_acc.npz -------------------------------------------------------------------------------- /letter_string/gpt3_letterstring_results_noprompt.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/gpt3_letterstring_results_noprompt.npz -------------------------------------------------------------------------------- /letter_string/gpt3_letterstring_results_sentence.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/gpt3_letterstring_results_sentence.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/aligned_vs_permuted_MC_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/aligned_vs_permuted_MC_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/aligned_vs_permuted_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/aligned_vs_permuted_gen_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_noprompt/prob_subtype_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/prob_subtype_acc.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_sentence/prob_subtype_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/prob_subtype_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_MC.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_MC.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/tworule_prog_vs_noprog_gen_acc.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/tworule_prog_vs_noprog_gen_acc.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior.npz -------------------------------------------------------------------------------- /digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior.npz -------------------------------------------------------------------------------- /digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_noprompt/ind_trial_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_noprompt/ind_trial_results.npz -------------------------------------------------------------------------------- /letter_string/GPT3_results_sentence/ind_trial_results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/letter_string/GPT3_results_sentence/ind_trial_results.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_MC.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_MC.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_gen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/tworule_prob_N_unique_rules_gen.npz -------------------------------------------------------------------------------- /digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_gen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_GPT3_data/threerule_prob_N_unique_rules_gen.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior_onerule.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_MC_acc_behavior_onerule.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_MC.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_MC.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_gen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/tworule_prob_N_unique_rules_gen.npz -------------------------------------------------------------------------------- /digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior_onerule.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_MC_acc_behavior_onerule.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_onerule.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_onerule.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_MC.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_MC.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_gen.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/threerule_prob_N_unique_rules_gen.npz -------------------------------------------------------------------------------- /digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior_onerule.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp2_behavioral_data/probcat_gen_acc_behavior_onerule.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/aligned_vs_permuted_MC_acc_behavior.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/aligned_vs_permuted_MC_acc_behavior.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/aligned_vs_permuted_gen_acc_behavior.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/aligned_vs_permuted_gen_acc_behavior.npz -------------------------------------------------------------------------------- /digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_prog_tworule.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taylorwwebb/emergent_analogies_LLM/HEAD/digit_mat/exp1_behavioral_data/probcat_gen_acc_behavior_prog_tworule.npz -------------------------------------------------------------------------------- /story_analogies/human_vs_gpt3_analysis.R: -------------------------------------------------------------------------------- 1 | setwd("./") 2 | 3 | # Human vs. GPT-3 4 | data <-read.csv("./human_vs_gpt3_data.csv") 5 | human_vs_gpt3_model <- glm(correct_pred ~ human_vs_gpt, data=data, family="binomial") 6 | summary(human_vs_gpt3_model) 7 | human_vs_gpt3_OR <- exp(cbind(OR = coef(human_vs_gpt3_model), confint(human_vs_gpt3_model))) 8 | summary(human_vs_gpt3_OR) 9 | 10 | 11 | -------------------------------------------------------------------------------- /UCLA_VAT/README.md: -------------------------------------------------------------------------------- 1 | ## UCLA Verbal Analogy Test 2 | 3 | To evaluate GPT-3 on UCLA VAT, run: 4 | ``` 5 | python3 ./eval_gpt_UCLA_VAT.py 6 | ``` 7 | Note that you will need to enter your OpenAI API key (line 8). 8 | 9 | To analyze GPT-3's responses and compare with human behavior, run: 10 | ``` 11 | python3 ./analyze_UCLA_VAT.py 12 | ``` 13 | Note that results for human participants and GPT-3 are already included in this repository. 14 | -------------------------------------------------------------------------------- /digit_mat/exp1_vs_exp2_stats.r: -------------------------------------------------------------------------------- 1 | setwd("./") 2 | data <-read.csv("./exp1_vs_exp2_all_data.csv") 3 | 4 | # Generative task - problem type X experiment (progressive vs. random order) interaction 5 | # Human only 6 | human_gen <- glm(gen_correct_pred ~ prob_type + exp1_vs_exp2 + prob_type:exp1_vs_exp2, data=subset(data, human_vs_gpt==0), family="binomial") 7 | summary(human_gen) 8 | # GPT-3 only 9 | GPT3_gen <- glm(gen_correct_pred ~ prob_type + exp1_vs_exp2 + prob_type:exp1_vs_exp2, data=subset(data, human_vs_gpt==1), family="binomial") 10 | summary(GPT3_gen) 11 | -------------------------------------------------------------------------------- /digit_mat/exp1_corr_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | 4 | # Load human data 5 | human_prob_acc = np.load('./exp1_behavioral_data/ind_subj_results.npz')['all_subj_gen_correct_pred'].mean(0) 6 | gpt3_prob_acc = np.load('./exp1_GPT3_data/all_prob.npz')['all_gen'].reshape((-1,32)).astype(float).mean(0) 7 | 8 | # Correlation analysis 9 | corr_results = scipy.stats.pearsonr(gpt3_prob_acc, human_prob_acc) 10 | print('correlation analysis:') 11 | print('r = ' + str(np.around(corr_results[0],4))) 12 | print('p = ' + str(np.around(corr_results[1],4))) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Taylor Webb licenses this file to You under the Apache License, Version 2.0 (the "License"); 2 | you may not use this file except in compliance with the License. You may obtain a copy of the License at: 3 | 4 | http://www.apache.org/licenses/LICENSE-2.0 5 | 6 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed 7 | on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8 | See the License for the specific language governing permissions and limitations under the License. 9 | -------------------------------------------------------------------------------- /story_analogies/README.md: -------------------------------------------------------------------------------- 1 | ## Story Analogies 2 | 3 | To perform individual analyses for text-davinci-003, GPT-4, and human participants, run: 4 | ``` 5 | python3 ./ind_cond_analyses.py 6 | ``` 7 | To perform analysis comparing human and GPT-3 performance, run the following R script: 8 | ``` 9 | ./human_vs_gpt3_analysis.R 10 | ``` 11 | 12 | Original story materials can be downloaded from [http://cvl.psych.ucla.edu/resources/AnalogyInventory.zip](http://cvl.psych.ucla.edu/resources/AnalogyInventory.zip) (file name: ```Cognitive Psychology.xlsx```, sheet name: ```Rattermann```). 13 | -------------------------------------------------------------------------------- /letter_string/letterstring_analysis.R: -------------------------------------------------------------------------------- 1 | setwd("./") 2 | library(dplyr) 3 | library(Matrix) 4 | library(lme4) 5 | data <-read.csv("./letterstring_data.csv") 6 | 7 | # Omnibus model 8 | model <- glm(correct_pred ~ human_vs_gpt + N_gen + human_vs_gpt:N_gen, data=data, family="binomial") 9 | summary(model) 10 | 11 | # Zero-generalization problems 12 | zerogen_model <- glm(correct_pred ~ human_vs_gpt, data=subset(data, N_gen==0), family="binomial") 13 | summary(zerogen_model) 14 | 15 | # Effect of number of generalizations 16 | # Human 17 | model <- glm(correct_pred ~ N_gen, data=subset(data, human_vs_gpt==0), family="binomial") 18 | summary(model) 19 | # GPT-3 20 | model <- glm(correct_pred ~ N_gen, data=subset(data, human_vs_gpt==1), family="binomial") 21 | summary(model) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Emergent Analogical Reasoning in Large Language Models 2 | 3 | Code for the paper [Emergent Analogical Reasoning in Large Language Models](https://arxiv.org/abs/2212.09196). 4 | 5 | Code for each set of experiments, along with detailed instructions, are contained within separate directories. 6 | 7 | ## Prerequisites 8 | 9 | - Python 3 10 | - [OpenAI Python Library](https://github.com/openai/openai-python) 11 | - [NumPy](https://numpy.org/) 12 | - [SciPy](https://scipy.org/) 13 | - [statsmodels](https://www.statsmodels.org/stable/index.html) 14 | - [Matplotlib](https://matplotlib.org/) 15 | - [pandas](https://pandas.pydata.org/) 16 | - [R](https://www.r-project.org/) 17 | 18 | 19 | ## Authorship 20 | 21 | All code was written by [Taylor Webb](https://github.com/taylorwwebb). 22 | -------------------------------------------------------------------------------- /story_analogies/gpt4_data.csv: -------------------------------------------------------------------------------- 1 | prob_ID,correct_pred,correct_AB,analogy_vs_similarity 2 | 1,1,0,0 3 | 2,1,0,0 4 | 3,1,0,0 5 | 4,1,0,0 6 | 5,1,0,0 7 | 6,1,0,0 8 | 7,1,0,0 9 | 8,0,0,0 10 | 9,1,0,0 11 | 10,0,0,0 12 | 11,1,0,0 13 | 12,1,0,0 14 | 13,1,0,0 15 | 14,1,0,0 16 | 15,1,0,0 17 | 16,1,0,0 18 | 17,1,0,0 19 | 18,1,0,0 20 | 1,0,1,0 21 | 2,1,1,0 22 | 3,0,1,0 23 | 4,1,1,0 24 | 5,1,1,0 25 | 6,1,1,0 26 | 7,0,1,0 27 | 8,0,1,0 28 | 9,1,1,0 29 | 10,0,1,0 30 | 11,1,1,0 31 | 12,1,1,0 32 | 13,0,1,0 33 | 14,1,1,0 34 | 15,0,1,0 35 | 16,1,1,0 36 | 17,1,1,0 37 | 18,1,1,0 38 | 1,1,0,1 39 | 2,1,0,1 40 | 3,1,0,1 41 | 4,1,0,1 42 | 5,1,0,1 43 | 6,1,0,1 44 | 7,1,0,1 45 | 8,1,0,1 46 | 9,1,0,1 47 | 10,1,0,1 48 | 11,1,0,1 49 | 12,1,0,1 50 | 13,1,0,1 51 | 14,1,0,1 52 | 15,1,0,1 53 | 16,1,0,1 54 | 17,1,0,1 55 | 18,1,0,1 56 | 1,1,1,1 57 | 2,1,1,1 58 | 3,1,1,1 59 | 4,1,1,1 60 | 5,1,1,1 61 | 6,1,1,1 62 | 7,1,1,1 63 | 8,0,1,1 64 | 9,1,1,1 65 | 10,1,1,1 66 | 11,1,1,1 67 | 12,1,1,1 68 | 13,1,1,1 69 | 14,1,1,1 70 | 15,1,1,1 71 | 16,1,1,1 72 | 17,1,1,1 73 | 18,1,1,1 -------------------------------------------------------------------------------- /letter_string/README.md: -------------------------------------------------------------------------------- 1 | ## Letter String Analogies 2 | 3 | To create new letter string problems, run: 4 | ``` 5 | python3 ./gen_problems.py 6 | ``` 7 | Problems are contained in: 8 | ``` 9 | ./all_prob.npz 10 | ``` 11 | To evaluate GPT-3 on letter string problems, run: 12 | ``` 13 | python3 ./eval_GPT3_letterstring_prob.py 14 | ``` 15 | Note that you will need to enter your OpenAI API key (line 15). 16 | 17 | To analyze GPT-3's responses, run: 18 | ``` 19 | python3 ./analyze_gpt3_letterstring.py 20 | ``` 21 | To plot figures comparing GPT-3 with human performance, run: 22 | ``` 23 | python3 ./compare_behavior_gpt3.py 24 | ``` 25 | To perform regression analyses, run: 26 | ``` 27 | python3 ./create_regression_dsets.py 28 | ``` 29 | and run the following R script: 30 | ``` 31 | ./letterstring_analysis.R 32 | ``` 33 | To perform correlation analyses, run: 34 | ``` 35 | python3 ./corr_analysis.py 36 | ``` 37 | To evaluate GPT-3 on letter string problems presented in alternative formats (without a prompt, or in the form of a setence), use the ```--noprompt``` or ```--sentence``` arguments for these scripts. 38 | 39 | Note that results for human participants and GPT-3 are already included in this repository. 40 | -------------------------------------------------------------------------------- /letter_string/corr_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import scipy.stats 5 | 6 | # Settings 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.") 9 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.") 10 | args = parser.parse_args() 11 | 12 | # Load data 13 | human_data = np.load('./behavioral_results/prob_subtype_acc.npz') 14 | human_subtype_acc = human_data['subtype_acc'] 15 | human_subtype_counts = human_data['subtype_counts'] 16 | if args.sentence: 17 | gpt3_data = np.load('./GPT3_results_sentence/prob_subtype_acc.npz') 18 | elif args.noprompt: 19 | gpt3_data = np.load('./GPT3_results_noprompt/prob_subtype_acc.npz') 20 | else: 21 | gpt3_data = np.load('./GPT3_results/prob_subtype_acc.npz') 22 | gpt3_subtype_acc = gpt3_data['subtype_acc'] 23 | gpt3_subtype_counts = gpt3_data['subtype_counts'] 24 | 25 | # Minimum number of trials per subtype 26 | min_trials = 5 27 | include = np.all(np.stack([(human_subtype_counts > min_trials),(gpt3_subtype_counts > min_trials)]),0) 28 | print('minimum number of trials per subtype = ' + str(min_trials)) 29 | print(str(include.sum()) + ' out of ' + str(len(include)) + ' subtypes included') 30 | 31 | # Correlation analysis 32 | corr_results = scipy.stats.pearsonr(gpt3_subtype_acc[include], human_subtype_acc[include]) 33 | print('correlation analysis:') 34 | print('r = ' + str(np.around(corr_results[0],4))) 35 | print('p = ' + str(np.around(corr_results[1],10))) -------------------------------------------------------------------------------- /digit_mat/README.md: -------------------------------------------------------------------------------- 1 | ## Digit Matrices 2 | 3 | To create new digit matrix problems, run: 4 | ``` 5 | python3 ./gen_problems.py 6 | python3 ./gen_4_5_rule_problems.py 7 | python3 ./combine_problems_1thru5.py 8 | ``` 9 | Problems for experiment #1 (1-3 rule and logic problems, results in Figure 3 of paper) are contained in: 10 | ``` 11 | ./all_problems.npz 12 | ``` 13 | and problems for experiment #2 (1-5 rule problems, results in Supplementary Figure S4) are contained in: 14 | ``` 15 | ./all_problems_1thru5.npz 16 | ``` 17 | 18 | To evaluate GPT-3 on experiment #1, run: 19 | ``` 20 | python3 ./eval_gpt_matprob.py 21 | ``` 22 | Note that you will need to enter your OpenAI API key (line 14). 23 | 24 | To evaluate GPT-3 on experiment #2, run: 25 | ``` 26 | python3 ./eval_gpt_matprob_prog_1thru5.py 27 | ``` 28 | To analyze GPT-3's responses on these experiments, run: 29 | ``` 30 | python3 ./analyze_gpt3_exp1.py 31 | python3 ./analyze_gpt3_exp2.py 32 | ``` 33 | To plot figures comparing GPT-3 with human performance, run: 34 | ``` 35 | python3 ./exp1_plot_GPT3_vs_human.py 36 | python3 ./exp2_plot_GPT3_vs_human.py 37 | ``` 38 | To perform statistical analysis for experiment #1, run: 39 | ``` 40 | python3 ./exp1_create_stats_dset.py 41 | ``` 42 | and run the following R script: 43 | ``` 44 | ./exp1_stats.r 45 | ``` 46 | To perform analysis comparing experiments #1 and #2, run: 47 | ``` 48 | python3 ./exp1_vs_exp2_create_stats_dset.py 49 | ``` 50 | and run the following R script: 51 | ``` 52 | ./exp1_vs_exp2_stats.r 53 | ``` 54 | Note that results for human participants and GPT-3 are already included in this repository. 55 | -------------------------------------------------------------------------------- /story_analogies/analyze_GPT3_story_analogies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Load data 4 | results = np.load('./gpt3_rattermann_results.npz') 5 | all_analogy_results = results['all_analogy_results'] 6 | all_analogy_correct = results['all_analogy_correct'] 7 | all_similarity_results = results['all_similarity_results'] 8 | all_similarity_correct = results['all_similarity_correct'] 9 | 10 | # Score 11 | all_analogy_correct_pred = [] 12 | for i in range(len(all_analogy_results)): 13 | print('response: ' + all_analogy_results[i]) 14 | print('correct_answer: ' + all_analogy_correct[i]) 15 | correct_pred = int(input("Correct? (correct=1, incorrect=0): ")) 16 | print(' ') 17 | all_analogy_correct_pred.append(correct_pred) 18 | all_similarity_correct_pred = [] 19 | for i in range(len(all_similarity_results)): 20 | print('response: ' + all_similarity_results[i]) 21 | print('correct_answer: ' + all_similarity_correct[i]) 22 | correct_pred = int(input("Correct? (correct=1, incorrect=0): ")) 23 | print('') 24 | all_similarity_correct_pred.append(correct_pred) 25 | 26 | # Report accuracy 27 | analogy_acc = np.mean(all_analogy_correct_pred) 28 | print('analogy acc. = ' + str(np.around(analogy_acc,4))) 29 | similarity_acc = np.mean(all_similarity_correct_pred) 30 | print('similarity acc. = ' + str(np.around(similarity_acc,4))) 31 | 32 | # Save 33 | np.savez('./gpt3_rattermann_results.npz', all_analogy_results=all_analogy_results, all_analogy_correct=all_analogy_correct, all_analogy_correct_pred=all_analogy_correct_pred, 34 | all_similarity_results=all_similarity_results, all_similarity_correct=all_similarity_correct, all_similarity_correct_pred=all_similarity_correct_pred) 35 | 36 | -------------------------------------------------------------------------------- /digit_mat/exp1_stats.r: -------------------------------------------------------------------------------- 1 | # Load data 2 | setwd("./") 3 | data <-read.csv("./exp1_all_data.csv") 4 | 5 | # Overall generative performance 6 | all_gen <- glm(gen_correct_pred ~ prob_type + human_vs_gpt + prob_type:human_vs_gpt, data=data, family="binomial") 7 | summary(all_gen) 8 | # Overall multiple-choice performance 9 | all_MC <- glm(MC_correct_pred ~ prob_type + human_vs_gpt + prob_type:human_vs_gpt, data=data, family="binomial") 10 | summary(all_MC) 11 | 12 | # Two-rule problems, generative performance, progression rule vs. no progression rule 13 | # Human only 14 | human_prog_vs_noprog <- glm(gen_correct_pred ~ tworule_prog_noprog, data=subset(subset(data, prob_type==1), human_vs_gpt==0), family="binomial") 15 | summary(human_prog_vs_noprog) 16 | # GPT-3 only 17 | GPT3_prog_vs_noprog <- glm(gen_correct_pred ~ tworule_prog_noprog, data=subset(subset(data, prob_type==1), human_vs_gpt==1), family="binomial") 18 | summary(GPT3_prog_vs_noprog) 19 | 20 | # Analysis of relational complexity 21 | # Human only 22 | human_relcompl <- glm(gen_correct_pred ~ N_unique_rules, data=subset(subset(data, prob_type==2), human_vs_gpt==0), family="binomial") 23 | summary(human_relcompl) 24 | # GPT-3 only 25 | GPT3_relcompl <- glm(gen_correct_pred ~ N_unique_rules, data=subset(subset(data, prob_type==2), human_vs_gpt==1), family="binomial") 26 | summary(GPT3_relcompl) 27 | 28 | # Aligned vs. permuted logic problems 29 | # Human only 30 | human_aligned_permute <- glm(gen_correct_pred ~ aligned_permuted, data=subset(subset(data, prob_type==3), human_vs_gpt==0), family="binomial") 31 | summary(human_aligned_permute) 32 | # GPT-3 only 33 | GPT3_aligned_permute <- glm(gen_correct_pred ~ aligned_permuted, data=subset(subset(data, prob_type==3), human_vs_gpt==1), family="binomial") 34 | summary(GPT3_aligned_permute) 35 | 36 | 37 | -------------------------------------------------------------------------------- /letter_string/create_regression_dsets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | import builtins 4 | 5 | # Load human behavioral data 6 | human_data = np.load('./behavioral_results/ind_subj_results.npz') 7 | human_correct_pred = human_data['all_subj_correct_pred'] 8 | human_prob_subtype = human_data['all_subj_prob_subtype'] 9 | human_N_gen = human_data['all_subj_N_gen'] 10 | human_realworld = human_data['all_subj_realworld'] 11 | 12 | # Load GPT-3 data 13 | gpt3_data = np.load('./GPT3_results/ind_trial_results.npz') 14 | gpt3_correct_pred = gpt3_data['all_prob_type_correct_pred'] 15 | gpt3_prob_subtype = gpt3_data['all_prob_type_subtype'] 16 | gpt3_N_gen = gpt3_data['all_prob_N_gen'] 17 | gpt3_realworld = gpt3_data['all_prob_realworld'] 18 | 19 | # Create dataset 20 | subjID = [] 21 | correct_pred = [] 22 | human_vs_gpt = [] 23 | prob_subtype = [] 24 | N_gen = [] 25 | realworld = [] 26 | 27 | # Add human data 28 | N_subj = human_correct_pred.shape[0] 29 | N_prob = human_correct_pred.shape[1] 30 | for s in range(N_subj): 31 | for p in range(N_prob): 32 | subjID.append(s) 33 | correct_pred.append(int(human_correct_pred[s,p])) 34 | human_vs_gpt.append(0) 35 | N_gen.append(human_N_gen[s,p]) 36 | realworld.append(human_realworld[s,p]) 37 | if p < human_prob_subtype.shape[1]: 38 | prob_subtype.append(human_prob_subtype[s,p]) 39 | else: 40 | prob_subtype.append(-1) 41 | 42 | # Add GPT-3 data 43 | N_trials_per_prob = gpt3_correct_pred.shape[1] 44 | for t in range(N_trials_per_prob): 45 | for p in range(N_prob): 46 | subjID.append(s+1) 47 | correct_pred.append(int(gpt3_correct_pred[p,t])) 48 | human_vs_gpt.append(1) 49 | N_gen.append(gpt3_N_gen[p,t]) 50 | realworld.append(gpt3_realworld[p,t]) 51 | if p < gpt3_prob_subtype.shape[0]: 52 | prob_subtype.append(gpt3_prob_subtype[p,t]) 53 | else: 54 | prob_subtype.append(-1) 55 | 56 | # Write csv files 57 | # Create file 58 | f = open('./letterstring_data.csv', 'w') 59 | writer = csv.writer(f) 60 | # Header 61 | header = ['subjID', 'correct_pred', 'human_vs_gpt', 'prob_subtype', 'N_gen'] 62 | writer.writerow(header) 63 | # Write data 64 | for i in range(len(subjID)): 65 | if realworld[i] == 0: 66 | data_row = [subjID[i], correct_pred[i], human_vs_gpt[i], prob_subtype[i], N_gen[i]] 67 | writer.writerow(data_row) 68 | # Close file 69 | f.close() 70 | -------------------------------------------------------------------------------- /letter_string/eval_GPT3_letterstring_prob.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import numpy as np 3 | import builtins 4 | import argparse 5 | import os 6 | import time 7 | 8 | # Settings 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.") 11 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.") 12 | args = parser.parse_args() 13 | 14 | # GPT-3 settings 15 | openai.api_key = "FILL_IN_API_KEY_HERE" 16 | if args.sentence: 17 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":40, "echo":False, "logprobs":1, } 18 | else: 19 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":40, "stop":"\n", "echo":False, "logprobs":1, } 20 | 21 | # Load all problems 22 | all_prob = np.load('./all_prob.npz', allow_pickle=True)['all_prob'] 23 | prob_types = builtins.list(all_prob.item().keys()) 24 | N_prob_types = len(prob_types) 25 | 26 | # Evaluate 27 | N_trials_per_prob_type = 50 28 | all_prob_type_responses = [] 29 | for p in range(N_prob_types): 30 | print('problem type' + str(p+1) + ' of ' + str(N_prob_types) + '...') 31 | prob_type_responses = [] 32 | for t in range(N_trials_per_prob_type): 33 | print('trial ' + str(t+1) + ' of ' + str(N_trials_per_prob_type) + '...') 34 | # Generate prompt 35 | prob = all_prob.item()[prob_types[p]]['prob'][t] 36 | prompt = '' 37 | if not args.noprompt: 38 | prompt += "Let's try to complete the pattern:\n\n" 39 | if args.sentence: 40 | prompt += 'If ' 41 | for i in range(len(prob[0][0])): 42 | prompt += str(prob[0][0][i]) 43 | if i < len(prob[0][0]) - 1: 44 | prompt += ' ' 45 | prompt += ' changes to ' 46 | for i in range(len(prob[0][1])): 47 | prompt += str(prob[0][1][i]) 48 | if i < len(prob[0][1]) - 1: 49 | prompt += ' ' 50 | prompt += ', then ' 51 | for i in range(len(prob[1][0])): 52 | prompt += str(prob[1][0][i]) 53 | if i < len(prob[1][0]) - 1: 54 | prompt += ' ' 55 | prompt += ' should change to ' 56 | else: 57 | prompt += '[' 58 | for i in range(len(prob[0][0])): 59 | prompt += str(prob[0][0][i]) 60 | if i < len(prob[0][0]) - 1: 61 | prompt += ' ' 62 | prompt += '] [' 63 | for i in range(len(prob[0][1])): 64 | prompt += str(prob[0][1][i]) 65 | if i < len(prob[0][1]) - 1: 66 | prompt += ' ' 67 | prompt += ']\n[' 68 | for i in range(len(prob[1][0])): 69 | prompt += str(prob[1][0][i]) 70 | if i < len(prob[1][0]) - 1: 71 | prompt += ' ' 72 | prompt += '] [' 73 | # Get response 74 | response = [] 75 | while len(response) == 0: 76 | try: 77 | response = openai.Completion.create(prompt=prompt, **kwargs) 78 | except: 79 | print('trying again...') 80 | time.sleep(5) 81 | prob_type_responses.append(response['choices'][0]['text']) 82 | all_prob_type_responses.append(prob_type_responses) 83 | # Save 84 | save_fname = './gpt3_letterstring_results' 85 | if args.sentence: 86 | save_fname += '_sentence' 87 | if args.noprompt: 88 | save_fname += '_noprompt' 89 | save_fname += '.npz' 90 | np.savez(save_fname, all_prob_type_responses=all_prob_type_responses, allow_pickle=True) 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /story_analogies/ind_cond_analyses.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | from scipy.stats import binomtest, ttest_1samp 4 | 5 | # Human vs. GPT-3 data 6 | human_data_file = open('./human_vs_gpt3_data.csv') 7 | csvreader = csv.reader(human_data_file) 8 | header = [] 9 | header = np.array(next(csvreader)) 10 | rows = [] 11 | for row in csvreader: 12 | rows.append(row) 13 | rows = np.array(rows).astype(int) 14 | 15 | # Human data analyses 16 | human_vs_gpt = np.where(header == 'human_vs_gpt')[0][0] 17 | human_data = rows[rows[:,human_vs_gpt]==0, :] 18 | human_acc = [] 19 | subjID = np.where(header == 'subjID')[0][0] 20 | N_subj = np.max(np.unique(human_data[:,subjID])) 21 | for s in range(N_subj): 22 | subj_data = human_data[human_data[:,subjID] == s+1, :] 23 | correct_pred = np.where(header == 'correct_pred')[0][0] 24 | subj_acc = np.mean(subj_data[:,correct_pred]) 25 | human_acc.append(subj_acc) 26 | print('\nHuman:') 27 | print('Combined:') 28 | print(ttest_1samp(human_acc, 0.5)) 29 | # Near analogy 30 | analogy_vs_similarity = np.where(header == 'analogy_vs_similarity')[0][0] 31 | human_similarity_data = human_data[human_data[:,analogy_vs_similarity]==1, :] 32 | human_similarity_acc = [] 33 | for s in range(N_subj): 34 | subj_data = human_similarity_data[human_similarity_data[:,subjID] == s+1, :] 35 | subj_acc = np.mean(subj_data[:,correct_pred]) 36 | human_similarity_acc.append(subj_acc) 37 | print('Near analogy:') 38 | print(ttest_1samp(human_similarity_acc, 0.5)) 39 | # Far analogy 40 | human_analogy_data = human_data[human_data[:,analogy_vs_similarity]==0, :] 41 | human_analogy_acc = [] 42 | for s in range(N_subj): 43 | subj_data = human_analogy_data[human_analogy_data[:,subjID] == s+1, :] 44 | subj_acc = np.mean(subj_data[:,correct_pred]) 45 | human_analogy_acc.append(subj_acc) 46 | print('Far analogy:') 47 | print(ttest_1samp(human_analogy_acc, 0.5)) 48 | 49 | # GPT-3 data analyses 50 | gpt3_data = rows[rows[:,human_vs_gpt]==1, :] 51 | print('\nGPT-3:') 52 | print('Combined:') 53 | print(binomtest(gpt3_data[:,correct_pred].sum(), n=gpt3_data.shape[0], p=0.5)) 54 | # Near analogy 55 | gpt3_similarity_data = gpt3_data[gpt3_data[:,analogy_vs_similarity]==1, :] 56 | print('Near analogy:') 57 | print(binomtest(gpt3_similarity_data[:,correct_pred].sum(), n=gpt3_similarity_data.shape[0], p=0.5)) 58 | # Far analogy 59 | gpt3_analogy_data = gpt3_data[gpt3_data[:,analogy_vs_similarity]==0, :] 60 | print('Far analogy:') 61 | print(binomtest(gpt3_analogy_data[:,correct_pred].sum(), n=gpt3_analogy_data.shape[0], p=0.5)) 62 | 63 | # GPT-4 data 64 | human_data_file = open('./gpt4_data.csv') 65 | csvreader = csv.reader(human_data_file) 66 | header = [] 67 | header = np.array(next(csvreader)) 68 | rows = [] 69 | for row in csvreader: 70 | rows.append(row) 71 | rows = np.array(rows).astype(int) 72 | 73 | # GPT-4 data analyses 74 | gpt4_data = rows 75 | correct_pred = np.where(header == 'correct_pred')[0][0] 76 | print('\nGPT-4:') 77 | print('Combined:') 78 | print(binomtest(gpt4_data[:,correct_pred].sum(), n=gpt4_data.shape[0], p=0.5)) 79 | # Near analogy 80 | analogy_vs_similarity = np.where(header == 'analogy_vs_similarity')[0][0] 81 | gpt4_similarity_data = gpt4_data[gpt4_data[:,analogy_vs_similarity]==1, :] 82 | print('Near analogy:') 83 | print(binomtest(gpt4_similarity_data[:,correct_pred].sum(), n=gpt4_similarity_data.shape[0], p=0.5)) 84 | # Far analogy 85 | gpt4_analogy_data = gpt4_data[gpt4_data[:,analogy_vs_similarity]==0, :] 86 | print('Far analogy:') 87 | print(binomtest(gpt4_analogy_data[:,correct_pred].sum(), n=gpt4_analogy_data.shape[0], p=0.5)) 88 | print(' ') 89 | 90 | -------------------------------------------------------------------------------- /UCLA_VAT/eval_gpt_UCLA_VAT.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import numpy as np 3 | import pandas as pd 4 | import builtins 5 | import time 6 | 7 | # GPT-3 settings 8 | openai.api_key = "FILL_IN_API_KEY_HERE" 9 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":10, "stop":"\n", "echo":True, "logprobs":1, } 10 | 11 | # Load problems 12 | df = pd.read_excel (r'./UCLA_VAT.xlsx', sheet_name='UCLA_VAT') 13 | # Extract data 14 | A = builtins.list(df['A']) 15 | B = builtins.list(df['B']) 16 | C = builtins.list(df['C']) 17 | D = builtins.list(df['D']) 18 | D_prime = builtins.list(df["D'"]) 19 | 20 | # Initialize storage for results 21 | all_synonym_correct_pred = [] 22 | all_opposite_correct_pred = [] 23 | all_function_correct_pred = [] 24 | all_category_correct_pred = [] 25 | results_fname = './UCLA_VAT_results.npz' 26 | # Evaluate 27 | N_prob = len(A) 28 | prob_order = np.arange(N_prob) 29 | np.random.shuffle(prob_order) 30 | for p in range(N_prob): 31 | print(str(p+1) + ' of ' + str(N_prob) + '...') 32 | if p == 0: 33 | prompt = A[prob_order[p]] + ' : ' + B[prob_order[p]] + ' :: ' + C[prob_order[p]] + ' : ' 34 | else: 35 | prompt = context + '\n\n' + A[prob_order[p]] + ' : ' + B[prob_order[p]] + ' :: ' + C[prob_order[p]] + ' : ' 36 | # Correct answer 37 | d_prompt = prompt + D[prob_order[p]] 38 | response = [] 39 | while len(response) == 0: 40 | try: 41 | response = openai.Completion.create(prompt=d_prompt, **kwargs) 42 | except: 43 | print('trying again...') 44 | time.sleep(5) 45 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1] 46 | if len(d_prompt) > np.array(response['choices'][0]['logprobs']['text_offset'])[0]: 47 | d_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:]) 48 | else: 49 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(d_prompt))[0][0] 50 | d_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind]) 51 | # Foil 52 | d_prime_prompt = prompt + D_prime[prob_order[p]] 53 | response = [] 54 | while len(response) == 0: 55 | try: 56 | response = openai.Completion.create(prompt=d_prime_prompt, **kwargs) 57 | except: 58 | print('trying again...') 59 | time.sleep(5) 60 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1] 61 | if len(d_prime_prompt) > np.array(response['choices'][0]['logprobs']['text_offset'])[0]: 62 | d_prime_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:]) 63 | else: 64 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(d_prime_prompt))[0][0] 65 | d_prime_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind]) 66 | # Correct 67 | correct_pred = d_avg_logprob > d_prime_avg_logprob 68 | if prob_order[p] < 20: 69 | all_synonym_correct_pred.append(correct_pred) 70 | elif prob_order[p] >= 20 and prob_order[p] < 40: 71 | all_opposite_correct_pred.append(correct_pred) 72 | elif prob_order[p] >= 40 and prob_order[p] < 60: 73 | all_function_correct_pred.append(correct_pred) 74 | elif prob_order[p] >= 60: 75 | all_category_correct_pred.append(correct_pred) 76 | # Add problem to context 77 | if correct_pred: 78 | context = d_prompt 79 | else: 80 | context = d_prime_prompt 81 | # Save results 82 | np.savez(results_fname, synonym=all_synonym_correct_pred, opposite=all_opposite_correct_pred, function=all_function_correct_pred, category=all_category_correct_pred, context=context, prob_order=prob_order, allow_pickle=True) 83 | -------------------------------------------------------------------------------- /UCLA_VAT/analyze_UCLA_VAT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from statsmodels.stats.proportion import proportion_confint 4 | from scipy.stats import sem 5 | import pandas as pd 6 | import builtins 7 | 8 | def hide_top_right(ax): 9 | ax.spines['right'].set_visible(False) 10 | ax.spines['top'].set_visible(False) 11 | ax.yaxis.set_ticks_position('left') 12 | ax.xaxis.set_ticks_position('bottom') 13 | 14 | def plot_ind_data(ax, x_points, data, ind_bar_width): 15 | max_count = 30 16 | point_unit = ind_bar_width / max_count 17 | # Plot 18 | for i in range(len(x_points)): 19 | unique_vals = np.unique(data[i]) 20 | for v in unique_vals: 21 | count = (data[i]==v).sum() 22 | span = count * point_unit 23 | x_min = x_points[i] - (span/2) 24 | x_max = x_points[i] + (span/2) 25 | x_vals = np.linspace(x_min,x_max,count) 26 | if count == 1: 27 | x_vals = np.mean([x_min,x_max]) 28 | if v == 0: 29 | y_vals = np.ones(count) * 0.005 30 | else: 31 | y_vals = np.ones(count) * v 32 | plt.scatter(x_vals, y_vals, color='black', s=0.4, clip_on=False, marker='_') 33 | return ax 34 | 35 | # Load data 36 | all_data = np.load('./UCLA_VAT_results.npz', allow_pickle=True) 37 | 38 | # Calculate accuracy and confidence intervals 39 | conditions = ['category', 'function', 'opposite', 'synonym'] 40 | N_conditions = len(conditions) 41 | all_acc = [] 42 | all_correct_pred = [] 43 | all_ci_lower = [] 44 | all_ci_upper = [] 45 | for c in range(N_conditions): 46 | correct_pred = all_data[conditions[c]] 47 | all_correct_pred.append(correct_pred) 48 | acc = correct_pred.mean() 49 | all_acc.append(acc) 50 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0]) 51 | all_ci_lower.append(ci_lower) 52 | all_ci_upper.append(ci_upper) 53 | # Combine conditions 54 | all_acc = np.array(all_acc) 55 | all_ci_lower = np.array(all_ci_lower) 56 | all_ci_upper = np.array(all_ci_upper) 57 | all_lower_err = all_acc - all_ci_lower 58 | all_upper_err = all_ci_upper - all_acc 59 | all_err = np.array([all_lower_err, all_upper_err]) 60 | # Overall accuracy 61 | all_correct_pred = np.concatenate(all_correct_pred) 62 | overall_gpt3_acc = all_correct_pred.astype(float).mean() 63 | overall_gpt3_ci_lower, overall_gpt3_ci_upper = proportion_confint(all_correct_pred.sum(), all_correct_pred.shape[0]) 64 | overall_gpt3_err = [overall_gpt3_acc - overall_gpt3_ci_lower, overall_gpt3_ci_upper - overall_gpt3_acc] 65 | 66 | # Human data 67 | df = pd.read_excel (r'./UCLA_VAT_ind_subj_data.xlsx', sheet_name='ind_subj') 68 | category_ind_subj_acc = np.array(builtins.list(df['category'])[1:]) / 100. 69 | function_ind_subj_acc = np.array(builtins.list(df['function'])[1:]) / 100. 70 | opposite_ind_subj_acc = np.array(builtins.list(df['opposite'])[1:]) / 100. 71 | synonym_ind_subj_acc = np.array(builtins.list(df['synonym'])[1:]) / 100. 72 | human_ind_subj = np.array([category_ind_subj_acc, function_ind_subj_acc, opposite_ind_subj_acc, synonym_ind_subj_acc]) 73 | human_acc = human_ind_subj.mean(1) 74 | human_err = sem(human_ind_subj,1) 75 | 76 | # Plot 77 | bar_width = 0.8 78 | ind_bar_width = bar_width / 2 79 | x_points = np.arange(N_conditions) 80 | gpt3_color = 'darkslateblue' 81 | human_color = 'powderblue' 82 | plot_fontsize = 14 83 | title_fontsize = 16 84 | ax = plt.subplot(111) 85 | plt.bar(x_points - (ind_bar_width/2), all_acc, yerr=all_err, color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 86 | plt.bar(x_points + (ind_bar_width/2), human_acc, yerr=human_err, color=human_color, edgecolor='black', width=ind_bar_width) 87 | plt.ylim([0,1]) 88 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 89 | plt.ylabel('Accuracy', fontsize=plot_fontsize) 90 | plt.xticks(x_points, ['Categorical', 'Function', 'Antonym', 'Synonym'], fontsize=12) 91 | hide_top_right(ax) 92 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False, bbox_to_anchor=(0.9,1)) 93 | plot_ind_data(ax, x_points + (ind_bar_width/2), human_ind_subj, ind_bar_width) 94 | current_xlim = ax.get_xlim() 95 | plt.plot([current_xlim[0], current_xlim[1]],[0.5,0.5],color='black',alpha=0.4) 96 | plt.xlim([current_xlim[0], current_xlim[1]]) 97 | plt.title('UCLA VAT', fontsize=title_fontsize, pad=20) 98 | ax.set_aspect(4) 99 | plt.tight_layout() 100 | plt.savefig('./UCLA_VAT_results.pdf', dpi=300, bbox_inches="tight") 101 | plt.close() 102 | -------------------------------------------------------------------------------- /story_analogies/eval_GPT3_story_analogies.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import numpy as np 3 | import builtins 4 | import time 5 | 6 | # GPT-3 settings 7 | openai.api_key = "FILL_IN_API_KEY_HERE" 8 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":256, "echo":False, "logprobs":1, } 9 | 10 | # Load problems 11 | df = pd.read_excel (r'./Rattermann.xlsx', sheet_name='Rattermann') 12 | source_story = builtins.list(df['Base'])[1:19] 13 | true_analogy = builtins.list(df['True Analogy Story'])[1:19] 14 | false_analogy = builtins.list(df['False Analogy Story'])[1:19] 15 | literal_similarity = builtins.list(df['Literally similar story'])[1:19] 16 | mere_appearance = builtins.list(df['Mere-Appearance Match'])[1:19] 17 | 18 | # Initialize results 19 | all_analogy_results = [] 20 | all_analogy_correct = [] 21 | all_similarity_results = [] 22 | all_similarity_correct = [] 23 | N_source_stories = 18 24 | for s in range(N_source_stories): 25 | print('Source story ' + str(s+1) + ' of ' + str(N_source_stories) + '...') 26 | # A. True analogy B. False analogy 27 | print('True analogy vs. false analogy') 28 | print(' ') 29 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + true_analogy[s] + '\n\nStory B: ' + false_analogy[s] + '\n\n' 30 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?' 31 | print(prompt) 32 | response = [] 33 | while len(response) == 0: 34 | try: 35 | response = openai.Completion.create(prompt=prompt, **kwargs) 36 | except: 37 | print('trying again...') 38 | time.sleep(5) 39 | all_analogy_results.append(response['choices'][0]['text']) 40 | print('response: ' + response['choices'][0]['text']) 41 | all_analogy_correct.append('A') 42 | print('correct_answer: A') 43 | print(' ') 44 | # A. False analogy B. True analogy 45 | print('False analogy vs. true analogy') 46 | print(' ') 47 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + false_analogy[s] + '\n\nStory B: ' + true_analogy[s] + '\n\n' 48 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?' 49 | print(prompt) 50 | response = [] 51 | while len(response) == 0: 52 | try: 53 | response = openai.Completion.create(prompt=prompt, **kwargs) 54 | except: 55 | print('trying again...') 56 | time.sleep(5) 57 | all_analogy_results.append(response['choices'][0]['text']) 58 | print('response: ' + response['choices'][0]['text']) 59 | all_analogy_correct.append('B') 60 | print('correct_answer: B') 61 | print(' ') 62 | # A. Literal similarity B. Mere appearance 63 | print('Literal similarity vs. mere appearance') 64 | print(' ') 65 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + literal_similarity[s] + '\n\nStory B: ' + mere_appearance[s] + '\n\n' 66 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?' 67 | print(prompt) 68 | response = [] 69 | while len(response) == 0: 70 | try: 71 | response = openai.Completion.create(prompt=prompt, **kwargs) 72 | except: 73 | print('trying again...') 74 | time.sleep(5) 75 | all_similarity_results.append(response['choices'][0]['text']) 76 | print('response: ' + response['choices'][0]['text']) 77 | all_similarity_correct.append('A') 78 | print('correct_answer: A') 79 | print(' ') 80 | # A. Mere appearance B. Literal similarity 81 | print('Mere appearance vs. Literal similarity') 82 | print(' ') 83 | prompt = 'Consider the following story:\n\nStory 1: ' + source_story[s] + '\n\nNow consider two more stories:\n\nStory A: ' + mere_appearance[s] + '\n\nStory B: ' + literal_similarity[s] + '\n\n' 84 | prompt += 'Which of Story A and Story B is a better analogy to Story 1? Is the best answer Story A, Story B, or both are equally analogous?' 85 | print(prompt) 86 | response = [] 87 | while len(response) == 0: 88 | try: 89 | response = openai.Completion.create(prompt=prompt, **kwargs) 90 | except: 91 | print('trying again...') 92 | time.sleep(5) 93 | all_similarity_results.append(response['choices'][0]['text']) 94 | print('response: ' + response['choices'][0]['text']) 95 | all_similarity_correct.append('B') 96 | print('correct_answer: B') 97 | print(' ') 98 | # Save results 99 | np.savez('./gpt3_rattermann_results.npz', all_analogy_results=all_analogy_results, all_analogy_correct=all_analogy_correct, all_similarity_results=all_similarity_results, all_similarity_correct=all_similarity_correct) 100 | 101 | -------------------------------------------------------------------------------- /digit_mat/exp2_plot_GPT3_vs_human.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | def check_path(path): 6 | if not os.path.exists(path): 7 | os.mkdir(path) 8 | 9 | def hide_top_right(ax): 10 | ax.spines['right'].set_visible(False) 11 | ax.spines['top'].set_visible(False) 12 | ax.yaxis.set_ticks_position('left') 13 | ax.xaxis.set_ticks_position('bottom') 14 | 15 | # Digit matrices 16 | # 1-rule, 2-rule, 3-rule, 4-rule, and 5-rule problems 17 | # GPT-3: progressive 18 | # Humans: progressive 19 | 20 | # Load GPT-3 data 21 | gpt3_gen_acc = np.load('./exp2_GPT3_data/all_probcat_gen_acc.npz') 22 | gpt3_MC_acc = np.load('./exp2_GPT3_data/all_probcat_MC_acc.npz') 23 | gpt3_gen_acc_mn = gpt3_gen_acc['acc'] 24 | gpt3_gen_acc_err = gpt3_gen_acc['err'] 25 | gpt3_MC_acc_mn = gpt3_MC_acc['acc'] 26 | gpt3_MC_acc_err = gpt3_MC_acc['err'] 27 | 28 | # Load human data 29 | human_gen_acc = np.load('./exp2_behavioral_data/probcat_gen_acc_behavior.npz') 30 | human_MC_acc = np.load('./exp2_behavioral_data/probcat_MC_acc_behavior.npz') 31 | human_gen_acc_mn = human_gen_acc['acc'] 32 | human_gen_acc_err = human_gen_acc['err'] 33 | human_MC_acc_mn = human_MC_acc['acc'] 34 | human_MC_acc_err = human_MC_acc['err'] 35 | 36 | # Plot settings 37 | N_conds = gpt3_MC_acc_mn.shape[0] 38 | total_bar_width = 0.8 39 | x_points = np.arange(N_conds) 40 | gpt3_color = 'darkslateblue' 41 | human_color = 'powderblue' 42 | plot_fontsize = 14 43 | title_fontsize = 16 44 | 45 | # Directory 46 | plot_dir = './exp2_GPT3_vs_human/' 47 | check_path(plot_dir) 48 | 49 | # Plot - generative 50 | ind_bar_width = total_bar_width / 2 51 | ax = plt.subplot(111) 52 | plt.bar(x_points - 0.2, gpt3_gen_acc_mn, yerr=gpt3_gen_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 53 | plt.bar(x_points + 0.2, human_gen_acc_mn, yerr=human_gen_acc_err, color=human_color, edgecolor='black', width=ind_bar_width) 54 | plt.ylim([0,1]) 55 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 56 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 57 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', '4-rule', '5-rule'], fontsize=plot_fontsize) 58 | plt.xlabel('Problem type', fontsize=plot_fontsize) 59 | hide_top_right(ax) 60 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.95,1)) 61 | results_fname = plot_dir + 'gen_gpt3_vs_human.png' 62 | ax.set_aspect(3) 63 | plt.tight_layout() 64 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 65 | plt.close() 66 | 67 | # Plot - multiple-choice 68 | ax = plt.subplot(111) 69 | plt.bar(x_points - 0.2, gpt3_MC_acc_mn, yerr=gpt3_MC_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 70 | plt.bar(x_points + 0.2, human_MC_acc_mn, yerr=human_MC_acc_err, color=human_color, edgecolor='black', width=ind_bar_width) 71 | plt.ylim([0,1]) 72 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 73 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize) 74 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', '4-rule', '5-rule'], fontsize=plot_fontsize) 75 | plt.xlabel('Problem type', fontsize=plot_fontsize) 76 | hide_top_right(ax) 77 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.95,1)) 78 | results_fname = plot_dir + 'MC_gpt3_vs_human.png' 79 | ax.set_aspect(3) 80 | plt.tight_layout() 81 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 82 | plt.close() 83 | 84 | ## Rule type analysis for one-rule problems 85 | 86 | # Load GPT-3 one-rule data 87 | gpt3_gen_acc_onerule = np.load('./exp2_GPT3_data/all_onerule_gen_acc.npz') 88 | gpt3_gen_acc_onerule_mn = gpt3_gen_acc_onerule['acc'] 89 | gpt3_gen_acc_onerule_err = gpt3_gen_acc_onerule['err'] 90 | 91 | # Load human one-rule data 92 | human_gen_acc_onerule = np.load('./exp2_behavioral_data/probcat_gen_acc_behavior_onerule.npz') 93 | human_gen_acc_onerule_mn = human_gen_acc_onerule['acc'] 94 | human_gen_acc_onerule_err = human_gen_acc_onerule['err'] 95 | 96 | # Plot settings 97 | N_conds = gpt3_gen_acc_onerule_mn.shape[0] 98 | x_points = np.arange(N_conds) 99 | ind_bar_width = total_bar_width / 2 100 | 101 | # Plot - generative 102 | ax = plt.subplot(111) 103 | plt.bar(x_points - 0.2, gpt3_gen_acc_onerule_mn, yerr=gpt3_gen_acc_onerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 104 | plt.bar(x_points + 0.2, human_gen_acc_onerule_mn, yerr=human_gen_acc_onerule_err, color=human_color, edgecolor='black', width=ind_bar_width) 105 | plt.ylim([0,1]) 106 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 107 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 108 | plt.xticks(x_points, ['Constant', 'Distribution', 'Progression'], fontsize=plot_fontsize) 109 | plt.xlabel('Rule type', fontsize=plot_fontsize) 110 | hide_top_right(ax) 111 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.65,1)) 112 | plt.title('One-rule problems', fontsize=title_fontsize) 113 | results_fname = plot_dir + 'onerule_gen_gpt3_vs_human.png' 114 | ax.set_aspect(3) 115 | plt.tight_layout() 116 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 117 | plt.close() 118 | 119 | 120 | -------------------------------------------------------------------------------- /digit_mat/analyze_gpt3_exp2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import builtins 4 | from statsmodels.stats.proportion import proportion_confint 5 | import os 6 | 7 | def check_path(path): 8 | if not os.path.exists(path): 9 | os.mkdir(path) 10 | 11 | def hide_top_right(ax): 12 | ax.spines['right'].set_visible(False) 13 | ax.spines['top'].set_visible(False) 14 | ax.yaxis.set_ticks_position('left') 15 | ax.xaxis.set_ticks_position('bottom') 16 | 17 | # Load data 18 | all_data = np.load('./gpt_matprob_results_1thru5.npz', allow_pickle=True) 19 | MC_correct_pred = all_data['all_MC_correct_pred'] 20 | gen_correct_pred = all_data['all_gen_correct_pred'] 21 | all_prob_types = builtins.list(MC_correct_pred.item().keys()) 22 | 23 | ## Analyze by major problem type 24 | correct_pred = {'combined_gen': [], 25 | 'combined_MC': [], 26 | 'one_rule_gen': [], 27 | 'one_rule_MC': [], 28 | 'two_rule_gen': [], 29 | 'two_rule_MC': [], 30 | 'three_rule_gen': [], 31 | 'three_rule_MC': [], 32 | 'four_rule_gen': [], 33 | 'four_rule_MC': [], 34 | 'five_rule_gen': [], 35 | 'five_rule_MC': []} 36 | for prob_type in all_prob_types: 37 | correct_pred['combined_gen'].append(gen_correct_pred.item()[prob_type]) 38 | correct_pred['combined_MC'].append(MC_correct_pred.item()[prob_type]) 39 | if 'constant' in prob_type or 'dist3' in prob_type or 'prog' in prob_type: 40 | correct_pred['one_rule_gen'].append(gen_correct_pred.item()[prob_type]) 41 | correct_pred['one_rule_MC'].append(MC_correct_pred.item()[prob_type]) 42 | elif 'two_rule' in prob_type: 43 | correct_pred['two_rule_gen'].append(gen_correct_pred.item()[prob_type]) 44 | correct_pred['two_rule_MC'].append(MC_correct_pred.item()[prob_type]) 45 | elif 'three_rule' in prob_type: 46 | correct_pred['three_rule_gen'].append(gen_correct_pred.item()[prob_type]) 47 | correct_pred['three_rule_MC'].append(MC_correct_pred.item()[prob_type]) 48 | elif 'four_rule' in prob_type: 49 | correct_pred['four_rule_gen'].append(gen_correct_pred.item()[prob_type]) 50 | correct_pred['four_rule_MC'].append(MC_correct_pred.item()[prob_type]) 51 | elif 'five_rule' in prob_type: 52 | correct_pred['five_rule_gen'].append(gen_correct_pred.item()[prob_type]) 53 | correct_pred['five_rule_MC'].append(MC_correct_pred.item()[prob_type]) 54 | # Convert to arrays 55 | for key in correct_pred.keys(): 56 | correct_pred[key] = np.concatenate(correct_pred[key]) 57 | # Calculate accuracy and confidence intervals 58 | all_acc = {} 59 | all_ci_lower = {} 60 | all_ci_upper = {} 61 | for key in correct_pred.keys(): 62 | all_acc[key] = correct_pred[key].mean() 63 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 64 | 65 | # Directory for saving results 66 | results_dir = './exp2_GPT3_data/' 67 | check_path(results_dir) 68 | 69 | # Save results 70 | # Generative 71 | all_gen_acc = np.array([all_acc['one_rule_gen'], all_acc['two_rule_gen'], all_acc['three_rule_gen'], all_acc['four_rule_gen'], all_acc['five_rule_gen']]) 72 | all_gen_lower_ci = np.array([all_ci_lower['one_rule_gen'], all_ci_lower['two_rule_gen'], all_ci_lower['three_rule_gen'], all_ci_lower['four_rule_gen'], all_ci_lower['five_rule_gen']]) 73 | all_gen_upper_ci = np.array([all_ci_upper['one_rule_gen'], all_ci_upper['two_rule_gen'], all_ci_upper['three_rule_gen'], all_ci_upper['four_rule_gen'], all_ci_upper['five_rule_gen']]) 74 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci 75 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc 76 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 77 | np.savez(results_dir + 'all_probcat_gen_acc.npz', acc=all_gen_acc, err=all_gen_err) 78 | # Multiple-choice 79 | all_MC_acc = np.array([all_acc['one_rule_MC'], all_acc['two_rule_MC'], all_acc['three_rule_MC'], all_acc['four_rule_MC'], all_acc['five_rule_MC']]) 80 | all_MC_lower_ci = np.array([all_ci_lower['one_rule_MC'], all_ci_lower['two_rule_MC'], all_ci_lower['three_rule_MC'], all_ci_lower['four_rule_MC'], all_ci_lower['five_rule_MC']]) 81 | all_MC_upper_ci = np.array([all_ci_upper['one_rule_MC'], all_ci_upper['two_rule_MC'], all_ci_upper['three_rule_MC'], all_ci_upper['four_rule_MC'], all_ci_upper['five_rule_MC']]) 82 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci 83 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc 84 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 85 | np.savez(results_dir + 'all_probcat_MC_acc.npz', acc=all_MC_acc, err=all_MC_err) 86 | 87 | ## Three major one-rule problem types 88 | correct_pred = {'constant_gen': [], 89 | 'constant_MC': [], 90 | 'dist3_gen': [], 91 | 'dist3_MC': [], 92 | 'prog_gen': [], 93 | 'prog_MC': []} 94 | for prob_type in all_prob_types: 95 | if 'constant' in prob_type: 96 | correct_pred['constant_gen'].append(gen_correct_pred.item()[prob_type]) 97 | correct_pred['constant_MC'].append(MC_correct_pred.item()[prob_type]) 98 | elif 'dist3' in prob_type: 99 | correct_pred['dist3_gen'].append(gen_correct_pred.item()[prob_type]) 100 | correct_pred['dist3_MC'].append(MC_correct_pred.item()[prob_type]) 101 | elif 'prog' in prob_type: 102 | correct_pred['prog_gen'].append(gen_correct_pred.item()[prob_type]) 103 | correct_pred['prog_MC'].append(MC_correct_pred.item()[prob_type]) 104 | # Convert to arrays 105 | for key in correct_pred.keys(): 106 | correct_pred[key] = np.concatenate(correct_pred[key]) 107 | # Calculate accuracy and confidence intervals 108 | all_acc = {} 109 | all_ci_lower = {} 110 | all_ci_upper = {} 111 | for key in correct_pred.keys(): 112 | all_acc[key] = correct_pred[key].mean() 113 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 114 | 115 | # Save results 116 | # Generative 117 | all_gen_acc = np.array([all_acc['constant_gen'], all_acc['dist3_gen'], all_acc['prog_gen']]) 118 | all_gen_lower_ci = np.array([all_ci_lower['constant_gen'], all_ci_lower['dist3_gen'], all_ci_lower['prog_gen']]) 119 | all_gen_upper_ci = np.array([all_ci_upper['constant_gen'], all_ci_upper['dist3_gen'], all_ci_upper['prog_gen']]) 120 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci 121 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc 122 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 123 | np.savez(results_dir + 'all_onerule_gen_acc.npz', acc=all_gen_acc, err=all_gen_err) 124 | # Multiple-choice 125 | all_MC_acc = np.array([all_acc['constant_MC'], all_acc['dist3_MC'], all_acc['prog_MC']]) 126 | all_MC_lower_ci = np.array([all_ci_lower['constant_MC'], all_ci_lower['dist3_MC'], all_ci_lower['prog_MC']]) 127 | all_MC_upper_ci = np.array([all_ci_upper['constant_MC'], all_ci_upper['dist3_MC'], all_ci_upper['prog_MC']]) 128 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci 129 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc 130 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 131 | np.savez(results_dir + 'all_onerule_MC_acc.npz', acc=all_MC_acc, err=all_MC_err) 132 | 133 | -------------------------------------------------------------------------------- /digit_mat/combine_problems_1thru5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | 4 | # Save problems as json and numpy files 5 | def save_prob(all_prob, all_answer_choices, all_correct_ind, prob_type_name, all_problems_np, all_problems_js, perm_invariant=False): 6 | # Add problems to numpy dict 7 | all_problems_np[prob_type_name] = {'prob': all_prob, 'answer_choices': all_answer_choices, 'correct_ind': all_correct_ind, 'perm_invariant': perm_invariant} 8 | # Convert to strings and save as json 9 | all_data = [] 10 | for p in range(all_prob.shape[0]): 11 | # Convert problem to string 12 | prompt = '' 13 | for r in range(3): 14 | for c in range(3): 15 | prompt += '[' 16 | if r == 2 and c == 2: 17 | for i in range(len(all_prob[p][1][c])): 18 | prompt += '  ' 19 | if i < (len(all_prob[p][1][c]) - 1): 20 | prompt += ' ' 21 | else: 22 | for i in range(len(all_prob[p][r][c])): 23 | if all_prob[p][r][c][i] == -1: 24 | prompt += '  ' 25 | else: 26 | prompt += str(all_prob[p][r][c][i]) 27 | 28 | if i < (len(all_prob[p][r][c]) - 1): 29 | prompt += ' ' 30 | prompt += ']' 31 | if c < 2: 32 | prompt += '   ' 33 | if r < 2 and c == 2: 34 | prompt += '
' 35 | # Convert choices to strings 36 | options = [] 37 | for a in range(8): 38 | option = '[' 39 | for i in range(len(all_answer_choices[p][a])): 40 | option += str(all_answer_choices[p][a][i]) 41 | if i < (len(all_answer_choices[p][a]) - 1): 42 | option += ' ' 43 | if len(all_answer_choices[p][a]) == 0: 44 | option += '  ' 45 | option += ']' 46 | options.append(option) 47 | # Add to dataset 48 | all_data.append({'prompt': prompt, 'options': options, 'correct': int(all_correct_ind[p]), 'prob_ind': p}) 49 | # Add to javascript data 50 | all_problems_js[prob_type_name] = all_data 51 | return all_problems_np, all_problems_js 52 | 53 | # Number of problems per category 54 | max_N_prob_per_cat = 10 55 | N_instances_per_prob = 100 56 | # All categories 57 | all_cat = ['one_rule', 'two_rule', 'three_rule', 'four_rule', 'five_rule'] 58 | 59 | # Load 1-3-rule problems 60 | one_3_rule_prob_fname = './all_problems.npz' 61 | one_3_rule_prob = np.load(one_3_rule_prob_fname, allow_pickle=True)['all_problems'] 62 | 63 | # Load 4-5-rule problems 64 | four_5_rule_prob_fname = './all_4_5_rule_problems.npz' 65 | four_5_rule_prob = np.load(four_5_rule_prob_fname, allow_pickle=True)['all_problems'] 66 | 67 | # Subsample problems and save as js script, also as numpy file 68 | all_problems_np = {} 69 | all_problems_js = {} 70 | # 1-rule problems 71 | one_rule_prob_names = ['row_constant', 'col_constant', 'dist3_diag1', 'dist3_diag2', 'prog_size1', 'prog_size2'] 72 | for prob_name in one_rule_prob_names: 73 | all_problems_np, all_problems_js = save_prob(one_3_rule_prob.item()[prob_name]['prob'], one_3_rule_prob.item()[prob_name]['answer_choices'], one_3_rule_prob.item()[prob_name]['correct_ind'], prob_name, all_problems_np, all_problems_js) 74 | # 2-rule problems 75 | for prob_name in one_3_rule_prob.item().keys(): 76 | if 'two_rule' in prob_name: 77 | all_problems_np, all_problems_js = save_prob(one_3_rule_prob.item()[prob_name]['prob'], one_3_rule_prob.item()[prob_name]['answer_choices'], one_3_rule_prob.item()[prob_name]['correct_ind'], prob_name, all_problems_np, all_problems_js) 78 | # 3-rule problems 79 | for prob_name in one_3_rule_prob.item().keys(): 80 | if 'three_rule' in prob_name: 81 | all_problems_np, all_problems_js = save_prob(one_3_rule_prob.item()[prob_name]['prob'], one_3_rule_prob.item()[prob_name]['answer_choices'], one_3_rule_prob.item()[prob_name]['correct_ind'], prob_name, all_problems_np, all_problems_js) 82 | # 4-rule problems 83 | all_four_rule_prob_names = [] 84 | for prob_name in four_5_rule_prob.item().keys(): 85 | if 'four_rule' in prob_name: 86 | all_four_rule_prob_names.append(prob_name) 87 | all_sampled_probs = [] 88 | for prob in range(max_N_prob_per_cat): 89 | # Sample problem instances 90 | prob_instances = [] 91 | answer_choices = [] 92 | correct_ind = [] 93 | for i in range(N_instances_per_prob): 94 | duplicate = True 95 | while duplicate: 96 | prob_type = all_four_rule_prob_names[np.floor(np.random.rand() * len(all_four_rule_prob_names)).astype(int)] 97 | prob_ind = np.floor(np.random.rand() * N_instances_per_prob).astype(int) 98 | combined_name = prob_type + '_' + str(prob_ind) 99 | if np.logical_not(np.any(combined_name == np.array(all_sampled_probs))): 100 | duplicate = False 101 | all_sampled_probs.append(combined_name) 102 | prob_instances.append(four_5_rule_prob.item()[prob_type]['prob'][prob_ind]) 103 | answer_choices.append(four_5_rule_prob.item()[prob_type]['answer_choices'][prob_ind]) 104 | correct_ind.append(four_5_rule_prob.item()[prob_type]['correct_ind'][prob_ind]) 105 | prob_instances = np.array(prob_instances) 106 | answer_choices = np.array(answer_choices) 107 | correct_ind = np.array(correct_ind) 108 | all_problems_np, all_problems_js = save_prob(prob_instances, answer_choices, correct_ind, 'four_rule_prob' + str(prob), all_problems_np, all_problems_js) 109 | # 5-rule problems 110 | all_five_rule_prob_names = [] 111 | for prob_name in four_5_rule_prob.item().keys(): 112 | if 'five_rule' in prob_name: 113 | all_five_rule_prob_names.append(prob_name) 114 | all_sampled_probs = [] 115 | for prob in range(max_N_prob_per_cat): 116 | # Sample problem instances 117 | prob_instances = [] 118 | answer_choices = [] 119 | correct_ind = [] 120 | for i in range(N_instances_per_prob): 121 | duplicate = True 122 | while duplicate: 123 | prob_type = all_five_rule_prob_names[np.floor(np.random.rand() * len(all_five_rule_prob_names)).astype(int)] 124 | prob_ind = np.floor(np.random.rand() * N_instances_per_prob).astype(int) 125 | combined_name = prob_type + '_' + str(prob_ind) 126 | if np.logical_not(np.any(combined_name == np.array(all_sampled_probs))): 127 | duplicate = False 128 | all_sampled_probs.append(combined_name) 129 | prob_instances.append(four_5_rule_prob.item()[prob_type]['prob'][prob_ind]) 130 | answer_choices.append(four_5_rule_prob.item()[prob_type]['answer_choices'][prob_ind]) 131 | correct_ind.append(four_5_rule_prob.item()[prob_type]['correct_ind'][prob_ind]) 132 | prob_instances = np.array(prob_instances) 133 | answer_choices = np.array(answer_choices) 134 | correct_ind = np.array(correct_ind) 135 | all_problems_np, all_problems_js = save_prob(prob_instances, answer_choices, correct_ind, 'five_rule_prob' + str(prob), all_problems_np, all_problems_js) 136 | # Save numpy file 137 | np_fname = './all_problems_1thru5.npz' 138 | np.savez(np_fname, all_problems=all_problems_np) 139 | # Convert to json string 140 | json_string = json.dumps(all_problems_js) 141 | # Write to js script 142 | js_fname = './all_problems_1thru5.js' 143 | js_fid = open(js_fname, 'w') 144 | js_fid.write('var all_problems = ' + json_string) 145 | js_fid.close() 146 | 147 | -------------------------------------------------------------------------------- /digit_mat/exp1_vs_exp2_create_stats_dset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | import builtins 4 | 5 | # Load human behavioral data for experiment 1 6 | human_data = np.load('./exp1_behavioral_data/ind_subj_results.npz') 7 | all_human_gen_correct_pred = human_data['all_subj_gen_correct_pred'] 8 | all_human_MC_correct_pred = human_data['all_subj_MC_correct_pred'] 9 | N_subj_exp1 = all_human_gen_correct_pred.shape[0] 10 | N_prob = all_human_gen_correct_pred.shape[1] 11 | # Create dataset 12 | subjID = [] 13 | gen_correct_pred = [] 14 | MC_correct_pred = [] 15 | human_vs_gpt = [] 16 | prob_type = [] 17 | exp1_vs_exp2 = [] 18 | onerule_subtype = [] 19 | for s in range(N_subj_exp1): 20 | for p in range(N_prob): 21 | if p < 22: 22 | # Subject ID 23 | subjID.append(s) 24 | # Correct prediction 25 | gen_correct_pred.append(all_human_gen_correct_pred[s,p]) 26 | MC_correct_pred.append(all_human_MC_correct_pred[s,p]) 27 | # Human vs. GPT-3 28 | human_vs_gpt.append(0) 29 | # Experiment 1 vs. experiment 4 30 | exp1_vs_exp2.append(0) 31 | # Problem-type specific variables 32 | # One-rule problems 33 | if p < 6: 34 | # Problem type 35 | prob_type.append(0) 36 | # One-rule subtypes 37 | if p < 2: 38 | onerule_subtype.append(0) 39 | elif p == 2 or p == 3: 40 | onerule_subtype.append(1) 41 | elif p == 4 or p == 5: 42 | onerule_subtype.append(2) 43 | # Two-rule problems 44 | elif p >= 6 and p < 12: 45 | # Problem type 46 | prob_type.append(1) 47 | # Dummy code one-rule subtypes 48 | onerule_subtype.append(-1) 49 | # Three-rule problems 50 | elif p >= 12 and p < 22: 51 | # Problem type 52 | prob_type.append(2) 53 | # Dummy code one-rule subtypes 54 | onerule_subtype.append(-1) 55 | # Load GPT-3 data for experiment 1 56 | gpt3_data = np.load('./gpt_matprob_results.npz', allow_pickle=True) 57 | prob_type_names = builtins.list(gpt3_data['all_gen_correct_pred'].item().keys()) 58 | # Add to dataset 59 | for p in range(N_prob): 60 | # Problem type name 61 | prob_type_name = prob_type_names[p] 62 | # Loop through all trials for this problem type 63 | N_trials = len(gpt3_data['all_gen_correct_pred'].item()[prob_type_name]) 64 | for t in range(N_trials): 65 | if p < 22: 66 | # Subject ID 67 | subjID.append(N_subj_exp1) 68 | # Correct prediction 69 | gen_correct_pred.append(int(gpt3_data['all_gen_correct_pred'].item()[prob_type_name][t])) 70 | MC_correct_pred.append(int(gpt3_data['all_MC_correct_pred'].item()[prob_type_name][t])) 71 | # Human vs. GPT-3 72 | human_vs_gpt.append(1) 73 | # Experiment 1 vs. experiment 4 74 | exp1_vs_exp2.append(0) 75 | # Problem-type specific variables 76 | # One-rule problems 77 | if p < 6: 78 | # Problem type 79 | prob_type.append(0) 80 | # One-rule subtypes 81 | if p < 2: 82 | onerule_subtype.append(0) 83 | elif p == 2 or p == 3: 84 | onerule_subtype.append(1) 85 | elif p == 4 or p == 5: 86 | onerule_subtype.append(2) 87 | # Two-rule problems 88 | elif p >= 6 and p < 12: 89 | # Problem type 90 | prob_type.append(1) 91 | # Dummy code one-rule subtypes 92 | onerule_subtype.append(-1) 93 | # Three-rule problems 94 | elif p >= 12 and p < 22: 95 | # Problem type 96 | prob_type.append(2) 97 | # Dummy code one-rule subtypes 98 | onerule_subtype.append(-1) 99 | 100 | # Load human behavioral data for experiment 2 101 | human_data = np.load('./exp2_behavioral_data/ind_subj_results.npz') 102 | all_human_gen_correct_pred = human_data['all_subj_gen_correct_pred'] 103 | all_human_MC_correct_pred = human_data['all_subj_MC_correct_pred'] 104 | N_subj_exp2 = all_human_gen_correct_pred.shape[0] 105 | N_prob = all_human_gen_correct_pred.shape[1] 106 | for s in range(N_subj_exp2): 107 | for p in range(N_prob): 108 | if p < 22: 109 | # Subject ID 110 | subjID.append(N_subj_exp1 + 1 + s) 111 | # Correct prediction 112 | gen_correct_pred.append(all_human_gen_correct_pred[s,p]) 113 | MC_correct_pred.append(all_human_MC_correct_pred[s,p]) 114 | # Human vs. GPT-3 115 | human_vs_gpt.append(0) 116 | # Experiment 1 vs. experiment 4 117 | exp1_vs_exp2.append(1) 118 | # Problem-type specific variables 119 | # One-rule problems 120 | if p < 6: 121 | # Problem type 122 | prob_type.append(0) 123 | # One-rule subtypes 124 | if p < 2: 125 | onerule_subtype.append(0) 126 | elif p == 2 or p == 3: 127 | onerule_subtype.append(1) 128 | elif p == 4 or p == 5: 129 | onerule_subtype.append(2) 130 | # Two-rule problems 131 | elif p >= 6 and p < 12: 132 | # Problem type 133 | prob_type.append(1) 134 | # Dummy code one-rule subtypes 135 | onerule_subtype.append(-1) 136 | # Three-rule problems 137 | elif p >= 12 and p < 22: 138 | # Problem type 139 | prob_type.append(2) 140 | # Dummy code one-rule subtypes 141 | onerule_subtype.append(-1) 142 | 143 | # Load GPT-3 data 144 | gpt3_data = np.load('./gpt_matprob_results_1thru5.npz', allow_pickle=True) 145 | prob_type_names = builtins.list(gpt3_data['all_gen_correct_pred'].item().keys()) 146 | # Add to dataset 147 | for p in range(N_prob): 148 | # Problem type name 149 | prob_type_name = prob_type_names[p] 150 | # Loop through all trials for this problem type 151 | N_trials = len(gpt3_data['all_gen_correct_pred'].item()[prob_type_name]) 152 | for t in range(N_trials): 153 | if p < 22: 154 | # Subject ID 155 | subjID.append(N_subj_exp1) 156 | # Correct prediction 157 | gen_correct_pred.append(int(gpt3_data['all_gen_correct_pred'].item()[prob_type_name][t])) 158 | MC_correct_pred.append(int(gpt3_data['all_MC_correct_pred'].item()[prob_type_name][t])) 159 | # Human vs. GPT-3 160 | human_vs_gpt.append(1) 161 | # Experiment 1 vs. experiment 2 162 | exp1_vs_exp2.append(1) 163 | # Problem-type specific variables 164 | # One-rule problems 165 | if p < 6: 166 | # Problem type 167 | prob_type.append(0) 168 | # One-rule subtypes 169 | if p < 2: 170 | onerule_subtype.append(0) 171 | elif p == 2 or p == 3: 172 | onerule_subtype.append(1) 173 | elif p == 4 or p == 5: 174 | onerule_subtype.append(2) 175 | # Two-rule problems 176 | elif p >= 6 and p < 12: 177 | # Problem type 178 | prob_type.append(1) 179 | # Dummy code one-rule subtypes 180 | onerule_subtype.append(-1) 181 | # Three-rule problems 182 | elif p >= 12 and p < 22: 183 | # Problem type 184 | prob_type.append(2) 185 | # Dummy code one-rule subtypes 186 | onerule_subtype.append(-1) 187 | 188 | # Write csv file 189 | # Create file 190 | f = open('./exp1_vs_exp2_all_data.csv', 'w') 191 | writer = csv.writer(f) 192 | # Header 193 | header = ['subjID', 'gen_correct_pred', 'MC_correct_pred', 'human_vs_gpt', 'exp1_vs_exp2', 'prob_type', 'onerule_subtype'] 194 | writer.writerow(header) 195 | # Write data 196 | for i in range(len(subjID)): 197 | data_row = [subjID[i], gen_correct_pred[i], MC_correct_pred[i], human_vs_gpt[i], exp1_vs_exp2[i], prob_type[i], onerule_subtype[i]] 198 | writer.writerow(data_row) 199 | # Close file 200 | f.close() 201 | 202 | 203 | 204 | 205 | 206 | 207 | -------------------------------------------------------------------------------- /digit_mat/exp1_create_stats_dset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | import builtins 4 | 5 | # Load human behavioral data 6 | human_data = np.load('./exp1_behavioral_data/ind_subj_results.npz') 7 | all_human_gen_correct_pred = human_data['all_subj_gen_correct_pred'] 8 | all_human_MC_correct_pred = human_data['all_subj_MC_correct_pred'] 9 | N_subj = all_human_gen_correct_pred.shape[0] 10 | N_prob = all_human_gen_correct_pred.shape[1] 11 | # Load data about # of unique relations in each problem 12 | N_unique_rules_2rule_prob = np.load('./N_unique_rules_2rule_prob.npz')['N_unique_rules'] 13 | N_unique_rules_3rule_prob = np.load('./N_unique_rules_3rule_prob.npz')['N_unique_rules'] 14 | # Create dataset 15 | subjID = [] 16 | gen_correct_pred = [] 17 | MC_correct_pred = [] 18 | human_vs_gpt = [] 19 | N_unique_rules = [] 20 | prob_type = [] 21 | onerule_rule_type = [] 22 | tworule_prog_noprog = [] 23 | aligned_permuted = [] 24 | for s in range(N_subj): 25 | for p in range(N_prob): 26 | # Subject ID 27 | subjID.append(s) 28 | # Correct prediction 29 | gen_correct_pred.append(all_human_gen_correct_pred[s,p]) 30 | MC_correct_pred.append(all_human_MC_correct_pred[s,p]) 31 | # Human vs. GPT-3 32 | human_vs_gpt.append(0) 33 | # Problem-type specific variables 34 | # One-rule problems 35 | if p < 6: 36 | # Problem type 37 | prob_type.append(0) 38 | # Number of unique rules 39 | N_unique_rules.append(0) 40 | # Rule type 41 | # Constant rule 42 | if p == 0 or p == 1: 43 | onerule_rule_type.append(0) 44 | # Distribution-of-3 rule 45 | elif p == 2 or p == 3: 46 | onerule_rule_type.append(1) 47 | # Progression rule 48 | elif p == 4 or p == 5: 49 | onerule_rule_type.append(2) 50 | # Dummy code for aligned vs. permuted (logic rules only) 51 | aligned_permuted.append(-1) 52 | # Dummy code for two-rule progression vs. no progression 53 | tworule_prog_noprog.append(-1) 54 | # Two-rule problems 55 | elif p >= 6 and p < 12: 56 | # Problem type 57 | prob_type.append(1) 58 | # Dummy code for one-rule rule type 59 | onerule_rule_type.append(-1) 60 | # Dummy code for aligned vs. permuted (logic rules only) 61 | aligned_permuted.append(-1) 62 | # Number of unique rules 63 | N_unique_rules.append(N_unique_rules_2rule_prob[p-6] - 1) 64 | # Progression rule present 65 | if p == 6 or p == 7 or p == 9: 66 | tworule_prog_noprog.append(0) 67 | elif p == 8 or p == 10 or p == 11: 68 | tworule_prog_noprog.append(1) 69 | # Three-rule problems 70 | elif p >= 12 and p < 22: 71 | # Problem type 72 | prob_type.append(2) 73 | # Dummy code for one-rule rule type 74 | onerule_rule_type.append(-1) 75 | # Dummy code for aligned vs. permuted (logic rules only) 76 | aligned_permuted.append(-1) 77 | # Number of unique rules 78 | N_unique_rules.append(N_unique_rules_3rule_prob[p-12] - 1) 79 | # Dummy code for two-rule progression vs. no progression 80 | tworule_prog_noprog.append(-1) 81 | # Logic problems 82 | elif p >= 22: 83 | # Problem type 84 | prob_type.append(3) 85 | # Dummy code for one-rule rule type 86 | onerule_rule_type.append(-1) 87 | # Dummy code for two-rule progression vs. no progression 88 | tworule_prog_noprog.append(-1) 89 | # Number of unique rules 90 | N_unique_rules.append(0) 91 | # Spatially aligned elements 92 | if p < 27: 93 | aligned_permuted.append(0) 94 | else: 95 | aligned_permuted.append(1) 96 | # Load GPT-3 data 97 | gpt3_data = np.load('./gpt_matprob_results.npz', allow_pickle=True) 98 | prob_type_names = builtins.list(gpt3_data['all_gen_correct_pred'].item().keys()) 99 | # Add to dataset 100 | for p in range(N_prob): 101 | # Problem type name 102 | prob_type_name = prob_type_names[p] 103 | # Loop through all trials for this problem type 104 | N_trials = len(gpt3_data['all_gen_correct_pred'].item()[prob_type_name]) 105 | for t in range(N_trials): 106 | # Subject ID 107 | subjID.append(N_subj) 108 | # Correct prediction 109 | gen_correct_pred.append(int(gpt3_data['all_gen_correct_pred'].item()[prob_type_name][t])) 110 | MC_correct_pred.append(int(gpt3_data['all_MC_correct_pred'].item()[prob_type_name][t])) 111 | # Human vs. GPT-3 112 | human_vs_gpt.append(1) 113 | # Problem-type specific variables 114 | # One-rule problems 115 | if p < 6: 116 | # Problem type 117 | prob_type.append(0) 118 | # Number of unique rules 119 | N_unique_rules.append(0) 120 | # Rule type 121 | # Constant rule 122 | if p == 0 or p == 1: 123 | onerule_rule_type.append(0) 124 | # Distribution-of-3 rule 125 | elif p == 2 or p == 3: 126 | onerule_rule_type.append(1) 127 | # Progression rule 128 | elif p == 4 or p == 5: 129 | onerule_rule_type.append(2) 130 | # Dummy code for aligned vs. permuted (logic rules only) 131 | aligned_permuted.append(-1) 132 | # Dummy code for two-rule progression vs. no progression 133 | tworule_prog_noprog.append(-1) 134 | # Two-rule problems 135 | elif p >= 6 and p < 12: 136 | # Problem type 137 | prob_type.append(1) 138 | # Dummy code for one-rule rule type 139 | onerule_rule_type.append(-1) 140 | # Number of unique rules 141 | N_unique_rules.append(N_unique_rules_2rule_prob[p-6] - 1) 142 | # Dummy code for aligned vs. permuted (logic rules only) 143 | aligned_permuted.append(-1) 144 | # Progression rule present 145 | if p == 6 or p == 7 or p == 9: 146 | tworule_prog_noprog.append(0) 147 | elif p == 8 or p == 10 or p == 11: 148 | tworule_prog_noprog.append(1) 149 | # Three-rule problems 150 | elif p >= 12 and p < 22: 151 | # Problem type 152 | prob_type.append(2) 153 | # Dummy code for one-rule rule type 154 | onerule_rule_type.append(-1) 155 | # Number of unique rules 156 | N_unique_rules.append(N_unique_rules_3rule_prob[p-12] - 1) 157 | # Dummy code for aligned vs. permuted (logic rules only) 158 | aligned_permuted.append(-1) 159 | # Dummy code for two-rule progression vs. no progression 160 | tworule_prog_noprog.append(-1) 161 | # Logic problems 162 | elif p >= 22: 163 | # Problem type 164 | prob_type.append(3) 165 | # Dummy code for one-rule rule type 166 | onerule_rule_type.append(-1) 167 | # Dummy code for two-rule progression vs. no progression 168 | tworule_prog_noprog.append(-1) 169 | # Number of unique rules 170 | N_unique_rules.append(0) 171 | # Spatially aligned elements 172 | if p < 27: 173 | aligned_permuted.append(0) 174 | else: 175 | aligned_permuted.append(1) 176 | 177 | # Write csv file 178 | # Create file 179 | f = open('./exp1_all_data.csv', 'w') 180 | writer = csv.writer(f) 181 | # Header 182 | header = ['subjID', 'gen_correct_pred', 'MC_correct_pred', 'human_vs_gpt', 'N_unique_rules', 'prob_type', 'onerule_rule_type', 'aligned_permuted', 'tworule_prog_noprog'] 183 | writer.writerow(header) 184 | # Write data 185 | for i in range(len(subjID)): 186 | data_row = [subjID[i], gen_correct_pred[i], MC_correct_pred[i], human_vs_gpt[i], N_unique_rules[i], prob_type[i], onerule_rule_type[i], aligned_permuted[i], tworule_prog_noprog[i]] 187 | writer.writerow(data_row) 188 | # Close file 189 | f.close() 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /digit_mat/eval_gpt_matprob.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import numpy as np 3 | import builtins 4 | import os 5 | 6 | # Split word into characters 7 | def split(word): 8 | return [char for char in word] 9 | 10 | # Load all problems 11 | all_prob = np.load('./all_problems.npz', allow_pickle=True) 12 | 13 | # GPT-3 settings 14 | openai.api_key = "FILL_IN_API_KEY_HERE" 15 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":10, "stop":"\n", "echo":True, "logprobs":1, } 16 | 17 | # Loop through all problem types 18 | all_prob_types = builtins.list(all_prob['all_problems'].item().keys()) 19 | # Load data if it already exists 20 | all_data_fname = './gpt_matprob_results.npz' 21 | if os.path.exists(all_data_fname): 22 | data_exists = True 23 | all_data = np.load('./gpt_matprob_results.npz', allow_pickle=True) 24 | else: 25 | data_exists = False 26 | # Create data structure for storing results 27 | all_gen_pred = {} 28 | all_gen_correct_pred = {} 29 | all_MC_pred = {} 30 | all_MC_correct_pred = {} 31 | all_alt_MC_correct_pred = {} 32 | for p in range(len(all_prob_types)): 33 | # Problem type 34 | prob_type = all_prob_types[p] 35 | # Load data 36 | if data_exists: 37 | all_gen_pred[prob_type] = all_data['all_gen_pred'].item()[prob_type] 38 | all_gen_correct_pred[prob_type] = all_data['all_gen_correct_pred'].item()[prob_type] 39 | all_MC_pred[prob_type] = all_data['all_MC_pred'].item()[prob_type] 40 | all_MC_correct_pred[prob_type] = all_data['all_MC_correct_pred'].item()[prob_type] 41 | all_alt_MC_correct_pred[prob_type] = all_data['all_alt_MC_correct_pred'].item()[prob_type] 42 | # Create data structure 43 | else: 44 | all_gen_pred[prob_type] = [] 45 | all_gen_correct_pred[prob_type] = [] 46 | all_MC_pred[prob_type] = [] 47 | all_MC_correct_pred[prob_type] = [] 48 | all_alt_MC_correct_pred[prob_type] = [] 49 | # Loop over all problem indices 50 | N_prob = 40 51 | for prob_ind in range(N_prob): 52 | print(str(prob_ind + 1) + ' of ' + str(N_prob) + '...') 53 | # Loop over all problem types 54 | for p in range(len(all_prob_types)): 55 | # Problem type 56 | prob_type = all_prob_types[p] 57 | print('Problem type: ' + prob_type + '...') 58 | perm_invariant = all_prob['all_problems'].item()[prob_type]['perm_invariant'] 59 | prob_type_N_prob = all_prob['all_problems'].item()[prob_type]['prob'].shape[0] 60 | if prob_ind < prob_type_N_prob and len(all_gen_correct_pred[prob_type]) <= prob_ind: 61 | 62 | # Problem 63 | prob = all_prob['all_problems'].item()[prob_type]['prob'][prob_ind] 64 | answer_choices = all_prob['all_problems'].item()[prob_type]['answer_choices'][prob_ind] 65 | correct_ind = all_prob['all_problems'].item()[prob_type]['correct_ind'][prob_ind] 66 | correct_answer = answer_choices[correct_ind] 67 | 68 | # Generate prompt 69 | prompt = '' 70 | for r in range(3): 71 | for c in range(3): 72 | prompt += '[' 73 | if not (r == 2 and c == 2): 74 | for i in range(len(prob[r][c])): 75 | if prob[r][c][i] == -1: 76 | prompt += ' ' 77 | else: 78 | prompt += str(prob[r][c][i]) 79 | if i < len(prob[r][c]) - 1: 80 | prompt += ' ' 81 | prompt += ']' 82 | if c < 2: 83 | prompt += ' ' 84 | else: 85 | prompt += '\n' 86 | 87 | # Get response 88 | response = openai.Completion.create(prompt=prompt, **kwargs) 89 | response_text = response['choices'][0]['text'] 90 | # Find portion of response corresponding to prediction 91 | prediction = response_text[len(prompt):] 92 | all_gen_pred[prob_type].append(prediction) 93 | # Get prediction set 94 | pred_set = [] 95 | invalid_char = False 96 | closing_bracket = False 97 | for i in range(len(split(prediction))): 98 | if prediction[i] != ' ': 99 | if prediction[i].isdigit(): 100 | pred_set.append(int(prediction[i])) 101 | elif prediction[i] == ']': 102 | break 103 | else: 104 | invalid_char = True 105 | break 106 | # Sort answer if problem type is permutation invariant 107 | if perm_invariant: 108 | correct_answer = np.sort(correct_answer) 109 | pred_set = np.sort(pred_set) 110 | # Determine whether prediction is correct 111 | correct_pred = False 112 | if not invalid_char and len(pred_set) == len(correct_answer): 113 | if np.all(pred_set == correct_answer): 114 | correct_pred = True 115 | all_gen_correct_pred[prob_type].append(correct_pred) 116 | 117 | # Get score for generated response 118 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1] 119 | response_complete = False 120 | token_ind = first_token_ind 121 | gen_completion = '' 122 | while not response_complete: 123 | token = response['choices'][0]['logprobs']['tokens'][token_ind] 124 | gen_completion += token 125 | contains_closed_bracket = False 126 | for i in range(len(token)): 127 | if token[i] == ']': 128 | contains_closed_bracket = True 129 | if contains_closed_bracket: 130 | response_complete = True 131 | if token == ']': 132 | last_token_ind = token_ind - 1 133 | else: 134 | last_token_ind = token_ind 135 | token_ind += 1 136 | gen_score = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind+1]) 137 | 138 | # Evaluate answer choices 139 | all_choice_logprob = [] 140 | for a in range(8): 141 | # Convert choice to string and remove ',' 142 | choice_str = np.array(split(str(answer_choices[a]))) 143 | choice_str = ''.join(builtins.list(choice_str[choice_str != ','])) 144 | # Add answer choice to prompt 145 | prompt_choice = prompt + choice_str[1:] 146 | # Get average log probability of response 147 | response = openai.Completion.create(prompt=prompt_choice, **kwargs) 148 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(prompt))[0][-1] 149 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(prompt_choice))[0][0] 150 | choice_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind]) 151 | all_choice_logprob.append(choice_avg_logprob) 152 | # Select answer 153 | model_choice = np.argmax(all_choice_logprob) 154 | all_MC_pred[prob_type].append(model_choice) 155 | # Determine whether multiple choice selection is correct 156 | MC_correct = model_choice == correct_ind 157 | all_MC_correct_pred[prob_type].append(MC_correct) 158 | 159 | # Alternative multiple-choice evaluation 160 | if correct_pred: 161 | alt_MC_correct = True 162 | else: 163 | if MC_correct: 164 | all_choice_logprob.append(gen_score) 165 | alt_model_choice = np.argmax(all_choice_logprob) 166 | alt_MC_correct = alt_model_choice == correct_ind 167 | else: 168 | alt_MC_correct = False 169 | all_alt_MC_correct_pred[prob_type].append(alt_MC_correct) 170 | 171 | # Save data 172 | eval_fname = './gpt_matprob_results.npz' 173 | np.savez(eval_fname, 174 | all_gen_pred=all_gen_pred, all_gen_correct_pred=all_gen_correct_pred, all_MC_pred=all_MC_pred, all_MC_correct_pred=all_MC_correct_pred, all_alt_MC_correct_pred=all_alt_MC_correct_pred, 175 | allow_pickle=True) 176 | -------------------------------------------------------------------------------- /letter_string/compare_behavior_gpt3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import os 5 | 6 | def check_path(path): 7 | if not os.path.exists(path): 8 | os.mkdir(path) 9 | 10 | def hide_top_right(ax): 11 | ax.spines['right'].set_visible(False) 12 | ax.spines['top'].set_visible(False) 13 | ax.yaxis.set_ticks_position('left') 14 | ax.xaxis.set_ticks_position('bottom') 15 | 16 | def plot_ind_data(ax, x_points, data, ind_bar_width): 17 | max_count = 24 18 | point_unit = ind_bar_width / max_count 19 | # Plot 20 | for i in range(len(x_points)): 21 | unique_vals = np.unique(data[i]) 22 | for v in unique_vals: 23 | count = (data[i]==v).sum() 24 | span = count * point_unit 25 | x_min = x_points[i] - (span/2) 26 | x_max = x_points[i] + (span/2) 27 | x_vals = np.linspace(x_min,x_max,count) 28 | if v == 0: 29 | y_vals = np.ones(count) * 0.005 30 | else: 31 | y_vals = np.ones(count) * v 32 | plt.scatter(x_vals, y_vals, color='black', s=0.4, clip_on=False, marker='_') 33 | return ax 34 | 35 | # Settings 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.") 38 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.") 39 | args = parser.parse_args() 40 | 41 | # Results directories 42 | if args.sentence: 43 | results_dir = './human_vs_GPT3_sentence/' 44 | GPT3_results_dir = './GPT3_results_sentence/' 45 | elif args.noprompt: 46 | results_dir = './human_vs_GPT3_noprompt/' 47 | GPT3_results_dir = './GPT3_results_noprompt/' 48 | else: 49 | results_dir = './human_vs_GPT3/' 50 | GPT3_results_dir = './GPT3_results/' 51 | check_path(results_dir) 52 | 53 | # Plot settings 54 | gpt3_color = 'darkslateblue' 55 | human_color = 'powderblue' 56 | plot_fontsize = 10 57 | title_fontsize = 12 58 | axis_label_fontsize = 12 59 | bar_width = 0.8 60 | ind_bar_width = bar_width / 2 61 | 62 | ## Zero-generalization problems, grouped by transformation type 63 | # Load results 64 | human_zerogen_results = np.load('./behavioral_results/zerogen_acc.npz') 65 | human_zerogen_acc = human_zerogen_results['all_acc'] 66 | human_zerogen_err = human_zerogen_results['all_err'] 67 | GPT3_zerogen_results = np.load(GPT3_results_dir + 'zerogen_acc.npz') 68 | GPT3_zerogen_acc = GPT3_zerogen_results['all_acc'] 69 | GPT3_zerogen_err = GPT3_zerogen_results['all_err'] 70 | # Sort based on accuracy 71 | rank_order = np.flip(np.argsort(human_zerogen_acc)) 72 | # Plot 73 | all_zerogen_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Remove\nredundant\nletter', 'Fix\nalphabetic\nsequence', 'Sort'] 74 | x_points = np.arange(len(all_zerogen_prob_type_names)) 75 | ax = plt.subplot(111) 76 | plt.bar(x_points - (ind_bar_width/2), GPT3_zerogen_acc[rank_order], yerr=GPT3_zerogen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 77 | plt.bar(x_points + (ind_bar_width/2), human_zerogen_acc[rank_order], yerr=human_zerogen_err[:,rank_order], color=human_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 78 | plt.ylim([0,1]) 79 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 80 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 81 | plt.xticks(x_points, np.array(all_zerogen_prob_type_names)[rank_order], fontsize=plot_fontsize) 82 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize) 83 | plt.title('Zero-generalization problems') 84 | plt.legend(['GPT-3', 'Human'],fontsize=12,frameon=False) 85 | hide_top_right(ax) 86 | plt.tight_layout() 87 | plt.savefig(results_dir + 'zerogen_acc.pdf', dpi=300, bbox_inches="tight") 88 | plt.close() 89 | 90 | ## One-generalization problems, grouped by generalization type 91 | # Load results 92 | human_onegen_results = np.load('./behavioral_results/onegen_acc.npz') 93 | human_onegen_acc = human_onegen_results['all_acc'] 94 | human_onegen_err = human_onegen_results['all_err'] 95 | GPT3_onegen_results = np.load(GPT3_results_dir + 'onegen_acc.npz') 96 | GPT3_onegen_acc = GPT3_onegen_results['all_acc'] 97 | GPT3_onegen_err = GPT3_onegen_results['all_err'] 98 | # Sort based on accuracy 99 | rank_order = np.flip(np.argsort(human_onegen_acc)) 100 | # Plot 101 | all_onegen_prob_type_names = ['Larger\ninterval', 'Longer\ntarget', 'Grouping', 'Interleaved\ndistractor', 'Letter-to-\nnumber', 'Reverse\norder'] 102 | x_points = np.arange(len(all_onegen_prob_type_names)) 103 | ax = plt.subplot(111) 104 | plt.bar(x_points - (ind_bar_width/2), GPT3_onegen_acc[rank_order], yerr=GPT3_onegen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 105 | plt.bar(x_points + (ind_bar_width/2), human_onegen_acc[rank_order], yerr=human_onegen_err[:,rank_order], color=human_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 106 | plt.ylim([0,1]) 107 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 108 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 109 | plt.xticks(x_points, np.array(all_onegen_prob_type_names)[rank_order], fontsize=plot_fontsize) 110 | plt.xlabel('Generalization type', fontsize=axis_label_fontsize) 111 | plt.title('One-generalization problems') 112 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False) 113 | hide_top_right(ax) 114 | plt.tight_layout() 115 | plt.savefig(results_dir + 'onegen_acc.pdf', dpi=300, bbox_inches="tight") 116 | plt.close() 117 | 118 | ## All problems, grouped by number of generalizations 119 | # Load results 120 | human_all_gen_results = np.load('./behavioral_results/all_gen_acc.npz') 121 | human_all_gen_acc_ind_results = human_all_gen_results['all_ind_results'] 122 | human_all_gen_acc = human_all_gen_results['all_acc'] 123 | human_all_gen_err = human_all_gen_results['all_err'] 124 | human_all_gen_std = human_all_gen_results['all_std'] 125 | GPT3_all_gen_results = np.load(GPT3_results_dir + 'all_gen_acc.npz') 126 | GPT3_all_gen_acc = GPT3_all_gen_results['all_acc'] 127 | GPT3_all_gen_err = GPT3_all_gen_results['all_err'] 128 | # Sort based on accuracy 129 | rank_order = np.flip(np.argsort(human_all_gen_acc)) 130 | # Plot 131 | all_gen_prob_type_names = np.arange(len(human_all_gen_acc)).astype(str) 132 | x_points = np.arange(len(all_gen_prob_type_names)) 133 | ax = plt.subplot(111) 134 | plt.bar(x_points - (ind_bar_width/2), GPT3_all_gen_acc[rank_order], yerr=GPT3_all_gen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 135 | plt.bar(x_points + (ind_bar_width/2), human_all_gen_acc[rank_order], yerr=human_all_gen_err[rank_order], color=human_color, edgecolor='black', width=ind_bar_width) 136 | plt.ylim([0,1]) 137 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 138 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 139 | plt.xticks(x_points, all_gen_prob_type_names, fontsize=plot_fontsize) 140 | plt.xlabel('Number of generalizations', fontsize=axis_label_fontsize) 141 | plt.title(' ') 142 | plot_ind_data(ax, x_points + (ind_bar_width/2), human_all_gen_acc_ind_results, ind_bar_width) 143 | hide_top_right(ax) 144 | ax.set_aspect(4) 145 | plt.tight_layout() 146 | plt.savefig(results_dir + 'all_gen_acc.pdf', dpi=300, bbox_inches="tight") 147 | plt.close() 148 | 149 | ## Generalization to real-world concepts 150 | # Load results 151 | human_realworld_results = np.load('./behavioral_results/realworld_acc.npz') 152 | human_realworld_acc = human_realworld_results['all_acc'] 153 | human_realworld_err = human_realworld_results['all_err'] 154 | GPT3_realworld_results = np.load(GPT3_results_dir + 'realworld_acc.npz') 155 | GPT3_realworld_acc = GPT3_realworld_results['all_acc'] 156 | GPT3_realworld_err = GPT3_realworld_results['all_err'] 157 | # Sort based on accuracy 158 | rank_order = np.flip(np.argsort(GPT3_realworld_acc)) 159 | # Plot 160 | all_realworld_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Sort'] 161 | x_points = np.arange(len(all_realworld_prob_type_names)) 162 | ax = plt.subplot(111) 163 | plt.bar(x_points - (ind_bar_width/2), GPT3_realworld_acc[rank_order], yerr=GPT3_realworld_err[:,rank_order], color=gpt3_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 164 | plt.bar(x_points + (ind_bar_width/2), human_realworld_acc[rank_order], yerr=human_realworld_err[:,rank_order], color=human_color, edgecolor='black', width=ind_bar_width, ecolor='gray') 165 | plt.ylim([0,1]) 166 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 167 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 168 | plt.xticks(x_points, np.array(all_realworld_prob_type_names)[rank_order], fontsize=plot_fontsize) 169 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize) 170 | plt.title('Real-world concept problems') 171 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False) 172 | hide_top_right(ax) 173 | ax.set_aspect(4) 174 | plt.tight_layout() 175 | plt.savefig(results_dir + 'realworld_acc.pdf', dpi=300, bbox_inches="tight") 176 | plt.close() 177 | -------------------------------------------------------------------------------- /digit_mat/eval_gpt_matprob_prog_1thru5.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import numpy as np 3 | import builtins 4 | import os 5 | 6 | def check_path(path): 7 | if not os.path.exists(path): 8 | os.mkdir(path) 9 | 10 | # Split word into characters 11 | def split(word): 12 | return [char for char in word] 13 | 14 | # Load all problems 15 | all_prob = np.load('./all_problems_1thru5.npz', allow_pickle=True) 16 | 17 | # GPT-3 settings 18 | openai.api_key = "FILL_IN_API_KEY_HERE" 19 | kwargs = { "engine":"text-davinci-003", "temperature":0, "max_tokens":10, "stop":"\n", "echo":True, "logprobs":1, } 20 | 21 | # Loop through all problem types 22 | all_prob_types = builtins.list(all_prob['all_problems'].item().keys()) 23 | # Load data if it already exists 24 | all_data_fname = './gpt_matprob_results_1thru5.npz' 25 | if os.path.exists(all_data_fname): 26 | data_exists = True 27 | all_data = np.load('./gpt_matprob_results_1thru5.npz', allow_pickle=True) 28 | else: 29 | data_exists = False 30 | # Create data structure for storing results 31 | all_gen_pred = {} 32 | all_gen_correct_pred = {} 33 | all_MC_pred = {} 34 | all_MC_correct_pred = {} 35 | all_alt_MC_correct_pred = {} 36 | for p in range(len(all_prob_types)): 37 | # Problem type 38 | prob_type = all_prob_types[p] 39 | # Load data 40 | if data_exists: 41 | all_gen_pred[prob_type] = all_data['all_gen_pred'].item()[prob_type] 42 | all_gen_correct_pred[prob_type] = all_data['all_gen_correct_pred'].item()[prob_type] 43 | all_MC_pred[prob_type] = all_data['all_MC_pred'].item()[prob_type] 44 | all_MC_correct_pred[prob_type] = all_data['all_MC_correct_pred'].item()[prob_type] 45 | all_alt_MC_correct_pred[prob_type] = all_data['all_alt_MC_correct_pred'].item()[prob_type] 46 | # Create data structure 47 | else: 48 | all_gen_pred[prob_type] = [] 49 | all_gen_correct_pred[prob_type] = [] 50 | all_MC_pred[prob_type] = [] 51 | all_MC_correct_pred[prob_type] = [] 52 | all_alt_MC_correct_pred[prob_type] = [] 53 | # Loop over all problem indices 54 | N_runs = 20 55 | for run in range(N_runs): 56 | print(str(run + 1) + ' of ' + str(N_runs) + '...') 57 | # Initialize context with task instructions 58 | context = '[1] [1] [1]\n[2] [2] [2]\n[3] [3] [3]\n\n' 59 | # Loop over all problem types 60 | for p in range(len(all_prob_types)): 61 | # Problem type 62 | prob_type = all_prob_types[p] 63 | print('Problem type: ' + prob_type + '...') 64 | perm_invariant = all_prob['all_problems'].item()[prob_type]['perm_invariant'] 65 | prob_type_N_prob = all_prob['all_problems'].item()[prob_type]['prob'].shape[0] 66 | if len(all_gen_correct_pred[prob_type]) <= run: 67 | 68 | # Sample problem index 69 | prob_ind = int(np.floor(np.random.rand() * prob_type_N_prob)) 70 | 71 | # Problem 72 | prob = all_prob['all_problems'].item()[prob_type]['prob'][prob_ind] 73 | answer_choices = all_prob['all_problems'].item()[prob_type]['answer_choices'][prob_ind] 74 | correct_ind = all_prob['all_problems'].item()[prob_type]['correct_ind'][prob_ind] 75 | correct_answer = answer_choices[correct_ind] 76 | 77 | # Generate prompt 78 | prompt = '' 79 | for r in range(3): 80 | for c in range(3): 81 | prompt += '[' 82 | if not (r == 2 and c == 2): 83 | for i in range(len(prob[r][c])): 84 | if prob[r][c][i] == -1: 85 | prompt += ' ' 86 | else: 87 | prompt += str(prob[r][c][i]) 88 | if i < len(prob[r][c]) - 1: 89 | prompt += ' ' 90 | prompt += ']' 91 | if c < 2: 92 | prompt += ' ' 93 | else: 94 | prompt += '\n' 95 | # Add context 96 | context_prompt = context + prompt 97 | 98 | # Get response 99 | fits_window = False 100 | response = [] 101 | while not fits_window: 102 | try: 103 | response = openai.Completion.create(prompt=context_prompt, **kwargs) 104 | except: 105 | print('deleting problem from context...') 106 | context_prompt = context_prompt.split('\n\n')[1:] 107 | new_context_prompt = '' 108 | for i in range(len(context_prompt)): 109 | new_context_prompt += context_prompt[i] 110 | if i < (len(context_prompt) - 1): 111 | new_context_prompt += '\n\n' 112 | context_prompt = new_context_prompt 113 | if len(response) > 0: 114 | fits_window = True 115 | response_text = response['choices'][0]['text'] 116 | # Find portion of response corresponding to prediction 117 | prediction = response_text[len(context_prompt):] 118 | all_gen_pred[prob_type].append(prediction) 119 | # Get prediction set 120 | pred_set = [] 121 | invalid_char = False 122 | closing_bracket = False 123 | for i in range(len(split(prediction))): 124 | if prediction[i] != ' ': 125 | if prediction[i].isdigit(): 126 | pred_set.append(int(prediction[i])) 127 | elif prediction[i] == ']': 128 | break 129 | else: 130 | invalid_char = True 131 | break 132 | # Sort answer if problem type is permutation invariant 133 | if perm_invariant: 134 | correct_answer = np.sort(correct_answer) 135 | pred_set = np.sort(pred_set) 136 | # Determine whether prediction is correct 137 | correct_pred = False 138 | if not invalid_char and len(pred_set) == len(correct_answer): 139 | if np.all(pred_set == correct_answer): 140 | correct_pred = True 141 | all_gen_correct_pred[prob_type].append(correct_pred) 142 | 143 | # Get score for generated response 144 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(context_prompt))[0][-1] 145 | response_complete = False 146 | token_ind = first_token_ind 147 | gen_completion = '' 148 | while not response_complete: 149 | token = response['choices'][0]['logprobs']['tokens'][token_ind] 150 | gen_completion += token 151 | contains_closed_bracket = False 152 | for i in range(len(token)): 153 | if token[i] == ']': 154 | contains_closed_bracket = True 155 | if contains_closed_bracket: 156 | response_complete = True 157 | if token == ']': 158 | last_token_ind = token_ind - 1 159 | else: 160 | last_token_ind = token_ind 161 | token_ind += 1 162 | gen_score = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind+1]) 163 | 164 | # Evaluate answer choices 165 | all_choice_logprob = [] 166 | for a in range(8): 167 | # Convert choice to string and remove ',' 168 | choice_str = np.array(split(str(answer_choices[a]))) 169 | choice_str = ''.join(builtins.list(choice_str[choice_str != ','])) 170 | # Add answer choice to prompt 171 | context_prompt_choice = context_prompt + choice_str[1:] 172 | # Get average log probability of response 173 | fits_window = False 174 | response = [] 175 | while not fits_window: 176 | try: 177 | response = openai.Completion.create(prompt=context_prompt_choice, **kwargs) 178 | except: 179 | print('deleting problem from context...') 180 | context_prompt = context_prompt.split('\n\n')[1:] 181 | new_context_prompt = '' 182 | for i in range(len(context_prompt)): 183 | new_context_prompt += context_prompt[i] 184 | if i < (len(context_prompt) - 1): 185 | new_context_prompt += '\n\n' 186 | context_prompt = new_context_prompt 187 | context_prompt_choice = context_prompt + choice_str[1:] 188 | if len(response) > 0: 189 | fits_window = True 190 | first_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) <= len(context_prompt))[0][-1] 191 | last_token_ind = np.where(np.array(response['choices'][0]['logprobs']['text_offset']) == len(context_prompt_choice))[0][0] 192 | choice_avg_logprob = np.mean(response['choices'][0]['logprobs']['token_logprobs'][first_token_ind:last_token_ind]) 193 | all_choice_logprob.append(choice_avg_logprob) 194 | # Select answer 195 | model_choice = np.argmax(all_choice_logprob) 196 | all_MC_pred[prob_type].append(model_choice) 197 | # Determine whether multiple choice selection is correct 198 | MC_correct = model_choice == correct_ind 199 | all_MC_correct_pred[prob_type].append(MC_correct) 200 | 201 | # Alternative multiple-choice evaluation 202 | if correct_pred: 203 | alt_MC_correct = True 204 | else: 205 | if MC_correct: 206 | all_choice_logprob.append(gen_score) 207 | alt_model_choice = np.argmax(all_choice_logprob) 208 | alt_MC_correct = alt_model_choice == correct_ind 209 | else: 210 | alt_MC_correct = False 211 | all_alt_MC_correct_pred[prob_type].append(alt_MC_correct) 212 | 213 | # Add problem to context 214 | model_choice_str = np.array(split(str(answer_choices[model_choice]))) 215 | model_choice_str = ''.join(builtins.list(model_choice_str[model_choice_str != ','])) 216 | completed_prob = context + prompt + model_choice_str[1:] 217 | completed_prob += '\n\n' 218 | context = completed_prob 219 | 220 | # Save data 221 | eval_fname = './gpt_matprob_results_1thru5.npz' 222 | np.savez(eval_fname, 223 | all_gen_pred=all_gen_pred, all_gen_correct_pred=all_gen_correct_pred, all_MC_pred=all_MC_pred, all_MC_correct_pred=all_MC_correct_pred, all_alt_MC_correct_pred=all_alt_MC_correct_pred, 224 | allow_pickle=True) 225 | # Raw output 226 | gen_data_dir = './gpt_matprob_results_1thru5/' 227 | check_path(gen_data_dir) 228 | gen_data_fname = gen_data_dir + str(run) + '.txt' 229 | gen_data_fid = open(gen_data_fname, 'w') 230 | gen_data_fid.write(context) 231 | gen_data_fid.close() 232 | 233 | else: 234 | 235 | # Load previously generated context 236 | gen_data_dir = './gpt_matprob_results_1thru5/' 237 | gen_data_fname = gen_data_dir + str(run) + '.txt' 238 | gen_data_fid = open(gen_data_fname, 'r') 239 | lines = gen_data_fid.readlines() 240 | context = ' '.join(lines) 241 | # Remove spaces 242 | context = context.split('\n') 243 | new_context = context[0] 244 | for c in range(1,len(context)): 245 | new_context += '\n' 246 | new_context += context[c][1:] 247 | context = new_context 248 | -------------------------------------------------------------------------------- /letter_string/gen_problems.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import random 4 | import json 5 | 6 | ## Problems generated using one of the following 6 transformations: 7 | # 8 | # Successorship 9 | # [a b c d] [a b c e] 10 | # [i j k l] [i j k m] 11 | # 12 | # Predecessorship 13 | # [b c d e] [a c d e] 14 | # [j k l m] [i k l m] 15 | # 16 | # Add letter to sequence 17 | # [a b c d] [a b c d e] 18 | # [i j k l] [i j k l m] 19 | # 20 | # Remove redundant character 21 | # [a b b c d e] [a b c d e] 22 | # [i j k l l m] [i j k l m] 23 | # 24 | # Fix alphabetic sequence 25 | # [a w c d e] [a b c d e] 26 | # [i j k p m] [i j k l m] 27 | # 28 | # Sort characters 29 | # [a b e d c] [a b c d e] 30 | # [i k j l m] [i j k l m] 31 | # 32 | # 33 | ## and between 0 and 3 generalizations, out of the following 6: 34 | # 35 | # Larger interval 36 | # [a b c d] [a b c e] 37 | # [i k m o] [i k m q] 38 | # 39 | # Longer target 40 | # [a b c d] [a b c e] 41 | # [i j k l m n o p] [i j k l m n o q] 42 | # 43 | # Grouping 44 | # [a b c d] [a b c e] 45 | # [i i j j k k l l] [i i j j k k m m] 46 | # 47 | # Interleaved X's 48 | # [a b c d] [a b c e] 49 | # [i x j x k x l x] [i x j x k x m x] 50 | # 51 | # Letter-to-number 52 | # [a b c d] [a b c e] 53 | # [1 2 3 4] [1 2 3 5] 54 | # 55 | # Reversal 56 | # [a b c d] [a b c e] 57 | # [m l k j] [m l k i] 58 | # 59 | ## 60 | 61 | # Alphabet 62 | letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 63 | N_letters = len(letters) 64 | # Numbers 65 | numbers = np.arange(N_letters) + 1 66 | # Linearly ordered real-world concepts 67 | realworld_linear = [['cold', 'cool', 'warm', 'hot'], 68 | ['love', 'like', 'dislike', 'hate'], 69 | ['jack', 'queen', 'king', 'ace'], 70 | ['penny', 'nickel', 'dime', 'quarter'], 71 | ['second', 'minute', 'hour', 'day']] 72 | 73 | # Successor transformation 74 | def apply_succ(prob_letters): 75 | return [prob_letters[:-1], prob_letters[:-2] + [prob_letters[-1]]] 76 | 77 | # Predecessor transformation 78 | def apply_pred(prob_letters): 79 | return [prob_letters[1:], [prob_letters[0]] + prob_letters[2:]] 80 | 81 | # Add letter to sequence 82 | def apply_add_letter(prob_letters): 83 | return [prob_letters[:-1], prob_letters] 84 | 85 | # Remove redundant letter 86 | def apply_remove_redundant(prob_letters): 87 | redundant_loc = np.arange(len(prob_letters)) 88 | np.random.shuffle(redundant_loc) 89 | redundant_loc = redundant_loc[0] 90 | prob_redundant = deepcopy(prob_letters) 91 | prob_redundant.insert(redundant_loc, prob_letters[redundant_loc]) 92 | return [prob_redundant, prob_letters] 93 | 94 | # Remove out-of-place character 95 | def apply_fix_alphabet(prob_letters): 96 | remaining_letters = np.array(deepcopy(letters)) 97 | remaining_letters = remaining_letters[np.all(np.expand_dims(np.array(remaining_letters),1) != np.expand_dims(np.array(prob_letters),0), 1)] 98 | np.random.shuffle(remaining_letters) 99 | insert_letter = remaining_letters[0] 100 | insert_loc = np.arange(len(prob_letters)) 101 | np.random.shuffle(insert_loc) 102 | insert_loc = insert_loc[0] 103 | prob_letters_insert = deepcopy(prob_letters) 104 | prob_letters_insert[insert_loc] = insert_letter 105 | return [prob_letters_insert, prob_letters] 106 | 107 | # Sort letters 108 | def apply_sort(prob_letters): 109 | swap_loc = np.arange(len(prob_letters)) 110 | np.random.shuffle(swap_loc) 111 | i_loc = swap_loc[0] 112 | j_loc = swap_loc[1] 113 | i_letter = prob_letters[i_loc] 114 | j_letter = prob_letters[j_loc] 115 | prob_swapped = deepcopy(prob_letters) 116 | prob_swapped[i_loc] = j_letter 117 | prob_swapped[j_loc] = i_letter 118 | return [prob_swapped, prob_letters] 119 | 120 | # Method for generating subset of problems 121 | def gen_prob_subset(N_generalize=0, N_prob=100, standard_len=5, longer_targ_len=9, larger_int_size=2, 122 | trans_allowed=['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort'], 123 | gen_allowed=['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse', 'realworld']): 124 | # Initialize storage for problems 125 | all_prob = [] 126 | all_trans = [] 127 | all_gen = [] 128 | all_src_letters = [] 129 | all_tgt_letters = [] 130 | while len(all_prob) < N_prob: 131 | # Sample source letters 132 | src_start = np.floor(np.random.rand() * (len(letters)-(standard_len-1))).astype(int) 133 | src_letters = letters[src_start:src_start+standard_len] 134 | # Sample generalizations 135 | random.shuffle(gen_allowed) 136 | generalize = gen_allowed[:N_generalize] 137 | # Sample target letters 138 | if 'realworld' in generalize: 139 | random.shuffle(realworld_linear) 140 | tgt_letters = realworld_linear[0] 141 | else: 142 | if 'longer_targ' in generalize and 'larger_int' in generalize: 143 | tgt_span = (longer_targ_len * larger_int_size) - 1 144 | src_duplicate = True 145 | while src_duplicate: 146 | tgt_start = np.floor(np.random.rand() * (len(letters)-(tgt_span-1))).astype(int) 147 | if src_start != tgt_start: 148 | src_duplicate = False 149 | tgt_letters = letters[tgt_start:tgt_start+tgt_span][::2] 150 | elif 'longer_targ' in generalize and 'larger_int' not in generalize: 151 | src_duplicate = True 152 | while src_duplicate: 153 | tgt_start = np.floor(np.random.rand() * (len(letters)-(longer_targ_len-1))).astype(int) 154 | if src_start != tgt_start: 155 | src_duplicate = False 156 | tgt_letters = letters[tgt_start:tgt_start+longer_targ_len] 157 | elif 'longer_targ' not in generalize and 'larger_int' in generalize: 158 | tgt_span = (standard_len * larger_int_size) - 1 159 | src_duplicate = True 160 | while src_duplicate: 161 | tgt_start = np.floor(np.random.rand() * (len(letters)-(tgt_span-1))).astype(int) 162 | if src_start != tgt_start: 163 | src_duplicate = False 164 | tgt_letters = letters[tgt_start:tgt_start+tgt_span][::2] 165 | elif 'longer_targ' not in generalize and 'larger_int' not in generalize: 166 | src_duplicate = True 167 | while src_duplicate: 168 | tgt_start = np.floor(np.random.rand() * (len(letters)-(standard_len-1))).astype(int) 169 | if src_start != tgt_start: 170 | src_duplicate = False 171 | tgt_letters = letters[tgt_start:tgt_start+standard_len] 172 | # Reverse target letters 173 | if 'reverse' in generalize: 174 | tgt_letters.reverse() 175 | # Sample transformation 176 | random.shuffle(trans_allowed) 177 | trans = trans_allowed[0] 178 | # Apply transformation 179 | if trans == 'succ': 180 | src = apply_succ(src_letters) 181 | tgt = apply_succ(tgt_letters) 182 | elif trans == 'pred': 183 | src = apply_pred(src_letters) 184 | tgt = apply_pred(tgt_letters) 185 | elif trans == 'add_letter': 186 | src = apply_add_letter(src_letters) 187 | tgt = apply_add_letter(tgt_letters) 188 | elif trans == 'remove_redundant': 189 | src = apply_remove_redundant(src_letters) 190 | tgt = apply_remove_redundant(tgt_letters) 191 | elif trans == 'fix_alphabet': 192 | src = apply_fix_alphabet(src_letters) 193 | tgt = apply_fix_alphabet(tgt_letters) 194 | elif trans == 'sort': 195 | src = apply_sort(src_letters) 196 | tgt = apply_sort(tgt_letters) 197 | # Generalization from letters to numbers 198 | if 'letter2num' in generalize: 199 | new_tgt = [] 200 | for i in range(len(tgt)): 201 | new_tgt_i = [] 202 | for j in range(len(tgt[i])): 203 | new_tgt_i.append(numbers[np.where(np.array(letters) == tgt[i][j])[0][0]]) 204 | new_tgt.append(new_tgt_i) 205 | tgt = new_tgt 206 | # Interleaved X's (or 0's, if target composed of numbers) 207 | if 'interleaved' in generalize: 208 | new_tgt = [] 209 | for i in range(len(tgt)): 210 | new_tgt_i = [] 211 | for j in range(len(tgt[i])): 212 | new_tgt_i.append(tgt[i][j]) 213 | if 'letter2num' in generalize: 214 | new_tgt_i.append('0') 215 | else: 216 | new_tgt_i.append('x') 217 | new_tgt.append(new_tgt_i) 218 | tgt = new_tgt 219 | # Grouping 220 | if 'group' in generalize: 221 | new_tgt = [] 222 | for i in range(len(tgt)): 223 | new_tgt_i = [] 224 | for j in range(len(tgt[i])): 225 | new_tgt_i.append(tgt[i][j]) 226 | new_tgt_i.append(tgt[i][j]) 227 | new_tgt.append(new_tgt_i) 228 | tgt = new_tgt 229 | # Check that problem doesn't already exist 230 | prob = [src, tgt] 231 | duplicate = False 232 | for p_prev in range(len(all_prob)): 233 | if np.array(prob).shape == np.array(all_prob[p_prev]).shape: 234 | if np.all(np.array(prob) == np.array(all_prob[p_prev])): 235 | duplicate = True 236 | # Add to problem subset 237 | if not duplicate: 238 | all_prob.append(prob) 239 | all_trans.append(trans) 240 | all_gen.append(generalize) 241 | all_src_letters.append(src_letters) 242 | all_tgt_letters.append(tgt_letters) 243 | return {'prob': all_prob, 'trans': all_trans, 'gen': all_gen, 'src_letters': all_src_letters, 'tgt_letters': all_tgt_letters} 244 | 245 | # Add problems to json and numpy file 246 | def save_prob(all_prob, prob_type_name, all_prob_js): 247 | # Convert to strings and save as json 248 | all_data = [] 249 | for p in range(len(all_prob['prob'])): 250 | # A 251 | prompt = '[' 252 | for i in range(len(all_prob['prob'][p][0][0])): 253 | prompt += str(all_prob['prob'][p][0][0][i]) 254 | if i < len(all_prob['prob'][p][0][0]) - 1: 255 | prompt += ' ' 256 | prompt += ']   [' 257 | # B 258 | for i in range(len(all_prob['prob'][p][0][1])): 259 | prompt += str(all_prob['prob'][p][0][1][i]) 260 | if i < len(all_prob['prob'][p][0][1]) - 1: 261 | prompt += ' ' 262 | prompt += ']
[' 263 | # C 264 | for i in range(len(all_prob['prob'][p][1][0])): 265 | prompt += str(all_prob['prob'][p][1][0][i]) 266 | if i < len(all_prob['prob'][p][1][0]) - 1: 267 | prompt += ' ' 268 | prompt += ']   [  ?  ]' 269 | # Add to dataset 270 | all_data.append({'prompt': prompt, 'prob_ind': p}) 271 | # Add to javascript data 272 | all_prob_js[prob_type_name] = all_data 273 | return all_prob_js 274 | 275 | # Split subset 276 | def split_subset(all_prob, N_split): 277 | all_prob_split = [] 278 | N_subset = int(len(all_prob['prob']) / N_split) 279 | for s in range(N_split): 280 | subset = {} 281 | for key in all_prob.keys(): 282 | subset[key] = [] 283 | for p in range(N_subset*s,N_subset*(s+1)): 284 | for key in all_prob.keys(): 285 | subset[key].append(all_prob[key][p]) 286 | all_prob_split.append(subset) 287 | return all_prob_split 288 | 289 | # Generate all basic analogies (zero generalizations) 290 | all_succ = gen_prob_subset(trans_allowed=['succ']) 291 | all_pred = gen_prob_subset(trans_allowed=['pred']) 292 | all_add_letter = gen_prob_subset(trans_allowed=['add_letter']) 293 | all_remove_redundant = gen_prob_subset(trans_allowed=['remove_redundant']) 294 | all_fix_alphabet = gen_prob_subset(trans_allowed=['fix_alphabet']) 295 | all_sort = gen_prob_subset(trans_allowed=['sort']) 296 | 297 | # Generate all problems with one generalization (randomly sample transformations) 298 | all_larger_int = gen_prob_subset(N_generalize=1, gen_allowed=['larger_int']) 299 | all_longer_targ = gen_prob_subset(N_generalize=1, gen_allowed=['longer_targ']) 300 | all_group = gen_prob_subset(N_generalize=1, gen_allowed=['group']) 301 | all_interleaved = gen_prob_subset(N_generalize=1, gen_allowed=['interleaved']) 302 | all_letter2num = gen_prob_subset(N_generalize=1, gen_allowed=['letter2num']) 303 | all_reverse = gen_prob_subset(N_generalize=1, gen_allowed=['reverse']) 304 | 305 | # Generate all problems with 2 and 3 generalizations 306 | all_2gen = gen_prob_subset(N_generalize=2, N_prob=600, gen_allowed=['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse']) 307 | all_2gen_split = split_subset(all_2gen, 6) 308 | all_3gen = gen_prob_subset(N_generalize=3, N_prob=600, gen_allowed=['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse']) 309 | all_3gen_split = split_subset(all_3gen, 6) 310 | 311 | # Generate problems involving generalization to real-world concepts 312 | all_realworld_succ = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['succ'], gen_allowed=['realworld']) 313 | all_realworld_pred = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['pred'], gen_allowed=['realworld']) 314 | all_realworld_add_letter = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['add_letter'], gen_allowed=['realworld']) 315 | all_realworld_sort = gen_prob_subset(standard_len=4, N_generalize=1, trans_allowed=['sort'], gen_allowed=['realworld']) 316 | 317 | # Combine problems 318 | all_prob_types = [all_succ, all_pred, all_add_letter, all_remove_redundant, all_fix_alphabet, all_sort, 319 | all_larger_int, all_longer_targ, all_group, all_interleaved, all_letter2num, all_reverse, 320 | all_2gen_split[0], all_2gen_split[1], all_2gen_split[2], all_2gen_split[3], all_2gen_split[4], all_2gen_split[5], 321 | all_3gen_split[0], all_3gen_split[1], all_3gen_split[2], all_3gen_split[3], all_3gen_split[4], all_3gen_split[5], 322 | all_realworld_succ, all_realworld_pred, all_realworld_add_letter, all_realworld_sort] 323 | all_prob_type_names = ['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort', 324 | 'larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse', 325 | '2gen_split1', '2gen_split2', '2gen_split3', '2gen_split4', '2gen_split5', '2gen_split6', 326 | '3gen_split1', '3gen_split2', '3gen_split3', '3gen_split4', '3gen_split5', '3gen_split6', 327 | 'realworld_succ', 'realworld_pred', 'realworld_add_letter', 'realworld_sort'] 328 | 329 | # Create js variable for all_problems 330 | all_prob_js = {} 331 | all_prob_np = {} 332 | for p in range(len(all_prob_types)): 333 | all_prob_js = save_prob(all_prob_types[p], all_prob_type_names[p], all_prob_js) 334 | all_prob_np[all_prob_type_names[p]] = all_prob_types[p] 335 | # Write numpy file 336 | np.savez('./all_prob.npz', all_prob=all_prob_np) 337 | # Convert to json strings 338 | all_prob_json_string = json.dumps(all_prob_js) 339 | # Write to js script 340 | js_fname = './all_prob.js' 341 | js_fid = open(js_fname, 'w') 342 | js_fid.write('var all_problems = ' + all_prob_json_string) 343 | js_fid.close() 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | -------------------------------------------------------------------------------- /letter_string/analyze_gpt3_letterstring.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from statsmodels.stats.proportion import proportion_confint 4 | from itertools import combinations 5 | import builtins 6 | import argparse 7 | import os 8 | 9 | def check_path(path): 10 | if not os.path.exists(path): 11 | os.mkdir(path) 12 | 13 | def hide_top_right(ax): 14 | ax.spines['right'].set_visible(False) 15 | ax.spines['top'].set_visible(False) 16 | ax.yaxis.set_ticks_position('left') 17 | ax.xaxis.set_ticks_position('bottom') 18 | 19 | # Settings 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--sentence', action='store_true', help="Present problem in sentence format.") 22 | parser.add_argument('--noprompt', action='store_true', help="Present problem without prompt.") 23 | args = parser.parse_args() 24 | 25 | # Load data 26 | if args.sentence: 27 | all_responses = np.load('./gpt3_letterstring_results_sentence.npz')['all_prob_type_responses'] 28 | elif args.noprompt: 29 | all_responses = np.load('./gpt3_letterstring_results_noprompt.npz')['all_prob_type_responses'] 30 | else: 31 | all_responses = np.load('./gpt3_letterstring_results.npz')['all_prob_type_responses'] 32 | N_prob_types = all_responses.shape[0] 33 | N_trials_per_prob_type = all_responses.shape[1] 34 | # Load problems 35 | all_prob = np.load('./all_prob.npz', allow_pickle=True)['all_prob'] 36 | prob_types = builtins.list(all_prob.item().keys()) 37 | 38 | # All possible combinations of transformations and generalizations 39 | trans = ['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort'] 40 | gen = ['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse'] 41 | all_trans = [] 42 | all_gen = [] 43 | for t in range(len(trans)): 44 | all_trans.append(trans[t]) 45 | all_gen.append([]) 46 | for t in range(len(trans)): 47 | for g in range(len(gen)): 48 | all_trans.append(trans[t]) 49 | all_gen.append([gen[g]]) 50 | all_2comb = builtins.list(combinations(np.arange(len(gen)),2)) 51 | for t in range(len(trans)): 52 | for c in range(len(all_2comb)): 53 | all_trans.append(trans[t]) 54 | all_gen.append([gen[all_2comb[c][0]], gen[all_2comb[c][1]]]) 55 | all_3comb = builtins.list(combinations(np.arange(len(gen)),3)) 56 | for t in range(len(trans)): 57 | for c in range(len(all_3comb)): 58 | all_trans.append(trans[t]) 59 | all_gen.append([gen[all_3comb[c][0]], gen[all_3comb[c][1]], gen[all_3comb[c][2]]]) 60 | N_prob_subtype = len(all_trans) 61 | subtype_counts = np.zeros(N_prob_subtype) 62 | 63 | # Calculate performance 64 | all_prob_type_correct_pred = [] 65 | all_prob_type_subtype = [] 66 | all_prob_N_gen = [] 67 | all_prob_realworld = [] 68 | for p in range(N_prob_types): 69 | all_correct_pred = [] 70 | all_subtype = [] 71 | all_N_gen = [] 72 | all_realworld = [] 73 | for t in range(N_trials_per_prob_type): 74 | response = all_responses[p][t] 75 | if args.sentence: 76 | if response[-1] == '.': 77 | response = response[:-1] 78 | if response[0] == ' ': 79 | response = response[1:] 80 | linebreak = True 81 | while linebreak: 82 | if response[:1] == '\n': 83 | response = response[1:] 84 | else: 85 | linebreak = False 86 | response_parsed = response.split(' ') 87 | correct_answer = all_prob.item()[prob_types[p]]['prob'][t][1][1] 88 | if np.array(correct_answer).astype(str).shape[0] == np.array(response_parsed).astype(str).shape[0]: 89 | correct_pred = np.all(np.array(correct_answer).astype(str) == np.array(response_parsed)) 90 | else: 91 | correct_pred = False 92 | all_correct_pred.append(correct_pred) 93 | else: 94 | if response[-1] == ']': 95 | response = response[:-1] 96 | response_parsed = response.split(' ') 97 | correct_answer = all_prob.item()[prob_types[p]]['prob'][t][1][1] 98 | if np.array(correct_answer).astype(str).shape[0] == np.array(response_parsed).astype(str).shape[0]: 99 | correct_pred = np.all(np.array(correct_answer).astype(str) == np.array(response_parsed)) 100 | else: 101 | correct_pred = False 102 | all_correct_pred.append(correct_pred) 103 | # Classify problem subtype 104 | for sp in range(N_prob_subtype): 105 | if all_prob.item()[prob_types[p]]['trans'][t] == all_trans[sp]: 106 | if len(all_prob.item()[prob_types[p]]['gen'][t]) == len(all_gen[sp]): 107 | if np.all(np.sort(all_prob.item()[prob_types[p]]['gen'][t]) == np.sort(np.array(all_gen[sp]))): 108 | all_subtype.append(sp) 109 | subtype_counts[sp] += 1 110 | # Number of generalizations 111 | N_gen = len(all_prob.item()[prob_types[p]]['gen'][t]) 112 | all_N_gen.append(N_gen) 113 | # Real-world problems 114 | if 'realworld' in all_prob.item()[prob_types[p]]['gen'][t]: 115 | all_realworld.append(1) 116 | else: 117 | all_realworld.append(0) 118 | all_prob_type_correct_pred.append(all_correct_pred) 119 | all_prob_N_gen.append(all_N_gen) 120 | all_prob_realworld.append(all_realworld) 121 | if len(all_subtype) > 0: 122 | all_prob_type_subtype.append(all_subtype) 123 | # Convert to arrays 124 | all_prob_type_correct_pred = np.array(all_prob_type_correct_pred) 125 | all_prob_type_subtype = np.array(all_prob_type_subtype) 126 | all_prob_N_gen = np.array(all_prob_N_gen) 127 | all_prob_realworld = np.array(all_prob_realworld) 128 | 129 | # Create directory for results 130 | if args.sentence: 131 | results_dir = './GPT3_results_sentence/' 132 | elif args.noprompt: 133 | results_dir = './GPT3_results_noprompt/' 134 | else: 135 | results_dir = './GPT3_results/' 136 | check_path(results_dir) 137 | 138 | # Save individual trial results 139 | np.savez(results_dir + 'ind_trial_results.npz', all_prob_type_correct_pred=all_prob_type_correct_pred, all_prob_type_subtype=all_prob_type_subtype, all_prob_N_gen=all_prob_N_gen, all_prob_realworld=all_prob_realworld) 140 | 141 | # Correlation analysis 142 | all_subtype_acc = [] 143 | for sp in range(N_prob_subtype): 144 | all_subtype_acc.append(all_prob_type_correct_pred[:all_prob_type_subtype.shape[0],:][all_prob_type_subtype == sp].mean()) 145 | np.savez(results_dir + 'prob_subtype_acc.npz', subtype_acc=all_subtype_acc, subtype_counts=subtype_counts) 146 | 147 | # Plot settings 148 | gpt3_color = 'darkslateblue' 149 | plot_fontsize = 10 150 | title_fontsize = 12 151 | axis_label_fontsize = 12 152 | bar_width = 0.8 153 | 154 | # Calculate accuracy for all zero-generalization problems 155 | all_zerogen_prob_types = ['succ', 'pred', 'add_letter', 'remove_redundant', 'fix_alphabet', 'sort'] 156 | all_zerogen_acc = [] 157 | all_zerogen_ci_lower = [] 158 | all_zerogen_ci_upper = [] 159 | for p in range(len(all_zerogen_prob_types)): 160 | correct_pred = np.array(all_prob_type_correct_pred[np.where(np.array(prob_types)==all_zerogen_prob_types[p])[0][0]]).astype(float) 161 | all_zerogen_acc.append(correct_pred.mean()) 162 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0]) 163 | all_zerogen_ci_lower.append(ci_lower) 164 | all_zerogen_ci_upper.append(ci_upper) 165 | all_zerogen_acc = np.array(all_zerogen_acc) 166 | all_zerogen_ci_lower = np.array(all_zerogen_ci_lower) 167 | all_zerogen_ci_upper = np.array(all_zerogen_ci_upper) 168 | all_zerogen_lower_err = all_zerogen_acc - all_zerogen_ci_lower 169 | all_zerogen_upper_err = all_zerogen_ci_upper - all_zerogen_acc 170 | all_zerogen_err = np.array([all_zerogen_lower_err, all_zerogen_upper_err]) 171 | # Sort based on accuracy 172 | rank_order = np.flip(np.argsort(all_zerogen_acc)) 173 | # Plot 174 | all_zerogen_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Remove\nredundant\nletter', 'Fix\nalphabetic\nsequence', 'Sort'] 175 | x_points = np.arange(len(all_zerogen_prob_types)) 176 | ax = plt.subplot(111) 177 | plt.bar(x_points, all_zerogen_acc[rank_order], yerr=all_zerogen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=bar_width) 178 | plt.ylim([0,1]) 179 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 180 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 181 | plt.xticks(x_points, np.array(all_zerogen_prob_type_names)[rank_order], fontsize=plot_fontsize) 182 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize) 183 | plt.title('Zero-generalization problems') 184 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False) 185 | hide_top_right(ax) 186 | plt.tight_layout() 187 | plt.savefig(results_dir + 'zerogen_acc.png', dpi=300, bbox_inches="tight") 188 | plt.close() 189 | # Save results 190 | np.savez(results_dir + 'zerogen_acc.npz', all_acc=all_zerogen_acc, all_err=all_zerogen_err) 191 | 192 | # Calculate all accuracy for one-generalization problems 193 | all_onegen_prob_types = ['larger_int', 'longer_targ', 'group', 'interleaved', 'letter2num', 'reverse'] 194 | all_onegen_acc = [] 195 | all_onegen_ci_lower = [] 196 | all_onegen_ci_upper = [] 197 | for p in range(len(all_onegen_prob_types)): 198 | correct_pred = np.array(all_prob_type_correct_pred[np.where(np.array(prob_types)==all_onegen_prob_types[p])[0][0]]).astype(float) 199 | all_onegen_acc.append(correct_pred.mean()) 200 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0]) 201 | all_onegen_ci_lower.append(ci_lower) 202 | all_onegen_ci_upper.append(ci_upper) 203 | all_onegen_acc = np.array(all_onegen_acc) 204 | all_onegen_ci_lower = np.array(all_onegen_ci_lower) 205 | all_onegen_ci_upper = np.array(all_onegen_ci_upper) 206 | all_onegen_lower_err = all_onegen_acc - all_onegen_ci_lower 207 | all_onegen_upper_err = all_onegen_ci_upper - all_onegen_acc 208 | all_onegen_err = np.array([all_onegen_lower_err, all_onegen_upper_err]) 209 | # Sort based on accuracy 210 | rank_order = np.flip(np.argsort(all_onegen_acc)) 211 | # Plot 212 | all_onegen_prob_type_names = ['Larger\ninterval', 'Longer\ntarget', 'Grouping', 'Interleaved\ndistractor', 'Letter-to-\nnumber', 'Reverse\norder'] 213 | x_points = np.arange(len(all_onegen_prob_types)) 214 | ax = plt.subplot(111) 215 | plt.bar(x_points, all_onegen_acc[rank_order], yerr=all_onegen_err[:,rank_order], color=gpt3_color, edgecolor='black', width=bar_width) 216 | plt.ylim([0,1]) 217 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 218 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 219 | plt.xticks(x_points, np.array(all_onegen_prob_type_names)[rank_order], fontsize=plot_fontsize) 220 | plt.xlabel('Generalization type', fontsize=axis_label_fontsize) 221 | plt.title('One-generalization problems') 222 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False) 223 | hide_top_right(ax) 224 | plt.tight_layout() 225 | plt.savefig(results_dir + 'onegen_acc.png', dpi=300, bbox_inches="tight") 226 | plt.close() 227 | # Save results 228 | np.savez(results_dir + 'onegen_acc.npz', all_acc=all_onegen_acc, all_err=all_onegen_err) 229 | 230 | # Calculate accuracy by number of generalizations 231 | gen_ind = [[0,6], [6,12], [12,18], [18,24]] 232 | all_gen_acc = [] 233 | all_gen_ci_lower = [] 234 | all_gen_ci_upper = [] 235 | for p in range(len(gen_ind)): 236 | correct_pred = np.array(all_prob_type_correct_pred[gen_ind[p][0]:gen_ind[p][1]]).flatten().astype(float) 237 | acc = correct_pred.mean() 238 | all_gen_acc.append(acc) 239 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0]) 240 | all_gen_ci_lower.append(ci_lower) 241 | all_gen_ci_upper.append(ci_upper) 242 | all_gen_ac = np.array(all_gen_acc) 243 | all_gen_ci_lower = np.array(all_gen_ci_lower) 244 | all_gen_ci_upper = np.array(all_gen_ci_upper) 245 | all_gen_lower_err = all_gen_acc - all_gen_ci_lower 246 | all_gen_upper_err = all_gen_ci_upper - all_gen_acc 247 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 248 | # Plot 249 | all_gen_prob_type_names = np.arange(len(gen_ind)).astype(str) 250 | x_points = np.arange(len(gen_ind)) 251 | ax = plt.subplot(111) 252 | plt.bar(x_points, all_gen_acc, yerr=all_gen_err, color=gpt3_color, edgecolor='black', width=bar_width) 253 | plt.ylim([0,1]) 254 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 255 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 256 | plt.xticks(x_points, all_gen_prob_type_names, fontsize=plot_fontsize) 257 | plt.xlabel('Number of generalizations', fontsize=axis_label_fontsize) 258 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False) 259 | hide_top_right(ax) 260 | plt.tight_layout() 261 | plt.savefig(results_dir + 'all_gen_acc.png', dpi=300, bbox_inches="tight") 262 | plt.close() 263 | # Save results 264 | np.savez(results_dir + 'all_gen_acc.npz', all_acc=all_gen_acc, all_err=all_gen_err) 265 | 266 | # Calculate accuracy for all real-world concept problems 267 | all_realworld_prob_types = ['realworld_succ', 'realworld_pred', 'realworld_add_letter', 'realworld_sort'] 268 | all_realworld_acc = [] 269 | all_realworld_ci_lower = [] 270 | all_realworld_ci_upper = [] 271 | for p in range(len(all_realworld_prob_types)): 272 | correct_pred = np.array(all_prob_type_correct_pred[np.where(np.array(prob_types)==all_realworld_prob_types[p])[0][0]]).astype(float) 273 | all_realworld_acc.append(correct_pred.mean()) 274 | ci_lower, ci_upper = proportion_confint(correct_pred.sum(), correct_pred.shape[0]) 275 | all_realworld_ci_lower.append(ci_lower) 276 | all_realworld_ci_upper.append(ci_upper) 277 | all_realworld_acc = np.array(all_realworld_acc) 278 | all_realworld_ci_lower = np.array(all_realworld_ci_lower) 279 | all_realworld_ci_upper = np.array(all_realworld_ci_upper) 280 | all_realworld_lower_err = all_realworld_acc - all_realworld_ci_lower 281 | all_realworld_upper_err = all_realworld_ci_upper - all_realworld_acc 282 | all_realworld_err = np.array([all_realworld_lower_err, all_realworld_upper_err]) 283 | # Sort based on accuracy 284 | rank_order = np.flip(np.argsort(all_realworld_acc)) 285 | # Plot 286 | all_realworld_prob_type_names = ['Successor', 'Predecessor', 'Extend\nsequence', 'Sort'] 287 | x_points = np.arange(len(all_realworld_prob_types)) 288 | ax = plt.subplot(111) 289 | plt.bar(x_points, all_realworld_acc[rank_order], yerr=all_realworld_err[:,rank_order], color=gpt3_color, edgecolor='black', width=bar_width) 290 | plt.ylim([0,1]) 291 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 292 | plt.ylabel('Generative accuracy', fontsize=axis_label_fontsize) 293 | plt.xticks(x_points, np.array(all_realworld_prob_type_names)[rank_order], fontsize=plot_fontsize) 294 | plt.xlabel('Transformation type', fontsize=axis_label_fontsize) 295 | plt.title('Real-world concept problems') 296 | plt.legend(['GPT-3'],fontsize=plot_fontsize,frameon=False) 297 | hide_top_right(ax) 298 | plt.tight_layout() 299 | plt.savefig(results_dir + 'realworld_acc.png', dpi=300, bbox_inches="tight") 300 | plt.close() 301 | # Save results 302 | np.savez(results_dir + 'realworld_acc.npz', all_acc=all_realworld_acc, all_err=all_realworld_err) 303 | -------------------------------------------------------------------------------- /digit_mat/analyze_gpt3_exp1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import builtins 3 | from statsmodels.stats.proportion import proportion_confint 4 | import os 5 | 6 | def check_path(path): 7 | if not os.path.exists(path): 8 | os.mkdir(path) 9 | 10 | def hide_top_right(ax): 11 | ax.spines['right'].set_visible(False) 12 | ax.spines['top'].set_visible(False) 13 | ax.yaxis.set_ticks_position('left') 14 | ax.xaxis.set_ticks_position('bottom') 15 | 16 | # Load data 17 | all_data = np.load('./gpt_matprob_results.npz', allow_pickle=True) 18 | MC_correct_pred = all_data['all_MC_correct_pred'] 19 | gen_correct_pred = all_data['all_gen_correct_pred'] 20 | all_prob_types = builtins.list(MC_correct_pred.item().keys()) 21 | 22 | ## Analyze by major problem type 23 | correct_pred = {'combined_gen': [], 24 | 'combined_MC': [], 25 | 'one_rule_gen': [], 26 | 'one_rule_MC': [], 27 | 'two_rule_gen': [], 28 | 'two_rule_MC': [], 29 | 'three_rule_gen': [], 30 | 'three_rule_MC': [], 31 | 'logic_rule_gen': [], 32 | 'logic_rule_MC': []} 33 | for prob_type in all_prob_types: 34 | correct_pred['combined_gen'].append(gen_correct_pred.item()[prob_type]) 35 | correct_pred['combined_MC'].append(MC_correct_pred.item()[prob_type]) 36 | if 'constant' in prob_type or 'dist3' in prob_type or 'prog' in prob_type: 37 | correct_pred['one_rule_gen'].append(gen_correct_pred.item()[prob_type]) 38 | correct_pred['one_rule_MC'].append(MC_correct_pred.item()[prob_type]) 39 | elif 'two_rule' in prob_type: 40 | correct_pred['two_rule_gen'].append(gen_correct_pred.item()[prob_type]) 41 | correct_pred['two_rule_MC'].append(MC_correct_pred.item()[prob_type]) 42 | elif 'three_rule' in prob_type: 43 | correct_pred['three_rule_gen'].append(gen_correct_pred.item()[prob_type]) 44 | correct_pred['three_rule_MC'].append(MC_correct_pred.item()[prob_type]) 45 | elif 'union' in prob_type or 'AND' in prob_type or 'XOR' in prob_type: 46 | correct_pred['logic_rule_gen'].append(gen_correct_pred.item()[prob_type]) 47 | correct_pred['logic_rule_MC'].append(MC_correct_pred.item()[prob_type]) 48 | # Convert to arrays 49 | for key in correct_pred.keys(): 50 | correct_pred[key] = np.concatenate(correct_pred[key]) 51 | # Calculate accuracy and confidence intervals 52 | all_acc = {} 53 | all_ci_lower = {} 54 | all_ci_upper = {} 55 | for key in correct_pred.keys(): 56 | all_acc[key] = correct_pred[key].mean() 57 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 58 | 59 | # Directory for saving results 60 | results_dir = './exp1_GPT3_data/' 61 | check_path(results_dir) 62 | 63 | # All problems 64 | np.savez(results_dir + 'all_prob.npz', all_gen=correct_pred['combined_gen'], all_MC=correct_pred['combined_MC']) 65 | 66 | # Major problem types 67 | # Generative 68 | all_gen_acc = np.array([all_acc['one_rule_gen'], all_acc['two_rule_gen'], all_acc['three_rule_gen'], all_acc['logic_rule_gen']]) 69 | all_gen_lower_ci = np.array([all_ci_lower['one_rule_gen'], all_ci_lower['two_rule_gen'], all_ci_lower['three_rule_gen'], all_ci_lower['logic_rule_gen']]) 70 | all_gen_upper_ci = np.array([all_ci_upper['one_rule_gen'], all_ci_upper['two_rule_gen'], all_ci_upper['three_rule_gen'], all_ci_upper['logic_rule_gen']]) 71 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci 72 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc 73 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 74 | np.savez(results_dir + 'all_probcat_gen_acc.npz', acc=all_gen_acc, err=all_gen_err) 75 | # Multiple-choice 76 | all_MC_acc = np.array([all_acc['one_rule_MC'], all_acc['two_rule_MC'], all_acc['three_rule_MC'], all_acc['logic_rule_MC']]) 77 | all_MC_lower_ci = np.array([all_ci_lower['one_rule_MC'], all_ci_lower['two_rule_MC'], all_ci_lower['three_rule_MC'], all_ci_lower['logic_rule_MC']]) 78 | all_MC_upper_ci = np.array([all_ci_upper['one_rule_MC'], all_ci_upper['two_rule_MC'], all_ci_upper['three_rule_MC'], all_ci_upper['logic_rule_MC']]) 79 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci 80 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc 81 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 82 | np.savez(results_dir + 'all_probcat_MC_acc.npz', acc=all_MC_acc, err=all_MC_err) 83 | 84 | ## Relational complexity analysis (controlling for number of unique rules) 85 | N_unique_rules_2rule_prob = np.load('./N_unique_rules_2rule_prob.npz')['N_unique_rules'] 86 | N_unique_rules_3rule_prob = np.load('./N_unique_rules_3rule_prob.npz')['N_unique_rules'] 87 | correct_pred = {'gen_2rule_1unique': [], 88 | 'MC_2rule_1unique': [], 89 | 'gen_2rule_2unique': [], 90 | 'MC_2rule_2unique': [], 91 | 'gen_3rule_1unique': [], 92 | 'MC_3rule_1unique': [], 93 | 'gen_3rule_2unique': [], 94 | 'MC_3rule_2unique': [], 95 | 'gen_3rule_3unique': [], 96 | 'MC_3rule_3unique': []} 97 | for prob_type in all_prob_types: 98 | if 'two_rule' in prob_type: 99 | if N_unique_rules_2rule_prob[int(prob_type[-1])] == 1: 100 | correct_pred['gen_2rule_1unique'].append(gen_correct_pred.item()[prob_type]) 101 | correct_pred['MC_2rule_1unique'].append(MC_correct_pred.item()[prob_type]) 102 | elif N_unique_rules_2rule_prob[int(prob_type[-1])] == 2: 103 | correct_pred['gen_2rule_2unique'].append(gen_correct_pred.item()[prob_type]) 104 | correct_pred['MC_2rule_2unique'].append(MC_correct_pred.item()[prob_type]) 105 | elif 'three_rule' in prob_type: 106 | if N_unique_rules_3rule_prob[int(prob_type[-1])] == 1: 107 | correct_pred['gen_3rule_1unique'].append(gen_correct_pred.item()[prob_type]) 108 | correct_pred['MC_3rule_1unique'].append(MC_correct_pred.item()[prob_type]) 109 | elif N_unique_rules_3rule_prob[int(prob_type[-1])] == 2: 110 | correct_pred['gen_3rule_2unique'].append(gen_correct_pred.item()[prob_type]) 111 | correct_pred['MC_3rule_2unique'].append(MC_correct_pred.item()[prob_type]) 112 | elif N_unique_rules_3rule_prob[int(prob_type[-1])] == 3: 113 | correct_pred['gen_3rule_3unique'].append(gen_correct_pred.item()[prob_type]) 114 | correct_pred['MC_3rule_3unique'].append(MC_correct_pred.item()[prob_type]) 115 | # Convert to arrays 116 | for key in correct_pred.keys(): 117 | correct_pred[key] = np.concatenate(correct_pred[key]) 118 | # Calculate accuracy and confidence intervals 119 | all_acc = {} 120 | all_ci_lower = {} 121 | all_ci_upper = {} 122 | for key in correct_pred.keys(): 123 | all_acc[key] = correct_pred[key].mean() 124 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 125 | 126 | # Save results 127 | # Two-rule problems 128 | # Generative 129 | all_gen_acc = np.array([all_acc['gen_2rule_1unique'], all_acc['gen_2rule_2unique']]) 130 | all_gen_ci_lower = np.array([all_ci_lower['gen_2rule_1unique'], all_ci_lower['gen_2rule_2unique']]) 131 | all_gen_ci_upper = np.array([all_ci_upper['gen_2rule_1unique'], all_ci_upper['gen_2rule_2unique']]) 132 | all_gen_lower_err = all_gen_acc - all_gen_ci_lower 133 | all_gen_upper_err = all_gen_ci_upper - all_gen_acc 134 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 135 | np.savez(results_dir + 'tworule_prob_N_unique_rules_gen.npz', acc=all_gen_acc, err=all_gen_err) 136 | # Multiple-choice 137 | all_MC_acc = np.array([all_acc['MC_2rule_1unique'], all_acc['MC_2rule_2unique']]) 138 | all_MC_ci_lower = np.array([all_ci_lower['MC_2rule_1unique'], all_ci_lower['MC_2rule_2unique']]) 139 | all_MC_ci_upper = np.array([all_ci_upper['MC_2rule_1unique'], all_ci_upper['MC_2rule_2unique']]) 140 | all_MC_lower_err = all_MC_acc - all_MC_ci_lower 141 | all_MC_upper_err = all_MC_ci_upper - all_MC_acc 142 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 143 | np.savez(results_dir + 'tworule_prob_N_unique_rules_MC.npz', acc=all_MC_acc, err=all_MC_err) 144 | # Three-rule problems 145 | # Generative 146 | all_gen_acc = np.array([all_acc['gen_3rule_1unique'], all_acc['gen_3rule_2unique'], all_acc['gen_3rule_3unique']]) 147 | all_gen_ci_lower = np.array([all_ci_lower['gen_3rule_1unique'], all_ci_lower['gen_3rule_2unique'], all_ci_lower['gen_3rule_3unique']]) 148 | all_gen_ci_upper = np.array([all_ci_upper['gen_3rule_1unique'], all_ci_upper['gen_3rule_2unique'], all_ci_upper['gen_3rule_3unique']]) 149 | all_gen_lower_err = all_gen_acc - all_gen_ci_lower 150 | all_gen_upper_err = all_gen_ci_upper - all_gen_acc 151 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 152 | np.savez(results_dir + 'threerule_prob_N_unique_rules_gen.npz', acc=all_gen_acc, err=all_gen_err) 153 | # Multiple-choice 154 | all_MC_acc = np.array([all_acc['MC_3rule_1unique'], all_acc['MC_3rule_2unique'], all_acc['MC_3rule_3unique']]) 155 | all_MC_ci_lower = np.array([all_ci_lower['MC_3rule_1unique'], all_ci_lower['MC_3rule_2unique'], all_ci_lower['MC_3rule_3unique']]) 156 | all_MC_ci_upper = np.array([all_ci_upper['MC_3rule_1unique'], all_ci_upper['MC_3rule_2unique'], all_ci_upper['MC_3rule_3unique']]) 157 | all_MC_lower_err = all_MC_acc - all_MC_ci_lower 158 | all_MC_upper_err = all_MC_ci_upper - all_MC_acc 159 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 160 | np.savez(results_dir + 'threerule_prob_N_unique_rules_MC.npz', acc=all_MC_acc, err=all_MC_err) 161 | 162 | ## Compare problems with vs. without progression rule (two-rule problems) 163 | correct_pred = {'tworule_prog_gen': [], 164 | 'tworule_noprog_gen': []} 165 | for prob_type in all_prob_types: 166 | if prob_type == 'two_rule_comb2' or prob_type == 'two_rule_comb4' or prob_type == 'two_rule_comb5': 167 | correct_pred['tworule_prog_gen'].append(gen_correct_pred.item()[prob_type]) 168 | elif prob_type == 'two_rule_comb0' or prob_type == 'two_rule_comb1' or prob_type == 'two_rule_comb3': 169 | correct_pred['tworule_noprog_gen'].append(gen_correct_pred.item()[prob_type]) 170 | # Convert to arrays 171 | for key in correct_pred.keys(): 172 | correct_pred[key] = np.concatenate(correct_pred[key]) 173 | # Calculate accuracy and confidence intervals 174 | all_acc = {} 175 | all_ci_lower = {} 176 | all_ci_upper = {} 177 | for key in correct_pred.keys(): 178 | all_acc[key] = correct_pred[key].mean() 179 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 180 | 181 | # Save results 182 | all_gen_acc = np.array([all_acc['tworule_noprog_gen'], all_acc['tworule_prog_gen']]) 183 | all_gen_lower_ci = np.array([all_ci_lower['tworule_noprog_gen'], all_ci_lower['tworule_prog_gen']]) 184 | all_gen_upper_ci = np.array([all_ci_upper['tworule_noprog_gen'], all_ci_upper['tworule_prog_gen']]) 185 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci 186 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc 187 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 188 | np.savez(results_dir + 'tworule_prog_vs_noprog_gen_acc.npz', acc=all_gen_acc, err=all_gen_err) 189 | 190 | # Three major one-rule problem types 191 | correct_pred = {'constant_gen': [], 192 | 'constant_MC': [], 193 | 'dist3_gen': [], 194 | 'dist3_MC': [], 195 | 'prog_gen': [], 196 | 'prog_MC': []} 197 | for prob_type in all_prob_types: 198 | if 'constant' in prob_type: 199 | correct_pred['constant_gen'].append(gen_correct_pred.item()[prob_type]) 200 | correct_pred['constant_MC'].append(MC_correct_pred.item()[prob_type]) 201 | elif 'dist3' in prob_type: 202 | correct_pred['dist3_gen'].append(gen_correct_pred.item()[prob_type]) 203 | correct_pred['dist3_MC'].append(MC_correct_pred.item()[prob_type]) 204 | elif 'prog' in prob_type: 205 | correct_pred['prog_gen'].append(gen_correct_pred.item()[prob_type]) 206 | correct_pred['prog_MC'].append(MC_correct_pred.item()[prob_type]) 207 | # Convert to arrays 208 | for key in correct_pred.keys(): 209 | correct_pred[key] = np.concatenate(correct_pred[key]) 210 | # Calculate accuracy and confidence intervals 211 | all_acc = {} 212 | all_ci_lower = {} 213 | all_ci_upper = {} 214 | for key in correct_pred.keys(): 215 | all_acc[key] = correct_pred[key].mean() 216 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 217 | 218 | # Save results 219 | # Generative 220 | all_gen_acc = np.array([all_acc['constant_gen'], all_acc['dist3_gen'], all_acc['prog_gen']]) 221 | all_gen_lower_ci = np.array([all_ci_lower['constant_gen'], all_ci_lower['dist3_gen'], all_ci_lower['prog_gen']]) 222 | all_gen_upper_ci = np.array([all_ci_upper['constant_gen'], all_ci_upper['dist3_gen'], all_ci_upper['prog_gen']]) 223 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci 224 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc 225 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 226 | np.savez(results_dir + 'all_onerule_gen_acc.npz', acc=all_gen_acc, err=all_gen_err) 227 | # Multiple-choice 228 | all_MC_acc = np.array([all_acc['constant_MC'], all_acc['dist3_MC'], all_acc['prog_MC']]) 229 | all_MC_lower_ci = np.array([all_ci_lower['constant_MC'], all_ci_lower['dist3_MC'], all_ci_lower['prog_MC']]) 230 | all_MC_upper_ci = np.array([all_ci_upper['constant_MC'], all_ci_upper['dist3_MC'], all_ci_upper['prog_MC']]) 231 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci 232 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc 233 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 234 | np.savez(results_dir + 'all_onerule_MC_acc.npz', acc=all_MC_acc, err=all_MC_err) 235 | 236 | # Permuted vs. non-permuted logic problems 237 | correct_pred = {'aligned_gen': [], 238 | 'aligned_MC': [], 239 | 'permuted_gen': [], 240 | 'permuted_MC': []} 241 | for prob_type in all_prob_types: 242 | if 'union' in prob_type or 'AND' in prob_type or 'XOR' in prob_type: 243 | if 'permuted' in prob_type: 244 | correct_pred['permuted_gen'].append(gen_correct_pred.item()[prob_type]) 245 | correct_pred['permuted_MC'].append(MC_correct_pred.item()[prob_type]) 246 | else: 247 | correct_pred['aligned_gen'].append(gen_correct_pred.item()[prob_type]) 248 | correct_pred['aligned_MC'].append(MC_correct_pred.item()[prob_type]) 249 | # Convert to arrays 250 | for key in correct_pred.keys(): 251 | correct_pred[key] = np.concatenate(correct_pred[key]) 252 | # Calculate accuracy and confidence intervals 253 | all_acc = {} 254 | all_ci_lower = {} 255 | all_ci_upper = {} 256 | for key in correct_pred.keys(): 257 | all_acc[key] = correct_pred[key].mean() 258 | all_ci_lower[key], all_ci_upper[key] = proportion_confint(correct_pred[key].sum(), correct_pred[key].shape[0]) 259 | 260 | # Save results 261 | # Generative 262 | all_gen_acc = np.array([all_acc['aligned_gen'], all_acc['permuted_gen']]) 263 | all_gen_lower_ci = np.array([all_ci_lower['aligned_gen'], all_ci_lower['permuted_gen']]) 264 | all_gen_upper_ci = np.array([all_ci_upper['aligned_gen'], all_ci_upper['permuted_gen']]) 265 | all_gen_lower_err = all_gen_acc - all_gen_lower_ci 266 | all_gen_upper_err = all_gen_upper_ci - all_gen_acc 267 | all_gen_err = np.array([all_gen_lower_err, all_gen_upper_err]) 268 | np.savez(results_dir + 'aligned_vs_permuted_gen_acc.npz', acc=all_gen_acc, err=all_gen_err) 269 | # Multiple-choice 270 | all_MC_acc = np.array([all_acc['aligned_MC'], all_acc['permuted_MC']]) 271 | all_MC_lower_ci = np.array([all_ci_lower['aligned_MC'], all_ci_lower['permuted_MC']]) 272 | all_MC_upper_ci = np.array([all_ci_upper['aligned_MC'], all_ci_upper['permuted_MC']]) 273 | all_MC_lower_err = all_MC_acc - all_MC_lower_ci 274 | all_MC_upper_err = all_MC_upper_ci - all_MC_acc 275 | all_MC_err = np.array([all_MC_lower_err, all_MC_upper_err]) 276 | np.savez(results_dir + 'aligned_vs_permuted_MC_acc.npz', acc=all_MC_acc, err=all_MC_err) 277 | 278 | 279 | 280 | -------------------------------------------------------------------------------- /digit_mat/exp1_plot_GPT3_vs_human.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | def check_path(path): 6 | if not os.path.exists(path): 7 | os.mkdir(path) 8 | 9 | def hide_top_right(ax): 10 | ax.spines['right'].set_visible(False) 11 | ax.spines['top'].set_visible(False) 12 | ax.yaxis.set_ticks_position('left') 13 | ax.xaxis.set_ticks_position('bottom') 14 | 15 | # Digit matrices 16 | # 1-rule, 2-rule, 3-rule, and logic-rule problems 17 | # GPT-3: zero-shot 18 | # Humans: shuffled 19 | 20 | ## Analysis of major problem types 21 | 22 | # Load GPT-3 data 23 | gpt3_gen_acc = np.load('./exp1_GPT3_data/all_probcat_gen_acc.npz') 24 | gpt3_MC_acc = np.load('./exp1_GPT3_data/all_probcat_MC_acc.npz') 25 | gpt3_gen_acc_mn = gpt3_gen_acc['acc'] 26 | gpt3_gen_acc_err = gpt3_gen_acc['err'] 27 | gpt3_MC_acc_mn = gpt3_MC_acc['acc'] 28 | gpt3_MC_acc_err = gpt3_MC_acc['err'] 29 | 30 | # Load human data 31 | human_gen_acc = np.load('./exp1_behavioral_data/probcat_gen_acc_behavior.npz') 32 | human_MC_acc = np.load('./exp1_behavioral_data/probcat_MC_acc_behavior.npz') 33 | human_gen_acc_mn = human_gen_acc['acc'] 34 | human_gen_acc_err = human_gen_acc['err'] 35 | human_MC_acc_mn = human_MC_acc['acc'] 36 | human_MC_acc_err = human_MC_acc['err'] 37 | 38 | # Plot settings 39 | N_conds = gpt3_MC_acc_mn.shape[0] 40 | total_bar_width = 0.8 41 | ind_bar_width = total_bar_width / 2 42 | x_points = np.arange(N_conds) 43 | gpt3_color = 'darkslateblue' 44 | human_color = 'powderblue' 45 | plot_fontsize = 14 46 | title_fontsize = 16 47 | 48 | # Directory 49 | plot_dir = './exp1_GPT3_vs_human/' 50 | check_path(plot_dir) 51 | 52 | # Plot - generative 53 | ax = plt.subplot(111) 54 | plt.bar(x_points - 0.2, gpt3_gen_acc_mn, yerr=gpt3_gen_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 55 | plt.bar(x_points + 0.2, human_gen_acc_mn, yerr=human_gen_acc_err, color=human_color, edgecolor='black', width=ind_bar_width) 56 | plt.ylim([0,1]) 57 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 58 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 59 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', 'Logic'], fontsize=plot_fontsize) 60 | plt.xlabel('Problem type', fontsize=plot_fontsize) 61 | hide_top_right(ax) 62 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 63 | results_fname = plot_dir + 'gen_gpt3_vs_human.png' 64 | ax.set_aspect(2.5) 65 | plt.tight_layout() 66 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 67 | plt.close() 68 | 69 | # Plot - multiple-choice 70 | ax = plt.subplot(111) 71 | plt.bar(x_points - 0.2, gpt3_MC_acc_mn, yerr=gpt3_MC_acc_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 72 | plt.bar(x_points + 0.2, human_MC_acc_mn, yerr=human_MC_acc_err, color=human_color, edgecolor='black', width=ind_bar_width) 73 | plt.ylim([0,1]) 74 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 75 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize) 76 | plt.xticks(x_points, ['1-rule', '2-rule', '3-rule', 'Logic'], fontsize=plot_fontsize) 77 | plt.xlabel('Problem type', fontsize=plot_fontsize) 78 | hide_top_right(ax) 79 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 80 | results_fname = plot_dir + 'MC_gpt3_vs_human.png' 81 | ax.set_aspect(2.5) 82 | plt.tight_layout() 83 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 84 | plt.close() 85 | 86 | ## Rule type analysis for one-rule problems 87 | 88 | # Load GPT-3 one-rule data 89 | gpt3_gen_acc_onerule = np.load('./exp1_GPT3_data/all_onerule_gen_acc.npz') 90 | gpt3_MC_acc_onerule = np.load('./exp1_GPT3_data/all_onerule_MC_acc.npz') 91 | gpt3_gen_acc_onerule_mn = gpt3_gen_acc_onerule['acc'] 92 | gpt3_gen_acc_onerule_err = gpt3_gen_acc_onerule['err'] 93 | gpt3_MC_acc_onerule_mn = gpt3_MC_acc_onerule['acc'] 94 | gpt3_MC_acc_onerule_err = gpt3_MC_acc_onerule['err'] 95 | 96 | # Load human one-rule data 97 | human_gen_acc_onerule = np.load('./exp1_behavioral_data/probcat_gen_acc_behavior_onerule.npz') 98 | human_MC_acc_onerule = np.load('./exp1_behavioral_data/probcat_MC_acc_behavior_onerule.npz') 99 | human_gen_acc_onerule_mn = human_gen_acc_onerule['acc'] 100 | human_gen_acc_onerule_err = human_gen_acc_onerule['err'] 101 | human_MC_acc_onerule_mn = human_MC_acc_onerule['acc'] 102 | human_MC_acc_onerule_err = human_MC_acc_onerule['err'] 103 | 104 | # Plot settings 105 | N_conds = gpt3_MC_acc_onerule_mn.shape[0] 106 | x_points = np.arange(N_conds) 107 | 108 | # Plot - generative 109 | ax = plt.subplot(111) 110 | plt.bar(x_points - 0.2, gpt3_gen_acc_onerule_mn, yerr=gpt3_gen_acc_onerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 111 | plt.bar(x_points + 0.2, human_gen_acc_onerule_mn, yerr=human_gen_acc_onerule_err, color=human_color, edgecolor='black', width=ind_bar_width) 112 | plt.ylim([0,1]) 113 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 114 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 115 | plt.xticks(x_points, ['Constant', 'Distribution', 'Progression'], fontsize=plot_fontsize) 116 | plt.xlabel('Rule type', fontsize=plot_fontsize) 117 | hide_top_right(ax) 118 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 119 | plt.title('One-rule problems', fontsize=title_fontsize) 120 | results_fname = plot_dir + 'onerule_gen_gpt3_vs_human.png' 121 | ax.set_aspect(2.5) 122 | plt.tight_layout() 123 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 124 | plt.close() 125 | 126 | # Plot - multiple-choice 127 | ax = plt.subplot(111) 128 | plt.bar(x_points - 0.2, gpt3_MC_acc_onerule_mn, yerr=gpt3_MC_acc_onerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 129 | plt.bar(x_points + 0.2, human_MC_acc_onerule_mn, yerr=human_MC_acc_onerule_err, color=human_color, edgecolor='black', width=ind_bar_width) 130 | plt.ylim([0,1]) 131 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 132 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize) 133 | plt.xticks(x_points, ['Constant', 'Distribution', 'Progression'], fontsize=plot_fontsize) 134 | plt.xlabel('Rule type', fontsize=plot_fontsize) 135 | hide_top_right(ax) 136 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 137 | plt.title('One-rule problems', fontsize=title_fontsize) 138 | results_fname = plot_dir + 'onerule_MC_gpt3_vs_human.png' 139 | ax.set_aspect(2.5) 140 | plt.tight_layout() 141 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 142 | plt.close() 143 | 144 | ## Progression vs. no progression two-rule problems 145 | 146 | # Load GPT-3 one-rule data 147 | gpt3_gen_acc_tworule = np.load('./exp1_GPT3_data/tworule_prog_vs_noprog_gen_acc.npz') 148 | gpt3_gen_acc_tworule_mn = gpt3_gen_acc_tworule['acc'] 149 | gpt3_gen_acc_tworule_err = gpt3_gen_acc_tworule['err'] 150 | 151 | # Load human one-rule data 152 | human_gen_acc_tworule = np.load('./exp1_behavioral_data/probcat_gen_acc_behavior_prog_tworule.npz') 153 | human_gen_acc_tworule_mn = human_gen_acc_tworule['acc'] 154 | human_gen_acc_tworule_err = human_gen_acc_tworule['err'] 155 | 156 | # Plot settings 157 | N_conds = gpt3_gen_acc_tworule_mn.shape[0] 158 | x_points = np.arange(N_conds) 159 | 160 | # Plot - generative 161 | ax = plt.subplot(111) 162 | plt.bar(x_points - 0.2, gpt3_gen_acc_tworule_mn, yerr=gpt3_gen_acc_tworule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 163 | plt.bar(x_points + 0.2, human_gen_acc_tworule_mn, yerr=human_gen_acc_tworule_err, color=human_color, edgecolor='black', width=ind_bar_width) 164 | plt.ylim([0,1]) 165 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 166 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 167 | plt.xticks(x_points, ['No progression', 'Progression'], fontsize=plot_fontsize) 168 | plt.xlabel(' ', fontsize=plot_fontsize) 169 | hide_top_right(ax) 170 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.85,1)) 171 | plt.title('Two-rule problems', fontsize=title_fontsize) 172 | results_fname = plot_dir + 'tworule_prog_vs_noprog_gen_gpt3_vs_human.png' 173 | ax.set_aspect(2.5) 174 | plt.tight_layout() 175 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 176 | plt.close() 177 | 178 | ## Aligned vs. permuted logic problems 179 | 180 | # Load GPT-3 one-rule data 181 | gpt3_gen_acc_logic = np.load('./exp1_GPT3_data/aligned_vs_permuted_gen_acc.npz') 182 | gpt3_MC_acc_logic = np.load('./exp1_GPT3_data/aligned_vs_permuted_MC_acc.npz') 183 | gpt3_gen_acc_logic_mn = gpt3_gen_acc_logic['acc'] 184 | gpt3_gen_acc_logic_err = gpt3_gen_acc_logic['err'] 185 | gpt3_MC_acc_logic_mn = gpt3_MC_acc_logic['acc'] 186 | gpt3_MC_acc_logic_err = gpt3_MC_acc_logic['err'] 187 | 188 | # Load human one-rule data 189 | human_gen_acc_logic = np.load('./exp1_behavioral_data/aligned_vs_permuted_gen_acc_behavior.npz') 190 | human_MC_acc_logic = np.load('./exp1_behavioral_data/aligned_vs_permuted_MC_acc_behavior.npz') 191 | human_gen_acc_logic_mn = human_gen_acc_logic['acc'] 192 | human_gen_acc_logic_err = human_gen_acc_logic['err'] 193 | human_MC_acc_logic_mn = human_MC_acc_logic['acc'] 194 | human_MC_acc_logic_err = human_MC_acc_logic['err'] 195 | 196 | # Plot settings 197 | N_conds = gpt3_MC_acc_logic_mn.shape[0] 198 | x_points = np.arange(N_conds) 199 | 200 | # Plot - generative 201 | ax = plt.subplot(111) 202 | plt.bar(x_points - 0.2, gpt3_gen_acc_logic_mn, yerr=gpt3_gen_acc_logic_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 203 | plt.bar(x_points + 0.2, human_gen_acc_logic_mn, yerr=human_gen_acc_logic_err, color=human_color, edgecolor='black', width=ind_bar_width) 204 | plt.ylim([0,1]) 205 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 206 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 207 | plt.xticks(x_points, ['Aligned', 'Permuted'], fontsize=plot_fontsize) 208 | plt.xlabel(' ', fontsize=plot_fontsize) 209 | hide_top_right(ax) 210 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 211 | plt.title('Logic problems', fontsize=title_fontsize) 212 | results_fname = plot_dir + 'logic_aligned_vs_permuted_gen_gpt3_vs_human.png' 213 | ax.set_aspect(2.5) 214 | plt.tight_layout() 215 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 216 | plt.close() 217 | 218 | # Plot - multiple-choice 219 | ax = plt.subplot(111) 220 | plt.bar(x_points - 0.2, gpt3_MC_acc_logic_mn, yerr=gpt3_MC_acc_logic_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 221 | plt.bar(x_points + 0.2, human_MC_acc_logic_mn, yerr=human_MC_acc_logic_err, color=human_color, edgecolor='black', width=ind_bar_width) 222 | plt.ylim([0,1]) 223 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 224 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize) 225 | plt.xticks(x_points, ['Aligned', 'Permuted'], fontsize=plot_fontsize) 226 | plt.xlabel(' ', fontsize=plot_fontsize) 227 | hide_top_right(ax) 228 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 229 | plt.title('Logic problems', fontsize=title_fontsize) 230 | results_fname = plot_dir + 'logic_aligned_vs_permuted_MC_gpt3_vs_human.png' 231 | ax.set_aspect(2.5) 232 | plt.tight_layout() 233 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 234 | plt.close() 235 | 236 | ### Relational complexity analysis 237 | 238 | ## Two-rule problems 239 | 240 | # Load GPT-3 one-rule data 241 | gpt3_gen_acc_tworule = np.load('./exp1_GPT3_data/tworule_prob_N_unique_rules_gen.npz') 242 | gpt3_MC_acc_tworule = np.load('./exp1_GPT3_data/tworule_prob_N_unique_rules_MC.npz') 243 | gpt3_gen_acc_tworule_mn = gpt3_gen_acc_tworule['acc'] 244 | gpt3_gen_acc_tworule_err = gpt3_gen_acc_tworule['err'] 245 | gpt3_MC_acc_tworule_mn = gpt3_MC_acc_tworule['acc'] 246 | gpt3_MC_acc_tworule_err = gpt3_MC_acc_tworule['err'] 247 | 248 | # Load human one-rule data 249 | human_gen_acc_tworule = np.load('./exp1_behavioral_data/tworule_prob_N_unique_rules_gen.npz') 250 | human_MC_acc_tworule = np.load('./exp1_behavioral_data/tworule_prob_N_unique_rules_MC.npz') 251 | human_gen_acc_tworule_mn = human_gen_acc_tworule['acc'] 252 | human_gen_acc_tworule_err = human_gen_acc_tworule['err'] 253 | human_MC_acc_tworule_mn = human_MC_acc_tworule['acc'] 254 | human_MC_acc_tworule_err = human_MC_acc_tworule['err'] 255 | 256 | # Plot settings 257 | N_conds = gpt3_MC_acc_tworule_mn.shape[0] 258 | x_points = np.arange(N_conds) 259 | 260 | # Plot - generative 261 | ax = plt.subplot(111) 262 | plt.bar(x_points - 0.2, gpt3_gen_acc_tworule_mn, yerr=gpt3_gen_acc_tworule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 263 | plt.bar(x_points + 0.2, human_gen_acc_tworule_mn, yerr=human_gen_acc_tworule_err, color=human_color, edgecolor='black', width=ind_bar_width) 264 | plt.ylim([0,1]) 265 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 266 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 267 | plt.xticks(x_points, ['1','2'], fontsize=plot_fontsize) 268 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize) 269 | hide_top_right(ax) 270 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 271 | plt.title('Two-rule problems', fontsize=title_fontsize) 272 | results_fname = plot_dir + 'tworule_rel_complexity_gen_gpt3_vs_human.png' 273 | ax.set_aspect(2.5) 274 | plt.tight_layout() 275 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 276 | plt.close() 277 | 278 | # Plot - multiple-choice 279 | ax = plt.subplot(111) 280 | plt.bar(x_points - 0.2, gpt3_MC_acc_tworule_mn, yerr=gpt3_MC_acc_tworule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 281 | plt.bar(x_points + 0.2, human_MC_acc_tworule_mn, yerr=human_MC_acc_tworule_err, color=human_color, edgecolor='black', width=ind_bar_width) 282 | plt.ylim([0,1]) 283 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 284 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize) 285 | plt.xticks(x_points, ['1','2'], fontsize=plot_fontsize) 286 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize) 287 | hide_top_right(ax) 288 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(0.75,1)) 289 | plt.title('Two-rule problems', fontsize=title_fontsize) 290 | results_fname = plot_dir + 'tworule_rel_complexity_MC_gpt3_vs_human.png' 291 | ax.set_aspect(2.5) 292 | plt.tight_layout() 293 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 294 | plt.close() 295 | 296 | ## Three-rule problems 297 | 298 | # Load GPT-3 one-rule data 299 | gpt3_gen_acc_threerule = np.load('./exp1_GPT3_data/threerule_prob_N_unique_rules_gen.npz') 300 | gpt3_MC_acc_threerule = np.load('./exp1_GPT3_data/threerule_prob_N_unique_rules_MC.npz') 301 | gpt3_gen_acc_threerule_mn = gpt3_gen_acc_threerule['acc'] 302 | gpt3_gen_acc_threerule_err = gpt3_gen_acc_threerule['err'] 303 | gpt3_MC_acc_threerule_mn = gpt3_MC_acc_threerule['acc'] 304 | gpt3_MC_acc_threerule_err = gpt3_MC_acc_threerule['err'] 305 | 306 | # Load human one-rule data 307 | human_gen_acc_threerule = np.load('./exp1_behavioral_data/threerule_prob_N_unique_rules_gen.npz') 308 | human_MC_acc_threerule = np.load('./exp1_behavioral_data/threerule_prob_N_unique_rules_MC.npz') 309 | human_gen_acc_threerule_mn = human_gen_acc_threerule['acc'] 310 | human_gen_acc_threerule_err = human_gen_acc_threerule['err'] 311 | human_MC_acc_threerule_mn = human_MC_acc_threerule['acc'] 312 | human_MC_acc_threerule_err = human_MC_acc_threerule['err'] 313 | 314 | # Plot settings 315 | N_conds = gpt3_MC_acc_threerule_mn.shape[0] 316 | x_points = np.arange(N_conds) 317 | 318 | # Plot - generative 319 | ax = plt.subplot(111) 320 | plt.bar(x_points - 0.2, gpt3_gen_acc_threerule_mn, yerr=gpt3_gen_acc_threerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 321 | plt.bar(x_points + 0.2, human_gen_acc_threerule_mn, yerr=human_gen_acc_threerule_err, color=human_color, edgecolor='black', width=ind_bar_width) 322 | plt.ylim([0,1]) 323 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 324 | plt.ylabel('Generative accuracy', fontsize=plot_fontsize) 325 | plt.xticks(x_points, ['1','2','3'], fontsize=plot_fontsize) 326 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize) 327 | hide_top_right(ax) 328 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 329 | plt.title('Three-rule problems', fontsize=title_fontsize) 330 | results_fname = plot_dir + 'threerule_rel_complexity_gen_gpt3_vs_human.png' 331 | ax.set_aspect(2.5) 332 | plt.tight_layout() 333 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 334 | plt.close() 335 | 336 | # Plot - multiple-choice 337 | ax = plt.subplot(111) 338 | plt.bar(x_points - 0.2, gpt3_MC_acc_threerule_mn, yerr=gpt3_MC_acc_threerule_err, color=gpt3_color, edgecolor='black', width=ind_bar_width) 339 | plt.bar(x_points + 0.2, human_MC_acc_threerule_mn, yerr=human_MC_acc_threerule_err, color=human_color, edgecolor='black', width=ind_bar_width) 340 | plt.ylim([0,1]) 341 | plt.yticks([0,0.2,0.4,0.6,0.8,1],['0','0.2','0.4','0.6','0.8','1'], fontsize=plot_fontsize) 342 | plt.ylabel('Multiple choice accuracy', fontsize=plot_fontsize) 343 | plt.xticks(x_points, ['1','2','3'], fontsize=plot_fontsize) 344 | plt.xlabel('Number of unique rules', fontsize=plot_fontsize) 345 | hide_top_right(ax) 346 | plt.legend(['GPT-3','Human'],fontsize=plot_fontsize,frameon=False,bbox_to_anchor=(1.2,1)) 347 | plt.title('Three-rule problems', fontsize=title_fontsize) 348 | results_fname = plot_dir + 'threerule_rel_complexity_MC_gpt3_vs_human.png' 349 | ax.set_aspect(2.5) 350 | plt.tight_layout() 351 | plt.savefig(results_fname, dpi=300, bbox_inches="tight") 352 | plt.close() 353 | 354 | -------------------------------------------------------------------------------- /story_analogies/human_vs_gpt3_data.csv: -------------------------------------------------------------------------------- 1 | probID,subjID,correct_pred,analogy_vs_similarity,human_vs_gpt 2 | 1,1,1,0,0 3 | 2,1,1,0,0 4 | 3,1,1,0,0 5 | 4,1,1,1,0 6 | 5,1,1,1,0 7 | 6,1,1,1,0 8 | 7,1,1,0,0 9 | 8,1,1,1,0 10 | 9,1,1,1,0 11 | 10,1,1,1,0 12 | 11,1,1,1,0 13 | 12,1,1,0,0 14 | 13,1,0,0,0 15 | 14,1,1,0,0 16 | 15,1,1,1,0 17 | 16,1,1,0,0 18 | 17,1,1,1,0 19 | 18,1,1,0,0 20 | 1,2,1,0,0 21 | 2,2,1,0,0 22 | 3,2,1,1,0 23 | 4,2,1,0,0 24 | 5,2,1,1,0 25 | 6,2,1,0,0 26 | 7,2,1,1,0 27 | 8,2,1,0,0 28 | 9,2,1,1,0 29 | 10,2,1,1,0 30 | 11,2,1,0,0 31 | 12,2,1,0,0 32 | 13,2,1,1,0 33 | 14,2,1,1,0 34 | 15,2,1,1,0 35 | 16,2,1,0,0 36 | 17,2,1,0,0 37 | 18,2,1,1,0 38 | 1,3,0,0,0 39 | 2,3,1,0,0 40 | 3,3,1,0,0 41 | 4,3,1,0,0 42 | 5,3,1,1,0 43 | 6,3,0,1,0 44 | 7,3,1,0,0 45 | 8,3,1,0,0 46 | 9,3,1,1,0 47 | 10,3,0,1,0 48 | 11,3,0,0,0 49 | 12,3,1,1,0 50 | 13,3,1,0,0 51 | 14,3,0,1,0 52 | 15,3,1,1,0 53 | 16,3,0,0,0 54 | 17,3,1,1,0 55 | 18,3,1,1,0 56 | 1,4,1,0,0 57 | 2,4,1,0,0 58 | 3,4,1,0,0 59 | 4,4,1,1,0 60 | 5,4,1,1,0 61 | 6,4,1,0,0 62 | 7,4,1,0,0 63 | 8,4,1,1,0 64 | 9,4,1,0,0 65 | 10,4,1,1,0 66 | 11,4,1,1,0 67 | 12,4,1,0,0 68 | 13,4,1,1,0 69 | 14,4,1,1,0 70 | 15,4,1,0,0 71 | 16,4,1,1,0 72 | 17,4,1,1,0 73 | 18,4,1,0,0 74 | 1,5,1,0,0 75 | 2,5,1,0,0 76 | 3,5,1,0,0 77 | 4,5,1,1,0 78 | 5,5,1,0,0 79 | 6,5,1,0,0 80 | 7,5,1,0,0 81 | 8,5,1,0,0 82 | 9,5,1,1,0 83 | 10,5,1,1,0 84 | 11,5,1,1,0 85 | 12,5,1,1,0 86 | 13,5,1,1,0 87 | 14,5,1,1,0 88 | 15,5,1,1,0 89 | 16,5,1,0,0 90 | 17,5,1,1,0 91 | 18,5,1,0,0 92 | 1,6,1,0,0 93 | 2,6,1,1,0 94 | 3,6,1,1,0 95 | 4,6,1,1,0 96 | 5,6,1,0,0 97 | 6,6,1,1,0 98 | 7,6,1,1,0 99 | 8,6,1,1,0 100 | 9,6,1,1,0 101 | 10,6,1,0,0 102 | 11,6,1,0,0 103 | 12,6,1,0,0 104 | 13,6,1,0,0 105 | 14,6,1,0,0 106 | 15,6,1,0,0 107 | 16,6,1,1,0 108 | 17,6,1,0,0 109 | 18,6,1,1,0 110 | 1,7,1,0,0 111 | 2,7,1,0,0 112 | 3,7,1,1,0 113 | 4,7,1,0,0 114 | 5,7,1,1,0 115 | 6,7,1,1,0 116 | 7,7,1,1,0 117 | 8,7,1,0,0 118 | 9,7,1,0,0 119 | 10,7,1,1,0 120 | 11,7,1,1,0 121 | 12,7,1,0,0 122 | 13,7,1,0,0 123 | 14,7,1,1,0 124 | 15,7,1,1,0 125 | 16,7,1,0,0 126 | 17,7,1,1,0 127 | 18,7,1,0,0 128 | 1,8,1,0,0 129 | 2,8,1,0,0 130 | 3,8,1,0,0 131 | 4,8,1,1,0 132 | 5,8,1,0,0 133 | 6,8,1,1,0 134 | 7,8,1,1,0 135 | 8,8,1,1,0 136 | 9,8,1,0,0 137 | 10,8,1,1,0 138 | 11,8,1,1,0 139 | 12,8,1,1,0 140 | 13,8,1,0,0 141 | 14,8,1,0,0 142 | 15,8,1,0,0 143 | 16,8,1,1,0 144 | 17,8,1,1,0 145 | 18,8,1,0,0 146 | 1,9,0,0,0 147 | 2,9,0,1,0 148 | 3,9,0,1,0 149 | 4,9,1,0,0 150 | 5,9,1,1,0 151 | 6,9,1,1,0 152 | 7,9,0,0,0 153 | 8,9,1,0,0 154 | 9,9,0,1,0 155 | 10,9,0,1,0 156 | 11,9,1,1,0 157 | 12,9,1,0,0 158 | 13,9,1,1,0 159 | 14,9,0,0,0 160 | 15,9,1,1,0 161 | 16,9,1,0,0 162 | 17,9,0,0,0 163 | 18,9,1,0,0 164 | 1,10,1,0,0 165 | 2,10,1,1,0 166 | 3,10,1,0,0 167 | 4,10,1,1,0 168 | 5,10,1,0,0 169 | 6,10,0,0,0 170 | 7,10,0,0,0 171 | 8,10,1,1,0 172 | 9,10,1,0,0 173 | 10,10,0,1,0 174 | 11,10,1,1,0 175 | 12,10,1,1,0 176 | 13,10,1,1,0 177 | 14,10,1,0,0 178 | 15,10,1,0,0 179 | 16,10,1,0,0 180 | 17,10,1,1,0 181 | 18,10,1,1,0 182 | 1,11,1,1,0 183 | 2,11,1,1,0 184 | 3,11,1,1,0 185 | 4,11,1,0,0 186 | 5,11,1,0,0 187 | 6,11,1,1,0 188 | 7,11,1,0,0 189 | 8,11,1,0,0 190 | 9,11,1,1,0 191 | 10,11,1,0,0 192 | 11,11,1,1,0 193 | 12,11,1,1,0 194 | 13,11,0,0,0 195 | 14,11,0,0,0 196 | 15,11,1,0,0 197 | 16,11,1,1,0 198 | 17,11,1,0,0 199 | 18,11,1,1,0 200 | 1,12,0,1,0 201 | 2,12,1,0,0 202 | 3,12,1,0,0 203 | 4,12,1,0,0 204 | 5,12,1,1,0 205 | 6,12,1,0,0 206 | 7,12,1,1,0 207 | 8,12,1,1,0 208 | 9,12,1,1,0 209 | 10,12,1,0,0 210 | 11,12,1,0,0 211 | 12,12,1,1,0 212 | 13,12,1,0,0 213 | 14,12,1,1,0 214 | 15,12,1,1,0 215 | 16,12,1,0,0 216 | 17,12,1,0,0 217 | 18,12,1,1,0 218 | 1,13,1,1,0 219 | 2,13,0,0,0 220 | 3,13,1,1,0 221 | 4,13,1,1,0 222 | 5,13,1,0,0 223 | 6,13,1,0,0 224 | 7,13,1,1,0 225 | 8,13,0,0,0 226 | 9,13,1,1,0 227 | 10,13,1,0,0 228 | 11,13,1,1,0 229 | 12,13,1,1,0 230 | 13,13,0,0,0 231 | 14,13,1,0,0 232 | 15,13,1,0,0 233 | 16,13,1,1,0 234 | 17,13,1,0,0 235 | 18,13,1,1,0 236 | 1,14,1,1,0 237 | 2,14,1,1,0 238 | 3,14,1,0,0 239 | 4,14,1,0,0 240 | 5,14,1,0,0 241 | 6,14,1,1,0 242 | 7,14,1,0,0 243 | 8,14,1,0,0 244 | 9,14,1,0,0 245 | 10,14,0,1,0 246 | 11,14,1,1,0 247 | 12,14,0,1,0 248 | 13,14,0,1,0 249 | 14,14,1,0,0 250 | 15,14,1,1,0 251 | 16,14,1,1,0 252 | 17,14,1,0,0 253 | 18,14,1,0,0 254 | 1,15,1,1,0 255 | 2,15,1,1,0 256 | 3,15,1,1,0 257 | 4,15,1,0,0 258 | 5,15,1,1,0 259 | 6,15,1,1,0 260 | 7,15,1,0,0 261 | 8,15,1,1,0 262 | 9,15,1,0,0 263 | 10,15,1,1,0 264 | 11,15,1,0,0 265 | 12,15,1,0,0 266 | 13,15,1,0,0 267 | 14,15,1,1,0 268 | 15,15,1,0,0 269 | 16,15,1,0,0 270 | 17,15,1,0,0 271 | 18,15,1,1,0 272 | 1,16,1,0,0 273 | 2,16,1,1,0 274 | 3,16,1,0,0 275 | 4,16,1,1,0 276 | 5,16,1,1,0 277 | 6,16,1,1,0 278 | 7,16,1,1,0 279 | 8,16,0,0,0 280 | 9,16,1,0,0 281 | 10,16,1,0,0 282 | 11,16,1,1,0 283 | 12,16,1,1,0 284 | 13,16,1,1,0 285 | 14,16,1,0,0 286 | 15,16,1,0,0 287 | 16,16,1,1,0 288 | 17,16,1,0,0 289 | 18,16,0,0,0 290 | 1,17,0,0,0 291 | 2,17,0,0,0 292 | 3,17,1,0,0 293 | 4,17,1,1,0 294 | 5,17,0,1,0 295 | 6,17,1,0,0 296 | 7,17,0,0,0 297 | 8,17,1,0,0 298 | 9,17,1,0,0 299 | 10,17,1,1,0 300 | 11,17,0,0,0 301 | 12,17,1,0,0 302 | 13,17,0,1,0 303 | 14,17,1,1,0 304 | 15,17,1,1,0 305 | 16,17,0,1,0 306 | 17,17,1,1,0 307 | 18,17,1,1,0 308 | 1,18,0,0,0 309 | 2,18,1,1,0 310 | 3,18,1,1,0 311 | 4,18,1,1,0 312 | 5,18,1,1,0 313 | 6,18,0,0,0 314 | 7,18,1,1,0 315 | 8,18,1,1,0 316 | 9,18,0,1,0 317 | 10,18,1,0,0 318 | 11,18,1,0,0 319 | 12,18,1,0,0 320 | 13,18,1,0,0 321 | 14,18,1,0,0 322 | 15,18,1,1,0 323 | 16,18,1,1,0 324 | 17,18,1,0,0 325 | 18,18,1,0,0 326 | 1,19,1,1,0 327 | 2,19,1,0,0 328 | 3,19,1,1,0 329 | 4,19,1,1,0 330 | 5,19,1,0,0 331 | 6,19,0,0,0 332 | 7,19,1,1,0 333 | 8,19,1,1,0 334 | 9,19,1,1,0 335 | 10,19,1,0,0 336 | 11,19,1,0,0 337 | 12,19,1,0,0 338 | 13,19,0,0,0 339 | 14,19,1,0,0 340 | 15,19,1,0,0 341 | 16,19,1,1,0 342 | 17,19,1,1,0 343 | 18,19,1,1,0 344 | 1,20,1,0,0 345 | 2,20,1,1,0 346 | 3,20,1,0,0 347 | 4,20,1,0,0 348 | 5,20,1,1,0 349 | 6,20,1,0,0 350 | 7,20,1,1,0 351 | 8,20,1,1,0 352 | 9,20,1,1,0 353 | 10,20,0,0,0 354 | 11,20,1,1,0 355 | 12,20,1,1,0 356 | 13,20,1,1,0 357 | 14,20,1,0,0 358 | 15,20,1,0,0 359 | 16,20,1,0,0 360 | 17,20,1,1,0 361 | 18,20,1,0,0 362 | 1,21,1,0,0 363 | 2,21,0,1,0 364 | 3,21,1,0,0 365 | 4,21,1,0,0 366 | 5,21,1,0,0 367 | 6,21,1,1,0 368 | 7,21,1,0,0 369 | 8,21,0,0,0 370 | 9,21,1,1,0 371 | 10,21,0,0,0 372 | 11,21,1,1,0 373 | 12,21,1,1,0 374 | 13,21,1,1,0 375 | 14,21,0,1,0 376 | 15,21,1,1,0 377 | 16,21,1,0,0 378 | 17,21,1,1,0 379 | 18,21,0,0,0 380 | 1,22,1,1,0 381 | 2,22,1,0,0 382 | 3,22,1,1,0 383 | 4,22,1,1,0 384 | 5,22,1,0,0 385 | 6,22,1,1,0 386 | 7,22,1,1,0 387 | 8,22,1,0,0 388 | 9,22,1,1,0 389 | 10,22,1,0,0 390 | 11,22,0,1,0 391 | 12,22,1,0,0 392 | 13,22,1,1,0 393 | 14,22,1,1,0 394 | 15,22,1,0,0 395 | 16,22,1,0,0 396 | 17,22,1,0,0 397 | 18,22,1,0,0 398 | 1,23,0,1,0 399 | 2,23,0,1,0 400 | 3,23,0,1,0 401 | 4,23,1,0,0 402 | 5,23,1,0,0 403 | 6,23,1,0,0 404 | 7,23,1,1,0 405 | 8,23,0,0,0 406 | 9,23,1,0,0 407 | 10,23,1,0,0 408 | 11,23,1,1,0 409 | 12,23,1,1,0 410 | 13,23,1,1,0 411 | 14,23,1,0,0 412 | 15,23,1,1,0 413 | 16,23,1,0,0 414 | 17,23,1,1,0 415 | 18,23,0,0,0 416 | 1,24,1,0,0 417 | 2,24,1,1,0 418 | 3,24,1,0,0 419 | 4,24,1,1,0 420 | 5,24,1,1,0 421 | 6,24,1,0,0 422 | 7,24,1,0,0 423 | 8,24,1,1,0 424 | 9,24,1,0,0 425 | 10,24,0,0,0 426 | 11,24,1,1,0 427 | 12,24,1,1,0 428 | 13,24,0,0,0 429 | 14,24,1,0,0 430 | 15,24,1,1,0 431 | 16,24,1,1,0 432 | 17,24,1,1,0 433 | 18,24,1,0,0 434 | 1,25,0,0,0 435 | 2,25,1,0,0 436 | 3,25,1,1,0 437 | 4,25,1,1,0 438 | 5,25,1,0,0 439 | 6,25,1,1,0 440 | 7,25,1,0,0 441 | 8,25,1,0,0 442 | 9,25,1,1,0 443 | 10,25,1,0,0 444 | 11,25,1,1,0 445 | 12,25,1,1,0 446 | 13,25,1,0,0 447 | 14,25,1,0,0 448 | 15,25,1,1,0 449 | 16,25,1,1,0 450 | 17,25,1,1,0 451 | 18,25,1,0,0 452 | 1,26,0,1,0 453 | 2,26,0,0,0 454 | 3,26,1,0,0 455 | 4,26,0,0,0 456 | 5,26,1,0,0 457 | 6,26,1,0,0 458 | 7,26,1,1,0 459 | 8,26,1,1,0 460 | 9,26,0,1,0 461 | 10,26,1,0,0 462 | 11,26,1,1,0 463 | 12,26,1,1,0 464 | 13,26,1,0,0 465 | 14,26,1,1,0 466 | 15,26,1,1,0 467 | 16,26,1,1,0 468 | 17,26,1,0,0 469 | 18,26,0,0,0 470 | 1,27,0,0,0 471 | 2,27,0,1,0 472 | 3,27,1,0,0 473 | 4,27,1,1,0 474 | 5,27,1,0,0 475 | 6,27,0,0,0 476 | 7,27,1,1,0 477 | 8,27,1,1,0 478 | 9,27,1,1,0 479 | 10,27,1,0,0 480 | 11,27,1,0,0 481 | 12,27,1,1,0 482 | 13,27,1,0,0 483 | 14,27,0,1,0 484 | 15,27,0,1,0 485 | 16,27,1,1,0 486 | 17,27,1,0,0 487 | 18,27,1,0,0 488 | 1,28,1,0,0 489 | 2,28,1,1,0 490 | 3,28,1,0,0 491 | 4,28,1,0,0 492 | 5,28,1,1,0 493 | 6,28,1,0,0 494 | 7,28,1,1,0 495 | 8,28,1,1,0 496 | 9,28,1,1,0 497 | 10,28,1,1,0 498 | 11,28,1,0,0 499 | 12,28,1,0,0 500 | 13,28,1,0,0 501 | 14,28,1,1,0 502 | 15,28,1,0,0 503 | 16,28,1,1,0 504 | 17,28,1,1,0 505 | 18,28,1,0,0 506 | 1,29,1,0,0 507 | 2,29,1,0,0 508 | 3,29,1,1,0 509 | 4,29,1,1,0 510 | 5,29,1,0,0 511 | 6,29,1,1,0 512 | 7,29,1,1,0 513 | 8,29,1,1,0 514 | 9,29,1,1,0 515 | 10,29,1,0,0 516 | 11,29,1,1,0 517 | 12,29,1,0,0 518 | 13,29,1,1,0 519 | 14,29,1,0,0 520 | 15,29,1,0,0 521 | 16,29,1,0,0 522 | 17,29,1,0,0 523 | 18,29,1,1,0 524 | 1,30,1,0,0 525 | 2,30,1,0,0 526 | 3,30,1,1,0 527 | 4,30,1,1,0 528 | 5,30,1,1,0 529 | 6,30,1,0,0 530 | 7,30,0,0,0 531 | 8,30,1,0,0 532 | 9,30,1,1,0 533 | 10,30,1,0,0 534 | 11,30,1,0,0 535 | 12,30,1,0,0 536 | 13,30,1,1,0 537 | 14,30,1,1,0 538 | 15,30,1,1,0 539 | 16,30,1,1,0 540 | 17,30,1,0,0 541 | 18,30,1,1,0 542 | 1,31,1,1,0 543 | 2,31,1,0,0 544 | 3,31,1,0,0 545 | 4,31,1,1,0 546 | 5,31,1,0,0 547 | 6,31,1,1,0 548 | 7,31,1,0,0 549 | 8,31,1,0,0 550 | 9,31,1,1,0 551 | 10,31,1,0,0 552 | 11,31,1,1,0 553 | 12,31,1,0,0 554 | 13,31,1,1,0 555 | 14,31,1,0,0 556 | 15,31,1,1,0 557 | 16,31,1,1,0 558 | 17,31,1,0,0 559 | 18,31,0,1,0 560 | 1,32,1,1,0 561 | 2,32,1,1,0 562 | 3,32,1,0,0 563 | 4,32,1,0,0 564 | 5,32,1,0,0 565 | 6,32,0,0,0 566 | 7,32,1,1,0 567 | 8,32,1,1,0 568 | 9,32,1,0,0 569 | 10,32,1,0,0 570 | 11,32,1,0,0 571 | 12,32,1,1,0 572 | 13,32,1,1,0 573 | 14,32,1,1,0 574 | 15,32,0,0,0 575 | 16,32,1,1,0 576 | 17,32,1,0,0 577 | 18,32,1,1,0 578 | 1,33,1,0,0 579 | 2,33,1,0,0 580 | 3,33,1,1,0 581 | 4,33,1,0,0 582 | 5,33,1,1,0 583 | 6,33,1,1,0 584 | 7,33,1,0,0 585 | 8,33,1,1,0 586 | 9,33,1,1,0 587 | 10,33,1,0,0 588 | 11,33,1,0,0 589 | 12,33,1,1,0 590 | 13,33,1,0,0 591 | 14,33,1,1,0 592 | 15,33,1,1,0 593 | 16,33,1,0,0 594 | 17,33,1,1,0 595 | 18,33,1,0,0 596 | 1,34,1,0,0 597 | 2,34,1,0,0 598 | 3,34,1,0,0 599 | 4,34,1,0,0 600 | 5,34,1,1,0 601 | 6,34,1,1,0 602 | 7,34,1,1,0 603 | 8,34,1,0,0 604 | 9,34,1,0,0 605 | 10,34,1,1,0 606 | 11,34,1,1,0 607 | 12,34,1,1,0 608 | 13,34,1,0,0 609 | 14,34,1,1,0 610 | 15,34,1,0,0 611 | 16,34,1,1,0 612 | 17,34,1,0,0 613 | 18,34,1,1,0 614 | 1,35,1,1,0 615 | 2,35,1,1,0 616 | 3,35,1,1,0 617 | 4,35,1,1,0 618 | 5,35,1,1,0 619 | 6,35,1,1,0 620 | 7,35,0,0,0 621 | 8,35,1,0,0 622 | 9,35,0,0,0 623 | 10,35,0,0,0 624 | 11,35,1,0,0 625 | 12,35,0,0,0 626 | 13,35,1,1,0 627 | 14,35,1,0,0 628 | 15,35,1,0,0 629 | 16,35,1,1,0 630 | 17,35,1,1,0 631 | 18,35,1,0,0 632 | 1,36,1,1,0 633 | 2,36,1,1,0 634 | 3,36,1,0,0 635 | 4,36,1,1,0 636 | 5,36,1,0,0 637 | 6,36,1,0,0 638 | 7,36,1,1,0 639 | 8,36,1,0,0 640 | 9,36,1,0,0 641 | 10,36,0,1,0 642 | 11,36,1,1,0 643 | 12,36,1,0,0 644 | 13,36,0,0,0 645 | 14,36,1,1,0 646 | 15,36,1,0,0 647 | 16,36,1,1,0 648 | 17,36,1,0,0 649 | 18,36,0,1,0 650 | 1,37,1,1,0 651 | 2,37,0,0,0 652 | 3,37,1,0,0 653 | 4,37,1,1,0 654 | 5,37,1,1,0 655 | 6,37,1,1,0 656 | 7,37,1,1,0 657 | 8,37,1,0,0 658 | 9,37,1,0,0 659 | 10,37,0,1,0 660 | 11,37,1,1,0 661 | 12,37,1,0,0 662 | 13,37,1,0,0 663 | 14,37,1,0,0 664 | 15,37,1,1,0 665 | 16,37,1,0,0 666 | 17,37,1,0,0 667 | 18,37,1,1,0 668 | 1,38,1,1,0 669 | 2,38,1,1,0 670 | 3,38,1,1,0 671 | 4,38,1,0,0 672 | 5,38,1,1,0 673 | 6,38,1,0,0 674 | 7,38,1,0,0 675 | 8,38,1,1,0 676 | 9,38,1,0,0 677 | 10,38,1,1,0 678 | 11,38,1,0,0 679 | 12,38,1,0,0 680 | 13,38,1,0,0 681 | 14,38,1,1,0 682 | 15,38,1,0,0 683 | 16,38,1,0,0 684 | 17,38,1,1,0 685 | 18,38,1,1,0 686 | 1,39,1,0,0 687 | 2,39,1,1,0 688 | 3,39,0,1,0 689 | 4,39,1,0,0 690 | 5,39,1,1,0 691 | 6,39,1,0,0 692 | 7,39,0,1,0 693 | 8,39,0,0,0 694 | 9,39,1,1,0 695 | 10,39,1,1,0 696 | 11,39,1,1,0 697 | 12,39,1,0,0 698 | 13,39,0,0,0 699 | 14,39,1,0,0 700 | 15,39,1,0,0 701 | 16,39,1,1,0 702 | 17,39,1,1,0 703 | 18,39,0,0,0 704 | 1,40,1,0,0 705 | 2,40,1,0,0 706 | 3,40,1,0,0 707 | 4,40,1,1,0 708 | 5,40,1,1,0 709 | 6,40,1,0,0 710 | 7,40,0,0,0 711 | 8,40,1,1,0 712 | 9,40,1,0,0 713 | 10,40,1,1,0 714 | 11,40,1,0,0 715 | 12,40,1,0,0 716 | 13,40,0,1,0 717 | 14,40,1,1,0 718 | 15,40,1,1,0 719 | 16,40,1,1,0 720 | 17,40,1,1,0 721 | 18,40,1,0,0 722 | 1,41,1,0,0 723 | 2,41,1,0,0 724 | 3,41,1,1,0 725 | 4,41,1,0,0 726 | 5,41,1,0,0 727 | 6,41,1,0,0 728 | 7,41,1,1,0 729 | 8,41,1,1,0 730 | 9,41,1,0,0 731 | 10,41,1,1,0 732 | 11,41,1,0,0 733 | 12,41,1,0,0 734 | 13,41,1,1,0 735 | 14,41,1,1,0 736 | 15,41,1,1,0 737 | 16,41,1,0,0 738 | 17,41,1,1,0 739 | 18,41,1,1,0 740 | 1,42,1,1,0 741 | 2,42,1,0,0 742 | 3,42,1,0,0 743 | 4,42,1,0,0 744 | 5,42,1,1,0 745 | 6,42,1,0,0 746 | 7,42,1,0,0 747 | 8,42,1,1,0 748 | 9,42,1,1,0 749 | 10,42,1,0,0 750 | 11,42,1,1,0 751 | 12,42,1,1,0 752 | 13,42,1,0,0 753 | 14,42,1,0,0 754 | 15,42,1,1,0 755 | 16,42,1,0,0 756 | 17,42,1,1,0 757 | 18,42,1,1,0 758 | 1,43,0,1,0 759 | 2,43,1,1,0 760 | 3,43,0,1,0 761 | 4,43,1,0,0 762 | 5,43,1,0,0 763 | 6,43,1,0,0 764 | 7,43,1,0,0 765 | 8,43,1,1,0 766 | 9,43,0,0,0 767 | 10,43,1,1,0 768 | 11,43,1,0,0 769 | 12,43,1,0,0 770 | 13,43,1,1,0 771 | 14,43,1,1,0 772 | 15,43,1,1,0 773 | 16,43,1,1,0 774 | 17,43,1,0,0 775 | 18,43,1,0,0 776 | 1,44,1,1,0 777 | 2,44,1,0,0 778 | 3,44,0,0,0 779 | 4,44,1,0,0 780 | 5,44,0,0,0 781 | 6,44,0,0,0 782 | 7,44,0,1,0 783 | 8,44,0,1,0 784 | 9,44,0,0,0 785 | 10,44,0,0,0 786 | 11,44,1,1,0 787 | 12,44,1,1,0 788 | 13,44,0,1,0 789 | 14,44,1,0,0 790 | 15,44,1,1,0 791 | 16,44,1,1,0 792 | 17,44,1,1,0 793 | 18,44,1,0,0 794 | 1,45,1,1,0 795 | 2,45,1,0,0 796 | 3,45,1,0,0 797 | 4,45,1,1,0 798 | 5,45,1,1,0 799 | 6,45,1,1,0 800 | 7,45,1,1,0 801 | 8,45,1,0,0 802 | 9,45,1,0,0 803 | 10,45,0,1,0 804 | 11,45,1,0,0 805 | 12,45,1,0,0 806 | 13,45,1,1,0 807 | 14,45,1,1,0 808 | 15,45,1,1,0 809 | 16,45,1,0,0 810 | 17,45,1,0,0 811 | 18,45,1,0,0 812 | 1,46,1,1,0 813 | 2,46,0,1,0 814 | 3,46,1,0,0 815 | 4,46,1,0,0 816 | 5,46,1,1,0 817 | 6,46,1,1,0 818 | 7,46,1,0,0 819 | 8,46,0,0,0 820 | 9,46,1,0,0 821 | 10,46,1,1,0 822 | 11,46,1,0,0 823 | 12,46,1,1,0 824 | 13,46,1,1,0 825 | 14,46,1,0,0 826 | 15,46,1,0,0 827 | 16,46,1,0,0 828 | 17,46,1,1,0 829 | 18,46,1,1,0 830 | 1,47,1,1,0 831 | 2,47,1,0,0 832 | 3,47,1,0,0 833 | 4,47,1,1,0 834 | 5,47,1,1,0 835 | 6,47,1,0,0 836 | 7,47,0,1,0 837 | 8,47,1,0,0 838 | 9,47,1,1,0 839 | 10,47,0,0,0 840 | 11,47,1,0,0 841 | 12,47,1,1,0 842 | 13,47,1,0,0 843 | 14,47,1,1,0 844 | 15,47,0,1,0 845 | 16,47,1,0,0 846 | 17,47,1,1,0 847 | 18,47,1,0,0 848 | 1,48,0,1,0 849 | 2,48,0,0,0 850 | 3,48,0,1,0 851 | 4,48,1,0,0 852 | 5,48,1,0,0 853 | 6,48,0,1,0 854 | 7,48,1,0,0 855 | 8,48,1,1,0 856 | 9,48,1,0,0 857 | 10,48,1,0,0 858 | 11,48,1,1,0 859 | 12,48,1,1,0 860 | 13,48,0,0,0 861 | 14,48,1,1,0 862 | 15,48,0,0,0 863 | 16,48,1,1,0 864 | 17,48,1,1,0 865 | 18,48,0,0,0 866 | 1,49,0,0,0 867 | 2,49,0,0,0 868 | 3,49,1,0,0 869 | 4,49,1,1,0 870 | 5,49,1,1,0 871 | 6,49,1,0,0 872 | 7,49,1,0,0 873 | 8,49,1,1,0 874 | 9,49,1,0,0 875 | 10,49,1,1,0 876 | 11,49,0,0,0 877 | 12,49,0,1,0 878 | 13,49,1,1,0 879 | 14,49,1,0,0 880 | 15,49,0,1,0 881 | 16,49,1,1,0 882 | 17,49,1,0,0 883 | 18,49,1,1,0 884 | 1,50,1,0,0 885 | 2,50,1,1,0 886 | 3,50,1,0,0 887 | 4,50,1,0,0 888 | 5,50,1,1,0 889 | 6,50,0,1,0 890 | 7,50,1,0,0 891 | 8,50,1,0,0 892 | 9,50,1,0,0 893 | 10,50,0,0,0 894 | 11,50,1,1,0 895 | 12,50,1,0,0 896 | 13,50,1,1,0 897 | 14,50,0,1,0 898 | 15,50,1,1,0 899 | 16,50,1,1,0 900 | 17,50,1,1,0 901 | 18,50,0,0,0 902 | 1,51,1,1,0 903 | 2,51,1,0,0 904 | 3,51,1,1,0 905 | 4,51,1,0,0 906 | 5,51,1,1,0 907 | 6,51,1,1,0 908 | 7,51,1,1,0 909 | 8,51,1,1,0 910 | 9,51,0,0,0 911 | 10,51,0,0,0 912 | 11,51,1,1,0 913 | 12,51,1,0,0 914 | 13,51,1,0,0 915 | 14,51,1,1,0 916 | 15,51,1,0,0 917 | 16,51,1,0,0 918 | 17,51,1,0,0 919 | 18,51,1,1,0 920 | 1,52,1,1,0 921 | 2,52,0,0,0 922 | 3,52,1,0,0 923 | 4,52,1,0,0 924 | 5,52,1,0,0 925 | 6,52,0,0,0 926 | 7,52,1,1,0 927 | 8,52,1,1,0 928 | 9,52,0,1,0 929 | 10,52,1,0,0 930 | 11,52,1,1,0 931 | 12,52,1,0,0 932 | 13,52,0,1,0 933 | 14,52,1,0,0 934 | 15,52,1,1,0 935 | 16,52,0,1,0 936 | 17,52,1,1,0 937 | 18,52,1,0,0 938 | 1,53,1,1,0 939 | 2,53,1,1,0 940 | 3,53,1,0,0 941 | 4,53,1,1,0 942 | 5,53,1,0,0 943 | 6,53,1,1,0 944 | 7,53,1,1,0 945 | 8,53,1,1,0 946 | 9,53,1,0,0 947 | 10,53,0,0,0 948 | 11,53,1,1,0 949 | 12,53,1,0,0 950 | 13,53,1,0,0 951 | 14,53,1,0,0 952 | 15,53,1,0,0 953 | 16,53,1,1,0 954 | 17,53,1,1,0 955 | 18,53,1,0,0 956 | 1,54,1,1,0 957 | 2,54,1,0,0 958 | 3,54,1,1,0 959 | 4,54,1,0,0 960 | 5,54,1,0,0 961 | 6,54,1,1,0 962 | 7,54,1,0,0 963 | 8,54,1,1,0 964 | 9,54,1,1,0 965 | 10,54,1,0,0 966 | 11,54,1,1,0 967 | 12,54,1,1,0 968 | 13,54,1,0,0 969 | 14,54,1,0,0 970 | 15,54,1,0,0 971 | 16,54,1,0,0 972 | 17,54,1,1,0 973 | 18,54,1,1,0 974 | 1,55,0,0,1 975 | 2,55,1,0,1 976 | 3,55,0,0,1 977 | 4,55,1,0,1 978 | 5,55,0,0,1 979 | 6,55,1,0,1 980 | 7,55,0,0,1 981 | 8,55,0,0,1 982 | 9,55,1,0,1 983 | 10,55,0,0,1 984 | 11,55,1,0,1 985 | 12,55,1,0,1 986 | 13,55,0,0,1 987 | 14,55,1,0,1 988 | 15,55,0,0,1 989 | 16,55,1,0,1 990 | 17,55,1,0,1 991 | 18,55,0,0,1 992 | 1,55,1,0,1 993 | 2,55,1,0,1 994 | 3,55,0,0,1 995 | 4,55,1,0,1 996 | 5,55,1,0,1 997 | 6,55,1,0,1 998 | 7,55,1,0,1 999 | 8,55,0,0,1 1000 | 9,55,1,0,1 1001 | 10,55,1,0,1 1002 | 11,55,1,0,1 1003 | 12,55,1,0,1 1004 | 13,55,0,0,1 1005 | 14,55,1,0,1 1006 | 15,55,1,0,1 1007 | 16,55,1,0,1 1008 | 17,55,1,0,1 1009 | 18,55,1,0,1 1010 | 1,55,0,1,1 1011 | 2,55,1,1,1 1012 | 3,55,1,1,1 1013 | 4,55,1,1,1 1014 | 5,55,0,1,1 1015 | 6,55,1,1,1 1016 | 7,55,1,1,1 1017 | 8,55,1,1,1 1018 | 9,55,1,1,1 1019 | 10,55,1,1,1 1020 | 11,55,1,1,1 1021 | 12,55,1,1,1 1022 | 13,55,1,1,1 1023 | 14,55,1,1,1 1024 | 15,55,0,1,1 1025 | 16,55,1,1,1 1026 | 17,55,1,1,1 1027 | 18,55,0,1,1 1028 | 1,55,1,1,1 1029 | 2,55,1,1,1 1030 | 3,55,0,1,1 1031 | 4,55,1,1,1 1032 | 5,55,0,1,1 1033 | 6,55,1,1,1 1034 | 7,55,1,1,1 1035 | 8,55,1,1,1 1036 | 9,55,1,1,1 1037 | 10,55,1,1,1 1038 | 11,55,1,1,1 1039 | 12,55,1,1,1 1040 | 13,55,0,1,1 1041 | 14,55,1,1,1 1042 | 15,55,0,1,1 1043 | 16,55,1,1,1 1044 | 17,55,0,1,1 1045 | 18,55,1,1,1 -------------------------------------------------------------------------------- /digit_mat/gen_4_5_rule_problems.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import permutations, combinations_with_replacement 3 | import builtins 4 | import random 5 | from copy import deepcopy 6 | import os 7 | import json 8 | 9 | # Method for generating distractors 10 | def gen_distractor(all_prob): 11 | # Methods and transformations 12 | distractor_methods = ['other_element', 13 | 'other_element_transformed', 14 | 'correct_answer_transformed', 15 | 'previous_distractor_transformed', 16 | 'random_new_element'] 17 | transformations = np.array([1,2,-1,-2]) 18 | # Additional methods for multi-rule problems 19 | if all_prob.shape[-1] > 1: 20 | distractor_methods.append('correct_answer_permuted') 21 | distractor_methods.append('other_element_permuted') 22 | distractor_methods.append('previous_distractor_permuted') 23 | distractor_methods.append('combination_other_elements') 24 | distractor_methods.append('combination_previous_distractors') 25 | # Loop through all problems 26 | all_probtype_answer_choices = [] 27 | all_probtype_correct_ind = [] 28 | for t in range(all_prob.shape[0]): 29 | all_answer_choices = [] 30 | all_correct_ind = [] 31 | for p in range(all_prob.shape[1]): 32 | # Problem 33 | prob = all_prob[t][p] 34 | # Extract correct answer 35 | correct_answer = prob[2,2,:] 36 | # Other elements (besides correct answer) in problem 37 | prob_flat = prob.reshape(prob.shape[0]*prob.shape[1],prob.shape[2]) 38 | other_prob_elements = prob_flat[:-1,:] 39 | other_prob_elements = other_prob_elements[np.logical_not(np.all(other_prob_elements == np.expand_dims(correct_answer,0),1))] 40 | # Generate each distractor 41 | answer_choices = [] 42 | for d in range(7): 43 | valid_distractor = False 44 | while not valid_distractor: 45 | # Sample method 46 | method = distractor_methods[int(np.floor(np.random.rand() * len(distractor_methods)))] 47 | # Other element from problem 48 | if method == 'other_element': 49 | np.random.shuffle(other_prob_elements) 50 | distractor = deepcopy(other_prob_elements[0]) 51 | # Other element transformed 52 | elif method == 'other_element_transformed': 53 | np.random.shuffle(other_prob_elements) 54 | distractor = deepcopy(other_prob_elements[0]) 55 | transform_dim = int(np.floor(np.random.rand() * all_prob.shape[-1])) 56 | np.random.shuffle(transformations) 57 | distractor[transform_dim] += transformations[0] 58 | # Correct answer transformed 59 | elif method == 'correct_answer_transformed': 60 | transform_dim = int(np.floor(np.random.rand() * all_prob.shape[-1])) 61 | np.random.shuffle(transformations) 62 | distractor = deepcopy(correct_answer) 63 | distractor[transform_dim] += transformations[0] 64 | # Previous distractor transformed 65 | elif method == 'previous_distractor_transformed': 66 | random.shuffle(answer_choices) 67 | if len(answer_choices) > 0: 68 | distractor = deepcopy(answer_choices[0]) 69 | transform_dim = int(np.floor(np.random.rand() * all_prob.shape[-1])) 70 | np.random.shuffle(transformations) 71 | distractor[transform_dim] += transformations[0] 72 | else: 73 | distractor = np.ones(all_prob.shape[-1]).astype(int) * -1 74 | # Random new element 75 | elif method == 'random_new_element': 76 | distractor = np.floor(np.random.rand(all_prob.shape[-1]) * 10).astype(int) 77 | # Correct answer permuted 78 | elif method == 'correct_answer_permuted': 79 | valid_perm = builtins.list(permutations(np.arange(all_prob.shape[-1]),all_prob.shape[-1]))[1:] 80 | random.shuffle(valid_perm) 81 | distractor = deepcopy(correct_answer)[builtins.list(valid_perm[0])] 82 | # Other element permuted 83 | elif method == 'other_element_permuted': 84 | valid_perm = builtins.list(permutations(np.arange(all_prob.shape[-1]),all_prob.shape[-1]))[1:] 85 | random.shuffle(valid_perm) 86 | np.random.shuffle(other_prob_elements) 87 | other_element = other_prob_elements[0] 88 | distractor = deepcopy(other_element)[builtins.list(valid_perm[0])] 89 | # Previous distractor permuted 90 | elif method == 'previous_distractor_permuted': 91 | random.shuffle(answer_choices) 92 | if len(answer_choices) > 0: 93 | valid_perm = builtins.list(permutations(np.arange(all_prob.shape[-1]),all_prob.shape[-1]))[1:] 94 | random.shuffle(valid_perm) 95 | previous_distractor = answer_choices[0] 96 | distractor = deepcopy(previous_distractor)[builtins.list(valid_perm[0])] 97 | else: 98 | distractor = np.ones(all_prob.shape[-1]).astype(int) * -1 99 | # Combination of other elements 100 | elif method == 'combination_other_elements': 101 | distractor = [] 102 | for dim in range(all_prob.shape[-1]): 103 | other_prob_elements_copy = deepcopy(other_prob_elements) 104 | np.random.shuffle(other_prob_elements_copy) 105 | distractor.append(other_prob_elements_copy[0,dim]) 106 | distractor = np.array(distractor) 107 | # Combination of previous distractors 108 | elif method == 'combination_previous_distractors': 109 | random.shuffle(answer_choices) 110 | if len(answer_choices) > 1: 111 | distractor = [] 112 | for dim in range(all_prob.shape[-1]): 113 | answer_choices_copy = deepcopy(answer_choices) 114 | np.random.shuffle(answer_choices_copy) 115 | distractor.append(answer_choices_copy[0][dim]) 116 | distractor = np.array(distractor) 117 | else: 118 | distractor = np.ones(all_prob.shape[-1]).astype(int) * -1 119 | # Check to make sure distractor isn't invalid or doesn't already exist 120 | if np.all(distractor >= 0) and np.all(distractor <= 9) and not np.all(distractor == correct_answer): 121 | if len(answer_choices) == 0: 122 | answer_choices.append(distractor) 123 | valid_distractor = True 124 | else: 125 | if np.logical_not(np.any(np.all(np.array(answer_choices) == np.expand_dims(distractor,0),1))): 126 | answer_choices.append(distractor) 127 | valid_distractor = True 128 | # Add correct answer and shuffle 129 | answer_choices.append(correct_answer) 130 | answer_choices = np.array(answer_choices) 131 | shuffled_order = np.arange(8) 132 | np.random.shuffle(shuffled_order) 133 | correct_ind = np.where(shuffled_order == 7)[0][0] 134 | answer_choices = answer_choices[shuffled_order] 135 | # Add to list 136 | all_answer_choices.append(answer_choices) 137 | all_correct_ind.append(correct_ind) 138 | # Combine across problem types 139 | all_probtype_answer_choices.append(np.array(all_answer_choices)) 140 | all_probtype_correct_ind.append(np.array(all_correct_ind)) 141 | # Convert to arrays 142 | all_probtype_answer_choices = np.array(all_probtype_answer_choices) 143 | all_probtype_correct_ind = np.array(all_probtype_correct_ind) 144 | return all_probtype_answer_choices, all_probtype_correct_ind 145 | 146 | # Save problems as json and numpy files 147 | def save_prob(all_prob, all_answer_choices, all_correct_ind, prob_type_name, all_problems_np, all_problems_js, perm_invariant=False): 148 | # Add problems to numpy dict 149 | all_problems_np[prob_type_name] = {'prob': all_prob, 'answer_choices': all_answer_choices, 'correct_ind': all_correct_ind, 'perm_invariant': perm_invariant} 150 | # Convert to strings and save as json 151 | all_data = [] 152 | for p in range(all_prob.shape[0]): 153 | # Convert problem to string 154 | prompt = '' 155 | for r in range(3): 156 | for c in range(3): 157 | prompt += '[' 158 | if r == 2 and c == 2: 159 | for i in range(len(all_prob[p][1][c])): 160 | prompt += '  ' 161 | if i < (len(all_prob[p][1][c]) - 1): 162 | prompt += ' ' 163 | else: 164 | for i in range(len(all_prob[p][r][c])): 165 | if all_prob[p][r][c][i] == -1: 166 | prompt += '  ' 167 | else: 168 | prompt += str(all_prob[p][r][c][i]) 169 | 170 | if i < (len(all_prob[p][r][c]) - 1): 171 | prompt += ' ' 172 | prompt += ']' 173 | if c < 2: 174 | prompt += '   ' 175 | if r < 2 and c == 2: 176 | prompt += '
' 177 | # Convert choices to strings 178 | options = [] 179 | for a in range(8): 180 | option = '[' 181 | for i in range(len(all_answer_choices[p][a])): 182 | option += str(all_answer_choices[p][a][i]) 183 | if i < (len(all_answer_choices[p][a]) - 1): 184 | option += ' ' 185 | if len(all_answer_choices[p][a]) == 0: 186 | option += '  ' 187 | option += ']' 188 | options.append(option) 189 | # Add to dataset 190 | all_data.append({'prompt': prompt, 'options': options, 'correct': int(all_correct_ind[p]), 'prob_ind': p}) 191 | # Add to javascript data 192 | all_problems_js[prob_type_name] = all_data 193 | return all_problems_np, all_problems_js 194 | 195 | #### 4-rule and 5-rule problems: 196 | # - 1 for each of 15 4-rule combinations 197 | # - 1 for each of 21 5-rule combinations 198 | 199 | # Number of problems-per-category will either be N (below) or maximum number possible 200 | N_probs = 100 201 | 202 | # All 10choose3 permutations 203 | all_10c3_perm = np.array(builtins.list(permutations(np.arange(10),3))) 204 | 205 | # Constant 206 | all_constant = [] 207 | all_row_constant = [] 208 | all_col_constant = [] 209 | for p in range(all_10c3_perm.shape[0]): 210 | row_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][0], all_10c3_perm[p][0]], 211 | [all_10c3_perm[p][1], all_10c3_perm[p][1], all_10c3_perm[p][1]], 212 | [all_10c3_perm[p][2], all_10c3_perm[p][2], all_10c3_perm[p][2]]]) 213 | all_row_constant.append(row_prob) 214 | all_constant.append(row_prob) 215 | col_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]], 216 | [all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]], 217 | [all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]]]) 218 | all_constant.append(col_prob) 219 | all_col_constant.append(col_prob) 220 | all_constant = np.array(all_constant) 221 | 222 | # Distribution-of-3 223 | all_dist3 = [] 224 | all_dist3_diag1 = [] 225 | all_dist3_diag2 = [] 226 | for p in range(all_10c3_perm.shape[0]): 227 | diag1_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]], 228 | [all_10c3_perm[p][1], all_10c3_perm[p][2], all_10c3_perm[p][0]], 229 | [all_10c3_perm[p][2], all_10c3_perm[p][0], all_10c3_perm[p][1]]]) 230 | all_dist3_diag1.append(diag1_prob) 231 | all_dist3.append(diag1_prob) 232 | diag2_prob = np.array([[all_10c3_perm[p][0], all_10c3_perm[p][1], all_10c3_perm[p][2]], 233 | [all_10c3_perm[p][2], all_10c3_perm[p][0], all_10c3_perm[p][1]], 234 | [all_10c3_perm[p][1], all_10c3_perm[p][2], all_10c3_perm[p][0]]]) 235 | all_dist3_diag2.append(diag2_prob) 236 | all_dist3.append(diag2_prob) 237 | all_dist3 = np.array(all_dist3) 238 | all_dist3_diag1 = np.array(all_dist3_diag1) 239 | all_dist3_diag2 = np.array(all_dist3_diag2) 240 | # Select subset of distribution-of-3 problems 241 | np.random.shuffle(all_dist3_diag1) 242 | 243 | # Progression 244 | prog_size1 = np.array([np.arange(0,5), np.arange(1,6), np.arange(2,7), np.arange(3,8), np.arange(4,9), np.arange(5,10)]) 245 | prog_size1_reversed = np.fliplr(prog_size1) 246 | prog_size2 = np.array([np.arange(0,10,2), np.arange(1,11,2)]) 247 | prog_size2_reversed = np.fliplr(prog_size2) 248 | all_prog_range = np.concatenate([prog_size1, prog_size1_reversed, prog_size2, prog_size2_reversed], 0) 249 | size1_size2 = np.concatenate([np.zeros(prog_size1.shape[0] + prog_size1_reversed.shape[0]), np.ones(prog_size2.shape[0] + prog_size2_reversed.shape[0])]) 250 | all_prog = [] 251 | all_prog_size1 = [] 252 | all_prog_size2 = [] 253 | for p in range(all_prog_range.shape[0]): 254 | prog_prob = np.array([[all_prog_range[p][0], all_prog_range[p][1], all_prog_range[p][2]], 255 | [all_prog_range[p][1], all_prog_range[p][2], all_prog_range[p][3]], 256 | [all_prog_range[p][2], all_prog_range[p][3], all_prog_range[p][4]]]) 257 | all_prog.append(prog_prob) 258 | if size1_size2[p] == 0: 259 | all_prog_size1.append(prog_prob) 260 | elif size1_size2[p] == 1: 261 | all_prog_size2.append(prog_prob) 262 | all_prog = np.array(all_prog) 263 | 264 | # All 4-rule and 5-rule sets (combinations with replacement) 265 | all_4rule_comb = builtins.list(combinations_with_replacement(np.arange(3), 4)) 266 | all_5rule_comb = builtins.list(combinations_with_replacement(np.arange(3), 5)) 267 | # All 4-rule and 5-rule permutations (with replacement) 268 | # 4 rules 269 | all_4rule_perm = [] 270 | for r1 in range(3): 271 | for r2 in range(3): 272 | for r3 in range(3): 273 | for r4 in range(3): 274 | all_4rule_perm.append([r1, r2, r3, r4]) 275 | all_4rule_perm = np.array(all_4rule_perm) 276 | #5 rules 277 | all_5rule_perm = [] 278 | for r1 in range(3): 279 | for r2 in range(3): 280 | for r3 in range(3): 281 | for r4 in range(3): 282 | for r5 in range(3): 283 | all_5rule_perm.append([r1, r2, r3, r4, r5]) 284 | all_5rule_perm = np.array(all_5rule_perm) 285 | # Sort permutations by combination 286 | # 4 rules 287 | all_4rule_perm_sorted = [] 288 | for c in range(len(all_4rule_comb)): 289 | all_4rule_perm_sorted.append(all_4rule_perm[np.all(np.expand_dims(np.array(all_4rule_comb[c]),0) == np.sort(all_4rule_perm,1), 1)]) 290 | # 5 rules 291 | all_5rule_perm_sorted = [] 292 | for c in range(len(all_5rule_comb)): 293 | all_5rule_perm_sorted.append(all_5rule_perm[np.all(np.expand_dims(np.array(all_5rule_comb[c]),0) == np.sort(all_5rule_perm,1), 1)]) 294 | 295 | # Combine problem types 296 | prob_types = [all_constant, all_dist3, all_prog] 297 | 298 | # Generate 4-rule problems 299 | all_4rule_prob = [] 300 | for c in range(len(all_4rule_comb)): 301 | all_comb_prob = [] 302 | for p in range(N_probs): 303 | duplicate_prob = True 304 | while duplicate_prob: 305 | # Randomly sample permutation 306 | all_perm = all_4rule_perm_sorted[c] 307 | np.random.shuffle(all_perm) 308 | perm = all_perm[0] 309 | # Sample rule instances 310 | # Rule 1 311 | r1_ind = np.floor(np.random.rand() * prob_types[perm[0]].shape[0]).astype(int) 312 | r1 = prob_types[perm[0]][r1_ind] 313 | # Rule 2 314 | duplicate_rule = True 315 | while duplicate_rule: 316 | r2_ind = np.floor(np.random.rand() * prob_types[perm[1]].shape[0]).astype(int) 317 | r2 = prob_types[perm[1]][r2_ind] 318 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(2).astype(bool).flatten())]) 319 | # Rule 3 320 | duplicate_rule = True 321 | while duplicate_rule: 322 | r3_ind = np.floor(np.random.rand() * prob_types[perm[2]].shape[0]).astype(int) 323 | r3 = prob_types[perm[2]][r3_ind] 324 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(3).astype(bool).flatten())]) 325 | # Rule 4 326 | duplicate_rule = True 327 | while duplicate_rule: 328 | r4_ind = np.floor(np.random.rand() * prob_types[perm[3]].shape[0]).astype(int) 329 | r4 = prob_types[perm[3]][r4_ind] 330 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(4).astype(bool).flatten())]) 331 | # Combine rules 1-3 332 | prob = np.stack([r1,r2,r3,r4],2) 333 | # Check if duplicate 334 | duplicate_detected = False 335 | for i in range(len(all_comb_prob)): 336 | if np.all(prob == all_comb_prob[i]): 337 | duplicate_detected = True 338 | if not duplicate_detected: 339 | duplicate_prob = False 340 | all_comb_prob.append(prob) 341 | all_comb_prob = np.array(all_comb_prob) 342 | all_4rule_prob.append(all_comb_prob) 343 | all_4rule_prob = np.array(all_4rule_prob) 344 | # Generate distractors 345 | all_4rule_answer_choices, all_4rule_correct_ind = gen_distractor(all_4rule_prob) 346 | 347 | # Generate 5-rule problems 348 | all_5rule_prob = [] 349 | for c in range(len(all_5rule_comb)): 350 | all_comb_prob = [] 351 | for p in range(N_probs): 352 | duplicate_prob = True 353 | while duplicate_prob: 354 | # Randomly sample permutation 355 | all_perm = all_5rule_perm_sorted[c] 356 | np.random.shuffle(all_perm) 357 | perm = all_perm[0] 358 | # Sample rule instances 359 | # Rule 1 360 | r1_ind = np.floor(np.random.rand() * prob_types[perm[0]].shape[0]).astype(int) 361 | r1 = prob_types[perm[0]][r1_ind] 362 | # Rule 2 363 | duplicate_rule = True 364 | while duplicate_rule: 365 | r2_ind = np.floor(np.random.rand() * prob_types[perm[1]].shape[0]).astype(int) 366 | r2 = prob_types[perm[1]][r2_ind] 367 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(2).astype(bool).flatten())]) 368 | # Rule 3 369 | duplicate_rule = True 370 | while duplicate_rule: 371 | r3_ind = np.floor(np.random.rand() * prob_types[perm[2]].shape[0]).astype(int) 372 | r3 = prob_types[perm[2]][r3_ind] 373 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(3).astype(bool).flatten())]) 374 | # Rule 4 375 | duplicate_rule = True 376 | while duplicate_rule: 377 | r4_ind = np.floor(np.random.rand() * prob_types[perm[3]].shape[0]).astype(int) 378 | r4 = prob_types[perm[3]][r4_ind] 379 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(4).astype(bool).flatten())]) 380 | # Rule 5 381 | duplicate_rule = True 382 | while duplicate_rule: 383 | r5_ind = np.floor(np.random.rand() * prob_types[perm[4]].shape[0]).astype(int) 384 | r5 = prob_types[perm[4]][r5_ind] 385 | duplicate_rule = np.any(np.all(np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten(), r5.flatten()], 0), 0) == np.expand_dims(np.stack([r1.flatten(), r2.flatten(), r3.flatten(), r4.flatten(), r5.flatten()], 0), 1), 2).flatten()[np.logical_not(np.eye(5).astype(bool).flatten())]) 386 | # Combine rules 1-3 387 | prob = np.stack([r1,r2,r3,r4,r5],2) 388 | # Check if duplicate 389 | duplicate_detected = False 390 | for i in range(len(all_comb_prob)): 391 | if np.all(prob == all_comb_prob[i]): 392 | duplicate_detected = True 393 | if not duplicate_detected: 394 | duplicate_prob = False 395 | all_comb_prob.append(prob) 396 | all_comb_prob = np.array(all_comb_prob) 397 | all_5rule_prob.append(all_comb_prob) 398 | all_5rule_prob = np.array(all_5rule_prob) 399 | # Generate distractors 400 | all_5rule_answer_choices, all_5rule_correct_ind = gen_distractor(all_5rule_prob) 401 | 402 | # Convert problems to strings and save as js script, also as numpy file 403 | all_problems_np = {} 404 | all_problems_js = {} 405 | # 4-rule problems 406 | for c in range(all_4rule_prob.shape[0]): 407 | all_problems_np, all_problems_js = save_prob(all_4rule_prob[c], all_4rule_answer_choices[c], all_4rule_correct_ind[c], 'four_rule_comb' + str(c), all_problems_np, all_problems_js) 408 | # 5-rule problems 409 | for c in range(all_5rule_prob.shape[0]): 410 | all_problems_np, all_problems_js = save_prob(all_5rule_prob[c], all_5rule_answer_choices[c], all_5rule_correct_ind[c], 'five_rule_comb' + str(c), all_problems_np, all_problems_js) 411 | # Save numpy file 412 | np_fname = './all_4_5_rule_problems.npz' 413 | np.savez(np_fname, all_problems=all_problems_np) 414 | # Convert to json string 415 | json_string = json.dumps(all_problems_js) 416 | # Write to js script 417 | js_fname = './all_4_5_rule_problems.js' 418 | js_fid = open(js_fname, 'w') 419 | js_fid.write('var all_problems = ' + json_string) 420 | js_fid.close() 421 | --------------------------------------------------------------------------------