├── Beam Decoding.ipynb ├── Data Augmentation for Quora Question Pairs.ipynb ├── Dropout in a minute.ipynb ├── How to get the last hidden vector of rnns properly.ipynb ├── Pos-tagging with Bert Fine-tuning.ipynb ├── PyTorch seq2seq template based on the g2p task.ipynb ├── README.md ├── Subword Segmentation Techniques.ipynb ├── Tensorflow seq2seq template based on g2p.ipynb ├── dropout.png └── no-dropout.png /Beam Decoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Beam decoding is essential for seq2seq tasks. But it's notoriously complicated to implement. Here's a relatively easy one, batchfying candidates." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Parameters" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 262, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "V = 100 # output dimensionality. number of vocabulary\n", 44 | "H = 10 # hidden dimensionality \n", 45 | "K = 3 # beam width\n", 46 | "T = 10 # decoding timesteps\n", 47 | "initial_y = np.array([[3], [6]], dtype=np.int32)\n", 48 | "N = len(initial_y) # batch size\n", 49 | "xh = np.random.randn(V, H) # weights from input to hidden\n", 50 | "hh = np.random.randn(H, H) # weights from hidden to hidden\n", 51 | "ho = np.random.randn(H, V) # weights from hidden to outputs\n", 52 | "\n", 53 | "EOS_ID = 0" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Utility functions" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 263, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def onehot(arry, size):\n", 70 | " '''\n", 71 | " arry: 2-d array of n, t\n", 72 | " size: output dimensions\n", 73 | " \n", 74 | " returns\n", 75 | " 3-d array of (n, t, size)\n", 76 | " '''\n", 77 | " labels_one_hot = (arry.ravel()[np.newaxis] == np.arange(size)[:, np.newaxis]).T\n", 78 | " labels_one_hot.shape = arry.shape + (size,)\n", 79 | " return labels_one_hot.astype('int32')" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 264, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "def softmax(x):\n", 89 | " \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n", 90 | " e_x = np.exp(x - np.max(x))\n", 91 | " return e_x / e_x.sum(axis=1,keepdims=True)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 270, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def decode(y):\n", 101 | " '''Decodes with a simple rnn.\n", 102 | " y: (N, t) array if not None.\n", 103 | " '''\n", 104 | " global N, V, H, xh, hh, ho\n", 105 | " \n", 106 | " prev_hidden = np.zeros((y.shape[0], H)) # initial hidden\n", 107 | " \n", 108 | " for t in range(y.shape[1]):\n", 109 | " token = y[:, t]\n", 110 | " x_to_h = np.matmul(onehot(token, V), xh) # (N, h)\n", 111 | " h_to_h = np.matmul(prev_hidden, hh) # (N, h)\n", 112 | " hidden = np.tanh(x_to_h + h_to_h)\n", 113 | " prev_hidden = hidden\n", 114 | " if t == y.shape[1]-1: # last step\n", 115 | " outputs = np.matmul(hidden, ho)\n", 116 | " probs = softmax(outputs)\n", 117 | " probs = np.log(probs)\n", 118 | " return probs # (N, V)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Beam decode" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 268, 131 | "metadata": { 132 | "scrolled": false 133 | }, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "========== timesteps= 0 ==========\n", 140 | "batch num= 0\n", 141 | "[13]\tFalse\t-1.801408656139552\n", 142 | "[48]\tFalse\t-1.9670540723462346\n", 143 | "[33]\tFalse\t-2.198485320487616\n", 144 | "batch num= 1\n", 145 | "[30]\tFalse\t-0.7897988012958753\n", 146 | "[1]\tFalse\t-2.4629262078330942\n", 147 | "[76]\tFalse\t-2.942492396759654\n", 148 | "========== timesteps= 1 ==========\n", 149 | "Expansion...\n", 150 | "batch num= 0\n", 151 | "[13 53]\tFalse\t-1.7223380770456393\n", 152 | "[13 13]\tFalse\t-2.153169837293493\n", 153 | "[13 21]\tFalse\t-2.1963759327725914\n", 154 | "[48 7]\tFalse\t-1.5501885837535672\n", 155 | "[48 56]\tFalse\t-2.0774770760291426\n", 156 | "[48 71]\tFalse\t-2.2559148677511205\n", 157 | "[33 13]\tFalse\t-2.02571782070014\n", 158 | "[33 20]\tFalse\t-2.2622432967465382\n", 159 | "[33 99]\tFalse\t-2.411861897194325\n", 160 | "batch num= 1\n", 161 | "[30 8]\tFalse\t-1.0387280196474995\n", 162 | "[30 38]\tFalse\t-1.4677326159171868\n", 163 | "[30 67]\tFalse\t-1.963540864046755\n", 164 | "[ 1 46]\tFalse\t-1.8347386403083892\n", 165 | "[ 1 50]\tFalse\t-2.062073079883135\n", 166 | "[ 1 71]\tFalse\t-2.79060930087043\n", 167 | "[76 89]\tFalse\t-2.4038844011808673\n", 168 | "[76 44]\tFalse\t-2.440214472610669\n", 169 | "[76 3]\tFalse\t-2.4820893854715784\n", 170 | "Pruning ...\n", 171 | "batch num= 0\n", 172 | "[48 7]\tFalse\t-1.5501885837535672\n", 173 | "[13 53]\tFalse\t-1.7223380770456393\n", 174 | "[33 13]\tFalse\t-2.02571782070014\n", 175 | "batch num= 1\n", 176 | "[30 8]\tFalse\t-1.0387280196474995\n", 177 | "[30 38]\tFalse\t-1.4677326159171868\n", 178 | "[ 1 46]\tFalse\t-1.8347386403083892\n", 179 | "========== timesteps= 2 ==========\n", 180 | "Expansion...\n", 181 | "batch num= 0\n", 182 | "[48 7 45]\tFalse\t-1.4907133631014375\n", 183 | "[48 7 49]\tFalse\t-1.5125758679961583\n", 184 | "[48 7 46]\tFalse\t-1.68788305167278\n", 185 | "[13 53 53]\tFalse\t-1.4720320535925133\n", 186 | "[13 53 17]\tFalse\t-1.503669670852281\n", 187 | "[13 53 49]\tFalse\t-2.2055456936009072\n", 188 | "[33 13 53]\tFalse\t-1.868266739269855\n", 189 | "[33 13 17]\tFalse\t-1.9480947954999472\n", 190 | "[33 13 16]\tFalse\t-2.0038997227337942\n", 191 | "batch num= 1\n", 192 | "[30 8 7]\tFalse\t-0.981799161726694\n", 193 | "[30 8 50]\tFalse\t-1.2210058993765245\n", 194 | "[30 8 33]\tFalse\t-1.6862021730157182\n", 195 | "[30 38 7]\tFalse\t-1.0155674726535289\n", 196 | "[30 38 50]\tFalse\t-2.4361530449230084\n", 197 | "[30 38 56]\tFalse\t-2.4468657026464666\n", 198 | "[ 1 46 43]\tFalse\t-1.730098315254036\n", 199 | "[ 1 46 23]\tFalse\t-1.910274720612243\n", 200 | "[ 1 46 64]\tFalse\t-1.9826010216276544\n", 201 | "Pruning ...\n", 202 | "batch num= 0\n", 203 | "[13 53 53]\tFalse\t-1.4720320535925133\n", 204 | "[48 7 45]\tFalse\t-1.4907133631014375\n", 205 | "[13 53 17]\tFalse\t-1.503669670852281\n", 206 | "batch num= 1\n", 207 | "[30 8 7]\tFalse\t-0.981799161726694\n", 208 | "[30 38 7]\tFalse\t-1.0155674726535289\n", 209 | "[30 8 50]\tFalse\t-1.2210058993765245\n", 210 | "========== timesteps= 3 ==========\n", 211 | "Expansion...\n", 212 | "batch num= 0\n", 213 | "[13 53 53 41]\tFalse\t-1.5506704012318269\n", 214 | "[13 53 53 25]\tFalse\t-1.561838485804779\n", 215 | "[13 53 53 98]\tFalse\t-1.583598742175906\n", 216 | "[48 7 45 80]\tFalse\t-1.435019828867431\n", 217 | "[48 7 45 76]\tFalse\t-1.44588109619326\n", 218 | "[48 7 45 1]\tFalse\t-1.7402676634547847\n", 219 | "[13 53 17 41]\tFalse\t-1.585756130334546\n", 220 | "[13 53 17 80]\tFalse\t-1.5910758874599538\n", 221 | "[13 53 17 25]\tFalse\t-1.6346056800788564\n", 222 | "batch num= 1\n", 223 | "[30 8 7 43]\tFalse\t-1.0350845938886057\n", 224 | "[30 8 7 69]\tFalse\t-1.487864532678162\n", 225 | "[30 8 7 14]\tFalse\t-1.4976892653945955\n", 226 | "[30 38 7 46]\tFalse\t-0.8816524694126934\n", 227 | "[30 38 7 45]\tFalse\t-1.361845258654082\n", 228 | "[30 38 7 50]\tFalse\t-1.3850981222623515\n", 229 | "[30 8 50 43]\tFalse\t-1.474636461781719\n", 230 | "[30 8 50 11]\tFalse\t-1.6120821984876703\n", 231 | "[30 8 50 14]\tFalse\t-1.6380255433340407\n", 232 | "Pruning ...\n", 233 | "batch num= 0\n", 234 | "[48 7 45 80]\tFalse\t-1.435019828867431\n", 235 | "[48 7 45 76]\tFalse\t-1.44588109619326\n", 236 | "[13 53 53 41]\tFalse\t-1.5506704012318269\n", 237 | "batch num= 1\n", 238 | "[30 38 7 46]\tFalse\t-0.8816524694126934\n", 239 | "[30 8 7 43]\tFalse\t-1.0350845938886057\n", 240 | "[30 38 7 45]\tFalse\t-1.361845258654082\n", 241 | "========== timesteps= 4 ==========\n", 242 | "Expansion...\n", 243 | "batch num= 0\n", 244 | "[48 7 45 80 1]\tFalse\t-1.4155848839072929\n", 245 | "[48 7 45 80 63]\tFalse\t-1.4809062680622915\n", 246 | "[48 7 45 80 94]\tFalse\t-1.732865605759868\n", 247 | "[48 7 45 76 67]\tFalse\t-1.581737929223578\n", 248 | "[48 7 45 76 92]\tFalse\t-1.5835056363202569\n", 249 | "[48 7 45 76 13]\tFalse\t-1.7110238710735544\n", 250 | "[13 53 53 41 62]\tFalse\t-1.417947578093456\n", 251 | "[13 53 53 41 66]\tFalse\t-1.5861790359814807\n", 252 | "[13 53 53 41 22]\tFalse\t-1.8198367051022502\n", 253 | "batch num= 1\n", 254 | "[30 38 7 46 15]\tFalse\t-0.8900240329071243\n", 255 | "[30 38 7 46 76]\tFalse\t-1.0029480860843567\n", 256 | "[30 38 7 46 40]\tFalse\t-1.2061234517379482\n", 257 | "[30 8 7 43 40]\tFalse\t-1.1621896355270631\n", 258 | "[30 8 7 43 74]\tFalse\t-1.1756940719777458\n", 259 | "[30 8 7 43 76]\tFalse\t-1.2723109905838466\n", 260 | "[30 38 7 45 76]\tFalse\t-1.2788018546331241\n", 261 | "[30 38 7 45 15]\tFalse\t-1.4225672540053953\n", 262 | "[30 38 7 45 40]\tFalse\t-1.5114030374152683\n", 263 | "Pruning ...\n", 264 | "batch num= 0\n", 265 | "[48 7 45 80 1]\tFalse\t-1.4155848839072929\n", 266 | "[13 53 53 41 62]\tFalse\t-1.417947578093456\n", 267 | "[48 7 45 80 63]\tFalse\t-1.4809062680622915\n", 268 | "batch num= 1\n", 269 | "[30 38 7 46 15]\tFalse\t-0.8900240329071243\n", 270 | "[30 38 7 46 76]\tFalse\t-1.0029480860843567\n", 271 | "[30 8 7 43 40]\tFalse\t-1.1621896355270631\n", 272 | "========== timesteps= 5 ==========\n", 273 | "Expansion...\n", 274 | "batch num= 0\n", 275 | "[48 7 45 80 1 7]\tFalse\t-1.4794142075137813\n", 276 | "[48 7 45 80 1 41]\tFalse\t-1.4933938451514475\n", 277 | "[48 7 45 80 1 71]\tFalse\t-1.5103248587788836\n", 278 | "[13 53 53 41 62 76]\tFalse\t-1.3396238306377644\n", 279 | "[13 53 53 41 62 89]\tFalse\t-1.4990384205012208\n", 280 | "[13 53 53 41 62 65]\tFalse\t-1.5266813295887929\n", 281 | "[48 7 45 80 63 38]\tFalse\t-1.5578383956446686\n", 282 | "[48 7 45 80 63 71]\tFalse\t-1.5693606690867572\n", 283 | "[48 7 45 80 63 67]\tFalse\t-1.5881997528395313\n", 284 | "batch num= 1\n", 285 | "[30 38 7 46 15 1]\tFalse\t-1.0160237052072278\n", 286 | "[30 38 7 46 15 0]\tTrue\t-0.8900240329071243\n", 287 | "[30 38 7 46 15 95]\tFalse\t-1.1659301187826023\n", 288 | "[30 38 7 46 76 79]\tFalse\t-1.2324967288788231\n", 289 | "[30 38 7 46 76 92]\tFalse\t-1.242909393188552\n", 290 | "[30 38 7 46 76 75]\tFalse\t-1.3027834896100112\n", 291 | "[30 8 7 43 40 43]\tFalse\t-1.2051650221374317\n", 292 | "[30 8 7 43 40 14]\tFalse\t-1.2439917222130705\n", 293 | "[30 8 7 43 40 22]\tFalse\t-1.323827336732793\n", 294 | "Pruning ...\n", 295 | "batch num= 0\n", 296 | "[13 53 53 41 62 76]\tFalse\t-1.3396238306377644\n", 297 | "[48 7 45 80 1 7]\tFalse\t-1.4794142075137813\n", 298 | "[48 7 45 80 1 41]\tFalse\t-1.4933938451514475\n", 299 | "batch num= 1\n", 300 | "[30 38 7 46 15 0]\tTrue\t-0.8900240329071243\n", 301 | "[30 38 7 46 15 1]\tFalse\t-1.0160237052072278\n", 302 | "[30 38 7 46 15 95]\tFalse\t-1.1659301187826023\n", 303 | "========== timesteps= 6 ==========\n", 304 | "Expansion...\n", 305 | "batch num= 0\n", 306 | "[13 53 53 41 62 76 39]\tFalse\t-1.3496708801476172\n", 307 | "[13 53 53 41 62 76 8]\tFalse\t-1.4023681944692268\n", 308 | "[13 53 53 41 62 76 23]\tFalse\t-1.4492251873156643\n", 309 | "[48 7 45 80 1 7 82]\tFalse\t-1.5279394757013531\n", 310 | "[48 7 45 80 1 7 45]\tFalse\t-1.5425022838809215\n", 311 | "[48 7 45 80 1 7 46]\tFalse\t-1.5619046177353486\n", 312 | "[48 7 45 80 1 41 82]\tFalse\t-1.430451510587661\n", 313 | "[48 7 45 80 1 41 68]\tFalse\t-1.6258046549927043\n", 314 | "[48 7 45 80 1 41 23]\tFalse\t-1.6287973043777444\n", 315 | "batch num= 1\n", 316 | "[30 38 7 46 15 0 67]\tTrue\t-0.8900240329071243\n", 317 | "[30 38 7 46 15 0 20]\tTrue\t-0.8900240329071243\n", 318 | "[30 38 7 46 15 0 57]\tTrue\t-0.8900240329071243\n", 319 | "[30 38 7 46 15 1 11]\tFalse\t-1.0339379651820644\n", 320 | "[30 38 7 46 15 1 76]\tFalse\t-1.1711319401110587\n", 321 | "[30 38 7 46 15 1 41]\tFalse\t-1.2156144563727591\n", 322 | "[30 38 7 46 15 95 20]\tFalse\t-1.2206017964230695\n", 323 | "[30 38 7 46 15 95 11]\tFalse\t-1.2973599838246521\n", 324 | "[30 38 7 46 15 95 39]\tFalse\t-1.347143453699934\n", 325 | "Pruning ...\n", 326 | "batch num= 0\n", 327 | "[13 53 53 41 62 76 39]\tFalse\t-1.3496708801476172\n", 328 | "[13 53 53 41 62 76 8]\tFalse\t-1.4023681944692268\n", 329 | "[48 7 45 80 1 41 82]\tFalse\t-1.430451510587661\n", 330 | "batch num= 1\n", 331 | "[30 38 7 46 15 0 57]\tTrue\t-0.8900240329071243\n", 332 | "[30 38 7 46 15 0 20]\tTrue\t-0.8900240329071243\n", 333 | "[30 38 7 46 15 0 67]\tTrue\t-0.8900240329071243\n", 334 | "========== timesteps= 7 ==========\n", 335 | "Expansion...\n", 336 | "batch num= 0\n", 337 | "[13 53 53 41 62 76 39 78]\tFalse\t-1.4009193941899798\n", 338 | "[13 53 53 41 62 76 39 38]\tFalse\t-1.4184446787212193\n", 339 | "[13 53 53 41 62 76 39 72]\tFalse\t-1.4859004214067744\n", 340 | "[13 53 53 41 62 76 8 20]\tFalse\t-1.4317120180607932\n", 341 | "[13 53 53 41 62 76 8 79]\tFalse\t-1.4555651089534953\n", 342 | "[13 53 53 41 62 76 8 88]\tFalse\t-1.5136947745882172\n", 343 | "[48 7 45 80 1 41 82 15]\tFalse\t-1.4773084124370452\n", 344 | "[48 7 45 80 1 41 82 84]\tFalse\t-1.5052793912158204\n", 345 | "[48 7 45 80 1 41 82 30]\tFalse\t-1.5128248163630005\n", 346 | "batch num= 1\n", 347 | "[30 38 7 46 15 0 57 8]\tTrue\t-0.8900240329071243\n", 348 | "[30 38 7 46 15 0 57 22]\tTrue\t-0.8900240329071243\n", 349 | "[30 38 7 46 15 0 57 23]\tTrue\t-0.8900240329071243\n", 350 | "[30 38 7 46 15 0 20 8]\tTrue\t-0.8900240329071243\n", 351 | "[30 38 7 46 15 0 20 22]\tTrue\t-0.8900240329071243\n", 352 | "[30 38 7 46 15 0 20 62]\tTrue\t-0.8900240329071243\n", 353 | "[30 38 7 46 15 0 67 82]\tTrue\t-0.8900240329071243\n", 354 | "[30 38 7 46 15 0 67 56]\tTrue\t-0.8900240329071243\n", 355 | "[30 38 7 46 15 0 67 8]\tTrue\t-0.8900240329071243\n", 356 | "Pruning ...\n", 357 | "batch num= 0\n", 358 | "[13 53 53 41 62 76 39 78]\tFalse\t-1.4009193941899798\n", 359 | "[13 53 53 41 62 76 39 38]\tFalse\t-1.4184446787212193\n", 360 | "[13 53 53 41 62 76 8 20]\tFalse\t-1.4317120180607932\n", 361 | "batch num= 1\n", 362 | "[30 38 7 46 15 0 67 8]\tTrue\t-0.8900240329071243\n", 363 | "[30 38 7 46 15 0 67 56]\tTrue\t-0.8900240329071243\n", 364 | "[30 38 7 46 15 0 67 82]\tTrue\t-0.8900240329071243\n", 365 | "========== timesteps= 8 ==========\n", 366 | "Expansion...\n", 367 | "batch num= 0\n", 368 | "[13 53 53 41 62 76 39 78 7]\tFalse\t-1.4068256900320346\n", 369 | "[13 53 53 41 62 76 39 78 49]\tFalse\t-1.4296420137168409\n", 370 | "[13 53 53 41 62 76 39 78 45]\tFalse\t-1.4415205750559386\n", 371 | "[13 53 53 41 62 76 39 38 7]\tFalse\t-1.3760183790382272\n", 372 | "[13 53 53 41 62 76 39 38 45]\tFalse\t-1.5019493281169982\n", 373 | "[13 53 53 41 62 76 39 38 50]\tFalse\t-1.511051096358533\n", 374 | "[13 53 53 41 62 76 8 20 79]\tFalse\t-1.4413910156791327\n", 375 | "[13 53 53 41 62 76 8 20 62]\tFalse\t-1.5197912451454039\n", 376 | "[13 53 53 41 62 76 8 20 66]\tFalse\t-1.522338865864498\n", 377 | "batch num= 1\n", 378 | "[30 38 7 46 15 0 67 8 76]\tTrue\t-0.8900240329071243\n", 379 | "[30 38 7 46 15 0 67 8 15]\tTrue\t-0.8900240329071243\n", 380 | "[30 38 7 46 15 0 67 8 40]\tTrue\t-0.8900240329071243\n", 381 | "[30 38 7 46 15 0 67 56 56]\tTrue\t-0.8900240329071243\n", 382 | "[30 38 7 46 15 0 67 56 87]\tTrue\t-0.8900240329071243\n", 383 | "[30 38 7 46 15 0 67 56 69]\tTrue\t-0.8900240329071243\n", 384 | "[30 38 7 46 15 0 67 82 7]\tTrue\t-0.8900240329071243\n", 385 | "[30 38 7 46 15 0 67 82 56]\tTrue\t-0.8900240329071243\n", 386 | "[30 38 7 46 15 0 67 82 82]\tTrue\t-0.8900240329071243\n", 387 | "Pruning ...\n", 388 | "batch num= 0\n", 389 | "[13 53 53 41 62 76 39 38 7]\tFalse\t-1.3760183790382272\n", 390 | "[13 53 53 41 62 76 39 78 7]\tFalse\t-1.4068256900320346\n", 391 | "[13 53 53 41 62 76 39 78 49]\tFalse\t-1.4296420137168409\n", 392 | "batch num= 1\n", 393 | "[30 38 7 46 15 0 67 82 82]\tTrue\t-0.8900240329071243\n", 394 | "[30 38 7 46 15 0 67 82 56]\tTrue\t-0.8900240329071243\n", 395 | "[30 38 7 46 15 0 67 82 7]\tTrue\t-0.8900240329071243\n", 396 | "========== timesteps= 9 ==========\n", 397 | "Expansion...\n", 398 | "batch num= 0\n", 399 | "[13 53 53 41 62 76 39 38 7 46]\tFalse\t-1.2724431112719832\n", 400 | "[13 53 53 41 62 76 39 38 7 50]\tFalse\t-1.472136698673575\n", 401 | "[13 53 53 41 62 76 39 38 7 45]\tFalse\t-1.5468339153451782\n", 402 | "[13 53 53 41 62 76 39 78 7 46]\tFalse\t-1.3125356502306427\n", 403 | "[13 53 53 41 62 76 39 78 7 45]\tFalse\t-1.513518005083196\n", 404 | "[13 53 53 41 62 76 39 78 7 49]\tFalse\t-1.5177294947988433\n", 405 | "[13 53 53 41 62 76 39 78 49 49]\tFalse\t-1.4052002087742437\n", 406 | "[13 53 53 41 62 76 39 78 49 45]\tFalse\t-1.4535441308184314\n", 407 | "[13 53 53 41 62 76 39 78 49 46]\tFalse\t-1.5499813329569079\n", 408 | "batch num= 1\n", 409 | "[30 38 7 46 15 0 67 82 82 49]\tTrue\t-0.8900240329071243\n", 410 | "[30 38 7 46 15 0 67 82 82 15]\tTrue\t-0.8900240329071243\n", 411 | "[30 38 7 46 15 0 67 82 82 45]\tTrue\t-0.8900240329071243\n", 412 | "[30 38 7 46 15 0 67 82 56 15]\tTrue\t-0.8900240329071243\n", 413 | "[30 38 7 46 15 0 67 82 56 46]\tTrue\t-0.8900240329071243\n", 414 | "[30 38 7 46 15 0 67 82 56 49]\tTrue\t-0.8900240329071243\n", 415 | "[30 38 7 46 15 0 67 82 7 46]\tTrue\t-0.8900240329071243\n", 416 | "[30 38 7 46 15 0 67 82 7 45]\tTrue\t-0.8900240329071243\n", 417 | "[30 38 7 46 15 0 67 82 7 49]\tTrue\t-0.8900240329071243\n", 418 | "Pruning ...\n", 419 | "batch num= 0\n", 420 | "[13 53 53 41 62 76 39 38 7 46]\tFalse\t-1.2724431112719832\n", 421 | "batch num= 1\n", 422 | "[30 38 7 46 15 0 67 82 7 49]\tTrue\t-0.8900240329071243\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "for t in range(T): \n", 428 | " def _get_preds_and_probs(PREDS):\n", 429 | " probs = decode(y=PREDS) # (N, V)\n", 430 | " preds_k = np.argsort(probs)[:, ::-1][:, :K].flatten() # (K*N,)\n", 431 | " probs_k = np.sort(probs)[:, ::-1][:, :K].flatten() # (K*N,)\n", 432 | " return preds_k, probs_k\n", 433 | " \n", 434 | " def logging(PREDS_k, EOS_k, PROBS_k):\n", 435 | " for i, (PREDS_k_batch, EOS_k_batch, PROBS_k_batch) in \\\n", 436 | " enumerate(zip(np.split(PREDS_k, N), np.split(EOS_k, N), np.split(PROBS_k, N) )):\n", 437 | " print(\"batch num=\", i)\n", 438 | " for each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch in zip(PREDS_k_batch, EOS_k_batch, PROBS_k_batch):\n", 439 | " print(\"{}\\t{}\\t{}\".format(each_PREDS_k_batch, each_EOS_k_batch, each_PROBS_k_batch))\n", 440 | " \n", 441 | " if t==0: # initial step\n", 442 | " print(\"=\"*10, \"timesteps=\", t, \"=\"*10)\n", 443 | " \n", 444 | " preds_k, probs_k = _get_preds_and_probs(initial_y) # (k*N), (k*N)\n", 445 | " PREDS_k = np.expand_dims(preds_k, -1) # PREDS_k: Final outputs, (k*N, 1)\n", 446 | " PROBS_k = probs_k\n", 447 | " EOS_k = preds_k==EOS_ID # \n", 448 | " \n", 449 | " # logging\n", 450 | " logging(PREDS_k, EOS_k, PROBS_k)\n", 451 | " \n", 452 | " else:\n", 453 | " print(\"=\"*10, \"timesteps=\", t, \"=\"*10)\n", 454 | " print(\"Expansion...\")\n", 455 | " \n", 456 | " preds_kk, probs_kk = _get_preds_and_probs(PREDS_k) # (k*k*N), (k*k*N) <- incremental(=local) values\n", 457 | " \n", 458 | " # preds for exanded beams\n", 459 | " PREDS_kk = np.repeat(PREDS_k, k, axis=0) # (k*k*N, t)\n", 460 | " PREDS_kk = np.append(PREDS_kk, np.expand_dims(preds_kk, -1), -1) # PREDS_kk: (k*k*N, t+1)\n", 461 | " \n", 462 | " # eos for expanded beams\n", 463 | " eos_kk = preds_kk==EOS_ID # (k*k*N) <- local\n", 464 | " EOS_kk = np.repeat(EOS_k, k, axis=0) # (k*k*N, )\n", 465 | " EOS_kk = np.logical_or(EOS_kk, eos_kk) # (k*k*N,)\n", 466 | " \n", 467 | " # probs for expanded beams\n", 468 | " PROBS_kk = np.repeat(PROBS_k, k, axis=0) # (k*k*N, )\n", 469 | " normalized_probs = ( PROBS_kk*t + probs_kk ) / (t+1)\n", 470 | " PROBS_kk = np.where(EOS_kk, PROBS_kk, normalized_probs) # (k*k*N, )\n", 471 | " \n", 472 | " # logging\n", 473 | " logging(PREDS_kk, EOS_kk, PROBS_kk)\n", 474 | " \n", 475 | " print(\"Pruning ...\")\n", 476 | " winners = [] # (k*N). k elements are selected out of k^2\n", 477 | " for j, prob_kk in enumerate(np.split(PROBS_kk, N)): # (k*k,) \n", 478 | " if t == T-1: # final step\n", 479 | " winner = np.argsort(prob_kk)[::-1][:1] # final 1 best\n", 480 | " winners.extend(list(winner + j*len(prob_kk)))\n", 481 | " else:\n", 482 | " winner = np.argsort(prob_kk)[::-1][:k]\n", 483 | " winners.extend(list(winner + j*len(prob_kk)))\n", 484 | " \n", 485 | " PREDS_k = PREDS_kk[winners] # (N, T) if final step, otherwise (k*N, t)\n", 486 | " PROBS_k = PROBS_kk[winners] # (N, T) if final step, otherwise (k*N, )\n", 487 | " EOS_k = EOS_kk[winners]\n", 488 | " \n", 489 | " # logging\n", 490 | " logging(PREDS_k, EOS_k, PROBS_k)" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": {}, 496 | "source": [ 497 | "Be aware that the tokens that follows the `` are stripped." 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [] 506 | } 507 | ], 508 | "metadata": { 509 | "kernelspec": { 510 | "display_name": "Python 3", 511 | "language": "python", 512 | "name": "python3" 513 | }, 514 | "language_info": { 515 | "codemirror_mode": { 516 | "name": "ipython", 517 | "version": 3 518 | }, 519 | "file_extension": ".py", 520 | "mimetype": "text/x-python", 521 | "name": "python", 522 | "nbconvert_exporter": "python", 523 | "pygments_lexer": "ipython3", 524 | "version": "3.7.1" 525 | } 526 | }, 527 | "nbformat": 4, 528 | "nbformat_minor": 2 529 | } 530 | -------------------------------------------------------------------------------- /Data Augmentation for Quora Question Pairs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Let's see if it's effective to augment training data in the task of [quora question pairs](https://www.kaggle.com/c/quora-question-pairs)." 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Download and extract QQP dataset." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 17, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 25, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "os.system('wget https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5')\n", 33 | "os.system('unzip QQP.zip')" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 26, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "total 60949\r\n", 46 | "-rw-r--r-- 1 root root 5815716 May 2 2018 dev.tsv\r\n", 47 | "-rw-r--r-- 1 root root 52360463 May 2 2018 train.tsv\r\n", 48 | "drwxr-xr-x 1 root root 0 Aug 5 10:31 original\r\n", 49 | "-rw-r--r-- 1 root root 4259840 Aug 5 10:32 test.tsv\r\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "!ls -ltr QQP" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "Let's check what the training data looks like." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "id\tqid1\tqid2\tquestion1\tquestion2\tis_duplicate\n", 74 | "133273\t213221\t213222\tHow is the life of a math student? Could you describe your own experiences?\tWhich level of prepration is enough for the exam jlpt5?\t0\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "train_data = \"QQP/train.tsv\"\n", 80 | "dev_data = \"QQP/dev.tsv\"\n", 81 | "print(\"\\n\".join(open(train_data, 'r').read().splitlines()[:2]))" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "import numpy as np\n", 91 | "import torch\n", 92 | "from torch import nn\n", 93 | "from torch.autograd import Variable\n", 94 | "import torch.nn.functional as F\n", 95 | "import torch.utils.data as Data\n", 96 | "import torch.optim as optim\n", 97 | "from torch.nn.utils.rnn import pack_padded_sequence\n", 98 | "from sklearn.metrics import f1_score, accuracy_score\n", 99 | "import random\n", 100 | "import copy\n", 101 | "from collections import Counter\n", 102 | "import re" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "## Prepare datasets" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "### train" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "def normalize(sent):\n", 126 | " sent = sent.lower()\n", 127 | " sent = re.sub(\"[^a-z0-9' ]\", \"\", sent)\n", 128 | " return sent" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "def split_data(fin, maxlen):\n", 138 | " '''Divide data into things of label 0's and 1's'''\n", 139 | " data0, data1 = [], []\n", 140 | " for line in open(fin, 'r').read().strip().splitlines()[1:]:\n", 141 | " cols = line.split(\"\\t\")\n", 142 | " if len(cols)==6:\n", 143 | " _, _, _, sent1, sent2, label = cols\n", 144 | " sent1 = normalize(sent1)\n", 145 | " sent2 = normalize(sent2)\n", 146 | " if len(sent1.split()) < maxlen/2 and len(sent2.split()) < maxlen/2:\n", 147 | " pair = (sent1, sent2)\n", 148 | " if label==\"0\":\n", 149 | " data0.append(pair)\n", 150 | " else:\n", 151 | " data1.append(pair) \n", 152 | " return data0, data1" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 7, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "MAXLEN = 200 # We include sentence pairs of which lengths are not more than 200 characters." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 8, 167 | "metadata": { 168 | "scrolled": false 169 | }, 170 | "outputs": [ 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "229442 134378\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "train0, train1 = split_data(train_data, MAXLEN) \n", 181 | "print(len(train0), len(train1))" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 9, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "# all training sents\n", 191 | "train01 = []\n", 192 | "for t in (train0, train1):\n", 193 | " for sent1, sent2 in t:\n", 194 | " train01.append(sent1)\n", 195 | " train01.append(sent2)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### dev" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 10, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "25544 14885\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "dev0, dev1 = split_data(dev_data, MAXLEN) \n", 220 | "print(len(dev0), len(dev1))" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "## Vocabulary" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 11, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "107030" 239 | ] 240 | }, 241 | "execution_count": 11, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "# num_vocab\n", 248 | "words = [word for sent in train01 for word in sent.split()]\n", 249 | "word2cnt = Counter(words)\n", 250 | "len(word2cnt)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 12, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "MIN_CNT = 5 # We include words that occurred at least 5 times.\n", 260 | "vocab = [\"\", \"\", \"\"]\n", 261 | "for word, cnt in word2cnt.most_common(len(word2cnt)):\n", 262 | " if cnt < MIN_CNT:\n", 263 | " break\n", 264 | " vocab.append(word)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "30429" 276 | ] 277 | }, 278 | "execution_count": 13, 279 | "metadata": {}, 280 | "output_type": "execute_result" 281 | } 282 | ], 283 | "source": [ 284 | "VOCAB_SIZE = len(vocab)\n", 285 | "VOCAB_SIZE" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 14, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "token2idx = {token:idx for idx, token in enumerate(vocab)}\n", 295 | "idx2token = {idx:token for idx, token in enumerate(vocab)}" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "## Encode" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 15, 308 | "metadata": {}, 309 | "outputs": [], 310 | "source": [ 311 | "def encode_sents(sent1, sent2):\n", 312 | " tokens1 = [token2idx.get(token, 1) for token in sent1.split()] #1:\n", 313 | " tokens2 = [token2idx.get(token, 1) for token in sent2.split()]\n", 314 | " \n", 315 | " tokens = tokens1 + [2] + tokens2 + [0]*MAXLEN # ... ... ...\n", 316 | " tokens = tokens[:MAXLEN]\n", 317 | " return tokens#" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "### \\#1. baseline" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 16, 330 | "metadata": { 331 | "scrolled": false 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "_X_train0 = [] # list of lists\n", 336 | "for sent1, sent2 in train0:\n", 337 | " tokens = encode_sents(sent1, sent2)\n", 338 | " _X_train0.append(tokens)\n", 339 | "\n", 340 | "_X_train1 = []\n", 341 | "for sent1, sent2 in train1:\n", 342 | " tokens = encode_sents(sent1, sent2)\n", 343 | " _X_train1.append(tokens)\n", 344 | "\n", 345 | "_X_train = _X_train0 + _X_train1\n", 346 | "_Y_train = [0]*len(_X_train0) + [1]*len(_X_train1)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 17, 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "name": "stdout", 356 | "output_type": "stream", 357 | "text": [ 358 | "229442 134378 363820\n" 359 | ] 360 | } 361 | ], 362 | "source": [ 363 | "print(len(_X_train0), len(_X_train1), len(_X_train0)+len(_X_train1))" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 18, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "NUM_EPOCHS = 10\n", 373 | "_X_train *= NUM_EPOCHS\n", 374 | "_Y_train *= NUM_EPOCHS" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 19, 380 | "metadata": {}, 381 | "outputs": [ 382 | { 383 | "name": "stdout", 384 | "output_type": "stream", 385 | "text": [ 386 | "3638200 3638200\n" 387 | ] 388 | } 389 | ], 390 | "source": [ 391 | "print(len(_X_train), len(_Y_train))" 392 | ] 393 | }, 394 | { 395 | "cell_type": "markdown", 396 | "metadata": {}, 397 | "source": [ 398 | "### \\#2. label0 aug." 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "The train0, non-duplicate sentence pairs, is augmented by matching a sentence with a random sentence." 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 20, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "_X_train0_aug = copy.copy(_X_train0)\n", 415 | "\n", 416 | "for sent1, sent2 in train0*9:\n", 417 | " sent = sent1 if random.random() < 0.5 else sent2\n", 418 | " tokens = encode_sents(sent, random.choice(train01))\n", 419 | " _X_train0_aug.append(tokens)\n", 420 | "\n", 421 | "_X_train1 *= NUM_EPOCHS\n", 422 | "_X_train_aug = _X_train0_aug + _X_train1\n", 423 | "_Y_train_aug = [0]*len(_X_train0_aug) + [1]*(len(_X_train1))" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 21, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "3638200 3638200\n" 436 | ] 437 | } 438 | ], 439 | "source": [ 440 | "print(len(_X_train_aug), len(_Y_train_aug))" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "### dev" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 22, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "_X_dev0, _X_dev1 = [], [] # list of lists\n", 457 | "maxlen = 0\n", 458 | "for sent1, sent2 in dev0:\n", 459 | " tokens = encode_sents(sent1, sent2)\n", 460 | " _X_dev0.append(tokens)\n", 461 | "for sent1, sent2 in dev1:\n", 462 | " tokens = encode_sents(sent1, sent2)\n", 463 | " _X_dev1.append(tokens)\n", 464 | "\n", 465 | "_X_dev = _X_dev0 + _X_dev1\n", 466 | "_Y_dev = [0]*len(_X_dev0) + [1]*len(_X_dev1)" 467 | ] 468 | }, 469 | { 470 | "cell_type": "markdown", 471 | "metadata": {}, 472 | "source": [ 473 | "### Convert to tensors" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 23, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "X_train = torch.LongTensor(_X_train)\n", 483 | "Y_train = torch.LongTensor(_Y_train)\n", 484 | "\n", 485 | "X_train_aug = torch.LongTensor(_X_train_aug)\n", 486 | "Y_train_aug = torch.LongTensor(_Y_train_aug)\n", 487 | "\n", 488 | "X_dev = torch.LongTensor(_X_dev)\n", 489 | "Y_dev = torch.LongTensor(_Y_dev)\n" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "## Data Loader" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 24, 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "BATCH_SIZE=256" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": {}, 511 | "source": [ 512 | "### \\#1. baseline" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 25, 518 | "metadata": {}, 519 | "outputs": [ 520 | { 521 | "name": "stdout", 522 | "output_type": "stream", 523 | "text": [ 524 | "14212\n" 525 | ] 526 | } 527 | ], 528 | "source": [ 529 | "train_dataset = Data.TensorDataset(X_train, Y_train)\n", 530 | "train_loader = Data.DataLoader(dataset=train_dataset,\n", 531 | " batch_size=BATCH_SIZE,\n", 532 | " shuffle=True,\n", 533 | " num_workers=4)\n", 534 | "print(len(train_loader))" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": {}, 540 | "source": [ 541 | "### \\#2. label0 aug." 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 26, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "name": "stdout", 551 | "output_type": "stream", 552 | "text": [ 553 | "14212\n" 554 | ] 555 | } 556 | ], 557 | "source": [ 558 | "train_aug_dataset = Data.TensorDataset(X_train_aug, Y_train_aug)\n", 559 | "train_aug_loader = Data.DataLoader(dataset=train_aug_dataset,\n", 560 | " batch_size=BATCH_SIZE,\n", 561 | " shuffle=True,\n", 562 | " num_workers=4)\n", 563 | "print(len(train_aug_loader))" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | "### dev" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 27, 576 | "metadata": {}, 577 | "outputs": [ 578 | { 579 | "name": "stdout", 580 | "output_type": "stream", 581 | "text": [ 582 | "158\n" 583 | ] 584 | } 585 | ], 586 | "source": [ 587 | "dev_dataset = Data.TensorDataset(X_dev, Y_dev)\n", 588 | "dev_loader = Data.DataLoader(dataset=dev_dataset,\n", 589 | " batch_size=BATCH_SIZE,\n", 590 | " shuffle=False,\n", 591 | " num_workers=4)\n", 592 | "print(len(dev_loader))" 593 | ] 594 | }, 595 | { 596 | "cell_type": "markdown", 597 | "metadata": {}, 598 | "source": [ 599 | "## Model" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": 65, 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "class Net(nn.Module):\n", 609 | " def __init__(self, embedding_dim=256, hidden_dim=256, vocab_size=VOCAB_SIZE):\n", 610 | " '''\n", 611 | " Fix the model architecture and its parameters for this purpose\n", 612 | " '''\n", 613 | " super(Net, self).__init__()\n", 614 | " \n", 615 | " self.embed = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)\n", 616 | " self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)\n", 617 | " self.dense = nn.Linear(hidden_dim*2, 2)\n", 618 | "\n", 619 | " def forward(self, x):\n", 620 | " x = x.to('cuda')\n", 621 | " seqlens = (x!=0).long().sum(1) # (N,)\n", 622 | " \n", 623 | " x = self.embed(x) \n", 624 | " \n", 625 | " packed_input = pack_padded_sequence(x, seqlens, batch_first=True, enforce_sorted=False)\n", 626 | " \n", 627 | " _, (last_hidden, c) = self.lstm(packed_input) # last_hidden: (num_layers * num_directions, batch, hidden_size)\n", 628 | " last_hidden = last_hidden.permute(1, 2, 0) # to (batch, hidden, num_directions)\n", 629 | " last_hidden = last_hidden.contiguous().view(last_hidden.size()[0], -1) # to (batch, hidden*num_directions)\n", 630 | " \n", 631 | " logits = self.dense(last_hidden)\n", 632 | " return logits\n" 633 | ] 634 | }, 635 | { 636 | "cell_type": "markdown", 637 | "metadata": {}, 638 | "source": [ 639 | "## Train & test functions" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": 66, 645 | "metadata": {}, 646 | "outputs": [], 647 | "source": [ 648 | "def eval(model, dev_loader):\n", 649 | " model.eval()\n", 650 | "\n", 651 | " y_pred, y_true = [], []\n", 652 | " with torch.no_grad():\n", 653 | " for inputs, targets in dev_loader:\n", 654 | " logits = model(inputs)\n", 655 | " _, preds = logits.max(1, keepdim=False)\n", 656 | " y_pred.extend(preds.tolist())\n", 657 | " y_true.extend(targets.tolist()) \n", 658 | " \n", 659 | " f1score = f1_score(y_true, y_pred)\n", 660 | " acc = accuracy_score(y_true, y_pred)\n", 661 | " \n", 662 | " print('F1_score: %0.3f, acc.: %0.3f\\n' %(f1score, acc))" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 67, 668 | "metadata": {}, 669 | "outputs": [], 670 | "source": [ 671 | "def train(model, train_loader, optimizer, criterion, eval_interval, dev_loader):\n", 672 | " model.train()\n", 673 | " for gs, (inputs, targets) in enumerate(train_loader):\n", 674 | " optimizer.zero_grad()\n", 675 | " logits = model(inputs)\n", 676 | " targets = targets.to('cuda')\n", 677 | " loss = criterion(logits, targets)\n", 678 | " \n", 679 | " loss.backward()\n", 680 | " optimizer.step()\n", 681 | " \n", 682 | " if gs > 0 and gs % eval_interval == 0:\n", 683 | " print(\"global step =\", gs)\n", 684 | " print(\"loss =%.3f\" % loss )\n", 685 | " eval(model, dev_loader)\n", 686 | " model.train()" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "metadata": {}, 692 | "source": [ 693 | "## Experiments" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 74, 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "model = Net().cuda()\n", 703 | "optimizer = optim.Adam(model.parameters(), lr=.001)\n", 704 | "criterion = nn.CrossEntropyLoss()\n", 705 | "eval_interval = len(train_loader)//NUM_EPOCHS" 706 | ] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "metadata": {}, 711 | "source": [ 712 | "### \\#1. baseline" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": 69, 718 | "metadata": {}, 719 | "outputs": [ 720 | { 721 | "name": "stdout", 722 | "output_type": "stream", 723 | "text": [ 724 | "global step = 1421\n", 725 | "loss =0.418\n", 726 | "F1_score: 0.675, acc.: 0.783\n", 727 | "\n", 728 | "global step = 2842\n", 729 | "loss =0.316\n", 730 | "F1_score: 0.729, acc.: 0.802\n", 731 | "\n", 732 | "global step = 4263\n", 733 | "loss =0.197\n", 734 | "F1_score: 0.740, acc.: 0.808\n", 735 | "\n", 736 | "global step = 5684\n", 737 | "loss =0.163\n", 738 | "F1_score: 0.746, acc.: 0.816\n", 739 | "\n", 740 | "global step = 7105\n", 741 | "loss =0.088\n", 742 | "F1_score: 0.744, acc.: 0.814\n", 743 | "\n", 744 | "global step = 8526\n", 745 | "loss =0.068\n", 746 | "F1_score: 0.745, acc.: 0.811\n", 747 | "\n", 748 | "global step = 9947\n", 749 | "loss =0.110\n", 750 | "F1_score: 0.746, acc.: 0.814\n", 751 | "\n", 752 | "global step = 11368\n", 753 | "loss =0.068\n", 754 | "F1_score: 0.748, acc.: 0.818\n", 755 | "\n", 756 | "global step = 12789\n", 757 | "loss =0.016\n", 758 | "F1_score: 0.746, acc.: 0.816\n", 759 | "\n", 760 | "global step = 14210\n", 761 | "loss =0.043\n", 762 | "F1_score: 0.745, acc.: 0.813\n", 763 | "\n" 764 | ] 765 | } 766 | ], 767 | "source": [ 768 | "train(model, train_loader, optimizer, criterion, eval_interval, dev_loader)" 769 | ] 770 | }, 771 | { 772 | "cell_type": "markdown", 773 | "metadata": {}, 774 | "source": [ 775 | "△ The best F1 score is .748, and accuracy is .818." 776 | ] 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "metadata": {}, 781 | "source": [ 782 | "### \\#2. aug." 783 | ] 784 | }, 785 | { 786 | "cell_type": "code", 787 | "execution_count": 75, 788 | "metadata": {}, 789 | "outputs": [ 790 | { 791 | "name": "stdout", 792 | "output_type": "stream", 793 | "text": [ 794 | "global step = 1421\n", 795 | "loss =0.298\n", 796 | "F1_score: 0.699, acc.: 0.727\n", 797 | "\n", 798 | "global step = 2842\n", 799 | "loss =0.193\n", 800 | "F1_score: 0.730, acc.: 0.766\n", 801 | "\n", 802 | "global step = 4263\n", 803 | "loss =0.134\n", 804 | "F1_score: 0.735, acc.: 0.765\n", 805 | "\n", 806 | "global step = 5684\n", 807 | "loss =0.122\n", 808 | "F1_score: 0.745, acc.: 0.776\n", 809 | "\n", 810 | "global step = 7105\n", 811 | "loss =0.110\n", 812 | "F1_score: 0.760, acc.: 0.799\n", 813 | "\n", 814 | "global step = 8526\n", 815 | "loss =0.120\n", 816 | "F1_score: 0.764, acc.: 0.801\n", 817 | "\n", 818 | "global step = 9947\n", 819 | "loss =0.081\n", 820 | "F1_score: 0.766, acc.: 0.804\n", 821 | "\n", 822 | "global step = 11368\n", 823 | "loss =0.086\n", 824 | "F1_score: 0.770, acc.: 0.809\n", 825 | "\n", 826 | "global step = 12789\n", 827 | "loss =0.053\n", 828 | "F1_score: 0.776, acc.: 0.820\n", 829 | "\n", 830 | "global step = 14210\n", 831 | "loss =0.080\n", 832 | "F1_score: 0.775, acc.: 0.821\n", 833 | "\n" 834 | ] 835 | } 836 | ], 837 | "source": [ 838 | "train(model, train_aug_loader, optimizer, criterion, eval_interval, dev_loader)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "markdown", 843 | "metadata": {}, 844 | "source": [ 845 | "△ The best F1 score is .776, and accuracy is .820." 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": null, 851 | "metadata": {}, 852 | "outputs": [], 853 | "source": [] 854 | } 855 | ], 856 | "metadata": { 857 | "anaconda-cloud": {}, 858 | "kernelspec": { 859 | "display_name": "Python 3", 860 | "language": "python", 861 | "name": "python3" 862 | }, 863 | "language_info": { 864 | "codemirror_mode": { 865 | "name": "ipython", 866 | "version": 3 867 | }, 868 | "file_extension": ".py", 869 | "mimetype": "text/x-python", 870 | "name": "python", 871 | "nbconvert_exporter": "python", 872 | "pygments_lexer": "ipython3", 873 | "version": "3.6.7" 874 | } 875 | }, 876 | "nbformat": 4, 877 | "nbformat_minor": 2 878 | } 879 | -------------------------------------------------------------------------------- /Dropout in a minute.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Dropout is arguably the most popular regularization technique in deep learning. Let's check again how it work." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 10, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np\n", 28 | "import tensorflow as tf" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 6, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "'1.5.0'" 40 | ] 41 | }, 42 | "execution_count": 6, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "tf.__version__" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 7, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "class Graph:\n", 58 | " def __init__(self, keep_prob=1.):\n", 59 | " # Inputs\n", 60 | " x = tf.expand_dims(tf.convert_to_tensor([1.], tf.float32), 1)\n", 61 | " y = tf.expand_dims(tf.convert_to_tensor([2.], tf.float32), 1)\n", 62 | "\n", 63 | " # Variables\n", 64 | " w1 = tf.Variable([[0.1, -0.1, 0.2]], dtype=tf.float32, name=\"weight1\")\n", 65 | "\n", 66 | " # fully connected layer (a.k.a. dense layer)\n", 67 | " h = tf.nn.relu(tf.matmul(x, w1))\n", 68 | " self.h = tf.nn.dropout(h, keep_prob=keep_prob)\n", 69 | "\n", 70 | " # Readout layer\n", 71 | " w2 = tf.Variable([[0.2], [0.1], [-0.1]], dtype=tf.float32, name=\"weight2\")\n", 72 | " self.pred = tf.matmul(self.h, w2)\n", 73 | "\n", 74 | " # Loss\n", 75 | " self.loss = tf.reduce_mean(tf.square(self.pred - y)) # L2 loss\n", 76 | "\n", 77 | " # Training scheme\n", 78 | " optimizer = tf.train.GradientDescentOptimizer(0.001)\n", 79 | " self.grads_and_vars = optimizer.compute_gradients(self.loss)\n", 80 | " self.train_op = optimizer.apply_gradients(self.grads_and_vars)\n", 81 | "\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 8, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def run(keep_prob=1.):\n", 91 | " g = Graph(keep_prob=keep_prob)\n", 92 | " with tf.Session() as sess:\n", 93 | " sess.run(tf.global_variables_initializer())\n", 94 | " \n", 95 | " # feed-forward and back-prop for getting gradients\n", 96 | " loss, hidden_units, output, _, _grads_and_vars = sess.run([g.loss, g.h, g.pred, g.train_op, g.grads_and_vars])\n", 97 | " grad1 = _grads_and_vars[0][0]\n", 98 | " grad2 = _grads_and_vars[1][0]\n", 99 | " \n", 100 | " return loss, hidden_units, output, grad1, grad2\n" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 9, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "loss, hidden_units, output, grad1, grad2 = run()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "### Results of no dropouts" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 18, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "loss= 4.0\n", 129 | "hidden units= [[0.1 0. 0.2]]\n", 130 | "y_hat= [[0.]]\n", 131 | "grad1= [[-0.8 0. 0.4]]\n", 132 | "grad2= [[-0.4]\n", 133 | " [ 0. ]\n", 134 | " [-0.8]]\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "print(\"loss=\", loss)\n", 140 | "print(\"hidden units=\", hidden_units)\n", 141 | "print(\"y_hat=\", output)\n", 142 | "print(\"grad1=\", grad1)\n", 143 | "print(\"grad2=\", grad2)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "Gradients flow back through the first and third units." 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "### Results of dropouts (50:50 prob.)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 24, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "tf.reset_default_graph()\n", 174 | "_loss, _hidden_units, _output, _grad1, _grad2 = run(keep_prob=.5)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 25, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "loss= 3.8416002\n", 187 | "hidden units= [[0.2 0. 0. ]]\n", 188 | "y_hat= [[0.04]]\n", 189 | "grad1= [[-1.57 0. 0. ]]\n", 190 | "grad2= [[-0.78]\n", 191 | " [ 0. ]\n", 192 | " [ 0. ]]\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "print(\"loss=\", _loss)\n", 198 | "print(\"hidden units=\", _hidden_units)\n", 199 | "print(\"y_hat=\", _output)\n", 200 | "print(\"grad1=\", _grad1)\n", 201 | "print(\"grad2=\", _grad2)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": { 214 | "collapsed": true 215 | }, 216 | "source": [ 217 | "The third unit becomes zero, so the gradient flows back through only the first unit." 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [] 226 | } 227 | ], 228 | "metadata": { 229 | "anaconda-cloud": {}, 230 | "kernelspec": { 231 | "display_name": "Python 3", 232 | "language": "python", 233 | "name": "python3" 234 | }, 235 | "language_info": { 236 | "codemirror_mode": { 237 | "name": "ipython", 238 | "version": 3 239 | }, 240 | "file_extension": ".py", 241 | "mimetype": "text/x-python", 242 | "name": "python", 243 | "nbconvert_exporter": "python", 244 | "pygments_lexer": "ipython3", 245 | "version": "3.7.1" 246 | } 247 | }, 248 | "nbformat": 4, 249 | "nbformat_minor": 1 250 | } 251 | -------------------------------------------------------------------------------- /How to get the last hidden vector of rnns properly.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "We'll see how to get the last hidden states of Rnns in Tensorflow and PyTorch." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import numpy as np\n", 28 | "import tensorflow as tf\n", 29 | "import torch\n", 30 | "from torch import nn\n", 31 | "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "'1.14.0'" 43 | ] 44 | }, 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "tf.__version__" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "data": { 61 | "text/plain": [ 62 | "'1.2.0'" 63 | ] 64 | }, 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "output_type": "execute_result" 68 | } 69 | ], 70 | "source": [ 71 | "torch.__version__" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "# Tensorflow" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "### Uni-directional" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "tf.reset_default_graph()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 6, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def rnn(x, bidirectional=False, seqlens=None, reuse=False):\n", 104 | " if not bidirectional:\n", 105 | " with tf.variable_scope(\"rnn\", reuse=reuse):\n", 106 | " cell = tf.contrib.rnn.GRUCell(1)\n", 107 | " outputs, last_hidden = tf.nn.dynamic_rnn(cell, x, sequence_length=seqlens, dtype=tf.float32)\n", 108 | " else: \n", 109 | " with tf.variable_scope(\"birnn\", reuse=reuse):\n", 110 | " cell = tf.contrib.rnn.GRUCell(1)\n", 111 | " cell_bw = tf.contrib.rnn.GRUCell(1)\n", 112 | " outputs, last_hidden = tf.nn.bidirectional_dynamic_rnn(cell, cell_bw, x, sequence_length=seqlens, dtype=tf.float32)\n", 113 | " outputs, last_hidden = tf.concat(outputs,-1), tf.concat(last_hidden, -1)\n", 114 | "\n", 115 | " return outputs, last_hidden" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 7, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "def onehot(arry, size):\n", 125 | " '''\n", 126 | " arry: 2-d array of n, t\n", 127 | " size: output dimensions\n", 128 | " \n", 129 | " returns\n", 130 | " 3-d array of (n, t, size)\n", 131 | " '''\n", 132 | " labels_one_hot = (arry.ravel()[np.newaxis] == np.arange(size)[:, np.newaxis]).T\n", 133 | " labels_one_hot.shape = arry.shape + (size,)\n", 134 | " return labels_one_hot.astype('float32')" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 8, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "_x1 = np.array([1, 2, 3], np.int32)\n", 144 | "_x1 = onehot(np.expand_dims(_x1, 0), 4)\n", 145 | "\n", 146 | "_x2 = np.array([1, 2, 3, 0], np.int32) # 0 means padding\n", 147 | "_x2 = onehot(np.expand_dims(_x2, 0), 4)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 9, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "WARNING:tensorflow:\n", 160 | "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", 161 | "For more information, please see:\n", 162 | " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", 163 | " * https://github.com/tensorflow/addons\n", 164 | " * https://github.com/tensorflow/io (for I/O related ops)\n", 165 | "If you depend on functionality not listed there, please file an issue.\n", 166 | "\n", 167 | "WARNING:tensorflow:From :4: GRUCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n", 168 | "Instructions for updating:\n", 169 | "This class is equivalent as tf.keras.layers.GRUCell, and will be replaced by that in Tensorflow 2.0.\n", 170 | "WARNING:tensorflow:From :5: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", 171 | "Instructions for updating:\n", 172 | "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n", 173 | "WARNING:tensorflow:From C:\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 174 | "Instructions for updating:\n", 175 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n", 176 | "WARNING:tensorflow:From C:\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\rnn_cell_impl.py:564: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 177 | "Instructions for updating:\n", 178 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n", 179 | "WARNING:tensorflow:From C:\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\rnn_cell_impl.py:574: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 180 | "Instructions for updating:\n", 181 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n", 182 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 183 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "# 1. no padding\n", 189 | "x1 = tf.convert_to_tensor(_x1)\n", 190 | "outputs1, last_hidden1 = rnn(x1)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 10, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 203 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "# 2. zero padding, no seqlens\n", 209 | "x2 = tf.convert_to_tensor(_x2)\n", 210 | "outputs2, last_hidden2 = rnn(x2, reuse=True) # We want to sync the variables up to compare the results." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 11, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 223 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 224 | "WARNING:tensorflow:From C:\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\ops\\rnn.py:244: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 225 | "Instructions for updating:\n", 226 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "# 3. zero padding with explicit seqlens\n", 232 | "outputs3, last_hidden3 = rnn(x2, seqlens=[3,], reuse=True) # Real sequence length is 3 as the last 0 is a padding." 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 12, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "# Session\n", 242 | "init_op = tf.global_variables_initializer()\n", 243 | "sess = tf.InteractiveSession()\n", 244 | "sess.run(init_op)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 13, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "text/plain": [ 255 | "array([[[-0.31044602],\n", 256 | " [-0.1496363 ],\n", 257 | " [ 0.0102744 ]]], dtype=float32)" 258 | ] 259 | }, 260 | "execution_count": 13, 261 | "metadata": {}, 262 | "output_type": "execute_result" 263 | } 264 | ], 265 | "source": [ 266 | "outputs1.eval()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 14, 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "data": { 276 | "text/plain": [ 277 | "array([[[-0.31044602],\n", 278 | " [-0.1496363 ],\n", 279 | " [ 0.0102744 ],\n", 280 | " [-0.2648457 ]]], dtype=float32)" 281 | ] 282 | }, 283 | "execution_count": 14, 284 | "metadata": {}, 285 | "output_type": "execute_result" 286 | } 287 | ], 288 | "source": [ 289 | "outputs2.eval() # the last step has non-zero outputs. This is not we usually want." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 15, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "text/plain": [ 300 | "array([[[-0.31044602],\n", 301 | " [-0.1496363 ],\n", 302 | " [ 0.0102744 ],\n", 303 | " [ 0. ]]], dtype=float32)" 304 | ] 305 | }, 306 | "execution_count": 15, 307 | "metadata": {}, 308 | "output_type": "execute_result" 309 | } 310 | ], 311 | "source": [ 312 | "outputs3.eval() # the last step is masked to zeros. This is usually correct." 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 16, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "data": { 322 | "text/plain": [ 323 | "array([[0.0102744]], dtype=float32)" 324 | ] 325 | }, 326 | "execution_count": 16, 327 | "metadata": {}, 328 | "output_type": "execute_result" 329 | } 330 | ], 331 | "source": [ 332 | "last_hidden1.eval()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 17, 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "text/plain": [ 343 | "array([[-0.2648457]], dtype=float32)" 344 | ] 345 | }, 346 | "execution_count": 17, 347 | "metadata": {}, 348 | "output_type": "execute_result" 349 | } 350 | ], 351 | "source": [ 352 | "last_hidden2.eval()" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 18, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "data": { 362 | "text/plain": [ 363 | "array([[0.0102744]], dtype=float32)" 364 | ] 365 | }, 366 | "execution_count": 18, 367 | "metadata": {}, 368 | "output_type": "execute_result" 369 | } 370 | ], 371 | "source": [ 372 | "last_hidden3.eval() # Now we have the same results as # 1." 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "△ Comment: Paddings are mostly added to construct mini-batches from multiples samples of variable lengths. Therefore typically we want to get the same results as the case we treat them individually and do not pad. To that end, when you add paddings, you should add `seqlens`. Paddings are masked to zeros." 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": {}, 385 | "source": [ 386 | "### Bi-directional" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 19, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "tf.reset_default_graph()" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 20, 401 | "metadata": {}, 402 | "outputs": [ 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "WARNING:tensorflow:From :10: bidirectional_dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", 408 | "Instructions for updating:\n", 409 | "Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API\n", 410 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 411 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 412 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 413 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n" 414 | ] 415 | } 416 | ], 417 | "source": [ 418 | "# 1. no padding\n", 419 | "x1 = tf.convert_to_tensor(_x1)\n", 420 | "outputs1, last_hidden1 = rnn(x1, bidirectional=True)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 21, 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "name": "stdout", 430 | "output_type": "stream", 431 | "text": [ 432 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 433 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 434 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 435 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n" 436 | ] 437 | } 438 | ], 439 | "source": [ 440 | "# 2. zero padding, no seqlens\n", 441 | "x2 = tf.convert_to_tensor(_x2)\n", 442 | "outputs2, last_hidden2 = rnn(x2, bidirectional=True, reuse=True)" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 22, 448 | "metadata": {}, 449 | "outputs": [ 450 | { 451 | "name": "stdout", 452 | "output_type": "stream", 453 | "text": [ 454 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 455 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 456 | "WARNING:tensorflow:Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n", 457 | "WARNING: Entity > could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting >: AssertionError: Bad argument number for Name: 3, expecting 4\n" 458 | ] 459 | } 460 | ], 461 | "source": [ 462 | "# 3. zero padding with explicit seqlens\n", 463 | "outputs3, last_hidden3 = rnn(x2, bidirectional=True, seqlens=[3,], reuse=True)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 23, 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "name": "stderr", 473 | "output_type": "stream", 474 | "text": [ 475 | "C:\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py:1735: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n", 476 | " warnings.warn('An interactive session is already active. This can '\n" 477 | ] 478 | } 479 | ], 480 | "source": [ 481 | "# Session\n", 482 | "init_op = tf.global_variables_initializer()\n", 483 | "sess = tf.InteractiveSession()\n", 484 | "sess.run(init_op)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 24, 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "data": { 494 | "text/plain": [ 495 | "array([[[ 0.03057988, -0.36812052],\n", 496 | " [ 0.11028282, -0.25675917],\n", 497 | " [ 0.03307126, 0.01558547]]], dtype=float32)" 498 | ] 499 | }, 500 | "execution_count": 24, 501 | "metadata": {}, 502 | "output_type": "execute_result" 503 | } 504 | ], 505 | "source": [ 506 | "outputs1.eval()" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 25, 512 | "metadata": {}, 513 | "outputs": [ 514 | { 515 | "data": { 516 | "text/plain": [ 517 | "array([[[ 0.03057988, -0.36047477],\n", 518 | " [ 0.11028282, -0.24538502],\n", 519 | " [ 0.03307126, 0.03696197],\n", 520 | " [ 0.1797748 , 0.02438327]]], dtype=float32)" 521 | ] 522 | }, 523 | "execution_count": 25, 524 | "metadata": {}, 525 | "output_type": "execute_result" 526 | } 527 | ], 528 | "source": [ 529 | "outputs2.eval() # Again, this is not we want." 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 26, 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "data": { 539 | "text/plain": [ 540 | "array([[[ 0.03057988, -0.36812052],\n", 541 | " [ 0.11028282, -0.25675917],\n", 542 | " [ 0.03307126, 0.01558547],\n", 543 | " [ 0. , 0. ]]], dtype=float32)" 544 | ] 545 | }, 546 | "execution_count": 26, 547 | "metadata": {}, 548 | "output_type": "execute_result" 549 | } 550 | ], 551 | "source": [ 552 | "outputs3.eval() # Again, note that the last step is masked to zeros." 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 27, 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "data": { 562 | "text/plain": [ 563 | "array([[ 0.03307126, -0.36812052]], dtype=float32)" 564 | ] 565 | }, 566 | "execution_count": 27, 567 | "metadata": {}, 568 | "output_type": "execute_result" 569 | } 570 | ], 571 | "source": [ 572 | "last_hidden1.eval()" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 28, 578 | "metadata": {}, 579 | "outputs": [ 580 | { 581 | "data": { 582 | "text/plain": [ 583 | "array([[ 0.1797748 , -0.36047477]], dtype=float32)" 584 | ] 585 | }, 586 | "execution_count": 28, 587 | "metadata": {}, 588 | "output_type": "execute_result" 589 | } 590 | ], 591 | "source": [ 592 | "last_hidden2.eval()" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 29, 598 | "metadata": {}, 599 | "outputs": [ 600 | { 601 | "data": { 602 | "text/plain": [ 603 | "array([[ 0.03307126, -0.36812052]], dtype=float32)" 604 | ] 605 | }, 606 | "execution_count": 29, 607 | "metadata": {}, 608 | "output_type": "execute_result" 609 | } 610 | ], 611 | "source": [ 612 | "last_hidden3.eval()" 613 | ] 614 | }, 615 | { 616 | "cell_type": "markdown", 617 | "metadata": {}, 618 | "source": [ 619 | "△ Note that in bidirectional rnns, the last_hidden state of the forward rnn (=0.03307126) is from the rightmost step in the sequence, while the last hidden state of the backward rnn (=-0.36812052) is from the leftmost step." 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "metadata": {}, 625 | "source": [ 626 | "# PyTorch" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "metadata": {}, 632 | "source": [ 633 | "## Uni-directional" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 30, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "class Rnn(torch.nn.Module):\n", 643 | " def __init__(self, bidirectional=False):\n", 644 | " super().__init__()\n", 645 | " self.rnn = nn.GRU(4, 1, batch_first=True, bidirectional=bidirectional)\n", 646 | " \n", 647 | " def forward(self, x, seqlens=None):\n", 648 | " if seqlens is not None:\n", 649 | " # packing -> rnn -> unpacking -> position recovery\n", 650 | " packed_input = pack_padded_sequence(x, seqlens, batch_first=True, enforce_sorted=False) \n", 651 | " outputs, last_hidden = self.rnn(packed_input)\n", 652 | " outputs, _ = pad_packed_sequence(outputs, batch_first=True, total_length=x.size()[1])\n", 653 | " else:\n", 654 | " outputs, last_hidden = self.rnn(x)\n", 655 | " last_hidden = last_hidden.permute(1, 2, 0) # to (batch, hidden, num_directions)\n", 656 | " last_hidden = last_hidden.view(last_hidden.size()[0], -1) # to (batch, hidden*num_directions)\n", 657 | " \n", 658 | " return outputs, last_hidden" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 31, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "# 1. no padding\n", 668 | "x1 = torch.from_numpy(_x1)\n", 669 | "model1 = Rnn()\n", 670 | "outputs1, last_hidden1 = model1(x1)" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 32, 676 | "metadata": {}, 677 | "outputs": [], 678 | "source": [ 679 | "# 2. zero padding, no seqlens\n", 680 | "x2 = torch.from_numpy(_x2)\n", 681 | "model2 = Rnn()\n", 682 | "for p1, p2 in zip(model1.parameters(), model2.parameters()): # sync up the variables\n", 683 | " p2.data = p1.data\n", 684 | "outputs2, last_hidden2 = model2(x2)" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 33, 690 | "metadata": {}, 691 | "outputs": [], 692 | "source": [ 693 | "# 3. zero padding with explicit seqlens\n", 694 | "model3 = Rnn()\n", 695 | "for p1, p3 in zip(model1.parameters(), model3.parameters()):\n", 696 | " p3.data = p1.data\n", 697 | "outputs3, last_hidden3 = model3(x2, seqlens=[3,])" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": 34, 703 | "metadata": {}, 704 | "outputs": [ 705 | { 706 | "data": { 707 | "text/plain": [ 708 | "tensor([[[0.0022],\n", 709 | " [0.1187],\n", 710 | " [0.2030]]], grad_fn=)" 711 | ] 712 | }, 713 | "execution_count": 34, 714 | "metadata": {}, 715 | "output_type": "execute_result" 716 | } 717 | ], 718 | "source": [ 719 | "outputs1" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 35, 725 | "metadata": {}, 726 | "outputs": [ 727 | { 728 | "data": { 729 | "text/plain": [ 730 | "tensor([[[ 0.0022],\n", 731 | " [ 0.1187],\n", 732 | " [ 0.2030],\n", 733 | " [-0.0466]]], grad_fn=)" 734 | ] 735 | }, 736 | "execution_count": 35, 737 | "metadata": {}, 738 | "output_type": "execute_result" 739 | } 740 | ], 741 | "source": [ 742 | "outputs2" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": 36, 748 | "metadata": {}, 749 | "outputs": [ 750 | { 751 | "data": { 752 | "text/plain": [ 753 | "tensor([[[0.0022],\n", 754 | " [0.1187],\n", 755 | " [0.2030],\n", 756 | " [0.0000]]], grad_fn=)" 757 | ] 758 | }, 759 | "execution_count": 36, 760 | "metadata": {}, 761 | "output_type": "execute_result" 762 | } 763 | ], 764 | "source": [ 765 | "outputs3" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": 37, 771 | "metadata": {}, 772 | "outputs": [ 773 | { 774 | "data": { 775 | "text/plain": [ 776 | "tensor([[0.2030]], grad_fn=)" 777 | ] 778 | }, 779 | "execution_count": 37, 780 | "metadata": {}, 781 | "output_type": "execute_result" 782 | } 783 | ], 784 | "source": [ 785 | "last_hidden1" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": 38, 791 | "metadata": {}, 792 | "outputs": [ 793 | { 794 | "data": { 795 | "text/plain": [ 796 | "tensor([[-0.0466]], grad_fn=)" 797 | ] 798 | }, 799 | "execution_count": 38, 800 | "metadata": {}, 801 | "output_type": "execute_result" 802 | } 803 | ], 804 | "source": [ 805 | "last_hidden2" 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "execution_count": 39, 811 | "metadata": {}, 812 | "outputs": [ 813 | { 814 | "data": { 815 | "text/plain": [ 816 | "tensor([[0.2030]], grad_fn=)" 817 | ] 818 | }, 819 | "execution_count": 39, 820 | "metadata": {}, 821 | "output_type": "execute_result" 822 | } 823 | ], 824 | "source": [ 825 | "last_hidden3" 826 | ] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "metadata": {}, 831 | "source": [ 832 | "△ Since there's no such argument as seqlens in pytorch, a trick was used. " 833 | ] 834 | }, 835 | { 836 | "cell_type": "markdown", 837 | "metadata": {}, 838 | "source": [ 839 | "## Bi-directional" 840 | ] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "execution_count": 40, 845 | "metadata": {}, 846 | "outputs": [], 847 | "source": [ 848 | "# 1. no padding\n", 849 | "model1 = Rnn(bidirectional=True)\n", 850 | "outputs1, last_hidden1 = model1(x1)" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": 41, 856 | "metadata": {}, 857 | "outputs": [], 858 | "source": [ 859 | "# 2. zero padding without seqlens\n", 860 | "model2 = Rnn(bidirectional=True)\n", 861 | "for p1, p2 in zip(model1.parameters(), model2.parameters()):\n", 862 | " p2.data = p1.data\n", 863 | "outputs2, last_hidden2 = model2(x2)" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": 42, 869 | "metadata": {}, 870 | "outputs": [], 871 | "source": [ 872 | "# 3. zero padding with explicit seqlens\n", 873 | "model3 = Rnn(bidirectional=True)\n", 874 | "for p1, p3 in zip(model1.parameters(), model3.parameters()):\n", 875 | " p3.data = p1.data\n", 876 | "outputs3, last_hidden3 = model3(x2, seqlens=[3,])" 877 | ] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "execution_count": 43, 882 | "metadata": {}, 883 | "outputs": [ 884 | { 885 | "data": { 886 | "text/plain": [ 887 | "tensor([[[ 0.1541, 0.1542],\n", 888 | " [ 0.1018, 0.1953],\n", 889 | " [ 0.3250, -0.2467]]], grad_fn=)" 890 | ] 891 | }, 892 | "execution_count": 43, 893 | "metadata": {}, 894 | "output_type": "execute_result" 895 | } 896 | ], 897 | "source": [ 898 | "outputs1" 899 | ] 900 | }, 901 | { 902 | "cell_type": "code", 903 | "execution_count": 44, 904 | "metadata": {}, 905 | "outputs": [ 906 | { 907 | "data": { 908 | "text/plain": [ 909 | "tensor([[[ 0.1541, 0.1687],\n", 910 | " [ 0.1018, 0.2188],\n", 911 | " [ 0.3250, -0.1759],\n", 912 | " [ 0.2979, 0.0983]]], grad_fn=)" 913 | ] 914 | }, 915 | "execution_count": 44, 916 | "metadata": {}, 917 | "output_type": "execute_result" 918 | } 919 | ], 920 | "source": [ 921 | "outputs2" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": 45, 927 | "metadata": {}, 928 | "outputs": [ 929 | { 930 | "data": { 931 | "text/plain": [ 932 | "tensor([[[ 0.1541, 0.1542],\n", 933 | " [ 0.1018, 0.1953],\n", 934 | " [ 0.3250, -0.2467],\n", 935 | " [ 0.0000, 0.0000]]], grad_fn=)" 936 | ] 937 | }, 938 | "execution_count": 45, 939 | "metadata": {}, 940 | "output_type": "execute_result" 941 | } 942 | ], 943 | "source": [ 944 | "outputs3" 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "execution_count": 46, 950 | "metadata": {}, 951 | "outputs": [ 952 | { 953 | "data": { 954 | "text/plain": [ 955 | "tensor([[0.3250, 0.1542]], grad_fn=)" 956 | ] 957 | }, 958 | "execution_count": 46, 959 | "metadata": {}, 960 | "output_type": "execute_result" 961 | } 962 | ], 963 | "source": [ 964 | "last_hidden1" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "execution_count": 47, 970 | "metadata": {}, 971 | "outputs": [ 972 | { 973 | "data": { 974 | "text/plain": [ 975 | "tensor([[0.2979, 0.1687]], grad_fn=)" 976 | ] 977 | }, 978 | "execution_count": 47, 979 | "metadata": {}, 980 | "output_type": "execute_result" 981 | } 982 | ], 983 | "source": [ 984 | "last_hidden2" 985 | ] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "execution_count": 48, 990 | "metadata": {}, 991 | "outputs": [ 992 | { 993 | "data": { 994 | "text/plain": [ 995 | "tensor([[0.3250, 0.1542]], grad_fn=)" 996 | ] 997 | }, 998 | "execution_count": 48, 999 | "metadata": {}, 1000 | "output_type": "execute_result" 1001 | } 1002 | ], 1003 | "source": [ 1004 | "last_hidden3" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "markdown", 1009 | "metadata": {}, 1010 | "source": [ 1011 | "△ Same here." 1012 | ] 1013 | } 1014 | ], 1015 | "metadata": { 1016 | "kernelspec": { 1017 | "display_name": "Python 3", 1018 | "language": "python", 1019 | "name": "python3" 1020 | }, 1021 | "language_info": { 1022 | "codemirror_mode": { 1023 | "name": "ipython", 1024 | "version": 3 1025 | }, 1026 | "file_extension": ".py", 1027 | "mimetype": "text/x-python", 1028 | "name": "python", 1029 | "nbconvert_exporter": "python", 1030 | "pygments_lexer": "ipython3", 1031 | "version": "3.7.3" 1032 | } 1033 | }, 1034 | "nbformat": 4, 1035 | "nbformat_minor": 2 1036 | } 1037 | -------------------------------------------------------------------------------- /Pos-tagging with Bert Fine-tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[BERT](https://arxiv.org/abs/1810.04805) is known to be good at Sequence tagging tasks like Named Entity Recognition. Let's see if it's true for POS-tagging." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import os\n", 28 | "from tqdm import tqdm_notebook as tqdm\n", 29 | "import numpy as np\n", 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "from torch.utils import data\n", 33 | "import torch.optim as optim\n", 34 | "from pytorch_pretrained_bert import BertTokenizer" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/plain": [ 45 | "'1.0.0'" 46 | ] 47 | }, 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | } 52 | ], 53 | "source": [ 54 | "torch.__version__" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "# Data preparation" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "Thanks to the great NLTK, we don't have to worry about datasets. Some of Penn Tree Banks are included in it. I believe they serves for the purpose." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "3914" 80 | ] 81 | }, 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "import nltk\n", 89 | "tagged_sents = nltk.corpus.treebank.tagged_sents()\n", 90 | "len(tagged_sents)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "[('Pierre', 'NNP'),\n", 102 | " ('Vinken', 'NNP'),\n", 103 | " (',', ','),\n", 104 | " ('61', 'CD'),\n", 105 | " ('years', 'NNS'),\n", 106 | " ('old', 'JJ'),\n", 107 | " (',', ','),\n", 108 | " ('will', 'MD'),\n", 109 | " ('join', 'VB'),\n", 110 | " ('the', 'DT'),\n", 111 | " ('board', 'NN'),\n", 112 | " ('as', 'IN'),\n", 113 | " ('a', 'DT'),\n", 114 | " ('nonexecutive', 'JJ'),\n", 115 | " ('director', 'NN'),\n", 116 | " ('Nov.', 'NNP'),\n", 117 | " ('29', 'CD'),\n", 118 | " ('.', '.')]" 119 | ] 120 | }, 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "tagged_sents[0]" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "tags = list(set(word_pos[1] for sent in tagged_sents for word_pos in sent))" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "\"JJS,VBN,SYM,NNP,POS,'',-NONE-,WP,CD,.,VBZ,RBS,RB,-RRB-,NNPS,FW,WDT,DT,WRB,PRP$,:,MD,JJ,$,EX,RBR,VBD,VBP,NN,PRP,-LRB-,LS,NNS,RP,#,TO,,,``,IN,VBG,CC,JJR,PDT,UH,VB,WP$\"" 148 | ] 149 | }, 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "output_type": "execute_result" 153 | } 154 | ], 155 | "source": [ 156 | "\",\".join(tags)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 8, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# By convention, the 0'th slot is reserved for padding.\n", 166 | "tags = [\"\"] + tags" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 9, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "tag2idx = {tag:idx for idx, tag in enumerate(tags)}\n", 176 | "idx2tag = {idx:tag for idx, tag in enumerate(tags)}" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 10, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/plain": [ 187 | "(3522, 392)" 188 | ] 189 | }, 190 | "execution_count": 10, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "# Let's split the data into train and test (or eval)\n", 197 | "from sklearn.model_selection import train_test_split\n", 198 | "train_data, test_data = train_test_split(tagged_sents, test_size=.1)\n", 199 | "len(train_data), len(test_data)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 11, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "# Data loader\n" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 12, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 13, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "class PosDataset(data.Dataset):\n", 234 | " def __init__(self, tagged_sents):\n", 235 | " sents, tags_li = [], [] # list of lists\n", 236 | " for sent in tagged_sents:\n", 237 | " words = [word_pos[0] for word_pos in sent]\n", 238 | " tags = [word_pos[1] for word_pos in sent]\n", 239 | " sents.append([\"[CLS]\"] + words + [\"[SEP]\"])\n", 240 | " tags_li.append([\"\"] + tags + [\"\"])\n", 241 | " self.sents, self.tags_li = sents, tags_li\n", 242 | "\n", 243 | " def __len__(self):\n", 244 | " return len(self.sents)\n", 245 | "\n", 246 | " def __getitem__(self, idx):\n", 247 | " words, tags = self.sents[idx], self.tags_li[idx] # words, tags: string list\n", 248 | "\n", 249 | " # We give credits only to the first piece.\n", 250 | " x, y = [], [] # list of ids\n", 251 | " is_heads = [] # list. 1: the token is the first piece of a word\n", 252 | " for w, t in zip(words, tags):\n", 253 | " tokens = tokenizer.tokenize(w) if w not in (\"[CLS]\", \"[SEP]\") else [w]\n", 254 | " xx = tokenizer.convert_tokens_to_ids(tokens)\n", 255 | "\n", 256 | " is_head = [1] + [0]*(len(tokens) - 1)\n", 257 | "\n", 258 | " t = [t] + [\"\"] * (len(tokens) - 1) # : no decision\n", 259 | " yy = [tag2idx[each] for each in t] # (T,)\n", 260 | "\n", 261 | " x.extend(xx)\n", 262 | " is_heads.extend(is_head)\n", 263 | " y.extend(yy)\n", 264 | "\n", 265 | " assert len(x)==len(y)==len(is_heads), \"len(x)={}, len(y)={}, len(is_heads)={}\".format(len(x), len(y), len(is_heads))\n", 266 | "\n", 267 | " # seqlen\n", 268 | " seqlen = len(y)\n", 269 | "\n", 270 | " # to string\n", 271 | " words = \" \".join(words)\n", 272 | " tags = \" \".join(tags)\n", 273 | " return words, x, is_heads, tags, y, seqlen\n" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 14, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "def pad(batch):\n", 283 | " '''Pads to the longest sample'''\n", 284 | " f = lambda x: [sample[x] for sample in batch]\n", 285 | " words = f(0)\n", 286 | " is_heads = f(2)\n", 287 | " tags = f(3)\n", 288 | " seqlens = f(-1)\n", 289 | " maxlen = np.array(seqlens).max()\n", 290 | "\n", 291 | " f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: \n", 292 | " x = f(1, maxlen)\n", 293 | " y = f(-2, maxlen)\n", 294 | "\n", 295 | "\n", 296 | " f = torch.LongTensor\n", 297 | "\n", 298 | " return words, f(x), is_heads, tags, f(y), seqlens" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "# Model" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 15, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "from pytorch_pretrained_bert import BertModel" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 16, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "class Net(nn.Module):\n", 324 | " def __init__(self, vocab_size=None):\n", 325 | " super().__init__()\n", 326 | " self.bert = BertModel.from_pretrained('bert-base-cased')\n", 327 | "\n", 328 | " self.fc = nn.Linear(768, vocab_size)\n", 329 | " self.device = device\n", 330 | "\n", 331 | " def forward(self, x, y):\n", 332 | " '''\n", 333 | " x: (N, T). int64\n", 334 | " y: (N, T). int64\n", 335 | " '''\n", 336 | " x = x.to(device)\n", 337 | " y = y.to(device)\n", 338 | " \n", 339 | " if self.training:\n", 340 | " self.bert.train()\n", 341 | " encoded_layers, _ = self.bert(x)\n", 342 | " enc = encoded_layers[-1]\n", 343 | " else:\n", 344 | " self.bert.eval()\n", 345 | " with torch.no_grad():\n", 346 | " encoded_layers, _ = self.bert(x)\n", 347 | " enc = encoded_layers[-1]\n", 348 | " \n", 349 | " logits = self.fc(enc)\n", 350 | " y_hat = logits.argmax(-1)\n", 351 | " return logits, y, y_hat" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "# Train an evaluate" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 17, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "def train(model, iterator, optimizer, criterion):\n", 368 | " model.train()\n", 369 | " for i, batch in enumerate(iterator):\n", 370 | " words, x, is_heads, tags, y, seqlens = batch\n", 371 | " _y = y # for monitoring\n", 372 | " optimizer.zero_grad()\n", 373 | " logits, y, _ = model(x, y) # logits: (N, T, VOCAB), y: (N, T)\n", 374 | "\n", 375 | " logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)\n", 376 | " y = y.view(-1) # (N*T,)\n", 377 | "\n", 378 | " loss = criterion(logits, y)\n", 379 | " loss.backward()\n", 380 | "\n", 381 | " optimizer.step()\n", 382 | "\n", 383 | " if i%10==0: # monitoring\n", 384 | " print(\"step: {}, loss: {}\".format(i, loss.item()))" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 18, 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "def eval(model, iterator):\n", 394 | " model.eval()\n", 395 | "\n", 396 | " Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], []\n", 397 | " with torch.no_grad():\n", 398 | " for i, batch in enumerate(iterator):\n", 399 | " words, x, is_heads, tags, y, seqlens = batch\n", 400 | "\n", 401 | " _, _, y_hat = model(x, y) # y_hat: (N, T)\n", 402 | "\n", 403 | " Words.extend(words)\n", 404 | " Is_heads.extend(is_heads)\n", 405 | " Tags.extend(tags)\n", 406 | " Y.extend(y.numpy().tolist())\n", 407 | " Y_hat.extend(y_hat.cpu().numpy().tolist())\n", 408 | "\n", 409 | " ## gets results and save\n", 410 | " with open(\"result\", 'w') as fout:\n", 411 | " for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat):\n", 412 | " y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1]\n", 413 | " preds = [idx2tag[hat] for hat in y_hat]\n", 414 | " assert len(preds)==len(words.split())==len(tags.split())\n", 415 | " for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]):\n", 416 | " fout.write(\"{} {} {}\\n\".format(w, t, p))\n", 417 | " fout.write(\"\\n\")\n", 418 | " \n", 419 | " ## calc metric\n", 420 | " y_true = np.array([tag2idx[line.split()[1]] for line in open('result', 'r').read().splitlines() if len(line) > 0])\n", 421 | " y_pred = np.array([tag2idx[line.split()[2]] for line in open('result', 'r').read().splitlines() if len(line) > 0])\n", 422 | "\n", 423 | " acc = (y_true==y_pred).astype(np.int32).sum() / len(y_true)\n", 424 | "\n", 425 | " print(\"acc=%.2f\"%acc)\n" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "## Load model and train" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": 19, 438 | "metadata": { 439 | "scrolled": false 440 | }, 441 | "outputs": [], 442 | "source": [ 443 | "model = Net(vocab_size=len(tag2idx))\n", 444 | "model.to(device)\n", 445 | "model = nn.DataParallel(model)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 20, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "train_dataset = PosDataset(train_data)\n", 455 | "eval_dataset = PosDataset(test_data)\n", 456 | "\n", 457 | "train_iter = data.DataLoader(dataset=train_dataset,\n", 458 | " batch_size=8,\n", 459 | " shuffle=True,\n", 460 | " num_workers=1,\n", 461 | " collate_fn=pad)\n", 462 | "test_iter = data.DataLoader(dataset=eval_dataset,\n", 463 | " batch_size=8,\n", 464 | " shuffle=False,\n", 465 | " num_workers=1,\n", 466 | " collate_fn=pad)\n", 467 | "\n", 468 | "optimizer = optim.Adam(model.parameters(), lr = 0.0001)\n", 469 | "\n", 470 | "criterion = nn.CrossEntropyLoss(ignore_index=0)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 23, 476 | "metadata": { 477 | "scrolled": false 478 | }, 479 | "outputs": [ 480 | { 481 | "name": "stdout", 482 | "output_type": "stream", 483 | "text": [ 484 | "step: 0, loss: 0.027864959090948105\n", 485 | "step: 10, loss: 0.03902581334114075\n", 486 | "step: 20, loss: 0.029155433177947998\n", 487 | "step: 30, loss: 0.036159448325634\n", 488 | "step: 40, loss: 0.04948236793279648\n", 489 | "step: 50, loss: 0.034221794456243515\n", 490 | "step: 60, loss: 0.017331236973404884\n", 491 | "step: 70, loss: 0.06194368004798889\n", 492 | "step: 80, loss: 0.01584777608513832\n", 493 | "step: 90, loss: 0.05200301483273506\n", 494 | "step: 100, loss: 0.042910996824502945\n", 495 | "step: 110, loss: 0.01104726456105709\n", 496 | "step: 120, loss: 0.09724321961402893\n", 497 | "step: 130, loss: 0.03911526873707771\n", 498 | "step: 140, loss: 0.01710551604628563\n", 499 | "step: 150, loss: 0.06321573257446289\n", 500 | "step: 160, loss: 0.03924640640616417\n", 501 | "step: 170, loss: 0.018429234623908997\n", 502 | "step: 180, loss: 0.08669907599687576\n", 503 | "step: 190, loss: 0.03778192400932312\n", 504 | "step: 200, loss: 0.06231529265642166\n", 505 | "step: 210, loss: 0.03337831050157547\n", 506 | "step: 220, loss: 0.02998737245798111\n", 507 | "step: 230, loss: 0.042920369654893875\n", 508 | "step: 240, loss: 0.03866969794034958\n", 509 | "step: 250, loss: 0.0313352607190609\n", 510 | "step: 260, loss: 0.07101577520370483\n", 511 | "step: 270, loss: 0.10232127457857132\n", 512 | "step: 280, loss: 0.02433393895626068\n", 513 | "step: 290, loss: 0.04439501464366913\n", 514 | "step: 300, loss: 0.040027499198913574\n", 515 | "step: 310, loss: 0.027676744386553764\n", 516 | "step: 320, loss: 0.07696736603975296\n", 517 | "step: 330, loss: 0.0451495386660099\n", 518 | "step: 340, loss: 0.047669027000665665\n", 519 | "step: 350, loss: 0.03548085317015648\n", 520 | "step: 360, loss: 0.05764956399798393\n", 521 | "step: 370, loss: 0.018015533685684204\n", 522 | "step: 380, loss: 0.01151026040315628\n", 523 | "step: 390, loss: 0.05745118111371994\n", 524 | "step: 400, loss: 0.08683514595031738\n", 525 | "step: 410, loss: 0.020299362018704414\n", 526 | "step: 420, loss: 0.03219173103570938\n", 527 | "step: 430, loss: 0.0664878711104393\n", 528 | "step: 440, loss: 0.059365227818489075\n", 529 | "acc=0.98\n" 530 | ] 531 | } 532 | ], 533 | "source": [ 534 | "train(model, train_iter, optimizer, criterion)\n", 535 | "eval(model, test_iter)\n" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "metadata": {}, 541 | "source": [ 542 | "Check the result." 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 24, 548 | "metadata": {}, 549 | "outputs": [ 550 | { 551 | "data": { 552 | "text/plain": [ 553 | "['Bonds NNS NNS',\n", 554 | " 'due JJ JJ',\n", 555 | " 'in IN IN',\n", 556 | " '2005 CD CD',\n", 557 | " 'have VBP VBP',\n", 558 | " 'a DT DT',\n", 559 | " '7 CD CD',\n", 560 | " '1\\\\/2 CD CD',\n", 561 | " '% NN NN',\n", 562 | " 'coupon NN NN',\n", 563 | " 'and CC CC',\n", 564 | " 'are VBP VBP',\n", 565 | " 'priced VBN VBN',\n", 566 | " '*-1 -NONE- -NONE-',\n", 567 | " 'at IN IN',\n", 568 | " 'par NN NN',\n", 569 | " '. . .',\n", 570 | " '',\n", 571 | " 'Mr. NNP NNP',\n", 572 | " 'Sidak NNP NNP',\n", 573 | " 'served VBD VBD',\n", 574 | " 'as IN IN',\n", 575 | " 'an DT DT',\n", 576 | " 'attorney NN NN',\n", 577 | " 'in IN IN',\n", 578 | " 'the DT DT',\n", 579 | " 'Reagan NNP NNP',\n", 580 | " 'administration NN NN',\n", 581 | " '. . .',\n", 582 | " '',\n", 583 | " 'Municipal NNP NNP',\n", 584 | " 'Issues NNPS NNPS',\n", 585 | " '',\n", 586 | " 'Viacom NNP NNP',\n", 587 | " 'denies VBZ VBZ',\n", 588 | " '0 -NONE- -NONE-',\n", 589 | " 'it PRP PRP',\n", 590 | " \"'s VBZ VBZ\",\n", 591 | " 'using VBG VBG',\n", 592 | " 'pressure NN NN',\n", 593 | " 'tactics NNS NNS',\n", 594 | " '. . .',\n", 595 | " '',\n", 596 | " 'Tokyo NNP NNP',\n", 597 | " \"'s POS POS\",\n", 598 | " 'leading VBG VBG',\n", 599 | " 'program NN NN',\n", 600 | " 'traders NNS NNS',\n", 601 | " 'are VBP VBP',\n", 602 | " 'the DT DT',\n", 603 | " 'big JJ JJ',\n", 604 | " 'U.S. NNP NNP',\n", 605 | " 'securities NNS NNS',\n", 606 | " 'houses NNS NNS',\n", 607 | " ', , ,',\n", 608 | " 'though IN IN',\n", 609 | " 'the DT DT',\n", 610 | " 'Japanese NNP NNS',\n", 611 | " 'are VBP VBP',\n", 612 | " 'playing VBG VBG',\n", 613 | " 'catch-up NN JJ',\n", 614 | " '. . .',\n", 615 | " '',\n", 616 | " 'That DT DT',\n", 617 | " \"'s VBZ VBZ\",\n", 618 | " 'why WRB WRB',\n", 619 | " 'Columbia NNP NNP',\n", 620 | " 'just RB RB',\n", 621 | " 'wrote VBD VBD',\n", 622 | " 'off RP RP',\n", 623 | " '$ $ $',\n", 624 | " '130 CD CD',\n", 625 | " 'million CD CD',\n", 626 | " '*U* -NONE- -NONE-',\n", 627 | " 'of IN IN',\n", 628 | " 'its PRP$ PRP$',\n", 629 | " 'junk NN NN',\n", 630 | " 'and CC CC',\n", 631 | " 'reserved VBD VBD',\n", 632 | " '$ $ $',\n", 633 | " '227 CD CD',\n", 634 | " 'million CD CD',\n", 635 | " '*U* -NONE- -NONE-',\n", 636 | " 'for IN IN',\n", 637 | " 'future JJ JJ',\n", 638 | " 'junk NN NN',\n", 639 | " 'losses NNS NNS',\n", 640 | " '*T*-1 -NONE- -NONE-',\n", 641 | " '. . .',\n", 642 | " '',\n", 643 | " 'Allergan NNP NNP',\n", 644 | " 'Inc. NNP NNP',\n", 645 | " 'said VBD VBD',\n", 646 | " '0 -NONE- -NONE-',\n", 647 | " 'it PRP PRP',\n", 648 | " 'received VBD VBD',\n", 649 | " 'Food NNP NNP',\n", 650 | " 'and CC CC',\n", 651 | " 'Drug NNP NNP',\n", 652 | " 'Administration NNP NNP']" 653 | ] 654 | }, 655 | "execution_count": 24, 656 | "metadata": {}, 657 | "output_type": "execute_result" 658 | } 659 | ], 660 | "source": [ 661 | "open('result', 'r').read().splitlines()[:100]" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "execution_count": null, 667 | "metadata": {}, 668 | "outputs": [], 669 | "source": [] 670 | } 671 | ], 672 | "metadata": { 673 | "kernelspec": { 674 | "display_name": "Python 3", 675 | "language": "python", 676 | "name": "python3" 677 | }, 678 | "language_info": { 679 | "codemirror_mode": { 680 | "name": "ipython", 681 | "version": 3 682 | }, 683 | "file_extension": ".py", 684 | "mimetype": "text/x-python", 685 | "name": "python", 686 | "nbconvert_exporter": "python", 687 | "pygments_lexer": "ipython3", 688 | "version": "3.7.1" 689 | } 690 | }, 691 | "nbformat": 4, 692 | "nbformat_minor": 2 693 | } 694 | -------------------------------------------------------------------------------- /PyTorch seq2seq template based on the g2p task.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "We'll write a simple template for seq2seq using PyTorch. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. If you want to know more about g2p, see my [repo](https:/github.com/kyubyong/g2p)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 109, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 110, 24 | "metadata": { 25 | "colab": {}, 26 | "colab_type": "code", 27 | "id": "TEXqpZ_U738q" 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import numpy as np\n", 32 | "from tqdm import tqdm_notebook as tqdm\n", 33 | "from distance import levenshtein\n", 34 | "import os\n", 35 | "import math\n", 36 | "import torch\n", 37 | "import torch.nn as nn\n", 38 | "import torch.nn.functional as F\n", 39 | "import torch.optim as optim\n", 40 | "from torch.utils import data\n", 41 | "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 111, 47 | "metadata": { 48 | "colab": { 49 | "base_uri": "https://localhost:8080/", 50 | "height": 34 51 | }, 52 | "colab_type": "code", 53 | "id": "_7vuctbU7381", 54 | "outputId": "f8ee2cbf-1f04-432f-ba42-d25fec61669b" 55 | }, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "'1.2.0'" 61 | ] 62 | }, 63 | "execution_count": 111, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "torch.__version__" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "colab_type": "text", 76 | "id": "B6te4HKk738_" 77 | }, 78 | "source": [ 79 | "# Hyperparameters" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 112, 85 | "metadata": { 86 | "colab": {}, 87 | "colab_type": "code", 88 | "id": "CWS2hkce739C" 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "class Hparams:\n", 93 | " batch_size = 128\n", 94 | " enc_maxlen = 20\n", 95 | " dec_maxlen = 20\n", 96 | " num_epochs = 10\n", 97 | " hidden_units = 128\n", 98 | " emb_units = 64\n", 99 | " graphemes = [\"\", \"\", \"\"] + list(\"abcdefghijklmnopqrstuvwxyz\")\n", 100 | " phonemes = [\"\", \"\", \"\", \"\"] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',\n", 101 | " 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH',\n", 102 | " 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH',\n", 103 | " 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1',\n", 104 | " 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW',\n", 105 | " 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH']\n", 106 | " lr = 0.001\n", 107 | " logdir = \"log/01\"\n", 108 | "hp = Hparams()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": { 114 | "colab_type": "text", 115 | "id": "nz-hD6dn739L" 116 | }, 117 | "source": [ 118 | "# Prepare Data" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 113, 124 | "metadata": { 125 | "colab": {}, 126 | "colab_type": "code", 127 | "id": "-as4PHs-739N" 128 | }, 129 | "outputs": [ 130 | { 131 | "data": { 132 | "text/plain": [ 133 | "[['R', 'AH0', 'F', 'Y', 'UW1', 'Z'],\n", 134 | " ['R', 'EH1', 'F', 'Y', 'UW2', 'Z'],\n", 135 | " ['R', 'IH0', 'F', 'Y', 'UW1', 'Z']]" 136 | ] 137 | }, 138 | "execution_count": 113, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "import nltk\n", 145 | "# nltk.download('cmudict')# <- if you haven't downloaded, do this.\n", 146 | "from nltk.corpus import cmudict\n", 147 | "cmu = cmudict.dict()\n", 148 | "cmu[\"refuse\"]" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 114, 154 | "metadata": { 155 | "colab": {}, 156 | "colab_type": "code", 157 | "id": "39gQ3vOi739S" 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "def load_vocab():\n", 162 | " g2idx = {g: idx for idx, g in enumerate(hp.graphemes)}\n", 163 | " idx2g = {idx: g for idx, g in enumerate(hp.graphemes)}\n", 164 | "\n", 165 | " p2idx = {p: idx for idx, p in enumerate(hp.phonemes)}\n", 166 | " idx2p = {idx: p for idx, p in enumerate(hp.phonemes)}\n", 167 | "\n", 168 | " return g2idx, idx2g, p2idx, idx2p # note that g and p mean grapheme and phoneme, respectively.\n", 169 | "\n", 170 | "g2idx, idx2g, p2idx, idx2p = load_vocab()" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 115, 176 | "metadata": { 177 | "colab": {}, 178 | "colab_type": "code", 179 | "id": "Zslytxn6739Z" 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "def prepare_data():\n", 184 | " words = [\" \".join(list(word)) for word, prons in cmu.items()]\n", 185 | " prons = [\" \".join(prons[0]) for word, prons in cmu.items()]\n", 186 | " indices = list(range(len(words)))\n", 187 | " from random import shuffle\n", 188 | " shuffle(indices)\n", 189 | " words = [words[idx] for idx in indices]\n", 190 | " prons = [prons[idx] for idx in indices]\n", 191 | " num_train, num_test = int(len(words)*.8), int(len(words)*.1)\n", 192 | " train_words, eval_words, test_words = words[:num_train], \\\n", 193 | " words[num_train:-num_test],\\\n", 194 | " words[-num_test:]\n", 195 | " train_prons, eval_prons, test_prons = prons[:num_train], \\\n", 196 | " prons[num_train:-num_test],\\\n", 197 | " prons[-num_test:] \n", 198 | " return train_words, eval_words, test_words, train_prons, eval_prons, test_prons" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 116, 204 | "metadata": { 205 | "colab": {}, 206 | "colab_type": "code", 207 | "id": "WHBXkAPG739j" 208 | }, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "f l a p j a c k\n", 215 | "F L AE1 P JH AE2 K\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "train_words, eval_words, test_words, train_prons, eval_prons, test_prons = prepare_data()\n", 221 | "print(train_words[0])\n", 222 | "print(train_prons[0])" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 117, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "def drop_lengthy_samples(words, prons, enc_maxlen, dec_maxlen):\n", 232 | " \"\"\"We only include such samples less than maxlen.\"\"\"\n", 233 | " _words, _prons = [], []\n", 234 | " for w, p in zip(words, prons):\n", 235 | " if len(w.split()) + 1 > enc_maxlen: continue\n", 236 | " if len(p.split()) + 1 > dec_maxlen: continue # 1: \n", 237 | " _words.append(w)\n", 238 | " _prons.append(p)\n", 239 | " return _words, _prons " 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 118, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "train_words, train_prons = drop_lengthy_samples(train_words, train_prons, hp.enc_maxlen, hp.dec_maxlen)\n", 249 | "# We do NOT apply this constraint to eval and test datasets." 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "# Data Loader" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 119, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "def encode(inp, type, dict):\n", 266 | " '''convert string into ids\n", 267 | " type: \"x\" or \"y\"\n", 268 | " dict: g2idx for 'x', p2idx for 'y'\n", 269 | " '''\n", 270 | " if type==\"x\": tokens = inp.split() + [\"\"]\n", 271 | " else: tokens = [\"\"] + inp.split() + [\"\"]\n", 272 | "\n", 273 | " x = [dict.get(t, dict[\"\"]) for t in tokens]\n", 274 | " return x" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 120, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "class G2pDataset(data.Dataset):\n", 284 | "\n", 285 | " def __init__(self, words, prons):\n", 286 | " \"\"\"\n", 287 | " words: list of words. e.g., [\"w o r d\", ]\n", 288 | " prons: list of prons. e.g., ['W ER1 D',]\n", 289 | " \"\"\"\n", 290 | " self.words = words\n", 291 | " self.prons = prons\n", 292 | "\n", 293 | " def __len__(self):\n", 294 | " return len(self.words)\n", 295 | "\n", 296 | " def __getitem__(self, idx):\n", 297 | " word, pron = self.words[idx], self.prons[idx]\n", 298 | " x = encode(word, \"x\", g2idx)\n", 299 | " y = encode(pron, \"y\", p2idx)\n", 300 | " decoder_input, y = y[:-1], y[1:]\n", 301 | "\n", 302 | " x_seqlen, y_seqlen = len(x), len(y)\n", 303 | " \n", 304 | " return x, x_seqlen, word, decoder_input, y, y_seqlen, pron" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 121, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "def pad(batch):\n", 314 | " '''Pads zeros such that the length of all samples in a batch is the same.'''\n", 315 | " f = lambda x: [sample[x] for sample in batch]\n", 316 | " x_seqlens = f(1)\n", 317 | " y_seqlens = f(5)\n", 318 | " words = f(2)\n", 319 | " prons = f(-1)\n", 320 | " \n", 321 | " x_maxlen = np.array(x_seqlens).max()\n", 322 | " y_maxlen = np.array(y_seqlens).max()\n", 323 | " \n", 324 | " f = lambda x, maxlen, batch: [sample[x]+[0]*(maxlen-len(sample[x])) for sample in batch]\n", 325 | " x = f(0, x_maxlen, batch)\n", 326 | " decoder_inputs = f(3, y_maxlen, batch)\n", 327 | " y = f(4, y_maxlen, batch)\n", 328 | " \n", 329 | " f = torch.LongTensor\n", 330 | " return f(x), x_seqlens, words, f(decoder_inputs), f(y), y_seqlens, prons" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "colab_type": "text", 337 | "id": "22mif4xf73-M" 338 | }, 339 | "source": [ 340 | "# Model" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 234, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "class Encoder(nn.Module):\n", 350 | " global g2idx, idx2g, p2idx, idx2p\n", 351 | " def __init__(self, emb_units, hidden_units):\n", 352 | " super().__init__()\n", 353 | " self.emb_units = emb_units\n", 354 | " self.hidden_units = hidden_units\n", 355 | " self.emb = nn.Embedding(len(g2idx), emb_units)\n", 356 | " self.rnn = nn.GRU(emb_units, hidden_units, batch_first=True)\n", 357 | " \n", 358 | " def forward(self, x, seqlens):\n", 359 | " x = self.emb(x)\n", 360 | " \n", 361 | " # packing -> rnn -> unpacking -> position recovery: note that enforce_sorted is set to False.\n", 362 | " packed_input = pack_padded_sequence(x, seqlens, batch_first=True, enforce_sorted=False) \n", 363 | " outputs, last_hidden = self.rnn(packed_input)\n", 364 | "# outputs, _ = pad_packed_sequence(outputs, batch_first=True, total_length=x.size()[1])\n", 365 | "\n", 366 | " # last hidden\n", 367 | " last_hidden = last_hidden.permute(1, 2, 0)\n", 368 | " last_hidden = last_hidden.view(last_hidden.size()[0], -1)\n", 369 | " \n", 370 | " return last_hidden\n", 371 | "\n" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 235, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "class Decoder(nn.Module):\n", 381 | " global g2idx, idx2g, p2idx, idx2p\n", 382 | " def __init__(self, emb_units, hidden_units):\n", 383 | " super().__init__()\n", 384 | " \n", 385 | " self.emb_units = emb_units\n", 386 | " self.hidden_units = hidden_units\n", 387 | " self.emb = nn.Embedding(len(p2idx), emb_units)\n", 388 | " self.rnn = nn.GRU(emb_units, hidden_units, batch_first=True)\n", 389 | " self.fc = nn.Linear(hidden_units, len(p2idx))\n", 390 | " \n", 391 | " def forward(self, decoder_inputs, h0):\n", 392 | " decoder_inputs = self.emb(decoder_inputs)\n", 393 | " \n", 394 | " outputs, last_hidden = self.rnn(decoder_inputs, h0)\n", 395 | " logits = self.fc(outputs) # (N, T, V)\n", 396 | " y_hat = logits.argmax(-1)\n", 397 | " \n", 398 | " return logits, y_hat, last_hidden\n" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 236, 404 | "metadata": { 405 | "colab": {}, 406 | "colab_type": "code", 407 | "id": "HA39FU4-73-O" 408 | }, 409 | "outputs": [], 410 | "source": [ 411 | "class Net(nn.Module):\n", 412 | " global g2idx, idx2g, p2idx, idx2p\n", 413 | " \n", 414 | " def __init__(self, encoder, decoder): \n", 415 | " super().__init__()\n", 416 | " self.encoder = encoder\n", 417 | " self.decoder = decoder\n", 418 | " \n", 419 | " def forward(self, x, seqlens, decoder_inputs, teacher_forcing=True, dec_maxlen=None): \n", 420 | " '''\n", 421 | " At training, decoder inputs (ground truth) and teacher forcing is applied. \n", 422 | " At evaluation, decoder inputs are ignored, and the decoding keeps for `dec_maxlen` steps.\n", 423 | " '''\n", 424 | " last_hidden = self.encoder(x, seqlens)\n", 425 | " h0 = last_hidden.unsqueeze(0)\n", 426 | " \n", 427 | " if teacher_forcing: # training\n", 428 | " logits, y_hat, h0 = self.decoder(decoder_inputs, h0)\n", 429 | " else: # evaluation\n", 430 | " decoder_inputs = decoder_inputs[:, :1] # \"\"\n", 431 | " logits, y_hat = [], []\n", 432 | " for t in range(dec_maxlen):\n", 433 | " _logits, _y_hat, h0 =self.decoder(decoder_inputs, h0) # _logits: (N, 1, V), _y_hat: (N, 1), h0: (1, N, N)\n", 434 | " logits.append(_logits)\n", 435 | " y_hat.append(_y_hat)\n", 436 | " decoder_inputs = _y_hat\n", 437 | " \n", 438 | " logits = torch.cat(logits, 1)\n", 439 | " y_hat = torch.cat(y_hat, 1)\n", 440 | " \n", 441 | " return logits, y_hat\n" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": {}, 447 | "source": [ 448 | "# Train & Eval functions" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 237, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "def train(model, iterator, optimizer, criterion, device):\n", 458 | " model.train()\n", 459 | " for i, batch in enumerate(iterator):\n", 460 | " x, x_seqlens, words, decoder_inputs, y, y_seqlens, prons = batch\n", 461 | " \n", 462 | " x, decoder_inputs = x.to(device), decoder_inputs.to(device) \n", 463 | " y = y.to(device)\n", 464 | " \n", 465 | " optimizer.zero_grad()\n", 466 | " logits, y_hat = model(x, x_seqlens, decoder_inputs)\n", 467 | " \n", 468 | " # calc loss\n", 469 | " logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)\n", 470 | " y = y.view(-1) # (N*T,)\n", 471 | " loss = criterion(logits, y)\n", 472 | " loss.backward()\n", 473 | " \n", 474 | " optimizer.step()\n", 475 | " \n", 476 | " if i and i%100==0:\n", 477 | " print(f\"step: {i}, loss: {loss.item()}\")\n", 478 | " " 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 239, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "def calc_per(Y_true, Y_pred):\n", 488 | " '''Calc phoneme error rate\n", 489 | " Y_true: list of predicted phoneme sequences. e.g., [[\"B\", \"L\", \"AA1\", \"K\", \"HH\", \"AW2\", \"S\"], ...]\n", 490 | " Y_pred: list of ground truth phoneme sequences. e.g., [[\"B\", \"L\", \"AA1\", \"K\", \"HH\", \"AW2\", \"S\"], ...]\n", 491 | " '''\n", 492 | " num_phonemes, num_erros = 0, 0\n", 493 | " for y_true, y_pred in zip(Y_true, Y_pred):\n", 494 | " num_phonemes += len(y_true)\n", 495 | " num_erros += levenshtein(y_true, y_pred)\n", 496 | "\n", 497 | " per = round(num_erros / num_phonemes, 2)\n", 498 | " return per, num_erros" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 240, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "def convert_ids_to_phonemes(ids, idx2p):\n", 508 | " phonemes = []\n", 509 | " for idx in ids:\n", 510 | " if idx == 3: # 3: \n", 511 | " break\n", 512 | " p = idx2p[idx]\n", 513 | " phonemes.append(p)\n", 514 | " return phonemes\n", 515 | " \n", 516 | " \n", 517 | "\n", 518 | "def eval(model, iterator, device, dec_maxlen):\n", 519 | " model.eval()\n", 520 | "\n", 521 | " Y_true, Y_pred = [], []\n", 522 | " with torch.no_grad():\n", 523 | " for i, batch in enumerate(iterator):\n", 524 | " x, x_seqlens, words, decoder_inputs, y, y_seqlens, prons = batch\n", 525 | " x, decoder_inputs = x.to(device), decoder_inputs.to(device) \n", 526 | "\n", 527 | " _, y_hat = model(x, x_seqlens, decoder_inputs, False, dec_maxlen) # <- teacher forcing is suppressed.\n", 528 | " \n", 529 | " y = y.to('cpu').numpy().tolist()\n", 530 | " y_hat = y_hat.to('cpu').numpy().tolist()\n", 531 | " for yy, yy_hat in zip(y, y_hat):\n", 532 | " y_true = convert_ids_to_phonemes(yy, idx2p)\n", 533 | " y_pred = convert_ids_to_phonemes(yy_hat, idx2p)\n", 534 | " Y_true.append(y_true)\n", 535 | " Y_pred.append(y_pred)\n", 536 | " \n", 537 | " # calc per.\n", 538 | " per, num_errors = calc_per(Y_true, Y_pred)\n", 539 | " print(\"per: %.2f\" % per, \"num errors: \", num_errors)\n", 540 | " \n", 541 | " with open(\"result\", \"w\") as fout:\n", 542 | " for y_true, y_pred in zip(Y_true, Y_pred):\n", 543 | " fout.write(\" \".join(y_true) + \"\\n\")\n", 544 | " fout.write(\" \".join(y_pred) + \"\\n\\n\")\n", 545 | " \n", 546 | " return per\n", 547 | " " 548 | ] 549 | }, 550 | { 551 | "cell_type": "markdown", 552 | "metadata": { 553 | "colab_type": "text", 554 | "id": "PKllLnfp73-V" 555 | }, 556 | "source": [ 557 | "# Train & Evaluate" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 241, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "train_dataset = G2pDataset(train_words, train_prons)\n", 567 | "eval_dataset = G2pDataset(eval_words, eval_prons)\n", 568 | "\n", 569 | "train_iter = data.DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=pad)\n", 570 | "eval_iter = data.DataLoader(eval_dataset, batch_size=hp.batch_size, shuffle=False, collate_fn=pad)\n" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 242, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 243, 585 | "metadata": { 586 | "colab": { 587 | "base_uri": "https://localhost:8080/", 588 | "height": 121 589 | }, 590 | "colab_type": "code", 591 | "id": "aF0cJceg73-m", 592 | "outputId": "ab78b80a-d6e3-4408-af21-acff58944a79", 593 | "scrolled": false 594 | }, 595 | "outputs": [ 596 | { 597 | "name": "stdout", 598 | "output_type": "stream", 599 | "text": [ 600 | "\n", 601 | "epoch: 1\n", 602 | "step: 100, loss: 2.729764461517334\n", 603 | "step: 200, loss: 2.06953763961792\n", 604 | "step: 300, loss: 1.6415880918502808\n", 605 | "step: 400, loss: 1.3378574848175049\n", 606 | "step: 500, loss: 1.2205088138580322\n", 607 | "step: 600, loss: 0.987713098526001\n", 608 | "step: 700, loss: 0.9300493597984314\n", 609 | "per: 0.39 num errors: 30495\n", 610 | "\n", 611 | "epoch: 2\n", 612 | "step: 100, loss: 0.862144947052002\n", 613 | "step: 200, loss: 0.8274500370025635\n", 614 | "step: 300, loss: 0.7277541160583496\n", 615 | "step: 400, loss: 0.8145508766174316\n", 616 | "step: 500, loss: 0.6055648922920227\n", 617 | "step: 600, loss: 0.6670782566070557\n", 618 | "step: 700, loss: 0.7308872938156128\n", 619 | "per: 0.30 num errors: 23436\n", 620 | "\n", 621 | "epoch: 3\n", 622 | "step: 100, loss: 0.6280044913291931\n", 623 | "step: 200, loss: 0.6292356848716736\n", 624 | "step: 300, loss: 0.6315799951553345\n", 625 | "step: 400, loss: 0.6083210110664368\n", 626 | "step: 500, loss: 0.5998032093048096\n", 627 | "step: 600, loss: 0.634599506855011\n", 628 | "step: 700, loss: 0.6139586567878723\n", 629 | "per: 0.25 num errors: 19908\n", 630 | "\n", 631 | "epoch: 4\n", 632 | "step: 100, loss: 0.5246961116790771\n", 633 | "step: 200, loss: 0.4894694685935974\n", 634 | "step: 300, loss: 0.5112945437431335\n", 635 | "step: 400, loss: 0.42986905574798584\n", 636 | "step: 500, loss: 0.598124623298645\n", 637 | "step: 600, loss: 0.4181916415691376\n", 638 | "step: 700, loss: 0.4648551046848297\n", 639 | "per: 0.23 num errors: 18315\n", 640 | "\n", 641 | "epoch: 5\n", 642 | "step: 100, loss: 0.5175663232803345\n", 643 | "step: 200, loss: 0.429765909910202\n", 644 | "step: 300, loss: 0.40104925632476807\n", 645 | "step: 400, loss: 0.5041416883468628\n", 646 | "step: 500, loss: 0.4927402436733246\n", 647 | "step: 600, loss: 0.4579430818557739\n", 648 | "step: 700, loss: 0.4033070504665375\n", 649 | "per: 0.22 num errors: 17284\n", 650 | "\n", 651 | "epoch: 6\n", 652 | "step: 100, loss: 0.40969985723495483\n", 653 | "step: 200, loss: 0.4264414310455322\n", 654 | "step: 300, loss: 0.4089752435684204\n", 655 | "step: 400, loss: 0.4005317986011505\n", 656 | "step: 500, loss: 0.46420779824256897\n", 657 | "step: 600, loss: 0.3897724449634552\n", 658 | "step: 700, loss: 0.3946235775947571\n", 659 | "per: 0.21 num errors: 16358\n", 660 | "\n", 661 | "epoch: 7\n", 662 | "step: 100, loss: 0.36400753259658813\n", 663 | "step: 200, loss: 0.43446823954582214\n", 664 | "step: 300, loss: 0.419318288564682\n", 665 | "step: 400, loss: 0.3300051987171173\n", 666 | "step: 500, loss: 0.4172106385231018\n", 667 | "step: 600, loss: 0.4165189862251282\n", 668 | "step: 700, loss: 0.37570977210998535\n", 669 | "per: 0.20 num errors: 15657\n", 670 | "\n", 671 | "epoch: 8\n", 672 | "step: 100, loss: 0.28185898065567017\n", 673 | "step: 200, loss: 0.3737757205963135\n", 674 | "step: 300, loss: 0.31888020038604736\n", 675 | "step: 400, loss: 0.31222569942474365\n", 676 | "step: 500, loss: 0.36826616525650024\n", 677 | "step: 600, loss: 0.32874614000320435\n", 678 | "step: 700, loss: 0.41009703278541565\n", 679 | "per: 0.19 num errors: 14824\n", 680 | "\n", 681 | "epoch: 9\n", 682 | "step: 100, loss: 0.3402358591556549\n", 683 | "step: 200, loss: 0.34929829835891724\n", 684 | "step: 300, loss: 0.3066072463989258\n", 685 | "step: 400, loss: 0.403970867395401\n", 686 | "step: 500, loss: 0.39007294178009033\n", 687 | "step: 600, loss: 0.36635467410087585\n", 688 | "step: 700, loss: 0.371225506067276\n", 689 | "per: 0.19 num errors: 14593\n", 690 | "\n", 691 | "epoch: 10\n", 692 | "step: 100, loss: 0.3137436509132385\n", 693 | "step: 200, loss: 0.3691956102848053\n", 694 | "step: 300, loss: 0.3513341546058655\n", 695 | "step: 400, loss: 0.32838159799575806\n", 696 | "step: 500, loss: 0.2848772406578064\n", 697 | "step: 600, loss: 0.3028918504714966\n", 698 | "step: 700, loss: 0.41611844301223755\n", 699 | "per: 0.18 num errors: 14374\n" 700 | ] 701 | } 702 | ], 703 | "source": [ 704 | "encoder = Encoder(hp.emb_units, hp.hidden_units)\n", 705 | "decoder = Decoder(hp.emb_units, hp.hidden_units)\n", 706 | "model = Net(encoder, decoder)\n", 707 | "model.to(device)\n", 708 | "\n", 709 | "optimizer = optim.Adam(model.parameters(), lr = hp.lr)\n", 710 | "criterion = nn.CrossEntropyLoss(ignore_index=0)\n", 711 | "\n", 712 | "for epoch in range(1, hp.num_epochs+1):\n", 713 | " print(f\"\\nepoch: {epoch}\")\n", 714 | " train(model, train_iter, optimizer, criterion, device)\n", 715 | " eval(model, eval_iter, device, hp.dec_maxlen)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": { 721 | "colab_type": "text", 722 | "id": "82t4Dmwp73--" 723 | }, 724 | "source": [ 725 | "# Inference" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": 244, 731 | "metadata": { 732 | "colab": {}, 733 | "colab_type": "code", 734 | "id": "jUyYlI4S73_O", 735 | "outputId": "ae0592d3-14b0-4f3b-94f8-ebc293c48304" 736 | }, 737 | "outputs": [], 738 | "source": [ 739 | "test_dataset = G2pDataset(test_words, test_prons)\n", 740 | "test_iter = data.DataLoader(test_dataset, batch_size=hp.batch_size, shuffle=False, collate_fn=pad)" 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "execution_count": 245, 746 | "metadata": { 747 | "scrolled": false 748 | }, 749 | "outputs": [ 750 | { 751 | "name": "stdout", 752 | "output_type": "stream", 753 | "text": [ 754 | "per: 0.18 num errors: 14045\n" 755 | ] 756 | }, 757 | { 758 | "data": { 759 | "text/plain": [ 760 | "0.18" 761 | ] 762 | }, 763 | "execution_count": 245, 764 | "metadata": {}, 765 | "output_type": "execute_result" 766 | } 767 | ], 768 | "source": [ 769 | "eval(model, test_iter, device, hp.dec_maxlen)" 770 | ] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "metadata": {}, 775 | "source": [ 776 | "Check the results." 777 | ] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "execution_count": 246, 782 | "metadata": {}, 783 | "outputs": [ 784 | { 785 | "data": { 786 | "text/plain": [ 787 | "['',\n", 788 | " 'G L IY1 M D',\n", 789 | " 'G L IY1 M D',\n", 790 | " '',\n", 791 | " 'P EY1 D AH0 N',\n", 792 | " 'P EY1 D AH0 N',\n", 793 | " '',\n", 794 | " 'B L UW1 N AH0 S',\n", 795 | " 'B L UW1 N AH0 S',\n", 796 | " '',\n", 797 | " 'HH OW1 L B R UH0 K S',\n", 798 | " 'HH OW1 L B R UH2 K S',\n", 799 | " '',\n", 800 | " 'B AE1 R IH0 S T ER0 Z',\n", 801 | " 'B AE1 R IH0 S T ER0 Z',\n", 802 | " '',\n", 803 | " 'P EH1 L T',\n", 804 | " 'P EH1 L T',\n", 805 | " '',\n", 806 | " 'M AA1 R K AH0 L',\n", 807 | " 'M AA1 R K AH0 L',\n", 808 | " '',\n", 809 | " 'F EY1 G ER0 S T R AH0 M',\n", 810 | " 'F EY1 G ER0 S T R AH0 M',\n", 811 | " '',\n", 812 | " 'P EH1 R AH0 SH UW2 T',\n", 813 | " 'P EH1 R AH0 CH UW0 T',\n", 814 | " '',\n", 815 | " 'B EH2 L W OW0 M IY1 N IY0',\n", 816 | " 'B EH0 L UW1 M IY0 N IY0',\n", 817 | " '',\n", 818 | " 'L AA1 HH AH0 V IH0 CH',\n", 819 | " 'L AH0 SH OW1 IH0 K S',\n", 820 | " '',\n", 821 | " 'F EY1 G AH0 N',\n", 822 | " 'F AE1 G AH0 N',\n", 823 | " '',\n", 824 | " 'P IH1 S ER0 EH0 K',\n", 825 | " 'P IH0 S AA1 R EH0 K',\n", 826 | " '',\n", 827 | " 'R IY1 D ER0 M AH0 N',\n", 828 | " 'R IY1 D ER0 M AH0 N',\n", 829 | " '',\n", 830 | " 'K AA1 K AH0 T UW2 Z',\n", 831 | " 'K AH0 K EY1 T OW0 Z',\n", 832 | " '',\n", 833 | " 'R IY0 B AH1 F IH0 NG',\n", 834 | " 'R IH0 F AH1 B IH0 NG',\n", 835 | " '',\n", 836 | " 'S AW1 TH D AW2 N',\n", 837 | " 'S AW1 TH D AW2 N',\n", 838 | " '',\n", 839 | " 'B AE1 L AH0 N T R EY2',\n", 840 | " 'B AE1 L AH0 N T R EY2',\n", 841 | " '',\n", 842 | " 'S L OW1 P S',\n", 843 | " 'S L OW1 P S',\n", 844 | " '',\n", 845 | " 'V AE1 N D ER0 V L IY2 T',\n", 846 | " 'V AE1 N D ER0 V L AY2 T',\n", 847 | " '',\n", 848 | " 'F AY1 R B AA2 M D',\n", 849 | " 'F AY1 ER0 B OW2 M B AH0 JH',\n", 850 | " '',\n", 851 | " 'P AH0 L UW1 T ER0',\n", 852 | " 'P AA1 L AH0 T ER0',\n", 853 | " '',\n", 854 | " 'D AO1 F IH0 NG',\n", 855 | " 'D AO1 F IH0 NG',\n", 856 | " '',\n", 857 | " 'P AE1 L K OW0',\n", 858 | " 'P AE1 L K OW0',\n", 859 | " '',\n", 860 | " 'SH AH0 HH IH1 N IY0 AH0 N',\n", 861 | " 'SH AH0 HH IY1 N IY0 AH0 N',\n", 862 | " '',\n", 863 | " 'K L EH1 N CH',\n", 864 | " 'K L EH1 N CH',\n", 865 | " '',\n", 866 | " 'P AE1 S AH0 B L IY0',\n", 867 | " 'P AE1 S AH0 B L IY0',\n", 868 | " '',\n", 869 | " 'JH AA1 S AH0 L D',\n", 870 | " 'JH AA1 S T AH0 L',\n", 871 | " '',\n", 872 | " 'AA1 M N IH0 B UH2 K',\n", 873 | " 'AA1 M N IH0 B UH2 K',\n", 874 | " '',\n", 875 | " 'K L IH1 N IH0 K S',\n", 876 | " 'K L IH1 N IH0 K S',\n", 877 | " '',\n", 878 | " 'K AA1 N R IY0',\n", 879 | " 'K AA1 N R IY0',\n", 880 | " '',\n", 881 | " 'R AO1 F',\n", 882 | " 'R AO1 F',\n", 883 | " '',\n", 884 | " 'AH0 M AE1 N AH0',\n", 885 | " 'AH0 M AA1 N AH0',\n", 886 | " '']" 887 | ] 888 | }, 889 | "execution_count": 246, 890 | "metadata": {}, 891 | "output_type": "execute_result" 892 | } 893 | ], 894 | "source": [ 895 | "open('result', 'r').read().splitlines()[-100:]" 896 | ] 897 | }, 898 | { 899 | "cell_type": "code", 900 | "execution_count": null, 901 | "metadata": {}, 902 | "outputs": [], 903 | "source": [] 904 | } 905 | ], 906 | "metadata": { 907 | "colab": { 908 | "name": "Seq2seq tutorial with g2p.ipynb", 909 | "provenance": [], 910 | "version": "0.3.2" 911 | }, 912 | "kernelspec": { 913 | "display_name": "Python 3", 914 | "language": "python", 915 | "name": "python3" 916 | }, 917 | "language_info": { 918 | "codemirror_mode": { 919 | "name": "ipython", 920 | "version": 3 921 | }, 922 | "file_extension": ".py", 923 | "mimetype": "text/x-python", 924 | "name": "python", 925 | "nbconvert_exporter": "python", 926 | "pygments_lexer": "ipython3", 927 | "version": "3.7.3" 928 | } 929 | }, 930 | "nbformat": 4, 931 | "nbformat_minor": 1 932 | } 933 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NLP Made Easy 2 | 3 | Simple code notes for explaining NLP building blocks 4 | 5 | * [Subword Segmentation Techniques](Subword%20Segmentation%20Techniques.ipynb) 6 | * Let's compare various tokenizers, i.e., nltk, BPE, SentencePiece, and Bert tokenizer. 7 | * [Beam Decoding](Beam%20Decoding.ipynb) 8 | * Beam decoding is essential for seq2seq tasks. But it's notoriously complicated to implement. Here's a relatively easy one, batchfying candidates. 9 | * [How to get the last hidden vector of rnns properly](How%20to%20get%20the%20last%20hidden%20vector%20of%20rnns%20properly.ipynb) 10 | * We'll see how to get the last hidden states of Rnns in Tensorflow and PyTorch. 11 | * [Tensorflow seq2seq template based on the g2p task](Tensorflow%20seq2seq%20template%20based%20on%20g2p.ipynb) 12 | * We'll write a simple template for seq2seq using Tensorflow. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. 13 | * [PyTorch seq2seq template based on the g2p task](PyTorch%20seq2seq%20template%20based%20on%20the%20g2p%20task.ipynb) 14 | * We'll write a simple template for seq2seq using PyTorch. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. 15 | * [Attention mechanism](Work in progress) 16 | * [POS-tagging with BERT Fine-tuning](Pos-tagging%20with%20Bert%20Fine-tuning.ipynb) 17 | * BERT is known to be good at Sequence tagging tasks like Named Entity Recognition. Let's see if it's true for POS-tagging. 18 | * [Dropout in a minute](Dropout%20in%20a%20minute.ipynb) 19 | * Dropout is arguably the most popular regularization technique in deep learning. Let's check again how it work. 20 | * Ngram LM vs. rnnlm(WIP) 21 | * [Data Augmentation for Quora Question Pairs](Data%20Augmentation%20for%20Quora%20Question%20Pairs.ipynb) 22 | * Let's see if it's effective to augment training data in the task of quora question pairs. 23 | -------------------------------------------------------------------------------- /Subword Segmentation Techniques.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Let's compare various tokenizers, i.e., nltk, BPE, SentencePiece, and Bert tokenizer." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 257, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import os, re" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# Word Tokenizer" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "But before that, let's look at a word tokenizer first." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from nltk.tokenize import word_tokenize" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 7, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "['There', \"'s\", 'a', 'son-in-law', ',', 'mother-in-law', ',', 'etc', '.']\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "sent = \"There's a son-in-law, mother-in-law, etc.\"\n", 68 | "tokens = word_tokenize(sent)\n", 69 | "print(tokens)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "How to get the original sentence from the tokens? " 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 8, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "\"There 's a son-in-law , mother-in-law , etc .\"" 88 | ] 89 | }, 90 | "execution_count": 8, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "\" \".join(tokens)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 9, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "\"There'sason-in-law,mother-in-law,etc.\"" 108 | ] 109 | }, 110 | "execution_count": 9, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "\"\".join(tokens)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "Tricky!" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "# BPE" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "Rico Sennrich, Barry Haddow and Alexandra Birch (2016): [Neural Machine Translation of Rare Words with Subword Units](http://www.aclweb.org/anthology/P16-1162) Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany.\n", 138 | "\n", 139 | "\"Byte Pair Encoding (BPE) (Gage, 1994) is a simple data compression technique that iteratively replaces the most frequent pair of bytes in a sequence with a single, unused byte. We adapt this algorithm for word segmentation. Instead of merging frequent pairs of bytes, we merge characters or\n", 140 | "character sequence\"" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "https://github.com/rsennrich/subword-nmt" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 4, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "0" 159 | ] 160 | }, 161 | "execution_count": 4, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "os.system(\"pip install subword-nmt\")" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "### Let's play a little with a toy example" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 261, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "low\n", 187 | "low\n", 188 | "low\n", 189 | "low\n", 190 | "low\n", 191 | "lower\n", 192 | "lower\n", 193 | "newest\n", 194 | "newest\n", 195 | "newest\n", 196 | "newest\n", 197 | "newest\n", 198 | "newest\n", 199 | "widest\n", 200 | "widest\n", 201 | "widest\n", 202 | "\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "# let's create a sample text\n", 208 | "# This is the same example as the one in the above paper.\n", 209 | "text =\"low\\n\"*5 + \"lower\\n\"*2 + \"newest\\n\"*6 + \"widest\\n\"*3\n", 210 | "print(text)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 262, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "with open('toy', 'w') as f:\n", 220 | " f.write(text)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "step 1. Learn bpe. \n", 228 | "Process byte pair encoding and generate merge operations, i.e., codes." 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 263, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "data": { 238 | "text/plain": [ 239 | "0" 240 | ] 241 | }, 242 | "execution_count": 263, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "# Note that -s means number of operations\n", 249 | "learn_bpe = \"subword-nmt learn-bpe -s 1 --min-frequency 2 < toy > codes\"\n", 250 | "os.system(learn_bpe)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 264, 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "==codes==\n", 263 | "#version: 0.2\n", 264 | "s t\n", 265 | "====\n", 266 | "number of codes: 1\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "codes = open('codes', 'r').read()\n", 272 | "print(\"==codes==\\n\" + codes + \"====\")\n", 273 | "print(\"number of codes: \", len(codes.splitlines())-1)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "△ `` means end of a word." 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "△ Check the toy sample carefully. The last 9 words end in `st`, which is most frequent." 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "metadata": {}, 293 | "source": [ 294 | "step 2. Apply bpe. \n", 295 | "Apply codes to the designated file such that the original text is segmented.\n", 296 | "For demo, we apply the codes to the same toy file.\n" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 265, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "0" 308 | ] 309 | }, 310 | "execution_count": 265, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "apply_bpe = \"subword-nmt apply-bpe -c codes < toy > bpe\"\n", 317 | "os.system(apply_bpe)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 266, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | "==segmented==\n", 330 | "l@@ o@@ w\n", 331 | "l@@ o@@ w\n", 332 | "l@@ o@@ w\n", 333 | "l@@ o@@ w\n", 334 | "l@@ o@@ w\n", 335 | "l@@ o@@ w@@ e@@ r\n", 336 | "l@@ o@@ w@@ e@@ r\n", 337 | "n@@ e@@ w@@ e@@ st\n", 338 | "n@@ e@@ w@@ e@@ st\n", 339 | "n@@ e@@ w@@ e@@ st\n", 340 | "n@@ e@@ w@@ e@@ st\n", 341 | "n@@ e@@ w@@ e@@ st\n", 342 | "n@@ e@@ w@@ e@@ st\n", 343 | "w@@ i@@ d@@ e@@ st\n", 344 | "w@@ i@@ d@@ e@@ st\n", 345 | "w@@ i@@ d@@ e@@ st\n", 346 | "====\n" 347 | ] 348 | } 349 | ], 350 | "source": [ 351 | "bpe = open('bpe', 'r').read()\n", 352 | "print(\"==segmented==\\n\" + bpe + \"====\")" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "metadata": {}, 358 | "source": [ 359 | "△ Note that only `st` is glued." 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": {}, 365 | "source": [ 366 | "step 3. Get vocab. \n", 367 | "We get vocabulary from the segmented file." 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 267, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "data": { 377 | "text/plain": [ 378 | "0" 379 | ] 380 | }, 381 | "execution_count": 267, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "get_vocab = \"subword-nmt get-vocab < bpe > vocab\"\n", 388 | "os.system(get_vocab)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 268, 394 | "metadata": {}, 395 | "outputs": [ 396 | { 397 | "name": "stdout", 398 | "output_type": "stream", 399 | "text": [ 400 | "==vocab==\n", 401 | "e@@ 17\n", 402 | "w@@ 11\n", 403 | "st 9\n", 404 | "l@@ 7\n", 405 | "o@@ 7\n", 406 | "n@@ 6\n", 407 | "w 5\n", 408 | "i@@ 3\n", 409 | "d@@ 3\n", 410 | "r 2\n", 411 | "====\n", 412 | "number of vocab: 10\n" 413 | ] 414 | } 415 | ], 416 | "source": [ 417 | "vocab = open('vocab', 'r').read()\n", 418 | "print(\"==vocab==\\n\" + vocab + \"====\")\n", 419 | "print(\"number of vocab: \", len(vocab.splitlines()))" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "metadata": {}, 425 | "source": [ 426 | "△ Note that # codes (=1) is not the same as # vocab (=10." 427 | ] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "metadata": {}, 432 | "source": [ 433 | "### What if we increase the number of operations?" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 269, 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "name": "stdout", 443 | "output_type": "stream", 444 | "text": [ 445 | "==codes==\n", 446 | "#version: 0.2\n", 447 | "s t\n", 448 | "e st\n", 449 | "l o\n", 450 | "w est\n", 451 | "n e\n", 452 | "ne west\n", 453 | "lo w\n", 454 | "w i\n", 455 | "wi d\n", 456 | "wid est\n", 457 | "====\n", 458 | "number of codes: 10\n", 459 | "\n", 460 | "==segmented==\n", 461 | "low\n", 462 | "low\n", 463 | "low\n", 464 | "low\n", 465 | "low\n", 466 | "lo@@ w@@ e@@ r\n", 467 | "lo@@ w@@ e@@ r\n", 468 | "newest\n", 469 | "newest\n", 470 | "newest\n", 471 | "newest\n", 472 | "newest\n", 473 | "newest\n", 474 | "widest\n", 475 | "widest\n", 476 | "widest\n", 477 | "====\n", 478 | "\n", 479 | "==vocab==\n", 480 | "newest 6\n", 481 | "low 5\n", 482 | "widest 3\n", 483 | "lo@@ 2\n", 484 | "w@@ 2\n", 485 | "e@@ 2\n", 486 | "r 2\n", 487 | "====\n", 488 | "number of vocab: 7\n" 489 | ] 490 | } 491 | ], 492 | "source": [ 493 | "learn_bpe = \"subword-nmt learn-bpe -s 10 --min-frequency 2 < toy > codes\"\n", 494 | "os.system(learn_bpe)\n", 495 | "codes = open('codes', 'r').read()\n", 496 | "print(\"==codes==\\n\" + codes + \"====\")\n", 497 | "print(\"number of codes: \", len(codes.splitlines())-1)\n", 498 | "\n", 499 | "apply_bpe = \"subword-nmt apply-bpe -c codes < toy > bpe\"\n", 500 | "os.system(apply_bpe)\n", 501 | "bpe = open('bpe', 'r').read()\n", 502 | "print(\"\\n==segmented==\\n\" + bpe + \"====\")\n", 503 | "\n", 504 | "get_vocab = \"subword-nmt get-vocab < bpe > vocab\"\n", 505 | "os.system(get_vocab)\n", 506 | "vocab = open('vocab', 'r').read()\n", 507 | "print(\"\\n==vocab==\\n\" + vocab + \"====\")\n", 508 | "print(\"number of vocab: \", len(vocab.splitlines()))" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "△ As you've seen, if you increase the number of operations, \n", 516 | "words should be less segmented,\n", 517 | "and the number of vocabulary should decrease. " 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": {}, 523 | "source": [ 524 | "### How to restore the original text from the segmented one?" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 270, 530 | "metadata": {}, 531 | "outputs": [ 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "low\n", 537 | "low\n", 538 | "low\n", 539 | "low\n", 540 | "low\n", 541 | "lower\n", 542 | "lower\n", 543 | "newest\n", 544 | "newest\n", 545 | "newest\n", 546 | "newest\n", 547 | "newest\n", 548 | "newest\n", 549 | "widest\n", 550 | "widest\n", 551 | "widest\n", 552 | "\n" 553 | ] 554 | } 555 | ], 556 | "source": [ 557 | "restored = re.sub(\"@@( |$)\", \"\", bpe)\n", 558 | "print(restored)" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "### How to restrict vocabulary?" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 271, 571 | "metadata": {}, 572 | "outputs": [ 573 | { 574 | "data": { 575 | "text/plain": [ 576 | "0" 577 | ] 578 | }, 579 | "execution_count": 271, 580 | "metadata": {}, 581 | "output_type": "execute_result" 582 | } 583 | ], 584 | "source": [ 585 | "reapply_bpe = \"subword-nmt apply-bpe -c codes --vocabulary vocab --vocabulary-threshold 5 < toy > bpe2\"\n", 586 | "os.system(reapply_bpe)" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 272, 592 | "metadata": {}, 593 | "outputs": [ 594 | { 595 | "name": "stdout", 596 | "output_type": "stream", 597 | "text": [ 598 | "low\n", 599 | "low\n", 600 | "low\n", 601 | "low\n", 602 | "low\n", 603 | "l@@ o@@ w@@ e@@ r\n", 604 | "l@@ o@@ w@@ e@@ r\n", 605 | "newest\n", 606 | "newest\n", 607 | "newest\n", 608 | "newest\n", 609 | "newest\n", 610 | "newest\n", 611 | "w@@ i@@ d@@ e@@ s@@ t\n", 612 | "w@@ i@@ d@@ e@@ s@@ t\n", 613 | "w@@ i@@ d@@ e@@ s@@ t\n", 614 | "\n" 615 | ] 616 | } 617 | ], 618 | "source": [ 619 | "bpe2 = open('bpe2', 'r').read()\n", 620 | "print(bpe2)" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 273, 626 | "metadata": {}, 627 | "outputs": [ 628 | { 629 | "name": "stdout", 630 | "output_type": "stream", 631 | "text": [ 632 | "low\n", 633 | "low\n", 634 | "low\n", 635 | "low\n", 636 | "low\n", 637 | "lo@@ w@@ e@@ r\n", 638 | "lo@@ w@@ e@@ r\n", 639 | "newest\n", 640 | "newest\n", 641 | "newest\n", 642 | "newest\n", 643 | "newest\n", 644 | "newest\n", 645 | "widest\n", 646 | "widest\n", 647 | "widest\n", 648 | "\n" 649 | ] 650 | } 651 | ], 652 | "source": [ 653 | "# To compare with the original bpe segmented result, print it again.\n", 654 | "print(bpe)" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "metadata": {}, 660 | "source": [ 661 | "△ `widest`, which was not segmented, is segmented into `w@@ i@@ d@@ e@@ s@@ t` because the frequency of `widest` was less than 5." 662 | ] 663 | }, 664 | { 665 | "cell_type": "markdown", 666 | "metadata": {}, 667 | "source": [ 668 | "Be careful that the original vocabulary or thresholded one doesn't hold any more. We need to get the final vocabulary now." 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 274, 674 | "metadata": {}, 675 | "outputs": [ 676 | { 677 | "data": { 678 | "text/plain": [ 679 | "0" 680 | ] 681 | }, 682 | "execution_count": 274, 683 | "metadata": {}, 684 | "output_type": "execute_result" 685 | } 686 | ], 687 | "source": [ 688 | "get_vocab = \"subword-nmt get-vocab < bpe2 > vocab2\"\n", 689 | "os.system(get_vocab)" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 275, 695 | "metadata": {}, 696 | "outputs": [ 697 | { 698 | "name": "stdout", 699 | "output_type": "stream", 700 | "text": [ 701 | "newest 6\n", 702 | "low 5\n", 703 | "w@@ 5\n", 704 | "e@@ 5\n", 705 | "i@@ 3\n", 706 | "d@@ 3\n", 707 | "s@@ 3\n", 708 | "t 3\n", 709 | "l@@ 2\n", 710 | "o@@ 2\n", 711 | "r 2\n", 712 | "\n" 713 | ] 714 | } 715 | ], 716 | "source": [ 717 | "vocab2 = open(\"vocab2\", 'r').read()\n", 718 | "print(vocab2)" 719 | ] 720 | }, 721 | { 722 | "cell_type": "markdown", 723 | "metadata": {}, 724 | "source": [ 725 | "### Let's test with a bigger text." 726 | ] 727 | }, 728 | { 729 | "cell_type": "markdown", 730 | "metadata": {}, 731 | "source": [ 732 | "Download a sample file for demonstration from subword-nmt." 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 276, 738 | "metadata": {}, 739 | "outputs": [ 740 | { 741 | "data": { 742 | "text/plain": [ 743 | "0" 744 | ] 745 | }, 746 | "execution_count": 276, 747 | "metadata": {}, 748 | "output_type": "execute_result" 749 | } 750 | ], 751 | "source": [ 752 | "download = \"wget https://github.com/rsennrich/subword-nmt/raw/master/subword_nmt/tests/data/corpus.en\"\n", 753 | "os.system(download)" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": 213, 759 | "metadata": {}, 760 | "outputs": [ 761 | { 762 | "name": "stdout", 763 | "output_type": "stream", 764 | "text": [ 765 | "iron cement is a ready for use paste which is laid as a fillet by putty knife or finger in the mould edges ( corners ) of the steel ingot mould .\n", 766 | "iron cement protects the ingot against the hot , abrasive steel casting process .\n", 767 | "a fire restant repair cement for fire places , ovens , open fireplaces etc .\n", 768 | "construction and repair of highways and ...\n", 769 | "an announcement must be commercial character .\n", 770 | "goods and services advancement through the P.O.Box system is NOT ALLOWED .\n", 771 | "deliveries ( spam ) and other improper information deleted .\n", 772 | "translator Internet is a Toolbar for MS Internet Explorer .\n", 773 | "it allows you to translate in real time any web pasge from one language to another .\n", 774 | "you only have to select languages and TI does all the work for you ! automatic dictionary updates ....\n", 775 | "this software is written in order to increase your English keyboard typing speed , through teaching the basics of how to put your hand on to the keyboard and give some training examples .\n", 776 | "each lesson teaches some extra k\n" 777 | ] 778 | } 779 | ], 780 | "source": [ 781 | "print(open('corpus.en', 'r').read()[:1000])" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": 217, 787 | "metadata": {}, 788 | "outputs": [ 789 | { 790 | "name": "stdout", 791 | "output_type": "stream", 792 | "text": [ 793 | "==codes==\n", 794 | "#version: 0.2\n", 795 | "t h\n", 796 | "th e\n", 797 | "i n\n", 798 | "a n\n", 799 | "e r\n", 800 | "r e\n", 801 | "o r\n", 802 | "a r\n", 803 | "t i\n", 804 | "an d\n", 805 | "o f\n", 806 | "e n\n", 807 | "o u\n", 808 | "o n\n", 809 | "t o\n", 810 | "o n\n", 811 | "====\n", 812 | "number of codes: 1000\n", 813 | "\n", 814 | "==segmented==\n", 815 | "ir@@ on c@@ ement is a read@@ y for use pa@@ st@@ e which is la@@ id as a fil@@ let by pu@@ t@@ ty k====\n", 816 | "\n", 817 | "==vocab==\n", 818 | "the 1358\n", 819 | ", 1291\n", 820 | ". 968\n", 821 | "and 663\n", 822 | "of 651\n", 823 | "a 623\n", 824 | "in 506\n", 825 | "to 490\n", 826 | "is 351\n", 827 | "ed 279\n", 828 | "s@@ 258\n", 829 | "c@@ 254\n", 830 | "you 253\n", 831 | "for 2====\n", 832 | "number of vocab: 1120\n" 833 | ] 834 | } 835 | ], 836 | "source": [ 837 | "learn_bpe = \"subword-nmt learn-bpe -s 1000 --min-frequency 2 < corpus.en > codes\"\n", 838 | "os.system(learn_bpe)\n", 839 | "codes = open('codes', 'r').read()\n", 840 | "print(\"==codes==\\n\" + codes[:100] + \"====\")\n", 841 | "print(\"number of codes: \", len(codes.splitlines())-1)\n", 842 | "\n", 843 | "apply_bpe = \"subword-nmt apply-bpe -c codes < corpus.en > bpe\"\n", 844 | "os.system(apply_bpe)\n", 845 | "bpe = open('bpe', 'r').read()\n", 846 | "print(\"\\n==segmented==\\n\" + bpe[:100] + \"====\")\n", 847 | "\n", 848 | "get_vocab = \"subword-nmt get-vocab < bpe > vocab\"\n", 849 | "os.system(get_vocab)\n", 850 | "vocab = open('vocab', 'r').read()\n", 851 | "print(\"\\n==vocab==\\n\" + vocab[:100] + \"====\")\n", 852 | "print(\"number of vocab: \", len(vocab.splitlines()))" 853 | ] 854 | }, 855 | { 856 | "cell_type": "markdown", 857 | "metadata": {}, 858 | "source": [ 859 | "# (BPE in) SentencePiece" 860 | ] 861 | }, 862 | { 863 | "cell_type": "code", 864 | "execution_count": 104, 865 | "metadata": {}, 866 | "outputs": [ 867 | { 868 | "data": { 869 | "text/plain": [ 870 | "0" 871 | ] 872 | }, 873 | "execution_count": 104, 874 | "metadata": {}, 875 | "output_type": "execute_result" 876 | } 877 | ], 878 | "source": [ 879 | "os.system(\"pip install sentencepiece\")" 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 106, 885 | "metadata": {}, 886 | "outputs": [], 887 | "source": [ 888 | "import sentencepiece as spm" 889 | ] 890 | }, 891 | { 892 | "cell_type": "markdown", 893 | "metadata": {}, 894 | "source": [ 895 | "step 1. Train. \n", 896 | "This should generate `m.model` and `m.vocab`. This is analogous to the `learn bpe` in `subword-nmt`. However, unlike `subword-nmt`, vocabulary, not merge operations, is fixed." 897 | ] 898 | }, 899 | { 900 | "cell_type": "code", 901 | "execution_count": 238, 902 | "metadata": {}, 903 | "outputs": [ 904 | { 905 | "data": { 906 | "text/plain": [ 907 | "True" 908 | ] 909 | }, 910 | "execution_count": 238, 911 | "metadata": {}, 912 | "output_type": "execute_result" 913 | } 914 | ], 915 | "source": [ 916 | "train = '--input=corpus.en --model_prefix=m --vocab_size=1000 --model_type=bpe'\n", 917 | "spm.SentencePieceTrainer.Train(train)" 918 | ] 919 | }, 920 | { 921 | "cell_type": "markdown", 922 | "metadata": {}, 923 | "source": [ 924 | "Check the vocab file." 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "execution_count": 237, 930 | "metadata": {}, 931 | "outputs": [ 932 | { 933 | "name": "stdout", 934 | "output_type": "stream", 935 | "text": [ 936 | "\n", 937 | "==vocab==\n", 938 | "\t0\n", 939 | "\t0\n", 940 | "\t0\n", 941 | "▁t\t-0\n", 942 | "▁a\t-1\n", 943 | "▁th\t-2\n", 944 | "in\t-3\n", 945 | "▁the\t-4\n", 946 | "er\t-5\n", 947 | "▁o\t-6\n", 948 | "re\t-7\n", 949 | "▁,\t-8\n", 950 | "▁s\t-9\n", 951 | "at\t-10\n", 952 | "nd\t-11\n", 953 | "▁.\n", 954 | "====\n", 955 | "number of vocab: 1000\n" 956 | ] 957 | } 958 | ], 959 | "source": [ 960 | "vocab = open('m.vocab', 'r').read()\n", 961 | "print(\"\\n==vocab==\\n\" + vocab[:100] + \"\\n====\")\n", 962 | "print(\"number of vocab: \", len(vocab.splitlines()))" 963 | ] 964 | }, 965 | { 966 | "cell_type": "markdown", 967 | "metadata": {}, 968 | "source": [ 969 | "△ ▁, which means a space, precedes other characters." 970 | ] 971 | }, 972 | { 973 | "cell_type": "markdown", 974 | "metadata": {}, 975 | "source": [ 976 | "step 2. Encode. \n", 977 | "First load the trained model and segment the designated text file so that all the pieces in the vocabulary should be generated." 978 | ] 979 | }, 980 | { 981 | "cell_type": "code", 982 | "execution_count": 285, 983 | "metadata": {}, 984 | "outputs": [ 985 | { 986 | "name": "stdout", 987 | "output_type": "stream", 988 | "text": [ 989 | "▁ ir on ▁c e ment ▁is ▁a ▁read y ▁for ▁use ▁p ast e ▁which ▁is ▁la id ▁as ▁a ▁f ill et ▁by ▁p ut t y ▁kn ife ▁or ▁f ing er ▁in ▁the ▁m ould ▁ ed g es ▁( ▁cor n ers ▁) ▁of ▁the ▁st e el ▁in g ot ▁m ould ▁. ▁ ir on ▁c e ment ▁pr ot ect s ▁the ▁in g ot ▁ag ain st ▁the ▁hot ▁, ▁ab r as ive ▁st e el ▁c ast ing ▁process ▁. ▁a ▁f ire ▁rest ant ▁rep a ir ▁c\n", 990 | "[923, 92, 20, 18, 924, 115, 55, 4, 596, 940, 59, 362, 28, 202, 924, 173, 55, 431, 112, 97, 4, 22, 126, 58, 158, 28, 61, 925, 940, 353, 654, 119, 22, 31, 8, 30, 7, 33, 204, 923, 27, 941, 26, 146, 888, 929, 102, 150, 29, 7, 124, 924, 60, 30, 941, 46, 33, 204, 15, 923, 92, 20, 18, 924, 115, 94, 46, 186, 930, 7, 30, 941, 46, 586, 138, 73, 7, 834, 11, 279, 931, 47, 196, 124, 924, 60, 18, 202, 31, 813, 15, 4, 22, 441, 621, 192, 449, 927, 92, 18]\n" 991 | ] 992 | } 993 | ], 994 | "source": [ 995 | "# Load model\n", 996 | "sp = spm.SentencePieceProcessor()\n", 997 | "sp.Load(\"m.model\")\n", 998 | "\n", 999 | "# Segment\n", 1000 | "input_text = open('corpus.en', 'r').read()\n", 1001 | "pieces = sp.EncodeAsPieces(input_text)\n", 1002 | "ids = sp.EncodeAsIds(input_text)\n", 1003 | "print(\" \".join(pieces[:100]))\n", 1004 | "print(ids[:100])" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "markdown", 1009 | "metadata": {}, 1010 | "source": [ 1011 | "### How to restore?" 1012 | ] 1013 | }, 1014 | { 1015 | "cell_type": "code", 1016 | "execution_count": 233, 1017 | "metadata": {}, 1018 | "outputs": [ 1019 | { 1020 | "data": { 1021 | "text/plain": [ 1022 | "'iron cement is a ready for use paste which is laid as a fillet by putty knife or finger in the mould edges ( corners ) of the steel ingot mould . iron cement protects the ingot against the hot , abrasive steel casting process . a fire restant repair c'" 1023 | ] 1024 | }, 1025 | "execution_count": 233, 1026 | "metadata": {}, 1027 | "output_type": "execute_result" 1028 | } 1029 | ], 1030 | "source": [ 1031 | "sp.DecodePieces(pieces[:100])" 1032 | ] 1033 | }, 1034 | { 1035 | "cell_type": "code", 1036 | "execution_count": 239, 1037 | "metadata": { 1038 | "scrolled": false 1039 | }, 1040 | "outputs": [ 1041 | { 1042 | "data": { 1043 | "text/plain": [ 1044 | "'iron cement is a ready for use paste which is laid as a fillet by putty knife or finger in the mould edges ( corners ) of the steel ingot mould . iron cement protects the ingot against the hot , abrasive steel casting process . a fire restant repair c'" 1045 | ] 1046 | }, 1047 | "execution_count": 239, 1048 | "metadata": {}, 1049 | "output_type": "execute_result" 1050 | } 1051 | ], 1052 | "source": [ 1053 | "sp.DecodeIds(ids[:100])" 1054 | ] 1055 | }, 1056 | { 1057 | "cell_type": "markdown", 1058 | "metadata": {}, 1059 | "source": [ 1060 | "# Bert Tokenizer" 1061 | ] 1062 | }, 1063 | { 1064 | "cell_type": "code", 1065 | "execution_count": 241, 1066 | "metadata": {}, 1067 | "outputs": [ 1068 | { 1069 | "data": { 1070 | "text/plain": [ 1071 | "0" 1072 | ] 1073 | }, 1074 | "execution_count": 241, 1075 | "metadata": {}, 1076 | "output_type": "execute_result" 1077 | } 1078 | ], 1079 | "source": [ 1080 | "os.system(\"pip install pytorch_pretrained_bert\")" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": 242, 1086 | "metadata": {}, 1087 | "outputs": [ 1088 | { 1089 | "name": "stdout", 1090 | "output_type": "stream", 1091 | "text": [ 1092 | "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n" 1093 | ] 1094 | } 1095 | ], 1096 | "source": [ 1097 | "from pytorch_pretrained_bert import BertTokenizer, BertModel\n", 1098 | "tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")" 1099 | ] 1100 | }, 1101 | { 1102 | "cell_type": "code", 1103 | "execution_count": 287, 1104 | "metadata": {}, 1105 | "outputs": [], 1106 | "source": [ 1107 | "input_text = open(\"corpus.en\", \"r\").read()\n", 1108 | "pieces = tokenizer.tokenize(input_text)" 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "code", 1113 | "execution_count": 288, 1114 | "metadata": {}, 1115 | "outputs": [ 1116 | { 1117 | "data": { 1118 | "text/plain": [ 1119 | "'iron cement is a ready for use paste which is laid as a fill ##et by put ##ty knife or'" 1120 | ] 1121 | }, 1122 | "execution_count": 288, 1123 | "metadata": {}, 1124 | "output_type": "execute_result" 1125 | } 1126 | ], 1127 | "source": [ 1128 | "\" \".join(pieces[:20])" 1129 | ] 1130 | }, 1131 | { 1132 | "cell_type": "markdown", 1133 | "metadata": {}, 1134 | "source": [ 1135 | "Bert Tokenizer is composed of Basic Tokenizer, which splits punctuations, and WordPiece Tokenizer. That can be a problem if you want to restore the original text. https://github.com/huggingface/pytorch-pretrained-BERT/issues/36\n" 1136 | ] 1137 | }, 1138 | { 1139 | "cell_type": "markdown", 1140 | "metadata": {}, 1141 | "source": [ 1142 | "Bert Tokenizer uses ##. It is different from @@ in BPE or ▁ in SentencePiece. \n", 1143 | "@@ is attached to the end of subwords, while ## and ▁ is to the front. \n", 1144 | "`@@ + space` and `space + ##` are removed for restoration, while ▁ is replaced by a space." 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "execution_count": null, 1150 | "metadata": {}, 1151 | "outputs": [], 1152 | "source": [] 1153 | } 1154 | ], 1155 | "metadata": { 1156 | "kernelspec": { 1157 | "display_name": "Python 3", 1158 | "language": "python", 1159 | "name": "python3" 1160 | }, 1161 | "language_info": { 1162 | "codemirror_mode": { 1163 | "name": "ipython", 1164 | "version": 3 1165 | }, 1166 | "file_extension": ".py", 1167 | "mimetype": "text/x-python", 1168 | "name": "python", 1169 | "nbconvert_exporter": "python", 1170 | "pygments_lexer": "ipython3", 1171 | "version": "3.7.1" 1172 | } 1173 | }, 1174 | "nbformat": 4, 1175 | "nbformat_minor": 2 1176 | } 1177 | -------------------------------------------------------------------------------- /Tensorflow seq2seq template based on g2p.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "We'll write a simple template for seq2seq using Tensorflow. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. If you want to know more about g2p, see my [repo](https://github.com/kyubyong/g2p)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "__author__ = \"kyubyong\"\n", 17 | "__address__ = \"https://github.com/kyubyong/nlp_made_easy\"\n", 18 | "__email__ = \"kbpark.linguist@gmail.com\"" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": { 25 | "colab": {}, 26 | "colab_type": "code", 27 | "id": "TEXqpZ_U738q" 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import numpy as np\n", 32 | "import tensorflow as tf\n", 33 | "from tqdm import tqdm_notebook as tqdm\n", 34 | "from distance import levenshtein\n", 35 | "import os\n", 36 | "import math" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": { 43 | "colab": { 44 | "base_uri": "https://localhost:8080/", 45 | "height": 34 46 | }, 47 | "colab_type": "code", 48 | "id": "_7vuctbU7381", 49 | "outputId": "f8ee2cbf-1f04-432f-ba42-d25fec61669b" 50 | }, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "'1.12.0'" 56 | ] 57 | }, 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "tf.__version__" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": { 70 | "colab_type": "text", 71 | "id": "B6te4HKk738_" 72 | }, 73 | "source": [ 74 | "# Hyperparameters" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 65, 80 | "metadata": { 81 | "colab": {}, 82 | "colab_type": "code", 83 | "id": "CWS2hkce739C" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "class Hparams:\n", 88 | " batch_size = 128\n", 89 | " enc_maxlen = 20\n", 90 | " dec_maxlen = 20\n", 91 | " num_epochs = 10\n", 92 | " hidden_units = 128\n", 93 | " graphemes = [\"\", \"\", \"\"] + list(\"abcdefghijklmnopqrstuvwxyz\")\n", 94 | " phonemes = [\"\", \"\", \"\", \"\"] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',\n", 95 | " 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH',\n", 96 | " 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH',\n", 97 | " 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1',\n", 98 | " 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW',\n", 99 | " 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH']\n", 100 | " lr = 0.001\n", 101 | " eval_steps = 500\n", 102 | " logdir = \"log/04\"\n", 103 | "hp = Hparams()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": { 109 | "colab_type": "text", 110 | "id": "nz-hD6dn739L" 111 | }, 112 | "source": [ 113 | "# Prepare Data" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 66, 119 | "metadata": { 120 | "colab": {}, 121 | "colab_type": "code", 122 | "id": "-as4PHs-739N" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "import nltk\n", 127 | "# nltk.download('cmudict')# <- if you haven't downloaded, do this.\n", 128 | "from nltk.corpus import cmudict\n", 129 | "cmu = cmudict.dict()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 67, 135 | "metadata": { 136 | "colab": {}, 137 | "colab_type": "code", 138 | "id": "39gQ3vOi739S" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "def load_vocab():\n", 143 | " g2idx = {g: idx for idx, g in enumerate(hp.graphemes)}\n", 144 | " idx2g = {idx: g for idx, g in enumerate(hp.graphemes)}\n", 145 | "\n", 146 | " p2idx = {p: idx for idx, p in enumerate(hp.phonemes)}\n", 147 | " idx2p = {idx: p for idx, p in enumerate(hp.phonemes)}\n", 148 | "\n", 149 | " return g2idx, idx2g, p2idx, idx2p # note that g and p mean grapheme and phoneme, respectively." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 68, 155 | "metadata": { 156 | "colab": {}, 157 | "colab_type": "code", 158 | "id": "Zslytxn6739Z" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "def prepare_data():\n", 163 | " words = [\" \".join(list(word)) for word, prons in cmu.items()]\n", 164 | " prons = [\" \".join(prons[0]) for word, prons in cmu.items()]\n", 165 | " indices = list(range(len(words)))\n", 166 | " from random import shuffle\n", 167 | " shuffle(indices)\n", 168 | " words = [words[idx] for idx in indices]\n", 169 | " prons = [prons[idx] for idx in indices]\n", 170 | " num_train, num_test = int(len(words)*.8), int(len(words)*.1)\n", 171 | " train_words, eval_words, test_words = words[:num_train], \\\n", 172 | " words[num_train:-num_test],\\\n", 173 | " words[-num_test:]\n", 174 | " train_prons, eval_prons, test_prons = prons[:num_train], \\\n", 175 | " prons[num_train:-num_test],\\\n", 176 | " prons[-num_test:] \n", 177 | " return train_words, eval_words, test_words, train_prons, eval_prons, test_prons" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 69, 183 | "metadata": { 184 | "colab": {}, 185 | "colab_type": "code", 186 | "id": "WHBXkAPG739j" 187 | }, 188 | "outputs": [ 189 | { 190 | "name": "stdout", 191 | "output_type": "stream", 192 | "text": [ 193 | "q u a l i t a t i v e\n", 194 | "K W AA1 L AH0 T EY2 T IH0 V\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "train_words, eval_words, test_words, train_prons, eval_prons, test_prons = prepare_data()\n", 200 | "print(train_words[0])\n", 201 | "print(train_prons[0])" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 70, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "def drop_lengthy_samples(words, prons, enc_maxlen, dec_maxlen):\n", 211 | " \"\"\"We only include such samples less than maxlen.\"\"\"\n", 212 | " _words, _prons = [], []\n", 213 | " for w, p in zip(words, prons):\n", 214 | " if len(w.split()) + 1 > enc_maxlen: continue\n", 215 | " if len(p.split()) + 1 > dec_maxlen: continue # 1: \n", 216 | " _words.append(w)\n", 217 | " _prons.append(p)\n", 218 | " return _words, _prons" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 71, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "train_words, train_prons = drop_lengthy_samples(train_words, train_prons, hp.enc_maxlen, hp.dec_maxlen)\n", 228 | "# We do NOT apply this constraint to eval and test datasets." 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": { 234 | "colab_type": "text", 235 | "id": "KfHMTzeH7394" 236 | }, 237 | "source": [ 238 | "# Data Loader" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 72, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "def encode(inp, type, dict):\n", 248 | " '''type: \"x\" or \"y\"'''\n", 249 | " inp_str = inp.decode(\"utf-8\")\n", 250 | " if type==\"x\": tokens = inp_str.split() + [\"\"]\n", 251 | " else: tokens = [\"\"] + inp_str.split() + [\"\"]\n", 252 | "\n", 253 | " x = [dict.get(t, dict[\"\"]) for t in tokens]\n", 254 | " return x\n", 255 | " " 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 73, 261 | "metadata": { 262 | "colab": {}, 263 | "colab_type": "code", 264 | "id": "G6jSAgus7399" 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "def generator_fn(words, prons):\n", 269 | " '''\n", 270 | " words: 1d byte array. e.g., [b\"w o r d\", ]\n", 271 | " prons: 1d byte array. e.g., [b'W ER1 D', ]\n", 272 | " \n", 273 | " yields\n", 274 | " xs: tuple of\n", 275 | " x: list of encoded x. encoder input\n", 276 | " x_seqlen: scalar.\n", 277 | " word: string\n", 278 | " \n", 279 | " ys: tuple of\n", 280 | " decoder_input: list of decoder inputs\n", 281 | " y: list of encoded y. label.\n", 282 | " y_seqlen: scalar.\n", 283 | " pron: string\n", 284 | " '''\n", 285 | " g2idx, idx2g, p2idx, idx2p = load_vocab()\n", 286 | " for word, pron in zip(words, prons):\n", 287 | " x = encode(word, \"x\", g2idx)\n", 288 | " y = encode(pron, \"y\", p2idx)\n", 289 | " decoder_input, y = y[:-1], y[1:]\n", 290 | "\n", 291 | " x_seqlen, y_seqlen = len(x), len(y)\n", 292 | " yield (x, x_seqlen, word), (decoder_input, y, y_seqlen, pron)\n", 293 | " " 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 74, 299 | "metadata": { 300 | "colab": {}, 301 | "colab_type": "code", 302 | "id": "QEvz4YTR73-I" 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "def input_fn(words, prons, batch_size, shuffle=False):\n", 307 | " '''Batchify data\n", 308 | " words: list of words. e.g., [\"word\", ]\n", 309 | " prons: list of prons. e.g., ['W ER1 D',]\n", 310 | " batch_size: scalar.\n", 311 | " shuffle: boolean\n", 312 | " '''\n", 313 | " shapes = ( ([None], (), ()),\n", 314 | " ([None], [None], (), ()) )\n", 315 | " types = ( (tf.int32, tf.int32, tf.string),\n", 316 | " (tf.int32, tf.int32, tf.int32, tf.string) )\n", 317 | " paddings = ( (0, 0, ''),\n", 318 | " (0, 0, 0, '') )\n", 319 | "\n", 320 | " dataset = tf.data.Dataset.from_generator(\n", 321 | " generator_fn,\n", 322 | " output_shapes=shapes,\n", 323 | " output_types=types,\n", 324 | " args=(words, prons)) # <- converted to np string arrays\n", 325 | " \n", 326 | " if shuffle:\n", 327 | " dataset = dataset.shuffle(128*batch_size) \n", 328 | " dataset = dataset.repeat() # iterate forever\n", 329 | " dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)\n", 330 | "\n", 331 | " return dataset" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 75, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "def get_batch(words, prons, batch_size, shuffle=False):\n", 341 | " '''Gets training / evaluation mini-batches\n", 342 | " fpath1: source file path. string.\n", 343 | " fpath2: target file path. string.\n", 344 | " maxlen1: source sent maximum length. scalar.\n", 345 | " maxlen2: target sent maximum length. scalar.\n", 346 | " vocab_fpath: string. vocabulary file path.\n", 347 | " batch_size: scalar\n", 348 | " shuffle: boolean\n", 349 | "\n", 350 | " Returns\n", 351 | " batches\n", 352 | " num_batches: number of mini-batches\n", 353 | " num_samples\n", 354 | " '''\n", 355 | " batches = input_fn(words, prons, batch_size, shuffle=shuffle)\n", 356 | " num_batches = calc_num_batches(len(words), batch_size)\n", 357 | " return batches, num_batches, len(words)\n" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "metadata": { 363 | "colab_type": "text", 364 | "id": "22mif4xf73-M" 365 | }, 366 | "source": [ 367 | "# Model" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 76, 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "def convert_idx_to_token_tensor(inputs, idx2token):\n", 377 | " '''Converts int32 tensor to string tensor.\n", 378 | " inputs: 1d int32 tensor. indices.\n", 379 | " idx2token: dictionary\n", 380 | "\n", 381 | " Returns\n", 382 | " 1d string tensor.\n", 383 | " '''\n", 384 | " def my_func(inputs):\n", 385 | " return \" \".join(idx2token[elem] for elem in inputs)\n", 386 | "\n", 387 | " return tf.py_func(my_func, [inputs], tf.string)\n" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 77, 393 | "metadata": { 394 | "colab": {}, 395 | "colab_type": "code", 396 | "id": "HA39FU4-73-O" 397 | }, 398 | "outputs": [], 399 | "source": [ 400 | "class Net:\n", 401 | " def __init__(self, hp):\n", 402 | " self.g2idx, self.idx2g, self.p2idx, self.idx2p = load_vocab()\n", 403 | " self.hp = hp\n", 404 | " \n", 405 | " def encode(self, xs):\n", 406 | " '''\n", 407 | " xs: tupple of \n", 408 | " x: (N, T). int32\n", 409 | " seqlens: (N,). int32\n", 410 | " words: (N,). string\n", 411 | " \n", 412 | " returns\n", 413 | " last hidden: (N, hidden_units). float32\n", 414 | " words: (N,). string\n", 415 | " '''\n", 416 | " with tf.variable_scope(\"encode\", reuse=tf.AUTO_REUSE):\n", 417 | " x, seqlens, words = xs\n", 418 | " x = tf.one_hot(x, len(self.g2idx))\n", 419 | " cell = tf.contrib.rnn.GRUCell(self.hp.hidden_units)\n", 420 | " _, last_hidden = tf.nn.dynamic_rnn(cell, x, seqlens, dtype=tf.float32)\n", 421 | " \n", 422 | " return last_hidden, words\n", 423 | " \n", 424 | " \n", 425 | " def decode(self, ys, h0=None):\n", 426 | " '''\n", 427 | " ys: tupple of \n", 428 | " decoder_inputs: (N, T). int32\n", 429 | " y: (N, T). int32\n", 430 | " seqlens: (N,). int32\n", 431 | " prons: (N,). string.\n", 432 | " h0: initial hidden state. (N, hidden_units)\n", 433 | " \n", 434 | " returns\n", 435 | " logits: (N, T, len(p2idx)). float32. before softmax\n", 436 | " y_hat: (N, T). int32.\n", 437 | " y: (N, T). int32. label.\n", 438 | " prons: (N,). string. ground truth phonemes \n", 439 | " last_hidden: (N, hidden_units). This is for autoregressive inference\n", 440 | " '''\n", 441 | " decoder_inputs, y, seqlens, prons = ys\n", 442 | " \n", 443 | " with tf.variable_scope(\"decode\", reuse=tf.AUTO_REUSE):\n", 444 | " inputs = tf.one_hot(decoder_inputs, len(self.p2idx))\n", 445 | " \n", 446 | " cell = tf.contrib.rnn.GRUCell(self.hp.hidden_units)\n", 447 | " outputs, last_hidden = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, dtype=tf.float32)\n", 448 | "\n", 449 | " # projection\n", 450 | " logits = tf.layers.dense(outputs, len(self.p2idx))\n", 451 | " y_hat = tf.to_int32(tf.argmax(logits, axis=-1))\n", 452 | " \n", 453 | " return logits, y_hat, y, prons, last_hidden\n", 454 | " \n", 455 | " def train(self, xs, ys):\n", 456 | " # forward\n", 457 | " last_hidden, words = self.encode(xs)\n", 458 | " logits, y_hat, y, prons, last_hidden = self.decode(ys, h0=last_hidden)\n", 459 | " \n", 460 | " # train scheme\n", 461 | " ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)\n", 462 | " nonpadding = tf.to_float(tf.not_equal(y, self.p2idx[\"\"])) # 0: \n", 463 | " loss = tf.reduce_sum(ce*nonpadding) / (tf.reduce_sum(nonpadding)+1e-7)\n", 464 | "\n", 465 | " global_step = tf.train.get_or_create_global_step()\n", 466 | " train_op = tf.train.AdamOptimizer(hp.lr).minimize(loss, global_step=global_step)\n", 467 | " \n", 468 | " return loss, train_op, global_step\n", 469 | "\n", 470 | " \n", 471 | " def eval(self, xs, ys):\n", 472 | " '''Predicts autoregressively\n", 473 | " At inference input ys is ignored.\n", 474 | " Returns\n", 475 | " y_hat: (N, T2)\n", 476 | " '''\n", 477 | " decoder_inputs, y, seqlens, prons = ys\n", 478 | " decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.p2idx[\"\"]\n", 479 | " ys = (decoder_inputs, y, seqlens, prons)\n", 480 | "\n", 481 | " last_hidden, words = self.encode(xs)\n", 482 | "\n", 483 | " \n", 484 | " h0 = last_hidden\n", 485 | " y_hats = []\n", 486 | " print(\"Inference graph is being built. Please be patient.\")\n", 487 | " for t in tqdm(range(self.hp.dec_maxlen)):\n", 488 | " _, y_hat, _, _, h0 = self.decode(ys, h0)\n", 489 | " if tf.reduce_sum(y_hat, 1)==0: break\n", 490 | " \n", 491 | " ys = (y_hat, y, seqlens, prons)\n", 492 | " y_hats.append(tf.squeeze(y_hat))\n", 493 | " y_hats = tf.stack(y_hats, 1)\n", 494 | " \n", 495 | " # monitor a random sample\n", 496 | " n = tf.random_uniform((), 0, tf.shape(y_hats)[0]-1, tf.int32)\n", 497 | " word = words[n]\n", 498 | " pred = convert_idx_to_token_tensor(y_hats[n], self.idx2p)\n", 499 | " pron = prons[n]\n", 500 | " \n", 501 | " return y_hats, word, pred, pron\n", 502 | "\n", 503 | " " 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": { 509 | "colab_type": "text", 510 | "id": "PKllLnfp73-V" 511 | }, 512 | "source": [ 513 | "# Train & Evaluate" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 78, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [ 522 | "def calc_num_batches(total_num, batch_size):\n", 523 | " return total_num // batch_size + int(total_num % batch_size != 0) " 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 79, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "# evaluation metric\n", 533 | "def per(ref, hyp):\n", 534 | " '''Calc phoneme error rate\n", 535 | " hyp: list of predicted phoneme sequences. e.g., [[\"B\", \"L\", \"AA1\", \"K\", \"HH\", \"AW2\", \"S\"], ...]\n", 536 | " ref: list of ground truth phoneme sequences. e.g., [[\"B\", \"L\", \"AA1\", \"K\", \"HH\", \"AW2\", \"S\"], ...]\n", 537 | " '''\n", 538 | " num_phonemes, num_erros = 0, 0\n", 539 | " g2idx, idx2g, p2idx, idx2p = load_vocab()\n", 540 | " for r, h in zip(ref, hyp):\n", 541 | " r = r.split()\n", 542 | " h = \" \".join(idx2p[idx] for idx in h)\n", 543 | " h = h.split(\"\")[0].strip().split()\n", 544 | " \n", 545 | " num_phonemes += len(r)\n", 546 | " num_erros += levenshtein(h, r)\n", 547 | "# print(h, r)\n", 548 | " per = round(num_erros / num_phonemes, 2)\n", 549 | " return per" 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "execution_count": 80, 555 | "metadata": { 556 | "colab": {}, 557 | "colab_type": "code", 558 | "id": "64f3a-fb73-Y", 559 | "scrolled": false 560 | }, 561 | "outputs": [], 562 | "source": [ 563 | "tf.reset_default_graph()\n", 564 | "# prepare batches\n", 565 | "train_batches, num_train_batches, num_train_samples = get_batch(train_words, train_prons,\n", 566 | " hp.batch_size, shuffle=True)\n", 567 | "eval_batches, num_eval_batches, num_eval_samples = get_batch(eval_words, eval_prons,\n", 568 | " hp.batch_size, shuffle=False)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 81, 574 | "metadata": { 575 | "colab": {}, 576 | "colab_type": "code", 577 | "id": "iL3CK4NW73-g" 578 | }, 579 | "outputs": [], 580 | "source": [ 581 | "# create a iterator of the correct shape and type\n", 582 | "iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)\n", 583 | "\n", 584 | "# create the initialisation operations\n", 585 | "train_init_op = iter.make_initializer(train_batches)\n", 586 | "eval_init_op = iter.make_initializer(eval_batches)" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": 82, 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "# variable specs\n", 596 | "def print_variable_specs(fpath):\n", 597 | " def get_size(shp):\n", 598 | " size = 1\n", 599 | " for d in range(len(shp)):\n", 600 | " size *=shp[d]\n", 601 | " return size\n", 602 | "\n", 603 | " params, num_params = [], 0\n", 604 | " for v in tf.global_variables():\n", 605 | " params.append(\"{}==={}\\n\".format(v.name, v.shape))\n", 606 | " num_params += get_size(v.shape)\n", 607 | " print(\"num_params:\", num_params)\n", 608 | "# with open(fpath, 'w') as fout:\n", 609 | "# fout.write(\"num_params: {}\\n\".format(num_params))\n", 610 | "# fout.write(\"\\n\".join(params))" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 83, 616 | "metadata": { 617 | "scrolled": false 618 | }, 619 | "outputs": [ 620 | { 621 | "name": "stdout", 622 | "output_type": "stream", 623 | "text": [ 624 | "Inference graph is being built. Please be patient.\n" 625 | ] 626 | }, 627 | { 628 | "data": { 629 | "application/vnd.jupyter.widget-view+json": { 630 | "model_id": "7b3f85d3acbc43dc8ebf4eec2ea9335c", 631 | "version_major": 2, 632 | "version_minor": 0 633 | }, 634 | "text/plain": [ 635 | "HBox(children=(IntProgress(value=0, max=20), HTML(value='')))" 636 | ] 637 | }, 638 | "metadata": {}, 639 | "output_type": "display_data" 640 | } 641 | ], 642 | "source": [ 643 | "# Load model\n", 644 | "net = Net(hp)\n", 645 | "xs, ys = iter.get_next()\n", 646 | "loss, train_op, global_step = net.train(xs, ys)\n", 647 | "y_hat, word, pred, pron = net.eval(xs, ys)" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 84, 653 | "metadata": { 654 | "colab": { 655 | "base_uri": "https://localhost:8080/", 656 | "height": 14303 657 | }, 658 | "colab_type": "code", 659 | "id": "frKAWTc873-q", 660 | "outputId": "2d464429-3e88-4f3f-9a9d-d64094d995f9", 661 | "scrolled": false 662 | }, 663 | "outputs": [ 664 | { 665 | "name": "stdout", 666 | "output_type": "stream", 667 | "text": [ 668 | "Variables initialized\n", 669 | "num_params: 444513\n" 670 | ] 671 | }, 672 | { 673 | "data": { 674 | "application/vnd.jupyter.widget-view+json": { 675 | "model_id": "cc6815075e624866b77282ec647338ec", 676 | "version_major": 2, 677 | "version_minor": 0 678 | }, 679 | "text/plain": [ 680 | "HBox(children=(IntProgress(value=0, max=7721), HTML(value='')))" 681 | ] 682 | }, 683 | "metadata": {}, 684 | "output_type": "display_data" 685 | }, 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "epoch= 1 is done!\n", 691 | "wrd: m a p e l\n", 692 | "exp: M AE1 P AH0 L\n", 693 | "got: M AE1 P L \n", 694 | "per=0.56\n", 695 | "\n", 696 | "epoch= 2 is done!\n", 697 | "wrd: s t e a r i c\n", 698 | "exp: S T IY1 R IH0 K\n", 699 | "got: S T EH1 R IY0 Z N Z N \n", 700 | "per=0.41\n", 701 | "\n", 702 | "epoch= 3 is done!\n", 703 | "wrd: s c o l d e d\n", 704 | "exp: S K OW1 L D AH0 D\n", 705 | "got: S K OW1 D L D L T V UW0 T UW0 V UW0\n", 706 | "per=0.34\n", 707 | "\n", 708 | "epoch= 4 is done!\n", 709 | "wrd: s c o l d e d\n", 710 | "exp: S K OW1 L D AH0 D\n", 711 | "got: S K OW1 L D D L T L T V EY0 T W EH1\n", 712 | "per=0.30\n", 713 | "\n", 714 | "epoch= 5 is done!\n", 715 | "wrd: n o a\n", 716 | "exp: N OW1 AH0\n", 717 | "got: N OW1 EH1 L EH1 N T EH1 T EH1 T \n", 718 | "per=0.26\n", 719 | "\n", 720 | "epoch= 6 is done!\n", 721 | "wrd: c o n f i d e n t i a l l y\n", 722 | "exp: K AA2 N F AH0 D EH1 N SH AH0 L IY0\n", 723 | "got: K AH0 N F EH1 D AH0 N T AH0 L IY0 AW1 \n", 724 | "per=0.24\n", 725 | "\n", 726 | "epoch= 7 is done!\n", 727 | "wrd: d o l i n g\n", 728 | "exp: D OW1 L IH0 NG\n", 729 | "got: D OW1 L IH0 NG G T AH1 V AH1 T\n", 730 | "per=0.22\n", 731 | "\n", 732 | "epoch= 8 is done!\n", 733 | "wrd: e q u a t o r\n", 734 | "exp: IH0 K W EY1 T ER0\n", 735 | "got: EH1 K W AA2 T ER0 TH AO1 TH AO1 AH1 T AH2\n", 736 | "per=0.22\n", 737 | "\n", 738 | "epoch= 9 is done!\n", 739 | "wrd: c a r t e r s v i l l e\n", 740 | "exp: K AA1 R T ER0 Z V IH2 L\n", 741 | "got: K AA1 R T ER0 S IH0 V L IY0 G TH G M AH1 TH\n", 742 | "per=0.20\n", 743 | "\n", 744 | "epoch= 10 is done!\n", 745 | "wrd: s e e f e l d t\n", 746 | "exp: S IY1 F IH0 L T\n", 747 | "got: S IY1 F EH2 L T D T \n", 748 | "per=0.19\n", 749 | "\n", 750 | "Training Done!\n" 751 | ] 752 | } 753 | ], 754 | "source": [ 755 | "# Session\n", 756 | "saver = tf.train.Saver()\n", 757 | "with tf.Session() as sess:\n", 758 | " ckpt = tf.train.latest_checkpoint(hp.logdir)\n", 759 | " if ckpt is None:\n", 760 | " sess.run(tf.global_variables_initializer())\n", 761 | " print(\"Variables initialized\")\n", 762 | " else:\n", 763 | " saver.restore(sess, ckpt)\n", 764 | " print(\"Restored from file: \", ckpt)\n", 765 | "\n", 766 | " print_variable_specs('specs')\n", 767 | "\n", 768 | " sess.run(train_init_op)\n", 769 | " total_steps = hp.num_epochs*num_train_batches\n", 770 | " _gs = sess.run(global_step)\n", 771 | " for _ in tqdm(range(_gs, total_steps+1)):\n", 772 | " # training\n", 773 | " _, _gs, _loss = sess.run([train_op, global_step,loss]) \n", 774 | "\n", 775 | " epoch = math.ceil(_gs / num_train_batches)\n", 776 | " \n", 777 | " if _gs and _gs % num_train_batches == 0: # Be careful that you should evaluate at every epoch due to train_init_op\n", 778 | " print(\"epoch=\", epoch, \"is done!\")\n", 779 | " sess.run(eval_init_op)\n", 780 | " _y_hats = []\n", 781 | " for _ in range(num_eval_batches):\n", 782 | " _y_hat, _word, _pred, _pron = sess.run([y_hat, word, pred, pron])\n", 783 | " _y_hats.extend(_y_hat.tolist())\n", 784 | " \n", 785 | " # sample monitor\n", 786 | " print(\"wrd:\", _word.decode(\"utf-8\"))\n", 787 | " print(\"exp:\", _pron.decode(\"utf-8\"))\n", 788 | " print(\"got:\", _pred.decode(\"utf-8\"))\n", 789 | " \n", 790 | " \n", 791 | " _per = per(eval_prons, _y_hats)\n", 792 | " print(\"per=%.2f\"%_per)\n", 793 | " print()\n", 794 | " \n", 795 | " sess.run(train_init_op)\n", 796 | " \n", 797 | " # save\n", 798 | " if not os.path.exists(hp.logdir): os.makedirs(hp.logdir)\n", 799 | " fname = os.path.join(hp.logdir, \"my_model_loss_%.2f_per_%.2f\" % (_loss, _per))\n", 800 | " saver.save(sess, fname, global_step=_gs)\n", 801 | " \n", 802 | " print(\"Training Done!\")" 803 | ] 804 | }, 805 | { 806 | "cell_type": "markdown", 807 | "metadata": { 808 | "colab_type": "text", 809 | "id": "82t4Dmwp73--" 810 | }, 811 | "source": [ 812 | "# Inference" 813 | ] 814 | }, 815 | { 816 | "cell_type": "code", 817 | "execution_count": 85, 818 | "metadata": {}, 819 | "outputs": [], 820 | "source": [ 821 | "tf.reset_default_graph()\n", 822 | "test_batches, num_test_batches, num_test_samples = get_batch(test_words, test_prons,\n", 823 | " hp.batch_size,\n", 824 | " shuffle=False)\n", 825 | "iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes)\n", 826 | "\n", 827 | "# create the initialisation operations\n", 828 | "test_init_op = iter.make_initializer(test_batches)" 829 | ] 830 | }, 831 | { 832 | "cell_type": "code", 833 | "execution_count": 88, 834 | "metadata": {}, 835 | "outputs": [ 836 | { 837 | "name": "stdout", 838 | "output_type": "stream", 839 | "text": [ 840 | "Inference graph is being built. Please be patient.\n" 841 | ] 842 | }, 843 | { 844 | "data": { 845 | "application/vnd.jupyter.widget-view+json": { 846 | "model_id": "b109bf9c46024b1e81b4033a4c5ced26", 847 | "version_major": 2, 848 | "version_minor": 0 849 | }, 850 | "text/plain": [ 851 | "HBox(children=(IntProgress(value=0, max=20), HTML(value='')))" 852 | ] 853 | }, 854 | "metadata": {}, 855 | "output_type": "display_data" 856 | } 857 | ], 858 | "source": [ 859 | "# Load model\n", 860 | "xs, ys = iter.get_next()\n", 861 | "net = Net(hp)\n", 862 | "y_hat, _, _, _ = net.eval(xs, ys)" 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "execution_count": 89, 868 | "metadata": { 869 | "scrolled": false 870 | }, 871 | "outputs": [ 872 | { 873 | "name": "stdout", 874 | "output_type": "stream", 875 | "text": [ 876 | "log/04/my_model_loss_0.40_per_0.19-7720\n", 877 | "INFO:tensorflow:Restoring parameters from log/04/my_model_loss_0.40_per_0.19-7720\n", 878 | "checkpoint restored\n", 879 | "per=0.20\n", 880 | "Done!\n" 881 | ] 882 | } 883 | ], 884 | "source": [ 885 | "# saver for restoration\n", 886 | "ckpt = tf.train.latest_checkpoint(hp.logdir)\n", 887 | "print(ckpt)\n", 888 | "# saver = tf.train.import_meta_graph(ckpt + \".meta\")# <- Do NOT use this as we'll use a distinct graph.\n", 889 | "saver = tf.train.Saver()\n", 890 | " \n", 891 | "with tf.Session() as sess:\n", 892 | " \n", 893 | " saver.restore(sess, ckpt); print(\"checkpoint restored\") \n", 894 | " sess.run(test_init_op)\n", 895 | "\n", 896 | " _y_hats = []\n", 897 | " for _ in range(num_test_batches):\n", 898 | " _y_hat = sess.run(y_hat)\n", 899 | " _y_hats.extend(_y_hat.tolist())\n", 900 | " \n", 901 | " _per = per(test_prons, _y_hats)\n", 902 | " \n", 903 | " print(\"per=%.2f\"%_per)\n", 904 | " \n", 905 | " # save\n", 906 | " g2idx, idx2g, p2idx, idx2p = load_vocab()\n", 907 | " \n", 908 | " with open(\"result\", 'w') as fout:\n", 909 | " fout.write(\"per: %.2f\\n\" % _per)\n", 910 | " for w, r, h in zip(test_words, test_prons, _y_hats):\n", 911 | " w = w.replace(\" \", \"\")\n", 912 | " h = \" \".join(idx2p[idx] for idx in h)\n", 913 | " h = h.split(\"\")[0].strip()\n", 914 | " fout.write(\"wrd: {}\\nexp: {}\\ngot: {}\\n\\n\".format(w, r, h))\n", 915 | " \n", 916 | " print(\"Done!\")" 917 | ] 918 | }, 919 | { 920 | "cell_type": "markdown", 921 | "metadata": {}, 922 | "source": [ 923 | "Let's see some results." 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": 90, 929 | "metadata": { 930 | "scrolled": false 931 | }, 932 | "outputs": [ 933 | { 934 | "data": { 935 | "text/plain": [ 936 | "['wrd: campau',\n", 937 | " 'exp: K AA1 M P AW0',\n", 938 | " 'got: K AE1 M P OW2',\n", 939 | " '',\n", 940 | " 'wrd: tension',\n", 941 | " 'exp: T EH1 N SH AH0 N',\n", 942 | " 'got: T EH1 N S IY0 AH0 N',\n", 943 | " '',\n", 944 | " 'wrd: pithy',\n", 945 | " 'exp: P IH1 TH IY0',\n", 946 | " 'got: P IH1 TH IY0',\n", 947 | " '',\n", 948 | " 'wrd: blaisdell',\n", 949 | " 'exp: B L EY1 S D AH0 L',\n", 950 | " 'got: B L EY1 S D AH0 L',\n", 951 | " '',\n", 952 | " 'wrd: reflectone',\n", 953 | " 'exp: R IY0 F L EH1 K T OW2 N',\n", 954 | " 'got: R IY0 F L EH1 K T AH0 N',\n", 955 | " '',\n", 956 | " 'wrd: cherishing',\n", 957 | " 'exp: CH EH1 R IH0 SH IH0 NG',\n", 958 | " 'got: CH EH1 R IH0 SH IH0 NG',\n", 959 | " '',\n", 960 | " 'wrd: necessitate',\n", 961 | " 'exp: N AH0 S EH1 S AH0 T EY2 T',\n", 962 | " 'got: N EH2 S AH0 S EH1 T IH0 T',\n", 963 | " '',\n", 964 | " 'wrd: swiatkowski',\n", 965 | " 'exp: S V IY0 AH0 T K AO1 F S K IY0',\n", 966 | " 'got: S W IH0 T AO1 K S W IH0 K',\n", 967 | " '',\n", 968 | " 'wrd: tendons',\n", 969 | " 'exp: T EH1 N D AH0 N Z',\n", 970 | " 'got: T EH1 N D AH0 N Z',\n", 971 | " '',\n", 972 | " 'wrd: nucleonic',\n", 973 | " 'exp: N UW2 K L IY0 AA1 N IH0 K',\n", 974 | " 'got: N AH0 K L EH1 N IH0 K',\n", 975 | " '',\n", 976 | " 'wrd: nutone',\n", 977 | " 'exp: N UW1 T OW2 N',\n", 978 | " 'got: N UW1 T OW2 N',\n", 979 | " '',\n", 980 | " 'wrd: demaree',\n", 981 | " 'exp: D EH0 M ER0 IY1',\n", 982 | " 'got: D IH0 M AA1 R IY0',\n", 983 | " '',\n", 984 | " 'wrd: soltau',\n", 985 | " 'exp: S OW1 L T AW0',\n", 986 | " 'got: S OW1 L T OW0',\n", 987 | " '',\n", 988 | " 'wrd: methodically',\n", 989 | " 'exp: M AH0 TH AA1 D IH0 K AH0 L IY0',\n", 990 | " 'got: M EH2 TH AH0 D AA1 K AH0 L IY0',\n", 991 | " '',\n", 992 | " 'wrd: ahoskie',\n", 993 | " 'exp: AH0 HH AO1 S K IY0',\n", 994 | " 'got: AH0 HH AO1 S K IY0',\n", 995 | " '',\n", 996 | " 'wrd: mcivor',\n", 997 | " 'exp: M AH0 K IH1 V ER0',\n", 998 | " 'got: M AH0 K V EH1 R',\n", 999 | " '',\n", 1000 | " 'wrd: generalissimo',\n", 1001 | " 'exp: JH EH2 N EH0 R AH0 L IH1 S IH0 M OW2',\n", 1002 | " 'got: JH EH2 N ER0 AH0 L IH1 Z AH0 M IY0',\n", 1003 | " '',\n", 1004 | " \"wrd: kasinga's\",\n", 1005 | " 'exp: K AH0 S IH1 NG G AH0 Z',\n", 1006 | " 'got: K AH0 S IH1 N JH AH0 Z',\n", 1007 | " '',\n", 1008 | " 'wrd: currin',\n", 1009 | " 'exp: K AO1 R IH0 N',\n", 1010 | " 'got: K ER1 IH0 N',\n", 1011 | " '',\n", 1012 | " 'wrd: deregulatory',\n", 1013 | " 'exp: D IY0 R EH1 G Y AH0 L AH0 T AO2 R IY0',\n", 1014 | " 'got: D EH2 R AH0 G L Y AA1 T ER0 IY0',\n", 1015 | " '',\n", 1016 | " 'wrd: calbos',\n", 1017 | " 'exp: K AA1 L B OW0 S',\n", 1018 | " 'got: K AE1 L B OW0 Z',\n", 1019 | " '',\n", 1020 | " 'wrd: kreg',\n", 1021 | " 'exp: K R EH1 G',\n", 1022 | " 'got: K R EH1 G',\n", 1023 | " '',\n", 1024 | " 'wrd: dezarn',\n", 1025 | " 'exp: D EY0 Z AA1 R N',\n", 1026 | " 'got: D IH0 Z AA1 R N',\n", 1027 | " '',\n", 1028 | " 'wrd: rapprochement',\n", 1029 | " 'exp: R AE2 P R OW2 SH M AA1 N',\n", 1030 | " 'got: R AE2 P R AH0 K M EY1 N T',\n", 1031 | " '',\n", 1032 | " 'wrd: rosenshine',\n", 1033 | " 'exp: R OW1 Z AH0 N SH AY2 N',\n", 1034 | " 'got: R OW0 Z EH1 N SH IY0 N',\n", 1035 | " '']" 1036 | ] 1037 | }, 1038 | "execution_count": 90, 1039 | "metadata": {}, 1040 | "output_type": "execute_result" 1041 | } 1042 | ], 1043 | "source": [ 1044 | "open('result', 'r').read().splitlines()[-100:]" 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "code", 1049 | "execution_count": null, 1050 | "metadata": {}, 1051 | "outputs": [], 1052 | "source": [] 1053 | } 1054 | ], 1055 | "metadata": { 1056 | "colab": { 1057 | "name": "Seq2seq tutorial with g2p.ipynb", 1058 | "provenance": [], 1059 | "version": "0.3.2" 1060 | }, 1061 | "kernelspec": { 1062 | "display_name": "Python 3", 1063 | "language": "python", 1064 | "name": "python3" 1065 | }, 1066 | "language_info": { 1067 | "codemirror_mode": { 1068 | "name": "ipython", 1069 | "version": 3 1070 | }, 1071 | "file_extension": ".py", 1072 | "mimetype": "text/x-python", 1073 | "name": "python", 1074 | "nbconvert_exporter": "python", 1075 | "pygments_lexer": "ipython3", 1076 | "version": "3.7.1" 1077 | } 1078 | }, 1079 | "nbformat": 4, 1080 | "nbformat_minor": 1 1081 | } 1082 | -------------------------------------------------------------------------------- /dropout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/nlp_made_easy/c481472a9e0a6e922e92722f382874e67453898c/dropout.png -------------------------------------------------------------------------------- /no-dropout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kyubyong/nlp_made_easy/c481472a9e0a6e922e92722f382874e67453898c/no-dropout.png --------------------------------------------------------------------------------