├── .gitignore ├── LICENSE ├── README.md ├── crf ├── CRF.ipynb └── crf.py ├── ctc └── CTC.ipynb ├── gmm ├── GMM-EM.ipynb ├── README.md ├── gmm.py ├── kmeans.py └── requirements.txt ├── gumbel ├── gumbel-distribution.ipynb ├── gumbel.py ├── img │ ├── i_xent.png │ └── x_xent.png └── variational_autoencoder_gumbel.py └── plda ├── PLDA.ipynb └── plda.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.pyc 3 | *.swp 4 | temp/* 5 | dist/* 6 | build/* 7 | tags 8 | 9 | # test-related 10 | .coverage 11 | .cache 12 | 13 | # developer environments 14 | .idea 15 | 16 | # jupyter notebook 17 | .ipynb_checkpoints 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ke Ding 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Machine Learning Algorithms & Implementations 2 | 3 | * [K-Means and Gaussian Mixture Model (GMM)](./gmm) 4 | * [Conditional Random Field (CRF)](./crf) 5 | * [Connectionist Temporal Classification (CTC)](./ctc) 6 | * [Gumbel Distribution](./gumbel) 7 | * [Probabilistic Linear Discriminant Analysis (PLDA)](./plda) 8 | * [Gaussian Process Regression (GPR)](./gp) 9 | -------------------------------------------------------------------------------- /crf/CRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1. 概述\n", 8 | "\n", 9 | "条件随机场(Conditional Random Field, CRF)是概率图模型(Probabilistic Graphical Model)与区分性分类( Discriminative Classification)的一种接合,能够用来对序列分类(标注)问题进行建模。\n", 10 | "\n", 11 | "如图1,论文 [1] 阐释了 CRF 与其他模型之间的关系。\n", 12 | "\n", 13 | "![](http://ww3.sinaimg.cn/large/6cbb8645jw1en4lqd5qq3j21510p1jvn.jpg)\n", 14 | "**图1. CRF 与 其他机器学习模型对比【[src](http://www.hankcs.com/nlp/segment/crf-segmentation-of-the-pure-java-implementation.html)】**\n", 15 | "\n", 16 | "\n", 17 | "\n", 18 | "![](https://ss0.bdstatic.com/70cFuHSh_Q1YnxGkpoWK1HF6hhy/it/u=1645229592,3117226322&fm=27&gp=0.jpg)\n", 19 | "**图2. 简单的线性链 CRF**\n", 20 | "\n", 21 | "本文我们重点关注输入结点独立的“线性链条件随机场”(Linear-Chain CRF)(如图2)的原理与实现。线性链 CRF 通过与双向 LSTM(Bi-LSTM)的接合,可以用来建模更一般的线性链 CRF(图3),提高模型的建模能力。\n", 22 | "\n", 23 | "![](https://timgsa.baidu.com/timg?image&quality=80&size=b9999_10000&sec=1520657695107&di=5973a34b94de511772a9337ad7de11e0&imgtype=0&src=http%3A%2F%2Fs14.sinaimg.cn%2Fmw690%2F002R78Yfgy71Mv4EvH78d%26690)\n", 24 | "**图3. 一般性线性链 CRF**" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "# 2. CRF 算法\n", 32 | "\n", 33 | "## 2.1 模型形式化\n", 34 | "\n", 35 | "\n", 36 | "给定长度为 $m$ 的序列, 以及状态集 $S$。 对于任意状态序列 $(s_1,\\cdots,s_m), s_i \\in S$, 定义其“势”(potential)如下:\n", 37 | "$$\n", 38 | "\\psi(s_1,\\dots, s_m) = \\prod_{i=1}^m\\psi (s_{i−1}, s_i , i)\n", 39 | "$$\n", 40 | "我们定义 $s_0$ 为特殊的开始符号 $*$。这里对 $s, s^\\prime ∈ S, i \\in {1 , \\dots, m}$,势函数 $\\psi(s, s^\\prime, i) \\ge 0$。也即,势函数是非负的,它对序列第 $i$ 位置发生的 $s$ 到 $s^\\prime$ 的状态转移都给出一个非负值。\n", 41 | "\n", 42 | "\n", 43 | "\n", 44 | "根据概率图模型的因子分解理论[1],我们有:\n", 45 | "$$\n", 46 | "p(s_1,\\dots,s_m|x_1,\\dots, x_m) = \\frac{\\psi(s_1,\\dots, s_m) }{\\sum_{s^\\prime_1,\\dots,s^\\prime_m} \\psi(s^\\prime_1,\\dots, s^\\prime_m)}\n", 47 | "$$\n", 48 | "\n", 49 | "$Z = \\sum_{s^\\prime_1,\\dots,s^\\prime_m} \\psi(s^\\prime_1,\\dots, s^\\prime_m) $ 为归一化因子。\n", 50 | "\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "同 HMM 类似,CRF 也涉及三类基本问题:评估(计算某一序列的似然值)、解码(给定输入,寻找似然最大的序列)及训练(根据数据估计 CRF 的参数),解决这三个问题也都涉及前向算法、后向算法及 Viterbi 算法。\n", 58 | "\n", 59 | "CRF 的势函数类似于概率,只不过没有归一化,因此这里介绍的 CRF 前向算法、Viterbi 算法、后向算法,同 HMM 基本一致。\n", 60 | "\n", 61 | "## 2.2 前向算法\n", 62 | "\n", 63 | "定义:\n", 64 | "$$\n", 65 | "\\alpha(i, s) = \\sum_{s_1,\\dots,s_{i-1}} \\psi(s_1,\\dots,s_{i-1}, s)\n", 66 | "$$\n", 67 | "\n", 68 | "表示,以 $s$ 结尾的长度为 $i$ 的子序列的势。\n", 69 | "\n", 70 | "显然,$\\alpha(1, s) = \\psi(*, s_1, 1)$\n", 71 | "\n", 72 | "根据定义,我们有如下递归关系:\n", 73 | "$$\n", 74 | "\\alpha(i, s) = \\sum_{s^\\prime \\in S} \\alpha(i-1, s^\\prime) \\times \\psi(s^\\prime, s, i)\n", 75 | "$$\n", 76 | "\n", 77 | "归一化因子可以计算如下:\n", 78 | "$$Z = \\sum_{s_1,\\dots,s_m} \\psi(s_1,\\dots s_m) = \\sum_{s\\in S}\\sum_{s_1,\\dots,s_{m-1}} \\psi(s_1,\\dots s_{m-1}, s)= \\sum_{s\\in S} \\alpha(m, s)\n", 79 | "$$\n", 80 | "\n", 81 | "对于给定的序列 $(s_1,\\cdots,s_m)$,其中条件概率(似然)可以计算:\n", 82 | "$$\n", 83 | "p(s_1,\\dots,s_m|x_1,\\dots, x_m) = \\frac{\\prod_{i=1}^m\\psi (s_{i−1}, s_i , i)}{\\sum_{s\\in S} \\alpha(m, s)}\n", 84 | "$$\n", 85 | "\n", 86 | "** 通过前向算法,我们解决了评估问题,计算和空间复杂度为 $O(m\\cdot|S|^2)$。**\n", 87 | "\n", 88 | "> 似然的计算过程中,只涉及乘法和加法,都是可导操作。因此,只需要实现前向操作,我们就可以借具有自动梯度功能的学习库(e.g. pytorch、tensorflow)实现基于最大似然准则的训练。一个基于 pytorch 的 CRF 实现见 [repo](https://github.com/DingKe/ml-tutorial/blob/master/crf/crf.py#L39)。" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 1, 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "2.69869828108e-08\n", 103 | "[[ 1.10026295e+00 2.52187760e+00 1.40997704e+00 1.36407554e+00\n", 104 | " 1.00201186e+00]\n", 105 | " [ 1.27679086e+01 1.03890052e+01 1.44699134e+01 1.15244329e+01\n", 106 | " 1.52767179e+01]\n", 107 | " [ 9.30306192e+01 1.09450375e+02 1.26777728e+02 1.28529576e+02\n", 108 | " 1.16835669e+02]\n", 109 | " [ 9.81861108e+02 8.70384204e+02 9.35531558e+02 7.98228277e+02\n", 110 | " 9.89225754e+02]\n", 111 | " [ 6.89790063e+03 8.71016058e+03 8.84778486e+03 9.21051594e+03\n", 112 | " 6.56093883e+03]\n", 113 | " [ 7.56109978e+04 7.00773298e+04 8.60611103e+04 5.63567069e+04\n", 114 | " 5.99238226e+04]\n", 115 | " [ 6.69236243e+05 6.42107210e+05 7.81638452e+05 6.32533145e+05\n", 116 | " 5.71122492e+05]\n", 117 | " [ 6.62242340e+06 5.24446290e+06 5.54750409e+06 4.68782248e+06\n", 118 | " 4.49353155e+06]\n", 119 | " [ 4.31080734e+07 4.09579660e+07 4.62891972e+07 4.60100937e+07\n", 120 | " 4.63083098e+07]\n", 121 | " [ 2.66620185e+08 4.91942550e+08 4.48597546e+08 3.42214705e+08\n", 122 | " 4.10510463e+08]]\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "import numpy as np\n", 128 | "\n", 129 | "def forward(psi):\n", 130 | " m, V, _ = psi.shape\n", 131 | " \n", 132 | " alpha = np.zeros([m, V])\n", 133 | " alpha[0] = psi[0, 0, :] # assume psi[0, 0, :] := psi(*,s,1)\n", 134 | " \n", 135 | " for t in range(1, m):\n", 136 | " for i in range(V):\n", 137 | " '''\n", 138 | " for k in range(V):\n", 139 | " alpha[t, i] += alpha[t - 1, k] * psi[t, k, i]\n", 140 | " '''\n", 141 | " alpha[t, i] = np.sum(alpha[t - 1, :] * psi[t, :, i])\n", 142 | " \n", 143 | " return alpha\n", 144 | "\n", 145 | "def pro(seq, psi):\n", 146 | " m, V, _ = psi.shape\n", 147 | " alpha = forward(psi)\n", 148 | " \n", 149 | " Z = np.sum(alpha[-1])\n", 150 | " M = psi[0, 0, seq[0]]\n", 151 | " for i in range(1, m):\n", 152 | " M *= psi[i, seq[i-1], seq[i]]\n", 153 | " \n", 154 | " p = M / Z\n", 155 | " return p\n", 156 | "\n", 157 | "np.random.seed(1111)\n", 158 | "V, m = 5, 10\n", 159 | "\n", 160 | "log_psi = np.random.random([m, V, V])\n", 161 | "psi = np.exp(log_psi) # nonnegative\n", 162 | "seq = np.random.choice(V, m)\n", 163 | "\n", 164 | "alpha = forward(psi)\n", 165 | "p = pro(seq, psi)\n", 166 | "print(p)\n", 167 | "print(alpha)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "## 2.3 Viterbi 解码\n", 175 | "\n", 176 | "Viterbi 利用动态规划,寻找似然最大的序列。Viterbi 与前向算法非常相似,只是将求和操作替换为最大值操作。\n", 177 | "\n", 178 | "$$\n", 179 | "\\alpha(j, s) = \\underset{s_1,\\dots,s_{j-1}}{\\mathrm{max}}\\psi(s_1,\\dots,s_{j-1}, s)\n", 180 | "$$\n", 181 | "显然,$\\alpha(1, s) = \\psi(*, s_1, 1)$\n", 182 | "\n", 183 | "根据定义,我们有如下递归关系:\n", 184 | "$$\n", 185 | "\\alpha(j, s) = \\underset{s^\\prime \\in S}{\\mathrm{max}}\\ \\alpha(j-1, s^\\prime) \\cdot \\psi(s^\\prime, s, j)\n", 186 | "$$\n", 187 | "\n", 188 | "在所有 $|s|^m$ 条可能的序列中,概率最大的路径的未归一化的值为:\n", 189 | "$$\n", 190 | "\\max \\alpha(m, s)\n", 191 | "$$\n", 192 | "沿着前向推导的反方向,可以得到最优的路径,算法复杂度是 $O(m*|S|^2)$。demo 实现如下:" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 2, 198 | "metadata": { 199 | "collapsed": false 200 | }, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "[1, 4, 2, 4, 3, 0, 3, 0, 3, 1]\n", 207 | "[1, 4, 2, 4, 3, 0, 3, 0, 3, 1]\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "def viterbi_1(psi):\n", 213 | " m, V, _ = psi.shape\n", 214 | " \n", 215 | " alpha = np.zeros([V])\n", 216 | " trans = np.ones([m, V]).astype('int') * -1\n", 217 | "\n", 218 | " alpha[:] = psi[0, 0, :] # assume psi[0, 0, :] := psi(*,s,1)\n", 219 | " \n", 220 | " for t in range(1, m):\n", 221 | " next_alpha = np.zeros([V])\n", 222 | " for i in range(V):\n", 223 | " tmp = alpha * psi[t, :, i]\n", 224 | " next_alpha[i] = np.max(tmp)\n", 225 | " trans[t, i] = np.argmax(tmp)\n", 226 | " alpha = next_alpha\n", 227 | " \n", 228 | " end = np.argmax(alpha)\n", 229 | " path = [end]\n", 230 | " for t in range(m - 1, 0, -1):\n", 231 | " cur = path[-1]\n", 232 | " pre = trans[t, cur]\n", 233 | " path.append(pre)\n", 234 | "\n", 235 | " return path[::-1]\n", 236 | "\n", 237 | "def viterbi_2(psi):\n", 238 | " m, V, _ = psi.shape\n", 239 | " \n", 240 | " alpha = np.zeros([m, V])\n", 241 | " alpha[0] = psi[0, 0, :] # assume psi[0, 0, :] := psi(*,s,1)\n", 242 | " for t in range(1, m):\n", 243 | " for i in range(V):\n", 244 | " tmp = alpha[t - 1, :] * psi[t, :, i]\n", 245 | " alpha[t, i] = np.max(tmp)\n", 246 | " \n", 247 | " end = np.argmax(alpha[-1])\n", 248 | " path = [end]\n", 249 | " for t in range(m - 1, 0, -1):\n", 250 | " cur = path[-1]\n", 251 | " pre = np.argmax(alpha[t - 1] * psi[t, :, cur])\n", 252 | " path.append(pre)\n", 253 | "\n", 254 | " return path[::-1]\n", 255 | "\n", 256 | "np.random.seed(1111)\n", 257 | "V, m = 5, 10\n", 258 | "\n", 259 | "log_psi = np.random.random([m, V, V])\n", 260 | "psi = np.exp(log_psi) # nonnegative\n", 261 | "\n", 262 | "path_1 = viterbi_1(psi)\n", 263 | "path_2 = viterbi_2(psi)\n", 264 | "print(path_1)\n", 265 | "print(path_2)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "## 2.4 后向算法\n", 273 | "\n", 274 | "为了训练 CRF, 我们需要计算相应的梯度。为了手动计算梯度(这也为后续优化打开大门),需要用到后向算法。\n", 275 | "\n", 276 | "定义:\n", 277 | "\n", 278 | "$$\n", 279 | "\\beta(j, s) = \\sum_{s_{j+1},\\dots, s_m} \\psi(s_{j+1},\\dots, s_m|s_j=s)\n", 280 | "$$\n", 281 | "\n", 282 | "其中,令 $\\beta(m, s) = 1$。\n", 283 | "> 可以认为序列结尾存在特殊的符号。为简单起见,不讨论结尾边界的特殊性,可以都参考前向边界的处理及参见[实现](https://github.com/DingKe/ml-tutorial/blob/master/crf/crf.py#L154)。\n", 284 | "\n", 285 | "根据定义,我们有如下递归关系:\n", 286 | "\n", 287 | "$$\n", 288 | "\\beta(j, s) = \\sum_{s^\\prime \\in S} \\beta(j+1, s^\\prime) \\cdot \\psi(s, s^\\prime, j+1)\n", 289 | "$$" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 3, 295 | "metadata": { 296 | "collapsed": false 297 | }, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "[[ 2.95024144e+08 2.61620644e+08 3.16953747e+08 2.02959597e+08\n", 304 | " 2.51250862e+08]\n", 305 | " [ 2.73494359e+07 2.31521489e+07 3.62404054e+07 2.84752625e+07\n", 306 | " 3.38820012e+07]\n", 307 | " [ 2.92799244e+06 3.00539203e+06 4.18174216e+06 3.30814155e+06\n", 308 | " 3.45104724e+06]\n", 309 | " [ 4.40588351e+05 4.18060894e+05 3.95721271e+05 4.50117410e+05\n", 310 | " 4.38635065e+05]\n", 311 | " [ 4.51172884e+04 5.40496888e+04 4.37931199e+04 4.98898498e+04\n", 312 | " 5.04357771e+04]\n", 313 | " [ 6.50740169e+03 5.21859026e+03 5.66773856e+03 4.73895449e+03\n", 314 | " 5.79578682e+03]\n", 315 | " [ 4.83173340e+02 5.36538120e+02 6.01820173e+02 7.07538756e+02\n", 316 | " 6.54966046e+02]\n", 317 | " [ 7.60936291e+01 7.90609361e+01 9.08681883e+01 5.80503199e+01\n", 318 | " 5.89976569e+01]\n", 319 | " [ 8.15414542e+00 7.95904764e+00 9.64664115e+00 8.69502743e+00\n", 320 | " 9.41073532e+00]\n", 321 | " [ 1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00\n", 322 | " 1.00000000e+00]]\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "def backward(psi):\n", 328 | " m, V, _ = psi.shape\n", 329 | " \n", 330 | " beta = np.zeros([m, V])\n", 331 | " beta[-1] = 1\n", 332 | " \n", 333 | " for t in range(m - 2, -1, -1):\n", 334 | " for i in range(V):\n", 335 | " '''\n", 336 | " for k in range(V):\n", 337 | " beta[t, i] += beta[t + 1, k] * psi[t + 1, i, k]\n", 338 | " '''\n", 339 | " beta[t, i] = np.sum(beta[t + 1, :] * psi[t + 1, i, :])\n", 340 | " \n", 341 | " return beta\n", 342 | "\n", 343 | "np.random.seed(1111)\n", 344 | "V, m = 5, 10\n", 345 | "\n", 346 | "log_psi = np.random.random([m, V, V])\n", 347 | "psi = np.exp(log_psi) # nonnegative\n", 348 | "seq = np.random.choice(V, m)\n", 349 | "\n", 350 | "beta = backward(psi)\n", 351 | "print(beta)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## 2.5 梯度计算\n", 359 | "$$\n", 360 | "Z = \\sum_{s_1,\\dots,s_m} \\psi(s_1,\\dots, s_m) = \\sum_{s^\\prime_{i-1} \\in S, s^\\prime_i \\in S} \\sum_{s_{i-1}=s^\\prime_{i-1}, s_i=s^\\prime_i} \\psi(s_1,\\dots, s_m) = \\sum_{s^\\prime_{i-1} \\in S, s^\\prime_i \\in S} \\alpha(i-1, s^\\prime_{i-1}) \\cdot \\beta(i, s^\\prime_i) \\cdot \\psi (s^\\prime_{i−1}, s^\\prime_i , i) \\ \\ \\ 1 < i \\le m\n", 361 | "$$\n", 362 | "\n", 363 | "对于 $i = 1$ 的边界情况:\n", 364 | "$$\n", 365 | "Z = \\sum_{s^\\prime_1 \\in S} \\beta(1, s^\\prime_i) \\cdot \\psi (*, s^\\prime_1 , 1)\n", 366 | "$$\n", 367 | "\n", 368 | "对于路径 $(s_1, \\cdots, s_m)$,\n", 369 | "$$\n", 370 | "p(s_1,\\dots,s_m|x_1,\\dots, x_m) = \\frac{\\psi(s_1,\\dots, s_m)}{Z} = \\frac{\\prod_{i=1}^m\\psi (s_{i−1}, s_i , i)}{Z} = \\frac{\\prod_{i=1}^m\\psi^i_{s_{i−1}, s_i}}{Z}\n", 371 | "$$\n", 372 | "其中,$\\psi^i_{s^\\prime, s} = \\psi(s^\\prime, s, i),\\ s^\\prime, s \\in S$。\n", 373 | "\n", 374 | "记分子 $\\prod_{i=1}^m\\psi (s_{i−1}, s_i , i) = M$ 则:\n", 375 | "$$\n", 376 | "\\frac{\\partial p(s_1,\\dots,s_m|x_1,\\dots, x_m)}{\\partial \\psi^k_{s^\\prime, s}} = \\frac{1}{Z} [ \\frac{M}{\\psi^k_{s^\\prime, s}} \\cdot \\delta_{s^\\prime = s_{k-1} \\& s = s_k} - p \\cdot \\alpha(k-1, s^\\prime) \\cdot \\beta(k, s)]\n", 377 | "$$\n", 378 | "\n", 379 | "其中,$\\delta_{true} = 1, \\delta_{false} = 0$。\n", 380 | "\n", 381 | "$$\n", 382 | "\\frac{\\partial \\ln p(s_1,\\dots,s_m|x_1,\\dots, x_m)}{\\partial \\psi^k_{s^\\prime, s}} = \\frac{1}{p} \\cdot \\frac{\\partial p(s_1,\\dots,s_m|x_1,\\dots, x_m)}{\\partial \\psi^k_{s^\\prime, s}} = \\frac{\\delta_{s^\\prime = s_{k-1} \\& s = s_k}}{\\psi^k_{s^\\prime, s}} - \\frac{1}{Z} \\alpha(k-1, s^\\prime) \\cdot \\beta(k, s)\n", 383 | "$$\n" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 4, 389 | "metadata": { 390 | "collapsed": false 391 | }, 392 | "outputs": [ 393 | { 394 | "name": "stdout", 395 | "output_type": "stream", 396 | "text": [ 397 | "[[ 0.75834232 -0.13348772 -0.16172055 -0.10355687 -0.12819671]\n", 398 | " [ 0. 0. 0. 0. 0. ]\n", 399 | " [ 0. 0. 0. 0. 0. ]\n", 400 | " [ 0. 0. 0. 0. 0. ]\n", 401 | " [ 0. 0. 0. 0. 0. ]]\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "def gradient(seq, psi):\n", 407 | " m, V, _ = psi.shape\n", 408 | " \n", 409 | " grad = np.zeros_like(psi)\n", 410 | " alpha = forward(psi)\n", 411 | " beta = backward(psi)\n", 412 | " \n", 413 | " Z = np.sum(alpha[-1])\n", 414 | " \n", 415 | " for t in range(1, m):\n", 416 | " for i in range(V):\n", 417 | " for j in range(V):\n", 418 | " grad[t, i, j] = -alpha[t - 1, i] * beta[t, j] / Z\n", 419 | " \n", 420 | " if i == seq[t - 1] and j == seq[t]:\n", 421 | " grad[t, i, j] += 1. / psi[t, i, j]\n", 422 | "\n", 423 | " # corner cases\n", 424 | " grad[0, 0, :] = -beta[0, :] / Z\n", 425 | " grad[0, 0, seq[0]] += 1. / psi[0, 0, seq[0]]\n", 426 | " \n", 427 | " return grad\n", 428 | " \n", 429 | "np.random.seed(1111)\n", 430 | "V, m = 5, 10\n", 431 | "\n", 432 | "log_psi = np.random.random([m, V, V])\n", 433 | "psi = np.exp(log_psi) # nonnegative\n", 434 | "seq = np.random.choice(V, m)\n", 435 | "\n", 436 | "grad = gradient(seq, psi)\n", 437 | "print(grad[0, :, :])" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 5, 443 | "metadata": { 444 | "collapsed": false, 445 | "scrolled": true 446 | }, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "[0 1 4 1 3 0 0 3 3 1]\n", 453 | "0.0001\n", 454 | "5e-05\n", 455 | "1.5e-05\n", 456 | "2, 1, 2, -2.22e-02, -2.22e-02, 1.55e-05\n", 457 | "4, 3, 3, -2.03e-02, -2.03e-02, 1.55e-05\n" 458 | ] 459 | } 460 | ], 461 | "source": [ 462 | "def check_grad(seq, psi, i, j, k, toleration=1e-5, delta=1e-10):\n", 463 | " m, V, _ = psi.shape\n", 464 | " \n", 465 | " grad_1 = gradient(seq, psi)[i, j, k]\n", 466 | " \n", 467 | " original = psi[i, j, k]\n", 468 | " \n", 469 | " # p1\n", 470 | " psi[i, j, k] = original - delta\n", 471 | " p1 = np.log(pro(seq, psi))\n", 472 | " \n", 473 | " # p2\n", 474 | " psi[i, j, k] = original + delta\n", 475 | " p2 = np.log(pro(seq, psi))\n", 476 | " \n", 477 | " psi[i, j, k] = original\n", 478 | " grad_2 = (p2 - p1) / (2 * delta)\n", 479 | " \n", 480 | " diff = np.abs(grad_1 - grad_2)\n", 481 | " if diff > toleration:\n", 482 | " print(\"%d, %d, %d, %.2e, %.2e, %.2e\" % (i, j, k, grad_1, grad_2, diff))\n", 483 | " \n", 484 | "np.random.seed(1111)\n", 485 | "V, m = 5, 10\n", 486 | "\n", 487 | "log_psi = np.random.random([m, V, V])\n", 488 | "psi = np.exp(log_psi) # nonnegative\n", 489 | "seq = np.random.choice(V, m)\n", 490 | "print(seq)\n", 491 | "\n", 492 | "for toleration in [1e-4, 5e-5, 1.5e-5]:\n", 493 | " print(toleration)\n", 494 | " for i in range(m):\n", 495 | " for j in range(V):\n", 496 | " for k in range(V):\n", 497 | " check_grad(seq, psi, i, j, k, toleration)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "首先定义基本的 log 域加法操作([参见](https://github.com/DingKe/ml-tutorial/blob/master/ctc/CTC.ipynb))。" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 6, 510 | "metadata": { 511 | "collapsed": true 512 | }, 513 | "outputs": [], 514 | "source": [ 515 | "ninf = -np.float('inf')\n", 516 | "\n", 517 | "def _logsumexp(a, b):\n", 518 | " '''\n", 519 | " np.log(np.exp(a) + np.exp(b))\n", 520 | "\n", 521 | " '''\n", 522 | "\n", 523 | " if a < b:\n", 524 | " a, b = b, a\n", 525 | "\n", 526 | " if b == ninf:\n", 527 | " return a\n", 528 | " else:\n", 529 | " return a + np.log(1 + np.exp(b - a)) \n", 530 | " \n", 531 | "def logsumexp(*args):\n", 532 | " '''\n", 533 | " from scipy.special import logsumexp\n", 534 | " logsumexp(args)\n", 535 | " '''\n", 536 | " res = args[0]\n", 537 | " for e in args[1:]:\n", 538 | " res = _logsumexp(res, e)\n", 539 | " return res" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 7, 545 | "metadata": { 546 | "collapsed": false 547 | }, 548 | "outputs": [ 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "3.03719722983e-14\n", 554 | "0.0\n" 555 | ] 556 | } 557 | ], 558 | "source": [ 559 | "def forward_log(log_psi):\n", 560 | " m, V, _ = log_psi.shape\n", 561 | " \n", 562 | " log_alpha = np.ones([m, V]) * ninf\n", 563 | " log_alpha[0] = log_psi[0, 0, :] # assume psi[0, 0, :] := psi(*,s,1)\n", 564 | " \n", 565 | " for t in range(1, m):\n", 566 | " for i in range(V):\n", 567 | " for j in range(V):\n", 568 | " log_alpha[t, j] = logsumexp(log_alpha[t, j], log_alpha[t - 1, i] + log_psi[t, i, j])\n", 569 | " \n", 570 | " return log_alpha\n", 571 | "\n", 572 | "def pro_log(seq, log_psi):\n", 573 | " m, V, _ = log_psi.shape\n", 574 | " log_alpha = forward_log(log_psi)\n", 575 | " \n", 576 | " log_Z = logsumexp(*[e for e in log_alpha[-1]])\n", 577 | " log_M = log_psi[0, 0, seq[0]]\n", 578 | " for i in range(1, m):\n", 579 | " log_M = log_M + log_psi[i, seq[i - 1], seq[i]]\n", 580 | " \n", 581 | " log_p = log_M - log_Z\n", 582 | " return log_p\n", 583 | "\n", 584 | "np.random.seed(1111)\n", 585 | "V, m = 5, 10\n", 586 | "\n", 587 | "log_psi = np.random.random([m, V, V])\n", 588 | "psi = np.exp(log_psi) # nonnegative\n", 589 | "seq = np.random.choice(V, m)\n", 590 | "\n", 591 | "alpha = forward(psi)\n", 592 | "log_alpha = forward_log(log_psi)\n", 593 | "print(np.sum(np.abs(np.log(alpha) - log_alpha)))\n", 594 | "\n", 595 | "p = pro(seq, psi)\n", 596 | "log_p = pro_log(seq, log_psi)\n", 597 | "print(np.sum(np.abs(np.log(p) - log_p)))" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 8, 603 | "metadata": { 604 | "collapsed": false 605 | }, 606 | "outputs": [ 607 | { 608 | "name": "stdout", 609 | "output_type": "stream", 610 | "text": [ 611 | "1.46851337579e-06\n", 612 | "1.86517468137e-14\n" 613 | ] 614 | } 615 | ], 616 | "source": [ 617 | "def backward_log(log_psi):\n", 618 | " m, V, _ = log_psi.shape\n", 619 | " \n", 620 | " log_beta = np.ones([m, V]) * ninf\n", 621 | " log_beta[-1] = 0\n", 622 | " \n", 623 | " for t in range(m - 2, -1, -1):\n", 624 | " for i in range(V):\n", 625 | " for j in range(V):\n", 626 | " log_beta[t, i] = logsumexp(log_beta[t, i], log_beta[t + 1, j] + log_psi[t + 1, i, j])\n", 627 | " \n", 628 | " return log_beta\n", 629 | "\n", 630 | "np.random.seed(1111)\n", 631 | "V, m = 5, 10\n", 632 | "\n", 633 | "log_psi = np.random.random([m, V, V])\n", 634 | "psi = np.exp(log_psi) # nonnegative\n", 635 | "seq = np.random.choice(V, m)\n", 636 | "\n", 637 | "beta = backward(psi)\n", 638 | "log_beta = backward_log(log_psi)\n", 639 | "\n", 640 | "print(np.sum(np.abs(beta - np.exp(log_beta))))\n", 641 | "print(np.sum(np.abs(log_beta - np.log(beta))))" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 9, 647 | "metadata": { 648 | "collapsed": false 649 | }, 650 | "outputs": [ 651 | { 652 | "name": "stdout", 653 | "output_type": "stream", 654 | "text": [ 655 | "[[ 0.75834232 -0.13348772 -0.16172055 -0.10355687 -0.12819671]\n", 656 | " [ 0. 0. 0. 0. 0. ]\n", 657 | " [ 0. 0. 0. 0. 0. ]\n", 658 | " [ 0. 0. 0. 0. 0. ]\n", 659 | " [ 0. 0. 0. 0. 0. ]]\n", 660 | "[[ 0.75834232 -0.13348772 -0.16172055 -0.10355687 -0.12819671]\n", 661 | " [ 0. 0. 0. 0. 0. ]\n", 662 | " [ 0. 0. 0. 0. 0. ]\n", 663 | " [ 0. 0. 0. 0. 0. ]\n", 664 | " [ 0. 0. 0. 0. 0. ]]\n", 665 | "1.11508025036e-14\n" 666 | ] 667 | } 668 | ], 669 | "source": [ 670 | "def gradient_log(seq, log_psi):\n", 671 | " m, V, _ = log_psi.shape\n", 672 | " \n", 673 | " grad = np.zeros_like(log_psi)\n", 674 | " log_alpha = forward_log(log_psi)\n", 675 | " log_beta = backward_log(log_psi)\n", 676 | " \n", 677 | " log_Z = logsumexp(*[e for e in log_alpha[-1]])\n", 678 | " for t in range(1, m):\n", 679 | " for i in range(V):\n", 680 | " for j in range(V):\n", 681 | " grad[t, i, j] -= np.exp(log_alpha[t - 1, i] + log_beta[t, j] - log_Z)\n", 682 | " if i == seq[t - 1] and j == seq[t]:\n", 683 | " grad[t, i, j] += np.exp(-log_psi[t, i, j])\n", 684 | "\n", 685 | " # corner cases\n", 686 | " grad[0, 0, :] -= np.exp(log_beta[0, :] - log_Z)\n", 687 | " grad[0, 0, seq[0]] += np.exp(-log_psi[0, 0, seq[0]])\n", 688 | " \n", 689 | " return grad\n", 690 | " \n", 691 | "np.random.seed(1111)\n", 692 | "V, m = 5, 10\n", 693 | "\n", 694 | "log_psi = np.random.random([m, V, V])\n", 695 | "psi = np.exp(log_psi) # nonnegative\n", 696 | "seq = np.random.choice(V, m)\n", 697 | "\n", 698 | "grad_1 = gradient(seq, psi)\n", 699 | "grad_2 = gradient_log(seq, log_psi)\n", 700 | "\n", 701 | "print(grad_1[0, :, :])\n", 702 | "print(grad_2[0, :, :])\n", 703 | "print(np.sum(np.abs(grad_1 - grad_2)))" 704 | ] 705 | }, 706 | { 707 | "cell_type": "markdown", 708 | "metadata": {}, 709 | "source": [ 710 | "在 log 域, 我们一般直接计算目标函数相对与 $\\ln\\psi$ 的梯度计算公式如下:\n", 711 | "\n", 712 | "\n", 713 | "$$\n", 714 | "\\frac{\\partial \\ln p(s_1,\\dots,s_m|x_1,\\dots, x_m)}{\\partial \\ln \\psi^k_{s^\\prime, s}} =\n", 715 | "\\frac{\\partial \\ln p(s_1,\\dots,s_m|x_1,\\dots, x_m)}{\\partial \\psi^k_{s^\\prime, s}} \\cdot \\frac{\\partial \\psi^k_{s^\\prime, s}}{\\partial \\ln \\psi^k_{s^\\prime, s}} = \\delta_{s^\\prime = s_{k-1} \\& s = s_k} - \\exp(\\ln\\alpha(k-1, s^\\prime) + \\ln \\beta(k, s) - \\ln Z + \\ln\\psi^k_{s^\\prime, s})\n", 716 | "$$\n", 717 | "\n", 718 | "只需将上面的 grad_log 稍做改动即可,不再赘述。" 719 | ] 720 | }, 721 | { 722 | "cell_type": "markdown", 723 | "metadata": {}, 724 | "source": [ 725 | "# 3. CRF + 人工神经网络\n", 726 | "\n", 727 | "## 3.1 势函数选择\n", 728 | "\n", 729 | "\n", 730 | "目前为止,我们都假设函数已经知道,在此基础上推导 CRF 的相关计算。理论上,除了非负性的要求 ,势函数可以灵活的选择。为也便于计算和训练,CRF 中一般选择指数的形式。假设输入为 $x_1,\\dots,x_m$,则势函数定义为:\n", 731 | "\n", 732 | "$$\n", 733 | "\\psi(s^\\prime, s, i) = \\exp(w \\cdot \\phi(x_1,\\dots,x_m, s^\\prime, s, i))\n", 734 | "$$\n", 735 | "\n", 736 | "则\n", 737 | "$$\n", 738 | "\\psi(s_1,\\dots, s_m) = \\prod_{i=1}^m\\psi (s_{i−1}, s_i , i) = \\prod_{i=1}^m\\exp(w \\cdot \\phi(x_1,\\dots,x_m, s_{i-1}, s_i,i))\n", 739 | "$$\n", 740 | "\n", 741 | "其中,$\\phi(x_1,\\dots,x_m, s^\\prime, s, i) \\in R^d$ 是特征向量,$w \\in R^d$ 是参数向量。 \n", 742 | "\n", 743 | "对于线性链模型,简化势函数为:\n", 744 | "$$\n", 745 | "\\psi(s^\\prime, s, i) = t(s|s^\\prime)e(s|x_i)\n", 746 | "$$\n", 747 | "\n", 748 | "转移势函数定义为:\n", 749 | "$$\n", 750 | "t(s|s^\\prime) = \\exp(v \\cdot g(s^\\prime, s))\n", 751 | "$$\n", 752 | "\n", 753 | "发射势函数定义为:\n", 754 | "$$\n", 755 | "e(s|x_i) = \\exp(w \\cdot f(s, x_i))\n", 756 | "$$\n", 757 | "\n", 758 | "则:\n", 759 | "$$\n", 760 | "\\psi(s_1,\\dots, s_m) = \\prod_{j=1}^m\\psi (s_{j−1}, s_j , j) = \\prod_{j=1}^m t(s_j|s_{j-1})e(s | x_j) = \\prod_{j=1}^m \\exp(v \\cdot g(s_{j-1}, s_j)) \\cdot \\exp(w \\cdot f(s_j, x_j)) \n", 761 | "$$\n", 762 | "\n", 763 | "$$\n", 764 | "\\psi(s_1,\\dots, s_m) = \\exp(\\sum_{i=1}^m v \\cdot g(s_{i-1}, s_i) + \\sum_{i=1}^m w \\cdot f(s_i, x_i))\n", 765 | "$$\n", 766 | "\n", 767 | "如果我们取对数,则我们得到一个线性模型,定义:\n", 768 | "\n", 769 | "$$\n", 770 | "score_t(s|s^\\prime) = \\log t(s|s^\\prime) = v \\cdot g(s^\\prime, s)\n", 771 | "$$\n", 772 | "\n", 773 | "$$\n", 774 | "score_e(s|x_i) = \\log e(s|x_i) = w \\cdot f(s, x_i)\n", 775 | "$$\n", 776 | "\n", 777 | "则\n", 778 | "\n", 779 | "$$\n", 780 | "\\log \\psi(s_1,\\dots, s_m) = \\sum_{i=1}^m v \\cdot g(s_{i-1}, s_i) + \\sum_{i=1}^m w \\cdot f(s_i, x_i) = \\sum_{i=1}^m score_t(s_{i-1}|s_i) + \\sum_{i=1}^m score_e(s_i|x_i)\n", 781 | "$$\n", 782 | "\n", 783 | "具体的,可以定义\n", 784 | "$$\n", 785 | "score_t(s_j|s_i) = P_{ij}\n", 786 | "$$\n", 787 | "其中,$P$ 是 $|S| \\times |S|$ 的转移矩阵。\n", 788 | "\n", 789 | "如果 $x = (x_1,\\cdots,x_m) \\in R^m$,则有: \n", 790 | "\n", 791 | "$$\n", 792 | "score_e(s_j|x_i) = W_j \\cdot x_i\n", 793 | "$$\n", 794 | "其中,$W \\in R^{|s| \\times n}$ 是权重矩阵。\n", 795 | "\n", 796 | "$$\n", 797 | "\\log \\psi(s_1,\\dots, s_m) = \\sum_{i=1}^m score_t(s|s^\\prime) + \\sum_{i=1}^m score_e(s|x_i) = \\sum_{i=1}^m P_{s_{i-1} s_{i}} + \\sum_{i=1}^m W_{s_i} \\cdot x_i\n", 798 | "$$\n", 799 | "\n", 800 | "这里,为简单起见,我们令 $x_i$ 是一个标量,实际中 $x_i$ 往往是向量。\n", 801 | "从 $x$ 到 $\\log\\psi$ 再到 $\\psi$ 都是可导的操作(四则运算和指数、对数运算),而 $\\psi$ 的梯度我们上面已经推导可以求得。因此,我们可以利用误差反传计算 $W$ 等参数的梯度,从而利用 SGD 等优化方法训练包括 CRF 在内的整个模型的参数。" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 10, 807 | "metadata": { 808 | "collapsed": false 809 | }, 810 | "outputs": [ 811 | { 812 | "name": "stdout", 813 | "output_type": "stream", 814 | "text": [ 815 | "[ 0.03394788 -0.11666261 0.02592661 0.07931277 0.02549323 0.11371901\n", 816 | " 0.02198856]\n", 817 | "[-0.62291675 -0.38050215 -0.18983737 -0.65300231 1.84625859]\n", 818 | "[[-0.34655117 -0.27314013 -0.16800195 -0.28352514 0.73359469]\n", 819 | " [-0.22747135 -0.2967193 -0.27009443 -0.2664594 0.87349324]\n", 820 | " [-0.27906702 -0.27747362 -0.33689934 -0.18786182 0.82788735]\n", 821 | " [-0.2701056 -0.16940564 -0.2624276 -0.29133856 -0.25558298]\n", 822 | " [ 0.72105085 0.86080584 0.76931185 -0.2103895 -0.11362927]]\n", 823 | "[-0.17736447 -0.21489701 -0.20747999 -0.19735031 0.79709179]\n" 824 | ] 825 | } 826 | ], 827 | "source": [ 828 | "def score(seq, x, W, P, S):\n", 829 | " m = len(seq)\n", 830 | " V = len(W)\n", 831 | " \n", 832 | " log_psi = np.zeros([m, V, V])\n", 833 | " \n", 834 | " # corner cases\n", 835 | " for i in range(V):\n", 836 | " # emit\n", 837 | " log_psi[0, 0, i] += S[i]\n", 838 | " # transmit\n", 839 | " log_psi[0, 0, i] += x[0] * W[i]\n", 840 | " \n", 841 | " for t in range(1, m):\n", 842 | " for i in range(V):\n", 843 | " for j in range(V):\n", 844 | " # emit\n", 845 | " log_psi[t, i, j] += x[t] * W[j]\n", 846 | " # transmit\n", 847 | " log_psi[t, i, j] += P[i, j]\n", 848 | " \n", 849 | " return log_psi \n", 850 | "\n", 851 | "def gradient_param(seq, x, W, P, S):\n", 852 | " m = len(seq)\n", 853 | " V = len(W)\n", 854 | " \n", 855 | " log_psi = score(seq, x, W, P, S)\n", 856 | " \n", 857 | " grad_psi = gradient_log(seq, log_psi)\n", 858 | " grad_log_psi = np.exp(log_psi) * grad_psi\n", 859 | " \n", 860 | " grad_x = np.zeros_like(x)\n", 861 | " grad_W = np.zeros_like(W)\n", 862 | " grad_P = np.zeros_like(P)\n", 863 | " grad_S = np.zeros_like(S)\n", 864 | " \n", 865 | " # corner cases\n", 866 | " for i in range(V):\n", 867 | " # emit\n", 868 | " grad_S[i] += grad_log_psi[0, 0, i]\n", 869 | " # transmit\n", 870 | " grad_W[i] += grad_log_psi[0, 0, i] * x[0]\n", 871 | " grad_x[0] += grad_log_psi[0, 0, i] * W[i]\n", 872 | " \n", 873 | " for t in range(1, m):\n", 874 | " for i in range(V):\n", 875 | " for j in range(V):\n", 876 | " # emit\n", 877 | " grad_W[j] += grad_log_psi[t, i, j] * x[t]\n", 878 | " grad_x[t] += grad_log_psi[t, i, j] * W[j]\n", 879 | " # transmit\n", 880 | " grad_P[i, j] += grad_log_psi[t, i, j]\n", 881 | " \n", 882 | " return grad_x, grad_W, grad_P, grad_S\n", 883 | " \n", 884 | "np.random.seed(1111)\n", 885 | "V, m = 5, 7\n", 886 | "\n", 887 | "seq = np.random.choice(V, m)\n", 888 | "x = np.random.random(m)\n", 889 | "W = np.random.random(V)\n", 890 | "P = np.random.random([V, V])\n", 891 | "S = np.random.random(V)\n", 892 | "\n", 893 | "grad_x, grad_W, grad_P, grad_S = gradient_param(seq, x, W, P, S)\n", 894 | "\n", 895 | "print(grad_x)\n", 896 | "print(grad_W)\n", 897 | "print(grad_P)\n", 898 | "print(grad_S)" 899 | ] 900 | }, 901 | { 902 | "cell_type": "markdown", 903 | "metadata": {}, 904 | "source": [ 905 | "梯度正确性检验如下:" 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": 11, 911 | "metadata": { 912 | "collapsed": false 913 | }, 914 | "outputs": [ 915 | { 916 | "name": "stdout", 917 | "output_type": "stream", 918 | "text": [ 919 | "Check X\n", 920 | "1, 6.75e-02, 6.75e-02, 5.74e-05\n", 921 | "2, 5.14e-01, 5.14e-01, 1.47e-05\n", 922 | "3, -3.17e-01, -3.17e-01, 1.51e-05\n", 923 | "5, -6.42e-02, -6.42e-02, 7.82e-05\n", 924 | "8, -4.38e-02, -4.38e-02, 1.08e-04\n", 925 | "Check W\n", 926 | "0, -6.55e-01, -6.55e-01, 1.13e-05\n", 927 | "2, -1.33e-03, -1.33e-03, 3.77e-04\n", 928 | "3, 5.88e-02, 5.89e-02, 1.15e-04\n", 929 | "Check P\n", 930 | "0, -4.50e-01, -4.51e-01, 1.03e-05\n", 931 | "0, -2.70e-01, -2.70e-01, 2.53e-05\n", 932 | "1, -2.11e-01, -2.11e-01, 3.13e-05\n", 933 | "1, -2.35e-01, -2.35e-01, 1.80e-05\n", 934 | "2, -2.93e-01, -2.93e-01, 1.76e-05\n", 935 | "2, -1.50e-01, -1.50e-01, 2.15e-05\n", 936 | "2, -1.72e-01, -1.72e-01, 3.40e-05\n", 937 | "2, -3.48e-01, -3.48e-01, 1.02e-05\n", 938 | "3, -1.90e-01, -1.90e-01, 3.10e-05\n", 939 | "3, -3.60e-01, -3.60e-01, 1.78e-05\n", 940 | "4, 5.47e-01, 5.47e-01, 1.50e-05\n", 941 | "Check S\n", 942 | "0, -2.02e-01, -2.02e-01, 2.13e-05\n", 943 | "1, -1.97e-01, -1.97e-01, 1.82e-05\n", 944 | "2, -1.05e-01, -1.05e-01, 6.22e-05\n" 945 | ] 946 | } 947 | ], 948 | "source": [ 949 | "def check_grad(seq, x, W, P, S, toleration=1e-5, delta=1e-10):\n", 950 | " m, V, _ = psi.shape\n", 951 | " \n", 952 | " grad_x, grad_W, grad_P, grad_S = gradient_param(seq, x, W, P, S)\n", 953 | "\n", 954 | " def llk(seq, x, W, P, S):\n", 955 | " log_psi = score(seq, x, W, P, S)\n", 956 | " spi = np.exp(log_psi)\n", 957 | " log_p = np.log(pro(seq, spi))\n", 958 | " return log_p\n", 959 | " \n", 960 | " # grad_x\n", 961 | " print('Check X')\n", 962 | " for i in range(len(x)):\n", 963 | " original = x[i]\n", 964 | " grad_1 = grad_x[i]\n", 965 | " \n", 966 | " # p1\n", 967 | " x[i] = original - delta\n", 968 | " p1 = llk(seq, x, W, P, S)\n", 969 | " \n", 970 | " # p2\n", 971 | " x[i] = original + delta\n", 972 | " p2 = llk(seq, x, W, P, S)\n", 973 | " \n", 974 | " x[i] = original\n", 975 | " grad_2 = (p2 - p1) / (2 * delta)\n", 976 | " \n", 977 | " diff = np.abs(grad_1 - grad_2) / np.abs(grad_2)\n", 978 | " if diff > toleration:\n", 979 | " print(\"%d, %.2e, %.2e, %.2e\" % (i, grad_1, grad_2, diff))\n", 980 | " \n", 981 | " # grad_W\n", 982 | " print('Check W')\n", 983 | " for i in range(len(W)):\n", 984 | " original = W[i]\n", 985 | " grad_1 = grad_W[i]\n", 986 | " \n", 987 | " # p1\n", 988 | " W[i] = original - delta\n", 989 | " p1 = llk(seq, x, W, P, S)\n", 990 | " \n", 991 | " # p2\n", 992 | " W[i] = original + delta\n", 993 | " p2 = llk(seq, x, W, P, S)\n", 994 | " \n", 995 | " W[i] = original\n", 996 | " grad_2 = (p2 - p1) / (2 * delta)\n", 997 | " \n", 998 | " diff = np.abs(grad_1 - grad_2) / np.abs(grad_2)\n", 999 | " if diff > toleration:\n", 1000 | " print(\"%d, %.2e, %.2e, %.2e\" % (i, grad_1, grad_2, diff))\n", 1001 | " \n", 1002 | " # grad_P\n", 1003 | " print('Check P')\n", 1004 | " for i in range(V):\n", 1005 | " for j in range(V):\n", 1006 | " original = P[i][j]\n", 1007 | " grad_1 = grad_P[i][j]\n", 1008 | " \n", 1009 | " # p1\n", 1010 | " P[i][j] = original - delta\n", 1011 | " p1 = llk(seq, x, W, P, S)\n", 1012 | " \n", 1013 | " # p2\n", 1014 | " P[i][j] = original + delta\n", 1015 | " p2 = llk(seq, x, W, P, S)\n", 1016 | " \n", 1017 | " P[i][j] = original\n", 1018 | " grad_2 = (p2 - p1) / (2 * delta)\n", 1019 | " \n", 1020 | " diff = np.abs(grad_1 - grad_2) / np.abs(grad_2)\n", 1021 | " if diff > toleration:\n", 1022 | " print(\"%d, %.2e, %.2e, %.2e\" % (i, grad_1, grad_2, diff))\n", 1023 | " \n", 1024 | " # grad_S\n", 1025 | " print('Check S')\n", 1026 | " for i in range(len(S)):\n", 1027 | " original = S[i]\n", 1028 | " grad_1 = grad_S[i]\n", 1029 | " \n", 1030 | " # p1\n", 1031 | " S[i] = original - delta\n", 1032 | " p1 = llk(seq, x, W, P, S)\n", 1033 | " \n", 1034 | " # p2\n", 1035 | " S[i] = original + delta\n", 1036 | " p2 = llk(seq, x, W, P, S)\n", 1037 | " \n", 1038 | " S[i] = original\n", 1039 | " grad_2 = (p2 - p1) / (2 * delta)\n", 1040 | " \n", 1041 | " diff = np.abs(grad_1 - grad_2) / np.abs(grad_2)\n", 1042 | " if diff > toleration:\n", 1043 | " print(\"%d, %.2e, %.2e, %.2e\" % (i, grad_1, grad_2, diff))\n", 1044 | " \n", 1045 | "np.random.seed(1111)\n", 1046 | "V, m = 5, 10\n", 1047 | "\n", 1048 | "seq = np.random.choice(V, m)\n", 1049 | "x = np.random.random(m)\n", 1050 | "W = np.random.random(V)\n", 1051 | "P = np.random.random([V, V])\n", 1052 | "S = np.random.random(V)\n", 1053 | "\n", 1054 | "check_grad(seq, x, W, P, S)\n" 1055 | ] 1056 | }, 1057 | { 1058 | "cell_type": "markdown", 1059 | "metadata": {}, 1060 | "source": [ 1061 | "## 3.2 Bi-LSTM + CRF\n", 1062 | "CRF 是强大的序列学习准则。配合双向循环神经网络(e.g. Bi-LSTM)的特征表征和学习能力,在许多序列学习任务上都取得了领先的结果[5~7]。\n", 1063 | "\n", 1064 | "基本模型如下:\n", 1065 | "![](https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=2428423495,1906169408&fm=27&gp=0.jpg)\n", 1066 | "**图4. Bi-LSTM CRF 模型**\n", 1067 | "\n", 1068 | "Bi-LSTM 对整个输入序列进行特征提取和建模,用非线性的模型建模发射得分;转移得分用另外的 $P$ 表示,作为 CRF 自身的参数。相对于常规的用于神经网络训练的目标函数,CRF 是带参数的损失函数。\n", 1069 | "\n", 1070 | "基于 pytorch 的 CRFLoss 实现见 [repo](https://github.com/DingKe/ml-tutorial/blob/master/crf/crf.py) 以及[3, 4],BiLSTM + CRF 的实现应用见[8]。" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "markdown", 1075 | "metadata": {}, 1076 | "source": [ 1077 | "# References\n", 1078 | "\n", 1079 | "1. Sutton and McCuallum. [An Introduction to Conditional Random Fields](http://homepages.inf.ed.ac.uk/csutton/publications/crftut-fnt.pdf).\n", 1080 | "2. Michael Collins.[The Forward-Backward Algorithm](http://www.cs.columbia.edu/~mcollins/fb.pdf).\n", 1081 | "3. [Pytorch CRF Forward and Viterbi Implementation](http://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html).\n", 1082 | "4. [BiLSTM-CRF on PyTorch](https://github.com/kaniblu/pytorch-bilstmcrf).\n", 1083 | "5. Collobert. [Deep Learning for Efficient Discriminative Parsing](http://ronan.collobert.com/pub/matos/2011_parsing_aistats.pdf).\n", 1084 | "6. Collobert et al. [Natural Language Processing (Almost) from Scratch](http://ronan.collobert.com/pub/matos/2011_nlp_jmlr.pdf).\n", 1085 | "7. Huang et al. [Bidirectional LSTM-CRF Models for Sequence Tagging](https://arxiv.org/abs/1508.01991).\n", 1086 | "8. [Bi-LSTM-CRF for NLP](https://github.com/UKPLab/emnlp2017-bilstm-cnn-crf)." 1087 | ] 1088 | } 1089 | ], 1090 | "metadata": { 1091 | "anaconda-cloud": {}, 1092 | "kernelspec": { 1093 | "display_name": "Python [Root]", 1094 | "language": "python", 1095 | "name": "Python [Root]" 1096 | }, 1097 | "language_info": { 1098 | "codemirror_mode": { 1099 | "name": "ipython", 1100 | "version": 2 1101 | }, 1102 | "file_extension": ".py", 1103 | "mimetype": "text/x-python", 1104 | "name": "python", 1105 | "nbconvert_exporter": "python", 1106 | "pygments_lexer": "ipython2", 1107 | "version": "2.7.12" 1108 | } 1109 | }, 1110 | "nbformat": 4, 1111 | "nbformat_minor": 0 1112 | } 1113 | -------------------------------------------------------------------------------- /crf/crf.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Pytorch CRF 3 | ''' 4 | import torch 5 | from torch.autograd import Function, Variable 6 | import torch.nn as nn 7 | 8 | 9 | class CRFLoss(nn.Module): 10 | 11 | def __init__(self, vocab_size, start=False, end=False, size_average=True): 12 | super(CRFLoss, self).__init__() 13 | self.vocab_size = vocab_size 14 | self.start = start 15 | self.end = end 16 | self.size_average = size_average 17 | 18 | self.P = nn.Parameter(torch.Tensor(vocab_size, vocab_size)) 19 | 20 | if self.start: 21 | self.S = nn.Parameter(torch.Tensor(vocab_size)) 22 | else: 23 | self.register_parameter('S', None) 24 | 25 | if self.end: 26 | self.E = nn.Parameter(torch.Tensor(vocab_size)) 27 | else: 28 | self.register_parameter('E', None) 29 | 30 | def reset_parameters(self): 31 | nn.init.normal(self.P.data, 0, 1) 32 | if self.S is not None: 33 | nn.init.constant(self.S, 0) 34 | if self.E is not None: 35 | nn.init.constant(self.E, 0) 36 | 37 | def forward(self, logits, labels, auto_grad=True): 38 | if auto_grad: 39 | return self._forward(logits, labels) 40 | else: 41 | return crf(logits, labels, self.P, self.S, 42 | self.E, self.size_average) 43 | 44 | def _forward(self, logits, labels): 45 | batch_size, seq_len, voc = logits.size() 46 | log_alpha = Variable(logits.data.new(batch_size, voc).fill_(0)) 47 | 48 | if self.S is not None: 49 | log_alpha = log_alpha + self.S.unsqueeze(0).expand(batch_size, voc) 50 | log_alpha = log_alpha + logits[:, 0, :] 51 | 52 | for t in range(1, seq_len): 53 | trans = self.P.unsqueeze(0).expand( 54 | batch_size, voc, voc) # transmit score 55 | emit = logits[:, t, :].unsqueeze(1).expand( 56 | batch_size, voc, voc) # emit score 57 | log_alpha_tm1 = log_alpha.unsqueeze(2).expand(batch_size, voc, voc) 58 | log_alpha = reduce_logsumexp(trans + emit + log_alpha_tm1, dim=1) 59 | 60 | if self.E is not None: 61 | log_Z = reduce_logsumexp( 62 | log_alpha + self.E.unsqueeze(0).expand(batch_size, voc), dim=1) 63 | else: 64 | log_Z = reduce_logsumexp(log_alpha, dim=1) 65 | 66 | # score for y 67 | labels_l = labels[:, :-1] 68 | expanded_size = labels_l.size() + (voc,) 69 | labels_l = labels_l.unsqueeze(-1).expand(expanded_size) 70 | 71 | labels_r = labels[:, 1:] 72 | labels_r = labels_r.unsqueeze(-1) 73 | 74 | P_row = self.P.unsqueeze(0).expand( 75 | batch_size, voc, voc).gather(1, labels_l) 76 | y_transmit_score = P_row.gather(2, labels_r).squeeze(-1) 77 | y_emit_score = logits.gather(2, labels.unsqueeze(2)).squeeze(-1) 78 | 79 | log_M = torch.sum(y_emit_score, dim=1) + \ 80 | torch.sum(y_transmit_score, dim=1) 81 | 82 | if self.S is not None: 83 | log_M = log_M + self.S.gather(0, labels[:, 0]) 84 | 85 | if self.E is not None: 86 | log_M = log_M + self.E.gather(0, labels[:, -1]) 87 | 88 | # negative likelihood 89 | nll = log_Z - log_M 90 | nll = nll.sum(0).view(1) 91 | 92 | if self.size_average: 93 | nll.div_(batch_size) 94 | 95 | return nll 96 | 97 | 98 | def reduce_logsumexp(input, dim=None): 99 | if dim is None: 100 | max_val = torch.max(input) 101 | ret = max_val + torch.log(torch.sum(torch.exp(input - max_val))) 102 | return ret 103 | else: 104 | max_val, _ = torch.max(input, dim=dim, keepdim=True) 105 | ret = max_val.squeeze(dim=dim) + \ 106 | torch.log(torch.sum(torch.exp(input - max_val), dim=dim)) 107 | return ret 108 | 109 | 110 | def logsumexp(a, b): 111 | max_val = torch.max(a, b) 112 | 113 | dtype = a.type() 114 | tmp = (b - a) * (a > b).type(dtype) + (a - b) * (a <= b).type(dtype) 115 | 116 | return max_val + torch.log1p(torch.exp(tmp)) 117 | 118 | 119 | def one_hot(size, index): 120 | ''' 121 | voc = size[-1] 122 | ret = index.expand(*size) == \ 123 | torch.arange(0, voc).type(torch.LongTensor).unsqueeze(0).expand(*size) 124 | return ret 125 | ''' 126 | mask = torch.LongTensor(*size).fill_(0) 127 | ret = mask.scatter_(1, index, 1) 128 | return ret 129 | 130 | 131 | def _crf_forward(logits, P, S=None, E=None): 132 | batch_size, seq_len, voc = logits.size() 133 | log_alpha = logits.new(batch_size, seq_len, voc).fill_(0) 134 | 135 | if S is not None: 136 | log_alpha[:, 0, :] = S.unsqueeze(0).expand(batch_size, voc) 137 | else: 138 | log_alpha[:, 0, :] = 0 139 | log_alpha[:, 0, :] += logits[:, 0, :] 140 | 141 | for t in range(1, seq_len): 142 | trans = P.unsqueeze(0).expand(batch_size, voc, voc) # transmit score 143 | emit = logits[:, t, :].unsqueeze(1).expand( 144 | batch_size, voc, voc) # emit score 145 | 146 | log_alpha_tm1 = log_alpha[:, t - 1, 147 | :].unsqueeze(2).expand(batch_size, voc, voc) 148 | log_alpha[:, t, :] = reduce_logsumexp( 149 | trans + emit + log_alpha_tm1, dim=1) 150 | 151 | return log_alpha 152 | 153 | 154 | def _crf_backward(logits, P, S=None, E=None): 155 | batch_size, seq_len, voc = logits.size() 156 | log_beta = logits.new(batch_size, seq_len, voc) 157 | 158 | if E is not None: 159 | log_beta[:, -1, :] = E.unsqueeze(0).expand(batch_size, voc) 160 | else: 161 | log_beta[:, -1, :] = 0 162 | 163 | for t in range(seq_len - 2, -1, -1): 164 | trans = P.unsqueeze(0).expand( 165 | batch_size, voc, voc) # transmit score 166 | emit = logits[:, t + 1, :].unsqueeze(1).expand( 167 | batch_size, voc, voc) # emit score 168 | log_beta_tp1 = log_beta[:, t + 1, 169 | :].unsqueeze(1).expand(batch_size, voc, voc) 170 | 171 | log_beta[:, t, :] = reduce_logsumexp( 172 | trans + emit + log_beta_tp1, dim=2) 173 | 174 | return log_beta 175 | 176 | 177 | class CRFF(Function): 178 | ''' 179 | ''' 180 | @staticmethod 181 | def forward(ctx, logits, labels, P, S=None, E=None, size_average=True): 182 | batch_size, seq_len, voc = logits.size() 183 | 184 | ctx.size_average = size_average 185 | ctx.log_alpha = log_alpha = _crf_forward(logits, P, S, E) 186 | ctx.S = S 187 | ctx.E = E 188 | 189 | # norm 190 | if E is not None: 191 | ctx.log_Z = log_Z = reduce_logsumexp( 192 | log_alpha[:, -1, :] + 193 | E.unsqueeze(0).expand(batch_size, voc), dim=1) 194 | else: 195 | ctx.log_Z = log_Z = reduce_logsumexp(log_alpha[:, -1, :], dim=1) 196 | 197 | # score for y 198 | labels_l = labels[:, :-1] 199 | expanded_size = labels_l.size() + (voc,) 200 | labels_l = labels_l.unsqueeze(-1).expand(expanded_size) 201 | 202 | labels_r = labels[:, 1:] 203 | labels_r = labels_r.unsqueeze(-1) 204 | 205 | P_row = P.unsqueeze(0).expand(batch_size, voc, voc).gather(1, labels_l) 206 | y_trans = P_row.gather(2, labels_r).squeeze(-1) 207 | y_emit = logits.gather(2, labels.unsqueeze(2)).squeeze(-1) 208 | 209 | log_M = torch.sum(y_emit, dim=1) + torch.sum(y_trans, dim=1) 210 | 211 | if S is not None: 212 | log_M += S.gather(0, labels[:, 0]) 213 | 214 | if E is not None: 215 | log_M += E.gather(0, labels[:, -1]) 216 | 217 | # negative likelihood 218 | nll = log_Z - log_M 219 | nll = nll.sum(0).view(1) 220 | 221 | if size_average: 222 | nll.div_(batch_size) 223 | 224 | ctx.save_for_backward(logits, labels, P) 225 | 226 | return nll 227 | 228 | @staticmethod 229 | def backward(ctx, output_grad): 230 | logits, labels, P = ctx.saved_variables 231 | logits = logits.data 232 | labels = labels.data 233 | P = P.data 234 | S, E = ctx.S, ctx.E 235 | 236 | batch_size, seq_len, voc = logits.size() 237 | dtype = output_grad.data.type() 238 | 239 | log_alpha = ctx.log_alpha 240 | log_beta = _crf_backward(logits, P, S, E) 241 | log_Z = ctx.log_Z.unsqueeze(-1) 242 | 243 | # storage for gradients 244 | logits_grad = Variable(logits.new(logits.size()).fill_(0)) 245 | P_grad = Variable(P.new(P.size()).fill_(0)) 246 | if S is not None: 247 | S_grad = Variable(S.new(S.size()).fill_(0)) 248 | else: 249 | S_grad = None 250 | if E is not None: 251 | E_grad = Variable(E.new(E.size()).fill_(0)) 252 | else: 253 | E_grad = None 254 | 255 | # end boundary 256 | if E_grad is not None: 257 | log_psi = E.unsqueeze(0).expand(batch_size, voc) 258 | delta = one_hot([batch_size, voc], labels[:, -1:]).type(dtype) 259 | delta_log_psi = torch.exp( 260 | log_alpha[:, -1, :] - log_Z + log_psi) - delta 261 | 262 | E_grad.data += delta_log_psi.sum(0) 263 | 264 | # normal cases 265 | for t in range(1, seq_len): 266 | for i in range(voc): 267 | emit = logits[:, t, :] 268 | trans = P[i, :].unsqueeze(0) 269 | log_psi = emit + trans 270 | 271 | left = (labels[:, t - 1] == 272 | i).unsqueeze(-1).expand(batch_size, voc) 273 | right = labels[:, t].unsqueeze(-1).expand(batch_size, voc) ==\ 274 | torch.arange(0, voc).type(torch.LongTensor).\ 275 | unsqueeze(0).expand(batch_size, voc) 276 | delta = (left * right).type(dtype) 277 | 278 | delta_log_psi = torch.exp(log_alpha[:, t - 1, i:i + 1] + 279 | log_beta[:, t, :] - 280 | log_Z + log_psi) - delta 281 | 282 | logits_grad.data[:, t, :] += delta_log_psi 283 | P_grad.data[i, :] += delta_log_psi.sum(0) 284 | 285 | # start boundary 286 | log_psi = logits[:, 0, :] + \ 287 | (S.unsqueeze(0) if S is not None else 0) 288 | delta = one_hot([batch_size, voc], labels[:, :1]).type(dtype) 289 | delta_log_psi = torch.exp(log_beta[:, 0, :] - log_Z + log_psi) - delta 290 | logits_grad.data[:, 0, :] += delta_log_psi 291 | if S_grad is not None: 292 | S_grad.data += delta_log_psi.sum(0) 293 | 294 | if ctx.size_average: 295 | logits_grad.data.div_(batch_size) 296 | P_grad.data.div_(batch_size) 297 | if S_grad is not None: 298 | S_grad.data.div_(batch_size) 299 | if E_grad is not None: 300 | E_grad.data.div_(batch_size) 301 | 302 | return logits_grad, None, P_grad, S_grad, E_grad, None 303 | 304 | 305 | crf = CRFF.apply 306 | 307 | 308 | def test_crf_forward(): 309 | import numpy as np 310 | 311 | logits = torch.Tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) 312 | P = torch.Tensor(np.ones([3, 3])) 313 | S = torch.Tensor([1, 2, 3]) 314 | E = None 315 | 316 | log_alpha = _crf_forward(logits, P, S, E) 317 | print(log_alpha) 318 | 319 | 320 | def test_crf_backward(): 321 | import numpy as np 322 | 323 | logits = torch.Tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) 324 | P = torch.Tensor(np.ones([3, 3])) 325 | S = None 326 | E = torch.Tensor([1, 2, 3]) 327 | E = None 328 | 329 | log_beta = _crf_backward(logits, P, S, E) 330 | print(log_beta) 331 | 332 | 333 | def test_forward(): 334 | batch_size = 3 335 | seq_len = 10 336 | voc = 7 337 | 338 | torch.manual_seed(1111) 339 | 340 | labels = torch.multinomial(torch.rand( 341 | batch_size, voc), seq_len, replacement=True) 342 | logits = torch.rand(batch_size, seq_len, voc) 343 | P = torch.randn(voc, voc) * 0.1 344 | S = torch.randn(voc) * 0.1 345 | E = torch.randn(voc) * 0.1 346 | 347 | logits = Variable(logits, requires_grad=True) 348 | labels = Variable(labels, requires_grad=True) 349 | P = Variable(P, requires_grad=True) 350 | S = Variable(S, requires_grad=True) 351 | E = Variable(E, requires_grad=True) 352 | 353 | nll = CRFF.apply(logits, labels, P, S, E) 354 | print(nll) 355 | 356 | 357 | def test_backward(): 358 | batch_size = 3 359 | seq_len = 10 360 | voc = 7 361 | 362 | torch.manual_seed(1111) 363 | 364 | labels = torch.multinomial(torch.rand( 365 | batch_size, voc), seq_len, replacement=True) 366 | logits = torch.rand(batch_size, seq_len, voc) 367 | P = torch.randn(voc, voc) * 0.1 368 | S = torch.randn(voc) * 0.1 369 | E = torch.randn(voc) * 0.1 370 | 371 | logits = Variable(logits, requires_grad=True) 372 | labels = Variable(labels, requires_grad=True) 373 | P = Variable(P, requires_grad=True) 374 | S = Variable(S, requires_grad=True) 375 | E = Variable(E, requires_grad=True) 376 | 377 | nll = CRFF.apply(logits, labels, P, S, E) 378 | nll.backward() 379 | 380 | print(logits.grad.data) 381 | print(P.grad.data) 382 | print(S.grad.data) 383 | print(E.grad.data) 384 | 385 | 386 | def test_grad(): 387 | import numpy as np 388 | 389 | batch_size = 3 390 | seq_len = 10 391 | voc = 7 392 | 393 | torch.manual_seed(1111) 394 | 395 | labels = torch.multinomial(torch.rand( 396 | batch_size, voc), seq_len, replacement=True) 397 | logits = torch.rand(batch_size, seq_len, voc) 398 | emit_score = torch.rand(batch_size, seq_len, voc) 399 | P = torch.randn(voc, voc) * 0.1 400 | S = torch.randn(voc) * 0.1 401 | E = torch.randn(voc) * 0.1 402 | 403 | logits = Variable(logits, requires_grad=True) 404 | labels = Variable(labels, requires_grad=True) 405 | P = Variable(P, requires_grad=True) 406 | S = Variable(S, requires_grad=True) 407 | E = Variable(E, requires_grad=True) 408 | 409 | nll = CRFF.apply(logits, labels, P, S, E) 410 | nll.backward() 411 | 412 | delta = 1e-3 413 | toleration = 5e-3 414 | 415 | # P 416 | print('P') 417 | for i in range(voc): 418 | for j in range(voc): 419 | V = P.data.numpy() 420 | 421 | o = V[i, j] 422 | grad_1 = P.grad.data.numpy()[i, j] 423 | 424 | V[i, j] = o + delta 425 | l1 = CRFF.apply(logits, labels, P, S, E).data.numpy()[0] 426 | V[i, j] = o - delta 427 | l2 = CRFF.apply(logits, labels, P, S, E).data.numpy()[0] 428 | V[i, j] = o 429 | 430 | grad_2 = (l1 - l2) / (2 * delta) 431 | 432 | diff = np.abs((grad_1 - grad_1)) 433 | if diff > toleration: 434 | print("%.2e, %.2e, %.2e" % (grad_1, grad_2, diff)) 435 | 436 | # logits 437 | print('logits') 438 | for i in range(batch_size): 439 | for j in range(seq_len): 440 | for k in range(voc): 441 | V = logits.data.numpy() 442 | 443 | o = V[i, j, k] 444 | grad_1 = logits.grad.data.numpy()[i, j, k] 445 | 446 | V[i, j, k] = o + delta 447 | l1 = CRFF.apply(logits, labels, P, S, E).data.numpy()[0] 448 | V[i, j, k] = o - delta 449 | l2 = CRFF.apply(logits, labels, P, S, E).data.numpy()[0] 450 | V[i, j, k] = o 451 | 452 | grad_2 = (l1 - l2) / (2 * delta) 453 | 454 | diff = np.abs((grad_1 - grad_2)) 455 | if diff > toleration: 456 | print("%.2e, %.2e, %.2e" % (grad_1, grad_2, diff)) 457 | 458 | 459 | def test_module(): 460 | import numpy as np 461 | torch.manual_seed(1111) 462 | 463 | batch_size = 3 464 | seq_len = 10 465 | voc = 7 466 | 467 | labels = torch.multinomial(torch.rand( 468 | batch_size, voc), seq_len, replacement=True) 469 | logits = torch.rand(batch_size, seq_len, voc) 470 | 471 | logits = Variable(logits, requires_grad=True) 472 | labels = Variable(labels) 473 | 474 | crf_model = CRFLoss(vocab_size=voc, start=True, end=True) 475 | crf_model.reset_parameters() 476 | 477 | l1 = crf_model(logits, labels, auto_grad=False) 478 | l1.backward() 479 | grads_1 = [logits.grad.data.clone()] +\ 480 | [param.grad.data.clone() for param in crf_model.parameters()] 481 | 482 | crf_model.zero_grad() 483 | logits.grad.data.fill_(0) 484 | l2 = crf_model(logits, labels, auto_grad=True) 485 | l2.backward() 486 | grads_2 = [logits.grad.data.clone()] +\ 487 | [param.grad.data.clone() for param in crf_model.parameters()] 488 | 489 | toleration = 5e-7 490 | delta = 1e-4 491 | for g1, g2 in zip(grads_1, grads_2): 492 | g1 = g1.view(-1) 493 | g2 = g2.view(-1) 494 | for i in range(g1.size(0)): 495 | # print('%.2e, %.2e, %.2e' % (g1[i], g2[i], g1[i] - g2[i])) 496 | pass 497 | 498 | print('\nOverview:') 499 | print('%.5e, %.5e, %.15e, %.5e' % (l1.data.sum(), l2.data.sum(), 500 | (l1 - l2).data.sum(), 501 | torch.sum(torch.abs(g1 - g2)))) 502 | 503 | # logits 504 | print('logits') 505 | grad_1 = grads_1[0] 506 | grad_2 = grads_2[0] 507 | for i in range(batch_size): 508 | for j in range(seq_len): 509 | for k in range(voc): 510 | V = logits.data.numpy() 511 | 512 | o = V[i, j, k] 513 | g1 = grad_1.numpy()[i, j, k] 514 | g2 = grad_2.numpy()[i, j, k] 515 | 516 | V[i, j, k] = o + delta 517 | l1 = crf_model(logits, labels).data.sum() 518 | V[i, j, k] = o - delta 519 | l2 = crf_model(logits, labels).data.sum() 520 | V[i, j, k] = o 521 | 522 | g3 = (l1 - l2) / (2 * delta) 523 | 524 | diff = np.abs((g2 - g1)) 525 | if diff > toleration: 526 | print("%.2e, %.2e, %.2e, %.2e" % (g1, g2, g3, diff)) 527 | 528 | # P 529 | print('P') 530 | grad_1 = grads_1[1] 531 | grad_2 = grads_2[1] 532 | for i in range(voc): 533 | for j in range(voc): 534 | V = crf_model.P.data.numpy() 535 | 536 | o = V[i, j] 537 | g1 = grad_1.numpy()[i, j] 538 | g2 = grad_2.numpy()[i, j] 539 | 540 | V[i, j] = o + delta 541 | l1 = crf_model(logits, labels).data.sum() 542 | V[i, j] = o - delta 543 | l2 = crf_model(logits, labels).data.sum() 544 | V[i, j] = o 545 | 546 | g3 = (l1 - l2) / (2 * delta) 547 | 548 | diff = np.abs((g2 - g1)) 549 | if diff > toleration: 550 | print("%.2e, %.2e, %.2e, %.2e" % (g1, g2, g3, diff)) 551 | 552 | # S 553 | print('S') 554 | grad_1 = grads_1[2] 555 | grad_2 = grads_2[2] 556 | for i in range(voc): 557 | V = crf_model.S.data.numpy() 558 | 559 | o = V[i] 560 | g1 = grad_1.numpy()[i] 561 | g2 = grad_2.numpy()[i] 562 | 563 | V[i] = o + delta 564 | l1 = crf_model(logits, labels).data.sum() 565 | V[i] = o - delta 566 | l2 = crf_model(logits, labels).data.sum() 567 | V[i] = o 568 | 569 | g3 = (l1 - l2) / (2 * delta) 570 | 571 | diff = np.abs((g2 - g1)) 572 | if diff > toleration: 573 | print("%.2e, %.2e, %.2e, %.2e" % (g1, g2, g3, diff)) 574 | 575 | # E 576 | print('E') 577 | grad_1 = grads_1[3] 578 | grad_2 = grads_2[3] 579 | for i in range(voc): 580 | V = crf_model.E.data.numpy() 581 | 582 | o = V[i] 583 | g1 = grad_1.numpy()[i] 584 | g2 = grad_2.numpy()[i] 585 | 586 | V[i] = o + delta 587 | l1 = crf_model(logits, labels).data.sum() 588 | V[i] = o - delta 589 | l2 = crf_model(logits, labels).data.sum() 590 | V[i] = o 591 | 592 | g3 = (l1 - l2) / (2 * delta) 593 | 594 | diff = np.abs((g2 - g1)) 595 | if diff > toleration: 596 | print("%.2e, %.2e, %.2e, %.2e" % (g1, g2, g3, diff)) 597 | 598 | 599 | if __name__ == '__main__': 600 | test_crf_forward() 601 | test_crf_backward() 602 | test_forward() 603 | test_backward() 604 | test_grad() 605 | test_module() 606 | -------------------------------------------------------------------------------- /ctc/CTC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "CTC( Connectionist Temporal Classification,连接时序分类)是一种用于序列建模的工具,其核心是定义了特殊的**目标函数/优化准则**[1]。\n", 8 | "\n", 9 | "> jupyter notebook 版见 [repo](https://github.com/DingKe/ml-tutorial/tree/master/ctc)." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "# 1. 算法\n", 18 | "这里大体根据 Alex Grave 的开山之作[1],讨论 CTC 的算法原理,并基于 numpy 从零实现 CTC 的推理及训练算法。\n", 19 | "\n", 20 | "## 1.1 序列问题形式化。\n", 21 | "序列问题可以形式化为如下函数:\n", 22 | "\n", 23 | "$$N_w: (R^m)^T \\rightarrow (R^n)^T$$\n", 24 | "其中,序列目标为字符串(词表大小为 $n$),即 $N_w$ 输出为 $n$ 维多项概率分布(e.g. 经过 softmax 处理)。\n", 25 | "\n", 26 | "网络输出为:$y = N_w(x)$,其中,$y_k^t$ $t$ 表示时刻第 $k$ 项的概率。\n", 27 | "\n", 28 | "![](https://distill.pub/2017/ctc/assets/full_collapse_from_audio.svg)\n", 29 | "**图1. 序列建模【[src](https://distill.pub/2017/ctc/)】**\n", 30 | "\n", 31 | "\n", 32 | "虽然并没为限定 $N_w$ 具体形式,下面为假设其了某种神经网络(e.g. RNN)。\n", 33 | "下面代码示例 toy $N_w$:" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "[[ 0.24654511 0.18837589 0.16937668 0.16757465 0.22812766]\n", 48 | " [ 0.25443629 0.14992236 0.22945293 0.17240658 0.19378184]\n", 49 | " [ 0.24134404 0.17179604 0.23572466 0.12994237 0.22119288]\n", 50 | " [ 0.27216255 0.13054313 0.2679252 0.14184499 0.18752413]\n", 51 | " [ 0.32558002 0.13485564 0.25228604 0.09743785 0.18984045]\n", 52 | " [ 0.23855586 0.14800386 0.23100255 0.17158135 0.21085638]\n", 53 | " [ 0.38534786 0.11524603 0.18220093 0.14617864 0.17102655]\n", 54 | " [ 0.21867406 0.18511892 0.21305488 0.16472572 0.21842642]\n", 55 | " [ 0.29856607 0.13646801 0.27196606 0.11562552 0.17737434]\n", 56 | " [ 0.242347 0.14102063 0.21716951 0.2355229 0.16393996]\n", 57 | " [ 0.26597326 0.10009752 0.23362892 0.24560198 0.15469832]\n", 58 | " [ 0.23337289 0.11918746 0.28540761 0.20197928 0.16005275]]\n", 59 | "[[ 1.]\n", 60 | " [ 1.]\n", 61 | " [ 1.]\n", 62 | " [ 1.]\n", 63 | " [ 1.]\n", 64 | " [ 1.]\n", 65 | " [ 1.]\n", 66 | " [ 1.]\n", 67 | " [ 1.]\n", 68 | " [ 1.]\n", 69 | " [ 1.]\n", 70 | " [ 1.]]\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "import numpy as np\n", 76 | "\n", 77 | "np.random.seed(1111)\n", 78 | "\n", 79 | "T, V = 12, 5\n", 80 | "m, n = 6, V\n", 81 | "\n", 82 | "x = np.random.random([T, m]) # T x m\n", 83 | "w = np.random.random([m, n]) # weights, m x n\n", 84 | "\n", 85 | "def softmax(logits):\n", 86 | " max_value = np.max(logits, axis=1, keepdims=True)\n", 87 | " exp = np.exp(logits - max_value)\n", 88 | " exp_sum = np.sum(exp, axis=1, keepdims=True)\n", 89 | " dist = exp / exp_sum\n", 90 | " return dist\n", 91 | "\n", 92 | "def toy_nw(x):\n", 93 | " y = np.matmul(x, w) # T x n \n", 94 | " y = softmax(y)\n", 95 | " return y\n", 96 | "\n", 97 | "y = toy_nw(x)\n", 98 | "print(y)\n", 99 | "print(y.sum(1, keepdims=True))" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## 1.2 align-free 变长映射\n", 107 | "上面的形式是输入和输出的一对一的映射。序列学习任务一般而言是多对多的映射关系(如语音识别中,上百帧输出可能仅对应若干音节或字符,并且每个输入和输出之间,也没有清楚的对应关系)。CTC 通过引入一个特殊的 blank 字符(用 % 表示),解决多对一映射问题。\n", 108 | "\n", 109 | "扩展原始词表 $L$ 为 $L^\\prime = L \\cup \\{\\text{blank}\\}$。对输出字符串,定义操作 $B$:1)合并连续的相同符号;2)去掉 blank 字符。\n", 110 | "\n", 111 | "例如,对于 “aa%bb%%cc”,应用 $B$,则实际上代表的是字符串 \"abc\"。同理“%a%b%cc%” 也同样代表 \"abc\"。\n", 112 | "$$\n", 113 | "B(aa\\%bb\\%\\%cc) = B(\\%a\\%b\\%cc\\%) = abc\n", 114 | "$$\n", 115 | "\n", 116 | "通过引入blank 及 $B$,可以实现了变长的映射。\n", 117 | "$$\n", 118 | "L^{\\prime T} \\rightarrow L^{\\le T}\n", 119 | "$$\n", 120 | "\n", 121 | "\t因为这个原因,CTC 只能建模输出长度小于输入长度的序列问题。" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "## 1.3 似然计算\n", 129 | "和大多数有监督学习一样,CTC 使用最大似然标准进行训练。\n", 130 | "\n", 131 | "给定输入 $x$,输出 $l$ 的条件概率为:\n", 132 | "$$\n", 133 | "p(l|x) = \\sum_{\\pi \\in B^{-1}(l)} p(\\pi|x)\n", 134 | "$$\n", 135 | "\n", 136 | "其中,$B^{-1}(l)$ 表示了长度为 $T$ 且示经过 $B$ 结果为 $l$ 字符串的集合。\n", 137 | "\n", 138 | "**CTC 假设输出的概率是(相对于输入)条件独立的**,因此有:\n", 139 | "$$p(\\pi|x) = \\prod y^t_{\\pi_t}, \\forall \\pi \\in L^{\\prime T}$$\n", 140 | "\n", 141 | "\n", 142 | "然而,直接按上式我们没有办理有效的计算似然值。下面用动态规划解决似然的计算及梯度计算, 涉及前向算法和后向算法。" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## 1.4 前向算法\n", 150 | "\n", 151 | "在前向及后向计算中,CTC 需要将输出字符串进行扩展。具体的,$(a_1,\\cdots,a_m)$ 每个字符之间及首尾分别插入 blank,即扩展为 $(\\%, a_1,\\%,a_2, \\%,\\cdots,\\%, a_m,\\%)$。下面的 $l$ 为原始字符串,$l^\\prime$ 指为扩展后的字符串。\n", 152 | "\n", 153 | "定义\n", 154 | "$$\n", 155 | "\\alpha_t(s) \\stackrel{def}{=} \\sum_{\\pi \\in N^T: B(\\pi_{1:t}) = l_{1:s}} \\prod_{t^\\prime=1}^t y^t_{\\pi^\\prime}\n", 156 | "$$\n", 157 | "\n", 158 | "显然有,\n", 159 | "\n", 160 | "$$\n", 161 | "\\begin{align}\n", 162 | "\\alpha_1(1) = y_b^1,\\\\\n", 163 | "\\alpha_1(2) = y_{l_1}^1,\\\\\n", 164 | "\\alpha_1(s) = 0, \\forall s > 2\n", 165 | "\\end{align}\n", 166 | "$$\n", 167 | "根据 $\\alpha$ 的定义,有如下递归关系:\n", 168 | "$$\n", 169 | "\\alpha_t(s) = \\{ \\begin{array}{l}\n", 170 | "(\\alpha_{t-1}(s)+\\alpha_{t-1}(s-1)) y^t_{l^\\prime_s},\\ \\ \\ if\\ l^\\prime_s = b \\ or\\ l_{s-2}^\\prime = l_s^{\\prime} \\\\\n", 171 | "(\\alpha_{t-1}(s)+\\alpha_{t-1}(s-1) + \\alpha_{t-1}(s-2)) y^t_{l^\\prime_s} \\ \\ otherwise\n", 172 | "\\end{array}\n", 173 | "$$" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "### 1.4.1 Case 2\n", 181 | "递归公式中 case 2 是一般的情形。如图所示,$t$ 时刻字符为 $s$ 为 blank 时,它可能由于两种情况扩展而来:1)重复上一字符,即上个字符也是 a,2)字符发生转换,即上个字符是非 a 的字符。第二种情况又分为两种情形,2.1)上一字符是 blank;2.2)a 由非 blank 字符直接跳转而来($B$) 操作中, blank 最终会被去掉,因此 blank 并不是必须的)。\n", 182 | "![](https://distill.pub/2017/ctc/assets/cost_regular.svg)\n", 183 | "**图2. 前向算法 Case 2 示例【[src](https://distill.pub/2017/ctc/)】**\n", 184 | "\n", 185 | "### 1.4.2 Case 1\n", 186 | "递归公式 case 1 是特殊的情形。\n", 187 | "如图所示,$t$ 时刻字符为 $s$ 为 blank 时,它只能由于两种情况扩展而来:1)重复上一字符,即上个字符也是 blank,2)字符发生转换,即上个字符是非 blank 字符。$t$ 时刻字符为 $s$ 为非 blank 时,类似于 case 2,但是这时两个相同字符之间的 blank 不能省略(否则无法区分\"aa\"和\"a\"),因此,也只有两种跳转情况。\n", 188 | "\n", 189 | "![](https://distill.pub/2017/ctc/assets/cost_no_skip.svg)\n", 190 | "**图3. 前向算法 Case 1 【[src](https://distill.pub/2017/ctc/)】**\n", 191 | "\n", 192 | "我们可以利用动态规划计算所有 $\\alpha$ 的值,算法时间和空间复杂度为 $O(T * L)$。\n", 193 | "\n", 194 | "似然的计算只涉及乘加运算,因此,CTC 的似然是可导的,可以尝试 tensorflow 或 pytorch 等具有自动求导功能的工具自动进行梯度计算。下面介绍如何手动高效的计算梯度。" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 2, 200 | "metadata": { 201 | "collapsed": false 202 | }, 203 | "outputs": [ 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "[[ 2.46545113e-01 1.67574654e-01 0.00000000e+00 0.00000000e+00\n", 209 | " 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 210 | " [ 6.27300235e-02 7.13969720e-02 4.26370730e-02 0.00000000e+00\n", 211 | " 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 212 | " [ 1.51395174e-02 1.74287803e-02 2.75214373e-02 5.54036251e-03\n", 213 | " 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", 214 | " [ 4.12040964e-03 4.61964998e-03 1.22337658e-02 4.68965079e-03\n", 215 | " 1.50787918e-03 1.03895167e-03 0.00000000e+00]\n", 216 | " [ 1.34152305e-03 8.51612635e-04 5.48713543e-03 1.64898136e-03\n", 217 | " 2.01779193e-03 1.37377693e-03 3.38261905e-04]\n", 218 | " [ 3.20028190e-04 3.76301179e-04 1.51214552e-03 1.22442454e-03\n", 219 | " 8.74730268e-04 1.06283215e-03 4.08416903e-04]\n", 220 | " [ 1.23322177e-04 1.01788478e-04 7.27708889e-04 4.00028082e-04\n", 221 | " 8.08904808e-04 5.40783712e-04 5.66942671e-04]\n", 222 | " [ 2.69673617e-05 3.70815141e-05 1.81389560e-04 1.85767281e-04\n", 223 | " 2.64362267e-04 3.82184328e-04 2.42231029e-04]\n", 224 | " [ 8.05153930e-06 7.40568461e-06 6.52280509e-05 4.24527009e-05\n", 225 | " 1.34393412e-04 1.47631121e-04 1.86429242e-04]\n", 226 | " [ 1.95126637e-06 3.64053019e-06 1.76025677e-05 2.53612828e-05\n", 227 | " 4.28581244e-05 5.31947855e-05 8.09585256e-05]\n", 228 | " [ 5.18984675e-07 1.37335633e-06 5.65009596e-06 1.05520069e-05\n", 229 | " 1.81445380e-05 1.87825719e-05 3.56811933e-05]\n", 230 | " [ 1.21116956e-07 3.82213679e-07 1.63908339e-06 3.27248912e-06\n", 231 | " 6.69699576e-06 7.59916314e-06 1.27103665e-05]]\n" 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "def forward(y, labels):\n", 237 | " T, V = y.shape\n", 238 | " L = len(labels)\n", 239 | " alpha = np.zeros([T, L])\n", 240 | "\n", 241 | " # init\n", 242 | " alpha[0, 0] = y[0, labels[0]]\n", 243 | " alpha[0, 1] = y[0, labels[1]]\n", 244 | "\n", 245 | " for t in range(1, T):\n", 246 | " for i in range(L):\n", 247 | " s = labels[i]\n", 248 | " \n", 249 | " a = alpha[t - 1, i] \n", 250 | " if i - 1 >= 0:\n", 251 | " a += alpha[t - 1, i - 1]\n", 252 | " if i - 2 >= 0 and s != 0 and s != labels[i - 2]:\n", 253 | " a += alpha[t - 1, i - 2]\n", 254 | " \n", 255 | " alpha[t, i] = a * y[t, s]\n", 256 | " \n", 257 | " return alpha\n", 258 | "\n", 259 | "labels = [0, 3, 0, 3, 0, 4, 0] # 0 for blank\n", 260 | "alpha = forward(y, labels)\n", 261 | "print(alpha)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "最后可以得到似然 $p(l|x) = \\alpha_T(|l^\\prime|) + \\alpha_T(|l^\\prime|-1)$。" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 3, 274 | "metadata": { 275 | "collapsed": false 276 | }, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "6.81811271177e-06\n" 283 | ] 284 | } 285 | ], 286 | "source": [ 287 | "p = alpha[-1, labels[-1]] + alpha[-1, labels[-2]]\n", 288 | "print(p)" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "## 1.5 后向计算\n", 296 | "类似于前向计算,我们定义后向计算。\n", 297 | "首先定义\n", 298 | "$$\n", 299 | "\\beta_t(s) \\stackrel{def}{=} \\sum_{\\pi \\in N^T: B(\\pi_{t:T}) = l_{s:|l|}} \\prod_{t^\\prime=t}^T y^t_{\\pi^\\prime}\n", 300 | "$$\n", 301 | "\n", 302 | "显然,\n", 303 | "$$\n", 304 | "\\begin{align}\n", 305 | "\\beta_T(|l^\\prime|) = y_b^T,\\\\\n", 306 | "\\beta_T(|l^\\prime|-1) = y_{l_{|l|}}^T,\\\\\n", 307 | "\\beta_T(s) = 0, \\forall s < |l^\\prime| - 1\n", 308 | "\\end{align}\n", 309 | "$$\n", 310 | "\n", 311 | "易得如下递归关系:\n", 312 | "$$\n", 313 | "\\beta_t(s) = \\{ \\begin{array}{l}\n", 314 | "(\\beta_{t+1}(s)+\\beta_{t+1}(s+1)) y^t_{l^\\prime_s},\\ \\ \\ if\\ l^\\prime_s = b \\ or\\ l_{s+2}^\\prime = l_s^{\\prime} \\\\\n", 315 | "(\\beta_{t+1}(s)+\\beta_{t+1}(s+1) + \\beta_{t+1}(s+2)) y^t_{l^\\prime_s} \n", 316 | "\\end{array}\n", 317 | "$$" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 4, 323 | "metadata": { 324 | "collapsed": false 325 | }, 326 | "outputs": [ 327 | { 328 | "name": "stdout", 329 | "output_type": "stream", 330 | "text": [ 331 | "[[ 1.25636660e-05 7.74586366e-06 8.69559539e-06 3.30990037e-06\n", 332 | " 2.41325357e-06 4.30516936e-07 1.21116956e-07]\n", 333 | " [ 3.00418145e-05 2.09170784e-05 2.53062822e-05 9.96351200e-06\n", 334 | " 8.39236521e-06 1.39591874e-06 4.91256769e-07]\n", 335 | " [ 7.14014755e-05 4.66705755e-05 7.46535563e-05 2.48066359e-05\n", 336 | " 2.77113594e-05 5.27279259e-06 1.93076535e-06]\n", 337 | " [ 1.69926001e-04 1.25923340e-04 2.33240296e-04 7.60839197e-05\n", 338 | " 9.89830489e-05 1.58379311e-05 8.00005392e-06]\n", 339 | " [ 4.20893778e-04 2.03461048e-04 6.84292101e-04 1.72696845e-04\n", 340 | " 3.08627225e-04 5.50636993e-05 2.93943967e-05]\n", 341 | " [ 4.81953899e-04 8.10796738e-04 1.27731424e-03 8.24448952e-04\n", 342 | " 7.48161143e-04 1.99769340e-04 9.02831714e-05]\n", 343 | " [ 9.80428697e-04 1.03986915e-03 3.68556718e-03 1.66879393e-03\n", 344 | " 2.56724754e-03 5.68961868e-04 3.78457146e-04]\n", 345 | " [ 2.40870506e-04 2.30339872e-03 4.81028886e-03 4.75397134e-03\n", 346 | " 4.31752827e-03 2.34462771e-03 9.82118206e-04]\n", 347 | " [ 0.00000000e+00 1.10150469e-03 1.28817322e-02 9.11579592e-03\n", 348 | " 1.35011919e-02 6.24293419e-03 4.49124231e-03]\n", 349 | " [ 0.00000000e+00 0.00000000e+00 9.52648414e-03 3.36188472e-02\n", 350 | " 2.50664437e-02 2.01536701e-02 1.50427081e-02]\n", 351 | " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 3.93092725e-02\n", 352 | " 4.25697510e-02 6.08622868e-02 6.20709492e-02]\n", 353 | " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00\n", 354 | " 0.00000000e+00 1.60052748e-01 2.33372894e-01]]\n" 355 | ] 356 | } 357 | ], 358 | "source": [ 359 | "def backward(y, labels):\n", 360 | " T, V = y.shape\n", 361 | " L = len(labels)\n", 362 | " beta = np.zeros([T, L])\n", 363 | "\n", 364 | " # init\n", 365 | " beta[-1, -1] = y[-1, labels[-1]]\n", 366 | " beta[-1, -2] = y[-1, labels[-2]]\n", 367 | "\n", 368 | " for t in range(T - 2, -1, -1):\n", 369 | " for i in range(L):\n", 370 | " s = labels[i]\n", 371 | " \n", 372 | " a = beta[t + 1, i] \n", 373 | " if i + 1 < L:\n", 374 | " a += beta[t + 1, i + 1]\n", 375 | " if i + 2 < L and s != 0 and s != labels[i + 2]:\n", 376 | " a += beta[t + 1, i + 2]\n", 377 | " \n", 378 | " beta[t, i] = a * y[t, s]\n", 379 | " \n", 380 | " return beta\n", 381 | "\n", 382 | "beta = backward(y, labels)\n", 383 | "print(beta)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "## 1.6 梯度计算\n", 391 | "下面,我们利用前向、后者计算的 $\\alpha$ 和 $\\beta$ 来计算梯度。\n", 392 | "\n", 393 | "根据 $\\alpha$、$\\beta$ 的定义,我们有:\n", 394 | "$$\\alpha_t(s)\\beta_t(s) = \\sum_{\\pi \\in B^{-1}(l):\\pi_t=l_s^\\prime} y^t_{l_s^\\prime} \\prod_{t=1}^T y^t_{\\pi_t} = y^t_{l_s^\\prime} \\cdot \\sum_{\\pi \\in B^{-1}(l):\\pi_t=l_s^\\prime} \\prod_{t=1}^T y^t_{\\pi_t}$$\n", 395 | "则\n", 396 | "$$\n", 397 | "\\frac{\\alpha_t(s)\\beta_t(s)}{ y^t_{l_s^\\prime}} = \\sum_{\\pi \\in B^{-1}(l):\\pi_t=l_s^\\prime} \\prod_{t=1}^T y^t_{\\pi_t} = \\sum_{\\pi \\in B^{-1}(l):\\pi_t=l_s^\\prime} p(\\pi|x) \n", 398 | "$$\n", 399 | "于是,可得似然\n", 400 | "$$\n", 401 | "p(l|x) = \\sum_{s=1}^{|l^\\prime|} \\sum_{\\pi \\in B^{-1}(l):\\pi_t=l_s^\\prime} p(\\pi|x) = \\sum_{s=1}^{|l^\\prime|} \\frac{\\alpha_t(s)\\beta_t(s)}{ y^t_{l_s^\\prime}} \n", 402 | "$$\n", 403 | "\n", 404 | "\n", 405 | "为计算 $\\frac{\\partial p(l|x)}{\\partial y^t_k}$,观察上式右端求各项,仅有 $s=k$ 的项包含 $y^t_k$,因此,其他项的偏导都为零,不用考虑。于是有:\n", 406 | "\n", 407 | "$$\n", 408 | "\\frac{\\partial p(l|x)}{\\partial y^t_k} = \\frac{\\partial \\frac{\\alpha_t(k)\\beta_t(k)}{ y^t_{k}} }{\\partial y^t_k} \n", 409 | "$$\n", 410 | "\n", 411 | "利用除法的求导准则有:\n", 412 | "$$\n", 413 | "\\frac{\\partial p(l|x)}{\\partial y^t_k} = \\frac{\\frac{2 \\cdot \\alpha_t(k)\\beta_t(k)}{ y^t_{k}} \\cdot y^t_{k} - \\alpha_t(k)\\beta_t(k) \\cdot 1}{{y^t_k}^2} = \\frac{\\alpha_t(k)\\beta_t(k)}{{y^t_k}^2}\n", 414 | "$$\n", 415 | "\n", 416 | "> 求导中,分子第一项是因为 $\\alpha(k)\\beta(k)$ 中包含为两个 $y^t_k$ 乘积项(即 ${y^t_k}^2$),其他均为与 $y^t_k$ 无关的常数。\n", 417 | "\n", 418 | "$l$ 中可能包含多个 $k$ 字符,它们计算的梯度要进行累加,因此,最后的梯度计算结果为:\n", 419 | "$$\n", 420 | "\\frac{\\partial p(l|x)}{\\partial y^t_k} = \\frac{1}{{y^t_k}^2} \\sum_{s \\in lab(l, k)} \\alpha_t(s)\\beta_t(s)\n", 421 | "$$\n", 422 | "其中,$lab(s)=\\{s: l_s^\\prime = k\\}$。\n", 423 | "\n", 424 | "一般我们优化似然函数的对数,因此,梯度计算如下:\n", 425 | "$$\n", 426 | "\\frac{\\partial \\ln(p(l|x))}{\\partial y^t_k} =\\frac{1}{p(l|x)} \\frac{\\partial p(l|x)}{\\partial y^t_k}\n", 427 | "$$\n", 428 | "其中,似然值在前向计算中已经求得: $p(l|x) = \\alpha_T(|l^\\prime|) + \\alpha_T(|l^\\prime|-1)$。\n", 429 | "\n", 430 | "对于给定训练集 $D$,待优化的目标函数为:\n", 431 | "$$\n", 432 | "O(D, N_w) = -\\sum_{(x,z)\\in D} \\ln(p(z|x))\n", 433 | "$$\n", 434 | "\n", 435 | "得到梯度后,我们可以利用任意优化方法(e.g. SGD, Adam)进行训练。" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 5, 441 | "metadata": { 442 | "collapsed": false 443 | }, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "[[ 2.50911241 0. 0. 2.27594441 0. ]\n", 450 | " [ 2.25397118 0. 0. 2.47384957 0. ]\n", 451 | " [ 2.65058465 0. 0. 2.77274592 0. ]\n", 452 | " [ 2.46136916 0. 0. 2.29678159 0.02303985]\n", 453 | " [ 2.300259 0. 0. 2.37548238 0.10334851]\n", 454 | " [ 2.40271071 0. 0. 2.19860276 0.23513657]\n", 455 | " [ 1.68914157 0. 0. 1.78214377 0.51794046]\n", 456 | " [ 2.32536762 0. 0. 1.75750877 0.92477606]\n", 457 | " [ 1.92883907 0. 0. 1.45529832 1.44239844]\n", 458 | " [ 2.06219335 0. 0. 0.7568118 1.96405515]\n", 459 | " [ 2.07914466 0. 0. 0.33858403 2.35197258]\n", 460 | " [ 2.6816852 0. 0. 0. 2.3377753 ]]\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "def gradient(y, labels):\n", 466 | " T, V = y.shape\n", 467 | " L = len(labels)\n", 468 | " \n", 469 | " alpha = forward(y, labels)\n", 470 | " beta = backward(y, labels)\n", 471 | " p = alpha[-1, -1] + alpha[-1, -2]\n", 472 | " \n", 473 | " grad = np.zeros([T, V])\n", 474 | " for t in range(T):\n", 475 | " for s in range(V):\n", 476 | " lab = [i for i, c in enumerate(labels) if c == s]\n", 477 | " for i in lab:\n", 478 | " grad[t, s] += alpha[t, i] * beta[t, i] \n", 479 | " grad[t, s] /= y[t, s] ** 2\n", 480 | " \n", 481 | " grad /= p\n", 482 | " return grad\n", 483 | " \n", 484 | "grad = gradient(y, labels)\n", 485 | "print(grad)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "metadata": {}, 491 | "source": [ 492 | "将基于前向-后向算法得到梯度与基于数值的梯度比较,以验证实现的正确性。" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 6, 498 | "metadata": { 499 | "collapsed": false 500 | }, 501 | "outputs": [ 502 | { 503 | "name": "stdout", 504 | "output_type": "stream", 505 | "text": [ 506 | "1e-05\n", 507 | "1e-06\n", 508 | "[0, 3]:3.91e-06\n", 509 | "[1, 0]:3.61e-06\n", 510 | "[1, 3]:2.66e-06\n", 511 | "[2, 0]:2.67e-06\n", 512 | "[2, 3]:3.88e-06\n", 513 | "[3, 0]:4.71e-06\n", 514 | "[3, 3]:3.39e-06\n", 515 | "[4, 0]:1.24e-06\n", 516 | "[4, 3]:4.79e-06\n", 517 | "[5, 0]:1.57e-06\n", 518 | "[5, 3]:2.98e-06\n", 519 | "[6, 0]:5.03e-06\n", 520 | "[6, 3]:4.89e-06\n", 521 | "[7, 0]:1.05e-06\n", 522 | "[7, 4]:4.19e-06\n", 523 | "[8, 4]:5.57e-06\n", 524 | "[9, 0]:5.95e-06\n", 525 | "[9, 3]:3.85e-06\n", 526 | "[10, 0]:1.09e-06\n", 527 | "[10, 3]:1.53e-06\n", 528 | "[10, 4]:3.82e-06\n" 529 | ] 530 | } 531 | ], 532 | "source": [ 533 | "def check_grad(y, labels, w=-1, v=-1, toleration=1e-3):\n", 534 | " grad_1 = gradient(y, labels)[w, v]\n", 535 | " \n", 536 | " delta = 1e-10\n", 537 | " original = y[w, v]\n", 538 | " \n", 539 | " y[w, v] = original + delta\n", 540 | " alpha = forward(y, labels)\n", 541 | " log_p1 = np.log(alpha[-1, -1] + alpha[-1, -2])\n", 542 | " \n", 543 | " y[w, v] = original - delta\n", 544 | " alpha = forward(y, labels)\n", 545 | " log_p2 = np.log(alpha[-1, -1] + alpha[-1, -2])\n", 546 | " \n", 547 | " y[w, v] = original\n", 548 | " \n", 549 | " grad_2 = (log_p1 - log_p2) / (2 * delta)\n", 550 | " if np.abs(grad_1 - grad_2) > toleration:\n", 551 | " print('[%d, %d]:%.2e' % (w, v, np.abs(grad_1 - grad_2)))\n", 552 | "\n", 553 | "for toleration in [1e-5, 1e-6]:\n", 554 | " print('%.e' % toleration)\n", 555 | " for w in range(y.shape[0]):\n", 556 | " for v in range(y.shape[1]):\n", 557 | " check_grad(y, labels, w, v, toleration)\n", 558 | " " 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "可以看到,前向-后向及数值梯度两种方法计算的梯度差异都在 1e-5 以下,误差最多在 1e-6 的量级。这初步验证了前向-后向梯度计算方法原理和实现的正确性。" 566 | ] 567 | }, 568 | { 569 | "cell_type": "markdown", 570 | "metadata": {}, 571 | "source": [ 572 | "## 1.7 logits 梯度\n", 573 | "在实际训练中,为了计算方便,可以将 CTC 和 softmax 的梯度计算合并,公式如下:\n", 574 | "$$\n", 575 | "\\frac{\\partial \\ln(p(l|x))}{\\partial y^t_k} = y^t_k - \\frac{1}{y^t_k \\cdot p(l|x)} \\sum_{s \\in lab(l, k)} \\alpha_t(s)\\beta_t(s)\n", 576 | "$$\n", 577 | "\n", 578 | "这是因为,softmax 的梯度反传公式为:\n", 579 | "\n", 580 | "$$\n", 581 | "\\frac{\\partial \\ln(p(l|x))}{\\partial u^t_k} = y^t_k (\\frac{\\partial \\ln(p(l|x))}{\\partial y^t_k} - \\sum_{j=1}^{V} \\frac{\\partial \\ln(p(l|x))}{\\partial y^t_j} y^t_j)\n", 582 | "$$\n", 583 | "\n", 584 | "接合上面两式,有:\n", 585 | "$$\n", 586 | "\\frac{\\partial \\ln(p(l|x))}{\\partial u^t_k} = \\frac{1}{y^t_k p(l|x)} \\sum_{s \\in lab(l, k)} \\alpha_t(s)\\beta_t(s) - y^t_k\n", 587 | "$$" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": 7, 593 | "metadata": { 594 | "collapsed": false 595 | }, 596 | "outputs": [ 597 | { 598 | "name": "stdout", 599 | "output_type": "stream", 600 | "text": [ 601 | "1.59941504485e-15\n" 602 | ] 603 | } 604 | ], 605 | "source": [ 606 | "def gradient_logits_naive(y, labels):\n", 607 | " '''\n", 608 | " gradient by back propagation\n", 609 | " '''\n", 610 | " y_grad = gradient(y, labels)\n", 611 | " \n", 612 | " sum_y_grad = np.sum(y_grad * y, axis=1, keepdims=True)\n", 613 | " u_grad = y * (y_grad - sum_y_grad) \n", 614 | " \n", 615 | " return u_grad\n", 616 | "\n", 617 | "def gradient_logits(y, labels):\n", 618 | " '''\n", 619 | " '''\n", 620 | " T, V = y.shape\n", 621 | " L = len(labels)\n", 622 | " \n", 623 | " alpha = forward(y, labels)\n", 624 | " beta = backward(y, labels)\n", 625 | " p = alpha[-1, -1] + alpha[-1, -2]\n", 626 | " \n", 627 | " u_grad = np.zeros([T, V])\n", 628 | " for t in range(T):\n", 629 | " for s in range(V):\n", 630 | " lab = [i for i, c in enumerate(labels) if c == s]\n", 631 | " for i in lab:\n", 632 | " u_grad[t, s] += alpha[t, i] * beta[t, i] \n", 633 | " u_grad[t, s] /= y[t, s] * p\n", 634 | " \n", 635 | " u_grad -= y\n", 636 | " return u_grad\n", 637 | " \n", 638 | "grad_l = gradient_logits_naive(y, labels)\n", 639 | "grad_2 = gradient_logits(y, labels)\n", 640 | "\n", 641 | "print(np.sum(np.abs(grad_l - grad_2)))" 642 | ] 643 | }, 644 | { 645 | "cell_type": "markdown", 646 | "metadata": {}, 647 | "source": [ 648 | "同上,我们利用数值梯度来初步检验梯度计算的正确性:" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 8, 654 | "metadata": { 655 | "collapsed": false 656 | }, 657 | "outputs": [ 658 | { 659 | "name": "stdout", 660 | "output_type": "stream", 661 | "text": [ 662 | "1e-05\n", 663 | "[0, 6]:-8.00e-02, -8.00e-02, 1.03e-05\n", 664 | "[4, 0]:4.29e-01, 4.29e-01, 1.10e-05\n", 665 | "[5, 6]:-7.59e-02, -7.59e-02, 1.22e-05\n", 666 | "[6, 2]:-1.38e-01, -1.38e-01, 1.23e-05\n", 667 | "[6, 3]:3.33e-01, 3.33e-01, 1.02e-05\n", 668 | "1e-06\n", 669 | "[0, 0]:3.88e-01, 3.88e-01, 7.03e-06\n", 670 | "[0, 1]:-1.59e-01, -1.59e-01, 2.78e-06\n", 671 | "[0, 2]:-8.89e-02, -8.89e-02, 3.47e-06\n", 672 | "[0, 3]:4.57e-01, 4.57e-01, 1.64e-06\n", 673 | "[0, 4]:-6.32e-02, -6.32e-02, 7.19e-06\n", 674 | "[0, 5]:-7.98e-02, -7.98e-02, 5.46e-06\n", 675 | "[0, 6]:-8.00e-02, -8.00e-02, 1.03e-05\n", 676 | "[0, 7]:-1.32e-01, -1.32e-01, 2.21e-06\n", 677 | "[0, 8]:-1.04e-01, -1.04e-01, 7.75e-06\n", 678 | "[0, 9]:-1.38e-01, -1.38e-01, 5.95e-06\n", 679 | "[1, 0]:3.41e-01, 3.41e-01, 2.79e-06\n", 680 | "[1, 1]:-1.18e-01, -1.18e-01, 6.08e-06\n", 681 | "[1, 3]:5.04e-01, 5.04e-01, 4.06e-06\n", 682 | "[1, 4]:-9.96e-02, -9.96e-02, 5.77e-06\n", 683 | "[1, 5]:-8.22e-02, -8.22e-02, 4.03e-06\n", 684 | "[1, 6]:-9.46e-02, -9.46e-02, 4.49e-06\n", 685 | "[1, 7]:-1.49e-01, -1.49e-01, 3.96e-06\n", 686 | "[1, 8]:-1.24e-01, -1.24e-01, 4.96e-06\n", 687 | "[1, 9]:-7.48e-02, -7.47e-02, 5.94e-06\n", 688 | "[2, 0]:3.29e-01, 3.29e-01, 3.47e-06\n", 689 | "[2, 1]:-9.42e-02, -9.42e-02, 1.63e-06\n", 690 | "[2, 2]:-9.17e-02, -9.17e-02, 4.47e-06\n", 691 | "[2, 3]:4.50e-01, 4.50e-01, 2.14e-06\n", 692 | "[2, 5]:-1.07e-01, -1.07e-01, 6.33e-06\n", 693 | "[2, 6]:-5.42e-02, -5.42e-02, 1.71e-06\n", 694 | "[2, 7]:-9.68e-02, -9.68e-02, 7.69e-06\n", 695 | "[2, 9]:-1.21e-01, -1.21e-01, 9.06e-06\n", 696 | "[3, 0]:4.42e-01, 4.42e-01, 9.21e-06\n", 697 | "[3, 1]:-6.71e-02, -6.71e-02, 5.75e-06\n", 698 | "[3, 2]:-1.16e-01, -1.16e-01, 5.26e-06\n", 699 | "[3, 3]:4.03e-01, 4.03e-01, 6.39e-06\n", 700 | "[3, 4]:-1.07e-01, -1.07e-01, 2.42e-06\n", 701 | "[3, 5]:-1.25e-01, -1.25e-01, 8.90e-06\n", 702 | "[3, 6]:-1.17e-01, -1.17e-01, 2.08e-06\n", 703 | "[3, 7]:-1.32e-01, -1.32e-01, 2.21e-06\n", 704 | "[3, 8]:-6.90e-02, -6.90e-02, 1.72e-06\n", 705 | "[3, 9]:-1.13e-01, -1.13e-01, 6.68e-06\n", 706 | "[4, 0]:4.29e-01, 4.29e-01, 1.10e-05\n", 707 | "[4, 1]:-7.17e-02, -7.17e-02, 7.76e-06\n", 708 | "[4, 3]:3.25e-01, 3.25e-01, 2.93e-06\n", 709 | "[4, 4]:-5.91e-02, -5.91e-02, 1.88e-06\n", 710 | "[4, 5]:-7.78e-02, -7.78e-02, 5.82e-06\n", 711 | "[4, 6]:-9.08e-02, -9.08e-02, 8.04e-06\n", 712 | "[4, 7]:-1.12e-01, -1.12e-01, 2.11e-06\n", 713 | "[4, 8]:-1.26e-01, -1.26e-01, 5.92e-06\n", 714 | "[4, 9]:-1.40e-01, -1.40e-01, 4.89e-06\n", 715 | "[5, 0]:1.86e-01, 1.86e-01, 6.60e-06\n", 716 | "[5, 2]:-9.52e-02, -9.52e-02, 5.01e-06\n", 717 | "[5, 3]:4.97e-01, 4.97e-01, 6.92e-06\n", 718 | "[5, 4]:2.50e-02, 2.50e-02, 5.20e-06\n", 719 | "[5, 5]:-7.93e-02, -7.93e-02, 2.78e-06\n", 720 | "[5, 6]:-7.59e-02, -7.59e-02, 1.22e-05\n", 721 | "[5, 7]:-1.05e-01, -1.05e-01, 9.53e-06\n", 722 | "[5, 8]:-1.25e-01, -1.25e-01, 9.07e-06\n", 723 | "[5, 9]:-6.74e-02, -6.74e-02, 3.08e-06\n", 724 | "[6, 0]:1.49e-01, 1.49e-01, 5.17e-06\n", 725 | "[6, 1]:-6.67e-02, -6.66e-02, 4.40e-06\n", 726 | "[6, 2]:-1.38e-01, -1.38e-01, 1.23e-05\n", 727 | "[6, 3]:3.33e-01, 3.33e-01, 1.02e-05\n", 728 | "[6, 4]:1.72e-01, 1.72e-01, 7.79e-06\n", 729 | "[6, 6]:-6.73e-02, -6.73e-02, 7.13e-06\n", 730 | "[6, 7]:-6.08e-02, -6.08e-02, 8.58e-06\n", 731 | "[6, 8]:-9.98e-02, -9.98e-02, 6.14e-06\n", 732 | "[6, 9]:-1.31e-01, -1.31e-01, 8.89e-06\n", 733 | "[7, 0]:2.39e-01, 2.39e-01, 2.53e-06\n", 734 | "[7, 2]:-1.61e-01, -1.61e-01, 1.78e-06\n", 735 | "[7, 3]:1.02e-01, 1.02e-01, 4.73e-06\n", 736 | "[7, 4]:3.63e-01, 3.63e-01, 7.51e-06\n", 737 | "[7, 5]:-6.02e-02, -6.02e-02, 4.63e-06\n", 738 | "[7, 6]:-1.05e-01, -1.05e-01, 3.29e-06\n", 739 | "[7, 8]:-8.83e-02, -8.83e-02, 4.16e-06\n", 740 | "[7, 9]:-6.05e-02, -6.05e-02, 2.89e-06\n", 741 | "[8, 0]:2.92e-01, 2.92e-01, 1.38e-06\n", 742 | "[8, 1]:-9.70e-02, -9.70e-02, 2.27e-06\n", 743 | "[8, 2]:-9.87e-02, -9.87e-02, 4.36e-06\n", 744 | "[8, 4]:5.07e-01, 5.07e-01, 2.30e-06\n", 745 | "[8, 5]:-1.28e-01, -1.28e-01, 7.41e-06\n", 746 | "[8, 6]:-1.09e-01, -1.09e-01, 6.32e-06\n", 747 | "[8, 7]:-1.20e-01, -1.20e-01, 6.01e-06\n", 748 | "[8, 8]:-7.00e-02, -7.00e-02, 1.44e-06\n", 749 | "[8, 9]:-1.55e-01, -1.55e-01, 2.47e-06\n", 750 | "[9, 0]:4.90e-01, 4.90e-01, 2.04e-06\n", 751 | "[9, 1]:-7.69e-02, -7.69e-02, 1.54e-06\n", 752 | "[9, 2]:-9.72e-02, -9.72e-02, 9.59e-06\n", 753 | "[9, 4]:2.61e-01, 2.61e-01, 2.16e-06\n", 754 | "[9, 5]:-9.27e-02, -9.27e-02, 5.07e-06\n", 755 | "[9, 6]:-7.70e-02, -7.70e-02, 1.03e-06\n", 756 | "[9, 7]:-8.30e-02, -8.30e-02, 5.42e-06\n", 757 | "[9, 8]:-1.17e-01, -1.17e-01, 4.26e-06\n", 758 | "[9, 9]:-1.39e-01, -1.39e-01, 3.53e-06\n" 759 | ] 760 | } 761 | ], 762 | "source": [ 763 | "def check_grad_logits(x, labels, w=-1, v=-1, toleration=1e-3):\n", 764 | " grad_1 = gradient_logits(softmax(x), labels)[w, v]\n", 765 | " \n", 766 | " delta = 1e-10\n", 767 | " original = x[w, v]\n", 768 | " \n", 769 | " x[w, v] = original + delta\n", 770 | " y = softmax(x)\n", 771 | " alpha = forward(y, labels)\n", 772 | " log_p1 = np.log(alpha[-1, -1] + alpha[-1, -2])\n", 773 | " \n", 774 | " x[w, v] = original - delta\n", 775 | " y = softmax(x)\n", 776 | " alpha = forward(y, labels)\n", 777 | " log_p2 = np.log(alpha[-1, -1] + alpha[-1, -2])\n", 778 | " \n", 779 | " x[w, v] = original\n", 780 | " \n", 781 | " grad_2 = (log_p1 - log_p2) / (2 * delta)\n", 782 | " if np.abs(grad_1 - grad_2) > toleration:\n", 783 | " print('[%d, %d]:%.2e, %.2e, %.2e' % (w, v, grad_1, grad_2, np.abs(grad_1 - grad_2)))\n", 784 | "\n", 785 | "np.random.seed(1111)\n", 786 | "x = np.random.random([10, 10])\n", 787 | "for toleration in [1e-5, 1e-6]:\n", 788 | " print('%.e' % toleration)\n", 789 | " for w in range(x.shape[0]):\n", 790 | " for v in range(x.shape[1]):\n", 791 | " check_grad_logits(x, labels, w, v, toleration)" 792 | ] 793 | }, 794 | { 795 | "cell_type": "markdown", 796 | "metadata": {}, 797 | "source": [ 798 | "# 2. 数值稳定性\n", 799 | "\n", 800 | "CTC 的训练过程面临数值下溢的风险,特别是序列较大的情况下。下面介绍两种数值上稳定的工程优化方法:1)log 域(许多 CRF 实现的常用方法);2)scale 技巧(原始论文 [1] 使用的方法)。\n", 801 | "\n" 802 | ] 803 | }, 804 | { 805 | "cell_type": "markdown", 806 | "metadata": {}, 807 | "source": [ 808 | "## 2.1 log 域计算\n", 809 | "\n", 810 | "log 计算涉及 logsumexp 操作。\n", 811 | "[经验表明](https://github.com/baidu-research/warp-ctc),在 log 域计算,即使使用单精度,也表现出良好的数值稳定性,可以有效避免下溢的风险。稳定性的代价是增加了运算的复杂性——原始实现只涉及乘加运算,log 域实现则需要对数和指数运算。" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": 9, 817 | "metadata": { 818 | "collapsed": false 819 | }, 820 | "outputs": [], 821 | "source": [ 822 | "ninf = -np.float('inf')\n", 823 | "\n", 824 | "def _logsumexp(a, b):\n", 825 | " '''\n", 826 | " np.log(np.exp(a) + np.exp(b))\n", 827 | "\n", 828 | " '''\n", 829 | " \n", 830 | " if a < b:\n", 831 | " a, b = b, a\n", 832 | " \n", 833 | " if b == ninf:\n", 834 | " return a\n", 835 | " else:\n", 836 | " return a + np.log(1 + np.exp(b - a)) \n", 837 | " \n", 838 | "def logsumexp(*args):\n", 839 | " '''\n", 840 | " from scipy.special import logsumexp\n", 841 | " logsumexp(args)\n", 842 | " '''\n", 843 | " res = args[0]\n", 844 | " for e in args[1:]:\n", 845 | " res = _logsumexp(res, e)\n", 846 | " return res" 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": {}, 852 | "source": [ 853 | "### 2.1.1 log 域前向算法\n", 854 | "基于 log 的前向算法实现如下:" 855 | ] 856 | }, 857 | { 858 | "cell_type": "code", 859 | "execution_count": 10, 860 | "metadata": { 861 | "collapsed": false 862 | }, 863 | "outputs": [ 864 | { 865 | "name": "stdout", 866 | "output_type": "stream", 867 | "text": [ 868 | "8.60881935942e-17\n" 869 | ] 870 | } 871 | ], 872 | "source": [ 873 | "def forward_log(log_y, labels):\n", 874 | " T, V = log_y.shape\n", 875 | " L = len(labels)\n", 876 | " log_alpha = np.ones([T, L]) * ninf\n", 877 | "\n", 878 | " # init\n", 879 | " log_alpha[0, 0] = log_y[0, labels[0]]\n", 880 | " log_alpha[0, 1] = log_y[0, labels[1]]\n", 881 | "\n", 882 | " for t in range(1, T):\n", 883 | " for i in range(L):\n", 884 | " s = labels[i]\n", 885 | " \n", 886 | " a = log_alpha[t - 1, i]\n", 887 | " if i - 1 >= 0:\n", 888 | " a = logsumexp(a, log_alpha[t - 1, i - 1])\n", 889 | " if i - 2 >= 0 and s != 0 and s != labels[i - 2]:\n", 890 | " a = logsumexp(a, log_alpha[t - 1, i - 2])\n", 891 | " \n", 892 | " log_alpha[t, i] = a + log_y[t, s]\n", 893 | " \n", 894 | " return log_alpha\n", 895 | "\n", 896 | "log_alpha = forward_log(np.log(y), labels)\n", 897 | "alpha = forward(y, labels)\n", 898 | "print(np.sum(np.abs(np.exp(log_alpha) - alpha)))" 899 | ] 900 | }, 901 | { 902 | "cell_type": "markdown", 903 | "metadata": {}, 904 | "source": [ 905 | "### 2.1.2 log 域后向算法\n", 906 | "基于 log 的后向算法实现如下:" 907 | ] 908 | }, 909 | { 910 | "cell_type": "code", 911 | "execution_count": 11, 912 | "metadata": { 913 | "collapsed": false 914 | }, 915 | "outputs": [ 916 | { 917 | "name": "stdout", 918 | "output_type": "stream", 919 | "text": [ 920 | "1.10399945005e-16\n" 921 | ] 922 | } 923 | ], 924 | "source": [ 925 | "def backward_log(log_y, labels):\n", 926 | " T, V = log_y.shape\n", 927 | " L = len(labels)\n", 928 | " log_beta = np.ones([T, L]) * ninf\n", 929 | "\n", 930 | " # init\n", 931 | " log_beta[-1, -1] = log_y[-1, labels[-1]]\n", 932 | " log_beta[-1, -2] = log_y[-1, labels[-2]]\n", 933 | "\n", 934 | " for t in range(T - 2, -1, -1):\n", 935 | " for i in range(L):\n", 936 | " s = labels[i]\n", 937 | " \n", 938 | " a = log_beta[t + 1, i] \n", 939 | " if i + 1 < L:\n", 940 | " a = logsumexp(a, log_beta[t + 1, i + 1])\n", 941 | " if i + 2 < L and s != 0 and s != labels[i + 2]:\n", 942 | " a = logsumexp(a, log_beta[t + 1, i + 2])\n", 943 | " \n", 944 | " log_beta[t, i] = a + log_y[t, s]\n", 945 | " \n", 946 | " return log_beta\n", 947 | "\n", 948 | "log_beta = backward_log(np.log(y), labels)\n", 949 | "beta = backward(y, labels)\n", 950 | "print(np.sum(np.abs(np.exp(log_beta) - beta)))" 951 | ] 952 | }, 953 | { 954 | "cell_type": "markdown", 955 | "metadata": {}, 956 | "source": [ 957 | "### 2.1.3 log 域梯度计算\n", 958 | "在前向、后向基础上,也可以在 log 域上计算梯度。" 959 | ] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": 12, 964 | "metadata": { 965 | "collapsed": false 966 | }, 967 | "outputs": [ 968 | { 969 | "name": "stdout", 970 | "output_type": "stream", 971 | "text": [ 972 | "4.97588081849e-14\n" 973 | ] 974 | } 975 | ], 976 | "source": [ 977 | "def gradient_log(log_y, labels):\n", 978 | " T, V = log_y.shape\n", 979 | " L = len(labels)\n", 980 | " \n", 981 | " log_alpha = forward_log(log_y, labels)\n", 982 | " log_beta = backward_log(log_y, labels)\n", 983 | " log_p = logsumexp(log_alpha[-1, -1], log_alpha[-1, -2])\n", 984 | " \n", 985 | " log_grad = np.ones([T, V]) * ninf\n", 986 | " for t in range(T):\n", 987 | " for s in range(V):\n", 988 | " lab = [i for i, c in enumerate(labels) if c == s]\n", 989 | " for i in lab:\n", 990 | " log_grad[t, s] = logsumexp(log_grad[t, s], log_alpha[t, i] + log_beta[t, i]) \n", 991 | " log_grad[t, s] -= 2 * log_y[t, s]\n", 992 | " \n", 993 | " log_grad -= log_p\n", 994 | " return log_grad\n", 995 | " \n", 996 | "log_grad = gradient_log(np.log(y), labels)\n", 997 | "grad = gradient(y, labels)\n", 998 | "#print(log_grad)\n", 999 | "#print(grad)\n", 1000 | "print(np.sum(np.abs(np.exp(log_grad) - grad)))" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "markdown", 1005 | "metadata": {}, 1006 | "source": [ 1007 | "## 2.2 scale\n", 1008 | "\n", 1009 | "### 2.2.1 前向算法\n", 1010 | "\n", 1011 | "为了下溢,在前向算法的每个时刻,都对计算出的 $\\alpha$ 的范围进行缩放:\n", 1012 | "$$\n", 1013 | "C_t \\stackrel{def}{=} \\sum_s\\alpha_t(s)\n", 1014 | "$$\n", 1015 | "\n", 1016 | "$$\n", 1017 | "\\hat{\\alpha}_t = \\frac{\\alpha_t(s)}{C_t}\n", 1018 | "$$\n", 1019 | "\n", 1020 | "缩放后的 $\\alpha$,不会随着时刻的积累变得太小。$\\hat{\\alpha}$ 替代 $\\alpha$,进行下一时刻的迭代。\n" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": 13, 1026 | "metadata": { 1027 | "collapsed": false 1028 | }, 1029 | "outputs": [], 1030 | "source": [ 1031 | "def forward_scale(y, labels):\n", 1032 | " T, V = y.shape\n", 1033 | " L = len(labels)\n", 1034 | " alpha_scale = np.zeros([T, L])\n", 1035 | "\n", 1036 | " # init\n", 1037 | " alpha_scale[0, 0] = y[0, labels[0]]\n", 1038 | " alpha_scale[0, 1] = y[0, labels[1]]\n", 1039 | " Cs = []\n", 1040 | " \n", 1041 | " C = np.sum(alpha_scale[0])\n", 1042 | " alpha_scale[0] /= C\n", 1043 | " Cs.append(C)\n", 1044 | "\n", 1045 | " for t in range(1, T):\n", 1046 | " for i in range(L):\n", 1047 | " s = labels[i]\n", 1048 | " \n", 1049 | " a = alpha_scale[t - 1, i] \n", 1050 | " if i - 1 >= 0:\n", 1051 | " a += alpha_scale[t - 1, i - 1]\n", 1052 | " if i - 2 >= 0 and s != 0 and s != labels[i - 2]:\n", 1053 | " a += alpha_scale[t - 1, i - 2]\n", 1054 | " \n", 1055 | " alpha_scale[t, i] = a * y[t, s]\n", 1056 | " \n", 1057 | " C = np.sum(alpha_scale[t])\n", 1058 | " alpha_scale[t] /= C\n", 1059 | " Cs.append(C)\n", 1060 | " \n", 1061 | " return alpha_scale, Cs" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": {}, 1067 | "source": [ 1068 | "由于进行了缩放,最后计算概率时要时行补偿:\n", 1069 | "$$\n", 1070 | "p(l|x) = \\alpha_T(|l^\\prime|) + \\alpha_T(|l^\\prime|-1) = (\\hat\\alpha_T(|l^\\prime|) + \\hat\\alpha_T(|l^\\prime|-1) * \\prod_{t=1}^T C_t\n", 1071 | "$$\n", 1072 | "\n", 1073 | "$$\n", 1074 | "\\ln(p(l|x)) = \\sum_t^T\\ln(C_t) + \\ln(\\hat\\alpha_T(|l^\\prime|) + \\hat\\alpha_T(|l^\\prime|-1))\n", 1075 | "$$" 1076 | ] 1077 | }, 1078 | { 1079 | "cell_type": "code", 1080 | "execution_count": 14, 1081 | "metadata": { 1082 | "collapsed": false 1083 | }, 1084 | "outputs": [ 1085 | { 1086 | "name": "stdout", 1087 | "output_type": "stream", 1088 | "text": [ 1089 | "(-13.202925982240107, -13.202925982240107, 0.0)\n" 1090 | ] 1091 | } 1092 | ], 1093 | "source": [ 1094 | "labels = [0, 1, 2, 0] # 0 for blank\n", 1095 | "\n", 1096 | "alpha_scale, Cs = forward_scale(y, labels)\n", 1097 | "log_p = np.sum(np.log(Cs)) + np.log(alpha_scale[-1][labels[-1]] + alpha_scale[-1][labels[-2]])\n", 1098 | "\n", 1099 | "alpha = forward(y, labels)\n", 1100 | "p = alpha[-1, labels[-1]] + alpha[-1, labels[-2]]\n", 1101 | "\n", 1102 | "print(np.log(p), log_p, np.log(p) - log_p)" 1103 | ] 1104 | }, 1105 | { 1106 | "cell_type": "markdown", 1107 | "metadata": {}, 1108 | "source": [ 1109 | "### 2.2.2 后向算法\n", 1110 | "后向算法缩放类似于前向算法,公式如下:\n", 1111 | "\n", 1112 | "$$\n", 1113 | "D_t \\stackrel{def}{=} \\sum_s\\beta_t(s)\n", 1114 | "$$\n", 1115 | "\n", 1116 | "$$\n", 1117 | "\\hat{\\beta}_t = \\frac{\\beta_t(s)}{D_t}\n", 1118 | "$$\n" 1119 | ] 1120 | }, 1121 | { 1122 | "cell_type": "code", 1123 | "execution_count": 15, 1124 | "metadata": { 1125 | "collapsed": false 1126 | }, 1127 | "outputs": [ 1128 | { 1129 | "name": "stdout", 1130 | "output_type": "stream", 1131 | "text": [ 1132 | "[[ 0.71362347 0.18910147 0.07964328 0.01763178]\n", 1133 | " [ 0.70165268 0.15859852 0.11849423 0.02125457]\n", 1134 | " [ 0.67689676 0.165374 0.13221504 0.02551419]\n", 1135 | " [ 0.71398181 0.11936432 0.13524265 0.03141122]\n", 1136 | " [ 0.70769657 0.13093688 0.12447135 0.0368952 ]\n", 1137 | " [ 0.63594568 0.1790638 0.14250065 0.04248987]\n", 1138 | " [ 0.63144322 0.1806382 0.13366043 0.05425815]\n", 1139 | " [ 0.33926289 0.35149591 0.24988622 0.05935497]\n", 1140 | " [ 0.30303623 0.26644554 0.33088584 0.0996324 ]\n", 1141 | " [ 0.12510056 0.3297143 0.3956509 0.14953425]\n", 1142 | " [ 0. 0.22078343 0.5153114 0.26390517]\n", 1143 | " [ 0. 0. 0.550151 0.449849 ]]\n" 1144 | ] 1145 | } 1146 | ], 1147 | "source": [ 1148 | "def backward_scale(y, labels):\n", 1149 | " T, V = y.shape\n", 1150 | " L = len(labels)\n", 1151 | " beta_scale = np.zeros([T, L])\n", 1152 | "\n", 1153 | " # init\n", 1154 | " beta_scale[-1, -1] = y[-1, labels[-1]]\n", 1155 | " beta_scale[-1, -2] = y[-1, labels[-2]]\n", 1156 | " \n", 1157 | " Ds = []\n", 1158 | " \n", 1159 | " D = np.sum(beta_scale[-1,:])\n", 1160 | " beta_scale[-1] /= D\n", 1161 | " Ds.append(D)\n", 1162 | "\n", 1163 | " for t in range(T - 2, -1, -1):\n", 1164 | " for i in range(L):\n", 1165 | " s = labels[i]\n", 1166 | " \n", 1167 | " a = beta_scale[t + 1, i] \n", 1168 | " if i + 1 < L:\n", 1169 | " a += beta_scale[t + 1, i + 1]\n", 1170 | " if i + 2 < L and s != 0 and s != labels[i + 2]:\n", 1171 | " a += beta_scale[t + 1, i + 2]\n", 1172 | " \n", 1173 | " beta_scale[t, i] = a * y[t, s]\n", 1174 | " \n", 1175 | " D = np.sum(beta_scale[t])\n", 1176 | " beta_scale[t] /= D\n", 1177 | " Ds.append(D)\n", 1178 | " \n", 1179 | " return beta_scale, Ds[::-1]\n", 1180 | "\n", 1181 | "beta_scale, Ds = backward_scale(y, labels)\n", 1182 | "print(beta_scale)" 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "markdown", 1187 | "metadata": {}, 1188 | "source": [ 1189 | "### 2.2.3 梯度计算\n", 1190 | "\n", 1191 | "$$\n", 1192 | "\\frac{\\partial \\ln(p(l|x))}{\\partial y^t_k} = \\frac{1}{p(l|x)} \\frac{\\partial p(l|x)}{\\partial y^t_k} = \\frac{1}{p(l|x)} \\frac{1}{{y^t_k}^2} \\sum_{s \\in lab(l, k)} \\alpha_t(s)\\beta_t(s) \n", 1193 | "$$\n", 1194 | "\n", 1195 | "考虑到\n", 1196 | "$$\n", 1197 | "p(l|x) = \\sum_{s=1}^{|l^\\prime|} \\frac{\\alpha_t(s)\\beta_t(s)}{ y^t_{l_s^\\prime}} \n", 1198 | "$$\n", 1199 | "以及\n", 1200 | "$$\n", 1201 | "\\alpha_t(s) = \\hat\\alpha_t(s) \\cdot \\prod_{k=1}^t C_k\n", 1202 | "$$\n", 1203 | "\n", 1204 | "$$\n", 1205 | "\\beta_t(s) = \\hat\\beta_t(s) \\cdot \\prod_{k=t}^T D_k\n", 1206 | "$$\n", 1207 | "\n", 1208 | "$$\n", 1209 | "\\frac{\\partial \\ln(p(l|x))}{\\partial y^t_k} = \\frac{1}{\\sum_{s=1}^{|l^\\prime|} \\frac{\\hat\\alpha_t(s) \\hat\\beta_t(s)}{y^t_{l^\\prime_s}}} \\frac{1}{{y^t_k}^2} \\sum_{s \\in lab(l, k)} \\hat\\alpha_t(s) \\hat\\beta_t(s)\n", 1210 | "$$\n", 1211 | "\n", 1212 | "式中最右项中的各个部分我们都已经求得。梯度计算实现如下:" 1213 | ] 1214 | }, 1215 | { 1216 | "cell_type": "code", 1217 | "execution_count": 16, 1218 | "metadata": { 1219 | "collapsed": false 1220 | }, 1221 | "outputs": [ 1222 | { 1223 | "name": "stdout", 1224 | "output_type": "stream", 1225 | "text": [ 1226 | "6.86256607096e-15\n" 1227 | ] 1228 | } 1229 | ], 1230 | "source": [ 1231 | "def gradient_scale(y, labels):\n", 1232 | " T, V = y.shape\n", 1233 | " L = len(labels)\n", 1234 | " \n", 1235 | " alpha_scale, _ = forward_scale(y, labels)\n", 1236 | " beta_scale, _ = backward_scale(y, labels)\n", 1237 | " \n", 1238 | " grad = np.zeros([T, V])\n", 1239 | " for t in range(T):\n", 1240 | " for s in range(V):\n", 1241 | " lab = [i for i, c in enumerate(labels) if c == s]\n", 1242 | " for i in lab:\n", 1243 | " grad[t, s] += alpha_scale[t, i] * beta_scale[t, i]\n", 1244 | " grad[t, s] /= y[t, s] ** 2\n", 1245 | " \n", 1246 | " # normalize factor\n", 1247 | " z = 0\n", 1248 | " for i in range(L):\n", 1249 | " z += alpha_scale[t, i] * beta_scale[t, i] / y[t, labels[i]]\n", 1250 | " grad[t] /= z\n", 1251 | " \n", 1252 | " return grad\n", 1253 | " \n", 1254 | "labels = [0, 3, 0, 3, 0, 4, 0] # 0 for blank\n", 1255 | "grad_1 = gradient_scale(y, labels)\n", 1256 | "grad_2 = gradient(y, labels)\n", 1257 | "print(np.sum(np.abs(grad_1 - grad_2)))" 1258 | ] 1259 | }, 1260 | { 1261 | "cell_type": "markdown", 1262 | "metadata": {}, 1263 | "source": [ 1264 | "### 2.2.4 logits 梯度\n", 1265 | "类似于 y 梯度的推导,logits 梯度计算公式如下:\n", 1266 | "$$\n", 1267 | "\\frac{\\partial \\ln(p(l|x))}{\\partial u^t_k} = \\frac{1}{y^t_k Z_t} \\sum_{s \\in lab(l, k)} \\hat\\alpha_t(s)\\hat\\beta_t(s) - y^t_k\n", 1268 | "$$\n", 1269 | "其中,\n", 1270 | "$$\n", 1271 | "Z_t \\stackrel{def}{=} \\sum_{s=1}^{|l^\\prime|} \\frac{\\hat\\alpha_t(s)\\hat\\beta_t(s)}{y^t_{l^\\prime_s}}\n", 1272 | "$$" 1273 | ] 1274 | }, 1275 | { 1276 | "cell_type": "markdown", 1277 | "metadata": {}, 1278 | "source": [ 1279 | "# 3. 解码\n", 1280 | "训练和的 $N_w$ 可以用来预测新的样本输入对应的输出字符串,这涉及到解码。\n", 1281 | "按照最大似然准则,最优的解码结果为:\n", 1282 | "$$\n", 1283 | "h(x) = \\underset{l \\in L^{\\le T}}{\\mathrm{argmax}}\\ p(l|x)\n", 1284 | "$$\n", 1285 | "\n", 1286 | "然而,上式不存在已知的高效解法。下面介绍几种实用的近似破解码方法。\n", 1287 | "\n", 1288 | "## 3.1 贪心搜索 (greedy search)\n", 1289 | "虽然 $p(l|x)$ 难以有效的计算,但是由于 CTC 的独立性假设,对于某个具体的字符串 $\\pi$(去 blank 前),确容易计算:\n", 1290 | "$$\n", 1291 | "p(\\pi|x) = \\prod_{k=1}^T p(\\pi_k|x)\n", 1292 | "$$\n", 1293 | "\n", 1294 | "因此,我们放弃寻找使 $p(l|x)$ 最大的字符串,退而寻找一个使 $p(\\pi|x)$ 最大的字符串,即:\n", 1295 | "\n", 1296 | "$$\n", 1297 | "h(x) \\approx B(\\pi^\\star)\n", 1298 | "$$\n", 1299 | "其中,\n", 1300 | "$$\n", 1301 | "\\pi^\\star = \\underset{\\pi \\in N^T}{\\mathrm{argmax}}\\ p(\\pi|x)\n", 1302 | "$$\n", 1303 | "\n", 1304 | "简化后,解码过程(构造 $\\pi^\\star$)变得非常简单(基于独立性假设): 在每个时刻输出概率最大的字符:\n", 1305 | "$$\n", 1306 | "\\pi^\\star = cat_{t=1}^T(\\underset{s \\in L^\\prime}{\\mathrm{argmax}}\\ y^t_s)\n", 1307 | "$$\n", 1308 | "\n", 1309 | "\n" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "code", 1314 | "execution_count": 17, 1315 | "metadata": { 1316 | "collapsed": false 1317 | }, 1318 | "outputs": [ 1319 | { 1320 | "name": "stdout", 1321 | "output_type": "stream", 1322 | "text": [ 1323 | "[1 3 5 5 5 5 1 5 3 4 4 3 0 4 5 0 3 1 3 3]\n", 1324 | "[1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3]\n" 1325 | ] 1326 | } 1327 | ], 1328 | "source": [ 1329 | "def remove_blank(labels, blank=0):\n", 1330 | " new_labels = []\n", 1331 | " \n", 1332 | " # combine duplicate\n", 1333 | " previous = None\n", 1334 | " for l in labels:\n", 1335 | " if l != previous:\n", 1336 | " new_labels.append(l)\n", 1337 | " previous = l\n", 1338 | " \n", 1339 | " # remove blank \n", 1340 | " new_labels = [l for l in new_labels if l != blank]\n", 1341 | " \n", 1342 | " return new_labels\n", 1343 | "\n", 1344 | "def insert_blank(labels, blank=0):\n", 1345 | " new_labels = [blank]\n", 1346 | " for l in labels:\n", 1347 | " new_labels += [l, blank]\n", 1348 | " return new_labels\n", 1349 | "\n", 1350 | "def greedy_decode(y, blank=0):\n", 1351 | " raw_rs = np.argmax(y, axis=1)\n", 1352 | " rs = remove_blank(raw_rs, blank)\n", 1353 | " return raw_rs, rs\n", 1354 | "\n", 1355 | "np.random.seed(1111)\n", 1356 | "y = softmax(np.random.random([20, 6]))\n", 1357 | "rr, rs = greedy_decode(y)\n", 1358 | "print(rr)\n", 1359 | "print(rs)" 1360 | ] 1361 | }, 1362 | { 1363 | "cell_type": "markdown", 1364 | "metadata": {}, 1365 | "source": [ 1366 | "## 3.2 束搜索(Beam Search)\n", 1367 | "显然,贪心搜索的性能非常受限。例如,它不能给出除最优路径之外的其他其优路径。很多时候,如果我们能拿到 nbest 的路径,后续可以利用其他信息来进一步优化搜索的结果。束搜索能近似找出 top 最优的若干条路径。" 1368 | ] 1369 | }, 1370 | { 1371 | "cell_type": "code", 1372 | "execution_count": 18, 1373 | "metadata": { 1374 | "collapsed": false 1375 | }, 1376 | "outputs": [ 1377 | { 1378 | "name": "stdout", 1379 | "output_type": "stream", 1380 | "text": [ 1381 | "([1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3], -29.261797539205567)\n", 1382 | "([1, 3, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3], -29.279020152518033)\n", 1383 | "([1, 3, 5, 1, 5, 3, 4, 2, 3, 4, 5, 3, 1, 3], -29.300726142201842)\n", 1384 | "([1, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3], -29.310307014773972)\n", 1385 | "([1, 3, 5, 1, 5, 3, 4, 2, 3, 3, 5, 3, 1, 3], -29.317948755514308)\n", 1386 | "([1, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3], -29.327529628086438)\n", 1387 | "([1, 3, 5, 1, 5, 4, 3, 4, 5, 3, 1, 3], -29.331572723457334)\n", 1388 | "([1, 3, 5, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3], -29.332631809924511)\n", 1389 | "([1, 3, 5, 4, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3], -29.334649090836038)\n", 1390 | "([1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3], -29.33969505198154)\n", 1391 | "([1, 3, 5, 2, 1, 5, 3, 4, 3, 4, 5, 3, 1, 3], -29.339823066915415)\n", 1392 | "([1, 3, 5, 1, 5, 4, 3, 3, 5, 3, 1, 3], -29.3487953367698)\n", 1393 | "([1, 5, 1, 5, 3, 4, 2, 3, 4, 5, 3, 1, 3], -29.349235617770248)\n", 1394 | "([1, 3, 5, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3], -29.349854423236977)\n", 1395 | "([1, 3, 5, 1, 5, 3, 4, 3, 4, 5, 3, 3], -29.350803198551016)\n", 1396 | "([1, 3, 5, 4, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3], -29.351871704148504)\n", 1397 | "([1, 3, 5, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3], -29.356917665294006)\n", 1398 | "([1, 3, 5, 2, 1, 5, 3, 4, 3, 3, 5, 3, 1, 3], -29.357045680227881)\n", 1399 | "([1, 3, 5, 1, 5, 3, 4, 5, 4, 5, 3, 1, 3], -29.363802591012263)\n", 1400 | "([1, 5, 1, 5, 3, 4, 2, 3, 3, 5, 3, 1, 3], -29.366458231082714)\n" 1401 | ] 1402 | } 1403 | ], 1404 | "source": [ 1405 | "def beam_decode(y, beam_size=10):\n", 1406 | " T, V = y.shape\n", 1407 | " log_y = np.log(y)\n", 1408 | " \n", 1409 | " beam = [([], 0)]\n", 1410 | " for t in range(T): # for every timestep\n", 1411 | " new_beam = []\n", 1412 | " for prefix, score in beam:\n", 1413 | " for i in range(V): # for every state\n", 1414 | " new_prefix = prefix + [i]\n", 1415 | " new_score = score + log_y[t, i]\n", 1416 | " \n", 1417 | " new_beam.append((new_prefix, new_score))\n", 1418 | " \n", 1419 | " # top beam_size\n", 1420 | " new_beam.sort(key=lambda x: x[1], reverse=True)\n", 1421 | " beam = new_beam[:beam_size]\n", 1422 | " \n", 1423 | " return beam\n", 1424 | " \n", 1425 | "np.random.seed(1111)\n", 1426 | "y = softmax(np.random.random([20, 6]))\n", 1427 | "beam = beam_decode(y, beam_size=100)\n", 1428 | "for string, score in beam[:20]:\n", 1429 | " print(remove_blank(string), score)" 1430 | ] 1431 | }, 1432 | { 1433 | "cell_type": "markdown", 1434 | "metadata": {}, 1435 | "source": [ 1436 | "## 3.3 前缀束搜索(Prefix Beam Search)\n", 1437 | "直接的束搜索的一个问题是,在保存的 top N 条路径中,可能存在多条实际上是同一结果(经过去重复、去 blank 操作)。这减少了搜索结果的多样性。下面介绍的前缀搜索方法,在搜索过程中不断的合并相同的前缀[2]。参考 [gist](https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0),前缀束搜索实现如下:" 1438 | ] 1439 | }, 1440 | { 1441 | "cell_type": "code", 1442 | "execution_count": 19, 1443 | "metadata": { 1444 | "collapsed": false 1445 | }, 1446 | "outputs": [ 1447 | { 1448 | "name": "stdout", 1449 | "output_type": "stream", 1450 | "text": [ 1451 | "([1, 5, 4, 1, 3, 4, 5, 2, 3], (-18.189863809114193, -17.613677981426175))\n", 1452 | "([1, 5, 4, 5, 3, 4, 5, 2, 3], (-18.19636512622969, -17.621013424585406))\n", 1453 | "([1, 5, 4, 1, 3, 4, 5, 1, 3], (-18.317018960331531, -17.666629973270073))\n", 1454 | "([1, 5, 4, 5, 3, 4, 5, 1, 3], (-18.323388267369936, -17.674125139073176))\n", 1455 | "([1, 5, 4, 1, 3, 4, 3, 2, 3], (-18.415808498759556, -17.862744326248826))\n", 1456 | "([1, 5, 4, 1, 3, 4, 3, 5, 3], (-18.366422766638632, -17.898463479112884))\n", 1457 | "([1, 5, 4, 5, 3, 4, 3, 2, 3], (-18.42224294936932, -17.870025672291458))\n", 1458 | "([1, 5, 4, 5, 3, 4, 3, 5, 3], (-18.372199113900191, -17.905130493229173))\n", 1459 | "([1, 5, 4, 1, 3, 4, 5, 4, 3], (-18.457066311773847, -17.880630315602037))\n", 1460 | "([1, 5, 4, 5, 3, 4, 5, 4, 3], (-18.462614293487096, -17.88759583852546))\n", 1461 | "([1, 5, 4, 1, 3, 4, 5, 3, 2], (-18.458941701567706, -17.951422824358747))\n", 1462 | "([1, 5, 4, 5, 3, 4, 5, 3, 2], (-18.464527031120184, -17.958629487208658))\n", 1463 | "([1, 5, 4, 1, 3, 4, 3, 1, 3], (-18.540857550725587, -17.920589910093689))\n", 1464 | "([1, 5, 4, 5, 3, 4, 3, 1, 3], (-18.547146092248852, -17.928030266681613))\n", 1465 | "([1, 5, 4, 1, 3, 4, 5, 3, 2, 3], (-19.325467801462263, -17.689203224408899))\n", 1466 | "([1, 5, 4, 5, 3, 4, 5, 3, 2, 3], (-19.328748799764973, -17.694105969982637))\n", 1467 | "([1, 5, 4, 1, 3, 4, 5, 3, 4], (-18.79699026165903, -17.945090229238392))\n", 1468 | "([1, 5, 4, 5, 3, 4, 5, 3, 4], (-18.803585534273239, -17.95258394264377))\n", 1469 | "([1, 5, 4, 3, 4, 3, 5, 2, 3], (-19.181531846082809, -17.859420073785095))\n", 1470 | "([1, 5, 4, 1, 3, 4, 5, 2, 3, 2], (-19.439349296385199, -17.884502168470895))\n" 1471 | ] 1472 | } 1473 | ], 1474 | "source": [ 1475 | "from collections import defaultdict\n", 1476 | "\n", 1477 | "def prefix_beam_decode(y, beam_size=10, blank=0):\n", 1478 | " T, V = y.shape\n", 1479 | " log_y = np.log(y)\n", 1480 | " \n", 1481 | " beam = [(tuple(), (0, ninf))] # blank, non-blank\n", 1482 | " for t in range(T): # for every timestep\n", 1483 | " new_beam = defaultdict(lambda : (ninf, ninf))\n", 1484 | " \n", 1485 | " for prefix, (p_b, p_nb) in beam:\n", 1486 | " for i in range(V): # for every state\n", 1487 | " p = log_y[t, i]\n", 1488 | " \n", 1489 | " if i == blank: # propose a blank\n", 1490 | " new_p_b, new_p_nb = new_beam[prefix]\n", 1491 | " new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)\n", 1492 | " new_beam[prefix] = (new_p_b, new_p_nb)\n", 1493 | " continue\n", 1494 | " else: # extend with non-blank\n", 1495 | " end_t = prefix[-1] if prefix else None\n", 1496 | " \n", 1497 | " # exntend current prefix\n", 1498 | " new_prefix = prefix + (i,)\n", 1499 | " new_p_b, new_p_nb = new_beam[new_prefix]\n", 1500 | " if i != end_t:\n", 1501 | " new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)\n", 1502 | " else:\n", 1503 | " new_p_nb = logsumexp(new_p_nb, p_b + p)\n", 1504 | " new_beam[new_prefix] = (new_p_b, new_p_nb)\n", 1505 | " \n", 1506 | " # keep current prefix\n", 1507 | " if i == end_t:\n", 1508 | " new_p_b, new_p_nb = new_beam[prefix]\n", 1509 | " new_p_nb = logsumexp(new_p_nb, p_nb + p)\n", 1510 | " new_beam[prefix] = (new_p_b, new_p_nb)\n", 1511 | " \n", 1512 | " # top beam_size\n", 1513 | " beam = sorted(new_beam.items(), key=lambda x : logsumexp(*x[1]), reverse=True)\n", 1514 | " beam = beam[:beam_size]\n", 1515 | " \n", 1516 | " return beam\n", 1517 | "\n", 1518 | "np.random.seed(1111)\n", 1519 | "y = softmax(np.random.random([20, 6]))\n", 1520 | "beam = prefix_beam_decode(y, beam_size=100)\n", 1521 | "for string, score in beam[:20]:\n", 1522 | " print(remove_blank(string), score)" 1523 | ] 1524 | }, 1525 | { 1526 | "cell_type": "markdown", 1527 | "metadata": {}, 1528 | "source": [ 1529 | "## 3.4 其他解码方法\n", 1530 | "上述介绍了基本解码方法。实际中,搜索过程可以接合额外的信息,提高搜索的准确度(例如在语音识别任务中,加入语言模型得分信息, [前缀搜索+语言模型](https://github.com/PaddlePaddle/DeepSpeech/blob/develop/decoders/decoders_deprecated.py\n", 1531 | "))。\n", 1532 | "\n", 1533 | "本质上,CTC 只是一个训练准则。训练完成后,$N_w$ 输出一系列概率分布,这点和常规基于交叉熵准则训练的模型完全一致。因此,特定应用领域常规的解码也可以经过一定修改用来 CTC 的解码。例如在语音识别任务中,利用 CTC 训练的声学模型可以无缝融入原来的 WFST 的解码器中[5](e.g. 参见 [EESEN](https://github.com/srvk/eesen))。\n", 1534 | "\n", 1535 | "此外,[1] 给出了一种利用 CTC 顶峰特点的启发式搜索方法。" 1536 | ] 1537 | }, 1538 | { 1539 | "cell_type": "markdown", 1540 | "metadata": {}, 1541 | "source": [ 1542 | "# 4. 工具\n", 1543 | "\n", 1544 | "[warp-ctc](https://github.com/baidu-research/warp-ctc) 是百度开源的基于 CPU 和 GPU 的高效并行实现。warp-ctc 自身提供 C 语言接口,对于流利的机器学习工具([torch](https://github.com/baidu-research/warp-ctc/tree/master/torch_binding)、 [pytorch](https://github.com/SeanNaren/deepspeech.pytorch) 和 [tensorflow](https://github.com/baidu-research/warp-ctc/tree/master/tensorflow_binding)、[chainer](https://github.com/jheymann85/chainer_ctc))都有相应的接口绑定。\n", 1545 | "\n", 1546 | "[cudnn 7](https://developer.nvidia.com/cudnn) 以后开始提供 CTC 支持。\n", 1547 | "\n", 1548 | "Tensorflow 也原生支持 [CTC loss](https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss),及 greedy 和 beam search 解码器。" 1549 | ] 1550 | }, 1551 | { 1552 | "cell_type": "markdown", 1553 | "metadata": {}, 1554 | "source": [ 1555 | "# 小结\n", 1556 | "1. CTC 可以建模无对齐信息的多对多序列问题(输入长度不小于输出),如语音识别、连续字符识别 [3,4]。\n", 1557 | "2. CTC 不需要输入与输出的对齐信息,可以实现端到端的训练。\n", 1558 | "3. CTC 在 loss 的计算上,利用了整个 labels 序列的全局信息,某种意义上相对逐帧计算损失的方法,\"更加区分性\"。" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "markdown", 1563 | "metadata": {}, 1564 | "source": [ 1565 | "# References\n", 1566 | "1. Graves et al. [Connectionist Temporal Classification : Labelling Unsegmented Sequence Data with Recurrent Neural Networks](ftp://ftp.idsia.ch/pub/juergen/icml2006.pdf).\n", 1567 | "2. Hannun et al. [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/abs/1408.2873).\n", 1568 | "3. Graves et al. [Towards End-To-End Speech Recognition with Recurrent Neural Networks](http://jmlr.org/proceedings/papers/v32/graves14.pdf).\n", 1569 | "4. Liwicki et al. [A novel approach to on-line handwriting recognition based on bidirectional long short-term memory networks](https://www.cs.toronto.edu/~graves/icdar_2007.pdf).\n", 1570 | "5. Zenkel et al. [Comparison of Decoding Strategies for CTC Acoustic Models](https://arxiv.org/abs/1708.004469).\n", 1571 | "5. Huannun. [Sequence Modeling with CTC](https://distill.pub/2017/ctc/)." 1572 | ] 1573 | } 1574 | ], 1575 | "metadata": { 1576 | "anaconda-cloud": {}, 1577 | "kernelspec": { 1578 | "display_name": "Python [Root]", 1579 | "language": "python", 1580 | "name": "Python [Root]" 1581 | }, 1582 | "language_info": { 1583 | "codemirror_mode": { 1584 | "name": "ipython", 1585 | "version": 2 1586 | }, 1587 | "file_extension": ".py", 1588 | "mimetype": "text/x-python", 1589 | "name": "python", 1590 | "nbconvert_exporter": "python", 1591 | "pygments_lexer": "ipython2", 1592 | "version": "2.7.12" 1593 | } 1594 | }, 1595 | "nbformat": 4, 1596 | "nbformat_minor": 0 1597 | } 1598 | -------------------------------------------------------------------------------- /gmm/README.md: -------------------------------------------------------------------------------- 1 | K-Means and Gaussain Mixture Models (GMM) with EM Training. 2 | 3 | -------------------------------------------------------------------------------- /gmm/gmm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from scipy.stats import multivariate_normal 5 | 6 | import pytest 7 | from numpy.testing import assert_allclose 8 | 9 | 10 | EPS = 1e-8 11 | 12 | 13 | class Gauss(object): 14 | ''' 15 | ''' 16 | 17 | def __init__(self, dim, mean=None, cov=None): 18 | self.dim = dim 19 | 20 | if mean is None: 21 | self.mean = np.zeros(dim) 22 | else: 23 | assert len(mean) == dim, "Dim not match" 24 | self.mean = mean 25 | 26 | if cov is None: 27 | self.cov = np.eye(dim) 28 | else: 29 | self.cov = cov 30 | 31 | self.rv = multivariate_normal(self.mean, self.cov) 32 | 33 | def update(self, mean, cov): 34 | self.mean, self.cov = mean, cov 35 | self.rv = multivariate_normal(self.mean, self.cov) 36 | 37 | def pdf(self, x): 38 | return self.rv.pdf(x) 39 | 40 | def __call__(self, x): 41 | return self.pdf(x) 42 | 43 | 44 | class GMM(object): 45 | ''' 46 | ''' 47 | 48 | def __init__(self, gauss, weight=[]): 49 | self.gauss = gauss 50 | self.weight = weight or np.ones(len(gauss)) / len(gauss) 51 | 52 | @property 53 | def k(self): 54 | return len(self.gauss) 55 | 56 | def pdf(self, x): 57 | return sum([self.weight[i] * g(x) for i, g in enumerate(self.gauss)]) 58 | 59 | def __call__(self, x, i=None): 60 | if i is None: 61 | return self.pdf(x) 62 | else: 63 | return self.weight[i] * self.gauss[i](x) 64 | 65 | def __getitem__(self, i): 66 | assert i < self.k, 'Out of Index' 67 | return self.gauss[i] 68 | 69 | def llk(self, x): 70 | return np.mean([np.log(self.pdf(e)) for e in x]) 71 | 72 | 73 | def em_step(gmm, x): 74 | num = len(x) 75 | dim = x.shape[-1] 76 | k = gmm.k 77 | 78 | gamma = np.zeros((k, num)) 79 | 80 | # E 81 | for i in range(k): 82 | for j in range(num): 83 | gamma[i][j] = gmm(x[j], i) 84 | gamma /= np.sum(gamma, 0) 85 | 86 | # M 87 | gmm.weight = np.sum(gamma, 1) / num 88 | for i in range(k): 89 | mean = np.average(x, axis=0, weights=gamma[i]) 90 | cov = np.zeros((dim, dim)) 91 | for j in range(num): 92 | delta = x[j] - mean 93 | cov[:] += gamma[i][j] * np.outer(delta, delta) 94 | cov /= np.sum(gamma[i]) 95 | cov += np.eye(dim) * EPS # avoid singular 96 | gmm[i].update(mean, cov) 97 | 98 | return gmm 99 | 100 | 101 | def prune_gmm(gmm, min_k=1): 102 | '''TODO: prune GMM components 103 | ''' 104 | return gmm 105 | 106 | 107 | def train_gmm(gmm, x, max_iter=100, threshold=1e-3, min_k=1): 108 | cur_llk = -np.float('inf') 109 | for i in range(max_iter): 110 | gmm = em_step(gmm, x) 111 | cur_llk, last_llk = gmm.llk(x), cur_llk 112 | 113 | print("Iter {}, log likelihood {}.".format(i, cur_llk)) 114 | 115 | if cur_llk - last_llk < threshold: # early stop 116 | break 117 | 118 | gmm = prune_gmm(gmm, min_k) 119 | 120 | 121 | def test_gauss(): 122 | dims = range(1, 5) 123 | 124 | # default 125 | for dim in dims: 126 | g = Gauss(dim) 127 | assert_allclose(g.mean, np.zeros(dim)) 128 | assert_allclose(g.cov, np.eye(dim)) 129 | 130 | x = np.random.random(dim) 131 | print(dim, g.pdf(x)) 132 | 133 | # pass 134 | for dim in dims: 135 | mean = np.random.random(dim) 136 | cov = np.random.random([dim, dim]) 137 | cov = np.matmul(cov, cov.T) 138 | 139 | g = Gauss(dim, mean, cov) 140 | assert_allclose(mean, g.mean) 141 | assert_allclose(cov, g.cov) 142 | 143 | x = np.random.random(dim) 144 | print(dim, g(x)) 145 | 146 | 147 | def test_gmm(): 148 | dims = range(1, 5) 149 | ks = range(1, 5) 150 | 151 | for dim in dims: 152 | for k in ks: 153 | print('Dim {}, K {}'.format(dim, k)) 154 | gs = [] 155 | for i in range(k): 156 | mean = np.random.random(dim) 157 | cov = np.random.random([dim, dim]) 158 | cov = np.matmul(cov, cov.T) 159 | gs.append(Gauss(dim, mean, cov)) 160 | gmm = GMM(gs) 161 | 162 | assert k == gmm.k 163 | 164 | x = np.random.random(dim) 165 | assert gmm.pdf(x) == gmm(x) 166 | for i in range(k): 167 | print('Component {}, {}'.format(i, gmm(x, i))) 168 | print('log likelihood: %.2f' % gmm.llk(np.expand_dims(x, 0))) 169 | 170 | 171 | def test_em_step(): 172 | np.random.seed(1111) 173 | dims = range(1, 5) 174 | ks = range(1, 5) 175 | 176 | for dim in dims: 177 | for k in ks: 178 | print('Dim {}, K {}'.format(dim, k)) 179 | gs = [] 180 | for i in range(k): 181 | mean = np.random.random(dim) 182 | cov = np.random.random([dim, dim]) 183 | cov = np.matmul(cov, cov.T) 184 | gs.append(Gauss(dim, mean, cov)) 185 | gmm = GMM(gs) 186 | 187 | x = np.random.random([1000, dim]) 188 | em_step(gmm, x) 189 | 190 | 191 | def test_train_gmm(): 192 | np.random.seed(1111) 193 | dims = range(1, 5) 194 | ks = range(1, 5) 195 | 196 | for dim in dims: 197 | for k in ks: 198 | print('Dim {}, K {}'.format(dim, k)) 199 | gs = [] 200 | for i in range(k): 201 | mean = np.random.random(dim) 202 | cov = np.random.random([dim, dim]) 203 | cov = np.matmul(cov, cov.T) 204 | gs.append(Gauss(dim, mean, cov)) 205 | gmm = GMM(gs) 206 | 207 | x = np.random.random([100, dim]) 208 | train_gmm(gmm, x, threshold=1e-2) 209 | 210 | 211 | def demo(): 212 | import matplotlib.pyplot as plt 213 | from kmeans import kmeans_cluster 214 | 215 | np.random.seed(1111) 216 | dim, k = 2, 2 217 | 218 | # generate data 219 | num = 50 220 | mean1 = np.zeros(dim) 221 | mean2 = np.ones(dim) * 2 222 | cov1 = np.eye(dim) 223 | cov2 = np.eye(dim) * 0.5 224 | 225 | x1 = np.random.multivariate_normal(mean1, cov1, [num, ]) 226 | x2 = np.random.multivariate_normal(mean2, cov2, [num, ]) 227 | x = np.concatenate([x1, x2], 0) 228 | 229 | plt.scatter(x1[:, 0], x1[:, 1], c='r') 230 | plt.scatter(x2[:, 0], x2[:, 1], c='g') 231 | 232 | # init GMM with kmeans 233 | gs = [] 234 | centers, assignment = kmeans_cluster(x, k) 235 | weight = [] 236 | for i in range(k): 237 | # mean 238 | mean = centers[i] 239 | 240 | # covariate 241 | cov = np.eye(dim) * 1e-6 242 | count = 0. 243 | for j in range(num * 2): 244 | if i == assignment[j]: 245 | cov += np.outer(mean - x[j], mean - x[j]) 246 | count += 1 247 | cov /= count 248 | weight.append(count / len(x)) 249 | 250 | gs.append(Gauss(dim, mean, cov)) 251 | gmm = GMM(gs, weight) 252 | centers = np.stack([gmm[i].mean for i in range(gmm.k)]) 253 | plt.scatter(centers[:, 0], centers[:, 1], c='b', s=50, marker='v') 254 | 255 | train_gmm(gmm, x, threshold=1e-4) 256 | centers = np.stack([gmm[i].mean for i in range(gmm.k)]) 257 | plt.scatter(centers[:, 0], centers[:, 1], c='y', s=500, marker='^') 258 | 259 | 260 | if __name__ == '__main__': 261 | pytest.main([__file__, '-s']) 262 | -------------------------------------------------------------------------------- /gmm/kmeans.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | 5 | import pytest 6 | 7 | 8 | EPS = 1e-8 9 | 10 | 11 | def kmeans_cluster(x, k, max_iter=10, threshold=1e-3, verbose=False): 12 | # init 13 | centers = np.zeros([k, x.shape[-1]]) 14 | for i in range(k): 15 | total_num = len(x) 16 | chosen_num = max(1, total_num / k) 17 | random_ids = np.random.choice(total_num, chosen_num, replace=False) 18 | centers[i, :] = np.mean(x[random_ids]) 19 | 20 | cur_total_dist = np.float('inf') 21 | dist = np.zeros([k, len(x)]) 22 | for i in range(max_iter): 23 | for j in range(k): 24 | for m, p in enumerate(x): 25 | dist[j, m] = np.mean((p - centers[j]) ** 2) 26 | 27 | min_idx = np.argmin(dist, 0) 28 | for j in range(k): 29 | ele = [x[m] for m, idx in enumerate(min_idx) if idx == j] 30 | centers[j, :] = np.mean(ele) 31 | 32 | total_dist = 0 33 | for j in range(k): 34 | dist_j = [np.mean((x[m] - centers[j]) ** 2) 35 | for m, idx in enumerate(min_idx) if idx == j] 36 | total_dist += sum(dist_j) 37 | 38 | cur_total_dist, last_total_dist = total_dist, cur_total_dist 39 | if verbose: 40 | print('Iter: {}, total dist: {}'.format(i, total_dist)) 41 | 42 | if last_total_dist - cur_total_dist < threshold: 43 | break 44 | 45 | for j in range(k): 46 | for m, p in enumerate(x): 47 | dist[j, m] = np.mean((p - centers[j]) ** 2) 48 | min_idx = np.argmin(dist, 0) 49 | 50 | return centers, min_idx 51 | 52 | 53 | def test_kmeans_cluster(): 54 | np.random.seed(1111) 55 | dims = range(1, 5) 56 | ks = range(2, 5) 57 | 58 | for dim in dims: 59 | for k in ks: 60 | print('Dim {}, K {}'.format(dim, k)) 61 | x = np.random.random([100, dim]) 62 | kmeans_cluster(x, k, max_iter=10, threshold=1e-2, verbose=True) 63 | 64 | 65 | def demo(): 66 | import matplotlib.pyplot as plt 67 | 68 | np.random.seed(1111) 69 | dim, k = 2, 2 70 | 71 | # generate data 72 | num = 50 73 | mean1 = np.zeros(dim) 74 | mean2 = np.ones(dim) * 2 75 | cov1 = np.eye(dim) 76 | cov2 = np.eye(dim) * 0.5 77 | 78 | x1 = np.random.multivariate_normal(mean1, cov1, [num, ]) 79 | x2 = np.random.multivariate_normal(mean2, cov2, [num, ]) 80 | x = np.concatenate([x1, x2], 0) 81 | 82 | plt.scatter(x1[:, 0], x1[:, 1], c='r') 83 | plt.scatter(x2[:, 0], x2[:, 1], c='g') 84 | 85 | centers, _ = kmeans_cluster(x, k, max_iter=10, threshold=1e-4) 86 | plt.scatter(centers[:, 0], centers[:, 1], c='y', s=500, marker='^') 87 | 88 | 89 | if __name__ == '__main__': 90 | pytest.main([__file__, '-s']) 91 | -------------------------------------------------------------------------------- /gmm/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pytest 4 | -------------------------------------------------------------------------------- /gumbel/gumbel-distribution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import matplotlib\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "%matplotlib inline" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 1. 分布的形式化\n", 22 | "\n", 23 | "## 物理意义\n", 24 | "\n", 25 | "[Gumbel](https://en.wikipedia.org/wiki/Gumbel_distribution) 分布是一种极值型分布。举例而言,假设每次测量心率值为一个随机变量(服从某种[指数族分布](https://en.wikipedia.org/wiki/Exponential_family),如正态分布),每天测量10次心率并取最大的一个心率值作为当天的心率测量值。显然,每天纪录的心率值也是一个随机变量,并且它的概率分布即为 Gumbel 分布。\n", 26 | "\n", 27 | "\n", 28 | "## 概率密度函数(PDF)\n", 29 | "\n", 30 | "Gumbel 分布的 PDF 如下:\n", 31 | "\n", 32 | "$$f(x;\\mu,\\beta) = e^{-z-e^{-z}},\\ z= \\frac{x - \\mu}{\\beta}$$\n", 33 | "\n", 34 | "公式中,$\\mu$ 是位置系数(Gumbel 分布的众数是 $\\mu$),$\\beta$ 是尺度系数(Gumbel 分布的方差是 $\\frac{\\pi^2}{6}\\beta^2$)。\n", 35 | "\n", 36 | "![PDF](https://upload.wikimedia.org/wikipedia/commons/thumb/3/32/Gumbel-Density.svg/488px-Gumbel-Density.svg.png)\n", 37 | "**Gumble PDF 示例图【[src](https://en.wikipedia.org/wiki/Gumbel_distribution)】**\n", 38 | "\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": { 45 | "collapsed": false 46 | }, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "0.183939720586\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "def gumbel_pdf(x, mu=0, beta=1):\n", 58 | " z = (x - mu) / beta\n", 59 | " return np.exp(-z - np.exp(-z)) / beta\n", 60 | "\n", 61 | "print(gumbel_pdf(0.5, 0.5, 2))" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## 累计密度函数(CDF)\n", 69 | "相应的,gumbel 分布的 CDF 的公式如下:\n", 70 | "\n", 71 | "$$F(z;\\mu,\\beta) = e^{-e^{-(x-\\mu)/\\beta}}$$\n", 72 | "\n" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": { 79 | "collapsed": false 80 | }, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "0.899965162661\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "def gumbel_cdf(x, mu=0, beta=1):\n", 92 | " z = (x - mu) / beta\n", 93 | " return np.exp(-np.exp(-z))\n", 94 | " \n", 95 | "print(gumbel_cdf(5, 0.5, 2)) " 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## CDF 的反函数\n", 103 | "根据 CDF 容易得到其反函数:\n", 104 | "\n", 105 | "$$F^{-1}(y;\\mu,\\beta) = \\mu - \\beta \\ln(-\\ln(y))$$\n", 106 | "\n", 107 | "我们可以利用反函数法和生成服从 Gumbel 分布的随机数。\n" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": { 114 | "collapsed": false 115 | }, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "5.0\n", 122 | "[[ 0.29569995 0.02471482 -1.7583011 ]\n", 123 | " [-0.25833806 0.10539504 2.66052767]]\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "def inv_gumbel_cdf(y, mu=0, beta=1, eps=1e-20):\n", 129 | " return mu - beta * np.log(-np.log(y + eps))\n", 130 | "\n", 131 | "print(inv_gumbel_cdf(gumbel_cdf(5, 0.5, 2), 0.5, 2))\n", 132 | "\n", 133 | "def sample_gumbel(shape):\n", 134 | " p = np.random.random(shape)\n", 135 | " return inv_gumbel_cdf(p)\n", 136 | "\n", 137 | "print(sample_gumbel([2,3]))" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "# 2. Gumbel-Max\n", 145 | "\n", 146 | "Gumbel 随机数可以用来对多项分布进行采样。\n", 147 | "\n", 148 | "## 2.1 基于 softmax 的采样\n", 149 | "\n", 150 | "首先来看常规的采样方法。\n", 151 | "\n", 152 | "对于 $logits = (x_1, \\dots, x_K)$,首先利用 softmax 运算得到规一化的概率分布(多项分布)。\n", 153 | "\n", 154 | "$$\\pi_k = \\frac{e^{x_k}}{\\sum_{k^\\prime=1}^{K} e^{x_{k^\\prime}}}$$\n", 155 | "\n", 156 | "然后,利用轮盘赌的方式采样。下面的代码,直接使用 numpy 的 [choice](https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.choice.html) 函数实现。\n", 157 | "\n" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 5, 163 | "metadata": { 164 | "collapsed": false 165 | }, 166 | "outputs": [], 167 | "source": [ 168 | "def softmax(logits):\n", 169 | " max_value = np.max(logits)\n", 170 | " exp = np.exp(logits - max_value)\n", 171 | " exp_sum = np.sum(exp)\n", 172 | " dist = exp / exp_sum\n", 173 | " return dist\n", 174 | "\n", 175 | "def roulette(p):\n", 176 | " p = np.asarray(p)\n", 177 | " cdf = p.cumsum()\n", 178 | " r = np.random.random()\n", 179 | " for i in range(len(cdf)):\n", 180 | " if r <= cdf[i]: break\n", 181 | " return i\n", 182 | "\n", 183 | "def sample_with_softmax(logits, size):\n", 184 | " '''\n", 185 | " pros = softmax(logits)\n", 186 | "\n", 187 | " ret = np.empty(np.product(size)).astype('int')\n", 188 | " for i in range(len(ret)):\n", 189 | " ret[i] = roulette(pros)\n", 190 | " \n", 191 | " return ret.reshape(size)\n", 192 | " '''\n", 193 | " \n", 194 | " pros = softmax(logits)\n", 195 | " return np.random.choice(len(logits), size, p=pros)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## 2.2 基于 gumbel 的采样(gumbel-max)\n", 203 | "对于某组 logits,生成相同数量的 gumbel 随机数,并加到 logits 上。 然后选择数值最大的元素的编号作为采样值。\n", 204 | "示例代码如下:" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 6, 210 | "metadata": { 211 | "collapsed": false 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "def sample_with_gumbel_noise(logits, size):\n", 216 | " noise = sample_gumbel((size, len(logits)))\n", 217 | " return np.argmax(logits + noise, axis=1)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "可以[证明](https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/),gumbel-max 方法的采样效果等效于基于 softmax 的方式。下面的实验直观地展示两种方法的采样效果。\n" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 7, 230 | "metadata": { 231 | "collapsed": false 232 | }, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "text/plain": [ 237 | "(array([ 4358., 22962., 7143., 6761., 3638., 5848., 5946.,\n", 238 | " 15969., 9951., 17424.]),\n", 239 | " array([ 0. , 0.9, 1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9. ]),\n", 240 | " )" 241 | ] 242 | }, 243 | "execution_count": 7, 244 | "metadata": {}, 245 | "output_type": "execute_result" 246 | }, 247 | { 248 | "data": { 249 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEXtJREFUeJzt3W2MXOV5h/Hrrp20ecHCDmvL9ZoaKqsxGMVxLHBLZIW6\nAQOVTUMTYaXBgCVXCNqERGo3+WJIGuRIkAINRaKJG5OmpogkslUcE8uNFTUKFDu4vNhJ7RAXL976\nJUsIEKmB5O6HOUsnfsa769ndmd2d6yetZuaeM+d5xrqH/zxnzgyRmUiSVO832j0BSdL4YzhIkgqG\ngySpYDhIkgqGgySpYDhIkgqGgySpYDhIkgqGgySpMLXdE2jWWWedlfPmzWv3NDRJ7dmz50RmdrV6\nXPtaY2nPnj0/A76XmSuG2nbChsO8efPYvXt3u6ehSSoi/rsd49rXGksRcWA4wQAeVpIkNWA4SJIK\nhoMkqWA4SJIKhoMkqWA4SJIKhoMkqWA4SJIKhoMkqTBhvyE92ub1PNLU4w5tuHKUZyKNLntbzXDl\nIEkqGA6SpILhIEkqGA6SpILhIEkqGA6SpILhIEkqGA6SpILhIEkqGA6SpILhIEkqGA6SpILhIEkq\nGA6SpILhIEkqGA6SpILhIEkqGA6SpILhIEkqGA6SpILhIEkqGA6SpMKQ4RARcyPi2xGxPyKejYiP\nVvUZEbEjIg5Ul9OrekTEPRFxMCKeiojFdftaU21/ICLW1NXfExFPV4+5JyJiLJ6sVO/w4cNccskl\nLFiwgPPPP5+7774bgP7+foD59rY62XBWDq8Dn8jMBcBS4KaIOA/oAXZm5nxgZ3Ub4HJgfvW3DrgP\namECrAcuAi4E1g+86Kpt1tU9bsXIn5o0uKlTp3LnnXeyf/9+HnvsMe6991727dvHhg0bAF62t9XJ\nhgyHzOzLzO9X118G9gNzgFXApmqzTcBV1fVVwANZ8xhwZkTMBi4DdmRmf2a+COwAVlT3TcvM72Vm\nAg/U7UsaM7Nnz2bx4tqb/zPOOIMFCxbwwgsvsGXLFoCfVJvZ2+pIp/WZQ0TMA94NPA7Mysw+qAUI\nMLPabA5wuO5hvVVtsHpvg7rUMocOHeLJJ5/koosu4ujRowCvQWt6OyLWRcTuiNh9/PjxUXpG0sgM\nOxwi4u3A14CPZebPBtu0QS2bqDeagy8ijbpXXnmFq6++mrvuuotp06YNtumY9HZm3p+ZSzJzSVdX\n17DmLI21YYVDRLyJWjB8NTO/XpWPVstmqstjVb0XmFv38G7gyBD17gb1gi8ijbbXXnuNq6++mg9/\n+MN84AMfAGDWrFkAb4LW9bY03gznbKUAvgTsz8zP1921FRg4K2MNsKWufm11ZsdS4KVqaf4ocGlE\nTK8+rLsUeLS67+WIWFqNdW3dvqQxk5msXbuWBQsW8PGPf/yN+sqVKwHeUd20t9WRpg5jm4uBjwBP\nR8TeqvYpYAPwUESsBZ4HPljdtw24AjgI/By4HiAz+yPiM8AT1Xafzsz+6vqNwJeBtwDfrP6kMfXd\n736Xr3zlK1xwwQUsWrQIgNtvv52enh7uuOOOaRFxAHtbHWrIcMjMf6fxsVOA5Q22T+CmU+xrI7Cx\nQX03sHCouUij6b3vfS+1dm3ovzJzSX3B3lYn8RvSkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgO\nkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKgzn/yEt\nSWqDeT2PNPW4QxuuHPHYrhwkSQVXDprw2vnuSpqsXDlIkgqGgySpYDhIkgqGgySpYDhIkgqGgySp\nYDhIkgqGgySpYDhIkgp+Q1qjzm8sSxOf4SBJw9RJb3w8rCRJKhgOkqSC4SBJKhgOkqTCkOEQERsj\n4lhEPFNXuzUiXoiIvdXfFXX3fTIiDkbEDyPisrr6iqp2MCJ66urnRMTjEXEgIv4lIt48mk9QOpUb\nbriBmTNnsnDhwjdqt956K3PmzAE4z95WJxvOyuHLwIoG9b/NzEXV3zaAiDgPuAY4v3rM30fElIiY\nAtwLXA6cB6yutgX4XLWv+cCLwNqRPCFpuK677jq2b99e1G+55RaAffa2OtmQ4ZCZ3wH6h7m/VcCD\nmfm/mflj4CBwYfV3MDOfy8xfAA8CqyIigD8EHq4evwm46jSfg9SUZcuWMWPGjOFubm+ro4zkM4eb\nI+Kp6rDT9Ko2Bzhct01vVTtV/R3ATzPz9ZPqUtt84QtfgNphJXtbHavZcLgP+F1gEdAH3FnVo8G2\n2US9oYhYFxG7I2L38ePHT2/G0jDceOON/OhHPwLYR4t6277WeNRUOGTm0cz8ZWb+CvgHaktrqL07\nmlu3aTdwZJD6CeDMiJh6Uv1U496fmUsyc0lXV1czU5cGNWvWLKZMmTJwsyW9bV9rPGoqHCJidt3N\nPwEGzmTaClwTEb8ZEecA84H/AJ4A5ldnb7yZ2gd7WzMzgW8Df1o9fg2wpZk5SaOhr6+v/qa9rY41\n5G8rRcRm4H3AWRHRC6wH3hcRi6gtkw8Bfw6Qmc9GxEPUluSvAzdl5i+r/dwMPApMATZm5rPVEH8N\nPBgRfwM8CXxp1J6dNIjVq1eza9cuTpw4QXd3N7fddhu7du1i7969UDvz6BLsbXWoIcMhM1c3KJ+y\nyTPzs8BnG9S3Adsa1J/j/5fuUsts3ry5qK1dWzvbNCL2ZebK+vvsbXUSvyEtSSoYDpKkguEgSSoY\nDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKk\nguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEg\nSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkguEgSSoYDpKkwpDhEBEbI+JYRDxTV5sRETsi4kB1Ob2q\nR0TcExEHI+KpiFhc95g11fYHImJNXf09EfF09Zh7IiJG+0lKjdxwww3MnDmThQsXvlHr7+/n/e9/\nP8BCe1udbDgrhy8DK06q9QA7M3M+sLO6DXA5ML/6WwfcB7UwAdYDFwEXAusHXnTVNuvqHnfyWNKY\nuO6669i+ffuv1TZs2MDy5csBnsHeVgcbMhwy8ztA/0nlVcCm6vom4Kq6+gNZ8xhwZkTMBi4DdmRm\nf2a+COwAVlT3TcvM72VmAg/U7UsaU8uWLWPGjBm/VtuyZQtr1rzx5t/eVsdq9jOHWZnZB1Bdzqzq\nc4DDddv1VrXB6r0N6lJbHD16lNmzZwP2tjrb1FHeX6NjqtlEvfHOI9ZRW6Zz9tlnNzM/qVlj1tv2\n9eQ3r+eRdk/htDW7cjhaLZupLo9V9V5gbt123cCRIerdDeoNZeb9mbkkM5d0dXU1OXXp1GbNmkVf\nXx/Qut62rzUeNbty2AqsATZUl1vq6jdHxIPUPqB7KTP7IuJR4Pa6D+ouBT6Zmf0R8XJELAUeB64F\n/q7JOWmCGw/vrlauXMmmTQMfp9nb6lxDhkNEbAbeB5wVEb3UzszYADwUEWuB54EPVptvA64ADgI/\nB64HqF4onwGeqLb7dGYOfMh9I7Uzot4CfLP6k8bc6tWr2bVrFydOnKC7u5vbbruNnp4ePvShDwEs\nBF7C3laHGjIcMnP1Ke5a3mDbBG46xX42Ahsb1HdTeyFKLbV58+aG9Z07dxIRz2TmGz1ub6vT+A1p\nSVLBcJAkFQwHSVLBcJAkFQwHSVLBcJAkFQwHSVLBcJAkFQwHSVLBcJAkFQwHSVLBcJAkFQwHSVLB\ncJAkFQwHSVLBcJAkFQwHSVLBcJAkFYb834RqcPN6HmnqcYc2XDnKM5Gk0ePKQZJUcOUgqSFXxZ3N\nlYMkqWA4SJIKhoMkqeBnDm3i8VxNVvb25ODKQZJUMBwkSQXDQZJUMBwkSQXDQZJUMBwkSQVPZZXU\ncZo93baTGA4dwheDJiP7eux4WEmSVDAcJEkFw0GSVBhROETEoYh4OiL2RsTuqjYjInZExIHqcnpV\nj4i4JyIORsRTEbG4bj9rqu0PRMSakT0laVRcYG+rk43GyuGSzFyUmUuq2z3AzsycD+ysbgNcDsyv\n/tYB90HtBQesBy4CLgTWD7zopDazt9WxxuKw0ipgU3V9E3BVXf2BrHkMODMiZgOXATsysz8zXwR2\nACvGYF7SSNnb6hgjDYcEvhUReyJiXVWblZl9ANXlzKo+Bzhc99jeqnaqutRu9rY61ki/53BxZh6J\niJnAjoj4wSDbRoNaDlIvd1B7ka4DOPvss093rtLp+EFmLm5Fb9vXGo9GtHLIzCPV5THgG9SOqx6t\nltRUl8eqzXuBuXUP7waODFJvNN79mbkkM5d0dXWNZOrSUF6D1vS2fa3xqOlwiIi3RcQZA9eBS4Fn\ngK3AwFkZa4At1fWtwLXVmR1LgZeqpfmjwKURMb36sO7Sqia1xauvvgrVa8PeVqcayWGlWcA3ImJg\nP/+cmdsj4gngoYhYCzwPfLDafhtwBXAQ+DlwPUBm9kfEZ4Anqu0+nZn9I5iXNCJHjx4FeGdE/Cf2\ntjpU0+GQmc8B72pQ/wmwvEE9gZtOsa+NwMZm5yKNpnPPPRdgX90prIC9rc7iD+9JE4Q/MqdWmpTh\n4ItIkkbG31aSJBUMB0lSwXCQJBUMB0lSwXCQJBUm5dlKk5lnYmmysrfHF1cOkqSC4SBJKhgOkqSC\n4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJ\nKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKhgOkqSC4SBJKoyb\ncIiIFRHxw4g4GBE97Z6PNFrsbU1E4yIcImIKcC9wOXAesDoizmvvrKSRs7c1UY2LcAAuBA5m5nOZ\n+QvgQWBVm+ckjQZ7WxPSeAmHOcDhutu9VU2a6OxtTUhT2z2BSjSoZbFRxDpgXXXzlYj44Sn2dxZw\nYpTmdjraNa5jNyE+N+jdv9PMPhsN06D2a709Afq6U8eesM95kN6eHxHbM3PFUPsYL+HQC8ytu90N\nHDl5o8y8H7h/qJ1FxO7MXDJ60xuedo3r2O0Ze5iG7O3x3tedOnYnPud64+Ww0hPUEu2ciHgzcA2w\ntc1zkkaDva0JaVysHDLz9Yi4GXgUmAJszMxn2zwtacTsbU1U4yIcADJzG7BtlHY35BJ9jLRrXMce\nx0axtzv139nXcxtEZvG5rySpw42XzxwkSePIpAqHdv1MQUTMjYhvR8T+iHg2Ij7aqrGr8adExJMR\n8a8tHvfMiHg4In5QPfffb+HYt1T/1s9ExOaI+K1Wjd0O9ra93WqTJhza/DMFrwOfyMwFwFLgphb/\nRMJHgf0tHG/A3cD2zHwn8K5WzSEi5gB/CSzJzIXUPui9phVjt4O9bW+3YuyTTZpwoI0/U5CZfZn5\n/er6y9QaqSXfgo2IbuBK4IutGK9u3GnAMuBLAJn5i8z8aQunMBV4S0RMBd5Kg+/FTCL2dgvZ2zWT\nKRzGxc8URMQ84N3A4y0a8i7gr4BftWi8AecCx4F/rJb9X4yIt7Vi4Mx8AbgDeB7oA17KzG+1Yuw2\nsbdby95mcoXDsH6CY0wnEPF24GvAxzLzZy0Y74+BY5m5Z6zHamAqsBi4LzPfDbwKtORYeERMp/bO\n+Rzgt4G3RcSftWLsNrG3W8veZnKFw7B+gmOsRMSbqL14vpqZX2/RsBcDKyPiELVDDX8YEf/UorF7\ngd7MHHgX+TC1F1Qr/BHw48w8npmvAV8H/qBFY7eDvW1vt9xkCoe2/UxBRAS145P7M/PzrRgTIDM/\nmZndmTmP2vP9t8xsybuMzPwf4HBE/F5VWg7sa8XY1JbcSyPirdW//XLa86Flq9jb9nbLjZtvSI9U\nm3+m4GLgI8DTEbG3qn2q+mbsZPYXwFer/2A9B1zfikEz8/GIeBj4PrWzaZ5kHHyjdKzY223R8b3t\nN6QlSYXJdFhJkjRKDAdJUsFwkCQVDAdJUsFwkCQVDAdJUsFwkCQVDAdJUuH/ANBIdjt1h6P8AAAA\nAElFTkSuQmCC\n", 250 | "text/plain": [ 251 | "" 252 | ] 253 | }, 254 | "metadata": {}, 255 | "output_type": "display_data" 256 | } 257 | ], 258 | "source": [ 259 | "np.random.seed(1111)\n", 260 | "logits = (np.random.random(10) - 0.5) * 2 # (-1, 1)\n", 261 | " \n", 262 | "pop = 100000\n", 263 | "softmax_samples = sample_with_softmax(logits, pop)\n", 264 | "gumbel_samples = sample_with_gumbel_noise(logits, pop)\n", 265 | " \n", 266 | "plt.subplot(1, 2, 1)\n", 267 | "plt.hist(softmax_samples)\n", 268 | " \n", 269 | "plt.subplot(1, 2, 2)\n", 270 | "plt.hist(gumbel_samples)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "由于 Gumbel 随机数可以预先计算好,采样过程也不需要计算 softmax,因此,某些情况下,gumbel-max 方法相比于 softmax,在采样速度上会有优势。" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "# 3. Gumbel-Softmax\n", 285 | "\n", 286 | "如果仅仅是提供一种常规 softmax 采样的替代方案, gumbel 分布似乎应用价值并不大。幸运的是,我们可以利用 gumbel 实现多项分布采样的 reparameterization(再参数化)。\n", 287 | "\n", 288 | "在介绍 [VAE](http://blog.csdn.net/jackytintin/article/details/53641885) 的时候讨论过,为了实现端到端的训练,VAE 采用的一个再参数化的技巧对高斯分布进行采样:" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 8, 294 | "metadata": { 295 | "collapsed": false 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "def guassian_sample(shape):\n", 300 | " epsilon = K.random_normal(shape, mean=0.,std=1) # 标准高斯分布\n", 301 | " z = z_mean + exp(z_log_var / 2) * epsilon" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "在介绍的 VAE 里,假设隐变量(latent variables)服从标准正态分布。下面将会看到,利用 gumbel-softmax 技巧,我们可以将隐变量建模为服从离散的多项分布。\n", 309 | "\n", 310 | "在上面的示例中,sample_with_softmax 使用了 choise 操作,而这个操作是不可导的。同样,观察 sample_with_gumbel_noise, armmax 操作同样不可导。然而,argmax 有一个 soft 版本,即 **softmax**。\n", 311 | "\n", 312 | "我们首先扩展上面定义的 softmax 函数,添加一个 temperature 参数。" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 9, 318 | "metadata": { 319 | "collapsed": true 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "def generalized_softmax(logits, temperature=1):\n", 324 | " logits = logits / temperature\n", 325 | " return softmax(logits)" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "temperature 是在大于零的参数,它控制着 softmax 的 soft 程度。温度越高,生成的分布越平滑;温度越低,生成的分布越接近离散的 one-hot 分布。下面示例对比了不同温度下,softmax 的结果。" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 10, 338 | "metadata": { 339 | "collapsed": false 340 | }, 341 | "outputs": [ 342 | { 343 | "data": { 344 | "text/plain": [ 345 | "" 346 | ] 347 | }, 348 | "execution_count": 10, 349 | "metadata": {}, 350 | "output_type": "execute_result" 351 | }, 352 | { 353 | "data": { 354 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGSJJREFUeJzt3X1wXfV95/H3BwvlgSeDbRrkK8cYgcdWwySbK9qdbLfZ\nEpAjJqIzxV7Rya4zBBxmTJuZhN01QytAyewqaWf5x5DFmdJlksGCkkntWfxQD4TuTjaxsUMCRRRs\nbFpLJoOxzUMWGmH5u3/cI/v6+sq6kq6u7v3l85rR+J5zfufcr/STP/foPPyOIgIzM0vLObNdgJmZ\nVZ/D3cwsQQ53M7MEOdzNzBLkcDczS5DD3cwsQQ53M7MEOdzNzBLkcDczS1DTbL3x/PnzY/HixbP1\n9lZkz549b0bEgmpsy/1aP9yvaaq0X2ct3BcvXszu3btn6+2tiKR/qta23K/1w/2apkr71YdlzMwS\n5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwSNGt3qE7V4nVPjrvstf4b\naliJVZP7dfK2bdvGV7/6VUZHR7n11ltZt25daZPzJf0MuBroiYgnxhZIWg38WTb5zYh4ZCZqdL/O\nnoYLdzOD0dFR1q5dy44dO8jlcnR0dNDd3c3y5cuLm40AXwLuLJ4p6RLgHiAPBLBH0uaIOFaj8pNQ\nrQ+umfoA9GEZswa0a9cu2traWLJkCc3NzfT09LBp06bSZiMR8TxwomR+J7AjIo5mgb4DWFGDsq2G\nvOdu1oCGh4dpbW09OZ3L5di5c2elqy8EDhZND2XzrMpm87CUw92sAUXEGfMkVbp6uYZnbFDSGmAN\nwKJFiyZR3eT4uPzMcLibNaBcLsfBg6d2voeGhmhpaal09SHgs8WbA54pbRQRG4ANAPl8/sxPkxo6\n2wcAVP8YdwofOA53swbU0dHB3r17OXDgAAsXLmRgYIBHH3200tW3A/9V0sXZ9PXAXZOtIYUATJnD\n3awBNTU1sX79ejo7OxkdHeWWW26hvb2d3t5e8vk83d3dAB+VNARcDHxB0n0R0R4RRyV9A3g221xf\nRBydre+lWqq5d58Ch7tZg+rq6qKrq+u0eX19fcWT70XEaddGjomIh4GHZ646m22+FNLMLEEOdzOz\nBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93M\nLEEOdzOzBDnczcwSVFG4S1oh6WVJ+yStK7N8kaQfSXpO0vOSusptx2bHtm3bWLp0KW1tbfT395dr\n8luSBrO+e0rSx8cWSFotaW/2tbp2VZvZdEwY7pLmAA8AnweWAzdLKn0AwJ8Bj0fEp4Ae4MFqF2pT\nMzo6ytq1a9m6dSuDg4Ns3LiRwcHB0mbvAfmIuBp4Avg2gKRLgHuA3wGuAe4pejSbmdWxSvbcrwH2\nRcT+iBgBBoAbS9oEcGH2+iLgUPVKtOnYtWsXbW1tLFmyhObmZnp6eti0aVNps3cj4r3s9U8pPDAZ\noBPYERFHI+IYsANYUZvKzWw6Kgn3hcDBoumhbF6xe4EvZs9r3AL8SbkNSVojabek3YcPH55CuTZZ\nw8PDtLa2npzO5XIMDw+fbZUvA1uz15X0vZnVoUrCXWXmRcn0zcD/jIgc0AV8T9IZ246IDRGRj4j8\nggULJl+tTVpEaVeBVK5LQdIXgTzwF2Ozym2yzHr+0DarM5WE+xDQWjSd48zDLl8GHgeIiJ8AHwbm\nV6NAm55cLsfBg6d2voeGhmhpaTmjnaTPAXcD3RHx67HmTNz3/tA2q0OVhPuzwJWSLpfUTOGE6eaS\nNv8MXAsgaRmFcPcuXB3o6Ohg7969HDhwgJGREQYGBuju7i5t9hHgIQrB/kbR/O3A9ZIuzk6kXp/N\nM7M6N2G4R8Rx4A4K/6lfonBVzIuS+iSNpcTXgdsk/QLYCHwpyh0PsJprampi/fr1dHZ2smzZMlat\nWkV7ezu9vb1s3nzyM7oVOB/4G0k/l7QZICKOAt+g8AH/LNCXzTOzOtdUSaOI2ELhRGnxvN6i14PA\nZ6pbmlVLV1cXXV2n33rQ19dXPPlKROTLrRsRDwMPz1x1ZjYTfIeqmVmCHO5mZglyuJuZJcjhbmaW\nIIe7mVmCHO5mZglyuJuZJcjhbmaWIIe7WYOq4CEskvRY9pCdnZIWZzPPlfSIpBckvSTprlrWbbXh\ncDdrQBU+hGU+cCwi2oD7gW9l81cCH4qITwCfBr4yFvyWDoe7WQOq8CEsc4FHstdPANeqMN5zAOdJ\naqIwaNwI8E6tarfacLibNaAKH8LSTPawlWwAwLeBeRSC/v8Br1MY0fUvyw0I53H6G5vD3awBTeYh\nLKWrUnh05ijQAlwOfF3SkjLv4XH6G5jD3awBVfgQlhGyh61kh2AuAo4Cfwxsi4gPsvH7f0zhCVyW\nEIe7WQOq8CEsbwGrs9c3AU9nz1n4Z+APVHAe8LvAP9aseKuJisZzN7P6UvwQltHRUW655ZaTD2HJ\n5/NjQf8mME/SPgp77D3Z6g8Afw38A4Xn5P51RDw/G9+HzRyHu1mDquAhLBERK0vXi4hfUbgc0hLm\ncDdrEIvXPTnustf6b6hhJdYIfMzdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswQ53M3MEuRw\nNzNLkMPdzCxBDnczswQ53M3MEuRwNzNLkMPdzCxBDnczswR5yF+zGeahem02eM/dzCxBDnczswQ5\n3M3MElRRuEtaIellSfskrRunzSpJg5JelPRodcu06di2bRtLly6lra2N/v7+ck3Ol/QzSccl3VS8\nQNKopJ9nX5trU7GZTdeEJ1QlzaHwtPTrgCHgWUmbI2KwqM2VwF3AZyLimKRLZ6pgm5zR0VHWrl3L\njh07yOVydHR00N3dzfLly4ubjQBfAu4ss4n3I+KTtajVzKqnkj33a4B9EbE/IkaAAeDGkja3AQ9E\nxDGAiHijumXaVO3atYu2tjaWLFlCc3MzPT09bNq0qbTZSEQ8D5yYhRLNbAZUEu4LgYNF00PZvGJX\nAVdJ+rGkn0paUW5DktZI2i1p9+HDh6dWsU3K8PAwra2tJ6dzuRzDw8OT2cSHsz77qaQ/rHqBZjYj\nKrnOXWXmRZntXAl8FsgB/0fSb0fEW6etFLEB2ACQz+dLt2EzIOLMH7NUrkvHtSgiDklaAjwt6YWI\neLVke2uANQCLFi2aRrVmVi2V7LkPAa1F0zngUJk2myLig4g4ALxMIextluVyOQ4ePPWH19DQEC0t\nLRWvHxGHsn/3A88AnyrTZkNE5CMiv2DBgmnXbGbTV0m4PwtcKelySc1AD1B61cTfAv8OQNJ8Codp\n9lezUJuajo4O9u7dy4EDBxgZGWFgYIDu7u6K1pV0saQPZa/nA58BBs++lpnVgwnDPSKOA3cA24GX\ngMcj4kVJfZLGUmI7cETSIPAj4D9FxJGZKtoq19TUxPr16+ns7GTZsmWsWrWK9vZ2ent72bz55Gf0\nRyUNASuBhyS9mM1fBuyW9AsK/dpffJWUmdWvisaWiYgtwJaSeb1FrwP4WvZldaarq4uurq7T5vX1\n9RVPvhcRp10bCRAR/xf4xMxWZ2YzwXeompklyOFuZpYgh7uZWYIc7mYN6v39exj+7lcYfui28cYM\nkqTHsjGhdkpaXLTgakk/ycaCekHSh2tVt9WGw92sAcWJUY7u+A6XrryPllsfZOPGjQwOnnEh03zg\nWES0AfcD3wKQ1AR8H7g9Itop3Hz4Qe2qt1pwuJs1oJHXX6Fp7mWcO/djaM65440ZNBd4JHv9BHCt\nCrcnXw88HxG/AIiIIxExWrPirSYc7mYN6Pi7R2i68NTdwOOMGdRMNi5Udr/K28A8CjcZhqTt2VDP\n/7k2VVst+RmqZomocMygoPD//t8AHcB7wFOS9kTEUyXb85hBDcx77mYNqOmCeRx/59TIquOMGTRC\nNi5Udpz9IuAohbGg/j4i3oyI9yjcoPivSlf2mEGNzeFu1oCaL7uK48cO8cFbvyRGPxhvzKC3gNXZ\n65uAp7O7ybcDV0v6aBb6v4/HDEqOD8uYNSCdM4dLrrudNx7vhTjBmjvvODlmUD6fHwv6N4F5kvZR\n2GPvAcielvbfKQwKGMCWiHhytr4XmxkOd7MG9ZErOlh4RQcAd999A3DGmEERESvLrRsR36dwOaQl\nyodlzMwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93M\nLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3\nM0uQw93MLEEVhbukFZJelrRP0rqztLtJUkjKV69EMzObrAnDXdIc4AHg88By4GZJy8u0uwD4U2Bn\ntYu06dm2bRtLly6lra2N/v7+ck3Ol/QzSccl3VS8QNJqSXuzr9W1qdjMpquSPfdrgH0RsT8iRoAB\n4MYy7b4BfBv4lyrWZ9M0OjrK2rVr2bp1K4ODg2zcuJHBwcHSZiPAl4BHi2dKugS4B/gdCr8H90i6\nuAZlm9k0VRLuC4GDRdND2byTJH0KaI2I/1XF2qwKdu3aRVtbG0uWLKG5uZmenh42bdpU2mwkIp4H\nTpTM7wR2RMTRiDgG7ABW1KBsM5umSsJdZebFyYXSOcD9wNcn3JC0RtJuSbsPHz5ceZU2ZcPDw7S2\ntp6czuVyDA8PV7r6hB/s4H41q0eVhPsQ0Fo0nQMOFU1fAPw28Iyk14DfBTaXO6kaERsiIh8R+QUL\nFky9aqtYRJwxTyr3eV3WWT/Yi97D/WpWZ5oqaPMscKWky4FhoAf447GFEfE2MH9sWtIzwJ0Rsbu6\npdpU5HI5Dh48tfM9NDRES0tLpasPAZ8t3hzwTLVqs1MWr3ty3GWv9d9Qw0osFRPuuUfEceAOYDvw\nEvB4RLwoqU9S90wXaNPT0dHB3r17OXDgACMjIwwMDNDdXXG3bQeul3RxdiL1+myemdW5SvbciYgt\nwJaSeb3jtP3s9MuyamlqamL9+vV0dnYyOjrKLbfcQnt7O729veTz+bGg/6ikIeBi4AuS7ouI9og4\nKukbFP56A+iLiKOz9b2YWeUqCndrbF1dXXR1dZ02r6+vr3jyvYg4494FgIh4GHh45qqzqXp//x6O\nPrUBTpygf+6fsm7dGfcXStJjwKeBI8C/j4jXihYuAgaBeyPiL2tVt9WGhx8wa0BxYpSjO77DpSvv\no+XWB8e7f2E+cCwi2ihc0fatkuX3A1trUK7NAoe7WQMaef0VmuZexrlzP4bmnDve/QtzgUey108A\n1yq7VErSHwL7gRdrVrTVlMPdrAEdf/cITReeuux0nPsXmsnuU8gujHgbmCfpPOC/APfVplqbDQ53\ns0RUeP9CUAj1+yPiVxNszzenNTCfUDVrQE0XzOP4O6cCd5z7F0Yo3IA4JKkJuAg4SmGsoJskfZvC\noZsTkv4lItYXrxwRG4ANAPl8/sy74ayuOdzNGlDzZVdx/NghPnjrlzRdMI+BgQEeffTR0mZvAauB\nnwA3AU9H4Zbl3xtrIOle4FelwW6Nz+Fu1oB0zhwuue523ni8F+IEa+68o9z9C29SOMa+j8Iee8+s\nFm015XA3a1AfuaKDhVd0AHD33YUhCkruX4iIWHm2bUTEvTNUns0yn1A1M0uQw93MLEEOdzOzBDnc\nzcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3M0uQw93MLEEOdzOzBHlsGWsYi9c9Oe6y1/pvqGEl\nZvXPe+5mZglyuJuZJcjhbmaWIB9zt6T4uLxZgffczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3\nM0uQw93MLEG+zt1sGnxdvdUrh7tZGWcLbXBwW/3zYRkzswQ53M3MElRRuEtaIellSfskrSuz/GuS\nBiU9L+kpSR+vfqk2Vdu2bWPp0qW0tbXR399frokkPZb1705Ji7OZiyW9L+nn2df/qGXdZjZ1E4a7\npDnAA8DngeXAzZKWlzR7DshHxNXAE8C3q12oTc3o6Chr165l69atDA4OsnHjRgYHB0ubzQeORUQb\ncD/wraJlr0bEJ7Ov22tVt5lNTyV77tcA+yJif0SMAAPAjcUNIuJHEfFeNvlTIFfdMm2qdu3aRVtb\nG0uWLKG5uZmenh42bdpU2mwu8Ej2+gngWkmqaaFmVlWVhPtC4GDR9FA2bzxfBrZOpyirnuHhYVpb\nW09O53I5hoeHS5s1k/VxRBwH3gbmZcsul/ScpL+X9Hs1KNnMqqCSSyHL7cFF2YbSF4E88PvjLF8D\nrAFYtGhRhSXadESc2VUV7pQH8DqwKCKOSPo08LeS2iPinZLtuV/N6kwle+5DQGvRdA44VNpI0ueA\nu4HuiPh1uQ1FxIaIyEdEfsGCBVOp1yYpl8tx8OCpP7yGhoZoaWkpbTZC1seSmoCLgKMR8euIOAIQ\nEXuAV4GrSld2v86O9/fvYfi7X2H4odsme6L8Okl7JL2Q/fsHtazbaqOScH8WuFLS5ZKagR5gc3ED\nSZ8CHqIQ7G9Uv0ybqo6ODvbu3cuBAwcYGRlhYGCA7u7u0mZvAauz1zcBT0dESFqQnVBH0hLgSmB/\nzYq3ccWJUY7u+A6XrryPllsfnOyJ8jeBL0TEJyj0+/dqVrjVzIThnh2DvQPYDrwEPB4RL0rqkzSW\nEn8BnA/8TXbJ3OZxNmc11tTUxPr16+ns7GTZsmWsWrWK9vZ2ent72bz5ZDe9CcyTtA/4GjB2ueu/\nBZ6X9AsKJ1pvj4ijNf8m7Awjr79C09zLOHfux9Cccyd1ojwinouIsb++XwQ+LOlDNSrdaqSi4Qci\nYguwpWReb9Hrz1W5Lquirq4uurq6TpvX19dXPBkRsbJ0vYj4AfCDma3OpuL4u0douvDUIbBcLsfO\nnTtLm512olzS2InyN4va/BHwXLlDqT6X0tg8tozNOA+uVRuTOFE+1r6dwqGa68s2jNgAbADI5/Nl\nL6Kw+uXhB8waUNMF8zj+zuGT05M5UZ5N54AfAv8xIl6tRc1WWw53swbUfNlVHD92iA/e+iUx+sFk\nT5TPBZ4E7oqIH9ewbKshH5Yxa0A6Zw6XXHc7bzzeC3GCNXfecfJEeT6fHwv64hPlRylc6QaFCyTa\ngD+X9OfZvOt9pVtaHO5mDeojV3Sw8IoOAO6+u3DuosIT5d8EvlmDEm0W+bCMmVmCHO5mZgnyYRn7\njeNH6NlvAu+5m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFu\nZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYL8DNUEne0Z\noX4+qNlvBu+5m5klyOFuZpYgh7uZWYIc7mZmCXK4m5klyOFuZpYgh7uZWYJ8nXudONu16eDr081s\ncirac5e0QtLLkvZJWldm+YckPZYt3ylpcbULtal7f/8ehr/7FYYfuo3+/v5yTTRe/0m6K5v/sqTO\nWtVsE3O/2tlMGO6S5gAPAJ8HlgM3S1pe0uzLwLGIaAPuB75V7UJtauLEKEd3fIdLV95Hy60PsnHj\nRgYHB0ubzadM/2X93AO0AyuAB7PfB5tl7lebSCV77tcA+yJif0SMAAPAjSVtbgQeyV4/AVwrSdUr\n06Zq5PVXaJp7GefO/Riacy49PT1s2rSptNlcyvffjcBARPw6Ig4A+yj8Ptgsc7/aRCoJ94XAwaLp\noWxe2TYRcRx4G5hXjQJteo6/e4SmCxecnM7lcgwPD5c2a6Z8/1XS9zYL3K82EUXE2RtIK4HOiLg1\nm/4PwDUR8SdFbV7M2gxl069mbY6UbGsNsCabXAq8fJa3ng+8Oblvpy7UW90XAxcC/5RNXwKcx+n/\nuT8JLC7tP6AP+ElEfD+b/1fAloj4QfEbuF9nRb31K9Tfz6gSjVjzxyNiwUSNKrlaZghoLZrOAYfG\naTMkqQm4CDhauqGI2ABsqOA9kbQ7IvKVtK0n9Va3pH8N3BsRndn0XQAR8d+K2mynfP9V0vfu11lQ\nb/2avV9d/Ywq0Yg1V6qSwzLPAldKulxSM4UTMZtL2mwGVmevbwKejon+JLBamU7/bQZ6squhLgeu\nBHbVqG47O/erndWEe+4RcVzSHcB2YA7wcES8KKkP2B0Rm4G/Ar4naR+FPYOemSzaKjed/svaPQ4M\nAseBtRExOivfiJ3G/WoTmfCY+2yRtCb7s7ChNGrdtdKoP59GrbuWGvFn1Ig1V6puw93MzKbOY8uY\nmSWoLsN9ouEO6pWk1yS9IOnnknbPdj31xv2aJvdrfaq7wzLZbdCvANdRuGTrWeDmiDjj3up6I+k1\nIB8RjXbd7Ixzv6bJ/Vq/6nHPvZLhDqzxuF/T5H6tU/UY7o18a3QAfydpT3Z3n53ifk2T+7VO1eN4\n7uUGHKuvY0fj+0xEHJJ0KbBD0j9GxP+e7aLqhPs1Te7XOlWPe+4V3RpdjyLiUPbvG8AP8Uh7xdyv\naXK/1ql6DPdKbquuO5LOk3TB2GvgeuAfZrequuJ+TZP7tU7V3WGZ8W6rnuWyKvFbwA+zYeybgEcj\nYtvsllQ/3K9pcr/Wr7q7FNLMzKavHg/LmJnZNDnczcwS5HA3M0uQw93MLEEOdzOzBDnczcwS5HA3\nM0uQw93MLEH/Hy9HeZRuzKXVAAAAAElFTkSuQmCC\n", 355 | "text/plain": [ 356 | "" 357 | ] 358 | }, 359 | "metadata": {}, 360 | "output_type": "display_data" 361 | } 362 | ], 363 | "source": [ 364 | "np.random.seed(1111)\n", 365 | "n = 10\n", 366 | "logits = (np.random.random(n) - 0.5) * 2 # (-1, 1)\n", 367 | "x = range(n)\n", 368 | "\n", 369 | "plt.subplot(1, 3, 1)\n", 370 | "t = .1\n", 371 | "plt.bar(x, generalized_softmax(logits, t))\n", 372 | "\n", 373 | "plt.subplot(1, 3, 2)\n", 374 | "t = 1\n", 375 | "plt.bar(x, generalized_softmax(logits, t))\n", 376 | "\n", 377 | "plt.subplot(1, 3, 3)\n", 378 | "t = 50\n", 379 | "plt.bar(x, generalized_softmax(logits, t))\n" 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "将 gumbel-max 中的 argmax 操作,替换为 softmax,便实现了对于原来不可导的采样操作的软化版的近似。训练中,可以通过逐渐降低温度,以逐步逼近真实的离散分布。利用 gumbel-softmax,VAE 的隐变量可以用多项分布进行建模,一个示例见 [repo](https://github.com/DingKe/ml-tutorial/tree/master/gumbel)。这里,仅展示一个 toy 示例(代码[来自](http://amid.fish/humble-gumbel))。" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 11, 392 | "metadata": { 393 | "collapsed": false 394 | }, 395 | "outputs": [ 396 | { 397 | "name": "stdout", 398 | "output_type": "stream", 399 | "text": [ 400 | "Logits: \n", 401 | "0.02 0.37 1.00 0.37 0.02\n", 402 | "Distribution mean: 2.00\n", 403 | "Distribution mean: 2.13\n", 404 | "Distribution mean: 2.23\n", 405 | "Distribution mean: 2.60\n", 406 | "Distribution mean: 2.75\n", 407 | "Distribution mean: 3.23\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "import tensorflow as tf\n", 413 | "sess = tf.Session()\n", 414 | "\n", 415 | "def differentiable_sample(logits, temperature=1):\n", 416 | " noise = tf.random_uniform(tf.shape(logits), seed=11)\n", 417 | " logits_with_noise = logits - tf.log(-tf.log(noise))\n", 418 | " return tf.nn.softmax(logits_with_noise / temperature)\n", 419 | "\n", 420 | "mean = tf.Variable(2.)\n", 421 | "idxs = tf.Variable([0., 1., 2., 3., 4.])\n", 422 | "# An unnormalised approximately-normal distribution\n", 423 | "logits = tf.exp(-(idxs - mean) ** 2)\n", 424 | "sess.run(tf.global_variables_initializer())\n", 425 | "\n", 426 | "def print_logit_vals():\n", 427 | " logit_vals = sess.run(logits)\n", 428 | " print(\" \".join([\"{:.2f}\"] * len(logit_vals)).format(*logit_vals))\n", 429 | " \n", 430 | "print(\"Logits: \")\n", 431 | "print_logit_vals()\n", 432 | "\n", 433 | "sample = differentiable_sample(logits)\n", 434 | "sample_weights = tf.Variable([1., 2., 3., 4., 5.], trainable=False)\n", 435 | "result = tf.reduce_sum(sample * sample_weights)\n", 436 | "\n", 437 | "sess.run(tf.global_variables_initializer())\n", 438 | "train_op = tf.train.GradientDescentOptimizer(learning_rate=1).minimize(-result)\n", 439 | "\n", 440 | "print(\"Distribution mean: {:.2f}\".format(sess.run(mean)))\n", 441 | "for i in range(5):\n", 442 | " sess.run(train_op)\n", 443 | " print(\"Distribution mean: {:.2f}\".format(sess.run(mean)))" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "metadata": {}, 449 | "source": [ 450 | "可以看到,利用 gumbel-softmax 训练参数向着预期的方向改变。" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 12, 456 | "metadata": { 457 | "collapsed": false 458 | }, 459 | "outputs": [ 460 | { 461 | "name": "stdout", 462 | "output_type": "stream", 463 | "text": [ 464 | "Logits: \n", 465 | "0.02 0.37 1.00 0.37 0.02\n", 466 | "Distribution mean: 2.00\n", 467 | "Distribution mean: 2.32\n", 468 | "Distribution mean: 2.65\n", 469 | "Distribution mean: 2.87\n", 470 | "Distribution mean: 3.10\n", 471 | "Distribution mean: 3.36\n" 472 | ] 473 | } 474 | ], 475 | "source": [ 476 | "import tensorflow as tf\n", 477 | "sess = tf.Session()\n", 478 | "\n", 479 | "mean = tf.Variable(2.)\n", 480 | "idxs = tf.Variable([0., 1., 2., 3., 4.])\n", 481 | "# An unnormalised approximately-normal distribution\n", 482 | "logits = tf.exp(-(idxs - mean) ** 2)\n", 483 | "sess.run(tf.global_variables_initializer())\n", 484 | "\n", 485 | "def print_logit_vals():\n", 486 | " logit_vals = sess.run(logits)\n", 487 | " print(\" \".join([\"{:.2f}\"] * len(logit_vals)).format(*logit_vals))\n", 488 | " \n", 489 | "print(\"Logits: \")\n", 490 | "print_logit_vals()\n", 491 | "\n", 492 | "sample = tf.nn.softmax(logits)\n", 493 | "sample_weights = tf.Variable([1., 2., 3., 4., 5.], trainable=False)\n", 494 | "result = tf.reduce_sum(sample * sample_weights)\n", 495 | "\n", 496 | "sess.run(tf.global_variables_initializer())\n", 497 | "train_op = tf.train.GradientDescentOptimizer(learning_rate=1).minimize(-result)\n", 498 | "\n", 499 | "print(\"Distribution mean: {:.2f}\".format(sess.run(mean)))\n", 500 | "for i in range(5):\n", 501 | " sess.run(train_op)\n", 502 | " print(\"Distribution mean: {:.2f}\".format(sess.run(mean)))" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": {}, 508 | "source": [ 509 | "# 讨论\n", 510 | "乍看起来,gumbel-softmax 的用处令人费解。比如上面的代码示例,直接使用 softmax,也可以达到类似的参数训练效果。但两者有着根本的区别。\n", 511 | "原理上,常规的 softmax 直接建模了一个概率分布(多项分布),基于交叉熵的训练准则使分布尽可能靠近目标分布;而 gumbel-softmax 则是对多项分布采样的一个近似。使用上,常规的有监督学习任务(分类器训练)中,直接学习输出的概率分布是自然的选择;而对于涉及采样的学习任务(VAE 隐变量采样、强化学习中对actions 集合进行采样以确定下一步的操作),gumbel-softmax 提供了一种再参数化的方法,使得模型可以以端到端的方式进行训练。" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "# References\n", 519 | "1. http://amid.fish/humble-gumbel\n", 520 | "2. [proof of Gumbel based sampling](https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/)\n", 521 | "3. https://blog.evjang.com/2016/11/tutorial-categorical-variational.html\n", 522 | "4. Jang et al. [CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX](https://arxiv.org/abs/1611.01144).\n" 523 | ] 524 | } 525 | ], 526 | "metadata": { 527 | "anaconda-cloud": {}, 528 | "kernelspec": { 529 | "display_name": "Python [Root]", 530 | "language": "python", 531 | "name": "Python [Root]" 532 | }, 533 | "language_info": { 534 | "codemirror_mode": { 535 | "name": "ipython", 536 | "version": 2 537 | }, 538 | "file_extension": ".py", 539 | "mimetype": "text/x-python", 540 | "name": "python", 541 | "nbconvert_exporter": "python", 542 | "pygments_lexer": "ipython2", 543 | "version": "2.7.12" 544 | } 545 | }, 546 | "nbformat": 4, 547 | "nbformat_minor": 0 548 | } 549 | -------------------------------------------------------------------------------- /gumbel/gumbel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from keras import backend as K 4 | 5 | def random_gumbel(shape, eps=1e-20): 6 | U = K.random_uniform(shape) 7 | return -K.log(-K.log(U + eps) + eps) 8 | 9 | 10 | def gumbel_softmax_sample(logits, temperature): 11 | """ Draw a sample from the Gumbel-Softmax distribution""" 12 | y = logits + random_gumbel(K.shape(logits)) 13 | return K.softmax(y / temperature) 14 | 15 | 16 | def gumbel_softmax(logits, temperature, hard=False): 17 | y = gumbel_softmax_sample(logits, temperature) 18 | if hard: 19 | k = K.shape(logits)[-1] 20 | y_hard = K.cast(K.one_hot(K.argmax(y, 1), k), K.floatx) 21 | y = K.stop_gradient(y_hard - y) + y 22 | return y 23 | -------------------------------------------------------------------------------- /gumbel/img/i_xent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DingKe/ml-tutorial/14b9fdf19119fad1bf97b9eb442c4473beb04be8/gumbel/img/i_xent.png -------------------------------------------------------------------------------- /gumbel/img/x_xent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DingKe/ml-tutorial/14b9fdf19119fad1bf97b9eb442c4473beb04be8/gumbel/img/x_xent.png -------------------------------------------------------------------------------- /gumbel/variational_autoencoder_gumbel.py: -------------------------------------------------------------------------------- 1 | '''This script demonstrates how to build a variational autoencoder with Keras. 2 | 3 | Reference: 4 | "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114 5 | "CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX" https://arxiv.org/abs/1611.01144 6 | ''' 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | from scipy.stats import norm 10 | 11 | from keras.layers import Input, Dense, Lambda, Reshape, Softmax, Flatten 12 | from keras.models import Model 13 | from keras.regularizers import l2 14 | from keras import backend as K 15 | from keras import objectives 16 | from keras.datasets import mnist 17 | 18 | from gumbel import gumbel_softmax 19 | 20 | np.random.seed(1111) # for reproducibility 21 | 22 | batch_size = 100 23 | nb_classes = 10 24 | n = 784 25 | m = 1 26 | hidden_dim = 256 27 | epochs = 50 28 | epsilon_std = 1.0 29 | use_loss = 'xent' # 'mse' or 'xent' 30 | 31 | decay = 1e-4 # weight decay, a.k. l2 regularization 32 | use_bias = True 33 | 34 | ## Encoder 35 | def build_encoder(temperature, hard): 36 | x = Input(batch_shape=(batch_size, n)) 37 | h_encoded = Dense(hidden_dim, kernel_regularizer=l2(decay), bias_regularizer=l2(decay), use_bias=use_bias, activation='tanh')(x) 38 | z = Dense(m * nb_classes, kernel_regularizer=l2(decay), bias_regularizer=l2(decay), use_bias=use_bias)(h_encoded) 39 | 40 | logits_z = Reshape((m, nb_classes))(z) # batch x m * nb_classes -> batch x m x nb_classes 41 | q_z = Softmax()(logits_z) 42 | log_q_z = Lambda(lambda x: K.log(x + K.epsilon()))(q_z) 43 | 44 | z = Lambda(lambda x: gumbel_softmax(x, temperature, hard))(logits_z) 45 | 46 | z = Flatten()(z) 47 | q_z = Flatten()(q_z) 48 | log_q_z = Flatten()(log_q_z) 49 | 50 | return x, z, q_z, log_q_z 51 | 52 | 53 | def build_decoder(z): 54 | # we instantiate these layers separately so as to reuse them later 55 | decoder_h = Dense(hidden_dim, kernel_regularizer=l2(decay), bias_regularizer=l2(decay), use_bias=use_bias, activation='tanh') 56 | decoder_mean = Dense(n, kernel_regularizer=l2(decay), bias_regularizer=l2(decay), use_bias=use_bias, activation='sigmoid') 57 | 58 | h_decoded = decoder_h(z) 59 | x_hat = decoder_mean(h_decoded) 60 | 61 | return x_hat, decoder_h, decoder_mean 62 | 63 | 64 | #temperature = K.variable(np.asarray([1])) 65 | temperature = 1 66 | hard = False 67 | x, z, q_z, log_q_z = build_encoder(temperature, hard) 68 | x_hat, decoder_h, decoder_mean = build_decoder(z) 69 | 70 | 71 | ## loss 72 | def vae_loss(x, x_hat): 73 | kl_loss = 0.01 + K.mean(q_z * (log_q_z - K.log(1.0 / nb_classes))) 74 | xent_loss = n * objectives.binary_crossentropy(x, x_hat) 75 | mse_loss = n * objectives.mse(x, x_hat) 76 | if use_loss == 'xent': 77 | return xent_loss - kl_loss 78 | elif use_loss == 'mse': 79 | return mse_loss - kl_loss 80 | else: 81 | raise Expception, 'Nonknow loss!' 82 | 83 | vae = Model(x, x_hat) 84 | vae.compile(optimizer='rmsprop', loss=vae_loss) 85 | 86 | # train the VAE on MNIST digits 87 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 88 | 89 | x_train = x_train.astype('float32') / 255. 90 | x_test = x_test.astype('float32') / 255. 91 | x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:]))) 92 | x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:]))) 93 | 94 | vae.fit(x_train, x_train, 95 | shuffle=True, 96 | epochs=epochs, 97 | batch_size=batch_size, 98 | validation_data=(x_test, x_test)) 99 | 100 | 101 | # build a digit generator that can sample from the learned distribution 102 | decoder_input = Input(shape=(m * nb_classes,)) 103 | _h_decoded = decoder_h(decoder_input) 104 | _x_hat = decoder_mean(_h_decoded) 105 | generator = Model(decoder_input, _x_hat) 106 | 107 | n = nb_classes 108 | digit_size = 28 109 | figure = np.zeros((digit_size, digit_size * n)) 110 | for i in range(nb_classes): 111 | z_sample = np.zeros([1, nb_classes]) 112 | z_sample[0, i] = 1 113 | x_decoded = generator.predict(z_sample) 114 | digit = x_decoded[0].reshape(digit_size, digit_size) 115 | figure[:, i * digit_size: (i + 1) * digit_size] = digit 116 | 117 | fig = plt.figure(figsize=(10, 10)) 118 | plt.imshow(figure, cmap='Greys_r') 119 | plt.show() 120 | fig.savefig('x_{}.png'.format(use_loss)) 121 | 122 | 123 | # data imputation 124 | n = 15 # figure with 15x15 digits 125 | figure = np.zeros((digit_size * 3, digit_size * n)) 126 | x = x_test[:batch_size,:] 127 | x_corupted = np.copy(x) 128 | x_corupted[:, 300:400] = 0 129 | x_encoded = vae.predict(x_corupted, batch_size=batch_size).reshape((-1, digit_size, digit_size)) 130 | x = x.reshape((-1, digit_size, digit_size)) 131 | x_corupted = x_corupted.reshape((-1, digit_size, digit_size)) 132 | for i in range(n): 133 | xi = x[i] 134 | xi_c = x_corupted[i] 135 | xi_e = x_encoded[i] 136 | figure[:digit_size, i * digit_size:(i+1)*digit_size] = xi 137 | figure[digit_size:2 * digit_size, i * digit_size:(i+1)*digit_size] = xi_c 138 | figure[2 * digit_size:, i * digit_size:(i+1)*digit_size] = xi_e 139 | 140 | fig = plt.figure(figsize=(10, 10)) 141 | plt.imshow(figure, cmap='Greys_r') 142 | plt.show() 143 | fig.savefig('i_{}.png'.format(use_loss)) 144 | -------------------------------------------------------------------------------- /plda/PLDA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "PLDA 是一个概率生成模型,最初是为解决人脸识别和验证问题而被提出[3,5],之后也被广泛应用到声纹识别等模式识别任务中。不同学者从不同的动机出发,提出了不尽相同的 PLDA 算法,文献[2] 在统一的框架下比较了三种 PLDA 算法变种(standard[3,6], simplified[4], two-covariance[5,8]),并在说话从识别任务上比较了它们的性能差异。\n", 8 | "\n", 9 | "本文讨论的 PLDA 主要是基于文献 [5] 中提出的 PLDA(即 two-covariance PLDA),这也是 Kaldi 中采用的算法。\n", 10 | "\n", 11 | "本文 1 节简单介绍 LDA,只对 PLDA 算法感兴趣的读者可以跳过。" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# 1. LDA\n", 19 | "\n", 20 | "\n", 21 | "## 1.1 基本 LDA\n", 22 | "\n", 23 | "线性判别分析(Linear Discriminant Analysis, LDA)[1] 是一种线性分类技术。LDA 假设数据服从高斯分布,并且各类的协方差相同。\n", 24 | "如果各类的先验概率为 $\\pi_k$($\\sum_k \\pi_k = 1$),则各类数据的概率分布为:\n", 25 | "$$\n", 26 | "p(x|k) \\sim N(\\mu_k, \\Sigma)\n", 27 | "$$\n", 28 | "\n", 29 | "观测到数据后,类别的后验为:\n", 30 | "$$\n", 31 | "p(k|x) = \\frac{p(x|k)\\pi_k}{p(x)}\n", 32 | "$$\n", 33 | "为了对数据进行分类,比较后验的似然比:\n", 34 | "$$\n", 35 | "\\ln \\frac{p(k|x)}{p(l|x)} = \\ln \\frac{p(x|k)}{p(x|l)} + \\ln \\frac{\\pi_k}{\\pi_l} = \\ln \\frac{\\pi_k}{\\pi_l} - \\frac{1}{2} (\\mu_k - \\mu_l)^T\\Sigma^{-1} (\\mu_k + \\mu_l) + x^T\\Sigma^{-1}(\\mu_k - \\mu_l)\n", 36 | "$$\n", 37 | "\n", 38 | "由于假设协方差相同,因此似然比是关于输入 $x$ 的线性函数。LDA 用一系列超平面划分数据空间,进而完成分类。图1 示意了LDA 构建的分类决策平面。\n", 39 | "\n", 40 | "![](http://www.ucl.ac.uk/~ucfbpve/papers/VermeeschGCubed2006/figures/discriminant.jpg)\n", 41 | "**图1. LDA 分类示意【[src](http://www.ucl.ac.uk/~ucfbpve/papers/VermeeschGCubed2006/figures/discriminant.jpg)】**\n", 42 | "\n", 43 | "> 如果允许各类的协方差不同,则决策面是二次的,称为二次判别分析(QDA(Quadratic Discriminant Analysis,QDA)。\n", 44 | "\n", 45 | "这种基本的 LDA 的参数估计非常简单直接:\n", 46 | "1. $\\hat \\pi_k = \\frac{n_k}{N}$,其中 $N = \\sum_{i=1}^K n_i$ \n", 47 | "2. $\\hat \\mu_k = \\frac{1}{n_k}\\sum_{i=1}^{n_k} x_{ki}$ \n", 48 | "3. $\\hat \\Sigma = \\frac{1}{N-K} \\sum_{k=1}^K \\sum_{i=1}^{n_k} (x_{ki} - \\hat \\mu_k)(x_{ki} - \\hat \\mu_k)^T$\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## 1.2 降维 LDA\n", 56 | "\n", 57 | "基本的 LDA 可能并不吸引人。LDA 得到广泛应用的一个重要原因是,它可以用来对数据进行降维。\n", 58 | "\n", 59 | "![](http://uc-r.github.io/public/images/analytics/discriminant_analysis/LDA.jpg)\n", 60 | "**图2. LDA 投影【[src](http://uc-r.github.io/public/images/analytics/discriminant_analysis/LDA.jpg)】**\n", 61 | "\n", 62 | "以两分类为例(图2),如果将数据沿着决策线(超平面)的方向投影,投影(降维)后不影响数据的分类结果,因为,非投影方向的各类的均值和(协)方差相同,并不能为分类提供有效信息。\n", 63 | "\n", 64 | "一般地,对于 $K$ 个类别的分类问题,对于分类为效的信息集中在 $K-1$ 维的子空间上。即基于 LDA 的假设,数据可以从原来的 $d$ 维降为 $K-1$ 维(假设$K-1 <= d$)。投影方向的计算可以参见[1]第4章。\n", 65 | "\n", 66 | "## 1.3 类内方差与类间方差\n", 67 | "\n", 68 | "对于 LDA 降维还有另一种解释来自 Fisher:**对于原数据 X,寻找一个线性组合 $Z = w^T X$(低维投影),使得 $Z$ 的类间方差与类内方差的比值最大。**\n", 69 | "\n", 70 | "显然,这样的投影方式使得数据在低维空间最容易区分。为此,我们需要优化瑞利熵(Rayleigh Quotient)[1]。\n", 71 | "\n", 72 | "$$\n", 73 | "\\mathbf{w} = \\underset{\\mathbf{w}}{\\mathrm{argmax}}\\ \\frac{\\mathbf{w}^T S_b \\mathbf{w}}{\\mathbf{w}^T S_w \\mathbf{w}}\n", 74 | "$$\n", 75 | "等价的\n", 76 | "$$\n", 77 | "\\mathbf{w} = \\underset{\\mathbf{w}}{\\mathrm{argmax}}\\ \\mathbf{w}^T S_b \\mathbf{w}\n", 78 | "$$\n", 79 | "$$\n", 80 | "s.t. \\mathbf{w}^T S_w \\mathbf{w} = 1\n", 81 | "$$\n", 82 | "这是一个广义特征值问题。具体地,上式可以用拉格朗日乘数求解,即寻找下式的极值:\n", 83 | "$$\n", 84 | "\\mathbf{w}^T S_b \\mathbf{w} - \\lambda (\\mathbf{w}^T S_w \\mathbf{w} - 1)\n", 85 | "$$\n", 86 | "求导后得到:\n", 87 | "$$\n", 88 | "S_b\\mathbf{w} = \\lambda S_w\\mathbf{w}\n", 89 | "$$\n", 90 | "\n", 91 | "如果 $S_w$ 可逆,则有:\n", 92 | "$$\n", 93 | "S_w^{-1} S_b\\mathbf{w} = \\lambda \\mathbf{w}\n", 94 | "$$\n", 95 | "即我们只需求解矩阵 $S_w^{-1} S_b$ 的特征向量,保留特征值最大前 $k$(不一定是 $K - 1$)个特征向量(列向量)即可以得到 $W$($d$ 行,$k$ 列)。 \n", 96 | "\n", 97 | "类内-类间方差的视角更加通过,首先,它并不需要假设数据分布服从高斯且协方差相同;其次,降维的维度不再依赖类别数目。" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "# 2. PLDA\n", 105 | "\n", 106 | "同 LDA 一样,各种 PLDA 都将数据之间的差异分解为类内差异和类间差异,但是从概率的角度进行建模。这里,我们按照 [5] 的思路,介绍所谓 two-covariance PLDA。\n", 107 | "\n", 108 | "假设数据 $\\mathbf{x}$ 满足如下概率关系:\n", 109 | "\n", 110 | "$$\n", 111 | "p(\\mathbf{x}|\\mathbf{y}) = \\mathcal{N}(\\mathbf{x}|\\mathbf{y}, \\Phi_w)\n", 112 | "$$\n", 113 | "\n", 114 | "$$\n", 115 | "p(\\mathbf{y}) = \\mathcal{N}(\\mathbf{y}|\\mathbf{m}, \\Phi_b)\n", 116 | "$$\n", 117 | "\n", 118 | "> LDA 假设各类中心服从离散分布,离散中心的个数固定,是训练数据中已知的类别数;PLDA 假设各类中心服从一个连续分布(高斯分布)。因此,PLDA 能够扩展到未知类别,从而用于未知类别的识别与认证。\n", 119 | "\n", 120 | "这里要求协方差矩阵 $\\Phi_w$ 是**正定**的对称方阵,反映了类间(within-class)的差异; $\\Phi_b$ 是**半正定** 的对称方阵,反映了类间(between-class)的差异。因此,PLDA 建模了数据生成的过程,并且同时 LDA 一样,显式地考虑了类内和类间方差。\n", 121 | "\n", 122 | "为了推断时方便,下面推导 PLDA 的一种等价表示。\n", 123 | "根据线性代数的基础知识, $\\Phi_w$ 和 $\\Phi_b$ 可以同时**合同对角化**(simultaneous diagonalization by congruence),即存在可逆矩阵 $V$,使得 $V^T\\Phi_bV=\\Psi$($\\Psi$ 为对角阵)且 $V^T\\Phi_wV = I$($I$ 是单位矩阵)。对角化方法见第 4 节。\n", 124 | "\n", 125 | "基于上述说明,PLDA 的等价表述为:\n", 126 | "$$\n", 127 | "\\mathbf{x} = \\mathbf{m} + A\\mathbf{u}\n", 128 | "$$\n", 129 | "其中,\n", 130 | "$$\\mathbf{u} \\sim \\mathcal{N}(\\cdot|\\mathbf{v}, I)$$\n", 131 | "$$\\mathbf{v} \\sim \\mathcal{N}(\\cdot|0, \\Psi)$$\n", 132 | "\n", 133 | "$$\n", 134 | "A = V^{-1}\n", 135 | "$$\n", 136 | "\n", 137 | "$\\mathbf{u}$ 是数据空间在投影空间的对应投影点。$\\Psi$ 反映了类间(within-class)的差异; $I$ 反映了类间(between-class)的差异,这里被规一化。" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "# 3. 基于 PLDA 的推断\n", 145 | "对于每一个观测数据 $\\mathbf{x}$ 我们都可以计算对应的 $\\mathbf{u} = A^{-1}(\\mathbf{x} - \\mathbf{m})$。 PLDA 的推断都在投影空间中进行。\n", 146 | "\n", 147 | "给定观一组同类的测数据 $\\mathbf{u}_{1,\\dots,n}$,$\\mathbf{v}$ 的后验概率分布为(参见 4.2.1):\n", 148 | "$$\n", 149 | "p(\\mathbf{v}|\\mathbf{u}_{1,\\dots,n}) = \\mathcal{N}(\\mathbf{v}|\\frac{n\\Psi}{n\\Psi + I} \\mathbf{\\bar u}, \\frac{\\Psi}{n\\Psi + I})\n", 150 | "$$\n", 151 | "其中,$\\mathbf{\\bar u} = \\frac{1}{n}\\sum_{i=1}^n\\mathbf{u}_i$。\n", 152 | "\n", 153 | "因此,对于未知数据点 $\\mathbf{u}^p$ 以及某类的若干数据点 $\\mathbf{u}^g_{1,\\dots,n}$(i.i.d.),$\\mathbf{u}^p$ 属于该类的似然值:\n", 154 | "$$\n", 155 | "p(\\mathbf{u}^p|\\mathbf{u}^g_{1,\\dots,n}) = \\mathcal{N}(\\frac{n\\Psi}{n\\Psi + I} \\mathbf{\\bar u}^g, \\frac{\\Psi}{n\\Psi + I} + I)\n", 156 | "$$\n", 157 | "\n", 158 | "$$\n", 159 | "\\ln p(\\mathbf{u}^p|\\mathbf{u}^g_{1,\\dots,n}) = C - \\frac{1}{2} (\\mathbf{u}^p - \\frac{n\\Psi}{n\\Psi + I} \\mathbf{\\bar u}^g)^T (\\frac{\\Psi}{n\\Psi + I} + I)^{-1}(\\mathbf{u}^p - \\frac{n\\Psi}{n\\Psi + I} \\mathbf{\\bar u}^g) -\\frac{1}{2}\\ln |\\frac{\\Psi}{n\\Psi + I} + I|\n", 160 | "$$\n", 161 | "\n", 162 | "其中,$C = -\\frac{1}{2}d\\ln 2\\pi$ 是与数据无关的常量,$d$ 是数据的维度。特殊的,$\\mathbf{u}^p$ 不属于任何已知类的概率为:\n", 163 | "$$\n", 164 | "p(\\mathbf{u}^p|\\emptyset) = \\mathcal{N}(\\mathbf{u}^p|0, \\Psi + I)\n", 165 | "$$\n", 166 | "\n", 167 | "$$\n", 168 | "\\ln p(\\mathbf{u}^p|\\emptyset) = C - \\frac{1}{2} \\mathbf{u}^{pT} (\\Psi + I)^{-1} \\mathbf{u}^p -\\frac{1}{2}\\ln |\\Psi + I|\n", 169 | "$$\n", 170 | "\n", 171 | "由于 $\\Psi$ 是对角阵,因此上式中各个协方差也都是对角阵,因此,似然和对数似然都很容易求得。\n", 172 | "\n", 173 | "利用 PLDA 进行识别(recognition)方法如下::\n", 174 | "$$\n", 175 | "i = \\underset{i}{\\mathrm{argmax}}\\ \\ln p(\\mathbf{u}^p|\\mathbf{u}^{g_i}_{1,\\dots,n}) \n", 176 | "$$\n", 177 | "\n", 178 | "\n", 179 | "对于认证问题(verification),可以计算其似然比:\n", 180 | "$$\n", 181 | "R = \\frac{p(\\mathbf{u}^p|\\mathbf{u}^g_{1,\\dots,n})}{p(\\mathbf{u}^p|\\emptyset)}\n", 182 | "$$\n", 183 | "\n", 184 | "或似然比对数(log likelihood ratio):\n", 185 | "\n", 186 | "$$\n", 187 | "\\ln R = \\ln p(\\mathbf{u}^p|\\mathbf{u}^g_{1,\\dots,n}) - \\ln p(\\mathbf{u}^p|\\emptyset)\n", 188 | "$$\n", 189 | "适当的选定阈值 $T$,当 $R > T$ 判定 $\\mathbf{u}$ 与已知数据属于同一个类别,反之则不是。\n", 190 | "\n", 191 | "\n", 192 | "> 这里介绍的 two-covariance PLDA 并学习一个低维空间投影[2],这一点不同 PLDA 及 standard PLDA 和 simplified PLDA。作为近似手段,可以在投影空间中,丢弃 $\\Psi$ 中对角元素较小的若干维度对应的值。" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "# 4. PLDA 参数估计\n", 200 | "\n", 201 | "PLDA 中,我们需要估计的参数包括 $A$、$\\Psi$ 和 $\\mathbf{m}$。\n", 202 | "\n", 203 | "## 4.1 直接求解\n", 204 | "对于 $K$ 类共 $N$ 个训练数据 $(x_1,\\dots,x_N)$。如果每类样本数量相等,都为 $n_k = N / K$,则参数有解析的估计形式 [5]。\n", 205 | "\n", 206 | "1. 计算类内 (scatter matrix) $S_w = \\frac{1}{N} \\sum_k\\sum_{i \\in \\mathcal{C}_k} (\\mathbf{x}_i - \\mathbf{m}_k)(\\mathbf{x}_i - \\mathbf{m}_k)^T$ 和 类间 $S_b = \\frac{1}{N}\\sum_kn_k(\\mathbf{m}_k - \\mathbf{m})(\\mathbf{m}_k - \\mathbf{m})^T$。其中,$m_k = \\frac{1}{n_k} \\sum_{i \\in \\mathcal{C}_k} \\mathbf{x}_i$ 为第 $k$ 类样本均值,$m = \\frac{1}{N} \\sum_k\\sum_{i \\in \\mathcal{C}_k} \\mathbf{x}_i$ 为全部样本均值。\n", 207 | "2. 计算 $S_w^{-1}S_b$ 的特征向量 $w_{1,\\dots,d}$,每个特征向量为一列,组成矩阵 $W$。计算对角阵 $\\Lambda_b = W^TS_bW$,$\\Lambda_w = W^TS_wW$。\n", 208 | "3. 计算 $A = W^{-T} (\\frac{n}{n-1}\\Lambda_w)^{1/2}$,$\\Psi = \\max(0, \\frac{n-1}{n}(\\Lambda_b/\\Lambda_w) - \\frac{1}{n})$。\n", 209 | "4. 如果将维度从原数据的 $d$ 维低到 $d^\\prime$ 维,则保留 $\\Psi$ 的前 $d^\\prime$ 大的对角元素,将其余置为零。在进行推断时,仅使用 $\\mathbf{u} = A^{-1}(\\mathbf{x}-\\mathbf{m})$ 非零对角无素对应的 $d^\\prime$ 个元素。\n", 210 | "\n", 211 | "如果各类数据的数量不一致,则上述算法只能求得近似的参数估计。" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "## 4.2 期望最大化方法(EM)\n", 219 | "\n", 220 | "这里先列出算法流程,具体推导见下文:\n", 221 | "\n", 222 | "输入:$K$ 类 $d$ 维数据,第 $k$ 个类别包含 $n_k$ 个样本,记 $x_{ki} \\in R^{d}, 1 \\le k \\le K$ 为 第 $k$ 个类别的第 $i$ 个样本点。\n", 223 | "\n", 224 | "输出:$d \\times d$ 对称矩阵 $\\Phi_w$,$d \\times d$ 对称矩阵 $\\Phi_b$,$d$ 维向量 $m$。\n", 225 | "\n", 226 | "1. 计算统计量,$N = \\sum_{k=1}^K n_k$, $f_k = \\sum_{i=1}^{n_k} x_{ki}$,$m = \\frac{1}{N}\\sum_k f_{k}$,$S = \\sum_k \\sum_i x_{ki}x_{ki}^T$\n", 227 | "2. 随机初始化 $\\Phi_w$,$\\Phi_b$,$m$\n", 228 | "3. 重复如下步骤至到满足终止条件:\n", 229 | " * 3.1 对每一个类别,计算: $\\hat \\Phi = (n\\Phi^{-1}_w + \\Phi_b^{-1})^{-1}$, $y = \\hat \\Phi (\\Phi_b^{-1} m + \\Phi_w^{-1} f)$,$yyt = \\hat \\Phi + y y^T$\n", 230 | " * 3.2 聚合计算结果:$R = \\sum_k n_k \\cdot yyt_k$,$T = \\sum_k y_k f_k^T$,$P = \\sum_k\\hat \\Phi_k$,$E =\\sum_k (y_k - m)(y_k - m)^T$\n", 231 | " * 3.3 更新:$\\Phi_w = \\frac{1}{N} (S + R - (T + T^T))$,$\\Phi_b = \\frac{1}{K}(P + E)$\n", 232 | "\n", 233 | "> 上述算法基本与[2]所列的 two-covariance 算法相同,不过这里我们直接使用全局均值做为 $m$ 的估计。如果各类别的样本量相同,则两种方法是等价,实现比较见[代码](./plda.py)。具体公式的差异见下文。\n", 234 | "\n", 235 | "基于 $\\Phi_w$ 和 $\\Phi_b$ 可以计算出 $\\Psi$ 和 $A^{-1}$。对角化的方法在 4.1 节算法的2、3步已经给出[5]。\n", 236 | "\n", 237 | "首先,计算 $\\Phi_w^{-1}\\Phi_b$ 的特征向量 $w_{1,\\dots,d}$,每个特征向量为一列,组成矩阵 $W$。则有:\n", 238 | "$$\n", 239 | "\\Lambda_b = W^T \\Phi_b W\n", 240 | "$$\n", 241 | "$$\n", 242 | "\\Lambda_w = W^T \\Phi_w W\n", 243 | "$$\n", 244 | "\n", 245 | "显然\n", 246 | "$$\n", 247 | "I = \\Lambda_w^{-1/2}\\Lambda_w \\Lambda_w^{-1/2} = \\Lambda_w^{-1/2} W^T \\Phi_w W \\Lambda_w^{-1/2}\n", 248 | "$$\n", 249 | "\n", 250 | "则:\n", 251 | "$$\n", 252 | "\\Psi = \\Lambda_b \\Lambda_w^{-1}\n", 253 | "$$\n", 254 | "$$\n", 255 | "V = W \\Lambda_w^{-1/2} \n", 256 | "$$\n", 257 | "$$\n", 258 | "A^{-1} = V^T = \\Lambda_w^{-1/2} W^T\n", 259 | "$$" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "这里我们首先通过 EM 算法估计 $m$、$\\Phi_b$ 和 $\\Phi_w$。在基础上,我们可以根据上面的公式计算 $A^{-1}$ 及 $\\Psi$,从而完成 PLDA 模型的训练。\n", 267 | "\n", 268 | "Kaldi 中基于期望最大化方法(EM)[实现](https://github.com/kaldi-asr/kaldi/blob/master/src/ivector/plda.cc)了 PLDA 的参数($\\Phi_b$ 和 $\\Phi_w$)估计。\n", 269 | "文献 [2] (算法2 及附录 B) 给出了估计 $\\Phi_w$ 和 $\\Phi_b$ 的算法流程,并给出实现[代码](https://sites.google.com/site/fastplda/ )。" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | "EM 算法的[优化目标](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm#Description)为:\n", 277 | "$$\n", 278 | "Q(\\theta|\\theta^{(t-1)}) = \\mathbb{E}_{y|x,\\theta^{(t-1)}} \\ln p(x, y|\\theta)\n", 279 | "$$\n", 280 | "因此我们需要知道 $y$ 的后验分布。\n", 281 | "\n", 282 | "### 4.2.1 隐变量 $y$ 的后验分布\n", 283 | "\n", 284 | "回顾 PLDA 的假设:\n", 285 | "$$\n", 286 | "p(\\mathbf{x}|\\mathbf{y}) = \\mathcal{N}(\\mathbf{x}|\\mathbf{y}, \\Phi_w)\n", 287 | "$$\n", 288 | "\n", 289 | "$$\n", 290 | "p(\\mathbf{y}) = \\mathcal{N}(\\mathbf{y}|\\mathbf{m}, \\Phi_b)\n", 291 | "$$\n", 292 | "\n", 293 | "\n", 294 | "给定某类的 $n$ 个数据 $x_{1,\\dots,n}$,则 $y$ 的后验分布可以为:\n", 295 | "$$\n", 296 | "p(y|x_{1,\\dots,n}) = p(x_{1,\\dots,n}|y)p(y) / p(x_{1,\\dots,n}) \\propto p(x_{1,\\dots,n}|y)p(y)\n", 297 | "$$\n", 298 | "\n", 299 | "后验为两个高斯分布的乘积,因此也服从高斯。因此,我们只需要计算均值向量和方差矩阵,即可以确定后验分布。\n", 300 | "\n", 301 | "$$\n", 302 | "\\ln p(y|x_{1,\\dots,n}) = \\ln p(x_{1,\\dots,n}|y) + \\ln p(y) = \\sum_i \\ln p(x_i|y) + \\ln p(y) = C - 0.5 \\sum_i y^T \\Phi_w^{-1} y + \\sum_i x_i^T \\Phi_w^{-1} y - 0.5 y^T \\Phi_b^{-1} y + m^T \\Phi_b^{-1} y\n", 303 | "$$\n", 304 | "\n", 305 | "整理 $y$ 的二次项为 $0.5 y^T (n\\Phi^{-1}_w + \\Phi_b^{-1}) y$,对比高斯分布的二次项系数,后验的协方差矩阵为:\n", 306 | "$$\n", 307 | "\\hat \\Phi = (n\\Phi^{-1}_w + \\Phi_b^{-1})^{-1} \n", 308 | "$$\n", 309 | "\n", 310 | "记均值为 $\\hat m$,则高斯分布的一次项为:\n", 311 | "$$\n", 312 | "y^T \\hat \\Phi^{-1} \\hat m\n", 313 | "$$\n", 314 | "\n", 315 | "整理上面的一次式有:\n", 316 | "$$\n", 317 | "y^T (\\Phi_b^{-1} m + \\Phi_w^{-1} \\sum_i x_i)\n", 318 | "$$\n", 319 | "对比两式,令 $f = \\sum_i x_i$\n", 320 | "$$\n", 321 | "\\hat m = \\hat \\Phi (\\Phi_b^{-1} m + \\Phi_w^{-1} f)\n", 322 | "$$\n", 323 | "\n", 324 | "综上,后验分布为:\n", 325 | "$$\n", 326 | "p(y|x_{1,\\dots,n}) \\sim \\mathcal{N}(\\hat m, \\hat \\Phi)\n", 327 | "$$\n", 328 | "\n" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "### 4.2.2 E step\n", 336 | "根据上面的推导有:\n", 337 | "$$\n", 338 | "\\hat \\Phi = (n\\Phi^{-1}_w + \\Phi_b^{-1})^{-1} \n", 339 | "$$\n", 340 | "\n", 341 | "$$\n", 342 | "\\hat m = \\hat \\Phi (\\Phi_b^{-1} m + \\Phi_w^{-1} f)\n", 343 | "$$\n", 344 | "\n", 345 | "根据后验概率,易得如下期望:\n", 346 | "$$\n", 347 | "\\mathbb{E}[y] = \\hat m = \\hat \\Phi (\\Phi_b^{-1} m + \\Phi_w^{-1} f)\n", 348 | "$$\n", 349 | "$$\n", 350 | "\\mathbb{E}[yy^T] = \\hat \\Phi + \\mathbb{E}[y] \\mathbb{E}[y]^T\n", 351 | "$$\n", 352 | "\n" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "EM 优化目标可以改写为:\n", 360 | "$$\n", 361 | "Q(\\theta|\\theta^{(t-1)}) = \\mathbb{E}_{y|x,\\theta^{(t-1)}} \\ln p(x, y|\\theta) = \\mathbb{E}_{y|x,\\theta^{(t-1)}} \\ln p(x|y, \\theta) + \\mathbb{E}_{y|\\theta^{(t-1)}} \\ln p(y|\\theta) \n", 362 | "$$\n", 363 | "显然,右式第一项只包含参数 $\\Phi_w$,第二项只包含 $\\Phi_b$ 及 $m$。因此,我们可以将优化目标分为独立的两部分。\n", 364 | "\n", 365 | "\n", 366 | "### 4.2.3 M step for $\\Phi_w$\n", 367 | "\n", 368 | "\n", 369 | "对于参数 $\\Phi_w$,EM 的最大化目标函数为:\n", 370 | "$$\n", 371 | "Q = \\mathbb{E}_{y}[\\ln p(x|y)] = \\mathbb{E}_{y}[\\sum_i \\ln p(x_i|y)]\n", 372 | "$$\n", 373 | "\n", 374 | "对 $x_i$ 有:\n", 375 | "$$\n", 376 | "\\ln p(x_i|y) = C - 0.5 \\ln|\\Phi_w| - 0.5 (x_i - y)^T \\Phi_w^{-1} (x_i - y) = C - 0.5 \\ln|\\Phi_w| -0.5 x_i^{T} \\Phi_w^{-1} x_i -0.5 y^{T} \\Phi_w^{-1} y + 0.5 x_i^T \\Phi_w^{-1} y + 0.5 y^T \\Phi_w^{-1} x_i\n", 377 | "$$\n", 378 | "\n", 379 | "$$\n", 380 | "\\mathbb{E}_{y}[\\ln p(x_i|y)] = C - 0.5 \\ln|\\Phi_w| -0.5 x_i^{T} \\Phi_w^{-1} x_i -0.5 \\mathrm{tr} (\\mathbb{E}[yy^T] \\Phi_w^{-1}) + 0.5 x_i^T \\Phi_w^{-1} \\mathbb{E}[y] + 0.5 \\mathbb{E}[y]^T \\Phi_w^{-1} x_i\n", 381 | "$$\n", 382 | "\n", 383 | "利用矩阵求导的知识[11],并注意到 $\\Phi_w$ 和 $\\mathbb{E}[yy]$ 都是对称矩阵,于是有:\n", 384 | "$$\n", 385 | "\\frac{\\partial}{\\partial \\Phi_w} \\mathbb{E}_{y}[\\ln p(x|y)] = -0.5 n \\Phi_w^{-1} + 0.5 \\Phi_w^{-1} S \\Phi_w^{-1} + 0.5n \\Phi_w^{-1}\\mathbb{E}[yy]^T\\Phi_w^{-1} - 0.5\\Phi_w^{-1} f \\mathbb{E}[y]^T\\Phi_w^{-1} - 0.5\\Phi_w^{-1} \\mathbb{E}[y] f^T\\Phi_w^{-1}\n", 386 | "$$\n", 387 | "\n", 388 | "其中,$S = \\sum_i x_i x_i^T, f = \\sum_i x_i$。令上式零,求得:\n", 389 | "\n", 390 | "$$\n", 391 | "n\\Phi_w = S + n\\mathbb{E}[yy^T] - (T + T^T)\n", 392 | "$$\n", 393 | "其中,$T^T = \\mathbb{E}[y] f^T$\n", 394 | "\n", 395 | "根据这个类别数据, $\\Phi_w$ 的估计为:\n", 396 | "$$\n", 397 | "\\Phi_w =\\frac{1}{n}[S + n\\mathbb{E}[yy^T] - (T + T^T)]\n", 398 | "$$\n", 399 | "\n", 400 | "上面的推导是基于一个类别的数据,如果我们有 $K$ 类,对依赖类别的变量加上相应的下标,则有:\n", 401 | "$$\n", 402 | "\\sum_{k=1}^K n_k \\Phi_w = \\sum_{k=1}^K S_k + n_k \\mathbb{E}[y_ky_k^T] - (T_k + T_k^T)\n", 403 | "$$\n", 404 | "\n", 405 | "根据所有类别数据的信息,$\\Phi_w$ 的估计公式为:\n", 406 | "$$\n", 407 | "\\Phi_w = \\frac{1}{N} (S_g + R_g - (T_g + T_g^T))\n", 408 | "$$\n", 409 | "其中,\n", 410 | "$$\n", 411 | "S_g = \\sum_{k=1}^K S_k = \\sum_{k=1}^K \\sum_{j=1}^{n_k} x_{kj}x_{kj}^T \n", 412 | "$$\n", 413 | "$$\n", 414 | "T_g = \\sum_{k=1}^K T_k = \\sum_{k=1}^K \\mathbb{E}[y_k] f_k^T\n", 415 | "$$\n", 416 | "$$\n", 417 | "R_g = \\sum_{k=1}^K n_k \\mathbb{E}[y_ky_k^T]\n", 418 | "$$" 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": {}, 424 | "source": [ 425 | "### 4.2.4 M step for $\\Phi_b$ 及 $m$\n", 426 | "\n", 427 | "对于参数 $\\Phi_b$ 和 $m$,EM 的最大化目标函数为:\n", 428 | "$$\n", 429 | "Q = \\mathbb{E}_{y}[\\ln p(y)] = \\sum_{k=1}^K \\mathbb{E}_{y_k}[\\ln p(y_k)]\n", 430 | "$$\n", 431 | "\n", 432 | "对 $y_k$ 有:\n", 433 | "$$\n", 434 | "\\ln p(y_k) = C - 0.5 \\ln|\\Phi_b| - 0.5 (y_k - m)^T \\Phi_w^{-1} (y_k - m) = C - 0.5 \\ln|\\Phi_b| -0.5 y_k^{T} \\Phi_b^{-1} y_k -0.5 m^{T} \\Phi_b^{-1} m + 0.5 y_k^T \\Phi_b^{-1} m + 0.5 m^T \\Phi_b^{-1} y_k\n", 435 | "$$\n", 436 | "\n", 437 | "$$\n", 438 | "\\mathbb{E}_{y}[\\ln p(x_i)] = C - 0.5 \\ln|\\Phi_b| -0.5 \\mathrm{tr} (\\mathbb{E}[yy^T] \\Phi_b^{-1}) -0.5 m^{T} \\Phi_b^{-1} m + 0.5 \\mathbb{E}[y]^T \\Phi_b^{-1} m + 0.5 m^T \\Phi_b^{-1} \\mathbb{E}[y]\n", 439 | "$$\n", 440 | "\n", 441 | "\n", 442 | "**a) $m$ 的估计**\n", 443 | "\n", 444 | "类似于4.2.3,我们有:\n", 445 | "$$\n", 446 | "\\frac{\\partial}{\\partial m} \\mathbb{E}_{y}[\\ln p(y)] = \\sum_k -m^T \\Phi_b^{-1} + \\mathbb{E}[y]^T\\Phi_b^{-1} \n", 447 | "$$\n", 448 | "\n", 449 | "令上式为零,求得:\n", 450 | "$$\n", 451 | "m = \\frac{1}{K} \\sum_k \\mathbb{E}[y_k]\n", 452 | "$$\n", 453 | "\n", 454 | "由于各类数量可能不均衡,可以加权[2]:\n", 455 | "$$\n", 456 | "m = \\frac{1}{N} \\sum_k n_k\\mathbb{E}[y_k]\n", 457 | "$$\n", 458 | "\n", 459 | "此时,如果各类别数量相同,将 $\\mathbb{E}[y] = \\hat \\Phi (\\Phi_b^{-1} m + \\Phi_w^{-1} f)$ 代入上式有:\n", 460 | "$$\n", 461 | "m = \\frac{1}{N} \\sum_{k=1}^K f_k = \\frac{1}{N} \\sum_{k=1}^K \\sum_{i=1}^{n_k} x_{ki}\n", 462 | "$$\n", 463 | "\n", 464 | "即 $m$ 是已知数据的全局均值,不需要迭代。在 Kaldi 的实现中,不论各类数量,都直接使用全局均值做为 $m$ 的估计。\n", 465 | "\n", 466 | "**b) $\\Phi_b$ 的估计**\n", 467 | "\n", 468 | "对 $\\Phi_b$ 求导得:\n", 469 | "$$\n", 470 | "\\frac{\\partial}{\\partial \\Phi_b} \\mathbb{E}_{y_k}[\\ln p(y_k)] = -0.5 \\Phi_b^{-1} + 0.5 \\Phi_b^{-1} \\mathbb{E}[y_k y_k^T] \\Phi_b^{-1} + 0.5\\Phi_b^{-1}mm^T\\Phi_b^{-1} - 0.5 \\Phi_b^{-1} \\mathbb{E}[y_k]m^T \\Phi_b^{-1} - 0.5 \\Phi_b^{-1} m \\mathbb{E}[y_k]^T \\Phi_b^{-1}\n", 471 | "$$\n", 472 | "\n", 473 | "令上式为零,求得更新公式:\n", 474 | "$$\n", 475 | "\\Phi_b = \\mathbb{E}[y_k y_k^T] + mm^T - \\mathbb{E}[y_k]m^T - m \\mathbb{E}[{y_k}^T]\n", 476 | "$$\n", 477 | "注意到 $\\mathbb{E}[yy^T] = \\hat \\Phi + \\mathbb{E}[y] \\mathbb{E}[y]^T$,则有:\n", 478 | "$$\n", 479 | "\\Phi_b = \\frac{1}{K} \\sum_k(\\hat \\Phi_k +(\\mathbb{E}[y_k] - m)(\\mathbb{E}[y_k] - m)^T)\n", 480 | "$$\n", 481 | "\n", 482 | "> 如果考虑各类之间的权重,则有:\n", 483 | "$$\n", 484 | "\\Phi_b = \\frac{1}{\\sum_{k}w_k} \\sum_k w_k (\\hat \\Phi_k +(\\mathbb{E}[{y_k}] - m)(\\mathbb{E}[{y_k}] - m)^T)\n", 485 | "$$\n", 486 | "\n", 487 | "而如果我们使用 $m = \\frac{1}{N} \\sum_k n_k\\mathbb{E}[y_k]$ 估计 $m$,则综合所有类别数据,并按照数据量加权,有[2]:\n", 488 | "$$\n", 489 | "\\Phi_b = \\frac{1}{N} R_g - mm^T\n", 490 | "$$" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": { 496 | "collapsed": true 497 | }, 498 | "source": [ 499 | "# References\n", 500 | "1. Hastie et al. [The Elements of Statistical Learning](https://web.stanford.edu/~hastie/ElemStatLearn/).\n", 501 | "2. Sizov et al. [Unifying Probabilistic Linear Discriminant Analysis Variants in Biometric Authentication](http://cs.uef.fi/~sizov/pdf/unifying_PLDA_ssspr2014.pdf).\n", 502 | "3. Prince & Elder. [Probabilistic Linear Discriminant Analysis for Inferences About Identity](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.97.6491&rep=rep1&type=pdf).\n", 503 | "4. Kenny et al. [Bayesian Speaker Verification with Heavy Tailed Priors](http://www.crim.ca/perso/patrick.kenny/kenny_Odyssey2010.pdf).\n", 504 | "5. Ioffe. [Probabilistic Linear Discriminant Analysis](https://ravisoji.com/assets/papers/Ioffe2006PLDA.pdf).\n", 505 | "6. Shafey et al. [A Scalable Formulation of Probabilistic Linear Discriminant Analysis: Applied to Face Recognition](http://ieeexplore.ieee.org/document/6461886/).\n", 506 | "7. Hastie & Tibshirani. [Discriminant Analysis by Gaussaian Mixtures](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.36.203).\n", 507 | "8. Brummer & De Villiers et al. [The Speaker Partitioning Problem](https://pdfs.semanticscholar.org/3e49/e2d4b026e6bfe4def3586d8cd9b2a90ee7ed.pdf).\n", 508 | "9. Jiang et al. [PLDA Modeling in I-vector and Supervector Space for Speaker Verification](https://pdfs.semanticscholar.org/bccb/205ca4069505aefd29fca5b5cdf3db02e3d4.pdf).\n", 509 | "10. Brummer et al. [EM for Probabilistic LDA](https://ce23b256-a-62cb3a1a-s-sites.googlegroups.com/site/nikobrummer/EMforPLDA.pdf?attachauth=ANoY7crcscfaCA3IqQl-SOmO-MU41YCYXsPkXgoI3yS1ND6EewKJI62_YbtfycbClTO7y49zyO-s8d038nPwwrL0DlTjd5kQPDFDIoAWvQoWnSUNQUxauB78WqO70sbBK73GS0_LXtFFHxyysqoB70Rz70Y5ipRzDyfhqgAxdclS2t5xGHhK6pJoOKc_gIqGZNzt7uAK_Oi6fhmfGm4Vek-3AsJka5F0mQ%3D%3D&attredirects=0).\n", 510 | "11. Petersen & Petersen. [The Matrix Cookbook](http://www2.imm.dtu.dk/pubdb/views/edoc_download.php/3274/pdf/imm3274.pdf)." 511 | ] 512 | } 513 | ], 514 | "metadata": { 515 | "anaconda-cloud": {}, 516 | "kernelspec": { 517 | "display_name": "Python [Root]", 518 | "language": "python", 519 | "name": "Python [Root]" 520 | }, 521 | "language_info": { 522 | "codemirror_mode": { 523 | "name": "ipython", 524 | "version": 2 525 | }, 526 | "file_extension": ".py", 527 | "mimetype": "text/x-python", 528 | "name": "python", 529 | "nbconvert_exporter": "python", 530 | "pygments_lexer": "ipython2", 531 | "version": "2.7.12" 532 | } 533 | }, 534 | "nbformat": 4, 535 | "nbformat_minor": 0 536 | } 537 | -------------------------------------------------------------------------------- /plda/plda.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Mar 31 23:29:47 2018 4 | 5 | @author: keding 6 | """ 7 | import numpy as np 8 | 9 | 10 | class PLDA(object): 11 | def __init__(self, type='inv'): 12 | '''Two Covariance PLDA. 13 | Args: 14 | type: full, diag 15 | ''' 16 | self.type = type 17 | 18 | if self.type == 'full': 19 | self.B = None # between-class covariance 20 | self.W = None # within-class covariance 21 | self.mu = None # between-class center 22 | elif self.type == 'diag': 23 | self.V = None # transform matrix 24 | self.psi = None # diagnolized between-class covariance 25 | self.mu = None # between-class center 26 | elif self.type == 'inv': 27 | self.invB = None 28 | self.invW = None 29 | self.mu = None 30 | 31 | def covert(self, target): 32 | '''Covert between types 33 | ''' 34 | if target == self.type: 35 | return 36 | 37 | if self.type == 'full' and target == 'diag': 38 | raise RuntimeError('Not Implemented yet!') 39 | elif self.type == 'diag' and target == 'full': 40 | raise RuntimeError('Not Implemented yet!') 41 | else: 42 | raise RuntimeError('Invalid type convertion!') 43 | 44 | self.type = target 45 | 46 | def compute_log_likelihood(self, data): 47 | """Comute the log likelihood for the whole dataset. 48 | 49 | Args: 50 | data: An array of the shape (number_of_features, number_of_samples). 51 | """ 52 | if self.type == 'full': 53 | return self._compute_llk_full(data, self.mu, self.B, self.W) 54 | elif self.type == 'diag': 55 | return self._compute_llk_diag(data, self.mu, self.psi, self.V) 56 | elif self.type == 'inv': 57 | return self._compute_llk_full(data, self.mu, self.invB, self.invW) 58 | 59 | def _compute_llk_full(self, data, mu, B, W): 60 | d, n = data.shape 61 | 62 | centered_data = data - mu 63 | 64 | # Total covariance matrix for the model with integrated out latent 65 | # variables 66 | Sigma_tot = B + W 67 | 68 | # Compute log-determinant of the Sigma_tot matrix 69 | E, _ = np.linalg.eig(Sigma_tot) 70 | log_det = np.sum(np.log(E)) 71 | 72 | return -0.5 * (n * d * np.log(2 * np.pi) + n * log_det + 73 | np.sum(np.dot(centered_data.T, np.linalg.inv(Sigma_tot)) * 74 | centered_data.T)) 75 | 76 | def _compute_llk_diag(self, data, mu, psi, V): 77 | d, n = data.shape 78 | 79 | u = np.dot(V, data - mu) 80 | 81 | # Total covariance matrix for the model with integrated out latent 82 | # variables 83 | Sigma_tot = psi + 1 84 | 85 | # Compute log-determinant of the Sigma_tot matrix 86 | log_det = np.sum(np.log(Sigma_tot)) 87 | 88 | return -0.5 * (n * d * np.log(2 * np.pi) + n * log_det + 89 | np.sum(u ** 2 / Sigma_tot[:, np.newaxis])) 90 | 91 | 92 | def preprocessing(data): 93 | ''' 94 | ''' 95 | # Sort the speakers by the number of utterances for the faster E-step 96 | data.sort(key=lambda x: x.shape[1]) 97 | 98 | # Pool all the data for the more efficient M-step 99 | pooled_data = np.hstack(data) 100 | 101 | N = pooled_data.shape[1] # total number of files 102 | K = len(data) # number of classes 103 | 104 | mu = pooled_data.mean(axis=1, keepdims=True) 105 | 106 | # Calc first and second moments 107 | f = [spk_data.sum(axis=1) for spk_data in data] 108 | f = np.asarray(f).T 109 | S = np.dot(pooled_data, pooled_data.T) 110 | 111 | return pooled_data, N, K, f, S, mu 112 | 113 | 114 | def initialize(plda, N, S, mu): 115 | cov = S / N - np.dot(mu, mu.T) 116 | 117 | if plda.type == 'full': 118 | plda.mu = mu 119 | plda.B = plda.W = cov 120 | plda.W = cov 121 | elif plda.type == 'inv': 122 | plda.mu = mu 123 | plda.invB = plda.invW = cov 124 | 125 | 126 | def inv_e_step(plda, data, N, f, S): 127 | dim_d = data[0].shape[0] 128 | 129 | B = np.linalg.inv(plda.invB) 130 | W = np.linalg.inv(plda.invW) 131 | mu = plda.mu 132 | 133 | # Initialize output matrices 134 | T = np.zeros((dim_d, dim_d)) 135 | R = np.zeros((dim_d, dim_d)) 136 | Y = np.zeros((dim_d, 1)) 137 | 138 | # Set auxiliary matrix 139 | Bmu = np.dot(B, mu) 140 | 141 | n_previous = 0 # number of utterances for a previous person 142 | for i in range(len(data)): 143 | n = data[i].shape[1] # number of utterances for a particular person 144 | if n != n_previous: 145 | # Update matrix that is dependent on the number of utterances 146 | invL_i = np.linalg.inv(B + n * W) 147 | n_previous = n 148 | 149 | gamma_i = Bmu + np.dot(W, f[:,[i]]) 150 | Ey_i = np.dot(invL_i, gamma_i) 151 | 152 | T += np.dot(Ey_i, f[:, [i]].T) 153 | R += n * (invL_i + np.dot(Ey_i, Ey_i.T)) 154 | Y += n * Ey_i 155 | 156 | return T, R, Y 157 | 158 | 159 | def inv_m_step(plda, T, R, Y, N, S): 160 | plda.mu = Y / N 161 | plda.invB = (R - np.dot(Y, Y.T) / N) / N 162 | plda.invW = (S - (T + T.T) + R) / N 163 | 164 | 165 | def full_e_step(plda, data, N, f, S): 166 | dim_d = data[0].shape[0] 167 | 168 | invB = np.linalg.inv(plda.B) 169 | invW = np.linalg.inv(plda.W) 170 | mu = plda.mu 171 | 172 | # Initialize output matrices 173 | T = np.zeros((dim_d, dim_d)) 174 | R = np.zeros((dim_d, dim_d)) 175 | P = np.zeros((dim_d, dim_d)) 176 | E = np.zeros((dim_d, dim_d)) 177 | 178 | # Set auxiliary matrix 179 | invBmu = np.dot(invB, mu) 180 | 181 | n_previous = 0 # number of utterances for a previous person 182 | for i in range(len(data)): 183 | n = data[i].shape[1] # number of utterances for a particular person 184 | if n != n_previous: 185 | # Update matrix that is dependent on the number of utterances 186 | Phi = np.linalg.inv(invB + n * invW) 187 | n_previous = n 188 | 189 | gamma_i = invBmu + np.dot(invW, f[:, [i]]) 190 | Ey_i = np.dot(Phi, gamma_i) 191 | Eyyt_i = Phi + np.dot(Ey_i, Ey_i.T) 192 | Ey_immu = Ey_i - mu 193 | 194 | T += np.dot(Ey_i, f[:, [i]].T) 195 | R += n * Eyyt_i 196 | P += Phi 197 | E += np.dot(Ey_immu, Ey_immu.T) 198 | 199 | return T, R, P, E 200 | 201 | 202 | def full_m_step(plda, T, R, P, E, N, K, S): 203 | plda.B = (P + E) / K 204 | plda.W = (S - (T + T.T) + R) / N 205 | 206 | 207 | def print_progress(plda, pooled_data, cur_iter, total_iters): 208 | progress_message = '%d-th\titeration out of %d.' % (cur_iter+1, 209 | total_iters) 210 | progress_message += (' Log-likelihood is %f' % 211 | plda.compute_log_likelihood(pooled_data)) 212 | print progress_message 213 | 214 | 215 | def train(plda, data, iterations): 216 | pooled_data, N, K, f, S, mu = preprocessing(data) 217 | initialize(plda, N, S, mu) 218 | 219 | for i in range(iterations): 220 | if plda.type == 'inv': 221 | T, R, Y = inv_e_step(plda, data, N, f, S) 222 | inv_m_step(plda, T, R, Y, N, S) 223 | elif plda.type == 'full': 224 | T, R, P, E = full_e_step(plda, data, N, f, S) 225 | full_m_step(plda, T, R, P, E, N, K, S) 226 | 227 | # Print current progress 228 | print_progress(plda, pooled_data, i, iterations) 229 | 230 | 231 | def test_plda(): 232 | types = ['full', 'inv', 'diag'] 233 | for t in types: 234 | plda = PLDA(t) 235 | 236 | 237 | def fake_data(D=2, K=3, n=10): 238 | sqrtB = np.diag(np.random.rand(D) + 0.1) 239 | B = np.dot(sqrtB, sqrtB.T) 240 | 241 | sqrtW = np.diag(np.random.rand(D) * 0.1 + 0.01) 242 | W = np.dot(sqrtW, sqrtW.T) 243 | 244 | m = np.random.randn(D, 1) 245 | 246 | data = [] 247 | for i in range(K): 248 | ni = n + i 249 | X = m + np.dot(sqrtB.T, np.random.randn(D, ni)) + np.dot(sqrtW.T, np.random.randn(D, ni)) 250 | data.append(X) 251 | 252 | return data, m, B, W 253 | 254 | 255 | def test_train(): 256 | 257 | D, K, n, iters = 2, 3, 10, 10 258 | 259 | np.random.seed(1111) 260 | data, m, B, W = fake_data(D, K, n) 261 | print(m) 262 | print(B) 263 | print(W) 264 | 265 | np.random.seed(1111) 266 | plda = PLDA('inv') 267 | train(plda, data, iterations=iters) 268 | print(plda.mu) 269 | print(plda.invB) 270 | print(plda.invW) 271 | 272 | np.random.seed(1111) 273 | plda = PLDA('full') 274 | train(plda, data, iterations=iters) 275 | print(plda.mu) 276 | print(plda.B) 277 | print(plda.W) 278 | 279 | 280 | if __name__ == '__main__': 281 | test_plda() 282 | test_train() --------------------------------------------------------------------------------