├── .gitignore ├── EFE_Precision_Updating.ipynb ├── EFE_learning_novelty_term.ipynb ├── Estimate_parameters.py ├── Message_passing_example.ipynb ├── Pencil_and_paper_exercise_solutions.ipynb ├── Prediction_error_example.ipynb ├── Readme.md ├── Simplified_simulation_script.py ├── original_matlab_code ├── EFE_Precision_Updating.m ├── EFE_learning_novelty_term.m ├── Estimate_parameters.m ├── Message_passing_example.m ├── Pencil_and_paper_exercise_solutions.m ├── Prediction_error_example.m ├── Simplified_simulation_script.m ├── Step_by_Step_AI_Guide.m ├── Step_by_Step_Hierarchical_Model.m ├── VFE_calculation_example.m ├── spm │ ├── spm_MDP_VB_LFP.m │ ├── spm_MDP_VB_trial.m │ ├── spm_MDP_check.m │ ├── spm_MDP_size.m │ ├── spm_axis.m │ ├── spm_cat.m │ ├── spm_combinations.m │ ├── spm_conv.m │ ├── spm_dir_norm.m │ ├── spm_iwft.m │ ├── spm_softmax.m │ ├── spm_speye.m │ ├── spm_vec.m │ └── spm_wft.m ├── spm_MDP_VB_ERP_tutorial.m ├── spm_MDP_VB_X_tutorial.m ├── spm_MDP_VB_game_tutorial.m └── spm_figure.m ├── spm ├── spm.py ├── spm_MDP_VB_LFP.py ├── spm_MDP_VB_trial.py ├── spm_MDP_check.py ├── spm_MDP_size.py ├── spm_auxillary.py ├── spm_axis.py ├── spm_combination.py ├── spm_conv.py ├── spm_dir_norm.py ├── spm_iwft.py ├── spm_softmax.py ├── spm_speye.py ├── spm_vec.py └── spm_wft.py └── utility └── math_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | utility/__pycache__/ 2 | spm/__pycache__/ 3 | .ipynb_checkpoints 4 | .DS_Store 5 | __pycache__/ -------------------------------------------------------------------------------- /EFE_learning_novelty_term.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "d444051d", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | " \n", 14 | "Novelty term for small concentration parameter values:\n", 15 | "0.505\n", 16 | " \n", 17 | "Novelty term for intermediate concentration parameter values:\n", 18 | "0.05050000000000002\n", 19 | " \n", 20 | "Novelty term for large concentration parameter values:\n", 21 | "0.005050000000000001\n", 22 | " \n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "# Calculating novelty term in expected free energy when learning 'A' matrix concentration parameters\n", 28 | "# (which drives parameter exploration)\n", 29 | "\n", 30 | "# Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its \n", 31 | "# Application to Empirical Data\n", 32 | "\n", 33 | "# By: Ryan Smith, Karl J. Friston, Christopher J. Whyte\n", 34 | "###############################################################################\n", 35 | "\n", 36 | "# 参考《主动推理》“新颖性”:P213 式7-11,P374\n", 37 | "\n", 38 | "import numpy as np\n", 39 | "\n", 40 | "def col_norm(A_norm):\n", 41 | " norm_constant = np.sum(A_norm, axis=0) # create normalizing constant from sum of columns\n", 42 | " A_normed = A_norm / norm_constant # divide columns by constant\n", 43 | " return A_normed\n", 44 | "\n", 45 | "# small concentration parameter values \n", 46 | "a1 = np.array([[0.25, 1], \n", 47 | " [0.75, 1]])\n", 48 | "\n", 49 | "# intermediate concentration parameter values \n", 50 | "a2 = np.array([[2.5, 10],\n", 51 | " [7.5, 10]])\n", 52 | "\n", 53 | "# large concentration parameter values \n", 54 | "a3 = np.array([[25, 100],\n", 55 | " [75, 100]])\n", 56 | "\n", 57 | "# normalize columns in 'a' to get likelihood matrix 'A'\n", 58 | "A1 = col_norm(a1)\n", 59 | "A2 = col_norm(a2)\n", 60 | "A3 = col_norm(a3)\n", 61 | "\n", 62 | "# calculate 'a_sum' \n", 63 | "a1_sum = np.array([[a1[0,0]+a1[1,0], a1[0,1]+a1[1,1]],\n", 64 | " [a1[0,0]+a1[1,0], a1[0,1]+a1[1,1]]])\n", 65 | "\n", 66 | "a2_sum = np.array([[a2[0,0]+a2[1,0], a2[0,1]+a2[1,1]],\n", 67 | " [a2[0,0]+a2[1,0], a2[0,1]+a2[1,1]]])\n", 68 | "\n", 69 | "a3_sum = np.array([[a3[0,0]+a3[1,0], a3[0,1]+a3[1,1]],\n", 70 | " [a3[0,0]+a3[1,0], a3[0,1]+a3[1,1]]])\n", 71 | "\n", 72 | "# element wise inverse for 'a' and 'a_sum'\n", 73 | "inv_a1 = 1 / a1\n", 74 | "inv_a2 = 1 / a2\n", 75 | "inv_a3 = 1 / a3\n", 76 | "\n", 77 | "inv_a1_sum = 1 / a1_sum\n", 78 | "inv_a2_sum = 1 / a2_sum\n", 79 | "inv_a3_sum = 1 / a3_sum\n", 80 | "\n", 81 | "# 'W' term for 'a' matrix\n", 82 | "W1 = 0.5 * (inv_a1 - inv_a1_sum)\n", 83 | "W2 = 0.5 * (inv_a2 - inv_a2_sum)\n", 84 | "W3 = 0.5 * (inv_a3 - inv_a3_sum)\n", 85 | "\n", 86 | "# beliefs over states under a policy at a time point\n", 87 | "s_pi_tau = np.array([0.9, 0.1])\n", 88 | "\n", 89 | "# predictive posterior over outcomes (A*s_pi_tau = predicted o_pi_tau)\n", 90 | "A1s = np.dot(A1, s_pi_tau)\n", 91 | "A2s = np.dot(A2, s_pi_tau)\n", 92 | "A3s = np.dot(A3, s_pi_tau)\n", 93 | "\n", 94 | "# W term multiplied by beliefs over states under a policy at a time point\n", 95 | "W1s = np.dot(W1, s_pi_tau)\n", 96 | "W2s = np.dot(W2, s_pi_tau)\n", 97 | "W3s = np.dot(W3, s_pi_tau)\n", 98 | "\n", 99 | "# compute novelty using dot product function\n", 100 | "Novelty_smallCP = np.dot(A1s, W1s)\n", 101 | "Novelty_intermediateCP = np.dot(A2s, W2s)\n", 102 | "Novelty_largeCP = np.dot(A3s, W3s)\n", 103 | "\n", 104 | "# show results\n", 105 | "print(' ')\n", 106 | "print('Novelty term for small concentration parameter values:')\n", 107 | "print(Novelty_smallCP)\n", 108 | "print(' ')\n", 109 | "print('Novelty term for intermediate concentration parameter values:')\n", 110 | "print(Novelty_intermediateCP)\n", 111 | "print(' ')\n", 112 | "print('Novelty term for large concentration parameter values:')\n", 113 | "print(Novelty_largeCP)\n", 114 | "print(' ')" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3 (ipykernel)", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.10.15" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 5 139 | } 140 | -------------------------------------------------------------------------------- /Estimate_parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ############################################################################### 4 | 5 | # 参考《主动推理》“参数估测”:P257 6 | 7 | def Estimate_parameters(DCM): 8 | """ 9 | MDP inversion using Variational Bayes 10 | FORMAT [DCM] = spm_dcm_mdp(DCM) 11 | 12 | Expects: 13 | -------------------------------------------------------------------------- 14 | DCM.MDP # MDP structure specifying a generative model 15 | DCM.field # parameter (field) names to optimise 16 | DCM.U # list of outcomes (stimuli) 17 | DCM.Y # list of responses (action) 18 | 19 | Returns: 20 | -------------------------------------------------------------------------- 21 | DCM.M # generative model (DCM) 22 | DCM.Ep # Conditional means (structure) 23 | DCM.Cp # Conditional covariances 24 | DCM.F # (negative) Free-energy bound on log evidence 25 | 26 | This routine inverts (list of) trials specified in terms of the 27 | stimuli or outcomes and subsequent choices or responses. It first 28 | computes the prior expectations (and covariances) of the free parameters 29 | specified by DCM.field. These parameters are log scaling parameters that 30 | are applied to the fields of DCM.MDP. 31 | 32 | If there is no learning implicit in multi-trial games, only unique trials 33 | (as specified by the stimuli), are used to generate (subjective) 34 | posteriors over choice or action. Otherwise, all trials are used in the 35 | order specified. The ensuing posterior probabilities over choices are 36 | used with the specified choices or actions to evaluate their log 37 | probability. This is used to optimise the MDP (hyper) parameters in 38 | DCM.field using variational Laplace (with numerical evaluation of the 39 | curvature). 40 | """ 41 | 42 | # OPTIONS 43 | ALL = False 44 | 45 | # Here we specify prior expectations (for parameter means and variances) 46 | prior_variance = 1/4 # smaller values will lead to a greater complexity 47 | # penalty (posteriors will remain closer to priors) 48 | 49 | pE = {} 50 | pC = {} 51 | 52 | for i, field in enumerate(DCM['field']): 53 | try: 54 | param = DCM['MDP'][field] 55 | param = np.double(param != 0) 56 | except KeyError: 57 | param = 1 58 | if ALL: 59 | pE[field] = np.zeros_like(param) 60 | pC[(i, i)] = np.diag(param) 61 | else: 62 | if field == 'alpha': 63 | pE[field] = np.log(16) # in log-space (to keep positive) 64 | pC[(i, i)] = prior_variance 65 | elif field == 'beta': 66 | pE[field] = np.log(1) # in log-space (to keep positive) 67 | pC[(i, i)] = prior_variance 68 | elif field == 'la': 69 | pE[field] = np.log(1) # in log-space (to keep positive) 70 | pC[(i, i)] = prior_variance 71 | elif field == 'rs': 72 | pE[field] = np.log(5) # in log-space (to keep positive) 73 | pC[(i, i)] = prior_variance 74 | elif field == 'eta': 75 | pE[field] = np.log(0.5 / (1 - 0.5)) # in logit-space - bounded between 0 and 1 76 | pC[(i, i)] = prior_variance 77 | elif field == 'omega': 78 | pE[field] = np.log(0.5 / (1 - 0.5)) # in logit-space - bounded between 0 and 1 79 | pC[(i, i)] = prior_variance 80 | else: 81 | pE[field] = 0 # if it can take any negative or positive value 82 | pC[(i, i)] = prior_variance 83 | 84 | pC = spm_cat(pC) 85 | 86 | # model specification 87 | M = { 88 | 'L': lambda P, M, U, Y: spm_mdp_L(P, M, U, Y), # log-likelihood function 89 | 'pE': pE, # prior means (parameters) 90 | 'pC': pC, # prior variance (parameters) 91 | 'mdp': DCM['MDP'] # MDP structure 92 | } 93 | 94 | # Variational Laplace 95 | Ep, Cp, F = spm_nlsi_Newton(M, DCM['U'], DCM['Y']) # This is the actual fitting routine 96 | 97 | # Store posterior distributions and log evidence (free energy) 98 | DCM['M'] = M # Generative model 99 | DCM['Ep'] = Ep # Posterior parameter estimates 100 | DCM['Cp'] = Cp # Posterior variances and covariances 101 | DCM['F'] = F # Free energy of model fit 102 | 103 | return DCM 104 | 105 | def spm_mdp_L(P, M, U, Y): 106 | """ 107 | log-likelihood function 108 | FORMAT L = spm_mdp_L(P,M,U,Y) 109 | P - parameter structure 110 | M - generative model 111 | U - inputs 112 | Y - observed responses 113 | 114 | This function runs the generative model with a given set of parameter 115 | values, after adding in the observations and actions on each trial 116 | from (real or simulated) participant data. It then sums the 117 | (log-)probabilities (log-likelihood) of the participant's actions under the model when it 118 | includes that set of parameter values. The variational Bayes fitting 119 | routine above uses this function to find the set of parameter values that maximize 120 | the probability of the participant's actions under the model (while also 121 | penalizing models with parameter values that move farther away from prior 122 | values). 123 | """ 124 | 125 | if not isinstance(P, dict): 126 | P = spm_unvec(P, M['pE']) 127 | 128 | # Here we re-transform parameter values out of log- or logit-space when 129 | # inserting them into the model to compute the log-likelihood 130 | mdp = M['mdp'] 131 | fields = M['pE'].keys() 132 | for field in fields: 133 | if field == 'alpha': 134 | mdp[field] = np.exp(P[field]) 135 | elif field == 'beta': 136 | mdp[field] = np.exp(P[field]) 137 | elif field == 'la': 138 | mdp[field] = np.exp(P[field]) 139 | elif field == 'rs': 140 | mdp[field] = np.exp(P[field]) 141 | elif field == 'eta': 142 | mdp[field] = 1 / (1 + np.exp(-P[field])) 143 | elif field == 'omega': 144 | mdp[field] = 1 / (1 + np.exp(-P[field])) 145 | else: 146 | mdp[field] = np.exp(P[field]) 147 | 148 | # place MDP in trial structure 149 | la = mdp['la_true'] # true level of loss aversion 150 | rs = mdp['rs_true'] # true preference magnitude for winning (higher = more risk-seeking) 151 | 152 | if 'la' in M['pE'] and 'rs' in M['pE']: 153 | mdp['C'][2] = np.array([[0, 0, 0], # Null 154 | [0, -mdp['la'], -mdp['la']], # Loss 155 | [0, mdp['rs'], mdp['rs'] / 2]]) # win 156 | elif 'la' in M['pE']: 157 | mdp['C'][2] = np.array([[0, 0, 0], # Null 158 | [0, -mdp['la'], -mdp['la']], # Loss 159 | [0, rs, rs / 2]]) # win 160 | elif 'rs' in M['pE']: 161 | mdp['C'][2] = np.array([[0, 0, 0], # Null 162 | [0, -la, -la], # Loss 163 | [0, mdp['rs'], mdp['rs'] / 2]]) # win 164 | else: 165 | mdp['C'][2] = np.array([[0, 0, 0], # Null 166 | [0, -la, -la], # Loss 167 | [0, rs, rs / 2]]) # win 168 | 169 | j = range(len(U)) # observations for each trial 170 | n = len(j) # number of trials 171 | 172 | MDP = [mdp] * n # Create MDP with number of specified trials 173 | for k in j: 174 | MDP[k]['o'] = U[k] # Add observations in each trial 175 | 176 | # solve MDP and accumulate log-likelihood 177 | MDP = spm_MDP_VB_X_tutorial(MDP) # run model with possible parameter values 178 | 179 | L = 0 # start (log) probability of actions given the model at 0 180 | 181 | for i in range(len(Y)): # Get probability of true actions for each trial 182 | for j in range(len(Y[0][1])): # Only get probability of the second (controllable) state factor 183 | L += np.log(MDP[i]['P'][:, Y[i][1][j], j] + np.finfo(float).eps) # sum the (log) probabilities of each action 184 | # given a set of possible parameter values 185 | 186 | print(f'LL: {L}') 187 | return L 188 | 189 | # def spm_cat(pC): 190 | # # This function concatenates the covariance matrices 191 | # # Placeholder implementation 192 | # LEN_KEY_I = len(set([key[0] for key in pC.keys()])) 193 | # LEN_KEY_J = len(set([key[1] for key in pC.keys()])) 194 | # return np.block([[pC.get((i, j), np.zeros((1, 1))) for j in range(LEN_KEY_J)] for i in range(LEN_KEY_I)]) 195 | 196 | def spm_nlsi_Newton(M, U, Y): 197 | # Placeholder implementation for the variational Laplace fitting routine 198 | # This should be replaced with the actual implementation 199 | Ep = M['pE'] 200 | Cp = M['pC'] 201 | F = -np.inf # Free energy (log evidence) 202 | return Ep, Cp, F 203 | 204 | def spm_unvec(P, pE): 205 | # This function converts a vector back to a structure 206 | # Placeholder implementation 207 | return pE 208 | 209 | def spm_MDP_VB_X_tutorial(MDP): 210 | # Placeholder implementation for the MDP solver 211 | # This should be replaced with the actual implementation 212 | for mdp in MDP: 213 | mdp['P'] = np.random.rand(3, 3, 3) # Random probabilities for demonstration 214 | return MDP -------------------------------------------------------------------------------- /Pencil_and_paper_exercise_solutions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "60c31a4b-a309-433e-88c1-239e8644a3c6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "from utility.math_utils import nat_log\n", 12 | "\n", 13 | "\n", 14 | "# Static perception\n", 15 | "def static_perception():\n", 16 | " # priors\n", 17 | " D = np.array([0.75, 0.25])\n", 18 | "\n", 19 | " # likelihood mapping\n", 20 | " A = np.array([[0.8, 0.2],\n", 21 | " [0.2, 0.8]])\n", 22 | "\n", 23 | " # observations\n", 24 | " o = np.array([1, 0])\n", 25 | "\n", 26 | " # express generative model in terms of update equations\n", 27 | " lns = nat_log(D) + nat_log(A.T @ o)\n", 28 | "\n", 29 | " # normalize using a softmax function to find posterior\n", 30 | " s = np.exp(lns) / np.sum(np.exp(lns))\n", 31 | "\n", 32 | " print('Posterior over states q(s):')\n", 33 | " print(s)\n", 34 | "\n", 35 | "# Dynamic perception\n", 36 | "def dynamic_perception():\n", 37 | " # priors\n", 38 | " D = np.array([0.5, 0.5])\n", 39 | "\n", 40 | " # likelihood mapping\n", 41 | " A = np.array([[0.9, 0.1],\n", 42 | " [0.1, 0.9]])\n", 43 | "\n", 44 | " # transitions\n", 45 | " B = np.array([[1, 0],\n", 46 | " [0, 1]])\n", 47 | "\n", 48 | " # observations\n", 49 | " o = {\n", 50 | " (1, 1): np.array([1, 0]),\n", 51 | " (1, 2): np.array([0, 0]),\n", 52 | " (2, 1): np.array([1, 0]),\n", 53 | " (2, 2): np.array([1, 0])\n", 54 | " }\n", 55 | "\n", 56 | " # number of timesteps\n", 57 | " T = 2\n", 58 | "\n", 59 | " # initialise posterior\n", 60 | " Qs = np.zeros((2, T))\n", 61 | " for t in range(T):\n", 62 | " Qs[:, t] = np.array([0.5, 0.5])\n", 63 | "\n", 64 | " for t in range(T):\n", 65 | " for tau in range(T):\n", 66 | " # get correct D and B for each time point\n", 67 | " if tau == 0: # first time point\n", 68 | " lnD = nat_log(D) # past\n", 69 | " lnBs = nat_log(B.T @ Qs[:, tau + 1]) # future\n", 70 | " elif tau == T - 1: # last time point\n", 71 | " lnBs = nat_log(B.T @ Qs[:, tau - 1]) # no contribution from future\n", 72 | "\n", 73 | " # likelihood\n", 74 | " lnAo = nat_log(A.T @ o[(t + 1, tau + 1)])\n", 75 | "\n", 76 | " # update equation\n", 77 | " if tau == 0:\n", 78 | " lns = 0.5 * lnD + 0.5 * lnBs + lnAo\n", 79 | " elif tau == T - 1:\n", 80 | " lns = 0.5 * lnBs + lnAo\n", 81 | "\n", 82 | " # normalize using a softmax function to find posterior\n", 83 | " Qs[:, tau] = np.exp(lns) / np.sum(np.exp(lns))\n", 84 | "\n", 85 | " print('Posterior over states q(s):')\n", 86 | " print(Qs)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 2, 92 | "id": "45deaa05-da8b-49f8-b58c-58346367030a", 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Posterior over states q(s):\n", 100 | "[0.92307687 0.07692313]\n", 101 | "Posterior over states q(s):\n", 102 | "[[0.93971703 0.97262816]\n", 103 | " [0.06028297 0.02737184]]\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "static_perception()\n", 109 | "dynamic_perception()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "440b0e74-848f-476d-beb4-017fc695747f", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "Python 3 (ipykernel)", 124 | "language": "python", 125 | "name": "python3" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.10.15" 138 | } 139 | }, 140 | "nbformat": 4, 141 | "nbformat_minor": 5 142 | } 143 | -------------------------------------------------------------------------------- /Prediction_error_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "62750596-cd39-4850-bb07-5c6d329eae08", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | " \n", 14 | "Prior Distribution over States:\n", 15 | "[0.5 0.5]\n", 16 | " \n", 17 | "State Prediction Error:\n", 18 | "[-0.17548846 -0.96897099]\n", 19 | " \n", 20 | "Depolarization:\n", 21 | "[-0.86863564 -1.66211817]\n", 22 | " \n", 23 | "Posterior Distribution over States:\n", 24 | "[0.68857861 0.31142139]\n", 25 | " \n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "import numpy as np\n", 31 | "from spm.spm_softmax import spm_softmax\n", 32 | "\n", 33 | "\n", 34 | "# 设置模型以计算状态预测误差\n", 35 | "A = np.array([[0.8, 0.4], \n", 36 | " [0.2, 0.6]]) # 似然\n", 37 | "\n", 38 | "B_t1 = np.array([[0.9, 0.2], \n", 39 | " [0.1, 0.8]]) # 前一时间步的转移先验\n", 40 | " \n", 41 | "B_t2 = np.array([[0.2, 0.3], \n", 42 | " [0.8, 0.7]]) # 当前时间步的转移先验\n", 43 | " \n", 44 | "o = np.array([1, 0]) # 观测\n", 45 | "\n", 46 | "s_pi_tau = np.array([0.5, 0.5]) # 状态的先验分布\n", 47 | "s_pi_tau_minus_1 = np.array([0.5, 0.5])\n", 48 | "s_pi_tau_plus_1 = np.array([0.5, 0.5])\n", 49 | "\n", 50 | "v_0 = np.log(s_pi_tau) # 去极化项(初始值)\n", 51 | "\n", 52 | "B_t2_cross_intermediate = B_t2.T # 转置 B_t2\n", 53 | "\n", 54 | "B_t2_cross = spm_softmax(B_t2_cross_intermediate) # 归一化转置 B_t2 的列\n", 55 | "\n", 56 | "# 计算状态预测误差(单次迭代)\n", 57 | "state_error = 0.5 * (np.log(B_t1 @ s_pi_tau_minus_1) + np.log(B_t2_cross @ s_pi_tau_plus_1)) \\\n", 58 | " + np.log(A.T @ o) - np.log(s_pi_tau) # 状态预测误差\n", 59 | "\n", 60 | "v = v_0 + state_error # 去极化\n", 61 | "\n", 62 | "s = np.exp(v) / np.sum(np.exp(v)) # 更新后的状态分布\n", 63 | "\n", 64 | "print(' ')\n", 65 | "print('Prior Distribution over States:')\n", 66 | "print(s_pi_tau)\n", 67 | "print(' ')\n", 68 | "print('State Prediction Error:')\n", 69 | "print(state_error)\n", 70 | "print(' ')\n", 71 | "print('Depolarization:')\n", 72 | "print(v)\n", 73 | "print(' ')\n", 74 | "print('Posterior Distribution over States:')\n", 75 | "print(s)\n", 76 | "print(' ')" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "id": "5ea73af9-cd15-448d-a45f-4a272b5d43c0", 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | " \n", 90 | "Risk Under Policy 1:\n", 91 | "2.408606420911068\n", 92 | " \n", 93 | "Risk Under Policy 2:\n", 94 | "7.30685276317247\n", 95 | " \n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "# 设置模型以计算结果预测误差\n", 101 | "# 这最小化期望自由能(最大化奖励和信息增益)\n", 102 | "\n", 103 | "# 计算两种策略下的风险(寻求奖励)\n", 104 | "\n", 105 | "A = np.array([[0.9, 0.1],\n", 106 | " [0.1, 0.9]]) # 似然\n", 107 | " \n", 108 | "S1 = np.array([0.9, 0.1]) # 策略1下的状态\n", 109 | "S2 = np.array([0.5, 0.5]) # 策略2下的状态\n", 110 | "\n", 111 | "C = np.array([1, 0]) # 首选结果\n", 112 | "\n", 113 | "o_1 = A @ S1 # 策略1下的预测结果\n", 114 | "o_2 = A @ S2 # 策略2下的预测结果\n", 115 | "z = np.exp(-16) # 添加到偏好分布中的小数以避免 log(0)\n", 116 | "\n", 117 | "risk_1 = np.dot(o_1, np.log(o_1) - np.log(C + z)) # 策略1下的风险\n", 118 | "\n", 119 | "risk_2 = np.dot(o_2, np.log(o_2) - np.log(C + z)) # 策略2下的风险\n", 120 | "\n", 121 | "print(' ')\n", 122 | "print('Risk Under Policy 1:')\n", 123 | "print(risk_1)\n", 124 | "print(' ')\n", 125 | "print('Risk Under Policy 2:')\n", 126 | "print(risk_2)\n", 127 | "print(' ')" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "id": "5382fd7a-e9a0-4285-b4d4-7313f131e644", 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | " \n", 141 | "Ambiguity Under Policy 1:\n", 142 | "0.6557507426621495\n", 143 | " \n", 144 | "Ambiguity Under Policy 2:\n", 145 | "0.5176633478852948\n", 146 | " \n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "# 计算两种策略下的模糊性(寻求信息)\n", 152 | "\n", 153 | "A = np.array([[0.4, 0.2],\n", 154 | " [0.6, 0.8]]) # 似然\n", 155 | " \n", 156 | "s1 = np.array([0.9, 0.1]) # 策略1下的状态\n", 157 | "s2 = np.array([0.1, 0.9]) # 策略2下的状态\n", 158 | "\n", 159 | "ambiguity_1 = -np.dot(np.diag(A.T @ np.log(A)), s1) # 策略1下的模糊性\n", 160 | "\n", 161 | "ambiguity_2 = -np.dot(np.diag(A.T @ np.log(A)), s2) # 策略2下的模糊性\n", 162 | "\n", 163 | "print(' ')\n", 164 | "print('Ambiguity Under Policy 1:')\n", 165 | "print(ambiguity_1)\n", 166 | "print(' ')\n", 167 | "print('Ambiguity Under Policy 2:')\n", 168 | "print(ambiguity_2)\n", 169 | "print(' ')" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "3c10ecd8-41b3-4b65-99ed-2f3d1972c7a5", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "Python 3 (ipykernel)", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.10.15" 198 | } 199 | }, 200 | "nbformat": 4, 201 | "nbformat_minor": 5 202 | } 203 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # 简介 2 | 3 | 本项目基于论文“A step-by-step tutorial on active inference and its application to empirical data”中的原始代码。原始代码使用Matlab语言编写,文件类型为.m。为了推广Friston自由能的思想,我们将代码转换为Python语言。 4 | 5 | This project is based on the original code from the paper "A step-by-step tutorial on active inference and its application to empirical data". The original code was written in Matlab with .m file types. To promote Friston's ideas, we have translated the code into Python. 6 | 7 | ## 论文信息 8 | 9 | Smith, R., Friston, K. J., & Whyte, C. J. (2022). A step-by-step tutorial on active inference and its application to empirical data. Journal of mathematical psychology, 107, 102632. 10 | 11 | 链接:https://www.sciencedirect.com/science/article/pii/S0022249621000973?via%3Dihub 12 | 13 | DOI: https://doi.org/10.1016/j.jmp.2021.102632 14 | 15 | 原matlab脚本仓库:https://github.com/rssmith33/Active-Inference-Tutorial-Scripts 16 | 17 | ## 目录结构 18 | 19 | `original_matlab_code/`: 包含论文中所附原始matlab代码。 20 | 21 | `spm`: 原始`spm`库(基于matlab开发,Github: https://github.com/spm/spm )的python版本。SPM(Statistical Parametric Mapping,统计参数映射)是指用于检验有关功能成像数据的假设的空间扩展统计过程的构建和评估。SPM软件包专为脑成像数据序列分析。SPM套件和相关理论最初由Karl Friston开发,用于对来自正电子发射断层扫描(PET)的功能神经影像学数据进行常规统计分析,当时他在医学研究委员会回旋加速器单元工作。该软件现在被称为SPMclassic,于1991年向新兴的功能成像社区推出,以促进实验室之间的协作和通用分析方案。 22 | 由于本项目仅供演示项目使用,非科研用途,因此spm的python代码并未与其原始的matlab版本spm.m做严格对应,如果传入参数格式与示例不一致,可能会产生非预期结果,请留意。 23 | 24 | `utility`: 包含定制化数学计算函数。 25 | 26 | 其他文件说明如下(翻译自原matlab脚本仓库README): 27 | 28 | ># 主动推理教程脚本 29 | > 30 | >分步主动推理建模教程的补充脚本 31 | > 32 | >作者:Ryan Smith 和 Christopher Whyte 33 | > 34 | >Step_by_Step_AI_Guide.m: 35 | > 36 | >这是主要的教程脚本。它以一个简单的 explore-exploit 任务为例,说明了如何在主动推理框架中构建部分可观察的马尔可夫决策过程 (POMDP) 模型。它展示了如何运行单试和多试模拟,包括感知、决策和学习。它还展示了如何生成模拟的神>经元反应。它进一步说明了如何将任务模型拟合到行为研究的经验数据中,并进行后续的贝叶斯组分析。注意:此代码已于 24 年 8 月 28 日更新,以改进遗忘率的实施方式。与最初发布的教程不同,此更新版本指定 omega 值越大,遗忘越严重。浓度参数的初始值现在也充当下限,防止这些参数随着时间的推移演变为难以置信的低值。 37 | > 38 | >Step_by_Step_Hierarchical_Model: 39 | > 40 | >单独的脚本说明如何构建分层(深度时间)模型,使用常用的古怪任务范例作为示例。这也显示了如何模拟在实证研究中使用此任务观察到的预测神经元反应(事件相关电位)。 41 | > 42 | >EFE_Precision_Updating: 43 | > 44 | >单独的脚本,允许读取器通过其先前 (beta) 中的更新来模拟预期自由能精度 (gamma) 的更新。在脚本顶部,您可以选择 prior over 策略、预期自由能 over 策略、新观测后 policies 的变分自由能 over 策略的值,以及初始先验 on expected precision 的值。然后,该脚本将模拟 16 次迭代更新,并在 Gamma 中绘制结果变化。通过改变先验和自由能的初始值,你可以对这些更新的动态以及它们如何依赖于所选初始值之间的关系有更多的直觉。 45 | > 46 | >VFE_calculation_example: 47 | > 48 | >单独的脚本,允许读者在给定新观察的情况下计算近似后验信念的变分自由能。读者可以指定一个生成模型(先验和似然矩阵)和一个观察值,然后实验当近似后验信念接近真正的后验信念时如何减少变分自由能。 49 | > 50 | >Prediction_error_example: 51 | > 52 | >允许读者计算状态和结果预测误差的单独脚本。它们分别最小化变分能和预期自由能。最小化状态预测误差可以保持准确的信念(同时也尽可能少地改变信念)。最大限度地减少结果预测误差可以最大限度地提高奖励和信息增益。 53 | > 54 | >Message_passing_example: 55 | > 56 | >允许读者执行 (边际) 消息传递的单独脚本。在第一个示例中,代码逐个遵循正文(第 2 节)中描述的消息传递步骤。在第二个示例中,这被扩展到计算与主动推理相关的神经过程理论中与消息传递相关的发射速率和 ERP。 57 | > 58 | >EFE_learning_novelty_term: 59 | > 60 | >单独的脚本,允许读者在学习似然矩阵 (A) 的狄利克雷浓度参数 (a) 时计算添加到预期自由能中的新颖项。较小的浓度参数会导致新颖性术语的值较大,该值是从策略的总 EFE 值中减去的。因此,对 A 矩阵中状态-结果映射的信念的信心较低,会导致代理选择能够增加对这些信念的信心的策略(“参数探索”)。 61 | > 62 | >Pencil_and_paper_exercise_solutions: 63 | > 64 | >教程论文中提供的 “铅笔和纸 ”练习的解决方案。提供这些是为了帮助读者对主动推理中使用的方程式形成直觉。 65 | > 66 | >spm_MDP_VB_X_tutorial: 67 | > 68 | >运行主动推理 (POMDP) 模型的标准例程的教程版本。注意:此代码已于 24 年 8 月 28 日更新,以改进遗忘率的实施方式。与最初发布的教程不同,此更新版本指定 omega 值越大,遗忘越严重。浓度参数的初始值现在也充当下限,防止这些参数随着时间的推移演变为难以置信的低值。 69 | > 70 | >Simplified_simulation_script: 71 | > 72 | >spm_MDB_VB_X_tutorial 脚本的简化和大量注释版本。提供此功能是为了使读者更容易理解标准仿真例程的工作原理。注意:此代码已于 24 年 8 月 28 日更新,以改进遗忘率的实施方式。与最初发布的教程不同,此更新版本指定 omega 值越大,遗忘越严重。浓度参数的初始值现在也充当下限,防止这些参数随着时间的推移演变为难以置信的低值。 73 | > 74 | >Estimate_parameters: 75 | > 76 | >由主教程脚本调用的脚本,用于估计 (模拟的) 行为数据的参数。 77 | > 78 | >注意: 附加脚本是主脚本调用的辅助函数,用于绘制仿真输出。 79 | 80 | ## 安装 81 | 82 | 克隆此仓库: 83 | 84 | `git clone git@github.com:ouyangzhiping/feppy.git` 85 | 86 | ## 使用方法 87 | 88 | 为了方便学习和可视化演示,本项目使用Jupyter Notebook编写。推荐使用python3.10及以上版本。 89 | 90 | ## 免责声明 91 | 92 | 本项目中的Python代码为非官方翻译版本,虽然已经过人工校验,但仍可能与原始Matlab代码存在差异。使用者需自行承担使用过程中可能产生的风险。 93 | 94 | The Python code in this project is an unofficial translation and may differ from the original Matlab code. Users should assume responsibility for any risks that may arise during use. 95 | 96 | ## 贡献 97 | 98 | 欢迎提交问题(issues)和请求(pull requests)以改进本项目。 99 | 100 | Contributions are welcome. Please submit issues and pull requests to improve this project. -------------------------------------------------------------------------------- /Simplified_simulation_script.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as stats 3 | from scipy.special import logsumexp, gammaln, psi 4 | from utility.math_utils import nat_log 5 | from spm.spm_auxillary import spm_wnorm 6 | 7 | 8 | def explore_exploit_model(Gen_model): 9 | # Number of time points or 'epochs' within a trial: T 10 | T = 3 11 | 12 | # Priors about initial states: D and d 13 | D = {} 14 | D[1] = np.array([[1], [0]]) # {'left better','right better'} 15 | D[2] = np.array([[1], [0], [0], [0]]) # {'start','hint','choose-left','choose-right'} 16 | 17 | d = {} 18 | d[1] = np.array([[0.25], [0.25]]) # {'left better','right better'} 19 | d[2] = np.array([[1], [0], [0], [0]]) # {'start','hint','choose-left','choose-right'} 20 | 21 | # State-outcome mappings and beliefs: A and a 22 | Ns = [len(D[1]), len(D[2])] # number of states in each state factor (2 and 4) 23 | 24 | A = {} 25 | A[1] = np.zeros((3, 2, 4)) 26 | for i in range(Ns[1]): 27 | A[1][:, :, i] = np.array([[1, 1], [0, 0], [0, 0]]) 28 | 29 | pHA = 1 30 | A[1][:, :, 1] = np.array([[0, 0], [pHA, 1 - pHA], [1 - pHA, pHA]]) 31 | 32 | A[2] = np.zeros((3, 2, 4)) 33 | for i in range(2): 34 | A[2][:, :, i] = np.array([[1, 1], [0, 0], [0, 0]]) 35 | 36 | pWin = 0.8 37 | A[2][:, :, 2] = np.array([[0, 0], [1 - pWin, pWin], [pWin, 1 - pWin]]) 38 | A[2][:, :, 3] = np.array([[0, 0], [pWin, 1 - pWin], [1 - pWin, pWin]]) 39 | 40 | A[3] = np.zeros((4, 2, 4)) 41 | for i in range(Ns[1]): 42 | A[3][i, :, i] = np.array([1, 1]) 43 | 44 | a = {} 45 | a[1] = A[1] * 200 46 | a[2] = A[2] * 200 47 | a[3] = A[3] * 200 48 | a[1][:, :, 1] = np.array([[0, 0], [0.25, 0.25], [0.25, 0.25]]) 49 | 50 | # Controlled transitions and transition beliefs : B{:,:,u} and b(:,:,u) 51 | B = {} 52 | B[1] = np.zeros((2, 2, 1)) 53 | B[1][:, :, 0] = np.array([[1, 0], [0, 1]]) 54 | 55 | B[2] = np.zeros((4, 4, 4)) 56 | B[2][:, :, 0] = np.array([[1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) 57 | B[2][:, :, 1] = np.array([[0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0], [0, 0, 0, 0]]) 58 | B[2][:, :, 2] = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0]]) 59 | B[2][:, :, 3] = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1]]) 60 | 61 | # Preferred outcomes: C and c 62 | No = [A[1].shape[0], A[2].shape[0], A[3].shape[0]] 63 | 64 | C = {} 65 | C[1] = np.zeros((No[0], T)) 66 | C[2] = np.zeros((No[1], T)) 67 | C[3] = np.zeros((No[2], T)) 68 | 69 | la = 1 70 | rs = 4 71 | C[2][:, :] = np.array([[0, 0, 0], [0, -la, -la], [0, rs, rs / 2]]) 72 | 73 | # Allowable policies: U or V. 74 | NumPolicies = 5 75 | NumFactors = 2 76 | 77 | V = np.ones((T - 1, NumPolicies, NumFactors)) 78 | V[:, :, 0] = np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) 79 | V[:, :, 1] = np.array([[1, 2, 2, 3, 4], [1, 3, 4, 1, 1]]) 80 | 81 | # Habits: E and e. 82 | E = np.array([[1, 1, 1, 1, 1]]).T 83 | e = np.array([[1, 1, 1, 1, 1]]).T 84 | 85 | # Additional optional parameters. 86 | eta = 1 87 | omega = 1 88 | beta = 1 89 | alpha = 32 90 | 91 | # Define POMDP Structure 92 | mdp = { 93 | 'T': T, 94 | 'V': V, 95 | 'A': A, 96 | 'B': B, 97 | 'C': C, 98 | 'D': D, 99 | 'd': d, 100 | 'eta': eta, 101 | 'omega': omega, 102 | 'alpha': alpha, 103 | 'beta': beta, 104 | 'NumPolicies': NumPolicies, 105 | 'NumFactors': NumFactors 106 | } 107 | 108 | if Gen_model == 1: 109 | mdp['E'] = E 110 | elif Gen_model == 2: 111 | mdp['a'] = a 112 | mdp['e'] = e 113 | 114 | # Labels for states, outcomes, and actions 115 | label = { 116 | 'factor': {1: 'contexts', 2: 'choice states'}, 117 | 'name': {1: ['left-better', 'right-better'], 2: ['start', 'hint', 'choose left', 'choose right']}, 118 | 'modality': {1: 'hint', 2: 'win/lose', 3: 'observed action'}, 119 | 'outcome': {1: ['null', 'left hint', 'right hint'], 2: ['null', 'lose', 'win'], 3: ['start', 'hint', 'choose left', 'choose right']}, 120 | 'action': {2: ['start', 'hint', 'left', 'right']} 121 | } 122 | mdp['label'] = label 123 | 124 | return mdp 125 | 126 | def col_norm(input_dict): 127 | normalized_dict = {} # Initialize a dictionary to store the normalized arrays 128 | 129 | for key, array in input_dict.items(): 130 | normalized_array = array.copy() # Make a copy of the original array 131 | z = np.sum(normalized_array, axis=0) # Create normalizing constant from the sum of columns 132 | normalized_array = normalized_array / z # Divide columns by the constant 133 | normalized_dict[key] = normalized_array # Store the normalized array in the dictionary 134 | 135 | return normalized_dict 136 | 137 | def flatten_3d_to_2d(x): 138 | if x.ndim != 3: 139 | raise ValueError("Input array must be 3-dimensional") 140 | return x.transpose(2, 1, 0).reshape(-1, x.shape[0]) 141 | 142 | def md_dot(A, s, f): 143 | if f == 0: 144 | B = np.dot(A.T, s) 145 | elif f == 1: 146 | B = np.dot(A, s) 147 | else: 148 | raise ValueError("f must be either 0 or 1.") 149 | 150 | return B 151 | 152 | def cell_md_dot(X, x): 153 | # Initialize dimensions 154 | DIM = np.arange(len(x)) + X.ndim - len(x) 155 | XNDIM = X.ndim 156 | # Compute dot product 157 | for d in range(len(x)): 158 | s = np.ones(XNDIM, dtype=int) 159 | s[DIM[d]] = len(x[d]) 160 | X = X * np.reshape(np.array(x[d]), s) 161 | X = np.sum(X, axis=DIM[d]) 162 | 163 | X = np.squeeze(X) 164 | return X 165 | 166 | def G_epistemic_value(A, s): 167 | """ 168 | Auxiliary function for Bayesian surprise or mutual information. 169 | 170 | Parameters: 171 | A - likelihood array (probability of outcomes given causes) 172 | s - probability density of causes 173 | 174 | Returns: 175 | G - epistemic value 176 | """ 177 | 178 | # Probability distribution over the hidden causes: i.e., Q(s) 179 | qx = spm_cross(s) # This is the outer product of the posterior over states 180 | # calculated with respect to itself 181 | 182 | # Accumulate expectation of entropy: i.e., E[lnP(o|s)] 183 | G = 0 184 | qo = np.array([0]) # Initialize qo with zeros 185 | 186 | qx = qx.T.flatten() # Transpose qx to match the original MATLAB code 187 | for i in np.where(qx > np.exp(-16))[0]: 188 | # for i in np.ndindex(np.where(qx > np.exp(-16))): 189 | # Probability over outcomes for this combination of causes 190 | po = 1 191 | for g in range(len(A)): 192 | po = spm_cross(po, flatten_3d_to_2d(A[g])[i]) 193 | po = po.flatten() 194 | # qo = qo + qx.flatten()[i] * po 195 | # G = G + qx.flatten()[i] * np.dot(po, nat_log(po)) 196 | qo = qo + qx[i] * po 197 | G = G + qx[i] * np.dot(po, nat_log(po)) 198 | 199 | # Subtract entropy of expectations: i.e., E[lnQ(o)] 200 | G = G - np.dot(qo, nat_log(qo)) 201 | 202 | return G 203 | 204 | def spm_cross(X, x=None, *args): 205 | # Handle single input 206 | if x is None: 207 | if isinstance(X, np.ndarray): 208 | Y = X 209 | else: 210 | Y = spm_cross(*X) 211 | return Y 212 | 213 | # Handle cell arrays (lists in Python) 214 | if isinstance(X, list): 215 | X = spm_cross(*X) 216 | if isinstance(x, list): 217 | x = spm_cross(*x) 218 | 219 | # Outer product of first pair of arguments 220 | if isinstance(X, int): 221 | A = X 222 | B = np.reshape(x, (1,) * 1 + x.shape) 223 | else: 224 | A = np.reshape(X, X.shape + (1,) * x.ndim) 225 | B = np.reshape(x, (1,) * X.ndim + x.shape) 226 | Y = np.squeeze(A * B) 227 | 228 | # Handle remaining arguments 229 | for arg in args: 230 | Y = spm_cross(Y, arg) 231 | 232 | return Y 233 | 234 | def spm_KL_dir(q, p): 235 | """ 236 | KL divergence between two Dirichlet distributions 237 | Calculate KL(Q||P) = where avg is wrt Q between two Dirichlet distributions Q and P 238 | 239 | Parameters: 240 | q : array-like 241 | Concentration parameter matrix of Q 242 | p : array-like 243 | Concentration parameter matrix of P 244 | 245 | Returns: 246 | d : float 247 | The KL divergence between Q and P 248 | """ 249 | # KL divergence based on log beta functions 250 | d = spm_betaln(p) - spm_betaln(q) - np.sum((p - q) * spm_psi(q + 1/32), axis=0) 251 | d = np.sum(d) 252 | 253 | return d 254 | 255 | # def spm_betaln(z): 256 | # """ 257 | # Returns the log of the multivariate beta function of a vector. 258 | 259 | # Parameters: 260 | # z (array-like): Input vector or array. 261 | 262 | # Returns: 263 | # y (float or ndarray): The natural logarithm of the beta function for corresponding elements of the vector z. 264 | # """ 265 | # if np.ndim(z) == 1: 266 | # z = z[np.nonzero(z)] 267 | # y = np.sum(gammaln(z)) - gammaln(np.sum(z)) 268 | # else: 269 | # y = np.zeros((1,) + z.shape[1:]) 270 | # for i in range(z.shape[1]): 271 | # for j in range(z.shape[2]): 272 | # for k in range(z.shape[3]): 273 | # for l in range(z.shape[4]): 274 | # for m in range(z.shape[5]): 275 | # y[0, i, j, k, l, m] = spm_betaln(z[:, i, j, k, l, m]) 276 | # return y 277 | 278 | def spm_betaln(z): 279 | """ 280 | Returns the log of the multivariate beta function of a vector. 281 | 282 | Parameters: 283 | z (array-like): Input vector or array. 284 | 285 | Returns: 286 | y (float or ndarray): The natural logarithm of the beta function for corresponding elements of the vector z. 287 | """ 288 | if np.ndim(z) > 1: 289 | z = z[np.nonzero(z)] 290 | y = np.sum(gammaln(z)) - gammaln(np.sum(z)) 291 | else: 292 | y = np.zeros(z.shape[1:]) 293 | it = np.nditer(y, flags=['multi_index'], op_flags=['writeonly']) 294 | while not it.finished: 295 | idx = it.multi_index 296 | y[idx] = spm_betaln(z[(slice(None),) + idx]) 297 | it.iternext() 298 | return y 299 | 300 | def spm_psi(A): 301 | """ 302 | Normalization of a probability transition rate matrix (columns) 303 | :param A: numeric array 304 | :return: normalized array 305 | """ 306 | return psi(A) - psi(np.sum(A, axis=0)) 307 | 308 | def B_norm(B): 309 | bb = B.copy() # Create a copy of B to avoid modifying the original 310 | z = np.sum(bb, axis=0) # Create normalizing constant from sum of columns 311 | bb = bb / z # Divide columns by constant 312 | bb[np.isnan(bb)] = 0 # Replace NaN with zero 313 | return bb 314 | 315 | # Random seed initialization 316 | np.random.seed() 317 | 318 | # Simulation Settings 319 | Gen_model = 1 # As in the main tutorial code 320 | 321 | # Specify Generative Model 322 | MDP = explore_exploit_model(Gen_model) # Placeholder for the model function 323 | 324 | # Normalize generative process and generative model 325 | A = MDP['A'] # Likelihood matrices 326 | B = MDP['B'] # Transition matrices 327 | C = MDP['C'] # Preferences over outcomes 328 | D = MDP['D'] # Priors over initial states 329 | T = MDP['T'] # Time points per trial 330 | V = MDP['V'] # Policies 331 | beta = MDP['beta'] # Expected free energy precision 332 | alpha = MDP['alpha'] # Action precision 333 | eta = MDP['eta'] # Learning rate 334 | omega = MDP['omega'] # Forgetting rate 335 | 336 | A = col_norm(A) 337 | B = col_norm(B) 338 | D = col_norm(D) 339 | 340 | # Generative model (lowercase matrices/vectors are beliefs about capitalized matrices/vectors) 341 | NumPolicies = MDP['NumPolicies'] # Number of policies 342 | NumFactors = MDP['NumFactors'] # Number of state factors 343 | 344 | # Store initial parameter values of generative model for free energy calculations after learning 345 | if 'd' in MDP: 346 | d_prior = {} 347 | d_complexity = {} 348 | for factor in range(len(MDP['d'])): 349 | d_prior[factor + 1] = MDP['d'][factor + 1] 350 | d_complexity[factor + 1] = spm_wnorm(d_prior[factor+1]) 351 | 352 | if 'a' in MDP: 353 | a_prior = {} 354 | a_complexity = {} 355 | for modality in range(len(MDP['a'])): 356 | a_prior[modality] = MDP['a'][modality] 357 | a_complexity[modality] = spm_wnorm(a_prior[modality]) * (a_prior[modality] > 0) 358 | 359 | # Normalize matrices before model inversion/inference 360 | if 'a' in MDP: 361 | a = col_norm(MDP['a']) 362 | else: 363 | a = col_norm(MDP['A']) 364 | 365 | if 'b' in MDP: 366 | b = col_norm(MDP['b']) 367 | else: 368 | b = col_norm(MDP['B']) 369 | 370 | for ii in range(len(C)): 371 | C[ii] = MDP['C'][ii + 1] + 1 / 32 372 | for t in range(T): 373 | C[ii][:, t] = nat_log(np.exp(C[ii][:, t]) / np.sum(np.exp(C[ii][:, t]))) 374 | 375 | if 'd' in MDP: 376 | d = col_norm(MDP['d']) 377 | else: 378 | d = col_norm(MDP['D']) 379 | 380 | if 'e' in MDP: 381 | E = MDP['e'] 382 | E = E / np.sum(E) 383 | elif 'E' in MDP: 384 | E = MDP['E'] 385 | E = E / np.sum(E) 386 | else: 387 | E = col_norm(np.ones((NumPolicies, 1))) 388 | E = E / np.sum(E) 389 | 390 | # Initialize variables 391 | NumModalities = len(a) # Number of outcome factors 392 | NumFactors = len(d) # Number of hidden state factors 393 | NumPolicies = V.shape[1] # Number of allowable policies 394 | NumStates = np.zeros(NumFactors, dtype=int) 395 | NumControllable_transitions = np.zeros(NumFactors, dtype=int) 396 | 397 | for factor in range(NumFactors): 398 | NumStates[factor] = b[factor + 1].shape[0] 399 | NumControllable_transitions[factor] = b[factor + 1].shape[2] 400 | 401 | # Initialize the approximate posterior over states conditioned on policies 402 | state_posterior = {} 403 | for policy in range(NumPolicies): 404 | for factor in range(NumFactors): 405 | NumStates[factor] = len(D[factor + 1]) 406 | state_posterior[factor] = np.ones((NumStates[factor], T, policy + 1)) / NumStates[factor] 407 | 408 | # Initialize the approximate posterior over policies 409 | policy_posteriors = np.ones((NumPolicies, T)) / NumPolicies 410 | 411 | # Initialize posterior over actions 412 | chosen_action = np.zeros((len(B), T - 1), dtype=int) 413 | 414 | # If there is only one policy 415 | for factors in range(NumFactors): 416 | if NumControllable_transitions[factors] == 1: 417 | chosen_action[factors, :] = np.ones(T - 1) 418 | 419 | MDP['chosen_action'] = chosen_action 420 | 421 | # Initialize expected free energy precision (beta) 422 | posterior_beta = 1 423 | gamma = [1 / posterior_beta] * np.ones(T) # Expected free energy precision 424 | 425 | # Message passing variables 426 | TimeConst = 4 # Time constant for gradient descent 427 | NumIterations = 16 # Number of message passing iterations 428 | 429 | 430 | # Lets go! Message passing and policy selection 431 | #-------------------------------------------------------------------------- 432 | # Initialize necessary variables 433 | true_states = np.zeros((NumFactors, T)) 434 | outcomes = np.zeros((NumModalities, T)) 435 | O = {} 436 | Ft = np.zeros((T, NumIterations, T, NumFactors)) 437 | F = np.zeros((NumPolicies, T)) 438 | G = np.zeros((NumPolicies, T)) 439 | policy_priors = np.zeros((NumPolicies, T)) 440 | policy_posteriors = np.zeros((NumPolicies, T)) 441 | gamma_update = np.zeros((NumIterations * T, 1)) 442 | policy_posterior_updates = np.zeros((NumPolicies, NumIterations * T)) 443 | policy_posterior = np.zeros((NumPolicies, T)) 444 | BMA_states = {} 445 | action_posterior = np.zeros((1, NumControllable_transitions[-1], T - 1)) 446 | 447 | normalized_firing_rates = {} 448 | prediction_error = {} 449 | Expected_states = {} 450 | for factor in range(NumFactors): 451 | # normalized_firing_rates = np.array([np.zeros((LEN_ITER, 2, 3, 3, 5)), np.zeros((LEN_ITER, 4, 3, 3, 5))], dtype=object) 452 | normalized_firing_rates[factor] = np.zeros((NumIterations, NumStates[factor], T, T, NumPolicies)) 453 | # prediction_error = np.array([np.zeros((16, 2, 3, 3, 5)), np.zeros((16, 4, 3, 3, 5))], dtype=object) 454 | prediction_error[factor] = np.zeros((NumIterations, NumStates[factor], T, T, NumPolicies)) 455 | # Expected_states = np.array([np.zeros((2, 1)), np.zeros((4, 1))], dtype=object) 456 | Expected_states[factor] = np.zeros((NumStates[factor])) 457 | 458 | # Main loop 459 | for t in range(T): 460 | for factor in range(NumFactors): 461 | if t == 0: 462 | # Sample initial states 463 | prob_state = D[factor + 1] 464 | else: 465 | prob_state = B[factor + 1][:, true_states[factor, t-1], MDP['chosen_action'][factor, t-1] - 1] 466 | true_states[factor, t] = np.argmax(np.cumsum(prob_state) >= np.random.rand()) 467 | 468 | # change the dtype for index calculation 469 | true_states = np.array(true_states, dtype=int) 470 | 471 | for modality in range(NumModalities): 472 | outcomes[modality, t] = np.argmax(np.cumsum(a[modality + 1][:, true_states[0, t], true_states[1, t]]) >= np.random.rand()) 473 | for modality in range(NumModalities): 474 | vec = np.zeros((1, a[modality + 1].shape[0])) 475 | index = int(outcomes[modality, t]) 476 | vec[0, index] = 1 477 | O[(modality, t)] = vec 478 | 479 | for policy in range(NumPolicies): 480 | for Ni in range(NumIterations): 481 | for factor in range(NumFactors): 482 | lnAo = np.zeros_like(state_posterior[factor]) 483 | for tau in range(T): 484 | v_depolarization = nat_log(state_posterior[factor][:, tau, policy]) 485 | if tau < t + 1: 486 | for modal in range(NumModalities): 487 | # TODO: different from original matlab code, because we can have redundant dimensions in Matlab, but not applicable in Python... 488 | # ...No impact on the result. 489 | lnA = nat_log(a[modal + 1][int(outcomes[modal, tau]), :, :]) 490 | for fj in range(NumFactors): 491 | if fj != factor: 492 | # TODO: there may be an issue in the original m code that the dimension "policy" of lnAs is missing. Use fixed 0 instead. 493 | lnAs = md_dot(lnA, state_posterior[fj][:, tau, 0], fj) 494 | lnA = lnAs 495 | # TODO: there may be an issue in the original m code that the dimension "policy" of lnAo is missing. Use fixed 0 instead. 496 | lnAo[:, tau, 0] += lnA 497 | if tau == 0: 498 | lnD = nat_log(d[factor + 1]) 499 | lnBs = nat_log(B_norm(b[factor + 1][:, :, int(V[tau, policy, factor] - 1)].T) @ state_posterior[factor][:, tau + 1, policy]) 500 | elif tau == T - 1: 501 | lnD = nat_log(b[factor + 1][:, :, int(V[tau - 1, policy, factor] - 1)] @ state_posterior[factor][:, tau - 1, policy]) 502 | lnBs = np.zeros_like(d[factor + 1]) 503 | else: 504 | lnD = nat_log(b[factor + 1][:, :, int(V[tau - 1, policy, factor] - 1)] @ state_posterior[factor][:, tau - 1, policy]) 505 | lnBs = nat_log(B_norm(b[factor + 1][:, :, int(V[tau, policy, factor] -1)].T) @ state_posterior[factor][:, tau + 1, policy]) 506 | # TODO: there may be an issue in the original m code that the dimension "policy" of lnAo is missing. Use fixed 0 instead. 507 | # v_depolarization += (0.5 * lnD.reshape(v_depolarization.shape) + 0.5 * lnBs.reshape(v_depolarization.shape) + lnAo[:, tau, 0] - v_depolarization) / TimeConst 508 | v_depolarization += (0.5 * lnD.reshape(v_depolarization.shape) + 0.5 * lnBs.reshape(v_depolarization.shape) + flatten_3d_to_2d(lnAo)[tau] - v_depolarization) / TimeConst 509 | # TODO: there may be an issue in the original m code that the dimension "policy" of lnAo is missing. Use fixed 0 instead. 510 | # Ft[tau, Ni, t, factor] = state_posterior[factor][:, tau, policy].T @ (0.5 * lnD.reshape(v_depolarization.shape) + 0.5 * lnBs.reshape(v_depolarization.shape) + lnAo[:, tau, 0] - nat_log(state_posterior[factor][:, tau, policy])) 511 | Ft[tau, Ni, t, factor] = state_posterior[factor][:, tau, policy].T @ (0.5 * lnD.reshape(v_depolarization.shape) + 0.5 * lnBs.reshape(v_depolarization.shape) + flatten_3d_to_2d(lnAo)[tau] - nat_log(state_posterior[factor][:, tau, policy])) 512 | state_posterior[factor][:, tau, policy] = np.exp(v_depolarization) / np.sum(np.exp(v_depolarization)) 513 | normalized_firing_rates[factor][Ni, :, tau, t, policy] = state_posterior[factor][:, tau, policy] 514 | prediction_error[factor][Ni, :, tau, t, policy] = v_depolarization 515 | Fintermediate = np.sum(Ft, axis=3) 516 | # TODO: this is a patch to adjust the size of Fintermediate. Could be optimized. 517 | Fintermediate = Fintermediate[:,:,t] 518 | Fintermediate = np.squeeze(np.sum(Fintermediate, axis=0)) 519 | F[policy, t] = Fintermediate[-1] 520 | 521 | Gintermediate = np.zeros((NumPolicies, 1)) 522 | horizon = T 523 | 524 | for policy in range(NumPolicies): 525 | if 'd' in MDP: 526 | for factor in range(NumFactors): 527 | Gintermediate[policy] -= d_complexity[factor + 1].T @ state_posterior[factor][:, 0, policy] 528 | for timestep in range(t, horizon): 529 | for factor in range(NumFactors): 530 | Expected_states[factor] = state_posterior[factor][:, timestep, policy] 531 | Gintermediate[policy] += G_epistemic_value(list(a.values()), list(Expected_states.values())) 532 | for modality in range(NumModalities): 533 | predictive_observations_posterior = cell_md_dot(a[modality + 1], Expected_states) 534 | Gintermediate[policy] += predictive_observations_posterior.T @ C[modality][:, t] 535 | if 'a' in MDP: 536 | Gintermediate[policy] -= cell_md_dot(a_complexity[modality], [predictive_observations_posterior, *Expected_states]) 537 | G[:, t] = Gintermediate.flatten() 538 | 539 | if t > 0: 540 | gamma[t] = gamma[t - 1] 541 | # For facilitation of calculating log(E) with different shape arrays in the iteration 542 | E = E.flatten() 543 | for ni in range(NumIterations): 544 | policy_priors[:, t] = np.exp(np.log(E) + gamma[t] * G[:, t]) / np.sum(np.exp(np.log(E) + gamma[t] * G[:, t])) 545 | policy_posteriors[:, t] = np.exp(np.log(E) + gamma[t] * G[:, t] + F[:, t]) / np.sum(np.exp(np.log(E) + gamma[t] * G[:, t] + F[:, t])) 546 | G_error = (policy_posteriors[:, t] - policy_priors[:, t]).T @ G[:, t] 547 | beta_update = posterior_beta - beta + G_error 548 | posterior_beta -= beta_update / 2 549 | gamma[t] = 1 / posterior_beta 550 | n = t * NumIterations + ni 551 | gamma_update[n, 0] = gamma[t].reshape(1, -1) 552 | policy_posterior_updates[:, n] = policy_posteriors[:, t] 553 | policy_posterior[:, t] = policy_posteriors[:, t] 554 | 555 | for factor in range(NumFactors): 556 | for tau in range(T): 557 | new_col = np.reshape(state_posterior[factor][:, tau, :], (NumStates[factor], NumPolicies)) @ policy_posteriors[:, t] 558 | new_col = new_col.reshape(1,-1).T 559 | if tau == 0: 560 | BMA_states[factor] = new_col 561 | else: 562 | BMA_states[factor] = np.hstack((BMA_states[factor], new_col)) 563 | 564 | if t < T - 1: 565 | action_posterior_intermediate = np.zeros((NumControllable_transitions[-1], 1)).T 566 | for policy in range(NumPolicies): 567 | sub = tuple(V[t, policy, :].astype(int) - 1) 568 | action_posterior_intermediate[sub] += policy_posteriors[policy, t] 569 | # sub = (slice(None),) * NumFactors 570 | action_posterior_intermediate[:] = np.exp(alpha * np.log(action_posterior_intermediate[:])) / np.sum(np.exp(alpha * np.log(action_posterior_intermediate[:]))) 571 | action_posterior[..., t] = action_posterior_intermediate 572 | ControlIndex = np.where(NumControllable_transitions > 1)[0] 573 | action = np.arange(1, NumControllable_transitions[ControlIndex] + 1) 574 | for factors in range(NumFactors): 575 | if NumControllable_transitions[factors] > 2: 576 | ind = np.argmax(np.cumsum(action_posterior_intermediate.flatten()) > np.random.rand()) 577 | MDP['chosen_action'][factor, t] = action[ind] 578 | 579 | 580 | # accumulate concentration paramaters (learning) --> MATLAB code L436 581 | 582 | for t in range(T): 583 | # a matrix (likelihood) 584 | # but this part is never executed 585 | if 'a' in MDP: 586 | for modality in range(NumModalities): 587 | a_learning = O[modality, t].T 588 | for factor in range(NumFactors): 589 | a_learning = spm_cross(a_learning, BMA_states[factor][:, t]) 590 | a_learning = a_learning * (MDP['a'][modality] > 0) 591 | MDP['a'][modality] = MDP['a'][modality] * omega + a_learning * eta 592 | 593 | # Initial hidden states d (priors) 594 | if 'd' in MDP: 595 | for factor in range(NumFactors): 596 | MDP['d'][factor + 1] = MDP['d'][factor + 1].astype(float) 597 | i = np.array(MDP['d'][factor + 1] > 0).flatten() 598 | if len(BMA_states[factor][i, 0]) == 1: 599 | MDP['d'][factor + 1][i] = omega * MDP['d'][factor + 1][i] + eta * BMA_states[factor][i, 0] 600 | else: 601 | MDP['d'][factor + 1][i] = omega * MDP['d'][factor + 1][i] + eta * BMA_states[factor][i, 0].reshape(MDP['d'][factor + 1].shape) 602 | 603 | # Policies e (habits) 604 | # but this part is never executed 605 | if 'e' in MDP: 606 | MDP['e'] = omega * MDP['e'] + eta * policy_posteriors[:, T-1] 607 | 608 | # Free energy of concentration parameters 609 | # ---------------------------------------------------------------------- 610 | 611 | # (negative) free energy of a 612 | # but this part is never executed 613 | MDP['Fa'] = np.zeros(NumModalities) 614 | for modality in range(1, NumModalities + 1): 615 | if 'a' in MDP: 616 | # Implement spm_KL_dir function for KL divergence calculation 617 | MDP['Fa'][modality-1] = - spm_KL_dir(MDP['a'][modality], a_prior[modality]) 618 | 619 | # (negative) free energy of d 620 | MDP['Fd'] = np.zeros(NumFactors) 621 | for factor in range(1, NumFactors + 1): 622 | if 'd' in MDP: 623 | MDP['Fd'][factor-1] = - spm_KL_dir(MDP['d'][factor], d_prior[factor]) 624 | 625 | # (negative) free energy of e 626 | # but this part is never executed 627 | if 'e' in MDP: 628 | MDP['Fe'] = - spm_KL_dir(MDP['e'], E) 629 | 630 | # Simulated dopamine responses (beta updates) 631 | # ---------------------------------------------------------------------- 632 | # "deconvolution" of neural encoding of precision 633 | if NumPolicies > 1: 634 | # gamma_update = gamma # Assuming gamma_update is defined in prior code 635 | phasic_dopamine = 8 * np.gradient(gamma_update.flatten()) + gamma_update.flatten() / 8 636 | else: 637 | phasic_dopamine = [] 638 | gamma_update = [] 639 | 640 | # Bayesian model average of neuronal variables; normalized firing rate and prediction error 641 | # ---------------------------------------------------------------------- 642 | # Assuming Ni (NumIterations) is defined as 16 from prior code 643 | Ni = NumIterations 644 | BMA_normalized_firing_rates = {} 645 | BMA_prediction_error = {} 646 | 647 | for factor in range(NumFactors): 648 | num_states = NumStates[factor] # NumStates is 0-indexed in Python 649 | BMA_normalized_firing_rates[factor + 1] = np.zeros((Ni, num_states, T, T)) 650 | BMA_prediction_error[factor + 1] = np.zeros((Ni, num_states, T, T)) 651 | 652 | for t in range(T): 653 | for policy in range(NumPolicies): 654 | # Accumulate normalized firing rates 655 | BMA_normalized_firing_rates[factor + 1][:, :, :T, t] += ( 656 | normalized_firing_rates[factor][:, :, :T, t, policy] * 657 | policy_posteriors[policy, t] 658 | ) 659 | 660 | # Accumulate prediction errors 661 | BMA_prediction_error[factor + 1][:, :, :T, t] += ( 662 | prediction_error[factor][:, :, :T, t, policy] * 663 | policy_posteriors[policy, t] 664 | ) 665 | 666 | print("Calculation completed. To be plotted.") -------------------------------------------------------------------------------- /original_matlab_code/EFE_Precision_Updating.m: -------------------------------------------------------------------------------- 1 | %% Example code for simulated expected free energy precision (beta/gamma) updates 2 | % (associated with dopamine in the neural process theory) 3 | 4 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 5 | % Application to Empirical Data 6 | 7 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 8 | 9 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 10 | 11 | % This script will reproduce the simulation results in Figure 9 12 | 13 | clear all 14 | close all 15 | 16 | % Here you can set the number of policies and the distributions that 17 | % contribute to prior and posterior policy precision 18 | 19 | E = [1 1 1 1 1]'; % Set a fixed-form prior distribution 20 | % over policies (habits) 21 | 22 | G = [12.505 9.51 12.5034 12.505 12.505]'; % Set an example expected 23 | % free energy distribution over policies 24 | 25 | F = [17.0207 1.7321 1.7321 17.0387 17.0387]'; % Set an example variational 26 | % free energy distribution over 27 | % policies after a new observation 28 | 29 | 30 | gamma_0 = 1; % Starting expected free energy precision value 31 | gamma = gamma_0; % Initial expected free energy precision to be updated 32 | beta_prior = 1/gamma; % Initial prior on expected free energy precision 33 | beta_posterior = beta_prior; % Initial posterior on expected free energy precision 34 | psi = 2; % Step size parameter (promotes stable convergence) 35 | 36 | for ni = 1:16 % number of variational updates (16) 37 | 38 | % calculate prior and posterior over policies (see main text for 39 | % explanation of equations) 40 | 41 | pi_0 = exp(log(E) - gamma*G)/sum(exp(log(E) - gamma*G)); % prior over policies 42 | 43 | pi_posterior = exp(log(E) - gamma*G - F)/sum(exp(log(E) - gamma*G - F)); % posterior 44 | % over policies 45 | % calculate expected free energy precision 46 | 47 | G_error = (pi_posterior - pi_0)'*-G; % expected free energy prediction error 48 | 49 | beta_update = beta_posterior - beta_prior + G_error; % change in beta: 50 | % gradient of F with respect to gamma 51 | % (recall gamma = 1/beta) 52 | 53 | beta_posterior = beta_posterior - beta_update/psi; % update posterior precision 54 | % estimate (with step size of psi = 2, which reduces 55 | % the magnitude of each update and can promote 56 | % stable convergence) 57 | 58 | gamma = 1/beta_posterior; % update expected free energy precision 59 | 60 | % simulate dopamine responses 61 | 62 | n = ni; 63 | 64 | gamma_dopamine(n,1) = gamma; % simulated neural encoding of precision 65 | % (beta_posterior^-1) at each iteration of 66 | % variational updating 67 | 68 | policies_neural(:,n) = pi_posterior; % neural encoding of posterior over policies at 69 | % each iteration of variational updating 70 | end 71 | 72 | %% Show Results 73 | 74 | disp(' '); 75 | disp('Final Policy Prior:'); 76 | disp(pi_0); 77 | disp(' '); 78 | disp('Final Policy Posterior:'); 79 | disp(pi_posterior); 80 | disp(' '); 81 | disp('Final Policy Difference Vector:'); 82 | disp(pi_posterior-pi_0); 83 | disp(' '); 84 | disp('Negative Expected Free Energy:'); 85 | disp(-G); 86 | disp(' '); 87 | disp('Prior G Precision (Prior Gamma):'); 88 | disp(gamma_0); 89 | disp(' '); 90 | disp('Posterior G Precision (Gamma):'); 91 | disp(gamma); 92 | disp(' '); 93 | 94 | gamma_dopamine_plot = [gamma_0;gamma_0;gamma_0;gamma_dopamine]; % Include prior value 95 | 96 | figure 97 | plot(gamma_dopamine_plot); 98 | ylim([min(gamma_dopamine_plot)-.05 max(gamma_dopamine_plot)+.05]) 99 | title('Expected Free Energy Precision (Tonic Dopamine)'); 100 | xlabel('Updates'); 101 | ylabel('\gamma'); 102 | 103 | figure 104 | plot([gradient(gamma_dopamine_plot)],'r'); 105 | ylim([min(gradient(gamma_dopamine_plot))-.01 max(gradient(gamma_dopamine_plot))+.01]) 106 | title('Rate of Change in Precision (Phasic Dopamine)'); 107 | xlabel('Updates'); 108 | ylabel('\gamma gradient'); 109 | 110 | % uncomment if you want to display/plot firing rates encoding beliefs about each 111 | % policy (columns = policies, rows = updates over time) 112 | 113 | % plot(policies_neural); 114 | % disp('Firing rates encoding beliefs over policies:'); 115 | % disp(policies_neural'); 116 | % disp(' '); 117 | -------------------------------------------------------------------------------- /original_matlab_code/EFE_learning_novelty_term.m: -------------------------------------------------------------------------------- 1 | %% Calculating novelty term in expected free energy when learning 'A' matrix concentration parameters 2 | % (which drives parameter exploration) 3 | 4 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 5 | % Application to Empirical Data 6 | 7 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 8 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 9 | 10 | clear 11 | close all 12 | 13 | %-- 'a' = concentration parameters for likelihood matrix 'A' 14 | 15 | % small concentration parameter values 16 | a1 = [.25 1; 17 | .75 1]; 18 | 19 | % intermediate concentration parameter values 20 | a2 = [2.5 10; 21 | 7.5 10]; 22 | 23 | % large concentration parameter values 24 | a3 = [25 100; 25 | 75 100]; 26 | 27 | % normalize columns in 'a' to get likelihood matrix 'A' (see col_norm 28 | % function at the end of script) 29 | A1 = col_norm(a1); 30 | A2 = col_norm(a2); 31 | A3 = col_norm(a3); 32 | 33 | % calculate 'a_sum' 34 | a1_sum = [a1(1,1)+a1(2,1) a1(1,2)+a1(2,2); 35 | a1(1,1)+a1(2,1) a1(1,2)+a1(2,2)]; 36 | 37 | a2_sum = [a2(1,1)+a2(2,1) a2(1,2)+a2(2,2); 38 | a2(1,1)+a2(2,1) a2(1,2)+a2(2,2)]; 39 | 40 | a3_sum = [a3(1,1)+a3(2,1) a3(1,2)+a3(2,2); 41 | a3(1,1)+a3(2,1) a3(1,2)+a3(2,2)]; 42 | 43 | % element wise inverse for 'a' and 'a_sum' 44 | inv_a1 = [1/a1(1,1) 1/a1(1,2); 45 | 1/a1(2,1) 1/a1(2,2)]; 46 | 47 | inv_a2 = [1/a2(1,1) 1/a2(1,2); 48 | 1/a2(2,1) 1/a2(2,2)]; 49 | 50 | inv_a3 = [1/a3(1,1) 1/a3(1,2); 51 | 1/a3(2,1) 1/a3(2,2)]; 52 | 53 | inv_a1_sum = [1/a1_sum(1,1) 1/a1_sum(1,2); 54 | 1/a1_sum(2,1) 1/a1_sum(2,2)]; 55 | 56 | inv_a2_sum = [1/a2_sum(1,1) 1/a2_sum(1,2); 57 | 1/a2_sum(2,1) 1/a2_sum(2,2)]; 58 | 59 | inv_a3_sum = [1/a3_sum(1,1) 1/a3_sum(1,2); 60 | 1/a3_sum(2,1) 1/a3_sum(2,2)]; 61 | 62 | % 'W' term for 'a' matrix 63 | W1 = .5*(inv_a1-inv_a1_sum); 64 | W2 = .5*(inv_a2-inv_a2_sum); 65 | W3 = .5*(inv_a3-inv_a3_sum); 66 | 67 | % beliefs over states under a policy at a time point 68 | s_pi_tau = [.9 .1]'; 69 | 70 | % predictive posterior over outcomes (A*s_pi_tau = predicted o_pi_tau) 71 | A1s = A1*s_pi_tau; 72 | A2s = A2*s_pi_tau; 73 | A3s = A3*s_pi_tau; 74 | 75 | % W term multiplied by beliefs over states under a policy at a time point 76 | W1s = W1*s_pi_tau; 77 | W2s = W2*s_pi_tau; 78 | W3s = W3*s_pi_tau; 79 | 80 | % compute novelty using dot product function 81 | Novelty_smallCP = dot((A1s),(W1s)); 82 | Novelty_intermediateCP = dot((A2s),(W2s)); 83 | Novelty_largeCP = dot((A3s),(W3s)); 84 | 85 | 86 | % show results 87 | disp(' '); 88 | disp('Novelty term for small concentration parameter values:'); 89 | disp(Novelty_smallCP); 90 | disp(' '); 91 | disp('Novelty term for intermediate concentration parameter values:'); 92 | disp(Novelty_intermediateCP); 93 | disp(' '); 94 | disp('Novelty term for large concentration parameter values:'); 95 | disp(Novelty_largeCP); 96 | disp(' '); 97 | 98 | 99 | %% function for normalizing 'a' to get likelihood matrix 'A' 100 | function A_normed = col_norm(A_norm) 101 | aa = A_norm; 102 | norm_constant = sum(aa,1); % create normalizing constant from sum of columns 103 | aa = aa./norm_constant; % divide columns by constant 104 | A_normed = aa; 105 | end 106 | -------------------------------------------------------------------------------- /original_matlab_code/Estimate_parameters.m: -------------------------------------------------------------------------------- 1 | function [DCM] = Estimate_parameters(DCM) 2 | 3 | % MDP inversion using Variational Bayes 4 | % FORMAT [DCM] = spm_dcm_mdp(DCM) 5 | % 6 | % Expects: 7 | %-------------------------------------------------------------------------- 8 | % DCM.MDP % MDP structure specifying a generative model 9 | % DCM.field % parameter (field) names to optimise 10 | % DCM.U % cell array of outcomes (stimuli) 11 | % DCM.Y % cell array of responses (action) 12 | % 13 | % Returns: 14 | %-------------------------------------------------------------------------- 15 | % DCM.M % generative model (DCM) 16 | % DCM.Ep % Conditional means (structure) 17 | % DCM.Cp % Conditional covariances 18 | % DCM.F % (negative) Free-energy bound on log evidence 19 | % 20 | % This routine inverts (cell arrays of) trials specified in terms of the 21 | % stimuli or outcomes and subsequent choices or responses. It first 22 | % computes the prior expectations (and covariances) of the free parameters 23 | % specified by DCM.field. These parameters are log scaling parameters that 24 | % are applied to the fields of DCM.MDP. 25 | % 26 | % If there is no learning implicit in multi-trial games, only unique trials 27 | % (as specified by the stimuli), are used to generate (subjective) 28 | % posteriors over choice or action. Otherwise, all trials are used in the 29 | % order specified. The ensuing posterior probabilities over choices are 30 | % used with the specified choices or actions to evaluate their log 31 | % probability. This is used to optimise the MDP (hyper) parameters in 32 | % DCM.field using variational Laplace (with numerical evaluation of the 33 | % curvature). 34 | % 35 | %__________________________________________________________________________ 36 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 37 | 38 | % Karl Friston 39 | % $Id: spm_dcm_mdp.m 7120 2017-06-20 11:30:30Z spm $ 40 | 41 | % OPTIONS 42 | %-------------------------------------------------------------------------- 43 | ALL = false; 44 | 45 | % Here we specify prior expectations (for parameter means and variances) 46 | %-------------------------------------------------------------------------- 47 | prior_variance = 1/4; % smaller values will lead to a greater complexity 48 | % penalty (posteriors will remain closer to priors) 49 | 50 | for i = 1:length(DCM.field) 51 | field = DCM.field{i}; 52 | try 53 | param = DCM.MDP.(field); 54 | param = double(~~param); 55 | catch 56 | param = 1; 57 | end 58 | if ALL 59 | pE.(field) = zeros(size(param)); 60 | pC{i,i} = diag(param); 61 | else 62 | if strcmp(field,'alpha') 63 | pE.(field) = log(16); % in log-space (to keep positive) 64 | pC{i,i} = prior_variance; 65 | elseif strcmp(field,'beta') 66 | pE.(field) = log(1); % in log-space (to keep positive) 67 | pC{i,i} = prior_variance; 68 | elseif strcmp(field,'la') 69 | pE.(field) = log(1); % in log-space (to keep positive) 70 | pC{i,i} = prior_variance; 71 | elseif strcmp(field,'rs') 72 | pE.(field) = log(5); % in log-space (to keep positive) 73 | pC{i,i} = prior_variance; 74 | elseif strcmp(field,'eta') 75 | pE.(field) = log(0.5/(1-0.5)); % in logit-space - bounded between 0 and 1 76 | pC{i,i} = prior_variance; 77 | elseif strcmp(field,'omega') 78 | pE.(field) = log(0.5/(1-0.5)); % in logit-space - bounded between 0 and 1 79 | pC{i,i} = prior_variance; 80 | else 81 | pE.(field) = 0; % if it can take any negative or positive value 82 | pC{i,i} = prior_variance; 83 | end 84 | end 85 | end 86 | 87 | pC = spm_cat(pC); 88 | 89 | % model specification 90 | %-------------------------------------------------------------------------- 91 | M.L = @(P,M,U,Y)spm_mdp_L(P,M,U,Y); % log-likelihood function 92 | M.pE = pE; % prior means (parameters) 93 | M.pC = pC; % prior variance (parameters) 94 | M.mdp = DCM.MDP; % MDP structure 95 | 96 | % Variational Laplace 97 | %-------------------------------------------------------------------------- 98 | [Ep,Cp,F] = spm_nlsi_Newton(M,DCM.U,DCM.Y); % This is the actual fitting routine 99 | 100 | % Store posterior distributions and log evidence (free energy) 101 | %-------------------------------------------------------------------------- 102 | DCM.M = M; % Generative model 103 | DCM.Ep = Ep; % Posterior parameter estimates 104 | DCM.Cp = Cp; % Posterior variances and covariances 105 | DCM.F = F; % Free energy of model fit 106 | 107 | return 108 | 109 | function L = spm_mdp_L(P,M,U,Y) 110 | % log-likelihood function 111 | % FORMAT L = spm_mdp_L(P,M,U,Y) 112 | % P - parameter structure 113 | % M - generative model 114 | % U - inputs 115 | % Y - observed repsonses 116 | % 117 | % This function runs the generative model with a given set of parameter 118 | % values, after adding in the observations and actions on each trial 119 | % from (real or simulated) participant data. It then sums the 120 | % (log-)probabilities (log-likelihood) of the participant's actions under the model when it 121 | % includes that set of parameter values. The variational Bayes fitting 122 | % routine above uses this function to find the set of parameter values that maximize 123 | % the probability of the participant's actions under the model (while also 124 | % penalizing models with parameter values that move farther away from prior 125 | % values). 126 | %__________________________________________________________________________ 127 | 128 | if ~isstruct(P); P = spm_unvec(P,M.pE); end 129 | 130 | % Here we re-transform parameter values out of log- or logit-space when 131 | % inserting them into the model to compute the log-likelihood 132 | %-------------------------------------------------------------------------- 133 | mdp = M.mdp; 134 | field = fieldnames(M.pE); 135 | for i = 1:length(field) 136 | if strcmp(field{i},'alpha') 137 | mdp.(field{i}) = exp(P.(field{i})); 138 | elseif strcmp(field{i},'beta') 139 | mdp.(field{i}) = exp(P.(field{i})); 140 | elseif strcmp(field{i},'la') 141 | mdp.(field{i}) = exp(P.(field{i})); 142 | elseif strcmp(field{i},'rs') 143 | mdp.(field{i}) = exp(P.(field{i})); 144 | elseif strcmp(field{i},'eta') 145 | mdp.(field{i}) = 1/(1+exp(-P.(field{i}))); 146 | elseif strcmp(field{i},'omega') 147 | mdp.(field{i}) = 1/(1+exp(-P.(field{i}))); 148 | else 149 | mdp.(field{i}) = exp(P.(field{i})); 150 | end 151 | end 152 | 153 | % place MDP in trial structure 154 | %-------------------------------------------------------------------------- 155 | la = mdp.la_true; % true level of loss aversion 156 | rs = mdp.rs_true; % true preference magnitude for winning (higher = more risk-seeking) 157 | 158 | if isfield(M.pE,'la')&&isfield(M.pE,'rs') 159 | mdp.C{2} = [0 0 0 ; % Null 160 | 0 -mdp.la -mdp.la ; % Loss 161 | 0 mdp.rs mdp.rs/2]; % win 162 | elseif isfield(M.pE,'la') 163 | mdp.C{2} = [0 0 0 ; % Null 164 | 0 -mdp.la -mdp.la ; % Loss 165 | 0 rs rs/2]; % win 166 | elseif isfield(M.pE,'rs') 167 | mdp.C{2} = [0 0 0 ; % Null 168 | 0 -la -la ; % Loss 169 | 0 mdp.rs mdp.rs/2]; % win 170 | else 171 | mdp.C{2} = [0 0 0 ; % Null 172 | 0 -la -la ; % Loss 173 | 0 rs rs/2]; % win 174 | end 175 | 176 | j = 1:numel(U); % observations for each trial 177 | n = numel(j); % number of trials 178 | 179 | [MDP(1:n)] = deal(mdp); % Create MDP with number of specified trials 180 | [MDP.o] = deal(U{j}); % Add observations in each trial 181 | 182 | % solve MDP and accumulate log-likelihood 183 | %-------------------------------------------------------------------------- 184 | MDP = spm_MDP_VB_X_tutorial(MDP); % run model with possible parameter values 185 | 186 | L = 0; % start (log) probability of actions given the model at 0 187 | 188 | for i = 1:numel(Y) % Get probability of true actions for each trial 189 | for j = 1:numel(Y{1}(:,2)) % Only get probability of the second (controllable) state factor 190 | 191 | L = L + log(MDP(i).P(:,Y{i}(2,j),j)+ eps); % sum the (log) probabilities of each action 192 | % given a set of possible parameter values 193 | end 194 | end 195 | 196 | clear('MDP') 197 | 198 | fprintf('LL: %f \n',L) -------------------------------------------------------------------------------- /original_matlab_code/Message_passing_example.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %-- Message Passing Examples--% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Supplementary Code for: A Tutorial on Active Inference Modelling and its 6 | % Application to Empirical Data 7 | 8 | % By: Ryan Smith and Christopher J. Whyte 9 | % We also acknowledge Samuel Taylor for contributing to this example code 10 | 11 | % This script provides two examples of (marginal) message passing, based on 12 | % the steps described in the main text. Each of the two examples (sections) 13 | % need to be run separately. The first example fixes all observed 14 | % variables immediately and does not include variables associated with the 15 | % neural process theory. The second example provides observations 16 | % sequentially and also adds in the neural process theory variables. To 17 | % remind the reader, the message passing steps in the main text are: 18 | 19 | % 1. Initialize the values of the approximate posteriors q(s_(?,?) ) 20 | % for all hidden variables (i.e., all edges) in the graph. 21 | % 2. Fix the value of observed variables (here, o_?). 22 | % 3. Choose an edge (V) corresponding to the hidden variable you want to 23 | % infer (here, s_(?,?)). 24 | % 4. Calculate the messages, ?(s_(?,?)), which take on values sent by 25 | % each factor node connected to V. 26 | % 5. Pass a message from each connected factor node N to V (often written 27 | % as ?_(N?V)). 28 | % 6. Update the approximate posterior represented by V according to the 29 | % following rule: q(s_(?,?) )? ? ?(s_(?,?))? ?(s_(?,?)). The arrow 30 | % notation here indicates messages from two different factors arriving 31 | % at the same edge. 32 | % 6A. Normalize the product of these messages so that q(s_(?,?) ) 33 | % corresponds to a proper probability distribution. 34 | % 6B. Use this new q(s_(?,?) ) to update the messages sent by 35 | % connected factors (i.e., for the next round of message passing). 36 | % 7. Repeat steps 4-6 sequentially for each edge. 37 | % 8. Steps 3-7 are then repeated until the difference between updates 38 | % converges to some acceptably low value (i.e., resulting in stable 39 | % posterior beliefs for all edges). 40 | 41 | %% Example 1: Fixed observations and message passing steps 42 | 43 | % This section carries out marginal message passing on a graph with beliefs 44 | % about states at two time points. In this first example, both observations 45 | % are fixed from the start (i.e., there are no ts as in full active inference 46 | % models with sequentially presented observations) to provide the simplest 47 | % example possible. We also highlight where each of the message passing 48 | % steps described in the main text are carried out. 49 | 50 | % Note that some steps (7 and 8) appear out of order when they involve loops that 51 | % repeat earlier steps 52 | 53 | % Specify generative model and initialize variables 54 | 55 | rng('shuffle') 56 | 57 | clear 58 | close all 59 | 60 | % priors 61 | D = [.5 .5]'; 62 | 63 | % likelihood mapping 64 | A = [.9 .1; 65 | .1 .9]; 66 | 67 | % transitions 68 | B = [1 0; 69 | 0 1]; 70 | 71 | % number of timesteps 72 | T = 2; 73 | 74 | % number of iterations of message passing 75 | NumIterations = 16; 76 | 77 | % initialize posterior (Step 1) 78 | for t = 1:T 79 | Qs(:,t) = [.5 .5]'; 80 | end 81 | 82 | % fix observations (Step 2) 83 | o{1} = [1 0]'; 84 | o{2} = [1 0]'; 85 | 86 | % iterate a set number of times (alternatively, until convergence) (Step 8) 87 | for Ni = 1:NumIterations 88 | % For each edge (hidden state) (Step 7) 89 | for tau = 1:T 90 | % choose an edge (Step 3) 91 | q = nat_log(Qs(:,tau)); 92 | 93 | % compute messages sent by D and B (Steps 4) using the posterior 94 | % computed in Step 6B 95 | if tau == 1 % first time point 96 | lnD = nat_log(D); % Message 1 97 | lnBs = nat_log(B'*Qs(:,tau+1)); % Message 2 98 | elseif tau == T % last time point 99 | lnBs = nat_log(B*Qs(:,tau-1)); % Message 1 100 | end 101 | 102 | % likelihood (Message 3) 103 | lnAo = nat_log(A'*o{tau}); 104 | 105 | % Steps 5-6 (Pass messages and update the posterior) 106 | % Since all terms are in log space, this is addition instead of 107 | % multiplication. This corresponds to equation 16 in the main 108 | % text (within the softmax) 109 | if tau == 1 110 | q = .5*lnD + .5*lnBs + lnAo; 111 | elseif tau == T 112 | q = .5*lnBs + lnAo; 113 | end 114 | 115 | % normalize using a softmax function to find posterior (Step 6A) 116 | Qs(:,tau) = (exp(q)/sum(exp(q))); 117 | qs(Ni,:,tau) = Qs(:,tau); % store value for each iteration 118 | end % Repeat for remaining edges (Step 7) 119 | end % Repeat until convergence/for fixed number of iterations (Step 8) 120 | 121 | Qs; % final posterior beliefs over states 122 | 123 | disp(' '); 124 | disp('Posterior over states q(s) in example 1:'); 125 | disp(' '); 126 | disp(Qs); 127 | 128 | figure 129 | 130 | % firing rates (traces) 131 | qs_plot = [D' D';qs(:,:,1) qs(:,:,2)]; % add prior to starting value 132 | plot(qs_plot) 133 | title('Example 1: Approximate Posteriors (1 per edge per time point)') 134 | ylabel('q(s_t_a_u)','FontSize',12) 135 | xlabel('Message passing iterations','FontSize',12) 136 | 137 | 138 | %% Example 2: Sequential observations and simulation of firing rates and ERPs 139 | 140 | % This script performs state estimation using the message passing 141 | % algorithm introduced in Parr, Markovic, Kiebel, & Friston (2019). 142 | % This script can be thought of as the full message passing solution to 143 | % problem 2 in the pencil and paper exercises. It also generates 144 | % simulated firing rates and ERPs in the same manner as those shown in 145 | % figs. 8, 10, 11, 14, 15, and 16. Unlike example 1, observations are 146 | % presented sequentially (i.e., two ts and two taus). 147 | 148 | % Specify generative model and initialise variables 149 | 150 | rng('shuffle') 151 | 152 | clear 153 | 154 | % priors 155 | D = [.5 .5]'; 156 | 157 | % likelihood mapping 158 | A = [.9 .1; 159 | .1 .9]; 160 | 161 | % transitions 162 | B = [1 0; 163 | 0 1]; 164 | 165 | % number of timesteps 166 | T = 2; 167 | 168 | % number of iterations of message passing 169 | NumIterations = 16; 170 | 171 | % initialize posterior (Step 1) 172 | for t = 1:T 173 | Qs(:,t) = [.5 .5]'; 174 | end 175 | 176 | % fix observations sequentially (Step 2) 177 | o{1,1} = [1 0]'; 178 | o{1,2} = [0 0]'; 179 | o{2,1} = [1 0]'; 180 | o{2,2} = [1 0]'; 181 | 182 | % Message Passing 183 | 184 | for t = 1:T 185 | for Ni = 1:NumIterations % (Step 8 loop of VMP) 186 | for tau = 1:T % (Step 7 loop of VMP) 187 | 188 | % initialise depolarization variable: v = ln(s) 189 | % choose an edge (Step 3 of VMP) 190 | v = nat_log(Qs(:,t)); 191 | 192 | % get correct D and B for each time point (Steps 4-5 of VMP) 193 | % using using the posterior computed in Step 6B 194 | if tau == 1 % first time point 195 | % past (Message 1) 196 | lnD = nat_log(D); 197 | 198 | % future (Message 2) 199 | lnBs = nat_log(B'*Qs(:,tau+1)); 200 | elseif tau == T % last time point 201 | % no contribution from future (only Message 1) 202 | lnBs = nat_log(B*Qs(:,tau-1)); 203 | end 204 | % likelihood (Message 3) 205 | lnAo = nat_log(A'*o{t,tau}); 206 | 207 | % calculate state prediction error: equation 24 208 | if tau == 1 209 | epsilon(:,Ni,t,tau) = .5*lnD + .5*lnBs + lnAo - v; 210 | elseif tau == T 211 | epsilon(:,Ni,t,tau) = .5*lnBs + lnAo - v; 212 | end 213 | 214 | % (Step 6 of VMP) 215 | % update depolarization variable: equation 25 216 | v = v + epsilon(:,Ni,t,tau); 217 | % normalize using a softmax function to find posterior: 218 | % equation 26 (Step 6A of VMP) 219 | Qs(:,tau) = (exp(v)/sum(exp(v))); 220 | % store Qs for firing rate plots 221 | xn(Ni,:,tau,t) = Qs(:,tau); 222 | end % Repeat for remaining edges (Step 7 of VMP) 223 | end % Repeat until convergence/for number of iterations (Step 8 of VMP) 224 | end 225 | 226 | Qs; % final posterior beliefs over states 227 | 228 | disp(' '); 229 | disp('Posterior over states q(s) in example 2:'); 230 | disp(' '); 231 | disp(Qs); 232 | 233 | % plots 234 | 235 | % get firing rates into usable format 236 | num_states = 2; 237 | num_epochs = 2; 238 | time_tau = [1 2 1 2; 239 | 1 1 2 2]; 240 | for t_tau = 1:size(time_tau,2) 241 | for epoch = 1:num_epochs 242 | % firing rate 243 | firing_rate{epoch,t_tau} = xn(:,time_tau(1,t_tau),time_tau(2,t_tau),epoch); 244 | ERP{epoch,t_tau} = gradient(firing_rate{epoch,t_tau}')'; 245 | end 246 | end 247 | 248 | % convert cells to matrices 249 | firing_rate = spm_cat(firing_rate)'; 250 | firing_rate = [zeros(length(D)*T,1)+[D; D] full(firing_rate)]; % add prior for starting value 251 | ERP = spm_cat(ERP); 252 | ERP = [zeros(length(D)*T,1)'; ERP]; % add 0 for starting value 253 | 254 | figure 255 | 256 | % firing rates 257 | imagesc(64*(1 - firing_rate)) 258 | cmap = gray(256); 259 | colormap(cmap) 260 | title('Example 2: Firing rates (Darker = higher value)') 261 | ylabel('Firing rate','FontSize',12) 262 | xlabel('Message passing iterations','FontSize',12) 263 | 264 | figure 265 | 266 | % firing rates (traces) 267 | plot(firing_rate') 268 | title('Example 2: Firing rates (traces)') 269 | ylabel('Firing rate','FontSize',12) 270 | xlabel('Message passing iterations','FontSize',12) 271 | 272 | figure 273 | 274 | % ERPs/LFPs 275 | plot(ERP) 276 | title('Example 2: Event-related potentials') 277 | ylabel('Response','FontSize',12) 278 | xlabel('Message passing iterations','FontSize',12) 279 | 280 | %% functions 281 | 282 | % natural log that replaces zero values with very small values for numerical reasons. 283 | function y = nat_log(x) 284 | y = log(x+exp(-16)); 285 | end 286 | -------------------------------------------------------------------------------- /original_matlab_code/Pencil_and_paper_exercise_solutions.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %-- Code/solutions for pencil and paper exercises --% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 6 | % Application to Empirical Data 7 | 8 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 9 | 10 | % Note to readers: be sure to run sections individually 11 | 12 | 13 | %% Static perception 14 | 15 | clear 16 | close all 17 | rng('default') 18 | 19 | % priors 20 | D = [.75 .25]'; 21 | 22 | % likelihood mapping 23 | A = [.8 .2; 24 | .2 .8]; 25 | 26 | % observations 27 | o = [1 0]'; 28 | 29 | % express generative model in terms of update equations 30 | lns = nat_log(D) + nat_log(A'*o); 31 | 32 | % normalize using a softmax function to find posterior 33 | s = (exp(lns)/sum(exp(lns))); 34 | 35 | disp('Posterior over states q(s):'); 36 | disp(' '); 37 | disp(s); 38 | 39 | % Note: Because the natural log of 0 is undefined, for numerical reasons 40 | % the nat_log function here replaces zero values with very small values. This 41 | % means that the answers generated by this function will vary slightly from 42 | % the exact solutions shown in the text. 43 | 44 | return 45 | 46 | %% Dynamic perception 47 | 48 | clear 49 | close all 50 | rng('default') 51 | 52 | % priors 53 | D = [.5 .5]'; 54 | 55 | % likelihood mapping 56 | A = [.9 .1; 57 | .1 .9]; 58 | 59 | % transitions 60 | B = [1 0; 61 | 0 1]; 62 | 63 | % observations 64 | o{1,1} = [1 0]'; 65 | o{1,2} = [0 0]'; 66 | o{2,1} = [1 0]'; 67 | o{2,2} = [1 0]'; 68 | 69 | % number of timesteps 70 | T = 2; 71 | 72 | % initialise posterior 73 | for t = 1:T 74 | Qs(:,t) = [.5 .5]'; 75 | end 76 | 77 | for t = 1:T 78 | for tau = 1:T 79 | % get correct D and B for each time point 80 | if tau == 1 % first time point 81 | lnD = nat_log(D);% past 82 | lnBs = nat_log(B'*Qs(:,tau+1));% future 83 | elseif tau == T % last time point 84 | lnBs = nat_log(B'*Qs(:,tau-1));% no contribution from future 85 | end 86 | % likelihood 87 | lnAo = nat_log(A'*o{t,tau}); 88 | % update equation 89 | if tau == 1 90 | lns = .5*lnD + .5*lnBs + lnAo; 91 | elseif tau == T 92 | lns = .5*lnBs + lnAo; 93 | end 94 | % normalize using a softmax function to find posterior 95 | Qs(:,tau) = (exp(lns)/sum(exp(lns))) 96 | end 97 | end 98 | 99 | Qs % final posterior beliefs over states 100 | 101 | disp('Posterior over states q(s):'); 102 | disp(' '); 103 | disp(Qs); 104 | 105 | %% functions 106 | 107 | % natural log that replaces zero values with very small values for numerical reasons. 108 | function y = nat_log(x) 109 | y = log(x+.01); 110 | end 111 | -------------------------------------------------------------------------------- /original_matlab_code/Prediction_error_example.m: -------------------------------------------------------------------------------- 1 | %% Example code for simulating state and outcome prediction errors 2 | 3 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 4 | % Application to Empirical Data 5 | 6 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 7 | 8 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 9 | clear 10 | close all 11 | %% set up model to calculate state prediction errors 12 | % This minimizes variational free energy (keeps posterior beliefs accurate 13 | % while also keeping them as close as possible to prior beliefs) 14 | 15 | A = [.8 .4; 16 | .2 .6]; % Likelihood 17 | 18 | B_t1 = [.9 .2; 19 | .1 .8]; % Transition prior from previous timestep 20 | 21 | B_t2 = [.2 .3; 22 | .8 .7]; % Transition prior from current timestep 23 | 24 | o = [1 0]'; % Observation 25 | 26 | s_pi_tau = [.5 .5]'; % Prior distribution over states. Note that we here 27 | % use the same value for s_pi_tau-1, s_pi_tau, and 28 | % s_pi_tau+1. But this need not be the case. 29 | 30 | s_pi_tau_minus_1 = [.5 .5]'; 31 | 32 | s_pi_tau_plus_1 = [.5 .5]'; 33 | 34 | v_0 = log(s_pi_tau); % Depolarization term (initial value) 35 | 36 | B_t2_cross_intermediate = B_t2'; % Transpose B_t2 37 | 38 | B_t2_cross = spm_softmax(B_t2_cross_intermediate); % Normalize columns in transposed B_t2 39 | 40 | %% Calculate state prediction error (single iteration) 41 | 42 | state_error = 1/2*(log(B_t1*s_pi_tau_minus_1)+log(B_t2_cross*s_pi_tau_plus_1))... 43 | +log(A'*o)-log(s_pi_tau); % state prediction error 44 | 45 | v = v_0 + state_error; % Depolarization 46 | 47 | s = (exp(v)/sum(exp(v))); % Updated distribution over states 48 | 49 | 50 | disp(' '); 51 | disp('Prior Distribution over States:'); 52 | disp(s_pi_tau); 53 | disp(' '); 54 | disp('State Prediction Error:'); 55 | disp(state_error); 56 | disp(' '); 57 | disp('Depolarization:'); 58 | disp(v); 59 | disp(' '); 60 | disp('Posterior Distribution over States:'); 61 | disp(s); 62 | disp(' '); 63 | 64 | return 65 | %% set up model to calculate outcome prediction errors 66 | % This minimizes expected free energy (maximizes reward and 67 | % information-gain) 68 | 69 | clear 70 | close all 71 | 72 | % Calculate risk (reward-seeking) term under two policies 73 | 74 | A = [.9 .1; 75 | .1 .9]; % Likelihood 76 | 77 | S1 = [.9 .1]'; % States under policy 1 78 | S2 = [.5 .5]'; % States under policy 2 79 | 80 | C = [1 0]'; % Preferred outcomes 81 | 82 | o_1 = A*S1; % Predicted outcomes under policy 1 83 | o_2 = A*S2; % Predicted outcomes under policy 2 84 | z = exp(-16); % Small number added to preference distribution to avoid log(0) 85 | 86 | risk_1 = dot(o_1,log(o_1) - log(C+z)); % Risk under policy 1 87 | 88 | risk_2 = dot(o_2,log(o_2) - log(C+z)); % Risk under policy 2 89 | 90 | disp(' '); 91 | disp('Risk Under Policy 1:'); 92 | disp(risk_1); 93 | disp(' '); 94 | disp('Risk Under Policy 2:'); 95 | disp(risk_2); 96 | disp(' '); 97 | 98 | 99 | % Calculate ambiguity (information-seeking) term under two policies 100 | 101 | A = [.4 .2; 102 | .6 .8]; % Likelihood 103 | 104 | s1 = [.9 .1]'; % States under policy 1 105 | s2 = [.1 .9]'; % States under policy 2 106 | 107 | 108 | ambiguity_1 = -dot(diag(A'*log(A)),s1); % Ambiguity under policy 1 109 | 110 | ambiguity_2 = -dot(diag(A'*log(A)),s2); % Ambiguity under policy 2 111 | 112 | disp(' '); 113 | disp('Ambiguity Under Policy 1:'); 114 | disp(ambiguity_1); 115 | disp(' '); 116 | disp('Ambiguity Under Policy 2:'); 117 | disp(ambiguity_2); 118 | disp(' '); 119 | -------------------------------------------------------------------------------- /original_matlab_code/Step_by_Step_Hierarchical_Model.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %-- Hierarchical Model Tutorial --% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 6 | % Application to Empirical Data 7 | 8 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 9 | 10 | % Step by step tutorial for building hierarchical POMDPs using the active 11 | % inference framework. Here we simulate the now classic "Local Global" auditory 12 | % mismatch paradigm. This will reproduce results similar to figs. 14-16. 13 | 14 | clear 15 | close all 16 | 17 | %% Level 1: Perception of individual stimuli 18 | %========================================================================== 19 | 20 | % prior beliefs about initial states 21 | %-------------------------------------------------------------------------- 22 | 23 | D{1} = [1 1]';% stimulus tone {high, low} 24 | 25 | d = D; 26 | 27 | % Here we seperate the generative process (the capital D) 28 | % from the generative model (the lower case d) allowing learning to occur 29 | % (i.e. to acccumulate concentration paramaters) in the generative model, 30 | % independent of the generative process. 31 | 32 | % probabilistic (likelihood) mapping from hidden states to outcomes: A 33 | %-------------------------------------------------------------------------- 34 | 35 | % outcome modality 1: stimulus tone 36 | A{1}= [1 0; %high tone 37 | 0 1];%low tone 38 | 39 | % seperate generative model from generative process 40 | a = A; 41 | 42 | % reduce precision 43 | pr1 = 2; % precision (inverse termperature) parameter (lower = less precise) 44 | a{1} = spm_softmax(pr1*log(A{1}+exp(-4))); 45 | 46 | a = a{1}*100; 47 | 48 | % By passing the a matrix through a softmax function with a precision paramater of 2 49 | % we slightly reduce the precision of the generative model, analagous to introducing 50 | % a degree of noise into our model of tone perception. We then multiply it 51 | % by 100 so that the level of noise stays constant across trials. 52 | 53 | % Transitions between states: B 54 | %-------------------------------------------------------------------------- 55 | 56 | B{1}= [1 0; %high tone 57 | 0 1];%low tone 58 | 59 | % MDP Structure 60 | %-------------------------------------------------------------------------- 61 | mdp_1.T = 1; % number of updates 62 | mdp_1.A = A; % likelihood mapping 63 | mdp_1.B = B; % transition probabilities 64 | mdp_1.D = D; % prior over initial states 65 | mdp_1.d = d; 66 | mdp_1.a = a; 67 | mdp_1.erp = 1; 68 | 69 | mdp_1.Aname = {'Stimulus'}; 70 | mdp_1.Bname = {'Stimulus'}; 71 | 72 | clear a d A B D 73 | 74 | MDP_1 = spm_MDP_check(mdp_1); 75 | 76 | clear mdp_1 77 | 78 | %% Level 2: Slower-timescale representations of perceived stimulus sequences 79 | %========================================================================== 80 | 81 | % prior beliefs about initial states in generative process (D) and 82 | % generative model (d) in terms of counts (i.e., concentration parameters) 83 | %-------------------------------------------------------------------------- 84 | D2{1} = [1 1 1 1]'; % Sequence type: {high, low, high-low, low-high} 85 | D2{2} = [1 0 0 0 0 0]'; % time in trial 86 | D2{3} = [1 0 0]'; % Report: {null, same, different} 87 | 88 | d2 = D2; 89 | d2{2} = d2{2}*100; 90 | d2{3} = d2{3}*100; 91 | 92 | % Again, we here seperate the generative model from the generative process, 93 | % and multiply d2{2} and d2{3} by 100 to prevent learning in the model's 94 | % representation of task phase (time in trial) and report state probabilities. 95 | 96 | % probabilistic (likelihood) mapping from hidden states to outcomes: A 97 | %-------------------------------------------------------------------------- 98 | 99 | % outcomes: A{1} stim (2), A{2} Report Feedback (3) 100 | 101 | %--- Stimulus 102 | for i = 1:6 103 | for j = 1:3 104 | A2{1}(:,:,i,j) = [1 0 1 0;%high 105 | 0 1 0 1];%low 106 | end 107 | end 108 | 109 | % oddball at fourth timestep 110 | for i = 4 111 | for j = 1:3 112 | A2{1}(:,:,i,j) = [1 0 0 1;%high 113 | 0 1 1 0];%low 114 | end 115 | end 116 | 117 | %--- Report 118 | for i = 1:6 119 | for j = 1:3 120 | A2{2}(:,:,i,j) = [1 1 1 1; %null 121 | 0 0 0 0; %incorrect 122 | 0 0 0 0];%correct 123 | end 124 | end 125 | 126 | % report "same" 127 | for i = 6 128 | for j = 2 129 | A2{2}(:,:,i,j) = [0 0 0 0; %null 130 | 0 0 1 1; %incorrect 131 | 1 1 0 0];%correct 132 | end 133 | end 134 | 135 | % report "different" 136 | for i = 6 137 | for j = 3 138 | A2{2}(:,:,i,j) = [0 0 0 0; %null 139 | 1 1 0 0; %incorrect 140 | 0 0 1 1];%correct 141 | end 142 | end 143 | 144 | a2 = A2; % likelihood (concentration parameters) for generative model 145 | 146 | % reduce precision 147 | pr2 = 2; % precision (inverse termperature) parameter (lower = less precise) 148 | a2{1} = spm_softmax(pr2*log(A2{1}+exp(-4))); 149 | 150 | a2{1} = a2{1}*100; 151 | a2{2} = a2{2}*100; 152 | 153 | % Transition probabilities: B 154 | %-------------------------------------------------------------------------- 155 | 156 | % Precision of sequence mapping 157 | B2{1} = eye(4,4); % maximally precise identity matrix (i.e., the true 158 | % sequence is stable within a trial) 159 | 160 | B2{2} = [0 0 0 0 0 0; 161 | 1 0 0 0 0 0; 162 | 0 1 0 0 0 0; 163 | 0 0 1 0 0 0; 164 | 0 0 0 1 0 0; 165 | 0 0 0 0 1 1]; % Deterministically transition through trial sequence 166 | 167 | % Report 168 | B2{3}(:,:,1) = [1 1 1; 169 | 0 0 0; 170 | 0 0 0]; % Pre-report 171 | B2{3}(:,:,2) = [0 0 0; 172 | 1 1 1; 173 | 0 0 0]; % Report "same" 174 | B2{3}(:,:,3) = [0 0 0; 175 | 0 0 0; 176 | 1 1 1]; % Report "different" 177 | 178 | % Policies 179 | %-------------------------------------------------------------------------- 180 | 181 | T = 6; % number of timesteps 182 | Nf = 3; % number of factors 183 | Pi = 2; % number of policies 184 | V2 = ones(T-1,Pi,Nf); 185 | 186 | % Report: "same" (left column) or "different" (right column) 187 | V2(:,:,3) = [1 1; 188 | 1 1; 189 | 1 1; 190 | 1 1; 191 | 2 3]; 192 | 193 | % C matrices (outcome modality by timestep) 194 | %-------------------------------------------------------------------------- 195 | C2{1} = zeros(2,T); 196 | 197 | % report 198 | C2{2} = [0 0 0 0 0 0; % no feedback yet 199 | 0 0 0 0 0 -1; % preference not to be incorrect at last timestep 200 | 0 0 0 0 0 1]; % preference for being correct at last timestep 201 | 202 | % MDP Structure 203 | %-------------------------------------------------------------------------- 204 | mdp.MDP = MDP_1; 205 | mdp.link = [1 0]; % identifies lower level state factors (rows) with higher 206 | % level observation modalities (columns). Here this means the 207 | % first observation at the higher level corresponds to 208 | % the first state factor at the lower level. 209 | 210 | mdp.T = T; % number of time points 211 | mdp.A = A2; % likelihood mapping for generative process 212 | mdp.a2 = a2; % likelihood mapping for generative model 213 | mdp.B = B2; % transition probabilities 214 | mdp.C = C2; % preferred outcomes 215 | mdp.D = D2; % priors over initial states for generative process 216 | mdp.d = d2; % priors over initial states for generative model 217 | mdp.V = V2; % policies 218 | mdp.erp = 1; % reset/decay paramater 219 | 220 | mdp.Aname = {'Stimulus', 'Report Feedback'}; 221 | mdp.Bname = {'Sequence', 'Time in trial', 'Report'}; 222 | 223 | 224 | % level one labels 225 | label.factor{1} = 'Stimulus'; label.name{1} = {'High','Low'}; 226 | label.modality{1} = 'Stimulus'; label.outcome{1} = {'High','Low'}; 227 | mdp.MDP.label = label; 228 | 229 | label.factor{1} = 'Sequence type'; label.name{1} = {'High','Low','High-low','Low-high'}; 230 | label.factor{2} = 'Time in trial'; label.name{2} = {'T1', 'T2', 'T3', 'T4', 'T5', 'T6'}; 231 | label.factor{3} = 'Report'; label.name{3} = {'Null', 'Same', 'Different'}; 232 | label.modality{1} = 'Tone'; label.outcome{1} = {'High', 'Low'}; 233 | label.modality{2} = 'Feedback'; label.outcome{2} = {'Null','Incorrect','Correct'}; 234 | label.action{3} = {'Null','Same','Different'}; 235 | mdp.label = label; 236 | 237 | mdp = spm_MDP_check(mdp); 238 | MDP = spm_MDP_VB_X_tutorial(mdp); 239 | 240 | %Plot trial 241 | spm_figure('GetWin','trial'); clf 242 | spm_MDP_VB_trial(MDP); 243 | 244 | %% Simulate all conditions 245 | 246 | % Here we specify the number of trials N and use a deal function (which copies 247 | % the input to N outputs) to create 10 identical mdp structures. We can 248 | % then pass this to the spm_MDP_VB_X_tutorial() script, which sequentially updates 249 | % the concentration paramaters aquired on each trial and passes them to the 250 | % mdp structure for the next trial (allowing learning to occur). 251 | 252 | N = 10; %number of trials 253 | 254 | % Local deviation - global standard 255 | mdp.s = 3; % first nine trials are high-low 256 | MDP_condition1(1:N) = deal(mdp); 257 | MDP_condition1(10).s = 3; % tenth trial is also high-low 258 | MDP_LDGS = spm_MDP_VB_X_tutorial(MDP_condition1); 259 | 260 | % Local standard - global deviation 261 | mdp.s = 3; % first nine trials are high-low 262 | MDP_condition2(1:N) = deal(mdp); 263 | MDP_condition2(10).s = 1; % tenth trial is a high trial 264 | MDP_LSGD = spm_MDP_VB_X_tutorial(MDP_condition2); 265 | 266 | %% Plot ERPs using standard routines for each of the four conditions 267 | 268 | % These are slightly modified versions of the standard plotting scripts 269 | % given in the SPM software. 270 | 271 | spm_figure('GetWin','ERP T1 - Local deviation - global standard'); clf 272 | spm_MDP_VB_ERP_tutorial(MDP_LDGS(1)); 273 | spm_figure('GetWin','Trial T1 - Local deviation - global standard'); clf 274 | spm_MDP_VB_trial(MDP_LDGS(1)); 275 | spm_figure('GetWin','ERP T10 - Local deviation - global standard'); clf 276 | spm_MDP_VB_ERP_tutorial(MDP_LDGS(10)); 277 | spm_figure('GetWin','Trial T10 - Local deviation - global standard'); clf 278 | spm_MDP_VB_trial(MDP_LDGS(10)); 279 | 280 | spm_figure('GetWin','ERP T1 - Local standard - global deviation'); clf 281 | spm_MDP_VB_ERP_tutorial(MDP_LSGD(1)); 282 | spm_figure('GetWin','Trial T1 - Local standard - global deviation'); clf 283 | spm_MDP_VB_trial(MDP_LSGD(1)); 284 | spm_figure('GetWin','ERP T10 - Local standard - global deviation'); clf 285 | spm_MDP_VB_ERP_tutorial(MDP_LSGD(10)); 286 | spm_figure('GetWin','Trial T10 - Local standard - global deviation'); clf 287 | spm_MDP_VB_trial(MDP_LSGD(10)); 288 | 289 | %% custom ERP plots 290 | 291 | % The ERP plotting routines give three outputs: 292 | % [level 2 ERPs, level 1 ERPs, indices] 293 | % There are 32 time indices per time step/epoch of gradient decent. Here 294 | % there are 6 timesteps so there are 32x6 = 192 individual time indexes. 295 | % The level 1 and 2 ERPs are the first derivitives at each time index. 296 | 297 | [u1_1,v1_1,ind] = spm_MDP_VB_ERP_tutorial(MDP_LDGS(1),1); 298 | [u1_10,v1_10] = spm_MDP_VB_ERP_tutorial(MDP_LDGS(10),1); 299 | 300 | [u2_1,v2_1] = spm_MDP_VB_ERP_tutorial(MDP_LSGD(1),1); 301 | [u2_10,v2_10] = spm_MDP_VB_ERP_tutorial(MDP_LSGD(10),1); 302 | 303 | % The indexes below are arbitarily chosen to best represent the ERPs at the 304 | % 4th time step, which starts at 96ms and ends at 128ms. To do this for 305 | % yourself we recommend just plotting the ERPs and selecting the appropiate 306 | % time window. For example, the 1st level ERPs start at the begining of 307 | % the epoch whereas the 2nd ERPs appear towards the end of the epoch. So to 308 | % include baseline periods in the plot you will likley have to select 309 | % slightly different time windows for each level as we have done here. 310 | 311 | % index into 2nd level 312 | index = (96:140); 313 | u1_1 = u1_1(index,:); % level 2 314 | u1_10 = u1_10(index,:); 315 | 316 | u2_1 = u2_1(index,:);% level 2 317 | u2_10 = u2_10(index,:); 318 | 319 | % index into ist level 320 | index = (70:120); 321 | v1_1 = v1_1(index,:);% level 1 322 | v1_10 = v1_10(index,:); 323 | v2_1 = v2_1(index,:);% level 1 324 | v2_10 = v2_10(index,:); 325 | 326 | time_low = (1:length(v1_1)); 327 | time_high = (1:length(u1_1)); 328 | 329 | %--- Lets make the plots! 330 | 331 | % low level plot 332 | limits = [20 45 -.5 1.2]; 333 | 334 | figure(10) 335 | hold on 336 | plot(time_low,sum(v2_10,2),'b','LineWidth',4) % local standard 337 | plot(time_low,sum(v1_10,2),'r','LineWidth',4) % local deviation 338 | axis(limits) 339 | set(gca,'FontSize',10) 340 | title('Mismatch negativity') 341 | legend('Local standard', 'Local deviation') 342 | 343 | % high level plot 344 | limits = [1 45 -.5 .5]; 345 | 346 | figure(11) 347 | hold on 348 | plot(time_high,sum(u1_10,2),'b','LineWidth',4) % Global standard 349 | plot(time_high,sum(u2_10,2),'r','LineWidth',4) % Global deviation 350 | axis(limits) 351 | set(gca,'FontSize',10) 352 | title('P300') 353 | legend('Global standard', 'Global deviation') 354 | 355 | % MMN (standard - mismatch) 356 | limits = [20 45 -1.2 .5]; 357 | 358 | figure(12) 359 | hold on 360 | plot(time_low,sum(v2_10-v1_10,2),'k','LineWidth',4) 361 | axis(limits) 362 | set(gca,'FontSize',10) 363 | title('Mismatch negativity: local standard - local deviation') 364 | 365 | % P300 (standard - mismatch) 366 | limits = [1 45 -.5 .5]; 367 | 368 | figure(13) 369 | hold on 370 | plot(time_high,sum(u1_10-u2_10,2),'k','LineWidth',4) 371 | axis(limits) 372 | set(gca,'FontSize',10) 373 | title('P300: Global standard - Global deviation') 374 | 375 | -------------------------------------------------------------------------------- /original_matlab_code/VFE_calculation_example.m: -------------------------------------------------------------------------------- 1 | %% Variational free energy calculation example 2 | 3 | % Supplementary Code for: A Step-by-Step Tutorial on Active Inference Modelling and its 4 | % Application to Empirical Data 5 | 6 | % By: Ryan Smith, Karl J. Friston, Christopher J. Whyte 7 | 8 | clear all 9 | 10 | True_observation = [1 0]'; % Set observation; Note that this could be set 11 | % to include more observations. For example, 12 | % it could be set to [0 0 1]' to present a third 13 | % observation. Note that this would require 14 | % adding a corresponding third row to the 15 | % Likelihood matrix below to specify the 16 | % probabilities of the third observation under 17 | % each state. One could similarly add a third 18 | % state by adding a third entry into the Prior 19 | % and a corresponding third column into the 20 | % likelihood. 21 | 22 | %% Generative Model 23 | 24 | % Specify Prior and likelihood 25 | 26 | Prior = [.5 .5]'; % Prior distribution p(s) 27 | 28 | Likelihood = [.8 .2; 29 | .2 .8]; % Likelihood distribution p(o|s): columns=states, 30 | % rows = observations 31 | 32 | Likelihood_of_observation = Likelihood'*True_observation; 33 | 34 | Joint_probability = Prior.*Likelihood_of_observation; % Joint probability 35 | % distribution p(o,s) 36 | 37 | Marginal_probability = sum(Joint_probability,1); % Marginal observation 38 | % probabilities p(o) 39 | %% Bayes theorem: exact posterior 40 | 41 | % This is the distribution we want to approximate using variational 42 | % inference. In many practical applications, we can not solve for this 43 | % directly. 44 | 45 | Posterior = Joint_probability... 46 | /Marginal_probability; % Posterior given true observation p(s|o) 47 | 48 | disp(' '); 49 | disp('Exact Posterior:'); 50 | disp(Posterior); 51 | disp(' '); 52 | 53 | %% Variational Free Energy 54 | 55 | % Note: q(s) = approximate posterior belief: we want to get this as close as 56 | % possible to the true posterior p(s|o) after a new observation. 57 | 58 | % Different decompisitions of Free Energy (F) 59 | 60 | % 1. F=E_q(s)[ln(q(s)/p(o,s))] 61 | 62 | % 2. F=E_q(s)[ln(q(s)/p(s))] - E_q(s)[ln(p(o|s))] % Complexity-accuracy 63 | % version 64 | 65 | % The first term can be interpreted as a complexity term (the KL divergence 66 | % between prior beliefs p(s) and approximate posterior beliefs q(s)). In 67 | % other words, how much beliefs have changed after a bew observation. 68 | 69 | % The second term (excluding the minus sign) is the accuracy or (including the 70 | % minus sign) the entropy (= expected surprisal) of observations given 71 | % approximate posterior beliefs q(s). Written in this way 72 | % free-energy-minimisation is equivalent to a statistical Occam's razor, 73 | % where the agent tries to find the most accurate posterior belief that also 74 | % changes its beliefs as little as possible. 75 | 76 | % 3. F=E_q(s)[ln(q(s)) - ln(p(s|o)p(o))] 77 | 78 | % 4. F=E_q(s)[ln(q(s)/p(s|o))] - ln(p(o)) 79 | 80 | % These two versions similarly show F in terms of a difference between 81 | % q(s) and the true posterior p(s|o). Here we focus on #4. 82 | 83 | % The first term is the KL divergence between the approximate posterior q(s) 84 | % and the unknown exact posterior p(s|o), also called the relative entropy. 85 | 86 | % The second term (excluding the minus sign) is the log evidence or (including 87 | % the minus sign) the surprisal of observations. Note that ln(p(o)) does 88 | % not depend on q(s), so its expectation value under q(s) is simply ln(p(o)). 89 | 90 | % Since this term does not depend on q(s), minimizing free energy means that 91 | % q(s) comes to approximate p(s|o), which is our unknown, desired quantity. 92 | 93 | % 5. F=E_q(s)[ln(q(s))-ln(p(o|s)p(s))] 94 | 95 | % We will use this decomposition for convenience when doing variational 96 | % inference below. Note how this decomposition is equivalent to the expression 97 | % shown in Figure 3 - F=E_q(s)(ln(q(s)/p(o,s)) - because ln(x)-ln(y) = ln(x/y) 98 | % and p(o|s)p(s)=p(o,s) 99 | 100 | %% Variational inference 101 | 102 | Initial_approximate_posterior = Prior; % Initial approximate posterior distribution. 103 | % Set this to match generative model prior 104 | 105 | % Calculate F 106 | Initial_F = Initial_approximate_posterior(1)*(log(Initial_approximate_posterior(1))... 107 | -log(Joint_probability(1)))+Initial_approximate_posterior(2)... 108 | *(log(Initial_approximate_posterior(2))-log(Joint_probability(2))); 109 | 110 | Optimized_approximate_posterior = Posterior; % Set approximate distribution to true posterior 111 | 112 | % Calculate F 113 | Minimized_F = Optimized_approximate_posterior(1)*(log(Optimized_approximate_posterior(1))... 114 | -log(Joint_probability(1)))+Optimized_approximate_posterior(2)... 115 | *(log(Optimized_approximate_posterior(2))-log(Joint_probability(2))); 116 | 117 | % We see that F is lower when the approximate posterior q(s) is closer to 118 | % the true distribution p(s|o) 119 | 120 | disp(' '); 121 | disp('Initial Approximate Posterior:'); 122 | disp(Initial_approximate_posterior); 123 | disp(' '); 124 | 125 | disp(' '); 126 | disp('Initial Variational Free Energy:'); 127 | disp(Initial_F); 128 | disp(' '); 129 | 130 | disp(' '); 131 | disp('Optimized Approximate Posterior:'); 132 | disp(Optimized_approximate_posterior); 133 | disp(' '); 134 | 135 | disp(' '); 136 | disp('Minimized Variational Free Energy:'); 137 | disp(Minimized_F); 138 | disp(' '); -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_MDP_VB_LFP.m: -------------------------------------------------------------------------------- 1 | function [v] = spm_MDP_VB_LFP(MDP,UNITS,f,SPECTRAL) 2 | % auxiliary routine for plotting simulated electrophysiological responses 3 | % FORMAT [v] = spm_MDP_VB_LFP(MDP,UNITS,FACTOR,SPECTRAL) 4 | % 5 | % MDP - structure (see spm_MDP_VB_X.m) 6 | % .xn - neuronal firing 7 | % .dn - phasic dopamine responses 8 | % 9 | % UNITS(1,:) - hidden state [default: all] 10 | % UNITS(2,:) - time step 11 | % 12 | % FACTOR - hidden factor to plot [default: 1] 13 | % SPECTRAL - replace raster with spectral responses [default: 0] 14 | % 15 | % v - selected unit responses {number of trials, number of units} 16 | % 17 | % This routine plots simulated electrophysiological responses. Graphics are 18 | % provided in terms of simulated spike rates (posterior expectations). 19 | % 20 | % see also: spm_MDP_VB_ERP (for hierarchical belief updating) 21 | %__________________________________________________________________________ 22 | 23 | % Karl Friston 24 | % Copyright (C) 2008-2022 Wellcome Centre for Human Neuroimaging 25 | 26 | % check for simulated neuronal responses 27 | %-------------------------------------------------------------------------- 28 | if ~isfield(MDP(1),'xn') 29 | warning ('please use another inversion scheme that simulates neuronal responses (e.g., spm_MDP_VB_XX)') 30 | return 31 | end 32 | 33 | % defaults 34 | %========================================================================== 35 | try, f; catch, f = 1; end 36 | try, UNITS; catch, UNITS = []; end 37 | try, SPECTRAL; catch, SPECTRAL = 0; end 38 | try, MDP = spm_MDP_check(MDP); end 39 | 40 | % dimensions 41 | %-------------------------------------------------------------------------- 42 | Nt = length(MDP); % number of trials 43 | try 44 | Ne = size(MDP(1).xn{f},4); % number of epochs 45 | Nx = size(MDP(1).B{f}, 1); % number of states 46 | Nb = size(MDP(1).xn{f},1); % number of time bins per epochs 47 | catch 48 | Ne = size(MDP(1).xn,4); % number of epochs 49 | Nx = size(MDP(1).A, 2); % number of states 50 | Nb = size(MDP(1).xn,1); % number of time bins per epochs 51 | end 52 | 53 | % units to plot 54 | %-------------------------------------------------------------------------- 55 | ALL = []; 56 | for i = 1:Ne 57 | for j = 1:Nx 58 | ALL(:,end + 1) = [j;i]; 59 | end 60 | end 61 | if size(ALL,2) > 512 62 | ii = round(linspace(1,size(ALL,2),512)); 63 | ALL = ALL(:,ii); 64 | end 65 | if isempty(UNITS) 66 | UNITS = ALL; 67 | end 68 | ii = 1:size(ALL,2); 69 | 70 | % summary statistics: firing rates 71 | %========================================================================== 72 | for i = 1:Nt 73 | 74 | % for all units 75 | %---------------------------------------------------------------------- 76 | str = {}; 77 | try 78 | xn = MDP(i).xn{f}; 79 | catch 80 | xn = MDP(i).xn; 81 | end 82 | for j = 1:size(ALL,2) 83 | for k = 1:Ne 84 | z{i,1}{k,j} = xn(:,ALL(1,j),ALL(2,j),k); 85 | end 86 | str{j} = sprintf('%s: t=%i',MDP(1).label.name{f}{ALL(1,j)},ALL(2,j)); 87 | end 88 | 89 | % for selected units 90 | %---------------------------------------------------------------------- 91 | for j = 1:size(UNITS,2) 92 | for k = 1:Ne 93 | v{i,1}{k,j} = xn(:,UNITS(1,j),UNITS(2,j),k); 94 | end 95 | end 96 | 97 | 98 | % dopamine or changes in precision 99 | %---------------------------------------------------------------------- 100 | dn(:,i) = mean(MDP(i).dn,2); 101 | 102 | end 103 | 104 | if nargout, return, end 105 | 106 | % phase amplitude coupling 107 | %========================================================================== 108 | dt = 1/64; % time bin (seconds) 109 | t = (1:(Nb*Ne*Nt))*dt; % time (seconds) 110 | Hz = 4:32; % frequency range 111 | n = 1/(4*dt); % window length 112 | w = Hz*(dt*n); % cycles per window 113 | 114 | % simulated firing rates 115 | %-------------------------------------------------------------------------- 116 | z = spm_cat(z)'; % firing rates of all units 117 | v = spm_cat(v)'; % firing rates of selected units 118 | 119 | % bandpass filter log rates between 8 and 32 Hz: local field potential 120 | %-------------------------------------------------------------------------- 121 | c = 1/32; 122 | x = log(z' + c); 123 | u = log(v' + c); 124 | x = spm_conv(x,2,0) - spm_conv(x,16,0); 125 | u = spm_conv(u,2,0) - spm_conv(u,16,0); 126 | 127 | % simulated firing rates 128 | %-------------------------------------------------------------------------- 129 | if Nt == 1, subplot(3,2,1), else, subplot(4,1,1),end 130 | image(t,ii,64*(1 - z)) 131 | title(MDP(1).label.factor{f},'FontSize',16) 132 | xlabel('time (sec)','FontSize',12) 133 | 134 | if numel(str) < 16 135 | grid on, set(gca,'YTick',1:(Ne*Nx)) 136 | set(gca,'YTickLabel',str) 137 | end 138 | grid on, set(gca,'XTick',(1:(Ne*Nt))*Nb*dt) 139 | if Ne*Nt > 32, set(gca,'XTickLabel',{}), end 140 | if Nt == 1, axis square, end 141 | 142 | % time frequency analysis and theta phase 143 | %-------------------------------------------------------------------------- 144 | wft = spm_wft(x,w,n); 145 | csd = sum(abs(wft),3); 146 | lfp = sum(x,2); 147 | phi = spm_iwft(sum(wft(1,:,:),3),w(1),n); 148 | lfp = 4*lfp/std(lfp) + 16; 149 | phi = 4*phi/std(phi) + 16; 150 | 151 | if Nt == 1, subplot(3,2,3), else, subplot(4,1,2),end 152 | imagesc(t,Hz,csd), axis xy, hold on 153 | plot(t,lfp,'w:',t,phi,'w'), hold off 154 | grid on, set(gca,'XTick',(1:(Ne*Nt))*Nb*dt) 155 | 156 | title('Time-frequency response','FontSize',16) 157 | xlabel('time (sec)'), ylabel('frequency (Hz)') 158 | if Nt == 1, axis square, end 159 | 160 | % spectral responses 161 | %-------------------------------------------------------------------------- 162 | if SPECTRAL 163 | 164 | % spectral responses (for each unit) 165 | %---------------------------------------------------------------------- 166 | if Nt == 1, subplot(3,2,1), else, subplot(4,2,1),end 167 | csd = squeeze(sum(abs(wft),2)); 168 | plot(Hz,log(squeeze(csd))) 169 | title('Spectral response','FontSize',16) 170 | xlabel('frequency (Hz)'), 171 | ylabel('log power') 172 | spm_axis tight, box off, axis square 173 | 174 | % amplitude-to-amplitude coupling (average over units) 175 | %---------------------------------------------------------------------- 176 | if Nt == 1, subplot(3,2,2), else, subplot(4,2,2),end 177 | cfc = 0; 178 | for i = 1:size(wft,3) 179 | cfc = cfc + corr((abs(wft(:,:,i)))'); 180 | end 181 | imagesc(Hz,Hz,cfc) 182 | title('Cross-frequency coupling','FontSize',16) 183 | xlabel('frequency (Hz)'), 184 | ylabel('frequency (Hz)') 185 | box off, axis square 186 | 187 | end 188 | 189 | % local field potentials 190 | %========================================================================== 191 | if Nt == 1, subplot(3,2,4), else, subplot(4,1,3),end 192 | plot(t,u), hold off, spm_axis tight, a = axis; 193 | plot(t,x,':'), hold on 194 | plot(t,u), hold off, axis(a) 195 | grid on, set(gca,'XTick',(1:(Ne*Nt))*Nb*dt), 196 | for i = 2:2:Nt 197 | h = patch(((i - 1) + [0 0 1 1])*Ne*Nb*dt,a([3,4,4,3]),-[1 1 1 1],'w'); 198 | set(h,'LineStyle',':','FaceColor',[1 1 1] - 1/32); 199 | end 200 | title('Local field potentials','FontSize',16) 201 | xlabel('time (sec)') 202 | ylabel('response') 203 | if Nt == 1, axis square, end, box off 204 | 205 | % firing rates 206 | %========================================================================== 207 | if Nt == 1, subplot(3,2,2) 208 | plot(t,v), hold on, spm_axis tight, a = axis; 209 | plot(t,z,':'), hold off 210 | grid on, set(gca,'XTick',(1:(Ne*Nt))*Nb*dt), axis(a) 211 | title('Firing rates','FontSize',16) 212 | xlabel('time (sec)') 213 | ylabel('response') 214 | axis square 215 | end 216 | 217 | % simulated dopamine responses (if not a moving policy) 218 | %========================================================================== 219 | if Nt == 1, subplot(3,2,6), else, subplot(4,1,4),end 220 | dn = spm_vec(dn); 221 | dn = dn.*(dn > 0); 222 | dn = dn + (dn + 1/16).*rand(size(dn))/8; 223 | bar(dn,1,'k'), title('Dopamine responses','FontSize',16) 224 | xlabel('time (updates)') 225 | ylabel('change in precision'), spm_axis tight, box off 226 | YLim = get(gca,'YLim'); YLim(1) = 0; set(gca,'YLim',YLim); 227 | if Nt == 1, axis square, end 228 | 229 | % simulated rasters 230 | %========================================================================== 231 | Nr = 16; 232 | if Nt == 1 && numel(ii) < 129 233 | subplot(3,2,5) 234 | 235 | R = kron(z,ones(Nr,Nr)); 236 | R = rand(size(R)) > R*(1 - 1/16); 237 | imagesc(t,1:(Nx*Ne),R),title('Unit firing','FontSize',16) 238 | xlabel('time (sec)','FontSize',12) 239 | 240 | grid on, set(gca,'XTick',(1:(Ne*Nt))*Nb*dt) 241 | grid on, set(gca,'YTick', 1:(Ne*Nx)) 242 | set(gca,'YTickLabel',str), axis square 243 | 244 | end 245 | 246 | 247 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_MDP_VB_trial.m: -------------------------------------------------------------------------------- 1 | function spm_MDP_VB_trial(MDP,gf,gg) 2 | % auxiliary plotting routine for spm_MDP_VB - single trial 3 | % FORMAT spm_MDP_VB_trial(MDP,[f,g]) 4 | % 5 | % MDP.P(M,T) - probability of emitting action 1,...,M at time 1,...,T 6 | % MDP.X - conditional expectations over hidden states 7 | % MDP.R - conditional expectations over policies 8 | % MDP.o - outcomes at time 1,...,T 9 | % MDP.s - states at time 1,...,T 10 | % MDP.u - action at time 1,...,T 11 | % 12 | % MDP.un = un; - simulated neuronal encoding of hidden states 13 | % MDP.xn = Xn; - simulated neuronal encoding of policies 14 | % MDP.wn = wn; - simulated neuronal encoding of precision 15 | % MDP.da = dn; - simulated dopamine responses (deconvolved) 16 | % 17 | % [f,g] - factors and outcomes to plot [Default: first 3] 18 | % 19 | % please see spm_MDP_VB. For multiple trials please see spm_MDP_VB_game 20 | %__________________________________________________________________________ 21 | 22 | % Karl Friston 23 | % Copyright (C) 2008-2022 Wellcome Centre for Human Neuroimaging 24 | 25 | % graphics 26 | %========================================================================== 27 | MDP = spm_MDP_check(MDP); clf 28 | 29 | % numbers of transitions, policies and states 30 | %-------------------------------------------------------------------------- 31 | if iscell(MDP.X) 32 | Nf = numel(MDP.B); % number of hidden state factors 33 | Ng = numel(MDP.A); % number of outcome factors 34 | X = MDP.X; 35 | C = MDP.C; 36 | for f = 1:Nf 37 | Nu(f) = size(MDP.B{f},3) > 1; 38 | end 39 | else 40 | Nf = 1; 41 | Ng = 1; 42 | Nu = 1; 43 | X = {MDP.X}; 44 | C = {MDP.C}; 45 | end 46 | 47 | % factors and outcomes to plot 48 | %-------------------------------------------------------------------------- 49 | maxg = 3; 50 | if nargin < 2, gf = 1:min(Nf,maxg); end 51 | if nargin < 3, gg = 1:min(Ng,maxg); end 52 | nf = numel(gf); 53 | ng = numel(gg); 54 | 55 | % posterior beliefs about hidden states 56 | %-------------------------------------------------------------------------- 57 | for f = 1:nf 58 | subplot(3*nf,2,(f - 1)*2 + 1), hold off 59 | image(64*(1 - X{gf(f)})), hold on 60 | if size(X{gf(f)},1) > 128 61 | spm_spy(X{gf(f)},12,1); 62 | end 63 | a = axis; 64 | if isfield(MDP,'s') 65 | hold on, plot(MDP.s(gf(f),:),'.r','MarkerSize',8), axis(a), hold off 66 | end 67 | if f < 2 68 | title(sprintf('Hidden states - %s',MDP.label.factor{gf(f)})); 69 | else 70 | title(MDP.label.factor{gf(f)}); 71 | end 72 | 73 | set(gca,'XTickLabel',{}); 74 | set(gca,'XTick',1:size(X{1},2)); 75 | 76 | YTickLabel = MDP.label.name{gf(f)}; 77 | if numel(YTickLabel) > 8 78 | i = linspace(1,numel(YTickLabel),8); 79 | YTickLabel = YTickLabel(round(i)); 80 | else 81 | i = 1:numel(YTickLabel); 82 | end 83 | set(gca,'YTick',i); 84 | set(gca,'YTickLabel',YTickLabel); 85 | end 86 | 87 | % posterior beliefs about control states 88 | %-------------------------------------------------------------------------- 89 | Nu = find(Nu); 90 | Np = length(Nu); 91 | for f = 1:Np 92 | subplot(3*Np,2,f*2) 93 | 94 | if iscell(MDP.P) 95 | P = MDP.P{f}; 96 | elseif Nf > 1 97 | ind = 1:Nf; 98 | P = MDP.P; 99 | for dim = 1:Nf 100 | if dim ~= ind(Nu(f)) 101 | P = sum(P,dim); 102 | end 103 | end 104 | P = squeeze(P); 105 | else 106 | P = squeeze(MDP.P); 107 | end 108 | 109 | % display 110 | %---------------------------------------------------------------------- 111 | image(64*(1 - P)) 112 | if isfield(MDP,'u') 113 | hold on, plot(MDP.u(Nu(f),:),'.c','MarkerSize',16), hold off 114 | end 115 | if f < 2 116 | title(sprintf('Action - %s',MDP.label.factor{Nu(f)})); 117 | else 118 | title(MDP.label.factor{Nu(f)}); 119 | end 120 | set(gca,'XTickLabel',{}); 121 | set(gca,'XTick',1:size(X{1},2)); 122 | 123 | YTickLabel = MDP.label.action{Nu(f)}; 124 | if numel(YTickLabel) > 8 125 | i = round(linspace(1,numel(YTickLabel),8)); 126 | YTickLabel = YTickLabel(i); 127 | else 128 | i = 1:numel(YTickLabel); 129 | end 130 | set(gca,'YTick',i); 131 | set(gca,'YTickLabel',YTickLabel); 132 | 133 | % policies 134 | %---------------------------------------------------------------------- 135 | subplot(3*Np,2,(Np + f - 1)*2 + 1) 136 | imagesc(MDP.V(:,:,Nu(f))') 137 | if f < 2 138 | title(sprintf('Allowable policies - %s',MDP.label.factor{Nu(f)})); 139 | else 140 | title(MDP.label.factor{Nu(f)}); 141 | end 142 | if f < Np 143 | set(gca,'XTickLabel',{}); 144 | end 145 | set(gca,'XTick',1:size(X{1},2) - 1); 146 | set(gca,'YTickLabel',{}); 147 | ylabel('policy') 148 | 149 | end 150 | 151 | % expectations over policies 152 | %-------------------------------------------------------------------------- 153 | if isfield(MDP,'un') 154 | subplot(3,2,4) 155 | image(64*(1 - MDP.un)) 156 | title('Posterior probability') 157 | ylabel('policy') 158 | xlabel('updates') 159 | end 160 | 161 | % sample (observation) and preferences 162 | %-------------------------------------------------------------------------- 163 | for g = 1:ng 164 | 165 | subplot(3*ng,2,(2*ng + g - 1)*2 + 1), hold off 166 | c = C{gg(g)}; 167 | if size(c,2) < size(MDP.o,2) 168 | c = repmat(c(:,1),1,size(MDP.o,2)); 169 | end 170 | if size(c,1) > 128 171 | spm_spy(c,16,1), hold on 172 | else 173 | imagesc(1 - c), hold on 174 | end 175 | plot(MDP.o(gg(g),:),'.c','MarkerSize',16), hold off 176 | if g < 2 177 | title(sprintf('Outcomes and preferences - %s',MDP.label.modality{gg(g)})); 178 | else 179 | title(MDP.label.modality{gg(g)}); 180 | end 181 | if g == ng 182 | xlabel('time'); 183 | else 184 | set(gca,'XTickLabel',{}); 185 | end 186 | set(gca,'XTick',1:size(X{1},2)) 187 | 188 | YTickLabel = MDP.label.outcome{gg(g)}; 189 | if numel(YTickLabel) > 8 190 | i = round(linspace(1,numel(YTickLabel),8)); 191 | YTickLabel = YTickLabel(i); 192 | else 193 | i = 1:numel(YTickLabel); 194 | end 195 | set(gca,'YTick',i); 196 | set(gca,'YTickLabel',YTickLabel); 197 | 198 | end 199 | 200 | % expected precision 201 | %-------------------------------------------------------------------------- 202 | if isfield(MDP,'dn') && isfield(MDP,'wn') 203 | if size(MDP.dn,2) > 0 204 | subplot(3,2,6) 205 | if size(MDP.dn,2) > 1 206 | plot(MDP.dn,'r:'), hold on, plot(MDP.wn,'c','LineWidth',2), hold off 207 | else 208 | bar(MDP.dn,1.1,'k'), hold on, plot(MDP.wn,'c','LineWidth',2), hold off 209 | end 210 | title('Expected precision (dopamine)') 211 | xlabel('updates'), ylabel('precision'), spm_axis tight, box off 212 | end 213 | end 214 | drawnow 215 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_MDP_check.m: -------------------------------------------------------------------------------- 1 | function [MDP] = spm_MDP_check(MDP) 2 | % MDP structure checking 3 | % FORMAT [MDP] = spm_MDP_check(MDP) 4 | % 5 | % MDP.V(T - 1,P,F) - P allowable policies of T moves over F factors 6 | % or 7 | % MDP.U(1,P,F) - P allowable actions at each move 8 | % MDP.T - number of outcomes 9 | % 10 | % MDP.A{G}(O,N1,...,NF) - likelihood of O outcomes given hidden states 11 | % MDP.B{F}(NF,NF,MF) - transitions among hidden under MF control states 12 | % MDP.C{G}(O,T) - prior preferences over O outcomes in modality G 13 | % MDP.D{F}(NF,1) - prior probabilities over initial states 14 | % MDP.E{F}(NF,1) - prior probabilities over initial control 15 | % 16 | % MDP.a{G} - concentration parameters for A 17 | % MDP.b{F} - concentration parameters for B 18 | % MDP.c{F} - concentration parameters for C 19 | % MDP.d{F} - concentration parameters for D 20 | % MDP.e{F} - concentration parameters for E 21 | % 22 | % optional: 23 | % MDP.s(F,T) - vector of true states - for each hidden factor 24 | % MDP.o(G,T) - vector of outcome - for each outcome modality 25 | % MDP.u(F,T - 1) - vector of action - for each hidden factor 26 | % MDP.w(1,T) - vector of precisions 27 | % 28 | % if C or D are not specified, they will be set to default values (of no 29 | % preferences and uniform priors over initial steps). If there are no 30 | % policies, it will be assumed that I = 1 and all policies (for each 31 | % marginal hidden state) are allowed. 32 | %__________________________________________________________________________ 33 | 34 | % Karl Friston 35 | % Copyright (C) 2008-2022 Wellcome Centre for Human Neuroimaging 36 | 37 | 38 | % deal with a sequence of trials 39 | %========================================================================== 40 | 41 | % if there are multiple structures check each separately 42 | %-------------------------------------------------------------------------- 43 | if numel(MDP) > 1 44 | for m = 1:size(MDP,1) % number of trials 45 | for i = 1:size(MDP,2) % number of agents 46 | mdp(m,i) = spm_MDP_check(MDP(m,i)); 47 | end 48 | end 49 | MDP = mdp; 50 | return 51 | end 52 | 53 | % fill in (posterior or process) likelihood and priors 54 | %-------------------------------------------------------------------------- 55 | if ~isfield(MDP,'A'), MDP.A = MDP.a; end 56 | if ~isfield(MDP,'B'), MDP.B = MDP.b; end 57 | 58 | % check format of likelihood and priors 59 | %-------------------------------------------------------------------------- 60 | if ~iscell(MDP.A), MDP.A = {full(MDP.A)}; end 61 | if ~iscell(MDP.B), MDP.B = {full(MDP.B)}; end 62 | 63 | if isfield(MDP,'a'), if ~iscell(MDP.a), MDP.a = {full(MDP.a)}; end; end 64 | if isfield(MDP,'b'), if ~iscell(MDP.b), MDP.b = {full(MDP.b)}; end; end 65 | 66 | 67 | % check dimensions and orders 68 | %========================================================================== 69 | 70 | % number of transitions, policies and states 71 | %-------------------------------------------------------------------------- 72 | Nf = numel(MDP.B); % number of hidden state factors 73 | for f = 1:Nf 74 | 75 | % ensure probabilities are normalised : B 76 | %---------------------------------------------------------------------- 77 | NU(f) = size(MDP.B{f},3); % number of hidden controls 78 | NS(f) = size(MDP.B{f},1); % number of hidden states 79 | MDP.B{f} = double(MDP.B{f}); 80 | MDP.B{f} = spm_dir_norm(MDP.B{f}); 81 | 82 | end 83 | 84 | % numbber of outcome modalities and outcomes 85 | %-------------------------------------------------------------------------- 86 | Ng = numel(MDP.A); % number of outcome factors 87 | for g = 1:Ng 88 | 89 | % ensure probabilities are normalised : A 90 | %---------------------------------------------------------------------- 91 | No(g) = size(MDP.A{g},1); % number of outcomes 92 | if ~(issparse(MDP.A{g}) || islogical(MDP.A{g})) 93 | MDP.A{g} = double(MDP.A{g}); 94 | end 95 | if ~islogical(MDP.A{g}) 96 | MDP.A{g} = full(spm_dir_norm(MDP.A{g})); 97 | end 98 | end 99 | 100 | % check sizes of Dirichlet parameterisation 101 | %-------------------------------------------------------------------------- 102 | [Nf,Ns,Nu] = spm_MDP_size(MDP); 103 | 104 | 105 | % check policy specification (create default moving policy U, if necessary) 106 | %-------------------------------------------------------------------------- 107 | if isfield(MDP,'U') 108 | if size(MDP.U,1) == 1 && size(MDP.U,3) == Nf 109 | MDP.U = shiftdim(MDP.U,1); 110 | end 111 | end 112 | try 113 | V(1,:,:) = MDP.U; % allowable actions (1,Np) 114 | catch 115 | try 116 | V = MDP.V; % allowable policies (T - 1,Np) 117 | catch 118 | 119 | % allowable (moving) policies using all allowable actions 120 | %------------------------------------------------------------------ 121 | MDP.U = spm_combinations(Nu); % U = U(Np,Nf) 122 | V(1,:,:) = MDP.U; % V = V(Nt,Np,Nf) 123 | end 124 | end 125 | MDP.V = V; 126 | 127 | % check policy specification 128 | %-------------------------------------------------------------------------- 129 | if Nf ~= size(V,3) && size(V,3) > 1 130 | error('please ensure V(:,:,1:Nf) is consistent with MDP.B{1:Nf}') 131 | end 132 | 133 | % check preferences 134 | %-------------------------------------------------------------------------- 135 | if ~isfield(MDP,'C') 136 | for g = 1:Ng 137 | MDP.C{g} = zeros(No(g),1); 138 | end 139 | end 140 | for g = 1:Ng 141 | if iscell(MDP.C) 142 | if isvector(MDP.C{g}) 143 | MDP.C{g} = spm_vec(MDP.C{g}); 144 | end 145 | if No(g) ~= size(MDP.C{g},1) 146 | error(['please ensure A{' num2str(g) '} and C{' num2str(g) '} are consistent']) 147 | end 148 | end 149 | end 150 | 151 | 152 | % check initial states 153 | %-------------------------------------------------------------------------- 154 | if ~isfield(MDP,{'D'}) 155 | for f = 1:Nf 156 | MDP.D{f} = ones(Ns(f),1); 157 | end 158 | end 159 | if Nf ~= numel(MDP.D) 160 | error('please check MDP.D') 161 | end 162 | for f = 1:Nf 163 | MDP.D{f} = MDP.D{f}(:); 164 | end 165 | 166 | % check initial controls 167 | %-------------------------------------------------------------------------- 168 | % if ~isfield(MDP,{'E'}) 169 | % for f = 1:Nf 170 | % MDP.E{f} = ones(Nu(f),1); 171 | % end 172 | % end 173 | % if Nf ~= numel(MDP.E) 174 | % error('please check MDP.E') 175 | % end 176 | % for f = 1:Nf 177 | % MDP.E{f} = MDP.E{f}(:); 178 | % end 179 | 180 | 181 | % check initial states and internal consistency 182 | %-------------------------------------------------------------------------- 183 | for f = 1:Nf 184 | if Ns(f) ~= size(MDP.D{f},1) 185 | error(['please ensure B{' num2str(f) '} and D{' num2str(f) '} are consistent']) 186 | end 187 | if size(V,3) > 1 188 | if Nu(f) < max(spm_vec(V(:,:,f))) 189 | error(['please check V(:,:,' num2str(f) ') or U(:,:,' num2str(f) ')']) 190 | end 191 | end 192 | for g = 1:Ng 193 | try 194 | Na = size(MDP.a{g}); 195 | catch 196 | Na = size(MDP.A{g}); 197 | end 198 | if ~all(Na(2:end) == Ns) 199 | error(['please ensure A{' num2str(g) '} and D{' num2str(f) '} are consistent']) 200 | end 201 | end 202 | end 203 | 204 | % check probability matrices are properly specified 205 | %-------------------------------------------------------------------------- 206 | for f = 1:numel(MDP.B) 207 | if ~all(spm_vec(any(MDP.B{f},1))) 208 | error(['please check B{' num2str(f) '} for missing entries']) 209 | end 210 | end 211 | for g = 1:numel(MDP.A) 212 | if ~all(spm_vec(any(MDP.A{g},1))) 213 | error(['please check A{' num2str(g) '} for missing entries']) 214 | end 215 | end 216 | 217 | % check initial states 218 | %-------------------------------------------------------------------------- 219 | if isfield(MDP,'s') 220 | if size(MDP.s,1) > numel(MDP.B) 221 | error('please specify an initial state MDP.s for %i factors',Nf) 222 | end 223 | f = max(MDP.s,[],2)'; 224 | if any(f > NS(1:numel(f))) 225 | error('please ensure initial states MDP.s are consistent with MDP.B') 226 | end 227 | end 228 | 229 | % check outcomes if specified 230 | %-------------------------------------------------------------------------- 231 | if isfield(MDP,'o') 232 | if numel(MDP.o) 233 | if size(MDP.o,1) ~= Ng 234 | error('please specify an outcomes MDP.o for %i modalities',Ng) 235 | end 236 | if any(max(MDP.o,[],2) > No(:)) 237 | error('please ensure # outcomes MDP.o are consistent with MDP.A') 238 | end 239 | end 240 | end 241 | 242 | % check (primary link array if necessary) 243 | %-------------------------------------------------------------------------- 244 | if isfield(MDP,'link') 245 | 246 | % cardinality of subordinate level 247 | %---------------------------------------------------------------------- 248 | nf = numel(MDP.MDP(1).B); % number of hidden factors 249 | for f = 1:nf 250 | ns(f) = size(MDP.MDP(1).B{f},1); % number of hidden states 251 | end 252 | 253 | % check the size of link 254 | %---------------------------------------------------------------------- 255 | if ~all(size(MDP.link) == [nf,Ng]) 256 | error('please check the size of link {%i,%i}',nf,Ng) 257 | end 258 | 259 | % convert matrix to cell array if necessary 260 | %---------------------------------------------------------------------- 261 | if isnumeric(MDP.link) 262 | link = cell(nf,Ng); 263 | for f = 1:size(MDP.link,1) 264 | for g = 1:size(MDP.link,2) 265 | if MDP.link(f,g) 266 | link{f,g} = spm_speye(ns(f),No(g),0); 267 | end 268 | end 269 | end 270 | MDP.link = link; 271 | end 272 | 273 | % check sizes of cell array 274 | %---------------------------------------------------------------------- 275 | for f = 1:size(MDP.link,1) 276 | for g = 1:size(MDP.link,2) 277 | if ~isempty(MDP.link{f,g}) 278 | if ~all(size(MDP.link{f,g}) == [ns(f),No(g)]); 279 | error('please check link{%i,%i}',f,g) 280 | end 281 | end 282 | end 283 | end 284 | 285 | end 286 | 287 | % Empirical prior preferences 288 | %-------------------------------------------------------------------------- 289 | if isfield(MDP,'linkC') 290 | if isnumeric(MDP.linkC) 291 | linkC = cell(numel(MDP.MDP.C),Ng); 292 | for f = 1:size(MDP.linkC,1) 293 | for g = 1:size(MDP.linkC,2) 294 | if MDP.linkC(f,g) 295 | linkC{f,g} = spm_speye(size(MDP.MDP.C{f},1),No(g),0); 296 | end 297 | end 298 | end 299 | MDP.linkC = linkC; 300 | end 301 | end 302 | 303 | % Empirical priors over policies 304 | %-------------------------------------------------------------------------- 305 | if isfield(MDP,'linkE') 306 | if isnumeric(MDP.linkE) 307 | linkE = cell(1,Ng); 308 | for g = 1:size(MDP.linkE,2) 309 | if MDP.linkE(g) 310 | linkE{g} = spm_speye(size(MDP.MDP.E,1),No(g),0); 311 | end 312 | end 313 | MDP.linkE = linkE; 314 | end 315 | end 316 | 317 | % check factors and outcome modalities have proper labels 318 | %-------------------------------------------------------------------------- 319 | for i = 1:Nf 320 | 321 | % name of factors 322 | %---------------------------------------------------------------------- 323 | try 324 | MDP.label.factor(i); 325 | catch 326 | try 327 | MDP.label.factor{i} = MDP.Bname{i}; 328 | catch 329 | MDP.label.factor{i} = sprintf('factor %i',i); 330 | end 331 | end 332 | 333 | % name of levels of each factor 334 | %---------------------------------------------------------------------- 335 | for j = 1:Ns(i) 336 | try 337 | MDP.label.name{i}(j); 338 | catch 339 | try 340 | MDP.label.name{i}{j} = MDP.Sname{i}{j}; 341 | catch 342 | MDP.label.name{i}{j} = sprintf('state %i(%i)',j,i); 343 | end 344 | end 345 | end 346 | 347 | % name of actions under each factor 348 | %---------------------------------------------------------------------- 349 | for j = 1:Nu(i) 350 | try 351 | MDP.label.action{i}(j); 352 | catch 353 | MDP.label.action{i}{j} = sprintf('act %i(%i)',j,i); 354 | end 355 | end 356 | end 357 | 358 | % name of outcomes under each modality 359 | %-------------------------------------------------------------------------- 360 | for i = 1:Ng 361 | try 362 | MDP.label.modality(i); 363 | catch 364 | try 365 | MDP.label.modality{i} = MDP.Bname{i}; 366 | catch 367 | MDP.label.modality{i} = sprintf('modality %i',i); 368 | end 369 | end 370 | for j = 1:No(i) 371 | try 372 | MDP.label.outcome{i}(j); 373 | catch 374 | try 375 | MDP.label.outcome{i}{j} = MDP.Oname{i}{j}; 376 | catch 377 | MDP.label.outcome{i}{j} = sprintf('outcome %i(%i)',j,i); 378 | end 379 | end 380 | end 381 | end 382 | 383 | % check names are specified properly 384 | %-------------------------------------------------------------------------- 385 | if isfield(MDP,'Aname') 386 | if numel(MDP.Aname) ~= Ng 387 | error('please specify an MDP.Aname for each modality') 388 | end 389 | else 390 | % MDP.Aname = MDP.label.modality; 391 | end 392 | if isfield(MDP,'Bname') 393 | if numel(MDP.Bname) ~= Nf 394 | error('please specify an MDP.Bname for each factor') 395 | end 396 | else 397 | % MDP.Bname = MDP.label.factor; 398 | end 399 | 400 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_MDP_size.m: -------------------------------------------------------------------------------- 1 | function [Nf,Ns,Nu,Ng,No] = spm_MDP_size(mdp) 2 | % Dimensions of MDP 3 | % FORMAT [Nf,Ns,Nu,Ng,No] = spm_MDP_size(mdp) 4 | % Nf - number of factors 5 | % Ns - states per factor 6 | % Nu - control per factors 7 | % Ng - number of modalities 8 | % No - levels per modality 9 | %__________________________________________________________________________ 10 | 11 | % Karl Friston 12 | % Copyright (C) 2022-2023 Wellcome Centre for Human Neuroimaging 13 | 14 | 15 | % checks 16 | %-------------------------------------------------------------------------- 17 | if ~isfield(mdp,'a'), mdp.a = mdp.A; end 18 | if ~isfield(mdp,'b'), mdp.b = mdp.B; end 19 | 20 | % sizes of factors and modilities 21 | %-------------------------------------------------------------------------- 22 | Nf = numel(mdp.b); % number of hidden factors 23 | Ng = numel(mdp.a); % number of outcome modalities 24 | Ns = zeros(1,Nf); 25 | Nu = zeros(1,Nf); 26 | No = zeros(1,Ng); 27 | for f = 1:Nf 28 | Ns(f) = size(mdp.b{f},1); % number of hidden states 29 | Nu(f) = size(mdp.b{f},3); % number of hidden controls 30 | end 31 | for g = 1:Ng 32 | No(g) = size(mdp.a{g},1); % number of outcomes 33 | end 34 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_axis.m: -------------------------------------------------------------------------------- 1 | function varargout = spm_axis(varargin) 2 | % AXIS Control axis scaling and appearance. 3 | 4 | if nargout 5 | [varargout{1:nargout}] = axis(varargin{:}); 6 | else 7 | try 8 | axis(varargin{:}); 9 | end 10 | end 11 | 12 | if nargin == 1 && any(strcmpi(varargin{1},{'tight','scale'})) 13 | spm_axis(gca,varargin{1}); 14 | elseif nargin == 2 && allAxes(varargin{1}) && strcmpi(varargin{2},'tight') 15 | for i = 1:numel(varargin{1}) 16 | lm = get(varargin{1}(i),'ylim'); 17 | if diff(lm) < 1e-12 18 | set(varargin{1}(i),'ylim',lm + [-1 1]); 19 | else 20 | set(varargin{1}(i),'ylim',lm + [-1 1]*diff(lm)/16); 21 | end 22 | end 23 | elseif nargin == 2 && allAxes(varargin{1}) && strcmpi(varargin{2},'scale') 24 | for i = 1:numel(varargin{1}) 25 | lm = get(varargin{1}(i),'ylim'); 26 | set(varargin{1}(i),'ylim',[0 lm(2)*(1 + 1/16)]); 27 | end 28 | end 29 | 30 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 31 | function result = allAxes(h) 32 | 33 | result = all(ishghandle(h)) && ... 34 | length(findobj(h,'type','axes','-depth',0)) == length(h); 35 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_cat.m: -------------------------------------------------------------------------------- 1 | function [x] = spm_cat(x,d) 2 | % Convert a cell array into a matrix - a compiled routine 3 | % FORMAT [x] = spm_cat(x,d) 4 | % x - cell array 5 | % d - dimension over which to concatenate [default - both] 6 | %__________________________________________________________________________ 7 | % Empty array elements are replaced by sparse zero partitions and single 0 8 | % entries are expanded to conform to the non-empty non zero elements. 9 | % 10 | % e.g.: 11 | % > x = spm_cat({eye(2) []; 0 [1 1; 1 1]}) 12 | % > full(x) = 13 | % 14 | % 1 0 0 0 15 | % 0 1 0 0 16 | % 0 0 1 1 17 | % 0 0 1 1 18 | % 19 | % If called with a dimension argument, a cell array is returned. 20 | %__________________________________________________________________________ 21 | 22 | % Karl Friston 23 | % Copyright (C) 2005-2022 Wellcome Centre for Human Neuroimaging 24 | 25 | 26 | %error('spm_cat.c not compiled - see Makefile') 27 | 28 | % check x is not already a matrix 29 | %-------------------------------------------------------------------------- 30 | if ~iscell(x), return, end 31 | 32 | % if concatenation over a specific dimension 33 | %-------------------------------------------------------------------------- 34 | [n,m] = size(x); 35 | if nargin > 1 36 | 37 | % concatenate over first dimension 38 | %---------------------------------------------------------------------- 39 | if d == 1 40 | y = cell(1,m); 41 | for i = 1:m 42 | y{i} = spm_cat(x(:,i)); 43 | end 44 | 45 | % concatenate over second 46 | %---------------------------------------------------------------------- 47 | elseif d == 2 48 | 49 | y = cell(n,1); 50 | for i = 1:n 51 | y{i} = spm_cat(x(i,:)); 52 | end 53 | 54 | % only viable for 2-D arrays 55 | %---------------------------------------------------------------------- 56 | else 57 | error('uknown option') 58 | end 59 | x = y; 60 | return 61 | 62 | end 63 | 64 | % find dimensions to fill in empty partitions 65 | %-------------------------------------------------------------------------- 66 | for i = 1:n 67 | for j = 1:m 68 | if iscell(x{i,j}) 69 | x{i,j} = spm_cat(x{i,j}); 70 | end 71 | [u,v] = size(x{i,j}); 72 | I(i,j) = u; 73 | J(i,j) = v; 74 | end 75 | end 76 | I = max(I,[],2); 77 | J = max(J,[],1); 78 | 79 | % sparse and empty partitions 80 | %-------------------------------------------------------------------------- 81 | [n,m] = size(x); 82 | for i = 1:n 83 | for j = 1:m 84 | if isempty(x{i,j}) 85 | x{i,j} = sparse(I(i),J(j)); 86 | end 87 | end 88 | end 89 | 90 | % concatenate 91 | %-------------------------------------------------------------------------- 92 | for i = 1:n 93 | y{i,1} = cat(2,x{i,:}); 94 | end 95 | try 96 | x = sparse(cat(1,y{:})); 97 | catch 98 | x = cat(1,y{:}); 99 | end 100 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_combinations.m: -------------------------------------------------------------------------------- 1 | function U = spm_combinations(Nu) 2 | % FORMAT U = spm_combinations(Nu) 3 | % Nu - vector of dimensions 4 | % U - combinations of indices 5 | % 6 | % returns a matrix of all combinations of Nu 7 | %__________________________________________________________________________ 8 | 9 | % Karl Friston 10 | % Copyright (C) 2022-2023 Wellcome Centre for Human Neuroimaging 11 | 12 | 13 | Nf = numel(Nu); 14 | U = zeros(prod(Nu),Nf); 15 | for f = 1:Nf 16 | for j = 1:Nf 17 | if j == f 18 | k{j} = 1:Nu(j); 19 | else 20 | k{j} = ones(1,Nu(j)); 21 | end 22 | end 23 | u = 1; 24 | for i = 1:Nf 25 | u = kron(k{i},u); 26 | end 27 | 28 | % accumulate 29 | %---------------------------------------------------------------------- 30 | U(:,f) = u(:); 31 | 32 | end 33 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_conv.m: -------------------------------------------------------------------------------- 1 | function [X] = spm_conv(X,sx,sy) 2 | % Gaussian convolution 3 | % FORMAT [X] = spm_conv(X,sx[,sy]) 4 | % X - matrix 5 | % sx - kernel width (FWHM) in pixels 6 | % sy - optional non-isomorphic smoothing 7 | %__________________________________________________________________________ 8 | % 9 | % spm_conv is a one or two dimensional convolution of a matrix variable in 10 | % working memory. It capitalizes on the sparsity structure of the problem 11 | % and the separablity of multidimensional convolution with a Gaussian 12 | % kernel by using one-dimensional convolutions and kernels that are 13 | % restricted to non near-zero values. 14 | %__________________________________________________________________________ 15 | 16 | % Karl Friston 17 | % Copyright (C) 1999-2022 Wellcome Centre for Human Neuroimaging 18 | 19 | 20 | % assume isomorphic smoothing 21 | %-------------------------------------------------------------------------- 22 | if nargin < 3; sy = sx; end 23 | sx = abs(sx); 24 | sy = abs(sy); 25 | [lx,ly] = size(X); 26 | 27 | % FWHM -> sigma 28 | %-------------------------------------------------------------------------- 29 | sx = sx/sqrt(8*log(2)) + eps; 30 | sy = sy/sqrt(8*log(2)) + eps; 31 | 32 | % kernels 33 | %-------------------------------------------------------------------------- 34 | Ex = min([fix(6*sx) lx]); 35 | x = -Ex:Ex; 36 | kx = exp(-x.^2/(2*sx^2)); 37 | kx = kx/sum(kx); 38 | Ey = min([fix(6*sy) ly]); 39 | y = -Ey:Ey; 40 | ky = exp(-y.^2/(2*sy^2)); 41 | ky = ky/sum(ky); 42 | 43 | % convolve 44 | %-------------------------------------------------------------------------- 45 | if lx > 1 && numel(kx) > 1 46 | for i = 1:ly 47 | u = X(:,i); 48 | v = [flipud(u(1:Ex)); u; flipud(u((1:Ex) + lx - Ex))]; 49 | X(:,i) = sparse(conv(full(v),kx,'valid')); 50 | end 51 | end 52 | if ly > 1 && numel(ky) > 1 53 | for i = 1:lx 54 | u = X(i,:); 55 | v = [fliplr(u(1:Ey)) u fliplr(u((1:Ey) + ly - Ey))]; 56 | X(i,:) = sparse(conv(full(v),ky,'valid')); 57 | end 58 | end 59 | 60 | return 61 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_dir_norm.m: -------------------------------------------------------------------------------- 1 | function A = spm_dir_norm(A) 2 | % Normalisation of a (Dirichlet) conditional probability matrix 3 | % FORMAT A = spm_dir_norm(a) 4 | % 5 | % a - (Dirichlet) parameters of a conditional probability matrix 6 | % 7 | % A - conditional probability matrix 8 | %__________________________________________________________________________ 9 | 10 | % Karl Friston 11 | % Copyright (C) 2022 Wellcome Centre for Human Neuroimaging 12 | 13 | A = rdivide(A,sum(A,1)); 14 | A(isnan(A)) = 1/size(A,1); 15 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_iwft.m: -------------------------------------------------------------------------------- 1 | function [s] = spm_iwft(C,k,n) 2 | % Inverse windowed Fourier transform - continuous synthesis 3 | % FORMAT [s] = spm_iwft(C,k,n); 4 | % s - 1-D time-series 5 | % k - Frequencies (cycles per window) 6 | % n - window length 7 | % C - coefficients (complex) 8 | %__________________________________________________________________________ 9 | 10 | % Karl Friston 11 | % Copyright (C) 2007-2022 Wellcome Centre for Human Neuroimaging 12 | 13 | 14 | % window function (Hanning) 15 | %-------------------------------------------------------------------------- 16 | N = size(C,2); 17 | s = zeros(1,N); 18 | C = conj(C); 19 | 20 | % spectral density 21 | %-------------------------------------------------------------------------- 22 | for i = 1:length(k) 23 | W = exp(-sqrt(-1)*(2*pi*k(i)*[0:(N - 1)]/n)); 24 | w = W.*C(i,:); 25 | s = s + real(w); 26 | end 27 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_softmax.m: -------------------------------------------------------------------------------- 1 | function [y] = spm_softmax(x,k) 2 | % Softmax (e.g., neural transfer) function over columns 3 | % FORMAT [y] = spm_softmax(x,k) 4 | % 5 | % x - numeric array array 6 | % k - precision, sensitivity or inverse temperature (default k = 1) 7 | % 8 | % y = exp(k*x)/sum(exp(k*x)) 9 | % 10 | % NB: If supplied with a matrix this routine will return the softmax 11 | % function over columns - so that spm_softmax([x1,x2,..]) = [1,1,...] 12 | %__________________________________________________________________________ 13 | 14 | % Karl Friston 15 | % Copyright (C) 2010-2022 Wellcome Centre for Human Neuroimaging 16 | 17 | 18 | % apply 19 | %-------------------------------------------------------------------------- 20 | if nargin > 1, x = k*x; end 21 | if size(x,1) < 2; y = ones(size(x)); return, end 22 | 23 | % exponentiate and normalise 24 | %-------------------------------------------------------------------------- 25 | x = exp(minus(x,max(x))); 26 | y = rdivide(x,sum(x)); 27 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_speye.m: -------------------------------------------------------------------------------- 1 | function [D] = spm_speye(m,n,k,c) 2 | % Sparse leading diagonal matrix 3 | % FORMAT [D] = spm_speye(m,n,k,c) 4 | % 5 | % returns an m x n matrix with ones along the k-th leading diagonal. If 6 | % called with an optional fourth argument c = 1, a wraparound sparse matrix 7 | % is returned. If c = 2, then empty rows or columns are filled in on the 8 | % leading diagonal. 9 | %__________________________________________________________________________ 10 | 11 | % Karl Friston 12 | % Copyright (C) 2007-2022 Wellcome Centre for Human Neuroimaging 13 | 14 | 15 | % default k = 0 16 | %-------------------------------------------------------------------------- 17 | if nargin < 4, c = 0; end 18 | if nargin < 3, k = 0; end 19 | if nargin < 2, n = m; end 20 | 21 | % leading diagonal matrix 22 | %-------------------------------------------------------------------------- 23 | D = spdiags(ones(m,1),k,m,n); 24 | 25 | % add wraparound if necessary 26 | %-------------------------------------------------------------------------- 27 | if c == 1 28 | if k < 0 29 | D = D + spm_speye(m,n,min(n,m) + k); 30 | elseif k > 0 31 | D = D + spm_speye(m,n,k - min(n,m)); 32 | end 33 | elseif c == 2 34 | i = find(~any(D)); 35 | D = D + sparse(i,i,1,n,m); 36 | 37 | end 38 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_vec.m: -------------------------------------------------------------------------------- 1 | function [vX] = spm_vec(X,varargin) 2 | % Vectorise a numeric, cell or structure array - a compiled routine 3 | % FORMAT [vX] = spm_vec(X) 4 | % X - numeric, cell or structure array[s] 5 | % vX - vec(X) 6 | % 7 | % See spm_unvec 8 | %__________________________________________________________________________ 9 | % 10 | % e.g.: 11 | % spm_vec({eye(2) 3}) = [1 0 0 1 3]' 12 | %__________________________________________________________________________ 13 | 14 | % Karl Friston 15 | % Copyright (C) 2005-2022 Wellcome Centre for Human Neuroimaging 16 | 17 | 18 | %error('spm_vec.c not compiled - see Makefile') 19 | 20 | % initialise X and vX 21 | %-------------------------------------------------------------------------- 22 | if nargin > 1 23 | X = [{X},varargin]; 24 | end 25 | 26 | 27 | % vectorise numerical arrays 28 | %-------------------------------------------------------------------------- 29 | if isnumeric(X) 30 | vX = X(:); 31 | 32 | % vectorise logical arrays 33 | %-------------------------------------------------------------------------- 34 | elseif islogical(X) 35 | vX = X(:); 36 | 37 | % vectorise structure into cell arrays 38 | %-------------------------------------------------------------------------- 39 | elseif isstruct(X) 40 | vX = []; 41 | f = fieldnames(X); 42 | X = X(:); 43 | for i = 1:numel(f) 44 | vX = cat(1,vX,spm_vec({X.(f{i})})); 45 | end 46 | 47 | % vectorise cells into numerical arrays 48 | %-------------------------------------------------------------------------- 49 | elseif iscell(X) 50 | vX = []; 51 | for i = 1:numel(X) 52 | vX = cat(1,vX,spm_vec(X{i})); 53 | end 54 | else 55 | vX = []; 56 | end 57 | -------------------------------------------------------------------------------- /original_matlab_code/spm/spm_wft.m: -------------------------------------------------------------------------------- 1 | function [C] = spm_wft(s,k,n) 2 | % Windowed fourier wavelet transform (time-frequency analysis) 3 | % FORMAT [C] = spm_wft(s,k,n) 4 | % s - (t X n) time-series 5 | % k - Frequencies (cycles per window) 6 | % n - window length 7 | % C - (w X t X n) coefficients (complex) 8 | %__________________________________________________________________________ 9 | 10 | % Karl Friston 11 | % Copyright (C) 2006-2022 Wellcome Centre for Human Neuroimaging 12 | 13 | 14 | % window function (Hanning) 15 | %-------------------------------------------------------------------------- 16 | [T,N] = size(s); 17 | n = round(n); 18 | h = 0.5*(1 - cos(2*pi*(1:n)/(n + 1))); 19 | h = h'/sum(h); 20 | C = zeros(length(k),T,N); 21 | 22 | 23 | % spectral density 24 | %-------------------------------------------------------------------------- 25 | for i = 1:length(k) 26 | W = exp(-1j*(2*pi*k(i)*(0:(T - 1))/n))'; 27 | for j = 1:N 28 | w = conv(full(s(:,j)).*W,h); 29 | C(i,:,j) = w((1:T) + n/2); 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /original_matlab_code/spm_MDP_VB_ERP_tutorial.m: -------------------------------------------------------------------------------- 1 | function [x,y,ind,xx_yy] = spm_MDP_VB_ERP_tutorial(MDP,FACTOR,T) 2 | % auxiliary routine for hierarchical electrophysiological responses 3 | % FORMAT [x,y] = spm_MDP_VB_ERP(MDP,FACTOR,T) 4 | % 5 | % MDP - structure (see spm_MDP_VB) 6 | % FACTOR - the hidden factors (at the second alevel) to plot 7 | % T - flag to return cell of expectations (at time T; usually 1) 8 | % 9 | % x - simulated ERPs (high-level) 10 | % y - simulated ERPs (low level) 11 | % ind - indices or bins at the end of each (synchronised) epoch 12 | % 13 | % This routine combines first and second level hidden expectations by 14 | % synchronising them; such that first level updating is followed by an 15 | % epoch of second level updating - during which updating is suspended 16 | % (and expectations are held constant). The ensuing spike rates can be 17 | % regarded as showing delay period activity. In this routine, simulated 18 | % local field potentials are band pass filtered spike rates (between eight 19 | % and 32 Hz). 20 | % 21 | % Graphics are provided for first and second levels, in terms of simulated 22 | % spike rates (posterior expectations), which are then combined to show 23 | % simulated local field potentials for both levels (superimposed). 24 | % 25 | % At the lower level, only expectations about hidden states in the first 26 | % epoch are returned (because the number of epochs can differ from trial 27 | % to trial). 28 | %__________________________________________________________________________ 29 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 30 | 31 | % Karl Friston 32 | % $Id: spm_MDP_VB_ERP.m 7382 2018-07-25 13:58:04Z karl $ 33 | 34 | 35 | % defaults: assume the first factor is of interest 36 | %========================================================================== 37 | try, f1 = FACTOR(1); catch, f1 = 1; end 38 | try, f2 = FACTOR(2); catch, f2 = 1; end 39 | 40 | % and T = 1 41 | %-------------------------------------------------------------------------- 42 | if nargin < 3, T = 1; end 43 | 44 | for m = 1:numel(MDP) 45 | 46 | % dimensions 47 | %---------------------------------------------------------------------- 48 | xn = MDP(m).xn{f1}; % neuronal responses 49 | Nb = size(xn,1); % number of time bins per epochs 50 | Nx = size(xn,2); % number of states 51 | Ne = size(xn,3); % number of epochs 52 | 53 | 54 | % expected hidden states 55 | %====================================================================== 56 | x = cell(Ne,Nx); 57 | y = cell(Ne); 58 | for k = 1:Ne 59 | for j = 1:Nx 60 | x{k,j} = xn(:,j,T,k); 61 | end 62 | if isfield(MDP,'mdp') 63 | y{k} = spm_MDP_VB_ERP_tutorial(MDP(m).mdp(k),f2,1); 64 | else 65 | y{k} = []; 66 | end 67 | end 68 | 69 | if nargin > 2, return, end 70 | 71 | % synchronise responses 72 | %---------------------------------------------------------------------- 73 | u = {}; 74 | v = {}; 75 | uu = spm_cat(x(1,:)); 76 | for k = 1:Ne 77 | 78 | % low-level 79 | %------------------------------------------------------------------ 80 | v{end + 1,1} = spm_cat(y{k}); 81 | if k > 1 82 | u{end + 1,1} = ones(size(v{end,:},1),1)*u{end,1}(end,:); 83 | else 84 | u{end + 1,1} = ones(size(v{end,:},1),1)*uu(1,:); 85 | end 86 | 87 | % time bin indices 88 | %------------------------------------------------------------------ 89 | ind(k) = size(u{end},1); 90 | 91 | % high-level 92 | %------------------------------------------------------------------ 93 | u{end + 1,1} = spm_cat(x(k,:)); 94 | v{end + 1,1} = ones(size(u{end,:},1),1)*v{end,1}(end,:); 95 | 96 | % time bin indices 97 | %------------------------------------------------------------------ 98 | ind(k) = ind(k) + size(u{end},1); 99 | 100 | end 101 | 102 | % accumulate over trials 103 | %---------------------------------------------------------------------- 104 | U{m,1} = u; 105 | V{m,1} = v; 106 | 107 | end 108 | 109 | % time bin (seconds) 110 | %-------------------------------------------------------------------------- 111 | u = spm_cat(U); 112 | v = spm_cat(V); 113 | dt = 1/64; 114 | t = (1:size(u,1))*dt; 115 | 116 | % bandpass filter between 8 and 32 Hz 117 | %-------------------------------------------------------------------------- 118 | c = 1/32; 119 | x = log(u + c); 120 | y = log(v + c); 121 | x = spm_conv(x,2,0) - spm_conv(x,16,0); 122 | y = spm_conv(y,2,0) - spm_conv(y,16,0); 123 | 124 | xx = x'; 125 | xx(end+1,:) = sum(xx,1); 126 | yy = y'; 127 | yy(end+1,:) = sum(yy,1); 128 | xx_yy = xx(end,:)+yy(end,:); 129 | 130 | if nargout > 2, return, end 131 | 132 | % simulated firing rates and the local field potentials 133 | %========================================================================== 134 | 135 | % higher-level unit responses 136 | %-------------------------------------------------------------------------- 137 | factor = MDP(1).label.factor{f1}; 138 | name = MDP(1).label.name{f1}; 139 | 140 | subplot(4,1,1), image(t,1:(size(u,2)),64*(1 - u')), ylabel('Unit') 141 | title(sprintf('Unit reponses : %s',factor),'FontSize',16) 142 | if numel(name) < 16 143 | grid on, set(gca,'YTick',1:numel(name)) 144 | set(gca,'YTickLabel',name) 145 | end 146 | 147 | % lower-level unit responses 148 | %-------------------------------------------------------------------------- 149 | factor = MDP(1).MDP(1).label.factor{f2}; 150 | name = MDP(1).MDP(1).label.name{f2}; 151 | 152 | subplot(4,1,3), image(t,1:(size(v,2)),64*(1 - v')), ylabel('Unit') 153 | title(sprintf('Unit reponses : %s',factor),'FontSize',16) 154 | if numel(factor) < 16 155 | grid on, set(gca,'YTick',1:numel(name)) 156 | set(gca,'YTickLabel',name) 157 | end 158 | 159 | % event related responses at both levels 160 | %-------------------------------------------------------------------------- 161 | % subplot(6,1,3), plot(t,x',t,y','-.') 162 | % title('Local field potentials','FontSize',16) 163 | % ylabel('Depolarisation'),spm_axis tight 164 | % grid on, set(gca,'XTick',(1:(length(t)/Nb))*Nb*dt) 165 | 166 | % event related responses summed 167 | %-------------------------------------------------------------------------- 168 | 169 | subplot(4,1,2), plot(t,xx(end,:)) 170 | title('Local field potentials (Level 2)','FontSize',16) 171 | ylabel('Depolarisation'),ylim([-.2 1.1]) %spm_axis tight 172 | grid on, %set(gca,'XTick',(1:(length(t)/Nb))*Nb*dt) 173 | 174 | subplot(4,1,4), plot(t,yy(end,:)) 175 | title('Local field potentials (Level 1)','FontSize',16) 176 | ylabel('Depolarisation'),ylim([-1 1]) %spm_axis tight 177 | grid on, %set(gca,'XTick',(1:(length(t)/Nb))*Nb*dt) 178 | 179 | 180 | -------------------------------------------------------------------------------- /original_matlab_code/spm_MDP_VB_game_tutorial.m: -------------------------------------------------------------------------------- 1 | function Q = spm_MDP_VB_game_tutorial(MDP) 2 | % auxiliary plotting routine for spm_MDP_VB - multiple trials 3 | % FORMAT Q = spm_MDP_VB_game(MDP) 4 | % 5 | % MDP.P(M,T) - probability of emitting action 1,...,M at time 1,...,T 6 | % MDP.Q(N,T) - an array of conditional (posterior) expectations over 7 | % N hidden states and time 1,...,T 8 | % MDP.X - and Bayesian model averages over policies 9 | % MDP.R - conditional expectations over policies 10 | % MDP.O(O,T) - a sparse matrix encoding outcomes at time 1,...,T 11 | % MDP.S(N,T) - a sparse matrix encoding states at time 1,...,T 12 | % MDP.U(M,T) - a sparse matrix encoding action at time 1,...,T 13 | % MDP.W(1,T) - posterior expectations of precision 14 | % 15 | % MDP.un = un - simulated neuronal encoding of hidden states 16 | % MDP.xn = Xn - simulated neuronal encoding of policies 17 | % MDP.wn = wn - simulated neuronal encoding of precision 18 | % MDP.da = dn - simulated dopamine responses (deconvolved) 19 | % MDP.rt = rt - simulated dopamine responses (deconvolved) 20 | % 21 | % returns summary of performance: 22 | % 23 | % Q.X = x - expected hidden states 24 | % Q.R = u - final policy expectations 25 | % Q.S = s - initial hidden states 26 | % Q.O = o - final outcomes 27 | % Q.p = p - performance 28 | % Q.q = q - reaction times 29 | % 30 | % please see spm_MDP_VB 31 | %__________________________________________________________________________ 32 | % Copyright (C) 2005 Wellcome Trust Centre for Neuroimaging 33 | 34 | % Karl Friston 35 | % $Id: spm_MDP_VB_game.m 7307 2018-05-08 09:44:04Z karl $ 36 | 37 | % numbers of transitions, policies and states 38 | %-------------------------------------------------------------------------- 39 | if iscell(MDP(1).X) 40 | Nf = numel(MDP(1).B); % number of hidden state factors 41 | Ng = numel(MDP(1).A); % number of outcome factors 42 | else 43 | Nf = 1; 44 | Ng = 1; 45 | end 46 | 47 | % graphics 48 | %========================================================================== 49 | Nt = length(MDP); % number of trials 50 | Ne = size(MDP(1).V,1) + 1; % number of epochs per trial 51 | Np = size(MDP(1).V,2) + 1; % number of policies 52 | for i = 1:Nt 53 | 54 | % assemble expectations of hidden states and outcomes 55 | %---------------------------------------------------------------------- 56 | for j = 1:Ne 57 | for k = 1:Ne 58 | for f = 1:Nf 59 | try 60 | x{f}{i,1}{k,j} = gradient(MDP(i).xn{f}(:,:,j,k)')'; 61 | catch 62 | x{f}{i,1}{k,j} = gradient(MDP(i).xn(:,:,j,k)')'; 63 | end 64 | end 65 | end 66 | end 67 | s(:,i) = MDP(i).s(:,2); 68 | o(:,i) = MDP(i).o(2,:)'; 69 | act_prob(:,i) = MDP(i).P(:,:,1)'; 70 | act(:,i) = MDP(i).u(2,1); 71 | w(:,i) = mean(MDP(i).dn,2); 72 | 73 | 74 | % assemble context learning 75 | %---------------------------------------------------------------------- 76 | for f = 1:Nf 77 | try 78 | try 79 | D = MDP(i).d{f}; 80 | catch 81 | D = MDP(i).D{f}; 82 | end 83 | catch 84 | try 85 | D = MDP(i).d; 86 | catch 87 | D = MDP(i).D; 88 | end 89 | end 90 | d{f}(:,i) = D/sum(D); 91 | end 92 | 93 | % assemble performance 94 | %---------------------------------------------------------------------- 95 | p(i) = 0; 96 | for g = 1:Ng 97 | try 98 | U = spm_softmax(MDP(i).C{g}); 99 | catch 100 | U = spm_softmax(MDP(i).C); 101 | end 102 | for t = 1:Ne 103 | p(i) = p(i) + log(U(MDP(i).o(g,t),t))/Ne; 104 | end 105 | end 106 | q(i) = sum(MDP(i).rt(2:end)); 107 | 108 | end 109 | 110 | % assemble output structure if required 111 | %-------------------------------------------------------------------------- 112 | if nargout 113 | Q.X = x; % expected hidden states 114 | Q.R = act_prob; % final policy expectations 115 | Q.S = s; % inital hidden states 116 | Q.O = o; % final outcomes 117 | Q.p = p; % performance 118 | Q.q = q; % reaction times 119 | return 120 | end 121 | 122 | 123 | % Initial states and expected policies (habit in red) 124 | %-------------------------------------------------------------------------- 125 | col = {'r.','g.','b.','c.','m.','k.'}; 126 | t = 1:Nt; 127 | subplot(5,1,1) 128 | if Nt < 64 129 | MarkerSize = 24; 130 | else 131 | MarkerSize = 16; 132 | end 133 | 134 | image(64*(1 - act_prob)), hold on 135 | 136 | plot(act,col{3},'MarkerSize',MarkerSize) 137 | 138 | try 139 | plot(Np*(1 - act_prob(Np,:)),'r') 140 | end 141 | try 142 | E = spm_softmax(spm_cat({MDP.e})); 143 | plot(Np*(1 - E(end,:)),'r:') 144 | end 145 | title('Action selection and action probabilities') 146 | xlabel('Trial'),ylabel('Action'), hold off 147 | yticklabels({'Start','Hint','Choose Left','Choose Right'}) 148 | % Performance 149 | %-------------------------------------------------------------------------- 150 | 151 | subplot(5,1,2), bar(p,'k'), hold on 152 | 153 | for i = 1:size(o,2) 154 | % j(i,1) = max(o(:,i)); 155 | if MDP(i).o(3,2) == 2 156 | j(i,1) = MDP(i).o(2,3)-1; 157 | else 158 | j(i,1) = MDP(i).o(2,2)-1; 159 | end 160 | if j(i,1) == 1 161 | jj(i,1) = 1; 162 | else 163 | jj(i,1) = -2; 164 | end 165 | end 166 | 167 | 168 | 169 | plot((j),col{2},'MarkerSize',MarkerSize); 170 | plot((jj),col{6},'MarkerSize',MarkerSize); 171 | 172 | 173 | title('Win/Loss and Free energies') 174 | ylabel('Value and Win/Loss'), spm_axis tight, hold off, box off 175 | set(gca,'YTick',[-4:1:3]) 176 | yticklabels({'','','','Free Energy','','Loss','Win'}) 177 | 178 | % Initial states (context) 179 | %-------------------------------------------------------------------------- 180 | subplot(5,1,3) 181 | col = {'r','b','g','c','m','k','r','b','g','c','m','k'}; 182 | for f = 1:Nf 183 | if Nf > 1 184 | plot(spm_cat(x{f}),col{f}), hold on 185 | else 186 | plot(spm_cat(x{f})) 187 | end 188 | end 189 | title('State estimation (ERPs)'), ylabel('Response'), 190 | spm_axis tight, hold off, box off 191 | 192 | % Precision (dopamine) 193 | %-------------------------------------------------------------------------- 194 | subplot(5,1,4) 195 | w = spm_vec(w); 196 | if Nt > 8 197 | fill([1 1:length(w) length(w)],[0; w.*(w > 0); 0],'k'), hold on 198 | fill([1 1:length(w) length(w)],[0; w.*(w < 0); 0],'k'), hold off 199 | else 200 | bar(w,1.1,'k') 201 | end 202 | title('Precision (dopamine)') 203 | ylabel('Precision','FontSize',12), spm_axis tight, box off 204 | YLim = get(gca,'YLim'); YLim(1) = 0; set(gca,'YLim',YLim); 205 | set(gca,'XTickLabel',{}); 206 | 207 | % learning - D 208 | %-------------------------------------------------------------------------- 209 | for f = 1 210 | subplot(5*Nf,1,Nf*4 + f), image(64*(1 - d{f})) 211 | if f < 2 212 | title('Context Learning') 213 | end 214 | set(gca,'XTick',1:Nt); 215 | % if f < Nf 216 | % set(gca,'XTickLabel',{}); 217 | % end 218 | % set(gca,'YTick',1); 219 | % try 220 | % set(gca,'YTickLabel',MDP(1).label.factor{f}); 221 | % end 222 | % try 223 | % set(gca,'YTickLabel',MDP(1).Bname{f}); 224 | % end 225 | 226 | yticklabels({'Left-Win','Right-Win'}) 227 | 228 | end 229 | -------------------------------------------------------------------------------- /spm/spm.py: -------------------------------------------------------------------------------- 1 | def spm_cat(x): 2 | result = {} 3 | for (i, j), values in x.items(): 4 | for k, value in enumerate(values): 5 | result[(j, k + i * len(values))] = [value] 6 | return result -------------------------------------------------------------------------------- /spm/spm_MDP_VB_LFP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from scipy.signal import convolve 4 | 5 | def spm_MDP_VB_LFP(MDP, UNITS=None, f=1, SPECTRAL=0): 6 | # 检查是否有模拟的神经元响应 7 | if 'xn' not in MDP[0]: 8 | print('请使用其他反演方案来模拟神经元响应(例如,spm_MDP_VB_XX)') 9 | return 10 | 11 | # 默认值 12 | if UNITS is None: 13 | UNITS = [] 14 | 15 | # 维度 16 | Nt = len(MDP) # 试验次数 17 | try: 18 | Ne = MDP[0]['xn'][f].shape[3] # 纪元数 19 | Nx = MDP[0]['B'][f].shape[0] # 状态数 20 | Nb = MDP[0]['xn'][f].shape[0] # 每个纪元的时间段数 21 | except: 22 | Ne = MDP[0]['xn'].shape[3] 23 | Nx = MDP[0]['A'].shape[1] 24 | Nb = MDP[0]['xn'].shape[0] 25 | 26 | # 要绘制的单元 27 | ALL = [] 28 | for i in range(Ne): 29 | for j in range(Nx): 30 | ALL.append([j, i]) 31 | if len(ALL) > 512: 32 | ii = np.round(np.linspace(0, len(ALL) - 1, 512)).astype(int) 33 | ALL = [ALL[i] for i in ii] 34 | if not UNITS: 35 | UNITS = ALL 36 | ii = list(range(len(ALL))) 37 | 38 | # 汇总统计:发射率 39 | z = [] 40 | v = [] 41 | dn = [] 42 | for i in range(Nt): 43 | str_list = [] 44 | try: 45 | xn = MDP[i]['xn'][f] 46 | except: 47 | xn = MDP[i]['xn'] 48 | z.append([[xn[:, ALL[j][0], ALL[j][1], k] for k in range(Ne)] for j in range(len(ALL))]) 49 | v.append([[xn[:, UNITS[j][0], UNITS[j][1], k] for k in range(Ne)] for j in range(len(UNITS))]) 50 | dn.append(np.mean(MDP[i]['dn'], axis=1)) 51 | 52 | if len(dn) == 0: 53 | return 54 | 55 | # 相位幅度耦合 56 | dt = 1 / 64 # 时间段(秒) 57 | t = np.arange(1, Nb * Ne * Nt + 1) * dt # 时间(秒) 58 | Hz = np.arange(4, 33) # 频率范围 59 | n = 1 / (4 * dt) # 窗口长度 60 | w = Hz * (dt * n) # 每个窗口的周期数 61 | 62 | # 模拟发射率 63 | z = np.concatenate([np.concatenate(z[i], axis=1) for i in range(len(z))], axis=1).T 64 | v = np.concatenate([np.concatenate(v[i], axis=1) for i in range(len(v))], axis=1).T 65 | 66 | # 带通滤波器在 8 到 32 Hz 之间的对数率:局部场电位 67 | c = 1 / 32 68 | x = np.log(z.T + c) 69 | u = np.log(v.T + c) 70 | x = convolve(x, np.ones((2,)) / 2, mode='same') - convolve(x, np.ones((16,)) / 16, mode='same') 71 | u = convolve(u, np.ones((2,)) / 2, mode='same') - convolve(u, np.ones((16,)) / 16, mode='same') 72 | 73 | # 绘图 74 | fig, axs = plt.subplots(4, 1, figsize=(10, 15)) 75 | axs[0].imshow(64 * (1 - z), aspect='auto', extent=[t[0], t[-1], 0, len(ii)]) 76 | axs[0].set_title(MDP[0]['label']['factor'][f]) 77 | axs[0].set_xlabel('time (sec)') 78 | if len(str_list) < 16: 79 | axs[0].grid(True) 80 | axs[0].set_yticks(range(Ne * Nx)) 81 | axs[0].set_yticklabels(str_list) 82 | axs[0].grid(True) 83 | axs[0].set_xticks(np.arange(1, Ne * Nt + 1) * Nb * dt) 84 | if Ne * Nt > 32: 85 | axs[0].set_xticklabels([]) 86 | if Nt == 1: 87 | axs[0].axis('square') 88 | 89 | # 时间频率分析和 theta 相位 90 | wft = np.abs(np.fft.fft(x, n=int(n), axis=0)) 91 | csd = np.sum(wft, axis=2) 92 | lfp = np.sum(x, axis=1) 93 | phi = np.angle(np.fft.ifft(np.sum(wft[0, :, :], axis=2), n=int(n))) 94 | lfp = 4 * lfp / np.std(lfp) + 16 95 | phi = 4 * phi / np.std(phi) + 16 96 | 97 | axs[1].imshow(csd, aspect='auto', extent=[t[0], t[-1], Hz[0], Hz[-1]], origin='lower') 98 | axs[1].plot(t, lfp, 'w:', t, phi, 'w') 99 | axs[1].grid(True) 100 | axs[1].set_xticks(np.arange(1, Ne * Nt + 1) * Nb * dt) 101 | axs[1].set_title('Time-frequency response') 102 | axs[1].set_xlabel('time (sec)') 103 | axs[1].set_ylabel('frequency (Hz)') 104 | if Nt == 1: 105 | axs[1].axis('square') 106 | 107 | # 频谱响应 108 | if SPECTRAL: 109 | fig, axs = plt.subplots(4, 2, figsize=(10, 15)) 110 | csd = np.sum(np.abs(wft), axis=1) 111 | axs[0, 0].plot(Hz, np.log(csd)) 112 | axs[0, 0].set_title('Spectral response') 113 | axs[0, 0].set_xlabel('frequency (Hz)') 114 | axs[0, 0].set_ylabel('log power') 115 | axs[0, 0].axis('tight') 116 | axs[0, 0].box(False) 117 | axs[0, 0].axis('square') 118 | 119 | cfc = 0 120 | for i in range(wft.shape[2]): 121 | cfc += np.corrcoef(np.abs(wft[:, :, i]).T) 122 | axs[0, 1].imshow(cfc, aspect='auto', extent=[Hz[0], Hz[-1], Hz[0], Hz[-1]], origin='lower') 123 | axs[0, 1].set_title('Cross-frequency coupling') 124 | axs[0, 1].set_xlabel('frequency (Hz)') 125 | axs[0, 1].set_ylabel('frequency (Hz)') 126 | axs[0, 1].box(False) 127 | axs[0, 1].axis('square') 128 | 129 | # 局部场电位 130 | axs[2].plot(t, u) 131 | axs[2].plot(t, x, ':') 132 | axs[2].grid(True) 133 | axs[2].set_xticks(np.arange(1, Ne * Nt + 1) * Nb * dt) 134 | for i in range(2, Nt + 1, 2): 135 | axs[2].axvspan((i - 1) * Ne * Nb * dt, i * Ne * Nb * dt, color='w', alpha=0.1) 136 | axs[2].set_title('Local field potentials') 137 | axs[2].set_xlabel('time (sec)') 138 | axs[2].set_ylabel('response') 139 | if Nt == 1: 140 | axs[2].axis('square') 141 | axs[2].box(False) 142 | 143 | # 发射率 144 | if Nt == 1: 145 | axs[3].plot(t, v) 146 | axs[3].plot(t, z, ':') 147 | axs[3].grid(True) 148 | axs[3].set_xticks(np.arange(1, Ne * Nt + 1) * Nb * dt) 149 | axs[3].set_title('Firing rates') 150 | axs[3].set_xlabel('time (sec)') 151 | axs[3].set_ylabel('response') 152 | axs[3].axis('square') 153 | 154 | # 模拟多巴胺响应(如果不是移动策略) 155 | dn = np.concatenate(dn) 156 | dn = dn * (dn > 0) 157 | dn = dn + (dn + 1 / 16) * np.random.rand(len(dn)) / 8 158 | axs[3].bar(np.arange(len(dn)), dn, color='k') 159 | axs[3].set_title('Dopamine responses') 160 | axs[3].set_xlabel('time (updates)') 161 | axs[3].set_ylabel('change in precision') 162 | axs[3].axis('tight') 163 | axs[3].box(False) 164 | axs[3].set_ylim(bottom=0) 165 | if Nt == 1: 166 | axs[3].axis('square') 167 | 168 | # 模拟光栅 169 | if Nt == 1 and len(ii) < 129: 170 | fig, ax = plt.subplots(1, 1, figsize=(10, 5)) 171 | R = np.kron(z, np.ones((16, 16))) 172 | R = np.random.rand(*R.shape) > R * (1 - 1 / 16) 173 | ax.imshow(R, aspect='auto', extent=[t[0], t[-1], 0, Nx * Ne]) 174 | ax.set_title('Unit firing') 175 | ax.set_xlabel('time (sec)') 176 | ax.grid(True) 177 | ax.set_xticks(np.arange(1, Ne * Nt + 1) * Nb * dt) 178 | ax.set_yticks(range(Ne * Nx)) 179 | ax.set_yticklabels(str_list) 180 | ax.axis('square') 181 | 182 | plt.show() 183 | 184 | # 示例调用 185 | # MDP = [{'xn': ..., 'dn': ..., 'label': {'factor': ..., 'name': ...}, 'B': ..., 'A': ...}] 186 | # spm_MDP_VB_LFP(MDP) -------------------------------------------------------------------------------- /spm/spm_MDP_VB_trial.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from spm_MDP_check import spm_MDP_check 4 | 5 | def spm_MDP_VB_trial(MDP, gf=None, gg=None): 6 | """ 7 | Auxiliary plotting routine for spm_MDP_VB - single trial 8 | """ 9 | # Check MDP structure 10 | MDP = spm_MDP_check(MDP) 11 | plt.clf() 12 | 13 | # Number of transitions, policies and states 14 | if isinstance(MDP['X'], list): 15 | Nf = len(MDP['B']) # number of hidden state factors 16 | Ng = len(MDP['A']) # number of outcome factors 17 | X = MDP['X'] 18 | C = MDP['C'] 19 | Nu = [np.size(MDP['B'][f], 2) > 1 for f in range(Nf)] 20 | else: 21 | Nf = 1 22 | Ng = 1 23 | Nu = [1] 24 | X = [MDP['X']] 25 | C = [MDP['C']] 26 | 27 | # Factors and outcomes to plot 28 | maxg = 3 29 | if gf is None: 30 | gf = list(range(1, min(Nf, maxg) + 1)) 31 | if gg is None: 32 | gg = list(range(1, min(Ng, maxg) + 1)) 33 | nf = len(gf) 34 | ng = len(gg) 35 | 36 | # Posterior beliefs about hidden states 37 | for f in range(nf): 38 | plt.subplot(3 * nf, 2, (f) * 2 + 1) 39 | plt.imshow(64 * (1 - X[gf[f] - 1]), cmap='gray') 40 | if X[gf[f] - 1].shape[0] > 128: 41 | spm_spy(X[gf[f] - 1], 12, 1) 42 | a = plt.axis() 43 | if 's' in MDP: 44 | plt.plot(MDP['s'][gf[f] - 1, :], '.r', markersize=8) 45 | plt.axis(a) 46 | if f < 1: 47 | plt.title(f'Hidden states - {MDP["label"]["factor"][gf[f] - 1]}') 48 | else: 49 | plt.title(MDP["label"]["factor"][gf[f] - 1]) 50 | plt.gca().set_xticklabels([]) 51 | plt.gca().set_xticks(range(1, X[0].shape[1] + 1)) 52 | 53 | YTickLabel = MDP["label"]["name"][gf[f] - 1] 54 | if len(YTickLabel) > 8: 55 | i = np.linspace(1, len(YTickLabel), 8) 56 | YTickLabel = [YTickLabel[int(round(idx)) - 1] for idx in i] 57 | else: 58 | i = range(1, len(YTickLabel) + 1) 59 | plt.gca().set_yticks(i) 60 | plt.gca().set_yticklabels(YTickLabel) 61 | 62 | # Posterior beliefs about control states 63 | Nu = [i for i, val in enumerate(Nu) if val] 64 | Np = len(Nu) 65 | for f in range(Np): 66 | plt.subplot(3 * Np, 2, (f + 1) * 2) 67 | if isinstance(MDP['P'], list): 68 | P = MDP['P'][f] 69 | elif Nf > 1: 70 | ind = list(range(1, Nf + 1)) 71 | P = MDP['P'] 72 | for dim in range(Nf): 73 | if dim != ind[Nu[f]]: 74 | P = np.sum(P, axis=dim) 75 | P = np.squeeze(P) 76 | else: 77 | P = np.squeeze(MDP['P']) 78 | 79 | # Display 80 | plt.imshow(64 * (1 - P), cmap='gray') 81 | if 'u' in MDP: 82 | plt.plot(MDP['u'][Nu[f], :], '.c', markersize=16) 83 | if f < 1: 84 | plt.title(f'Action - {MDP["label"]["factor"][Nu[f]]}') 85 | else: 86 | plt.title(MDP["label"]["factor"][Nu[f]]) 87 | plt.gca().set_xticklabels([]) 88 | plt.gca().set_xticks(range(1, X[0].shape[1] + 1)) 89 | 90 | YTickLabel = MDP["label"]["action"][Nu[f]] 91 | if len(YTickLabel) > 8: 92 | i = np.round(np.linspace(1, len(YTickLabel), 8)) 93 | YTickLabel = [YTickLabel[int(idx) - 1] for idx in i] 94 | else: 95 | i = range(1, len(YTickLabel) + 1) 96 | plt.gca().set_yticks(i) 97 | plt.gca().set_yticklabels(YTickLabel) 98 | 99 | # Policies 100 | plt.subplot(3 * Np, 2, (Np + f) * 2 + 1) 101 | plt.imshow(MDP['V'][:, :, Nu[f]].T, cmap='gray') 102 | if f < 1: 103 | plt.title(f'Allowable policies - {MDP["label"]["factor"][Nu[f]]}') 104 | else: 105 | plt.title(MDP["label"]["factor"][Nu[f]]) 106 | if f < Np - 1: 107 | plt.gca().set_xticklabels([]) 108 | plt.gca().set_xticks(range(1, X[0].shape[1])) 109 | 110 | # Expectations over policies 111 | if 'un' in MDP: 112 | plt.subplot(3, 2, 4) 113 | plt.imshow(64 * (1 - MDP['un']), cmap='gray') 114 | plt.title('Posterior probability') 115 | plt.ylabel('policy') 116 | plt.xlabel('updates') 117 | 118 | # Sample (observation) and preferences 119 | for g in range(ng): 120 | plt.subplot(3 * ng, 2, (2 * ng + g) * 2 + 1) 121 | c = C[gg[g] - 1] 122 | if c.shape[1] < MDP['o'].shape[1]: 123 | c = np.tile(c[:, 0], (1, MDP['o'].shape[1])) 124 | if c.shape[0] > 128: 125 | spm_spy(c, 16, 1) 126 | else: 127 | plt.imshow(1 - c, cmap='gray') 128 | plt.plot(MDP['o'][gg[g] - 1, :], '.c', markersize=16) 129 | if g < 1: 130 | plt.title(f'Outcomes and preferences - {MDP["label"]["modality"][gg[g] - 1]}') 131 | else: 132 | plt.title(MDP["label"]["modality"][gg[g] - 1]) 133 | if g == ng - 1: 134 | plt.xlabel('time') 135 | else: 136 | plt.gca().set_xticklabels([]) 137 | plt.gca().set_xticks(range(1, X[0].shape[1])) 138 | 139 | YTickLabel = MDP["label"]["outcome"][gg[g] - 1] 140 | if len(YTickLabel) > 8: 141 | i = np.round(np.linspace(1, len(YTickLabel), 8)) 142 | YTickLabel = [YTickLabel[int(idx) - 1] for idx in i] 143 | else: 144 | i = range(1, len(YTickLabel) + 1) 145 | plt.gca().set_yticks(i) 146 | plt.gca().set_yticklabels(YTickLabel) 147 | 148 | # Expected precision 149 | if 'dn' in MDP and 'wn' in MDP: 150 | if MDP['dn'].shape[1] > 0: 151 | plt.subplot(3, 2, 6) 152 | if MDP['dn'].shape[1] > 1: 153 | plt.plot(MDP['dn'], 'r:') 154 | plt.plot(MDP['wn'], 'c', linewidth=2) 155 | else: 156 | plt.bar(range(len(MDP['dn'])), MDP['dn'], 1.1, color='k') 157 | plt.plot(MDP['wn'], 'c', linewidth=2) 158 | plt.title('Expected precision (dopamine)') 159 | plt.xlabel('updates') 160 | plt.ylabel('precision') 161 | plt.tight_layout() 162 | plt.box(False) 163 | plt.draw() 164 | 165 | def spm_spy(matrix, threshold, size): 166 | # This function should visualize the matrix 167 | # Placeholder for the actual implementation 168 | pass -------------------------------------------------------------------------------- /spm/spm_MDP_check.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spm_combination import spm_combinations 3 | from spm_dir_norm import spm_dir_norm 4 | from spm_MDP_size import spm_MDP_size 5 | from spm_speye import spm_speye 6 | 7 | 8 | def spm_MDP_check(MDP): 9 | """ 10 | MDP structure checking 11 | """ 12 | # deal with a sequence of trials 13 | if isinstance(MDP, list) and len(MDP) > 1: 14 | for m in range(len(MDP)): 15 | for i in range(len(MDP[m])): 16 | MDP[m][i] = spm_MDP_check(MDP[m][i]) 17 | return MDP 18 | 19 | # fill in (posterior or process) likelihood and priors 20 | if 'A' not in MDP: 21 | MDP['A'] = MDP.get('a', None) 22 | if 'B' not in MDP: 23 | MDP['B'] = MDP.get('b', None) 24 | 25 | # check format of likelihood and priors 26 | if not isinstance(MDP['A'], list): 27 | MDP['A'] = [np.array(MDP['A'])] 28 | if not isinstance(MDP['B'], list): 29 | MDP['B'] = [np.array(MDP['B'])] 30 | 31 | if 'a' in MDP and not isinstance(MDP['a'], list): 32 | MDP['a'] = [np.array(MDP['a'])] 33 | if 'b' in MDP and not isinstance(MDP['b'], list): 34 | MDP['b'] = [np.array(MDP['b'])] 35 | 36 | # check dimensions and orders 37 | Nf = len(MDP['B']) # number of hidden state factors 38 | NU = [] 39 | NS = [] 40 | for f in range(Nf): 41 | NU.append(MDP['B'][f].shape[2]) # number of hidden controls 42 | NS.append(MDP['B'][f].shape[0]) # number of hidden states 43 | MDP['B'][f] = MDP['B'][f].astype(float) 44 | MDP['B'][f] = spm_dir_norm(MDP['B'][f]) 45 | 46 | Ng = len(MDP['A']) # number of outcome factors 47 | No = [] 48 | for g in range(Ng): 49 | No.append(MDP['A'][g].shape[0]) # number of outcomes 50 | if not (np.issparse(MDP['A'][g]) or np.islogical(MDP['A'][g])): 51 | MDP['A'][g] = MDP['A'][g].astype(float) 52 | if not np.islogical(MDP['A'][g]): 53 | MDP['A'][g] = spm_dir_norm(MDP['A'][g]) 54 | 55 | # check sizes of Dirichlet parameterisation 56 | Nf, Ns, Nu = spm_MDP_size(MDP) 57 | 58 | # check policy specification (create default moving policy U, if necessary) 59 | if 'U' in MDP: 60 | if MDP['U'].shape[0] == 1 and MDP['U'].shape[2] == Nf: 61 | MDP['U'] = np.moveaxis(MDP['U'], 0, -1) 62 | try: 63 | V = np.expand_dims(MDP['U'], axis=0) 64 | except: 65 | try: 66 | V = MDP['V'] 67 | except: 68 | MDP['U'] = spm_combinations(Nu) 69 | V = np.expand_dims(MDP['U'], axis=0) 70 | MDP['V'] = V 71 | 72 | # check policy specification 73 | if Nf != V.shape[2] and V.shape[2] > 1: 74 | raise ValueError('please ensure V[:,:,1:Nf] is consistent with MDP.B{1:Nf}') 75 | 76 | # check preferences 77 | if 'C' not in MDP: 78 | MDP['C'] = [np.zeros((No[g], 1)) for g in range(Ng)] 79 | for g in range(Ng): 80 | if isinstance(MDP['C'], list): 81 | if MDP['C'][g].ndim == 1: 82 | MDP['C'][g] = MDP['C'][g].reshape(-1, 1) 83 | if No[g] != MDP['C'][g].shape[0]: 84 | raise ValueError(f'please ensure A[{g}] and C[{g}] are consistent') 85 | 86 | # check initial states 87 | if 'D' not in MDP: 88 | MDP['D'] = [np.ones((Ns[f], 1)) for f in range(Nf)] 89 | if Nf != len(MDP['D']): 90 | raise ValueError('please check MDP.D') 91 | for f in range(Nf): 92 | MDP['D'][f] = MDP['D'][f].reshape(-1, 1) 93 | 94 | # check initial controls 95 | # if 'E' not in MDP: 96 | # MDP['E'] = [np.ones((Nu[f], 1)) for f in range(Nf)] 97 | # if Nf != len(MDP['E']): 98 | # raise ValueError('please check MDP.E') 99 | # for f in range(Nf): 100 | # MDP['E'][f] = MDP['E'][f].reshape(-1, 1) 101 | 102 | # check initial states and internal consistency 103 | for f in range(Nf): 104 | if Ns[f] != MDP['D'][f].shape[0]: 105 | raise ValueError(f'please ensure B[{f}] and D[{f}] are consistent') 106 | if V.shape[2] > 1: 107 | if Nu[f] < np.max(V[:, :, f]): 108 | raise ValueError(f'please check V[:,:,{f}] or U[:,:,{f}]') 109 | for g in range(Ng): 110 | try: 111 | Na = MDP['a'][g].shape 112 | except: 113 | Na = MDP['A'][g].shape 114 | if not all(Na[1:] == Ns): 115 | raise ValueError(f'please ensure A[{g}] and D[{f}] are consistent') 116 | 117 | # check probability matrices are properly specified 118 | for f in range(len(MDP['B'])): 119 | if not np.all(np.any(MDP['B'][f], axis=0)): 120 | raise ValueError(f'please check B[{f}] for missing entries') 121 | for g in range(len(MDP['A'])): 122 | if not np.all(np.any(MDP['A'][g], axis=0)): 123 | raise ValueError(f'please check A[{g}] for missing entries') 124 | 125 | # check initial states 126 | if 's' in MDP: 127 | if MDP['s'].shape[0] > len(MDP['B']): 128 | raise ValueError(f'please specify an initial state MDP.s for {Nf} factors') 129 | f = np.max(MDP['s'], axis=1) 130 | if np.any(f > NS[:len(f)]): 131 | raise ValueError('please ensure initial states MDP.s are consistent with MDP.B') 132 | 133 | # check outcomes if specified 134 | if 'o' in MDP: 135 | if len(MDP['o']): 136 | if MDP['o'].shape[0] != Ng: 137 | raise ValueError(f'please specify an outcomes MDP.o for {Ng} modalities') 138 | if np.any(np.max(MDP['o'], axis=1) > No): 139 | raise ValueError('please ensure # outcomes MDP.o are consistent with MDP.A') 140 | 141 | # check (primary link array if necessary) 142 | if 'link' in MDP: 143 | nf = len(MDP['MDP'][0]['B']) 144 | ns = [MDP['MDP'][0]['B'][f].shape[0] for f in range(nf)] 145 | if not all(np.array(MDP['link']).shape == [nf, Ng]): 146 | raise ValueError(f'please check the size of link [{nf},{Ng}]') 147 | if isinstance(MDP['link'], np.ndarray): 148 | link = [[None for _ in range(Ng)] for _ in range(nf)] 149 | for f in range(len(MDP['link'])): 150 | for g in range(len(MDP['link'][f])): 151 | if MDP['link'][f][g]: 152 | link[f][g] = spm_speye(ns[f], No[g], 0) 153 | MDP['link'] = link 154 | for f in range(len(MDP['link'])): 155 | for g in range(len(MDP['link'][f])): 156 | if MDP['link'][f][g] is not None: 157 | if not all(np.array(MDP['link'][f][g]).shape == [ns[f], No[g]]): 158 | raise ValueError(f'please check link[{f},{g}]') 159 | 160 | # Empirical prior preferences 161 | if 'linkC' in MDP: 162 | if isinstance(MDP['linkC'], np.ndarray): 163 | linkC = [[None for _ in range(Ng)] for _ in range(len(MDP['MDP']['C']))] 164 | for f in range(len(MDP['linkC'])): 165 | for g in range(len(MDP['linkC'][f])): 166 | if MDP['linkC'][f][g]: 167 | linkC[f][g] = spm_speye(MDP['MDP']['C'][f].shape[0], No[g], 0) 168 | MDP['linkC'] = linkC 169 | 170 | # Empirical priors over policies 171 | if 'linkE' in MDP: 172 | if isinstance(MDP['linkE'], np.ndarray): 173 | linkE = [None for _ in range(Ng)] 174 | for g in range(len(MDP['linkE'][0])): 175 | if MDP['linkE'][0][g]: 176 | linkE[g] = spm_speye(MDP['MDP']['E'].shape[0], No[g], 0) 177 | MDP['linkE'] = linkE 178 | 179 | # check factors and outcome modalities have proper labels 180 | for i in range(Nf): 181 | try: 182 | MDP['label']['factor'][i] 183 | except: 184 | try: 185 | MDP['label']['factor'][i] = MDP['Bname'][i] 186 | except: 187 | MDP['label']['factor'][i] = f'factor {i}' 188 | for j in range(NS[i]): 189 | try: 190 | MDP['label']['name'][i][j] 191 | except: 192 | try: 193 | MDP['label']['name'][i][j] = MDP['Sname'][i][j] 194 | except: 195 | MDP['label']['name'][i][j] = f'state {j}({i})' 196 | for j in range(Nu[i]): 197 | try: 198 | MDP['label']['action'][i][j] 199 | except: 200 | MDP['label']['action'][i][j] = f'act {j}({i})' 201 | 202 | for i in range(Ng): 203 | try: 204 | MDP['label']['modality'][i] 205 | except: 206 | try: 207 | MDP['label']['modality'][i] = MDP['Bname'][i] 208 | except: 209 | MDP['label']['modality'][i] = f'modality {i}' 210 | for j in range(No[i]): 211 | try: 212 | MDP['label']['outcome'][i][j] 213 | except: 214 | try: 215 | MDP['label']['outcome'][i][j] = MDP['Oname'][i][j] 216 | except: 217 | MDP['label']['outcome'][i][j] = f'outcome {j}({i})' 218 | 219 | # check names are specified properly 220 | if 'Aname' in MDP: 221 | if len(MDP['Aname']) != Ng: 222 | raise ValueError('please specify an MDP.Aname for each modality') 223 | if 'Bname' in MDP: 224 | if len(MDP['Bname']) != Nf: 225 | raise ValueError('please specify an MDP.Bname for each factor') 226 | 227 | return MDP 228 | 229 | -------------------------------------------------------------------------------- /spm/spm_MDP_size.py: -------------------------------------------------------------------------------- 1 | def spm_MDP_size(mdp): 2 | """ 3 | Dimensions of MDP 4 | :param mdp: dictionary containing MDP parameters 5 | :return: tuple (Nf, Ns, Nu, Ng, No) 6 | Nf - number of factors 7 | Ns - states per factor 8 | Nu - control per factors 9 | Ng - number of modalities 10 | No - levels per modality 11 | """ 12 | 13 | # checks 14 | if 'a' not in mdp: 15 | mdp['a'] = mdp['A'] 16 | if 'b' not in mdp: 17 | mdp['b'] = mdp['B'] 18 | 19 | # sizes of factors and modalities 20 | Nf = len(mdp['b']) # number of hidden factors 21 | Ng = len(mdp['a']) # number of outcome modalities 22 | Ns = [0] * Nf 23 | Nu = [0] * Nf 24 | No = [0] * Ng 25 | 26 | for f in range(Nf): 27 | Ns[f] = mdp['b'][f].shape[0] # number of hidden states 28 | Nu[f] = mdp['b'][f].shape[2] # number of hidden controls 29 | 30 | for g in range(Ng): 31 | No[g] = mdp['a'][g].shape[0] # number of outcomes 32 | 33 | return Nf, Ns, Nu, Ng, No -------------------------------------------------------------------------------- /spm/spm_auxillary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_log(A): 4 | """ 5 | Log of numeric array plus a small constant. 6 | """ 7 | return np.log(A + 1e-16) 8 | 9 | def spm_norm(A): 10 | """ 11 | Normalization of a probability transition matrix (columns). 12 | """ 13 | A = A / np.sum(A, axis=0, keepdims=True) 14 | A[np.isnan(A)] = 1 / A.shape[0] 15 | return A 16 | 17 | def spm_wnorm(A): 18 | """ 19 | This function normalizes the input matrix A. 20 | It adds a small constant to A, then uses broadcasting to subtract the inverse of each column 21 | entry from the inverse of the sum of the columns and then divides by 2. 22 | """ 23 | A = A + np.exp(-16) 24 | A = (1.0 / np.sum(A, axis=0) - 1.0 / A) / 2.0 25 | return A 26 | 27 | def spm_ind2sub(siz, ndx): 28 | """ 29 | Subscripts from linear index. 30 | """ 31 | n = len(siz) 32 | k = np.cumprod([1] + list(siz[:-1])) 33 | sub = np.zeros(n, dtype=int) 34 | for i in range(n-1, -1, -1): 35 | vi = (ndx - 1) % k[i] + 1 36 | vj = (ndx - vi) // k[i] + 1 37 | sub[i] = vj 38 | ndx = vi 39 | return sub -------------------------------------------------------------------------------- /spm/spm_axis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def spm_axis(*args): 5 | if len(args) == 0: 6 | raise ValueError("No arguments provided") 7 | 8 | if len(args) == 1 and args[0] in ['tight', 'scale']: 9 | spm_axis(plt.gca(), args[0]) 10 | elif len(args) == 2 and all_axes(args[0]) and args[1] == 'tight': 11 | for ax in args[0]: 12 | ylim = ax.get_ylim() 13 | if np.diff(ylim) < 1e-12: 14 | ax.set_ylim(ylim[0] - 1, ylim[1] + 1) 15 | else: 16 | ax.set_ylim(ylim[0] - np.diff(ylim) / 16, ylim[1] + np.diff(ylim) / 16) 17 | elif len(args) == 2 and all_axes(args[0]) and args[1] == 'scale': 18 | for ax in args[0]: 19 | ylim = ax.get_ylim() 20 | ax.set_ylim(0, ylim[1] * (1 + 1/16)) 21 | else: 22 | plt.axis(*args) 23 | 24 | def all_axes(handles): 25 | return all(isinstance(h, plt.Axes) for h in handles) and len(handles) == len([h for h in handles if isinstance(h, plt.Axes)]) -------------------------------------------------------------------------------- /spm/spm_combination.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_combinations(Nu): 4 | """ 5 | Returns a matrix of all combinations of Nu. 6 | 7 | Parameters: 8 | Nu (list or array): Vector of dimensions 9 | 10 | Returns: 11 | np.ndarray: Combinations of indices 12 | """ 13 | Nf = len(Nu) 14 | U = np.zeros((np.prod(Nu), Nf), dtype=int) 15 | 16 | for f in range(Nf): 17 | k = [] 18 | for j in range(Nf): 19 | if j == f: 20 | k.append(np.arange(1, Nu[j] + 1)) 21 | else: 22 | k.append(np.ones(Nu[j], dtype=int)) 23 | 24 | u = np.array([1]) 25 | for i in range(Nf): 26 | u = np.kron(k[i], u) 27 | 28 | # accumulate 29 | U[:, f] = u.flatten() 30 | 31 | return U -------------------------------------------------------------------------------- /spm/spm_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import convolve1d 3 | 4 | def spm_conv(X, sx, sy=None): 5 | """ 6 | Gaussian convolution 7 | Parameters: 8 | X - matrix 9 | sx - kernel width (FWHM) in pixels 10 | sy - optional non-isomorphic smoothing 11 | """ 12 | if sy is None: 13 | sy = sx 14 | sx = abs(sx) 15 | sy = abs(sy) 16 | lx, ly = X.shape 17 | 18 | # FWHM -> sigma 19 | sx = sx / np.sqrt(8 * np.log(2)) + np.finfo(float).eps 20 | sy = sy / np.sqrt(8 * np.log(2)) + np.finfo(float).eps 21 | 22 | # kernels 23 | Ex = min(int(6 * sx), lx) 24 | x = np.arange(-Ex, Ex + 1) 25 | kx = np.exp(-x**2 / (2 * sx**2)) 26 | kx = kx / np.sum(kx) 27 | Ey = min(int(6 * sy), ly) 28 | y = np.arange(-Ey, Ey + 1) 29 | ky = np.exp(-y**2 / (2 * sy**2)) 30 | ky = ky / np.sum(ky) 31 | 32 | # convolve 33 | if lx > 1 and len(kx) > 1: 34 | for i in range(ly): 35 | u = X[:, i] 36 | v = np.concatenate((np.flipud(u[:Ex]), u, np.flipud(u[-Ex:]))) 37 | X[:, i] = convolve1d(v, kx, mode='constant', cval=0.0)[Ex:-Ex] 38 | 39 | if ly > 1 and len(ky) > 1: 40 | for i in range(lx): 41 | u = X[i, :] 42 | v = np.concatenate((np.fliplr([u[:Ey]])[0], u, np.fliplr([u[-Ey:]])[0])) 43 | X[i, :] = convolve1d(v, ky, mode='constant', cval=0.0)[Ey:-Ey] 44 | 45 | return X -------------------------------------------------------------------------------- /spm/spm_dir_norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_dir_norm(A): 4 | """ 5 | Normalisation of a (Dirichlet) conditional probability matrix 6 | Parameters: 7 | A - (Dirichlet) parameters of a conditional probability matrix 8 | 9 | Returns: 10 | A - conditional probability matrix 11 | """ 12 | A = A / np.sum(A, axis=1, keepdims=True) 13 | A[np.isnan(A)] = 1 / A.shape[0] 14 | return A -------------------------------------------------------------------------------- /spm/spm_iwft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_iwft(C, k, n): 4 | """ 5 | Inverse windowed Fourier transform - continuous synthesis 6 | :param C: coefficients (complex) 7 | :param k: Frequencies (cycles per window) 8 | :param n: window length 9 | :return: 1-D time-series 10 | """ 11 | # window function (Hanning) 12 | N = C.shape[1] 13 | s = np.zeros(N) 14 | C = np.conj(C) 15 | 16 | # spectral density 17 | for i in range(len(k)): 18 | W = np.exp(-1j * (2 * np.pi * k[i] * np.arange(N) / n)) 19 | w = W * C[i, :] 20 | s += np.real(w) 21 | 22 | return s -------------------------------------------------------------------------------- /spm/spm_softmax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_softmax(x, k=1): 4 | """ 5 | Softmax (e.g., neural transfer) function over columns 6 | 7 | Parameters: 8 | x - numeric array 9 | k - precision, sensitivity or inverse temperature (default k = 1) 10 | 11 | Returns: 12 | y - softmax values 13 | """ 14 | # Apply precision, sensitivity or inverse temperature 15 | if k != 1: 16 | x = k * x 17 | 18 | # If input has less than 2 rows, return an array of ones 19 | if x.shape[0] < 2: 20 | return np.ones_like(x) 21 | 22 | # Exponentiate and normalize 23 | x = np.exp(x - np.max(x, axis=0)) 24 | y = x / np.sum(x, axis=0) 25 | 26 | return y -------------------------------------------------------------------------------- /spm/spm_speye.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import diags, csr_matrix 3 | 4 | def spm_speye(m, n=None, k=0, c=0): 5 | """ 6 | Sparse leading diagonal matrix 7 | 8 | Returns an m x n matrix with ones along the k-th leading diagonal. If 9 | called with an optional fourth argument c = 1, a wraparound sparse matrix 10 | is returned. If c = 2, then empty rows or columns are filled in on the 11 | leading diagonal. 12 | """ 13 | if n is None: 14 | n = m 15 | 16 | # leading diagonal matrix 17 | D = diags([1] * m, k, shape=(m, n), format='csr') 18 | 19 | # add wraparound if necessary 20 | if c == 1: 21 | if k < 0: 22 | D = D + spm_speye(m, n, min(n, m) + k) 23 | elif k > 0: 24 | D = D + spm_speye(m, n, k - min(n, m)) 25 | elif c == 2: 26 | i = np.where(~D.toarray().any(axis=0))[0] 27 | D = D + csr_matrix((np.ones(len(i)), (i, i)), shape=(m, n)) 28 | 29 | return D 30 | 31 | # Example usage 32 | # D = spm_speye(5, 5, 0, 2) 33 | # print(D.toarray()) -------------------------------------------------------------------------------- /spm/spm_vec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_vec(X, *args): 4 | """ 5 | Vectorise a numeric, list or dictionary array. 6 | FORMAT [vX] = spm_vec(X) 7 | X - numeric, list or dictionary array[s] 8 | vX - vec(X) 9 | 10 | See spm_unvec 11 | 12 | e.g.: 13 | spm_vec([np.eye(2), 3]) = [1, 0, 0, 1, 3] 14 | """ 15 | 16 | # Initialise X and vX 17 | if args: 18 | X = [X] + list(args) 19 | 20 | # Vectorise numerical arrays 21 | if isinstance(X, (np.ndarray, int, float)): 22 | vX = np.array(X).flatten() 23 | 24 | # Vectorise logical arrays 25 | elif isinstance(X, (np.bool_, bool)): 26 | vX = np.array(X).flatten() 27 | 28 | # Vectorise dictionary into list arrays 29 | elif isinstance(X, dict): 30 | vX = [] 31 | for key in X: 32 | vX = np.concatenate((vX, spm_vec(X[key]))) 33 | 34 | # Vectorise lists into numerical arrays 35 | elif isinstance(X, list): 36 | vX = [] 37 | for item in X: 38 | vX = np.concatenate((vX, spm_vec(item))) 39 | 40 | else: 41 | vX = np.array([]) 42 | 43 | return vX -------------------------------------------------------------------------------- /spm/spm_wft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def spm_wft(s, k, n): 4 | """ 5 | Windowed Fourier wavelet transform (time-frequency analysis) 6 | 7 | Parameters: 8 | s (ndarray): (t X n) time-series 9 | k (ndarray): Frequencies (cycles per window) 10 | n (int): window length 11 | 12 | Returns: 13 | ndarray: (w X t X n) coefficients (complex) 14 | """ 15 | 16 | # Window function (Hanning) 17 | T, N = s.shape 18 | n = round(n) 19 | h = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, n + 1) / (n + 1))) 20 | h = h / np.sum(h) 21 | C = np.zeros((len(k), T, N), dtype=complex) 22 | 23 | # Spectral density 24 | for i in range(len(k)): 25 | W = np.exp(-1j * (2 * np.pi * k[i] * np.arange(T) / n)) 26 | for j in range(N): 27 | w = np.convolve(s[:, j] * W, h, mode='same') 28 | C[i, :, j] = w 29 | 30 | return C -------------------------------------------------------------------------------- /utility/math_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def nat_log(x): 4 | return np.log(x + np.exp(-16)) 5 | 6 | def dic2mat(dic): 7 | # 获取矩阵的最大行和列 8 | max_row = max(row for row, col in dic.keys()) + 1 9 | max_col = max(col for row, col in dic.keys()) + 1 10 | 11 | # 初始化全矩阵 12 | matrix = np.zeros((max_row, max_col), dtype=float) 13 | 14 | # 将矩阵的值放置在全矩阵的正确位置 15 | for (row, col), value in dic.items(): 16 | if isinstance(value[0], list): 17 | matrix[row, col] = float(value[0][0]) 18 | else: 19 | matrix[row, col] = float(value[0]) 20 | 21 | return matrix --------------------------------------------------------------------------------