├── .DS_Store ├── .gitignore ├── AIC.Rmd ├── Central Limit Theorem.ipynb ├── Demo_model.ipynb ├── EM algorithm.ipynb ├── GLM.ipynb ├── LKJ prior.ipynb ├── MCMC ├── MCMC-blog.ipynb ├── MCMC-poisson.ipynb ├── MCMC.ipynb └── MCMC_HMC.ipynb ├── PoissonRegression.ipynb ├── ProbabilityDistribution.ipynb ├── PyMC3_practice.ipynb ├── README.md ├── TFP_LKJ.ipynb ├── WishartDistribution.ipynb ├── approx_prob_dist.ipynb └── data.RData /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kidaufo/StatisticalModeling/68007e022846355148102f1fcbe118ffdcb99268/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | fig/ 2 | .ipynb_checkpoints/ 3 | .png 4 | -------------------------------------------------------------------------------- /AIC.Rmd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kidaufo/StatisticalModeling/68007e022846355148102f1fcbe118ffdcb99268/AIC.Rmd -------------------------------------------------------------------------------- /Central Limit Theorem.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 中心極限定理\n", 8 | "平均$\\mu$、分散$\\sigma^2$の任意の確率分布に従う母集団から抽出した標本において、サンプルサイズが大きくなるにつれて標本平均の分布は平均$\\mu$、分散$\\sigma^2/n$の正規分布に近づく。  \n", 9 | "\n", 10 | "sample from aribitrary probability distribution, sample mean" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "\n", 21 | "import numpy as np\n", 22 | "import scipy as sp\n", 23 | "import pandas as pd\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from matplotlib import cm\n", 26 | "import seaborn as sns\n", 27 | "from tqdm import tqdm_notebook\n", 28 | "from scipy.stats import norm, poisson, gamma, uniform\n", 29 | "\n", 30 | "sns.set_style('white')\n", 31 | "sns.set_context('talk')\n", 32 | "\n", 33 | "np.random.seed(123)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 25, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from scipy import stats" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 26, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "mu = 3.0\n", 52 | "n_sample = 100\n", 53 | "\n", 54 | "def poisson_means(n_sample):\n", 55 | " means = []\n", 56 | " for _ in range(1000):\n", 57 | " data = poisson.rvs(mu=mu, size=n_sample)\n", 58 | " means.append(data.mean())\n", 59 | "\n", 60 | " means = np.array(means)\n", 61 | " plt.figure()\n", 62 | " plt.hist(means)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 28, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEMCAYAAADTfFGvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEgBJREFUeJzt3X+sX3V9x/HnLW2Vwr0Vq0yWkErpfANOi66LDSLRuDTMiDAxUDLBbZKyiD+WDrTGbR0gEBEYCsTZ6OycSsE6cCBDBgh2Q4zELExL37OKIzRdBdTei1QsbffHOTdev7vtvZ97v/ece+/3+Uiaw/ecz7mf97dJed3P+ZzPOX379+9HkqTxmtN2AZKkmcXgkCQVMTgkSUUMDklSEYNDklTE4JAkFTE4JElFDA5JUhGDQ5JUxOCQJBUxOCRJRQwOSVIRg0OSVMTgkCQVmdt2ARMVEc9TBd9g27VI0gwyAOzLzAn//3/GBgdVaPT19/cvbLsQSZophoaGYJJXm2ZycAz29/cvfPjhh9uuQ5JmjOXLlzM0NDSpKzXOcUiSihgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKmIwSFJKjKTFwBK09rg7j1s2dHsE3FOOGqAgUPnNdqneo/BIU2RLTsGWbX+oUb73Lh6BSuWLGq0T/UeL1VJkooYHJKkIgaHJKmIwSFJKmJwSJKKGBySpCIGhySpiMEhSSpicEiSihgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKmI7+PoIW28WAh8uZA02xgcPaSNFwuBLxeSZhsvVUmSihgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKnIuNZxRMQcYDXwHmAJsBP4KrAuM4fqNsuBq4HlwCCwoT6+Z8TP+R3gWuANwPPAl4EPDv8MSdL0N94FgB8EPgp8HLgXeAVwGXACcGpELK33PwicBRwPXA4MAO8FiIgjgPuAHcB5wG8BVwFHA2/tzteRJE21MYMjIvqoguPTmfnhevc9EfE0sDEiTqQKh13A6Zn5K+DOiHgWuD4irszM7cCFwBHAiZn5dP2zn6jbvi4zv931bydJ6rrxzHH0A18AvtSxf2u9PRZYCdxeh8awTcAh9THq7QPDoVG7GxgC3lJYtySpJWOOODJzEHj/KIfOqLePUl1uyo7znoyIQSDqXcdRBdDINnsj4rERbSRJ09yE7qqKiNcBa4HbgJ/Vu0d77OoQ1TwHwMJxtJEkTXPFwRERrwfuAh4Dzgf6xjhlX709WLt9BzkmSZpGioIjIs4G7gEeB95cz1cMjyL6RzllgGrSnHo7VhtJ0jQ37uCIiDXATcC3gFMycwdAZj4DbAeWdrQ/kioohuc+cpQ2hwDH0DE/IkmavsYVHBHxbuAa4Bbg1MzsHCHcDZwWEfNH7DsT2AvcP6LNmyLixSParAQOpxrFSJJmgPGs4zgS+CTwY+AG4LURv3ET1DaqhXznUK3JuI5qgeAVwPrMfLxu9yngfcC9EXEpsKg+718z88GufBtJ0pQbz4jjVGAB8HJgM9WlqpF/Ts3Mrfx69LAJWEP1aJEPDP+QzHwSeBPwNPBFqpXltwBnd+erSJKaMJ51HJ8HPj+OdpuBFWO0+R7wB+OuTpI07fh0XElSEYNDklTE4JAkFTE4JElFDA5JUhGDQ5JUxOCQJBUxOCRJRQwOSVIRg0OSVMTgkCQVMTgkSUUMDklSEYNDU27xogVtlyCpi8Z8rLo0WYfNn8vg7j1s2TE4duMuOuGoAQYOnddon1IvMDjUiC07Blm1/qFG+9y4egUrlixqtE+pF3ipSpJUxOCQJBUxOCRJRQwOSVIRg0OSVMTg0Kzl+hFpang7rmatttaPACxffETjfUpNMTg0q7WxfgTgkXUrG+9TaoqXqiRJRQwOSVIRg0OSVMTgkCQVMTgkSUUMDklSEYNDklSkeB1HRJwIfAc4JjOfGLF/G3DsKKe8NDOfqtssB64GlgODwAZgXWbuKS9dktSGouCIiOOAOzrPi4jDgSXAWuCBjtN+XrdZCtwLPAicBRwPXA4MAO+dQO2SpBaMKzgiYi5wAXAlMNro4NVAH/DVzNx6gB+zFtgFnJ6ZvwLujIhngesj4srM3F5cvSSpceOd4zgZ+BhwDfChUY6fCOwGfnCQn7ESuL0OjWGbgEPqY5KkGWC8wfEosCQzLwGeH+X4MuCnwE0R8fOIeCYiNkbEywAiYgFwNJAjT8rMJ6nmOmKiX0CS1KxxXarKzJ1jNFkGvAz4PnA9cBxwKfCNiHgtsLBuN9pjSoeo5jkkSTNAt56O+z5gTmZ+u/68OSK2AP8OvBP42hjn7+tSHZKkKdaVdRyZ+Z0RoTG87z+oJsOX8euRRv8opw/U7SRJM8CkgyMiDouIP42IZR375wDzgacy8xlgO7C0o82RVGHyG3MfkqTpqxsjjl8C1wLrOva/DTgUuL/+fDdwWkTMH9HmTGDviDaSpGlu0nMcmbk3Ii4DromITwL/AvwucAnVuo7766ZXAedQrd+4DngFcAWwPjMfn2wdkqRmdGuO41rgfOCNVMFxEfD3VEEx3GYr1XqNw6nWb6yhGql8oBs1SJKaUTziyMwNVM+Y6tz/WeCzY5y7GVhR2qckafrw6biSpCIGhySpiMEhSSpicEiSihgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKmIwSFJKmJwSJKKGBySpCIGhySpiMEhSSpicEiSihgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKmIwSFJKmJwSJKKGBySpCIGhySpiMEhSSpicEiSihgckqQiBockqYjBIUkqMrf0hIg4EfgOcExmPjFi/0rgcuCVwE7ghsy8puPc5cDVwHJgENgArMvMPRP9ApKkZhWNOCLiOOAOOgInIk6q928F3g58Efh4RFw0os1S4F5gN3AWcA2wBvi7SdQvSWrYuEYcETEXuAC4EhhtdHAp8N3MPLf+fFdEzAM+EhHXZ+ZzwFpgF3B6Zv4KuDMingWuj4grM3P7ZL+MJGnqjXfEcTLwMapRwodGHoiIFwKnAF/pOGcT8CLgpPrzSuD2OjRGtjmkPiZJmgHGGxyPAksy8xLg+Y5jS4B5QHbs31ZvIyIWAEd3tsnMJ6nmOqKkaElSe8Z1qSozdx7k8MJ6O9ixf6jeDhykzXC7gfHUIUlqXzdux+0b4/i+cbaRJM0A3QiOXfW2v2P/wIjjgwdoM9xu1yj7JUnTUDeC44fAXmBpx/7hz5mZzwDbO9tExJFUYdI5PyJJmqYmHRyZ+Uvgm8DbI2LkJakzqUYSD9ef7wZOi4j5HW32AvdPtg5JUjOKV44fwEeBe4CNEbGB6hbci4G1mfls3eYq4Byq9RvXAa8ArgDWZ+bjXapDkjTFuvKsqsy8j2r0cDxwG/DHwMWZedWINlup1mscTrV+Yw1wLfCBbtQgSWpG8YgjMzdQPWOqc/+twK1jnLsZWFHapyRp+vDpuJKkIgaHJKmIwSFJKmJwSJKKGBySpCIGhySpiMEhSSpicEiSihgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKmIwSFJKmJwSJKKGBySpCIGhySpiMEhzSKLFy1ouwT1gOJ3jkuavg6bP5fB3XvYsmOw0X5POGqAgUPnNdqn2mNwSLPMlh2DrFr/UKN9bly9ghVLFjXap9rjpSpJUhGDQ5JUxOCQJBUxOCRJRQwOSVIRg0OSVMTgkCQVMTgkSUUMDklSEYNDklSka48ciYi5wBDwwo5Dv8jMw+s2K4HLgVcCO4EbMvOabtUgSZp63XxWVVCFxruA/x6xfy9ARJwE3AHcDPw1cDLw8Yjoy8yru1iHJGkKdTM4lgH7gE2Z+ewoxy8FvpuZ59af74qIecBHIuL6zHyui7VIkqZIN+c4TgR+OFpoRMQLgVOAr3Qc2gS8CDipi3VIkqZQt0ccz0XEXVSXofYAtwAXAUcD84DsOGdbvQ3gG12sRZI0Rbo54lgGHAvcCbwFuAw4B7gdWFi36Xy7zFC9HehiHZKkKdTNEcfZwE8z87/qz9+MiJ3AF4CVY5y7r4t1SJKmUNeCIzMfGGX31zo+93d8Hh5p7OpWHZKkqdWV4IiII4G3Afdl5o9GHDq03u6kui13acepw5875z4kzSCLFy1ouwQ1qFsjjn3Ap4FPAGtG7D+bKjDuAb4JvD0irsvM/fXxM6lGGw93qQ5JLThs/lwGd+9hy47Oacypd8JRAwwcOq/xfntZV4IjM5+KiBuB90fEILAZeD3wEarV4dsi4qNUAbIxIjZQ3YJ7MbD2AOs+JM0gW3YMsmr9Q433u3H1ClYsWdR4v72sm5Pjfwk8AfwZsBbYDqwDrgLIzPsi4kzgEuC2+vjFvfjIkTZ+M1u++IhG+5M0e3VzcnwPVUhcdZA2twK3dqvPmaqN38weWTfWjW2SND4+HVeSVMTgkCQVMTgkSUUMDklSEYNDklTE4JAkFTE4JElFDA5JUhGDQ5JUxOCQJBUxOCRJRQwOSVIRg0OSVMTgkCQVMTgkSUUMDklSEYNDklTE4JAkFTE4JElFDA5JUhGDQ9KMtnjRgrZL6Dlz2y6gLYO797Blx2Dj/S5ffETjfUqz2WHz57by7/mEowYYOHReo31OFz0bHFt2DLJq/UON9/vIupWN9ynNdm38e964egUrlixqtM/pwktVkqQiBockqYjBIUkqYnBIkooYHJI0Ab18G3DP3lUlSZPR1m3A0P6twAaHJE1QW7f1t30rsJeqJElFGh9xRMQ5wF8BS4AfA1dm5uebrkOSNDGNjjgi4izgi8DXgTOA+4F/jIh3NFmHJGnimh5xXAHckplr6s9fj4gXA5cBmxquRZI0AY2NOCJiCXAs8JWOQ5uA4yLimKZqkSRNXN/+/fsb6Sgi3gJ8DViWmY+M2P8a4LvAH2bmXQU/bx/Q19/fP6F69u+HfQ1995EOmdPXSt9t9dtm337n2d9vm323+Z3n9PXR1zexc4eGhgD2Z+aEBw5NXqpaWG87b3oeqrcDhT9vHzBnaGio+ZuoJWnmGqD6/+eENRkcY+Vj0RfJTNegSFILmryrale97by2NNBxXJI0jTUZHFlvl3bsX9pxXJI0jTUWHJm5DXgM6FyzcSbwg8x8vKlaJEkT1/Q8waXA5yLiZ8AdwOnAWcCqhuuQJE1QY7fjDouIC4CLgKOBH1E9cuSfGi1CkjRhjQeHJGlm8+m4kqQiBockqYjBIUkqYnBIkooYHJKkIj33vCffQDg1ImIOsBp4D9Xf7U7gq8C6zBw62LkqExH/DLw6MzufwqAJiIhTqN4V9Frg51SvfvhwZj7TamHTWE+NOHwD4ZT6IHAD1aPzzwCuAd4FfLnNomabiHgn8Edt1zFbRMQK4N+A/wXeRrVI+Z3AZ9qsa7rrqXUcEbENeDgzV43YdzPVb2/Ht1fZzBYRfcDTwE2ZeeGI/WcDG4HXZOZ/tlXfbBERvw18D/gF8JwjjsmLiAfq/3xjZu6v910IrAFelZnPtlbcNNYzIw7fQDil+oEvAF/q2L+13h7bbDmz1meAu4F72y5kNoiIlwBvAD41HBoAmXljZh5raBxYL81xHFdvO5/Cu63eBtVDGFUoMweB949y6Ix6+/0Gy5mVIuJ84PeAVwJXt1zObPEqqvcE/bS+8vBW4HmqX4DWZObuNoubznpmxEH330Cog4iI1wFrgdsyc+tY7XVgEbEYuBZ4T2Y+1XY9s8hL6+0G4CngNOBvgfOAT7VT0szQSyOOrr6BUAcWEa+nevrxY8D5LZczo9XzR/8A3JmZnZdZNTnz6+2DI+bm7qv/zq+OiEsz80ct1Tat9dKIwzcQNqCeEL8HeBx4c2Y+3XJJM92FwKuBv4iIuRExl/qXoPrzWL8Q6cCGrzbc2bH/61R/x69qtpyZo5eCwzcQTrGIWAPcBHwLOCUzd7Rc0mzwDuAlwA5gT/3nPKobDvZQ3fKsiflBvX1Bx/7hkUjv3HJaqGeCwzcQTq2IeDfV2o1bgFMz0xFcd1wA/H7HnzuAJ+r/vr290ma8R4H/4f+/SG54kvxbjVc0Q/TaOo4/AT4H3Miv30D458CqzLy5xdJmtIg4kiqUfwKcS/WPbqRtTup2T0RsAE52Hcfk1ZdWb6K6k2oD1Z1rlwI3ZuaaFkub1nppcpzM3BARL6B6A+H5VG8gPM/QmLRTgQXAy4HNoxw/l2qdhzStZObNEfEc8DdUv0z+hCo4rmy1sGmup0YckqTJ65k5DklSdxgckqQiBockqYjBIUkqYnBIkooYHJKkIgaHJKmIwSFJKmJwSJKK/B+hL4QdzRWdYQAAAABJRU5ErkJggg==\n", 73 | "text/plain": [ 74 | "
" 75 | ] 76 | }, 77 | "metadata": { 78 | "needs_background": "light" 79 | }, 80 | "output_type": "display_data" 81 | }, 82 | { 83 | "data": { 84 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAEMCAYAAADTfFGvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEPBJREFUeJzt3X+sX3V9x/HnLS2Rwr0IXbphQiq1862YCW5drIAaI2mY04lg+DHHjL9KhiKzAwZxs1ERAgISIRobM4ta7RAnhh9DBFZhcRgJWcDVvh1aQ+w6hsz1XqVCabs/zrl6/aZyv+/b2/O9t/f5SJrD93w+53w+h7T3dT+fzznnO7Rnzx4kSerXvEF3QJI0uxgckqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSub3UykihoALgPOAo4EfAFdm5pcm1FkJfAx4GfA4cENmXtNznuXA1cByYBRYB6zJzJ37fCWSpE70O+K4lOYH/o3AG4FvAusj4gyAiDgBuA3YDJwGrAc+HhEXjp8gIpYB9wA7gDOAa4DVwCem5UokSZ0Ymuy16hGxgGYEsT4zz5+wfyNwUGa+OiLuBg7LzBUTyq8EVgG/l5lPR8RngZXAssx8pq3zV8D1wJLM3FrpeEQ8SxN8o5XjJGmOGwF2Z2ZfM05708+Bu4DXAk/27H8GOCIinge8BvhgT/nNwMXACcC/0ITGreOhMaHOp9qyzxX7Pg8YGh4ePrx4nCTNWWNjY7CP69uTBkdm7gYegV+tdSwG3gGcDJwLLAUWANlz6KPtNiLiOzRrI79RJzOfiIhRIKbQ99Hh4eHDH3zwwSkcKklz0/LlyxkbG9unmZrqUOU0mlECwO3AF4Hj28+9HRlrtyPA4b+lzni9kWI/JEkDUh2uPEQzbXU+cCJNeEx2jt3AUB91JEmzQGnEkZlbgC3Afe0U040Tiod7qo+PIrbz65FGb53xetsr/ZAkDc6kI46IODIizomIF/QUPdRuj6FZQF/WUz7+OTPz58DW3joRsZgmTHrXRyRJM1Q/U1XzaEYW5/bsX9luvwvcB5zWLp6PO51mJDG+en0X8KaIOLinzi5gY63bkqRB6eeuqp9GxKeASyLiKZogOInmocDPZmZGxGXA3cCGiFhHcwvuRcAlmflUe6qrgLOBOyLiOuDFwOXA2sx8bJqvS5K0n/S7OP4B4O+Bd9IsiJ8DfIh2FJKZ99KMHl4K3AK8DbgoM68aP0FmbqYZpRxGc2fWauBamleZSJJmiUmfHJ+pIuL/fI5Dkmra5zi2Z+bzp3qOKT9yLs0Gozt2smlb92+lOfaoEUYOWdB5u1IXDA4d0DZtG+WstQ903u6GVStYsXRR5+1KXfD7OCRJJQaHJKnE4JAklRgckqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSgwOSVKJwSFJKjE4JEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSicEhSSoxOCRJJQaHJKnE4JAklRgckqQSg0PaD5YsWjjoLkj7zfxBd0A6EB168HxGd+xk07bRTts99qgRRg5Z0GmbmnsMDmk/2bRtlLPWPtBpmxtWrWDF0kWdtqm5x6kqSVKJwSFJKjE4JEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSicEhSSoxOCRJJQaHJKnE4JAklRgckqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSgwOSVLJ/H4qRcQ8YBVwHrAUeBz4OrAmM8faOsuBq4HlwCiwri3fOeE8vw9cC7waeBb4CnDx+DkkSTNfX8EBXAxcBnwcuAd4MfBR4FjglIhY1u7/NnAG8FLgY8AI8D6AiDgCuBfYBvwl8LvAVcDRwBun53IkSfvbpMEREUM0wfGZzLy03X13RDwJbIiI42nCYTvw5sx8BrgjIp4Cro+IKzJzK/Be4Ajg+Mx8sj33T9q6r8zM70z71UmSpl0/axzDwBeBL/Xs39xuXwSsBG5tQ2PczcBBbRnt9lvjodG6CxgD3lDstyRpQCYdcWTmKPD+vRSd2m6/TzPdlD3HPRERo0C0u15CE0AT6+yKiC0T6kiSZrgp3VUVEa8ELgFuAX7W7h7dS9UxmnUOgMP7qCNJmuHKwRERJwJ3AluAdwNDkxyyu90+V73dz1EmSZpBSsEREWcCdwOPAa9v1yvGRxHDezlkhGbRnHY7WR1J0gzXd3BExGrgy8C/Aa/JzG0AmflzYCuwrKf+YpqgGF/7yL3UOQg4hp71EUnSzNVXcETEu4BrgJuAUzKzd4RwF/CmiDh4wr7TgV3Axgl1XhcRR06osxI4jGYUI0maBfp5jmMx8Engx8ANwB9G/MZNUI/SPMh3Ns0zGdfRPCB4ObA2Mx9r630aOB+4JyI+Aixqj/vnzPz2tFyNJGm/62fEcQqwEHghcD/NVNXEP6dk5mZ+PXq4GVhN82qRC8ZPkplPAK8DngTW0zxZfhNw5vRciiSpC/08x/F54PN91LsfWDFJne8BJ/fdO0nSjOPbcSVJJQaHJKnE4JAklRgckqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSgwOSVKJwSFJKjE4JEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSicEhSSoxOCRJJQaHJKnE4JAklRgckqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSgwOSVLJ/EF3QHPD6I6dbNo22mmby5cc0Wl70lxhcKgTm7aNctbaBzpt8+E1KzttT5ornKqSJJUYHJKkEoNDklRicEiSSgwOSVKJwSFJKjE4JEklBockqcTgkA4gSxYtHHQXNAf45Lh0ADn04PkDeb3LsUeNMHLIgk7b1OAYHNIBZhCvd9mwagUrli7qtE0NjlNVkqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySppPwcR0QcD3wXOCYzfzJh/0rgY8DLgMeBGzLzmp5jlwNXA8uBUWAdsCYzd071AiRJ3SqNOCLiJcBt9ARORJzQ7t8MnAasBz4eERdOqLMMuAfYAZwBXAOsBj6xD/2XJHWsrxFHRMwHzgWuAPY2OvgI8FBmntN+vjMiFgAfjIjrM/Np4BJgO/DmzHwGuCMingKuj4grMnPrvl6MJGn/63fEcRJwJc0o4W8nFkTE84DXAF/tOeZm4PnACe3nlcCtbWhMrHNQWyZJmgX6DY7vA0sz88PAsz1lS4EFQPbsf7TdRkQsBI7urZOZT9CsdUSl05KkwelrqiozH3+O4sPbbe/rOMfa7chz1BmvN9JPPyRJgzcdb8cdmqR8d591tJ8N4nXbAMuXHNF5m5L2n+kIju3tdrhn/8iE8tHfUme83va97Nc0G8TrtgEeXuMSlnQgmY4HAH8I7AKW9ewf/5yZ+XNga2+diFhMEya96yOSpBlqn4MjM38J3AecFhETp6ROpxlJPNh+vgt4U0Qc3FNnF7BxX/shSerGdH0D4GXA3cCGiFhHcwvuRcAlmflUW+cq4Gya5zeuA14MXA6szczHpqkfkqT9bFreVZWZ99KMHl4K3AK8DbgoM6+aUGczzfMah9E8v7EauBa4YDr6IEnqRnnEkZnraN4x1bv/a8DXJjn2fmBFtU1J0szh23ElSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSgwOSVKJwSFJKjE4JEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSicEhaZ8tWbRw0F1Qh6brq2MlzWGHHjyf0R072bRttPO2jz1qhJFDFnTe7lxmcEiaFpu2jXLW2gc6b3fDqhWsWLqo83bnMqeqJEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSicEhSSoxOCRJJQaHJKnE4JAklRgckqQSg0OSVGJwSJJKDA5JUonBIUkqMTgkSSUGhySpxOCQJJUYHJKkEoNDklRicEiSSgwOSVKJwSFJKjE4JEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSicEhSSoxOCTNaksWLRx0F+ac+YPugCTti0MPns/ojp1s2jbaabvHHjXCyCELOm1zpjA4JM16m7aNctbaBzptc8OqFaxYuqjTNmcKp6okSSUGhySpxOCQJJW4xjEAg1jIW77kiE7bk3TgMjgGYBALeQ+vWdlpe9KBbi7fBmxwSNIUDOo2YBj8rcCdB0dEnA38HbAU+DFwRWZ+vut+SNK+GsTsAQz+VuBOF8cj4gxgPfAN4FRgI3BjRLy1y35Ikqau6xHH5cBNmbm6/fyNiDgS+Chwc8d9kSRNQWfBERFLgRcBl/YU3QycERHHZOaWrvozqLlJ726SNNsN7dmzp5OGIuINwO3AcZn58IT9rwAeAv4kM+8snG83MDQ8PDyl/uzZA7s7uvaJDpo3NJC2B9XuINv2mg/8dgfZ9iCved7QEENDUzt2bGwMYE9mTnmposupqsPbbe+v+WPtdqR4vt3AvLGxse6HDZI0e43Q/Pycsi6DY7J8LF1IZnorsSQNQJd3VW1vt71zSyM95ZKkGazL4Mh2u6xn/7KecknSDNZZcGTmo8AWoPeZjdOB/8zMx7rqiyRp6rpeJ/gI8LmI+BlwG/Bm4AzgrI77IUmaos5uxx0XEecCFwJHAz+ieeXIFzrthCRpyjoPDknS7OYXOUmSSgwOSVKJwSFJKjE4JEklBockqcT3Pc1BETEPWAWcR/NNjI8DXwfWZObYcx2r3xQRQ8AFNP8vjwZ+AFyZmV8aaMdmsYj4J+Dlmdn7lglNIiLm07w49nk9Rb/IzMOmqx1HHHPTxcANNK+5PxW4Bng78JVBdmqWuhS4GrgReCPwTWB9+22XKoqIvwDeMuh+zGJBExpvB1414c/rprMRn+OYY9rfkJ8EvpyZ752w/0xgA/CKzPz3QfVvNomIBTSjtfWZef6E/RuBgzLz1YPq22wUES8Avgf8AnjaEUddRPw58AVgODOf2l/tOFU19wwDXwT+sWf/5nb7IsDg6M8u4LU0QTzRM4Bf9Vj3WeAu4JfASQPuy2x1PPDD/RkaYHDMOZk5Crx/L0Wnttv/6LA7s1pm7gYegV+N5BYD7wBOBs4dYNdmnYh4N/BHwMtopv40NccBT0fEnTThuxO4CbhwOtcvXeMQEfFK4BLglszcPFl97dVpwH8DVwB30Izq1IeIWAJcC5yXmT8ddH9mueNoZg3uAN4AfBQ4G7i1/eVmWjjimOMi4kSaNxVvAd494O7MZg/RTFu9nOYf6+0R8frMdBHxObQ/zP4BuCMzvzro/hwAzgT+NzMfaT/fFxGP0/wiczLNzRv7zOCYw9oF8XU0t5Cekpm9c/XqU2ZuoQnf+yJilOYuq1cB3x5ox2a+99KE7R+0t5JC+zXT7eddhm//MvNbe9l9e7s9DoND+yIiVtPMJW8E3pKZfnVvUUQcCfwpcE9m/teEoofa7Qu679Ws81bgd4BteynbSbNmtK7LDs1WEbEY+DPg3sz80YSiQ9rttE0DusYxB0XEu2ie3biJZqRhaEzNPJqRRe9C+Mp2+wiazLnAH/f8uQ34Sfvftw6ua7PObuAzwPt69p9Jcwfgv05XQz7HMce0v5VsAf4HOAd4tqfKoy5Q9i8ibgDeA3wIeJDmTpZLgS9k5nsG2bfZKiLWASf5HEddRHyS5i0GlwH3AycCHwQ+nZl/PV3tOFU195wCLAReSPMXq9c5eEdQxQeAx4B3AR+m+U35Q3hLqQbjb2j+Dr6T5k7JrcAa4KrpbMQRhySpxDUOSVKJwSFJKjE4JEklBockqcTgkCSVGBySpBKDQ5JUYnBIkkoMDklSyf8DP90Wa0mXBoAAAAAASUVORK5CYII=\n", 85 | "text/plain": [ 86 | "
" 87 | ] 88 | }, 89 | "metadata": { 90 | "needs_background": "light" 91 | }, 92 | "output_type": "display_data" 93 | }, 94 | { 95 | "data": { 96 | "image/png": "\n", 97 | "text/plain": [ 98 | "
" 99 | ] 100 | }, 101 | "metadata": { 102 | "needs_background": "light" 103 | }, 104 | "output_type": "display_data" 105 | }, 106 | { 107 | "data": { 108 | "image/png": "\n", 109 | "text/plain": [ 110 | "
" 111 | ] 112 | }, 113 | "metadata": { 114 | "needs_background": "light" 115 | }, 116 | "output_type": "display_data" 117 | } 118 | ], 119 | "source": [ 120 | "for i in [2, 10, 100, 1000]:\n", 121 | " poisson_means(i)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.7.1" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /MCMC/MCMC_HMC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "import numpy as np\n", 11 | "import scipy as sp\n", 12 | "import pandas as pd\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "from matplotlib import cm\n", 15 | "import seaborn as sns\n", 16 | "from tqdm import tqdm_notebook\n", 17 | "from scipy.stats import norm, poisson, gamma, uniform\n", 18 | "sns.set(style=\"white\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "正規分布の母数を推定する。ただし、事前分布は(0, 1000)の一様分布とする。" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 6, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "text/plain": [ 36 | "" 37 | ] 38 | }, 39 | "execution_count": 6, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | }, 43 | { 44 | "data": { 45 | "image/png": "\n", 46 | "text/plain": [ 47 | "
" 48 | ] 49 | }, 50 | "metadata": { 51 | "needs_background": "light" 52 | }, 53 | "output_type": "display_data" 54 | } 55 | ], 56 | "source": [ 57 | "mu_true = 170\n", 58 | "sigma_true = 7\n", 59 | "n_sample = 1000\n", 60 | "\n", 61 | "prior_min = 0\n", 62 | "prior_max = 1000\n", 63 | "\n", 64 | "np.random.seed(7)\n", 65 | "data = norm.rvs(mu_true, sigma_true, n_sample)\n", 66 | "sns.distplot(data)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 7, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "prior = uniform(prior_min, prior_max)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "ハミルトニアンは位置エネルギーと運動エネルギーの和である。理想状態では物体はハミルトニアンが一定になるように運動する。ハミルトニアンが保存されることを用いて、物体の運動の軌跡を予測することができる。 \n", 83 | "\n", 84 | "ハミルトニアンを微分し。0と置くことでハミルトンの運動方程式を得る。これをleap frog法により解くことで、奇跡が予測できる。運動量と位置を交互に更新していくことで、ハミルトニアンを高い精度で保存する方法。 \n", 85 | "\n", 86 | "質量m=1、重力加速度g=1とすると、位置エネルギーはhで表される。HMC法では、hはパラメータ$\\theta$の関数であり、マイナスをとった対数事後分布を表す。つまり、事後確率が高いところは高さが低いところに相当し、物体が通ることが多くなる。また、運動エネルギーは$\\frac{1}{2m}p^2$である。pは標準正規分布からのサンプルである。 \n", 87 | "\n", 88 | "事後分布と標準正規分布の同時分布のカーネルがハミルトニアンになる。このとき、新しい候補の受容確率はハミルトニアンの差の指数となり、ハミルトニアンの差は近似誤差のみで、0に近い値をとるので、受容確率は1に近くなる。" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 46, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# x[0]: mu, x[1]: sigma\n", 98 | "def log_posterior_kernel(x, mu, sigma):\n", 99 | " return n_sample * np.log(sigma ** 2) / 2 + np.square(x - mu).sum() / (2 * sigma ** 2)\n", 100 | "\n", 101 | "# それぞれのパラメータによる微分を並べる\n", 102 | "def log_posterior_kernel_diff(x, mu, sigma):\n", 103 | " dmu = -(x - mu).sum() / sigma ** 2\n", 104 | " dsigma = n_sample / (2 * sigma ** 2) - np.square(x - mu).sum() / (2 * sigma ** 4)\n", 105 | " return np.array([dmu, dsigma])" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 82, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/latex": [ 116 | "$\\displaystyle \\frac{x \\log{\\left(sig^{2} \\right)}}{2}$" 117 | ], 118 | "text/plain": [ 119 | "x*log(sig**2)/2" 120 | ] 121 | }, 122 | "execution_count": 82, 123 | "metadata": {}, 124 | "output_type": "execute_result" 125 | } 126 | ], 127 | "source": [ 128 | "import sympy\n", 129 | "from sympy import symbols\n", 130 | "\n", 131 | "n, x, mu, sig = symbols('x y mu sig' )\n", 132 | "n * sympy.log(sig ** 2) / 2 " 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 50, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "" 144 | ] 145 | }, 146 | "execution_count": 50, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | }, 150 | { 151 | "data": { 152 | "image/png": "\n", 153 | "text/plain": [ 154 | "
" 155 | ] 156 | }, 157 | "metadata": { 158 | "needs_background": "light" 159 | }, 160 | "output_type": "display_data" 161 | } 162 | ], 163 | "source": [ 164 | "# パラメータと対数事後確率のカーネルの contour plot\n", 165 | "# vectorizeするのは難しそう。。\n", 166 | "x1 = np.linspace(150, 200)\n", 167 | "x2 = np.linspace(5, 10)\n", 168 | "X1, X2 = np.meshgrid(x1, x2)\n", 169 | "Z = np.array([log_posterior_kernel(data, a, b) for a, b in zip(X1.flatten(), X2.flatten())])\n", 170 | "Z = Z.reshape(X1.shape)\n", 171 | "\n", 172 | "plt.contour(X1, X2, Z, cmap=cm.Blues)\n", 173 | "plt.xlabel(\"mu\")\n", 174 | "plt.ylabel(\"s\")\n", 175 | "plt.colorbar()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 51, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "text/plain": [ 186 | "array([-0.37088578, 0.16181131])" 187 | ] 188 | }, 189 | "execution_count": 51, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "np.random.randn(2)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 22, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "ename": "NameError", 205 | "evalue": "name 'T' is not defined", 206 | "output_type": "error", 207 | "traceback": [ 208 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 209 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 210 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0maccepted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 211 | "\u001b[0;31mNameError\u001b[0m: name 'T' is not defined" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "t = 1000\n", 217 | "l = 100\n", 218 | "eps = 0.01\n", 219 | "\n", 220 | "# initial parameter\n", 221 | "current_params = [150, 5]\n", 222 | "\n", 223 | "# generate random variables from standard normal\n", 224 | "p = np.random.randn(2)\n", 225 | "\n", 226 | "# leap frog\n", 227 | "\n", 228 | "\n", 229 | "# accept the proposal or not\n", 230 | "\n", 231 | "\n", 232 | "res = []\n", 233 | "for _ in range(T):\n", 234 | " accepted = False\n", 235 | " p = np.random.randn()\n", 236 | " theta_prop, p_prop = leap_frog_step(theta, p, L, eps)\n", 237 | " r = np.exp(hamiltonian(theta, p) - hamiltonian(theta_prop, p_prop))\n", 238 | "# print(r)\n", 239 | " if np.random.rand() < r:\n", 240 | " accepted = True\n", 241 | " theta= theta_prop\n", 242 | " p = p_prop\n", 243 | " res.append([theta_prop, p_prop, accepted])" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 52, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "p = np.random.randn(2)\n", 253 | "theta = [150, 5]\n", 254 | "p -= eps * log_posterior_kernel_diff(data , theta[0], theta[1]) / 2\n", 255 | "theta += eps * p\n", 256 | "p -= eps * log_posterior_kernel_diff(data , theta[0], theta[1]) / 2" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 58, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "14768.320956095447" 268 | ] 269 | }, 270 | "execution_count": 58, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "log_posterior_kernel(data, theta[0], theta[1]) + (p**2).sum() / 2" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 60, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "def hamiltonian(theta, p):\n", 286 | " return log_posterior_kernel(data, theta[0], theta[1]) + (p**2).sum() / 2\n", 287 | "\n", 288 | "def leap_frog(theta, p, L, eps):\n", 289 | " \"\"\"\n", 290 | " x: parameter\n", 291 | " y: p\n", 292 | " \"\"\"\n", 293 | " res = [[theta, p, hamiltonian(theta, p)]]\n", 294 | " for _ in range(L):\n", 295 | " p -= eps * log_posterior_kernel_diff(data , theta[0], theta[1]) / 2\n", 296 | " theta += eps * p\n", 297 | " p -= eps * log_posterior_kernel_diff(data , theta[0], theta[1]) / 2\n", 298 | " res.append([theta, p, hamiltonian(theta, p)])\n", 299 | " return res\n", 300 | "\n", 301 | "\n", 302 | "\n", 303 | "# res = []\n", 304 | "# for _ in range(T):\n", 305 | "# accepted = False\n", 306 | "# p = np.random.randn()\n", 307 | "# theta_prop, p_prop = leap_frog_step(theta, p, L, eps)\n", 308 | "# r = np.exp(hamiltonian(theta, p) - hamiltonian(theta_prop, p_prop))\n", 309 | "# # print(r)\n", 310 | "# if np.random.rand() < r:\n", 311 | "# accepted = True\n", 312 | "# theta= theta_prop\n", 313 | "# p = p_prop\n", 314 | "# res.append([theta_prop, p_prop, accepted])" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 65, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "p = np.random.randn(2)\n", 324 | "theta = [150, 5]\n", 325 | "res = leap_frog(theta, p, 100, 0.01)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 66, 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "data": { 335 | "text/plain": [ 336 | "[[[150, 5], array([45.78211507, 25.11116066]), 10359.26546501995],\n", 337 | " [array([220.41219201, 28.21704146]),\n", 338 | " array([45.78211507, 25.11116066]),\n", 339 | " 10342.784129513393],\n", 340 | " [array([220.41219201, 28.21704146]),\n", 341 | " array([45.78211507, 25.11116066]),\n", 342 | " 10231.562578343654],\n", 343 | " [array([220.41219201, 28.21704146]),\n", 344 | " array([45.78211507, 25.11116066]),\n", 345 | " 10038.746105992966],\n", 346 | " [array([220.41219201, 28.21704146]),\n", 347 | " array([45.78211507, 25.11116066]),\n", 348 | " 9784.734228510415],\n", 349 | " [array([220.41219201, 28.21704146]),\n", 350 | " array([45.78211507, 25.11116066]),\n", 351 | " 9492.861200732805],\n", 352 | " [array([220.41219201, 28.21704146]),\n", 353 | " array([45.78211507, 25.11116066]),\n", 354 | " 9185.407388982954],\n", 355 | " [array([220.41219201, 28.21704146]),\n", 356 | " array([45.78211507, 25.11116066]),\n", 357 | " 8880.913235514561],\n", 358 | " [array([220.41219201, 28.21704146]),\n", 359 | " array([45.78211507, 25.11116066]),\n", 360 | " 8593.01347049296],\n", 361 | " [array([220.41219201, 28.21704146]),\n", 362 | " array([45.78211507, 25.11116066]),\n", 363 | " 8330.474244009012],\n", 364 | " [array([220.41219201, 28.21704146]),\n", 365 | " array([45.78211507, 25.11116066]),\n", 366 | " 8097.930800267921],\n", 367 | " [array([220.41219201, 28.21704146]),\n", 368 | " array([45.78211507, 25.11116066]),\n", 369 | " 7896.885720508933],\n", 370 | " [array([220.41219201, 28.21704146]),\n", 371 | " array([45.78211507, 25.11116066]),\n", 372 | " 7726.68343116028],\n", 373 | " [array([220.41219201, 28.21704146]),\n", 374 | " array([45.78211507, 25.11116066]),\n", 375 | " 7585.323258417608],\n", 376 | " [array([220.41219201, 28.21704146]),\n", 377 | " array([45.78211507, 25.11116066]),\n", 378 | " 7470.072791662333],\n", 379 | " [array([220.41219201, 28.21704146]),\n", 380 | " array([45.78211507, 25.11116066]),\n", 381 | " 7377.896379377075],\n", 382 | " [array([220.41219201, 28.21704146]),\n", 383 | " array([45.78211507, 25.11116066]),\n", 384 | " 7305.734449088767],\n", 385 | " [array([220.41219201, 28.21704146]),\n", 386 | " array([45.78211507, 25.11116066]),\n", 387 | " 7250.672269884444],\n", 388 | " [array([220.41219201, 28.21704146]),\n", 389 | " array([45.78211507, 25.11116066]),\n", 390 | " 7210.031689471114],\n", 391 | " [array([220.41219201, 28.21704146]),\n", 392 | " array([45.78211507, 25.11116066]),\n", 393 | " 7181.411921069739],\n", 394 | " [array([220.41219201, 28.21704146]),\n", 395 | " array([45.78211507, 25.11116066]),\n", 396 | " 7162.698310592624],\n", 397 | " [array([220.41219201, 28.21704146]),\n", 398 | " array([45.78211507, 25.11116066]),\n", 399 | " 7152.0521682025665],\n", 400 | " [array([220.41219201, 28.21704146]),\n", 401 | " array([45.78211507, 25.11116066]),\n", 402 | " 7147.89035391228],\n", 403 | " [array([220.41219201, 28.21704146]),\n", 404 | " array([45.78211507, 25.11116066]),\n", 405 | " 7148.860177731191],\n", 406 | " [array([220.41219201, 28.21704146]),\n", 407 | " array([45.78211507, 25.11116066]),\n", 408 | " 7153.813031375219],\n", 409 | " [array([220.41219201, 28.21704146]),\n", 410 | " array([45.78211507, 25.11116066]),\n", 411 | " 7161.778745021877],\n", 412 | " [array([220.41219201, 28.21704146]),\n", 413 | " array([45.78211507, 25.11116066]),\n", 414 | " 7171.941743250148],\n", 415 | " [array([220.41219201, 28.21704146]),\n", 416 | " array([45.78211507, 25.11116066]),\n", 417 | " 7183.619497474319],\n", 418 | " [array([220.41219201, 28.21704146]),\n", 419 | " array([45.78211507, 25.11116066]),\n", 420 | " 7196.243422476495],\n", 421 | " [array([220.41219201, 28.21704146]),\n", 422 | " array([45.78211507, 25.11116066]),\n", 423 | " 7209.342162362122],\n", 424 | " [array([220.41219201, 28.21704146]),\n", 425 | " array([45.78211507, 25.11116066]),\n", 426 | " 7222.52710264967],\n", 427 | " [array([220.41219201, 28.21704146]),\n", 428 | " array([45.78211507, 25.11116066]),\n", 429 | " 7235.479894681125],\n", 430 | " [array([220.41219201, 28.21704146]),\n", 431 | " array([45.78211507, 25.11116066]),\n", 432 | " 7247.94176294267],\n", 433 | " [array([220.41219201, 28.21704146]),\n", 434 | " array([45.78211507, 25.11116066]),\n", 435 | " 7259.704370378891],\n", 436 | " [array([220.41219201, 28.21704146]),\n", 437 | " array([45.78211507, 25.11116066]),\n", 438 | " 7270.60203203222],\n", 439 | " [array([220.41219201, 28.21704146]),\n", 440 | " array([45.78211507, 25.11116066]),\n", 441 | " 7280.50508759225],\n", 442 | " [array([220.41219201, 28.21704146]),\n", 443 | " array([45.78211507, 25.11116066]),\n", 444 | " 7289.314265257241],\n", 445 | " [array([220.41219201, 28.21704146]),\n", 446 | " array([45.78211507, 25.11116066]),\n", 447 | " 7296.955890721801],\n", 448 | " [array([220.41219201, 28.21704146]),\n", 449 | " array([45.78211507, 25.11116066]),\n", 450 | " 7303.377815057891],\n", 451 | " [array([220.41219201, 28.21704146]),\n", 452 | " array([45.78211507, 25.11116066]),\n", 453 | " 7308.545953263418],\n", 454 | " [array([220.41219201, 28.21704146]),\n", 455 | " array([45.78211507, 25.11116066]),\n", 456 | " 7312.441341160674],\n", 457 | " [array([220.41219201, 28.21704146]),\n", 458 | " array([45.78211507, 25.11116066]),\n", 459 | " 7315.057632176335],\n", 460 | " [array([220.41219201, 28.21704146]),\n", 461 | " array([45.78211507, 25.11116066]),\n", 462 | " 7316.398967467017],\n", 463 | " [array([220.41219201, 28.21704146]),\n", 464 | " array([45.78211507, 25.11116066]),\n", 465 | " 7316.478163059148],\n", 466 | " [array([220.41219201, 28.21704146]),\n", 467 | " array([45.78211507, 25.11116066]),\n", 468 | " 7315.315166352646],\n", 469 | " [array([220.41219201, 28.21704146]),\n", 470 | " array([45.78211507, 25.11116066]),\n", 471 | " 7312.935741694299],\n", 472 | " [array([220.41219201, 28.21704146]),\n", 473 | " array([45.78211507, 25.11116066]),\n", 474 | " 7309.3703509443785],\n", 475 | " [array([220.41219201, 28.21704146]),\n", 476 | " array([45.78211507, 25.11116066]),\n", 477 | " 7304.65320020603],\n", 478 | " [array([220.41219201, 28.21704146]),\n", 479 | " array([45.78211507, 25.11116066]),\n", 480 | " 7298.82142830832],\n", 481 | " [array([220.41219201, 28.21704146]),\n", 482 | " array([45.78211507, 25.11116066]),\n", 483 | " 7291.91441635837],\n", 484 | " [array([220.41219201, 28.21704146]),\n", 485 | " array([45.78211507, 25.11116066]),\n", 486 | " 7283.973200815119],\n", 487 | " [array([220.41219201, 28.21704146]),\n", 488 | " array([45.78211507, 25.11116066]),\n", 489 | " 7275.039975180385],\n", 490 | " [array([220.41219201, 28.21704146]),\n", 491 | " array([45.78211507, 25.11116066]),\n", 492 | " 7265.157667630994],\n", 493 | " [array([220.41219201, 28.21704146]),\n", 494 | " array([45.78211507, 25.11116066]),\n", 495 | " 7254.369583795529],\n", 496 | " [array([220.41219201, 28.21704146]),\n", 497 | " array([45.78211507, 25.11116066]),\n", 498 | " 7242.71910546662],\n", 499 | " [array([220.41219201, 28.21704146]),\n", 500 | " array([45.78211507, 25.11116066]),\n", 501 | " 7230.249437381737],\n", 502 | " [array([220.41219201, 28.21704146]),\n", 503 | " array([45.78211507, 25.11116066]),\n", 504 | " 7217.003395341594],\n", 505 | " [array([220.41219201, 28.21704146]),\n", 506 | " array([45.78211507, 25.11116066]),\n", 507 | " 7203.0232298982955],\n", 508 | " [array([220.41219201, 28.21704146]),\n", 509 | " array([45.78211507, 25.11116066]),\n", 510 | " 7188.350480662992],\n", 511 | " [array([220.41219201, 28.21704146]),\n", 512 | " array([45.78211507, 25.11116066]),\n", 513 | " 7173.025856978032],\n", 514 | " [array([220.41219201, 28.21704146]),\n", 515 | " array([45.78211507, 25.11116066]),\n", 516 | " 7157.089141290713],\n", 517 | " [array([220.41219201, 28.21704146]),\n", 518 | " array([45.78211507, 25.11116066]),\n", 519 | " 7140.5791120709055],\n", 520 | " [array([220.41219201, 28.21704146]),\n", 521 | " array([45.78211507, 25.11116066]),\n", 522 | " 7123.533483546439],\n", 523 | " [array([220.41219201, 28.21704146]),\n", 524 | " array([45.78211507, 25.11116066]),\n", 525 | " 7105.988859899791],\n", 526 | " [array([220.41219201, 28.21704146]),\n", 527 | " array([45.78211507, 25.11116066]),\n", 528 | " 7087.980701886488],\n", 529 | " [array([220.41219201, 28.21704146]),\n", 530 | " array([45.78211507, 25.11116066]),\n", 531 | " 7069.5433041079195],\n", 532 | " [array([220.41219201, 28.21704146]),\n", 533 | " array([45.78211507, 25.11116066]),\n", 534 | " 7050.709781405603],\n", 535 | " [array([220.41219201, 28.21704146]),\n", 536 | " array([45.78211507, 25.11116066]),\n", 537 | " 7031.512063045906],\n", 538 | " [array([220.41219201, 28.21704146]),\n", 539 | " array([45.78211507, 25.11116066]),\n", 540 | " 7011.980893538743],\n", 541 | " [array([220.41219201, 28.21704146]),\n", 542 | " array([45.78211507, 25.11116066]),\n", 543 | " 6992.14583908463],\n", 544 | " [array([220.41219201, 28.21704146]),\n", 545 | " array([45.78211507, 25.11116066]),\n", 546 | " 6972.035298775282],\n", 547 | " [array([220.41219201, 28.21704146]),\n", 548 | " array([45.78211507, 25.11116066]),\n", 549 | " 6951.676519786365],\n", 550 | " [array([220.41219201, 28.21704146]),\n", 551 | " array([45.78211507, 25.11116066]),\n", 552 | " 6931.095615899708],\n", 553 | " [array([220.41219201, 28.21704146]),\n", 554 | " array([45.78211507, 25.11116066]),\n", 555 | " 6910.317588778097],\n", 556 | " [array([220.41219201, 28.21704146]),\n", 557 | " array([45.78211507, 25.11116066]),\n", 558 | " 6889.36635149062],\n", 559 | " [array([220.41219201, 28.21704146]),\n", 560 | " array([45.78211507, 25.11116066]),\n", 561 | " 6868.264753851827],\n", 562 | " [array([220.41219201, 28.21704146]),\n", 563 | " array([45.78211507, 25.11116066]),\n", 564 | " 6847.034609195096],\n", 565 | " [array([220.41219201, 28.21704146]),\n", 566 | " array([45.78211507, 25.11116066]),\n", 567 | " 6825.696722250436],\n", 568 | " [array([220.41219201, 28.21704146]),\n", 569 | " array([45.78211507, 25.11116066]),\n", 570 | " 6804.270917840771],\n", 571 | " [array([220.41219201, 28.21704146]),\n", 572 | " array([45.78211507, 25.11116066]),\n", 573 | " 6782.776070148968],\n", 574 | " [array([220.41219201, 28.21704146]),\n", 575 | " array([45.78211507, 25.11116066]),\n", 576 | " 6761.230132341609],\n", 577 | " [array([220.41219201, 28.21704146]),\n", 578 | " array([45.78211507, 25.11116066]),\n", 579 | " 6739.650166364838],\n", 580 | " [array([220.41219201, 28.21704146]),\n", 581 | " array([45.78211507, 25.11116066]),\n", 582 | " 6718.052372753682],\n", 583 | " [array([220.41219201, 28.21704146]),\n", 584 | " array([45.78211507, 25.11116066]),\n", 585 | " 6696.452120318897],\n", 586 | " [array([220.41219201, 28.21704146]),\n", 587 | " array([45.78211507, 25.11116066]),\n", 588 | " 6674.863975595403],\n", 589 | " [array([220.41219201, 28.21704146]),\n", 590 | " array([45.78211507, 25.11116066]),\n", 591 | " 6653.301731953971],\n", 592 | " [array([220.41219201, 28.21704146]),\n", 593 | " array([45.78211507, 25.11116066]),\n", 594 | " 6631.7784382932505],\n", 595 | " [array([220.41219201, 28.21704146]),\n", 596 | " array([45.78211507, 25.11116066]),\n", 597 | " 6610.30642724274],\n", 598 | " [array([220.41219201, 28.21704146]),\n", 599 | " array([45.78211507, 25.11116066]),\n", 600 | " 6588.8973428193],\n", 601 | " [array([220.41219201, 28.21704146]),\n", 602 | " array([45.78211507, 25.11116066]),\n", 603 | " 6567.562167490172],\n", 604 | " [array([220.41219201, 28.21704146]),\n", 605 | " array([45.78211507, 25.11116066]),\n", 606 | " 6546.311248604663],\n", 607 | " [array([220.41219201, 28.21704146]),\n", 608 | " array([45.78211507, 25.11116066]),\n", 609 | " 6525.154324164687],\n", 610 | " [array([220.41219201, 28.21704146]),\n", 611 | " array([45.78211507, 25.11116066]),\n", 612 | " 6504.100547911346],\n", 613 | " [array([220.41219201, 28.21704146]),\n", 614 | " array([45.78211507, 25.11116066]),\n", 615 | " 6483.158513710863],\n", 616 | " [array([220.41219201, 28.21704146]),\n", 617 | " array([45.78211507, 25.11116066]),\n", 618 | " 6462.336279228525],\n", 619 | " [array([220.41219201, 28.21704146]),\n", 620 | " array([45.78211507, 25.11116066]),\n", 621 | " 6441.641388883937],\n", 622 | " [array([220.41219201, 28.21704146]),\n", 623 | " array([45.78211507, 25.11116066]),\n", 624 | " 6421.080896084891],\n", 625 | " [array([220.41219201, 28.21704146]),\n", 626 | " array([45.78211507, 25.11116066]),\n", 627 | " 6400.661384740626],\n", 628 | " [array([220.41219201, 28.21704146]),\n", 629 | " array([45.78211507, 25.11116066]),\n", 630 | " 6380.388990058293],\n", 631 | " [array([220.41219201, 28.21704146]),\n", 632 | " array([45.78211507, 25.11116066]),\n", 633 | " 6360.269418628925],\n", 634 | " [array([220.41219201, 28.21704146]),\n", 635 | " array([45.78211507, 25.11116066]),\n", 636 | " 6340.307967811467]]" 637 | ] 638 | }, 639 | "execution_count": 66, 640 | "metadata": {}, 641 | "output_type": "execute_result" 642 | } 643 | ], 644 | "source": [ 645 | "res" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": null, 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [] 654 | } 655 | ], 656 | "metadata": { 657 | "kernelspec": { 658 | "display_name": "Python 3", 659 | "language": "python", 660 | "name": "python3" 661 | }, 662 | "language_info": { 663 | "codemirror_mode": { 664 | "name": "ipython", 665 | "version": 3 666 | }, 667 | "file_extension": ".py", 668 | "mimetype": "text/x-python", 669 | "name": "python", 670 | "nbconvert_exporter": "python", 671 | "pygments_lexer": "ipython3", 672 | "version": "3.7.1" 673 | } 674 | }, 675 | "nbformat": 4, 676 | "nbformat_minor": 2 677 | } 678 | -------------------------------------------------------------------------------- /PyMC3_practice.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "\n", 11 | "from warnings import filterwarnings\n", 12 | "\n", 13 | "from matplotlib.patches import Ellipse\n", 14 | "from matplotlib import pyplot as plt\n", 15 | "import numpy as np\n", 16 | "import pymc3 as pm\n", 17 | "import seaborn as sns" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "with pm.Model() as model:\n", 27 | " mu = pm.Normal('mu', mu=0, sd=1)\n", 28 | " obs = pm.Normal('obs', mu=mu, sd=1, observed=np.random.randn(100))" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "モデルはrandom variables (RVs) への参照をもち、logpやその勾配を計算する。" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 12, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "data": { 45 | "text/plain": [ 46 | "([mu, obs], [mu], [obs])" 47 | ] 48 | }, 49 | "execution_count": 12, 50 | "metadata": {}, 51 | "output_type": "execute_result" 52 | } 53 | ], 54 | "source": [ 55 | "model.basic_RVs, model.free_RVs, model.observed_RVs" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 13, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "array(-155.78695688)" 67 | ] 68 | }, 69 | "execution_count": 13, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "# log probability density function\n", 76 | "model.logp({'mu': 0})" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 23, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "array(-0.91893853)" 88 | ] 89 | }, 90 | "execution_count": 23, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "with pm.Model():\n", 97 | " x = pm.Normal('x', mu=0, sd=1)\n", 98 | "\n", 99 | "x.logp({'x': 0})" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "PyMC3では制約ありのRVsは制約なしになるように自動的に変換される。 \n", 107 | "include_transformed=Trueを引数に与えることで、変換後のパラメータの結果も表示されるようになる。" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 11, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "[x_log__, y_interval__]" 119 | ] 120 | }, 121 | "execution_count": 11, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "with pm.Model() as model:\n", 128 | " x = pm.Gamma('x', mu=1, sd=1)\n", 129 | " y = pm.Uniform('y', lower=-1, upper=2)\n", 130 | " z = pm.Deterministic('z', x + y)\n", 131 | " \n", 132 | "model.free_RVs" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 15, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "[x, y, z]" 144 | ] 145 | }, 146 | "execution_count": 15, 147 | "metadata": {}, 148 | "output_type": "execute_result" 149 | } 150 | ], 151 | "source": [ 152 | "model.deterministics" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "### 初期値の設定 \n", 160 | "基本的に初期値は自動的に設定されるが、手動で変更することも可能。" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 17, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "array([1., 1., 1., 1., 1.])" 172 | ] 173 | }, 174 | "execution_count": 17, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "# initial values\n", 181 | "with pm.Model():\n", 182 | " x = pm.Normal('x', mu=1, sd=2, shape=5)\n", 183 | "\n", 184 | "x.tag.test_value" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 27, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "data": { 194 | "text/plain": [ 195 | "array([ 0.89674484, 0.3252694 , -0.24739399, 0.77750987, 0.30550593])" 196 | ] 197 | }, 198 | "execution_count": 27, 199 | "metadata": {}, 200 | "output_type": "execute_result" 201 | } 202 | ], 203 | "source": [ 204 | "with pm.Model():\n", 205 | " x = pm.Normal('x', mu=0, sd=1, shape=5, testval=np.random.randn(5))\n", 206 | "\n", 207 | "x.tag.test_value" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "### サンプリング\n", 215 | "NUTSでは良い初期値を探索するために、変分推論などを用いる。 \n", 216 | "NUTSでのサンプリングが難しい場合、メトロポリス法を使っても改善することは少なく、初期値やパラメータを調整する方が良い。" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 18, 222 | "metadata": {}, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/plain": [ 227 | "['BinaryGibbsMetropolis',\n", 228 | " 'BinaryMetropolis',\n", 229 | " 'CSG',\n", 230 | " 'CategoricalGibbsMetropolis',\n", 231 | " 'CauchyProposal',\n", 232 | " 'CompoundStep',\n", 233 | " 'DEMetropolis',\n", 234 | " 'ElemwiseCategorical',\n", 235 | " 'EllipticalSlice',\n", 236 | " 'HamiltonianMC',\n", 237 | " 'LaplaceProposal',\n", 238 | " 'Metropolis',\n", 239 | " 'MultivariateNormalProposal',\n", 240 | " 'NUTS',\n", 241 | " 'NormalProposal',\n", 242 | " 'PoissonProposal',\n", 243 | " 'SGFS',\n", 244 | " 'SMC',\n", 245 | " 'Slice']" 246 | ] 247 | }, 248 | "execution_count": 18, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "list(filter(lambda x: x[0].isupper(), dir(pm.step_methods)))" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "### 変分推論" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 21, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stderr", 271 | "output_type": "stream", 272 | "text": [ 273 | "100%|██████████| 10000/10000 [00:54<00:00, 184.92it/s]\n" 274 | ] 275 | } 276 | ], 277 | "source": [ 278 | "w = pm.floatX([.2, .8])\n", 279 | "mu = pm.floatX([-.3, .5])\n", 280 | "sd = pm.floatX([.1, .1])\n", 281 | "with pm.Model() as model:\n", 282 | " pm.NormalMixture('x', w=w, mu=mu, sd=sd)\n", 283 | " approx = pm.fit(method=pm.SVGD(n_particles=200, jitter=1.))" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 22, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "image/png": "\n", 294 | "text/plain": [ 295 | "
" 296 | ] 297 | }, 298 | "metadata": { 299 | "needs_background": "light" 300 | }, 301 | "output_type": "display_data" 302 | } 303 | ], 304 | "source": [ 305 | "plt.figure()\n", 306 | "trace = approx.sample(10000)\n", 307 | "sns.distplot(trace['x']);" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "### デバッグ\n", 315 | "theanoのprinting.Printを用いる。" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 24, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stderr", 325 | "output_type": "stream", 326 | "text": [ 327 | "Only 5 samples in chain.\n", 328 | "Sequential sampling (1 chains in 1 job)\n", 329 | "CompoundStep\n", 330 | ">Metropolis: [sd]\n", 331 | ">Metropolis: [mu]\n", 332 | "/Users/yuho.kida/python_DS/lib/python3.7/site-packages/pymc3/sampling.py:476: UserWarning: The number of samples is too small to check convergence reliably.\n", 333 | " warnings.warn(\"The number of samples is too small to check convergence reliably.\")\n" 334 | ] 335 | }, 336 | { 337 | "data": { 338 | "image/png": "\n", 339 | "text/plain": [ 340 | "
" 341 | ] 342 | }, 343 | "metadata": { 344 | "needs_background": "light" 345 | }, 346 | "output_type": "display_data" 347 | } 348 | ], 349 | "source": [ 350 | "from io import StringIO\n", 351 | "import theano.tensor as tt\n", 352 | "import sys\n", 353 | "\n", 354 | "x = np.random.randn(100)\n", 355 | "\n", 356 | "old_stdout = sys.stdout\n", 357 | "mystdout = sys.stdout = StringIO()\n", 358 | "\n", 359 | "with pm.Model() as model:\n", 360 | " mu = pm.Normal('mu', mu=0, sd=1)\n", 361 | " sd = pm.Normal('sd', mu=0, sd=1)\n", 362 | "\n", 363 | " mu_print = tt.printing.Print('mu')(mu)\n", 364 | " sd_print = tt.printing.Print('sd')(sd)\n", 365 | "\n", 366 | " obs = pm.Normal('obs', mu=mu_print, sd=sd_print, observed=x)\n", 367 | " step = pm.Metropolis()\n", 368 | " trace = pm.sample(5, step, tune=0, chains=1, progressbar=False) # Make sure not to draw too many samples\n", 369 | "\n", 370 | "sys.stdout = old_stdout\n", 371 | "\n", 372 | "output = mystdout.getvalue().split('\\n')\n", 373 | "mulines = [s for s in output if 'mu' in s]\n", 374 | "\n", 375 | "muvals = [line.split()[-1] for line in mulines]\n", 376 | "plt.plot(np.arange(0, len(muvals)), muvals)\n", 377 | "plt.xlabel('proposal iteration')\n", 378 | "plt.ylabel('mu value');" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "### Mass matrix\n", 386 | "\n", 387 | "HMC法のパラメータは3つ。\n", 388 | "\n", 389 | "1. discretization time ϵ\n", 390 | "2. mass matrix Σ−1\n", 391 | "3. number of steps taken L\n", 392 | "\n" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [] 401 | } 402 | ], 403 | "metadata": { 404 | "kernelspec": { 405 | "display_name": "Python 3", 406 | "language": "python", 407 | "name": "python3" 408 | }, 409 | "language_info": { 410 | "codemirror_mode": { 411 | "name": "ipython", 412 | "version": 3 413 | }, 414 | "file_extension": ".py", 415 | "mimetype": "text/x-python", 416 | "name": "python", 417 | "nbconvert_exporter": "python", 418 | "pygments_lexer": "ipython3", 419 | "version": "3.7.1" 420 | } 421 | }, 422 | "nbformat": 4, 423 | "nbformat_minor": 2 424 | } 425 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StatisticalModeling 2 | This repository incluldes Python/R sciripts for statistical modeling. This is mostly for personal use. 3 | -------------------------------------------------------------------------------- /data.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kidaufo/StatisticalModeling/68007e022846355148102f1fcbe118ffdcb99268/data.RData --------------------------------------------------------------------------------