├── EFE_Precision_Updating.m ├── EFE_learning_novelty_term.m ├── Estimate_parameters.m ├── Message_passing_example.m ├── Pencil_and_paper_exercise_solutions.m ├── Prediction_error_example.m ├── README.md ├── Simplified_simulation_script.m ├── Step_by_Step_AI_Guide.m ├── Step_by_Step_Hierarchical_Model.m ├── VFE_calculation_example.m ├── spm_MDP_VB_ERP_tutorial.m ├── spm_MDP_VB_X_tutorial.m └── spm_MDP_VB_game_tutorial.m /EFE_Precision_Updating.m: -------------------------------------------------------------------------------- 1 | %% Example code for simulated expected free energy precision (beta/gamma) updates 2 | % (associated with dopamine in the neural process theory) 3 | 4 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 5 | % Application to Empirical Data 6 | 7 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 8 | 9 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 10 | 11 | clear all 12 | close all 13 | 14 | % This script will reproduce the simulation results in Figure 9 15 | 16 | % Here you can set the number of policies and the distributions that 17 | % contribute to prior and posterior policy precision 18 | 19 | E = [1 1 1 1 1]'; % Set a fixed-form prior distribution 20 | % over policies (habits) 21 | 22 | G = [12.505 9.51 12.5034 12.505 12.505]'; % Set an example expected 23 | % free energy distribution over policies 24 | 25 | F = [17.0207 1.7321 1.7321 17.0387 17.0387]'; % Set an example variational 26 | % free energy distribution over 27 | % policies after a new observation 28 | 29 | 30 | gamma_0 = 1; % Starting expected free energy precision value 31 | gamma = gamma_0; % Initial expected free energy precision to be updated 32 | beta_prior = 1/gamma; % Initial prior on expected free energy precision 33 | beta_posterior = beta_prior; % Initial posterior on expected free energy precision 34 | psi = 2; % Step size parameter (promotes stable convergence) 35 | 36 | for ni = 1:16 % number of variational updates (16) 37 | 38 | % calculate prior and posterior over policies (see main text for 39 | % explanation of equations) 40 | 41 | pi_0 = exp(log(E) - gamma*G)/sum(exp(log(E) - gamma*G)); % prior over policies 42 | 43 | pi_posterior = exp(log(E) - gamma*G - F)/sum(exp(log(E) - gamma*G - F)); % posterior 44 | % over policies 45 | % calculate expected free energy precision 46 | 47 | G_error = (pi_posterior - pi_0)'*-G; % expected free energy prediction error 48 | 49 | beta_update = beta_posterior - beta_prior + G_error; % change in beta: 50 | % gradient of F with respect to gamma 51 | % (recall gamma = 1/beta) 52 | 53 | beta_posterior = beta_posterior - beta_update/psi; % update posterior precision 54 | % estimate (with step size of psi = 2, which reduces 55 | % the magnitude of each update and can promote 56 | % stable convergence) 57 | 58 | gamma = 1/beta_posterior; % update expected free energy precision 59 | 60 | % simulate dopamine responses 61 | 62 | n = ni; 63 | 64 | gamma_dopamine(n,1) = gamma; % simulated neural encoding of precision 65 | % (beta_posterior^-1) at each iteration of 66 | % variational updating 67 | 68 | policies_neural(:,n) = pi_posterior; % neural encoding of posterior over policies at 69 | % each iteration of variational updating 70 | end 71 | 72 | %% Show Results 73 | 74 | disp(' '); 75 | disp('Final Policy Prior:'); 76 | disp(pi_0); 77 | disp(' '); 78 | disp('Final Policy Posterior:'); 79 | disp(pi_posterior); 80 | disp(' '); 81 | disp('Final Policy Difference Vector:'); 82 | disp(pi_posterior-pi_0); 83 | disp(' '); 84 | disp('Negative Expected Free Energy:'); 85 | disp(-G); 86 | disp(' '); 87 | disp('Prior G Precision (Prior Gamma):'); 88 | disp(gamma_0); 89 | disp(' '); 90 | disp('Posterior G Precision (Gamma):'); 91 | disp(gamma); 92 | disp(' '); 93 | 94 | gamma_dopamine_plot = [gamma_0;gamma_0;gamma_0;gamma_dopamine]; % Include prior value 95 | 96 | figure 97 | plot(gamma_dopamine_plot); 98 | ylim([min(gamma_dopamine_plot)-.05 max(gamma_dopamine_plot)+.05]) 99 | title('Expected Free Energy Precision (Tonic Dopamine)'); 100 | xlabel('Updates'); 101 | ylabel('\gamma'); 102 | 103 | figure 104 | plot([gradient(gamma_dopamine_plot)],'r'); 105 | ylim([min(gradient(gamma_dopamine_plot))-.01 max(gradient(gamma_dopamine_plot))+.01]) 106 | title('Rate of Change in Precision (Phasic Dopamine)'); 107 | xlabel('Updates'); 108 | ylabel('\gamma gradient'); 109 | 110 | % uncomment if you want to display/plot firing rates encoding beliefs about each 111 | % policy (columns = policies, rows = updates over time) 112 | 113 | % plot(policies_neural); 114 | % disp('Firing rates encoding beliefs over policies:'); 115 | % disp(policies_neural'); 116 | % disp(' '); 117 | -------------------------------------------------------------------------------- /EFE_learning_novelty_term.m: -------------------------------------------------------------------------------- 1 | %% Calculating novelty term in expected free energy when learning 'A' matrix concentration parameters 2 | % (which drives parameter exploration) 3 | 4 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 5 | % Application to Empirical Data 6 | 7 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 8 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 9 | 10 | clear 11 | close all 12 | 13 | %-- 'a' = concentration parameters for likelihood matrix 'A' 14 | 15 | % small concentration parameter values 16 | a1 = [.25 1; 17 | .75 1]; 18 | 19 | % intermediate concentration parameter values 20 | a2 = [2.5 10; 21 | 7.5 10]; 22 | 23 | % large concentration parameter values 24 | a3 = [25 100; 25 | 75 100]; 26 | 27 | % normalize columns in 'a' to get likelihood matrix 'A' (see col_norm 28 | % function at the end of script) 29 | A1 = col_norm(a1); 30 | A2 = col_norm(a2); 31 | A3 = col_norm(a3); 32 | 33 | % calculate 'a_sum' 34 | a1_sum = [a1(1,1)+a1(2,1) a1(1,2)+a1(2,2); 35 | a1(1,1)+a1(2,1) a1(1,2)+a1(2,2)]; 36 | 37 | a2_sum = [a2(1,1)+a2(2,1) a2(1,2)+a2(2,2); 38 | a2(1,1)+a2(2,1) a2(1,2)+a2(2,2)]; 39 | 40 | a3_sum = [a3(1,1)+a3(2,1) a3(1,2)+a3(2,2); 41 | a3(1,1)+a3(2,1) a3(1,2)+a3(2,2)]; 42 | 43 | % element wise inverse for 'a' and 'a_sum' 44 | inv_a1 = [1/a1(1,1) 1/a1(1,2); 45 | 1/a1(2,1) 1/a1(2,2)]; 46 | 47 | inv_a2 = [1/a2(1,1) 1/a2(1,2); 48 | 1/a2(2,1) 1/a2(2,2)]; 49 | 50 | inv_a3 = [1/a3(1,1) 1/a3(1,2); 51 | 1/a3(2,1) 1/a3(2,2)]; 52 | 53 | inv_a1_sum = [1/a1_sum(1,1) 1/a1_sum(1,2); 54 | 1/a1_sum(2,1) 1/a1_sum(2,2)]; 55 | 56 | inv_a2_sum = [1/a2_sum(1,1) 1/a2_sum(1,2); 57 | 1/a2_sum(2,1) 1/a2_sum(2,2)]; 58 | 59 | inv_a3_sum = [1/a3_sum(1,1) 1/a3_sum(1,2); 60 | 1/a3_sum(2,1) 1/a3_sum(2,2)]; 61 | 62 | % 'W' term for 'a' matrix 63 | W1 = .5*(inv_a1-inv_a1_sum); 64 | W2 = .5*(inv_a2-inv_a2_sum); 65 | W3 = .5*(inv_a3-inv_a3_sum); 66 | 67 | % beliefs over states under a policy at a time point 68 | s_pi_tau = [.9 .1]'; 69 | 70 | % predictive posterior over outcomes (A*s_pi_tau = predicted o_pi_tau) 71 | A1s = A1*s_pi_tau; 72 | A2s = A2*s_pi_tau; 73 | A3s = A3*s_pi_tau; 74 | 75 | % W term multiplied by beliefs over states under a policy at a time point 76 | W1s = W1*s_pi_tau; 77 | W2s = W2*s_pi_tau; 78 | W3s = W3*s_pi_tau; 79 | 80 | % compute novelty using dot product function 81 | Novelty_smallCP = dot((A1s),(W1s)); 82 | Novelty_intermediateCP = dot((A2s),(W2s)); 83 | Novelty_largeCP = dot((A3s),(W3s)); 84 | 85 | 86 | % show results 87 | disp(' '); 88 | disp('Novelty term for small concentration parameter values:'); 89 | disp(Novelty_smallCP); 90 | disp(' '); 91 | disp('Novelty term for intermediate concentration parameter values:'); 92 | disp(Novelty_intermediateCP); 93 | disp(' '); 94 | disp('Novelty term for large concentration parameter values:'); 95 | disp(Novelty_largeCP); 96 | disp(' '); 97 | 98 | 99 | %% function for normalizing 'a' to get likelihood matrix 'A' 100 | function A_normed = col_norm(A_norm) 101 | aa = A_norm; 102 | norm_constant = sum(aa,1); % create normalizing constant from sum of columns 103 | aa = aa./norm_constant; % divide columns by constant 104 | A_normed = aa; 105 | end 106 | -------------------------------------------------------------------------------- /Estimate_parameters.m: -------------------------------------------------------------------------------- 1 | function [DCM] = Estimate_parameters(DCM) 2 | 3 | % MDP inversion using Variational Bayes 4 | % FORMAT [DCM] = spm_dcm_mdp(DCM) 5 | % 6 | % Expects: 7 | %-------------------------------------------------------------------------- 8 | % DCM.MDP % MDP structure specifying a generative model 9 | % DCM.field % parameter (field) names to optimise 10 | % DCM.U % cell array of outcomes (stimuli) 11 | % DCM.Y % cell array of responses (action) 12 | % 13 | % Returns: 14 | %-------------------------------------------------------------------------- 15 | % DCM.M % generative model (DCM) 16 | % DCM.Ep % Conditional means (structure) 17 | % DCM.Cp % Conditional covariances 18 | % DCM.F % (negative) Free-energy bound on log evidence 19 | % 20 | % This routine inverts (cell arrays of) trials specified in terms of the 21 | % stimuli or outcomes and subsequent choices or responses. It first 22 | % computes the prior expectations (and covariances) of the free parameters 23 | % specified by DCM.field. These parameters are log scaling parameters that 24 | % are applied to the fields of DCM.MDP. 25 | % 26 | % If there is no learning implicit in multi-trial games, only unique trials 27 | % (as specified by the stimuli), are used to generate (subjective) 28 | % posteriors over choice or action. Otherwise, all trials are used in the 29 | % order specified. The ensuing posterior probabilities over choices are 30 | % used with the specified choices or actions to evaluate their log 31 | % probability. This is used to optimise the MDP (hyper) parameters in 32 | % DCM.field using variational Laplace (with numerical evaluation of the 33 | % curvature). 34 | % 35 | %__________________________________________________________________________ 36 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 37 | 38 | % Karl Friston 39 | % $Id: spm_dcm_mdp.m 7120 2017-06-20 11:30:30Z spm $ 40 | 41 | % OPTIONS 42 | %-------------------------------------------------------------------------- 43 | ALL = false; 44 | 45 | % Here we specify prior expectations (for parameter means and variances) 46 | %-------------------------------------------------------------------------- 47 | prior_variance = 1/4; % smaller values will lead to a greater complexity 48 | % penalty (posteriors will remain closer to priors) 49 | 50 | for i = 1:length(DCM.field) 51 | field = DCM.field{i}; 52 | try 53 | param = DCM.MDP.(field); 54 | param = double(~~param); 55 | catch 56 | param = 1; 57 | end 58 | if ALL 59 | pE.(field) = zeros(size(param)); 60 | pC{i,i} = diag(param); 61 | else 62 | if strcmp(field,'alpha') 63 | pE.(field) = log(16); % in log-space (to keep positive) 64 | pC{i,i} = prior_variance; 65 | elseif strcmp(field,'beta') 66 | pE.(field) = log(1); % in log-space (to keep positive) 67 | pC{i,i} = prior_variance; 68 | elseif strcmp(field,'la') 69 | pE.(field) = log(1); % in log-space (to keep positive) 70 | pC{i,i} = prior_variance; 71 | elseif strcmp(field,'rs') 72 | pE.(field) = log(5); % in log-space (to keep positive) 73 | pC{i,i} = prior_variance; 74 | elseif strcmp(field,'eta') 75 | pE.(field) = log(0.5/(1-0.5)); % in logit-space - bounded between 0 and 1 76 | pC{i,i} = prior_variance; 77 | elseif strcmp(field,'omega') 78 | pE.(field) = log(0.5/(1-0.5)); % in logit-space - bounded between 0 and 1 79 | pC{i,i} = prior_variance; 80 | else 81 | pE.(field) = 0; % if it can take any negative or positive value 82 | pC{i,i} = prior_variance; 83 | end 84 | end 85 | end 86 | 87 | pC = spm_cat(pC); 88 | 89 | % model specification 90 | %-------------------------------------------------------------------------- 91 | M.L = @(P,M,U,Y)spm_mdp_L(P,M,U,Y); % log-likelihood function 92 | M.pE = pE; % prior means (parameters) 93 | M.pC = pC; % prior variance (parameters) 94 | M.mdp = DCM.MDP; % MDP structure 95 | 96 | % Variational Laplace 97 | %-------------------------------------------------------------------------- 98 | [Ep,Cp,F] = spm_nlsi_Newton(M,DCM.U,DCM.Y); % This is the actual fitting routine 99 | 100 | % Store posterior distributions and log evidence (free energy) 101 | %-------------------------------------------------------------------------- 102 | DCM.M = M; % Generative model 103 | DCM.Ep = Ep; % Posterior parameter estimates 104 | DCM.Cp = Cp; % Posterior variances and covariances 105 | DCM.F = F; % Free energy of model fit 106 | 107 | return 108 | 109 | function L = spm_mdp_L(P,M,U,Y) 110 | % log-likelihood function 111 | % FORMAT L = spm_mdp_L(P,M,U,Y) 112 | % P - parameter structure 113 | % M - generative model 114 | % U - inputs 115 | % Y - observed repsonses 116 | % 117 | % This function runs the generative model with a given set of parameter 118 | % values, after adding in the observations and actions on each trial 119 | % from (real or simulated) participant data. It then sums the 120 | % (log-)probabilities (log-likelihood) of the participant's actions under the model when it 121 | % includes that set of parameter values. The variational Bayes fitting 122 | % routine above uses this function to find the set of parameter values that maximize 123 | % the probability of the participant's actions under the model (while also 124 | % penalizing models with parameter values that move farther away from prior 125 | % values). 126 | %__________________________________________________________________________ 127 | 128 | if ~isstruct(P); P = spm_unvec(P,M.pE); end 129 | 130 | % Here we re-transform parameter values out of log- or logit-space when 131 | % inserting them into the model to compute the log-likelihood 132 | %-------------------------------------------------------------------------- 133 | mdp = M.mdp; 134 | field = fieldnames(M.pE); 135 | for i = 1:length(field) 136 | if strcmp(field{i},'alpha') 137 | mdp.(field{i}) = exp(P.(field{i})); 138 | elseif strcmp(field{i},'beta') 139 | mdp.(field{i}) = exp(P.(field{i})); 140 | elseif strcmp(field{i},'la') 141 | mdp.(field{i}) = exp(P.(field{i})); 142 | elseif strcmp(field{i},'rs') 143 | mdp.(field{i}) = exp(P.(field{i})); 144 | elseif strcmp(field{i},'eta') 145 | mdp.(field{i}) = 1/(1+exp(-P.(field{i}))); 146 | elseif strcmp(field{i},'omega') 147 | mdp.(field{i}) = 1/(1+exp(-P.(field{i}))); 148 | else 149 | mdp.(field{i}) = exp(P.(field{i})); 150 | end 151 | end 152 | 153 | % place MDP in trial structure 154 | %-------------------------------------------------------------------------- 155 | la = mdp.la_true; % true level of loss aversion 156 | rs = mdp.rs_true; % true preference magnitude for winning (higher = more risk-seeking) 157 | 158 | if isfield(M.pE,'la')&&isfield(M.pE,'rs') 159 | mdp.C{2} = [0 0 0 ; % Null 160 | 0 -mdp.la -mdp.la ; % Loss 161 | 0 mdp.rs mdp.rs/2]; % win 162 | elseif isfield(M.pE,'la') 163 | mdp.C{2} = [0 0 0 ; % Null 164 | 0 -mdp.la -mdp.la ; % Loss 165 | 0 rs rs/2]; % win 166 | elseif isfield(M.pE,'rs') 167 | mdp.C{2} = [0 0 0 ; % Null 168 | 0 -la -la ; % Loss 169 | 0 mdp.rs mdp.rs/2]; % win 170 | else 171 | mdp.C{2} = [0 0 0 ; % Null 172 | 0 -la -la ; % Loss 173 | 0 rs rs/2]; % win 174 | end 175 | 176 | j = 1:numel(U); % observations for each trial 177 | n = numel(j); % number of trials 178 | 179 | [MDP(1:n)] = deal(mdp); % Create MDP with number of specified trials 180 | [MDP.o] = deal(U{j}); % Add observations in each trial 181 | 182 | % solve MDP and accumulate log-likelihood 183 | %-------------------------------------------------------------------------- 184 | MDP = spm_MDP_VB_X_tutorial(MDP); % run model with possible parameter values 185 | 186 | L = 0; % start (log) probability of actions given the model at 0 187 | 188 | for i = 1:numel(Y) % Get probability of true actions for each trial 189 | for j = 1:numel(Y{1}(:,2)) % Only get probability of the second (controllable) state factor 190 | 191 | L = L + log(MDP(i).P(:,Y{i}(2,j),j)+ eps); % sum the (log) probabilities of each action 192 | % given a set of possible parameter values 193 | end 194 | end 195 | 196 | clear('MDP') 197 | 198 | fprintf('LL: %f \n',L) 199 | -------------------------------------------------------------------------------- /Message_passing_example.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %-- Message Passing Examples--% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Supplementary Code for: A Tutorial on Active Inference Modelling and its 6 | % Application to Empirical Data 7 | 8 | % By: Ryan Smith and Christopher J. Whyte 9 | % We also acknowledge Samuel Taylor for contributing to this example code 10 | 11 | % This script provides two examples of (marginal) message passing, based on 12 | % the steps described in the main text. Each of the two examples (sections) 13 | % need to be run separately. The first example fixes all observed 14 | % variables immediately and does not include variables associated with the 15 | % neural process theory. The second example provides observations 16 | % sequentially and also adds in the neural process theory variables. To 17 | % remind the reader, the message passing steps in the main text are: 18 | 19 | % 1. Initialize the values of the approximate posteriors q(s_(?,?) ) 20 | % for all hidden variables (i.e., all edges) in the graph. 21 | % 2. Fix the value of observed variables (here, o_?). 22 | % 3. Choose an edge (V) corresponding to the hidden variable you want to 23 | % infer (here, s_(?,?)). 24 | % 4. Calculate the messages, ?(s_(?,?)), which take on values sent by 25 | % each factor node connected to V. 26 | % 5. Pass a message from each connected factor node N to V (often written 27 | % as ?_(N?V)). 28 | % 6. Update the approximate posterior represented by V according to the 29 | % following rule: q(s_(?,?) )? ? ?(s_(?,?))? ?(s_(?,?)). The arrow 30 | % notation here indicates messages from two different factors arriving 31 | % at the same edge. 32 | % 6A. Normalize the product of these messages so that q(s_(?,?) ) 33 | % corresponds to a proper probability distribution. 34 | % 6B. Use this new q(s_(?,?) ) to update the messages sent by 35 | % connected factors (i.e., for the next round of message passing). 36 | % 7. Repeat steps 4-6 sequentially for each edge. 37 | % 8. Steps 3-7 are then repeated until the difference between updates 38 | % converges to some acceptably low value (i.e., resulting in stable 39 | % posterior beliefs for all edges). 40 | 41 | %% Example 1: Fixed observations and message passing steps 42 | 43 | % This section carries out marginal message passing on a graph with beliefs 44 | % about states at two time points. In this first example, both observations 45 | % are fixed from the start (i.e., there are no ts as in full active inference 46 | % models with sequentially presented observations) to provide the simplest 47 | % example possible. We also highlight where each of the message passing 48 | % steps described in the main text are carried out. 49 | 50 | % Note that some steps (7 and 8) appear out of order when they involve loops that 51 | % repeat earlier steps 52 | 53 | % Specify generative model and initialize variables 54 | 55 | rng('shuffle') 56 | 57 | clear 58 | close all 59 | 60 | % priors 61 | D = [.5 .5]'; 62 | 63 | % likelihood mapping 64 | A = [.9 .1; 65 | .1 .9]; 66 | 67 | % transitions 68 | B = [1 0; 69 | 0 1]; 70 | 71 | % number of timesteps 72 | T = 2; 73 | 74 | % number of iterations of message passing 75 | NumIterations = 16; 76 | 77 | % initialize posterior (Step 1) 78 | for t = 1:T 79 | Qs(:,t) = [.5 .5]'; 80 | end 81 | 82 | % fix observations (Step 2) 83 | o{1} = [1 0]'; 84 | o{2} = [1 0]'; 85 | 86 | % iterate a set number of times (alternatively, until convergence) (Step 8) 87 | for Ni = 1:NumIterations 88 | % For each edge (hidden state) (Step 7) 89 | for tau = 1:T 90 | % choose an edge (Step 3) 91 | q = nat_log(Qs(:,tau)); 92 | 93 | % compute messages sent by D and B (Steps 4) using the posterior 94 | % computed in Step 6B 95 | if tau == 1 % first time point 96 | lnD = nat_log(D); % Message 1 97 | lnBs = nat_log(B'*Qs(:,tau+1)); % Message 2 98 | elseif tau == T % last time point 99 | lnBs = nat_log(B*Qs(:,tau-1)); % Message 1 100 | end 101 | 102 | % likelihood (Message 3) 103 | lnAo = nat_log(A'*o{tau}); 104 | 105 | % Steps 5-6 (Pass messages and update the posterior) 106 | % Since all terms are in log space, this is addition instead of 107 | % multiplication. This corresponds to equation 16 in the main 108 | % text (within the softmax) 109 | if tau == 1 110 | q = .5*lnD + .5*lnBs + lnAo; 111 | elseif tau == T 112 | q = .5*lnBs + lnAo; 113 | end 114 | 115 | % normalize using a softmax function to find posterior (Step 6A) 116 | Qs(:,tau) = (exp(q)/sum(exp(q))); 117 | qs(Ni,:,tau) = Qs(:,tau); % store value for each iteration 118 | end % Repeat for remaining edges (Step 7) 119 | end % Repeat until convergence/for fixed number of iterations (Step 8) 120 | 121 | Qs; % final posterior beliefs over states 122 | 123 | disp(' '); 124 | disp('Posterior over states q(s) in example 1:'); 125 | disp(' '); 126 | disp(Qs); 127 | 128 | figure 129 | 130 | % firing rates (traces) 131 | qs_plot = [D' D';qs(:,:,1) qs(:,:,2)]; % add prior to starting value 132 | plot(qs_plot) 133 | title('Example 1: Approximate Posteriors (1 per edge per time point)') 134 | ylabel('q(s_t_a_u)','FontSize',12) 135 | xlabel('Message passing iterations','FontSize',12) 136 | 137 | 138 | %% Example 2: Sequential observations and simulation of firing rates and ERPs 139 | 140 | % This script performs state estimation using the message passing 141 | % algorithm introduced in Parr, Markovic, Kiebel, & Friston (2019). 142 | % This script can be thought of as the full message passing solution to 143 | % problem 2 in the pencil and paper exercises. It also generates 144 | % simulated firing rates and ERPs in the same manner as those shown in 145 | % figs. 8, 10, 11, 14, 15, and 16. Unlike example 1, observations are 146 | % presented sequentially (i.e., two ts and two taus). 147 | 148 | % Specify generative model and initialise variables 149 | 150 | rng('shuffle') 151 | 152 | clear 153 | 154 | % priors 155 | D = [.5 .5]'; 156 | 157 | % likelihood mapping 158 | A = [.9 .1; 159 | .1 .9]; 160 | 161 | % transitions 162 | B = [1 0; 163 | 0 1]; 164 | 165 | % number of timesteps 166 | T = 2; 167 | 168 | % number of iterations of message passing 169 | NumIterations = 16; 170 | 171 | % initialize posterior (Step 1) 172 | for t = 1:T 173 | Qs(:,t) = [.5 .5]'; 174 | end 175 | 176 | % fix observations sequentially (Step 2) 177 | o{1,1} = [1 0]'; 178 | o{1,2} = [0 0]'; 179 | o{2,1} = [1 0]'; 180 | o{2,2} = [1 0]'; 181 | 182 | % Message Passing 183 | 184 | for t = 1:T 185 | for Ni = 1:NumIterations % (Step 8 loop of VMP) 186 | for tau = 1:T % (Step 7 loop of VMP) 187 | 188 | % initialise depolarization variable: v = ln(s) 189 | % choose an edge (Step 3 of VMP) 190 | v = nat_log(Qs(:,t)); 191 | 192 | % get correct D and B for each time point (Steps 4-5 of VMP) 193 | % using using the posterior computed in Step 6B 194 | if tau == 1 % first time point 195 | % past (Message 1) 196 | lnD = nat_log(D); 197 | 198 | % future (Message 2) 199 | lnBs = nat_log(B'*Qs(:,tau+1)); 200 | elseif tau == T % last time point 201 | % no contribution from future (only Message 1) 202 | lnBs = nat_log(B*Qs(:,tau-1)); 203 | end 204 | % likelihood (Message 3) 205 | lnAo = nat_log(A'*o{t,tau}); 206 | 207 | % calculate state prediction error: equation 24 208 | if tau == 1 209 | epsilon(:,Ni,t,tau) = .5*lnD + .5*lnBs + lnAo - v; 210 | elseif tau == T 211 | epsilon(:,Ni,t,tau) = .5*lnBs + lnAo - v; 212 | end 213 | 214 | % (Step 6 of VMP) 215 | % update depolarization variable: equation 25 216 | v = v + epsilon(:,Ni,t,tau); 217 | % normalize using a softmax function to find posterior: 218 | % equation 26 (Step 6A of VMP) 219 | Qs(:,tau) = (exp(v)/sum(exp(v))); 220 | % store Qs for firing rate plots 221 | xn(Ni,:,tau,t) = Qs(:,tau); 222 | end % Repeat for remaining edges (Step 7 of VMP) 223 | end % Repeat until convergence/for number of iterations (Step 8 of VMP) 224 | end 225 | 226 | Qs; % final posterior beliefs over states 227 | 228 | disp(' '); 229 | disp('Posterior over states q(s) in example 2:'); 230 | disp(' '); 231 | disp(Qs); 232 | 233 | % plots 234 | 235 | % get firing rates into usable format 236 | num_states = 2; 237 | num_epochs = 2; 238 | time_tau = [1 2 1 2; 239 | 1 1 2 2]; 240 | for t_tau = 1:size(time_tau,2) 241 | for epoch = 1:num_epochs 242 | % firing rate 243 | firing_rate{epoch,t_tau} = xn(:,time_tau(1,t_tau),time_tau(2,t_tau),epoch); 244 | ERP{epoch,t_tau} = gradient(firing_rate{epoch,t_tau}')'; 245 | end 246 | end 247 | 248 | % convert cells to matrices 249 | firing_rate = spm_cat(firing_rate)'; 250 | firing_rate = [zeros(length(D)*T,1)+[D; D] full(firing_rate)]; % add prior for starting value 251 | ERP = spm_cat(ERP); 252 | ERP = [zeros(length(D)*T,1)'; ERP]; % add 0 for starting value 253 | 254 | figure 255 | 256 | % firing rates 257 | imagesc(t,1:(num_states*num_epochs),64*(1 - firing_rate)) 258 | cmap = gray(256); 259 | colormap(cmap) 260 | title('Example 2: Firing rates (Darker = higher value)') 261 | ylabel('Firing rate','FontSize',12) 262 | xlabel('Message passing iterations','FontSize',12) 263 | 264 | figure 265 | 266 | % firing rates (traces) 267 | plot(firing_rate') 268 | title('Example 2: Firing rates (traces)') 269 | ylabel('Firing rate','FontSize',12) 270 | xlabel('Message passing iterations','FontSize',12) 271 | 272 | figure 273 | 274 | % ERPs/LFPs 275 | plot(ERP) 276 | title('Example 2: Event-related potentials') 277 | ylabel('Response','FontSize',12) 278 | xlabel('Message passing iterations','FontSize',12) 279 | 280 | %% functions 281 | 282 | % natural log that replaces zero values with very small values for numerical reasons. 283 | function y = nat_log(x) 284 | y = log(x+exp(-16)); 285 | end 286 | -------------------------------------------------------------------------------- /Pencil_and_paper_exercise_solutions.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %-- Code/solutions for pencil and paper exercises --% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 6 | % Application to Empirical Data 7 | 8 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 9 | 10 | % Note to readers: be sure to run sections individually 11 | 12 | 13 | %% Static perception 14 | 15 | clear 16 | close all 17 | rng('default') 18 | 19 | % priors 20 | D = [.75 .25]'; 21 | 22 | % likelihood mapping 23 | A = [.8 .2; 24 | .2 .8]; 25 | 26 | % observations 27 | o = [1 0]'; 28 | 29 | % express generative model in terms of update equations 30 | lns = nat_log(D) + nat_log(A'*o); 31 | 32 | % normalize using a softmax function to find posterior 33 | s = (exp(lns)/sum(exp(lns))); 34 | 35 | disp('Posterior over states q(s):'); 36 | disp(' '); 37 | disp(s); 38 | 39 | % Note: Because the natural log of 0 is undefined, for numerical reasons 40 | % the nat_log function here replaces zero values with very small values. This 41 | % means that the answers generated by this function will vary slightly from 42 | % the exact solutions shown in the text. 43 | 44 | return 45 | 46 | %% Dynamic perception 47 | 48 | clear 49 | close all 50 | rng('default') 51 | 52 | % priors 53 | D = [.5 .5]'; 54 | 55 | % likelihood mapping 56 | A = [.9 .1; 57 | .1 .9]; 58 | 59 | % transitions 60 | B = [1 0; 61 | 0 1]; 62 | 63 | % observations 64 | o{1,1} = [1 0]'; 65 | o{1,2} = [0 0]'; 66 | o{2,1} = [1 0]'; 67 | o{2,2} = [1 0]'; 68 | 69 | % number of timesteps 70 | T = 2; 71 | 72 | % initialise posterior 73 | for t = 1:T 74 | Qs(:,t) = [.5 .5]'; 75 | end 76 | 77 | for t = 1:T 78 | for tau = 1:T 79 | % get correct D and B for each time point 80 | if tau == 1 % first time point 81 | lnD = nat_log(D);% past 82 | lnBs = nat_log(B'*Qs(:,tau+1));% future 83 | elseif tau == T % last time point 84 | lnBs = nat_log(B'*Qs(:,tau-1));% no contribution from future 85 | end 86 | % likelihood 87 | lnAo = nat_log(A'*o{t,tau}); 88 | % update equation 89 | if tau == 1 90 | lns = .5*lnD + .5*lnBs + lnAo; 91 | elseif tau == T 92 | lns = .5*lnBs + lnAo; 93 | end 94 | % normalize using a softmax function to find posterior 95 | Qs(:,tau) = (exp(lns)/sum(exp(lns))) 96 | end 97 | end 98 | 99 | Qs % final posterior beliefs over states 100 | 101 | disp('Posterior over states q(s):'); 102 | disp(' '); 103 | disp(Qs); 104 | 105 | %% functions 106 | 107 | % natural log that replaces zero values with very small values for numerical reasons. 108 | function y = nat_log(x) 109 | y = log(x+.01); 110 | end 111 | -------------------------------------------------------------------------------- /Prediction_error_example.m: -------------------------------------------------------------------------------- 1 | %% Example code for simulating state and outcome prediction errors 2 | 3 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 4 | % Application to Empirical Data 5 | 6 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 7 | 8 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 9 | clear 10 | close all 11 | %% set up model to calculate state prediction errors 12 | % This minimizes variational free energy (keeps posterior beliefs accurate 13 | % while also keeping them as close as possible to prior beliefs) 14 | 15 | A = [.8 .4; 16 | .2 .6]; % Likelihood 17 | 18 | B_t1 = [.9 .2; 19 | .1 .8]; % Transition prior from previous timestep 20 | 21 | B_t2 = [.2 .3; 22 | .8 .7]; % Transition prior from current timestep 23 | 24 | o = [1 0]'; % Observation 25 | 26 | s_pi_tau = [.5 .5]'; % Prior distribution over states. Note that we here 27 | % use the same value for s_pi_tau-1, s_pi_tau, and 28 | % s_pi_tau+1. But this need not be the case. 29 | 30 | s_pi_tau_minus_1 = [.5 .5]'; 31 | 32 | s_pi_tau_plus_1 = [.5 .5]'; 33 | 34 | v_0 = log(s_pi_tau); % Depolarization term (initial value) 35 | 36 | B_t2_cross_intermediate = B_t2'; % Transpose B_t2 37 | 38 | B_t2_cross = spm_softmax(B_t2_cross_intermediate); % Normalize columns in transposed B_t2 39 | 40 | %% Calculate state prediction error (single iteration) 41 | 42 | state_error = 1/2*(log(B_t1*s_pi_tau_minus_1)+log(B_t2_cross*s_pi_tau_plus_1))... 43 | +log(A'*o)-log(s_pi_tau); % state prediction error 44 | 45 | v = v_0 + state_error; % Depolarization 46 | 47 | s = (exp(v)/sum(exp(v))); % Updated distribution over states 48 | 49 | 50 | disp(' '); 51 | disp('Prior Distribution over States:'); 52 | disp(s_pi_tau); 53 | disp(' '); 54 | disp('State Prediction Error:'); 55 | disp(state_error); 56 | disp(' '); 57 | disp('Depolarization:'); 58 | disp(v); 59 | disp(' '); 60 | disp('Posterior Distribution over States:'); 61 | disp(s); 62 | disp(' '); 63 | 64 | return 65 | %% set up model to calculate outcome prediction errors 66 | % This minimizes expected free energy (maximizes reward and 67 | % information-gain) 68 | 69 | clear 70 | close all 71 | 72 | % Calculate risk (reward-seeking) term under two policies 73 | 74 | A = [.9 .1; 75 | .1 .9]; % Likelihood 76 | 77 | S1 = [.9 .1]'; % States under policy 1 78 | S2 = [.5 .5]'; % States under policy 2 79 | 80 | C = [1 0]'; % Preferred outcomes 81 | 82 | o_1 = A*S1; % Predicted outcomes under policy 1 83 | o_2 = A*S2; % Predicted outcomes under policy 2 84 | z = exp(-16); % Small number added to preference distribution to avoid log(0) 85 | 86 | risk_1 = dot(o_1,log(o_1) - log(C+z)); % Risk under policy 1 87 | 88 | risk_2 = dot(o_2,log(o_2) - log(C+z)); % Risk under policy 2 89 | 90 | disp(' '); 91 | disp('Risk Under Policy 1:'); 92 | disp(risk_1); 93 | disp(' '); 94 | disp('Risk Under Policy 2:'); 95 | disp(risk_2); 96 | disp(' '); 97 | 98 | 99 | % Calculate ambiguity (information-seeking) term under two policies 100 | 101 | A = [.4 .2; 102 | .6 .8]; % Likelihood 103 | 104 | s1 = [.9 .1]'; % States under policy 1 105 | s2 = [.1 .9]'; % States under policy 2 106 | 107 | 108 | ambiguity_1 = -dot(diag(A'*log(A)),s1); % Ambiguity under policy 1 109 | 110 | ambiguity_2 = -dot(diag(A'*log(A)),s2); % Ambiguity under policy 2 111 | 112 | disp(' '); 113 | disp('Ambiguity Under Policy 1:'); 114 | disp(ambiguity_1); 115 | disp(' '); 116 | disp('Ambiguity Under Policy 2:'); 117 | disp(ambiguity_2); 118 | disp(' '); 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Active-Inference-Tutorial-Scripts 2 | 3 | Supplementary scripts for Step-by-step active inference modelling tutorial 4 | 5 | By Ryan Smith and Christopher Whyte 6 | 7 | Step_by_Step_AI_Guide.m: 8 | 9 | This is the main tutorial script. It illustrates how to build a partially observable Markov decision process (POMDP) model within the active inference framework, using a simple explore-exploit task as an example. It shows how to run single-trial and multi-trial simulations including perception, decision-making, and learning. It also shows how to generate simulated neuronal responses. It further illustrates how to fit task models to empirical data for behavioral studies and do subsequent Bayesian group analyses. 10 | NOTE: This code was updated on 8/28/24 to improve the way forgetting rates are implemented. Unlike in the original published tutorial, this updated version specifies that greater omega values promote greater forgetting. Initial values for concentration parameters also now act as a floor, preventing these parameters from evolving toward implausibly low values over time. 11 | 12 | Step_by_Step_Hierarchical_Model: 13 | 14 | Separate script illustrating how to build a hierarchical (deep temporal) model, using a commonly used oddball task paradigm as an example. This also shows how to simulate predicted neuronal responses (event-related potentials) observed using this task in empirical studies. 15 | 16 | EFE_Precision_Updating: 17 | 18 | Separate script that allows the reader to simulate updates in the expected free energy precision (gamma) through updates in its prior (beta). At the top of the script you can choose values for the prior over policies, expected free energy over policies, and variational free energy over policies after a new observation, as well as the initial prior on expected precision. The script will then simulate 16 iterative updates and plot the resulting changes in gamma. By changing the initial values of the priors and free energies, you can get more of an intuition about the dynamics of these updates and how they depend on the relationship between the initial values that are chosen. 19 | 20 | VFE_calculation_example: 21 | 22 | Separate script that allows the reader to calculate variational free energy for approximate posterior beliefs given a new observation. The reader can specify a generative model (priors and likelihood matrix) and an observation, and then experiment with how variational free energy is reduced as approximate posterior beliefs approach the true posteriors. 23 | 24 | Prediction_error_example: 25 | 26 | Separate script that allows the reader to calculate state and outcome prediction errors. These minimize variational and expected free energy, respectively. Minimizing state prediction errors maintains accurate beliefs (while also changing beliefs as little as possible). Minimizing outcome prediction errors maximizes reward and information gain. 27 | 28 | Message_passing_example: 29 | 30 | Separate script that allows the reader to perform (marginal) message passing. In the first example, the code follows the message passing steps described in the main text (section 2) one by one. In the second example, this is extended to also calculate firing rates and ERPs associated with message passing in the neural process theory associated with active inference. 31 | 32 | EFE_learning_novelty_term: 33 | 34 | Separate script that allows the reader to calculate the novelty term that is added to the expected free energy when learning the Dirichlet concentration parameters (a) for the likelihood matrix (A). Small concentration parameters lead to a larger value for the novelty term, which is subtracted from the total EFE value for a policy. Therefore, less confidence in beliefs about state-outcome mappings in the A matrix lead the agent to select policies that will increase confidence in those beliefs ('parameter exploration'). 35 | 36 | Pencil_and_paper_exercise_solutions: 37 | 38 | Solutions to 'pencil and paper' exercises provided in the tutorial paper. These are provided to aid the reader in developing intuitions for the equations used in active inference. 39 | 40 | spm_MDP_VB_X_tutorial: 41 | 42 | Tutorial version of the standard routine for running active inference (POMDP) models. 43 | NOTE: This code was updated on 8/28/24 to improve the way forgetting rates are implemented. Unlike in the original published tutorial, this updated version specifies that greater omega values promote greater forgetting. Initial values for concentration parameters also now act as a floor, preventing these parameters from evolving toward implausibly low values over time. 44 | 45 | Simplified_simulation_script: 46 | 47 | Simplified and heavily commented version of the spm_MDB_VB_X_tutorial script. This is provided to make it easier for the reader to understand how the standard simulation routines work. 48 | NOTE: This code was updated on 8/28/24 to improve the way forgetting rates are implemented. Unlike in the original published tutorial, this updated version specifies that greater omega values promote greater forgetting. Initial values for concentration parameters also now act as a floor, preventing these parameters from evolving toward implausibly low values over time. 49 | 50 | Estimate_parameters: 51 | 52 | Script called by the main tutorial script for estimating parameters on (simulated) behavioral data. 53 | 54 | NOTE: Additional scripts are secondary functions called by the main scripts for plotting simulation outputs. 55 | -------------------------------------------------------------------------------- /Simplified_simulation_script.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %-- Simplified Simulation Script --% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 6 | % Application to Empirical Data 7 | 8 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 9 | % UPDATED: 8/28/2024 (modified forgetting rate implementation) 10 | rng('shuffle') 11 | close all 12 | clear 13 | 14 | % This code simulates a single trial of the explore-exploit task introduced 15 | % in the active inference tutorial using a stripped down version of the model 16 | % inversion scheme implemented in the spm_MDP_VB_X.m script. 17 | 18 | % Note that this implementation uses the marginal message passing scheme 19 | % described in (Parr et al., 2019), and will return very slightly 20 | % (negligably) different values than the spm_MDP_VB_X.m script in 21 | % simulation results. 22 | 23 | % Parr, T., Markovic, D., Kiebel, S., & Friston, K. J. (2019). Neuronal 24 | % message passing using Mean-field, Bethe, and Marginal approximations. 25 | % Scientific Reports, 9, 1889. 26 | 27 | %% Simulation Settings 28 | 29 | % To simulate the task when prior beliefs (d) are separated from the 30 | % generative process, set the 'Gen_model' variable directly 31 | % below to 1. To do so for priors (d), likelihoods (a), and habits (e), 32 | % set the 'Gen_model' variable to 2: 33 | 34 | Gen_model = 1; % as in the main tutorial code, many parameters can be adjusted 35 | % in the model setup, within the explore_exploit_model 36 | % function starting on line 810. This includes, among 37 | % others (similar to in the main tutorial script): 38 | 39 | % prior beliefs about context (d): alter line 876 40 | 41 | % beliefs about hint accuracy in the likelihood (a): alter lines 996-998 42 | 43 | % to adjust habits (e), alter line 1155 44 | 45 | %% Specify Generative Model 46 | 47 | MDP = explore_exploit_model(Gen_model); 48 | 49 | % Model specification is reproduced at the bottom of this script (starting 50 | % on line 810), but see main tutorial script for more complete walk-through 51 | 52 | %% Model Inversion to Simulate Behavior 53 | %========================================================================== 54 | 55 | % Normalize generative process and generative model 56 | %-------------------------------------------------------------------------- 57 | 58 | % before sampling from the generative process and inverting the generative 59 | % model we need to normalize the columns of the matrices so that they can 60 | % be treated as a probability distributions 61 | 62 | % generative process 63 | A = MDP.A; % Likelihood matrices 64 | B = MDP.B; % Transition matrices 65 | C = MDP.C; % Preferences over outcomes 66 | D = MDP.D; % Priors over initial states 67 | T = MDP.T; % Time points per trial 68 | V = MDP.V; % Policies 69 | beta = MDP.beta; % Expected free energy precision 70 | alpha = MDP.alpha; % Action precision 71 | eta = MDP.eta; % Learning rate 72 | omega = MDP.omega; % Forgetting rate 73 | 74 | A = col_norm(A); 75 | B = col_norm(B); 76 | D = col_norm(D); 77 | 78 | % generative model (lowercase matrices/vectors are beliefs about capitalized matrices/vectors) 79 | 80 | NumPolicies = MDP.NumPolicies; % Number of policies 81 | NumFactors = MDP.NumFactors; % Number of state factors 82 | 83 | % Store initial paramater values of generative model for free energy 84 | % calculations after learning 85 | %-------------------------------------------------------------------------- 86 | 87 | % 'complexity' of d vector concentration paramaters 88 | if isfield(MDP,'d') 89 | for factor = 1:numel(MDP.d) 90 | % store d vector values before learning 91 | d_prior{factor} = MDP.d{factor}; 92 | % compute "complexity" - lower concentration paramaters have 93 | % smaller values creating a lower expected free energy thereby 94 | % encouraging 'novel' behaviour 95 | d_complexity{factor} = spm_wnorm(d_prior{factor}); 96 | end 97 | end 98 | 99 | if isfield(MDP,'a') 100 | % complexity of a maxtrix concentration parameters 101 | for modality = 1:numel(MDP.a) 102 | a_prior{modality} = MDP.a{modality}; 103 | a_complexity{modality} = spm_wnorm(a_prior{modality}).*(a_prior{modality} > 0); 104 | end 105 | end 106 | 107 | % Normalise matrices before model inversion/inference 108 | %-------------------------------------------------------------------------- 109 | 110 | % normalize A matrix 111 | if isfield(MDP,'a') 112 | a = col_norm(MDP.a); 113 | else 114 | a = col_norm(MDP.A); 115 | end 116 | 117 | % normalize B matrix 118 | if isfield(MDP,'b') 119 | b = col_norm(MDP.b); 120 | else 121 | b = col_norm(MDP.B); 122 | end 123 | 124 | % normalize C and transform into log probability 125 | for ii = 1:numel(C) 126 | C{ii} = MDP.C{ii} + 1/32; 127 | for t = 1:T 128 | C{ii}(:,t) = nat_log(exp(C{ii}(:,t))/sum(exp(C{ii}(:,t)))); 129 | end 130 | end 131 | 132 | % normalize D vector 133 | if isfield(MDP,'d') 134 | d = col_norm(MDP.d); 135 | else 136 | d = col_norm(MDP.D); 137 | end 138 | 139 | % normalize E vector 140 | if isfield(MDP,'e') 141 | E = MDP.e; 142 | E = E./sum(E); 143 | elseif isfield(MDP,'E') 144 | E = MDP.E; 145 | E = E./sum(E); 146 | else 147 | E = col_norm(ones(NumPolicies,1)); 148 | E = E./sum(E); 149 | end 150 | 151 | % Initialize variables 152 | %-------------------------------------------------------------------------- 153 | 154 | % numbers of transitions, policies and states 155 | NumModalities = numel(a); % number of outcome factors 156 | NumFactors = numel(d); % number of hidden state factors 157 | NumPolicies = size(V,2); % number of allowable policies 158 | for factor = 1:NumFactors 159 | NumStates(factor) = size(b{factor},1); % number of hidden states 160 | NumControllable_transitions(factor) = size(b{factor},3); % number of hidden controllable hidden states for each factor (number of B matrices) 161 | end 162 | 163 | % initialize the approximate posterior over states conditioned on policies 164 | % for each factor as a flat distribution over states at each time point 165 | for policy = 1:NumPolicies 166 | for factor = 1:NumFactors 167 | NumStates(factor) = length(D{factor}); % number of states in each hidden state factor 168 | state_posterior{factor} = ones(NumStates(factor),T,policy)/NumStates(factor); 169 | end 170 | end 171 | 172 | % initialize the approximate posterior over policies as a flat distribution 173 | % over policies at each time point 174 | policy_posteriors = ones(NumPolicies,T)/NumPolicies; 175 | 176 | % initialize posterior over actions 177 | chosen_action = zeros(ndims(B),T-1); 178 | 179 | % if there is only one policy 180 | for factors = 1:NumFactors 181 | if NumControllable_transitions(factors) == 1 182 | chosen_action(factors,:) = ones(1,T-1); 183 | end 184 | end 185 | MDP.chosen_action = chosen_action; 186 | 187 | % initialize expected free energy precision (beta) 188 | posterior_beta = 1; 189 | gamma(1) = 1/posterior_beta; % expected free energy precision 190 | 191 | % message passing variables 192 | TimeConst = 4; % time constant for gradient descent 193 | NumIterations = 16; % number of message passing iterations 194 | 195 | % Lets go! Message passing and policy selection 196 | %-------------------------------------------------------------------------- 197 | 198 | for t = 1:T % loop over time points 199 | 200 | % sample generative process 201 | %---------------------------------------------------------------------- 202 | 203 | for factor = 1:NumFactors % number of hidden state factors 204 | % Here we sample from the prior distribution over states to obtain the 205 | % state at each time point. At T = 1 we sample from the D vector, and at 206 | % time T > 1 we sample from the B matrix. To do this we make a vector 207 | % containing the cumulative sum of the columns (which we know sum to one), 208 | % generate a random number (0-1),and then use the find function to take 209 | % the first number in the cumulative sum vector that is >= the random number. 210 | % For example if our D vector is [.5 .5] 50% of the time the element of the 211 | % vector corresponding to the state one will be >= to the random number. 212 | 213 | % sample states 214 | if t == 1 215 | prob_state = D{factor}; % sample initial state T = 1 216 | elseif t>1 217 | prob_state = B{factor}(:,true_states(factor,t-1),MDP.chosen_action(factor,t-1)); 218 | end 219 | true_states(factor,t) = find(cumsum(prob_state)>= rand,1); 220 | end 221 | 222 | % sample observations 223 | for modality = 1:NumModalities % loop over number of outcome modalities 224 | outcomes(modality,t) = find(cumsum(a{modality }(:,true_states(1,t),true_states(2,t)))>=rand,1); 225 | end 226 | 227 | % express observations as a structure containing a 1 x observations 228 | % vector for each modality with a 1 in the position corresponding to 229 | % the observation recieved on that trial 230 | for modality = 1:NumModalities 231 | vec = zeros(1,size(a{modality},1)); 232 | index = outcomes(modality,t); 233 | vec(1,index) = 1; 234 | O{modality,t} = vec; 235 | clear vec 236 | end 237 | 238 | % marginal message passing (minimize F and infer posterior over states) 239 | %---------------------------------------------------------------------- 240 | 241 | for policy = 1:NumPolicies 242 | for Ni = 1:NumIterations % number of iterations of message passing 243 | for factor = 1:NumFactors 244 | lnAo = zeros(size(state_posterior{factor})); % initialise matrix containing the log likelihood of observations 245 | for tau = 1:T % loop over tau 246 | v_depolarization = nat_log(state_posterior{factor}(:,tau,policy)); % convert approximate posteriors into depolarisation variable v 247 | if tau