├── .gitignore ├── DM.png ├── HM.png ├── .DS_Store ├── Agent.png ├── HDM.png ├── HMM.png ├── Buckley.png ├── Friston.png ├── .gitattributes ├── Bogacz Free Energy Tutorial ├── Bogacz_fig10.png ├── Bogacz_fig3.png ├── Bogacz_fig4.png ├── Bogacz_fig5.png ├── Bogacz_fig6.png ├── Bogacz_fig7.png ├── Bogacz_fig9.png ├── Bogacz_tab1.png ├── Bogacz - 2017 - A tutorial on the free-energy framework for modelling perception and learning.pdf └── Untitled.ipynb ├── README.md ├── Free Energy.ipynb ├── ELBO.ipynb ├── Generalised precision matrix.ipynb ├── Slides.ipynb └── DEM.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *checkpoint.ipynb 3 | -------------------------------------------------------------------------------- /DM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/DM.png -------------------------------------------------------------------------------- /HM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/HM.png -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/.DS_Store -------------------------------------------------------------------------------- /Agent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Agent.png -------------------------------------------------------------------------------- /HDM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/HDM.png -------------------------------------------------------------------------------- /HMM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/HMM.png -------------------------------------------------------------------------------- /Buckley.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Buckley.png -------------------------------------------------------------------------------- /Friston.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Friston.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig10.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig3.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig4.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig5.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig6.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig7.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_fig9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_fig9.png -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz_tab1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz_tab1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Active Inference 2 | 3 | This notebook implements a minimal example of an Agent using Karl Friston's Active Inference and Free Energy concepts. The implementation only depends on Python 2.7, numpy and matplotlib. 4 | 5 | André van Schaik -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Bogacz - 2017 - A tutorial on the free-energy framework for modelling perception and learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vschaik/Active-Inference/HEAD/Bogacz Free Energy Tutorial/Bogacz - 2017 - A tutorial on the free-energy framework for modelling perception and learning.pdf -------------------------------------------------------------------------------- /Free Energy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# The Free Energy Principle.\n", 8 | "\n", 9 | "## By André van Schaik\n", 10 | "\n", 11 | "### __[International Centre for Neuromorphic Systems](https://westernsydney.edu.au/icns)__ 25/12/2018" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "This notebook contains some derivations in regards to Karl Friston's Free Energy Principle. It's a bit of a scratch pad." 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "To optimally estimate the hidden state of its environment, $\\theta$, from its noisy sensory input, $\\phi$, an agent would like to calculate the Bayesian posterior probability density $p(\\theta|\\phi)$. However, in all but the simplest cases, this might not be directly calculable. *Variational Bayes* suggests a workaround by minimising the Kullback-Leibler divergence between what it believes the state of its environment is (encoded in a Recognition density $q(\\theta)$) and the true Bayesian posterior.\n", 26 | "\n", 27 | "\\begin{align*}\n", 28 | "D_{KL}(\\: q(\\theta) \\; || \\; p(\\theta|\\phi) \\: ) = \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta|\\phi)} \\: d\\theta}\n", 29 | "\\end{align*}\n", 30 | "\n", 31 | "The KL divergence is a measure of the difference between two probability distributions, is always positve, and is 0 if and only if the two distributions are the same. Thus adapting $q(\\theta)$ to minimise this KL divergence will result in $q(\\theta)$ being a close approximation of $p(\\theta|\\phi)$. " 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "Obviously, to evaluate (1) directly, we would still need to be able to calculate $p(\\theta|\\phi)$ and we seem to have made no progress. However, the FEP uses the fact that $p(\\theta,\\phi) = p(\\theta|\\phi)p(\\phi)$, to write this as:\n", 39 | "\n", 40 | "\\begin{align*}\n", 41 | "D_{KL}(\\: q(\\theta) \\; || \\; p(\\theta|\\phi) \\: ) &= \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta,\\phi)/p(\\phi)} \\: d\\theta} \\\\\n", 42 | "&= \\int{q(\\theta) \\: \\{ ln\\:q(\\theta) - ln\\:p(\\theta,\\phi) + ln\\:p(\\phi) \\} \\: d\\theta} \\\\\n", 43 | "&= \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta,\\phi)} \\: d\\theta} + \\int{q(\\theta) \\: ln\\:p(\\phi) \\: d\\theta} \\\\\n", 44 | "&= \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta,\\phi)} \\: d\\theta} + ln\\:p(\\phi) \\int{q(\\theta) \\: d\\theta} \\\\\n", 45 | "&= \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta,\\phi)} \\: d\\theta} + ln\\:p(\\phi) \\\\\n", 46 | "\\end{align*}\n", 47 | "\n", 48 | "since $\\int{q(\\theta) \\: d\\theta} = 1$ by definition of a probability density. We continue by writing:\n", 49 | "\n", 50 | "\\begin{align*}\n", 51 | "D_{KL}(\\: q(\\theta) \\; || \\; p(\\theta|\\phi) \\: ) = F + ln\\:p(\\phi)\\\\\n", 52 | "F = \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta,\\phi)} \\: d\\theta} \\\\\n", 53 | "\\end{align*}\n", 54 | "\n", 55 | "The joint density $p(\\theta,\\phi)$ is called the generative density, and represents the agent's belief in how the world works. It can be factorised into $p(\\theta,\\phi) = p(\\phi,\\theta) = p(\\phi|\\theta)\\:p(\\theta)$ where a prior $p(\\theta)$ encodes the agent's beliefs for the world states prior to new sensory input, and a likelihood $p(\\phi|\\theta)$ encodes how the agent's sensory signals relate to the world states. Thus, if we have a model for how the world states generate sensory perception (or if we can learn one), we can calculate $F$, which is called the *Variational Free Energy*, and is the KL divergence between the Recognition density, $q(\\theta)$, and the Generative density, $p(\\theta, \\phi)$. We probably don't know $p(\\phi)$, but, since this doesn't depend on $\\theta$, it plays no role in optimising $q(\\theta)$. " 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "\\begin{align*}\n", 63 | "F &= \\int{q(\\theta) \\: ln \\frac{q(\\theta)}{p(\\theta,\\phi)} \\: d\\theta} \\\\\n", 64 | "&= \\int{q(\\theta) \\: ln \\: q(\\theta) \\: d\\theta} - \\int{q(\\theta) \\: ln \\: p(\\theta,\\phi) \\: d\\theta} \\\\\n", 65 | "&= \\int{q(\\theta) \\: ln \\: q(\\theta) \\: d\\theta} + \\int{q(\\theta) \\: E(\\theta,\\phi) \\: d\\theta} \\\\\n", 66 | "\\end{align*}\n", 67 | "\n", 68 | "with\n", 69 | "\n", 70 | "\\begin{align*}\n", 71 | "E = -ln \\: p(\\theta,\\phi)\n", 72 | "\\end{align*}\n", 73 | "\n", 74 | "The first term in $F$ is just the definition of negative entropy of $q$. The second is the average of $E$ over $q$, which is called *average energy*." 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## References\n", 82 | "\n", 83 | "Dempster A.P., Laird, N.M. and Rubin, D.B. (1977). [Maximum likelihood from incomplete data via the EM algorithm](https://www.jstor.org/stable/2984875). Journal of the Royal Statistical Society. Series B (Methodological), 39.1 ,1-38.\n", 84 | "\n", 85 | "Neal R.M., Hinton G.E. (1998) [A view of the EM algorithm that justifies incremental, sparse, and other variants](http://www.cs.toronto.edu/~fritz/absps/emk.pdf). In: Jordan M.I. (eds) Learning in Graphical Models. NATO ASI Series (Series D: Behavioural and Social Sciences), vol 89. Springer, Dordrecht\n", 86 | "\n", 87 | "Rao, R.P., and Ballard, D.H. (1999). [Predictive coding in the visual cortex: a functional\n", 88 | "interpretation of some extra-classical receptive-field effects](https://www.researchgate.net/publication/13103385_Predictive_Coding_in_the_Visual_Cortex_a_Functional_Interpretation_of_Some_Extra-classical_Receptive-field_Effects). Nature Neuroscience,\n", 89 | "2, 79–87.\n", 90 | "\n", 91 | "Buckley C.L., Kim, C.S., McGregor, S. and Seth, A.K. (2017) [The free energy principle for action and perception: a mathematical review](https://www.sciencedirect.com/science/article/pii/S0022249617300962). Journal of Mathematical Psychology, 81, 55-79.\n" 92 | ] 93 | } 94 | ], 95 | "metadata": { 96 | "hide_input": false, 97 | "kernelspec": { 98 | "display_name": "Python 2", 99 | "language": "python", 100 | "name": "python2" 101 | }, 102 | "language_info": { 103 | "codemirror_mode": { 104 | "name": "ipython", 105 | "version": 2 106 | }, 107 | "file_extension": ".py", 108 | "mimetype": "text/x-python", 109 | "name": "python", 110 | "nbconvert_exporter": "python", 111 | "pygments_lexer": "ipython2", 112 | "version": "2.7.15" 113 | } 114 | }, 115 | "nbformat": 4, 116 | "nbformat_minor": 2 117 | } 118 | -------------------------------------------------------------------------------- /Bogacz Free Energy Tutorial/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib notebook\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import ipywidgets as ipw\n", 13 | "import matplotlib" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from IPython.display import display\n", 23 | "\n", 24 | "def fig_2g(CU):\n", 25 | " global state, params\n", 26 | " \n", 27 | " def g(v):\n", 28 | " return v\n", 29 | "\n", 30 | " def dg_dphi(v):\n", 31 | " return 1\n", 32 | "\n", 33 | " def reset_all():\n", 34 | " global state, params\n", 35 | " phi = mu_p\n", 36 | " e_p = 0\n", 37 | " e_u = 0\n", 38 | " v_p = mu_p\n", 39 | " theta = Theta\n", 40 | " r_p = R_p\n", 41 | " r_u = R_u\n", 42 | " state = np.asarray([phi, e_p, e_u], dtype=float)\n", 43 | " params = np.asarray([v_p, theta, r_p, r_u], dtype=float)\n", 44 | "\n", 45 | " def draw_fig2g(tau, w_params, reset, cont):\n", 46 | " global state, params\n", 47 | "\n", 48 | " def dphi_dt(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 49 | " return theta * e_u * dg_dphi(phi) - e_p\n", 50 | "\n", 51 | " def ep(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 52 | " return (phi - v_p) * r_p\n", 53 | " \n", 54 | " def eu(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 55 | " return (u - theta * g(phi)) * r_u\n", 56 | " \n", 57 | " def dvp(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 58 | " return e_p\n", 59 | "\n", 60 | " def drp(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 61 | " return (1/r_p - (phi - v_p)**2)/2.0\n", 62 | "\n", 63 | " def dru(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 64 | " return (1/r_u - (u - g(phi))**2)/2.0\n", 65 | " \n", 66 | " def dtheta(phi, e_p, e_u, v_p, theta, r_p, r_u, u, v):\n", 67 | " return e_u * g(phi)\n", 68 | "\n", 69 | " if reset == True:\n", 70 | " reset_all()\n", 71 | " [phi, e_p, e_u] = state\n", 72 | " [v_p, theta, r_p, r_u] = params\n", 73 | "\n", 74 | " [h.remove() for h in ax.get_children() if isinstance(h, matplotlib.lines.Line2D)]\n", 75 | " trace = np.zeros((steps, 9))\n", 76 | " trace[0, 0:3] = state\n", 77 | " trace[0, 3:7] = params\n", 78 | " \n", 79 | " for t in range(steps-1):\n", 80 | " if (t*dt)%5 == 0:\n", 81 | " v = max(np.random.normal(mu_p, np.sqrt(1/R_p)), 0)\n", 82 | " trace[t, 8] = v\n", 83 | "\n", 84 | " if (t*dt)%0.5 == 0:\n", 85 | " u = max(np.random.normal(Theta*g(v), np.sqrt(1/R_u)), 0)\n", 86 | " trace[t, 7] = u\n", 87 | "\n", 88 | " state[0] += dt / tau * dphi_dt(*trace[t])\n", 89 | "# state[0] = v\n", 90 | " state[1] = ep(*trace[t])\n", 91 | " state[2] = eu(*trace[t])\n", 92 | " trace[t+1, 0:3] = state\n", 93 | " trace[t+1, 3:7] = params\n", 94 | " trace[t+1, 7] = u\n", 95 | " trace[t+1, 8] = v\n", 96 | "\n", 97 | " if (t*dt + 10*dt)%0.5 == 0:\n", 98 | " params += 0.1 * w_params * np.array([dvp(*trace[t+1]), 0*dtheta(*trace[t+1]), drp(*trace[t+1]), dru(*trace[t+1])])\n", 99 | "# if params[2]<1: params[2] = 1\n", 100 | "# if params[3]<1: params[3] = 1\n", 101 | "# if params[2]>1: params[2] = 1\n", 102 | "# if params[3]>1: params[3] = 1\n", 103 | " trace[t+1, 3:7] = params\n", 104 | "\n", 105 | " ax.plot(np.arange(steps) * dt, trace[:,0], color='C0')\n", 106 | " ax.plot(np.arange(steps) * dt, np.sqrt(trace[:,7]/trace[:,4]), color='C7')\n", 107 | " ax.plot(np.arange(steps) * dt, trace[:,8], color='C8')\n", 108 | " ax.plot(np.arange(steps) * dt, trace[:,1], color='C1')\n", 109 | " ax.plot(np.arange(steps) * dt, trace[:,2], color='C2')\n", 110 | " ax.plot(np.arange(steps) * dt, trace[:,3], color='C3')\n", 111 | " ax.plot(np.arange(steps) * dt, trace[:,4], color='C4')\n", 112 | " ax.plot(np.arange(steps) * dt, trace[:,5], color='C5')\n", 113 | " ax.plot(np.arange(steps) * dt, trace[:,6], color='C6')\n", 114 | " plt.legend([r'$\\phi$', r'$\\sqrt{u}$', r'$v$', \n", 115 | " r'$\\varepsilon_p$', r'$\\varepsilon_u$', r'$v_p$', \n", 116 | " r'$\\theta$', r'$r_p$', r'$r_u$'], ncol=3, loc=1)\n", 117 | "\n", 118 | "\n", 119 | " # set real world mean and variance of food size and light intensity\n", 120 | " mu_p = 6.0 \n", 121 | " R_p = 1/2.0\n", 122 | " R_u = 1/2.0\n", 123 | " Theta = 1.0\n", 124 | "\n", 125 | " # simulation parameters\n", 126 | " dt = 0.001\n", 127 | " dur = 100\n", 128 | " steps = int(dur/dt)\n", 129 | "\n", 130 | " # initialise state and parameter estimates\n", 131 | " reset_all()\n", 132 | " \n", 133 | " tau = ipw.FloatSlider(value=1.0, min=0.1, max=20, step=0.1, continuous_update=CU) \n", 134 | " w_params = ipw.FloatSlider(value=0.01, min=0.01, max=2, step=0.01, continuous_update=CU) \n", 135 | " sliders = ipw.VBox([tau, w_params]) \n", 136 | " reset = ipw.ToggleButton(value=False, description='Reset', button_style='info')\n", 137 | " cont = ipw.ToggleButton(value=False, description='Continue', button_style='info')\n", 138 | " buttons = ipw.VBox([reset, cont])\n", 139 | " controls = ipw.HBox([sliders, buttons])\n", 140 | "\n", 141 | " fig = plt.figure(figsize=(8,4), num='Fig 2g')\n", 142 | " ax = fig.add_subplot(1, 1, 1)\n", 143 | " myplot = ipw.interactive(draw_fig2g, tau=tau, w_params=w_params, reset=reset, cont=cont)\n", 144 | " cont.value = True\n", 145 | " display(controls)\n", 146 | " plt.xlabel('time')\n", 147 | " plt.ylabel('Activity')\n", 148 | " plt.axis([0, dur, -2, 10])\n", 149 | "\n", 150 | " \n", 151 | "fig_2g(False)" 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 2", 158 | "language": "python", 159 | "name": "python2" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 2 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython2", 171 | "version": "2.7.15" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /ELBO.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Variational Inference in 5 minutes #\n", 8 | "\n", 9 | "(copied with a few changes from [2 posts](http://davmre.github.io/blog/inference/2015/11/13/elbo-in-5min) by Dave Moore)\n", 10 | "\n", 11 | "Let $p(x,z)$ be a probability model with observed variables $x$ and latent variables $z$. We want to infer the posterior $p(z|x)$, but in general this won’t have any nice form that we can write down. Instead, we’ll pick some approximating family $q(z;\\lambda)$, with parameters $\\lambda$, and then try to find the distribution within this family that best approximates the posterior. For example, if we model each latent variable independently (a “mean field” approximation) using a scalar Gaussian, the parameters $\\lambda$ are just the means and standard deviations of these Gaussians.\n", 12 | "\n", 13 | "A natural approach to fitting the approximation parameters $\\lambda$ is to minimize the [KL divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence) between our approximation $q(z;\\lambda)$ and the posterior $p(z|x)$.$^2$ Writing this out,\n", 14 | "\n", 15 | "\\begin{equation}\n", 16 | "KL[q(z ; \\lambda) \\| p(z | x)]=\\int q(z ; \\lambda) \\log \\frac{q(z ; \\lambda)}{p(z | x)} d z ,\n", 17 | "\\end{equation}\n", 18 | " \n", 19 | "we see that it depends on the posterior density $p(z|x)$ which we don’t know. However, we do have access to the joint distribution $p(x,z)$, which is proportional to the posterior, so we can just apply simple algebra to unpack the normalizing constant:\n", 20 | "\n", 21 | "\\begin{equation}\n", 22 | "\\begin{aligned}\n", 23 | "KL[q(z ; \\lambda) \\| p(z | x)] &=\\int q(z ; \\lambda) \\log \\frac{q(z ; \\lambda)}{p(z | x)} d z \\\\\n", 24 | "&=\\int q(z ; \\lambda)[\\log q(z ; \\lambda)-\\log p(z | x)] d z \\\\\n", 25 | "&=\\int q(z ; \\lambda)\\left[\\log q(z ; \\lambda)-\\log \\frac{p(x, z)}{p(x)}\\right] d z \\\\\n", 26 | "&=\\log p(x)+\\int q(z ; \\lambda)[\\log q(z ; \\lambda)-\\log p(x, z)] d z \\\\\n", 27 | "&=\\log p(x)-\\mathcal{F}(\\lambda ; x)\n", 28 | "\\end{aligned}\n", 29 | "\\end{equation} \n", 30 | "\n", 31 | "This shows that the KL divergence is equal to the model evidence $\\log{p(x)}$, which is an (unknown) normalizing constant, minus a term $\\mathcal{F}$ given by\n", 32 | "\n", 33 | "\\begin{equation}\n", 34 | "\\mathcal{F}(\\lambda ; x)=\\int q(z ; \\lambda)[\\log p(x, z)-\\log q(z ; \\lambda)] d z\n", 35 | "\\end{equation}\n", 36 | "\n", 37 | "This term is alternately referred to as (negative) variational free energy or the evidence lower bound (ELBO). It is a lower bound on $\\log{p(x)}$\n", 38 | " because we can write $\\log{p(x)}=\\mathcal{F} + KL[q(z;λ)‖p(z|x)]$\n", 39 | " and the KL divergence is nonnegative. Since the model evidence is constant, maximizing $\\mathcal{F}$\n", 40 | " minimizes the KL divergence.\n", 41 | "\n", 42 | "This is the core of variational inference: pick an approximating family and minimize KL divergence between your approximation and the true posterior. \n", 43 | "\n", 44 | "The practical difficulty tends to be that $\\mathcal{F}$ involves an expectation, so evaluating and optimizing it requires either model-specific math$^2$ or Monte Carlo techniques. " 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "One approach is to note that $\\mathcal{F}$ is really just an expectation with respect to our approximating distribution $q$:\n", 52 | "\n", 53 | "\\begin{equation}\n", 54 | "\\begin{aligned}\n", 55 | "\\mathcal{F}(\\lambda ; x) &=E_{z \\sim q}[\\log p(x, z)-\\log q(z ; \\lambda)] \\\\\n", 56 | "&=E_{z \\sim q}[\\log p(x, z)]+H(q ; \\lambda)\n", 57 | "\\end{aligned}\n", 58 | "\\end{equation}\n", 59 | "\n", 60 | "where we’ve made the simplifying assumption that the entropy $H(q;\\lambda)$ is available in closed form. This is true for Gaussian approximating families, but if we’re using some other weird family we can always move the entropy back into the Monte Carlo approximation. The expectation over $\\log{p(x,z)}$ might not have a closed form, but we can approximate it by drawing $n$ samples $z_i ∼ q(z;\\lambda)$ and evaluating the empirical expectation\n", 61 | "\n", 62 | "\\begin{equation}\n", 63 | "\\hat{\\mathcal{F}}(\\lambda ; x)=\\frac{1}{n} \\sum_{i=1}^{n} \\log p\\left(x, z_{i}\\right)+H(q ; \\lambda)\n", 64 | "\\end{equation}\n", 65 | "\n", 66 | "Our approach will be to do gradient ascent on this Monte Carlo approximation. But wait, you might object, $\\lambda$ doesn’t appear anywhere in (the Monte Carlo part of) this expression, so how can we compute a gradient? The answer is that $\\lambda$ was a parameter of the distribution that produced $z$, so we just have to differentiate through the sampling algorithm, holding fixed the random seed (this is the “reparameterization trick” $^3$ ). In many cases this is straightforward to do.\n", 67 | "\n", 68 | "For example, if $q$ is Gaussian parameterized by a mean and standard deviation $\\lambda=(\\mu,\\sigma)$, a typical sampling procedure would first sample a standard Gaussian variable $\\varepsilon \\sim N(0,1)$\n", 69 | " and then compute the transform $z = \\sigma \\varepsilon + \\mu$. Rewriting our Monte Carlo ELBO in terms of these “base variables” $\\varepsilon_i$,\n", 70 | "\n", 71 | "\\begin{equation}\n", 72 | "\\hat{\\mathcal{F}}(\\lambda ; x)=\\frac{1}{n} \\sum_{i=1}^{n} \\log p\\left(x, \\sigma \\varepsilon_{i}+\\mu\\right)+H(q ; \\lambda)\n", 73 | "\\end{equation}\n", 74 | "\n", 75 | "we can now easily differentiate this expression with respect to $\\mu$ and $\\sigma$ (by the chain rule, this will involve the model gradient $\\nabla_z \\log{p(x,z)}$). The result is a stochastic estimate of the gradient of the ELBO, which you can plug into your favorite stochastic optimization algorithm (SGD, Adagrad, etc.).\n", 76 | "\n", 77 | "Note the only assumption we’ve made about the model is that we have access to gradients $\\nabla_z \\log{p(x,z)}$, which is nearly always the case thanks to automatic differentiation. This is how Stan implements variational inference for arbitrary models (more details in [their paper](https://arxiv.org/abs/1506.03431)), and many other languages now support autodiff as well, such as [autograd](https://github.com/HIPS/autograd) and [JAX](https://github.com/google/jax). \n", 78 | "\n", 79 | "If model gradients are not available, it’s still possible to estimate the ELBO gradient using a trick from reinforcement learning, described in the paper [Black Box Variational Inference](https://arxiv.org/abs/1401.0118). However, this estimate is higher-variance, so optimization will converge much more slowly than when model gradients are available." 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "$^1$There are other approaches, including the alternate divergence $KL[p‖q]$ which leads to [expectation propagation](https://tminka.github.io/papers/ep/), or the Laplace approximation which locally matches the curvature at the mode. \n", 87 | "\n", 88 | "$^2$For certain classes of models, e.g., exponential families with conjugate priors, the math is well understood and essentially automateable. This is the idea behind [variational message](http://www.jmlr.org/papers/volume6/winn05a/winn05a.pdf) passing as implemented in, e.g., [Infer.NET](https://dotnet.github.io/infer/). \n", 89 | "\n", 90 | "$^3$This trick was introduced by Kingma, Salimans, and Welling in the context of variational autoencoders, though also independently proposed by several others around the same time. Shakir Mohamed has a nice post that goes into more depth on the history and applicability of this trick. " 91 | ] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.7.6" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 2 115 | } 116 | -------------------------------------------------------------------------------- /Generalised precision matrix.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Precision\n", 15 | "\n", 16 | "To derive the correlation matrices I am assuming a discrete time series of random values. I expect that the derivation in continuous time will proceed analogously.\n", 17 | "\n", 18 | "Assume $x$ is a stationary stochastic process. Then we can define:\n", 19 | "\n", 20 | "\\begin{align*}\n", 21 | "\\mu &= \\frac{1}{n} \\sum_{i=1}^{n}{x_i} \\\\\n", 22 | "\\sigma^2 &= 1/\\pi = \\frac{1}{n} \\sum_{i=1}^{n}{(x_i - \\bar{x})(x_i -\\bar{x})} \\\\\n", 23 | "\\nu(h) &= \\frac{1}{n} \\sum_{i=1}^{n-h}{(x_i - \\bar{x})(x_{i+h} - \\bar{x})} \\\\\n", 24 | "\\rho(h) &= \\nu(h)/\\sigma^2 = \\pi \\nu(h) \\\\\n", 25 | "\\end{align*}\n", 26 | "\n", 27 | "where $\\mu$ is the mean, $\\sigma^2$ is the variance, $\\pi$ is the precision, $\\nu(h)$ is the cross-covariance, and $\\rho(h)$ is the autocorrelation function, both evaluated at delay $h$, with $\\rho(0) = 1$ by definition, when the number of samples $n \\to \\infty$.\n", 28 | "\n", 29 | "Without loss of generality, we can assume $x$ is a zero mean process so we can simplify to:\n", 30 | "\n", 31 | "\\begin{align*}\n", 32 | "\\mu &= 0 \\\\\n", 33 | "\\sigma^2 &= 1/\\pi = \\frac{1}{n} \\sum_{i=1}^{n}{x_i x_i} \\\\\n", 34 | "\\nu(h) &= \\frac{1}{n} \\sum_{i=1}^{n-h}{x_i x_{i+h}} \\\\\n", 35 | "\\rho(h) &= \\nu(h)/\\sigma^2 = \\pi \\nu(h) \\\\\n", 36 | "\\end{align*}\n", 37 | "\n", 38 | "If $x$ is not a single variable, but a set of variables, i.e., a vector $\\mathbf{x}$ then the variance $\\mathbf{\\Sigma}$ and precision $\\mathbf{\\Pi}$ of the process are matrices, as are the cross-covariance $\\mathbf{N}(h)$ and the autocorrelation function $\\mathbf{P}(h)$ evaluated at delay $h$:\n", 39 | "\n", 40 | "\\begin{align*}\n", 41 | "\\mathbf{\\Sigma}^2 &= \\mathbf{\\Pi}^{-1} = \\frac{1}{n} \\sum_{i=1}^{n-h}{\\mathbf{x}_i \\mathbf{x}_i^T} \\\\\n", 42 | "\\mathbf{N}(h) &= \\frac{1}{n} \\sum_{i=1}^{n-h}{\\mathbf{x}_i \\mathbf{x}_{i+h}} \\\\\n", 43 | "\\mathbf{P}(h) &= \\mathbf{\\Pi} \\mathbf{N}(h) \\\\\n", 44 | "\\end{align*}\n", 45 | "\n", 46 | "If the variables that make up $\\mathbf{x}$ are independent, then $\\mathbf{\\Sigma}^2$ and $\\mathbf{\\Pi}$ are diagonal matrices. \n", 47 | "\n", 48 | "In **\"DEM: A variational treatment of dynamic systems\"**, Friston et al introduce the use of generalised coordinates containing the higher orders of motion to describe the trajectory of a time varying system in a way that allows Expectation Maximisation for dynamical systems. If $\\mathbf{\\tilde x}$ is a vector of generalised coordinates for state $x$, we have:\n", 49 | "\n", 50 | "\\begin{align*}\n", 51 | "\\mathbf{\\tilde x} &= [x, \\dot x, \\ddot x, \\dddot x, \\cdots]^T \\\\\n", 52 | "\\end{align*}\n", 53 | "\n", 54 | "where the higher orders of motion are defined as:\n", 55 | "\n", 56 | "\\begin{align*}\n", 57 | "&\\lim_{\\Delta t \\to 0} E \\left( \\dot{x}(t) - \\frac{x(t + \\Delta t) - x(t)}{\\Delta t}\\right) = 0 \\\\\n", 58 | "&\\lim_{\\Delta t \\to 0} E \\left( \\ddot{x}(t) - \\frac{\\dot{x}(t + \\Delta t) - \\dot{x}(t)}{\\Delta t}\\right) = 0 \\\\\n", 59 | "&\\cdots \\\\\n", 60 | "\\end{align*}\n", 61 | "\n", 62 | "or, more informally, using the finite difference method:\n", 63 | "\n", 64 | "\\begin{align*}\n", 65 | "\\dot{x}(t) &= \\frac{x(t + \\Delta t) - x(t)}{\\Delta t} ; \\Delta t \\to 0 \\\\\n", 66 | "\\ddot{x}(t) &= \\frac{\\dot{x}(t + \\Delta t) - \\dot{x}(t)}{\\Delta t} \\\\\n", 67 | "&= \\frac{x(t + \\Delta t) - x(t)}{(\\Delta t)^2} - \\frac{x(t) - x(t - \\Delta t)}{(\\Delta t)^2} \\\\\n", 68 | "&= \\frac{x(t + \\Delta t) - 2x(t) + x(t - \\Delta t)}{(\\Delta t)^2}; \\Delta t \\to 0 \\\\\n", 69 | "&\\cdots\\\\\n", 70 | "\\end{align*}\n", 71 | "\n", 72 | "Because the levels in generalised coordinates are derivatives of the levels above, temporal correlations will exist between the levels, hence $\\mathbf{\\Sigma}^2$ is no longer a diagonal matrix. We can define:\n", 73 | "\n", 74 | "\\begin{align*}\n", 75 | "\\mathbf{\\Sigma}^2 = \\mathbf{\\Pi}^{-1} &= \\frac{1}{n} \\sum_{i=1}^n{\\mathbf{\\tilde x}_i \\mathbf{\\tilde x}_i^T} \\\\\n", 76 | "&= \n", 77 | " \\begin{bmatrix}\n", 78 | " C(x,x) & C(x,\\dot{x}) & C(x,\\ddot{x}) & \\cdots \\\\\n", 79 | " C(\\dot{x},x) & C(\\dot{x},\\dot{x}) & C(\\dot{x},\\ddot{x}) \\\\\n", 80 | " C(\\ddot{x},x) & C(\\ddot{x},\\dot{x}) & C(\\ddot{x},\\ddot{x}) \\\\\n", 81 | " \\vdots & & & \\ddots\n", 82 | " \\end{bmatrix}\n", 83 | "\\end{align*}\n", 84 | "\n", 85 | "where\n", 86 | "\n", 87 | "\\begin{align*}\n", 88 | "C(a,b) = \\frac{1}{n} \\sum_{i=1}^{n}{a_i b_i}\n", 89 | "\\end{align*}\n", 90 | "\n" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "We can easily calculate the first four entries:\n", 98 | "\n", 99 | "\\begin{align*}\n", 100 | "C(x,x) &= \\sigma^2 = \\nu(0) = \\sigma^2 \\rho(0)\\\\\n", 101 | "\\\\\n", 102 | "C(x,\\dot{x}) = C(\\dot{x},x) &= \\frac{1}{n} \\sum_{i=1}^{n}{x(i) \\dot{x}(i)} \\\\\n", 103 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{x(i) \\frac{x(i + \\Delta t) - x(i)}{\\Delta t}} \\\\\n", 104 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{x(i) x(i + \\Delta t) - x(i) x(i)}{\\Delta t}} \\\\\n", 105 | "&= \\frac{\\nu(\\Delta t) - \\nu(0)}{\\Delta t} \\\\\n", 106 | "&= \\dot{\\nu}(0) \\\\\n", 107 | "&= \\sigma^2\\dot{\\rho}(0) \\\\\n", 108 | "\\\\\n", 109 | "C(\\dot{x},\\dot{x}) &= \\frac{1}{n} \\sum_{i=1}^{n}{\\dot{x}(i) \\dot{x}(i)} \\\\\n", 110 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{x(i + \\Delta t) - x(i)}{\\Delta t} \\frac{x(i + \\Delta t) - x(i)}{\\Delta t}} \\\\\n", 111 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{x(i + \\Delta t) x(i + \\Delta t) - 2x(i) x(i + \\Delta t) + x(i)x(i)}{(\\Delta t)^2}} \\\\\n", 112 | "&\\text{and since} \\; x(i + \\Delta t) x(i + \\Delta t) = x(i) x(i)\\\\\n", 113 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{2 x(i) x(i) - 2x(i) x(i + \\Delta t)}{(\\Delta t)^2}} \\\\\n", 114 | "&= \\frac{2\\nu(0) - 2\\nu(\\Delta t)}{(\\Delta t)^2} \\\\\n", 115 | "&\\text{and because} \\; \\nu(\\Delta t) = \\nu(-\\Delta t)\\\\\n", 116 | "&= \\frac{-\\nu(\\Delta t) + 2\\nu(0) - \\nu(-\\Delta t)}{(\\Delta t)^2} \\\\\n", 117 | "&= -\\ddot{\\nu}(0)\\\\\n", 118 | "&= -\\sigma^2 \\ddot{\\rho}(0)\\\\\n", 119 | "\\end{align*}\n", 120 | "\n", 121 | "We now define:\n", 122 | "\n", 123 | "\\begin{align*}\n", 124 | "C(\\overset{j}{x},\\overset{k}{x}) &= \\frac{1}{n} \\sum_{i=1}^{n}{\\overset{j}{x}(i) \\overset{k}{x}(i)} \\\\\n", 125 | "&= \\overset{j+k}{\\nu}(0)\\\\\n", 126 | "&= \\sigma^2 \\overset{j+k}{\\rho}(0)\\\\\n", 127 | "\\end{align*}\n", 128 | "\n", 129 | "where we use $\\overset{j}{x}_i$ and $\\overset{k}{x}_i$ to indicate the $j$th and $k$th derivative of $x_i$, respectively. We can then proceed by induction:\n", 130 | "\n", 131 | "\\begin{align*}\n", 132 | "C(\\overset{j+2}{x},\\overset{k}{x}) &= \\frac{1}{n} \\sum_{i=1}^{n}{\\overset{j+2}{x}(i) \\overset{k}{x}(i)} \\\\\n", 133 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{\\overset{j}{x}(i + \\Delta t) - 2\\overset{j}{x}(i) + \\overset{j}{x}(i - \\Delta t)}{(\\Delta t)^2} \\overset{k}{x}(i)} \\\\\n", 134 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{\\overset{j}{x}(i + \\Delta t) \\overset{k}{x}(i) - 2\\overset{j}{x}(i) \\overset{k}{x}(i) + \\overset{j}{x}(i - \\Delta t)\\overset{k}{x}(i)}{(\\Delta t)^2}} \\\\\n", 135 | "&= \\overset{j+k+2}{\\nu}(0)\\\\\n", 136 | "&= \\sigma^2 \\overset{j+k+2}{\\rho}(0)\\\\\n", 137 | "\\end{align*}\n", 138 | "\n", 139 | "and symmetrically:\n", 140 | "\\begin{align*}\n", 141 | "C(\\overset{j}{x},\\overset{k+2}{x}) &= \\frac{1}{n} \\sum_{i=1}^{n}{\\overset{j}{x}(i) \\overset{k+2}{x}(i)} \\\\\n", 142 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\overset{j}{x}(i) \\frac{\\overset{k}{x}(i + \\Delta t) - 2\\overset{k}{x}(i) + \\overset{k}{x}(i - \\Delta t)}{(\\Delta t)^2}} \\\\\n", 143 | "&= \\frac{1}{n} \\sum_{i=1}^{n}{\\frac{\\overset{j}{x}(i) \\overset{k}{x}(i + \\Delta t) - 2\\overset{j}{x}(i) \\overset{k}{x}(i) + \\overset{j}{x}(i)\\overset{k}{x}(i - \\Delta t)}{(\\Delta t)^2}} \\\\\n", 144 | "&= \\overset{j+k+2}{\\nu}(0)\\\\\n", 145 | "&= \\sigma^2 \\overset{j+k+2}{\\rho}(0)\\\\\n", 146 | "\\end{align*}\n", 147 | "\n", 148 | "Thus, the entry at row $j+2$, column $k$ (both starting from $0$) is simply the second derivative of the entry at row $j$, column $k$, as is the entry at row $j$, column $k+2$. This results in the pattern that the correlation between derivatives of order $j$ and $k$ yields the $j+k$'s derivative of the autocorrelation function multiplied by $\\sigma^2$ and additionally multiplied by $-1$ if both $j$ and $k$ are odd. Thus we can write:\n", 149 | "\n", 150 | "\\begin{align*}\n", 151 | "\\mathbf{\\Sigma}^2 = \\mathbf{\\Pi}^{-1} =\n", 152 | " \\sigma^2 \\begin{bmatrix}\n", 153 | " \\rho(0) & \\dot{\\rho}(0) & \\ddot{\\rho}(0) & \\dddot{\\rho}(0) & \\ddot{\\ddot{\\rho}}(0) & \\ddot{\\dddot{\\rho}}(0) & \\cdots \\\\\n", 154 | " \\dot{\\rho}(0) & -\\ddot{\\rho}(0) & \\dddot{\\rho}(0) & -\\ddot{\\ddot{\\rho}}(0) & \\ddot{\\dddot{\\rho}}(0) & -\\dddot{\\dddot{\\rho}}(0)\\\\\n", 155 | " \\ddot{\\rho}(0) & \\dddot{\\rho}(0) & \\ddot{\\ddot{\\rho}}(0) & \\ddot{\\dddot{\\rho}}(0) & \\dddot{\\dddot{\\rho}}(0) & \\dot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 156 | " \\dddot{\\rho}(0) & -\\ddot{\\ddot{\\rho}}(0) & \\ddot{\\dddot{\\rho}}(0) & -\\dddot{\\dddot{\\rho}}(0) & \\dot{\\dddot{\\dddot{\\rho}}}(0) & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 157 | " \\ddot{\\ddot{\\rho}}(0) & \\ddot{\\dddot{\\rho}}(0) & \\dddot{\\dddot{\\rho}}(0) & \\dot{\\dddot{\\dddot{\\rho}}}(0) & \\ddot{\\dddot{\\dddot{\\rho}}}(0) & \\dddot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 158 | " \\ddot{\\dddot{\\rho}}(0) & -\\dddot{\\dddot{\\rho}}(0) & \\dot{\\dddot{\\dddot{\\rho}}}(0) & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) & \\dddot{\\dddot{\\dddot{\\rho}}}(0) & -\\dot{\\dddot{\\dddot{\\dddot{\\rho}}}}(0)\\\\\n", 159 | " \\vdots & & & & & & \\ddots\n", 160 | " \\end{bmatrix}\n", 161 | "\\end{align*}\n" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "The autocorrelation function $\\rho(h)$ is by construction symmetric around $h=0$. In other words, $\\rho$ is an even function. This implies directly that $\\dot{\\rho}(h)$ is odd, $\\ddot{\\rho}(h)$ is even, etc. Odd functions evaluated at $0$ yield $0$. Also $\\rho(0) = 1$ as can be seen from $C(x,x)$ above. Thus we can simplify this as:\n", 169 | "\n", 170 | "\\begin{align*}\n", 171 | "\\mathbf{\\Sigma}^2 = \\mathbf{\\Pi}^{-1} =\n", 172 | " \\sigma^2 \\begin{bmatrix}\n", 173 | " 1 & 0 & \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\cdots \\\\\n", 174 | " 0 & -\\ddot{\\rho}(0) & 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0)\\\\\n", 175 | " \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 \\\\\n", 176 | " 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 177 | " \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 & \\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 \\\\\n", 178 | " 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 & -\\dot{\\dddot{\\dddot{\\dddot{\\rho}}}}(0)\\\\\n", 179 | " \\vdots & & & & & & \\ddots\n", 180 | " \\end{bmatrix}\n", 181 | "\\end{align*}\n", 182 | "\n", 183 | "This matrix can be evaluated for any analytic autocorrelation function. If we assume, for convenience, that the temporal correlation of all innovations have the same zero mean Gaussian form with precision parameter $\\gamma = 1/\\sigma^2$, we can proceed with:\n", 184 | "\n", 185 | "\\begin{align*}\n", 186 | "\\rho(h) &= e^{-\\frac{\\gamma}{2} h^2} \\\\\n", 187 | "\\dot{\\rho}(h) &= -\\gamma h \\rho(h) \\\\\n", 188 | "\\ddot{\\rho}(h) &= -\\gamma \\rho(h) + (\\gamma h)^2 \\rho(h) \\\\\n", 189 | "\\dddot{\\rho}(h) &= \\gamma^2 h \\rho(h) + 2 \\gamma^2 h \\rho(h) - (\\gamma h)^3 \\rho(h) \\\\\n", 190 | "&= 3 \\gamma^2 h \\rho(h) - (\\gamma h)^3 \\rho(h) \\\\\n", 191 | "\\ddot{\\ddot{\\rho}}(h) &= 3 \\gamma^2 \\rho(h) - 3 \\gamma^3 h^2 \\rho(h) - 3 \\gamma^3 h^2 \\rho(h) + (\\gamma h)^4 \\rho(h) \\\\\n", 192 | "&= 3 \\gamma^2 \\rho(h) - 6 \\gamma^3 h^2 \\rho(h) + (\\gamma h)^4 \\rho(h) \\\\\n", 193 | "\\ddot{\\dddot{\\rho}}(h) &= - 3 \\gamma^3 h \\rho(h) - 12 \\gamma^3 h \\rho(h) + 6 \\gamma^4 h^3 \\rho(h) + 4 \\gamma^4 h^3 \\rho(h) - (\\gamma h)^5 \\rho(h) \\\\\n", 194 | "&= - 15 \\gamma^3 h \\rho(h) + 10 \\gamma^4 h^3 \\rho(h) - (\\gamma h)^5 \\rho(h) \\\\\n", 195 | "\\dddot{\\dddot{\\rho}}(h) &= -15 \\gamma^3 \\rho(h) + 15 \\gamma^4 h^2 \\rho(h) + 30 \\gamma^4 h^2 \\rho(h) - 10 \\gamma^5 h^4 \\rho(h) - 5 \\gamma^5 h^4 \\rho(h) + (\\gamma h)^6 \\rho(h) \\\\\n", 196 | "&= -15 \\gamma^3 \\rho(h) + 45 \\gamma^4 h^2 \\rho(h) - 15 \\gamma^5 h^4 \\rho(h) + (\\gamma h)^6 \\rho(h) \\\\\n", 197 | "\\end{align*}\n", 198 | "\n", 199 | "The expression for higher order derivatives gets longer and longer. As a shortcut to calculate the higher order derivatives, compare the expressions above with those of the (probabilist's) Hermite polynomials:\n", 200 | "\n", 201 | "\\begin{align*}\n", 202 | "He_n(x) &= (-1)^n e^{\\frac{x^2}{2}} \\frac{d^n}{dx^n} e^{-\\frac{x^2}{2}} \\\\\n", 203 | "\\rho(h) &= e^{-\\frac{\\gamma}{2} h^2} \\\\\n", 204 | "&= \\frac{He_0(\\gamma^{1/2} h)}{(-1)^0 e^{\\frac{\\gamma h^2}{2}}} \\\\\n", 205 | "\\overset{n}\\rho(h) &= \\frac{He_n(\\gamma^{1/2} h)}{(-1)^n e^{\\frac{\\gamma h^2}{2}}} \\\\\n", 206 | "\\end{align*}\n", 207 | "\n", 208 | "where $\\overset{n}\\rho(h)$ is the $n$th derivative of $\\rho(h)$. Because we are only interested in the values evaluated at $h=0$ for even values of $n$, the denominator evaluates to $1$ so that this simplifies to:\n", 209 | "\n", 210 | "\\begin{align*}\n", 211 | "\\overset{n}\\rho(0) &= He_n(0) \\\\\n", 212 | "\\end{align*}\n", 213 | "\n", 214 | "The explicit expression for the Hermite polynomials can be used to calculate the coefficients for all even derivatives quickly. For $n$ is even, they are given by\n", 215 | "\n", 216 | "\\begin{align*}\n", 217 | "He_n(x) &= n!\\sum_{m=0}^{n/2}\\frac{(-1)^m}{m! (n-2m)!}\\frac{x^{n-2m}}{2^m} \\\\\n", 218 | "He_n(0) &= n!\\sum_{m=0}^{n/2}\\frac{(-1)^m}{m! (n-2m)!}\\frac{0^{n-2m}}{2^m} \\\\\n", 219 | " &= n!\\frac{(-1)^{n/2}}{(n/2)! (0)!}\\frac{0^0}{2^{n/2}} \\\\\n", 220 | " &= n!\\frac{(-1)^{n/2}}{(n/2)!}\\frac{1}{2^{n/2}} \\\\\n", 221 | " &= \\left(-\\frac{1}{2}\\right)^{n/2}\\frac{n!}{(n/2)!}\\\\\n", 222 | "\\end{align*}\n", 223 | "\n", 224 | "\n", 225 | "Evaluating the above yields:\n" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 2, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "Derivatives of 𝜌\n", 238 | "0 1.0 γ^0\n", 239 | "2 -1.0 γ^1\n", 240 | "4 3.0 γ^2\n", 241 | "6 -15.0 γ^3\n", 242 | "8 105.0 γ^4\n", 243 | "10 -945.0 γ^5\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "from math import factorial as fac\n", 249 | "print(\"Derivatives of 𝜌\")\n", 250 | "for m in range(6):\n", 251 | " print(\"{:<3} {:>8} γ^{:}\".format(2*m, (-1/2)**m * fac(2*m) / fac(m), m))" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "\\begin{align*}\n", 259 | "\\mathbf{\\Sigma}^2 = \\mathbf{\\Pi}^{-1} &=\n", 260 | " \\sigma^2 \\begin{bmatrix}\n", 261 | " 1 & 0 & \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\cdots \\\\\n", 262 | " 0 & -\\ddot{\\rho}(0) & 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0)\\\\\n", 263 | " \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 \\\\\n", 264 | " 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 265 | " \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 & \\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 \\\\\n", 266 | " 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 & -\\dot{\\dddot{\\dddot{\\dddot{\\rho}}}}(0)\\\\\n", 267 | " \\vdots & & & & & & \\ddots\n", 268 | " \\end{bmatrix} \\\\\n", 269 | " &= \\sigma^2 \\begin{bmatrix}\n", 270 | " 1 & 0 & -\\gamma & 0 & 3 \\gamma^2 & 0 \\\\\n", 271 | " 0 & \\gamma & 0 & -3 \\gamma^2 & 0 & 15 \\gamma^3 \\\\\n", 272 | " -\\gamma & 0 & 3 \\gamma^2 & 0 & -15 \\gamma^3 & 0 \\\\\n", 273 | " 0 & -3 \\gamma^2 & 0 & 15 \\gamma^3 & 0 & -105 \\gamma^4 \\\\\n", 274 | " 3 \\gamma^2 & 0 & -15 \\gamma^3 & 0 & 105 \\gamma^4 & 0 \\\\\n", 275 | " 0 & 15 \\gamma^3 & 0 & -105 \\gamma^4 & 0 & 945 \\gamma^5 \\\\\n", 276 | " \\end{bmatrix} \\\\\n", 277 | "\\end{align*}" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "Note, to get Friston's version, we need to replace $\\gamma$ with $\\gamma / 2$. In his work he simply states that he assumes a \"Gaussian form\" for the autocorrelation with precision parameter $\\gamma$. The autocorrelation functions isn't explicitly defined, but in order to get the results in the paper, it would have to be defined as $\\rho(h) = \\exp(-\\frac{\\gamma}{4} h^2)$, which implies that $\\gamma$ is the precision of the process multiplied by two. This seems an unusual definition, but yields the results in his papers:\n", 285 | "\n", 286 | "\\begin{align*}\n", 287 | "\\mathbf{\\Sigma}^2 = \\mathbf{\\Pi}^{-1}\n", 288 | " &= \\sigma^2 \\begin{bmatrix}\n", 289 | " 1 & 0 & -\\frac{\\gamma}{2} & 0 & 3 \\left(\\frac{\\gamma}{2}\\right)^2 & 0 \\\\\n", 290 | " 0 & \\frac{\\gamma}{2} & 0 & -3 \\left(\\frac{\\gamma}{2}\\right)^2 & 0 & 15 \\left(\\frac{\\gamma}{2}\\right)^3 \\\\\n", 291 | " -\\frac{\\gamma}{2} & 0 & 3 \\left(\\frac{\\gamma}{2}\\right)^2 & 0 & -15 \\left(\\frac{\\gamma}{2}\\right)^3 & 0 \\\\\n", 292 | " 0 & -3 \\left(\\frac{\\gamma}{2}\\right)^2 & 0 & 15 \\left(\\frac{\\gamma}{2}\\right)^3 & 0 & -105 \\left(\\frac{\\gamma}{2}\\right)^4 \\\\\n", 293 | " 3 \\left(\\frac{\\gamma}{2}\\right)^2 & 0 & -15 \\left(\\frac{\\gamma}{2}\\right)^3 & 0 & 105 \\left(\\frac{\\gamma}{2}\\right)^4 & 0 \\\\\n", 294 | " 0 & 15 \\left(\\frac{\\gamma}{2}\\right)^3 & 0 & -105 \\left(\\frac{\\gamma}{2}\\right)^4 & 0 & 945 \\left(\\frac{\\gamma}{2}\\right)^5 \\\\\n", 295 | " \\end{bmatrix} \\\\\n", 296 | " &= \\sigma^2 \\begin{bmatrix}\n", 297 | " 1 & 0 & -\\frac{1}{2}\\gamma & 0 & \\frac{3}{4}\\gamma^2 & 0 \\\\\n", 298 | " 0 & \\frac{1}{2}\\gamma & 0 & -\\frac{3}{4}\\gamma^2 & 0 & \\frac{15}{8}\\gamma^3 \\\\\n", 299 | " -\\frac{1}{2}\\gamma & 0 & \\frac{3}{4}\\gamma^2 & 0 & -\\frac{15}{8}\\gamma^3 & 0 \\\\\n", 300 | " 0 & -\\frac{3}{4}\\gamma^2 & 0 & \\frac{15}{8}\\gamma^3 & 0 & -\\frac{105}{16}\\gamma^4 \\\\\n", 301 | " \\frac{3}{4}\\gamma^2 & 0 & -\\frac{15}{8}\\gamma^3 & 0 & \\frac{105}{16}\\gamma^4 & 0 \\\\\n", 302 | " 0 & \\frac{15}{8}\\gamma^3 & 0 & -\\frac{105}{16}\\gamma^4 & 0 & \\frac{945}{32}\\gamma^5 \\\\\n", 303 | " \\end{bmatrix} \\\\\n", 304 | "\\end{align*}" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "Note, the literature on Hermite polynomials can cause confusion because the physicist's definition and probabilist's definition are different. If we define the autocorrelation as $\\rho(h) = \\exp(-\\gamma h^2)$, we could then use the physicist's Hermite polynomials instead.\n", 312 | "\n", 313 | "\\begin{align*}\n", 314 | "H_n(x) &= (-1)^n e^{x^2} \\frac{d^n}{dx^n} e^{-x^2} \\\\\n", 315 | "\\rho(h) &= e^{-\\gamma h^2} \\\\\n", 316 | "&= \\frac{H_0(\\gamma^{1/2} h)}{(-1)^0 e^{\\frac{\\gamma h^2}{2}}} \\\\\n", 317 | "\\overset{n}\\rho(h) &= \\frac{H_n(\\gamma^{1/2} h)}{(-1)^n e^{\\frac{\\gamma h^2}{2}}} \\\\\n", 318 | "\\end{align*}\n", 319 | "\n", 320 | "Evaluating these at $h=0$ for even values of $n$, the denominator evaluates to $1$ so that this simplifies to:\n", 321 | "\n", 322 | "\\begin{align*}\n", 323 | "\\overset{n}\\rho(0) &= H_n(0) \\\\\n", 324 | "\\end{align*}\n", 325 | "\n", 326 | "The explicit expression for the Hermite polynomials can be used to calculate the coefficients for all even derivatives quickly. For $n$ is even, they are given by\n", 327 | "\n", 328 | "\\begin{align*}\n", 329 | "H_n(x) &= n!\\sum_{m=0}^{n/2}\\frac{(-1)^{n/2 - m}}{2m! (n/2-m)!} (2x)^{2m} \\\\\n", 330 | "H_n(0) &= n!\\sum_{m=0}^{n/2}\\frac{(-1)^{n/2 - m}}{2m! (n/2-m)!} (0)^{2m} \\\\\n", 331 | " &= n!\\frac{(-1)^{n/2}}{(0)! (n/2)!} 0^0 \\\\\n", 332 | " &= (-1)^{n/2}\\frac{n!}{(n/2)!} \\\\\n", 333 | "\\end{align*}" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 3, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "Derivatives of 𝜌\n", 346 | "0 1.0 γ^0\n", 347 | "2 -2.0 γ^1\n", 348 | "4 12.0 γ^2\n", 349 | "6 -120.0 γ^3\n", 350 | "8 1680.0 γ^4\n", 351 | "10 -30240.0 γ^5\n" 352 | ] 353 | } 354 | ], 355 | "source": [ 356 | "print(\"Derivatives of 𝜌\")\n", 357 | "for m in range(6):\n", 358 | " print(\"{:<3} {:>8} γ^{:}\".format(2*m, (-1)**m * fac(2*m) / fac(m), m))" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "\\begin{align*}\n", 366 | "\\mathbf{\\Sigma}^2 = \\mathbf{\\Pi}^{-1} &=\n", 367 | " \\sigma^2 \\begin{bmatrix}\n", 368 | " 1 & 0 & \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\cdots \\\\\n", 369 | " 0 & -\\ddot{\\rho}(0) & 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0)\\\\\n", 370 | " \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 \\\\\n", 371 | " 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 372 | " \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 & \\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 \\\\\n", 373 | " 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 & -\\dot{\\dddot{\\dddot{\\dddot{\\rho}}}}(0)\\\\\n", 374 | " \\vdots & & & & & & \\ddots\n", 375 | " \\end{bmatrix} \\\\\n", 376 | " &= \\sigma^2 \\begin{bmatrix}\n", 377 | " 1 & 0 & -2\\gamma & 0 & 12 \\gamma^2 & 0 \\\\\n", 378 | " 0 & 2\\gamma & 0 & -12 \\gamma^2 & 0 & 120 \\gamma^3 \\\\\n", 379 | " -2\\gamma & 0 & 12 \\gamma^2 & 0 & -120 \\gamma^3 & 0 \\\\\n", 380 | " 0 & -12 \\gamma^2 & 0 & 120 \\gamma^3 & 0 & -1680 \\gamma^4 \\\\\n", 381 | " 12 \\gamma^2 & 0 & -120 \\gamma^3 & 0 & 1680 \\gamma^4 & 0 \\\\\n", 382 | " 0 & 120 \\gamma^3 & 0 & -1680 \\gamma^4 & 0 & 30240 \\gamma^5 \\\\\n", 383 | " \\end{bmatrix} \\\\\n", 384 | "\\end{align*}" 385 | ] 386 | } 387 | ], 388 | "metadata": { 389 | "kernelspec": { 390 | "display_name": "Python 3 (ipykernel)", 391 | "language": "python", 392 | "name": "python3" 393 | }, 394 | "language_info": { 395 | "codemirror_mode": { 396 | "name": "ipython", 397 | "version": 3 398 | }, 399 | "file_extension": ".py", 400 | "mimetype": "text/x-python", 401 | "name": "python", 402 | "nbconvert_exporter": "python", 403 | "pygments_lexer": "ipython3", 404 | "version": "3.7.11" 405 | } 406 | }, 407 | "nbformat": 4, 408 | "nbformat_minor": 4 409 | } 410 | -------------------------------------------------------------------------------- /Slides.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "slideshow": { 7 | "slide_type": "slide" 8 | } 9 | }, 10 | "source": [ 11 | "## Dynamic Expectation Maximisation\n", 12 | "\n", 13 | "

\n", 14 | "### André van Schaik\n", 15 | "\n", 16 | "

\n", 17 | "#### __[International Centre for Neuromorphic Systems](https://westernsydney.edu.au/icns)__ \n", 18 | "12/03/2019" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "slideshow": { 25 | "slide_type": "slide" 26 | } 27 | }, 28 | "source": [ 29 | "" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": { 35 | "slideshow": { 36 | "slide_type": "subslide" 37 | } 38 | }, 39 | "source": [ 40 | "" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "slideshow": { 47 | "slide_type": "subslide" 48 | } 49 | }, 50 | "source": [ 51 | "" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": { 57 | "slideshow": { 58 | "slide_type": "slide" 59 | } 60 | }, 61 | "source": [ 62 | "### Hidden Markov Model ###\n", 63 | "\n", 64 | "We can write a Hidden Markov Model as:\n", 65 | "\n", 66 | "\\begin{align*}\n", 67 | "y[t] &= g(x[t]) + z \\\\\n", 68 | "x[t] &= f(x[t-1]) + w\n", 69 | "\\end{align*}\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "slideshow": { 76 | "slide_type": "subslide" 77 | } 78 | }, 79 | "source": [ 80 | "" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "slideshow": { 87 | "slide_type": "subslide" 88 | } 89 | }, 90 | "source": [ 91 | "### Dynamic model in generalised coordinates ###\n", 92 | "\n", 93 | "We can write a dynamic input-state-output model as:\n", 94 | "\n", 95 | "\\begin{align*}\n", 96 | "\\tilde y &= \\tilde g + \\tilde z \\\\\n", 97 | "D \\tilde x &= \\tilde f + \\tilde w \\tag{1}\n", 98 | "\\end{align*}\n", 99 | "\n", 100 | "where the $\\tilde a$ notation indicates variables and functions in generalised coordinates of motion $\\tilde a = [a, a', a'', a''', ...]^T$. " 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "slideshow": { 107 | "slide_type": "subslide" 108 | } 109 | }, 110 | "source": [ 111 | "$D$ is a block-matrix derivative operator, whose first leading-diagonal contains identity matrices. This operator simply shifts the vectors of generalised motion so $a[i]$ that is replaced by $a[i+1]$.\n", 112 | "\n", 113 | "Importantly, $a'$ is the *value* of the derivative of $a$ with respect to time; in other words, it is a dimensionless number, even though the time derivative obviously has a unit of $[s^{-1}]$. One way of interpreting this is that $a' = \\tau\\,da/dt$ where $\\tau = dt$ and all time is measured in units of $dt$, and similarly for the higher orders of motion." 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": { 119 | "slideshow": { 120 | "slide_type": "subslide" 121 | } 122 | }, 123 | "source": [ 124 | "The predicted sensor response $\\tilde g$ and motion $\\tilde f$ of the hidden states $\\tilde x$ in absence of random fluctuations are:\n", 125 | "\n", 126 | "\\begin{align*}\n", 127 | "\\begin{split}\n", 128 | "g &= g(x, v) \\\\\n", 129 | "g' &= g_x x' + g_v v' \\\\\n", 130 | "g'' &= g_x x'' + g_v v'' \\\\\n", 131 | "&\\phantom{g=\\,} \\vdots \\\\\n", 132 | "\\end{split}\n", 133 | "\\:\\:\\:\n", 134 | "\\begin{split}\n", 135 | "f &= f(x, v) \\\\\n", 136 | "f' &= f_x x' + f_v v' \\\\\n", 137 | "f'' &= f_x x'' + f_v v'' \\\\\n", 138 | "&\\phantom{f=\\,} \\vdots \\\\\n", 139 | "\\end{split}\n", 140 | "\\end{align*}\n", 141 | "\n", 142 | "Here, $f$ and $g$ are continuous nonlinear functions and $\\tilde v$ are known causes or inputs, which can also result from actions by the agent. The notation $a_b$ is shorthand for $\\partial{a}/\\partial{b}$. " 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": { 148 | "slideshow": { 149 | "slide_type": "subslide" 150 | } 151 | }, 152 | "source": [ 153 | "We assume that the observation noise $\\tilde z$ follows a zero-mean Gaussian distribution and similarly for the state noise $\\tilde w$. The input drive is also Gaussian but with a mean that can be different from zero. We also assume Gaussian priors for the hyperparameters $\\lambda$ and $\\theta$.\n", 154 | "\n", 155 | "\\begin{align*}\n", 156 | "p(\\tilde z) &= \\mathcal{N}(0, \\tilde \\Sigma^z)\\\\\n", 157 | "p(\\tilde w) &= \\mathcal{N}(0, \\tilde \\Sigma^w)\\\\\n", 158 | "p(\\tilde v) &= \\mathcal{N}(\\tilde \\eta^v, \\tilde C^v)\\\\\n", 159 | "p(\\lambda) &= \\mathcal{N}(\\lambda : \\eta^\\lambda, C^\\lambda)\\\\\n", 160 | "p(\\theta) &= \\mathcal{N}(\\theta : \\eta^\\theta, C^\\theta)\n", 161 | "\\end{align*}\n", 162 | "\n", 163 | "Note, we use $\\tilde C^v$ instead of $\\tilde \\Sigma^v$ to indicate this is a *prior* covariance, rather than a *conditional* covariance.\n" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "slideshow": { 170 | "slide_type": "subslide" 171 | } 172 | }, 173 | "source": [ 174 | "We can now evaluate the joint density over observations $\\tilde y$, hidden states $\\tilde x$, and inputs $\\tilde v$:\n", 175 | "\n", 176 | "\\begin{align*}\n", 177 | "p(\\tilde y, \\tilde x, \\tilde v \\,|\\, \\theta, \\lambda) &= p(\\tilde y \\,|\\, \\tilde x, \\tilde v, \\theta, \\lambda) \\; p(\\tilde x \\,|\\, \\tilde v, \\theta, \\lambda) \\; p(\\tilde v) \\\\\n", 178 | "p(\\tilde y \\,|\\, \\tilde x, \\tilde v, \\theta, \\lambda) &= \\mathcal{N}(\\tilde y : \\tilde g, \\tilde \\Sigma(\\lambda)^z) \\\\\n", 179 | "p(\\tilde x \\,|\\, \\tilde v, \\theta, \\lambda) &= \\mathcal{N}(D\\tilde x : \\tilde f, \\tilde \\Sigma(\\lambda)^w) \\\\\n", 180 | "p(\\tilde v) &= \\mathcal{N}(\\tilde v : \\eta^v, C^v)\n", 181 | "\\end{align*}\n", 182 | "\n", 183 | "where $\\theta$ contains the parameters describing $f$ and $g$, and $\\lambda$ are hyperparameters which control the amplitude and smoothness of the random fluctuations. Here we have indicated explicitly which random variable is generated by each normal distribution. According to $(1)$, the random variable for state transition is $D\\tilde x$, which therefore links different levels of motion.\n", 184 | "\n" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "slideshow": { 191 | "slide_type": "subslide" 192 | } 193 | }, 194 | "source": [ 195 | "This allows us to write the directed Bayesian graph for the model:\n", 196 | "\n", 197 | "" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": { 203 | "slideshow": { 204 | "slide_type": "subslide" 205 | } 206 | }, 207 | "source": [ 208 | "### Hierarchical dynamic model ###\n", 209 | "\n", 210 | "For a hierarchical dynamic model (HDM), we assume each higher level generates causes for the level below, so that the causes $v$ link levels, whereas hidden states $x$ link dynamics over time. Further it is assumed that the noise processes at each level $w^{(i)}$ and $z^{(i)}$ are conditionally independent. This leads to the following Bayesian directed graph:\n", 211 | "\n", 212 | "\n", 213 | "\n", 214 | "Here $\\vartheta^{(i)} = [\\theta^{(i)}, \\lambda^{(i)}]$ and $u^{(i)} = [\\tilde v^{(i)}, \\tilde x^{(i)}]$." 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "slideshow": { 221 | "slide_type": "skip" 222 | } 223 | }, 224 | "source": [ 225 | "### Temporal smoothness ###\n", 226 | "\n", 227 | "Since the different levels of motion are linked, the covariance matrix will have off-diagonal elements with non-zero values. The covariance is given by the Kronecker product $\\tilde \\Sigma(\\lambda)^z = S(\\gamma)^{-1} \\otimes \\Sigma(\\lambda)^z$, where $\\Sigma(\\lambda)^z$ is a diagonal matrix specifying the variance of the (often assumed independent) Gaussian noise at each level, and $S(\\gamma)$ is the temporal precision precision matrix, which encodes the temporal dependencies between levels, which is a function of their autocorrelations:\n", 228 | "\n", 229 | "\\begin{align*}\n", 230 | "S(\\gamma)^{-1} &= \n", 231 | " \\begin{bmatrix}\n", 232 | " 1 & 0 & \\ddot{\\rho}(0) & \\cdots \\\\\n", 233 | " 0 & -\\ddot{\\rho}(0) & 0 \\\\\n", 234 | " \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) \\\\\n", 235 | " \\vdots & & & \\ddots\n", 236 | " \\end{bmatrix}\n", 237 | "\\end{align*}\n", 238 | "\n", 239 | "Here $\\ddot{\\rho}(0)$ is the second derivative of the autocorrelation function evaluated at zero. Note, that because the autocorrelation function is even (symmetrical for positive and negative delays), the odd derivatives of the autocorrelation function are all odd functions, and thus are zero when evaluated at zero.\n", 240 | "\n", 241 | "While $\\Sigma(\\lambda)^z$ can be evaluated for any analytical autocorrelation function, we assume here that the temporal correlations all have the same Gaussian form, which gives:\n", 242 | "\n", 243 | "\\begin{align*}\n", 244 | "S(\\gamma)^{-1} &= \n", 245 | " \\begin{bmatrix}\n", 246 | " 1 & 0 & -\\frac{1}{2}\\gamma & \\cdots \\\\\n", 247 | " 0 & \\frac{1}{2}\\gamma & 0 \\\\\n", 248 | " -\\frac{1}{2}\\gamma & 0 & \\frac{3}{4}\\gamma^2 \\\\\n", 249 | " \\vdots & & & \\ddots\n", 250 | " \\end{bmatrix}\n", 251 | "\\end{align*}\n", 252 | "\n", 253 | "Here, $\\gamma$ is the precision parameter of a Gaussian autocorrelation function. Typically, $\\gamma > 1$, which ensures the precisions of high-order motion converge quickly to zero. This is important because it enables us to truncate the representation of an infinite number of generalised coordinates to a relatively small number, since high-order prediction errors have a vanishingly small precision. Friston states that an order of n=6 is sufficient in most cases.\n", 254 | "\n", 255 | "Instead of using the covariance matrix, we can use its inverse, the precision matrix, which is defined by $\\tilde \\Pi(\\lambda)^z = S(\\gamma) \\otimes \\Pi(\\lambda)^z$, where $\\Pi(\\lambda)^z$ is a diagonal matrix of the precisions of the Gaussian noise at each level of the generalised coordinates.\n", 256 | "\n", 257 | "\\begin{align*}\n", 258 | "p(\\tilde y, \\tilde x, \\tilde v \\,|\\, \\theta, \\lambda) &= p(\\tilde y \\,|\\, \\tilde x, \\tilde v, \\theta, \\lambda) \\; p(\\tilde x \\,|\\, \\tilde v, \\theta, \\lambda) \\; p(\\tilde v) \\\\\n", 259 | " &= (2\\pi)^{-N_y/2} {|\\tilde\\Pi^z|}^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon^v}^T \\tilde\\Pi^z \\tilde\\varepsilon^v} (2\\pi)^{-N_x/2} {|\\tilde\\Pi^w|}^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon^x}^T \\tilde\\Pi^w \\tilde\\varepsilon^x} \\; p(\\tilde v) \\\\\n", 260 | " &= (2\\pi)^{-(N_y+N_x)/2} (|\\tilde\\Pi^z| + |\\tilde\\Pi^w|)^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon^v}^T \\tilde\\Pi^z \\tilde\\varepsilon^v} e^{-\\frac{1}{2}{\\tilde\\varepsilon^x}^T \\tilde\\Pi^w \\tilde\\varepsilon^x} \\; p(\\tilde v)\\\\\n", 261 | " &= (2\\pi)^{-N/2} |\\tilde\\Pi|^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} \\; p(\\tilde v)\\\\ \n", 262 | "\\tilde\\Pi &= \n", 263 | " \\begin{bmatrix}\n", 264 | " \\tilde\\Pi^z & \\\\\n", 265 | " & \\tilde\\Pi^w\n", 266 | " \\end{bmatrix}\\\\\n", 267 | "\\tilde\\varepsilon &= \n", 268 | " \\begin{bmatrix}\n", 269 | " \\tilde\\varepsilon^v = \\ \\ \\ \\tilde y - \\tilde g \\\\\n", 270 | " \\tilde\\varepsilon^x = D\\tilde x - \\tilde f\n", 271 | " \\end{bmatrix}\\\\\n", 272 | "N &= \\text{Rank}(\\tilde\\Pi)\n", 273 | "\\end{align*} \n", 274 | "\n", 275 | "Here we introduce auxilary variables $\\tilde\\varepsilon(t)$, which are the prediction errors for the generalised responses and motion of the hidden states, with respective predictions $\\tilde g(t)$ and $\\tilde f(t)$ and their precisions encoded by $\\tilde\\Pi$.\n", 276 | "\n", 277 | "The log probability can thus be written:\n", 278 | "\\begin{align*}\n", 279 | "\\ln p(\\tilde y, \\tilde x, \\tilde v \\,|\\, \\theta, \\lambda) &= \\frac{1}{2} \\ln |\\tilde\\Pi| - \\frac{1}{2}{{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} - \\frac{N}{2} \\ln 2\\pi + \\ln p(\\tilde v)\n", 280 | "\\end{align*}\n", 281 | "\n", 282 | "where the third term is constant, and the fourth term is defined by the input causes and considered known." 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": { 288 | "slideshow": { 289 | "slide_type": "slide" 290 | } 291 | }, 292 | "source": [ 293 | "### Model inversion ###\n", 294 | "\n", 295 | "For model inversion, we are trying to estimate the parameters $\\vartheta$ of a model given some observations $y$ and a model $m$ by maximising the conditional density $p(\\vartheta \\,|\\, \\tilde y, m)$. However, this density is in general not directly calculable as it involves normalising over all possible observations. *Variational Bayes* suggests a workaround by minimising the Kullback-Leibler divergence between what it believes the state of its environment is (encoded in a Recognition density $q(\\theta)$) and the true Bayesian posterior." 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": { 301 | "slideshow": { 302 | "slide_type": "subslide" 303 | } 304 | }, 305 | "source": [ 306 | "\\begin{align*}\n", 307 | "D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta \\,|\\, \\tilde y, m) \\: ) = \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta\\,|\\, \\tilde y, m)} \\: d\\vartheta}\n", 308 | "\\end{align*}\n", 309 | "\n", 310 | "The KL divergence is a measure of the difference between two probability distributions, is always positve, and is 0 if and only if the two distributions are the same. Thus adapting $q(\\vartheta)$ to minimise this KL divergence will result in $q(\\vartheta)$ being a close approximation of $p(\\vartheta\\,|\\, \\tilde y, m)$. \n", 311 | "\n", 312 | "Obviously, to evaluate this KL divergence directly, we would still need to be able to calculate $p(\\vartheta\\,|\\, \\tilde y, m)$ and we seem to have made no progress. " 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": { 318 | "slideshow": { 319 | "slide_type": "subslide" 320 | } 321 | }, 322 | "source": [ 323 | "However, the FEP uses the fact that $p(\\vartheta, \\tilde y, m) = p(\\vartheta\\,|\\, \\tilde y, m) p(\\tilde y, m)$, to write this as:\n", 324 | "\n", 325 | "\\begin{align*}\n", 326 | "D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: ) &= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)/p(\\tilde y, m)} \\: d\\vartheta} \\\\\n", 327 | "&= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\tilde y, m) \\: d\\vartheta} \\\\\n", 328 | "&= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} + \\ln p(\\tilde y, m) \\int{q(\\vartheta) \\: d\\vartheta} \\\\\n", 329 | "&= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} + \\ln p(\\tilde y, m) \\\\\n", 330 | "\\end{align*}\n", 331 | "\n", 332 | "since $\\int{q(\\vartheta) \\: d\\vartheta} = 1$ by definition of a probability density. \n", 333 | "\n", 334 | "We probably don't know $p(\\tilde y, m)$, but, since this doesn't depend on $\\vartheta$, it plays no role in optimising $q(\\vartheta)$. " 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "metadata": { 340 | "slideshow": { 341 | "slide_type": "subslide" 342 | } 343 | }, 344 | "source": [ 345 | "\\begin{align*}\n", 346 | "D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: ) &= \\ln p(\\tilde y, m) - F\\\\\n", 347 | "F &= -\\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} \\\\\n", 348 | "&= -D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta, \\tilde y, m) \\: )\\\\\n", 349 | "\\end{align*}" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "slideshow": { 356 | "slide_type": "subslide" 357 | } 358 | }, 359 | "source": [ 360 | "The joint density $p(\\vartheta,\\tilde y, m)$ is called the generative density, and represents the agent's belief in how the world works. It can be factorised into \n", 361 | "\n", 362 | "\\begin{align*}\n", 363 | "p(\\vartheta,\\tilde y, m) = p(\\tilde y, \\vartheta, m) = p(\\tilde y \\,|\\, \\vartheta, m)\\:p(\\vartheta, m)\n", 364 | "\\end{align*}\n", 365 | "\n", 366 | "where a prior $p(\\vartheta, m)$ encodes the agent's beliefs for the world states prior to new sensory input, and a likelihood $p(\\tilde y|\\vartheta, m)$ encodes how the agent's sensory signals relate to the world states. " 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": { 372 | "slideshow": { 373 | "slide_type": "subslide" 374 | } 375 | }, 376 | "source": [ 377 | "Thus, if we have a model for how the world states generate sensory perception (or if we can learn one), we can calculate $F$, which is called the *Variational Free Energy*, and is the negative of the KL divergence between the Recognition density, $q(\\vartheta)$, and the Generative density, $p(\\vartheta, \\tilde y, m)$. We can simply maximise $F$ to make $q(\\vartheta)$ the best possible approximation of $p(\\vartheta,\\tilde y, m)$, and thereby maximise $p(\\vartheta\\,|\\, \\tilde y, m)$." 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": { 383 | "slideshow": { 384 | "slide_type": "subslide" 385 | } 386 | }, 387 | "source": [ 388 | "Given a model, we can determine the probability of the observations under the model by:\n", 389 | "\n", 390 | "\\begin{align*}\n", 391 | "\\ln p(\\tilde y, m) &= \\ln p(\\tilde y \\,|\\, m) + \\ln p(m) \\\\\n", 392 | "\\ln p(\\tilde y \\,|\\, m) &= \\ln p(\\tilde y, m) - \\ln p(m) \\\\\n", 393 | "&= F + D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: ) - \\ln p(m)\n", 394 | "\\end{align*}" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": { 400 | "slideshow": { 401 | "slide_type": "subslide" 402 | } 403 | }, 404 | "source": [ 405 | "Thus for a given model ($p(m)=1$), we can write the log likelihood as:\n", 406 | "\n", 407 | "\\begin{align*}\n", 408 | "\\ln p(\\tilde y \\,|\\, m) &= F + D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: )\n", 409 | "\\end{align*}\n", 410 | "\n", 411 | "This indicates that $F$ can be used as a lower-bound for the log-evidence, since the KL divergence term is always positive and is $0$ if and only if $q(\\vartheta) = p(\\vartheta\\,|\\, \\tilde y, m)$." 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": { 417 | "slideshow": { 418 | "slide_type": "skip" 419 | } 420 | }, 421 | "source": [ 422 | "$F$ can also be expressed as:\n", 423 | "\n", 424 | "\\begin{align*}\n", 425 | "F &= -\\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} \\\\\n", 426 | "&= \\left< \\ln p(\\vartheta, \\tilde y, m) \\right>_q - \\left< \\ln q(\\vartheta) \\right>_q\n", 427 | "\\end{align*}\n", 428 | "\n", 429 | "which comprises the internal energy $U(\\vartheta, \\tilde y) = \\ln p(\\vartheta, \\tilde y)$ of a given model $m$ expected under $q(\\vartheta)$ and the entropy of $q(\\vartheta)$, which is a measure of its uncertainty.\n", 430 | "\n" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": { 436 | "slideshow": { 437 | "slide_type": "slide" 438 | } 439 | }, 440 | "source": [ 441 | "### Optimisation ###\n", 442 | "\n", 443 | "The introduction of $q(\\vartheta)$ converts the difficult integration problem inherent in Bayesian Inference into a much simpler optimisation problem of adapting $q(\\vartheta)$ to maximise $F$. To further simplify calculation, we usually assume that the model parameters can be partitioned over the states $u = [\\tilde v, \\tilde x]^T$, the parameters $\\theta$, and the hyperparameters $\\lambda$, as:\n", 444 | "\n", 445 | "\\begin{align*}\n", 446 | "q(\\vartheta) &= q(u(t)) \\, q(\\theta) \\, q(\\lambda) \\\\\n", 447 | "&= \\prod_i q(\\vartheta^i) \\\\\n", 448 | "\\vartheta^i &= \\{u(t), \\theta, \\lambda\\}\n", 449 | "\\end{align*}\n", 450 | "\n", 451 | "This partition is called the *mean field* approximation in statistical physics. We further assume that over the timescale of inference, only the states $u$ change with time $t$, while the (hyper)parameters are assumed constant." 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": { 457 | "slideshow": { 458 | "slide_type": "skip" 459 | } 460 | }, 461 | "source": [ 462 | "Under this partition, optimisation is still achieved by maximising the Free Energy, but we can now do this separately for each partition, by averaging over the other partitions. To show this, we define $F$ as an integral over the parameter partitions:\n", 463 | "\n", 464 | "\\begin{align*}\n", 465 | "F &= \\int f^i \\, d\\vartheta^i \\\\\n", 466 | "&= \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} - \\int{q(\\vartheta) \\: \\ln q(\\vartheta) \\: d\\vartheta} \\\\\n", 467 | "&= \\iint{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} \\: d\\vartheta^i} - \\iint{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta) \\: d\\vartheta^{\\backslash i} \\: d\\vartheta^i} \\\\\n", 468 | "f^i &= \\int{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} } - \\int{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta) \\: d\\vartheta^{\\backslash i}} \\\\\n", 469 | "&= q(\\vartheta^i) \\: \\int{ q(\\vartheta^{\\backslash i}) \\: U(\\vartheta, \\tilde y) \\: d\\vartheta^{\\backslash i} } - q(\\vartheta^i) \\: \\int{q(\\vartheta^{\\backslash i}) \\: (\\ln q(\\vartheta^i) + \\ln q(\\vartheta^{\\backslash i})) \\: d\\vartheta^{\\backslash i}} \\\\\n", 470 | "&= q(\\vartheta^i) \\: V(\\vartheta^i) - q(\\vartheta^i) \\: \\ln q(\\vartheta^i) - \\int{q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta^{\\backslash i}) \\: d\\vartheta^{\\backslash i}} \\\\\n", 471 | "\\partial_{q(\\vartheta^i)} \\: f^i &= V(\\vartheta^i) - \\ln q(\\vartheta^i) - \\ln Z^i \\\\\n", 472 | "\\end{align*}\n", 473 | "\n", 474 | "Here, $\\vartheta^{\\backslash i}$ denotes all parameters not in set $i$, i.e., its Markov blanket, $Z^i$ contains all the terms of $f^i$ that do not depend on $\\vartheta^i$, and\n", 475 | "\n", 476 | "\\begin{align*}\n", 477 | "V(\\vartheta^i) &= \\int{q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} } = \\int{ q(\\vartheta^{\\backslash i}) \\: U(\\vartheta, \\tilde y) \\: d\\vartheta^{\\backslash i} } = \\left< U(\\vartheta) \\right>_{q(\\vartheta^{\\backslash i})}\\\\\n", 478 | "\\end{align*}" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": { 484 | "slideshow": { 485 | "slide_type": "skip" 486 | } 487 | }, 488 | "source": [ 489 | "The Fundamental Lemma of variational calculus states that the free energy is maximised when:\n", 490 | "\n", 491 | "\\begin{align*}\n", 492 | "\\delta_{q(\\vartheta^i)} F &= 0 \\Leftrightarrow \\partial_{q(\\vartheta^i)} \\: f^i = 0 \\\\\n", 493 | "\\ln q(\\vartheta^i) &= V(\\vartheta^{\\backslash i}) - \\ln Z^i\\\\\n", 494 | "q(\\vartheta^i) &= \\frac{1}{Z^i} \\exp \\left(V(\\vartheta^i)\\right) = \\frac{1}{Z^i} \\exp\\left(\\left< U(\\vartheta) \\right>_{q(\\vartheta^{\\backslash i})}\\right) \n", 495 | "\\end{align*}\n", 496 | "\n", 497 | "Thus, $Z^i$ is a normalisation constant, and is also called a partition function in physics. The final equation indicates that the variational density over one parameter set is an exponential function of the internal energy averaged over all other parameters." 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": { 503 | "slideshow": { 504 | "slide_type": "skip" 505 | } 506 | }, 507 | "source": [ 508 | "Given our partitions above, we can then write:\n", 509 | "\n", 510 | "\\begin{align*}\n", 511 | "q(u(t)) &\\propto \\exp \\left(V(t)\\right) \\\\\n", 512 | "V(t) &= \\left< U(t) \\right>_{q(\\theta)q(\\lambda)} \\\\\n", 513 | "q(\\theta) &\\propto \\exp \\left(\\bar{V}^\\theta \\right) \\\\\n", 514 | "\\bar{V}^\\theta &= \\int \\left< U(t) \\right>_{q(u)q(\\lambda)} dt + U^\\theta \\\\\n", 515 | "q(\\lambda) &\\propto \\exp \\left(\\bar{V}^\\lambda \\right) \\\\\n", 516 | "\\bar{V}^\\lambda &= \\int \\left< U(t) \\right>_{q(u)q(\\theta)} dt + U^\\lambda\n", 517 | "\\end{align*}" 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": { 523 | "slideshow": { 524 | "slide_type": "skip" 525 | } 526 | }, 527 | "source": [ 528 | "In a dynamical system, the instantaneous internal energy $U(t)$ is a function of time. Because the parameters and hyperparameters are considered constant over a period of observation, their variational densities are functions of the path integal of this internal energy. $U^\\theta = \\ln p(\\theta)$ and $U^\\lambda = \\ln p(\\lambda)$ are the prior energies of the parameters and hyperparameters, respectively. \n", 529 | "\n", 530 | "From these equations we see that the variational density over states can be determined from the instantaneous internal energy averaged over parameters and hyperparameters, whereas the density over parameters and hyperparameters can only be determined when data has been observed over a certain amount of time. In the absence of data, the integrals will be zero, and the conditional density simply reduces to the prior density." 531 | ] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "metadata": { 536 | "slideshow": { 537 | "slide_type": "skip" 538 | } 539 | }, 540 | "source": [ 541 | "*Variational Bayes* assumes the above equations are analytically tractable, which needs needs the choice of appropriate (conjugate) priors. The conditional distributions $q(\\vartheta^i)$ above can then be updated through iteration as new data becomes available:\n", 542 | "\n", 543 | "\\begin{align*}\n", 544 | "\\ln q(u(t)) &\\propto \\left< U(t) \\right>_{q(\\theta)q(\\lambda)} \\\\\n", 545 | "\\ln q(\\theta) &\\propto \\int \\left< U(t) \\right>_{q(u)q(\\lambda)} dt + \\ln p(\\theta) \\\\\n", 546 | "\\ln q(\\lambda) &\\propto \\int \\left< U(t) \\right>_{q(u)q(\\theta)} dt + \\ln p(\\lambda)\n", 547 | "\\end{align*}" 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": { 553 | "slideshow": { 554 | "slide_type": "slide" 555 | } 556 | }, 557 | "source": [ 558 | "### The Laplace approximation ###\n", 559 | "\n", 560 | "The Laplace approximation assumes that the marginals of the conditional density assume a Gaussian form, i.e., $q(\\vartheta^i) = \\mathcal{N}(\\vartheta^i : \\mu^i, C^i)$, where $\\mu^i$ and $C^i$ are the sufficient statistics. For notational clarity, we will use $\\mu^i$, $C^i$, and $P^i$ for the conditional mean, covariance, and precision of the $i^\\text{th}$ marginal, respectively, and $\\eta^i$, $\\Sigma^i$, and $\\Pi^i$ for their priors. This approximation simplifies the updates to the marginals of the conditional densities." 561 | ] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "metadata": { 566 | "slideshow": { 567 | "slide_type": "subslide" 568 | } 569 | }, 570 | "source": [ 571 | "For each partition $\\vartheta^i$, we can then write:\n", 572 | "\n", 573 | "\\begin{align*}\n", 574 | "q(\\vartheta^i) &= \\frac{1}{\\sqrt{2\\pi C^i}} \\exp \\left( \\frac{-(\\vartheta^i - \\mu^i)^2}{2C^i} \\right) \\\\\n", 575 | "&= \\frac{1}{Z^i} \\exp \\left( -\\varepsilon(\\vartheta^i) \\right) \\\\ \n", 576 | "Z^i &= \\sqrt{2\\pi C^i} \\\\\n", 577 | "\\varepsilon(\\vartheta^i) &= \\frac{(\\vartheta^i - \\mu^i)^2}{2C^i}\n", 578 | "\\end{align*}" 579 | ] 580 | }, 581 | { 582 | "cell_type": "markdown", 583 | "metadata": { 584 | "slideshow": { 585 | "slide_type": "subslide" 586 | } 587 | }, 588 | "source": [ 589 | "Recall that the Free Energy was defined as:\n", 590 | "\n", 591 | "\\begin{align*}\n", 592 | "F &= -\\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} \\\\\n", 593 | "&= - \\int{q(\\vartheta) \\: \\ln q(\\vartheta) \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} \\\\\n", 594 | "&= - \\int{q(\\vartheta) \\: \\ln \\prod_i q(\\vartheta^i) \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} \\\\\n", 595 | "&= - \\int{q(\\vartheta) \\: \\sum_i(\\ln Z^i + \\varepsilon(\\vartheta^i)) \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} \\\\\n", 596 | "&= - \\sum_i(\\ln Z^i)\\int{q(\\vartheta) \\: d\\vartheta} - \\int{q(\\vartheta) \\: \\sum_i(\\varepsilon(\\vartheta^i)) \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} \\\\\n", 597 | "&= - \\sum_i(\\ln Z^i) - \\sum_i \\frac{1}{2C^i}\\int{q(\\vartheta) \\: (\\vartheta^i - \\mu^i)^2 \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} \\\\\n", 598 | "&= - \\sum_i(\\ln Z^i) - \\frac{1}{2} + \\left< U \\right>_q \\\\\n", 599 | "\\end{align*}\n", 600 | "\n", 601 | "Now we still need to find an expression we can calculate for $\\left< U \\right>_q$. " 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "metadata": { 607 | "slideshow": { 608 | "slide_type": "subslide" 609 | } 610 | }, 611 | "source": [ 612 | "To do this, a further approximation assumes that $q$ is sharply peaked at its mean value $\\mu$, so that the integration is only non-zero close to $\\vartheta = \\mu$. We can then use a Taylor expansion around the mean to obtain: \n", 613 | "\n", 614 | "\\begin{align*}\n", 615 | "\\left< U \\right>_q &= \\int{q(\\vartheta) \\: U(\\vartheta, \\tilde y) \\: d\\vartheta} \\\\\n", 616 | "&= \\int{q(\\vartheta) \\: \\left\\{ U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\delta\\vartheta + \\frac{1}{2} \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\delta\\vartheta^2 \\right\\} \\: d\\vartheta} \\\\\n", 617 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\int{q(\\vartheta) \\: (\\vartheta - \\mu) \\: d\\vartheta} + \\frac{1}{2} \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\int{ q(\\vartheta) \\: (\\vartheta - \\mu)^2 \\: d\\vartheta} \\\\\n", 618 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\left\\{ \\int{\\vartheta q(\\vartheta) \\: d\\vartheta} - \\mu \\right\\} + \\frac{1}{2} \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\int{ q(\\vartheta) \\: (\\vartheta - \\mu)^2 \\: d\\vartheta} \\\\\n", 619 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\left\\{ \\mu - \\mu \\right\\} + \\frac{1}{2} \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\int{ q(\\vartheta) \\: (\\vartheta - \\mu)^2 \\: d\\vartheta} \\\\\n", 620 | "&= U(\\mu, \\tilde y) + \\frac{1}{2} \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu C \\\\\n", 621 | "\\end{align*}" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": { 627 | "slideshow": { 628 | "slide_type": "subslide" 629 | } 630 | }, 631 | "source": [ 632 | "This now allows us to write for the free energy:\n", 633 | "\n", 634 | "\\begin{align*}\n", 635 | "F &= U(\\mu, \\tilde y) - \\frac{1}{2} + \\frac{1}{2} \\sum_i( \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i} C^i - \\ln 2\\pi C^i) \\\\\n", 636 | "\\end{align*}" 637 | ] 638 | }, 639 | { 640 | "cell_type": "markdown", 641 | "metadata": { 642 | "slideshow": { 643 | "slide_type": "subslide" 644 | } 645 | }, 646 | "source": [ 647 | "To find the optimal variances, we maximise the free energy with respect to the variances, so that the partial derivatives are zero:\n", 648 | "\n", 649 | "\\begin{align*}\n", 650 | "\\frac{dF}{d\\vartheta^i} &= \\frac{1}{2} \\left\\{ \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i} - \\frac{1}{C^i} \\right\\} = 0 \\\\\n", 651 | "C^{i*} &= \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i}^{-1} \\\\\n", 652 | "F &= U(\\mu, \\tilde y) - \\frac{1}{2} \\ln 2\\pi C^*\n", 653 | "\\end{align*}\n", 654 | "\n", 655 | "where we use the notation $C^{i*}$ to indicate this is the optimal variance which maximises the free energy." 656 | ] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "metadata": { 661 | "slideshow": { 662 | "slide_type": "subslide" 663 | } 664 | }, 665 | "source": [ 666 | "Now, I don't quite know how to get here, but Friston just states that the updates under the Laplace approximation become:" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": { 672 | "slideshow": { 673 | "slide_type": "-" 674 | } 675 | }, 676 | "source": [ 677 | "\\begin{align*}\n", 678 | "\\bar{U} &= \\int U(t)dt + U^\\theta + U^\\lambda \\\\\n", 679 | "\\bar{V}^u &= \\int U(u, t|\\mu^\\theta, \\mu^\\lambda) + W(t)^\\theta + W(t)^\\lambda dt \\\\\n", 680 | "\\bar{V}^\\theta &= \\int U(\\mu^u, t|\\theta, \\mu^\\lambda) + W(t)^u + W(t)^\\lambda dt + U^\\theta \\\\\n", 681 | "\\bar{V}^\\lambda &= \\int U(\\mu^u, t|\\mu^\\theta, \\lambda) + W(t)^u + W(t)^\\theta dt + U^\\lambda \\\\\n", 682 | "W(t)^u &= \\frac{1}{2} \\text{tr}(C^u U(t)_{uu}) \\\\\n", 683 | "W(t)^\\theta &= \\frac{1}{2} \\text{tr}(C^\\theta U(t)_{\\theta\\theta}) \\\\\n", 684 | "W(t)^\\lambda &= \\frac{1}{2} \\text{tr}(C^\\lambda U(t)_{\\lambda\\lambda}) \\\\\n", 685 | "\\end{align*}\n", 686 | "\n", 687 | "where $U_{xx} = d^2U/dx^2$." 688 | ] 689 | }, 690 | { 691 | "cell_type": "markdown", 692 | "metadata": { 693 | "slideshow": { 694 | "slide_type": "subslide" 695 | } 696 | }, 697 | "source": [ 698 | "Also, the conditional precisions are equal to the negative curvatures of the internal action:\n", 699 | "\n", 700 | "\\begin{align*}\n", 701 | "P^u &= -\\bar{U}_{uu} = -U(t)_{uu} \\\\\n", 702 | "P^\\theta &= -\\bar{U}_{\\theta\\theta} = - \\int U(t)_{\\theta\\theta} \\: dt - U^\\theta_{\\theta\\theta} \\\\\n", 703 | "P^\\lambda &= -\\bar{U}_{\\lambda\\lambda} = - \\int U(t)_{\\lambda\\lambda} \\: dt- U^\\lambda_{\\lambda\\lambda} \\\\\n", 704 | "\\end{align*}" 705 | ] 706 | }, 707 | { 708 | "cell_type": "markdown", 709 | "metadata": { 710 | "slideshow": { 711 | "slide_type": "subslide" 712 | } 713 | }, 714 | "source": [ 715 | "For our HDM the gradients and curvature of the internal energy are:\n", 716 | "\n", 717 | "\\begin{align*}\n", 718 | "U(t)_u &= -\\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon \\\\\n", 719 | "U(t)_\\theta &= -\\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon \\\\\n", 720 | "U(t)_{\\lambda i} &= - \\frac{1}{2} \\text{tr} (Q_i(\\tilde \\varepsilon \\tilde \\varepsilon^T - \\tilde \\Sigma)) \\\\\n", 721 | "U(t)_{uu} &= -\\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon_u \\\\\n", 722 | "U(t)_{\\theta\\theta} &= -\\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta \\\\\n", 723 | "U(t)_{\\lambda\\lambda ij} &= - \\frac{1}{2} \\text{tr} (Q_i \\tilde \\Sigma Q_j \\tilde \\Sigma)) \\\\\n", 724 | "\\end{align*}" 725 | ] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": { 730 | "slideshow": { 731 | "slide_type": "subslide" 732 | } 733 | }, 734 | "source": [ 735 | "with:\n", 736 | "\n", 737 | "\\begin{align*}\n", 738 | "U(t)_{\\lambda i} &= \\frac{dU(t)}{d\\lambda_i} \\\\\n", 739 | "Q_{i} &= \\frac{d\\tilde\\Pi}{d\\lambda_i} \\\\\n", 740 | "\\end{align*}\n", 741 | "\n", 742 | "and where we assume that\n", 743 | "\n", 744 | "\\begin{align*}\n", 745 | "\\frac{d^2\\tilde\\Pi}{d\\lambda_i^2} &= 0\\\\\n", 746 | "\\end{align*}" 747 | ] 748 | }, 749 | { 750 | "cell_type": "markdown", 751 | "metadata": { 752 | "slideshow": { 753 | "slide_type": "subslide" 754 | } 755 | }, 756 | "source": [ 757 | "The derivatives with respect to each parameter $\\tilde \\varepsilon_{\\theta}(t) = \\tilde \\varepsilon_{u\\theta} \\mu^u(t)$ rest on the second derivative of the models functions that mediate interactions between each parameter and the states:\n", 758 | "\n", 759 | "\\begin{align*}\n", 760 | "\\tilde \\varepsilon_{\\theta u}^T &= \\tilde \\varepsilon_{u\\theta} = -\n", 761 | "\\begin{bmatrix}\n", 762 | "I \\otimes g_{v\\theta} & I \\otimes g_{x\\theta} \\\\\n", 763 | "I \\otimes f_{v\\theta} & I \\otimes f_{x\\theta}\n", 764 | "\\end{bmatrix}\n", 765 | "\\end{align*}" 766 | ] 767 | } 768 | ], 769 | "metadata": { 770 | "celltoolbar": "Slideshow", 771 | "hide_input": false, 772 | "kernelspec": { 773 | "display_name": "Python 2", 774 | "language": "python", 775 | "name": "python2" 776 | }, 777 | "language_info": { 778 | "codemirror_mode": { 779 | "name": "ipython", 780 | "version": 2 781 | }, 782 | "file_extension": ".py", 783 | "mimetype": "text/x-python", 784 | "name": "python", 785 | "nbconvert_exporter": "python", 786 | "pygments_lexer": "ipython2", 787 | "version": "2.7.15" 788 | } 789 | }, 790 | "nbformat": 4, 791 | "nbformat_minor": 2 792 | } 793 | -------------------------------------------------------------------------------- /DEM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Dynamic Expectation Maximisation\n", 8 | "\n", 9 | "## By André van Schaik\n", 10 | "\n", 11 | "### __[International Centre for Neuromorphic Systems](https://westernsydney.edu.au/icns)__ \n", 12 | "10/01/2019 - 29/01/2023 (yes, it has taken me 4 years with lots of interruptions!)" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "I made this notebook on Dynamic Expectation Maximisation (DEM) following \"Hierarchical Models in the Brain\", by Karl Friston, PLoS Computational Biology 4(11): e1000211. [doi:10.1371/journal.pcbi.1000211](https://doi.org/10.1371/journal.pcbi.1000211) also using \"DEM: A variational treatment of dynamic systems\" by Friston, Trujillo-Barreto, and Daunizeau, Neuroimage Volume 41, Issue 3, 1 July 2008, Pages 849-885 [doi:10.1016/j.neuroimage.2008.02.054](https://doi.org/10.1016/j.neuroimage.2008.02.054). I found some of the steps in those papers not obvious and couldn't find an obvious source that derived them in sufficient detail. I am hoping that I have now captured all these steps in this notebook and the ones it links to." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "### Dynamic model in generalised coordinates ###\n", 27 | "\n", 28 | "We can write a dynamic input-state-output model as:\n", 29 | "\n", 30 | "\\begin{align*}\n", 31 | "\\tilde y &= \\tilde g + \\tilde z \\\\\n", 32 | "D \\tilde x &= \\tilde f + \\tilde w \\tag{1}\n", 33 | "\\end{align*}\n", 34 | "\n", 35 | "where $\\tilde y$ are the observations, and $\\tilde x$ are the hidden states of the system. The $\\tilde a$ notation indicates variables and functions in generalised coordinates of motion $\\tilde a = [a, \\dot{a}, \\ddot{a}, \\dddot{a}, ...]^T$. Here, $\\dot{a}$ is the *value* of the derivative of $a$ with respect to time; in other words, it is a dimensionless number, even though the time derivative obviously has a unit of $[s^{-1}]$. One way of interpreting this is that $\\dot{a} = \\tau\\,da/dt$ where $\\tau = dt$ and all time is measured in units of $dt$, and similarly for the higher orders of motion.\n", 36 | "\n", 37 | "$D$ is a block-matrix derivative operator, whose first leading-diagonal contains identity matrices. This operator simply shifts the vectors of generalised motion so $a[i]$ that is replaced by $a[i+1]$.\n", 38 | "\n", 39 | "The predicted sensor response $\\tilde g$ and motion $\\tilde f$ of the hidden states $\\tilde x$ in absence of random fluctuations are:\n", 40 | "\n", 41 | "\\begin{align*}\n", 42 | "\\begin{split}\n", 43 | "g &= g(x, \\nu) \\\\\n", 44 | "\\dot{g} &= g_x \\dot{x} + g_\\nu \\dot{\\nu} \\\\\n", 45 | "\\ddot{g} &= g_x \\ddot{x} + g_\\nu \\ddot{\\nu} \\\\\n", 46 | "&\\phantom{g=\\,} \\vdots \\\\\n", 47 | "\\end{split}\n", 48 | "\\:\\:\\:\n", 49 | "\\begin{split}\n", 50 | "f &= f(x, \\nu) \\\\\n", 51 | "\\dot{f} &= f_x \\dot{x} + f_\\nu \\dot{\\nu} \\\\\n", 52 | "\\ddot{f} &= f_x \\ddot{x} + f_\\nu \\ddot{\\nu} \\\\\n", 53 | "&\\phantom{f=\\,} \\vdots \\\\\n", 54 | "\\end{split}\n", 55 | "\\end{align*}\n", 56 | "\n", 57 | "Here, $f$ and $g$ are continuous nonlinear functions and $\\tilde \\nu$ are causes or inputs, which can also result from actions by the agent. The notation $a_b$ is shorthand for $\\partial{a}/\\partial{b}$. We assume that the observation noise $\\tilde z$ follows a zero-mean Gaussian distribution $p(\\tilde z) = \\mathcal{N}(0, \\tilde \\Sigma^z)$ and similarly for the state noise $p(\\tilde w) = \\mathcal{N}(0, \\tilde \\Sigma^w)$. The input drive is also Gaussian but with a mean that can be different from zero: $p(\\tilde \\nu) = N(\\tilde \\eta^\\nu, \\tilde C^\\nu)$, where we use $\\tilde C^\\nu$ instead of $\\tilde \\Sigma^\\nu$ to indicate this is a *prior* covariance, rather than a *conditional* covariance. We can then evaluate the joint density over observations $\\tilde y$, hidden states $\\tilde x$, and inputs $\\tilde \\nu$:\n", 58 | "\n", 59 | "\\begin{align*}\n", 60 | "p(\\tilde y, \\tilde x, \\tilde \\nu \\,|\\, \\theta, \\lambda) &= p(\\tilde y \\,|\\, \\tilde x, \\tilde \\nu, \\theta, \\lambda) \\; p(\\tilde x \\,|\\, \\tilde \\nu, \\theta, \\lambda) \\; p(\\tilde \\nu) \\\\\n", 61 | "p(\\tilde y \\,|\\, \\tilde x, \\tilde \\nu, \\theta, \\lambda) &= \\mathcal{N}(\\tilde y : \\tilde g, \\tilde \\Sigma(\\lambda)^z) \\\\\n", 62 | "p(\\tilde x \\,|\\, \\tilde \\nu, \\theta, \\lambda) &= \\mathcal{N}(D\\tilde x : \\tilde f, \\tilde \\Sigma(\\lambda)^w) \\\\\n", 63 | "p(\\tilde \\nu) &= \\mathcal{N}(\\tilde \\nu : \\eta^\\nu, C^\\nu)\n", 64 | "\\end{align*}\n", 65 | "\n", 66 | "where $\\theta$ contains the parameters describing $f$ and $g$, and $\\lambda$ are hyperparameters which control the amplitude and smoothness of the random fluctuations. Here we have indicated explicitly which random variable is generated by each normal distribution. According to $(1)$, the random variable for state transition is $D\\tilde x$, which therefore links different levels of motion.\n", 67 | "\n", 68 | "Finally, we also assume Gaussian priors for the hyperparameters $\\lambda$ and $\\theta$:\n", 69 | "\n", 70 | "\\begin{align*}\n", 71 | "p(\\lambda) &= \\mathcal{N}(\\lambda : \\eta^\\lambda, C^\\lambda)\\\\\n", 72 | "p(\\theta) &= \\mathcal{N}(\\theta : \\eta^\\theta, C^\\theta)\n", 73 | "\\end{align*}" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "This allows us to write the directed Bayesian graph for the model:\n", 81 | "\n", 82 | "" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Model inversion ###\n", 90 | "\n", 91 | "For model inversion, we are trying to estimate the parameters $\\vartheta$ of a model given some observations $y$ and a model $m$ by maximising the conditional density $p(\\vartheta \\,|\\, \\tilde y, m)$. However, this density is in general not directly calculable as it involves normalising over all possible observations. *Variational Bayes* suggests a workaround by minimising the Kullback-Leibler divergence between what it believes the state of its environment is (encoded in a Recognition density $q(\\theta)$) and the true Bayesian posterior.\n", 92 | "\n", 93 | "\\begin{align*}\n", 94 | "D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta \\,|\\, \\tilde y, m) \\: ) = \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta\\,|\\, \\tilde y, m)} \\: d\\vartheta}\n", 95 | "\\end{align*}\n", 96 | "\n", 97 | "The KL divergence is a measure of the difference between two probability distributions, is always positve, and is 0 if and only if the two distributions are the same. Thus adapting $q(\\vartheta)$ to minimise this KL divergence will result in $q(\\vartheta)$ being a close approximation of $p(\\vartheta\\,|\\, \\tilde y, m)$. \n", 98 | "\n", 99 | "Obviously, to evaluate this KL divergence directly, we would still need to be able to calculate $p(\\vartheta\\,|\\, \\tilde y, m)$ and we seem to have made no progress. However, the FEP uses the fact that $p(\\vartheta, \\tilde y, m) = p(\\vartheta\\,|\\, \\tilde y, m) p(\\tilde y, m)$, to write this as:\n", 100 | "\n", 101 | "\\begin{align*}\n", 102 | "D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: ) &= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)/p(\\tilde y, m)} \\: d\\vartheta} \\\\\n", 103 | "&= \\int{q(\\vartheta) \\: \\{ \\ln\\:q(\\vartheta) - \\ln\\:p(\\vartheta, \\tilde y, m) + \\ln\\:p(\\tilde y, m) \\} \\: d\\vartheta} \\\\\n", 104 | "&= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\tilde y, m) \\: d\\vartheta} \\\\\n", 105 | "&= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} + \\ln p(\\tilde y, m) \\int{q(\\vartheta) \\: d\\vartheta} \\\\\n", 106 | "&= \\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} + \\ln p(\\tilde y, m) \\\\\n", 107 | "\\end{align*}\n", 108 | "\n", 109 | "since $\\int{q(\\vartheta) \\: d\\vartheta} = 1$ by definition of a probability density. We continue by writing:\n", 110 | "\n", 111 | "\\begin{align*}\n", 112 | "D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: ) &= \\ln p(\\tilde y, m) - F\\\\\n", 113 | "F &= -\\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} \\\\\n", 114 | "&= -D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta, \\tilde y, m) \\: )\\\\\n", 115 | "\\end{align*}\n", 116 | "\n", 117 | "The joint density $p(\\vartheta,\\tilde y, m)$ is called the generative density, and represents the agent's belief in how the world works. It can be factorised into $p(\\vartheta,\\tilde y, m) = p(\\tilde y, \\vartheta, m) = p(\\tilde y \\,|\\, \\vartheta, m)\\:p(\\vartheta, m)$ where a prior $p(\\vartheta, m)$ encodes the agent's beliefs for the world states prior to new sensory input, and a likelihood $p(\\tilde y|\\vartheta, m)$ encodes how the agent's sensory signals relate to the world states. Thus, if we have a model for how the world states generate sensory perception (or if we can learn one), we can calculate $F$, which is called the *Variational Free Energy*, and is the negative of the KL divergence between the Recognition density, $q(\\vartheta)$, and the Generative density, $p(\\vartheta, \\tilde y, m)$. We probably don't know $\\ln p(\\tilde y, m)$, but, since this doesn't depend on $\\vartheta$, it plays no role in optimising $q(\\vartheta)$. We can simply maximise $F$ to make $q(\\vartheta)$ the best possible approximation of $p(\\vartheta,\\tilde y, m)$, and thereby also of $p(\\vartheta\\,|\\, \\tilde y, m)$." 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "Given a model, we can determine the probability of the observations under the model by:\n", 125 | "\n", 126 | "\\begin{align*}\n", 127 | "\\ln p(\\tilde y, m) &= \\ln p(\\tilde y \\,|\\, m) + \\ln p(m) \\\\\n", 128 | "\\ln p(\\tilde y \\,|\\, m) &= \\ln p(\\tilde y, m) - \\ln p(m) \\\\\n", 129 | "&= F + D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: ) - \\ln p(m)\n", 130 | "\\end{align*}\n", 131 | "\n", 132 | "Thus for a given model ($p(m)=1$), we can write its log likelihood as:\n", 133 | "\n", 134 | "\\begin{align*}\n", 135 | "\\ln p(\\tilde y \\,|\\, m) &= F + D_{KL}(\\: q(\\vartheta) \\; || \\; p(\\vartheta\\,|\\, \\tilde y, m) \\: )\n", 136 | "\\end{align*}\n", 137 | "\n", 138 | "This indicates that $F$ can be used as a lower-bound for the log-evidence, since the KL divergence term is always positive and is $0$ if and only if $q(\\vartheta) = p(\\vartheta\\,|\\, \\tilde y, m)$." 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "$F$ can also be expressed as:\n", 146 | "\n", 147 | "\\begin{align*}\n", 148 | "F &= -\\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} \\\\\n", 149 | "&= \\left< \\ln p(\\vartheta, \\tilde y, m) \\right>_q - \\left< \\ln q(\\vartheta) \\right>_q\n", 150 | "\\end{align*}\n", 151 | "\n", 152 | "which comprises the internal energy $U(\\vartheta, \\tilde y) = \\ln p(\\vartheta, \\tilde y)$ of a given model $m$ expected under $q(\\vartheta)$ and the entropy of $q(\\vartheta)$, which is a measure of its uncertainty.\n", 153 | "\n" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "### Optimisation ###\n", 161 | "\n", 162 | "The introduction of $q(\\vartheta)$ converts the difficult integration problem inherent in Bayesian Inference into a much simpler optimisation problem of adapting $q(\\vartheta)$ to maximise $F$. To further simplify calculation, we usually assume that the model parameters can be partitioned over the states $u = [\\tilde \\nu, \\tilde x]^T$, the parameters $\\theta$, and the hyperparameters $\\lambda$, as:\n", 163 | "\n", 164 | "\\begin{align*}\n", 165 | "q(\\vartheta) &= q(u(t)) \\, q(\\theta) \\, q(\\lambda) \\\\\n", 166 | "&= \\prod_i q(\\vartheta^i) \\\\\n", 167 | "\\vartheta^i &= \\{u(t), \\theta, \\lambda\\}\n", 168 | "\\end{align*}\n", 169 | "\n", 170 | "This partition is called the *mean field* approximation in statistical physics. We further assume that over the timescale of inference, only the states $u$ change with time $t$, while the (hyper)parameters are assumed constant.\n", 171 | "\n", 172 | "Under this partition, optimisation is still achieved by maximising the Free Energy, but we can now do this separately for each partition, by averaging over the other partitions. To show this, we define $F$ as an integral over the parameter partitions:\n", 173 | "\n", 174 | "\\begin{align*}\n", 175 | "F &= \\int f^i \\, d\\vartheta^i \\\\\n", 176 | "&= \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} - \\int{q(\\vartheta) \\: \\ln q(\\vartheta) \\: d\\vartheta} \\\\\n", 177 | "&= \\iint{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} \\: d\\vartheta^i} - \\iint{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta) \\: d\\vartheta^{\\backslash i} \\: d\\vartheta^i} \\\\\n", 178 | "&= \\int{\\left( \\int{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} } - \\int{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta) \\: d\\vartheta^{\\backslash i} } \\right) \\: d\\vartheta^i }\\\\\n", 179 | "f^i &= \\int{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} } - \\int{q(\\vartheta^i) \\: q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta) \\: d\\vartheta^{\\backslash i}} \\\\\n", 180 | "&= q(\\vartheta^i) \\: \\int{ q(\\vartheta^{\\backslash i}) \\: U(\\vartheta, \\tilde y) \\: d\\vartheta^{\\backslash i} } - q(\\vartheta^i) \\: \\int{q(\\vartheta^{\\backslash i}) \\: (\\ln q(\\vartheta^i) + \\ln q(\\vartheta^{\\backslash i})) \\: d\\vartheta^{\\backslash i}} \\\\\n", 181 | "&= q(\\vartheta^i) \\left( V(\\vartheta^i) - \\ln q(\\vartheta^i) - \\int{q(\\vartheta^{\\backslash i}) \\: \\ln q(\\vartheta^{\\backslash i}) \\: d\\vartheta^{\\backslash i}} \\right) \\\\\n", 182 | "F &= \\int q(\\vartheta^i) \\: \\left( V(\\vartheta^i) - \\ln q(\\vartheta^i) - \\ln Z^i \\right)\\, d\\vartheta^i \\\\\n", 183 | "\\end{align*}\n", 184 | "\n", 185 | "Here, $\\vartheta^{\\backslash i}$ denotes all parameters not in set $i$, i.e., its Markov blanket, $\\ln Z^i$ contains all the terms of $f^i$ that do not depend on $\\vartheta^i$, and\n", 186 | "\n", 187 | "\\begin{align*}\n", 188 | "V(\\vartheta^i) &= \\int{q(\\vartheta^{\\backslash i}) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta^{\\backslash i} } = \\int{ q(\\vartheta^{\\backslash i}) \\: U(\\vartheta, \\tilde y) \\: d\\vartheta^{\\backslash i} } = \\left< U(\\vartheta) \\right>_{q(\\vartheta^{\\backslash i})}\\\\\n", 189 | "\\end{align*}\n", 190 | "\n", 191 | "The free energy is maximised when the derivative of $F$ with respect to the distribution $q(\\vartheta^i)$ = 0. \n", 192 | "\n", 193 | "\\begin{align*}\n", 194 | "\\delta_{q(\\vartheta^i)} F &= q(\\vartheta^i) \\: \\left( V(\\vartheta^i) - \\ln q(\\vartheta^i) - \\ln Z^i \\right) = 0 \\\\\n", 195 | "\\end{align*}\n", 196 | "\n", 197 | "Since this has to hold for any choice of $q(\\vartheta^i)$, this means that:\n", 198 | "\n", 199 | "\\begin{align*}\n", 200 | "\\left( V(\\vartheta^i) - \\ln q(\\vartheta^i) - \\ln Z^i \\right) &= 0 \\\\\n", 201 | "\\ln q(\\vartheta^i) &= V(\\vartheta^i) - \\ln Z^i\\\\\n", 202 | "q(\\vartheta^i) &= \\frac{1}{Z^i} \\exp \\left(V(\\vartheta^i)\\right) = \\frac{1}{Z^i} \\exp\\left(\\left< U(\\vartheta) \\right>_{q(\\vartheta^{\\backslash i})}\\right) \n", 203 | "\\end{align*}\n", 204 | "\n", 205 | "Thus, $Z^i$ is a normalisation constant ensuring the distribtion integrates to $1$, and is also called a partition function in physics. The final equation indicates that the variational density over one parameter set is an exponential function of the internal energy averaged over all other parameters.\n", 206 | "\n", 207 | "Given our partitions above, we can then write:\n", 208 | "\n", 209 | "\\begin{align*}\n", 210 | "q(u(t)) &\\propto \\exp \\left(V(t)\\right) \\\\\n", 211 | "V(t) &= \\left< U(t) \\right>_{q(\\theta)q(\\lambda)} \\\\\n", 212 | "q(\\theta) &\\propto \\exp \\left(\\bar{V}^\\theta \\right) \\\\\n", 213 | "\\bar{V}^\\theta &= \\int \\left< U(t) \\right>_{q(u)q(\\lambda)} \\:dt + U^\\theta \\\\\n", 214 | "q(\\lambda) &\\propto \\exp \\left(\\bar{V}^\\lambda \\right) \\\\\n", 215 | "\\bar{V}^\\lambda &= \\int \\left< U(t) \\right>_{q(u)q(\\theta)} \\:dt + U^\\lambda \\\\\n", 216 | "\\bar{U} &= \\int U(t)\\:dt + U^\\theta + U^\\lambda \\\\\n", 217 | "\\end{align*}\n", 218 | "\n", 219 | "In a dynamical system, the instantaneous internal energy $U(t)$ is a function of time. Because the parameters and hyperparameters are considered constant over a period of observation, their variational densities are functions of the path integal of this internal energy, which is called *action* in physics. We use the notation $\\bar{V}$ and $\\bar{U}$ to indicate these integrals. $U^\\theta = \\ln p(\\theta)$ and $U^\\lambda = \\ln p(\\lambda)$ are the prior energies of the parameters and hyperparameters, respectively. \n", 220 | "\n", 221 | "From these equations we see that the variational density over states can be determined from the instantaneous internal energy averaged over parameters and hyperparameters, whereas the density over parameters and hyperparameters can only be determined when data has been observed over a certain amount of time. In the absence of data, the integrals will be zero, and the conditional density simply reduces to the prior density.\n", 222 | "\n", 223 | "*Variational Bayes* assumes the above equations are analytically tractable, which needs the choice of appropriate (conjugate) priors. The conditional distributions $q(\\vartheta^i)$ above can then be updated through iteration as new data becomes available:\n", 224 | "\n", 225 | "\\begin{align*}\n", 226 | "U(t) &= \\ln p(\\vartheta, \\tilde y, m)(t) \\\\\n", 227 | "\\ln q(u(t)) &\\propto \\left< U(t) \\right>_{q(\\theta)q(\\lambda)} \\\\\n", 228 | "\\ln q(\\theta) &\\propto \\int \\left< U(t) \\right>_{q(u)q(\\lambda)} \\:dt + \\ln p(\\theta) \\\\\n", 229 | "\\ln q(\\lambda) &\\propto \\int \\left< U(t) \\right>_{q(u)q(\\theta)} \\:dt + \\ln p(\\lambda)\n", 230 | "\\end{align*}\n", 231 | "\n" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "### The Laplace approximation ###\n", 239 | "\n", 240 | "If the above equations are not analytically tractable, or if we want to avoid those calculations, we can apply the Laplace approximation. The Laplace approximation assumes that the marginals of the conditional density assume a Gaussian form, i.e., $q(\\vartheta^i) = \\mathcal{N}(\\vartheta^i : \\mu^i, \\Sigma^i)$, where $\\mu^i$ and $\\Sigma^i$ are the sufficient statistics. For notational clarity, we will use $\\mu^i$, $\\Sigma^i$, and $\\Pi^i$ for the conditional mean, covariance, and precision of the $i^\\text{th}$ marginal, respectively, and $\\eta^i$, $C^i$, and $P^i$ for their priors. This approximation simplifies the updates to the marginals of the conditional densities.\n", 241 | "\n", 242 | "For each partition $\\vartheta^i$, we can then write:\n", 243 | "\n", 244 | "\\begin{align*}\n", 245 | "q(\\vartheta^i) &= \\frac{1}{\\sqrt{(2\\pi)^{n^i} |\\Sigma^i|}} \\exp \\left( \\frac{-(\\vartheta^i - \\mu^i)^2}{2\\Sigma^i} \\right) \\\\\n", 246 | "&= \\frac{1}{Z^i} \\exp \\left( -\\varepsilon(\\vartheta^i) \\right) \\\\ \n", 247 | "Z^i &= \\sqrt{(2\\pi)^{n^i} |\\Sigma^i|} \\\\\n", 248 | "\\varepsilon(\\vartheta^i) &= \\frac{(\\vartheta^i - \\mu^i)^2}{2\\Sigma^i} \\\\\n", 249 | "&= \\frac{1}{2} (\\vartheta^i - \\mu^i)^T {\\Sigma^i}^{-1} (\\vartheta^i - \\mu^i)\n", 250 | "\\end{align*}\n", 251 | "\n", 252 | "Where $n^i$ is the number of parameters in partition $i$.\n", 253 | "\n", 254 | "Recall that the Free Energy was defined as:\n", 255 | "\n", 256 | "\\begin{align*}\n", 257 | "F &= -\\int{q(\\vartheta) \\: \\ln \\frac{q(\\vartheta)}{p(\\vartheta, \\tilde y, m)} \\: d\\vartheta} \\\\\n", 258 | "&= - \\int{q(\\vartheta) \\: \\ln q(\\vartheta) \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\ln p(\\vartheta, \\tilde y, m) \\: d\\vartheta} \\\\\n", 259 | "&= - \\int{q(\\vartheta) \\: \\ln \\prod_i q(\\vartheta^i) \\: d\\vartheta} + \\left< U \\right>_q \\\\\n", 260 | "&= - \\int{q(\\vartheta) \\: \\ln \\prod_i \\frac{1}{Z^i} \\exp \\left( -\\varepsilon(\\vartheta^i) \\right) \\: d\\vartheta} + \\left< U \\right>_q \\\\\n", 261 | "&= \\int{q(\\vartheta) \\: \\sum_i(\\ln Z^i + \\varepsilon(\\vartheta^i)) \\: d\\vartheta} + \\left< U \\right>_q \\\\\n", 262 | "&= \\sum_i(\\ln Z^i)\\int{q(\\vartheta) \\: d\\vartheta} + \\int{q(\\vartheta) \\: \\sum_i(\\varepsilon(\\vartheta^i)) \\: d\\vartheta} + \\left< U \\right>_q \\\\\n", 263 | "&\\left(\\int{q(\\vartheta) \\: d\\vartheta} = 1\\right) \\\\\n", 264 | "&= \\sum_i(\\ln Z^i) + \\int{q(\\vartheta) \\: \\sum_i(\\frac{1}{2} (\\vartheta^i - \\mu^i)^T {\\Sigma^i}^{-1} (\\vartheta^i - \\mu^i)) \\: d\\vartheta} + \\left< U \\right>_q \\\\\n", 265 | "&= \\sum_i(\\ln Z^i) + \\sum_i(\\frac{1}{2} \\int{q(\\vartheta^i) \\: (\\vartheta^i - \\mu^i)^T {\\Sigma^i}^{-1} (\\vartheta^i - \\mu^i) \\: d\\vartheta^i}) + \\left< U \\right>_q \\\\\n", 266 | "&\\left( \\text{using } a^T B a = \\text{tr} \\left( a a^T B \\right) \\right)\\\\\n", 267 | "&= \\sum_i(\\ln Z^i) + \\sum_i \\frac{1}{2}\\int{q(\\vartheta^i) \\: \\text{tr} \\left((\\vartheta^i - \\mu^i) (\\vartheta^i - \\mu^i)^T {\\Sigma^i}^{-1} \\right) \\: d\\vartheta^i}) + \\left< U \\right>_q \\\\\n", 268 | "&= \\sum_i(\\ln Z^i) + \\sum_i \\frac{1}{2} \\int{\\text{tr} \\left(q(\\vartheta^i) \\: (\\vartheta^i - \\mu^i) (\\vartheta^i - \\mu^i)^T {\\Sigma^i}^{-1} \\right) \\: d\\vartheta^i}) + \\left< U \\right>_q \\\\\n", 269 | "&= \\sum_i(\\ln Z^i) + \\sum_i \\frac{1}{2} \\text{tr} \\left(\\int{q(\\vartheta^i) \\: (\\vartheta^i - \\mu^i) (\\vartheta^i - \\mu^i)^T {\\Sigma^i}^{-1} \\: d\\vartheta^i}) \\right) + \\left< U \\right>_q \\\\\n", 270 | "&= \\sum_i(\\ln Z^i) + \\sum_i \\frac{1}{2} \\text{tr} \\left(\\Sigma^i {\\Sigma^i}^{-1} \\right) + \\left< U \\right>_q \\\\\n", 271 | "&= \\sum_i(\\ln Z^i + \\frac{n^i}{2}) + \\left< U \\right>_q \\\\\n", 272 | "\\end{align*}\n" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "Now we still need to find an expression we can calculate for $\\left< U \\right>_q$. To do this, a further approximation assumes that $q$ is sharply peaked at its mean value $\\mu$, so that the integration is only non-zero close to $\\vartheta = \\mu$. This seems quite restrictive: not only do we assume $q$ is a Gaussian distribution, but also it has to be a narrow distribution around its mean. However, this is just another way of saying that the parameters are nearly constant over the time segment that we are doing inference over. With this assumption we can then use a Taylor expansion around the mean up to second order to obtain: \n", 280 | "\n", 281 | "\\begin{align*}\n", 282 | "\\left< U \\right>_q &= \\int{q(\\vartheta) \\: U(\\vartheta, \\tilde y) \\: d\\vartheta} \\\\\n", 283 | "&= \\int{q(\\vartheta) \\: \\left\\{ U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\delta\\vartheta + \\frac{1}{2} \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\delta\\vartheta^2 \\right\\} \\: d\\vartheta} \\\\\n", 284 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\int{q(\\vartheta) \\: (\\vartheta - \\mu) \\: d\\vartheta} + \\frac{1}{2} \\int{ q(\\vartheta) \\: (\\vartheta - \\mu)^T \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu (\\vartheta - \\mu) \\: d\\vartheta} \\\\\n", 285 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\left\\{ \\int{\\vartheta q(\\vartheta) \\: d\\vartheta} - \\mu \\right\\} + \\frac{1}{2} \\int{ q(\\vartheta) \\: \\text{tr} \\left( (\\vartheta - \\mu) (\\vartheta - \\mu)^T \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\right) \\: d\\vartheta} \\\\\n", 286 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\left\\{ \\int{\\vartheta q(\\vartheta) \\: d\\vartheta} - \\mu \\right\\} + \\frac{1}{2} \\text{tr} \\left( \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\int{ q(\\vartheta) \\: (\\vartheta - \\mu) (\\vartheta - \\mu)^T \\: d\\vartheta} \\right) \\\\\n", 287 | "&= U(\\mu, \\tilde y) + \\left[ \\frac{dU}{d\\vartheta} \\right]_\\mu \\left\\{ \\mu - \\mu \\right\\} + \\frac{1}{2} \\text{tr} \\left( \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\Sigma \\right) \\\\\n", 288 | "&= U(\\mu, \\tilde y) + \\frac{1}{2} \\text{tr} \\left( \\left[ \\frac{d^2U}{d\\vartheta^2} \\right]_\\mu \\Sigma \\right) \\\\\n", 289 | "\\end{align*}\n", 290 | "\n", 291 | "This now allows us to write for the free energy:\n", 292 | "\n", 293 | "\\begin{align*}\n", 294 | "F &= \\sum_i\\left(\\ln Z^i + \\frac{n^i}{2}\\right) + \\left< U \\right>_q \\\\\n", 295 | "&= \\sum_i\\left(\\frac{1}{2} \\ln (2\\pi)^{n^i} |\\Sigma^i| + \\frac{n^i}{2}\\right) + U(\\mu, \\tilde y) + \\frac{1}{2} \\sum_i \\text{tr} \\left( \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i} \\Sigma^i \\right)\\\\\n", 296 | "&= \\frac{1}{2} \\sum_i\\left(n^i \\ln (2\\pi) + \\ln |\\Sigma^i| + n^i \\right) + U(\\mu, \\tilde y) + \\frac{1}{2} \\sum_i \\left\\{ \\text{tr} \\left( \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i} \\Sigma^i \\right) \\right\\} \\\\\n", 297 | "\\end{align*}\n", 298 | "\n", 299 | "To find the optimal variances, we maximise the free energy with respect to the variances, so that the partial derivatives are zero:\n", 300 | "\n", 301 | "\\begin{align*}\n", 302 | "\\frac{dF}{d\\Sigma^i} &= \\frac{1}{2} \\left\\{ \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i} + {\\Sigma^i}^{-1} \\right\\}^T = 0 \\\\\n", 303 | "\\implies \\Sigma^{i*} &= - \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i}^{-1} \\\\\n", 304 | "\\end{align*}\n", 305 | "\n", 306 | "where we've used the matrix derivative identities:\n", 307 | "\n", 308 | "\\begin{align*}\n", 309 | "\\frac{d \\text{tr} \\left( B A \\right)}{dA} &:= B^T \\\\\n", 310 | "\\frac{d \\ln |A|}{dA} &:= {{A}^{-1}}^T \\\\\n", 311 | "\\end{align*}\n", 312 | "\n", 313 | "and we use the notation $\\Sigma^{i*}$ to indicate this is the optimal variance which maximises the free energy.\n", 314 | "\n", 315 | "Finally, this allows us to write for the free energy under the Laplace approximation with sharply peaked Gaussian distributions for $q(\\vartheta^i)$:\n", 316 | "\n", 317 | "\\begin{align*}\n", 318 | "F &= \\frac{1}{2} \\sum_i\\left(n^i \\ln (2\\pi) + \\ln |\\Sigma^{i*}| + n^i\\right) + U(\\mu, \\tilde y) + \\frac{1}{2} \\sum_i \\text{tr}\\left(- \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i} \\left[ \\frac{d^2U}{d{\\vartheta^i}^2} \\right]_{\\mu^i}^{-1}\\right)\\\\\n", 319 | "&= \\frac{1}{2} \\sum_i\\left(n^i \\ln (2\\pi) +\\ln |\\Sigma^{i*}| + n^i\\right) + U(\\mu, \\tilde y) + \\frac{1}{2} \\sum_i \\text{tr}\\left(-I\\right)\\\\\n", 320 | "&= \\frac{1}{2} \\sum_i\\left(n^i \\ln (2\\pi) + \\ln |\\Sigma^{i*}| + n^i\\right) + U(\\mu, \\tilde y) - \\frac{n^i}{2} \\\\\n", 321 | "&= U(\\mu, \\tilde y) + \\frac{1}{2} \\sum_i \\left( n^i \\ln (2\\pi) + \\ln |\\Sigma^{i*}| \\right)\n", 322 | "\\end{align*}" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "We now proceed to obtain an expression for the variational distribution of each partition under the Laplace approximation.\n", 330 | "\n", 331 | "\\begin{align*}\n", 332 | "\\ln q(u(t)) &\\propto V(t) = \\left< U(t) \\right>_{q(\\theta)q(\\lambda)}\\\\\n", 333 | "\\ln q(\\theta) &\\propto \\bar{V}^\\theta = \\int \\left< U(t) \\right>_{q(u)q(\\lambda)} \\:dt + U^\\theta\\\\\n", 334 | "\\ln q(\\lambda) &\\propto \\bar{V}^\\lambda = \\int \\left< U(t) \\right>_{q(u)q(\\theta)} \\:dt + U^\\lambda\\\\\n", 335 | "\\end{align*}\n", 336 | "\n", 337 | "Because of the mean field assumption, we can also write:\n", 338 | "\n", 339 | "\\begin{align*}\n", 340 | "\\ln p(\\vartheta, \\tilde y, m) &= \\ln p(u, t, \\tilde y, m) + \\ln p(\\theta, \\tilde y, m) + \\ln p(\\lambda, \\tilde y, m)\\\\\n", 341 | "\\end{align*}\n", 342 | "\n", 343 | "and we have from before our estimates for these distributions:\n", 344 | "\n", 345 | "\\begin{align*}\n", 346 | "q(\\vartheta^i) &= \\frac{1}{Z^i} \\exp \\left( \\frac{-(\\vartheta^i - \\mu^i)^2}{2\\Sigma^i} \\right) \\\\\n", 347 | "\\end{align*}\n", 348 | "\n", 349 | "Now this allows us to proceed with:\n", 350 | "\n", 351 | "\\begin{align*}\n", 352 | "V(t) &= \\left< U(t) \\right>_{q(\\theta)q(\\lambda)} \\\\\n", 353 | "&= \\int q(\\lambda) q(\\theta) \\ln p(\\vartheta, \\tilde y, m) d\\theta d\\lambda \\\\\n", 354 | "&= \\int q(\\lambda) q(\\theta) \\ln p(u, t, \\tilde y, m) + \\ln p(\\theta, \\tilde y, m) + \\ln p(\\lambda, \\tilde y, m) d\\theta d\\lambda \\\\\n", 355 | "&\\approx \\int q(\\lambda) q(\\theta) \\ln p(u, t, \\tilde y, m) + \\ln q(\\theta, \\tilde y, m) + \\ln q(\\lambda, \\tilde y, m) d\\theta d\\lambda \\\\\n", 356 | "&= \\int q(\\lambda) q(\\theta) \\ln \\left( p(u,t|\\mu^\\theta, \\mu^\\lambda) \\right) d\\theta d\\lambda \\\\\n", 357 | "&+ \\int q(\\lambda) q(\\theta) \\left( -\\frac{1}{2} \\left( \\theta - \\mu^{\\theta} \\right)^T {\\Sigma^{\\theta}}^{-1} \\left( \\theta - \\mu^{\\theta} \\right) - \\ln Z^\\theta \\right) d\\theta d\\lambda \\\\\n", 358 | "&+ \\int q(\\lambda) q(\\theta) \\left( -\\frac{1}{2} \\left( \\lambda - \\mu^{\\lambda} \\right)^T {\\Sigma^{\\lambda}}^{-1} \\left( \\lambda - \\mu^{\\lambda} \\right) - \\ln Z^\\lambda \\right) d\\theta d\\lambda \\\\\n", 359 | "&= U(u,t|\\mu^\\theta, \\mu^\\lambda) \\\\\n", 360 | "&+ \\int q(\\lambda) q(\\theta) \\left( -\\frac{1}{2} \\left( \\theta - \\mu^{\\theta} \\right)^T {\\Sigma^{\\theta}}^{-1} \\left( \\theta - \\mu^{\\theta} \\right) \\right) d\\theta d\\lambda - \\int q(\\theta) \\ln Z^\\theta d\\theta\\\\\n", 361 | "&+ \\int q(\\lambda) q(\\theta) \\left( -\\frac{1}{2} \\left( \\lambda - \\mu^{\\lambda} \\right)^T {\\Sigma^{\\lambda}}^{-1} \\left( \\lambda - \\mu^{\\lambda} \\right) \\right) d\\theta d\\lambda - \\int q(\\lambda) \\ln Z^\\lambda d\\lambda\\\\\n", 362 | "&\\text{(inserting the optimal variances:)}\\\\\n", 363 | "&= U(u,t|\\mu^\\theta, \\mu^\\lambda) \\\\\n", 364 | "&+ \\int q(\\lambda) q(\\theta) \\left( \\frac{1}{2} \\left( \\theta - \\mu^{\\theta} \\right)^T \\left[ \\frac{d^2U(t)}{d{\\theta}^2} \\right]_{\\mu^\\theta} \\left( \\theta - \\mu^{\\theta} \\right) \\right) d\\theta d\\lambda - \\left< \\ln Z^\\theta \\right>_{q(\\theta)}\\\\\n", 365 | "&+ \\int q(\\lambda) q(\\theta) \\left( \\frac{1}{2} \\left( \\lambda - \\mu^{\\lambda} \\right)^T \\left[ \\frac{d^2U(t)}{d{\\lambda}^2} \\right]_{\\mu^\\lambda} \\left( \\lambda - \\mu^{\\lambda} \\right) \\right) d\\theta d\\lambda - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)}\\\\\n", 366 | "&\\left( \\text{and using } a^T B a = \\text{tr} \\left( a a^T B \\right) \\right)\\\\\n", 367 | "&= U(u,t|\\mu^\\theta, \\mu^\\lambda) \\\\\n", 368 | "&+ \\int q(\\theta) \\text{tr} \\left( \\frac{1}{2} \\left( \\theta - \\mu^{\\theta} \\right) \\left( \\theta - \\mu^{\\theta} \\right)^T \\left[ \\frac{d^2U(t)}{d{\\theta}^2} \\right]_{\\mu^\\theta} \\right) d\\theta - \\left< \\ln Z^\\theta \\right>_{q(\\theta)}\\\\\n", 369 | "&+ \\int q(\\lambda) \\text{tr} \\left( \\frac{1}{2} \\left( \\lambda - \\mu^{\\lambda} \\right) \\left( \\lambda - \\mu^{\\lambda} \\right)^T \\left[ \\frac{d^2U(t)}{d{\\lambda}^2} \\right]_{\\mu^\\lambda} \\right) d\\lambda - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)}\\\\\n", 370 | "&\\left( \\text{and the integal of a trace is the trace of the integral:} \\right)\\\\\n", 371 | "&= U(u,t|\\mu^\\theta, \\mu^\\lambda) \\\\\n", 372 | "&+ \\frac{1}{2} \\text{tr} \\left( \\int q(\\theta) \\left( \\theta - \\mu^{\\theta} \\right) \\left( \\theta - \\mu^{\\theta} \\right)^T \\left[ \\frac{d^2U(t)}{d{\\theta}^2} \\right]_{\\mu^\\theta} d\\theta \\right) - \\left< \\ln Z^\\theta \\right>_{q(\\theta)}\\\\\n", 373 | "&+ \\frac{1}{2} \\text{tr} \\left( \\int q(\\lambda)\\left( \\lambda - \\mu^{\\lambda} \\right) \\left( \\lambda - \\mu^{\\lambda} \\right)^T \\left[ \\frac{d^2U(t)}{d{\\lambda}^2} \\right]_{\\mu^\\lambda} d\\lambda \\right) - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)}\\\\\n", 374 | "&\\left( \\text{and because the optimal variances are constant when evaluating the integral} \\right)\\\\\n", 375 | "&= U(u,t|\\mu^\\theta, \\mu^\\lambda) \\\\\n", 376 | "&+ \\frac{1}{2} \\text{tr} \\left( \\int q(\\theta) \\left( \\theta - \\mu^{\\theta} \\right) \\left( \\theta - \\mu^{\\theta} \\right)^Td\\theta \\left[ \\frac{d^2U(t)}{d{\\theta}^2} \\right]_{\\mu^\\theta} \\right) - \\left< \\ln Z^\\theta \\right>_{q(\\theta)}\\\\\n", 377 | "&+ \\frac{1}{2} \\text{tr} \\left( \\int q(\\lambda)\\left( \\lambda - \\mu^{\\lambda} \\right) \\left( \\lambda - \\mu^{\\lambda} \\right)^T d\\lambda \\left[ \\frac{d^2U(t)}{d{\\lambda}^2} \\right]_{\\mu^\\lambda} \\right) - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)}\\\\\n", 378 | "&= U(u,t|\\mu^\\theta, \\mu^\\lambda) + \\frac{1}{2} \\text{tr}\\left( \\Sigma^\\theta \\left[ \\frac{d^2U(t)}{d{\\theta}^2} \\right]_{\\mu^\\theta} \\right) + \\frac{1}{2} \\text{tr} \\left(\\Sigma^\\lambda \\left[ \\frac{d^2U(t)}{d{\\lambda}^2} \\right]_{\\mu^\\lambda} \\right) - \\left< \\ln Z^\\theta \\right>_{q(\\theta)} - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)}\\\\\n", 379 | "\\bar{V}^u &= \\int V(t) \\:dt \\\\\n", 380 | "&= \\int U(u, t|\\mu^\\theta, \\mu^\\lambda) + W(t)^\\theta + W(t)^\\lambda - \\left< \\ln Z^\\theta \\right>_{q(\\theta)} - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)} \\:dt \\\\\n", 381 | "\\text{Similarly, we find:}\\\\\n", 382 | "\\bar{V}^\\theta &= \\int U(\\mu^u, t|\\theta, \\mu^\\lambda) + W(t)^u + W(t)^\\lambda - \\left< \\ln Z^u \\right>_{q(u)} - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)} \\:dt + U^\\theta \\\\\n", 383 | "\\bar{V}^\\lambda &= \\int U(\\mu^u, t|\\mu^\\theta, \\lambda) + W(t)^u + W(t)^\\theta - \\left< \\ln Z^u \\right>_{q(u)} - \\left< \\ln Z^\\theta \\right>_{q(\\theta)} \\:dt + U^\\lambda \\\\\n", 384 | "W(t)^u &= \\frac{1}{2} \\text{tr}(\\Sigma^u U(t)_{uu}) \\\\\n", 385 | "W(t)^\\theta &= \\frac{1}{2} \\text{tr}(\\Sigma^\\theta U(t)_{\\theta\\theta}) \\\\\n", 386 | "W(t)^\\lambda &= \\frac{1}{2} \\text{tr}(\\Sigma^\\lambda U(t)_{\\lambda\\lambda}) \\\\\n", 387 | "\\end{align*}\n", 388 | "\n", 389 | "Also, the conditional precisions are equal to the negative curvatures of the internal action:\n", 390 | "\n", 391 | "\\begin{align*}\n", 392 | "\\bar{U} &= \\int U(t)\\:dt + U^\\theta + U^\\lambda \\\\\n", 393 | "\\Pi^u &= -\\bar{U}_{uu} = -U(t)_{uu} \\\\\n", 394 | "\\Pi^\\theta &= -\\bar{U}_{\\theta\\theta} = - \\int U(t)_{\\theta\\theta} \\: dt - U^\\theta_{\\theta\\theta} \\\\\n", 395 | "\\Pi^\\lambda &= -\\bar{U}_{\\lambda\\lambda} = - \\int U(t)_{\\lambda\\lambda} \\: dt- U^\\lambda_{\\lambda\\lambda} \\\\\n", 396 | "\\end{align*}\n", 397 | "\n", 398 | "To get their optimal values, these need to be evaluated at the mode of each partition." 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "### Temporal smoothness ###\n", 406 | "\n", 407 | "Since the different levels of motion in the generalised coordinates are linked, the actual precision matrix will have off-diagonal elements with non-zero values. The precision is given by the Kronecker product $\\tilde \\Pi^i = S(\\gamma) \\otimes \\Pi^i$, where $\\Pi^i$ is a diagonal matrix specifying the precision of the (often assumed independent) Gaussian noise at each level as determined in the previous section, and $S(\\gamma)$ is the temporal precision matrix, which encodes the temporal dependencies between levels, which is a function of their autocorrelations:\n", 408 | "\n", 409 | "\\begin{align*}\n", 410 | "S =\n", 411 | " \\begin{bmatrix}\n", 412 | " 1 & 0 & \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 \\\\\n", 413 | " 0 & -\\ddot{\\rho}(0) & 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0)\\\\\n", 414 | " \\ddot{\\rho}(0) & 0 & \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 \\\\\n", 415 | " 0 & -\\ddot{\\ddot{\\rho}}(0) & 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) \\\\\n", 416 | " \\ddot{\\ddot{\\rho}}(0) & 0 & \\dddot{\\dddot{\\rho}}(0) & 0 & \\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 \\\\\n", 417 | " 0 & -\\dddot{\\dddot{\\rho}}(0) & 0 & -\\ddot{\\dddot{\\dddot{\\rho}}}(0) & 0 & -\\ddot{\\ddot{\\dddot{\\dddot{\\rho}}}}(0)\\\\\n", 418 | " \\end{bmatrix}^{-1}\n", 419 | "\\end{align*}\n", 420 | "\n", 421 | "(see [here](Generalised%20precision%20matrix.ipynb) for my derivation of this step.) Here $\\ddot{\\rho}(0)$ is the second derivative with respect to time of the autocorrelation function evaluated at zero. Note, that because the autocorrelation function is even (symmetrical for positive and negative delays), the odd derivatives of the autocorrelation function are all odd functions, and thus are zero when evaluated at zero.\n", 422 | "\n", 423 | "While $S$ can be evaluated for any analytical autocorrelation function, we assume here that the temporal correlations all have the same Gaussian form, which gives:\n", 424 | "\n", 425 | "\\begin{align*}\n", 426 | "S &=\n", 427 | " \\begin{bmatrix}\n", 428 | " 1 & 0 & -\\gamma & 0 & 3 \\gamma^2 & 0 \\\\\n", 429 | " 0 & \\gamma & 0 & -3 \\gamma^2 & 0 & 15 \\gamma^3 \\\\\n", 430 | " -\\gamma & 0 & 3 \\gamma^2 & 0 & -15 \\gamma^3 & 0 \\\\\n", 431 | " 0 & -3 \\gamma^2 & 0 & 15 \\gamma^3 & 0 & -105 \\gamma^4 \\\\\n", 432 | " 3 \\gamma^2 & 0 & -15 \\gamma^3 & 0 & 105 \\gamma^4 & 0 \\\\\n", 433 | " 0 & 15 \\gamma^3 & 0 & -105 \\gamma^4 & 0 & 945 \\gamma^5 \\\\\n", 434 | " \\end{bmatrix}^{-1}\n", 435 | "\\end{align*}\n", 436 | "\n", 437 | "Here, $\\gamma$ is the precision parameter of a Gaussian autocorrelation function. Typically, $\\gamma > 1$, which ensures the precisions of high-order motion converge quickly to zero. This is important because it enables us to truncate the representation of an infinite number of generalised coordinates to a relatively small number, since high-order prediction errors have a vanishingly small precision. Friston states that an order of n=6 is sufficient in most cases. Also note that my derivation of this matrix compared to Friston has $\\gamma_{AvS} = 1/2 \\gamma_{KF}$. I believe this is because I took the Gaussian autocorrelation function with precision parameter $\\gamma$ to be $\\rho(h) = \\exp(-\\frac{\\gamma}{2} h^2)$, whereas Friston's result implies he used $\\rho(h) = \\exp(-\\frac{\\gamma}{4} h^2)$, which is an odd definition.\n", 438 | "\n", 439 | "Putting all this together, we get:\n", 440 | "\n", 441 | "\\begin{align*}\n", 442 | "p(\\tilde y, \\tilde x, \\tilde \\nu \\,|\\, \\theta, \\lambda) &= p(\\tilde y \\,|\\, \\tilde x, \\tilde \\nu, \\theta, \\lambda) \\; p(\\tilde x \\,|\\, \\tilde \\nu, \\theta, \\lambda) \\; p(\\tilde \\nu) \\\\\n", 443 | " &= (2\\pi)^{-N_y/2} {|\\tilde\\Pi^z|}^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon^\\nu}^T \\tilde\\Pi^z \\tilde\\varepsilon^\\nu} (2\\pi)^{-N_x/2} {|\\tilde\\Pi^w|}^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon^x}^T \\tilde\\Pi^w \\tilde\\varepsilon^x} \\; p(\\tilde \\nu) \\\\\n", 444 | " &= (2\\pi)^{-(N_y + N_x)/2} (|\\tilde\\Pi^z| + |\\tilde\\Pi^w|)^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon^\\nu}^T \\tilde\\Pi^z \\tilde\\varepsilon^\\nu} e^{-\\frac{1}{2}{\\tilde\\varepsilon^x}^T \\tilde\\Pi^w \\tilde\\varepsilon^x} \\; p(\\tilde \\nu)\\\\\n", 445 | " &= (2\\pi)^{-N/2} |\\tilde\\Pi|^{1/2} e^{-\\frac{1}{2}{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} \\; p(\\tilde \\nu)\\\\ \n", 446 | "\\tilde\\Pi &= \n", 447 | " \\begin{bmatrix}\n", 448 | " \\tilde\\Pi^z & \\\\\n", 449 | " & \\tilde\\Pi^w\n", 450 | " \\end{bmatrix}\\\\\n", 451 | "\\tilde\\varepsilon &= \n", 452 | " \\begin{bmatrix}\n", 453 | " \\tilde\\varepsilon^\\nu = \\tilde y - \\tilde g \\\\\n", 454 | " \\tilde\\varepsilon^x = D\\tilde x - \\tilde f\n", 455 | " \\end{bmatrix}\\\\\n", 456 | "N &= \\text{Rank}(\\tilde\\Pi)\n", 457 | "\\end{align*} \n", 458 | "\n", 459 | "Here we introduce auxilary variables $\\tilde\\varepsilon(t)$, which are the prediction errors for the generalised responses and motion of the hidden states, with respective predictions $\\tilde g(t)$ and $\\tilde f(t)$ and their precisions encoded by $\\tilde\\Pi$.\n", 460 | "\n", 461 | "The log probability can thus be written:\n", 462 | "\\begin{align*}\n", 463 | "\\ln p(\\tilde y, \\tilde x, \\tilde \\nu \\,|\\, \\theta, \\lambda) &= U(t) = - \\frac{N}{2} \\ln 2\\pi + \\frac{1}{2} \\ln |\\tilde\\Pi| - \\frac{1}{2}{{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} + \\ln p(\\tilde \\nu)\n", 464 | "\\end{align*}\n", 465 | "\n", 466 | "where the first term is constant, and the fourth term is defined by the input causes and considered known." 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": {}, 472 | "source": [ 473 | "### Sequences ###\n", 474 | "Our observations are often not in the form of generalised coordinates, but instead as a sequence of observations in time. We can use Taylor's theorem to link the two noting that the generalised coordinates contain the higher order derivatives of the state at time $t$. Friston writes this as:\n", 475 | "\n", 476 | "\\begin{align*}\n", 477 | "y &= [y(1), \\dots, y(N)]^T \\\\\n", 478 | "y &= \\tilde E(t) \\tilde y(t) \\\\\n", 479 | "\\tilde E(t) &= E(t) \\otimes I \\\\\n", 480 | "E(t)_{ij} &= \\frac{(i-t)^{j-1}}{(j-1)!}\\\\\n", 481 | "\\tilde y(t) &= \\tilde E(t)^{-1} y\\\\\n", 482 | "\\end{align*}\n", 483 | "\n", 484 | "The relation between $i$ and $t$ is not entirely clear to me in his formulation, nor is the numbering of the levels. For a single state we can generate a sequence $y = [y[t - 1\\Delta t], \\dots, y[t - N\\Delta t]]^T$ from the generalised coordinates $\\tilde y$ with elements $\\overset{j}{y}$ where $j = 0 \\dots n$, using the Taylor expansion:\n", 485 | "\n", 486 | "\\begin{align*}\n", 487 | "y &= E(t) \\tilde y(t) \\\\\n", 488 | "\\begin{bmatrix}\n", 489 | "y(t-1\\Delta t) \\\\\n", 490 | "y(t-2\\Delta t) \\\\\n", 491 | "\\vdots \\\\\n", 492 | "y(t-N\\Delta t) \\\\\n", 493 | "\\end{bmatrix}\n", 494 | " &= \n", 495 | "\\begin{bmatrix}\\frac{(-1\\Delta t)^0}{0!}& \\frac{(-1\\Delta t)^1}{1!}& \\frac{(-1\\Delta t)^2}{2!}& ...& \\frac{(-1\\Delta t)^n}{n!} \\\\\n", 496 | "\\frac{(-2\\Delta t)^0}{0!}& \\frac{(-2\\Delta t)^1}{1!}& \\frac{(-2\\Delta t)^2}{2!}& ...& \\frac{(-2\\Delta t)^n}{n!} \\\\\n", 497 | "& & \\vdots \\\\\n", 498 | "\\frac{(-N\\Delta t)^0}{0!}& \\frac{(-N\\Delta t)^1}{1!}& \\frac{(-N\\Delta t)^2}{2!}& ...& \\frac{(-N\\Delta t)^n}{n!}\\end{bmatrix} \n", 499 | "\\begin{bmatrix}\n", 500 | "\\overset{0}{y(t)}\\\\\n", 501 | "\\overset{1}{y(t)}\\\\\n", 502 | "\\overset{2}{y(t)}\\\\\n", 503 | "\\vdots \\\\\n", 504 | "\\overset{n}{y(t)}\\\\\n", 505 | "\\end{bmatrix}\\\\\n", 506 | "E(t)_{ij} &= \\frac{(-(i+1)\\Delta t)^{j}}{(j)!}\\\\\n", 507 | "\\tilde y(t) &= E(t)^{-1} y\\\\\n", 508 | "\\end{align*}\n", 509 | "\n", 510 | "For this to work, $E(t)^{-1}$ needs to exist, meaning that it needs to be a square matrix and the number of elements in the sequence $N$ ($i = 0, \\dots, N-1$) equals the number of levels in the generalised coordinates $n+1$ (the $j$s)." 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "This is a fixed matrix for a fixed number of levels in the generalised coordinates. So let's calculate this up to the sixth derivative:" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 46, 523 | "metadata": {}, 524 | "outputs": [ 525 | { 526 | "name": "stdout", 527 | "output_type": "stream", 528 | "text": [ 529 | "E = \n", 530 | "[1 1 1 1 1 1 1] / 1\n", 531 | "[1 2 3 4 5 6 7] / 1\n", 532 | "[ 1 4 9 16 25 36 49] / 2\n", 533 | "[ 1 8 27 64 125 216 343] / 6\n", 534 | "[ 1 16 81 256 625 1296 2401] / 24\n", 535 | "[ 1 32 243 1024 3125 7776 16807] / 120\n", 536 | "[ 1 64 729 4096 15625 46656 117649] / 720\n" 537 | ] 538 | } 539 | ], 540 | "source": [ 541 | "from math import factorial as fac\n", 542 | "from numpy import zeros\n", 543 | "from numpy.linalg import inv\n", 544 | "\n", 545 | "E = zeros((7, 7))\n", 546 | "print(\"E = \")\n", 547 | "for j in range(7):\n", 548 | " for i in range(7):\n", 549 | " E[i, j] = ((i+1)**(j))/fac(j)\n", 550 | " print(\"{:} / {:}\".format((E[:, j] * fac(j)).astype(int), fac(j)))\n" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 47, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "name": "stdout", 560 | "output_type": "stream", 561 | "text": [ 562 | "E_inv = \n", 563 | "[[ 7. -21. 35. -35. 21.\n", 564 | " -7. 1. ]\n", 565 | " [ -11.15 43.95 -79.08333333 82. -50.25\n", 566 | " 16.98333333 -2.45 ]\n", 567 | " [ 14.17777778 -65.48333333 129.66666667 -141.38888889 89.33333333\n", 568 | " -30.81666667 4.51111111]\n", 569 | " [ -13.875 71. -152.375 176. -115.625\n", 570 | " 41. -6.125 ]\n", 571 | " [ 9.83333333 -54. 123.5 -150.66666667 103.5\n", 572 | " -38. 5.83333333]\n", 573 | " [ -4.5 26. -62.5 80. -57.5\n", 574 | " 22. -3.5 ]\n", 575 | " [ 1. -6. 15. -20. 15.\n", 576 | " -6. 1. ]]\n" 577 | ] 578 | } 579 | ], 580 | "source": [ 581 | "E_inv = inv(E)\n", 582 | "print(\"E_inv = \")\n", 583 | "print(E_inv)" 584 | ] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "metadata": {}, 589 | "source": [ 590 | "Except for the third row, it looks like each row can be turned into integers by multiplying with the inverse factorial sequence $(7-i)!$. " 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": 52, 596 | "metadata": {}, 597 | "outputs": [ 598 | { 599 | "name": "stdout", 600 | "output_type": "stream", 601 | "text": [ 602 | "[ 35280. -105840. 176400.00000001 -176400.00000002\n", 603 | " 105840.00000002 -35280.00000001 5040. ] / 5040\n", 604 | "[ -8028. 31644. -56940. 59040.00000001\n", 605 | " -36180.00000001 12228. -1764. ] / 720\n", 606 | "[ 1701.33333333 -7858. 15560. -16966.66666667\n", 607 | " 10720. -3698. 541.33333333] / 120\n", 608 | "[ -333. 1704. -3657. 4224. -2775. 984. -147.] / 24\n", 609 | "[ 59. -324. 741. -904. 621. -228. 35.] / 6\n", 610 | "[ -9. 52. -125. 160. -115. 44. -7.] / 2\n", 611 | "[ 1. -6. 15. -20. 15. -6. 1.] / 1\n" 612 | ] 613 | } 614 | ], 615 | "source": [ 616 | "for i in range(7):\n", 617 | " print(\"{:} / {:}\".format((E_inv[i] * fac(7-i) * 1), fac(7-i)* 1))" 618 | ] 619 | }, 620 | { 621 | "cell_type": "markdown", 622 | "metadata": {}, 623 | "source": [ 624 | "For the third row, an additional factor 3 is required. This is peculiar, but does seem to work." 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 55, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "name": "stdout", 634 | "output_type": "stream", 635 | "text": [ 636 | "[ 35280. -105840. 176400.00000001 -176400.00000002\n", 637 | " 105840.00000002 -35280.00000001 5040. ] / 5040\n", 638 | "[ -8028. 31644. -56940. 59040.00000001\n", 639 | " -36180.00000001 12228. -1764. ] / 720\n", 640 | "[ 5104. -23574. 46680.00000001 -50900.00000001\n", 641 | " 32160.00000001 -11094. 1624. ] / 360\n", 642 | "[ -333. 1704. -3657. 4224. -2775. 984. -147.] / 24\n", 643 | "[ 59. -324. 741. -904. 621. -228. 35.] / 6\n", 644 | "[ -9. 52. -125. 160. -115. 44. -7.] / 2\n", 645 | "[ 1. -6. 15. -20. 15. -6. 1.] / 1\n" 646 | ] 647 | } 648 | ], 649 | "source": [ 650 | "for i in range(7):\n", 651 | " print(\"{:} / {:}\".format((E_inv[i] * fac(7-i) * (1 + 2*(i==2))), fac(7-i)* (1 + 2*(i==2))))" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": {}, 657 | "source": [ 658 | "As a sanity check, the sum of each row should be zero except for the first row, so that if the sequence $y$ is constant, the zeroth order generalised coordinate is equal to that constant, and all the derivatives are zero." 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 50, 664 | "metadata": {}, 665 | "outputs": [ 666 | { 667 | "name": "stdout", 668 | "output_type": "stream", 669 | "text": [ 670 | "[ 1.00000000e+00 -2.46469511e-13 4.44089210e-15 -3.55271368e-14\n", 671 | " 5.68434189e-14 -1.90958360e-14 4.88498131e-15]\n", 672 | "[1 0 0 0 0 0 0]\n" 673 | ] 674 | } 675 | ], 676 | "source": [ 677 | "from numpy import sum\n", 678 | "print(sum(E_inv, axis=1))\n", 679 | "print(sum(E_inv, axis=1).astype(int))" 680 | ] 681 | }, 682 | { 683 | "cell_type": "markdown", 684 | "metadata": {}, 685 | "source": [ 686 | "### Dynamic Expectation Maximisation ###\n", 687 | "Above, we showed that under the Laplace approximation, the optimal precision of each partition can be found by calculating the second derivative (Hessian) of the internal entergy and evaluating it at the mode of that partition. To find the modes we use an optimisation technique. As with conventional variational schemes, we can update the modes of our three partitions in three distinct steps. However, the step dealing with the state (**D**-step) must integrate its conditional mode $\\tilde \\mu := \\mu^u(t)$ over time to accumulate the quantities necessary for updating the parameters (**E**-step) and hyperparameters (**M**-step). We now consider optimising the modes or conditional means in each of these steps." 688 | ] 689 | }, 690 | { 691 | "cell_type": "markdown", 692 | "metadata": {}, 693 | "source": [ 694 | "#### The D-step ####\n", 695 | "\n", 696 | "$\\require{color}$In static systems, the mode of the conditional density maximises variational energy, such that $V(t)_u = 0$; this is the solution to a gradient ascent scheme: $\\dot{\\tilde \\mu} = V(t)_u$. In dynamic systems, we also require the path of the mode to be the mode of the path: $\\dot{\\tilde \\mu} = D \\tilde \\mu$. These two conditions can be satisfied by:\n", 697 | "\n", 698 | "\\begin{align*}\n", 699 | "\\dot{\\tilde \\mu} - D \\tilde \\mu = V(t)_u \\\\\n", 700 | "\\end{align*}\n", 701 | "\n", 702 | "Here $\\dot{\\tilde \\mu} - D \\tilde \\mu $ can be regarded as motion in a frame of reference that moves along the trajectory encoded by the generalised coordinates. The stationary solution in this moving frame of reference maximises variational action. \n", 703 | "\n", 704 | "So far, all our discussions have assumed we're operating in continuous time. However, we would execute the D-step at discrete intervals. To apply this we linearise the trajectory following (Ozaki, 1992):\n", 705 | "\n", 706 | "\\begin{align*}\n", 707 | "\\Delta \\tilde \\mu &= J(t)^{-1} \\left( \\exp(J(t) \\Delta t) - I \\right) \\dot{\\tilde \\mu} \\\\\n", 708 | "J(t) &:= \\frac{\\partial \\dot{\\tilde \\mu}}{\\partial u} \\\\\n", 709 | "\\end{align*}\n", 710 | "\n", 711 | "Note, this linearisation simply uses the Taylor series expansion around $\\tilde \\mu(t)$ to get $\\tilde \\mu(t+\\Delta t)$. For this linearised system we have $\\dot {\\tilde \\mu} = J(t) \\tilde \\mu$ and $\\ddot {\\tilde \\mu} = J(t) \\dot {\\tilde \\mu}$; ... :\n", 712 | "\n", 713 | "\\begin{align*}\n", 714 | "\\exp(x) &= 1 + \\frac{x}{1!} + \\frac{x^2}{2!} + \\frac{x^3}{3!} + \\cdots \\\\\n", 715 | "\\exp(J(t) \\Delta t) - 1 &= J(t) \\Delta t + \\frac{J(t)^2 \\Delta t^2}{2!} + \\frac{J(t)^3 \\Delta t^3}{3!} + \\cdots \\\\\n", 716 | "J(t)^{-1} \\left(\\exp(J(t) \\Delta t) - 1 \\right) &= \\Delta t + \\frac{J(t) \\Delta t^2}{2!} + \\frac{J(t)^2 \\Delta t^3}{3!} + \\cdots \\\\\n", 717 | "\\tilde \\mu(t) + J(t)^{-1} \\left(\\exp(J(t) \\Delta t) - 1 \\right) \\dot{\\tilde \\mu} &= \\tilde \n", 718 | "\\mu(t) + (\\Delta t + \\frac{J(t) \\Delta t^2}{2!} + \\frac{J(t)^2 \\Delta t^3}{3!} + \\cdots) \\dot{\\tilde \\mu} \\\\\n", 719 | "&= \\tilde \\mu(t) + (\\Delta t + \\frac{J(t) \\Delta t^2}{2!} + \\frac{J(t)^2 \\Delta t^3}{3!} + \\cdots) J(t) \\tilde \\mu \\\\\n", 720 | "&= \\tilde \\mu(t) + \\frac{J(t) \\tilde \\mu \\Delta t}{1!} + \\frac{(J(t) \\tilde \\mu \\Delta t)^2}{2!} + \\frac{(J(t) \\tilde \\mu \\Delta t)^3}{3!} + \\cdots \\\\\n", 721 | "&= \\tilde \\mu(t) + \\frac{\\dot {\\tilde \\mu} \\Delta t}{1!} + \\frac{\\ddot {\\tilde \\mu} \\Delta t^2}{2!} + \\frac{\\dddot {\\tilde \\mu} \\Delta t^3}{3!} + \\cdots \\\\\n", 722 | "&= \\tilde \\mu(t + \\Delta t) \\\\\n", 723 | "\\end{align*}" 724 | ] 725 | }, 726 | { 727 | "cell_type": "markdown", 728 | "metadata": {}, 729 | "source": [ 730 | "Thus, the D step becomes:\n", 731 | "\n", 732 | "\\begin{align*}\n", 733 | "\\begin{bmatrix}\n", 734 | "\\dot {\\tilde y} \\\\\n", 735 | "\\dot {\\tilde \\mu} \\\\\n", 736 | "\\dot {\\tilde \\eta} \\\\\n", 737 | "\\end{bmatrix}\n", 738 | "&=\n", 739 | "\\begin{bmatrix}\n", 740 | "D \\tilde y \\\\\n", 741 | "V(t)_u + D \\tilde \\mu \\\\\n", 742 | "D \\tilde \\eta \\\\\n", 743 | "\\end{bmatrix} \\\\\n", 744 | "J(t) &= \n", 745 | "\\begin{bmatrix}\n", 746 | "D & 0& 0 \\\\\n", 747 | "V(t)_{uy} & V(t)_{uu} + D & V(t)_{u\\eta}\\\\\n", 748 | "0 & 0 & D \\\\\n", 749 | "\\end{bmatrix} \\\\\n", 750 | "U(t) &= - \\frac{N}{2} \\ln 2\\pi + \\frac{1}{2} \\ln |\\tilde\\Pi| - \\frac{1}{2}{{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} + \\ln p(\\tilde \\nu) \\\\\n", 751 | "V(t) &= U(t) + W(t)^\\theta + W(t)^\\lambda - \\left< \\ln Z^\\theta \\right>_{q(\\theta)} - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)}\\\\\n", 752 | "V(t)_u &= U(t)_u + W(t)_u^\\theta \\\\\n", 753 | "V(t)_{uu} &= U(t)_{uu} + W(t)_{uu}^\\theta \\\\\n", 754 | "V(t)_{uy} &= U(t)_{uy}\\\\\n", 755 | "V(t)_{u\\eta} &= U(t)_{u\\eta} \\\\\n", 756 | "U(t)_u &= - \\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon \\\\\n", 757 | "U(t)_{uu} &= - \\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon_u \\\\\n", 758 | "U(t)_{uy} &= -\\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon_y \\\\\n", 759 | "U(t)_{u\\eta} &= -\\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon_\\eta \\\\\n", 760 | "U(t)_{\\theta \\theta} &= - \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta \\\\\n", 761 | "U(t)_{\\lambda \\lambda} &= - \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon_\\lambda \\\\\n", 762 | "W(t)^u &= \\frac{1}{2} \\text{tr}(\\Sigma^u U(t)_{uu}) = -\\frac{1}{2} \\text{tr}(\\Sigma^u \\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon_u) \\\\\n", 763 | "W(t)^\\theta &= \\frac{1}{2} \\text{tr}(\\Sigma^\\theta U(t)_{\\theta\\theta}) = -\\frac{1}{2} \\text{tr}(\\Sigma^\\theta \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta) \\\\\n", 764 | "W(t)^\\lambda &= \\frac{1}{2} \\text{tr}(\\Sigma^\\lambda U(t)_{\\lambda\\lambda}) = -\\frac{1}{2} \\text{tr}(\\Sigma^\\lambda \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon_\\lambda) \\\\\n", 765 | "W(t)_{u_i}^\\theta &= -\\frac{1}{2} \\text{tr} \\left( \\Sigma^\\theta \\tilde \\varepsilon_{\\theta u_i}^T \\tilde \\Pi \\tilde \\varepsilon_\\theta \\right) \\\\\n", 766 | "W(t)_{uu_{ij}}^\\theta &= -\\frac{1}{2} \\text{tr} \\left( \\Sigma^\\theta \\tilde \\varepsilon_{\\theta u_i}^T \\tilde \\Pi \\tilde \\varepsilon_{\\theta u_j} \\right) \\\\\n", 767 | "\\tilde\\varepsilon &= \n", 768 | " \\begin{bmatrix}\n", 769 | " \\tilde\\varepsilon^\\nu = \\tilde y - \\tilde g \\\\\n", 770 | " \\tilde\\varepsilon^x = D\\tilde x - \\tilde f\n", 771 | " \\end{bmatrix} \\\\\n", 772 | "\\tilde\\varepsilon_u &= \n", 773 | " \\begin{bmatrix}\n", 774 | " \\tilde\\varepsilon_\\nu^\\nu & \\tilde\\varepsilon_x^\\nu \\\\\n", 775 | " \\tilde\\varepsilon_\\nu^x & \\tilde\\varepsilon_x^x \\\\\n", 776 | " \\end{bmatrix} =\n", 777 | " \\begin{bmatrix}\n", 778 | " g_\\nu & g_x \\\\\n", 779 | " f_\\nu & f_x \\\\\n", 780 | " \\end{bmatrix} \\\\\n", 781 | "\\tilde \\varepsilon_y &=\n", 782 | "\\begin{bmatrix}\n", 783 | "\\tilde \\varepsilon_y^\\nu = I \\\\\n", 784 | "\\tilde \\varepsilon_y^x = 0 \\\\\n", 785 | "\\end{bmatrix} \\\\\n", 786 | "\\tilde \\varepsilon_\\eta &=\n", 787 | "\\begin{bmatrix}\n", 788 | "\\tilde \\varepsilon_\\eta^\\nu = I \\\\\n", 789 | "\\tilde \\varepsilon_\\eta^x = 0 \\\\\n", 790 | "\\end{bmatrix} \\\\\n", 791 | "\\tilde\\varepsilon_{u \\theta} &= \\tilde\\varepsilon_{\\theta u}^T =\n", 792 | " \\begin{bmatrix}\n", 793 | " g_{\\nu \\theta} & g_{x \\theta} \\\\\n", 794 | " f_{\\nu \\theta} & f_{x \\theta} \\\\\n", 795 | " \\end{bmatrix} \\\\\n", 796 | "\\tilde\\varepsilon_\\theta &= \\tilde\\varepsilon_{u \\theta} \\mu^u \\\\\n", 797 | "\\tilde\\varepsilon_{u \\lambda} &= \n", 798 | " \\begin{bmatrix}\n", 799 | " g_{\\nu \\lambda} & g_{x \\lambda} \\\\\n", 800 | " f_{\\nu \\lambda} & f_{x \\lambda} \\\\\n", 801 | " \\end{bmatrix} = \\mathbf{0} \\\\\n", 802 | "\\tilde\\varepsilon_\\lambda &= \\tilde\\varepsilon_{u \\lambda} \\mu^u = \\mathbf{0}\\\\\n", 803 | "\\Delta \\tilde \\mu &= \\left( \\exp(\\Delta t J(t)) - I \\right) J(t)^{-1} (V(t)_{u} + D \\tilde \\mu) \\\\\n", 804 | "\\Pi^u &= -U(t)_{uu}|_\\mu = \\{\\tilde \\varepsilon_u^T \\tilde \\Pi \\tilde \\varepsilon_u\\}|_\\mu \\\\\n", 805 | "\\end{align*}\n", 806 | "\n", 807 | "The mean-field term, $W(t)^\\lambda$ does not contribute to the D-step because it is not a function of the states, since the hyperparameters only control the amplitude and smoothness of the random fluctuations and $g$ and $f$ do not depend on the hyperparameters. This means uncertainly about the hyperparameters does not affect the update for the states. " 808 | ] 809 | }, 810 | { 811 | "cell_type": "markdown", 812 | "metadata": {}, 813 | "source": [ 814 | "#### The E- and M-steps ####\n", 815 | "\n", 816 | "Exactly the same update procedure can be used for the **E**-step (parameter update) and **M**-step (hyperparameter update). However, in this instance there are no generalised coordinates to consider as we consider them stationary over the period of inference. Furthermore, we can set the interval between updates to be arbitrarily long because the parameters are updated after the time-series has been integrated. If $\\Delta t \\to \\infty$ is sufficiently large, the matrix exponential in the Ozaki linearisation disappears (because the curvature of the Jacobian is negative definite) giving:\n", 817 | "\n", 818 | "\\begin{align*}\n", 819 | "\\Delta \\mu^\\theta &= -J(\\theta)^{-1} \\dot\\mu^\\theta \\\\\n", 820 | "\\dot\\mu^\\theta & = \\bar V_{\\theta}^\\theta \\\\\n", 821 | "J(\\theta) &= \\bar V_{\\theta \\theta}^\\theta \\\\\n", 822 | "\\Delta \\mu^\\lambda &= -J(\\lambda)^{-1} \\dot\\mu^\\lambda \\\\\n", 823 | "\\dot\\mu^\\lambda & = \\bar V_{\\lambda}^\\lambda \\\\\n", 824 | "J(\\lambda) &= \\bar V_{\\lambda \\lambda}^\\lambda \\\\\n", 825 | "\\end{align*}\n", 826 | "\n", 827 | "This is a conventional Gauss-Newton update scheme. In this sense, the D-Step can be regarded as a generalization of classical ascent schemes to generalised coordinates that cover dynamic systems. For our model, the requisite gradients and curvatures of variational action for the E-step are:\n", 828 | "\n", 829 | "\\begin{align*}\n", 830 | "U(t) &= - \\frac{N}{2} \\ln 2\\pi + \\frac{1}{2} \\ln |\\tilde\\Pi| - \\frac{1}{2}{{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} + \\ln p(\\tilde \\nu) \\\\\n", 831 | "\\bar{V}^\\theta &= \\int U(t) + W(t)^u + W(t)^\\lambda - \\left< \\ln Z^u \\right>_{q(u)} - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)} \\:dt + U^\\theta \\\\\n", 832 | "&= \\int U(t) + W(t)^u + W(t)^\\lambda - \\left< \\ln Z^u \\right>_{q(u)} - \\left< \\ln Z^\\lambda \\right>_{q(\\lambda)} \\:dt + {\\varepsilon^\\theta}^T \\Pi^\\theta \\varepsilon^\\theta - \\ln Z^\\theta \\\\\n", 833 | "\\bar V_\\theta^\\theta &= \\int \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon + W(t)_\\theta^u \\:dt - \\Pi^\\theta \\varepsilon^\\theta \\\\\n", 834 | "\\bar V_{\\theta \\theta}^\\theta &= \\int \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta + W(t)_{\\theta \\theta}^u \\:dt - \\Pi^\\theta \\\\\n", 835 | "W(t)_{\\theta_i}^u &= -\\frac{1}{2} \\text{tr} \\left( \\Sigma_t^u \\tilde \\varepsilon_{u \\theta_i}^T \\tilde \\Pi \\tilde \\varepsilon_u \\right) \\\\\n", 836 | "W(t)_{\\theta_i \\theta_j}^u &= -\\frac{1}{2} \\text{tr} \\left( \\Sigma_t^u \\tilde \\varepsilon_{u \\theta_i}^T \\tilde \\Pi \\tilde \\varepsilon_{u \\theta_j} \\right) \\\\\n", 837 | "J(\\theta) &= \\bar V_{\\theta \\theta}^\\theta \\\\\n", 838 | "\\Delta \\mu^\\theta &= -J(\\theta)^{-1} \\dot\\mu^\\theta = -{\\bar V_{\\theta \\theta}^\\theta}^{-1} \\bar V_{\\theta}^\\theta \\\\\n", 839 | "\\Pi^\\theta &= -\\bar{U}_{\\theta\\theta}|_{\\mu^\\theta} = - \\int U(t)_{\\theta\\theta}|_{\\mu^\\theta} \\: dt - U^\\theta_{\\theta\\theta} \\\\\n", 840 | "U(t)_{\\theta \\theta} &= - \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta \\\\\n", 841 | "U^\\theta_{\\theta\\theta} &= -{C^\\theta}^{-1} \\\\\n", 842 | "\\Pi^\\theta &= {C^\\theta}^{-1} + \\int \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta \\: dt \\\\\n", 843 | "\\end{align*}\n", 844 | "\n", 845 | "where $Z^\\theta$ is a normalisation constant. Here the precision matrix is updated from the prior precision ${C^\\theta}^{-1}$. If we assume the parameters are not constant, but varying much more slowly than the states, then this could be estimated iteratively, using the last estimate of the precision as the prior for the next one, assuming sufficient observations per update. In this case we would use:\n", 846 | "\n", 847 | "\\begin{align*}\n", 848 | "\\Delta \\Pi^\\theta &= \\int \\tilde \\varepsilon_\\theta^T \\tilde \\Pi \\tilde \\varepsilon_\\theta \\:dt \\\\\n", 849 | "\\end{align*}\n", 850 | "\n", 851 | "Similarly, for the hyperparameters, but in this case the precision matrix $\\tilde\\Pi$ is a function of the hyperparameters $\\lambda$:\n", 852 | "\n", 853 | "\\begin{align*}\n", 854 | "U(t) &= - \\frac{N}{2} \\ln 2\\pi + \\frac{1}{2} \\ln |\\tilde\\Pi| - \\frac{1}{2}{{\\tilde\\varepsilon}^T \\tilde\\Pi \\tilde\\varepsilon} + \\ln p(\\tilde \\nu) \\\\\n", 855 | "\\bar{V}^\\lambda &= \\int U(t) + W(t)^u + W(t)^\\theta - \\left< \\ln Z^u \\right>_{q(u)} - \\left< \\ln Z^\\theta \\right>_{q(\\theta)} \\:dt + U^\\lambda \\\\\n", 856 | "&= \\int U(t) + W(t)^u + W(t)^\\theta - \\left< \\ln Z^u \\right>_{q(u)} - \\left< \\ln Z^\\theta \\right>_{q(\\theta)} \\:dt + {\\varepsilon^\\lambda}^T \\Pi^\\lambda \\varepsilon^\\lambda - \\ln Z^\\lambda \\\\\n", 857 | "\\bar V_\\lambda^\\lambda &= \\int \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon + W(t)_\\lambda^u + W(t)_\\lambda^\\theta \\:dt - \\Pi^\\lambda \\varepsilon^\\lambda \\\\\n", 858 | "\\bar V_{\\lambda \\lambda}^\\lambda &= \\int \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon_\\lambda + W(t)_{\\lambda \\lambda}^u \\:dt - \\Pi^\\lambda \\\\\n", 859 | "W(t)_{\\lambda_i}^u &= -\\frac{1}{2} \\text{tr} \\left( \\Sigma_i^u \\tilde \\varepsilon_{u}^{iT} \\tilde\\Pi_{\\lambda_i} \\tilde \\varepsilon_u^i \\right) \\\\\n", 860 | "W(t)_{\\lambda_i}^\\theta &= -\\frac{1}{2} \\text{tr} \\left( \\Sigma_i^\\theta \\tilde \\varepsilon_{\\theta}^{iT} \\tilde\\Pi_{\\lambda_i} \\tilde \\varepsilon_\\theta^i \\right) \\\\\n", 861 | "W(t)_{\\lambda \\lambda}^u &= 0 \\\\\n", 862 | "J(\\lambda) &= \\bar V_{\\lambda \\lambda}^\\lambda \\\\\n", 863 | "\\Delta \\mu^\\lambda &= -J(\\lambda)^{-1} \\dot\\mu^\\lambda = - {\\bar V_{\\lambda \\lambda}^\\lambda}^{-1} \\bar V_{\\lambda}^\\lambda\\\\\n", 864 | "\\Pi^\\lambda &= -\\bar{U}_{\\lambda\\lambda} = - \\int U(t)_{\\lambda\\lambda} \\: dt- U^\\lambda_{\\lambda\\lambda} \\\\\n", 865 | "U(t)_{\\lambda \\lambda} &= - \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon_\\lambda \\\\\n", 866 | "U^\\lambda_{\\lambda\\lambda} &= -{C^\\lambda}^{-1} \\\\\n", 867 | "\\Pi^\\lambda &= {C^\\lambda}^{-1} + \\int \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon_\\lambda \\:dt \\\\\n", 868 | "&\\text{or}\\\\\n", 869 | "\\Delta \\Pi^\\lambda &= \\int \\tilde \\varepsilon_\\lambda^T \\tilde \\Pi \\tilde \\varepsilon_\\lambda \\:dt \\\\\n", 870 | "\\end{align*}\n", 871 | "\n", 872 | "where $W(t)_{\\lambda \\lambda}^u = 0$ by definition because we assume that the system is linear in the hyperparameters. Although uncertainty about the hyperparameters does not affect the states and parameters, uncertainty about both the states and parameters affect the hyperparameter update. \n" 873 | ] 874 | }, 875 | { 876 | "cell_type": "markdown", 877 | "metadata": {}, 878 | "source": [ 879 | "### Hierarchical dynamic model ###\n", 880 | "\n", 881 | "For a hierarchical dynamic model (HDM), we assume each higher level generates causes for the level below, so that the causes $\\nu$ link levels, whereas hidden states $x$ link dynamics over time. Further it is assumed that the noise processes at each level $w^{(i)}$ and $z^{(i)}$ are conditionally independent. This leads to the following Bayesian directed graph:\n", 882 | "\n", 883 | "\n", 884 | "\n", 885 | "Here $\\vartheta^{(i)} = [\\theta^{(i)}, \\lambda^{(i)}]$ and $u^{(i)} = [\\tilde \\nu^{(i)}, \\tilde x^{(i)}]$ and:\n", 886 | "\n", 887 | "\\begin{align*}\n", 888 | "g_\\nu &= \n", 889 | "\\begin{bmatrix}\n", 890 | "g_\\nu^{(1)} & & \\\\\n", 891 | "0 & \\ddots & \\\\\n", 892 | "& \\ddots & g_\\nu^{(m)} \\\\\n", 893 | "& & 0 \\\\\n", 894 | "\\end{bmatrix}\n", 895 | "&g_x = \n", 896 | "\\begin{bmatrix}\n", 897 | "g_x^{(1)} & & \\\\\n", 898 | "0 & \\ddots & \\\\\n", 899 | "& \\ddots & g_x^{(m)} \\\\\n", 900 | "& & 0 \\\\\n", 901 | "\\end{bmatrix}\\\\\n", 902 | "f_\\nu &= \n", 903 | "\\begin{bmatrix}\n", 904 | "f_\\nu^{(1)} & & \\\\\n", 905 | "& \\ddots & \\\\\n", 906 | "& & f_\\nu^{(m)} \\\\\n", 907 | "\\end{bmatrix}\n", 908 | "&f_x = \n", 909 | "\\begin{bmatrix}\n", 910 | "f_x^{(1)} & & \\\\\n", 911 | "& \\ddots & \\\\\n", 912 | "& & f_x^{(m)} \\\\\n", 913 | "\\end{bmatrix}\\\\\n", 914 | "\\end{align*}\n", 915 | "\n", 916 | "Note that the partial derivatives of $g(x,\\nu)$ have an extra row to accommodate the highest hierarchical level." 917 | ] 918 | } 919 | ], 920 | "metadata": { 921 | "hide_input": false, 922 | "kernelspec": { 923 | "display_name": "Python 3 (ipykernel)", 924 | "language": "python", 925 | "name": "python3" 926 | }, 927 | "language_info": { 928 | "codemirror_mode": { 929 | "name": "ipython", 930 | "version": 3 931 | }, 932 | "file_extension": ".py", 933 | "mimetype": "text/x-python", 934 | "name": "python", 935 | "nbconvert_exporter": "python", 936 | "pygments_lexer": "ipython3", 937 | "version": "3.7.11" 938 | } 939 | }, 940 | "nbformat": 4, 941 | "nbformat_minor": 4 942 | } 943 | --------------------------------------------------------------------------------