├── AnalysisFunctions └── analysis_WSLS_v1.m ├── Figure2_simulations.m ├── Figure3B_localminima.m ├── Figure3_localminima.m ├── Figure4_ParameterRecovery.m ├── Figure5_both.m ├── Figure5_confusionMatrix.m ├── Figure6_fitUnimportantParams.m ├── Figure7_ValidateFit.m ├── Figures └── Figure2.pdf ├── FittingFunctions ├── fit_M1random_v1.m ├── fit_M2WSLS_v1.m ├── fit_M3RescorlaWagner_v1.m ├── fit_M4CK_v1.m ├── fit_M5RWCK_v1.m ├── fit_M6RescorlaWagnerBias_v1.m └── fit_all_v1.m ├── HelperFunctions ├── addABCs.m ├── addFacetLines.m ├── easy_gridOfEqualFigures.m ├── imageTextMatrix.m └── saveFigurePdf.m ├── LICENSE ├── LikelihoodFunctions ├── lik_M1random_v1.m ├── lik_M2WSLS_v1.m ├── lik_M3RescorlaWagner_v1.m ├── lik_M4CK_v1.m ├── lik_M5RWCK_v1.m ├── lik_M6RescorlaWagnerBias_v1.m └── lik_fullRL_v1.m ├── README.md ├── SimulationFunctions ├── choose.m ├── simulate_M1random_v1.m ├── simulate_M2WSLS_v1.m ├── simulate_M3RescorlaWagner_v1.m ├── simulate_M4ChoiceKernel_v1.m ├── simulate_M5RWCK_v1.m ├── simulate_M6RescorlaWagnerBias_v1.m ├── simulate_blind_v1.m ├── simulate_fullRL_v1.m └── simulate_validationModel_v1.m └── WilsonCollins_TenRulesForModelFitting.pdf /AnalysisFunctions/analysis_WSLS_v1.m: -------------------------------------------------------------------------------- 1 | function out = analysis_WSLS_v1(a, r) 2 | 3 | aLast = [nan a(1:end-1)]; 4 | stay = aLast == a; 5 | rLast = [nan r(1:end-1)]; 6 | 7 | winStay = nanmean(stay(rLast == 1)); 8 | loseStay = nanmean(stay(rLast == 0)); 9 | out = [loseStay winStay]; -------------------------------------------------------------------------------- /Figure2_simulations.m: -------------------------------------------------------------------------------- 1 | %%%%% Bob Wilson & Anne Collins 2 | %%%%% 2018 3 | %%%%% Code to produce figure 2 in submitted paper "Ten simple rules for the 4 | %%%%% computational modeling of behavioral data" 5 | 6 | 7 | clear 8 | 9 | addpath('./SimulationFunctions') 10 | addpath('./AnalysisFunctions') 11 | addpath('./HelperFunctions') 12 | %% 13 | % set up colors 14 | global AZred AZblue AZcactus AZsky AZriver AZsand AZmesa AZbrick 15 | 16 | AZred = [171,5,32]/256; 17 | AZblue = [12,35,75]/256; 18 | AZcactus = [92, 135, 39]/256; 19 | AZsky = [132, 210, 226]/256; 20 | AZriver = [7, 104, 115]/256; 21 | AZsand = [241, 158, 31]/256; 22 | AZmesa = [183, 85, 39]/256; 23 | AZbrick = [74, 48, 39]/256; 24 | 25 | 26 | 27 | % experiment parameters 28 | T = 100; % number of trials 29 | mu = [0.2 0.8]; % mean reward of bandits 30 | 31 | % number of repetitions for simulations 32 | Nrep = 110; 33 | 34 | % Model 1: Random responding 35 | for n = 1:Nrep 36 | b = 0.5; 37 | [a, r] = simulate_M1random_v1(T, mu, b); 38 | sim(1).a(:,n) = a; 39 | sim(1).r(:,n) = r; 40 | end 41 | 42 | % Model 2: Win-stay-lose-shift 43 | for n = 1:Nrep 44 | epsilon = 0.1; 45 | [a, r] = simulate_M2WSLS_v1(T, mu, epsilon); 46 | sim(2).a(:,n) = a; 47 | sim(2).r(:,n) = r; 48 | end 49 | % Model 3: Rescorla Wagner 50 | for n = 1:Nrep 51 | alpha = 0.1; 52 | beta = 5; 53 | [a, r] = simulate_M3RescorlaWagner_v1(T, mu, alpha, beta); 54 | sim(3).a(:,n) = a; 55 | sim(3).r(:,n) = r; 56 | end 57 | 58 | % Model 4: Choice kernel 59 | for n = 1:Nrep 60 | alpha_c = 0.1; 61 | beta_c = 3; 62 | [a, r] = simulate_M4ChoiceKernel_v1(T, mu, alpha_c, beta_c); 63 | sim(4).a(:,n) = a; 64 | sim(4).r(:,n) = r; 65 | end 66 | 67 | % Model 5: Rescorla-Wagner + choice kernel 68 | for n = 1:Nrep 69 | alpha = 0.1; 70 | beta = 5; 71 | alpha_c = 0.1; 72 | beta_c = 1; 73 | [a, r] = simulate_M5RWCK_v1(T, mu, alpha, beta, alpha_c, beta_c); 74 | sim(5).a(:,n) = a; 75 | sim(5).r(:,n) = r; 76 | end 77 | 78 | 79 | %% win-stay-lose-shift analysis 80 | for i = 1:length(sim) 81 | for n = 1:Nrep 82 | sim(i).wsls(:,n) = analysis_WSLS_v1(sim(i).a(:,n)', sim(i).r(:,n)'); 83 | end 84 | wsls(:,i) = nanmean(sim(i).wsls,2); 85 | end 86 | 87 | %% Plot WSLS behavior for all models 88 | figure(1); clf; hold on; 89 | l = plot([0 1], wsls); 90 | ylim([0 1]) 91 | set(l, 'marker', '.', 'markersize', 50, 'linewidth', 3) 92 | 93 | 94 | legend({'M1: random' 'M2: WSLS' 'M3: RW' 'M4: CK' 'M5: RW+CK'}, ... 95 | 'location', 'southeast') 96 | xlabel('previous reward') 97 | ylabel('probability of staying') 98 | 99 | set(gca, 'xtick', [0 1], 'tickdir', 'out', 'fontsize', 18, 'xlim', [-0.1 1.1]) 100 | 101 | 102 | %% p(correct) analysis 103 | alphas = [0.02:0.02:1]; 104 | betas = [1 2 5 10 20]; 105 | 106 | for n = 1:1000 107 | n 108 | for i = 1:length(alphas) 109 | for j = 1:length(betas) 110 | [a, r] = simulate_M3RescorlaWagner_v1(T, mu, alphas(i), betas(j)); 111 | [~,imax] = max(mu); 112 | correct(i,j,n) = nanmean(a == imax); 113 | correctEarly(i,j,n) = nanmean(a(1:10) == imax); 114 | correctLate(i,j,n) = nanmean(a(end-9:end) == imax); 115 | end 116 | end 117 | end 118 | 119 | %% plot p(correct) behavior 120 | figure(1); 121 | E = nanmean(correctEarly,3); 122 | L = nanmean(correctLate,3); 123 | 124 | figure(1); clf; 125 | set(gcf, 'Position', [284 498 750 300]) 126 | ax = easy_gridOfEqualFigures([0.2 0.1], [0.08 0.14 0.05 0.03]); 127 | 128 | axes(ax(1)); hold on; 129 | l = plot([0 1], wsls); 130 | ylim([0 1]) 131 | set(l, 'marker', '.', 'markersize', 50, 'linewidth', 3) 132 | leg1 = legend({'M1: random' 'M2: WSLS' 'M3: RW' 'M4: CK' 'M5: RW+CK'}, ... 133 | 'location', 'southeast'); 134 | xlabel('previous reward') 135 | % ylabel('probability of staying') 136 | ylabel('p(stay)') 137 | title('stay behavior', 'fontweight', 'normal') 138 | xlim([-0.1 1.1]); 139 | ylim([0 1.04]) 140 | set(ax(1), 'xtick', [0 1]) 141 | set(leg1, 'fontsize', 12) 142 | set(leg1, 'position', [0.19 0.2133 0.1440 0.2617]) 143 | set(ax(1), 'ytick', [0 0.5 1]) 144 | 145 | axes(ax(2)); hold on; 146 | l1 = plot(alphas, E); 147 | xlabel('learning rate, \alpha') 148 | ylabel('p(correct)') 149 | title('early trials', 'fontweight', 'normal') 150 | 151 | for i = 1:length(betas) 152 | leg{i} = ['\beta = ' num2str(betas(i))]; 153 | end 154 | leg2 = legend(l1(end:-1:1), {leg{end:-1:1}}); 155 | 156 | set([leg1 leg2], 'fontsize', 12) 157 | set(leg2, 'position', [0.6267 0.6453 0.1007 0.2617]); 158 | 159 | axes(ax(3)); hold on; 160 | l2 = plot(alphas, L); 161 | xlabel('learning rate, \alpha') 162 | % ylabel('p(correct)') 163 | title('late trials', 'fontweight', 'normal') 164 | for i = 1:length(l1) 165 | f = (i-1)/(length(l1)-1); 166 | set([l1(i) l2(i)], 'color', AZred*f + AZblue*(1-f)); 167 | end 168 | set([l1 l2], 'linewidth', 3) 169 | set(ax(3), 'yticklabel', []) 170 | 171 | set(ax(2:3), 'ylim', [0.5 1.02]) 172 | set(ax, 'fontsize', 18, 'tickdir', 'out') 173 | addABCs(ax(1:2), [-0.06 0.09], 32) 174 | 175 | 176 | %% save resulting figure 177 | saveFigurePdf(gcf, './Figures/Figure2') 178 | -------------------------------------------------------------------------------- /Figure3B_localminima.m: -------------------------------------------------------------------------------- 1 | function localminima 2 | %clear all 3 | alphas = [.06:.01:.5];% learning rate 4 | betas = [1 4:2:20];% inverse temperature 5 | rhos=[.5:.01:.98];% WM memory weight 6 | % for coarser parameters and faster brute force computation 7 | % [~,coarsealphas] = intersect(alphas,[.05:.05:.5]);%[0.05:.05:1]; 8 | % [~,coarsebetas] = intersect(betas,[1 4:4:20]);%[1 5:5:50]; 9 | % [~,coarserhos]=intersect(rhos,[.5:.05:.99]);%[0:.05:1]; 10 | Ks=2:6;% capacity 11 | 12 | % real simulation parameters 13 | realalpha=.1; 14 | realbeta=8; 15 | realrho=.9; 16 | realK=4; 17 | %% simulate one data set 18 | 19 | [stim,update,choice,rew,setsize]=simulate(realalpha,realbeta,realrho,realK); 20 | 21 | %% fmincon fitting 22 | 23 | % set fmincon options 24 | options=optimset('MaxFunEval',100000,'Display','off','algorithm','active-set');% 25 | 26 | % run optimization over 10 starting points 27 | for init=1:10 28 | % random starting point 29 | x0=rand(1,3); 30 | % optimize 31 | [pval,fval,bla,bla2] =fmincon(@(x) computellh(x,realK,stim,update,choice,rew,setsize),x0,[],[],[],[],... 32 | [0 0 0],[1 1 1],[],options); 33 | % store optimization result 34 | pars(init,:) = [pval,fval]; 35 | end 36 | % find global best 37 | [mf,i]=min(pars(:,end)); 38 | pars = pars(i,:); 39 | %% brute force fitting 40 | i1=0; 41 | for alpha=alphas 42 | i1=i1+1 43 | i2=0; 44 | for beta=betas 45 | i2=i2+1; 46 | j1=0; 47 | for rho=rhos 48 | j1=j1+1; 49 | j2=0; 50 | for K=realK 51 | j2=j2+1; 52 | p=[rho,alpha,beta/50]; 53 | % store likelihood over parameters in a mesh 54 | llh(i1,i2,j1,j2)=-computellh(p,K,stim,update,choice,rew,setsize); 55 | end 56 | end 57 | end 58 | end 59 | 60 | %% plot the results - in 2d 61 | 62 | figure; 63 | subplot(2,2,1) 64 | llh2=squeeze(max(squeeze(llh),[],2)); 65 | mi=min(llh2(:)); 66 | ma=max(llh2(:)); 67 | x=repmat(1:length(alphas),length(rhos),1)'; 68 | y=repmat(1:length(rhos),length(alphas),1); 69 | [mb,i]=max(llh2(:)); 70 | imagesc(alphas(1:end),rhos(1:end),llh2',[mi,ma]) 71 | colorbar 72 | hold on 73 | plot(alphas(x(i)),rhos(y(i)),'ok') 74 | plot(realalpha,realrho,'xr') 75 | plot(pars(2),pars(1),'*k') 76 | xlabel('alpha') 77 | ylabel('rho') 78 | set(gca,'fontsize',16) 79 | % 80 | 81 | %% iterate simulation and fitting 82 | 83 | options=optimset('MaxFunEval',100000,'Display','off','algorithm','active-set');% 84 | 85 | % number of random starting points for optimizer 86 | ninitialpoints=10; 87 | % for 100 simulations 88 | for iter = 1:100 89 | disp(['simulation #',num2str(iter)]) 90 | % generate data 91 | [stim,update,choice,rew,setsize]=simulate(realalpha,realbeta,realrho,realK); 92 | pars=[]; 93 | % fit simulated data with ninitialpoints random starting points 94 | for init=1:ninitialpoints 95 | x0=rand(1,3); 96 | [pval,fval,bla,bla2] =fmincon(@(x) computellh(x,realK,stim,update,choice,rew,setsize),x0,[],[],[],[],... 97 | [0 0 0],[1 1 1],[],options); 98 | pars(init,:) = [pval,fval]; 99 | [m,i]=min(pars(:,end)); 100 | bestllh(iter,init)=m; 101 | bestpars(iter,init,:)=pars(i,1:end-1); 102 | end 103 | % find global best fit 104 | [mf,i]=min(pars(:,end)); 105 | % find at which random starting point it was found 106 | when(iter,1)=i; 107 | % find at which random starting point a likelihood within .01 of the 108 | % global best was found 109 | i=find(bestllh(iter,:) thresh; 58 | 59 | for i = 1:2 60 | axes(ax(i)); 61 | plot(Xsim(i,ind), Xfit(i,ind), 'o', 'color', AZblue, 'markersize', 8, 'linewidth', 1, ... 62 | 'markerfacecolor', [1 1 1]*0.5) 63 | end 64 | 65 | set(ax(1,2),'xscale', 'log', 'yscale' ,'log') 66 | 67 | axes(ax(1)); t = title('learning rate'); 68 | axes(ax(2)); t(2) = title('softmax temperature'); 69 | 70 | axes(ax(1)); xlabel('simulated \alpha'); ylabel('fit \alpha'); 71 | axes(ax(2)); xlabel('simulated \beta'); ylabel('fit \beta'); 72 | 73 | 74 | set(ax, 'tickdir', 'out', 'fontsize', 18) 75 | set(t, 'fontweight', 'normal') 76 | addABCs(ax(1), [-0.07 0.08], 32) 77 | addABCs(ax(2), [-0.1 0.08], 32, 'B') 78 | set(ax, 'tickdir', 'out') 79 | for i= 1:size(Xsim,1) 80 | axes(ax(i)); 81 | xl = get(gca, 'xlim'); 82 | plot(xl, xl, 'k--') 83 | end 84 | saveFigurePdf(gcf, '~/Desktop/Figure4') 85 | saveFigureEps(gcf, '~/Desktop/Figure4') 86 | saveFigurePng(gcf, '~/Desktop/Figure4') 87 | 88 | 89 | 90 | %% 91 | figure(1); clf; hold on; 92 | % plot(Xsim(1,:), Xsim(2,:),'.') 93 | plot(Xfit(2,:), Xfit(1,:),'.') 94 | set(gca, 'xscale', 'log') 95 | -------------------------------------------------------------------------------- /Figure5_both.m: -------------------------------------------------------------------------------- 1 | 2 | %%%%% Bob Wilson & Anne Collins 3 | %%%%% 2018 4 | %%%%% Code to produce figure 5 in submitted paper "Ten simple rules for the 5 | %%%%% computational modeling of behavioral data" 6 | 7 | 8 | clear 9 | 10 | 11 | CM = zeros(5); 12 | 13 | T = 1000; 14 | mu = [0.2 0.8]; 15 | 16 | 17 | %% 18 | 19 | CM1 = [1 0 0 0 0 20 | 0.01 0.99 0 0 0 21 | 0.34 0.12 0.54 0 0 22 | 0.35 0.09 0 0.54 0.01 23 | 0.14 0.04 0.26 0.26 0.3]; 24 | 25 | 26 | CM2 = [ 0.9700 0.0300 0 0 0 27 | 0.0400 0.9600 0 0 0 28 | 0.0600 0 0.9400 0 0 29 | 0.0600 0 0.0100 0.9300 0 30 | 0.0300 0 0.1000 0.1500 0.7200]; 31 | 32 | 33 | %% inverse confusion matrices 34 | for i = 1:size(CM1,2) 35 | iCM1(:,i) = CM1(:,i) / sum(CM1(:,i)); 36 | iCM2(:,i) = CM2(:,i) / sum(CM2(:,i)); 37 | end 38 | 39 | 40 | 41 | %% 42 | 43 | figure(1); clf; 44 | set(gcf, 'Position', [400 405 900 800]); 45 | ax = easy_gridOfEqualFigures([0.02 0.23 0.17], [0.1 0.14 0.03]); 46 | axes(ax(1)); 47 | t = imageTextMatrix(CM1); 48 | set(t(CM1'<0.3), 'color', 'w') 49 | hold on; 50 | [l1, l2] = addFacetLines(CM1); 51 | set(t, 'fontsize', 18) 52 | xlabel('fit model') 53 | ylabel('simulated model') 54 | 55 | axes(ax(2)); 56 | t = imageTextMatrix(CM2); 57 | set(t(CM2'<0.3), 'color', 'w') 58 | hold on; 59 | [l1, l2] = addFacetLines(CM2); 60 | set(t, 'fontsize', 18) 61 | xlabel('fit model') 62 | ylabel('simulated model') 63 | 64 | a = annotation('textbox', [0 0.94 1 0.06]); 65 | set(a, 'string', 'confusion matrix: p(fit model | simulated model)', 'fontsize', 38, ... 66 | 'horizontalalignment', 'center', ... 67 | 'linestyle', 'none', 'fontweight', 'normal') 68 | 69 | axes(ax(3)); 70 | t = imageTextMatrix(round(100*iCM1)/100); 71 | set(t(iCM1'<0.3), 'color', 'w') 72 | hold on; 73 | [l1, l2] = addFacetLines(CM1); 74 | set(t, 'fontsize', 18) 75 | xlabel('fit model') 76 | ylabel('simulated model') 77 | 78 | axes(ax(4)); 79 | t = imageTextMatrix(round(100*iCM2)/100); 80 | set(t(iCM2'<0.3), 'color', 'w') 81 | hold on; 82 | [l1, l2] = addFacetLines(iCM2); 83 | set(t, 'fontsize', 18) 84 | xlabel('fit model') 85 | ylabel('simulated model') 86 | 87 | a = annotation('textbox', [0 0.94-0.52 1 0.06]); 88 | set(a, 'string', 'inversion matrix: p(simulated model | fit model)', 'fontsize', 38, ... 89 | 'horizontalalignment', 'center',... 90 | 'linestyle', 'none', 'fontweight', 'normal') 91 | 92 | 93 | set(ax, 'xtick', [1:5], 'ytick', [1:5], 'fontsize', 24, ... 94 | 'xaxislocation', 'top', 'tickdir', 'out') 95 | addABCs(ax, [-0.05 0.08], 40) 96 | saveFigurePdf(gcf, '~/Desktop/Figure5') 97 | % saveFigurePdf(gcf, './Figures/Figure5') -------------------------------------------------------------------------------- /Figure5_confusionMatrix.m: -------------------------------------------------------------------------------- 1 | %%%%% Bob Wilson & Anne Collins 2 | %%%%% 2018 3 | %%%%% Code to produce figure 5 in submitted paper "Ten simple rules for the 4 | %%%%% computational modeling of behavioral data" 5 | 6 | 7 | 8 | clear 9 | 10 | addpath('./SimulationFunctions') 11 | addpath('./AnalysisFunctions') 12 | addpath('./HelperFunctions') 13 | addpath('./FittingFunctions') 14 | addpath('./LikelihoodFunctions') 15 | 16 | %% 17 | 18 | CM = zeros(5); 19 | 20 | T = 1000; 21 | mu = [0.2 0.8]; 22 | 23 | 24 | for count = 1:100 25 | count 26 | 27 | figure(1); clf; 28 | FM = round(100*CM/sum(CM(1,:)))/100; 29 | t = imageTextMatrix(FM); 30 | set(t(FM'<0.3), 'color', 'w') 31 | hold on; 32 | [l1, l2] = addFacetLines(CM); 33 | set(t, 'fontsize', 22) 34 | title(['count = ' num2str(count)]); 35 | set(gca, 'xtick', [1:5], 'ytick', [1:5], 'fontsize', 28, ... 36 | 'xaxislocation', 'top', 'tickdir', 'out') 37 | xlabel('fit model') 38 | ylabel('simulated model') 39 | 40 | 41 | drawnow 42 | % Model 1 43 | b = rand; 44 | [a, r] = simulate_M1random_v1(T, mu, b); 45 | [BIC, iBEST, BEST] = fit_all_v1(a, r); 46 | CM(1,:) = CM(1,:) + BEST; 47 | 48 | % Model 2 49 | epsilon = rand; 50 | [a, r] = simulate_M2WSLS_v1(T, mu, epsilon); 51 | [BIC, iBEST, BEST] = fit_all_v1(a, r); 52 | CM(2,:) = CM(2,:) + BEST; 53 | 54 | % Model 3 55 | 56 | alpha = rand; 57 | beta = 1+exprnd(1); 58 | [a, r] = simulate_M3RescorlaWagner_v1(T, mu, alpha, beta); 59 | [BIC, iBEST, BEST] = fit_all_v1(a, r); 60 | CM(3,:) = CM(3,:) + BEST; 61 | 62 | % Model 4 63 | alpha_c = rand; 64 | beta_c = 1+exprnd(1); 65 | [a, r] = simulate_M4ChoiceKernel_v1(T, mu, alpha_c, beta_c); 66 | [BIC, iBEST, BEST] = fit_all_v1(a, r); 67 | CM(4,:) = CM(4,:) + BEST; 68 | 69 | % Model 5 70 | alpha = rand; 71 | beta = 1+exprnd(1); 72 | alpha_c = rand; 73 | beta_c = 1+exprnd(1); 74 | [a, r] = simulate_M5RWCK_v1(T, mu, alpha, beta, alpha_c, beta_c); 75 | [BIC, iBEST, BEST] = fit_all_v1(a, r); 76 | CM(5,:) = CM(5,:) + BEST; 77 | 78 | end 79 | %% 80 | figure(1); 81 | title('') 82 | set(gcf, 'Position', [811 417 500 400]) 83 | set(gca, 'fontsize', 28); 84 | saveFigurePdf(gcf, '~/Figures/Figure5b') 85 | % 86 | % 87 | % [Xf, LL, BIC] = fit_M1random_v1(a, r); 88 | % [Xf, LL, BIC] = fit_M2WSLS_v1(a, r); 89 | % [Xf, LL, BIC] = fit_M3RescorlaWagner_v1(a, r); 90 | % [Xf, LL, BIC] = fit_M4CK_v1(a, r); 91 | % [Xf, LL, BIC] = fit_M5RWCK_v1(a, r); 92 | -------------------------------------------------------------------------------- /Figure6_fitUnimportantParams.m: -------------------------------------------------------------------------------- 1 | function Figure6_fitUnimportantParams 2 | 3 | % Generate and recover: simulate with RLbias, fit with RL and RL bias 4 | for rep=1:100 5 | rep 6 | % pick learning rate between .1 and .5 7 | alpha=.1+.4*rand; 8 | % pick inverse temperature between 1 and 9 9 | beta = 1+8*rand; 10 | % pick left-right bias between 0 and .2 11 | bias = .2*rand; 12 | % generate data 13 | D=simulate(alpha,beta,bias); 14 | % fit with both models 15 | [ps1,ps2]=fitRL(D); 16 | % save results: true params (1-3), RL fit params (4-6), RLbias params 17 | % (7-9). 18 | data(rep,:)=[alpha,beta,bias, [ps1(:,1:2) 0*ps1(:,1)],ps2(:,1:3)]; 19 | end 20 | 21 | % rescale the fit inverse temperatures 22 | data(:,5)=10*data(:,5); 23 | data(:,8)=10*data(:,8); 24 | 25 | % plot true against recovered parameters 26 | names={'\alpha','\beta','bias'}; 27 | 28 | 29 | figure(1); clf; 30 | set(gcf, 'Position', [440 378 700 450]); 31 | wg = [0.09 0.1 0.1 0.03]; 32 | hg = [0.12 0.23 0.1]; 33 | [ax, hb, wb] = easy_gridOfEqualFigures(hg, wg); 34 | 35 | AZred = [171,5,32]/256; 36 | for i=1:2% for each parameter 37 | % top line: fits with classic RL 38 | axes(ax(i)); hold on; 39 | plot(data(:,i),data(:,3+i),'o', 'color', AZred, 'markersize', 8, 'linewidth', 1); 40 | if i == 1; xlim([0 0.6]); end 41 | plot(xlim,xlim,'k--') 42 | l = lsline; 43 | set(l, 'linewidth', 3) 44 | 45 | xlabel(['simulated ',names{i}]) 46 | ylabel(['fit ',names{i}]) 47 | end 48 | 49 | for i = 1:3 50 | % bottom line: fits with RL + bias. 51 | axes(ax(i+3)); hold on; 52 | plot(data(:,i),data(:,6+i),'o', 'color', AZred, 'markersize', 8, 'linewidth', 1) 53 | plot(xlim,xlim,'k--') 54 | if i == 1; xlim([0 0.6]); end 55 | l = lsline; 56 | set(l, 'linewidth', 3) 57 | hold on 58 | 59 | xlabel(['simulated ',names{i}]) 60 | ylabel(['fit ',names{i}]) 61 | 62 | end 63 | set(ax([1 4]), 'xtick', [0:0.2:0.6], 'ylim', [0 1]) 64 | set(ax(3), 'visible', 'off') 65 | x1 = wg(1)/3; 66 | x2 = sum(wb(1:2))+wg(2); 67 | y1 = sum(hb(1:2)+hg(1:2)'); 68 | h = annotation('textbox', [x1 y1 x2 hg(end)], 'string', 'model 3 without bias', ... 69 | 'horizontalalignment', 'left', 'fontsize', 24, 'linestyle', 'none', ... 70 | 'fontweight', 'bold') 71 | 72 | x1 = wg(1)/3; 73 | x2 = sum(wb(1:3))+sum(wg(2:3)); 74 | y1 = sum(hb(1)+hg(1)'); 75 | h = annotation('textbox', [x1 y1 x2 hg(end)], 'string', 'model 3 including bias', ... 76 | 'horizontalalignment', 'left', 'fontsize', 24, 'linestyle', 'none', ... 77 | 'fontweight', 'bold') 78 | 79 | set(ax, 'fontsize', 18, 'tickdir', 'out') 80 | 81 | saveFigurePdf(gcf, '~/Desktop/Figure6b') 82 | 83 | 84 | 85 | 86 | 87 | 88 | end 89 | 90 | function D=simulate(alpha,beta,bias) 91 | % simulate RL 92 | k=0; 93 | 94 | for s=1:20 95 | Q=[.5 .5]; 96 | % define the correct action for this block 97 | if mod(s,2) == 1 98 | corA = 1; 99 | else 100 | corA = 2; 101 | end 102 | %corA=1+(rand>.5); 103 | % initialize Q-values 104 | %Q=[.5 .5]; %BOB EDIT 105 | % run 50 trials 106 | for t=1:50 107 | k=k+1; 108 | % bias in favor of choice 1 in the softmax 109 | p2=1/(1+exp(beta*(bias+Q(1)-Q(2)))); 110 | % select an action, decide if correct 111 | a=1+(rand save as a regular pdf, 10 | % flag = 1 => save with -zbuffer and -r flags 11 | % res - optional (only if flag = 1, default 300) resolution of picture 12 | % 13 | % NOTE : Setting flag = 1 can help deal with blurry rendering of pdfs on mac 14 | 15 | % Robert Wilson 16 | % 18-Mar-2010 17 | 18 | if exist('flag') ~= 1 19 | flag = 0; 20 | end 21 | 22 | if exist('res') ~= 1 23 | res = 300; 24 | end 25 | 26 | set(figHandle, 'windowstyle', 'normal') 27 | set(figHandle, 'paperpositionmode', 'auto') 28 | 29 | pp = get(figHandle, 'paperposition'); 30 | wp = pp(3); 31 | hp = pp(4); 32 | set(figHandle, 'papersize', [wp hp]) 33 | 34 | if flag 35 | print(figHandle, '-dpdf', ['-r' num2str(res)], '-zbuffer', savename); 36 | else 37 | print(figHandle, '-dpdf', savename); 38 | end 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Anne Collins 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_M1random_v1.m: -------------------------------------------------------------------------------- 1 | function NegLL = lik_M1random_v1(a, r, b) 2 | 3 | % note r is not used here but included to fit notation better with other 4 | % likelihood functions 5 | 6 | 7 | 8 | T = length(a); 9 | 10 | % loop over all trial 11 | for t = 1:T 12 | 13 | % compute choice probabilities 14 | p = [b 1-b]; 15 | 16 | % compute choice probability for actual choice 17 | choiceProb(t) = p(a(t)); 18 | 19 | end 20 | 21 | % compute negative log-likelihood 22 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_M2WSLS_v1.m: -------------------------------------------------------------------------------- 1 | function NegLL = lik_M2WSLS_v1(a, r, epsilon) 2 | 3 | 4 | % last reward/action (initialize as nan) 5 | rLast = nan; 6 | aLast = nan; 7 | 8 | 9 | T = length(a); 10 | 11 | % loop over all trial 12 | for t = 1:T 13 | 14 | % compute choice probabilities 15 | if isnan(rLast) 16 | 17 | % first trial choose randomly 18 | p = [0.5 0.5]; 19 | 20 | else 21 | 22 | % choice depends on last reward 23 | if rLast == 1 24 | 25 | % win stay (with probability 1-epsilon) 26 | p = epsilon/2*[1 1]; 27 | p(aLast) = 1-epsilon/2; 28 | 29 | else 30 | 31 | % lose shift (with probability 1-epsilon) 32 | p = (1-epsilon/2) * [1 1]; 33 | p(aLast) = epsilon / 2; 34 | 35 | end 36 | end 37 | 38 | % compute choice probability for actual choice 39 | choiceProb(t) = p(a(t)); 40 | 41 | aLast = a(t); 42 | rLast = r(t); 43 | end 44 | 45 | % compute negative log-likelihood 46 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_M3RescorlaWagner_v1.m: -------------------------------------------------------------------------------- 1 | function NegLL = lik_M3RescorlaWagner_v1(a, r, alpha, beta) 2 | 3 | 4 | Q = [0.5 0.5]; 5 | 6 | 7 | T = length(a); 8 | 9 | % loop over all trial 10 | for t = 1:T 11 | 12 | % compute choice probabilities 13 | p = exp(beta*Q) / sum(exp(beta*Q)); 14 | 15 | % compute choice probability for actual choice 16 | choiceProb(t) = p(a(t)); 17 | 18 | % update values 19 | delta = r(t) - Q(a(t)); 20 | Q(a(t)) = Q(a(t)) + alpha * delta; 21 | 22 | end 23 | 24 | % compute negative log-likelihood 25 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_M4CK_v1.m: -------------------------------------------------------------------------------- 1 | function NegLL = lik_M4CK_v1(a, r, alpha_c, beta_c) 2 | 3 | 4 | CK = [0 0]; 5 | 6 | 7 | T = length(a); 8 | 9 | % loop over all trial 10 | for t = 1:T 11 | 12 | % compute choice probabilities 13 | p = exp(beta_c*CK) / sum(exp(beta_c*CK)); 14 | 15 | % compute choice probability for actual choice 16 | choiceProb(t) = p(a(t)); 17 | 18 | % update choice kernel 19 | CK = (1-alpha_c) * CK; 20 | CK(a(t)) = CK(a(t)) + alpha_c * 1; 21 | 22 | end 23 | 24 | % compute negative log-likelihood 25 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_M5RWCK_v1.m: -------------------------------------------------------------------------------- 1 | function NegLL = lik_M5RWCK_v1(a, r, alpha, beta, alpha_c, beta_c) 2 | 3 | Q = [0.5 0.5]; 4 | CK = [0 0]; 5 | 6 | 7 | T = length(a); 8 | 9 | % loop over all trial 10 | for t = 1:T 11 | 12 | % compute choice probabilities 13 | V = beta * Q + beta_c * CK; 14 | p = exp(V) / sum(exp(V)); 15 | 16 | % compute choice probability for actual choice 17 | choiceProb(t) = p(a(t)); 18 | 19 | % update values 20 | delta = r(t) - Q(a(t)); 21 | Q(a(t)) = Q(a(t)) + alpha * delta; 22 | 23 | % update choice kernel 24 | CK = (1-alpha_c) * CK; 25 | CK(a(t)) = CK(a(t)) + alpha_c * 1; 26 | 27 | end 28 | 29 | % compute negative log-likelihood 30 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_M6RescorlaWagnerBias_v1.m: -------------------------------------------------------------------------------- 1 | function NegLL = lik_M6RescorlaWagnerBias_v1(a, r, alpha, beta, Qbias) 2 | 3 | 4 | Q = [0.5 0.5]; 5 | 6 | 7 | T = length(a); 8 | 9 | % loop over all trial 10 | for t = 1:T 11 | 12 | % compute choice probabilities 13 | V = Q; 14 | V(1) = V(1) + Qbias; 15 | p = exp(beta*V) / sum(exp(beta*V)); 16 | 17 | % compute choice probability for actual choice 18 | choiceProb(t) = p(a(t)); 19 | 20 | % update values 21 | delta = r(t) - Q(a(t)); 22 | Q(a(t)) = Q(a(t)) + alpha * delta; 23 | 24 | end 25 | 26 | % compute negative log-likelihood 27 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /LikelihoodFunctions/lik_fullRL_v1.m: -------------------------------------------------------------------------------- 1 | function [NegLL, choiceProb, CP] = lik_fullRL_v1(a, r, s, alpha, beta) 2 | 3 | % values for each state 4 | % Q(a,s) = value of taking action a in state s 5 | Q = zeros(3); 6 | 7 | T = length(a); 8 | for t = 1:T 9 | 10 | % compute choice probabilities 11 | p = exp(beta * Q(:,s(t))); 12 | p = p / sum(p); 13 | CP(:,t) = p; 14 | 15 | % compute probability of chosen option 16 | choiceProb(t) = p(a(t)); 17 | 18 | % update values 19 | Q(a(t),s(t)) = Q(a(t),s(t)) + alpha * (r(t) - Q(a(t),s(t))); 20 | 21 | end 22 | 23 | % compute negative log-likelihood 24 | NegLL = -sum(log(choiceProb)); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TenSimpleRulesModeling 2 | This is the code for the figures in the "Ten Simple Rules for Computational Modeling of Psychological Data" paper. 3 | -------------------------------------------------------------------------------- /SimulationFunctions/choose.m: -------------------------------------------------------------------------------- 1 | function a = choose(p) 2 | 3 | a = max(find([-eps cumsum(p)] < rand)); 4 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_M1random_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r] = simulate_M1random_v1(T, mu, b) 2 | 3 | for t = 1:T 4 | 5 | % compute choice probabilities 6 | p = [b 1-b]; 7 | 8 | % make choice according to choice probababilities 9 | a(t) = choose(p); 10 | 11 | % generate reward based on choice 12 | r(t) = rand < mu(a(t)); 13 | 14 | end 15 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_M2WSLS_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r] = simulate_M2WSLS_v1(T, mu, epsilon) 2 | 3 | % last reward/action (initialize as nan) 4 | rLast = nan; 5 | aLast = nan; 6 | 7 | for t = 1:T 8 | 9 | % compute choice probabilities 10 | if isnan(rLast) 11 | 12 | % first trial choose randomly 13 | p = [0.5 0.5]; 14 | 15 | else 16 | 17 | % choice depends on last reward 18 | if rLast == 1 19 | 20 | % win stay (with probability 1-epsilon) 21 | p = epsilon/2*[1 1]; 22 | p(aLast) = 1-epsilon/2; 23 | 24 | else 25 | 26 | % lose shift (with probability 1-epsilon) 27 | p = (1-epsilon/2) * [1 1]; 28 | p(aLast) = epsilon / 2; 29 | 30 | end 31 | end 32 | 33 | % make choice according to choice probababilities 34 | a(t) = choose(p); 35 | 36 | % generate reward based on choice 37 | r(t) = rand < mu(a(t)); 38 | 39 | 40 | aLast = a(t); 41 | rLast = r(t); 42 | end 43 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_M3RescorlaWagner_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r] = simulate_M3RescorlaWagner_v1(T, mu, alpha, beta) 2 | 3 | Q = [0.5 0.5]; 4 | 5 | for t = 1:T 6 | 7 | % compute choice probabilities 8 | p = exp(beta*Q) / sum(exp(beta*Q)); 9 | 10 | % make choice according to choice probababilities 11 | a(t) = choose(p); 12 | 13 | % generate reward based on choice 14 | r(t) = rand < mu(a(t)); 15 | 16 | % update values 17 | delta = r(t) - Q(a(t)); 18 | Q(a(t)) = Q(a(t)) + alpha * delta; 19 | 20 | end 21 | 22 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_M4ChoiceKernel_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r] = simulate_M4ChoiceKernel_v1(T, mu, alpha_c, beta_c) 2 | 3 | CK = [0 0]; 4 | 5 | for t = 1:T 6 | 7 | % compute choice probabilities 8 | p = exp(beta_c*CK) / sum(exp(beta_c*CK)); 9 | 10 | % make choice according to choice probababilities 11 | a(t) = choose(p); 12 | 13 | % generate reward based on choice 14 | r(t) = rand < mu(a(t)); 15 | 16 | % update choice kernel 17 | CK = (1-alpha_c) * CK; 18 | CK(a(t)) = CK(a(t)) + alpha_c * 1; 19 | 20 | 21 | end 22 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_M5RWCK_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r] = simulate_M5RWCK_v1(T, mu, alpha, beta, alpha_c, beta_c) 2 | 3 | Q = [0.5 0.5]; 4 | CK = [0 0]; 5 | 6 | for t = 1:T 7 | 8 | % compute choice probabilities 9 | V = beta * Q + beta_c * CK; 10 | p = exp(V) / sum(exp(V)); 11 | 12 | % make choice according to choice probababilities 13 | a(t) = choose(p); 14 | 15 | % generate reward based on choice 16 | r(t) = rand < mu(a(t)); 17 | 18 | % update values 19 | delta = r(t) - Q(a(t)); 20 | Q(a(t)) = Q(a(t)) + alpha * delta; 21 | 22 | % update choice kernel 23 | CK = (1-alpha_c) * CK; 24 | CK(a(t)) = CK(a(t)) + alpha_c * 1; 25 | 26 | 27 | end 28 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_M6RescorlaWagnerBias_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r] = simulate_M6RescorlaWagnerBias_v1(T, mu, alpha, beta, Qbias) 2 | 3 | Q = [0.5 0.5]; 4 | 5 | for t = 1:T 6 | 7 | % compute choice probabilities 8 | V = Q; 9 | V(1) = V(1) + Qbias; 10 | p = exp(beta*V) / sum(exp(beta*V)); 11 | 12 | % make choice according to choice probababilities 13 | a(t) = choose(p); 14 | 15 | % generate reward based on choice 16 | r(t) = rand < mu(a(t)); 17 | 18 | % update values 19 | delta = r(t) - Q(a(t)); 20 | Q(a(t)) = Q(a(t)) + alpha * delta; 21 | 22 | end 23 | 24 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_blind_v1.m: -------------------------------------------------------------------------------- 1 | function [AA, RR, SS, QQ] = simulate_blind_v1(alpha, beta, T) 2 | Q = [0 0 0]; 3 | for t = 1:T 4 | 5 | s = randi(3); 6 | 7 | % compute choice probabilities 8 | p = exp(beta * Q); 9 | p = p / sum(p); 10 | 11 | % choose 12 | a = choose(p); 13 | 14 | % determine reward 15 | switch s 16 | case 1 17 | 18 | if a == 1 19 | r = 1; 20 | else 21 | r = 0; 22 | end 23 | 24 | case 2 25 | 26 | if a == 1 27 | r = 1; 28 | else 29 | r = 0; 30 | end 31 | 32 | case 3 33 | 34 | if a == 3 35 | r = 1; 36 | else 37 | r = 0; 38 | end 39 | end 40 | 41 | % update values 42 | Q(a) = Q(a) + alpha * (r - Q(a)); 43 | QQ(:,t) = Q; 44 | AA(t) = a; 45 | SS(t) = s; 46 | RR(t) = r; 47 | end -------------------------------------------------------------------------------- /SimulationFunctions/simulate_fullRL_v1.m: -------------------------------------------------------------------------------- 1 | function [a, r, s] = simulate_fullRL_v1(alpha, beta, T) 2 | 3 | % values for each state 4 | % Q(a,s) = value of taking action a in state s 5 | Q = zeros(3); 6 | 7 | for t = 1:T 8 | 9 | 10 | s(t) = randi(3); 11 | 12 | % compute choice probabilities 13 | p = exp(beta * Q(:,s(t))); 14 | p = p / sum(p); 15 | 16 | % choose 17 | a(t) = choose(p'); 18 | 19 | % determine reward 20 | switch s(t) 21 | case 1 22 | 23 | if a(t) == 1 24 | r(t) = 1; 25 | else 26 | r(t) = 0; 27 | end 28 | 29 | case 2 30 | 31 | if a(t) == 1 32 | r(t) = 1; 33 | else 34 | r(t) = 0; 35 | end 36 | 37 | case 3 38 | 39 | if a(t) == 3 40 | r(t) = 1; 41 | else 42 | r(t) = 0; 43 | end 44 | end 45 | 46 | % update values 47 | Q(a(t),s(t)) = Q(a(t),s(t)) + alpha * (r(t) - Q(a(t),s(t))); 48 | 49 | end 50 | -------------------------------------------------------------------------------- /SimulationFunctions/simulate_validationModel_v1.m: -------------------------------------------------------------------------------- 1 | function [AA, RR, QQ] = simulate_validationModel_v1(alpha, beta, T) 2 | Q = [0 0 0]; 3 | for t = 1:T 4 | 5 | s = randi(3); 6 | 7 | % compute choice probabilities 8 | p = exp(beta * Q); 9 | p = p / sum(p); 10 | 11 | % choose 12 | a = choose(p); 13 | 14 | % determine reward 15 | switch s 16 | case 1 17 | 18 | if a == 1 19 | r = 1; 20 | else 21 | r = 0; 22 | end 23 | 24 | case 2 25 | 26 | if a == 1 27 | r = 1; 28 | else 29 | r = 0; 30 | end 31 | 32 | case 3 33 | 34 | if a == 3 35 | r = 1; 36 | else 37 | r = 0; 38 | end 39 | end 40 | 41 | % update values 42 | Q(a) = Q(a) + alpha * (r - Q(a)); 43 | QQ(:,t) = Q; 44 | AA(t) = a; 45 | RR(t) = r; 46 | end -------------------------------------------------------------------------------- /WilsonCollins_TenRulesForModelFitting.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnneCollins/TenSimpleRulesModeling/1f2709ed595da123b441b8faff5147bb94c97795/WilsonCollins_TenRulesForModelFitting.pdf --------------------------------------------------------------------------------