├── .gitignore ├── README.md ├── computational_model ├── Manifest.toml ├── Project.toml ├── analysis_scripts │ ├── anal_utils.jl │ ├── analyse_by_N.jl │ ├── analyse_hp_sweep.jl │ ├── analyse_human_data.jl │ ├── analyse_rollout_timing.jl │ ├── analyse_variable_rollouts.jl │ ├── behaviour_by_success.jl │ ├── calc_human_prior.jl │ ├── compare_maze_path_lengths.jl │ ├── compare_perf_without_rollout.jl │ ├── estimate_num_mazes.jl │ ├── euclidean_prolific_ids.jl │ ├── eval_value_function.jl │ ├── model_replay_analyses.jl │ ├── perf_by_rollout_number.jl │ ├── quantify_internal_model.jl │ ├── repeat_human_actions.jl │ ├── results │ │ └── example_rollout.bson │ ├── rollout_as_pg.jl │ ├── run_all_analyses.jl │ └── shuffle_rollout_times.jl ├── figs │ └── .gitkeep ├── models │ ├── N100_T50_Lplan8_seed61_1000_hps.bson │ ├── N100_T50_Lplan8_seed61_1000_mod.bson │ ├── N100_T50_Lplan8_seed61_1000_opt.bson │ ├── N100_T50_Lplan8_seed61_1000_policy.bson │ ├── N100_T50_Lplan8_seed61_1000_prediction.bson │ ├── N100_T50_Lplan8_seed61_1000_progress.bson │ ├── N100_T50_Lplan8_seed62_1000_hps.bson │ ├── N100_T50_Lplan8_seed62_1000_mod.bson │ ├── N100_T50_Lplan8_seed62_1000_opt.bson │ ├── N100_T50_Lplan8_seed62_1000_policy.bson │ ├── N100_T50_Lplan8_seed62_1000_prediction.bson │ ├── N100_T50_Lplan8_seed62_1000_progress.bson │ ├── N100_T50_Lplan8_seed63_1000_hps.bson │ ├── N100_T50_Lplan8_seed63_1000_mod.bson │ ├── N100_T50_Lplan8_seed63_1000_opt.bson │ ├── N100_T50_Lplan8_seed63_1000_policy.bson │ ├── N100_T50_Lplan8_seed63_1000_prediction.bson │ ├── N100_T50_Lplan8_seed63_1000_progress.bson │ ├── N100_T50_Lplan8_seed64_1000_hps.bson │ ├── N100_T50_Lplan8_seed64_1000_mod.bson │ ├── N100_T50_Lplan8_seed64_1000_opt.bson │ ├── N100_T50_Lplan8_seed64_1000_policy.bson │ ├── N100_T50_Lplan8_seed64_1000_prediction.bson │ ├── N100_T50_Lplan8_seed64_1000_progress.bson │ ├── N100_T50_Lplan8_seed65_1000_hps.bson │ ├── N100_T50_Lplan8_seed65_1000_mod.bson │ ├── N100_T50_Lplan8_seed65_1000_opt.bson │ ├── N100_T50_Lplan8_seed65_1000_policy.bson │ ├── N100_T50_Lplan8_seed65_1000_prediction.bson │ └── N100_T50_Lplan8_seed65_1000_progress.bson ├── plot_paper_figs │ ├── plot_all.jl │ ├── plot_fig_RTs.jl │ ├── plot_fig_mechanism_behav.jl │ ├── plot_fig_mechanism_neural.jl │ ├── plot_fig_replays.jl │ ├── plot_supp_RT_by_step.jl │ ├── plot_supp_exploration.jl │ ├── plot_supp_fig_network_size.jl │ ├── plot_supp_hp_sweep.jl │ ├── plot_supp_human_euc_comparison.jl │ ├── plot_supp_human_summary.jl │ ├── plot_supp_internal_model.jl │ ├── plot_supp_plan_probs.jl │ ├── plot_supp_values.jl │ ├── plot_supp_variable.jl │ └── plot_utils.jl ├── repeated_submission.py ├── results │ └── .gitkeep ├── src │ ├── ToPlanOrNotToPlan.jl │ ├── a2c.jl │ ├── environment.jl │ ├── exports.jl │ ├── human_utils_maze.jl │ ├── initializations.jl │ ├── io.jl │ ├── loss_hyperparameters.jl │ ├── maze.jl │ ├── model.jl │ ├── model_planner.jl │ ├── planning.jl │ ├── plotting.jl │ ├── priors.jl │ ├── train.jl │ ├── walls.jl │ ├── walls_baselines.jl │ └── walls_build.jl └── walls_train.jl ├── human_data.txt └── human_data ├── Euclidean_prolific_data.sqlite └── prolific_data.sqlite /.gitignore: -------------------------------------------------------------------------------- 1 | *bson 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A recurrent network model of planning explains hippocampal replay and human behavior 2 | 3 | In this repository, we provide code for training and analysing the reinforcement learning agent described in Jensen et al. (2024): "A recurrent network model of planning explains hippocampal replay and human behavior" (https://www.nature.com/articles/s41593-024-01675-7). 4 | 5 | Human data for the behavioural experiments both with and without periodic boundaries can be found in the `human_data/` directory. 6 | Code for training and analysing the reinforcement learning models, as well as generating the figures in the paper, can be found in the `computational_model/` directory. 7 | A collection of pretrained base models is provided in `./computational_model/models/`. 8 | 9 | To run the code, julia >= 1.7 should be installed together with all the packages from the Manifest.toml file. 10 | To install these packages:\ 11 | `cd ./computational_model`\ 12 | `julia --project=.`\ 13 | `using Pkg`\ 14 | `Pkg.instantiate()`\ 15 | To run the pretrained models, BSON 0.3.5 and Flux 0.13.5 should be installed since backwards compatibility was not preserved for the latest versions of these packages. 16 | Julia 1.8.0 was used for all analyses in the paper. 17 | 18 | The primary script used to train RL agents is './computational_model/walls_train.jl'. 19 | A useful script for getting started on downstream analyses of the computational model is './computational_model/analysis_scripts/analyse_rollout_timing.jl'. 20 | The primary script used for analyses of the human data is './computational_model/analysis_scripts/analyse_human_data.jl'. 21 | 22 | For any questions, comments, or suggestions, please reach out to Kris Jensen (kris.torp.jensen@gmail.com). 23 | -------------------------------------------------------------------------------- /computational_model/Project.toml: -------------------------------------------------------------------------------- 1 | name = "ToPlanOrNotToPlan" 2 | uuid = "900268d7-6a04-4a1d-9c18-30a86fd30022" 3 | authors = ["Kristopher T. Jensen", "Ta-Chu Kao"] 4 | version = "0.1.0" 5 | 6 | [deps] 7 | ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" 8 | BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" 9 | BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232" 10 | CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" 11 | CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" 12 | ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" 13 | Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" 14 | DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" 15 | DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" 16 | Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" 17 | Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" 18 | HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" 19 | ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5" 20 | JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" 21 | LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" 22 | LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" 23 | Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" 24 | MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" 25 | NaNStatistics = "b946abbf-3ea7-4610-9019-9858bfdeaf2d" 26 | NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" 27 | Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" 28 | Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" 29 | PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" 30 | PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" 31 | Query = "1a8c2f83-1ff3-5112-b086-8aa67b057ba1" 32 | Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" 33 | RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" 34 | Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" 35 | SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9" 36 | SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" 37 | Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" 38 | StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" 39 | StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" 40 | Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" 41 | Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" 42 | VegaDatasets = "0ae4a718-28b7-58ec-9efb-cded64d6d5b4" 43 | VegaLite = "112f6efa-9a02-5b7d-90c0-432ed331239a" 44 | Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" 45 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/anal_utils.jl: -------------------------------------------------------------------------------- 1 | #this script loads some useful libraries and sets various global defaults. 2 | 3 | # load some libraries that we generally need 4 | import Pkg 5 | Pkg.activate("../") 6 | using Revise 7 | using PyPlot, PyCall 8 | using Distributions, Statistics, Random, StatsBase 9 | using Flux, Zygote 10 | using BSON: @save, @load 11 | @pyimport matplotlib.gridspec as gspec 12 | 13 | #set some default paths 14 | global datadir = "./results/" #directory to write results to 15 | global loaddir = "../models/" #directory to load models from 16 | 17 | # select default global models 18 | global seeds = 61:65 #random seeds 19 | global plan_epoch = 1000 #training epoch to use for evaluation (1000 is final epoch) 20 | global greedy_actions = true #sample actions greedily at test time 21 | global N = 100 #number of units 22 | global Lplan = 8 #planning horizon 23 | global Larena = 4 #arena size 24 | global prefix = "" #model name prefix 25 | global epoch = plan_epoch #redundant 26 | 27 | ### lognormal helper functions ### 28 | function lognorm(x; mu = 0, sig = 0, delta = 0) 29 | #pdf for shifted lognormal distribution (shift = delta) 30 | if x <= delta return 0 end 31 | return 1 / ((x-delta) * sig * sqrt(2*pi)) * exp(- (log(x-delta) - mu)^2 / (2*sig^2)) 32 | end 33 | 34 | Phi(x) = cdf(Normal(), x) #standard normal pdf 35 | function calc_post_mean(r; deltahat=0, muhat=0, sighat=0) 36 | #compute posterior mean thinking time for a given response time 'r' 37 | #deltahat, muhat, and sighat are the parameters of the lognormal prior over delays 38 | if r < deltahat+1 return 0 end #if response is faster than delta, no thinking 39 | k1, k2 = 0, r - deltahat #integration limits 40 | term1 = exp(muhat+sighat^2/2) 41 | term2 = Phi((log(k2)-muhat-sighat^2)/sighat) - Phi((log(k1)-muhat-sighat^2)/sighat) 42 | term3 = Phi((log(k2)-muhat)/sighat) - Phi((log(k1)-muhat)/sighat) 43 | post_delay = (term1*term2/term3 + deltahat) #add back delta for posterior mean delay 44 | return r - post_delay #posterior mean thinking time is response minus mean delay 45 | end 46 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/analyse_by_N.jl: -------------------------------------------------------------------------------- 1 | #in this script, we consider how rewards and rollout probabilities change over learning 2 | #for models of different sizes. 3 | 4 | #load some stuff 5 | include("anal_utils.jl") 6 | using ToPlanOrNotToPlan 7 | 8 | println("analyzing learning and rollouts for different model sizes") 9 | 10 | loss_hp = LossHyperparameters(0, 0, 0, 0) #not computing losses 11 | Nhiddens = [60;80;100] #number of hidden units 12 | epochs = 0:50:1000 #training epochs to consider 13 | 14 | 15 | meanrews, pfracs = [zeros(length(Nhiddens), length(seeds), length(epochs)) for _ = 1:2] #containers for storing results 16 | for (ihid, Nhidden) = enumerate(Nhiddens) #for each network size 17 | for (iseed, seed) = enumerate(seeds) #for each random seed 18 | for (iepoch, epoch) = enumerate(epochs) 19 | 20 | filename = "N$(Nhidden)_T50_Lplan8_seed$(seed)_$epoch" #model to load 21 | network, opt, store, hps, policy, prediction = recover_model(loaddir*filename) #load model parameters 22 | 23 | #instantiate environment and agent 24 | Larena = hps["Larena"] 25 | model_properties, wall_environment, model_eval = build_environment( 26 | Larena, hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions 27 | ) 28 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) 29 | 30 | Random.seed!(1) #set random seed for reproducibility 31 | batch_size = 5000 #number of environments to consider 32 | tic = time() 33 | #run the experiment 34 | L, ys, rews, as, world_states, hs = run_episode( 35 | m, wall_environment, loss_hp; batch=batch_size, calc_loss = false 36 | ) 37 | plan_frac = sum(as .== 5)/sum(as .> 0.5) #fraction of actions that were rollouts 38 | mean_rew = sum(rews .> 0.5) / batch_size #average reward per episode 39 | 40 | println("N=$Nhidden, seed=$seed, epoch=$epoch, avg rew=$(mean_rew), rollout fraction=$(plan_frac)") 41 | #store results 42 | meanrews[ihid, iseed, iepoch] = mean_rew 43 | pfracs[ihid, iseed, iepoch] = plan_frac 44 | 45 | end 46 | end 47 | end 48 | 49 | #save all results 50 | res_dict = Dict("seeds" => seeds, 51 | "Nhiddens" => Nhiddens, 52 | "epochs" => epochs, 53 | "meanrews" => meanrews, 54 | "planfracs" => pfracs) 55 | @save datadir * "rew_and_plan_by_n.bson" res_dict 56 | 57 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/analyse_hp_sweep.jl: -------------------------------------------------------------------------------- 1 | #in this script, we repeat some key analyses with different network sizes and rollout lengths 2 | #we do this to assess the robustness of our results 3 | 4 | # load some scripts 5 | include("anal_utils.jl") 6 | 7 | global run_default_analyses = false # load functions without running analyses for default models 8 | include("repeat_human_actions.jl") 9 | include("perf_by_rollout_number.jl") 10 | include("behaviour_by_success.jl") 11 | global run_default_analyses = true # back to default 12 | 13 | println("repeating analyses with different hyperparameters") 14 | 15 | prefix = "" 16 | seeds = 51:55 #use a separate set of seeds 17 | sizes = [60;100;140] #model sizes to consider 18 | Lplans = [4;8;12] #planning horizons to consider 19 | 20 | for N = sizes #for each network size 21 | for Lplan = Lplans #for each planning horizon 22 | println("running N=$N, L=$Lplan") 23 | # correlation with human RT #### 24 | repeat_human_actions(;seeds, N, Lplan, epoch, prefix = "hp_") 25 | 26 | # change in performance with replay number ### 27 | run_perf_by_rollout_number(;seeds, N, Lplan, epoch, prefix = "hp_") 28 | 29 | # change in policy after successful/unsuccessful replay #### 30 | run_causal_rollouts(;seeds, N, Lplan, epoch, prefix = "hp_") 31 | end 32 | end 33 | 34 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/analyse_human_data.jl: -------------------------------------------------------------------------------- 1 | # in this script, we load the human behavioural data and save some useful summary statistics 2 | 3 | # load scripts and model 4 | include("anal_utils.jl") 5 | include("euclidean_prolific_ids.jl") 6 | using ToPlanOrNotToPlan 7 | using SQLite, DataFrames, ImageFiltering 8 | using NaNStatistics 9 | 10 | println("loading and processing human behavioural data") 11 | wraparound = true 12 | 13 | #perform analyses for both non-guided ("play") and guided ("follow") episodes 14 | for game_type = ["play"; "follow"] 15 | 16 | # build RL environment 17 | T = 100 18 | Larena = 4 19 | environment_dimensions = EnvironmentDimensions(Larena^2, 2, 5, T, Larena) 20 | 21 | #initialize some arrays for storing data 22 | all_RTs, all_trial_nums, all_trial_time, all_rews, all_states = [], [], [], [], [] 23 | all_wall_loc, all_ps, all_as = [], [], [] 24 | Nepisodes, tokens = [], [] 25 | 26 | if wraparound 27 | db = SQLite.DB("../../human_data/prolific_data.sqlite") 28 | wrapstr = "" 29 | else 30 | db = SQLite.DB("../../human_data/Euclidean_prolific_data.sqlite") 31 | wrapstr = "_euclidean" 32 | end 33 | 34 | users = (DBInterface.execute(db, "SELECT id FROM users") |> DataFrame)[:, "id"] 35 | if game_type == "play" nskip = 2 else nskip = 8 end #number of initial episodes to discard 36 | 37 | println("loading users for game type: $(game_type)") 38 | i_user = 0 39 | for user_id = users 40 | user_eps = DBInterface.execute(db, "SELECT * FROM episodes WHERE user_id = "*string(user_id[1])) |> DataFrame #episode data 41 | usize = size(user_eps, 1) #total number of episodes for this user 42 | info = DBInterface.execute(db, "SELECT * FROM users WHERE id = "*string(user_id)) |> DataFrame 43 | token = info[1, "token"] 44 | if (usize >= 58) && (length(token) == 24) #finished task and prolific-sized token 45 | if wraparound || (token in euclidean_ids) 46 | i_user += 1; if i_user % 10 == 0 println(i_user) end 47 | rews, as, states, wall_loc, ps, times, trial_nums, trial_time, RTs, shot = extract_maze_data(db, user_id, Larena, game_type = game_type, skip_init = nskip) 48 | append!(all_RTs, [RTs]) #reaction times 49 | append!(all_rews, [rews]) #rewards 50 | append!(all_trial_nums, [trial_nums]) #trial numbes 51 | append!(all_trial_time, [trial_time]) #time within trial 52 | append!(all_states, [states]) #subject locations 53 | append!(all_wall_loc, [wall_loc]) #wall locations 54 | append!(all_ps, [ps]) #reward locations 55 | append!(all_as, [as]) #actions taken 56 | append!(Nepisodes, size(ps, 2)) #number of episodes for this user 57 | push!(tokens, info[1, "token"]) 58 | end 59 | end 60 | end 61 | valid_users = 1:length(all_rews) 62 | 63 | println("processing data for $(length(valid_users)) users") 64 | 65 | #store all data 66 | data = [all_states, all_ps, all_as, all_wall_loc, all_rews, all_RTs, all_trial_nums, all_trial_time] 67 | @save "$(datadir)/human_all_data_$game_type$wrapstr.bson" data 68 | 69 | #store some generally useful data 70 | data = Dict("all_rews" => all_rews, "all_RTs" => all_RTs) 71 | @save "$(datadir)/human_RT_and_rews_$game_type$wrapstr.bson" data 72 | 73 | # compute steps by trial number 74 | function comp_rew_by_step(rews; Rmin = 4) 75 | keep_inds = findall( sum(rews .> 0.5, dims = 2)[:] .>= Rmin ) #only consider episodes with at least Rmin completed trials 76 | all_durs = zeros(length(keep_inds), Rmin) #container for durations of each trial (in steps) 77 | for (ib, b) = enumerate(keep_inds) #loop through episodes 78 | sortrew = sortperm(-rews[b, :]) #find reward times 79 | rewtimes = [0; sortrew[1:Rmin]] 80 | durs = rewtimes[2:Rmin+1] - rewtimes[1:Rmin] #difference between reward times 81 | all_durs[ib, :] = durs #store 82 | end 83 | μ = mean(all_durs, dims = 1)[:] #mean 84 | s = std(all_durs, dims = 1)[:] / sqrt(length(keep_inds)) #standard error 85 | return μ, s 86 | end 87 | 88 | μs, ss = [], [] 89 | Rmin = 4 90 | for i = valid_users #compute for each user 91 | μ, s = comp_rew_by_step(all_rews[i], Rmin = Rmin) 92 | push!(μs, μ) 93 | push!(ss, s) 94 | end 95 | μs = reduce(hcat, μs) #combine data 96 | ss = reduce(hcat, ss) 97 | 98 | #save data 99 | data = [Rmin, μs, ss] 100 | @save "$(datadir)/human_by_trial_$game_type$wrapstr.bson" data 101 | 102 | # RT by distance step and distance to goal 103 | function human_RT_by_difficulty(T, rews, ps, wall_loc, Larena, trial_nums, trial_time, RTs, states) 104 | trials = 20 #maximum number of trials 105 | new_RTs = zeros(trials, size(rews, 1), T) .+ NaN #RTs 106 | new_dists = zeros(trials, size(rews, 1)) .+ NaN #distances to goal 107 | for b = 1:size(rews, 1) #for each episode 108 | rew = rews[b, :] #rewards in this episode 109 | min_dists = dist_to_rew(ps[:, b:b], wall_loc[:, :, b:b], Larena) #minimum distances to goal for each state 110 | for trial = 2:trials #consider only exploitation 111 | if sum(rew .> 0.5) .> (trial - 0.5) #finished trial 112 | inds = findall((trial_nums[b, :] .== trial) .& (trial_time[b, :] .> 0.5)) #all timepoints within trial 113 | new_RTs[trial, b, 1:length(inds)] = RTs[b, inds] #reaction times 114 | state = states[:, b, inds[1]] #initial state 115 | new_dists[trial, b] = min_dists[Int(state[1]), Int(state[2])] #distance to goal from initial state 116 | end 117 | end 118 | end 119 | return new_RTs, new_dists #return RTs and distances 120 | end 121 | 122 | # repeat by participant 123 | RTs, dists = [], [] 124 | for u = valid_users 125 | new_RTs, new_dists = human_RT_by_difficulty(T, all_rews[u], all_ps[u], all_wall_loc[u], Larena, all_trial_nums[u], all_trial_time[u], all_RTs[u], all_states[u]) 126 | push!(RTs, new_RTs); push!(dists, new_dists) #add results to container 127 | end 128 | 129 | #write result to a file 130 | data = [RTs, dists, all_trial_nums, all_trial_time] 131 | @save "$(datadir)RT_by_complexity_by_user_$game_type$wrapstr.bson" data 132 | 133 | # compute RT by unique states visited during exploration 134 | all_unique_states = [] 135 | for i = 1:length(all_RTs) #for each user 136 | states, rews, as = all_states[i], all_rews[i], all_as[i] #extract states, rewards and actions 137 | unique_states = zeros(size(all_RTs[i])) .+ NaN #how many states had been seen when the action was taken 138 | for b = 1:size(rews,1) #for each episode 139 | if sum(rews[b, :]) == 0 #if there are no finished trials 140 | tmax = sum(as[b, :] .> 0.5) #iterate until end 141 | else 142 | tmax = findall(rews[b, :] .== 1)[1] #iterate until first reward 143 | end 144 | visited = Bool.(zeros(16)) #which states have been visited 145 | for t = 1:tmax #for each action in trial 1 146 | visited[Int(state_ind_from_state(Larena, states[:,b,t])[1])] = true #visited corresponding state 147 | unique_states[b, t] = sum(visited) #number of unique states 148 | end 149 | end 150 | push!(all_unique_states, unique_states) #add to container 151 | end 152 | 153 | #write data to file 154 | data = [all_RTs, all_unique_states] 155 | @save "$(datadir)unique_states_$game_type$wrapstr.bson" data 156 | 157 | end 158 | 159 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/analyse_rollout_timing.jl: -------------------------------------------------------------------------------- 1 | # in this script, we analyse the timing of rollouts in the RL agent 2 | # the resulting 'response times' can then be compared with human behavioural data 3 | 4 | ## load scripts and model 5 | include("anal_utils.jl") 6 | using ToPlanOrNotToPlan 7 | 8 | println("analysing the timings of rollouts in the RL agent") 9 | 10 | loss_hp = LossHyperparameters(0, 0, 0, 0) #not computing losses 11 | epoch = plan_epoch #test epoch 12 | 13 | for seed = seeds #iterate through models trained independently 14 | 15 | # load the model parameters 16 | fname = "N100_T50_Lplan8_seed$(seed)_$epoch" 17 | println("loading ", fname) 18 | network, opt, store, hps, policy, prediction = recover_model("../models/$fname"); 19 | 20 | # instantiate model and environment 21 | Larena = hps["Larena"] 22 | model_properties, wall_environment, model_eval = build_environment( 23 | Larena, hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions 24 | ) 25 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) 26 | 27 | # run a bunch of episodes 28 | Random.seed!(1) 29 | batch_size = 50000 30 | tic = time() 31 | L, ys, rews, as, world_states, hs = run_episode( 32 | m, wall_environment, loss_hp; batch=batch_size, calc_loss = false 33 | ) 34 | 35 | # extract some data we might need 36 | states = reduce((a, b) -> cat(a, b, dims = 3), [ws.agent_state for ws = world_states]) #states over time 37 | wall_loc = world_states[1].environment_state.wall_loc #wall location 38 | ps = world_states[1].environment_state.reward_location #reward location 39 | Tmax, Nstates = size(as, 2), Larena^2 #extract some dimensions 40 | rew_locs = reshape(ps, Nstates, batch_size, 1) .* ones(1, 1, Tmax) #for each time point 41 | println("average reward: ", sum(rews .> 0.5) / batch_size, " time: ", time() - tic) #average reward per episode 42 | 43 | # how many steps/actions were planned 44 | plan_steps = zeros(batch_size, Tmax); 45 | for t = 1:Tmax-1 46 | plan_steps[:,t] = sum(world_states[t+1].planning_state.plan_cache' .> 0.5, dims = 2)[:]; 47 | end 48 | 49 | #extract some trial information 50 | trial_ts = zeros(batch_size, Tmax) # network iteration within trial 51 | trial_ids = zeros(batch_size, Tmax) # trial number 52 | trial_anums = zeros(batch_size, Tmax) # action number (not counting rollouts) 53 | for b = 1:batch_size #iterate through episodes 54 | Nrew = sum(rews[b, :] .> 0.5) #total number of rewards 55 | sortrew = sortperm(-rews[b, :]) #indices or sorted array 56 | rewts = sortrew[1:Nrew] #times at which we got reward 57 | diffs = [rewts; Tmax+1] - [0; rewts] #duration of each trial 58 | trial_ids[b, :] = reduce(vcat, [ones(diffs[i]) * i for i = 1:(Nrew+1)])[1:Tmax] #trial number 59 | trial_ts[b, :] = reduce(vcat, [1:diffs[i] for i = 1:(Nrew+1)])[1:Tmax] #time within trial 60 | 61 | finished = findall(as[b, :] .== 0) #timepoints at which episode is finished 62 | #zero out finished steps 63 | trial_ids[b, finished] .= 0 64 | trial_ts[b, finished] .= 0 65 | plan_steps[b, finished] .= 0 66 | 67 | #extract the action number for each iteration 68 | ep_as = as[b, :] 69 | for id = 1:(Nrew+1) #for each trial 70 | inds = findall(trial_ids[b, :] .== id) #indices of this trial 71 | trial_as = ep_as[inds] #actions within this trial 72 | anums = zeros(Int64, length(inds)) #list of action numbers 73 | anum = 1 #start at first action 74 | for a = 2:length(inds) #go through all network iterations 75 | anums[a] = anum #store the action number 76 | if trial_as[a] <= 4.5 anum +=1 end #increment if not a rollout 77 | end 78 | trial_anums[b, inds] = anums #store all action numbers 79 | end 80 | end 81 | 82 | ## look at performance by trial 83 | 84 | Rmin = 4 #only consider trials with >=Rmin reward (to control for correlation between performance and steps-per-trial) 85 | inds = findall(sum(rews, dims = 2)[:] .>= Rmin) #episodes with >= Rmin reward 86 | perfs = reduce(hcat, [[trial_anums[b, trial_ids[b, :] .== t][end] for t = 1:Rmin] for b = inds])' #performance for each trial number 87 | 88 | # compute optimal baseline 89 | mean_dists = zeros(batch_size) # mean goal distances from all non-goal locations 90 | for b in 1:batch_size 91 | dists = dist_to_rew(ps[:, b:b], wall_loc[:, :, b:b], Larena) #goal distances for this arena 92 | mean_dists[b] = sum(dists) / (Nstates - 1) #average across non-goal states 93 | end 94 | μ, s = mean(perfs, dims = 1)[:], std(perfs, dims = 1)[:]/sqrt(batch_size) #compute summary statistics 95 | data = [Rmin, μ, s, mean(mean_dists)] 96 | @save "$(datadir)/model_by_trial$seed.bson" data #store data 97 | 98 | ## planning by difficulty 99 | 100 | trials = 15 101 | new_RTs = zeros(trials, batch_size, hps["T"]) .+ NaN; 102 | new_alt_RTs = zeros(trials, batch_size, hps["T"]) .+ NaN; 103 | new_dists = zeros(trials, batch_size) .+ NaN; 104 | for b = 1:batch_size 105 | rew = rews[b, :] #rewards in this episode 106 | min_dists = dist_to_rew(ps[:, b:b], wall_loc[:, :, b:b], Larena) #minimum distances to goal for each state 107 | for trial = 2:trials 108 | if sum(rew .> 0.5) .> (trial - 0.5) #finish trial 109 | inds = findall((trial_ids[b, :] .== trial) .& (trial_ts[b, :] .> 1.5)) #all timepoints within trial 110 | 111 | anums = trial_anums[b, inds] 112 | RTs = [sum(anums .== anum) for anum = 1:anums[end]] 113 | 114 | plan_nums = plan_steps[b, inds] 115 | alt_RTs = [sum(plan_nums[anums .== anum]) for anum = 1:anums[end]] #count as number of simulated steps 116 | new_alt_RTs[trial, b, 1:length(alt_RTs)] = alt_RTs #reaction times 117 | 118 | for anum = 1:anums[end] 119 | ainds = findall(anums .== anum) 120 | if length(ainds) > 1.5 121 | @assert all(plan_nums[ainds[1:(length(ainds)-1)]] .> 0.5) #should all have non-zero plans 122 | end 123 | end 124 | 125 | new_RTs[trial, b, 1:length(RTs)] = RTs #reaction times 126 | state = states[:, b, inds[1]] #initial state 127 | new_dists[trial, b] = min_dists[Int(state[1]), Int(state[2])] 128 | end 129 | end 130 | end 131 | 132 | dists = 1:8 133 | dats = [new_RTs[(new_dists.==dist), :] for dist in dists] 134 | data = [dists, dats] 135 | @save "$(datadir)model_RT_by_complexity$(seed)_$epoch.bson" data 136 | alt_dats = [new_alt_RTs[(new_dists.==dist), :] for dist in dists] 137 | data = [dists, alt_dats] 138 | @save "$(datadir)model_RT_by_complexity_bystep$(seed)_$epoch.bson" data 139 | 140 | ## look at exploration 141 | 142 | RTs = zeros(size(rews)) .+ NaN; 143 | unique_states = zeros(size(rews)) .+ NaN; #how many states had been seen when the action was taken 144 | for b = 1:batch_size 145 | inds = findall(trial_ids[b, :] .== 1) 146 | anums = Int.(trial_anums[b, inds]) 147 | if sum(rews[b, :]) == 0 tmax = sum(as[b, :] .> 0.5) else tmax = findall(rews[b, :] .== 1)[1] end 148 | visited = Bool.(zeros(16)) #which states have been visited 149 | for anum = unique(anums) 150 | state = states[:,b,findall(anums .== anum)[1]] 151 | visited[Int(state_ind_from_state(Larena, state)[1])] = true 152 | unique_states[b, anum+1] = sum(visited) 153 | RTs[b, anum+1] = sum(anums .== anum) 154 | end 155 | end 156 | 157 | data = [RTs, unique_states] 158 | @save "$(datadir)model_unique_states_$(seed)_$epoch.bson" data 159 | 160 | ## do decoding of rew loc by unique states 161 | unums = 1:15 162 | dec_perfs = zeros(length(unums)) 163 | for unum = unums 164 | inds = findall(unique_states .== unum) 165 | ahot = zeros(Float32, 5, length(inds)) 166 | for (i, ind) = enumerate(inds) ahot[Int(as[ind]), i] = 1f0 end 167 | X = [hs[:, inds]; ahot] #Nhidden x batch x T -> Nhidden x iters 168 | Y = rew_locs[:, inds] 169 | Yhat = m.prediction(X)[17:32, :] 170 | Yhat = exp.(Yhat .- Flux.logsumexp(Yhat; dims=1)) #softmax over states 171 | perf = sum(Yhat .* Y) / size(Y, 2) 172 | dec_perfs[unum] = perf 173 | end 174 | data = [unums, dec_perfs] 175 | @save "$(datadir)model_exploration_predictions_$(seed)_$epoch.bson" data 176 | 177 | end 178 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/analyse_variable_rollouts.jl: -------------------------------------------------------------------------------- 1 | #in this script, we repeat some key analyses with different network sizes and rollout lengths 2 | #we do this to assess the robustness of our results 3 | 4 | # load some scripts 5 | include("anal_utils.jl") 6 | 7 | global run_default_analyses = false # load functions without running analyses for default models 8 | include("repeat_human_actions.jl") 9 | include("perf_by_rollout_number.jl") 10 | include("behaviour_by_success.jl") 11 | global run_default_analyses = true # back to default 12 | 13 | println("repeating analyses with different hyperparameters") 14 | 15 | prefix = "variable_" 16 | seeds = 61:65 17 | N, Lplan = 100, 8 18 | 19 | println("running N=$N, L=$Lplan") 20 | # correlation with human RT #### 21 | repeat_human_actions(;seeds, N, Lplan, epoch, prefix = prefix, model_prefix = prefix) 22 | 23 | # change in performance with replay number ### 24 | run_perf_by_rollout_number(;seeds, N, Lplan, epoch, prefix = prefix, model_prefix = prefix) 25 | 26 | # change in policy after successful/unsuccessful replay #### 27 | run_causal_rollouts(;seeds, N, Lplan, epoch, prefix = prefix, model_prefix = prefix) 28 | 29 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/calc_human_prior.jl: -------------------------------------------------------------------------------- 1 | ## load scripts and model 2 | include("anal_utils.jl") 3 | using ToPlanOrNotToPlan 4 | 5 | wrapstr = "" 6 | #wrapstr = "_euclidean" # to run the Euclidean analysis 7 | 8 | println("computing prior parameters for human response times") 9 | 10 | #start by loading our processed human data for the guided ('follow') trials 11 | @load "$(datadir)/human_all_data_follow$wrapstr.bson" data 12 | _, _, _, _, _, all_RTs_f, all_trial_nums_f, all_trial_time_f = data; 13 | Nuser = length(all_RTs_f) #number of users 14 | 15 | params = Dict(key => zeros(Nuser, 3) for key = ["initial"; "later"]) #all participants; initial/later actions; 3 parameters 16 | for (i_init, initial) = enumerate([true, false]) #are we considering the first action in each trial? 17 | if initial key = "initial" else key = "later" end 18 | for u = 1:Nuser #for each participant 19 | 20 | # datapoints during guided trials 21 | if initial 22 | inds = (all_trial_nums_f[u] .> 1.5) .& (all_trial_time_f[u] .== 1) #first action in exploitation trials 23 | else 24 | inds = (all_trial_nums_f[u] .> 1.5) .& (all_trial_time_f[u] .> 1.5) #later actions in exploitation trials 25 | end 26 | RTs_f = all_RTs_f[u][inds] #reaction times for these actions 27 | RTs_f = RTs_f[.~isnan.(RTs_f)] #remove if there is missing data 28 | if u % 10 == 0 println("user $u, $key actions, $(length(RTs_f)) datapoints") end 29 | 30 | #try different deltas in our shifted lognormal prior 31 | deltas = 0:1:(minimum(RTs_f)-1) #list of deltas to try (the ones with appropriate support) 32 | Ls, mus, sigs = [zeros(length(deltas)) for _ = 1:3] #corresponding log liks and optimal params 33 | for (i, delta) = enumerate(deltas) #compute likelihood with each delta 34 | mus[i] = mean(log.(RTs_f .- delta)) #mean of the shifted lognormal 35 | sigs[i] = std(log.(RTs_f .- delta)) #standard deviation 36 | Ls[i] = sum(log.(lognorm.(RTs_f, mu = mus[i], sig = sigs[i], delta = delta))) #log likelihood of the data 37 | end 38 | 39 | #extract maximum likelihood parameters 40 | muhat, sighat, deltahat = [arr[argmax(Ls)] for arr = [mus, sigs, deltas]] 41 | params[key][u, :] = [muhat; sighat; deltahat] #store parameters 42 | end 43 | end 44 | 45 | #write to file 46 | @save "$datadir/guided_lognormal_params_delta$wrapstr.bson" params #mu, sigma, delta 47 | 48 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/compare_maze_path_lengths.jl: -------------------------------------------------------------------------------- 1 | #in this script, we compare path lengths in the 'Euclidean' and 'non-Euclidean' arenas 2 | 3 | include("anal_utils.jl") 4 | using ToPlanOrNotToPlan 5 | 6 | N = 10000 # number of arenas to compare 7 | cat3(a, b) = cat(a, b, dims = 3) 8 | w_wraps = reduce(cat3, [maze(4, wrap = true) for _ in 1:N]) # generate toroidal mazes 9 | w_nowraps = reduce(cat3, [maze(4, wrap = false) for _ in 1:N]) # generate Euclidean mazes 10 | 11 | ps = onehot_from_loc(4, 1:16) # possible goal locations 12 | dists_wraps = zeros(N, 16, 16) # all-to-all distances 13 | dists_nowraps = zeros(N, 16, 16) 14 | for i1 = 1:N # for each maze 15 | for i2 = 1:16 # for each goal location 16 | # compute distance from all start 17 | dists_wraps[i1, i2, :] = dist_to_rew(ps[:, i2:i2], w_wraps[:, :, i1:i1], 4) 18 | dists_nowraps[i1, i2, :] = dist_to_rew(ps[:, i2:i2], w_nowraps[:, :, i1:i1], 4) 19 | end 20 | end 21 | 22 | dists_wraps = dists_wraps[dists_wraps .> 0.5] # ignore self-distances 23 | dists_nowraps = dists_nowraps[dists_nowraps .> 0.5] # ignore self-distances 24 | 25 | dists = [dists_wraps, dists_nowraps] 26 | @save "$(datadir)/wrap_and_nowrap_pairwise_dists.bson" dists 27 | 28 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/compare_perf_without_rollout.jl: -------------------------------------------------------------------------------- 1 | # in this script we compare a model with and without rollouts. 2 | # this allows us to investigate whether rollouts improve performance, taking into account the opportunity cost. 3 | 4 | # load scripts and model 5 | include("anal_utils.jl") 6 | using ToPlanOrNotToPlan 7 | 8 | println("comparing performance with and without rollouts") 9 | 10 | loss_hp = LossHyperparameters(0, 0, 0, 0) #we're not computing a loss 11 | 12 | greedy_actions = true 13 | epoch = plan_epoch 14 | results = Dict() #dictionary to store results 15 | batch_size = 50000 #number of environments to run 16 | 17 | for seed = seeds #for each independently trained model 18 | 19 | results[seed] = Dict() #results for this model 20 | for plan = [false; true] #no rollouts (false) or rollouts (true) 21 | Random.seed!(1) #set a seed for reproducibility 22 | 23 | filename = "N100_T50_Lplan8_seed$(seed)_$epoch" #model to load 24 | network, opt, store, hps, policy, prediction = recover_model(loaddir*filename) #load model parameters 25 | Larena = hps["Larena"] 26 | #construct environment, noting whether rollouts are allowed or now 27 | model_properties, wall_environment, model_eval = build_environment( 28 | Larena, hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions,no_planning = (~plan) 29 | ) 30 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) #construct model 31 | Nstates = Larena^2 32 | 33 | tic = time() 34 | L, ys, rews, as, world_states, hs = run_episode( 35 | m, wall_environment, loss_hp; batch=batch_size, calc_loss = false 36 | ) #let the agent act in the environments (parallelized) 37 | 38 | #print a brief summary 39 | println("\n", seed, " rollouts: ", plan) 40 | println(sum(rews .> 0.5) / batch_size, " ", time() - tic) 41 | println("planning fraction: ", sum(as .> 4.5) / sum(as .> 0.5)) 42 | results[seed][plan] = rews #write result before moving on 43 | end 44 | end 45 | 46 | #save results 47 | @save "$(datadir)/performance_with_out_planning.bson" results 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/estimate_num_mazes.jl: -------------------------------------------------------------------------------- 1 | #in this script, we estimate the space of environments spanned by our task set 2 | 3 | #load some stuff 4 | include("anal_utils.jl") 5 | using ToPlanOrNotToPlan 6 | 7 | println("estimating the total number of possible tasks") 8 | 9 | #construct environment 10 | model_properties, wall_environment, model_eval = build_environment(Larena, N, 50, Lplan = Lplan) 11 | 12 | batch = 50000 #number of environments to create 13 | Nstates = Larena^2 #number of unique states (and therefore potential reward locations) 14 | all_Npairs, all_Nids = [], [] #total pairwise comparisons and total identical comparisons 15 | Nseeds = 10 #repeat 10 times for uncertainty quantification 16 | for seed = 1:Nseeds 17 | Random.seed!(seed) #set random seed for reproducibility 18 | #create environments 19 | world_state, agent_input = wall_environment.initialize(zeros(2), zeros(2), batch, model_properties) 20 | Ws = world_state.environment_state.wall_loc #all the wall locations 21 | 22 | Npairs, Nid = 0, 0 #start from zero comparisons 23 | for b1 = 1:batch #for each environment 24 | if b1 % 10000 == 0 println("seed $seed of $Nseeds, environment $b1: $Npairs pairwise comparisons, $Nid identical") end 25 | for b2 = b1+1:batch #for each different environment 26 | Npairs += 1 #one more pairwise comparison 27 | Nid += Int(Ws[:, :, b1] == Ws[:, :, b2]) #are these two mazes identical? 28 | end 29 | end 30 | 31 | frac_id = Nid/Npairs #fraction of identical wall layouts 32 | println("fraction identical: ", frac_id) #inverse of the number of wall layouts 33 | println("effective task space: ", Nstates/frac_id) #16 rew locations * 1/f wall layouts 34 | push!(all_Npairs, Npairs); push!(all_Nids, Nid) 35 | end 36 | 37 | #save result for future reference 38 | result = Dict("Npairs" => all_Npairs, "Nids" => all_Nids) 39 | @save "$datadir/estimate_num_mazes.bson" result 40 | 41 | task_spaces = Nstates * all_Npairs ./ all_Nids #effective task space for each seed 42 | num_mazes = mean(task_spaces) #mean 43 | err = std(task_spaces)/sqrt(length(task_spaces)) #standard error 44 | println("effective task space: ", num_mazes, " sem: ", err) #16 rew locations 45 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/euclidean_prolific_ids.jl: -------------------------------------------------------------------------------- 1 | global euclidean_ids = [ 2 | "60ec1eca30581c200a01370c"; 3 | "5bacd7c4a3eb9600018dd69e"; 4 | "5f83e5fac85c275782756b34"; 5 | "5e5a690a2bd7a624fc5c5b23"; 6 | "610131dc603d6066307aee2d"; 7 | "646262b20e27617617167a44"; 8 | "5f31e44019827411ed17a882"; 9 | "61413300c9d2346cfb3d9ca5"; 10 | "5fb91837b8c8756d924f7351"; 11 | "6133d6ceda5d8d0022d8b108"; 12 | "6164a16e7e3cfad826f0fd4c"; 13 | "615c87f5fa0f37dd6988da48"; 14 | "61522cca7870d26be8d1ace7"; 15 | "6106a778cc3899ec7b5a753d"; 16 | "60f7f3873819d1655123ca3c"; 17 | "616477d6533d65db91e19733"; 18 | "61660b4af0be02375d417718"; 19 | "611cbeb6f91dd95de0263ebd"; 20 | "60da34669fd87bea96619e41"; 21 | "5f48e56c7f02363ac9920942"; 22 | "61674f47e2887da2ffd9bd55"; 23 | "6167ba597ea59c1d6bf376c0"; 24 | "5f1c55c8ea92af4d99d03137"; 25 | "60d0d3e7e4d728de23952f2a"; 26 | "611bd9d8f2dc237d183fe3f7"; 27 | "5fd8e3ee6feb8e0c98df2a2f"; 28 | "5e16612c4fa02ac47f453669"; 29 | "60d9dc4df7164f75763470d8"; 30 | "5ec3eabec8f93307a60c5d1b"; 31 | "60a595202752aa1fcd8e07e7"; 32 | "5f551b05e51b928b231a852f"; 33 | "5c49fa25c2653a00018b2d79"; 34 | "604812111a944813de0baa21"; 35 | "6150c99fa0819f39b939f866"; 36 | "61393cb48277d307aa677992"; 37 | "6166c3bf37325d7674746dcd"; 38 | "6155f36c55c59b225ed57879"; 39 | "647f32f00c9efc0c32604a71"; 40 | "5c2288f7867f660001ad65bc"; 41 | "5f096f8e95f9771fe09fe2d3"; 42 | "5f2f1a662d7fa52ee59ec87d"; 43 | "5fbbde667818f3dc301f6647"; 44 | "61327d64b2aa927d577068da"; 45 | "60ebf6da9f545cb42c88af68"; 46 | "5f58fa9f43f1b20c9761ed64"; 47 | "566dedbd57f93300112d0ee2"; 48 | "6130fbdcf112f58c0bd7fd7e"; 49 | "5a9aa66a35237b000112937b"; 50 | "61524134790f9a1cd758817c"; 51 | "60fd310fff5b2600c9d95554"; 52 | ] 53 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/eval_value_function.jl: -------------------------------------------------------------------------------- 1 | ## in this script we analyse how the value estimate of the agent changes with rollouts 2 | 3 | ## load scripts and model 4 | include("anal_utils.jl") 5 | using ToPlanOrNotToPlan 6 | 7 | println("comparing performance with and without rollouts") 8 | 9 | ## 10 | 11 | loss_hp = LossHyperparameters(0, 0, 0, 0) #we're not computing a loss 12 | 13 | greedy_actions = true 14 | epoch = plan_epoch 15 | results = Dict() #dictionary to store results 16 | batch_size = 10000 #number of environments to run 17 | data = Dict() 18 | for seed = seeds 19 | results[seed] = Dict() #results for this model 20 | Random.seed!(1) #set a seed for reproducibility 21 | 22 | filename = "N100_T50_Lplan8_seed$(seed)_$epoch" #model to load 23 | println("\nrunning $filename") 24 | network, opt, store, hps, policy, prediction = recover_model(loaddir*filename) #load model parameters 25 | Larena = hps["Larena"] 26 | 27 | ## construct environment, noting whether rollouts are allowed or now 28 | model_properties, wall_environment, model_eval = build_environment( 29 | Larena, hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions 30 | ) 31 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) #construct model 32 | Nstates = Larena^2 33 | 34 | tic = time() 35 | L, ys, rews, as, world_states, hs = run_episode( 36 | m, wall_environment, loss_hp; batch=batch_size, calc_loss = false 37 | ) #let the agent act in the environments (parallelized) 38 | 39 | #print a brief summary 40 | println(sum(rews .> 0.5) / batch_size, " ", time() - tic) 41 | println("planning fraction: ", sum(as .> 4.5) / sum(as .> 0.5)) 42 | 43 | ## 44 | 45 | # extract time-within-episode 46 | ts = reduce(hcat, [state.environment_state.time for state = world_states]) 47 | # compute all reward-to-gos 48 | rew_to_go = sum(rews .> 0.5, dims = 2) .- [zeros(batch_size) cumsum(rews[:, 1:size(rews,2)-1], dims = 2)] 49 | 50 | # compute all value functions 51 | Naction = wall_environment.dimensions.Naction 52 | Vs = ys[Naction+1, :, :] 53 | accuracy = abs.(Vs - rew_to_go) # accuracy of the value function 54 | 55 | ## 56 | plan_nums = Int.(zeros(size(accuracy))) # rollout iteration 57 | tot_plans = Int.(zeros(size(accuracy))) # total number of rollouts 58 | suc_rolls = Int.(zeros(size(accuracy))) # is this a response to a successful rollout 59 | num_suc_rolls = Int.(zeros(size(accuracy))) # number of successful rollouts in this sequence 60 | for b = 1:batch_size # for each episode 61 | plan_num, init_plan = 0, 0 # initialize 62 | if sum(rews[b, :]) > 0.5 # if we found the goal at least once 63 | for anum = findfirst(rews[b, :] .== 1)+1:sum(as[b, :] .> 0.5) # for each iteration 64 | a = as[b, anum] 65 | 66 | if (a == 5) && (rews[b, anum-1] != 1) # planning and didn't just get reward 67 | plan_num += 1 # update rollout number within sequence 68 | if plan_num == 1 # just started planning 69 | init_plan = anum # iteration at which this rollout sequence started 70 | # didn't plan on previous iteration, should have no plan input 71 | @assert sum(world_states[anum].planning_state.plan_input[:, b]) < 0.5 72 | else 73 | # planned on previous iteration, should have planning input 74 | @assert sum(world_states[anum].planning_state.plan_input[:, b]) > 0.5 75 | end 76 | plan_nums[b, anum] = plan_num-1 # number of rollouts before generating this output 77 | suc_rolls[b, anum] = world_states[anum].planning_state.plan_input[end, b] # is this a response to a successful rollout? 78 | else 79 | if plan_num > 0 # just finished planning 80 | tot_plans[b, (init_plan):(anum)] .= plan_num # total number of rollouts in this sequence 81 | plan_nums[b, anum] = plan_num 82 | # double check that we've just planned 83 | @assert sum(world_states[anum].planning_state.plan_input[:, b]) > 0.5 84 | suc_rolls[b, anum] = world_states[anum].planning_state.plan_input[end, b] # is this a response to a successful rollout? 85 | num_suc_rolls[b, (init_plan):(anum)] .= sum(suc_rolls[b, (init_plan):(anum)]) # total number of successful rollouts in this sequence 86 | end 87 | plan_num = 0 # reset planning counter 88 | end 89 | end 90 | end 91 | end 92 | 93 | # save relevant data 94 | data[seed] = Dict("tot_plans" => tot_plans, 95 | "plan_nums" => plan_nums, 96 | "suc_rolls" => suc_rolls, 97 | "num_suc_rolls" => num_suc_rolls, 98 | "Vs" => Vs, 99 | "rew_to_go" => rew_to_go, 100 | "as" => as, 101 | "ts" => ts) 102 | 103 | end 104 | @save datadir * "value_function_eval.bson" data # store result 105 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/perf_by_rollout_number.jl: -------------------------------------------------------------------------------- 1 | # in this script, we analyse how model performance depends on the number of rollouts 2 | # this allow us to investigate whether rollouts improve the policy 3 | 4 | # load some packages 5 | include("anal_utils.jl") 6 | using ToPlanOrNotToPlan 7 | 8 | try 9 | println("running default analyses: ", run_default_analyses) 10 | catch e 11 | global run_default_analyses = true 12 | end 13 | 14 | """ 15 | run_perf_by_plan_number(;seeds, N, Lplan, epoch, prefix = "") 16 | analyses the performance (in terms of steps to goal) on trial 2 17 | as a function of the number of enforced rollouts. 18 | """ 19 | function run_perf_by_rollout_number(;seeds, N, Lplan, epoch, prefix = "", model_prefix = "") 20 | println("quantifying trial 2 performance by number of rollouts") 21 | 22 | res_dict = Dict() #dictionary to store results 23 | 24 | for seed = seeds #iterate through random seeds 25 | res_dict[seed] = Dict() #results for this seed 26 | filename = "$(model_prefix)N$(N)_T50_Lplan$(Lplan)_seed$(seed)_$epoch" #model to load 27 | println("\nloading $filename") 28 | network, opt, store, hps, policy, prediction = recover_model(loaddir*filename) #load model parameters 29 | 30 | Larena = hps["Larena"] #size of the arena 31 | model_properties, wall_environment, model_eval = build_environment( 32 | Larena, hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions 33 | ) #construct RL environment 34 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) #construct model 35 | 36 | # set some parameters 37 | batch = 1 #one batch at a time 38 | ed = wall_environment.dimensions 39 | Nstates, Naction = ed.Nstates, ed.Naction 40 | Nin_base = Naction + 1 + 1 + Nstates + 2 * Nstates #'physical' input dimensions 41 | Nhidden = m.model_properties.Nhidden 42 | tmax = 50 43 | Lplan = model_properties.Lplan 44 | nreps = 5000 #number of random environments to consider 45 | nplans = 0:15 #number of plans enforced 46 | dts = zeros(2, nreps, length(nplans)) .+ NaN; #time to goal 47 | policies = zeros(2, nreps, length(nplans), 10, 5) .+ NaN; #store policies 48 | mindists = zeros(nreps, length(nplans)); 49 | 50 | for ictrl = [1;2] #plan input or not (zerod out) 51 | for nrep = 1:nreps #for each repetition 52 | if nrep % 1000 == 0 println(nrep) end 53 | for (iplan, nplan) = enumerate(nplans) #for each number of rollouts enforced 54 | Random.seed!(nrep) #set random seed for consistent environment across #rollouts 55 | world_state, agent_input = wall_environment.initialize( 56 | zeros(2), zeros(2), batch, m.model_properties 57 | ) #initialize environment 58 | agent_state = world_state.agent_state #initialize agent location 59 | h_rnn = m.network[GRUind].cell.state0 .+ Float32.(zeros(Nhidden, batch)) #expand hidden state 60 | exploit = Bool.(zeros(batch)) #keep track of exploration vs exploitation 61 | rew = zeros(batch) #keep track of reward 62 | if iplan == 1 63 | ps, ws = world_state.environment_state.reward_location, world_state.environment_state.wall_loc 64 | global dists = dist_to_rew(ps, ws, Larena) #compute distances to goal for this arena 65 | end 66 | 67 | tot_n = nplan #rollouts to go 68 | t = 0 #timestep 69 | finished = false #have we finished this environment 70 | nact = 0 #number of physical actions 71 | while ~finished #until finished 72 | t += 1 #update iteration number 73 | agent_input, world_state, rew = agent_input, world_state, rew 74 | if (ictrl == 2 && exploit[1]) agent_input[Nin_base+1:end, :] .= 0 end #no planning input if ctrl 75 | h_rnn, agent_output, a = m.forward(m, ed, agent_input, h_rnn) #RNN step 76 | 77 | plan = false #have we just performed a rollout 78 | if exploit[1] && (tot_n > 0.5) #exploitation phase and more plans to go 79 | plan, tot_n = true, tot_n-1 #we will perform a rollout; decrease counter-to-go 80 | if tot_n == 0 #if we're done 81 | state = world_state.agent_state[:] #current location 82 | mindists[nrep, iplan] = dists[state[1], state[2]] #distance to goal 83 | end 84 | end 85 | if plan #we need to do a rollout 86 | a[1] = 5 #perform a rollout 87 | elseif exploit[1] #exploitation phase 88 | nact += 1 #increment action number to goal 89 | a[1] = argmax(agent_output[1:4, 1]) #greedy action selection 90 | if nact <= 10 policies[ictrl, nrep, iplan, nact, :] = agent_output[1:5, 1] end #store policy 91 | else #exploration phase (trial 1) 92 | a[1] = argmax(agent_output[1:5, 1]) #greedy and don't store anything 93 | end 94 | 95 | exploit[rew[:] .> 0.5] .= true #we're in the exploitation phase if we have found reward 96 | #pass action to environment 97 | rew, agent_input, world_state, predictions = wall_environment.step( 98 | agent_output, a, world_state, wall_environment.dimensions, m.model_properties, m, h_rnn 99 | ) 100 | 101 | if rew[1] > 0.5 # if we just found a reward 102 | if exploit[1] #already found reward before 103 | finished = true #now finished since we only consider trial 2 104 | dts[ictrl, nrep, iplan] = t - t1 - 1 - nplan #store the number of actions to goal 105 | @assert nact == (t - t1 - 1 - nplan) 106 | @assert nact >= mindists[nrep, iplan] 107 | else #first time 108 | global t1 = t #reset timer at the end of first trial 109 | end 110 | end 111 | if t > tmax finished = true end #impose maximum time steps 112 | end 113 | end 114 | end 115 | end 116 | 117 | #store some data for this seed 118 | res_dict[seed]["dts"] = dts 119 | res_dict[seed]["mindists"] = mindists 120 | res_dict[seed]["nplans"] = nplans 121 | res_dict[seed]["policies"] = policies 122 | 123 | end 124 | 125 | #save our data 126 | savename = "$(prefix)N$(N)_Lplan$(Lplan)" 127 | @save datadir * "perf_by_n_$(savename).bson" res_dict 128 | 129 | end 130 | 131 | #run_default_analyses is a global parameter in anal_utils.jl 132 | run_default_analyses && run_perf_by_rollout_number(;seeds,N,Lplan,epoch) 133 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/quantify_internal_model.jl: -------------------------------------------------------------------------------- 1 | # in this script, we quantify the accuracy of the internal world model over training time 2 | 3 | # load scripts 4 | include("anal_utils.jl") 5 | using ToPlanOrNotToPlan, NaNStatistics 6 | 7 | println("quantifying accuracy of the internal world model over training") 8 | 9 | batch = 1000 #number of environments to consider 10 | results = Dict() #dictionary for storing results 11 | 12 | for seed = seeds #for each independently trained RL agent 13 | results[seed] = Dict() #results for this model 14 | for epoch = 0:50:1000 #for each training epoch 15 | 16 | # seed random seed for reproducibility 17 | Random.seed!(1) 18 | 19 | filename = "N100_T50_Lplan8_seed$(seed)_$epoch" #model to load 20 | network, opt, store, hps, policy, prediction = recover_model(loaddir*filename) #load model parameters 21 | 22 | #initialize environment and model 23 | Larena = hps["Larena"] 24 | model_properties, environment, model_eval = build_environment( 25 | Larena, hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions 26 | ) 27 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) 28 | 29 | #extract some useful parameters 30 | ed = environment.dimensions 31 | Nout = m.model_properties.Nout 32 | Nhidden = m.model_properties.Nhidden 33 | T, Naction, Nstates = ed.T, ed.Naction, ed.Nstates 34 | 35 | ### initialize reward probabilities and state ### 36 | world_state, agent_input = environment.initialize(zeros(2), zeros(2), batch, m.model_properties, initial_params = []) 37 | agent_state = world_state.agent_state 38 | h_rnn = m.network[GRUind].cell.state0 .+ Float32.(zeros(Nhidden, batch)) #expand hidden state 39 | 40 | #containers for storing prediction results 41 | rew_preds, state_preds = zeros(batch, 200) .+ NaN, zeros(batch, 200) .+ NaN 42 | exploit = Bool.(zeros(batch)) #are we in the exploitation phase 43 | iter = 1 #iteration number 44 | rew, old_rew = zeros(batch), zeros(batch) #containers for storing reward information 45 | 46 | #iterate through RL agent/environment 47 | while any(world_state.environment_state.time .< (T+1 - 1e-2)) 48 | iter += 1 #update iteration number 49 | h_rnn, agent_output, a = m.forward(m, ed, agent_input, h_rnn) #RNN step 50 | active = (world_state.environment_state.time .< (T+1 - 1e-2)) #active episodes 51 | 52 | old_rew[:] = rew[:] #did I get reward on previous timestep? 53 | #update environment given action and current state 54 | rew, agent_input, world_state, predictions = environment.step( 55 | agent_output, a, world_state, environment.dimensions, m.model_properties, 56 | m, h_rnn 57 | ) 58 | 59 | #extract true next state and reward location 60 | strue = [coord[1] for coord = argmax(onehot_from_state(Larena, world_state.agent_state), dims = 1)][:] 61 | rtrue = [coord[1] for coord = argmax(world_state.environment_state.reward_location, dims = 1)][:] 62 | 63 | #calculate reward prediction accuracy 64 | i1, i2 = (Naction + Nstates + 2), (Naction + Nstates + 1 + Nstates) #indices of corresponding output 65 | rpred = [coord[1] for coord = argmax(agent_output[i1:i2, :], dims = 1)][:] #extract prediction output 66 | inds = findall(exploit .& active) #only consider exploitation 67 | rew_preds[inds, iter] = Float64.(rpred .== rtrue)[inds] #store binary 'success' data 68 | 69 | ### calculate state accuracy ### 70 | i1, i2 = (Naction + 1 + 1), (Naction + 1 + Nstates) #indices of corresponding output 71 | spred = [coord[1] for coord = argmax(agent_output[i1:i2, :], dims = 1)][:] #extract prediction output 72 | inds = findall((old_rew .< 0.5) .& active) #ignore teleportation step 73 | state_preds[inds, iter] = Float64.(spred .== strue)[inds] #store binary 'success' data 74 | 75 | exploit[old_rew .> 0.5] .= true #indicate which episodes are in the exploitation phase 76 | rew, agent_input, a = zeropad_data(rew, agent_input, a, active) #mask if episode is finished 77 | end 78 | 79 | println(seed, " ", epoch, " ", nanmean(rew_preds), " ", nanmean(state_preds)) 80 | results[seed][epoch] = Dict("rew" => nanmean(rew_preds), "state" => nanmean(state_preds)) #store result 81 | end 82 | end 83 | 84 | @save "$(datadir)/internal_model_accuracy.bson" results 85 | 86 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/results/example_rollout.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/analysis_scripts/results/example_rollout.bson -------------------------------------------------------------------------------- /computational_model/analysis_scripts/run_all_analyses.jl: -------------------------------------------------------------------------------- 1 | # in this script, we call all of the model analysis functions. 2 | # this may take a while to run unless you have a very big computer 3 | 4 | println("running all analyses") 5 | tic = time() 6 | 7 | include("analyse_human_data.jl") 8 | include("calc_human_prior.jl") 9 | 10 | include("analyse_rollout_timing.jl") 11 | include("repeat_human_actions.jl") 12 | include("perf_by_rollout_number.jl") 13 | include("compare_perf_without_rollout.jl") 14 | include("shuffle_rollout_times.jl") 15 | include("behaviour_by_success.jl") 16 | include("model_replay_analyses.jl") 17 | include("rollout_as_pg.jl") 18 | #include("analyse_by_N.jl") [need to re-train models or include pretrained models] 19 | include("estimate_num_mazes.jl") 20 | 21 | include("compare_maze_path_lengths.jl") 22 | #include("analyse_hp_sweep.jl") [need to re-train models or include pretrained models] 23 | include("eval_value_function.jl") 24 | #include("quantify_internal_model.jl") [need to re-train models or include full training run in pretrained models] 25 | #include("analyse_variable_rollouts.jl") [need to re-train models or include pretrained models] 26 | 27 | println("\nFinished after ", (time()-tic)/60/60, " hours.") 28 | -------------------------------------------------------------------------------- /computational_model/analysis_scripts/shuffle_rollout_times.jl: -------------------------------------------------------------------------------- 1 | #in this script, we evaluate the performance of our original model 2 | #and compare this to a model where the rollout times have been shuffled 3 | #to see whether the structured rollout timings of the agent are important for performance 4 | 5 | # load scripts and model 6 | include("anal_utils.jl") 7 | using ToPlanOrNotToPlan 8 | 9 | println("comparing performance with real and shuffled rollout times") 10 | 11 | epoch = plan_epoch #model training epoch to use for evaluation (default to final epoch) 12 | results = Dict() #container for storing results 13 | batch = 50000 #number of episodes to simulate 14 | 15 | for seed = seeds #for each independently trained model 16 | 17 | plan_ts, Nact, Nplan = [], [], [] #containers for storing data 18 | results[seed] = Dict() #dict for this model 19 | for shuffle = [false; true] #run both the non-shuffled and shuffled replays 20 | Random.seed!(1) #set random seed for identical arenas across the two scenarios 21 | 22 | # load model parameters and create environment 23 | network, opt, store, hps, policy, prediction = recover_model("$(loaddir)N100_T50_Lplan8_seed$(seed)_$epoch") 24 | model_properties, environment, model_eval = build_environment( 25 | hps["Larena"], hps["Nhidden"], hps["T"], Lplan = hps["Lplan"], greedy_actions = greedy_actions, no_planning = shuffle) 26 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) #construct model 27 | 28 | #extract some useful parameters 29 | ed = environment.dimensions 30 | Nhidden = m.model_properties.Nhidden 31 | T = ed.T 32 | 33 | #initialize environment 34 | world_state, agent_input = environment.initialize(zeros(2), zeros(2), batch, m.model_properties, initial_params = []) 35 | agent_state = world_state.agent_state 36 | h_rnn = m.network[GRUind].cell.state0 .+ Float32.(zeros(Nhidden, batch)); #expand hidden state 37 | 38 | rews, as = [], [] 39 | rew = zeros(batch) 40 | iter = 0 41 | while any(world_state.environment_state.time .< (T+1 - 1e-2)) #run until completion 42 | iter += 1 #count iteration 43 | h_rnn, agent_output, a = m.forward(m, ed, agent_input, h_rnn) #RNN step 44 | a[rew .> 0.5] .= 1f0 #no 'planning' at reward 45 | 46 | if shuffle #if we're shuffling the replay times 47 | for b = 1:batch #for each episode 48 | if iter in plan_ts[b] #this is a shuffled time 49 | if rew[b] < 0.5 #if we're not at the reward 50 | a[b] = 5f0 #perform a rollout 51 | else #if we're at the reward location, resample a new rollout iteration 52 | remaining = Set(iter+1:Nact[b]-3) #set of remaining iteration 53 | options = setdiff(remaining, plan_ts[b]) #consider the ones where we are not already planning to do a rollout 54 | if length(options) > 0 #if there are other iterations left 55 | push!(plan_ts[b], rand(options)) #sample a new rollout iteration 56 | end 57 | end 58 | end 59 | end 60 | end 61 | 62 | active = (world_state.environment_state.time .< (T+1 - 1e-2)) #active episodes 63 | #take an environment step 64 | rew, agent_input, world_state, predictions = environment.step( 65 | agent_output, a, world_state, environment.dimensions, m.model_properties, 66 | m, h_rnn 67 | ) 68 | rew, agent_input, a = zeropad_data(rew, agent_input, a, active) #mask episodes that are finished 69 | push!(rews, rew); push!(as, a) #store rewards and actions from this iteration 70 | end 71 | 72 | rews, as = [reduce(vcat, arr) for arr = [rews, as]] #combine rewards and actions into array 73 | Nact, Nplan = sum(as .> 0.5, dims = 1), sum(as .> 4.5, dims = 1) #number of actions and number of plans in each episode 74 | plan_ts = [Set(randperm(max(Nact[b]-3, Nplan[b]))[1:Nplan[b]]) for b = 1:batch] #resample the iterations at which I should plan (avoiding last iterations) 75 | 76 | #print some summary data 77 | println("\n", seed, " shuffled: ", shuffle) 78 | println("reward: ", sum(rews .> 0.5) / batch) #reward 79 | println("rollout fraction: ", sum(as .> 4.5) / sum(as .> 0.5)) #planning fraction 80 | results[seed][shuffle] = rews #store the rewards for this experiment 81 | 82 | end 83 | end 84 | 85 | #store result 86 | @save "$(datadir)/performance_shuffled_planning.bson" results 87 | -------------------------------------------------------------------------------- /computational_model/figs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/figs/.gitkeep -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed61_1000_hps.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed61_1000_hps.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed61_1000_mod.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed61_1000_mod.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed61_1000_opt.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed61_1000_opt.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed61_1000_policy.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed61_1000_policy.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed61_1000_prediction.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed61_1000_prediction.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed61_1000_progress.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed61_1000_progress.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed62_1000_hps.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed62_1000_hps.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed62_1000_mod.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed62_1000_mod.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed62_1000_opt.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed62_1000_opt.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed62_1000_policy.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed62_1000_policy.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed62_1000_prediction.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed62_1000_prediction.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed62_1000_progress.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed62_1000_progress.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed63_1000_hps.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed63_1000_hps.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed63_1000_mod.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed63_1000_mod.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed63_1000_opt.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed63_1000_opt.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed63_1000_policy.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed63_1000_policy.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed63_1000_prediction.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed63_1000_prediction.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed63_1000_progress.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed63_1000_progress.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed64_1000_hps.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed64_1000_hps.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed64_1000_mod.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed64_1000_mod.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed64_1000_opt.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed64_1000_opt.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed64_1000_policy.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed64_1000_policy.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed64_1000_prediction.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed64_1000_prediction.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed64_1000_progress.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed64_1000_progress.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed65_1000_hps.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed65_1000_hps.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed65_1000_mod.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed65_1000_mod.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed65_1000_opt.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed65_1000_opt.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed65_1000_policy.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed65_1000_policy.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed65_1000_prediction.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed65_1000_prediction.bson -------------------------------------------------------------------------------- /computational_model/models/N100_T50_Lplan8_seed65_1000_progress.bson: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/models/N100_T50_Lplan8_seed65_1000_progress.bson -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_all.jl: -------------------------------------------------------------------------------- 1 | #### main text figures #### 2 | include("plot_utils.jl") 3 | 4 | using Suppressor 5 | @suppress_err begin 6 | 7 | #Fig 2 8 | println("\nplotting response time figure") 9 | include("plot_fig_RTs.jl") 10 | 11 | #Fig 3 12 | println("\nplotting RNN behavior figure") 13 | include("plot_fig_mechanism_behav.jl") 14 | 15 | #Fig 4 16 | println("\nplotting replay figure") 17 | include("plot_fig_replays.jl") 18 | 19 | #Fig 5 20 | println("\nplotting PG figure") 21 | include("plot_fig_mechanism_neural.jl") 22 | 23 | #### supplementary figures #### 24 | 25 | #Sfig 1 26 | println("\nplotting supplementary Euclidean comparison") 27 | include("plot_supp_human_euc_comparison.jl") 28 | 29 | #Sfig 2 30 | println("\nplotting supplementary human data") 31 | include("plot_supp_human_summary.jl") 32 | 33 | #Sfig 3 34 | println("\nplotting supplementary learning analyses") 35 | include("plot_supp_fig_network_size.jl") 36 | 37 | #Sfig 4 38 | println("\nplotting supplementary RT by step within trial") 39 | include("plot_supp_RT_by_step.jl") 40 | 41 | #Sfig 5 42 | println("\nplotting supplementary hp sweep") 43 | #include("plot_supp_hp_sweep.jl") 44 | 45 | #Sfig 6 46 | println("\nplotting supplementary value function analyses") 47 | include("plot_supp_values.jl") 48 | 49 | #Sfig 7 50 | println("\nplotting supplementary internal model") 51 | include("plot_supp_internal_model.jl") 52 | 53 | #Sfig 8 54 | println("\nplotting supplementary analyses with variable rollout durations") 55 | include("plot_supp_variable.jl") 56 | 57 | #Sfig 9 58 | println("\nplotting supplementary exploration analyses") 59 | include("plot_supp_exploration.jl") 60 | 61 | #Sfig 13 62 | println("\nplotting supplementary re-plan probabilities") 63 | include("plot_supp_plan_probs.jl") 64 | 65 | 66 | 67 | end 68 | 69 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_fig_mechanism_behav.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure 3 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | using ToPlanOrNotToPlan 5 | using Flux 6 | 7 | bot, top = 0.0, 1.0 8 | fig = figure(figsize = (17*cm, 3.0*cm)) 9 | 10 | #plot performance and entropy by number of rollouts 11 | 12 | #start by loading and extracting data 13 | @load "$datadir/perf_by_n_N100_Lplan8.bson" res_dict 14 | seeds = sort([k for k = keys(res_dict)]) 15 | Nseed = length(seeds) 16 | ms1, ms2, bs, es1, es2 = [], [], [], [], [] 17 | dists = 1:6; bydist = [] 18 | for (is, seed) = enumerate(seeds) #for each model 19 | #time within trial, distance to goal, and policy 20 | dts, mindists, policies = [res_dict[seed][k] for k = ["dts"; "mindists"; "policies"]] 21 | #select episodes where the trial finished for all rollout numbers 22 | keepinds = findall((.~isnan.(sum(dts, dims = (1,3))[:])) .& (mindists[:, 2] .>= 0)) 23 | new_dts = dts[:, keepinds, :] 24 | new_mindists = mindists[keepinds, 2] 25 | policies = policies[:, keepinds, :, :, :] 26 | #mean performance across episodes with (m1) and without (m2) rollout feedback 27 | m1, m2 = mean(new_dts[1,:,:], dims = 1)[:], mean(new_dts[2,:,:], dims = 1)[:] 28 | push!(bydist, reduce(vcat, [mean(new_dts[1,new_mindists .== dist,:], dims = 1) for dist = dists])) 29 | push!(ms1, m1); push!(ms2, m2); push!(bs, mean(new_mindists)) #also store optimal (bs)) 30 | p1, p2 = policies[1, :, :, :, :], policies[2, :, :, :, :] #extract log policies 31 | p1, p2 = [p .- Flux.logsumexp(p, dims = 4) for p = [p1, p2]] #normalize 32 | e1, e2 = [-sum(exp.(p) .* p, dims = 4)[:, :, :, 1] for p = [p1, p2]] #entropy 33 | m1, m2 = [mean(e[:,:,1], dims = 1)[:] for e = [e1,e2]] #only consider entropy of first action 34 | push!(es1, m1); push!(es2, m2) #store entropies 35 | end 36 | bydist = reduce((a,b) -> cat(a, b, dims = 3), bydist) 37 | bydist = mean(bydist, dims = 3)[:, :, 1] 38 | #concatenate across seeds 39 | ms1, ms2, es1, es2 = [reduce(hcat, arr) for arr = [ms1, ms2, es1, es2]] 40 | # compute mean and std across seeds 41 | m1, s1 = mean(ms1, dims = 2)[:], std(ms1, dims = 2)[:]/sqrt(Nseed) 42 | m2, s2 = mean(ms2, dims = 2)[:], std(ms2, dims = 2)[:]/sqrt(Nseed) 43 | me1, se1 = mean(es1, dims = 2)[:], std(es1, dims = 2)[:]/sqrt(Nseed) 44 | me2, se2 = mean(es2, dims = 2)[:], std(es2, dims = 2)[:]/sqrt(Nseed) 45 | nplans = (1:length(m1)) .- 1 # 46 | 47 | # plot performance vs number of rollouts 48 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.00, right=0.33, bottom = bot, top = 1.0, wspace=0.50) 49 | ax = fig.add_subplot(grids[1,1]) 50 | ax.plot(nplans,m1, ls = "-", color = col_p, label = "agent") #mean 51 | ax.fill_between(nplans,m1-s1,m1+s1, color = col_p, alpha = 0.2) #standard error 52 | plot([nplans[1]; nplans[end]], ones(2)*mean(bs), color = col_c, ls = "-", label = "optimal") #optimal baseline 53 | ax.plot(nplans,m2, ls = ":", color = col_c, label = "ctrl") #mean 54 | ax.fill_between(nplans,m2-s2,m2+s2, color = col_c, alpha = 0.2) #standard error 55 | legend(frameon = false, loc = "upper right", fontsize = fsize_leg, handlelength=1.5, handletextpad=0.5, borderpad = 0.0, labelspacing = 0.05) 56 | xlabel("# rollouts") 57 | ylabel("steps to goal") 58 | ylim(0.9*mean(bs), maximum(m1+s1)+0.1*mean(bs)) 59 | xticks([0;5;10;15]) 60 | 61 | # plot entropy vs number of rollouts 62 | ax = fig.add_subplot(grids[1,2]) 63 | ax.plot(nplans,me1, ls = "-", color = col_p, label = "agent") #mean 64 | ax.fill_between(nplans,me1-se1,me1+se1, color = col_p, alpha = 0.2) #standard error 65 | plot([nplans[1]; nplans[end]], ones(2)*log(4), color = col_c, ls = "-", label = "uniform") #entropy of uniform policy 66 | legend(frameon = false, fontsize = fsize_leg, handlelength=1.5, handletextpad=0.5, borderpad = 0.0, labelspacing = 0.05) 67 | xlabel("# rollouts") 68 | ylabel("entropy (nats)", labelpad = 1) 69 | ylim(0, 1.1*log(4)) 70 | xticks([0;5;10;15]) 71 | yticks([0; 1]) 72 | 73 | # plot performance with and without rollouts 74 | 75 | # load data across different random seeds 76 | @load "$(datadir)/performance_with_out_planning.bson" results 77 | ress = zeros(length(seeds), 2) 78 | for (i, plan) = enumerate([true; false]) 79 | for (iseed, seed) = enumerate(seeds) 80 | rews = results[seed][plan] 81 | ress[iseed, i] = sum(rews) / size(rews, 1) 82 | end 83 | end 84 | m, s = mean(ress, dims = 1)[:], std(ress, dims = 1)[:]/sqrt(length(seeds)) # mean and sem across seeds 85 | println("performance with and without rollouts:") #print result 86 | println(m, " ", s) 87 | 88 | # also add shuffled rollouts 89 | # load result across random seeds 90 | @load "$(datadir)/performance_shuffled_planning.bson" results 91 | ress_shuff = zeros(length(seeds), 2) 92 | for (i, shuffle) = enumerate([true; false]) 93 | for (iseed, seed) = enumerate(seeds) 94 | rews = results[seed][shuffle] 95 | ress_shuff[iseed, i] = sum(rews) / size(rews, 2) 96 | end 97 | end 98 | m, s = mean(ress_shuff, dims = 1)[:], std(ress_shuff, dims = 1)[:]/sqrt(length(seeds)) #mean and standard error 99 | println("shuffled performance: ", m, " (", s, ")") #print result 100 | 101 | ress = [ress ress_shuff[:, 1:1]] 102 | 103 | # plot result 104 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.405, right=0.505, bottom = 0, top = 1, wspace=0.15) 105 | ax = fig.add_subplot(grids[1,1]) 106 | plot_comparison(ax, ress; xticklabs=["rollout", "no roll", "shuffled"], ylab = "avg. reward", col = col_p, col2 = 1.2*col_p, yticks = [6;7;8], rotation = 45) 107 | 108 | # plot example goal directed and non-goal directed rollouts 109 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.53, right=0.69, bottom = -0.03, top = 0.80, wspace=0.15) 110 | ax = fig.add_subplot(grids[1,1]) 111 | 112 | # load example rollouts 113 | @load "$datadir/example_rollout.bson" store 114 | plan_state, ps, ws, state = store[10] 115 | rew_loc = state_from_onehot(4, ps) 116 | 117 | # plot arena 118 | arena_lines(ps, ws, 4, rew=false, col="k", rew_col = "k", col_arena = "k", lw_arena = lw_arena, lw_wall = lw_wall) 119 | labels = ["successful"; "unsuccessful"] 120 | for i = 1:2 #for successful and unsuccessful 121 | col = [col_p1, col_p2][i] #different colours 122 | 123 | #extract the rollout paths 124 | plan_state, ps, ws, state = store[[4;5][i]] 125 | plan_state = plan_state[1:sum(plan_state .> 0)] 126 | states = [state state_from_loc(4, plan_state')] #prepend original state 127 | N = size(states, 2) 128 | 129 | for s = 1:(N-1) #plot each line segment 130 | x1, y1, x2, y2 = [states[:, s]; states[:, s+1]] 131 | if s == 1 label = labels[i] else label = nothing end #labels for legend 132 | if x1 == 4 && x2 == 1 #pass through right wall 133 | ax.plot([x1; 4.5], [y1; 0.5*(y1+y2)], color = col, lw = linewidth) 134 | ax.plot([0.5; x2], [0.5*(y1+y2); y2], color = col, lw = linewidth) 135 | elseif x1 == 1 && x2 == 4 #left wall 136 | ax.plot([x2; 4.5], [y2; 0.5*(y1+y2)], color = col, lw = linewidth) 137 | ax.plot([0.5; x1], [0.5*(y1+y2); y1], color = col, lw = linewidth) 138 | elseif y1 == 4 && y2 == 1 #top wall 139 | ax.plot([x1; 0.5*(x1+x2)], [y1; 4.5], color = col, lw = linewidth) 140 | ax.plot([0.5*(x1+x2); x2], [0.5; y2], color = col, lw = linewidth) 141 | elseif y1 == 1 && y2 == 4 #bottom wall 142 | ax.plot([x2; 0.5*(x1+x2)], [y2; 4.5], color = col, lw = linewidth) 143 | ax.plot([0.5*(x1+x2); x1], [0.5; y1], color = col, lw = linewidth) 144 | else #just a normal line segment 145 | ax.plot([x1; x2], [y1; y2], color = col, label = label, lw = linewidth) 146 | end 147 | end 148 | end 149 | 150 | ax.scatter([state[1]], [state[2]], color = col_p, s = 150, zorder = 1000) #original loc 151 | ax.plot([rew_loc[1]], [rew_loc[2]], color = "k", marker="x", markersize=12, ls="", mew=3) #goal 152 | ax.legend(frameon = false, fontsize = fsize_leg, loc = "upper center", bbox_to_anchor = (0.5, 1.28), handlelength=1.5, handletextpad=0.5, borderpad = 0.0, labelspacing = 0.05) 153 | 154 | # plot change in policy with successful and unsuccessful rollouts 155 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.78, right=1.00, bottom = 0, top = top, wspace=0.10) 156 | for i = 1:2 #rewarded and non-rewarded rollout 157 | ms = [] 158 | for seed = seeds #iterate through random seeds 159 | @load "$(datadir)/causal_N100_Lplan8_$(seed)_$(plan_epoch).bson" data 160 | #rollouts action probability under new and old policy 161 | p_simulated_actions, p_simulated_actions_old = data["p_simulated_actions"], data["p_simulated_actions_old"] 162 | #rollout probabilities 163 | p_initial_sim, p_continue_sim = data["p_initial_sim"], data["p_continue_sim"] 164 | p_simulated_actions ./= (1 .- p_continue_sim) #renormalize over actions 165 | p_simulated_actions_old ./= (1 .- p_initial_sim) #renormalize over actions 166 | inds = findall(.~isnan.(sum(p_simulated_actions, dims = 1)[:])) #make sure we have data for both scenarios 167 | push!(ms, [mean(p_simulated_actions_old[i, inds]); mean(p_simulated_actions[i, inds])]) #mean for new and old 168 | end 169 | ms = reduce(hcat, ms) #concatenate across seeds 170 | m3, s3 = mean(ms, dims = 2)[1:2], std(ms, dims = 2)[1:2] / sqrt(length(seeds)) #mean and sem across seeds 171 | 172 | # plot results 173 | ax = fig.add_subplot(grids[1,i]) 174 | ax.bar(1:2, m3, yerr = s3, color = [col_p1, col_p2][i], capsize = capsize) 175 | # plot individual data points 176 | shifts = 1:size(ms, 2); shifts = (shifts .- mean(shifts))/std(shifts)*0.2 177 | ax.scatter([1 .+ shifts; 2 .+ shifts], [ms[1, :]; ms[2, :]], color = col_point, marker = ".", s = 15, zorder = 100) 178 | ax.set_xticks(1:2) 179 | ax.set_xticklabels(["pre"; "post"]) 180 | if i == 1 #successful rollout 181 | ax.set_ylabel(L"$\pi(\hat{a}_1)$", labelpad = 0) 182 | ax.set_title("succ.", fontsize = fsize) 183 | ax.set_yticks([0.1;0.3;0.5;0.7]) 184 | else #unsuccessful rollout 185 | ax.set_title("unsucc.", fontsize = fsize) 186 | ax.set_yticks([]) 187 | end 188 | #set some parameters 189 | ax.set_ylim(0.0, 0.8) 190 | ax.set_xlim([0.4; 2.6]) 191 | ax.axhline(0.25, color = color = col_c, ls = "-") 192 | end 193 | 194 | # add labels and save 195 | y1, y2 = 1.16, 0.46 196 | x1, x2, x3, x4, x5 = -0.05, 0.15, 0.34, 0.51, 0.71 197 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label, ) 198 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 199 | plt.text(x3,y1,"C";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 200 | plt.text(x4,y1,"D";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 201 | plt.text(x5,y1,"E";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 202 | 203 | savefig("./figs/fig_mechanism_behav.pdf", bbox_inches = "tight") 204 | savefig("./figs/fig_mechanism_behav.png", bbox_inches = "tight") 205 | close() 206 | 207 | 208 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_fig_mechanism_neural.jl: -------------------------------------------------------------------------------- 1 | #This script plots panels C-E of Figure 5 from Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | using MultivariateStats 5 | 6 | fig = figure(figsize = (17*cm, 3.0*cm)) 7 | 8 | # load data 9 | #@load "$(datadir)planning_as_pg.bson" res_dict 10 | @load "$(datadir)planning_as_pg_new.bson" res_dict; println("using new!!") 11 | #@load "$(datadir)planning_as_pg_exp.bson" res_dict; println("using exp!!") 12 | seeds = sort([k for k = keys(res_dict)]) 13 | 14 | # PCA plot of mean hidden state updates 15 | seed = 62 # example seed to plot 16 | alphas = res_dict[seed]["jacs"] # true hidden state updates 17 | actions = Int.(res_dict[seed]["sim_as"]) # rollout actions 18 | betas = res_dict[seed]["sim_gs"] # predicted hidden state updates 19 | betas = reduce(vcat, [betas[i:i, :, actions[i]] for i = 1:length(actions)]) # concatenate 20 | pca = MultivariateStats.fit(PCA, (betas .- mean(betas, dims = 1))'; maxoutdim=3) # perform PCA on the predicted changes 21 | Zb = predict(pca, (betas .- mean(betas, dims = 1))') # project into PC space 22 | Za = predict(pca, (alphas .- mean(alphas, dims = 1))') # project into PC space 23 | 24 | # plot result 25 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.19, right=0.42, bottom = -0.24, top = 0.98, wspace=0.05) 26 | ax = fig.add_subplot(grids[1,1], projection="3d") 27 | 28 | cols = [col_c, col_p, "g", "c"] # colours to use 29 | for a = 1:4 # for each action 30 | meanb = mean(Zb[:, actions .== a], dims = 2)[:] # mean predicted 31 | meanb = meanb / sqrt(sum(meanb.^2)) # normalize vector 32 | ax.plot3D([0; meanb[1]], [0; meanb[2]], [0; meanb[3]], ls = "-", color = cols[a], lw = 2) # plot predicted 33 | ax.scatter3D([meanb[1]], [meanb[2]], [meanb[3]], color = cols[a], s = 50) # plot end points 34 | meana = mean(Za[:, actions .== a], dims = 2)[:] # mean empirical 35 | meana = meana / sqrt(sum(meana.^2)) # normalize vector 36 | ax.plot3D([0; meana[1]], [0; meana[2]], [0; meana[3]], ls = ":", color = cols[a], lw = 3) # plot empirical 37 | end 38 | # add some labels 39 | ax.plot3D(zeros(2), zeros(2), zeros(2), ls = "-", color = "k", lw = 2, label = L"{\bf \alpha}^\mathrm{PG}_{1}") 40 | ax.plot3D(zeros(2), zeros(2), zeros(2), ls = ":", color = "k", lw = 3, label = L"{\bf \alpha}^\mathrm{RNN}") 41 | ax.set_xlabel("PC 1", labelpad = -16, rotation = 9); 42 | ax.set_ylabel("PC 2", labelpad = -17, rotation = 107); 43 | ax.set_zlabel("PC 3", labelpad = -17, rotation = 92); 44 | 45 | # set some plot parameters 46 | ax.view_init(elev=35., azim=75.) 47 | ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([]) 48 | ax.legend(frameon = false, ncol = 2, bbox_to_anchor = (0.52, 1.05), loc = "upper center",columnspacing=1, 49 | fontsize = fsize_leg, borderpad = 0.0, labelspacing = 0.2, handlelength = 1.3,handletextpad=0.4, handleheight=1) 50 | 51 | # plot discrepancy between expected and empirical 52 | meanspca, meanspca2 = [[[] for j = 1:3] for _ = 1:2] 53 | # iterate through alpha_RNN, alpha_RNN_ctrl, and alpha_RNN_ctrl2 [second action switched] 54 | for (ij, jkey) = enumerate(["jacs"; "jacs_shift"; "jacs_shift2"]) 55 | for seed = seeds # iterate through models 56 | sim_as, sim_a2s = [res_dict[seed][k] for k = ["sim_as", "sim_a2s"]] # rollout actions 57 | jacs, gs, gs2 = [copy(res_dict[seed][k]) for k = [jkey, "sim_gs", "sim_gs2"]] # alpha_RNN, alpha_PG1, alpha_PG2 58 | inds, inds2 = 1:length(sim_as), findall(.~isnan.(sim_a2s)) # actions to consider (min rollout length of 2 for alpha_PG2) 59 | betas = reduce(vcat,[gs[i:i,:,Int(sim_as[i])] for i=inds]) # alpha_PG1 60 | betas2 = reduce(vcat,[gs2[i:i,:,Int(sim_a2s[i])] for i=inds2]) # alpha_PG2 61 | 62 | #pis = res_dict[seed]["all_pis"][:, [CartesianIndex()], :] # inser middle dimension 63 | #betasexp = sum(gs .* pis, dims = 3)[:, :, 1] # E_a\sim(pi) [alpha_PG1] (this is zero) 64 | 65 | jacs, betas, betas2 = [arr .- mean(arr, dims = 1) for arr = [jacs, betas, betas2]] # mean-subtract 66 | jacs, betas, betas2 = [arr ./ sqrt.(sum(arr.^2, dims = 2)) for arr = [jacs, betas, betas2]] # normalize 67 | # compute angles in PC space 68 | pca, pca2 = [MultivariateStats.fit(PCA, (beta .- mean(beta, dims = 1))'; maxoutdim=3) for beta = [betas, betas2]] 69 | Za,Zb = [predict(pca, (vals .- mean(vals, dims = 1))') for vals = [jacs, betas]] # project into PC space 70 | Za2,Zb2 = [predict(pca2, (vals .- mean(vals, dims = 1))') for vals = [jacs[inds2,:], betas2]] # project into PC space 71 | Za,Za2,Zb,Zb2 = [arr ./ sqrt.(sum(arr.^2, dims = 2)) for arr = [Za, Za2, Zb, Zb2]] # normalize 72 | push!(meanspca[ij], mean(sum(Za .* Zb, dims = 2)[:, 1, :], dims = 1)[1]) # cos \theta for alpha_PG1 73 | push!(meanspca2[ij], mean(sum(Za2 .* Zb2, dims = 2)[:, 1, :], dims = 1)[1]) # cos \theta for alpha_PG2 74 | end 75 | end 76 | meanspca, meanspca2 = [reduce(hcat, ms) for ms = [meanspca, meanspca2]] # concatenate results 77 | 78 | # plot results 79 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.49, right=0.76, bottom = 0.0, top = 1.0, wspace=0.25) 80 | labs = repeat([L"$1^\mathrm{st}$ action", L"$2^\mathrm{nd}$ action"], outer = 2) 81 | global limy = nothing # instantiate ylim parameter 82 | for (ires, res) = enumerate([meanspca, meanspca2]) # first action and seconnd action 83 | ax = fig.add_subplot(grids[1,ires]) 84 | res = res[:, [1; ires+1]] # data for this action 85 | mus, ss = mean(res, dims = 1)[:], std(res, dims = 1)[:]/sqrt(length(seeds)) # summary statistics 86 | println("action $ires: mu = ", mus, ", sem = ", ss) # print result 87 | ax.bar([1;2], mus, yerr = ss, color = col_p, capsize = capsize) # bar plot 88 | # plot individual data points 89 | shifts = 1:size(res, 1); shifts = (shifts .- mean(shifts))/std(shifts)*0.2 # add some jitter 90 | ax.scatter([1 .+ shifts; 2 .+ shifts], [res[:, 1]; res[:, 2]], color = col_point, marker = ".", s = 15, zorder = 100) 91 | ax.axhline(0, color = col_c, lw = 2) # baseline 92 | if ires == 1 # set some plotting parameters 93 | vmin, vmax = minimum(mus-ss), maximum(mus+ss) 94 | global limy = [vmin - 0.1*(vmax-vmin); vmax+0.1*(vmax-vmin)] 95 | ax.set_ylabel(L"$\cos \theta$", labelpad = -0.07) 96 | else 97 | global limy = limy 98 | ax.set_yticks([]) 99 | end 100 | # set some labels etc. 101 | ax.set_xticks([1;2]) 102 | ax.set_xticklabels([L"${\bf \alpha}^\mathrm{RNN}$", L"${\bf \alpha}^\mathrm{RNN}_\mathrm{ctrl}$"]) 103 | ax.set_ylim(limy) 104 | ax.set_title(labs[ires], fontsize = fsize, pad = 0.04) 105 | ax.set_xlim(0.5, 2.5) 106 | end 107 | 108 | # Plot rollout frequency by network size 109 | 110 | # load our stored data 111 | @load "$datadir/rew_and_plan_by_n.bson" res_dict 112 | # extract the relevant data 113 | meanrews, pfracs, seeds, Nhiddens, epochs = [res_dict[k] for k = ["meanrews", "planfracs", "seeds", "Nhiddens", "epochs"]]; 114 | 115 | # only consider second epoch onwards 116 | i1, N = 2, sum(epochs .<= plan_epoch) 117 | mms = mean(meanrews, dims = 2)[:, 1, i1:N] # mean reward across seeds 118 | sms = std(meanrews, dims = 2)[:, 1, i1:N] / sqrt(length(seeds)) # standard error 119 | mps = mean(pfracs, dims = 2)[:, 1, i1:N] # mean rollout frequency across seeds 120 | sps = std(pfracs, dims = 2)[:, 1, i1:N] / sqrt(length(seeds)) # standard error 121 | 122 | # plot the data 123 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.87, right=1.0, bottom = 0, top = 1.0, wspace=0.15) 124 | ax = fig.add_subplot(grids[1,1]) 125 | ax.axhline(0.2, ls = "-", color = col_c) # baseline 126 | for (ihid, Nhidden) = enumerate(Nhiddens) # for each network sizes 127 | frac = 0.45*(Nhidden - minimum(Nhiddens))/(maximum(Nhiddens) - minimum(Nhiddens)) .+ 0.76 128 | col = col_p * frac # colour 129 | # plot mean and standard error 130 | ax.plot(mms[ihid, :], mps[ihid, :], ls = "-", color = col, label = Nhidden) 131 | ax.fill_between(mms[ihid, :], mps[ihid, :]-sps[ihid, :], mps[ihid, :]+sps[ihid, :], color = col, alpha = 0.2) 132 | end 133 | # set some labels and other plotting parameters 134 | ax.set_xlabel("mean reward") 135 | ax.set_ylabel(L"$p$"*"(rollout)") 136 | ax.set_ylim(0, 0.65) 137 | ax.set_xticks([0;4;8]) 138 | ax.legend(frameon = false, fontsize = fsize_leg, handlelength=1.5, handletextpad=0.5, borderpad = 0.0, 139 | labelspacing = 0.05, loc = "lower center", bbox_to_anchor = (0.75, -0.035)) 140 | 141 | # save figure 142 | savefig("./figs/fig_mechanism_neural.pdf", bbox_inches = "tight") 143 | savefig("./figs/fig_mechanism_neural.png", bbox_inches = "tight") 144 | close() 145 | 146 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_RT_by_step.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S4 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | 5 | # start by loading the data for our human participants 6 | @load "$datadir/RT_predictions_N100_Lplan8_1000.bson" data 7 | RTs_by_u, pplans_by_u, dists_by_u, steps_by_u = [data[k] for k = ["RTs_by_u", "pplans_by_u", "dists_by_u", "steps_by_u"]]; 8 | 9 | fig = figure(figsize = (10*cm, 3*cm)) 10 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.00, right=1.00, bottom = 0.0, top = 1.0, wspace=0.35) 11 | 12 | xsteps = 1:3 # steps within trial to consider 13 | xdists = 1:6 # distance to goal to consider 14 | # colors for plots 15 | cols = [[[0;0;0], [0.4;0.4;0.4], [0.7;0.7;0.7]], [[0.00;0.19;0.52], [0.24;0.44;0.77], [0.49;0.69;1.0]]] 16 | 17 | for (idat, dat) = enumerate([RTs_by_u, pplans_by_u]) # for humans and model 18 | fig.add_subplot(grids[1,idat]) 19 | for xstep = xsteps # for each step within trial 20 | mus = zeros(length(dat), length(xdists)) .+ NaN # initialize data array 21 | for u = 1:length(dat) # for each model/participant 22 | for d = xdists # for each distance to goal 23 | # relevant actions 24 | inds = findall((dists_by_u[u] .== d) .& (steps_by_u[u] .== -xstep)) 25 | if length(inds) >= 1 # if we have at least one action satisfying these criteria 26 | mus[u, d] = mean(dat[u][inds]) # store the data 27 | end 28 | end 29 | end 30 | m = nanmean(mus, dims = 1)[:] # mean across models/participants 31 | s = nanstd(mus, dims = 1)[:] ./ sqrt.(sum(.~isnan.(mus), dims = 1)[:]) # standard error 32 | 33 | plot(xdists, m, label = "step = $xstep", color = cols[idat][xstep]) # plot mean 34 | fill_between(xdists, m-s, m+s, color = cols[idat][xstep], alpha = 0.2) # plot standard error 35 | end 36 | xlabel("distance to goal") 37 | if idat == 1 38 | ylabel("thinking time") # human 39 | else 40 | ylabel(L"$\pi($"*"rollout"*L"$)$") # model 41 | end 42 | legend(frameon = false, fontsize = fsize_leg) 43 | end 44 | 45 | # add labels and save figure 46 | y1 = 1.16 47 | x1, x2 = -0.13, 0.45 48 | fsize = fsize_label 49 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 50 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 51 | 52 | savefig("./figs/supp_plan_by_step.pdf", bbox_inches = "tight") 53 | savefig("./figs/supp_plan_by_step.png", bbox_inches = "tight") 54 | close() 55 | 56 | 57 | # compute and print residual correlations between pi(rollout) and human RT 58 | allcors = [] # array to store correlations 59 | for u = 1:length(RTs_by_u) # for each user 60 | mean_sub_RTs, mean_sub_pplans = [], [] # mean subtracted thinking times and rollouts probs 61 | for dist = 1:20 #for each distance-to-goal 62 | for xstep = 1:100 #for each step-within-trial 63 | inds = (dists_by_u[u] .== dist) .& (steps_by_u[u] .== -xstep) # corresponding indices 64 | if sum(inds) >= 2 #require at least 2 data points 65 | new_RTs, new_pplans = RTs_by_u[u][inds], pplans_by_u[u][inds] #find the corresponding RTs and pi(rollout) 66 | mean_sub_RTs = [mean_sub_RTs; new_RTs .- mean(new_RTs)] #subtract mean of RTs and append 67 | mean_sub_pplans = [mean_sub_pplans; new_pplans .- mean(new_pplans)] #subtract mean of pi(rollout) and append 68 | end 69 | end 70 | end 71 | push!(allcors, cor(mean_sub_RTs, mean_sub_pplans)) # store residual correlation for this participant 72 | end 73 | # print result 74 | println("mean and standard error of residual correlation:") 75 | println(mean(allcors), " ", std(allcors)/sqrt(length(allcors))) 76 | 77 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_exploration.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S9 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | 5 | fig = figure(figsize = (10.5*cm, 7*cm)) 6 | grids = fig.add_gridspec(nrows=2, ncols=2, left=0.00, right=1.00, bottom = 0, top = 1.0, wspace=0.61, hspace = 0.7) 7 | 8 | # we start by plotting the accuracy of the internal world model as a function of the number of states visited during exploration 9 | 10 | # load data from different model seeds 11 | global ms, unums = [], [] 12 | for seed = seeds 13 | @load "$(datadir)model_exploration_predictions_$(seed)_$plan_epoch.bson" data 14 | global unums, dec_perfs = data 15 | push!(ms, dec_perfs) 16 | end 17 | ms = reduce(hcat, ms) # concatenate across seeds 18 | m, s = mean(ms, dims = 2)[:], std(ms, dims = 2)[:]/sqrt(size(ms, 2)) # mean and standard error across seeds 19 | 20 | # now plot our results 21 | ax = fig.add_subplot(grids[1,1]) 22 | ax.plot(unums, 1 ./ (16 .- unums), color = col_c, label = "optimal", zorder = -1000) 23 | ax.plot(unums, m, color = col_p, label = "agent") 24 | ax.fill_between(unums, m-s, m+s, color = col_p, alpha = 0.2) 25 | ax.set_xlabel("states visited") 26 | ax.set_ylabel("accuracy") 27 | ax.legend(frameon = false, fontsize = fsize_leg) 28 | ax.set_ylim(0.0, 0.7) 29 | ax.set_xlim(unums[1], unums[end]) 30 | 31 | # plot human thinking time against pi(rollout) during exploration 32 | 33 | # load our results 34 | @load "$datadir/RT_predictions_N100_Lplan8_explore_1000.bson" data 35 | allsims, RTs, pplans = [data[k] for k = ["correlations"; "RTs_by_u"; "pplans_by_u"]]; 36 | RTs, pplans = [reduce(vcat, arr) for arr = [RTs, pplans]]; 37 | bins = 0.05:0.05:0.70 38 | xs = 0.5*(bins[1:length(bins)-1] + bins[2:end]) 39 | 40 | # bin data and generate shuffled control 41 | RTs_shuff = RTs[randperm(length(RTs))] 42 | dat = [RTs[(pplans .>= bins[i]) .& (pplans .< bins[i+1])] for i = 1:length(bins)-1] 43 | dat_shuff = [RTs_shuff[(pplans .>= bins[i]) .& (pplans .< bins[i+1])] for i = 1:length(bins)-1] 44 | 45 | # mean and standard error 46 | m, m_c = [[mean(d) for d = dat] for dat = [dat, dat_shuff]] 47 | s, s_c = [[std(d)/sqrt(length(d)) for d = dat] for dat = [dat, dat_shuff]] 48 | 49 | # plot result 50 | ax = fig.add_subplot(grids[1,2]) 51 | ax.bar(xs, m, color = col_p, width = 0.04, linewidth = 0, label = "data") 52 | ax.errorbar(xs, m, yerr = s, fmt = "none", color = "k", capsize = 2, lw = 1.5) 53 | ax.errorbar(xs, m_c, yerr = s_c, fmt = "-", color = col_c, capsize = 2, lw = 1.5, label = "shuffle") 54 | ax.set_xlabel(L"$\pi$"*"(rollout)") 55 | ax.set_ylabel("thinking time (ms)") 56 | ax.legend(frameon = false, fontsize = fsize_leg) 57 | 58 | # print result 59 | m = mean(allsims, dims = 1) 60 | s = std(allsims, dims = 1) / sqrt(size(allsims, 1)) 61 | println("correlations mean and sem: ", m, " ", s) 62 | 63 | # plot thinking time against the number of unique states visited during exploration for RL agent 64 | 65 | uvals = 2:15 # number of unique states to consider (ignore very first action) 66 | RTs_us = zeros(length(seeds), length(uvals)) .+ NaN # initialize array for storing thinking times 67 | 68 | for (iseed, seed) = enumerate(seeds) # for each model seed 69 | @load "$(datadir)model_unique_states_$(seed)_1000.bson" data # load result 70 | RTs, unique_states = data # extract data 71 | new_us, new_rts = [], [] 72 | for b = 1:size(RTs, 1) # for each episode 73 | us = unique_states[b, :] # unique state counts 74 | inds = 2:sum(.~isnan.(us)) # actions of episode 1 75 | rts = RTs[b, inds] # thinking times for these actions 76 | push!(new_us, us[inds]); push!(new_rts, rts) # store data 77 | end 78 | new_us, new_rts = reduce(vcat, new_us), reduce(vcat, new_rts) # concatenate results across episodes 79 | RTs_us[iseed, :] = [mean(new_rts[new_us .== uval]) .- 1 for uval = uvals]*120 # mean value for this model in ms 80 | end 81 | 82 | # plot result 83 | m, s = mean(RTs_us, dims = 1)[:], std(RTs_us, dims = 1)[:]/sqrt(size(RTs_us, 1)) # mean and standard error across models 84 | ax = fig.add_subplot(grids[2,1]) 85 | ax.plot(uvals, m, ls = "-", color = col_p) 86 | ax.fill_between(uvals, m-s, m+s, color = col_p, alpha = 0.2) 87 | ax.set_xlabel("states visited") 88 | ax.set_ylabel("thinking time (ms)") 89 | ax.set_xlim(uvals[1], uvals[end]) 90 | 91 | # plot thinking time against unique states for human participants 92 | 93 | # load our data 94 | @load "$(datadir)/human_RT_and_rews_follow.bson" data 95 | keep = findall([nanmean(RTs) for RTs = data["all_RTs"]] .< 690) 96 | Nkeep = length(keep) # subjects to consider 97 | @load "$datadir/guided_lognormal_params_delta.bson" params #load prior parameters 98 | 99 | @load "$(datadir)unique_states_play.bson" data 100 | all_RTs, all_unique_states = data # response time and state informationduring exploration 101 | RTs_us = zeros(Nkeep, length(uvals)) .+ NaN # initialize array to store data 102 | for (i_u, u) = enumerate(keep) # for each participant 103 | new_us, new_rts = [], [] 104 | for b = 1:size(all_RTs[u], 1) # for each episode 105 | us = all_unique_states[u][b, :] # unique states for this episode 106 | inds = 2:sum(.~isnan.(us)) # indices of first trial 107 | rts = all_RTs[u][b, inds] # response times 108 | # compute posterior mean thinking times 109 | later = params["later"][u, :] 110 | later_post_mean(r) = calc_post_mean(r, muhat=later[1], sighat=later[2], deltahat=later[3], mode = false) 111 | rts = later_post_mean.(all_RTs[u][b, inds]) #posterior mean 112 | push!(new_us, us[inds]); push!(new_rts, rts) # store our results 113 | end 114 | new_us, new_rts = reduce(vcat, new_us), reduce(vcat, new_rts) # concatenate across episodes 115 | RTs_us[i_u, :] = [mean(new_rts[new_us .== uval]) for uval = uvals] # mean values for this participant 116 | end 117 | 118 | # plot our results 119 | m, s = nanmean(RTs_us, dims = 1)[:], nanstd(RTs_us, dims = 1)[:]/sqrt(size(RTs_us, 1)) # mean and standard error across participants 120 | ax = fig.add_subplot(grids[2,2]) 121 | ax.plot(uvals, m, "k-") 122 | ax.fill_between(uvals, m-s, m+s, color = "k", alpha = 0.2) 123 | ax.set_xlabel("states visited") 124 | ax.set_ylabel("thinking time (ms)") 125 | ax.set_xlim(uvals[1], uvals[end]) 126 | 127 | # add labels and save figure 128 | y1 = 1.08 129 | y2 = 0.47 130 | x1, x2 = -0.18, 0.43 131 | fsize = fsize_label 132 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 133 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 134 | plt.text(x1,y2,"C";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 135 | plt.text(x2,y2,"D";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 136 | savefig("./figs/supp_exploration.pdf", bbox_inches = "tight") 137 | savefig("./figs/supp_exploration.png", bbox_inches = "tight") 138 | close() 139 | 140 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_fig_network_size.jl: -------------------------------------------------------------------------------- 1 | ## This script plots Figure S3 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | using NaNStatistics 5 | 6 | ## instantiate figure 7 | fig = figure(figsize = (10*cm, 7*cm)) 8 | 9 | # load data 10 | @load "$datadir/rew_and_plan_by_n.bson" res_dict 11 | meanrews, pfracs, seeds, Nhiddens, epochs = [res_dict[k] for k = ["meanrews", "planfracs", "seeds", "Nhiddens", "epochs"]] 12 | 13 | # extract reward and plan data 14 | mms = mean(meanrews, dims = 2)[:, 1, :] # mean across agents 15 | sms = std(meanrews, dims = 2)[:, 1, :] / sqrt(length(seeds)) # standard error 16 | mps = mean(pfracs, dims = 2)[:, 1, :] # mean 17 | sps = std(pfracs, dims = 2)[:, 1, :] / sqrt(length(seeds)) # standard error 18 | 19 | ## convert from epochs to episodes 20 | xs = epochs*40*200 / 1000000 21 | 22 | # for both reward and planning fraction 23 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.00, right=1.00, bottom = 0.0, top = 0.4, wspace=0.45, hspace=0.60) 24 | for (idat, dat) = enumerate([(mms, sms), (mps, sps)]) 25 | ax = fig.add_subplot(grids[1,idat]) # new subplot 26 | m, s = dat # mean and standard error 27 | for (ihid, Nhidden) = enumerate(Nhiddens) # for each network size 28 | frac = (Nhidden - minimum(Nhiddens))/(maximum(Nhiddens) - minimum(Nhiddens)) 29 | frac = (0.45*frac .+ 0.76) 30 | col = col_p * frac 31 | # plot mean and sem 32 | ax.plot(xs, m[ihid, :], ls = "-", color = col, label = Nhidden) 33 | ax.fill_between(xs, m[ihid, :]-s[ihid, :], m[ihid, :]+s[ihid, :], color = col, alpha = 0.2) 34 | end 35 | 36 | # set some axis labels etc. 37 | ax.set_xlabel("training episodes (x"*L"$10^6$"*")") 38 | if idat == 1 39 | ax.legend(frameon = false, fontsize = fsize_leg, handlelength=1.5, handletextpad=0.5, borderpad = 0.0, labelspacing = 0.05) 40 | ax.set_ylabel("mean reward") 41 | ax.set_ylim(0, 9) 42 | else 43 | ax.set_ylabel(L"$p$"*"(rollout)") 44 | ax.set_ylim(0, 0.65) 45 | ax.axhline(0.2, ls = "--", color = "k") 46 | end 47 | ax.set_xticks(0:2:8) 48 | ax.set_xlim(0,8) 49 | end 50 | 51 | 52 | ## add human data 53 | 54 | @load "$(datadir)/human_RT_and_rews_follow.bson" data; data_follow = data # guided episodes 55 | means = [nanmean(RTs) for RTs = data_follow["all_RTs"]] 56 | keep = findall(means .< 690) # non-outlier users 57 | Nkeep = length(keep) 58 | 59 | @load "$(datadir)/human_all_data_play.bson" data; 60 | _, _, _, _, all_rews_p, all_RTs_p, all_trial_nums_p, _ = data; 61 | @load "$datadir/guided_lognormal_params_delta.bson" params # parameters of prior distributions 62 | all_TTs, all_DTs = [], [] 63 | for u = keep # for each participant 64 | rts, tnums = all_RTs_p[u], all_trial_nums_p[u] # RTs and trial numbers 65 | new_TTs, new_DTs = [zeros(size(rts)) .+ NaN for _ = 1:2] # thinking time and delay times 66 | initial, later = params["initial"][u, :], params["later"][u, :] 67 | # functions for computing posterior means 68 | initial_post_mean(r) = calc_post_mean(r, muhat=initial[1], sighat=initial[2], deltahat=initial[3], mode = false) 69 | later_post_mean(r) = calc_post_mean(r, muhat=later[1], sighat=later[2], deltahat=later[3], mode = false) 70 | tnum = 1 71 | for ep = 1:size(rts, 1) # for each episode 72 | for b = 1:sum(tnums[ep, :] .> 0.5) # for each action 73 | t, rt = tnums[ep, b], rts[ep, b] # trial number and response time 74 | if b > 1.5 # discard very first action 75 | if t == tnum # same trial as before 76 | new_TTs[ep, b] = later_post_mean(rt) 77 | else # first action of new trial 78 | new_TTs[ep, b] = initial_post_mean(rt) 79 | end 80 | new_DTs[ep, b] = rt - new_TTs[ep, b] # delays is response time minus thinking time 81 | end 82 | tnum = t 83 | end 84 | end 85 | push!(all_TTs, Float64.(new_TTs)) 86 | push!(all_DTs, Float64.(new_DTs)) 87 | end 88 | # store data 89 | all_RTs = [all_TTs[i] + all_DTs[i] for i = 1:length(all_TTs)] 90 | 91 | ## 92 | # combine data 93 | rews_by_episode = reduce(hcat, [nansum(rews, dims = 2) for rews = all_rews_p]) 94 | RTs_by_episode = reduce(hcat, [nanmedian(RTs, dims = 2) for RTs = all_RTs]) 95 | TTs_by_episode = reduce(hcat, [nanmedian(TTs, dims = 2) for TTs = all_TTs]) 96 | DTs_by_episode = reduce(hcat, [nanmedian(DTs, dims = 2) for DTs = all_DTs]) 97 | function permtest(m, label) 98 | # run permutation test for significance 99 | cval = cor(1:length(m), m) # correlation 100 | Niter = 10000; ctrl = zeros(Niter) 101 | for n = 1:Niter 102 | ctrl[n] = cor(randperm(length(m)), m) 103 | end 104 | println("$(label) correlation of mean: $cval, permutation p = $(mean(ctrl .> cval))") 105 | end 106 | 107 | ## first plot rews 108 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.00, right=0.175, bottom = 0.63, top = 1.0, wspace=0.45, hspace=0.60) 109 | 110 | dat = rews_by_episode 111 | mr = mean(dat, dims = 2)[:] # mean 112 | sr = std(dat, dims = 2)[:] / sqrt(size(dat, 2)) #standard error 113 | xs = 1:length(mr) 114 | ax = fig.add_subplot(grids[1,1]) # new subplot 115 | permtest(mr, "reward") 116 | cvals = [cor(xs, dat[:, i]) for i = 1:size(dat, 2)] 117 | println("reward mean & sem across people: $(mean(cvals)) & $(std(cvals)/sqrt(length(cvals)))") 118 | 119 | ax.plot(xs, mr, ls = "-", color = col_h) 120 | ax.fill_between(xs, mr-sr, mr+sr, color = col_h, alpha = 0.2) 121 | # set some axis labels etc. 122 | ax.set_xlabel("episode") 123 | ax.set_ylabel("mean reward") 124 | ax.set_ylim(6.5, 9.5) 125 | ax.set_xticks(0:12:38) 126 | ax.set_xlim(0,38) 127 | 128 | ## now plot RTs 129 | grids = fig.add_gridspec(nrows=1, ncols=3, left=0.330, right=1.00, bottom = 0.63, top = 1.0, wspace=0.55, hspace=0.60) 130 | 131 | labels = ["response"; "thinking"; "delay"] 132 | lss = ["-", "-", "-"] 133 | all_cvals = [] 134 | for (idat, dat) = enumerate([RTs_by_episode, TTs_by_episode, DTs_by_episode]) 135 | ax = fig.add_subplot(grids[1,idat]) # new subplot 136 | m = mean(dat, dims = 2)[:] # mean 137 | s = std(dat, dims = 2)[:] / sqrt(size(dat, 2)) #standard error 138 | xs = 1:length(m) 139 | 140 | # run permutation test for significance 141 | permtest(m, labels[idat]) 142 | cvals = [cor(xs, dat[:, i]) for i = 1:size(dat, 2)] 143 | push!(all_cvals, cvals) 144 | println("$(labels[idat]) mean & sem across people: $(mean(cvals)) & $(std(cvals)/sqrt(length(cvals)))") 145 | 146 | ax.plot(xs, m, ls = lss[idat], color = col_h) 147 | ax.fill_between(xs, m-s, m+s, color = col_h, alpha = 0.2) 148 | ax.set_xlabel("episode") 149 | if idat == 1 150 | ax.set_ylabel("time (ms)") 151 | end 152 | ax.set_xticks(0:12:38) 153 | ax.set_xlim(0,38) 154 | ax.set_title(labels[idat], fontsize = fsize) 155 | end 156 | diffs = all_cvals[3] - all_cvals[2] 157 | println("delay - thinking cvals: m = $(mean(diffs)), sem = $(std(diffs)/sqrt(length(diffs)))") 158 | 159 | ## add panel labels and save 160 | y1, y2 = 1.08, 0.45 161 | x1, x2 = -0.09, 0.45 162 | fsize = fsize_label 163 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 164 | plt.text(x2-0.25,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 165 | plt.text(x1,y2,"C"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 166 | plt.text(x2,y2,"D";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 167 | 168 | 169 | savefig("./figs/supp_fig_by_size.pdf", bbox_inches = "tight") 170 | savefig("./figs/supp_fig_by_size.png", bbox_inches = "tight") 171 | close() 172 | 173 | 174 | ### print avg rew and RT for first/last 5 episodes ### 175 | 176 | mr1, mr2 = (mean(mr[1:5])), (mean(mr[34:38])) 177 | println("mean first five rew: $mr1") 178 | println("mean last five rew: $mr2") 179 | 180 | mt = mean(RTs_by_episode, dims = 2)[:] # mean 181 | mt1, mt2 = (mean(mt[1:5])), (mean(mt[34:38])) 182 | println("mean first five RT: $mt1") 183 | println("mean last five RT: $mt2") 184 | 185 | println("scale up: ", mr1*mt1/mt2) 186 | 187 | println("rel time: ", abs(mt2-mt1)/mt1*100) 188 | println("rel rew: ", abs(mr2-mr1)/mr1*100) 189 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_hp_sweep.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S5 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | using BSON: @load 5 | using Random, NaNStatistics, Statistics 6 | cm = 1/2.54 7 | 8 | fig = figure(figsize = (17*cm, 7.5*cm)) 9 | bot, top = 0.72, 0.38 10 | params = [(60,4),(60,8),(60,12),(100,4),(100,8),(100,12),(140,4),(140,8),(140,12)] 11 | prefix = "hp_sweep_" 12 | Nseed = length(seeds) 13 | 14 | ### plot human RT corr ### 15 | all_vals, all_errs, all_ticklabels, all_corrs = [], [], [], [] 16 | for (ip, p) = enumerate(params) 17 | N, Lplan = p 18 | savename = "hp_sweep_N$(N)_Lplan$(Lplan)_1000_weiji" 19 | @load "$datadir/RT_predictions_$savename.bson" data; 20 | allsims = data["correlations"] 21 | m1, s1 = mean(allsims[:, 1]), std(allsims[:, 1])/sqrt(size(allsims, 1)) 22 | push!(all_vals, m1); push!(all_errs, s1) 23 | 24 | println(p, ": ", m1, " ", s1) 25 | corrs = allsims[:, 1] 26 | println(minimum(corrs), " ", maximum(corrs)) 27 | push!(all_corrs, corrs) 28 | push!(all_ticklabels, string(p)) 29 | end 30 | 31 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.04, right=0.47, bottom = bot, top = 1.0, wspace=0.35) 32 | ax = fig.add_subplot(grids[1,1]) 33 | ax.bar(1:length(all_vals), all_vals, yerr = all_errs, capsize = capsize,color = col_p) 34 | for (i, corrs) = enumerate(all_corrs) # plot individual data points 35 | shifts = 1:length(corrs); shifts = (shifts .- mean(shifts))/std(shifts)*0.15 36 | ax.scatter(i .+ shifts, all_corrs, color = col_point, marker = ".", s = 3, alpha = 0.5, zorder = 100) 37 | end 38 | ax.set_xticks(1:length(all_vals)) 39 | ax.set_xticklabels(all_ticklabels, rotation = 45, ha = "right") 40 | ax.set_ylim(0,0.22) 41 | ax.set_ylim(-0.05, 0.35) 42 | ylabel("correlation with\nthinking time"); ax.set_yticks(0:0.1:0.3) 43 | 44 | ### plot delta perf with rollouts ### 45 | 46 | all_vals, all_errs, all_ticklabels, all_diffs = [], [], [], [] 47 | for (ip, p) = enumerate(params) 48 | N, Lplan = p 49 | savename = "$(prefix)N$(N)_Lplan$(Lplan)" 50 | @load "$datadir/perf_by_n_$(savename).bson" res_dict 51 | seeds = sort([k for k = keys(res_dict)]) 52 | Nseed = length(seeds) 53 | ms = [] 54 | t0, t1 = 1, 6 55 | for (is, seed) = enumerate(seeds) 56 | dts, mindists, policies = [res_dict[seed][k] for k = ["dts"; "mindists"; "policies"]] 57 | keepinds = findall((.~isnan.(sum(dts[1, :, t0:t1], dims = (2))[:])) .& (mindists[:, 2] .>= 0)) 58 | new_dts = dts[:, keepinds, :] 59 | m = mean(new_dts[1,:,:], dims = 1)[:] 60 | push!(ms, m[t0]-m[t1]) 61 | end 62 | ms = reduce(hcat, ms) 63 | m1, s1 = mean(ms), std(ms)/sqrt(Nseed) 64 | push!(all_vals, m1) 65 | push!(all_errs, s1) 66 | push!(all_diffs, ms) 67 | push!(all_ticklabels, string(p)) 68 | end 69 | 70 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.57, right=1.0, bottom = bot, top = 1.0, wspace=0.35) 71 | ax = fig.add_subplot(grids[1,1]) 72 | ax.bar(1:length(all_vals), all_vals, yerr = all_errs, capsize = capsize,color = col_p) 73 | for (i, diffs) = enumerate(all_diffs) # plot individual data points 74 | shifts = 1:length(diffs); shifts = (shifts .- mean(shifts))/std(shifts)*0.15 75 | ax.scatter(i .+ shifts, all_diffs, color = col_point, marker = ".", s = 3, alpha = 0.5, zorder = 100) 76 | end 77 | ax.set_xticks(1:length(all_vals)) 78 | ax.set_xticklabels(all_ticklabels, rotation = 45, ha = "right") 79 | ylabel(L"$\Delta$"*"steps") 80 | ax.set_ylim(0, 1.5) 81 | 82 | ### plot delta pi(a1) ### 83 | 84 | grids = fig.add_gridspec(nrows=1, ncols=length(params), left=0.00, right=1.0, bottom = 0, top = top, wspace=0.35) 85 | for (ip, p) = enumerate(params) 86 | N, Lplan = p 87 | all_ms = [] 88 | for i = 1:2 #rewarded and non-rewarded sim 89 | ms = [] 90 | for seed = 51:55 91 | if p == (140,12) 92 | @load "$datadir/causal_N$(N)_Lplan$(Lplan)_$(seed)_1000_single.bson" data 93 | else 94 | @load "$datadir/causal_N$(N)_Lplan$(Lplan)_$(seed)1000_single.bson" data 95 | end 96 | p_simulated_actions, p_simulated_actions_old = data["p_simulated_actions"], data["p_simulated_actions_old"] 97 | p_initial_sim, p_continue_sim = data["p_initial_sim"], data["p_continue_sim"] 98 | p_simulated_actions ./= (1 .- p_continue_sim) 99 | p_simulated_actions_old ./= (1 .- p_initial_sim) 100 | inds = findall(.~isnan.(sum(p_simulated_actions, dims = 1)[:])) #data for all 101 | #old and new probabilities 102 | push!(ms, [mean(p_simulated_actions_old[i, inds]); mean(p_simulated_actions[i, inds])]) 103 | end 104 | ms = reduce(hcat, ms) 105 | push!(all_ms, ms[2, :]-ms[1,:]) 106 | end 107 | 108 | ms = [mean(ms) for ms = all_ms] 109 | ss = [std(ms)/sqrt(Nseed) for ms = all_ms] 110 | ax = fig.add_subplot(grids[1,ip]) 111 | ax.bar(1:2, ms, yerr = ss, color = [col_p1, col_p2], capsize = capsize) 112 | if plot_points 113 | shifts = 1:length(all_ms[1]); shifts = (shifts .- mean(shifts))/std(shifts)*0.2 114 | ax.scatter([1 .+ shifts; 2 .+ shifts], [all_ms[1]; all_ms[2]], color = col_point, marker = ".", s = 15, zorder = 100) 115 | end 116 | ax.set_xticks(1:2) 117 | ax.set_xticklabels(["succ"; "un"]) 118 | if ip == 1 119 | ax.set_ylabel(L"$\Delta \pi(\hat{a}_1)$", labelpad = 0) 120 | ax.set_yticks([-0.4;0;0.4]) 121 | else 122 | ax.set_yticks([]) 123 | end 124 | ax.set_title(string(p), fontsize = fsize) 125 | ax.set_ylim(-0.4,0.4) 126 | ax.set_xlim([0.4; 2.6]) 127 | ax.axhline(0.0, color = "k", lw = 1) 128 | end 129 | 130 | 131 | 132 | ### add labels and save ### 133 | 134 | add_labels = true 135 | if add_labels 136 | y1, y2 = 1.07, 0.48 137 | x1, x2 = -0.09, 0.50 138 | fsize = fsize_label 139 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 140 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 141 | plt.text(x1,y2,"C";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 142 | end 143 | 144 | savefig("./figs/supp_hp_sweep.pdf", bbox_inches = "tight") 145 | savefig("./figs/supp_hp_sweep.png", bbox_inches = "tight") 146 | close() 147 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_human_euc_comparison.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S1 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | using ToPlanOrNotToPlan 5 | Random.seed!(1) # set random seed (for jitter in panel D) 6 | 7 | RTs_play, TTs_play, base1s, base2s, rews = [], [], [], [], [] 8 | 9 | for wrapstr = ["", "_euclidean"] 10 | # We start by loading some of our human behavioural data 11 | 12 | @load "$(datadir)/human_all_data_play$wrapstr.bson" data; 13 | _, _, _, _, all_rews_p, all_RTs_p, all_trial_nums_p, _ = data; 14 | @load "$(datadir)/human_all_data_follow$wrapstr.bson" data; 15 | _, _, _, _, all_rews_f, all_RTs_f, all_trial_nums_f, _ = data; 16 | 17 | means1, means2 = [[nanmean(RT) for RT = RTs] for RTs = [all_RTs_f, all_RTs_p]] 18 | keep = findall(means1 .< 690) # non-outlier users 19 | Nkeep = length(keep) 20 | means1, means2 = means1[keep], means2[keep] 21 | 22 | @load "$datadir/guided_lognormal_params_delta$wrapstr.bson" params # parameters of prior distributions 23 | initial_delays = (params["initial"][:, 3]+exp.(params["initial"][:, 1]+params["initial"][:, 2].^2/2))[keep] 24 | later_delays = (params["later"][:, 3]+exp.(params["later"][:, 1]+params["later"][:, 2].^2/2))[keep] 25 | push!(base1s, initial_delays) 26 | push!(base2s, later_delays) 27 | push!(RTs_play, means2) 28 | 29 | push!(rews, [nansum(rew)/size(rew, 1) for rew = all_rews_p][keep]) 30 | 31 | all_TTs = [] 32 | for u = keep 33 | new_TTs = [] 34 | rts, tnums = all_RTs_p[u], all_trial_nums_p[u] 35 | initial, later = params["initial"][u, :], params["later"][u, :] 36 | initial_post_mean(r) = calc_post_mean(r, muhat=initial[1], sighat=initial[2], deltahat=initial[3], mode = false) 37 | later_post_mean(r) = calc_post_mean(r, muhat=later[1], sighat=later[2], deltahat=later[3], mode = false) 38 | tnum = 1 39 | for ep = 1:size(rts, 1) 40 | for b = 1:sum(tnums[ep, :] .> 0.5) # for each action 41 | t, rt = tnums[ep, b], rts[ep, b] # trial number and response time 42 | if t > 1.5 # if we're in exploitation 43 | if t == tnum # same trial as before 44 | push!(new_TTs, later_post_mean(rt)) 45 | else # first action of new trial 46 | push!(new_TTs, initial_post_mean(rt)) 47 | end 48 | end 49 | tnum = t 50 | end 51 | end 52 | push!(all_TTs, nanmean(Float64.(new_TTs))) 53 | end 54 | push!(TTs_play, all_TTs) 55 | 56 | end 57 | 58 | 59 | #### plot some results #### 60 | 61 | titles = ["reaction"; "thinking"; "initial"; "later"; "rewards"] 62 | ylabs = ["time (ms)"; "thinking time (ms)"; "time (ms)"; "time (ms)"; "avg. reward"] 63 | datas = [RTs_play, TTs_play, base1s, base2s, rews] 64 | inds = [2;5] 65 | titles, datas, ylabs = titles[inds], datas[inds], ylabs[inds] 66 | 67 | fig = figure(figsize = (15*cm, 3*cm)) 68 | grids = fig.add_gridspec(nrows=1, ncols=length(datas), left=0.00, right=0.36, bottom = 0, top = 1.0, wspace=0.6) 69 | 70 | for (idat, data) = enumerate(datas) 71 | torus, euclid = data 72 | NT, NE = length(torus), length(euclid) 73 | mus = [mean(torus); mean(euclid)] # mean across users 74 | diff = mus[1] - mus[2] 75 | comb = [torus; euclid] 76 | ctrls = zeros(10000) 77 | for i = 1:10000 78 | newcomb = comb[randperm(length(comb))] 79 | ctrls[i] = mean(newcomb[1:NT]) - mean(newcomb[NT+1:end]) 80 | end 81 | println(titles[idat], " means: ", mus, " p = ", mean(ctrls .> diff)) 82 | 83 | ax = fig.add_subplot(grids[1,idat]) 84 | ax.bar(1:2, mus, color = col_c) # bar plot 85 | # plot individual data points 86 | ax.scatter(ones(NT)+randn(NT)*0.1, torus, marker = ".", s = 6, color = "k", zorder = 100) 87 | ax.scatter(ones(NE)*2+randn(NE)*0.1, euclid, marker = ".", s = 6, color = "k", zorder = 100) 88 | ax.set_xticks(1:2) 89 | ax.set_xticklabels(["wrap"; "no-wrap"], rotation = 45, ha = "right") 90 | ax.set_ylabel(ylabs[idat]) 91 | #ax.set_title(titles[idat]) 92 | end 93 | 94 | ### plot rew vs RT for both conditions ### 95 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.47, right=1.0, bottom = 0, top = 1.0, wspace=0.6) 96 | ax = fig.add_subplot(grids[1,1]) 97 | for i = 1:2 98 | ax.scatter(RTs_play[i], rews[i], color = ["k", col_c][i], marker = ".", s = 60) 99 | end 100 | ax.legend(["wrap"; "no-wrap"], loc = (0.4, 0.65), handletextpad=0.4, borderaxespad = 0.3, handlelength = 1.0, fontsize = fsize_leg) 101 | ax.set_xlabel("mean RT (ms)") 102 | ax.set_ylabel("mean reward") 103 | ax.set_yticks(4:2:16) 104 | 105 | ### distribution of path lengths ### 106 | 107 | @load "$(datadir)/wrap_and_nowrap_pairwise_dists.bson" dists 108 | 109 | ds = 1:12 110 | hist_wraps = [sum(dists[1] .== d) for d = ds] 111 | hist_nowraps = [sum(dists[2] .== d) for d = ds] 112 | xs = reduce(vcat, [[d-0.5; d+0.5] for d = ds]) 113 | hwraps = reduce(vcat, [[h;h] for h = hist_wraps/sum(hist_wraps)]) 114 | hnowraps = reduce(vcat, [[h;h] for h = hist_nowraps/sum(hist_wraps)]) 115 | ax = fig.add_subplot(grids[1,2]) 116 | plot(xs, hwraps, color = "k") 117 | plot(xs, hnowraps, color = col_c) 118 | axvline(mean(dists[1]), color = "k", lw = 1.5) 119 | axvline(mean(dists[2]), color = col_c, lw = 1.5) 120 | ax.set_xlabel("distance to goal") 121 | ax.set_ylabel("frequency") 122 | 123 | # add labels and save 124 | y1 = 1.16 125 | x1, x2, x3, x4 = -0.05, 0.15, 0.40, 0.70 126 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label, ) 127 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 128 | plt.text(x3,y1,"C";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 129 | plt.text(x4,y1,"D";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 130 | savefig("./figs/supp_human_euclidean_comparison.pdf", bbox_inches = "tight") 131 | savefig("./figs/supp_human_euclidean_comparison.png", bbox_inches = "tight") 132 | close() 133 | 134 | 135 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_human_summary.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S2 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | using ToPlanOrNotToPlan 5 | using StatsBase 6 | Random.seed!(1) # set random seed (for jitter in panel D) 7 | wrap = true 8 | wrap = false 9 | wrapstr = (if wrap "" else "_euclidean" end) 10 | 11 | # We start by loading some of our human behavioural data 12 | @load "$(datadir)/human_RT_and_rews_play$wrapstr.bson" data; data_play = data # non-guided episodes 13 | @load "$(datadir)/human_RT_and_rews_follow$wrapstr.bson" data; data_follow = data # guided episodes 14 | means1, means2 = [[nanmean(RTs) for RTs = data["all_RTs"]] for data = [data_follow, data_play]] 15 | keep = findall(means1 .< 690) # non-outlier users 16 | Nkeep = length(keep) 17 | 18 | 19 | # mean response times and rewards for all users 20 | mean_RTs = [[nanmean(RTs) for RTs = data["all_RTs"]] for data = [data_follow, data_play]] 21 | mean_rews = [[sum(rews)/size(rews, 1) for rews = data["all_rews"]] for data = [data_follow, data_play]] 22 | sums = [nansum(RT)/size(RT,1)/1e3 for RT = data_play["all_RTs"]] 23 | 24 | fig = figure(figsize = (15*cm, 3*cm)) 25 | grids = fig.add_gridspec(nrows=1, ncols=3, left=0.00, right=0.78, bottom = 0, top = 1.0, wspace=0.5) 26 | 27 | # plot average reward against average response time for all users 28 | 29 | for i = 1:2 # plot data for guided and non-guided episodes 30 | ax = fig.add_subplot(grids[1,i]) 31 | ax.scatter(mean_RTs[i][keep], mean_rews[i][keep], color = "k", marker = ".", s = 60) 32 | ax.set_xlabel("mean RT (ms)") 33 | ax.set_ylabel("mean reward") 34 | ax.set_title(["guided"; "non-guided"][i], fontsize = fsize) 35 | ax.set_yticks(4:2:12) 36 | end 37 | 38 | # plot mean probability of optimal action against mean response time 39 | 40 | # load the data 41 | @load "$(datadir)/human_all_data_follow$wrapstr.bson" data; 42 | @load "$(datadir)/human_all_data_play$wrapstr.bson" data; 43 | all_states, all_ps, all_as, all_wall_loc, all_rews, all_RTs, _, _ = data # extract relevant data 44 | all_opts = [] 45 | Larena = 4; ed = EnvironmentDimensions(4^2, 2, 5, 50, Larena) # Environment parameters 46 | for i = keep # for each user 47 | opts = [] # list of optimality of actions 48 | for b = 1:size(all_as[i], 1) # for each episode 49 | dists = dist_to_rew(all_ps[i][:, b:b], all_wall_loc[i][:, :, b:b], Larena) # distance to goal from each state 50 | if sum(all_rews[i][b, :]) > 0.5 # if at least 1 trial was completed 51 | for t = findall(all_rews[i][b,:] .> 0.5)[1]+1:sum(all_as[i][b,:] .> 0.5) # for each action 52 | # extract optimal policy 53 | pi_opt = optimal_policy(Int.(all_states[i][:,b,t]), all_wall_loc[i][:,:,b], dists, ed) 54 | push!(opts, Float64(pi_opt[Int(all_as[i][b,t])] > 1e-2)) # was the action taken optimal? 55 | end 56 | end 57 | end 58 | push!(all_opts, mean(opts)) # store results for this user 59 | end 60 | RTs = [nanmean(RTs) for RTs = all_RTs[keep]] # corresponding response times 61 | inds = findall(all_opts .> 0.5) 62 | RTs, all_opts = RTs[inds], Float64.(all_opts[inds]) 63 | 64 | rcor = cor(RTs, all_opts) # correlations between response times and optimality 65 | ctrls = zeros(10000); for i = 1:10000 ctrls[i] = cor(RTs, all_opts[randperm(length(inds))]) end # permutation test 66 | println(rcor, " ", mean(ctrls .> rcor)) 67 | 68 | # plot result 69 | ax = fig.add_subplot(grids[1,3]) 70 | ax.scatter(RTs, all_opts, color = "k", marker = ".", s = 60) 71 | ax.set_xlabel("mean RT (ms)") 72 | ax.set_ylabel(L"$p$"*"(optimal)") 73 | 74 | 75 | # plot means of the prior distributions for each user 76 | 77 | @load "$datadir/guided_lognormal_params_delta$wrapstr.bson" params # parameters of prior distributions 78 | 79 | # parameters for initial action of a trial and later actions of a trial 80 | initial_delays = params["initial"][:, 3]+exp.(params["initial"][:, 1]+params["initial"][:, 2].^2/2) 81 | later_delays = params["later"][:, 3]+exp.(params["later"][:, 1]+params["later"][:, 2].^2/2) 82 | 83 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.90, right=1.0, bottom = 0, top = 1.0, wspace=0.40) 84 | 85 | ax = fig.add_subplot(grids[1,1]) 86 | mus = [mean(initial_delays[keep]); mean(later_delays[keep])] # mean across users 87 | ax.bar(1:2, mus, color = col_c) # bar plot 88 | # plot individual data points 89 | ax.scatter(ones(Nkeep)+randn(Nkeep)*0.1, initial_delays[keep], marker = ".", s = 6, color = "k", zorder = 1000) 90 | ax.scatter(ones(Nkeep)*2+randn(Nkeep)*0.1, later_delays[keep], marker = ".", s = 6, color = "k", zorder = 100) 91 | ax.set_xticks(1:2) 92 | ax.set_xticklabels(["initial"; "later"], rotation = 45, ha = "right") 93 | ax.set_ylabel("time (ms)") 94 | 95 | # print some results as well 96 | println("correlation between thinking time and optimality: ", rcor, ", p = ", mean(ctrls .> rcor)) 97 | println("mean optimality: ", mean(all_opts)) 98 | 99 | # add labels and save 100 | y1 = 1.16 101 | x1, x2, x3, x4 = -0.09, 0.21, 0.49, 0.80 102 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label, ) 103 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 104 | plt.text(x3,y1,"C";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 105 | plt.text(x4,y1,"D";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 106 | savefig("./figs/supp_human_data$wrapstr.pdf", bbox_inches = "tight") 107 | savefig("./figs/supp_human_data$wrapstr.png", bbox_inches = "tight") 108 | close() 109 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_internal_model.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S7 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | 5 | fig = figure(figsize = (10*cm, 6*cm)) 6 | grids = fig.add_gridspec(nrows=2, ncols=2, left=0.00, right=1.00, bottom = 0.0, top = 1.0, wspace=0.5, hspace = 0.30) 7 | 8 | # start by loading some data to plot 9 | @load "$(datadir)/internal_model_accuracy.bson" results 10 | 11 | global epochs = [] # training epochs 12 | rews, states = [], [] # reward prediction accuracy and state prediction accuracy 13 | for seed = seeds # for each trained model 14 | epochs = sort([k for k = keys(results[seed])]) # checkpoint epochs 15 | global epochs = epochs[epochs .<= plan_epoch] # training epochs 16 | push!(rews, [results[seed][e]["rew"] for e = epochs]) # reward prediction accuracy 17 | push!(states, [results[seed][e]["state"] for e = epochs]) # state prediction accuracy 18 | end 19 | rews, states = [reduce(hcat, arr) for arr = [rews, states]] # concatenate across models 20 | 21 | mrs, mss = mean(rews, dims = 2)[:], mean(states, dims = 2)[:] # mean accuracies 22 | srs, sss = std(rews, dims = 2)[:]/sqrt(length(seeds)), std(states, dims = 2)[:]/sqrt(length(seeds)) # standard errors 23 | xs = epochs*40*200 / 1000000 # convert to number of episodes seen 24 | 25 | # for the state prediction data and reward prediction data 26 | for (idat, dat) = enumerate([(mss, sss), (mrs, srs)]) 27 | for irange = 1:2 # for the full and zoomed in y ranges 28 | ax = fig.add_subplot(grids[irange,idat]) 29 | m, s = dat # extract mean and sem for this data 30 | ax.plot(xs, m, ls = "-", color = col_p) # plot mean 31 | ax.fill_between(xs, m-s, m+s, color = col_p, alpha = 0.2) # standard error 32 | if irange == 2 # zoomed in 33 | ax.set_xlabel("training episodes (x"*L"$10^6$"*")") 34 | ax.set_xticks(0:2:8) 35 | ax.set_ylim(0.99, 1.0002) 36 | ax.set_yticks([0.99, 1.0]) 37 | ax.set_yticklabels(["0.99"; "1.0"]) 38 | else # full range 39 | ax.set_xticks([]) 40 | ax.set_ylim(0.0, 1.02) 41 | ax.set_yticks([0.0; 0.5; 1.0]) 42 | ax.set_yticklabels(["0.00"; "0.50"; "1.00"]) 43 | end 44 | if idat == 1 45 | ax.set_ylabel("state prediction") 46 | else 47 | ax.set_ylabel("reward prediction") 48 | end 49 | ax.set_xlim(0,8) 50 | end 51 | end 52 | 53 | # add labels and save 54 | y1 = 1.12 55 | x1, x2 = -0.13, 0.46 56 | fsize = fsize_label 57 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 58 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 59 | 60 | savefig("./figs/supp_internal_model.pdf", bbox_inches = "tight") 61 | savefig("./figs/supp_internal_model.png", bbox_inches = "tight") 62 | close() 63 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_plan_probs.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S13 of Jensen et al. 2 | 3 | include("plot_utils.jl") 4 | 5 | fig = figure(figsize = (5*cm, 3*cm)) 6 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.00, right=1.00, bottom = 0.0, top = 1.0, wspace=0.35) 7 | 8 | all_ms = [] # store list of pre/post values for both successful and unsuccessful rollouts 9 | for i = 1:2 # for successful and unsuccessful rollouts 10 | ms = [] 11 | for seed = seeds # for each trained model 12 | @load "$(datadir)/causal_N100_Lplan8_$(seed)_$(plan_epoch).bson" data # load data 13 | # extract pre- and post-rollout rollout probabilities 14 | p_initial_sim, p_continue_sim = data["p_initial_sim"], data["p_continue_sim"] 15 | inds = findall(.~isnan.(sum(p_continue_sim, dims = 1)[:])) #data for both successful and unsuccesful rollout 16 | # store data 17 | push!(ms, [mean(p_initial_sim[i, inds]); mean(p_continue_sim[i, inds])]) 18 | end 19 | ms = reduce(hcat, ms) # concatenate across models 20 | push!(all_ms, ms[2, :]) # store data for later stats 21 | # mean and two standard errors across seeds 22 | m3, s3 = mean(ms, dims = 2)[1:2], 2*std(ms, dims = 2)[1:2] / sqrt(length(seeds)) 23 | 24 | # plot our newly loaded data 25 | ax = fig.add_subplot(grids[1,i]) 26 | ax.bar(1:2, m3, yerr = s3, color = [col_p1, col_p2][i], capsize = capsize) # bar plot 27 | # plot individual data points 28 | shifts = 1:size(ms, 2); shifts = (shifts .- mean(shifts))/std(shifts)*0.2 # add some jitter 29 | ax.scatter([1 .+ shifts; 2 .+ shifts], [ms[1, :]; ms[2, :]], color = col_point, marker = ".", s = 15, zorder = 100) 30 | ax.set_xticks(1:2) 31 | ax.set_xticklabels(["pre"; "post"]) 32 | if i == 1 # successful 33 | ax.set_ylabel(L"$\pi($"*"rollout"*L"$)$", labelpad = -1.5) 34 | ax.set_title("succ.", fontsize = fsize) 35 | ax.set_yticks(0.0:0.2:0.8) 36 | else # unsuccessful 37 | ax.set_title("unsucc.", fontsize = fsize) 38 | ax.set_yticks([]) 39 | end 40 | ax.set_ylim(0.0, 0.8) 41 | ax.set_xlim(0.5, 2.5) 42 | end 43 | 44 | # add labels and save figure 45 | y1 = 1.16 46 | x1, x2 = -0.25, 0.45 47 | fsize = fsize_label 48 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize, ) 49 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize,) 50 | 51 | savefig("./figs/supp_plan_probs.pdf", bbox_inches = "tight") 52 | savefig("./figs/supp_plan_probs.png", bbox_inches = "tight") 53 | close() 54 | 55 | # print the difference between successful and unsuccessful rollouts 56 | delta = all_ms[2] - all_ms[1] 57 | println("post delta: ", mean(delta), " ", std(delta)/sqrt(length(delta))) 58 | println("Gaussian p = ", cdf(Normal(mean(delta), std(delta)/sqrt(length(delta))), 0)) 59 | 60 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_values.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S6 of Jensen et al. 2 | 3 | include("plot_utils.jl") #various global settings 4 | 5 | # load data 6 | @load "$datadir/value_function_eval.bson" data 7 | 8 | as = [data[seed]["as"] for seed = seeds] # actions 9 | Vs = [data[seeds[i]]["Vs"][as[i] .> 0.5] for i = 1:length(seeds)] # values 10 | rtg = [data[seeds[i]]["rew_to_go"][as[i] .> 0.5] for i = 1:length(seeds)] # reward to go 11 | ts = [data[seeds[i]]["ts"][as[i] .> 0.5]/51*20 for i = 1:length(seeds)] # time within episode 12 | accs = [Vs[i] - rtg[i] for i = 1:length(seeds)] 13 | all_accs = reduce(vcat, accs) # combine across agents 14 | all_rtg = reduce(vcat, rtg) # combine across agents 15 | 16 | all_last_as = [] # last actions 17 | for a = as 18 | last_inds = sum(a .> 0.5, dims = 2)[:] 19 | push!(all_last_as, reduce(hcat, [a[i, last_inds[i]-9:last_inds[i]] for i = 1:size(a, 1)])') 20 | end 21 | all_last_as = reduce(vcat, all_last_as) 22 | 23 | ### first just plot some general statistics 24 | fig = figure(figsize = (15*cm, 7*cm)) 25 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.1, right=0.9, bottom = 0.6, top = 1.0, wspace=0.4) 26 | axs = [fig.add_subplot(grids[1,i]) for i = 1:2] 27 | 28 | # plot histogram of rew-to-go and prediction errors 29 | bins1 = -7:1:7 30 | bins2 = -7:1:7 31 | axs[1].hist(all_accs, alpha = 0.5, color = col_p, bins = bins1, label = "value function", zorder = 10000) 32 | axs[1].hist(mean(all_rtg) .- all_rtg, color = col_c, alpha = 0.5, bins = bins2, label = "constant") 33 | axs[1].set_xlabel("prediction error") 34 | axs[1].set_ylabel("frequency") 35 | axs[1].set_yticks([]) 36 | axs[1].legend(fontsize = fsize_leg, ncol = 2, loc = "upper center", bbox_to_anchor = (0.5, 1.2), frameon = false) 37 | 38 | # plot prediction errors vs. time 39 | bins = 0:1:20 40 | xs = 0.5*(bins[1:length(bins)-1]+bins[2:end]) 41 | res = zeros(length(seeds), length(xs)) 42 | for i = 1:length(seeds) 43 | res[i, :] = [mean(abs.(accs[i][(ts[i] .> bins[j]) .& (ts[i] .<= bins[j+1])])) for j = 1:length(xs)] 44 | end 45 | m, s = mean(res, dims = 1)[:], std(res, dims = 1)[:]/sqrt(length(seeds)) 46 | axs[2].plot(xs, m, color = col_p) 47 | axs[2].fill_between(xs, m-s, m+s, alpha = 0.2, color = col_p) 48 | axs[2].set_xlabel("time within episode (s)") 49 | axs[2].set_ylabel("prediction error") 50 | 51 | 52 | ### separate data by rollouts sequence length 53 | 54 | plan_lengths = 1:5 # lengths to consider 55 | plan_nums = 0:plan_lengths[end] 56 | dkeys = ["tot_plans", "plan_nums", "suc_rolls", "num_suc_rolls", "Vs", "rew_to_go"] 57 | all_accs, all_vals, all_vals0, all_accs0 = [zeros(length(seeds), length(plan_lengths), length(plan_nums)) for _ = 1:4] 58 | for (iseed, seed) = enumerate(seeds) 59 | tot_plans, plan_nums, suc_rolls, num_suc_rolls, Vs, rew_to_go = [data[seed][dkey] for dkey = dkeys] 60 | accuracy = abs.(Vs - rew_to_go) 61 | for (ilength, plan_length) = enumerate(plan_lengths) 62 | for (inum, number) = enumerate(0:plan_length) 63 | inds = ((tot_plans .== plan_length) .& (plan_nums .== number) .& (suc_rolls .< 10.5)) 64 | inds0 = (inds .& (num_suc_rolls .< 0.5)) # sequences with no successful rollouts 65 | accs = accuracy[inds] 66 | accs0 = accuracy[inds0] 67 | vals = Vs[inds] 68 | vals0 = Vs[inds0] 69 | rtg = rew_to_go[inds] 70 | # store result of this rollout 71 | all_accs[iseed, ilength, inum] = mean(accs) 72 | all_accs0[iseed, ilength, inum] = mean(accs0) 73 | all_vals[iseed, ilength, inum] = mean(vals) 74 | all_vals0[iseed, ilength, inum] = mean(vals0) 75 | end 76 | end 77 | end 78 | 79 | ## plot 80 | 81 | cols = [[0.00; 0.09; 0.32], [0.00;0.19;0.52], [0.19;0.39;0.72], [0.34;0.54;0.87], [0.49;0.69;1.0]] 82 | grids = fig.add_gridspec(nrows=1, ncols=4, left=0.0, right=1, bottom = 0.0, top = 0.35, wspace=0.6) 83 | axs = [fig.add_subplot(grids[1,i]) for i = 1:4] 84 | for (ilength, plan_length) = enumerate(plan_lengths) 85 | for (idat, dat) = enumerate([all_vals, all_vals0, all_accs, all_accs0]) # for each data type to plot 86 | m = mean(dat[:, ilength, :], dims = 1)[1:plan_length+1] 87 | s = std(dat[:, ilength, :], dims = 1)[1:plan_length+1]/sqrt(size(dat, 1)) 88 | xs = 0:plan_length 89 | axs[idat].plot(xs, m, label = (if idat == 1 "$plan_length rollouts" else nothing end), color = cols[ilength]) 90 | axs[idat].fill_between(xs, m-s, m+s, alpha = 0.2, color = cols[ilength]) 91 | end 92 | end 93 | #axs[1].legend() 94 | ylabels = ["value", "value [failed]", "error", "error [failed]"] 95 | for i = 1:4 axs[i].set_ylabel(ylabels[i]) end 96 | for ax = axs ax.set_xlabel("rollout number") end 97 | 98 | ## labels and save 99 | 100 | y1, y2 = 1.07, 0.42 101 | x1, x2, x3, x4 = -0.07, 0.195, 0.46, 0.745 102 | plt.text(x1+0.13,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label, ) 103 | plt.text(x3+0.035,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 104 | plt.text(x1,y2,"C"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label, ) 105 | plt.text(x2,y2,"D";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 106 | plt.text(x3,y2,"E";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 107 | plt.text(x4,y2,"F";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 108 | 109 | savefig("./figs/supp_value_function.pdf", bbox_inches = "tight") 110 | savefig("./figs/supp_value_function.png", bbox_inches = "tight") 111 | close() 112 | 113 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_supp_variable.jl: -------------------------------------------------------------------------------- 1 | #This script plots Figure S8 of Jensen et al. 2 | 3 | include("plot_utils.jl") #various global settings 4 | using Flux 5 | 6 | fig = figure(figsize = (12*cm, 3.0*cm)) 7 | 8 | #plot thinking times against the probability of performing a rollout 9 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.0, right=0.3, bottom = 0.0, top = 1.0, wspace=0.05) 10 | ax = fig.add_subplot(grids[0,0]) 11 | #load results 12 | @load "$(datadir)RT_predictions_variable_N100_Lplan8_$(plan_epoch).bson" data 13 | allsims, RTs, pplans, dists = [data[k] for k = ["correlations"; "RTs_by_u"; "pplans_by_u"; "dists_by_u"]]; 14 | RTs, pplans, dists = [reduce(vcat, arr) for arr = [RTs, pplans, dists]] 15 | 16 | bins = 0.1:0.05:0.8 #bin edges for histogram 17 | xs = 0.5*(bins[1:length(bins)-1] + bins[2:end]) #bin centers 18 | RTs_shuff = RTs[randperm(length(RTs))] #perform a shuffle 19 | dat = [RTs[(pplans .>= bins[i]) .& (pplans .< bins[i+1])] for i = 1:length(bins)-1] #real data 20 | #create shuffle 21 | dat_shuff = [RTs_shuff[(pplans .>= bins[i]) .& (pplans .< bins[i+1])] for i = 1:length(bins)-1] 22 | #data to plot 23 | m, m_c = [[mean(d) for d = dat] for dat = [dat, dat_shuff]] 24 | s, s_c = [[std(d)/sqrt(length(d)) for d = dat] for dat = [dat, dat_shuff]] 25 | 26 | #plot the data 27 | ax.bar(xs, m, color = col_p, width = 0.04, linewidth = 0, label = "data") 28 | ax.errorbar(xs, m, yerr = s, fmt = "none", color = "k", capsize = 2, lw = 1.5) 29 | ax.errorbar(xs, m_c, yerr = s_c, fmt = "-", color = col_c, capsize = 2, lw = 1.5, label = "shuffle") 30 | ax.set_xlabel(L"$\pi$"*"(rollout)") 31 | ax.set_ylabel("thinking time (ms)") 32 | ax.set_yticks([0;50;100;150;200;250]) 33 | ax.set_ylim(0, 250) 34 | ax.legend(frameon = false, fontsize = fsize_leg) 35 | 36 | m = mean(allsims, dims = 1)[1:2] 37 | s = std(allsims, dims = 1)[1:2] / sqrt(size(allsims, 1)) 38 | println("mean and sem of correlations: ", m, " ", s) #print results 39 | 40 | 41 | # plot performance vs nplan 42 | 43 | #start by loading and extracting data 44 | @load "$datadir/perf_by_n_variable_N100_Lplan8.bson" res_dict 45 | seeds = sort([k for k = keys(res_dict)]) 46 | Nseed = length(seeds) 47 | ms1, ms2, bs, es1, es2 = [], [], [], [], [] 48 | for (is, seed) = enumerate(seeds) #for each model 49 | #time within trial, distance to goal, and policy 50 | dts, mindists, policies = [res_dict[seed][k] for k = ["dts"; "mindists"; "policies"]] 51 | #select episodes where the trial finished for all rollout numbers 52 | keepinds = findall((.~isnan.(sum(dts, dims = (1,3))[:])) .& (mindists[:, 2] .>= 0)) 53 | new_dts = dts[:, keepinds, :] 54 | new_mindists = mindists[keepinds, 2] 55 | policies = policies[:, keepinds, :, :, :] 56 | #mean performance across episodes with (m1) and without (m2) rollout feedback 57 | m1, m2 = mean(new_dts[1,:,:], dims = 1)[:], mean(new_dts[2,:,:], dims = 1)[:] 58 | push!(ms1, m1); push!(ms2, m2); push!(bs, mean(new_mindists)) #also store optimal (bs) 59 | p1, p2 = policies[1, :, :, :, :], policies[2, :, :, :, :] #extract log policies 60 | p1, p2 = [p .- Flux.logsumexp(p, dims = 4) for p = [p1, p2]] #normalize 61 | e1, e2 = [-sum(exp.(p) .* p, dims = 4)[:, :, :, 1] for p = [p1, p2]] #entropy 62 | m1, m2 = [mean(e[:,:,1], dims = 1)[:] for e = [e1,e2]] #only consider entropy of first action 63 | push!(es1, m1); push!(es2, m2) #store entropies 64 | end 65 | #concatenate across seeds 66 | ms1, ms2, es1, es2 = [reduce(hcat, arr) for arr = [ms1, ms2, es1, es2]] 67 | # compute mean and std across seeds 68 | m1, s1 = mean(ms1, dims = 2)[:], std(ms1, dims = 2)[:]/sqrt(Nseed) 69 | m2, s2 = mean(ms2, dims = 2)[:], std(ms2, dims = 2)[:]/sqrt(Nseed) 70 | me1, se1 = mean(es1, dims = 2)[:], std(es1, dims = 2)[:]/sqrt(Nseed) 71 | me2, se2 = mean(es2, dims = 2)[:], std(es2, dims = 2)[:]/sqrt(Nseed) 72 | nplans = (1:length(m1)) .- 1 # 73 | 74 | # plot performance vs number of rollouts 75 | grids = fig.add_gridspec(nrows=1, ncols=1, left=0.41, right=0.62, bottom = 0.0, top = 1.0, wspace=0.50) 76 | ax = fig.add_subplot(grids[1,1]) 77 | ax.plot(nplans,m1, ls = "-", color = col_p, label = "agent") #mean 78 | ax.fill_between(nplans,m1-s1,m1+s1, color = col_p, alpha = 0.2) #standard error 79 | plot([nplans[1]; nplans[end]], ones(2)*mean(bs), color = col_c, ls = "-", label = "optimal") #optimal baseline 80 | legend(frameon = false, loc = "upper right", fontsize = fsize_leg, handlelength=1.5, handletextpad=0.5, borderpad = 0.0, labelspacing = 0.05) 81 | xlabel("# rollouts") 82 | ylabel("steps to goal") 83 | ylim(0.9*mean(bs), maximum(m1+s1)+0.1*mean(bs)) 84 | xticks([0;5;10;15]) 85 | 86 | # plot change in policy with successful and unsuccessful rollouts 87 | grids = fig.add_gridspec(nrows=1, ncols=2, left=0.75, right=1.00, bottom = 0, top = 1.0, wspace=0.10) 88 | for i = 1:2 #rewarded and non-rewarded rollout 89 | ms = [] 90 | for seed = seeds #iterate through random seeds 91 | @load "$(datadir)/variable_causal_N100_Lplan8_$(seed)_$(plan_epoch).bson" data 92 | #rollouts action probability under new and old policy 93 | p_simulated_actions, p_simulated_actions_old = data["p_simulated_actions"], data["p_simulated_actions_old"] 94 | #rollout probabilities 95 | p_initial_sim, p_continue_sim = data["p_initial_sim"], data["p_continue_sim"] 96 | p_simulated_actions ./= (1 .- p_continue_sim) #renormalize over actions 97 | p_simulated_actions_old ./= (1 .- p_initial_sim) #renormalize over actions 98 | inds = findall(.~isnan.(sum(p_simulated_actions, dims = 1)[:])) #make sure we have data for both scenarios 99 | push!(ms, [mean(p_simulated_actions_old[i, inds]); mean(p_simulated_actions[i, inds])]) #mean for new and old 100 | end 101 | ms = reduce(hcat, ms) #concatenate across seeds 102 | m3, s3 = mean(ms, dims = 2)[1:2], std(ms, dims = 2)[1:2] / sqrt(length(seeds)) #mean and sem across seeds 103 | 104 | # plot results 105 | ax = fig.add_subplot(grids[1,i]) 106 | ax.bar(1:2, m3, yerr = s3, color = [col_p1, col_p2][i], capsize = capsize) 107 | # plot individual data points 108 | shifts = 1:size(ms, 2); shifts = (shifts .- mean(shifts))/std(shifts)*0.2 109 | ax.scatter([1 .+ shifts; 2 .+ shifts], [ms[1, :]; ms[2, :]], color = col_point, marker = ".", s = 15, zorder = 100) 110 | ax.set_xticks(1:2) 111 | ax.set_xticklabels(["pre"; "post"]) 112 | if i == 1 #successful rollout 113 | ax.set_ylabel(L"$\pi(\hat{a}_1)$", labelpad = 0) 114 | ax.set_title("succ.", fontsize = fsize) 115 | ax.set_yticks([0.1;0.3;0.5;0.7]) 116 | else #unsuccessful rollout 117 | ax.set_title("unsucc.", fontsize = fsize) 118 | ax.set_yticks([]) 119 | end 120 | #set some parameters 121 | ax.set_ylim(0.0, 0.8) 122 | ax.set_xlim([0.4; 2.6]) 123 | ax.axhline(0.25, color = color = col_c, ls = "-") 124 | end 125 | 126 | # add labels and save 127 | y1, y2 = 1.16, 0.46 128 | x1, x2, x3 = -0.105, 0.33, 0.675 129 | plt.text(x1,y1,"A"; ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label, ) 130 | plt.text(x2,y1,"B";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 131 | plt.text(x3,y1,"C";ha="left",va="top",transform=fig.transFigure,fontweight="bold",fontsize=fsize_label,) 132 | 133 | savefig("./figs/supp_variable_time.pdf", bbox_inches = "tight") 134 | savefig("./figs/supp_variable_time.png", bbox_inches = "tight") 135 | close() 136 | -------------------------------------------------------------------------------- /computational_model/plot_paper_figs/plot_utils.jl: -------------------------------------------------------------------------------- 1 | import Pkg 2 | Pkg.activate("../") 3 | using Revise 4 | using PyPlot, PyCall, LaTeXStrings 5 | using Random, Statistics, NaNStatistics, Distributions 6 | using BSON: @save, @load 7 | @pyimport matplotlib.gridspec as gspec 8 | @pyimport matplotlib.patches as patch 9 | Random.seed!(1) 10 | 11 | global fsize = 10 12 | global fsize_leg = 8 13 | global fsize_label = 12 14 | global cm = 1 / 2.54 15 | global datadir = "../analysis_scripts/results/" 16 | global figdir = "./figs/" 17 | global lw_wall = 5 18 | global lw_arena = 1.3 19 | global linewidth = 3 20 | global npermute = 10000 #how many permutations for permutation tests 21 | global plot_points = true # plot individual data points for n < 10 (required by NN) 22 | global capsize = 3.5 23 | 24 | # set some plotting parameters 25 | rc("font", size = fsize) 26 | rc("pdf", fonttype = 42) 27 | rc("lines", linewidth = linewidth) 28 | rc("axes", linewidth = 1) 29 | 30 | PyCall.PyDict(matplotlib."rcParams")["axes.spines.top"] = false 31 | PyCall.PyDict(matplotlib."rcParams")["axes.spines.right"] = false 32 | 33 | rc("font", family="sans-serif") 34 | PyCall.PyDict(matplotlib."rcParams")["font.sans-serif"] = "arial" 35 | 36 | 37 | ### set global color scheme ### 38 | global col_h = [0;0;0]/255 # human data 39 | global col_p = [76;127;210]/255 # RL agent 40 | global col_p1 = col_p * 0.88 #darker 41 | global col_p2 = col_p .+ [0.45; 0.35; 0.175] #lighter 42 | global col_c = [0.6,0.6,0.6] #ctrl 43 | global col_point = 0.5*(col_c+col_h) #individual data points 44 | 45 | ### select global models 46 | 47 | global seeds = 61:65 48 | global plan_epoch = 1000 49 | 50 | function get_human_inds() 51 | #get indices of human participants to analyse 52 | @load "$(datadir)/human_RT_and_rews_follow.bson" data #load data 53 | keep = findall([nanmean(RTs) for RTs = data["all_RTs"]] .< 690) #less than 690 ms RT on guided trials 54 | return keep 55 | end 56 | 57 | 58 | function plot_comparison(ax, data; xticklabs = ["", ""], ylab = "", xlab = nothing, col = "k", col2 = nothing, ylims = nothing, plot_title = nothing, yticks = nothing, rotation = 0) 59 | if col2 == nothing col2 = col end 60 | niters = size(data, 1) 61 | m = nanmean(data, dims = 1)[:] 62 | s = nanstd(data, dims = 1)[:] / sqrt(niters) 63 | xs = 1:size(data, 2) 64 | 65 | for n = 1:niters 66 | ax.scatter(xs, data[n, :], color = col2, s = 50, alpha = 0.6, marker = ".") 67 | end 68 | for n = 1:niters 69 | ax.plot(xs, data[n, :], ls = ":", color = col2, alpha = 0.6, linewidth = linewidth*2/3) 70 | end 71 | ax.errorbar(xs, m, yerr = s, fmt = "-", color = col, capsize = capsize) 72 | 73 | ax.set_xlim(1-0.5, xs[end]+0.5) 74 | if rotation == 0 ha = "center" else ha = "right" end 75 | ax.set_xticks(xs) 76 | ax.set_xticklabels(xticklabs, rotation = rotation, ha = ha, rotation_mode = "anchor") 77 | ax.set_xlabel(xlab) 78 | ax.set_ylabel(ylab) 79 | ax.set_ylim(ylims) 80 | #println(ylims, " ", yticks) 81 | if ~isnothing(yticks) ax.set_yticks(yticks) end 82 | ax.set_title(plot_title, fontsize = fsize) 83 | end 84 | 85 | # lognormal helper function 86 | 87 | Phi(x) = cdf(Normal(), x) #standard normal pdf 88 | function calc_post_mean(r; deltahat=0, muhat=0, sighat=0, mode = false) 89 | #compute posterior mean thinking time for a given response time 'r' 90 | if (r < deltahat+1) return 0.0 end #if response time lower than minimum delay, return 0 91 | 92 | if mode 93 | post_delay = deltahat+exp(muhat-sighat^2) 94 | post_delay = min(r, post_delay) #can at most be response time 95 | return r - post_delay 96 | end 97 | 98 | k1, k2 = 0, r - deltahat #integration limits 99 | term1 = exp(muhat+sighat^2/2) 100 | term2 = Phi((log(k2)-muhat-sighat^2)/sighat) - Phi((log(k1)-muhat-sighat^2)/sighat) 101 | term3 = Phi((log(k2)-muhat)/sighat) - Phi((log(k1)-muhat)/sighat) 102 | post_delay = (term1*term2/term3 + deltahat) #add back delta for posterior mean delay 103 | return r - post_delay #posterior mean thinking time is response minus mean delay 104 | end 105 | 106 | # Stats helper function 107 | 108 | 109 | #permutation test 110 | function permutation_test(arr1, arr2) 111 | #test whetter arr1 is larger than arr2 112 | rands = zeros(npermute) 113 | for n = 1:npermute 114 | inds = Bool.(rand(0:1, length(arr1))) 115 | b1, b2 = [arr1[inds]; arr2[.~inds]], [arr1[.~inds]; arr2[inds]] 116 | rands[n] = nanmean(b1-b2) 117 | end 118 | trueval = nanmean(arr1-arr2) 119 | return rands, trueval 120 | end 121 | 122 | # Experimental data helper function 123 | 124 | global replaydir = "../../replay_analyses/" 125 | global plot_experimental_replays = false # False by default in case we haven't run these analyses 126 | function load_exp_data(;summary = false) 127 | # This function loads experimental replay data 128 | # If 'summary' we load summary data for our supplementary figure 129 | # Else load the full experimental dataset 130 | 131 | if summary 132 | resdir = replaydir*"results/summary_data/" 133 | else 134 | resdir = replaydir*"results/decoding/" 135 | end 136 | 137 | # Filenames to load 138 | fnames = readdir(resdir); fnames = fnames[[~occursin("succ", f) for f = fnames]] 139 | rnames = [f[10:length(f)-2] for f = fnames] # Animal names+id for each file (i.e. sessions) 140 | res_dict = Dict() # Dictionary for storing results 141 | for (i_f, f) = enumerate(fnames) # For each file 142 | res = load_pickle("$resdir$f") # Load content of file 143 | res_dict[rnames[i_f]] = res # Store in our result dict 144 | end 145 | return rnames, res_dict # Return session names and results 146 | end 147 | -------------------------------------------------------------------------------- /computational_model/repeated_submission.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is for submitting to the cambridge HPC CPU cluster and re-starting the job 3 | from the most recent model once the job times out 4 | """ 5 | 6 | import time 7 | import subprocess as sb 8 | import os 9 | import numpy as np 10 | import datetime 11 | 12 | ### set parameters ### 13 | Nhidden = "100" # number of hidden units 14 | seeds = [61,62,63,64,65] # seeds to run 15 | Lplan = "8" # planning depth 16 | T = "50" # maximum time in units of physical actions 17 | last_save = "1000" # number of model iterations 18 | lrate = "1e-3" #learning rate 19 | prefix = "" #prefix something to filename 20 | 21 | loads, load_epochs, load_fnames = [], [], ["" for seed in seeds] #some defaults for loading previous models 22 | 23 | def generate_submission_file(seed, load = "false", load_epoch = "0", load_fname = ""): 24 | 25 | load = load 26 | load_epoch = load_epoch 27 | load_fname = load_fname 28 | 29 | options = """"-t 20 --project=. ./walls_train.jl""" 30 | options += " --load "+load 31 | options += " --load_epoch "+load_epoch 32 | options += " --Nhidden "+Nhidden 33 | options += " --seed "+str(seed) 34 | options += " --Lplan "+Lplan 35 | options += " --T "+T 36 | if load_fname != "": 37 | options += " --load_fname "+load_fname 38 | options += " --lrate "+lrate 39 | if len(prefix) > 0: options += " --prefix "+prefix 40 | options += " --n_epochs "+str(int(last_save)+1) 41 | options += """ " """ 42 | 43 | substring = """#!/bin/bash 44 | #! 45 | #! Example SLURM job script for Peta4-Skylake (Skylake CPUs, OPA) 46 | #! Last updated: Mon 13 Nov 12:25:17 GMT 2017 47 | #! 48 | 49 | #!############################################################# 50 | #!#### Modify the options in this section as appropriate ###### 51 | #!############################################################# 52 | 53 | #! sbatch directives begin here ############################### 54 | #! Name of the job: 55 | #SBATCH -J metaRL 56 | #! Which project should be charged: 57 | #SBATCH -A T2-CS156-CPU 58 | #! How many whole nodes should be allocated? 59 | #SBATCH --nodes=1 60 | #! How many (MPI) tasks will there be in total? (<= nodes*32) 61 | #! The skylake/skylake-himem nodes have 32 CPUs (cores) each. 62 | #SBATCH --ntasks=32 63 | #! How much wallclock time will be required? 64 | #SBATCH --time=36:00:00 65 | #! What types of email messages do you wish to receive? 66 | #SBATCH --mail-type=NONE 67 | #! Uncomment this to prevent the job from being requeued (e.g. if 68 | #! interrupted by node failure or system downtime): 69 | ##SBATCH --no-requeue 70 | 71 | #! For 6GB per CPU, set "-p skylake"; for 12GB per CPU, set "-p skylake-himem": 72 | #SBATCH -p skylake 73 | 74 | #!SBATCH -e slurm-%j.err 75 | #!SBATCH -o slurm-%j.out 76 | 77 | #! sbatch directives end here (put any additional directives above this line) 78 | 79 | 80 | 81 | #! Notes: 82 | #! Charging is determined by core number*walltime. 83 | #! The --ntasks value refers to the number of tasks to be launched by SLURM only. This 84 | #! usually equates to the number of MPI tasks launched. Reduce this from nodes*32 if 85 | #! demanded by memory requirements, or if OMP_NUM_THREADS>1. 86 | #! Each task is allocated 1 core by default, and each core is allocated 5980MB (skylake) 87 | #! and 12030MB (skylake-himem). If this is insufficient, also specify 88 | #! --cpus-per-task and/or --mem (the latter specifies MB per node). 89 | 90 | #! Number of nodes and tasks per node allocated by SLURM (do not change): 91 | numnodes=$SLURM_JOB_NUM_NODES 92 | numtasks=$SLURM_NTASKS 93 | mpi_tasks_per_node=$(echo "$SLURM_TASKS_PER_NODE" | sed -e 's/^\\([0-9][0-9]*\\).*$/\\1/') 94 | #! ############################################################ 95 | #! Modify the settings below to specify the application's environment, location 96 | #! and launch method: 97 | 98 | #! Optionally modify the environment seen by the application 99 | #! (note that SLURM reproduces the environment at submission irrespective of ~/.bashrc): 100 | . /etc/profile.d/modules.sh # Leave this line (enables the module command) 101 | module purge # Removes all modules still loaded 102 | module load rhel7/default-peta4 # REQUIRED - loads the basic environment 103 | 104 | #! Insert additional module load commands after this line if needed: 105 | #!source /home/ktj21/.bash_profile 106 | module load miniconda/3 107 | #!conda activate ktj21 108 | 109 | #module load openmpi 110 | 111 | #! Full path to application executable: 112 | application="/rds/user/ktj21/hpc-work/julia-1.7.1/bin/julia" 113 | 114 | #! Run options for the application: 115 | options="""+options+""" 116 | 117 | #! Work directory (i.e. where the job will run): 118 | workdir="$SLURM_SUBMIT_DIR" # The value of SLURM_SUBMIT_DIR sets workdir to the directory 119 | # in which sbatch is run. 120 | 121 | 122 | #! Are you using OpenMP (NB this is unrelated to OpenMPI)? If so increase this 123 | #! safe value to no more than 32: 124 | export OMP_NUM_THREADS=20 125 | 126 | #! Number of MPI tasks to be started by the application per node and in total (do not change): 127 | np=$[${numnodes}*${mpi_tasks_per_node}] 128 | 129 | #! The following variables define a sensible pinning strategy for Intel MPI tasks - 130 | #! this should be suitable for both pure MPI and hybrid MPI/OpenMP jobs: 131 | export I_MPI_PIN_DOMAIN=omp:compact # Domains are $OMP_NUM_THREADS cores in size 132 | export I_MPI_PIN_ORDER=compact # Adjacent domains have minimal sharing of caches/sockets 133 | #! Notes: 134 | #! 1. These variables influence Intel MPI only. 135 | #! 2. Domains are non-overlapping sets of cores which map 1-1 to MPI tasks. 136 | #! 3. I_MPI_PIN_PROCESSOR_LIST is ignored if I_MPI_PIN_DOMAIN is set. 137 | #! 4. If MPI tasks perform better when sharing caches/sockets, try I_MPI_PIN_ORDER=compact. 138 | 139 | 140 | #! Uncomment one choice for CMD below (add mpirun/mpiexec options if necessary): 141 | 142 | #! Choose this for a MPI code (possibly using OpenMP) using Intel MPI. 143 | #CMD="mpirun -ppn $mpi_tasks_per_node -np $np $application $options" 144 | 145 | #! Choose this for a pure shared-memory OpenMP parallel program on a single node: 146 | #! (OMP_NUM_THREADS threads will be created): 147 | CMD="$application $options" 148 | 149 | #! Choose this for a MPI code (possibly using OpenMP) using OpenMPI: 150 | #CMD="mpirun -npernode $mpi_tasks_per_node -np $np $application $options" 151 | 152 | #CMD="mpirun -ppn 1 -np 1 $application $options" 153 | 154 | ############################################################### 155 | ### You should not have to change anything below this line #### 156 | ############################################################### 157 | 158 | cd $workdir 159 | echo -e "Changed directory to `pwd`.\\n" 160 | 161 | JOBID=$SLURM_JOB_ID 162 | 163 | echo -e "JobID: $JOBID\\n======" 164 | echo "Time: `date`" 165 | echo "Running on master node: `hostname`" 166 | echo "Current directory: `pwd`" 167 | 168 | if [ "$SLURM_JOB_NODELIST" ]; then 169 | #! Create a machine file: 170 | export NODEFILE=`generate_pbs_nodefile` 171 | cat $NODEFILE | uniq > machine.file.$JOBID 172 | echo -e "\\nNodes allocated:\\n================" 173 | echo `cat machine.file.$JOBID | sed -e 's/\\..*$//g'` 174 | fi 175 | 176 | echo -e "\\nnumtasks=$numtasks, numnodes=$numnodes, mpi_tasks_per_node=$mpi_tasks_per_node (OMP_NUM_THREADS=$OMP_NUM_THREADS)" 177 | 178 | echo -e "\\nExecuting command:\\n==================\\n$CMD\\n" 179 | 180 | eval $CMD 181 | 182 | """ 183 | 184 | return substring 185 | 186 | def create_model_name( Nhidden, T, seed, Lplan, prefix = ""): 187 | '''somewhat ugly solution of copying this from julia''' 188 | #define some useful model name 189 | mod_name = prefix+"N"+Nhidden+"_T"+T+"_Lplan"+Lplan+"_seed"+str(seed) 190 | return mod_name 191 | 192 | def last_written_epoch(task, mod_name): 193 | '''find the last saved checkpoint''' 194 | files = os.listdir('./models/'+task+'/') #all saved files 195 | exts = [f[len(mod_name):] for f in files if f[:len(mod_name)] == mod_name] 196 | epochs = [] 197 | for ext in exts: 198 | try: 199 | epochs.append(int(ext.split('_')[1])) 200 | except ValueError: 201 | None 202 | 203 | if len(epochs) == 0: #no files exist 204 | print('something went wrong and we dont have any checkpoints!') 205 | exit() 206 | 207 | max_epoch = str(int(np.sort(epochs)[-1])) #largest epoch 208 | return max_epoch 209 | 210 | if __name__ == '__main__': 211 | finished = [False for seed in seeds] 212 | if len(loads) == 0: 213 | loads = [("false" if load_fname == "" else "true") for load_fname in load_fnames] 214 | if len(load_epochs) == 0: 215 | load_epochs = ["0" for seed in seeds] 216 | most_recent_load_epoch = [load_epoch for load_epoch in load_epochs] 217 | mod_names = [create_model_name( Nhidden, T, seed, Lplan, prefix = prefix) for seed in seeds] 218 | 219 | starttime = datetime.datetime.now() 220 | starttime = starttime.strftime("%Y_%m_%d_%H_%M_%S") 221 | 222 | while not all(finished): 223 | 224 | seeds = [seeds[i] for i in range(len(seeds)) if not finished[i]] 225 | finished = [finished[i] for i in range(len(seeds)) if not finished[i]] 226 | 227 | substrs = [generate_submission_file(seeds[i], load = loads[i], load_epoch = load_epochs[i], load_fname = load_fnames[i]) for i in range(len(seeds))] 228 | slurm_outputs, pids = [], [] 229 | for i in range(len(seeds)): 230 | slurmname = "slurm_submit_"+starttime+'_'+str(seeds[i]) 231 | with open(slurmname, "w") as f: 232 | f.write(substrs[i]) 233 | 234 | print("starting new job: ", mod_names[i], starttime, loads[i], load_epochs[i]) 235 | 236 | slurm_outputs.append(sb.check_output(["sbatch", slurmname]).decode('UTF-8')) 237 | pids.append(slurm_outputs[-1].split()[-1]) #process id 238 | 239 | time.sleep(1) #wait one second for job to be submitted 240 | squeue_out = sb.check_output(["squeue", "-u", "ktj21"]).decode('UTF-8') 241 | assert pids[-1] in squeue_out #check that it's been submitted 242 | print('submitted job:', pids[-1]) 243 | 244 | running = [True for seed in seeds] 245 | tic = time.time() 246 | while any(running): 247 | running = [] 248 | time.sleep(60*60) #wait 60 mins at a time 249 | squeue_out = sb.check_output(["squeue", "-u", "ktj21"]).decode('UTF-8') 250 | for i in range(len(seeds)): 251 | if pids[i] in squeue_out: 252 | print('still running', (time.time() - tic)/60/60) 253 | running.append(True) 254 | else: 255 | running.append(False) #finished 256 | 257 | 258 | ### check file list and identify most recent checkpoint ### 259 | load_epochs = [last_written_epoch(task, mod_name) for mod_name in mod_names] 260 | load_fnames = ["" for seed in seeds] #load newest model rather than old load name 261 | 262 | for i in range(len(seeds)): 263 | if load_epochs[i] >= last_save or load_epochs[i] == most_recent_load_epoch[i]: 264 | #if we've reached the end or not made progress 265 | print('finished!', load_epochs[i], last_save, most_recent_load_epoch[i]) 266 | finished[i] = True 267 | loads[i] = "false" 268 | else: 269 | loads[i] = "true" # need to load from previous state 270 | most_recent_load_epoch[i] = load_epochs[i] 271 | finished[i] = False 272 | 273 | -------------------------------------------------------------------------------- /computational_model/results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/computational_model/results/.gitkeep -------------------------------------------------------------------------------- /computational_model/src/ToPlanOrNotToPlan.jl: -------------------------------------------------------------------------------- 1 | module ToPlanOrNotToPlan 2 | 3 | include("plotting.jl") 4 | include("train.jl") 5 | include("model.jl") 6 | include("loss_hyperparameters.jl") 7 | include("environment.jl") 8 | include("a2c.jl") 9 | include("initializations.jl") 10 | include("walls.jl") 11 | include("maze.jl") 12 | include("walls_build.jl") 13 | include("io.jl") 14 | include("model_planner.jl") 15 | include("planning.jl") 16 | include("human_utils_maze.jl") 17 | include("exports.jl") 18 | include("priors.jl") 19 | include("walls_baselines.jl") 20 | 21 | rc("font", size = 16) 22 | PyCall.PyDict(matplotlib."rcParams")["axes.spines.top"] = false 23 | PyCall.PyDict(matplotlib."rcParams")["axes.spines.right"] = false 24 | 25 | end 26 | -------------------------------------------------------------------------------- /computational_model/src/environment.jl: -------------------------------------------------------------------------------- 1 | struct EnvironmentDimensions 2 | Nstates::Int 3 | Nstate_rep::Int 4 | Naction::Int 5 | T::Int 6 | Larena::Int 7 | end 8 | 9 | struct Environment 10 | initialize::Function 11 | step::Function 12 | dimensions::EnvironmentDimensions 13 | end 14 | 15 | struct WorldState 16 | agent_state 17 | environment_state 18 | planning_state 19 | end 20 | 21 | function WorldState(; agent_state, environment_state, planning_state) 22 | return WorldState(agent_state, environment_state, planning_state) 23 | end 24 | -------------------------------------------------------------------------------- /computational_model/src/exports.jl: -------------------------------------------------------------------------------- 1 | # functions 2 | 3 | # plotting.jl 4 | export plot_progress 5 | 6 | # loss_hyperparameters.jl 7 | export LossHyperparameters 8 | 9 | # model.jl 10 | export ModelProperties, ModularModel 11 | export create_model_name 12 | export build_model 13 | 14 | # environment.jl 15 | export EnvironmentDimensions, EnvironmentProperties, Environment, WorldState 16 | 17 | # a2c.jl 18 | export GRUind 19 | export run_episode, model_loss 20 | export forward_modular, sample_actions, construct_ahot, zeropad_data 21 | 22 | # wall_build.jl 23 | export Arena 24 | export build_arena, build_environment 25 | export act_and_receive_reward 26 | export update_agent_state 27 | 28 | # walls.jl 29 | export onehot_from_state, 30 | onehot_from_loc, 31 | state_from_loc, 32 | state_ind_from_state, 33 | state_from_onehot, 34 | gen_input, 35 | get_wall_input, 36 | comp_π 37 | export WallState 38 | 39 | # train.jl 40 | export gmap 41 | 42 | # io.jl 43 | export recover_model, save_model 44 | 45 | # initializations.jl 46 | export initialize_arena, gen_maze_walls, gen_wall_walls 47 | 48 | # maze.jl 49 | export maze 50 | 51 | # plotting 52 | export arena_lines, 53 | plot_arena, 54 | plot_weiji_gif 55 | 56 | #planners 57 | export build_planner 58 | export PlanState 59 | export model_planner 60 | 61 | #analyses of human data 62 | export get_wall_rep 63 | export extract_maze_data 64 | 65 | #priors 66 | export prior_loss 67 | 68 | #baseline policies 69 | export random_policy, dist_to_rew, optimal_policy -------------------------------------------------------------------------------- /computational_model/src/human_utils_maze.jl: -------------------------------------------------------------------------------- 1 | using SQLite, DataFrames, Statistics, HypothesisTests 2 | 3 | adict = Dict("[\"Up\"]" => 3, "[\"Down\"]" => 4, "[\"Right\"]" => 1, "[\"Left\"]" => 2) 4 | 5 | function find_seps(str) 6 | seps = [ 7 | findall("]],[[", str; overlap=true) 8 | findall("]],[]", str; overlap=true) 9 | findall("[],[[", str; overlap=true) 10 | findall("[],[]", str; overlap=true) 11 | ] 12 | return sort(reduce(hcat, seps)[3, :]) 13 | end 14 | 15 | function get_wall_rep(wallstr, arena) 16 | seps = find_seps(wallstr) 17 | columns = [ 18 | wallstr[3:(seps[4] - 2)], 19 | wallstr[(seps[4] + 2):(seps[8] - 2)], 20 | wallstr[(seps[8] + 2):(seps[12] - 2)], 21 | wallstr[(seps[12] + 2):(length(wallstr) - 2)], 22 | ] 23 | subseps = [[0; find_seps(col); length(col) + 1] for col in columns] 24 | 25 | wdict = Dict( 26 | "[\"Top\"]" => 3, "[\"Bottom\"]" => 4, "[\"Right\"]" => 1, "[\"Left\"]" => 2 27 | ) 28 | 29 | new_walls = zeros(16, 4) 30 | for (i, col) in enumerate(columns) 31 | for j in 1:4 32 | ind = state_ind_from_state(arena, [i; j])[1] 33 | s1, s2 = subseps[i][j:(j + 1)] 34 | entries = split(col[((s1 + 2):(s2 - 2))], ",") 35 | for entry in entries 36 | if length(entry) > 0.5 37 | new_walls[ind, wdict[entry]] = 1 38 | end 39 | end 40 | end 41 | end 42 | return new_walls 43 | end 44 | 45 | function extract_maze_data(db, user_id, Larena; T=100, max_RT=5000, game_type = "play", 46 | skip_init = 1, skip_finit = 0) 47 | Nstates = Larena^2 48 | 49 | epis = DataFrame(DBInterface.execute( 50 | db, "SELECT * FROM episodes WHERE user_id = $user_id AND game_type = '$game_type'" 51 | )) 52 | 53 | if "attention_problem" in names(epis) #discard episodes with a failed attention check 54 | atts = epis[:, "attention_problem"] 55 | keep = findall(atts .== "null") 56 | epis = epis[keep, :] 57 | end 58 | 59 | ids = epis[:, "id"] #episode ids 60 | 61 | #make sure each episode has at least 2 steps 62 | stepnums = [size(DataFrame(DBInterface.execute(db, "SELECT * FROM steps WHERE episode_id = " * string(id))), 1) for id = ids] 63 | ids = ids[stepnums .> 1.5] 64 | 65 | inds = (1+skip_init):(length(ids)-skip_finit) #allow for discarding the first/last few episodes 66 | ids = ids[inds] 67 | 68 | batch_size = length(ids) 69 | 70 | rews, as, times = zeros(batch_size, T), zeros(batch_size, T), zeros(batch_size, T) 71 | states = ones(2, batch_size, T) 72 | trial_nums, trial_time = zeros(batch_size, T), zeros(batch_size, T) 73 | wall_loc, ps = zeros(16, 4, batch_size), zeros(16, batch_size) 74 | for b in 1:batch_size 75 | steps = DataFrame(DBInterface.execute( 76 | db, "SELECT * FROM steps WHERE episode_id = " * string(ids[b]) 77 | )) 78 | trial_num = 1 79 | t0 = 0 80 | 81 | wall_loc[:, :, b] = get_wall_rep(epis[inds[b], "walls"], Larena) 82 | ps[:, b] = onehot_from_state(Larena, 83 | [parse(Int, epis[inds[b], "reward"][i]) for i in [2; 4]] .+ 1 84 | ) 85 | Tb = size(steps, 1) #steps on this trial 86 | 87 | for i in reverse(1:Tb) #steps are stored in reverse order 88 | t = steps[i, "step"] 89 | if (t > 0.5) && (i < Tb-0.5 || steps[i, "action_time"] < 20000) #last action of previous episode can carry over 90 | times[b, t] = steps[i, "action_time"] 91 | rews[b, t] = Int(steps[i, "outcome"] == "[\"Hit_reward\"]") 92 | as[b, t] = adict[steps[i, "action_type"]] 93 | states[:, b, t] = [parse(Int, steps[i, "agent"][j]) for j in [2; 4]] .+ 1 94 | 95 | trial_nums[b, t] = trial_num 96 | trial_time[b, t] = t - t0 97 | if rews[b, t] > 0.5 #found reward 98 | trial_num += 1 #next trial 99 | t0 = t #reset trial_time 100 | end 101 | end 102 | end 103 | end 104 | 105 | RTs = [times[:, 1:1] (times[:, 2:T] - times[:, 1:(T - 1)])] 106 | RTs[RTs .< 0.5] .= NaN #end of trial 107 | for b in 1:batch_size 108 | rewtimes = findall(rews[b, 1:T] .> 0.5) 109 | RTs[b, rewtimes .+ 1] .-= (8 * 50) #after update; subtract the 400 ms showing that we are at reward 110 | end 111 | 112 | shot = zeros(Nstates, size(states, 2), size(states, 3)) .+ NaN 113 | for b in 1:size(states, 2) 114 | Tb = sum(as[1, :] .> 0.5) 115 | shot[:, b, 1:Tb] = onehot_from_state(Larena, Int.(states[:, b, 1:Tb])) 116 | end 117 | 118 | return ( 119 | rews, 120 | as, 121 | states, 122 | wall_loc, 123 | ps, 124 | times, 125 | trial_nums, 126 | trial_time, 127 | RTs, 128 | shot, 129 | ) 130 | end 131 | -------------------------------------------------------------------------------- /computational_model/src/initializations.jl: -------------------------------------------------------------------------------- 1 | function reset_agent_state(Larena, reward_location, batch) 2 | Nstates = Larena^2 3 | agent_state = rand(Categorical(ones(Larena) / Larena), 2, batch) #random starting location (2xbatch) 4 | #make sure we cannot start at reward! 5 | for b in 1:batch 6 | tele_reward_location = ones(Nstates) / (Nstates - 1) 7 | tele_reward_location[Bool.(reward_location[:, b])] .= 0 8 | agent_state[:, b] = state_from_loc( 9 | Larena, rand(Categorical(tele_reward_location), 1, 1) 10 | ) 11 | end 12 | return agent_state 13 | end 14 | 15 | ### task-specific initialization ### 16 | function gen_maze_walls( 17 | Larena, batch 18 | ) 19 | wall_loc = zeros(Float32, Larena^2, 4, batch) #whether there is a wall between neighboring agent_states 20 | for b in 1:batch 21 | wall_loc[:, :, b] = maze(Larena) 22 | end 23 | return wall_loc 24 | end 25 | 26 | function initialize_arena(reward_location, agent_state, batch, model_properties, environment_dimensions, initial_plan_state;initial_params = []) 27 | Zygote.ignore() do 28 | Larena=environment_dimensions.Larena; Nstates = Larena^2 29 | rew_loc = rand(Categorical(ones(Nstates) / Nstates), batch) 30 | if maximum(reward_location) <= 0 31 | reward_location = zeros(Float32, Nstates, batch) #Nstates x batch 32 | for b in 1:batch 33 | reward_location[rew_loc[b], b] = 1.0f0 34 | end 35 | end 36 | 37 | if maximum(agent_state) <= 0 38 | agent_state = reset_agent_state(Larena, reward_location, batch) 39 | end 40 | 41 | if length(initial_params) > 0 #load environment 42 | wall_loc = initial_params 43 | else 44 | wall_loc = gen_maze_walls(Larena, batch) 45 | end 46 | 47 | #note: start at t=1 for backwards compatibility 48 | world_state = WorldState(; 49 | environment_state=WallState(; 50 | wall_loc=Int32.(wall_loc), reward_location=Float32.(reward_location), time = ones(Float32, batch), 51 | ), 52 | agent_state=Int32.(agent_state), 53 | planning_state = initial_plan_state(batch) 54 | ) 55 | 56 | ahot = zeros(Float32, environment_dimensions.Naction, batch) #should use 'Naction' from somewhere 57 | rew = zeros(Float32, 1, batch) #no reward or actions yet 58 | x = gen_input(world_state, ahot, rew, environment_dimensions, model_properties) 59 | 60 | return world_state, Float32.(x) 61 | end 62 | end 63 | -------------------------------------------------------------------------------- /computational_model/src/io.jl: -------------------------------------------------------------------------------- 1 | using BSON: @load, @save 2 | 3 | function recover_model(filename) 4 | 5 | #@load filename * "_opt.bson" opt 6 | opt = nothing 7 | @load filename * "_hps.bson" hps 8 | @load filename * "_progress.bson" store 9 | @load filename * "_mod.bson" network 10 | @load filename * "_policy.bson" policy 11 | @load filename * "_prediction.bson" prediction 12 | return network, opt, store, hps, policy, prediction 13 | end 14 | 15 | function save_model(m, store, opt, filename, environment, loss_hp; Lplan) 16 | model_properties = m.model_properties 17 | network = m.network 18 | hps = Dict( 19 | "Nhidden" => model_properties.Nhidden, 20 | "T" => environment.dimensions.T, 21 | "Larena" => environment.dimensions.Larena, 22 | "Nin" => model_properties.Nin, 23 | "Nout" => model_properties.Nout, 24 | "GRUind" => ToPlanOrNotToPlan.GRUind, 25 | "βp" => loss_hp.βp, 26 | "βe" => loss_hp.βe, 27 | "βr" => loss_hp.βr, 28 | "Lplan" => Lplan, 29 | ) 30 | @save filename * "_progress.bson" store 31 | @save filename * "_mod.bson" network 32 | @save filename * "_opt.bson" opt 33 | @save filename * "_hps.bson" hps 34 | 35 | if :policy in fieldnames(typeof(m)) 36 | policy = m.policy 37 | @save filename * "_policy.bson" policy 38 | end 39 | if :prediction in fieldnames(typeof(m)) 40 | prediction = m.prediction 41 | @save filename * "_prediction.bson" prediction 42 | end 43 | if :prediction_state in fieldnames(typeof(m)) 44 | prediction_state = m.prediction_state 45 | @save filename * "_prediction_state.bson" prediction_state 46 | end 47 | end 48 | -------------------------------------------------------------------------------- /computational_model/src/loss_hyperparameters.jl: -------------------------------------------------------------------------------- 1 | struct LossHyperparameters 2 | βv::Float32 3 | βe::Float32 4 | βp::Float32 5 | βr::Float32 6 | end 7 | 8 | function LossHyperparameters(; βv, βe, βp, βr) 9 | return LossHyperparameters(βv, βe, βp, βr) 10 | end 11 | 12 | -------------------------------------------------------------------------------- /computational_model/src/maze.jl: -------------------------------------------------------------------------------- 1 | function neighbor(cell, dir, msize) 2 | neigh = ((cell + 1 * dir .+ msize .- 1) .% msize) .+ 1 3 | return neigh 4 | end 5 | 6 | function neighbors(cell, msize; wrap = true) 7 | dirs = [[1, 0], [-1, 0], [0, 1], [0, -1]] 8 | Ns = [cell+ dirs[a] for a = 1:4] 9 | as = 1:4 10 | if wrap # states outside arena pushed to other side 11 | Ns = [((N .+ msize .- 1) .% msize) .+ 1 for N in Ns] 12 | else # states outside arena not considered 'neighbors' 13 | inds = findall( (minimum.(Ns) .> 0.5) .& (maximum.(Ns) .< msize+0.5) ) 14 | Ns, as = Ns[inds], as[inds] 15 | end 16 | return Ns, as 17 | end 18 | 19 | function walk(maz::Array, nxtcell::Vector, msize, visited::Vector=[]; wrap = true) 20 | dir_map = Dict(1 => 2, 2 => 1, 3 => 4, 4 => 3) 21 | push!(visited, (nxtcell[1] - 1) * msize + nxtcell[2]) #add to list of visited cells 22 | 23 | neighs, as = neighbors(nxtcell, msize, wrap = wrap) # get list of neighbors 24 | 25 | for nnum in randperm(length(neighs)) #for each neighbor in randomly shuffled list 26 | neigh, a = neighs[nnum], as[nnum] # corresponding state and action 27 | ind = (neigh[1] - 1) * msize + neigh[2] # convert from coordinates to index 28 | if ind ∉ visited #check that we haven't been there 29 | maz[nxtcell[1], nxtcell[2], a] = 0.0f0 #remove wall 30 | maz[neigh[1], neigh[2], dir_map[a]] = 0.0f0 #remove reverse wall 31 | maz, visited = walk(maz, neigh, msize, visited, wrap = wrap) # 32 | end 33 | end 34 | return maz, visited 35 | end 36 | 37 | function maze(msize; wrap = true) 38 | dirs = [[1, 0], [-1, 0], [0, 1], [0, -1]] 39 | dir_map = Dict(1 => 2, 2 => 1, 3 => 4, 4 => 3) 40 | maz = ones(Float32, msize, msize, 4) #start with walls everywhere 41 | cell = rand(1:msize, 2) #where do we start? 42 | maz, visited = walk(maz, cell, msize, wrap = wrap) #walk through maze 43 | 44 | # remove a couple of additional walls to increase degeneracy 45 | if wrap 46 | holes = Int(3 * (msize - 3)) #3 for Larena=4, 6 for Larena = 5 47 | else 48 | holes = Int(4 * (msize - 3)) #4 for Larena=4, 8 for Larena = 5 49 | # note permanent walls 50 | maz[msize, :, 1] .= 0.5f0; maz[1, :, 2] .= 0.5f0 51 | maz[:, msize, 3] .= 0.5f0; maz[:, 1, 4] .= 0.5f0 52 | end 53 | for _ in 1:holes 54 | walls = findall(maz .== 1) 55 | wall = rand(walls) 56 | cell, a = [wall[1]; wall[2]], wall[3] 57 | 58 | neigh = neighbor([cell[1]; cell[2]], dirs[a], msize) 59 | maz[cell[1], cell[2], a] = 0.0f0 #remove wall 60 | maz[neigh[1], neigh[2], dir_map[a]] = 0.0f0 #remove reverse wall 61 | end 62 | maz[maz .== 0.5] .= 1f0 # reinstate permanent walls 63 | 64 | maz = reshape(permutedims(maz, [2, 1, 3]), prod(size(maz)[1:2]), 4) 65 | 66 | return Float32.(maz) 67 | end 68 | -------------------------------------------------------------------------------- /computational_model/src/model.jl: -------------------------------------------------------------------------------- 1 | using Zygote, Flux 2 | 3 | struct ModelProperties 4 | Nout::Int 5 | Nhidden::Int 6 | Nin::Int 7 | Lplan::Int 8 | greedy_actions::Bool 9 | no_planning::Any #if true, never stand still 10 | end 11 | 12 | struct ModularModel 13 | model_properties::ModelProperties 14 | network::Any 15 | policy::Any 16 | prediction::Any 17 | forward::Function 18 | end 19 | 20 | function ModelProperties(; Nout, Nhidden, Nin, Lplan, greedy_actions, no_planning = false) 21 | return ModelProperties(Nout, Nhidden, Nin, Lplan, greedy_actions, no_planning) 22 | end 23 | 24 | function Modular_model(mp::ModelProperties, Naction::Int; Nstates = nothing, neighbor = false) 25 | # define our model! 26 | network = Chain(GRU(mp.Nin, mp.Nhidden)) 27 | policy = Chain(Dense(mp.Nhidden, Naction+1)) #policy and value function 28 | Npred_out = mp.Nout - Naction - 1 29 | prediction = Chain(Dense(mp.Nhidden+Naction, Npred_out, relu), Dense(Npred_out, Npred_out)) 30 | return ModularModel(mp, network, policy, prediction, forward_modular) 31 | end 32 | 33 | function build_model(mp::ModelProperties, Naction::Int) 34 | return Modular_model(mp, Naction) 35 | end 36 | 37 | function create_model_name( 38 | Nhidden::Int, 39 | T::Int, 40 | seed, 41 | Lplan::Int; 42 | prefix = "" 43 | ) 44 | #define some useful model name 45 | mod_name = 46 | prefix* 47 | "N$Nhidden" * 48 | "_T$T" * 49 | "_Lplan$Lplan" * 50 | "_seed$seed" 51 | 52 | return mod_name 53 | end 54 | -------------------------------------------------------------------------------- /computational_model/src/model_planner.jl: -------------------------------------------------------------------------------- 1 | using Distributions 2 | 3 | function model_tree_search(goal, world_state, model, h_rnn, plan_inds, times, ed, mp, planner; Print = false) 4 | 5 | Larena, Naction = ed.Larena, ed.Naction 6 | Nstates = Larena^2 7 | 8 | batch = size(h_rnn, 2) 9 | path = zeros(4, planner.Lplan, batch) 10 | all_Vs = zeros(Float32, batch) #value functions 11 | found_rew = zeros(Float32, batch) #did I finish planning 12 | plan_states = zeros(Int32, planner.Lplan, batch) 13 | wall_loc = world_state.environment_state.wall_loc 14 | 15 | #only consider planning states 16 | h_rnn = h_rnn[:, plan_inds] 17 | goal = goal[plan_inds] 18 | times = times[plan_inds] 19 | wall_loc = wall_loc[:, :, plan_inds] 20 | ytemp = h_rnn #same for GRU 21 | 22 | agent_input = zeros(Float32, mp.Nin) #instantiate 23 | new_world_state = world_state 24 | 25 | for n_steps = 1:planner.Lplan 26 | batch = length(goal) #number of active states 27 | 28 | if n_steps > 1.5 #start from current hidden state 29 | ### generate new output ### 30 | h_rnn, ytemp = model.network[GRUind].cell(h_rnn, agent_input) #forward pass 31 | end 32 | 33 | ### generate actions from hidden activity ### 34 | logπ_V = model.policy(ytemp) 35 | #normalize over actions 36 | logπ = logπ_V[1:4, :] .- Flux.logsumexp(logπ_V[1:4, :], dims = 1) #softmax 37 | Vs = logπ_V[6, :] / 10f0 #range ~ [0,1] 38 | 39 | πt = exp.(logπ) 40 | a = zeros(Int32, 1, batch) #sample actions 41 | a[:] = Int32.(rand.(Categorical.([πt[:, b] for b = 1:batch]))) 42 | #a[:] = Int32.(rand.(Categorical.([ones(4) / 4 for b = 1:batch]))) #random action 43 | 44 | ### record actions ### 45 | for (ib, b) = enumerate(plan_inds) 46 | path[a[1, ib], n_steps, b] = 1f0 #'a' in local coordinates, 'path' in global 47 | end 48 | 49 | ### generate predictions ### 50 | ahot = zeros(Float32, Naction, batch) #one-hot 51 | for b = 1:batch ahot[a[1, b], b] = 1f0 end 52 | prediction_input = [ytemp; ahot] #input to prediction module 53 | prediction_output = model.prediction(prediction_input) #output from prediction module 54 | 55 | ### draw new states ### 56 | spred = prediction_output[1:Nstates, :] #predicted states (Nstates x batch) 57 | spred = spred .- Flux.logsumexp(spred; dims=1) #softmax over states 58 | state_dist = exp.(spred) #state distribution 59 | new_states = Int32.(argmax.([state_dist[:, b] for b = 1:batch])) #maximum likelihood new states 60 | 61 | Print && println(n_steps, " ", batch, " ", mean(maximum(πt, dims = 1)), " ", mean(maximum(state_dist, dims = 1))) 62 | 63 | ### record information about having finished ### 64 | not_finished = findall(new_states .!= goal) #vector of states that have not finished! 65 | finished = findall(new_states .== goal) #found the goal location on these ones 66 | 67 | all_Vs[plan_inds] = Vs #store latest value 68 | plan_states[n_steps, plan_inds] = new_states #store states 69 | found_rew[plan_inds[finished]] .+= 1f0 #record where we found the goal location 70 | 71 | if length(not_finished) == 0 return path, all_Vs, found_rew, plan_states end #finish if all done 72 | 73 | ### only consider active states going forward ### 74 | h_rnn = h_rnn[:, not_finished] 75 | goal = goal[not_finished] 76 | plan_inds = plan_inds[not_finished] 77 | times = times[not_finished] .+ 1f0 #increment time 78 | wall_loc = wall_loc[:, :, not_finished] 79 | reward_location = onehot_from_loc(Larena, goal) #onehot 80 | xplan = zeros(Float32, planner.Nplan_in, length(goal)) #no planning input 81 | 82 | ###reward inputs ### 83 | rew = zeros(Float32, 1, length(not_finished)) #we continue with the ones that did not get reward 84 | 85 | ### update world state ### 86 | new_world_state = WorldState(; 87 | agent_state=state_from_loc(Larena, new_states[not_finished]'), 88 | environment_state=WallState(; wall_loc, reward_location, time = times), 89 | planning_state = PlanState(xplan, nothing) 90 | ) 91 | 92 | ### generate input ### 93 | agent_input = gen_input(new_world_state, ahot[:, not_finished], rew, ed, mp) 94 | 95 | end 96 | return path, all_Vs, found_rew, plan_states 97 | end 98 | 99 | function model_planner(world_state, 100 | ahot, 101 | ed, 102 | agent_output, 103 | at_rew, 104 | planner, 105 | model, 106 | h_rnn, 107 | mp; 108 | Print = false, 109 | returnall = false, 110 | true_transition = false 111 | ) 112 | 113 | Larena = ed.Larena 114 | Naction = ed.Naction 115 | Nstates = ed.Nstates 116 | batch = size(ahot, 2) 117 | times = world_state.environment_state.time 118 | 119 | plan_inds = findall(Bool.(ahot[5, :]) .& (.~at_rew)) #everywhere we stand still not at the reward 120 | #rpred = agent_output[(Naction + Nstates + 2):(Naction + Nstates + 1 + Nstates), :] 121 | rpred = agent_output[(Naction + Nstates + 3):(Naction + Nstates + 2 + Nstates), :] 122 | goal = [argmax(rpred[:, b]) for b = 1:batch] #index of ML goal location 123 | 124 | ### agent-driven planning ### 125 | path, all_Vs, found_rew, plan_states = model_tree_search(goal, world_state, model, h_rnn, plan_inds, times, ed, mp, planner, Print = Print) 126 | 127 | xplan = zeros(planner.Nplan_in, batch) 128 | for b = 1:batch 129 | xplan[:, b] = [path[:, :, b][:]; found_rew[b]] 130 | end 131 | planning_state = PlanState(xplan, plan_states) 132 | 133 | returnall && return planning_state, plan_inds, (path, all_Vs, found_rew, plan_states) 134 | return planning_state, plan_inds 135 | end 136 | -------------------------------------------------------------------------------- /computational_model/src/planning.jl: -------------------------------------------------------------------------------- 1 | struct PlanState 2 | plan_input::Array{Float32} 3 | plan_cache::Any 4 | end 5 | 6 | struct Planner 7 | Lplan::Int 8 | Nplan_in::Int 9 | Nplan_out::Int 10 | planning_time::Float32 11 | planning_cost::Float32 12 | planning_algorithm::Function 13 | constant_rollout_time::Bool 14 | end 15 | 16 | function none_planner(world_state, 17 | ahot, 18 | ep, 19 | agent_output, 20 | at_rew, 21 | planner, 22 | model, 23 | h_rnn, 24 | mp) 25 | batch = size(ahot, 2) 26 | xplan = zeros(Float32, 0, batch) #no input 27 | plan_inds = [] #no indices 28 | plan_cache = nothing #no cache 29 | planning_state = PlanState(xplan, plan_cache) 30 | 31 | return planning_state, plan_inds 32 | end 33 | 34 | function build_planner(Lplan, Larena; planning_time = 1f0, planning_cost = 0f0, constant_rollout_time = true) 35 | Nstates = Larena^2 36 | 37 | if Lplan <= 0.5 38 | Nplan_in, Nplan_out = 0, 0 39 | planning_algorithm = none_planner 40 | initial_plan_state = (batch -> PlanState([], [])) 41 | else 42 | Nplan_in = 4*Lplan+1 #action sequence and whether we ended at the reward location 43 | Nplan_out = Nstates #rew location 44 | planning_algorithm = model_planner 45 | planning_time = 0.3f0 #planning needs to be fairly cheap here 46 | initial_plan_state = (batch -> PlanState([], [])) #we don't use a cache 47 | end 48 | 49 | planner = Planner(Lplan, Nplan_in, Nplan_out, planning_time, planning_cost, planning_algorithm, constant_rollout_time) 50 | return planner, initial_plan_state 51 | end 52 | 53 | -------------------------------------------------------------------------------- /computational_model/src/plotting.jl: -------------------------------------------------------------------------------- 1 | using PyPlot 2 | using PyCall 3 | ### set some reasonable plotting defaults 4 | rc("font"; size=16) 5 | PyCall.PyDict(matplotlib."rcParams")["axes.spines.top"] = false 6 | PyCall.PyDict(matplotlib."rcParams")["axes.spines.right"] = false 7 | 8 | function plot_progress(rews, vals; fname="figs/progress.png") 9 | figure(; figsize=(6, 2.5)) 10 | axs = 120 .+ (1:2) 11 | data = [rews, vals] 12 | ts = 1:length(rews) 13 | labs = ["reward", "prediction"] 14 | for i in 1:2 15 | subplot(axs[i]) 16 | plot(ts, data[i], "k-") 17 | xlabel("epochs") 18 | ylabel(labs[i]) 19 | title(labs[i]) 20 | end 21 | tight_layout() 22 | savefig(fname; bbox_inches="tight") 23 | close() 24 | return nothing 25 | end 26 | 27 | 28 | ## plotting utils 29 | 30 | function arena_lines(ps, wall_loc, Larena; rew=true, col="k", rew_col = "k", lw_arena = 1., col_arena = ones(3)*0.3, lw_wall = 10) 31 | Nstates = Larena^2 32 | for i in 0:Larena 33 | axvline(i + 0.5; color=col_arena, lw = lw_arena) 34 | axhline(i + 0.5; color=col_arena, lw = lw_arena) 35 | end 36 | 37 | if rew 38 | rew_loc = state_from_onehot(Larena, ps) 39 | scatter([rew_loc[1]], [rew_loc[2]]; c=rew_col, marker="*", s=350, zorder = 50) #reward location 40 | end 41 | 42 | for s in 1:Nstates #for each state 43 | for i in 1:4 #for each neighbor 44 | if Bool(wall_loc[s, i]) 45 | state = state_from_loc(Larena, s) 46 | if i == 1 #wall to the right 47 | z1, z2 = state + [0.5; 0.5], state + [0.5; -0.5] 48 | elseif i == 2 #wall to the left 49 | z1, z2 = state + [-0.5; 0.5], state + [-0.5; -0.5] 50 | elseif i == 3 #wall above 51 | z1, z2 = state + [0.5; 0.5], state + [-0.5; 0.5] 52 | elseif i == 4 #wall below 53 | z1, z2 = state + [0.5; -0.5], state + [-0.5; -0.5] 54 | end 55 | plot([z1[1]; z2[1]], [z1[2]; z2[2]]; color=col, ls="-", lw=lw_wall) 56 | end 57 | end 58 | end 59 | xlim(0.49, Larena + 0.52) 60 | ylim(0.48, Larena + 0.51) 61 | xticks([]) 62 | yticks([]) 63 | return axis("off") 64 | end 65 | 66 | function plot_arena(ps, wall_loc, Larena; ind=1) 67 | ps = ps[:, ind] 68 | wall_loc = wall_loc[:, :, ind] 69 | figure(; figsize=(6, 6)) 70 | arena_lines(ps, wall_loc, Larena) 71 | savefig("figs/wall/test_arena.png"; bbox_inches="tight") 72 | return close() 73 | end 74 | 75 | 76 | function plot_rollout(state, rollout, wall, Larena) 77 | if Bool(rollout[end]) col = [0.5, 0.8, 0.5] else col = [0.5, 0.5, 0.8] end 78 | rollout = Int.(rollout) 79 | #new_state = state 80 | for a = rollout[1:length(rollout)-1] 81 | if a > 0.5 82 | if wall[state_ind_from_state(Larena, state)[1], a] > 0.5 83 | new_state = state 84 | else 85 | new_state = state + [[1;0],[-1;0],[0;1],[0;-1]][a] 86 | end 87 | new_state = (new_state .+ Larena .- 1) .% Larena .+ 1 88 | x1, x2 = [f(state[1], new_state[1]) for f = [min, max]] 89 | y1, y2 = [f(state[2], new_state[2]) for f = [min, max]] 90 | 91 | lw = 5 92 | #println(a, " ", state, " ", new_state, " ", x1, " ", x2, " ", y1, " ", y2) 93 | if x2 - x1 > 1.5 94 | plot([x1,x1-0.5], [y1,y2], ls = "-", color = col, lw = lw) 95 | plot([x2,x2+0.5], [y1,y2], ls = "-", color = col, lw = lw) 96 | elseif y2 - y1 > 1.5 97 | plot([x1,x2], [y1,y1-0.5], ls = "-", color = col, lw = lw) 98 | plot([x1,x2], [y2,y2+0.5], ls = "-", color = col, lw = lw) 99 | else 100 | plot([x1,x2], [y1,y2], ls = "-", color = col, lw = lw) 101 | end 102 | state = new_state 103 | end 104 | end 105 | end 106 | 107 | function plot_weiji_gif( 108 | ps, 109 | wall_loc, 110 | states, 111 | as, 112 | rews, 113 | Larena, 114 | RTs, 115 | fname; 116 | Tplot=10, 117 | res = 60, 118 | minframe = 3, #number of movement frames 119 | figsize = (4,4), 120 | rewT = 400, #delay at reward in ms 121 | fix_first_RT = true, 122 | first_RT = 500, 123 | plot_rollouts = false, #do we explicitly plot rollouts? 124 | rollout_time = 120, #duration of rollout plotting in ms 125 | rollouts = [], #array of actual rollouts, 126 | dpi = 80, #image resolution 127 | plot_rollout_frames = false #plot frames for rollouts even if we don't plot the rollouts 128 | ) 129 | #plot gif of agent moving through each batch 130 | #ps is Nstates x batch 131 | #wall_loc is Nstates x 4 x batch 132 | #states is 2 x batch x Tmax 133 | #as is batch x T 134 | #RTs are the reaction times for each step in ms 135 | #T_act is the time taken for an action in ms 136 | #T_rew is the time taken at reward in ms 137 | #res is the resolution (ms / frame) 138 | #Tplot is number of _seconds_ to plot (in real time) 139 | 140 | ##the minimum plotted reaction time is res*minframe 141 | 142 | run(`sh -c "mkdir -p $(fname)_temp"`) 143 | Tplot = Tplot*1e3/res #now in units of frames 144 | if fix_first_RT RTs[:, 1] .= first_RT end #fix the first RT since we think this is probably quite noisy 145 | 146 | for batch in 1:size(ps, 2) 147 | bstr = lpad(batch, 2, "0") 148 | rew_loc = state_from_onehot(Larena, ps[:, batch]) 149 | ### plot arena 150 | ### plot movement 151 | Rtot = 0 152 | t = 0 # real time 153 | anum = 0 # number of actions 154 | rew_col = "lightgrey" 155 | 156 | while (anum < sum(as[batch, :] .> 0.5)) && (t < Tplot) 157 | anum += 1 158 | 159 | astr = lpad(anum, 3, "0") 160 | println(bstr, " ", astr) 161 | 162 | RT = RTs[batch, anum] #reaction time for this action 163 | nframe = max(Int(round(RT/res)), minframe) #number of frames to plot (at least three) 164 | rewframes = 0 #no reward frames 165 | 166 | if rews[batch, anum] > 0.5 167 | rewframes = Int(round(rewT/res)) #add a singular frame at reward 168 | end 169 | 170 | if (anum > 1.5) && (rews[batch, anum-1] > 0.5) 171 | #Rtot += 1 #update total reward 172 | rew_col = "k" #show that we've found the reward 173 | end 174 | 175 | R_increased = false #have we increased R for this action 176 | frames = (minframe - nframe + 1):(minframe+rewframes) 177 | frolls = Dict(f => 0 for f = frames) #dictionary pointing to rollout; 0 is None 178 | 179 | if plot_rollouts || plot_rollout_frames #either plot rollouts or the corresponding frames 180 | nroll = sum(rollouts[batch, anum, 1, :] .> 0.5) #how many rollouts? 181 | println("rolls: ", nroll) 182 | f_per_roll = Int(round(rollout_time/res)) #frames per rollout 183 | frames = min(frames[1], -nroll*f_per_roll+1):frames[end] #make sure there is enough frames for plotting rollouts 184 | frolls = Dict(f => 0 for f = frames) #dictionary pointing to rollout; 0 is None 185 | 186 | for roll = 1:nroll 187 | new_rolls = (-(f_per_roll*roll-1):-(f_per_roll*(roll-1))) # f_per_roll frame intervals 188 | if nroll == 1 frac = 0.5 else frac = (roll-1)/(nroll-1) end #[0, 1] 189 | new_roll_1 = Int(round(frames[1]*frac - (f_per_roll-1)*(1-frac))) 190 | new_rolls = new_roll_1:(new_roll_1+f_per_roll-1) 191 | for r = new_rolls frolls[r] = nroll-roll+1 end #; println(roll, " ", r) end 192 | end 193 | end 194 | 195 | 196 | for f = frames 197 | state = states[:, batch, anum] 198 | fstr = lpad(f - frames[1], 3, "0") 199 | 200 | frac = min(max(0, (f - 1) / minframe), 1) 201 | figure(; figsize=figsize) 202 | 203 | arena_lines(ps[:, batch], wall_loc[:, :, batch], Larena; col="k", rew_col = rew_col) 204 | 205 | col = "b" 206 | if (rewframes > 0) && (frac >= 1) 207 | col = "g" #colour green when at reward 208 | if ~R_increased Rtot, R_increased = Rtot+1, true end #increase R because we found the reward 209 | end 210 | 211 | if plot_rollouts #plot the rollout 212 | if frolls[f] > 0.5 plot_rollout(state, rollouts[batch, anum, :, frolls[f]], wall_loc[:, :, batch], Larena) end 213 | end 214 | 215 | a = as[batch, anum] #higher resolution 216 | state += frac * [Int(a == 1) - Int(a == 2); Int(a == 3) - Int(a == 4)] #move towards next state 217 | state = (state .+ Larena .- 0.5) .% Larena .+ 0.5 218 | scatter([state[1]], [state[2]]; marker="o", color=col, s=200, zorder = 100) 219 | 220 | tstr = lpad(t, 3, "0") 221 | t += 1 222 | realt = t*res*1e-3 223 | println(step, " ", f, " ", round(frac, digits = 2), " ", RT, " ", round(realt, digits=1)) 224 | title("t = " * string(round(realt, digits=1)) * " (R = " * string(Rtot) * ")") 225 | 226 | astr = lpad(anum, 2, "0") 227 | if t <= Tplot 228 | savefig( 229 | "$(fname)_temp/temp" * bstr * "_" * tstr * "_" * fstr * "_" * astr * ".png"; 230 | bbox_inches="tight", 231 | dpi=dpi, 232 | ) 233 | end 234 | close() 235 | end 236 | end 237 | end 238 | 239 | ###combine pngs to gif 240 | run(`convert -delay $(Int(round(res/10))) "$(fname)_temp/temp*.png" $fname.gif`) 241 | #return 242 | return run(`sh -c "rm -r $(fname)_temp"`) 243 | end 244 | 245 | 246 | 247 | -------------------------------------------------------------------------------- /computational_model/src/priors.jl: -------------------------------------------------------------------------------- 1 | 2 | 3 | function U_prior(state, Naction) 4 | #uniform prior 5 | Zygote.ignore() do 6 | batch = size(state, 2) 7 | return ones(Float32, Naction, batch) / Naction 8 | end 9 | end 10 | 11 | function prior_loss(agent_output, state, active, mod) 12 | #return -KL[q || p] 13 | #return -KL[q || p] 14 | act = Float32.(Flux.unsqueeze(active, 1)) 15 | Naction = length(mod.policy[1].bias)-1 16 | logp = log.(U_prior(state, Naction)) #KL regularization with uniform prior 17 | logπ = agent_output[1:Naction, :] 18 | if mod.model_properties.no_planning 19 | logπ = logπ[1:Naction-1, :] .- Flux.logsumexp(logπ[1:Naction-1, :]; dims=1) 20 | logp = logp[1:Naction-1, :] .- Flux.logsumexp(logp[1:Naction-1, :]; dims=1) 21 | end 22 | logp = logp .* act 23 | logπ = logπ .* act 24 | lprior = sum(exp.(logπ) .* (logp - logπ )) #-KL 25 | return lprior 26 | end 27 | -------------------------------------------------------------------------------- /computational_model/src/train.jl: -------------------------------------------------------------------------------- 1 | using Flux, Zygote 2 | 3 | function gmap(f, prms, gss::Zygote.ADictOrGrads...) 4 | gsout = Zygote.Grads(IdDict{Any,Any}(), prms) 5 | return gmap!(f, gsout, gss...) 6 | end 7 | 8 | function gmap!(f, gsout::Zygote.Grads, gss::Zygote.ADictOrGrads...) 9 | for (ip, p) in enumerate(gsout.params) 10 | gsout[p] = f((_getformap(gs, gs.params[ip]) for gs in gss)...) 11 | end 12 | return gsout 13 | end 14 | function _getformap(gs, p) 15 | g = gs[p] 16 | return isnothing(g) ? fill!(similar(p), 0) : g 17 | end -------------------------------------------------------------------------------- /computational_model/src/walls.jl: -------------------------------------------------------------------------------- 1 | ### functions that are shared across environments ### 2 | 3 | using Flux, Statistics, Random, Distributions, StatsFuns, Zygote, PyPlot 4 | 5 | struct WallState 6 | reward_location::Array{Float32} 7 | wall_loc::Array{Int32} 8 | time::Array{Float32} 9 | end 10 | 11 | function WallState(; reward_location, wall_loc, time = zeros(1)) 12 | return WallState(reward_location, wall_loc, time) 13 | end 14 | 15 | function state_ind_from_state(Larena, state) 16 | #input is 2 x batch 17 | #output is (batch, ) 18 | return Larena * (state[1, :] .- 1) + state[2, :] 19 | end 20 | 21 | function onehot_from_loc(Larena, loc) 22 | #input: (batch, ) 23 | #output: Nstates x batch 24 | Nstates = Larena^2 25 | batch = length(loc) 26 | shot = zeros(Nstates, batch) 27 | for b in 1:batch 28 | shot[loc[b], b] = 1 29 | end 30 | return shot # don't take gradients of this 31 | end 32 | Zygote.@nograd(onehot_from_loc) 33 | 34 | function onehot_from_state(Larena, state) 35 | #input: 2 x batch 36 | #output: Nstates x batch 37 | state_ind = state_ind_from_state(Larena, state) # (batch,) 38 | return onehot_from_loc(Larena, state_ind) # don't take gradients of this 39 | end 40 | Zygote.@nograd(onehot_from_state) 41 | 42 | function state_from_loc(Larena, loc) 43 | #input: 1 x batch 44 | #output: 2 x batch 45 | return [(loc .- 1) .÷ Larena .+ 1; (loc .- 1) .% Larena .+ 1] 46 | end 47 | 48 | function state_from_onehot(Larena, shot) 49 | #inpute: Nstates x batch 50 | #output: 2 x batch 51 | loc = [sortperm(-shot[:, b])[1] for b in 1:size(shot, 2)] 52 | loc = reduce(hcat, loc) 53 | return state_from_loc(Larena, loc) 54 | end 55 | 56 | function get_wall_input(state, wall_loc) 57 | #state is 2xB 58 | #wall_loc is Nstates x 4 x B (4 is right/left/up/down) 59 | input = [wall_loc[:, 1, :]; wall_loc[:, 3, :]] #all horizontal and all vertical walls 60 | return input # don't take gradients of this 61 | end 62 | 63 | function gen_input( 64 | world_state, ahot, rew, ed, model_properties 65 | ) 66 | batch = size(rew, 2) 67 | newstate = world_state.agent_state 68 | wall_loc = world_state.environment_state.wall_loc 69 | Naction = ed.Naction 70 | Nstates = ed.Nstates 71 | shot = onehot_from_state(ed.Larena, newstate) #one-hot encoding (Nstates x batch) 72 | wall_input = get_wall_input(newstate, wall_loc) #get input about walls 73 | Nwall_in = size(wall_input, 1) 74 | Nin = model_properties.Nin 75 | plan_input = world_state.planning_state.plan_input 76 | Nplan_in = size(plan_input, 1) 77 | 78 | ### speed this up ### 79 | x = zeros(Nin, batch) 80 | x[1:Naction, :] = ahot 81 | x[Naction + 1, :] = rew[:] 82 | x[Naction + 2, :] = world_state.environment_state.time / 50f0 #smaller time input in [0,1] 83 | x[(Naction + 3):(Naction + 2 + Nstates), :] = shot 84 | x[(Naction + 2 + Nstates + 1):(Naction + 2 + Nstates + Nwall_in), :] = wall_input 85 | 86 | if length(plan_input) > 0 #set planning input 87 | x[(Naction + 2 + Nstates + Nwall_in + 1):(Naction + 2 + Nstates + Nwall_in + Nplan_in), :] = world_state.planning_state.plan_input 88 | end 89 | 90 | return Float32.(x) 91 | end 92 | 93 | 94 | function get_rew_locs(reward_location) 95 | return [argmax(reward_location[:, i]) for i in 1:size(reward_location, 2)] 96 | end 97 | Zygote.@nograd get_rew_locs #don't take gradients of this 98 | -------------------------------------------------------------------------------- /computational_model/src/walls_baselines.jl: -------------------------------------------------------------------------------- 1 | function random_policy(x, md, ed; stay=true) 2 | #if stay is false, only uniform over actual actions 3 | batch = size(x, 2) 4 | ys = Float32.(zeros(md.Nout, batch)) 5 | if stay 6 | ys[1:(ed.Naction), :] .= log(1 / ed.Naction) 7 | else 8 | ys[1:(ed.Naction - 1), :] .= log(1 / (ed.Naction - 1)) 9 | ys[ed.Naction, :] .= -Inf 10 | end 11 | return ys 12 | end 13 | 14 | function dist_to_rew(ps, wall_loc, Larena) 15 | #compute geodesic distance to reward from each state (i.e. taking walls into account) 16 | #ps is Nstates x 1 17 | #wall_loc is 16x4x1 18 | Nstates = Larena^2 19 | deltas = [[1; 0], [-1; 0], [0; 1], [0; -1]] #transitions for each action 20 | rew_loc = state_from_onehot(Larena, ps) #2x1 21 | dists = zeros(Larena, Larena) .+ NaN #distances to goal 22 | dists[rew_loc[1], rew_loc[2]] = 0 #reward has zero distance 23 | live_states = Bool.(zeros(Nstates)) 24 | live_states[state_ind_from_state(Larena, rew_loc)[1]] = true #start from rew loc and work backwards 25 | for step in 1:(Nstates - 1) #steps from reward 26 | for state_ind in findall(live_states) #all states I was at in (step-1) steps 27 | state = state_from_loc(Larena, state_ind) 28 | for a in 1:4 #for each action 29 | if ~Bool(wall_loc[state_ind, a, 1]) #if I do not hit a wall 30 | newstate = state .+ deltas[a] #where do I end up in 'step' steps 31 | newstate = Int.((newstate .+ Larena .- 1) .% Larena .+ 1) #1:L (2xbatch) 32 | if isnan(dists[newstate[1], newstate[2]]) #if I haven't gotten here in fewer steps 33 | dists[newstate[1], newstate[2]] = step #got here in step steps 34 | new_ind = state_ind_from_state(Larena, newstate)[1] 35 | live_states[new_ind] = true #need to search from here for >step steps 36 | end 37 | end 38 | end 39 | live_states[state_ind] = false #done searching for this state 40 | end 41 | end 42 | return dists #return geodesics 43 | end 44 | 45 | function optimal_policy(state, wall_loc, dists, ed) 46 | #return uniform log policy over actions that minimize the path length to goal 47 | #state is 2x1 48 | #wall_loc is Nstates x 4 x 1 49 | #dists is Larena x Larena of geodesic distance (from dist_to_rew()) 50 | Naction, Larena = ed.Naction, ed.Larena 51 | deltas = [[1; 0], [-1; 0], [0; 1], [0; -1]] #transitions for each action 52 | nn_dists = zeros(4) .+ Inf #distance to reward for each action 53 | state_ind = state_ind_from_state(Larena, state)[1] #where am I 54 | for a in 1:4 #for each action 55 | if ~Bool(wall_loc[state_ind, a, 1]) #if I do not hit a wall 56 | newstate = state .+ deltas[a] #where am I now 57 | newstate = Int.((newstate .+ Larena .- 1) .% Larena .+ 1) #1:L (2xbatch) 58 | nn_dists[a] = dists[newstate[1], newstate[2]] #how far is this from reward 59 | end 60 | end 61 | as = findall(nn_dists .== minimum(nn_dists)) #all optimal actions 62 | πt = zeros(Naction) 63 | πt[as] .= 1 / length(as) #uniform policy 64 | return πt #optimal policy 65 | end -------------------------------------------------------------------------------- /computational_model/src/walls_build.jl: -------------------------------------------------------------------------------- 1 | 2 | #In this script, we instantiate the RL environment which include initialize() and step() functions. 3 | 4 | using Flux, Statistics, Random, Distributions, StatsFuns, Zygote, PyPlot, Logging 5 | 6 | """function that computes things like the input and output dimensionality of the network""" 7 | function useful_dimensions(Larena, planner) 8 | Nstates = Larena^2 #number of states in arena 9 | Nstate_rep = 2 #dimensionality of the state representation (e.g. '2' for x,y-coordinates) 10 | Naction = 5 #number of actions available 11 | Nout = Naction + 1 + Nstates #actions and value function and prediction of state 12 | Nout += 1 # needed for backward compatibility (this lives between state and reward predictions) 13 | Nwall_in = 2 * Nstates #provide full info 14 | Nin = Naction + 1 + 1 + Nstates + Nwall_in #5 actions, 1 rew, 1 time, L^2 states, some walls 15 | 16 | Nin += planner.Nplan_in #additional inputs from planning 17 | Nout += planner.Nplan_out #additional outputs for planning 18 | 19 | return Nstates, Nstate_rep, Naction, Nout, Nin 20 | end 21 | 22 | """function that objects the position of the agent given an action (assuming no walls)""" 23 | function update_agent_state(agent_state, amove, Larena) 24 | new_agent_state = 25 | agent_state + [amove[1:1, :] - amove[2:2, :]; amove[3:3, :] - amove[4:4, :]] #2xbatch 26 | new_agent_state = Int32.((new_agent_state .+ Larena .- 1) .% Larena .+ 1) #1:L (2xbatch) 27 | return new_agent_state 28 | end 29 | 30 | """ 31 | act_and_receive_reward(action, world_state, planning, env_dimensions, agent_output, model, hidden_state, model_properties) 32 | output: reward, new_world_state, ground_truth_predictions, one-hot-action, agent_at_reward? 33 | This function implements the 'environment' of the RL algorithm. 34 | It takes as input the output of the agent and the state of the world and returns the new state of the world. 35 | Note that 'planning' in our formulation takes place in the environment, so the output includes the result of planning. 36 | """ 37 | function act_and_receive_reward( 38 | a, world_state, planner, environment_dimensions, agent_output, model, h_rnn, mp 39 | ) 40 | agent_state = world_state.agent_state 41 | environment_state = world_state.environment_state 42 | reward_location = environment_state.reward_location 43 | wall_loc = environment_state.wall_loc 44 | Naction = environment_dimensions.Naction 45 | Larena = environment_dimensions.Larena 46 | 47 | agent_state_ind = state_ind_from_state(Larena, agent_state) #extract index 48 | batch = size(a, 2) #batch size 49 | Nstates = Larena^2 50 | 51 | ahot = zeros(Naction, batch) #attempted action 52 | amove = zeros(Naction, batch) #actual movement 53 | rew = zeros(Float32, 1, batch) #reward collected 54 | 55 | #construct array of attempted and actual movements 56 | for b in 1:batch 57 | abatch = a[1, b] # action 58 | ahot[abatch, b] = 1f0 #attempted action 59 | if (abatch < 4.5) && Bool(wall_loc[agent_state_ind[b], abatch, b]) 60 | rew[1, b] -= 0f0 #penalty for hitting wall? 61 | else 62 | amove[abatch, b] = 1 #only move if we don't hit a wall 63 | end 64 | end 65 | 66 | new_agent_state = update_agent_state(agent_state, amove, Larena) #(x,y) coordinates 67 | shot = onehot_from_state(Larena, new_agent_state) #one-hot encoding (Nstates x batch) 68 | s_index = reduce(vcat, [sortperm(-shot[:, b])[1] for b in 1:batch]) #corresponding index 69 | r_index = get_rew_locs(reward_location) #index of reward location 70 | predictions = (Int32.(s_index), Int32.(r_index)) #things to be predicted by the agent 71 | 72 | found_rew = Bool.(reward_location[Bool.(shot)]) #moved to the reward 73 | s_old_hot = onehot_from_state(Larena, agent_state) #one-hot encoding of previous agent_state 74 | at_rew = Bool.(reward_location[Bool.(s_old_hot)]) #at reward before action 75 | 76 | moved = sum(amove[1:4, :]; dims=1)[:] #did I perform a movement? (size batch) 77 | rew[1, found_rew .& (moved .> 0.5)] .= 1 #get reward if agent moved to reward location 78 | 79 | ### teleport the agents that found the reward on the previous iteration ### 80 | for b in 1:batch 81 | if at_rew[b] #at reward 82 | tele_reward_location = ones(Nstates) / (Nstates - 1) #where can I teleport to (not rew location) 83 | tele_reward_location[Bool.(reward_location[:, b])] .= 0 84 | new_state = rand(Categorical(tele_reward_location), 1, 1) #sample new state uniformly at random 85 | new_agent_state[:, b] = state_from_loc(Larena, new_state) #convert to (x,y) coordinates 86 | shot[:, b] .= 0f0; shot[new_state[1], b] = 1f0 #update onehot location 87 | end 88 | end 89 | 90 | #run planning algorithm 91 | planning_state, plan_inds = planner.planning_algorithm(world_state, 92 | ahot, 93 | environment_dimensions, 94 | agent_output, 95 | at_rew, 96 | planner, 97 | model, 98 | h_rnn, 99 | mp) 100 | 101 | planned = Bool.(zeros(batch)); planned[plan_inds] .= true #which agents within the batch engaged in planning 102 | 103 | #update the time elapsed for each episode 104 | new_time = copy(environment_state.time) 105 | new_time[ .~ planned ] .+= 1f0 #increment time for acting 106 | if planner.constant_rollout_time 107 | new_time[planned] .+= planner.planning_time #increment time for planning 108 | else 109 | plan_states = planning_state.plan_cache 110 | plan_lengths = sum(plan_states[:, planned] .> 0.5, dims = 1)[:] # number of planning steps for each batch 111 | new_time[planned] += plan_lengths*planner.planning_time/5 112 | if rand() < 1e-5 println("variable planning time! ", plan_lengths*planner.planning_time/5) end 113 | end 114 | 115 | rew[1, planned] .+= planner.planning_cost #cost of planning (in units of rewards; default 0) 116 | 117 | #update the state of the world 118 | new_world_state = WorldState(; 119 | agent_state=new_agent_state, 120 | environment_state=WallState(; wall_loc, reward_location, time = new_time), 121 | planning_state=planning_state 122 | ) 123 | 124 | return Float32.(rew), new_world_state, predictions, ahot, at_rew 125 | end 126 | 127 | """ 128 | build_environment(arena_size, N_hidden, max_time, planning_depth, greedy_actions, no_planning) 129 | This function constructs an environment object which includes 'initialize' and 'step' methods for the agent to interact with. 130 | """ 131 | function build_environment( 132 | Larena::Int, 133 | Nhidden::Int, 134 | T::Int; 135 | Lplan::Int, 136 | greedy_actions=false, 137 | no_planning = false, 138 | constant_rollout_time = true, 139 | ) 140 | 141 | # create planner object 142 | # note that planner includes a 'plan_state' which can carry over in more general planning algorithms 143 | planner, initial_plan_state = build_planner(Lplan, Larena; constant_rollout_time) 144 | Nstates, Nstate_rep, Naction, Nout, Nin = useful_dimensions(Larena, planner) #compute some useful quantities 145 | model_properties = ModelProperties(Nout, Nhidden, Nin, Lplan, greedy_actions, no_planning) #initialize a model property object 146 | environment_dimensions = EnvironmentDimensions(Nstates, Nstate_rep, Naction, T, Larena) #initialize an environment dimension object 147 | 148 | ### define a 'step' function that updates the environment ### 149 | function step(agent_output, a, world_state, environment_dimensions, model_properties, model, h_rnn) 150 | 151 | Zygote.ignore() do #no differentiation through the environment 152 | rew, new_world_state, predictions, ahot, at_rew = act_and_receive_reward( 153 | a, world_state, planner, environment_dimensions, agent_output, model, h_rnn, model_properties 154 | ) #take a step through the environment 155 | #generate agent input 156 | agent_input = gen_input(new_world_state, ahot, rew, environment_dimensions, model_properties) 157 | #return reward, input, world state and ground truths for predictions 158 | return rew, Float32.(agent_input), new_world_state, predictions 159 | end 160 | end 161 | 162 | #create initialization function 163 | function initialize(reward_location, agent_state, batch, mp; initial_params = []) 164 | return initialize_arena(reward_location, agent_state, batch, mp, environment_dimensions, initial_plan_state, initial_params=initial_params) 165 | end 166 | 167 | #construct environment with initialize() and step() functions and a list of dimensions 168 | environment = Environment(initialize, step, environment_dimensions) 169 | 170 | ### task specific evaluation/progress function #### 171 | function model_eval(m, batch::Int64, loss_hp::LossHyperparameters) 172 | Nrep = 5 173 | means = zeros(Nrep, batch) 174 | all_actions = zeros(Nrep, batch) 175 | firstrews = zeros(Nrep, batch) 176 | preds = zeros(T - 1, Nrep, batch) 177 | Naction = environment_dimensions.Naction 178 | Nstates = environment_dimensions.Nstates 179 | 180 | for i in 1:Nrep 181 | _, agent_outputs, rews, actions, world_states, _ = run_episode( 182 | m, environment, loss_hp; hidden=true, batch=batch 183 | ) 184 | agent_states = mapreduce( 185 | (x) -> x.agent_state, (x, y) -> cat(x, y; dims=3), world_states 186 | ) 187 | means[i, :] = sum(rews .>= 0.5; dims=2) #compute total reward for each batch 188 | all_actions[i, :] = (mean(actions .== 5; dims=2) ./ mean(actions .> 0.5, dims = 2)) #fraction of standing still for each batch 189 | for b in 1:batch 190 | firstrews[i, b] = sortperm(-(rews[b, :] .> 0.5))[1] #time to first reward 191 | end 192 | for t in 1:(T - 1) 193 | for b in 1:batch 194 | pred = sortperm( 195 | -agent_outputs[(Naction + 1 + 1):(Naction + 1 + Nstates), b, t] 196 | )[1] #predicted agent_state at time t (i.e. prediction of agent_state(t+1)) 197 | agent_state = Int32.(agent_states[:, b, t + 1]) #true agent_state(t+1) 198 | preds[t, i, b] = onehot_from_state(Larena, agent_state)[Int32(pred)] #did we get it right? 199 | end 200 | end 201 | 202 | end 203 | return mean(means), 204 | mean(preds), mean(all_actions), mean(firstrews) 205 | end 206 | 207 | return model_properties, environment, model_eval 208 | end 209 | -------------------------------------------------------------------------------- /computational_model/walls_train.jl: -------------------------------------------------------------------------------- 1 | using Pkg; Pkg.activate(".") 2 | using Revise 3 | using ToPlanOrNotToPlan 4 | using Flux, Statistics, Random, Distributions 5 | using StatsFuns, Zygote, ArgParse, NaNStatistics 6 | using Logging 7 | 8 | function parse_commandline() 9 | s = ArgParseSettings() 10 | ### set default values of command line options ### 11 | @add_arg_table s begin 12 | "--Nhidden" 13 | help = "Number of hidden units" 14 | arg_type = Int 15 | default = 100 16 | "--Larena" 17 | help = "Arena size (per side)" 18 | arg_type = Int 19 | default = 4 20 | "--T" 21 | help = "Number of timesteps per episode (in units of physical actions)" 22 | arg_type = Int 23 | default = 50 24 | "--Lplan" 25 | help = "Maximum planning horizon" 26 | arg_type = Int 27 | default = 8 28 | "--load" 29 | help = "Load previous model instead of initializing new model" 30 | arg_type = Bool 31 | default = false 32 | "--load_epoch" 33 | help = "which epoch to load" 34 | arg_type = Int 35 | default = 0 36 | "--seed" 37 | help = "Which random seed to use" 38 | arg_type = Int 39 | default = 1 40 | "--save_dir" 41 | help = "Save directory" 42 | arg_type = String 43 | default = "./" 44 | "--beta_p" 45 | help = "Relative importance of predictive loss" 46 | arg_type = Float32 47 | default = 0.5f0 48 | "--prefix" 49 | help = "Add prefix to model name" 50 | arg_type = String 51 | default = "" 52 | "--load_fname" 53 | help = "Model to load (default to default model name)" 54 | arg_type = String 55 | default = "" 56 | "--n_epochs" 57 | help = "Total number of training epochs" 58 | arg_type = Int 59 | default = 1001 60 | "--batch_size" 61 | help = "Batch size for each gradient step" 62 | arg_type = Int 63 | default = 40 64 | "--lrate" 65 | help = "Learning rate" 66 | arg_type = Float64 67 | default = 1e-3 68 | "--constant_rollout_time" 69 | help = "Do rollouts take a fixed amount of time irrespective of length" 70 | arg_type = Bool 71 | default = true 72 | end 73 | return parse_args(s) 74 | end 75 | 76 | function main() 77 | 78 | ##### global parameters ##### 79 | args = parse_commandline() 80 | println(args) 81 | 82 | # extract command line arguments 83 | Larena = args["Larena"] 84 | Lplan = args["Lplan"] 85 | Nhidden = args["Nhidden"] 86 | T = args["T"] 87 | load = args["load"] 88 | seed = args["seed"] 89 | save_dir = args["save_dir"] 90 | prefix = args["prefix"] 91 | load_fname = args["load_fname"] 92 | n_epochs = Int(args["n_epochs"]) 93 | βp = Float32(args["beta_p"]) 94 | batch_size = Int(args["batch_size"]) 95 | lrate = Float64(args["lrate"]) 96 | constant_rollout_time = Bool(args["constant_rollout_time"]) 97 | 98 | Base.Filesystem.mkpath(save_dir) 99 | Random.seed!(seed) #set random seed 100 | 101 | loss_hp = LossHyperparameters(; 102 | # predictive loss weight 103 | βp=βp, 104 | # value function loss weight 105 | βv=0.05f0, 106 | # entropy loss cost 107 | βe=0.05f0, 108 | # reward loss cost 109 | βr=1.0f0, 110 | ) 111 | 112 | # build RL environment 113 | model_properties, wall_environment, model_eval = build_environment( 114 | Larena, Nhidden, T; Lplan, constant_rollout_time 115 | ) 116 | # build RL agent 117 | m = build_model(model_properties, 5) 118 | # construct summary string 119 | mod_name = create_model_name( 120 | Nhidden, T, seed, Lplan, prefix = prefix 121 | ) 122 | 123 | #training parameters 124 | n_batches, save_every = 200, 50 125 | opt = ADAM(lrate) #initialize optimiser 126 | 127 | #used to keep track of progress 128 | rews, preds = [], [] 129 | epoch = 0 #start at epoch 0 130 | @info "model name" mod_name 131 | @info "training info" n_epochs n_batches batch_size 132 | 133 | if load #if we load a previous model 134 | if load_fname == "" #filename not specified; fall back to default 135 | fname = "$save_dir/models/" * mod_name * "_" * string(args["load_epoch"]) 136 | else #load specific model 137 | fname = "$save_dir/models/" * load_fname 138 | end 139 | #load the parameters and initialize model 140 | network, opt, store, hps, policy, prediction = recover_model(fname) 141 | m = ModularModel(model_properties, network, policy, prediction, forward_modular) 142 | 143 | #load the learning curve from the previous model 144 | rews, preds = store[1], store[2] 145 | epoch = length(rews) #start where we were 146 | if load_fname != "" #loaded pretrained model; reset optimiser 147 | opt = ADAM(lrate) 148 | end 149 | end 150 | 151 | prms = Params(Flux.params(m.network, m.policy, m.prediction)) #model parameters 152 | println("parameter length: ", length(prms)) 153 | for p = prms println(size(p)) end 154 | 155 | Nthread = Threads.nthreads() #number of threads available 156 | multithread = Nthread > 1 #multithread if we can 157 | @info "multithreading" Nthread 158 | thread_batch_size = Int(ceil(batch_size / Nthread)) #distribute batch evenly across threads 159 | #construct function without arguments for Flux 160 | closure = () -> model_loss(m, wall_environment, loss_hp, thread_batch_size) / Nthread 161 | 162 | gmap_grads(g1, g2) = gmap(+, prms, g1, g2) #define map function for reducing gradients 163 | 164 | #function for training on a single match 165 | function loop!(batch, closure) 166 | if multithread #distribute across threads? 167 | all_gs = Vector{Zygote.Grads}(undef, Nthread) #vector of gradients for each thread 168 | Threads.@threads for i in 1:Nthread #on each thread 169 | rand_roll = rand(100*batch*i) #run through some random numbers 170 | gs = gradient(closure, prms) #compute gradient 171 | all_gs[i] = gs #save gradient 172 | end 173 | gs = reduce(gmap_grads, all_gs) #sum across our gradients 174 | else 175 | gs = gradient(closure, prms) #if we're not multithreading, just compute a simple gradient 176 | end 177 | return Flux.Optimise.update!(opt, prms, gs) #update model parameters 178 | end 179 | 180 | t0 = time() #wallclock time 181 | while epoch < n_epochs 182 | epoch += 1 #count epochs 183 | flush(stdout) #flush output 184 | 185 | Rmean, pred, mean_a, first_rew = model_eval( 186 | m, batch_size, loss_hp 187 | ) #evaluate performance 188 | Flux.reset!(m) #reset model 189 | 190 | if (epoch - 1) % save_every == 0 #occasionally save our model 191 | Base.Filesystem.mkpath("$save_dir/models") 192 | filename = "$save_dir/models/" * mod_name * "_" * string(epoch - 1) 193 | store = [rews, preds] 194 | save_model(m, store, opt, filename, wall_environment, loss_hp; Lplan) 195 | end 196 | 197 | #print progress 198 | elapsed_time = round(time() - t0; digits=1) 199 | println("progress: epoch=$epoch t=$elapsed_time R=$Rmean pred=$pred plan=$mean_a first=$first_rew") 200 | push!(rews, Rmean) 201 | push!(preds, pred) 202 | plot_progress(rews, preds) #plot prorgess 203 | 204 | for batch in 1:n_batches #for each batch 205 | loop!(batch, closure) #perform an update step 206 | end 207 | end 208 | 209 | Flux.reset!(m) #reset model state 210 | # save model 211 | filename = "$save_dir/results/" * mod_name 212 | store = [rews, preds] 213 | return save_model(m, store, opt, filename, wall_environment, loss_hp; Lplan) 214 | end 215 | main() 216 | -------------------------------------------------------------------------------- /human_data.txt: -------------------------------------------------------------------------------- 1 | Our human behavioral data is stored as an .sqlite file. 2 | 3 | For all analyses, the first 8 guided trials and first 2 non-guided trials were discarded to allow subjects to get used to the task. 4 | 5 | -------------------------------------------------------------------------------- /human_data/Euclidean_prolific_data.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/human_data/Euclidean_prolific_data.sqlite -------------------------------------------------------------------------------- /human_data/prolific_data.sqlite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KrisJensen/planning_code/a9c03b4dd4c1df747dd8e9608621652b1f935cad/human_data/prolific_data.sqlite --------------------------------------------------------------------------------