├── README.md ├── construct_CSC.m ├── construct_stimulus.m ├── RW.m ├── TD.m ├── KRW.m ├── KTD.m ├── exercises.m └── simulate_models.m /README.md: -------------------------------------------------------------------------------- 1 | RL-tutorial 2 | ==== 3 | 4 | Reinforcement learning tutorial. See wiki for more information. 5 | 6 | Questions? Contact Sam Gershman (gershman@fas.harvard.edu). -------------------------------------------------------------------------------- /construct_CSC.m: -------------------------------------------------------------------------------- 1 | function x = construct_CSC(s) 2 | 3 | % Create complete serial compound representation. 4 | % 5 | % USAGE: x = construct_CSC(s) 6 | % 7 | % INPUTS: 8 | % s - stimulus timeseries (see construct_stimulus.m) 9 | % 10 | % OUTPUTS: 11 | % x - [trial_length x trial_length*D] complete serial compound representation 12 | % 13 | % Sam Gershman, June 2017 14 | 15 | x = []; 16 | for d = 1:size(s,2) 17 | x = [x diag(s(:,d))]; 18 | end -------------------------------------------------------------------------------- /construct_stimulus.m: -------------------------------------------------------------------------------- 1 | function s = construct_stimulus(stim) 2 | 3 | % Create complete serial compound representation. 4 | % 5 | % USAGE: s = construct_stimulus(stim) 6 | % 7 | % INPUTS: 8 | % stim - structure containing the following fields: 9 | % .trial_length - length of trial 10 | % .onset [1 x D] - onset for each stimulus (relative to trial start, where 1 = trial start) 11 | % .dur [1 x D] - duration for each stimulus 12 | % 13 | % OUTPUTS: 14 | % s - [trial_length x D] stimulus timeseries 15 | % 16 | % Sam Gershman, June 2017 17 | 18 | D = length(stim.onset); 19 | s = zeros(stim.trial_length,D); 20 | for d = 1:D 21 | s(stim.onset(d):(stim.onset(d)+stim.dur(d)-1),d) = 1; 22 | end -------------------------------------------------------------------------------- /RW.m: -------------------------------------------------------------------------------- 1 | function model = RW(X,r,param) 2 | 3 | % Rescorla-Wagner model 4 | % 5 | % USAGE: model = RW(X,r,[param]) 6 | % 7 | % INPUTS: 8 | % X - [N x D] matrix of stimulus features, where N is the number of 9 | % timepoints, D is the number of features. 10 | % r - [N x 1] vector of rewards 11 | % param (optional) - parameter structure with the following fields: 12 | % .alpha - learning rate (default: 0.3) 13 | % 14 | % OUTPUTS: 15 | % model - [1 x N] structure with the following fields for each timepoint: 16 | % .w - [D x 1] estimated weight vector 17 | % .dt - prediction error 18 | % .rhat - reward prediction 19 | % .V - value estimate 20 | % 21 | % Sam Gershman, June 2017 22 | 23 | % initialization 24 | [N,D] = size(X); 25 | w = zeros(D,1); % weights 26 | 27 | % parameters 28 | if nargin < 3 || isempty(param); param = struct('alpha',0.3); end 29 | alpha = param.alpha; % learning rate 30 | 31 | % run Kalman filter 32 | for n = 1:N 33 | 34 | rhat = X(n,:)*w; % reward prediction 35 | dt = r(n) - rhat; % prediction error 36 | w = w + alpha*dt*X(n,:)'; % weight update 37 | 38 | % store results 39 | model(n) = struct('w',w,'dt',dt,'rhat',rhat); 40 | 41 | end -------------------------------------------------------------------------------- /TD.m: -------------------------------------------------------------------------------- 1 | function model = TD(X,r,param) 2 | 3 | % Temporal difference learning model 4 | % 5 | % USAGE: model = TD(X,r,[param]) 6 | % 7 | % INPUTS: 8 | % X - [N x D] matrix of stimulus features, where N is the number of 9 | % timepoints, D is the number of features. 10 | % r - [N x 1] vector of rewards 11 | % param (optional) - parameter structure with the following fields: 12 | % .alpha - learning rate (default: 0.17) 13 | % .g - discount factor (default: 0.9) 14 | % 15 | % OUTPUTS: 16 | % model - [1 x N] structure with the following fields for each timepoint: 17 | % .w - [D x 1] estimated weight vector 18 | % .dt - prediction error 19 | % .rhat - reward prediction 20 | % .V - value estimate 21 | % 22 | % Sam Gershman, June 2017 23 | 24 | % initialization 25 | [N,D] = size(X); 26 | w = zeros(D,1); % weights 27 | X = [X; zeros(1,D)]; % add buffer at end 28 | 29 | % parameters 30 | if nargin < 3 || isempty(param); param = struct('alpha',0.17,'g',0.9); end 31 | alpha = param.alpha; % learning rate 32 | g = param.g; % discount factor 33 | 34 | % run Kalman filter 35 | for n = 1:N 36 | 37 | h = X(n,:) - g*X(n+1,:); % temporal difference features 38 | V = X(n,:)*w; % value estimate 39 | rhat = h*w; % reward prediction 40 | dt = r(n) - rhat; % prediction error 41 | w = w + alpha*dt*h'; % weight update 42 | 43 | % store results 44 | model(n) = struct('w',w,'dt',dt,'rhat',rhat,'V',V); 45 | 46 | end -------------------------------------------------------------------------------- /KRW.m: -------------------------------------------------------------------------------- 1 | function model = KRW(X,r,param) 2 | 3 | % Kalman Rescorla-Wagner model 4 | % 5 | % USAGE: model = KRW(X,r,[param]) 6 | % 7 | % INPUTS: 8 | % X - [N x D] matrix of stimulus features, where N is the number of 9 | % timepoints, D is the number of features. 10 | % r - [N x 1] vector of rewards 11 | % param (optional) - parameter structure with the following fields: 12 | % .c - prior variance (default: 1) 13 | % .s - observation noise variance (default: 1) 14 | % .q - transition noise variance (default: 0.01) 15 | % 16 | % OUTPUTS: 17 | % model - [1 x N] structure with the following fields for each timepoint: 18 | % .w - [D x 1] posterior mean weight vector 19 | % .C - [D x D] posterior weight covariance 20 | % .K - [D x 1] Kalman gain (learning rates for each dimension) 21 | % .dt - prediction error 22 | % .rhat - reward prediction 23 | % 24 | % Sam Gershman, June 2017 25 | 26 | % initialization 27 | [N,D] = size(X); 28 | w = zeros(D,1); % weights 29 | X = [X; zeros(1,D)]; % add buffer at end 30 | 31 | % parameters 32 | if nargin < 3 || isempty(param); param = struct('c',1,'s',1,'q',0.01); end 33 | C = param.c*eye(D); % prior variance 34 | s = param.s; % observation noise variance 35 | Q = param.q*eye(D); % transition noise variance 36 | 37 | % run Kalman filter 38 | for n = 1:N 39 | 40 | h = X(n,:); % stimulus features 41 | rhat = h*w; % reward prediction 42 | dt = r(n) - rhat; % prediction error 43 | C = C + Q; % a priori covariance 44 | P = h*C*h'+s; % residual covariance 45 | K = C*h'/P; % Kalman gain 46 | w = w + K*dt; % weight update 47 | C = C - K*h*C; % posterior covariance update 48 | 49 | % store results 50 | model(n) = struct('w',w,'C',C,'K',K,'dt',dt,'rhat',rhat); 51 | 52 | end -------------------------------------------------------------------------------- /KTD.m: -------------------------------------------------------------------------------- 1 | function model = KTD(X,r,param) 2 | 3 | % Kalman temporal difference learning model 4 | % 5 | % USAGE: model = KTD(X,r,[param]) 6 | % 7 | % INPUTS: 8 | % X - [N x D] matrix of stimulus features, where N is the number of 9 | % timepoints, D is the number of features. 10 | % r - [N x 1] vector of rewards 11 | % param (optional) - parameter structure with the following fields: 12 | % .c - prior variance (default: 1) 13 | % .s - observation noise variance (default: 1) 14 | % .q - transition noise variance (default: 0.01) 15 | % .g - discount factor (default: 0.9) 16 | % 17 | % OUTPUTS: 18 | % model - [1 x N] structure with the following fields for each timepoint: 19 | % .w - [D x 1] posterior mean weight vector 20 | % .C - [D x D] posterior weight covariance 21 | % .K - [D x 1] Kalman gain (learning rates for each dimension) 22 | % .dt - prediction error 23 | % .rhat - reward prediction 24 | % .V - value estimate 25 | % 26 | % Sam Gershman, June 2017 27 | 28 | % initialization 29 | [N,D] = size(X); 30 | w = zeros(D,1); % weights 31 | X = [X; zeros(1,D)]; % add buffer at end 32 | 33 | % parameters 34 | if nargin < 3 || isempty(param); param = struct('c',1,'s',1,'q',0.005,'g',0.9); end 35 | C = param.c*eye(D); % prior variance 36 | s = param.s; % observation noise variance 37 | Q = param.q*eye(D); % transition noise variance 38 | g = param.g; % discount factor 39 | 40 | % run Kalman filter 41 | for n = 1:N 42 | 43 | h = X(n,:) - g*X(n+1,:); % temporal difference features 44 | V = X(n,:)*w; % value estimate 45 | rhat = h*w; % reward prediction 46 | dt = r(n) - rhat; % prediction error 47 | C = C + Q; % a priori covariance 48 | P = h*C*h'+s; % residual covariance 49 | K = C*h'/P; % Kalman gain 50 | w = w + K*dt; % weight update 51 | C = C - K*h*C; % posterior covariance update 52 | 53 | % store results 54 | model(n) = struct('w',w,'C',C,'K',K,'dt',dt,'rhat',rhat,'V',V); 55 | 56 | end -------------------------------------------------------------------------------- /exercises.m: -------------------------------------------------------------------------------- 1 | function exercises(exercise) 2 | 3 | % Exercise solutions 4 | % 5 | % USAGE: exercises(exercise) 6 | % 7 | % INPUTS: 8 | % sim - which simulation to run: 9 | % 'latent inhibition' (RW and KRW models) 10 | % 'delay conditioning' (TD model) 11 | % 12 | % Sam Gershman, June 2017 13 | 14 | switch exercise 15 | 16 | case 'latent inhibition' 17 | 18 | % construct stimuli 19 | nTrials = 10; % number of trials 20 | X_pre = ones(2*nTrials,1); 21 | X_nopre = [zeros(nTrials,1); ones(nTrials,1)]; 22 | r = [zeros(nTrials,1); ones(nTrials,1)]; 23 | X_pre_delay = [ones(nTrials,1); zeros(nTrials,1); ones(nTrials,1)]; 24 | r_delay = [zeros(nTrials*2,1); ones(nTrials,1)]; 25 | X_nopre_delay = [zeros(2*nTrials,1); ones(nTrials,1)]; 26 | 27 | % run models 28 | model_RW_pre = RW(X_pre,r); 29 | model_KRW_pre = KRW(X_pre,r); 30 | model_RW_nopre = RW(X_nopre,r); 31 | model_KRW_nopre = KRW(X_nopre,r); 32 | model_KRW_pre_delay = KRW(X_pre_delay,r_delay); 33 | model_KRW_nopre_delay = KRW(X_nopre_delay,r_delay); 34 | 35 | % plot weights: no delay 36 | for n=1:nTrials*2 37 | w_RW(n,:) = [model_RW_pre(n).w model_RW_nopre(n).w]; 38 | w_KRW(n,:) = [model_KRW_pre(n).w model_KRW_nopre(n).w]; 39 | end 40 | figure; 41 | plot(w_RW,'LineWidth',3); hold on; 42 | plot(w_KRW,'--','LineWidth',3); 43 | xlabel('Trial','FontSize',25); 44 | ylabel('Weight','FontSize',25); 45 | set(gca,'FontSize',25); 46 | legend({'RW (pre)' 'RW (nopre)' 'KRW (pre)' 'KRW (nopre)'},'FontSize',25); 47 | 48 | % plot weights: delay 49 | for n=1:nTrials*3 50 | w_KRW_delay(n,:) = [model_KRW_pre_delay(n).w model_KRW_nopre_delay(n).w]; 51 | end 52 | figure; 53 | LI = [w_KRW(end,2)-w_KRW(end,1) w_KRW_delay(end,2)-w_KRW_delay(end,1)]; 54 | bar(LI); colormap bone 55 | ylabel('Latent inhibition effect (nopre-pre)','FontSize',25); 56 | set(gca,'FontSize',25,'XTickLabel',{'No Delay' 'Delay'}); 57 | 58 | case 'delay conditioning' 59 | 60 | % construct stimuli 61 | nTrials = 10; % number of trials 62 | trial_length = 8; 63 | s = construct_stimulus(struct('trial_length',trial_length,'onset',1,'dur',5)); % A->+ 64 | x = construct_CSC(s); % complete serial compound representation 65 | r = construct_stimulus(struct('trial_length',trial_length,'onset',5,'dur',1)); 66 | X = repmat(x,nTrials+1,1); 67 | r = [repmat(r,nTrials,1); zeros(trial_length,1)]; 68 | 69 | % run model 70 | model = TD(X,r); 71 | 72 | % plot TD error 73 | dt = [model.dt]; 74 | dt = reshape(dt,trial_length,nTrials+1)'; 75 | figure; imagesc(dt(1:end-1,:)); colormap hot; colorbar 76 | xlabel('Timestep','FontSize',25); 77 | ylabel('Trial','FontSize',25); 78 | set(gca,'FontSize',25,'XTick',1:trial_length,'YTick',1:nTrials); 79 | title('TD error','FontSize',25); 80 | 81 | % plot omission response 82 | figure; plot(dt(end,:),'LineWidth',3); 83 | xlabel('Timestep','FontSize',25); 84 | ylabel('TD error','FontSize',25); 85 | set(gca,'FontSize',25,'XTick',1:trial_length); 86 | title('Omission response','FontSize',25); 87 | end -------------------------------------------------------------------------------- /simulate_models.m: -------------------------------------------------------------------------------- 1 | function simulate_models(sim) 2 | 3 | % Simulate models 4 | % 5 | % USAGE: simulate_models(sim) 6 | % 7 | % INPUTS: 8 | % sim - which simulation to run: 9 | % 'second-order conditioning' (TD and KTD models) 10 | % 'forward blocking' (RW and KRW models) 11 | % 'backward blocking' (RW and KRW models) 12 | % 13 | % Sam Gershman, June 2017 14 | 15 | switch sim 16 | 17 | case 'second-order conditioning' 18 | 19 | % construct stimuli 20 | nTrials = 10; % number of trials 21 | trial_length = 5; 22 | s1 = zeros(trial_length,2); 23 | s1(:,2) = construct_stimulus(struct('trial_length',trial_length,'onset',2,'dur',1)); % B->+ 24 | x1 = construct_CSC(s1); % complete serial compound representation 25 | r1 = construct_stimulus(struct('trial_length',trial_length,'onset',2,'dur',1)); 26 | s2 = construct_stimulus(struct('trial_length',trial_length,'onset',[1 2],'dur',[1 1])); % A->B->- 27 | x2 = construct_CSC(s2); 28 | r2 = zeros(trial_length,1); 29 | x3 = x1; % B->- 30 | r3 = r2; 31 | X = [repmat(x1,nTrials,1); repmat(x2,nTrials,1); repmat(x3,nTrials,1)]; 32 | r = [repmat(r1,nTrials,1); repmat(r2,nTrials,1); repmat(r3,nTrials,1)]; 33 | 34 | % run models 35 | model_TD = TD(X,r); 36 | model_KTD = KTD(X,r); 37 | for n=1:length(model_TD) 38 | w_TD(n,:) = model_TD(n).w; 39 | w_KTD(n,:) = model_KTD(n).w; 40 | end 41 | 42 | % plot weights 43 | plot(w_TD(2*nTrials*trial_length+1:trial_length:end,[1 trial_length+2]),'LineWidth',3); hold on; 44 | plot(w_KTD(2*nTrials*trial_length+1:trial_length:end,[1 trial_length+2]),'--','LineWidth',3); 45 | xlabel('Extinction trial','FontSize',25); 46 | ylabel('Weight','FontSize',25); 47 | set(gca,'FontSize',25); 48 | legend({'A (TD)' 'B (TD)' 'A (KTD)' 'B (KTD)'},'FontSize',25); 49 | 50 | case 'forward blocking' 51 | 52 | % construct stimuli 53 | nTrials = 10; % number of trials 54 | X = [repmat([1 0],nTrials,1); repmat([1 1],nTrials,1)]; 55 | r = ones(nTrials*2,1); 56 | 57 | % run models 58 | model_RW = RW(X,r); 59 | model_KRW = KRW(X,r); 60 | 61 | % plot weights 62 | for n=1:length(model_RW) 63 | w_RW(n,:) = model_RW(n).w; 64 | w_KRW(n,:) = model_KRW(n).w; 65 | end 66 | plot(w_RW,'LineWidth',3); hold on; 67 | plot(w_KRW,'--','LineWidth',3); 68 | xlabel('Trial','FontSize',25); 69 | ylabel('Weight','FontSize',25); 70 | set(gca,'FontSize',25); 71 | legend({'A (RW)' 'B (RW)' 'A (KRW)' 'B (KRW)'},'FontSize',25); 72 | 73 | case 'backward blocking' 74 | 75 | % construct stimuli 76 | nTrials = 10; % number of trials 77 | X = [repmat([1 1],nTrials,1); repmat([1 0],nTrials,1)]; 78 | r = ones(nTrials*2,1); 79 | 80 | % run models 81 | model_RW = RW(X,r); 82 | model_KRW = KRW(X,r); 83 | 84 | % plot weights 85 | for n=1:length(model_RW) 86 | w_RW(n,:) = model_RW(n).w; 87 | w_KRW(n,:) = model_KRW(n).w; 88 | end 89 | plot(w_RW,'LineWidth',3); hold on; 90 | plot(w_KRW,'--','LineWidth',3); 91 | xlabel('Trial','FontSize',25); 92 | ylabel('Weight','FontSize',25); 93 | set(gca,'FontSize',25); 94 | legend({'A (RW)' 'B (RW)' 'A (KRW)' 'B (KRW)'},'FontSize',25); 95 | 96 | end --------------------------------------------------------------------------------