├── models ├── __init__.py ├── README.md └── bayesian_IRT.py ├── src ├── __init__.py ├── kaggle_submision_to_csv.py ├── utils.py ├── generate_IRT_model_results.py └── IRT_model_plots_analysis.py ├── analysis ├── __init__.py ├── 5-learning.ipynb ├── 3-misc.ipynb ├── 8-action-traces.ipynb ├── 6-incomplete-data.ipynb └── 0-arc-dataset.ipynb ├── .cursorignore ├── figures ├── arc-preprint-figure.png ├── arc-preprint-figure2.png ├── arc-preprint-figure2-background.png ├── bayes_IRT_model_burn2000_N10000_imputed_4_trace_plot.png └── bayes_IRT_model_burn2000_N10000_imputed_4_irt_parameters.png ├── .gitignore ├── .vscode └── launch.json ├── results └── bayes_IRT_model_burn2000_N10000_imputed_4_stats.md ├── polars_cfg.json ├── requirements.txt ├── survey └── readme.md ├── data └── readme.md ├── arc_data ├── ARC_training_tasks_ordered.json └── ARC_evaluation_tasks_ordered.json └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.cursorignore: -------------------------------------------------------------------------------- 1 | # Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv) 2 | .venv/ -------------------------------------------------------------------------------- /figures/arc-preprint-figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Gris/h-arc/HEAD/figures/arc-preprint-figure.png -------------------------------------------------------------------------------- /figures/arc-preprint-figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Gris/h-arc/HEAD/figures/arc-preprint-figure2.png -------------------------------------------------------------------------------- /figures/arc-preprint-figure2-background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Gris/h-arc/HEAD/figures/arc-preprint-figure2-background.png -------------------------------------------------------------------------------- /figures/bayes_IRT_model_burn2000_N10000_imputed_4_trace_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Gris/h-arc/HEAD/figures/bayes_IRT_model_burn2000_N10000_imputed_4_trace_plot.png -------------------------------------------------------------------------------- /figures/bayes_IRT_model_burn2000_N10000_imputed_4_irt_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Gris/h-arc/HEAD/figures/bayes_IRT_model_burn2000_N10000_imputed_4_irt_parameters.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # virtual environment 2 | .venv/ 3 | penv 4 | 5 | # cache 6 | __pycache__ 7 | .DS_Store 8 | 9 | # data 10 | data/ 11 | data_/ 12 | survey/ 13 | *.pkl 14 | *tar.gz 15 | 16 | # excluded analyses 17 | analysis/excluded/ 18 | 19 | # some source code 20 | src/make_dataset_csv.py 21 | 22 | figures/ 23 | 24 | models/*.pkl 25 | # ignore 26 | steps.txt 27 | *.out 28 | *.sbatch 29 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File with Arguments", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "args": "${command:pickArgs}" 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /results/bayes_IRT_model_burn2000_N10000_imputed_4_stats.md: -------------------------------------------------------------------------------- 1 | === IRT Model Statistics: bayes_IRT_model_burn2000_N10000_imputed_4 === 2 | 3 | Feedback Effects 4 | -------------------------------------------------- 5 | First additional attempt ($\epsilon_1 =1.25$):\n 6 | - Probability increase: 27.7% 7 | (94% HDI: [26.1%, 29.3%]) 8 | 9 | 10 | Second additional attempt ($\epsilon_2 =1.69$):\n 11 | - Probability increase: 34.4% 12 | (94% HDI: [33.1%, 35.6%]) 13 | 14 | 15 | Task Type Differences 16 | -------------------------------------------------- 17 | Training Tasks: 18 | 19 | - Mean difficulty ($\beta = -0.29$)\n 20 | - Success probability increase: 7.1% 21 | (94% HDI: [3.4%, 10.6%]) 22 | 23 | 24 | Evaluation Tasks: 25 | 26 | - Mean difficulty ($\beta = 0.24$)\n 27 | - Success probability increase: -6.1% 28 | (94% HDI: [-9.5%, -2.6%]) 29 | 30 | 31 | Mean Task Accuracy by Shot 32 | 33 | -------------------------------------------------- 34 | 35 | 1-shot: 36 | 37 | - Training Tasks: 54.6% 38 | (94% HDI: [53.3%, 55.8%]) 39 | 40 | - Evaluation Tasks: 49.2% 41 | (94% HDI: [47.9%, 50.4%]) 42 | 43 | 44 | 2-shots: 45 | 46 | - Training Tasks: 66.6% 47 | (94% HDI: [65.4%, 67.8%]) 48 | 49 | - Evaluation Tasks: 61.6% 50 | (94% HDI: [60.5%, 62.8%]) 51 | 52 | 53 | 3-shots: 54 | 55 | - Training Tasks: 70.5% 56 | (94% HDI: [69.3%, 71.6%]) 57 | 58 | - Evaluation Tasks: 65.7% 59 | (94% HDI: [64.6%, 66.8%]) 60 | -------------------------------------------------------------------------------- /polars_cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "environment": { 3 | "POLARS_AUTO_STRUCTIFY": null, 4 | "POLARS_FMT_MAX_COLS": null, 5 | "POLARS_FMT_MAX_ROWS": "15", 6 | "POLARS_FMT_NUM_DECIMAL": null, 7 | "POLARS_FMT_NUM_GROUP_SEPARATOR": null, 8 | "POLARS_FMT_NUM_LEN": null, 9 | "POLARS_FMT_STR_LEN": "45", 10 | "POLARS_FMT_TABLE_CELL_ALIGNMENT": null, 11 | "POLARS_FMT_TABLE_CELL_LIST_LEN": null, 12 | "POLARS_FMT_TABLE_CELL_NUMERIC_ALIGNMENT": null, 13 | "POLARS_FMT_TABLE_DATAFRAME_SHAPE_BELOW": null, 14 | "POLARS_FMT_TABLE_FORMATTING": null, 15 | "POLARS_FMT_TABLE_HIDE_COLUMN_DATA_TYPES": null, 16 | "POLARS_FMT_TABLE_HIDE_COLUMN_NAMES": null, 17 | "POLARS_FMT_TABLE_HIDE_COLUMN_SEPARATOR": null, 18 | "POLARS_FMT_TABLE_HIDE_DATAFRAME_SHAPE_INFORMATION": null, 19 | "POLARS_FMT_TABLE_INLINE_COLUMN_DATA_TYPE": null, 20 | "POLARS_FMT_TABLE_ROUNDED_CORNERS": null, 21 | "POLARS_MAX_EXPR_DEPTH": null, 22 | "POLARS_STREAMING_CHUNK_SIZE": null, 23 | "POLARS_TABLE_WIDTH": null, 24 | "POLARS_VERBOSE": null, 25 | "POLARS_WARN_UNSTABLE": null 26 | }, 27 | "direct": { 28 | "set_fmt_float": "mixed", 29 | "set_float_precision": null, 30 | "set_thousands_separator": "", 31 | "set_decimal_separator": ".", 32 | "set_trim_decimal_zeros": false 33 | } 34 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.4 2 | arviz==0.19.0 3 | asttokens==2.4.1 4 | cachetools==5.5.0 5 | cloudpickle==3.0.0 6 | comm==0.2.2 7 | cons==0.4.6 8 | contourpy==1.2.1 9 | cycler==0.12.1 10 | debugpy==1.8.5 11 | decorator==5.1.1 12 | dm-tree==0.1.8 13 | etuples==0.3.9 14 | executing==2.0.1 15 | filelock==3.16.1 16 | fonttools==4.53.1 17 | h5netcdf==1.3.0 18 | h5py==3.11.0 19 | ipykernel==6.29.5 20 | ipython==8.26.0 21 | ipywidgets==8.1.5 22 | jedi==0.19.1 23 | jupyter-client==8.6.2 24 | jupyter-core==5.7.2 25 | jupyterlab-widgets==3.0.13 26 | kiwisolver==1.4.5 27 | levenshtein==0.25.1 28 | llvmlite==0.43.0 29 | logical-unification==0.4.6 30 | markdown-it-py==3.0.0 31 | matplotlib==3.9.2 32 | matplotlib-inline==0.1.7 33 | mdurl==0.1.2 34 | minikanren==1.0.3 35 | multipledispatch==1.0.0 36 | nest-asyncio==1.6.0 37 | numba==0.60.0 38 | numpy==1.26.4 39 | packaging==24.1 40 | pandas==2.2.2 41 | parso==0.8.4 42 | pexpect==4.9.0 43 | pillow==10.4.0 44 | pip==24.2 45 | platformdirs==4.2.2 46 | polars==1.5.0 47 | preliz==0.9.1 48 | prompt-toolkit==3.0.47 49 | psutil==6.0.0 50 | ptyprocess==0.7.0 51 | pure-eval==0.2.3 52 | pyarrow==17.0.0 53 | pygments==2.18.0 54 | pymc==5.16.2 55 | pyparsing==3.1.2 56 | pytensor==2.25.4 57 | python-dateutil==2.9.0.post0 58 | pytz==2024.1 59 | pyzmq==26.1.0 60 | rapidfuzz==3.9.6 61 | rich==13.8.1 62 | scipy==1.12.0 63 | seaborn==0.13.2 64 | setuptools==75.1.0 65 | six==1.16.0 66 | stack-data==0.6.3 67 | threadpoolctl==3.5.0 68 | toolz==0.12.1 69 | tornado==6.4.1 70 | traitlets==5.14.3 71 | typing-extensions==4.12.2 72 | tzdata==2024.1 73 | wcwidth==0.2.13 74 | widgetsnbextension==4.0.13 75 | xarray==2024.9.0 76 | xarray-einstats==0.8.0 77 | -------------------------------------------------------------------------------- /src/kaggle_submision_to_csv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import polars as pl 3 | from argparse import ArgumentParser 4 | from utils import grid2str 5 | import os 6 | from pathlib import Path 7 | 8 | basepath = Path(__file__).parent.parent 9 | 10 | 11 | def get_args(): 12 | parser = ArgumentParser() 13 | parser.add_argument("--submission_id", type=str, required=True) 14 | return parser.parse_args() 15 | 16 | 17 | def kaggle_submision_to_csv(input_json, submission_id): 18 | dataframe = { 19 | "task_name": [], 20 | "test_number": [], 21 | "submission_id": [], 22 | "test_output_grid": [], 23 | "attempt_number": [], 24 | } 25 | for k, test_list in input_json.items(): 26 | for i, test_submissions in enumerate(test_list): 27 | for attempt, grid in test_submissions.items(): 28 | dataframe["task_name"].append(k + ".json") 29 | dataframe["test_number"].append(i + 1) 30 | dataframe["submission_id"].append(submission_id) 31 | dataframe["test_output_grid"].append(grid2str(grid)) 32 | dataframe["attempt_number"].append(int(attempt.split("_")[-1])) 33 | 34 | df = pl.DataFrame(dataframe) 35 | output_path = os.path.join( 36 | basepath, "data", "kaggle_solutions", submission_id, "submission.csv" 37 | ) 38 | df.write_csv(output_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | args = get_args() 43 | input_json_path = os.path.join( 44 | basepath, "data", "kaggle_solutions", args.submission_id, "submission.json" 45 | ) 46 | with open(input_json_path, "r") as f: 47 | input_json = json.load(f) 48 | kaggle_submision_to_csv(input_json, args.submission_id) 49 | -------------------------------------------------------------------------------- /survey/readme.md: -------------------------------------------------------------------------------- 1 | # Survey directory description 2 | 3 | This document describes the CSV files in the `/survey` directory 4 | 5 | ## feedback_data.csv 6 | 7 | This file contains participant feedback. 8 | 9 | This file contains participant feedback. 10 | 11 | | Column name | Description | 12 | | ----------- | ----------------------------------------- | 13 | | exp_name | Experiment name | 14 | | task_type | Type of task (training or evaluation set) | 15 | | hashed_id | Anonymized participant identifier | 16 | | feedback | Feedback provided by the participant | 17 | 18 | ## demographics_data.csv 19 | 20 | This file contains demographic information about the participants. 21 | 22 | | Column name | Description | 23 | | --------------- | ---------------------------------------------------------- | 24 | | exp_name | Experiment name (internal identifier) | 25 | | task_type | Type of task (training or evaluation set) | 26 | | hashed_id | Anonymized participant identifier | 27 | | age | Age of the participant | 28 | | gender | Gender of the participant | 29 | | race | Race of the participant | 30 | | education_level | Education level of the participant | 31 | | normal_vision | Boolean indicating if the participant has normal vision | 32 | | color_blind | Boolean indicating if the participant is color blind | 33 | | fluent_english | Boolean indicating if the participant is fluent in English | 34 | 35 | ## withdraw_data.csv 36 | 37 | This file contains information about participants who withdrew from the experiment. 38 | 39 | | Column name | Description | 40 | | ---------------- | ---------------------------------------------- | 41 | | exp_name | Experiment name | 42 | | task_type | Type of task (training or evaluation set) | 43 | | hashed_id | Anonymized participant identifier | 44 | | withdraw | Boolean indicating if the participant withdrew | 45 | | withdraw_reason | Reason for withdrawal | 46 | | withdraw_comment | Comment on withdrawal | 47 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # Bayesian Item Response Theory (IRT) Model 2 | 3 | This directory contains the implementation of a Bayesian Item Response Theory (IRT) model used to analyze participant performance on the Abstraction and Reasoning Corpus (ARC) tasks. The model is implemented in Python using the `pymc` library. 4 | 5 | ## Model Overview 6 | 7 | Intuitively, IRT models disambiguate between different latent variables that are hypothesized to drive probability of success on tasks within a test. Due to the limitations of random task-participant pairings in a finite experimental setup, certain tasks/participants can potentially be under- or overestimated with respect to ground-truth difficulty/ability when simply considering empirical success rate. Fitting an IRT model using a Bayesian framework allowed us to extract credible intervals for each parameter, which is useful for dealing with the inherent uncertainty of empirical data in a principled way. Through Bayesian data imputation, missing values were treated as additional parameters to be estimated and were sampled, conditioned on values of the model parameters during inference. Additionally, the inferred item difficulties allow us to examine difficulty distributions across ARC tasks and datasets, as well as lay out a task difficulty ordering to get a better sense of which kinds of tasks are easier or harder for people. Participant and item difficulties were given $\mathcal{N}(0, \sigma_{\alpha})$ and $\mathcal{N}(0, \sigma_{\beta})$ priors with $\sigma_{\alpha}, \sigma_{\beta} \sim \mathcal{N}^+(1)$ hyperpriors. The feedback effect was modeled as follows: $\gamma_0 = 0$, $\gamma_1 \sim \mathcal{N}^+(1)$ and $\gamma_2 = \gamma_1 + \delta$, where $\delta \sim \mathcal{N}^+(1)$. 8 | 9 | ## Mean Probability of Success 10 | 11 | The model calculates the mean probability of success for both the training and evaluation sets for each attempt. This is computed by averaging the predicted probabilities over all participants and tasks within a given set. 12 | 13 | The formula for the mean probability of success is: 14 | 15 | $$P_{\text{set}}(k) = \frac{1}{N_p N_t} \sum_{i=1}^{N_p} \sum_{j \in \mathcal{T}_{\text{set}}} \text{logit}^{-1}(\hat{\alpha}_i - \hat{\beta}_j + \hat{\gamma}_k)$$ 16 | 17 | where $`\mathcal{T}_{\text{set}}`$ represents either training or evaluation tasks, $`k \in \{0,1,2\}`$ is the attempt number and $`P_{\text{set}}(k)`$ is the mean probability of success for the training or evaluation set at attempt $k$. 18 | 19 | In this equation: 20 | 21 | - $\hat{\alpha}_i$ is the estimated ability for participant $i$. 22 | - $\hat{\beta}_j$ is the estimated difficulty for task $j$. 23 | - $\hat{\gamma}_k$ is the estimated effect for attempt $k$. 24 | 25 | ## Implementation 26 | 27 | The model is defined in `bayesian_IRT.py`. The script takes a DataFrame of participant responses and fits the IRT model using PyMC, returning the model object and the inference trace. 28 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime as dt 2 | import polars as pl 3 | import hashlib 4 | 5 | 6 | def parse_mixed_datetime(datetime_str): 7 | formats = ["%m/%d/%Y, %I:%M:%S %p", "%m/%d/%Y, %H:%M:%S", "%Y-%m-%d %H:%M:%S.%f"] 8 | for fmt in formats: 9 | try: 10 | return dt.strptime(datetime_str, fmt) 11 | except ValueError: 12 | continue 13 | return None 14 | 15 | 16 | def get_summary(df, verbose=False): 17 | """Get ARC summary data frame by filtering out traces where last action is invalid""" 18 | # list of final actions 19 | final_actions = [ 20 | "no_last_description", 21 | "write_description", 22 | "write_last_description", 23 | ] 24 | df = df.with_columns( 25 | attempt_number=pl.when(pl.col("attempt_number") == 4) 26 | .then(3) 27 | .otherwise(pl.col("attempt_number")) 28 | ) 29 | complete_task_joint_ids = ( 30 | df.select( 31 | pl.all() 32 | .top_k_by(["attempt_number", "num_actions"], k=1) 33 | .over(["joint_id_task"], mapping_strategy="explode") 34 | ) 35 | .filter(pl.col("action").is_in(final_actions)) 36 | .select("joint_id_task") 37 | ) 38 | df_summary = df.join(complete_task_joint_ids, on="joint_id_task").select( 39 | pl.all() 40 | .top_k_by("num_actions", k=1) 41 | .over(["joint_id_task", "attempt_number"], mapping_strategy="explode") 42 | ) 43 | if verbose: 44 | print( 45 | f"Filtered out {df.n_unique('joint_id_task') - df_summary.n_unique('joint_id_task')}/{df.n_unique('joint_id_task')} participant task attempts" 46 | ) 47 | df_summary = df_summary[ 48 | [ 49 | "hashed_id", 50 | "task_name", 51 | "joint_id_task", 52 | "task_number", 53 | "attempt_number", 54 | "action", 55 | "solved", 56 | "test_output_grid", 57 | "first_written_solution", 58 | "last_written_solution", 59 | "num_actions", 60 | "exp_name", 61 | "task_type", 62 | "complete", 63 | ] 64 | ] 65 | return df_summary 66 | 67 | 68 | def get_errors(df): 69 | """Take ARC summary data frame and filter for all incorrect attempts""" 70 | df_errors = df.filter(pl.col("solved") == False) 71 | 72 | # get frequency of errors 73 | df_errors = df_errors.group_by(["task_name", "test_output_grid"]).agg( 74 | pl.count().alias("count"), 75 | pl.first("task_type").alias("task_type"), 76 | ) 77 | 78 | # get hashed output grid 79 | df_errors = df_errors.with_columns( 80 | pl.col("test_output_grid") 81 | .map_elements( 82 | lambda x: hashlib.md5(x.encode()).hexdigest(), return_dtype=pl.Utf8 83 | ) 84 | .alias("hashed_output_grid") 85 | ) 86 | 87 | return df_errors 88 | 89 | 90 | def grid2str(grid): 91 | """Converts an ARC grid in numpy form to a string representation""" 92 | grid_str = "|" 93 | for row in grid: 94 | for num in row: 95 | grid_str += str(num) 96 | grid_str += "|" 97 | return grid_str 98 | 99 | 100 | def include_incomplete(df_summary, df_incomplete, verbose=False): 101 | df_summary = df_summary.drop("condition") 102 | df_incomplete_summary = get_summary(df_incomplete, verbose=verbose) 103 | df_incomplete_summary = df_incomplete_summary.select(df_summary.columns) 104 | df_summary = df_summary.with_columns(pl.lit(False).alias("incomplete")) 105 | df_incomplete_summary = df_incomplete_summary.with_columns( 106 | pl.lit(True).alias("incomplete") 107 | ) 108 | df_summary = df_summary.vstack(df_incomplete_summary) 109 | if verbose: 110 | print( 111 | f"Included {df_incomplete_summary.n_unique('joint_id_task')}/{df_incomplete.n_unique('joint_id_task')} incomplete participant task attempts" 112 | ) 113 | return df_summary, df_incomplete_summary 114 | 115 | 116 | def md5(grid): 117 | """Converts a string representation of a grid to an md5 hash for indexing""" 118 | return hashlib.md5(bytes(grid, encoding="utf-8")).hexdigest() 119 | -------------------------------------------------------------------------------- /models/bayesian_IRT.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pymc as pm 3 | import numpy as np 4 | 5 | 6 | def bayes_irt(df, n_samples=5000, tune=1000, seed=0): 7 | """ 8 | Bayesian Item Response Theory model 9 | :param df: pandas DataFrame 10 | :param n_samples: int, number of samples 11 | :param tune: int, number of burn in samples 12 | :param seed: int, seed for random number generator 13 | :return: model, trace 14 | """ 15 | 16 | np.random.seed(seed) 17 | # create indices for participants 18 | participants_idx, participants = pd.factorize(df["hashed_id"], sort=True) 19 | 20 | # create task index 21 | task_idx, tasks = pd.factorize(df["task_name"], sort=True) 22 | 23 | # task type 24 | task_type_idx, _ = pd.factorize( 25 | df["task_type"], sort=True 26 | ) # sorted means evaluation=0, training=1 27 | 28 | # training and eval tasks 29 | training_task_idx = task_idx[task_type_idx == 1] 30 | training_task_idx = np.unique(training_task_idx) 31 | training_tasks = tasks[training_task_idx] 32 | eval_task_idx = task_idx[task_type_idx == 0] 33 | eval_task_idx = np.unique(eval_task_idx) 34 | eval_tasks = tasks[eval_task_idx] 35 | 36 | # training and eval participants 37 | training_participants_idx = participants_idx[task_type_idx == 1] 38 | training_participants_idx = np.unique(training_participants_idx) 39 | eval_participants_idx = participants_idx[task_type_idx == 0] 40 | eval_participants_idx = np.unique(eval_participants_idx) 41 | 42 | # coords 43 | coords = { 44 | "participants": participants, 45 | "tasks": tasks, 46 | "shots": ["1-shot", "2-shots", "3-shots"], 47 | "obs": np.arange(len(df)), 48 | "training_tasks": training_tasks, 49 | "eval_tasks": eval_tasks, 50 | } 51 | 52 | with pm.Model(coords=coords) as model: 53 | 54 | # hyperpriors 55 | mu_alpha = 0 56 | mu_beta = 0 57 | 58 | sigma_alpha = pm.HalfNormal("sigma_alpha", sigma=1) 59 | sigma_beta = pm.HalfNormal("sigma_beta", sigma=1) 60 | 61 | # Ability (alpha) for each participant 62 | alpha = pm.Normal("alpha", mu=mu_alpha, sigma=sigma_alpha, dims="participants") 63 | 64 | # Difficulty (beta) for each task 65 | beta = pm.Normal("beta", mu=mu_beta, sigma=sigma_beta, dims="tasks") 66 | 67 | # Learning rate (epsilon) for each shot 68 | epsilon_zero = 0 69 | epsilon_one = pm.HalfNormal("epsilon_one", sigma=1) 70 | delta = pm.HalfNormal("epsilon_delta", sigma=1) 71 | epsilon_two = epsilon_one + delta 72 | pm.Deterministic("epsilon_two", epsilon_two) 73 | 74 | # Stack epsilons as a vector 75 | epsilon = pm.math.stack([epsilon_zero, epsilon_one, epsilon_two]) 76 | 77 | # Likelihood 78 | p = pm.math.invlogit( 79 | alpha[participants_idx, None] - beta[task_idx, None] + epsilon[None, :] 80 | ) 81 | observed = df[["1-shot", "2-shots", "3-shots"]].values 82 | pm.Bernoulli("outcomes", p=p, observed=observed, dims=("obs", "shots")) 83 | 84 | # Calculate logits for every participant on every train task 85 | logits_all_training = ( 86 | alpha[:, None, None] # Participant abilities (N_participants, 1, 1) 87 | - beta[ 88 | None, training_task_idx, None 89 | ] # Task difficulties (1, N__training_tasks, 1) 90 | + epsilon[None, None, :] # Learning rates (1, 1, N_shots) 91 | ) 92 | 93 | # Apply invlogit to get probabilities 94 | p_all_training = pm.math.invlogit(logits_all_training) 95 | pm.Deterministic( 96 | "mean_task_acc_training", 97 | p_all_training.mean( 98 | axis=(0, 1) 99 | ), # Average across participants and tasks for each attempt 100 | dims="shots", 101 | ) 102 | 103 | # Calculate logits for every participant on every eval task 104 | logits_all_eval = ( 105 | alpha[:, None, None] # Participant abilities (N_participants, 1, 1) 106 | - beta[None, eval_task_idx, None] # Task difficulties (1, N_eval_tasks, 1) 107 | + epsilon[None, None, :] # Learning rates (1, 1, N_shots) 108 | ) 109 | 110 | # Apply invlogit to get probabilities 111 | p_all_eval = pm.math.invlogit(logits_all_eval) 112 | pm.Deterministic( 113 | "mean_task_acc_eval", 114 | p_all_eval.mean( 115 | axis=(0, 1) 116 | ), # Average across participants and tasks for each attempt 117 | dims="shots", 118 | ) 119 | 120 | # Sampling 121 | trace = pm.sample( 122 | n_samples, 123 | tune=tune, 124 | return_inferencedata=True, 125 | random_seed=seed, 126 | progressbar=True, 127 | ) 128 | 129 | return model, trace 130 | -------------------------------------------------------------------------------- /analysis/5-learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "vscode": { 8 | "languageId": "r" 9 | } 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stderr", 14 | "output_type": "stream", 15 | "text": [ 16 | "Le chargement a n'ecessit'e le package : Matrix\n", 17 | "\n", 18 | "************\n", 19 | "Welcome to afex. For support visit: http://afex.singmann.science/\n", 20 | "\n", 21 | "- Functions for ANOVAs: aov_car(), aov_ez(), and aov_4()\n", 22 | "- Methods for calculating p-values with mixed(): 'S', 'KR', 'LRT', and 'PB'\n", 23 | "- 'afex_aov' and 'mixed' objects can be passed to emmeans() for follow-up tests\n", 24 | "- Get and set global package options with: afex_options()\n", 25 | "- Set sum-to-zero contrasts globally: set_sum_contrasts()\n", 26 | "- For example analyses see: browseVignettes(\"afex\")\n", 27 | "************\n", 28 | "\n", 29 | "\n", 30 | "Attachement du package : 'afex'\n", 31 | "\n", 32 | "\n", 33 | "L'objet suivant est masqu'e depuis 'package:lme4':\n", 34 | "\n", 35 | " lmer\n", 36 | "\n", 37 | "\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "library(lme4)\n", 43 | "library(ggplot2)\n", 44 | "library(afex)\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": { 51 | "vscode": { 52 | "languageId": "r" 53 | } 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "# load data\n", 58 | "df_task_num_success <- read.csv(\"../data/task_number_outcomes.csv\")\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": { 65 | "vscode": { 66 | "languageId": "r" 67 | } 68 | }, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "Generalized linear mixed model fit by maximum likelihood (Laplace\n", 74 | " Approximation) [glmerMod]\n", 75 | " Family: binomial ( logit )\n", 76 | "Formula: success ~ task_number + (1 | hashed_id)\n", 77 | " Data: df_task_num_success\n", 78 | "Control: glmerControl(optimizer = \"bobyqa\")\n", 79 | "\n", 80 | " AIC BIC logLik deviance df.resid \n", 81 | " 7368.7 7389.3 -3681.4 7362.7 7143 \n", 82 | "\n", 83 | "Scaled residuals: \n", 84 | " Min 1Q Median 3Q Max \n", 85 | "-2.7843 -0.4753 0.2765 0.4093 2.0394 \n", 86 | "\n", 87 | "Random effects:\n", 88 | " Groups Name Variance Std.Dev.\n", 89 | " hashed_id (Intercept) 3.913 1.978 \n", 90 | "Number of obs: 7146, groups: hashed_id, 1632\n", 91 | "\n", 92 | "Fixed effects:\n", 93 | " Estimate Std. Error z value Pr(>|z|) \n", 94 | "(Intercept) 1.37745 0.09073 15.181 <2e-16 ***\n", 95 | "task_number 0.02333 0.02044 1.141 0.254 \n", 96 | "---\n", 97 | "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", 98 | "\n", 99 | "Correlation of Fixed Effects:\n", 100 | " (Intr)\n", 101 | "task_number -0.649" 102 | ] 103 | }, 104 | "metadata": {}, 105 | "output_type": "display_data" 106 | } 107 | ], 108 | "source": [ 109 | "# fit mixed model\n", 110 | "mixed_model <- glmer(\n", 111 | " success ~ task_number +\n", 112 | " (1 | hashed_id),\n", 113 | " data = df_task_num_success,\n", 114 | " family = \"binomial\",\n", 115 | " control = glmerControl(optimizer = \"bobyqa\")\n", 116 | ")\n", 117 | "summary(mixed_model)\n" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 4, 123 | "metadata": { 124 | "vscode": { 125 | "languageId": "r" 126 | } 127 | }, 128 | "outputs": [ 129 | { 130 | "name": "stderr", 131 | "output_type": "stream", 132 | "text": [ 133 | "Contrasts set to contr.sum for the following variables: hashed_id\n", 134 | "\n", 135 | "Numerical variables NOT centered on 0: task_number\n", 136 | "If in interactions, interpretation of lower order (e.g., main) effects difficult.\n", 137 | "\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "mixed_model_lrt <- mixed(\n", 143 | " success ~ task_number +\n", 144 | " (1 | hashed_id),\n", 145 | " data = df_task_num_success,\n", 146 | " family = \"binomial\",\n", 147 | " method = \"LRT\",\n", 148 | " control = glmerControl(optimizer = \"bobyqa\")\n", 149 | ")\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 5, 155 | "metadata": { 156 | "vscode": { 157 | "languageId": "r" 158 | } 159 | }, 160 | "outputs": [ 161 | { 162 | "data": { 163 | "text/plain": [ 164 | "Mixed Model Anova Table (Type 3 tests, LRT-method)\n", 165 | "\n", 166 | "Model: success ~ task_number + (1 | hashed_id)\n", 167 | "Data: df_task_num_success\n", 168 | "Df full model: 3\n", 169 | " Effect df Chisq p.value\n", 170 | "1 task_number 1 1.30 .254\n", 171 | "---\n", 172 | "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '+' 0.1 ' ' 1" 173 | ] 174 | }, 175 | "metadata": {}, 176 | "output_type": "display_data" 177 | } 178 | ], 179 | "source": [ 180 | "mixed_model_lrt\n" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "R", 187 | "language": "R", 188 | "name": "ir" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": "r", 192 | "file_extension": ".r", 193 | "mimetype": "text/x-r-source", 194 | "name": "R", 195 | "pygments_lexer": "r", 196 | "version": "4.3.2" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 2 201 | } 202 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | # Data directory description 2 | 3 | This document describes the CSV files in the `/data` directory. 4 | 5 | ## data.csv 6 | 7 | This file contains all collected task data for complete and incomplete participant data, respectively. Participant data was deemed incomplete if the data file had an insufficient number of attempted tasks (5 or 10 depending on the experiment). 8 | 9 | Each row represents a single, numbered action taken by a unique participant on a given task and attempt with all relevant experiment, participant, task and action information. 10 | 11 | | Column name | Description | 12 | | ------------------ | ----------------------------------------------------------------------------------------- | 13 | | exp_name | Experiment name (internal identifier) | 14 | | task_type | Type of task (training or evaluation set) | 15 | | hashed_id | Anonymized participant identifier | 16 | | joint_id_task | Combined identifier for participant and task | 17 | | task_name | Name of the task | 18 | | task_number | Number of the task (i.e., 3 is the third task completed) | 19 | | time | Timestamp of the action | 20 | | attempt_number | Number of the attempt | 21 | | action_id | Number of the action taken | 22 | | solved | Boolean indicating if the task was solved at this action | 23 | | done | Boolean indicating if the attempt is complete (last action) | 24 | | test_input_grid | Input grid for the task | 25 | | test_input_size_x | X-dimension of the input grid | 26 | | test_input_size_y | Y-dimension of the input grid | 27 | | test_output_grid | Output grid for the task in string format | 28 | | test_output_size_x | X-dimension of the output grid | 29 | | test_output_size_y | Y-dimension of the output grid | 30 | | action | Action taken by the participant | 31 | | action_x | X-coordinate of the action | 32 | | action_y | Y-coordinate of the action | 33 | | select_loc | Selected location | 34 | | selected_data | Data selected by the participant | 35 | | selected_symbol | Symbol selected by the participant (i.e., color in the experiment interface) | 36 | | selected_tool | Tool selected by the participant | 37 | | copy_paste_data | Data used in copy-paste actions | 38 | | complete | Boolean indicating if the data is from a participant that completed the experiment or not | 39 | 40 | ## summary_data.csv 41 | 42 | This file contains summary data for complete and incomplete participant data, respectively. Each row represents a summary of an attempt by a unique participant at a given task. 43 | 44 | | Column name | Description | 45 | | ---------------------- | ----------------------------------------------------------------------------------------- | 46 | | exp_name | Experiment name (internal identifier) | 47 | | task_type | Type of task (training or evaluation set) | 48 | | hashed_id | Anonymized participant identifier | 49 | | joint_id_task | Combined identifier for participant and task | 50 | | task_name | Name of the task | 51 | | task_number | Number of the task (i.e., 3 is the third completed task) | 52 | | attempt_number | Number of the attempt | 53 | | num_actions | Number of actions taken until submission | 54 | | solved | Boolean indicating if the task was solved | 55 | | test_output_grid | Output grid for the task | 56 | | first_written_solution | First solution written by the participant | 57 | | last_written_solution | Last solution written by the participant | 58 | | complete | Boolean indicating if the data is from a participant that completed the experiment or not | 59 | 60 | ## incorrect_submissions.csv 61 | 62 | These files contain error information for complete and incomplete participant data, respectively. 63 | 64 | | Column name | Description | 65 | | ---------------- | ----------------------------------------- | 66 | | task_name | Name of the task | 67 | | task_type | Type of task (training or evaluation set) | 68 | | test_output_grid | Output grid for the task | 69 | | count | Number of occurrences of this error | 70 | -------------------------------------------------------------------------------- /src/generate_IRT_model_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse as ap 4 | import os 5 | import json 6 | from pathlib import Path 7 | import cloudpickle as cpkl 8 | import sys 9 | import polars as pl 10 | 11 | basepath = Path(__file__).parent.parent 12 | sys.path.append(str(basepath)) 13 | 14 | from models.bayesian_IRT import bayes_irt 15 | 16 | 17 | def get_args(): 18 | parser = ap.ArgumentParser(description="Generate IRT model results") 19 | parser.add_argument("--n_samples", type=int, default=5000, help="number of samples") 20 | parser.add_argument( 21 | "--n_burn", type=int, default=1000, help="number of burn in samples" 22 | ) 23 | parser.add_argument( 24 | "--seed", type=int, default=0, help="seed for random number generator" 25 | ) 26 | parser.add_argument("--impute", action="store_true", help="impute missing data") 27 | return parser.parse_args() 28 | 29 | 30 | def load_data(): 31 | # load data 32 | df_summary = pl.read_csv(os.path.join(basepath, "data", "clean_summary_data.csv")) 33 | df_summary_incomplete = pl.read_csv( 34 | os.path.join(basepath, "data", "clean_summary_data_incomplete.csv") 35 | ) 36 | 37 | # combine data 38 | df_summary = pl.concat([df_summary, df_summary_incomplete]) 39 | 40 | # filter data 41 | columns = [ 42 | "hashed_id", 43 | "task_name", 44 | "attempt_number", 45 | "solved", 46 | "task_type", 47 | "complete", 48 | ] 49 | df = ( 50 | df_summary.select(columns) 51 | .with_columns( 52 | pl.col("attempt_number").cast(pl.Int32), 53 | pl.col("solved").cast(pl.Int32), 54 | ) 55 | .to_pandas() 56 | ) 57 | 58 | # reshape for bernouilli 59 | df = df.pivot_table( 60 | values="solved", 61 | index=["hashed_id", "task_name", "task_type", "complete"], 62 | columns="attempt_number", 63 | aggfunc="first", 64 | ).reset_index() 65 | df.columns.name = None 66 | df = df.rename(columns={1: "1-shot", 2: "2-shots", 3: "3-shots"}) 67 | df = df.fillna(1) 68 | return df 69 | 70 | 71 | def get_incomplete_tasks(df, ordered_tasks_list): 72 | incomplete = df.groupby("hashed_id").size() 73 | incomplete = incomplete.reset_index() 74 | incomplete = incomplete.rename(columns={0: "n_complete"}) 75 | incomplete = incomplete[incomplete["n_complete"] < 5] 76 | repeats = 5 - incomplete["n_complete"].to_numpy().flatten() 77 | 78 | tasks = [] 79 | for h, m, r in zip(incomplete["hashed_id"], incomplete["n_complete"], repeats): 80 | # Get completed tasks for this participant 81 | completed_tasks = df[df["hashed_id"] == h]["task_name"].to_numpy() 82 | # Find indices of completed tasks in ordered list 83 | completed_indices = np.where(np.isin(ordered_tasks_list, completed_tasks))[0] 84 | 85 | if len(completed_indices) == 0: 86 | # If no tasks completed, start from beginning 87 | next_tasks = [] 88 | idx = 0 89 | while len(next_tasks) < r: 90 | if ordered_tasks_list[idx] not in completed_tasks: 91 | next_tasks.append(ordered_tasks_list[idx]) 92 | idx = (idx + 1) % len(ordered_tasks_list) 93 | else: 94 | # Start from the highest completed index 95 | next_tasks = [] 96 | idx = (completed_indices.max() + 1) % len(ordered_tasks_list) 97 | while len(next_tasks) < r: 98 | if ordered_tasks_list[idx] not in completed_tasks: 99 | next_tasks.append(ordered_tasks_list[idx]) 100 | idx = (idx + 1) % len(ordered_tasks_list) 101 | 102 | tasks.extend(next_tasks) 103 | 104 | hashed_ids = incomplete["hashed_id"].repeat(repeats) 105 | return hashed_ids, np.array(tasks) 106 | 107 | 108 | def load_ordered_tasks(): 109 | df_training_tasks_ordered = json.load( 110 | open(os.path.join(basepath, "data", "ARC_training_tasks_ordered.json")) 111 | ) 112 | df_eval_tasks_ordered = json.load( 113 | open(os.path.join(basepath, "data", "ARC_evaluation_tasks_ordered.json")) 114 | ) 115 | return df_training_tasks_ordered, df_eval_tasks_ordered 116 | 117 | 118 | def fill_na(df, hashed_ids, tasks, task_type="training"): 119 | nan_rows = pd.DataFrame( 120 | { 121 | "hashed_id": hashed_ids, 122 | "task_name": tasks, 123 | "task_type": task_type, 124 | "complete": False, 125 | "1-shot": np.nan, 126 | "2-shots": np.nan, 127 | "3-shots": np.nan, 128 | } 129 | ) 130 | df = pd.concat([df, nan_rows]) 131 | return df 132 | 133 | 134 | def include_missing_data(df): 135 | # separate training and evaluation data 136 | df_training = df[df["task_type"] == "training"] 137 | df_eval = df[df["task_type"] == "evaluation"] 138 | # load ordered tasks 139 | df_training_tasks_ordered, df_eval_tasks_ordered = load_ordered_tasks() 140 | # get incomplete tasks 141 | hashed_ids_training, tasks_training = get_incomplete_tasks( 142 | df_training, np.array(df_training_tasks_ordered) 143 | ) 144 | hashed_ids_eval, tasks_eval = get_incomplete_tasks( 145 | df_eval, np.array(df_eval_tasks_ordered) 146 | ) 147 | # fill missing data 148 | df_training = fill_na( 149 | df_training, hashed_ids_training, tasks_training, task_type="training" 150 | ) 151 | df_eval = fill_na(df_eval, hashed_ids_eval, tasks_eval, task_type="evaluation") 152 | # combine data 153 | df = pd.concat([df_training, df_eval]) 154 | return df 155 | 156 | 157 | if __name__ == "__main__": 158 | args = get_args() 159 | # load data 160 | df = load_data() 161 | if args.impute: 162 | df = include_missing_data(df) 163 | # run IRT model 164 | model, trace = bayes_irt( 165 | df, n_samples=args.n_samples, tune=args.n_burn, seed=args.seed 166 | ) 167 | # save model 168 | imputed = "_imputed" if args.impute else "" 169 | with open( 170 | os.path.join( 171 | basepath, 172 | "models", 173 | f"bayes_IRT_model_burn{args.n_burn}_N{args.n_samples}{imputed}_{args.seed}.pkl", 174 | ), 175 | "wb", 176 | ) as f: 177 | cpkl.dump((model, trace), f) 178 | -------------------------------------------------------------------------------- /arc_data/ARC_training_tasks_ordered.json: -------------------------------------------------------------------------------- 1 | [ 2 | "007bbfb7.json", 3 | "00d62c1b.json", 4 | "017c7c7b.json", 5 | "045e512c.json", 6 | "0520fde7.json", 7 | "05269061.json", 8 | "05f2a901.json", 9 | "06df4c85.json", 10 | "08ed6ac7.json", 11 | "09629e4f.json", 12 | "0962bcdd.json", 13 | "0a938d79.json", 14 | "0b148d64.json", 15 | "0ca9ddb6.json", 16 | "0d3d703e.json", 17 | "0e206a2e.json", 18 | "10fcaaa3.json", 19 | "11852cab.json", 20 | "1190e5a7.json", 21 | "137eaa0f.json", 22 | "178fcbfb.json", 23 | "1a07d186.json", 24 | "1b2d62fb.json", 25 | "1b60fb0c.json", 26 | "1bfc4729.json", 27 | "1c786137.json", 28 | "1cf80156.json", 29 | "1e32b0e9.json", 30 | "1f0c79e5.json", 31 | "1f642eb9.json", 32 | "1f85a75f.json", 33 | "2013d3e2.json", 34 | "2204b7a8.json", 35 | "22168020.json", 36 | "22233c11.json", 37 | "22eb0ac0.json", 38 | "234bbc79.json", 39 | "23581191.json", 40 | "239be575.json", 41 | "23b5c85d.json", 42 | "253bf280.json", 43 | "25d487eb.json", 44 | "25d8a9c8.json", 45 | "25ff71a9.json", 46 | "264363fd.json", 47 | "272f95fa.json", 48 | "27a28665.json", 49 | "28bf18c6.json", 50 | "29623171.json", 51 | "29c11459.json", 52 | "2bcee788.json", 53 | "2bee17df.json", 54 | "2c608aff.json", 55 | "2dc579da.json", 56 | "2dd70a9a.json", 57 | "2dee498d.json", 58 | "321b1fc6.json", 59 | "32597951.json", 60 | "3345333e.json", 61 | "3618c87e.json", 62 | "363442ee.json", 63 | "36d67576.json", 64 | "36fdfd69.json", 65 | "3906de3d.json", 66 | "3ac3eb23.json", 67 | "3bd67248.json", 68 | "3bdb4ada.json", 69 | "3befdf3e.json", 70 | "3c9b0459.json", 71 | "3de23699.json", 72 | "3e980e27.json", 73 | "3eda0437.json", 74 | "3f7978a0.json", 75 | "40853293.json", 76 | "4093f84a.json", 77 | "41e4d17e.json", 78 | "4258a5f9.json", 79 | "42a50994.json", 80 | "4347f46a.json", 81 | "445eab21.json", 82 | "447fd412.json", 83 | "44d8ac46.json", 84 | "44f52bb0.json", 85 | "4522001f.json", 86 | "4612dd53.json", 87 | "46442a0e.json", 88 | "46f33fce.json", 89 | "48d8fb45.json", 90 | "4938f0c2.json", 91 | "496994bd.json", 92 | "49d1d64f.json", 93 | "4be741c5.json", 94 | "4c4377d9.json", 95 | "4c5c2cf0.json", 96 | "50846271.json", 97 | "508bd3b6.json", 98 | "50cb2852.json", 99 | "5117e062.json", 100 | "5168d44c.json", 101 | "539a4f51.json", 102 | "53b68214.json", 103 | "543a7ed5.json", 104 | "54d82841.json", 105 | "54d9e175.json", 106 | "5521c0d9.json", 107 | "5582e5ca.json", 108 | "5614dbcf.json", 109 | "56dc2b01.json", 110 | "56ff96f3.json", 111 | "57aa92db.json", 112 | "5ad4f10b.json", 113 | "5bd6f4ac.json", 114 | "5daaa586.json", 115 | "60b61512.json", 116 | "623ea044.json", 117 | "63613498.json", 118 | "6430c8c4.json", 119 | "6455b5f5.json", 120 | "662c240a.json", 121 | "67385a82.json", 122 | "673ef223.json", 123 | "67a3c6ac.json", 124 | "67a423a3.json", 125 | "67e8384a.json", 126 | "681b3aeb.json", 127 | "6855a6e4.json", 128 | "694f12f3.json", 129 | "6a1e5592.json", 130 | "6aa20dc0.json", 131 | "6b9890af.json", 132 | "6cdd2623.json", 133 | "6cf79266.json", 134 | "6d0160f0.json", 135 | "6d0aefbc.json", 136 | "6d58a25d.json", 137 | "6d75e8bb.json", 138 | "6e02f1e3.json", 139 | "6e19193c.json", 140 | "6ecd11f4.json", 141 | "6f8cd79b.json", 142 | "6fa7a44f.json", 143 | "72322fa7.json", 144 | "72ca375d.json", 145 | "7447852a.json", 146 | "7468f01a.json", 147 | "746b3537.json", 148 | "74dd1130.json", 149 | "75b8110e.json", 150 | "760b3cac.json", 151 | "776ffc46.json", 152 | "77fdfe62.json", 153 | "780d0b14.json", 154 | "7837ac64.json", 155 | "794b24be.json", 156 | "7b6016b9.json", 157 | "7b7f7511.json", 158 | "7df24a62.json", 159 | "7e0986d6.json", 160 | "7f4411dc.json", 161 | "7fe24cdd.json", 162 | "80af3007.json", 163 | "810b9b61.json", 164 | "82819916.json", 165 | "834ec97d.json", 166 | "8403a5d5.json", 167 | "846bdb03.json", 168 | "855e0971.json", 169 | "85c4e7cd.json", 170 | "868de0fa.json", 171 | "8731374e.json", 172 | "88a10436.json", 173 | "88a62173.json", 174 | "890034e9.json", 175 | "8a004b2b.json", 176 | "8be77c9e.json", 177 | "8d5021e8.json", 178 | "8d510a79.json", 179 | "8e1813be.json", 180 | "8e5a5113.json", 181 | "8eb1be9a.json", 182 | "8efcae92.json", 183 | "8f2ea7aa.json", 184 | "90c28cc7.json", 185 | "90f3ed37.json", 186 | "913fb3ed.json", 187 | "91413438.json", 188 | "91714a58.json", 189 | "9172f3a0.json", 190 | "928ad970.json", 191 | "93b581b8.json", 192 | "941d9a10.json", 193 | "94f9d214.json", 194 | "952a094c.json", 195 | "9565186b.json", 196 | "95990924.json", 197 | "963e52fc.json", 198 | "97a05b5b.json", 199 | "98cf29f8.json", 200 | "995c5fa3.json", 201 | "99fa7670.json", 202 | "9aec4887.json", 203 | "9af7a82c.json", 204 | "9d9215db.json", 205 | "9dfd6313.json", 206 | "9ecd008a.json", 207 | "9f236235.json", 208 | "a1570a43.json", 209 | "a2fd1cf0.json", 210 | "a3df8b1e.json", 211 | "a416b8f3.json", 212 | "a48eeaf7.json", 213 | "a5f85a15.json", 214 | "a61ba2ce.json", 215 | "a61f2674.json", 216 | "a65b410d.json", 217 | "a68b268e.json", 218 | "a699fb00.json", 219 | "a740d043.json", 220 | "a79310a0.json", 221 | "a85d4709.json", 222 | "a8c38be5.json", 223 | "a8d7556c.json", 224 | "a9f96cdd.json", 225 | "aabf363d.json", 226 | "aba27056.json", 227 | "ae3edfdc.json", 228 | "ae4f1146.json", 229 | "aedd82e4.json", 230 | "af902bf9.json", 231 | "b0c4d837.json", 232 | "b190f7f5.json", 233 | "b1948b0a.json", 234 | "b230c067.json", 235 | "b27ca6d3.json", 236 | "b2862040.json", 237 | "b527c5c6.json", 238 | "b548a754.json", 239 | "b60334d2.json", 240 | "b6afb2da.json", 241 | "b7249182.json", 242 | "b782dc8a.json", 243 | "b91ae062.json", 244 | "b9b7f026.json", 245 | "ba26e723.json", 246 | "ba97ae07.json", 247 | "bb43febb.json", 248 | "bbc9ae5d.json", 249 | "bc1d5164.json", 250 | "bd4472b8.json", 251 | "bda2d7a6.json", 252 | "bdad9b1f.json", 253 | "be94b721.json", 254 | "beb8660c.json", 255 | "c0f76784.json", 256 | "c1d99e64.json", 257 | "c3e719e8.json", 258 | "c444b776.json", 259 | "c59eb873.json", 260 | "c8cbb738.json", 261 | "c909285e.json", 262 | "c9e6f938.json", 263 | "c9f8e694.json", 264 | "cbded52d.json", 265 | "cce03e0d.json", 266 | "cdecee7f.json", 267 | "ce22a75a.json", 268 | "ce4f8723.json", 269 | "ce602527.json", 270 | "ce9e57f2.json", 271 | "cf98881b.json", 272 | "d037b0a7.json", 273 | "d06dbe63.json", 274 | "d0f5fe59.json", 275 | "d10ecb37.json", 276 | "d13f3404.json", 277 | "d23f8c26.json", 278 | "d2abd087.json", 279 | "d364b489.json", 280 | "d406998b.json", 281 | "d43fd935.json", 282 | "d4469b4b.json", 283 | "d4a91cb9.json", 284 | "d4f3cd78.json", 285 | "d511f180.json", 286 | "d5d6de2d.json", 287 | "d687bc17.json", 288 | "d6ad076f.json", 289 | "d89b689b.json", 290 | "d8c310e9.json", 291 | "d90796e8.json", 292 | "d9f24cd1.json", 293 | "dae9d2b5.json", 294 | "db3e9e38.json", 295 | "dbc1a6ce.json", 296 | "dc1df850.json", 297 | "dc433765.json", 298 | "ddf7fa4f.json", 299 | "de1cd16c.json", 300 | "e21d9049.json", 301 | "e26a3af2.json", 302 | "e3497940.json", 303 | "e40b9e2f.json", 304 | "e48d4e1a.json", 305 | "e509e548.json", 306 | "e6721834.json", 307 | "e73095fd.json", 308 | "e76a88a6.json", 309 | "e8593010.json", 310 | "e8dc4411.json", 311 | "e9614598.json", 312 | "e98196ab.json", 313 | "e9afcf9a.json", 314 | "ea32f347.json", 315 | "ea786f4a.json", 316 | "eb5a1d5d.json", 317 | "ec883f72.json", 318 | "ecdecbb3.json", 319 | "ed36ccf7.json", 320 | "ef135b50.json", 321 | "f15e1fac.json", 322 | "f1cefba8.json", 323 | "f25fbde4.json", 324 | "f25ffba3.json", 325 | "f2829549.json", 326 | "f35d900a.json", 327 | "f5b8619d.json", 328 | "f8a8fe49.json", 329 | "f8b3ba0a.json", 330 | "f8c80d96.json", 331 | "f8ff0b80.json", 332 | "fafffa47.json", 333 | "fcb5c309.json", 334 | "fcc82909.json", 335 | "feca6190.json", 336 | "ff28f65a.json", 337 | "ff805c23.json" 338 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Human Abstraction and Reasoning Corpus (H-ARC) 2 | 3 | This repository contains scripts and additional information about H-ARC to accompany our paper [published in](https://www.nature.com/articles/s41597-025-05687-1) the Scientific Data journal. 4 | 5 | ![Figure 2: Example of a human actions traces from H-ARC and corresponding natural language descriptions](figures/arc-preprint-figure2-background.png) 6 | 7 | The data can be downloaded from our [OSF repository](https://osf.io/bh8yq). 8 | 9 | Participant responses, natural language descriptions, errors and state space graphs can all be explored visually on our [project webpage](https://arc-visualizations.github.io/index.html). 10 | 11 | H-ARC consists of action by action traces of humans solving ARC tasks from the both the training and evaluation sets using an interface and setup similar to François Chollet's initial proposal. The experiment platform can be viewed at [this link](https://exps.gureckislab.org/e/assumption-fast-natural/#/) and the underlying code is available in an [accompanying Github repo](https://github.com/Le-Gris/harc-viewer). 12 | 13 | The original ARC dataset can be found [here](https://github.com/fchollet/ARC-AGI). 14 | 15 | ## Citing our work 16 | 17 | ``` 18 | @article{legrisComprehensiveBehavioralDataset2025, 19 | title = {A {{Comprehensive Behavioral Dataset}} for the {{Abstraction}} and {{Reasoning Corpus}}}, 20 | author = {LeGris, Solim and Vong, Wai Keen and Lake, Brenden M. and Gureckis, Todd M.}, 21 | year = {2025}, 22 | month = aug, 23 | journal = {Scientific Data}, 24 | volume = {12}, 25 | number = {1}, 26 | pages = {1380}, 27 | issn = {2052-4463}, 28 | doi = {10.1038/s41597-025-05687-1}, 29 | abstract = {The Abstraction and Reasoning Corpus (ARC) is a visual program synthesis benchmark designed to test out-of-distribution generalization in machines. Comparing AI algorithms to human performance is essential to measure progress on these problems. In this paper, we present H-ARC (Human-ARC): a novel large-scale dataset containing solution attempts from over 1700 humans on ARC problems. The dataset spans the full set of 400 training and 400 evaluation tasks from the original ARC benchmark, and it is the largest human evaluation to date. By publishing the dataset, we contribute human responses to each problem, step-by-step behavioral action traces from the ARC user-interface, and natural-language solution descriptions of the inferred program/rule. We believe this dataset will be of value to researchers, both in cognitive science and AI, since it offers the potential to facilitate the discovery of underlying mechanisms supporting abstraction and reasoning in people. The insights to be gained from these data not only have value for cognitive science, but could in turn inform the design of more efficient, human-like AI algorithms.} 30 | } 31 | ``` 32 | 33 | ## Getting started 34 | 35 | ### Setting up the Python Environment 36 | 37 | 1. Ensure you have Python 3.10 or later installed on your system. 38 | 39 | 2. Clone this repository to your local machine: 40 | 41 | ```bash 42 | gh repo clone le-gris/h-arc 43 | cd h-arc 44 | ``` 45 | 46 | 3. Create a virtual environment: 47 | 48 | ```bash 49 | python -m venv .venv 50 | ``` 51 | 52 | 4. Activate the virtual environment: 53 | 54 | - On Windows: 55 | ```bash 56 | venv\Scripts\activate 57 | ``` 58 | - On macOS and Linux: 59 | ```bash 60 | source .venv/bin/activate 61 | ``` 62 | 63 | 5. Install the required packages using pip and the requirements.txt file: 64 | ```bash 65 | pip install -r requirements.txt 66 | ``` 67 | 68 | ### Extracting the dataset 69 | 70 | The H-ARC dataset can be downloaded as a zip archive from our OSF repository. To extract it: 71 | 72 | 1. Navigate to the project root directory if you're not already there and move the zip archive there. Make sure it is named `osfstorage-archive.zip`. 73 | 74 | 2. Use the following command to extract the dataset: 75 | - On Windows: 76 | ```bash 77 | tar -xf data/osfstorage-archive.zip 78 | ``` 79 | - On macOS and Linux: 80 | ```bash 81 | unzip data/osfstorage-archive.zip 82 | ``` 83 | 84 | After extraction, you should see several CSV files in the `data` and `survey` folders. 85 | 86 | ## Dataset 87 | 88 | The H-ARC dataset consists of several CSV files containing different aspects of human performance on ARC tasks. 89 | 90 | All files are in CSV format. In the `data` folder, there are the following files: 91 | 92 | - `data.csv`: All collected data from complete / incomplete participant data 93 | - `incorrect_submissions.csv`: All unique errors on each task and their counts from complete/incomplete participant data 94 | - `summary_data.csv`: Attempt by attempt summary data for complete/incomplete participant data 95 | - [`readme.md`](data/readme.md): data directory description 96 | 97 | In the `survey` folder, there are the following files: 98 | 99 | - `feedback.csv`: Participant feedback 100 | - `demographics.csv`: Demographic information 101 | - `withdraw.csv`: Withdrawal information 102 | - [`readme.md`](survey/readme.md): survey directory description 103 | 104 | For more detailed information about the dataset, see each of the readme files. 105 | 106 | ## Bayesian IRT Model 107 | 108 | To analyze performance using a Bayesian Item Response Theory model: 109 | 110 | 1. Generate the model: 111 | 112 | ```bash 113 | python src/generate_IRT_model_results.py --n_samples 10000 --n_burn 2000 --seed 4 --impute 114 | ``` 115 | 116 | Remove the `--impute` flag to exclude missing data from the analysis. 117 | 118 | 2. Generate plots and statistics: 119 | ```bash 120 | python src/IRT_model_plots_analysis.py --model_path models/bayes_IRT_model_burn2000_N10000_imputed_4.pkl --verbose 121 | ``` 122 | 123 | This will create trace plots, parameter visualizations, and detailed statistics in the `figures/` and `results/` directories. 124 | 125 | For more details, see the model [`readme.md`](models/README.md) file. 126 | 127 | ## Analyses 128 | 129 | We include in this repository the main Jupyter notebooks used to compute reported results from our paper. 130 | 131 | ### Notebooks 132 | 133 | #### [0-arc-dataset.ipynb](analysis/0-arc-dataset.ipynb) 134 | 135 | This notebook looks at some aspects of the ARC dataset structure. 136 | 137 | #### [1-basic-results](analysis/1-basic-results.ipynb) 138 | 139 | This notebook computes basic performance metrics on the H-ARC dataset, including overall solve rates, action counts, and time-related statistics for both training and evaluation tasks. 140 | 141 | #### [2-demogrpahics](analysis/2-demographics.ipynb) 142 | 143 | This notebook looks at some basic demographics data from our pool of participants. 144 | 145 | #### [3-misc](analysis/3-misc.ipynb) 146 | 147 | This notebook contains miscellaneous analyses, including participant counts for different experimental conditions and various data processing steps. 148 | 149 | #### [4-errors](analysis/4-errors.ipynb) 150 | 151 | This notebook analyzes error patterns in participant responses, including copy errors and other common mistake types across both training and evaluation tasks. 152 | 153 | #### [5-learning](analysis/5-learning.ipynb) 154 | 155 | This notebook examines learning effects across tasks using mixed-effects logistic regression models. It analyzes how task success rates change as participants progress through the experiment. 156 | 157 | #### [6-incomplete-data-analysis](analysis/6-incomplete-data.ipynb) 158 | 159 | This notebook focuses on analyzing incomplete task attempts, comparing performance metrics between participants who completed all tasks and those who didn't, and examining factors that might contribute to task incompletion. 160 | 161 | #### [7-human-machine](analysis/7-human-machine.ipynb) 162 | 163 | This notebook compares the performance of human participants with that of algorithmic solutions to evaluation set ARC tasks. It analyzes success rates, error patterns, and solution strategies between humans and AI systems. 164 | 165 | #### [8-action-traces](analysis/8-action-traces.ipynb) 166 | 167 | This notebook shows how to use the data to extract action traces for further analysis. 168 | 169 | ## Processing Kaggle Submission 170 | 171 | Follow these steps to process a Kaggle submission file. This will faciliate downstream human-machine comparisons. Here we use the "Claude-3.5 (Baseline)" approach from the [ARC Prize leaderboard](https://arcprize.org/leaderboard) as an example. 172 | 173 | 1. Create the necessary directories: 174 | 175 | ```bash 176 | mkdir -p data/kaggle_solutions/claude3_5-langchain 177 | ``` 178 | 179 | 2. Visit the following webpage: 180 | [Claude 3.5 Langchain ARC Submission](https://www.kaggle.com/code/gregkamradt/using-frontier-models-on-arc-agi-via-langchain/output) 181 | 182 | 3. Download the `submission.json` file from the webpage into the `data/kaggle_solutions/claude3_5-langchain` directory. 183 | 184 | 4. Run the `kaggle_submision_to_csv.py` script with the appropriate submission ID: 185 | ```bash 186 | python src/kaggle_submision_to_csv.py --submission_id claude3_5-langchain 187 | ``` 188 | 189 | This will process the JSON file and create a CSV file in the same directory with a similar format to our human data. 190 | 191 | ## License 192 | 193 | This dataset is licensed under the [CC0 1.0 Universal](https://creativecommons.org/publicdomain/zero/1.0/) and can be used for any purposes. 194 | -------------------------------------------------------------------------------- /analysis/3-misc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# imports\n", 10 | "import sys\n", 11 | "\n", 12 | "sys.path.append(\"..\")\n", 13 | "from src.utils import *\n", 14 | "import polars as pl\n", 15 | "from datetime import datetime as dt\n", 16 | "import seaborn as sns\n", 17 | "import numpy as np\n", 18 | "from matplotlib import pyplot as plt\n", 19 | "from scipy.stats import ttest_ind, permutation_test" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "data": { 29 | "text/plain": [ 30 | "" 31 | ] 32 | }, 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "output_type": "execute_result" 36 | } 37 | ], 38 | "source": [ 39 | "# polars config\n", 40 | "pl.Config.load_from_file(\"../polars_cfg.json\")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# data file paths\n", 50 | "data_path = \"/Users/solimlegris/Projets/h-arc-osf/data/data.csv\"\n", 51 | "summary_path = \"/Users/solimlegris/Projets/h-arc-osf/data/summary_data.csv\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "df_summary = pl.read_csv(summary_path)\n", 61 | "df_all = pl.read_csv(data_path)\n", 62 | "\n", 63 | "# parse time\n", 64 | "df_all = df_all.with_columns(pl.col(\"time\").cast(pl.Datetime))\n", 65 | "\n", 66 | "columns = [\n", 67 | " \"exp_name\",\n", 68 | " \"hashed_id\",\n", 69 | " \"joint_id_task\",\n", 70 | " \"task_name\",\n", 71 | " \"task_number\",\n", 72 | " \"task_type\",\n", 73 | " \"attempt_number\",\n", 74 | " \"action\",\n", 75 | " \"action_id\",\n", 76 | " \"solved\",\n", 77 | " \"time\",\n", 78 | " \"test_input_grid\",\n", 79 | " \"test_output_grid\",\n", 80 | "]\n", 81 | "df_all = df_all.select(columns)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "Number of participants given 5 tasks on training set: 542\n", 94 | "Number of participants given 10 tasks on training set: 241\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "df_by_participant = df_summary.group_by(\"hashed_id\").agg(\n", 100 | " pl.max(\"task_number\").alias(\"tasks_completed\"),\n", 101 | " pl.min(\"task_number\").alias(\"first_task\"),\n", 102 | " pl.first(\"task_type\"),\n", 103 | " pl.first(\"exp_name\"),\n", 104 | ")\n", 105 | "\n", 106 | "# number of participants given 5 tasks on training set\n", 107 | "five = df_by_participant.filter(\n", 108 | " (pl.col(\"task_type\") == \"training\")\n", 109 | " & ~(pl.col(\"exp_name\").is_in([\"expv0\", \"expv1\"]))\n", 110 | ")\n", 111 | "ten = df_by_participant.filter(\n", 112 | " (pl.col(\"task_type\") == \"training\") & (pl.col(\"exp_name\").is_in([\"expv0\", \"expv1\"]))\n", 113 | ")\n", 114 | "print(\"Number of participants given 5 tasks on training set: \", len(five))\n", 115 | "print(\"Number of participants given 10 tasks on training set: \", len(ten))" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "Number of participants not prevented from submitting the same grid after an incorrect attempt: 405/1729\n", 128 | "\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "# get number of participants who were prevented from submitting the same task\n", 134 | "num_copy_allowed = len(\n", 135 | " df_all.group_by(\"hashed_id\")\n", 136 | " .agg(pl.max(\"time\").alias(\"last_time\"), pl.first(\"task_type\"))\n", 137 | " .filter((pl.col(\"last_time\") < dt(2023, 12, 1, 15, 43, 10)))\n", 138 | ")\n", 139 | "total = len(df_summary[\"hashed_id\"].unique())\n", 140 | "print(\n", 141 | " f\"Number of participants not prevented from submitting the same grid after an incorrect attempt: {num_copy_allowed}/{total}\\n\"\n", 142 | ")" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "Proportion of incorrect submissions that were copies of previous submissions outputs: 0.06\n", 155 | "\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "# rate of copied outputs\n", 161 | "test_outputs = (\n", 162 | " df_summary.filter(~pl.col(\"solved\"))\n", 163 | " .group_by(\"joint_id_task\")\n", 164 | " .agg(pl.col(\"test_output_grid\"), pl.len())\n", 165 | ")\n", 166 | "total = test_outputs.select(pl.sum(\"len\")).item()\n", 167 | "# apply set to remove duplicates\n", 168 | "test_outputs = test_outputs.with_columns(\n", 169 | " pl.col(\"test_output_grid\").list.n_unique().alias(\"unique_count\")\n", 170 | ")\n", 171 | "num_copied = (\n", 172 | " test_outputs.select((pl.col(\"len\") - pl.col(\"unique_count\")).alias(\"num_copies\"))\n", 173 | " .select(pl.sum(\"num_copies\"))\n", 174 | " .item()\n", 175 | ")\n", 176 | "print(\n", 177 | " f\"Proportion of incorrect submissions that were copies of previous submissions outputs: {num_copied/total:.2f}\\n\"\n", 178 | ")" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 8, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/html": [ 189 | "
\n", 196 | "shape: (10, 3)
task_namesolvedcount
strf64u32
"b4a43f3b.json"0.010
"31d5ba1a.json"0.09
"a8610ef7.json"0.08
"8719f442.json"0.07
"79fb03f4.json"0.06
"fea12743.json"0.09090911
"e6de6e8f.json"0.110
"34b99a2b.json"0.1111119
"1c56ad9f.json"0.1111119
"0c9aba6e.json"0.1111119
" 197 | ], 198 | "text/plain": [ 199 | "shape: (10, 3)\n", 200 | "┌───────────────┬──────────┬───────┐\n", 201 | "│ task_name ┆ solved ┆ count │\n", 202 | "│ --- ┆ --- ┆ --- │\n", 203 | "│ str ┆ f64 ┆ u32 │\n", 204 | "╞═══════════════╪══════════╪═══════╡\n", 205 | "│ b4a43f3b.json ┆ 0.0 ┆ 10 │\n", 206 | "│ 31d5ba1a.json ┆ 0.0 ┆ 9 │\n", 207 | "│ a8610ef7.json ┆ 0.0 ┆ 8 │\n", 208 | "│ 8719f442.json ┆ 0.0 ┆ 7 │\n", 209 | "│ 79fb03f4.json ┆ 0.0 ┆ 6 │\n", 210 | "│ fea12743.json ┆ 0.090909 ┆ 11 │\n", 211 | "│ e6de6e8f.json ┆ 0.1 ┆ 10 │\n", 212 | "│ 34b99a2b.json ┆ 0.111111 ┆ 9 │\n", 213 | "│ 1c56ad9f.json ┆ 0.111111 ┆ 9 │\n", 214 | "│ 0c9aba6e.json ┆ 0.111111 ┆ 9 │\n", 215 | "└───────────────┴──────────┴───────┘" 216 | ] 217 | }, 218 | "execution_count": 8, 219 | "metadata": {}, 220 | "output_type": "execute_result" 221 | } 222 | ], 223 | "source": [ 224 | "# top ten hardest tasks ordered by number of participants\n", 225 | "df_summary.filter(pl.col(\"task_type\") == \"evaluation\").select(\n", 226 | " pl.col([\"joint_id_task\", \"task_name\", \"attempt_number\", \"solved\"])\n", 227 | " .top_k_by(\"attempt_number\", k=1)\n", 228 | " .over(\"joint_id_task\", mapping_strategy=\"explode\")\n", 229 | ").group_by(\"task_name\").agg(pl.sum(\"solved\") / pl.len(), pl.len().alias(\"count\")).sort(\n", 230 | " [\"solved\", \"count\"], descending=[False, True]\n", 231 | ").head(\n", 232 | " 10\n", 233 | ")" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": ".venv", 240 | "language": "python", 241 | "name": "python3" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.12.5" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 2 258 | } 259 | -------------------------------------------------------------------------------- /arc_data/ARC_evaluation_tasks_ordered.json: -------------------------------------------------------------------------------- 1 | [ 2 | "f0afb749.json", 3 | "94414823.json", 4 | "dc2e9a9d.json", 5 | "f83cb3f6.json", 6 | "baf41dbf.json", 7 | "93b4f4b3.json", 8 | "ff72ca3e.json", 9 | "50f325b5.json", 10 | "da515329.json", 11 | "60a26a3e.json", 12 | "14754a24.json", 13 | "4ff4c9da.json", 14 | "f9d67f8b.json", 15 | "5ffb2104.json", 16 | "2037f2c7.json", 17 | "00dbd492.json", 18 | "9c1e755f.json", 19 | "6a11f6da.json", 20 | "e760a62e.json", 21 | "7bb29440.json", 22 | "19bb5feb.json", 23 | "6ad5bdfd.json", 24 | "891232d6.json", 25 | "292dd178.json", 26 | "67b4a34d.json", 27 | "94be5b80.json", 28 | "df8cc377.json", 29 | "ce8d95cc.json", 30 | "72a961c9.json", 31 | "6f473927.json", 32 | "18419cfa.json", 33 | "45bbe264.json", 34 | "7c8af763.json", 35 | "f8be4b64.json", 36 | "e7dd8335.json", 37 | "103eff5b.json", 38 | "a57f2f04.json", 39 | "52fd389e.json", 40 | "7d1f7ee8.json", 41 | "95a58926.json", 42 | "8dae5dfc.json", 43 | "2753e76c.json", 44 | "c6e1b8da.json", 45 | "516b51b7.json", 46 | "351d6448.json", 47 | "c48954c1.json", 48 | "dc2aa30b.json", 49 | "712bf12e.json", 50 | "cb227835.json", 51 | "cd3c21df.json", 52 | "20981f0e.json", 53 | "03560426.json", 54 | "ca8de6ea.json", 55 | "e2092e0c.json", 56 | "195ba7dc.json", 57 | "fc754716.json", 58 | "09c534e7.json", 59 | "ac0c5833.json", 60 | "27a77e38.json", 61 | "7e02026e.json", 62 | "a680ac02.json", 63 | "ac605cbb.json", 64 | "5b6cbef5.json", 65 | "17b80ad2.json", 66 | "4acc7107.json", 67 | "67c52801.json", 68 | "ce039d91.json", 69 | "506d28a5.json", 70 | "5a5a2103.json", 71 | "0c9aba6e.json", 72 | "55783887.json", 73 | "ecaa0ec1.json", 74 | "929ab4e9.json", 75 | "ae58858e.json", 76 | "c658a4bd.json", 77 | "477d2879.json", 78 | "281123b4.json", 79 | "12422b43.json", 80 | "47996f11.json", 81 | "73c3b0d8.json", 82 | "137f0df0.json", 83 | "94133066.json", 84 | "ed98d772.json", 85 | "fea12743.json", 86 | "e69241bd.json", 87 | "64a7c07e.json", 88 | "7d419a02.json", 89 | "9772c176.json", 90 | "b457fec5.json", 91 | "310f3251.json", 92 | "c92b942c.json", 93 | "140c817e.json", 94 | "b7999b51.json", 95 | "ac3e2b04.json", 96 | "3d31c5b3.json", 97 | "2546ccf6.json", 98 | "626c0bcc.json", 99 | "de493100.json", 100 | "90347967.json", 101 | "88207623.json", 102 | "45737921.json", 103 | "fb791726.json", 104 | "c3202e5a.json", 105 | "642d658d.json", 106 | "456873bc.json", 107 | "782b5218.json", 108 | "9b365c51.json", 109 | "b9630600.json", 110 | "c7d4e6ad.json", 111 | "c35c1b4c.json", 112 | "60c09cac.json", 113 | "d19f7514.json", 114 | "8ba14f53.json", 115 | "0c786b71.json", 116 | "a04b2602.json", 117 | "e6de6e8f.json", 118 | "7039b2d7.json", 119 | "7d18a6fb.json", 120 | "4c177718.json", 121 | "c97c0139.json", 122 | "1e81d6f9.json", 123 | "4364c1c4.json", 124 | "72207abc.json", 125 | "e4075551.json", 126 | "31d5ba1a.json", 127 | "896d5239.json", 128 | "4e45f183.json", 129 | "009d5c81.json", 130 | "a406ac07.json", 131 | "5af49b42.json", 132 | "b942fd60.json", 133 | "11e1fe23.json", 134 | "b7cb93ac.json", 135 | "cfb2ce5a.json", 136 | "62b74c02.json", 137 | "7953d61e.json", 138 | "c663677b.json", 139 | "96a8c0cd.json", 140 | "a8610ef7.json", 141 | "0a1d4ef5.json", 142 | "69889d6e.json", 143 | "a934301b.json", 144 | "97239e3d.json", 145 | "4f537728.json", 146 | "a096bf4d.json", 147 | "575b1a71.json", 148 | "13713586.json", 149 | "8719f442.json", 150 | "40f6cd08.json", 151 | "12eac192.json", 152 | "770cc55f.json", 153 | "bc4146bd.json", 154 | "0b17323b.json", 155 | "ca8f78db.json", 156 | "e9bb6954.json", 157 | "639f5a19.json", 158 | "85b81ff1.json", 159 | "551d5bf1.json", 160 | "55059096.json", 161 | "5783df64.json", 162 | "3a301edc.json", 163 | "22a4bbc2.json", 164 | "4aab4007.json", 165 | "f9a67cb5.json", 166 | "f823c43c.json", 167 | "642248e4.json", 168 | "705a3229.json", 169 | "ad7e01d0.json", 170 | "73182012.json", 171 | "e99362f0.json", 172 | "c64f1187.json", 173 | "4e469f39.json", 174 | "e5c44e8f.json", 175 | "ccd554ac.json", 176 | "7ee1c6ea.json", 177 | "e5790162.json", 178 | "29700607.json", 179 | "9ddd00f0.json", 180 | "3194b014.json", 181 | "aa18de87.json", 182 | "af24b4cc.json", 183 | "e1baa8a4.json", 184 | "414297c0.json", 185 | "e133d23d.json", 186 | "1d398264.json", 187 | "e88171ec.json", 188 | "0e671a1a.json", 189 | "8e2edd66.json", 190 | "15696249.json", 191 | "e7b06bea.json", 192 | "48f8583b.json", 193 | "7c9b52a0.json", 194 | "3391f8c0.json", 195 | "f5c89df1.json", 196 | "42918530.json", 197 | "c074846d.json", 198 | "5207a7b5.json", 199 | "bf32578f.json", 200 | "8b28cd80.json", 201 | "fe9372f3.json", 202 | "a59b95c0.json", 203 | "93c31fbe.json", 204 | "1c56ad9f.json", 205 | "bf89d739.json", 206 | "e78887d1.json", 207 | "bd14c3bf.json", 208 | "c87289bb.json", 209 | "2a5f8217.json", 210 | "f21745ec.json", 211 | "59341089.json", 212 | "833dafe3.json", 213 | "505fff84.json", 214 | "79369cc6.json", 215 | "af22c60d.json", 216 | "aab50785.json", 217 | "b4a43f3b.json", 218 | "b0722778.json", 219 | "85fa5666.json", 220 | "fd4b2b02.json", 221 | "b1fc8b8e.json", 222 | "d56f2372.json", 223 | "1a2e2828.json", 224 | "358ba94e.json", 225 | "b20f7c8b.json", 226 | "8ee62060.json", 227 | "bbb1b8b6.json", 228 | "9b2a60aa.json", 229 | "25094a63.json", 230 | "d5c634a2.json", 231 | "0692e18c.json", 232 | "d304284e.json", 233 | "0f63c0b9.json", 234 | "9def23fe.json", 235 | "9b4c17c4.json", 236 | "27f8ce4f.json", 237 | "05a7bcf2.json", 238 | "42a15761.json", 239 | "c62e2108.json", 240 | "817e6c09.json", 241 | "ba9d41b8.json", 242 | "ea9794b1.json", 243 | "8cb8642d.json", 244 | "845d6e51.json", 245 | "e345f17b.json", 246 | "e95e3d8e.json", 247 | "9110e3c5.json", 248 | "e9b4f6fc.json", 249 | "d2acf2cb.json", 250 | "0934a4d8.json", 251 | "e9c9d9a1.json", 252 | "070dd51e.json", 253 | "762cd429.json", 254 | "da2b0fe3.json", 255 | "5289ad53.json", 256 | "e21a174a.json", 257 | "79fb03f4.json", 258 | "c1990cce.json", 259 | "20818e16.json", 260 | "bcb3040b.json", 261 | "2685904e.json", 262 | "3490cc26.json", 263 | "58743b76.json", 264 | "15113be4.json", 265 | "d017b73f.json", 266 | "cad67732.json", 267 | "12997ef3.json", 268 | "fd096ab6.json", 269 | "5b692c0f.json", 270 | "3f23242b.json", 271 | "992798f6.json", 272 | "1d0a4b61.json", 273 | "aa300dc3.json", 274 | "e74e1818.json", 275 | "4b6b68e5.json", 276 | "b15fca0b.json", 277 | "f5aa3634.json", 278 | "3b4c2228.json", 279 | "aa4ec2a5.json", 280 | "2b01abd0.json", 281 | "21f83797.json", 282 | "1acc24af.json", 283 | "15663ba9.json", 284 | "f3b10344.json", 285 | "6ea4a07e.json", 286 | "0bb8deee.json", 287 | "54db823b.json", 288 | "ef26cbf6.json", 289 | "f3cdc58f.json", 290 | "423a55dc.json", 291 | "2697da3f.json", 292 | "08573cc6.json", 293 | "0a2355a6.json", 294 | "256b0a75.json", 295 | "50aad11f.json", 296 | "f45f5ca7.json", 297 | "e66aafb8.json", 298 | "1da012fc.json", 299 | "1e97544e.json", 300 | "d931c21c.json", 301 | "68b67ca3.json", 302 | "58e15b12.json", 303 | "e7a25a18.json", 304 | "b0f4d537.json", 305 | "332efdb3.json", 306 | "16b78196.json", 307 | "9c56f360.json", 308 | "4cd1b7b2.json", 309 | "0607ce86.json", 310 | "5b526a93.json", 311 | "136b0064.json", 312 | "92e50de0.json", 313 | "81c0276b.json", 314 | "3979b1a8.json", 315 | "d37a1ef5.json", 316 | "bb52a14b.json", 317 | "9bebae7a.json", 318 | "66e6c45b.json", 319 | "604001fa.json", 320 | "981571dc.json", 321 | "0becf7df.json", 322 | "9356391f.json", 323 | "695367ec.json", 324 | "50a16a69.json", 325 | "ac2e8ecf.json", 326 | "a3f84088.json", 327 | "212895b5.json", 328 | "ea959feb.json", 329 | "62ab2642.json", 330 | "319f2597.json", 331 | "0d87d2a6.json", 332 | "dd2401ed.json", 333 | "c8b7cc0f.json", 334 | "5d2a5c43.json", 335 | "4852f2fa.json", 336 | "17cae0c1.json", 337 | "696d4842.json", 338 | "3ed85e70.json", 339 | "692cd3b6.json", 340 | "d47aa2ff.json", 341 | "e619ca6e.json", 342 | "1c02dbbe.json", 343 | "37d3e8b2.json", 344 | "b7fb29bc.json", 345 | "48131b3c.json", 346 | "2c737e39.json", 347 | "f4081712.json", 348 | "67636eac.json", 349 | "e1d2900e.json", 350 | "2c0b0aff.json", 351 | "f0df5ff0.json", 352 | "d492a647.json", 353 | "d94c3b52.json", 354 | "e9ac8c9e.json", 355 | "e0fb7511.json", 356 | "2072aba6.json", 357 | "99306f82.json", 358 | "6df30ad6.json", 359 | "ed74f2f2.json", 360 | "1a6449f1.json", 361 | "e872b94a.json", 362 | "e41c6fd3.json", 363 | "31adaf00.json", 364 | "73ccf9c2.json", 365 | "903d1b4a.json", 366 | "1990f7a8.json", 367 | "8597cfd7.json", 368 | "3ee1011a.json", 369 | "917bccba.json", 370 | "9f27f097.json", 371 | "8a371977.json", 372 | "32e9702f.json", 373 | "9caba7c3.json", 374 | "e633a9e5.json", 375 | "e681b708.json", 376 | "184a9768.json", 377 | "1c0d0a4b.json", 378 | "84f2aca1.json", 379 | "00576224.json", 380 | "84db8fc4.json", 381 | "2f0c5170.json", 382 | "d4c90558.json", 383 | "33b52de3.json", 384 | "be03b35f.json", 385 | "b7f8a4d8.json", 386 | "8fbca751.json", 387 | "cf133acc.json", 388 | "aee291af.json", 389 | "fafd9572.json", 390 | "963f59bc.json", 391 | "bf699163.json", 392 | "759f3fd3.json", 393 | "d282b262.json", 394 | "5833af48.json", 395 | "34b99a2b.json", 396 | "f3e62deb.json", 397 | "9a4bb226.json", 398 | "e7639916.json", 399 | "66f2d22f.json", 400 | "d4b1c2b1.json", 401 | "e57337a4.json" 402 | ] -------------------------------------------------------------------------------- /analysis/8-action-traces.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# imports\n", 10 | "import polars as pl\n", 11 | "import seaborn as sns\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "from matplotlib.lines import Line2D\n", 14 | "from datetime import timezone\n", 15 | "from datetime import datetime as dt\n", 16 | "import numpy as np\n", 17 | "import sys\n", 18 | "from scipy.stats import pearsonr\n", 19 | "from scipy import stats\n", 20 | "\n", 21 | "sys.path.append(\"../\")\n", 22 | "from src.utils import *" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/plain": [ 33 | "" 34 | ] 35 | }, 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "# polars config\n", 43 | "pl.Config.load_from_file(\"../polars_cfg.json\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "# 1. Data loading\n", 51 | "\n", 52 | "Choose whether to include participants that didn't complete all five tasks in the analyses with the True / False toggle.\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# data file paths\n", 62 | "clean_data_path = \"/Users/solimlegris/Projets/h-arc-osf/data/data.csv\"" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "df_all = pl.read_csv(clean_data_path)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/html": [ 82 | "
\n", 89 | "shape: (5, 40)
exp_nametask_typehashed_idjoint_id_tasktask_nametask_numberis_tutorialtimeattempt_numberaction_idsolveddonetest_input_gridtest_input_size_xtest_input_size_ytest_output_gridtest_output_size_xtest_output_size_yactionaction_xaction_yselect_locselected_dataselected_symbolselected_toolcopy_paste_datafirst_written_solutionlast_written_solutionwithdrawwithdraw_reasonwithdraw_commentagegenderraceeducation_levelhousehold_incomenormal_visioncolor_blindfluent_englishcomplete
strstrstrstrstri64boolstri64i64boolboolstri64i64stri64i64strf64f64strstri64strstrstrstrboolstrstrstrstrstrstrstrstrstrstrbool
"expv2""training""59e36641c3391529505bdf7af902cec0""59e36641c3391529505bdf7af902cec0_32597951.jso…"32597951.json"1false"2023-08-03T04:22:11.000000"11falsetrue"|000|000|000|"33"|000|000|000|"33"reset_grid"nullnullnull"[]"0"edit""[]""YES DONE EASILY""INTERESTING TO COMPLETE"falsenullnullnullnullnullnullnullnullnullnulltrue
"expv2""training""59e36641c3391529505bdf7af902cec0""59e36641c3391529505bdf7af902cec0_32597951.jso…"32597951.json"1false"2023-08-03T04:22:34.000000"12falsetrue"|10010010010010010|01001001001001001|00100100…1717"|000|000|000|"33"change_color"nullnullnull"[]"1"edit""[]""YES DONE EASILY""INTERESTING TO COMPLETE"falsenullnullnullnullnullnullnullnullnullnulltrue
"expv2""training""59e36641c3391529505bdf7af902cec0""59e36641c3391529505bdf7af902cec0_32597951.jso…"32597951.json"1false"2023-08-03T04:22:46.000000"13falsetrue"|10010010010010010|01001001001001001|00100100…1717"|000|000|000|000|000|000|000|000|000|000|000|…173"change_height"nullnullnull"[]"1"edit""[]""YES DONE EASILY""INTERESTING TO COMPLETE"falsenullnullnullnullnullnullnullnullnullnulltrue
"expv2""training""59e36641c3391529505bdf7af902cec0""59e36641c3391529505bdf7af902cec0_32597951.jso…"32597951.json"1false"2023-08-03T04:23:01.000000"14falsetrue"|10010010010010010|01001001001001001|00100100…1717"|00000000000000000|00000000000000000|00000000…1717"change_width"nullnullnull"[]"1"edit""[]""YES DONE EASILY""INTERESTING TO COMPLETE"falsenullnullnullnullnullnullnullnullnullnulltrue
"expv2""training""59e36641c3391529505bdf7af902cec0""59e36641c3391529505bdf7af902cec0_32597951.jso…"32597951.json"1false"2023-08-03T04:23:41.000000"15falsetrue"|10010010010010010|01001001001001001|00100100…1717"|00000000000000000|00000000000000000|00000000…1717"change_color"nullnullnull"[]"1"edit""[]""YES DONE EASILY""INTERESTING TO COMPLETE"falsenullnullnullnullnullnullnullnullnullnulltrue
" 90 | ], 91 | "text/plain": [ 92 | "shape: (5, 40)\n", 93 | "┌──────────┬───────────┬────────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n", 94 | "│ exp_name ┆ task_type ┆ hashed_id ┆ joint_id_ ┆ … ┆ normal_vi ┆ color_bli ┆ fluent_en ┆ complete │\n", 95 | "│ --- ┆ --- ┆ --- ┆ task ┆ ┆ sion ┆ nd ┆ glish ┆ --- │\n", 96 | "│ str ┆ str ┆ str ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ bool │\n", 97 | "│ ┆ ┆ ┆ str ┆ ┆ str ┆ str ┆ str ┆ │\n", 98 | "╞══════════╪═══════════╪════════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n", 99 | "│ expv2 ┆ training ┆ 59e36641c3 ┆ 59e36641c ┆ … ┆ null ┆ null ┆ null ┆ true │\n", 100 | "│ ┆ ┆ 391529505b ┆ 339152950 ┆ ┆ ┆ ┆ ┆ │\n", 101 | "│ ┆ ┆ df7af902ce ┆ 5bdf7af90 ┆ ┆ ┆ ┆ ┆ │\n", 102 | "│ ┆ ┆ c0 ┆ 2cec0_325 ┆ ┆ ┆ ┆ ┆ │\n", 103 | "│ ┆ ┆ ┆ 97951.jso ┆ ┆ ┆ ┆ ┆ │\n", 104 | "│ ┆ ┆ ┆ … ┆ ┆ ┆ ┆ ┆ │\n", 105 | "│ expv2 ┆ training ┆ 59e36641c3 ┆ 59e36641c ┆ … ┆ null ┆ null ┆ null ┆ true │\n", 106 | "│ ┆ ┆ 391529505b ┆ 339152950 ┆ ┆ ┆ ┆ ┆ │\n", 107 | "│ ┆ ┆ df7af902ce ┆ 5bdf7af90 ┆ ┆ ┆ ┆ ┆ │\n", 108 | "│ ┆ ┆ c0 ┆ 2cec0_325 ┆ ┆ ┆ ┆ ┆ │\n", 109 | "│ ┆ ┆ ┆ 97951.jso ┆ ┆ ┆ ┆ ┆ │\n", 110 | "│ ┆ ┆ ┆ … ┆ ┆ ┆ ┆ ┆ │\n", 111 | "│ expv2 ┆ training ┆ 59e36641c3 ┆ 59e36641c ┆ … ┆ null ┆ null ┆ null ┆ true │\n", 112 | "│ ┆ ┆ 391529505b ┆ 339152950 ┆ ┆ ┆ ┆ ┆ │\n", 113 | "│ ┆ ┆ df7af902ce ┆ 5bdf7af90 ┆ ┆ ┆ ┆ ┆ │\n", 114 | "│ ┆ ┆ c0 ┆ 2cec0_325 ┆ ┆ ┆ ┆ ┆ │\n", 115 | "│ ┆ ┆ ┆ 97951.jso ┆ ┆ ┆ ┆ ┆ │\n", 116 | "│ ┆ ┆ ┆ … ┆ ┆ ┆ ┆ ┆ │\n", 117 | "│ expv2 ┆ training ┆ 59e36641c3 ┆ 59e36641c ┆ … ┆ null ┆ null ┆ null ┆ true │\n", 118 | "│ ┆ ┆ 391529505b ┆ 339152950 ┆ ┆ ┆ ┆ ┆ │\n", 119 | "│ ┆ ┆ df7af902ce ┆ 5bdf7af90 ┆ ┆ ┆ ┆ ┆ │\n", 120 | "│ ┆ ┆ c0 ┆ 2cec0_325 ┆ ┆ ┆ ┆ ┆ │\n", 121 | "│ ┆ ┆ ┆ 97951.jso ┆ ┆ ┆ ┆ ┆ │\n", 122 | "│ ┆ ┆ ┆ … ┆ ┆ ┆ ┆ ┆ │\n", 123 | "│ expv2 ┆ training ┆ 59e36641c3 ┆ 59e36641c ┆ … ┆ null ┆ null ┆ null ┆ true │\n", 124 | "│ ┆ ┆ 391529505b ┆ 339152950 ┆ ┆ ┆ ┆ ┆ │\n", 125 | "│ ┆ ┆ df7af902ce ┆ 5bdf7af90 ┆ ┆ ┆ ┆ ┆ │\n", 126 | "│ ┆ ┆ c0 ┆ 2cec0_325 ┆ ┆ ┆ ┆ ┆ │\n", 127 | "│ ┆ ┆ ┆ 97951.jso ┆ ┆ ┆ ┆ ┆ │\n", 128 | "│ ┆ ┆ ┆ … ┆ ┆ ┆ ┆ ┆ │\n", 129 | "└──────────┴───────────┴────────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘" 130 | ] 131 | }, 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "df_all.head()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 9, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "sorted_df = df_all.sort(\n", 148 | " by=[\"joint_id_task\", \"task_number\", \"attempt_number\", \"action_id\"]\n", 149 | ")\n", 150 | "aggregated_df = sorted_df.group_by(\n", 151 | " \"joint_id_task\", \"attempt_number\", maintain_order=True\n", 152 | ").agg(pl.col(\"test_input_grid\").alias(\"grid_states\"))" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 10, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/html": [ 163 | "
\n", 170 | "shape: (5, 3)
joint_id_taskattempt_numbergrid_states
stri64list[str]
"00136a0f6142ddb48fd80ea8ff22f12d_27a77e38.jso…1["|000|000|000|", "|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|", … "|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|"]
"00136a0f6142ddb48fd80ea8ff22f12d_27a77e38.jso…2["|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|", "|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|", … "|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|"]
"00136a0f6142ddb48fd80ea8ff22f12d_27a77e38.jso…3["|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|", "|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|", … "|912849821|443127679|216978436|986342917|555555555|000000000|000000000|000000000|000000000|"]
"00136a0f6142ddb48fd80ea8ff22f12d_7e02026e.jso…1["|000|000|000|", "|808888808080|088800808000|888800088888|800080880080|088808088088|008880000000|808808808000|080800888888|000880080800|000080880880|000880880888|888080000888|", … "|808888808080|088800808000|888800088888|800080880080|088808088088|008880000000|808808808000|080800888888|000880080800|000080880880|000880880888|888080000888|"]
"00136a0f6142ddb48fd80ea8ff22f12d_7e02026e.jso…2["|808888808080|088800808000|888800088888|800080880080|088808088088|008880000000|808808808000|080800888888|000880080800|000080880880|000880880888|888080000888|", "|808888808080|088800808000|888800088888|800080880080|088808088088|008880000000|808808808000|080800888888|000880080800|000080880880|000880880888|888080000888|", "|808888808080|088800808000|888800088888|800080880080|088808088088|008880000000|808808808000|080800888888|000880080800|000080880880|000880880888|888080000888|"]
" 171 | ], 172 | "text/plain": [ 173 | "shape: (5, 3)\n", 174 | "┌────────────────────────────────────────┬────────────────┬────────────────────────────────────────┐\n", 175 | "│ joint_id_task ┆ attempt_number ┆ grid_states │\n", 176 | "│ --- ┆ --- ┆ --- │\n", 177 | "│ str ┆ i64 ┆ list[str] │\n", 178 | "╞════════════════════════════════════════╪════════════════╪════════════════════════════════════════╡\n", 179 | "│ 00136a0f6142ddb48fd80ea8ff22f12d_27a77 ┆ 1 ┆ [\"|000|000|000|\", │\n", 180 | "│ e38.jso… ┆ ┆ \"|912849821|443127679|21697… │\n", 181 | "│ 00136a0f6142ddb48fd80ea8ff22f12d_27a77 ┆ 2 ┆ [\"|912849821|443127679|216978436|98634 │\n", 182 | "│ e38.jso… ┆ ┆ 2917|55… │\n", 183 | "│ 00136a0f6142ddb48fd80ea8ff22f12d_27a77 ┆ 3 ┆ [\"|912849821|443127679|216978436|98634 │\n", 184 | "│ e38.jso… ┆ ┆ 2917|55… │\n", 185 | "│ 00136a0f6142ddb48fd80ea8ff22f12d_7e020 ┆ 1 ┆ [\"|000|000|000|\", │\n", 186 | "│ 26e.jso… ┆ ┆ \"|808888808080|088800808000… │\n", 187 | "│ 00136a0f6142ddb48fd80ea8ff22f12d_7e020 ┆ 2 ┆ [\"|808888808080|088800808000|888800088 │\n", 188 | "│ 26e.jso… ┆ ┆ 888|800… │\n", 189 | "└────────────────────────────────────────┴────────────────┴────────────────────────────────────────┘" 190 | ] 191 | }, 192 | "execution_count": 10, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "aggregated_df.head()" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 11, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "output_json_path = \"../data/grid_states.json\"\n", 208 | "\n", 209 | "# Write the aggregated DataFrame to a JSON file\n", 210 | "aggregated_df.write_json(output_json_path)" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": ".venv", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.12.9" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /src/IRT_model_plots_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | import os 3 | import cloudpickle as cpkl 4 | import arviz as az 5 | import polars as pl 6 | import matplotlib.pyplot as plt 7 | from matplotlib.patches import Patch 8 | from pathlib import Path 9 | import numpy as np 10 | import seaborn as sns 11 | 12 | basepath = Path(__file__).parent.parent 13 | 14 | LABEL_SIZE = 20 15 | TITLE_SIZE = 24 16 | TICK_SIZE = 18 17 | LEGEND_SIZE = 18 18 | 19 | plt.rcParams["text.usetex"] = True 20 | 21 | 22 | def get_args(): 23 | parser = ap.ArgumentParser(description="Generate IRT model plots and analysis") 24 | parser.add_argument("--model_path", type=str, help="path to model") 25 | parser.add_argument("--verbose", action="store_true", help="print verbose output") 26 | return parser.parse_args() 27 | 28 | 29 | def load_model(model_path): 30 | with open(os.path.join(basepath, model_path), "rb") as f: 31 | model, trace = cpkl.load(f) 32 | return model, trace 33 | 34 | 35 | def plot_trace(trace, model_name): 36 | ax = az.plot_trace( 37 | trace, 38 | var_names=[ 39 | "sigma_alpha", 40 | "sigma_beta", 41 | "epsilon_one", 42 | "epsilon_delta", 43 | "epsilon_two", 44 | ], 45 | ) 46 | fig = ax.ravel()[0].figure 47 | plt.tight_layout() 48 | fig.savefig( 49 | os.path.join(basepath, "figures", f"{model_name}_trace_plot.png"), 50 | bbox_inches="tight", 51 | dpi=300, 52 | ) 53 | plt.close() 54 | 55 | 56 | def get_participant_success_rate(): 57 | # load data 58 | df_summary = pl.read_csv(os.path.join(basepath, "data", "clean_summary_data.csv")) 59 | df_summary_incomplete = pl.read_csv( 60 | os.path.join(basepath, "data", "clean_summary_data_incomplete.csv") 61 | ) 62 | df_summary = pl.concat([df_summary, df_summary_incomplete]) 63 | 64 | # filter data 65 | participants_max_tasks = df_summary.group_by("hashed_id").agg( 66 | pl.max("task_number"), pl.first("task_type") 67 | ) 68 | participants_max_tasks = participants_max_tasks.rename( 69 | {"task_number": "max_task_number"} 70 | ) 71 | df_summary_ = df_summary.join(participants_max_tasks, on="hashed_id") 72 | participant_success = df_summary_.group_by("hashed_id").agg( 73 | pl.sum("solved") / pl.max("max_task_number"), pl.first("complete") 74 | ) 75 | return participant_success 76 | 77 | 78 | def generate_irt_dataframes(trace): 79 | # Load task success rate 80 | mean_task_acc_three_shot = pl.read_csv( 81 | os.path.join(basepath, "data", "mean_task_acc_three_attempts.csv") 82 | ) 83 | # Extract necessary data from the trace 84 | alpha_mean = trace.posterior["alpha"].mean(dim=["chain", "draw"]) 85 | alpha_hdi = az.hdi(trace.posterior["alpha"], hdi_prob=0.94) 86 | participants = trace.posterior["alpha"].coords["participants"].values 87 | participant_success = get_participant_success_rate() 88 | participant_success_values = ( 89 | participant_success.select("hashed_id", "solved") 90 | .to_pandas() 91 | .set_index("hashed_id") 92 | .loc[participants] 93 | .values.flatten() 94 | ) 95 | complete = ( 96 | participant_success.select("hashed_id", "complete") 97 | .to_pandas() 98 | .set_index("hashed_id") 99 | .loc[participants] 100 | .values.flatten() 101 | ) 102 | 103 | beta_mean = trace.posterior["beta"].mean(dim=["chain", "draw"]) 104 | beta_hdi = az.hdi(trace.posterior["beta"], hdi_prob=0.94) 105 | tasks = trace.posterior["beta"].coords["tasks"].values 106 | 107 | # order success rate by tasks 108 | task_success_values = ( 109 | mean_task_acc_three_shot.select("task_name", "mean_solved") 110 | .to_pandas() 111 | .set_index("task_name") 112 | .loc[tasks] 113 | .values.flatten() 114 | ) 115 | 116 | task_type_values = ( 117 | mean_task_acc_three_shot.select("task_name", "task_type") 118 | .to_pandas() 119 | .set_index("task_name") 120 | .loc[tasks] 121 | .values.flatten() 122 | ) 123 | epsilon_one_mean = trace.posterior["epsilon_one"].mean(dim=["chain", "draw"]) 124 | epsilon_two_mean = trace.posterior["epsilon_two"].mean(dim=["chain", "draw"]) 125 | epsilon_one_hdi = az.hdi(trace.posterior["epsilon_one"], hdi_prob=0.94) 126 | epsilon_two_hdi = az.hdi(trace.posterior["epsilon_two"], hdi_prob=0.94) 127 | 128 | epsilon_one = trace.posterior["epsilon_one"].values.flatten() 129 | epsilon_two = trace.posterior["epsilon_two"].values.flatten() 130 | 131 | # Create DataFrame for participant abilities 132 | plot_df_ability = pl.DataFrame( 133 | { 134 | "hashed_id": participants, 135 | "ability_mean": alpha_mean.values, 136 | "success_rate": participant_success_values, 137 | "ability_hdi_lower": alpha_hdi.sel(hdi="lower").alpha.values, 138 | "ability_hdi_upper": alpha_hdi.sel(hdi="higher").alpha.values, 139 | "complete": complete, 140 | } 141 | ) 142 | 143 | # Create DataFrame for task difficulties 144 | plot_df_difficulty = pl.DataFrame( 145 | { 146 | "tasks": tasks, 147 | "task_type": task_type_values, 148 | "success_rate": task_success_values, 149 | "diff_mean": beta_mean.values, 150 | "diff_hdi_lower": beta_hdi.sel(hdi="lower").beta.values, 151 | "diff_hdi_upper": beta_hdi.sel(hdi="higher").beta.values, 152 | } 153 | ) 154 | 155 | # Create DataFrame for epsilon values 156 | plot_df_epsilon = pl.DataFrame( 157 | { 158 | "epsilon": list(epsilon_one) + list(epsilon_two), 159 | "attempt": [r"$\epsilon_1$"] * len(epsilon_one) 160 | + [r"$\epsilon_2$"] * len(epsilon_two), 161 | } 162 | ) 163 | 164 | # create alternative dataframe for epsilon 165 | plot_df_epsilon_alt = pl.DataFrame( 166 | { 167 | "attempt": [1, 2], 168 | "epsilon_mean": [ 169 | epsilon_one_mean.values.item(), 170 | epsilon_two_mean.values.item(), 171 | ], 172 | "epsilon_hdi_lower": [ 173 | epsilon_one_hdi.sel(hdi="lower").epsilon_one.values.item(), 174 | epsilon_two_hdi.sel(hdi="lower").epsilon_two.values.item(), 175 | ], 176 | "epsilon_hdi_upper": [ 177 | epsilon_one_hdi.sel(hdi="higher").epsilon_one.values.item(), 178 | epsilon_two_hdi.sel(hdi="higher").epsilon_two.values.item(), 179 | ], 180 | } 181 | ) 182 | 183 | return plot_df_ability, plot_df_difficulty, plot_df_epsilon, plot_df_epsilon_alt 184 | 185 | 186 | def plot_irt_parameters(df_epsilon, plot_df_ability, plot_df_difficulty, model_name): 187 | # Create figure and gridspec for the main layout 188 | fig = plt.figure(figsize=(20, 12)) 189 | gs_main = plt.GridSpec( 190 | 2, 2, width_ratios=[3, 5], height_ratios=[1, 1], hspace=0.25, wspace=0.175 191 | ) 192 | 193 | eps_palette = sns.color_palette("Paired", 2) 194 | others_palette = sns.color_palette("tab10") 195 | 196 | # Create subgridspecs for the right plots 197 | gs_ability = gs_main[0, 1].subgridspec(1, 2, width_ratios=[4, 1], wspace=0.01) 198 | gs_diff = gs_main[1, 1].subgridspec(1, 2, width_ratios=[4, 1], wspace=0.01) 199 | 200 | # Left plot (epsilon distributions) 201 | ax_eps = fig.add_subplot(gs_main[0, 0]) 202 | 203 | sns.histplot( 204 | df_epsilon, 205 | x="epsilon", 206 | hue="attempt", 207 | bins=50, 208 | alpha=0.8, 209 | palette=eps_palette, 210 | stat="density", 211 | kde=True, 212 | ax=ax_eps, 213 | legend=True, 214 | ) 215 | ax_eps.set_title("(a) Feedback effect", fontsize=TITLE_SIZE) 216 | ax_eps.set_xlabel("Parameter estimates", fontsize=LABEL_SIZE) 217 | ax_eps.set_ylabel("Count", fontsize=LABEL_SIZE) 218 | ax_eps.set_xticklabels(ax_eps.get_xticks(), fontsize=TICK_SIZE) 219 | ax_eps.set_yticklabels(ax_eps.get_yticks(), fontsize=TICK_SIZE) 220 | handles = ax_eps.get_legend().legend_handles 221 | labels = [r"$\gamma_1$", r"$\gamma_2$"] 222 | ax_eps.legend(handles, labels, fontsize=LEGEND_SIZE) 223 | ax_eps.grid(True, alpha=0.1) 224 | 225 | # Top right plots (ability) 226 | ax_ability_main = fig.add_subplot(gs_ability[0, 0]) 227 | ax_ability_kde = fig.add_subplot(gs_ability[0, 1]) 228 | 229 | # Jittered success rates for scatter plots 230 | jittered_success_rate_ability = plot_df_ability.select( 231 | "success_rate" 232 | ).to_numpy().flatten() + np.random.normal(0, 0.01, len(plot_df_ability)) 233 | 234 | # Create ability scatter plot 235 | ax_ability_main.scatter( 236 | jittered_success_rate_ability, 237 | plot_df_ability.select("ability_mean").to_numpy().flatten(), 238 | s=20, 239 | alpha=0.075, 240 | color="black", 241 | zorder=2, 242 | ) 243 | 244 | ax_ability_main.set_title("(b) Participant ability", fontsize=TITLE_SIZE) 245 | ax_ability_main.set_xlabel("Mean participant accuracy", fontsize=LABEL_SIZE) 246 | ax_ability_main.set_xticklabels([]) 247 | ax_ability_main.set_ylabel(r"$\alpha$", fontsize=LABEL_SIZE) 248 | ax_ability_main.set_yticklabels(ax_ability_main.get_yticks(), fontsize=TICK_SIZE) 249 | ax_ability_main.grid(True, alpha=0.1) 250 | ax_ability_main.axhline(0, color="black", linestyle="--", alpha=0.2) 251 | 252 | # Create ability KDE 253 | sns.kdeplot( 254 | plot_df_ability, 255 | y="ability_mean", 256 | ax=ax_ability_kde, 257 | fill=True, 258 | color=others_palette[4], 259 | alpha=0.1, 260 | ) 261 | ax_ability_kde.set_yticklabels([]) 262 | ax_ability_kde.set_xticklabels([]) 263 | ax_ability_kde.set_xlabel("") 264 | ax_ability_kde.set_ylabel("") 265 | ax_ability_kde.spines["left"].set_visible(False) 266 | ax_ability_kde.spines["top"].set_visible(False) 267 | ax_ability_kde.spines["right"].set_visible(False) 268 | ax_ability_kde.spines["bottom"].set_visible(False) 269 | ax_ability_kde.tick_params(axis="y", left=False) 270 | ax_ability_kde.tick_params(axis="x", bottom=False) 271 | ax_ability_kde.set_ylim(ax_ability_main.get_ylim()) 272 | 273 | # Bottom right plots (difficulty) 274 | ax_diff_main = fig.add_subplot(gs_diff[0, 0]) 275 | ax_diff_kde = fig.add_subplot(gs_diff[0, 1]) 276 | 277 | # Jittered success rates for scatter plots 278 | jittered_success_rate_difficulty = plot_df_difficulty.select( 279 | "success_rate" 280 | ).to_numpy().flatten() + np.random.normal(0, 0.01, len(plot_df_difficulty)) 281 | 282 | # Create difficulty scatter plot 283 | ax_diff_main.scatter( 284 | jittered_success_rate_difficulty, 285 | plot_df_difficulty.select("diff_mean").to_numpy().flatten(), 286 | s=20, 287 | alpha=0.2, 288 | color="black", 289 | zorder=2, 290 | ) 291 | ax_diff_main.errorbar( 292 | jittered_success_rate_difficulty, 293 | plot_df_difficulty["diff_mean"], 294 | yerr=[ 295 | plot_df_difficulty["diff_mean"] - plot_df_difficulty["diff_hdi_lower"], 296 | plot_df_difficulty["diff_hdi_upper"] - plot_df_difficulty["diff_mean"], 297 | ], 298 | fmt="none", 299 | ecolor="gray", 300 | alpha=0.1, 301 | zorder=1, 302 | ) 303 | 304 | # add mean difficulty for training and eval tasks 305 | mean_diff_training = ( 306 | plot_df_difficulty.filter(pl.col("task_type") == "training") 307 | .select(pl.mean("diff_mean")) 308 | .item() 309 | ) 310 | mean_diff_eval = ( 311 | plot_df_difficulty.filter(pl.col("task_type") == "evaluation") 312 | .select(pl.mean("diff_mean")) 313 | .item() 314 | ) 315 | 316 | ax_diff_main.axhline( 317 | mean_diff_training, color=others_palette[0], linestyle="--", alpha=0.5 318 | ) 319 | ax_diff_main.axhline( 320 | mean_diff_eval, color=others_palette[1], linestyle="--", alpha=0.5 321 | ) 322 | 323 | ax_diff_main.set_title("(c) Task difficulty", fontsize=TITLE_SIZE) 324 | ax_diff_main.set_xlabel("Mean task accuracy", fontsize=LABEL_SIZE) 325 | ax_diff_main.set_ylabel(r"$\beta$", fontsize=LABEL_SIZE) 326 | ax_diff_main.set_xticklabels(ax_diff_main.get_xticklabels(), fontsize=TICK_SIZE) 327 | ax_diff_main.set_yticklabels(ax_diff_main.get_yticklabels(), fontsize=TICK_SIZE) 328 | ax_diff_main.grid(True, alpha=0.1) 329 | ax_diff_main.axhline(0, color="black", linestyle="--", alpha=0.2) 330 | 331 | # create legend for kde 332 | legend_elements = [ 333 | Patch(edgecolor=others_palette[0], label="Training set", fill=False), 334 | Patch(edgecolor=others_palette[1], label="Evaluation set", fill=False), 335 | ] 336 | 337 | # Add legend with custom handles 338 | ax_diff_main.legend(handles=legend_elements, loc="upper right") 339 | 340 | # Create difficulty KDE 341 | sns.kdeplot( 342 | plot_df_difficulty.filter(pl.col("task_type") == "training"), 343 | y="diff_mean", 344 | ax=ax_diff_kde, 345 | fill=True, 346 | color=others_palette[0], 347 | alpha=0.1, 348 | ) 349 | sns.kdeplot( 350 | plot_df_difficulty.filter(pl.col("task_type") == "evaluation"), 351 | y="diff_mean", 352 | ax=ax_diff_kde, 353 | fill=True, 354 | color=others_palette[1], 355 | alpha=0.1, 356 | ) 357 | ax_diff_kde.set_yticklabels([]) 358 | ax_diff_kde.set_xticklabels([]) 359 | ax_diff_kde.set_xlabel("") 360 | ax_diff_kde.set_ylabel("") 361 | ax_diff_kde.spines["left"].set_visible(False) 362 | ax_diff_kde.spines["top"].set_visible(False) 363 | ax_diff_kde.spines["right"].set_visible(False) 364 | ax_diff_kde.spines["bottom"].set_visible(False) 365 | ax_diff_kde.tick_params(axis="y", left=False) 366 | ax_diff_kde.tick_params(axis="x", bottom=False) 367 | ax_diff_kde.set_ylim(ax_diff_main.get_ylim()) 368 | 369 | # make sure ability and diff share the same xlim 370 | ax_ability_main.set_xlim(ax_diff_main.get_xlim()) 371 | 372 | # plt.tight_layout() 373 | plt.savefig( 374 | os.path.join(basepath, "figures", f"{model_name}_irt_parameters.png"), 375 | bbox_inches="tight", 376 | dpi=300, 377 | ) 378 | plt.close() 379 | 380 | 381 | def forest_plots(plot_df_ability, plot_df_difficulty, model_name): 382 | # Create figure with two subplots 383 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12)) 384 | 385 | # Ability Forest Plot 386 | ability_sorted = plot_df_ability.sort("ability_mean").with_row_index("idx") 387 | ability_sorted_complete = ability_sorted.filter(pl.col("complete")) 388 | 389 | # plot abiltiies for complete participants 390 | ax1.errorbar( 391 | x=ability_sorted_complete.select("ability_mean").to_numpy().flatten(), 392 | y=ability_sorted_complete.select("idx").to_numpy().flatten(), 393 | xerr=[ 394 | ability_sorted_complete.select("ability_mean").to_numpy().flatten() 395 | - ability_sorted_complete.select("ability_hdi_lower").to_numpy().flatten(), 396 | ability_sorted_complete.select("ability_hdi_upper").to_numpy().flatten() 397 | - ability_sorted_complete.select("ability_mean").to_numpy().flatten(), 398 | ], 399 | fmt="o", 400 | color="black", 401 | markersize=1, 402 | elinewidth=0.5, 403 | capsize=0, 404 | alpha=0.1, 405 | label="Complete", 406 | ) 407 | # add ability for incomplete participants with color 408 | ability_sorted_incomplete = ability_sorted.filter(~pl.col("complete")).sort( 409 | "ability_mean" 410 | ) 411 | ax1.errorbar( 412 | x=ability_sorted_incomplete.select("ability_mean").to_numpy().flatten(), 413 | y=ability_sorted_incomplete.select("idx").to_numpy().flatten(), 414 | xerr=[ 415 | ability_sorted_incomplete.select("ability_mean").to_numpy().flatten() 416 | - ability_sorted_incomplete.select("ability_hdi_lower") 417 | .to_numpy() 418 | .flatten(), 419 | ability_sorted_incomplete.select("ability_hdi_upper").to_numpy().flatten() 420 | - ability_sorted_incomplete.select("ability_mean").to_numpy().flatten(), 421 | ], 422 | fmt="o", 423 | color="red", 424 | markersize=1, 425 | elinewidth=0.5, 426 | capsize=0, 427 | alpha=0.5, 428 | label="Incomplete", 429 | ) 430 | 431 | ax1.axvline(x=0, color="black", linestyle="--", alpha=0.2) 432 | ax1.set_xlabel(r"Ability Parameter ($\alpha$)", fontsize=LABEL_SIZE) 433 | ax1.set_yticks([]) 434 | ax1.grid(True, alpha=0.1) 435 | ax1.set_title("(a) Participant Abilities with 94% HDI", fontsize=TITLE_SIZE) 436 | ax1.tick_params(axis="x", labelsize=TICK_SIZE) 437 | ax1.legend(fontsize=LEGEND_SIZE) 438 | 439 | # Difficulty Forest Plot 440 | difficulty_sorted = plot_df_difficulty.sort("diff_mean") 441 | ax2.errorbar( 442 | x=difficulty_sorted.select("diff_mean").to_numpy().flatten(), 443 | y=range(len(difficulty_sorted)), 444 | xerr=[ 445 | difficulty_sorted.select("diff_mean").to_numpy().flatten() 446 | - difficulty_sorted.select("diff_hdi_lower").to_numpy().flatten(), 447 | difficulty_sorted.select("diff_hdi_upper").to_numpy().flatten() 448 | - difficulty_sorted.select("diff_mean").to_numpy().flatten(), 449 | ], 450 | fmt="o", 451 | markersize=2, 452 | elinewidth=0.5, 453 | capsize=0, 454 | alpha=0.3, 455 | ) 456 | ax2.axvline(x=0, color="black", linestyle="--", alpha=0.2) 457 | ax2.set_xlabel(r"Difficulty Parameter ($\beta$)", fontsize=LABEL_SIZE) 458 | ax2.set_yticks([]) 459 | ax2.grid(True, alpha=0.1) 460 | ax2.set_title("(b) Task Difficulties with 94% HDI", fontsize=TITLE_SIZE) 461 | ax2.tick_params(axis="x", labelsize=TICK_SIZE) 462 | 463 | plt.tight_layout() 464 | plt.savefig( 465 | os.path.join(basepath, "figures", f"{model_name}_forest_plots.png"), 466 | bbox_inches="tight", 467 | dpi=300, 468 | ) 469 | plt.close() 470 | 471 | 472 | def logit_to_prob(logit): 473 | """Convert logit to probability.""" 474 | # Convert from logit scale to probability 475 | prob = 1 / (1 + np.exp(-(logit))) 476 | return prob 477 | 478 | 479 | def get_stats(trace, model_name, df, verbose=False): 480 | """Calculate and output key statistics from the IRT model trace. 481 | 482 | Args: 483 | trace: ArviZ trace object containing posterior samples 484 | model_name: Name of the model (used for output file naming) 485 | verbose: Whether to print results to console 486 | """ 487 | # Initialize output string to store all results 488 | output = [] 489 | output.append(f"=== IRT Model Statistics: {model_name} ===\n") 490 | 491 | # 1. Feedback Effects (Epsilon) 492 | output.append("Feedback Effects") 493 | output.append("-" * 50) 494 | 495 | # Calculate epsilon statistics 496 | epsilon1 = trace.posterior["epsilon_one"].mean(dim=("chain", "draw")).values.item() 497 | epsilon2 = trace.posterior["epsilon_two"].mean(dim=("chain", "draw")).values.item() 498 | 499 | # Calculate probability changes for average participant/task 500 | change1 = logit_to_prob(epsilon1) - logit_to_prob(0) 501 | change2 = logit_to_prob(epsilon2) - logit_to_prob(0) 502 | 503 | # Calculate HDI 504 | epsilon1_hdi = az.hdi(trace.posterior["epsilon_one"], hdi_prob=0.94) 505 | epsilon2_hdi = az.hdi(trace.posterior["epsilon_two"], hdi_prob=0.94) 506 | 507 | # Convert to probability differences 508 | epsilon1_hdi_prob = ( 509 | logit_to_prob(epsilon1_hdi.sel(hdi="lower").epsilon_one.values) 510 | - logit_to_prob(0), 511 | logit_to_prob(epsilon1_hdi.sel(hdi="higher").epsilon_one.values) 512 | - logit_to_prob(0), 513 | ) 514 | epsilon2_hdi_prob = ( 515 | logit_to_prob(epsilon2_hdi.sel(hdi="lower").epsilon_two.values) 516 | - logit_to_prob(0), 517 | logit_to_prob(epsilon2_hdi.sel(hdi="higher").epsilon_two.values) 518 | - logit_to_prob(0), 519 | ) 520 | 521 | output.append(rf"First additional attempt ($\epsilon_1 ={epsilon1:.2f}$):\n") 522 | output.append(f"- Probability increase: {change1:.1%}") 523 | output.append( 524 | f" (94% HDI: [{epsilon1_hdi_prob[0]:.1%}, {epsilon1_hdi_prob[1]:.1%}])\n\n" 525 | ) 526 | output.append(rf"Second additional attempt ($\epsilon_2 ={epsilon2:.2f}$):\n") 527 | output.append(f"- Probability increase: {change2:.1%}") 528 | output.append( 529 | f" (94% HDI: [{epsilon2_hdi_prob[0]:.1%}, {epsilon2_hdi_prob[1]:.1%}])\n\n" 530 | ) 531 | 532 | # 2. Task Type Differences 533 | output.append("Task Type Differences") 534 | output.append("-" * 50) 535 | 536 | # Calculate mean difficulties 537 | beta_training = trace.posterior["beta"].sel( 538 | tasks=df.filter(pl.col("task_type") == "training") 539 | .select("tasks") 540 | .to_numpy() 541 | .flatten() 542 | ) 543 | beta_eval = trace.posterior["beta"].sel( 544 | tasks=df.filter(pl.col("task_type") == "evaluation") 545 | .select("tasks") 546 | .to_numpy() 547 | .flatten() 548 | ) 549 | 550 | beta_training_mean = beta_training.mean().values 551 | beta_eval_mean = beta_eval.mean().values 552 | 553 | # Calculate probability differences for average participant 554 | p_training = logit_to_prob(0 - beta_training_mean) - logit_to_prob(0) 555 | p_eval = logit_to_prob(0 - beta_eval_mean) - logit_to_prob(0) 556 | 557 | # Calculate HDI 558 | training_hdi = az.hdi( 559 | trace.posterior["beta"] 560 | .sel( 561 | tasks=df.filter(pl.col("task_type") == "training") 562 | .select("tasks") 563 | .to_numpy() 564 | .flatten() 565 | ) 566 | .mean(dim=["tasks"]) 567 | .values.flatten(), 568 | hdi_prob=0.94, 569 | ) 570 | eval_hdi = az.hdi( 571 | trace.posterior["beta"] 572 | .sel( 573 | tasks=df.filter(pl.col("task_type") == "evaluation") 574 | .select("tasks") 575 | .to_numpy() 576 | .flatten() 577 | ) 578 | .mean(dim=["tasks"]) 579 | .values.flatten(), 580 | hdi_prob=0.94, 581 | ) 582 | 583 | p_training_hdi = ( 584 | logit_to_prob(0 - training_hdi[1]) - logit_to_prob(0), 585 | logit_to_prob(0 - training_hdi[0]) - logit_to_prob(0), 586 | ) 587 | p_eval_hdi = ( 588 | logit_to_prob(0 - eval_hdi[1]) - logit_to_prob(0), 589 | logit_to_prob(0 - eval_hdi[0]) - logit_to_prob(0), 590 | ) 591 | 592 | output.append("Training Tasks:\n") 593 | output.append(rf"- Mean difficulty ($\beta = {beta_training_mean:.2f}$)\n") 594 | output.append(f"- Success probability increase: {p_training:.1%}") 595 | output.append(f" (94% HDI: [{p_training_hdi[0]:.1%}, {p_training_hdi[1]:.1%}])\n\n") 596 | 597 | output.append("Evaluation Tasks:\n") 598 | output.append(rf"- Mean difficulty ($\beta = {beta_eval_mean:.2f}$)\n") 599 | output.append(f"- Success probability increase: {p_eval:.1%}") 600 | output.append(f" (94% HDI: [{p_eval_hdi[0]:.1%}, {p_eval_hdi[1]:.1%}])\n\n") 601 | 602 | # 3. Mean Task Accuracy by Shot 603 | output.append("Mean Task Accuracy by Shot\n") 604 | output.append("-" * 50) 605 | 606 | # Calculate mean task accuracies 607 | mean_task_acc_training = trace.posterior["mean_task_acc_training"].mean( 608 | dim=["chain", "draw"] 609 | ) 610 | mean_task_acc_training_hdi = az.hdi( 611 | trace.posterior["mean_task_acc_training"], hdi_prob=0.94 612 | ) 613 | mean_task_acc_eval = trace.posterior["mean_task_acc_eval"].mean( 614 | dim=["chain", "draw"] 615 | ) 616 | mean_task_acc_eval_hdi = az.hdi( 617 | trace.posterior["mean_task_acc_eval"], hdi_prob=0.94 618 | ) 619 | 620 | # Create formatted output for each shot 621 | for shot_idx, shot in enumerate(mean_task_acc_training.shots.values): 622 | output.append(f"\n{shot}:\n") 623 | 624 | # Training tasks 625 | train_acc = mean_task_acc_training.values[shot_idx] * 100 626 | train_hdi_lower = ( 627 | mean_task_acc_training_hdi.sel(hdi="lower").mean_task_acc_training.values[ 628 | shot_idx 629 | ] 630 | * 100 631 | ) 632 | train_hdi_upper = ( 633 | mean_task_acc_training_hdi.sel(hdi="higher").mean_task_acc_training.values[ 634 | shot_idx 635 | ] 636 | * 100 637 | ) 638 | 639 | output.append(f"- Training Tasks: {train_acc:.1f}%") 640 | output.append(f" (94% HDI: [{train_hdi_lower:.1f}%, {train_hdi_upper:.1f}%])\n") 641 | 642 | # Evaluation tasks 643 | eval_acc = mean_task_acc_eval.values[shot_idx] * 100 644 | eval_hdi_lower = ( 645 | mean_task_acc_eval_hdi.sel(hdi="lower").mean_task_acc_eval.values[shot_idx] 646 | * 100 647 | ) 648 | eval_hdi_upper = ( 649 | mean_task_acc_eval_hdi.sel(hdi="higher").mean_task_acc_eval.values[shot_idx] 650 | * 100 651 | ) 652 | 653 | output.append(f"- Evaluation Tasks: {eval_acc:.1f}%") 654 | output.append(f" (94% HDI: [{eval_hdi_lower:.1f}%, {eval_hdi_upper:.1f}%])\n") 655 | 656 | # Write results to file 657 | output_path = os.path.join(basepath, "results", f"{model_name}_stats.md") 658 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 659 | with open(output_path, "w") as f: 660 | f.write("\n".join(output)) 661 | 662 | # Print to console if verbose 663 | if verbose: 664 | print("\n".join(output)) 665 | 666 | return output_path 667 | 668 | 669 | if __name__ == "__main__": 670 | args = get_args() 671 | model_name = Path(args.model_path).stem 672 | model, trace = load_model(args.model_path) 673 | # total_params = ( 674 | # len(trace.posterior.participants) # ability parameters 675 | # + len(trace.posterior.tasks) # difficulty parameters 676 | # + 2 # epsilon parameters 677 | # ) 678 | # print(f"Total number of parameters: {total_params}") 679 | plot_trace(trace, model_name) 680 | plot_df_ability, plot_df_difficulty, plot_df_epsilon, plot_df_epsilon_alt = ( 681 | generate_irt_dataframes(trace) 682 | ) 683 | # save IRT model parameters 684 | plot_df_ability.write_csv( 685 | os.path.join(basepath, "data", f"{model_name}_ability_parameters.csv") 686 | ) 687 | plot_df_difficulty.write_csv( 688 | os.path.join(basepath, "data", f"{model_name}_difficulty_parameters.csv") 689 | ) 690 | plot_df_epsilon_alt.write_csv( 691 | os.path.join(basepath, "data", f"{model_name}_epsilon_parameters.csv") 692 | ) 693 | plot_irt_parameters( 694 | plot_df_epsilon, plot_df_ability, plot_df_difficulty, model_name 695 | ) 696 | forest_plots(plot_df_ability, plot_df_difficulty, model_name) 697 | get_stats(trace, model_name, plot_df_difficulty, args.verbose) 698 | -------------------------------------------------------------------------------- /analysis/6-incomplete-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# imports\n", 10 | "import seaborn as sns\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import sys\n", 13 | "import numpy as np\n", 14 | "\n", 15 | "sys.path.append(\"..\")\n", 16 | "from src.utils import *\n", 17 | "from scipy.stats import ttest_ind, permutation_test\n", 18 | "from scipy import stats" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 3, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "" 30 | ] 31 | }, 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "# polars config\n", 39 | "pl.Config.load_from_file(\"../polars_cfg.json\")" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# data file paths\n", 49 | "data_path = \"/Users/solimlegris/Projets/h-arc-osf/data/data.csv\"\n", 50 | "summary_path = \"/Users/solimlegris/Projets/h-arc-osf/data/summary_data.csv\"" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 5, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "df_summary = pl.read_csv(summary_path)\n", 60 | "df_all = pl.read_csv(data_path)\n", 61 | "\n", 62 | "# parse time\n", 63 | "df_all = df_all.with_columns(pl.col(\"time\").cast(pl.Datetime))\n", 64 | "\n", 65 | "columns = [\n", 66 | " \"exp_name\",\n", 67 | " \"hashed_id\",\n", 68 | " \"joint_id_task\",\n", 69 | " \"task_name\",\n", 70 | " \"task_number\",\n", 71 | " \"task_type\",\n", 72 | " \"attempt_number\",\n", 73 | " \"action\",\n", 74 | " \"action_id\",\n", 75 | " \"solved\",\n", 76 | " \"time\",\n", 77 | " \"test_input_grid\",\n", 78 | " \"test_output_grid\",\n", 79 | "]\n", 80 | "df_all = df_all.select(columns)\n", 81 | "\n", 82 | "# load task accuracy data\n", 83 | "mean_task_acc = pl.read_csv(\"../data/mean_task_acc_three_attempts.csv\")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 7, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "image/png": "", 94 | "text/plain": [ 95 | "
" 96 | ] 97 | }, 98 | "metadata": {}, 99 | "output_type": "display_data" 100 | } 101 | ], 102 | "source": [ 103 | "# plot distribution of task number where participants dropped out\n", 104 | "sns.histplot(\n", 105 | " df_summary.filter(~pl.col(\"complete\")).select(\n", 106 | " pl.all()\n", 107 | " .top_k_by(\"task_number\", k=1)\n", 108 | " .over(\"hashed_id\", mapping_strategy=\"explode\")\n", 109 | " ),\n", 110 | " x=\"task_number\",\n", 111 | " hue=\"task_type\",\n", 112 | " multiple=\"dodge\",\n", 113 | " stat=\"proportion\",\n", 114 | " common_norm=False,\n", 115 | ")\n", 116 | "plt.title(\"Distribution of task number where participants dropped out\")\n", 117 | "plt.show()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 8, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/html": [ 128 | "
\n", 135 | "shape: (5, 3)
task_typetask_numberlen
stri64u32
"evaluation"181
"evaluation"250
"evaluation"344
"evaluation"467
"training"127
" 136 | ], 137 | "text/plain": [ 138 | "shape: (5, 3)\n", 139 | "┌────────────┬─────────────┬─────┐\n", 140 | "│ task_type ┆ task_number ┆ len │\n", 141 | "│ --- ┆ --- ┆ --- │\n", 142 | "│ str ┆ i64 ┆ u32 │\n", 143 | "╞════════════╪═════════════╪═════╡\n", 144 | "│ evaluation ┆ 1 ┆ 81 │\n", 145 | "│ evaluation ┆ 2 ┆ 50 │\n", 146 | "│ evaluation ┆ 3 ┆ 44 │\n", 147 | "│ evaluation ┆ 4 ┆ 67 │\n", 148 | "│ training ┆ 1 ┆ 27 │\n", 149 | "└────────────┴─────────────┴─────┘" 150 | ] 151 | }, 152 | "execution_count": 8, 153 | "metadata": {}, 154 | "output_type": "execute_result" 155 | } 156 | ], 157 | "source": [ 158 | "# for each task_type, compute how many drop after 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 attempts\n", 159 | "# total_training =\n", 160 | "dropout_after = (\n", 161 | " df_summary.filter(~pl.col(\"complete\"))\n", 162 | " .group_by(\"hashed_id\")\n", 163 | " .agg(pl.max(\"task_number\"), pl.first(\"task_type\"))\n", 164 | " .group_by([\"task_type\", \"task_number\"])\n", 165 | " .agg(pl.len())\n", 166 | ")\n", 167 | "dropout_after.sort([\"task_type\", \"task_number\"]).head()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 9, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "Training dropout rate: 94/783 (0.120, 12.0%)\n", 180 | "Evaluation dropout rate: 242/946 (0.256, 25.6%)\n" 181 | ] 182 | } 183 | ], 184 | "source": [ 185 | "# dropout rates\n", 186 | "training_incomplete = (\n", 187 | " df_summary.filter(~pl.col(\"complete\"))\n", 188 | " .filter(pl.col(\"task_type\") == \"training\")\n", 189 | " .n_unique(\"hashed_id\")\n", 190 | ")\n", 191 | "training_total = df_summary.filter(pl.col(\"task_type\") == \"training\").n_unique(\n", 192 | " \"hashed_id\"\n", 193 | ")\n", 194 | "training_dropout_rate = training_incomplete / training_total\n", 195 | "\n", 196 | "evaluation_incomplete = (\n", 197 | " df_summary.filter(~pl.col(\"complete\"))\n", 198 | " .filter(pl.col(\"task_type\") == \"evaluation\")\n", 199 | " .n_unique(\"hashed_id\")\n", 200 | ")\n", 201 | "evaluation_total = df_summary.filter(pl.col(\"task_type\") == \"evaluation\").n_unique(\n", 202 | " \"hashed_id\"\n", 203 | ")\n", 204 | "evaluation_dropout_rate = evaluation_incomplete / evaluation_total\n", 205 | "\n", 206 | "print(\n", 207 | " f\"Training dropout rate: {training_incomplete}/{training_total} ({training_dropout_rate:.3f}, {round(training_dropout_rate * 100, 1)}%)\"\n", 208 | ")\n", 209 | "print(\n", 210 | " f\"Evaluation dropout rate: {evaluation_incomplete}/{evaluation_total} ({evaluation_dropout_rate:.3f}, {round(evaluation_dropout_rate * 100, 1)}%)\"\n", 211 | ")" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 10, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "Proportion of task data missing: 10.3%\n", 224 | "Proportion of task data expected on training set: 7.5%\n", 225 | "Proportion of task data expected on evaluation set: 13.3%\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "# amount of missing data\n", 231 | "df_summary_ten = df_summary.filter(pl.col(\"exp_name\").is_in([\"expv0\", \"expv1\"]))\n", 232 | "df_summary_five = df_summary.filter(~pl.col(\"exp_name\").is_in([\"expv0\", \"expv1\"]))\n", 233 | "total_expected_task_data = (\n", 234 | " df_summary_ten.n_unique(\"hashed_id\") * 10\n", 235 | " + df_summary_five.n_unique(\"hashed_id\") * 5\n", 236 | ")\n", 237 | "total_expected_task_data_training = (\n", 238 | " df_summary_ten.n_unique(\"hashed_id\") * 10\n", 239 | " + df_summary_five.filter(pl.col(\"task_type\") == \"training\").n_unique(\"hashed_id\")\n", 240 | " * 5\n", 241 | ")\n", 242 | "total_expected_task_data_evaluation = (\n", 243 | " df_summary_five.filter(pl.col(\"task_type\") == \"evaluation\").n_unique(\"hashed_id\")\n", 244 | " * 5\n", 245 | ")\n", 246 | "observed_task_data = df_summary.n_unique(\"joint_id_task\")\n", 247 | "observed_task_data_training = df_summary.filter(\n", 248 | " pl.col(\"task_type\") == \"training\"\n", 249 | ").n_unique(\"joint_id_task\")\n", 250 | "observed_task_data_evaluation = df_summary.filter(\n", 251 | " pl.col(\"task_type\") == \"evaluation\"\n", 252 | ").n_unique(\"joint_id_task\")\n", 253 | "print(\n", 254 | " f\"Proportion of task data missing: {round((1 - observed_task_data / total_expected_task_data) * 100, 1)}%\"\n", 255 | ")\n", 256 | "print(\n", 257 | " f\"Proportion of task data expected on training set: {round((1 - observed_task_data_training / total_expected_task_data_training) * 100, 1)}%\"\n", 258 | ")\n", 259 | "print(\n", 260 | " f\"Proportion of task data expected on evaluation set: {round((1 - observed_task_data_evaluation / total_expected_task_data_evaluation) * 100, 1)}%\"\n", 261 | ")" 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kernelspec": { 267 | "display_name": ".venv", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.12.5" 282 | } 283 | }, 284 | "nbformat": 4, 285 | "nbformat_minor": 2 286 | } 287 | -------------------------------------------------------------------------------- /analysis/0-arc-dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# imports\n", 10 | "import polars as pl\n", 11 | "import seaborn as sns\n", 12 | "from matplotlib import pyplot as plt\n", 13 | "from scipy.stats import ttest_ind, permutation_test" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# load data\n", 23 | "training_set = pl.read_csv(\"../arc_data/ARC_training_tasks.csv\")\n", 24 | "evaluation_set = pl.read_csv(\"../arc_data/ARC_evaluation_tasks.csv\")" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "def get_min_max(df, example_type):\n", 34 | " return (\n", 35 | " df.filter(pl.col(\"example_type\") == example_type)\n", 36 | " .group_by(\"task_name\")\n", 37 | " .agg(pl.max(\"example_number\"))\n", 38 | " .select(\n", 39 | " pl.min(\"example_number\").alias(\"min_examples\"),\n", 40 | " pl.max(\"example_number\").alias(\"max_examples\"),\n", 41 | " )\n", 42 | " .row(0)\n", 43 | " )" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Training set:\n", 56 | "\tmin examples: 2\n", 57 | "\tmax examples: 10\n", 58 | "Evaluation set:\n", 59 | "\tmin examples: 2\n", 60 | "\tmax examples: 7\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "# Get min and max number of training examples for both sets\n", 66 | "train_min, train_max = get_min_max(training_set, \"train\")\n", 67 | "eval_min, eval_max = get_min_max(evaluation_set, \"train\")\n", 68 | "\n", 69 | "print(f\"Training set:\\n\\tmin examples: {train_min}\\n\\tmax examples: {train_max}\")\n", 70 | "print(f\"Evaluation set:\\n\\tmin examples: {eval_min}\\n\\tmax examples: {eval_max}\")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "There are between 2 and 10 inference examples for training set tasks and 2 and 7 for evaluation set tasks.\n" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "Training set:\n", 90 | "\tmin examples: 1\n", 91 | "\tmax examples: 3\n", 92 | "Evaluation set:\n", 93 | "\tmin examples: 1\n", 94 | "\tmax examples: 2\n" 95 | ] 96 | } 97 | ], 98 | "source": [ 99 | "# get min and max number of training examples on training/eval set for test examples\n", 100 | "train_min, train_max = get_min_max(training_set, \"test\")\n", 101 | "eval_min, eval_max = get_min_max(evaluation_set, \"test\")\n", 102 | "\n", 103 | "print(f\"Training set:\\n\\tmin examples: {train_min}\\n\\tmax examples: {train_max}\")\n", 104 | "print(f\"Evaluation set:\\n\\tmin examples: {eval_min}\\n\\tmax examples: {eval_max}\")" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "There are between 1 and 3 test examples on the training set and 1 and 2 on the test set.\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "We did not collect data for 16 tests over 14 tasks in the training set\n", 124 | "We did not collect data for 19 tests over 19 tasks in the evaluation set\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "training_set_no_data_tests = (\n", 130 | " training_set.filter(pl.col(\"example_type\") == \"test\")\n", 131 | " .group_by(\"task_name\")\n", 132 | " .agg((pl.max(\"example_number\") - 1).alias(\"missed_examples\"))\n", 133 | " .filter(pl.col(\"missed_examples\") > 0)\n", 134 | ")\n", 135 | "num_missed_training = training_set_no_data_tests.select(\n", 136 | " pl.sum(\"missed_examples\")\n", 137 | ").item()\n", 138 | "evaluation_set_no_data_tests = (\n", 139 | " evaluation_set.filter(pl.col(\"example_type\") == \"test\")\n", 140 | " .group_by(\"task_name\")\n", 141 | " .agg((pl.max(\"example_number\") - 1).alias(\"missed_examples\"))\n", 142 | " .filter(pl.col(\"missed_examples\") > 0)\n", 143 | ")\n", 144 | "num_missed_eval = evaluation_set_no_data_tests.select(pl.sum(\"missed_examples\")).item()\n", 145 | "\n", 146 | "print(\n", 147 | " f\"We did not collect data for {num_missed_training} tests over {training_set_no_data_tests.height} tasks in the training set\"\n", 148 | ")\n", 149 | "print(\n", 150 | " f\"We did not collect data for {num_missed_eval} tests over {evaluation_set_no_data_tests.height} tasks in the evaluation set\"\n", 151 | ")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "def get_set_tests(dataset):\n", 161 | " # grid size\n", 162 | " filtered_tests = dataset.filter((pl.col(\"example_type\") == \"test\")).with_columns(\n", 163 | " (pl.col(\"output_height\") * pl.col(\"output_width\")).alias(\"grid_size\")\n", 164 | " )\n", 165 | " # number of examples per task\n", 166 | " num_examples_per_task = (\n", 167 | " dataset.filter(pl.col(\"example_type\") == \"train\")\n", 168 | " .group_by(\"task_name\")\n", 169 | " .agg(pl.max(\"example_number\").alias(\"num_examples\"))\n", 170 | " )\n", 171 | " filtered_tests = filtered_tests.join(num_examples_per_task, on=\"task_name\")\n", 172 | " return filtered_tests" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "training_set_tests = get_set_tests(training_set)\n", 182 | "training_set_tests = training_set_tests.with_columns(pl.lit(\"training\").alias(\"set\"))\n", 183 | "evaluation_set_tests = get_set_tests(evaluation_set)\n", 184 | "evaluation_set_tests = evaluation_set_tests.with_columns(\n", 185 | " pl.lit(\"evaluation\").alias(\"set\")\n", 186 | ")\n", 187 | "dataset = pl.concat([training_set_tests, evaluation_set_tests])" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 9, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "image/png": "", 198 | "text/plain": [ 199 | "
" 200 | ] 201 | }, 202 | "metadata": {}, 203 | "output_type": "display_data" 204 | } 205 | ], 206 | "source": [ 207 | "# plot grid size distribution\n", 208 | "plt.figure(figsize=(12, 6))\n", 209 | "sns.histplot(dataset, x=\"grid_size\", bins=16, hue=\"set\")\n", 210 | "plt.title(\"Grid size distribution\")\n", 211 | "training_set_mean = training_set_tests.select(pl.mean(\"grid_size\")).item()\n", 212 | "evaluation_set_mean = evaluation_set_tests.select(pl.mean(\"grid_size\")).item()\n", 213 | "training_set_sd = training_set_tests.select(pl.std(\"grid_size\")).item()\n", 214 | "evaluation_set_sd = evaluation_set_tests.select(pl.std(\"grid_size\")).item()\n", 215 | "plt.text(\n", 216 | " 0.5,\n", 217 | " 0.5,\n", 218 | " f\"Training set mean: {round(training_set_mean, 1)}, SD={round(training_set_sd, 1)}\\nEvaluation set mean: {round(evaluation_set_mean, 1)}, SD={round(evaluation_set_sd, 1)}\",\n", 219 | " horizontalalignment=\"center\",\n", 220 | " verticalalignment=\"center\",\n", 221 | " transform=plt.gca().transAxes,\n", 222 | ")\n", 223 | "plt.plot()\n", 224 | "plt.show()" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 10, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "Test statistic: -6.81421447779164\n", 237 | "P-value: 1.9843266953504276e-11\n", 238 | "Degrees of freedom: 730.1473878533382\n" 239 | ] 240 | } 241 | ], 242 | "source": [ 243 | "# run permutation test\n", 244 | "training_set_grid_sizes = training_set_tests.select(\"grid_size\").to_numpy().flatten()\n", 245 | "evaluation_set_grid_sizes = (\n", 246 | " evaluation_set_tests.select(\"grid_size\").to_numpy().flatten()\n", 247 | ")\n", 248 | "\n", 249 | "stats = ttest_ind(\n", 250 | " training_set_grid_sizes,\n", 251 | " evaluation_set_grid_sizes,\n", 252 | " equal_var=False,\n", 253 | ")\n", 254 | "print(f\"Test statistic: {stats.statistic}\")\n", 255 | "print(f\"P-value: {stats.pvalue}\")\n", 256 | "print(f\"Degrees of freedom: {stats.df}\")" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "Permutation tests show that output grid sizes are significantly smaller in the training set than in the evaluation set (p < .001).\n" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 11, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "dataset.write_csv(\"../data/ARC_all_tasks_tests_stats.csv\")" 273 | ] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": ".venv", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.12.5" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 2 297 | } 298 | --------------------------------------------------------------------------------